Spaces:
Paused
Paused
| """Query documents.""" | |
| import re | |
| import string | |
| from collections import defaultdict | |
| from collections.abc import Sequence | |
| from itertools import groupby | |
| from typing import cast | |
| import numpy as np | |
| from langdetect import detect | |
| from sqlalchemy.engine import make_url | |
| from sqlmodel import Session, and_, col, or_, select, text | |
| from raglite._config import RAGLiteConfig | |
| from raglite._database import Chunk, ChunkEmbedding, IndexMetadata, create_database_engine | |
| from raglite._embed import embed_sentences | |
| from raglite._typing import FloatMatrix | |
| def vector_search( | |
| query: str | FloatMatrix, | |
| *, | |
| num_results: int = 3, | |
| config: RAGLiteConfig | None = None, | |
| ) -> tuple[list[str], list[float]]: | |
| """Search chunks using ANN vector search.""" | |
| # Read the config. | |
| config = config or RAGLiteConfig() | |
| db_backend = make_url(config.db_url).get_backend_name() | |
| # Get the index metadata (including the query adapter, and in the case of SQLite, the index). | |
| index_metadata = IndexMetadata.get("default", config=config) | |
| # Embed the query. | |
| query_embedding = ( | |
| embed_sentences([query], config=config)[0, :] if isinstance(query, str) else np.ravel(query) | |
| ) | |
| # Apply the query adapter to the query embedding. | |
| Q = index_metadata.get("query_adapter") # noqa: N806 | |
| if config.vector_search_query_adapter and Q is not None: | |
| query_embedding = (Q @ query_embedding).astype(query_embedding.dtype) | |
| # Search for the multi-vector chunk embeddings that are most similar to the query embedding. | |
| if db_backend == "postgresql": | |
| # Check that the selected metric is supported by pgvector. | |
| metrics = {"cosine": "<=>", "dot": "<#>", "euclidean": "<->", "l1": "<+>", "l2": "<->"} | |
| if config.vector_search_index_metric not in metrics: | |
| error_message = f"Unsupported metric {config.vector_search_index_metric}." | |
| raise ValueError(error_message) | |
| # With pgvector, we can obtain the nearest neighbours and similarities with a single query. | |
| engine = create_database_engine(config) | |
| with Session(engine) as session: | |
| distance_func = getattr( | |
| ChunkEmbedding.embedding, f"{config.vector_search_index_metric}_distance" | |
| ) | |
| distance = distance_func(query_embedding).label("distance") | |
| results = session.exec( | |
| select(ChunkEmbedding.chunk_id, distance).order_by(distance).limit(8 * num_results) | |
| ) | |
| chunk_ids_, distance = zip(*results, strict=True) | |
| chunk_ids, similarity = np.asarray(chunk_ids_), 1.0 - np.asarray(distance) | |
| elif db_backend == "sqlite": | |
| # Load the NNDescent index. | |
| index = index_metadata.get("index") | |
| ids = np.asarray(index_metadata.get("chunk_ids")) | |
| cumsum = np.cumsum(np.asarray(index_metadata.get("chunk_sizes"))) | |
| # Find the neighbouring multi-vector indices. | |
| from pynndescent import NNDescent | |
| multi_vector_indices, distance = cast(NNDescent, index).query( | |
| query_embedding[np.newaxis, :], k=8 * num_results | |
| ) | |
| similarity = 1 - distance[0, :] | |
| # Transform the multi-vector indices into chunk indices, and then to chunk ids. | |
| chunk_indices = np.searchsorted(cumsum, multi_vector_indices[0, :], side="right") + 1 | |
| chunk_ids = np.asarray([ids[chunk_index - 1] for chunk_index in chunk_indices]) | |
| # Score each unique chunk id as the mean similarity of its multi-vector hits. Chunk ids with | |
| # fewer hits are padded with the minimum similarity of the result set. | |
| unique_chunk_ids, counts = np.unique(chunk_ids, return_counts=True) | |
| score = np.full( | |
| (len(unique_chunk_ids), np.max(counts)), np.min(similarity), dtype=similarity.dtype | |
| ) | |
| for i, (unique_chunk_id, count) in enumerate(zip(unique_chunk_ids, counts, strict=True)): | |
| score[i, :count] = similarity[chunk_ids == unique_chunk_id] | |
| pooled_similarity = np.mean(score, axis=1) | |
| # Sort the chunk ids by their adjusted similarity. | |
| sorted_indices = np.argsort(pooled_similarity)[::-1] | |
| unique_chunk_ids = unique_chunk_ids[sorted_indices][:num_results] | |
| pooled_similarity = pooled_similarity[sorted_indices][:num_results] | |
| return unique_chunk_ids.tolist(), pooled_similarity.tolist() | |
| def keyword_search( | |
| query: str, *, num_results: int = 3, config: RAGLiteConfig | None = None | |
| ) -> tuple[list[str], list[float]]: | |
| """Search chunks using BM25 keyword search.""" | |
| # Read the config. | |
| config = config or RAGLiteConfig() | |
| db_backend = make_url(config.db_url).get_backend_name() | |
| # Connect to the database. | |
| engine = create_database_engine(config) | |
| with Session(engine) as session: | |
| if db_backend == "postgresql": | |
| # Convert the query to a tsquery [1]. | |
| # [1] https://www.postgresql.org/docs/current/textsearch-controls.html | |
| query_escaped = re.sub(r"[&|!():<>\"]", " ", query) | |
| tsv_query = " | ".join(query_escaped.split()) | |
| # Perform keyword search with tsvector. | |
| statement = text(""" | |
| SELECT id as chunk_id, ts_rank(to_tsvector('simple', body), to_tsquery('simple', :query)) AS score | |
| FROM chunk | |
| WHERE to_tsvector('simple', body) @@ to_tsquery('simple', :query) | |
| ORDER BY score DESC | |
| LIMIT :limit; | |
| """) | |
| results = session.execute(statement, params={"query": tsv_query, "limit": num_results}) | |
| elif db_backend == "sqlite": | |
| # Convert the query to an FTS5 query [1]. | |
| # [1] https://www.sqlite.org/fts5.html#full_text_query_syntax | |
| query_escaped = re.sub(f"[{re.escape(string.punctuation)}]", "", query) | |
| fts5_query = " OR ".join(query_escaped.split()) | |
| # Perform keyword search with FTS5. In FTS5, BM25 scores are negative [1], so we | |
| # negate them to make them positive. | |
| # [1] https://www.sqlite.org/fts5.html#the_bm25_function | |
| statement = text(""" | |
| SELECT chunk.id as chunk_id, -bm25(keyword_search_chunk_index) as score | |
| FROM chunk JOIN keyword_search_chunk_index ON chunk.rowid = keyword_search_chunk_index.rowid | |
| WHERE keyword_search_chunk_index MATCH :match | |
| ORDER BY score DESC | |
| LIMIT :limit; | |
| """) | |
| results = session.execute(statement, params={"match": fts5_query, "limit": num_results}) | |
| # Unpack the results. | |
| chunk_ids, keyword_score = zip(*results, strict=True) | |
| chunk_ids, keyword_score = list(chunk_ids), list(keyword_score) # type: ignore[assignment] | |
| return chunk_ids, keyword_score # type: ignore[return-value] | |
| def reciprocal_rank_fusion( | |
| rankings: list[list[str]], *, k: int = 60 | |
| ) -> tuple[list[str], list[float]]: | |
| """Reciprocal Rank Fusion.""" | |
| # Compute the RRF score. | |
| chunk_ids = {chunk_id for ranking in rankings for chunk_id in ranking} | |
| chunk_id_score: defaultdict[str, float] = defaultdict(float) | |
| for ranking in rankings: | |
| chunk_id_index = {chunk_id: i for i, chunk_id in enumerate(ranking)} | |
| for chunk_id in chunk_ids: | |
| chunk_id_score[chunk_id] += 1 / (k + chunk_id_index.get(chunk_id, len(chunk_id_index))) | |
| # Rank RRF results according to descending RRF score. | |
| rrf_chunk_ids, rrf_score = zip( | |
| *sorted(chunk_id_score.items(), key=lambda x: x[1], reverse=True), strict=True | |
| ) | |
| return list(rrf_chunk_ids), list(rrf_score) | |
| def hybrid_search( | |
| query: str, *, num_results: int = 3, num_rerank: int = 100, config: RAGLiteConfig | None = None | |
| ) -> tuple[list[str], list[float]]: | |
| """Search chunks by combining ANN vector search with BM25 keyword search.""" | |
| # Run both searches. | |
| vs_chunk_ids, _ = vector_search(query, num_results=num_rerank, config=config) | |
| ks_chunk_ids, _ = keyword_search(query, num_results=num_rerank, config=config) | |
| # Combine the results with Reciprocal Rank Fusion (RRF). | |
| chunk_ids, hybrid_score = reciprocal_rank_fusion([vs_chunk_ids, ks_chunk_ids]) | |
| chunk_ids, hybrid_score = chunk_ids[:num_results], hybrid_score[:num_results] | |
| return chunk_ids, hybrid_score | |
| def retrieve_chunks( | |
| chunk_ids: list[str], | |
| *, | |
| config: RAGLiteConfig | None = None, | |
| ) -> list[Chunk]: | |
| """Retrieve chunks by their ids.""" | |
| config = config or RAGLiteConfig() | |
| engine = create_database_engine(config) | |
| with Session(engine) as session: | |
| chunks = list(session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all()) | |
| chunks = sorted(chunks, key=lambda chunk: chunk_ids.index(chunk.id)) | |
| return chunks | |
| def retrieve_segments( | |
| chunk_ids: list[str] | list[Chunk], | |
| *, | |
| neighbors: tuple[int, ...] | None = (-1, 1), | |
| config: RAGLiteConfig | None = None, | |
| ) -> list[str]: | |
| """Group chunks into contiguous segments and retrieve them.""" | |
| # Retrieve the chunks. | |
| config = config or RAGLiteConfig() | |
| chunks: list[Chunk] = ( | |
| retrieve_chunks(chunk_ids, config=config) # type: ignore[arg-type,assignment] | |
| if all(isinstance(chunk_id, str) for chunk_id in chunk_ids) | |
| else chunk_ids | |
| ) | |
| # Extend the chunks with their neighbouring chunks. | |
| if neighbors: | |
| engine = create_database_engine(config) | |
| with Session(engine) as session: | |
| neighbor_conditions = [ | |
| and_(Chunk.document_id == chunk.document_id, Chunk.index == chunk.index + offset) | |
| for chunk in chunks | |
| for offset in neighbors | |
| ] | |
| chunks += list(session.exec(select(Chunk).where(or_(*neighbor_conditions))).all()) | |
| # Keep only the unique chunks. | |
| chunks = list(set(chunks)) | |
| # Sort the chunks by document_id and index (needed for groupby). | |
| chunks = sorted(chunks, key=lambda chunk: (chunk.document_id, chunk.index)) | |
| # Group the chunks into contiguous segments. | |
| segments: list[list[Chunk]] = [] | |
| for _, group in groupby(chunks, key=lambda chunk: chunk.document_id): | |
| segment: list[Chunk] = [] | |
| for chunk in group: | |
| if not segment or chunk.index == segment[-1].index + 1: | |
| segment.append(chunk) | |
| else: | |
| segments.append(segment) | |
| segment = [chunk] | |
| segments.append(segment) | |
| # Rank segments according to the aggregate relevance of their chunks. | |
| chunk_id_to_score = {chunk.id: 1 / (i + 1) for i, chunk in enumerate(chunks)} | |
| segments.sort( | |
| key=lambda segment: sum(chunk_id_to_score.get(chunk.id, 0.0) for chunk in segment), | |
| reverse=True, | |
| ) | |
| # Convert the segments into strings. | |
| segments = [ | |
| segment[0].headings.strip() + "\n\n" + "".join(chunk.body for chunk in segment).strip() # type: ignore[misc] | |
| for segment in segments | |
| ] | |
| return segments # type: ignore[return-value] | |
| def rerank_chunks( | |
| query: str, | |
| chunk_ids: list[str] | list[Chunk], | |
| *, | |
| config: RAGLiteConfig | None = None, | |
| ) -> list[Chunk]: | |
| """Rerank chunks according to their relevance to a given query.""" | |
| # Retrieve the chunks. | |
| config = config or RAGLiteConfig() | |
| chunks: list[Chunk] = ( | |
| retrieve_chunks(chunk_ids, config=config) # type: ignore[arg-type,assignment] | |
| if all(isinstance(chunk_id, str) for chunk_id in chunk_ids) | |
| else chunk_ids | |
| ) | |
| # Early exit if no reranker is configured. | |
| if not config.reranker: | |
| return chunks | |
| # Select the reranker. | |
| if isinstance(config.reranker, Sequence): | |
| # Detect the languages of the chunks and queries. | |
| langs = {detect(str(chunk)) for chunk in chunks} | |
| langs.add(detect(query)) | |
| # If all chunks and the query are in the same language, use a language-specific reranker. | |
| rerankers = dict(config.reranker) | |
| if len(langs) == 1 and (lang := next(iter(langs))) in rerankers: | |
| reranker = rerankers[lang] | |
| else: | |
| reranker = rerankers.get("other") | |
| else: | |
| # A specific reranker was configured. | |
| reranker = config.reranker | |
| # Rerank the chunks. | |
| if reranker: | |
| results = reranker.rank(query=query, docs=[str(chunk) for chunk in chunks]) | |
| chunks = [chunks[result.doc_id] for result in results.results] | |
| return chunks | |