Spaces:
Paused
Paused
| """Compute and update an optimal query adapter.""" | |
| import numpy as np | |
| from sqlmodel import Session, col, select | |
| from tqdm.auto import tqdm | |
| from raglite._config import RAGLiteConfig | |
| from raglite._database import Chunk, ChunkEmbedding, Eval, IndexMetadata, create_database_engine | |
| from raglite._embed import embed_sentences | |
| from raglite._search import vector_search | |
| def update_query_adapter( # noqa: PLR0915, C901 | |
| *, | |
| max_triplets: int = 4096, | |
| max_triplets_per_eval: int = 64, | |
| optimize_top_k: int = 40, | |
| config: RAGLiteConfig | None = None, | |
| ) -> None: | |
| """Compute an optimal query adapter and update the database with it. | |
| This function computes an optimal linear transform A, called a 'query adapter', that is used to | |
| transform a query embedding q as A @ q before searching for the nearest neighbouring chunks in | |
| order to improve the quality of the search results. | |
| Given a set of triplets (qᵢ, pᵢ, nᵢ), we want to find the query adapter A that increases the | |
| score pᵢ'qᵢ of the positive chunk pᵢ and decreases the score nᵢ'qᵢ of the negative chunk nᵢ. | |
| If the nearest neighbour search uses the dot product as its relevance score, we can find the | |
| optimal query adapter by solving the following relaxed Procrustes optimisation problem with a | |
| bound on the Frobenius norm of A: | |
| A* = argmax Σᵢ pᵢ' (A qᵢ) - nᵢ' (A qᵢ) | |
| Σᵢ (pᵢ - nᵢ)' A qᵢ | |
| trace[ (P - N) A Q' ] where Q := [q₁'; ...; qₖ'] | |
| P := [p₁'; ...; pₖ'] | |
| N := [n₁'; ...; nₖ'] | |
| trace[ Q' (P - N) A ] | |
| trace[ M A ] where M := Q' (P - N) | |
| s.t. ||A||_F == 1 | |
| = M' / ||M||_F | |
| If the nearest neighbour search uses the cosine similarity as its relevance score, we can find | |
| the optimal query adapter by solving the following orthogonal Procrustes optimisation problem | |
| with an orthogonality constraint on A: | |
| A* = argmax Σᵢ pᵢ' (A qᵢ) - nᵢ' (A qᵢ) | |
| Σᵢ (pᵢ - nᵢ)' A qᵢ | |
| trace[ (P - N) A Q' ] | |
| trace[ Q' (P - N) A ] | |
| trace[ M A ] | |
| trace[ U Σ V' A ] where U Σ V' := M is the SVD of M | |
| trace[ Σ V' A U ] | |
| s.t. A'A == 𝕀 | |
| = V U' | |
| Additionally, we want to limit the effect of A* so that it adjusts q just enough to invert | |
| incorrectly ordered (q, p, n) triplets, but not so much as to affect the correctly ordered ones. | |
| To achieve this, we'll rewrite M as α(M / s) + (1 - α)𝕀, where s scales M to the same norm as 𝕀, | |
| and choose the smallest α that ranks (q, p, n) correctly. If α = 0, the relevance score gap | |
| between an incorrect (p, n) pair would be B := (p - n)' q < 0. If α = 1, the relevance score gap | |
| would be A := (p - n)' (p - n) / ||p - n|| > 0. For a target relevance score gap of say | |
| C := 5% * A, the optimal α is then given by αA + (1 - α)B = C => α = (B - C) / (B - A). | |
| """ | |
| config = config or RAGLiteConfig() | |
| config_no_query_adapter = RAGLiteConfig( | |
| **{**config.__dict__, "vector_search_query_adapter": False} | |
| ) | |
| engine = create_database_engine(config) | |
| with Session(engine) as session: | |
| # Get random evals from the database. | |
| chunk_embedding = session.exec(select(ChunkEmbedding).limit(1)).first() | |
| if chunk_embedding is None: | |
| error_message = "First run `insert_document()` to insert documents." | |
| raise ValueError(error_message) | |
| evals = session.exec( | |
| select(Eval).order_by(Eval.id).limit(max(8, max_triplets // max_triplets_per_eval)) | |
| ).all() | |
| if len(evals) * max_triplets_per_eval < len(chunk_embedding.embedding): | |
| error_message = "First run `insert_evals()` to generate sufficient evals." | |
| raise ValueError(error_message) | |
| # Loop over the evals to generate (q, p, n) triplets. | |
| Q = np.zeros((0, len(chunk_embedding.embedding))) # noqa: N806 | |
| P = np.zeros_like(Q) # noqa: N806 | |
| N = np.zeros_like(Q) # noqa: N806 | |
| for eval_ in tqdm( | |
| evals, desc="Extracting triplets from evals", unit="eval", dynamic_ncols=True | |
| ): | |
| # Embed the question. | |
| question_embedding = embed_sentences([eval_.question], config=config) | |
| # Retrieve chunks that would be used to answer the question. | |
| chunk_ids, _ = vector_search( | |
| question_embedding, num_results=optimize_top_k, config=config_no_query_adapter | |
| ) | |
| retrieved_chunks = session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all() | |
| # Extract (q, p, n) triplets by comparing the retrieved chunks with the eval. | |
| num_triplets = 0 | |
| for i, retrieved_chunk in enumerate(retrieved_chunks): | |
| # Select irrelevant chunks. | |
| if retrieved_chunk.id not in eval_.chunk_ids: | |
| # Look up all positive chunks (each represented by the mean of its multi-vector | |
| # embedding) that are ranked lower than this negative one (represented by the | |
| # embedding in the multi-vector embedding that best matches the query). | |
| p_mean = [ | |
| np.mean(chunk.embedding_matrix, axis=0, keepdims=True) | |
| for chunk in retrieved_chunks[i + 1 :] | |
| if chunk is not None and chunk.id in eval_.chunk_ids | |
| ] | |
| n_top = retrieved_chunk.embedding_matrix[ | |
| np.argmax(retrieved_chunk.embedding_matrix @ question_embedding.T), | |
| np.newaxis, | |
| :, | |
| ] | |
| # Filter out any (p, n, q) triplets for which the mean positive embedding ranks | |
| # higher than the top negative one. | |
| p_mean = [p_e for p_e in p_mean if (n_top - p_e) @ question_embedding.T > 0] | |
| if not p_mean: | |
| continue | |
| # Stack the (p, n, q) triplets. | |
| p = np.vstack(p_mean) | |
| n = np.repeat(n_top, p.shape[0], axis=0) | |
| q = np.repeat(question_embedding, p.shape[0], axis=0) | |
| num_triplets += p.shape[0] | |
| # Append the (query, positive, negative) tuples to the Q, P, N matrices. | |
| Q = np.vstack([Q, q]) # noqa: N806 | |
| P = np.vstack([P, p]) # noqa: N806 | |
| N = np.vstack([N, n]) # noqa: N806 | |
| # Check if we have sufficient triplets for this eval. | |
| if num_triplets >= max_triplets_per_eval: | |
| break | |
| # Check if we have sufficient triplets to compute the query adapter. | |
| if Q.shape[0] > max_triplets: | |
| Q, P, N = Q[:max_triplets, :], P[:max_triplets, :], N[:max_triplets, :] # noqa: N806 | |
| break | |
| # Normalise the rows of Q, P, N. | |
| Q /= np.linalg.norm(Q, axis=1, keepdims=True) # noqa: N806 | |
| P /= np.linalg.norm(P, axis=1, keepdims=True) # noqa: N806 | |
| N /= np.linalg.norm(N, axis=1, keepdims=True) # noqa: N806 | |
| # Compute the optimal weighted query adapter A*. | |
| # TODO: Matmul in float16 is extremely slow compared to single or double precision, why? | |
| gap_before = np.sum((P - N) * Q, axis=1) | |
| gap_after = 2 * (1 - np.sum(P * N, axis=1)) / np.linalg.norm(P - N, axis=1) | |
| gap_target = 0.05 * gap_after | |
| α = (gap_before - gap_target) / (gap_before - gap_after) # noqa: PLC2401 | |
| MT = (α[:, np.newaxis] * (P - N)).T @ Q # noqa: N806 | |
| s = np.linalg.norm(MT, ord="fro") / np.sqrt(MT.shape[0]) | |
| MT = np.mean(α) * (MT / s) + np.mean(1 - α) * np.eye(Q.shape[1]) # noqa: N806 | |
| if config.vector_search_index_metric == "dot": | |
| # Use the relaxed Procrustes solution. | |
| A_star = MT / np.linalg.norm(MT, ord="fro") # noqa: N806 | |
| elif config.vector_search_index_metric == "cosine": | |
| # Use the orthogonal Procrustes solution. | |
| U, _, VT = np.linalg.svd(MT, full_matrices=False) # noqa: N806 | |
| A_star = U @ VT # noqa: N806 | |
| else: | |
| error_message = f"Unsupported ANN metric: {config.vector_search_index_metric}" | |
| raise ValueError(error_message) | |
| # Store the optimal query adapter in the database. | |
| index_metadata = session.get(IndexMetadata, "default") or IndexMetadata(id="default") | |
| index_metadata.metadata_ = {**index_metadata.metadata_, "query_adapter": A_star} | |
| session.add(index_metadata) | |
| session.commit() | |