Spaces:
Paused
Paused
| """RAGLite typing.""" | |
| import io | |
| import pickle | |
| from collections.abc import Callable | |
| from typing import Any, Protocol | |
| import numpy as np | |
| from sqlalchemy.engine import Dialect | |
| from sqlalchemy.sql.operators import Operators | |
| from sqlalchemy.types import Float, LargeBinary, TypeDecorator, TypeEngine, UserDefinedType | |
| from raglite._config import RAGLiteConfig | |
| FloatMatrix = np.ndarray[tuple[int, int], np.dtype[np.floating[Any]]] | |
| FloatVector = np.ndarray[tuple[int], np.dtype[np.floating[Any]]] | |
| IntVector = np.ndarray[tuple[int], np.dtype[np.intp]] | |
| class SearchMethod(Protocol): | |
| def __call__( | |
| self, query: str, *, num_results: int = 3, config: RAGLiteConfig | None = None | |
| ) -> tuple[list[str], list[float]]: ... | |
| class NumpyArray(TypeDecorator[np.ndarray[Any, np.dtype[np.floating[Any]]]]): | |
| """A NumPy array column type for SQLAlchemy.""" | |
| impl = LargeBinary | |
| def process_bind_param( | |
| self, value: np.ndarray[Any, np.dtype[np.floating[Any]]] | None, dialect: Dialect | |
| ) -> bytes | None: | |
| """Convert a NumPy array to bytes.""" | |
| if value is None: | |
| return None | |
| buffer = io.BytesIO() | |
| np.save(buffer, value, allow_pickle=False, fix_imports=False) | |
| return buffer.getvalue() | |
| def process_result_value( | |
| self, value: bytes | None, dialect: Dialect | |
| ) -> np.ndarray[Any, np.dtype[np.floating[Any]]] | None: | |
| """Convert bytes to a NumPy array.""" | |
| if value is None: | |
| return None | |
| return np.load(io.BytesIO(value), allow_pickle=False, fix_imports=False) # type: ignore[no-any-return] | |
| class PickledObject(TypeDecorator[object]): | |
| """A pickled object column type for SQLAlchemy.""" | |
| impl = LargeBinary | |
| def process_bind_param(self, value: object | None, dialect: Dialect) -> bytes | None: | |
| """Convert a Python object to bytes.""" | |
| if value is None: | |
| return None | |
| return pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL, fix_imports=False) | |
| def process_result_value(self, value: bytes | None, dialect: Dialect) -> object | None: | |
| """Convert bytes to a Python object.""" | |
| if value is None: | |
| return None | |
| return pickle.loads(value, fix_imports=False) # type: ignore[no-any-return] # noqa: S301 | |
| class HalfVecComparatorMixin(UserDefinedType.Comparator[FloatVector]): | |
| """A mixin that provides comparison operators for halfvecs.""" | |
| def cosine_distance(self, other: FloatVector) -> Operators: | |
| """Compute the cosine distance.""" | |
| return self.op("<=>", return_type=Float)(other) | |
| def dot_distance(self, other: FloatVector) -> Operators: | |
| """Compute the dot product distance.""" | |
| return self.op("<#>", return_type=Float)(other) | |
| def euclidean_distance(self, other: FloatVector) -> Operators: | |
| """Compute the Euclidean distance.""" | |
| return self.op("<->", return_type=Float)(other) | |
| def l1_distance(self, other: FloatVector) -> Operators: | |
| """Compute the L1 distance.""" | |
| return self.op("<+>", return_type=Float)(other) | |
| def l2_distance(self, other: FloatVector) -> Operators: | |
| """Compute the L2 distance.""" | |
| return self.op("<->", return_type=Float)(other) | |
| class HalfVec(UserDefinedType[FloatVector]): | |
| """A PostgreSQL half-precision vector column type for SQLAlchemy.""" | |
| cache_ok = True # HalfVec is immutable. | |
| def __init__(self, dim: int | None = None) -> None: | |
| super().__init__() | |
| self.dim = dim | |
| def get_col_spec(self, **kwargs: Any) -> str: | |
| return f"halfvec({self.dim})" | |
| def bind_processor(self, dialect: Dialect) -> Callable[[FloatVector | None], str | None]: | |
| """Process NumPy ndarray to PostgreSQL halfvec format for bound parameters.""" | |
| def process(value: FloatVector | None) -> str | None: | |
| return f"[{','.join(str(x) for x in np.ravel(value))}]" if value is not None else None | |
| return process | |
| def result_processor( | |
| self, dialect: Dialect, coltype: Any | |
| ) -> Callable[[str | None], FloatVector | None]: | |
| """Process PostgreSQL halfvec format to NumPy ndarray.""" | |
| def process(value: str | None) -> FloatVector | None: | |
| if value is None: | |
| return None | |
| return np.fromstring(value.strip("[]"), sep=",", dtype=np.float16) | |
| return process | |
| class comparator_factory(HalfVecComparatorMixin): # noqa: N801 | |
| ... | |
| class Embedding(TypeDecorator[FloatVector]): | |
| """An embedding column type for SQLAlchemy.""" | |
| cache_ok = True # Embedding is immutable. | |
| impl = NumpyArray | |
| def __init__(self, dim: int = -1): | |
| super().__init__() | |
| self.dim = dim | |
| def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[FloatVector]: | |
| if dialect.name == "postgresql": | |
| return dialect.type_descriptor(HalfVec(self.dim)) | |
| return dialect.type_descriptor(NumpyArray()) | |
| class comparator_factory(HalfVecComparatorMixin): # noqa: N801 | |
| ... | |