Spaces:
Paused
Paused
| """String embedder.""" | |
| from functools import partial | |
| from typing import Literal | |
| import numpy as np | |
| from litellm import embedding | |
| from llama_cpp import LLAMA_POOLING_TYPE_NONE, Llama | |
| from tqdm.auto import tqdm, trange | |
| from raglite._config import RAGLiteConfig | |
| from raglite._litellm import LlamaCppPythonLLM | |
| from raglite._typing import FloatMatrix, IntVector | |
| def _embed_sentences_with_late_chunking( # noqa: PLR0915 | |
| sentences: list[str], *, config: RAGLiteConfig | None = None | |
| ) -> FloatMatrix: | |
| """Embed a document's sentences with late chunking.""" | |
| def _count_tokens( | |
| sentences: list[str], embedder: Llama, sentinel_char: str, sentinel_tokens: list[int] | |
| ) -> list[int]: | |
| # Join the sentences with the sentinel token and tokenise the result. | |
| sentences_tokens = np.asarray( | |
| embedder.tokenize(sentinel_char.join(sentences).encode(), add_bos=False), dtype=np.intp | |
| ) | |
| # Map all sentinel token variants to the first one. | |
| for sentinel_token in sentinel_tokens[1:]: | |
| sentences_tokens[sentences_tokens == sentinel_token] = sentinel_tokens[0] | |
| # Count how many tokens there are in between sentinel tokens to recover the token counts. | |
| sentinel_indices = np.where(sentences_tokens == sentinel_tokens[0])[0] | |
| num_tokens = np.diff(sentinel_indices, prepend=0, append=len(sentences_tokens)) | |
| assert len(num_tokens) == len(sentences), f"Sentinel `{sentinel_char}` appears in document" | |
| num_tokens_list: list[int] = num_tokens.tolist() | |
| return num_tokens_list | |
| def _create_segment( | |
| content_start_index: int, | |
| max_tokens_preamble: int, | |
| max_tokens_content: int, | |
| num_tokens: IntVector, | |
| ) -> tuple[int, int]: | |
| # Compute the segment sentence start index so that the segment preamble has no more than | |
| # max_tokens_preamble tokens between [segment_start_index, content_start_index). | |
| cumsum_backwards = np.cumsum(num_tokens[:content_start_index][::-1]) | |
| offset_preamble = np.searchsorted(cumsum_backwards, max_tokens_preamble, side="right") | |
| segment_start_index = content_start_index - int(offset_preamble) | |
| # Allow a larger segment content if we didn't use all of the allowed preamble tokens. | |
| max_tokens_content = max_tokens_content + ( | |
| max_tokens_preamble - np.sum(num_tokens[segment_start_index:content_start_index]) | |
| ) | |
| # Compute the segment sentence end index so that the segment content has no more than | |
| # max_tokens_content tokens between [content_start_index, segment_end_index). | |
| cumsum_forwards = np.cumsum(num_tokens[content_start_index:]) | |
| offset_segment = np.searchsorted(cumsum_forwards, max_tokens_content, side="right") | |
| segment_end_index = content_start_index + int(offset_segment) | |
| return segment_start_index, segment_end_index | |
| # Assert that we're using a llama-cpp-python model, since API-based embedding models don't | |
| # support outputting token-level embeddings. | |
| config = config or RAGLiteConfig() | |
| assert config.embedder.startswith("llama-cpp-python") | |
| embedder = LlamaCppPythonLLM.llm( | |
| config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE | |
| ) | |
| n_ctx = embedder.n_ctx() | |
| n_batch = embedder.n_batch | |
| # Identify the tokens corresponding to a sentinel character. | |
| sentinel_char = "⊕" | |
| sentinel_test = f"A{sentinel_char}B {sentinel_char} C.\n{sentinel_char}D" | |
| sentinel_tokens = [ | |
| token | |
| for token in embedder.tokenize(sentinel_test.encode(), add_bos=False) | |
| if sentinel_char in embedder.detokenize([token]).decode() | |
| ] | |
| assert len(sentinel_tokens), f"Sentinel `{sentinel_char}` not supported by embedder" | |
| # Compute the number of tokens per sentence. We use a method based on a sentinel token to | |
| # minimise the number of calls to embedder.tokenize, which incurs a significant overhead | |
| # (presumably to load the tokenizer) [1]. | |
| # TODO: Make token counting faster and more robust once [1] is fixed. | |
| # [1] https://github.com/abetlen/llama-cpp-python/issues/1763 | |
| num_tokens_list: list[int] = [] | |
| sentence_batch, sentence_batch_len = [], 0 | |
| for i, sentence in enumerate(sentences): | |
| sentence_batch.append(sentence) | |
| sentence_batch_len += len(sentence) | |
| if i == len(sentences) - 1 or sentence_batch_len > (n_ctx // 2): | |
| num_tokens_list.extend( | |
| _count_tokens(sentence_batch, embedder, sentinel_char, sentinel_tokens) | |
| ) | |
| sentence_batch, sentence_batch_len = [], 0 | |
| num_tokens = np.asarray(num_tokens_list, dtype=np.intp) | |
| # Compute the maximum number of tokens for each segment's preamble and content. | |
| # Unfortunately, llama-cpp-python truncates the input to n_batch tokens and crashes if you try | |
| # to increase it [1]. Until this is fixed, we have to limit max_tokens to n_batch. | |
| # TODO: Improve the context window size once [1] is fixed. | |
| # [1] https://github.com/abetlen/llama-cpp-python/issues/1762 | |
| max_tokens = min(n_ctx, n_batch) - 16 | |
| max_tokens_preamble = round(0.382 * max_tokens) # Golden ratio. | |
| max_tokens_content = max_tokens - max_tokens_preamble | |
| # Compute a list of segments, each consisting of a preamble and content. | |
| segments = [] | |
| content_start_index = 0 | |
| while content_start_index < len(sentences): | |
| segment_start_index, segment_end_index = _create_segment( | |
| content_start_index, max_tokens_preamble, max_tokens_content, num_tokens | |
| ) | |
| segments.append((segment_start_index, content_start_index, segment_end_index)) | |
| content_start_index = segment_end_index | |
| # Embed the segments and apply late chunking. | |
| sentence_embeddings_list: list[FloatMatrix] = [] | |
| if len(segments) > 1 or segments[0][2] > 128: # noqa: PLR2004 | |
| segments = tqdm(segments, desc="Embedding", unit="segment", dynamic_ncols=True) | |
| for segment in segments: | |
| # Get the token embeddings of the entire segment, including preamble and content. | |
| segment_start_index, content_start_index, segment_end_index = segment | |
| segment_sentences = sentences[segment_start_index:segment_end_index] | |
| segment_embedding = np.asarray(embedder.embed("".join(segment_sentences))) | |
| # Split the segment embeddings into embedding matrices per sentence. | |
| segment_tokens = num_tokens[segment_start_index:segment_end_index] | |
| sentence_size = np.round( | |
| len(segment_embedding) * (segment_tokens / np.sum(segment_tokens)) | |
| ).astype(np.intp) | |
| sentence_matrices = np.split(segment_embedding, np.cumsum(sentence_size)[:-1]) | |
| # Compute the segment sentence embeddings by averaging the token embeddings. | |
| content_sentence_embeddings = [ | |
| np.mean(sentence_matrix, axis=0, keepdims=True) | |
| for sentence_matrix in sentence_matrices[content_start_index - segment_start_index :] | |
| ] | |
| sentence_embeddings_list.append(np.vstack(content_sentence_embeddings)) | |
| sentence_embeddings = np.vstack(sentence_embeddings_list) | |
| # Normalise the sentence embeddings to unit norm and cast to half precision. | |
| if config.embedder_normalize: | |
| sentence_embeddings /= np.linalg.norm(sentence_embeddings, axis=1, keepdims=True) | |
| sentence_embeddings = sentence_embeddings.astype(np.float16) | |
| return sentence_embeddings | |
| def _embed_sentences_with_windowing( | |
| sentences: list[str], *, config: RAGLiteConfig | None = None | |
| ) -> FloatMatrix: | |
| """Embed a document's sentences with windowing.""" | |
| def _embed_string_batch(string_batch: list[str], *, config: RAGLiteConfig) -> FloatMatrix: | |
| # Embed the batch of strings. | |
| if config.embedder.startswith("llama-cpp-python"): | |
| # LiteLLM doesn't yet support registering a custom embedder, so we handle it here. | |
| # Additionally, we explicitly manually pool the token embeddings to obtain sentence | |
| # embeddings because token embeddings are universally supported, while sequence | |
| # embeddings are only supported by some models. | |
| embedder = LlamaCppPythonLLM.llm( | |
| config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE | |
| ) | |
| embeddings = np.asarray([np.mean(row, axis=0) for row in embedder.embed(string_batch)]) | |
| else: | |
| # Use LiteLLM's API to embed the batch of strings. | |
| response = embedding(config.embedder, string_batch) | |
| embeddings = np.asarray([item["embedding"] for item in response["data"]]) | |
| # Normalise the embeddings to unit norm and cast to half precision. | |
| if config.embedder_normalize: | |
| embeddings /= np.linalg.norm(embeddings, axis=1, keepdims=True) | |
| embeddings = embeddings.astype(np.float16) | |
| return embeddings | |
| # Window the sentences with a lookback of `config.embedder_sentence_window_size - 1` sentences. | |
| config = config or RAGLiteConfig() | |
| sentence_windows = [ | |
| "".join(sentences[max(0, i - (config.embedder_sentence_window_size - 1)) : i + 1]) | |
| for i in range(len(sentences)) | |
| ] | |
| # Embed the sentence windows in batches. | |
| batch_size = 64 | |
| batch_range = ( | |
| partial(trange, desc="Embedding", unit="batch", dynamic_ncols=True) | |
| if len(sentence_windows) > batch_size | |
| else range | |
| ) | |
| batch_embeddings = [ | |
| _embed_string_batch(sentence_windows[i : i + batch_size], config=config) | |
| for i in batch_range(0, len(sentence_windows), batch_size) # type: ignore[operator] | |
| ] | |
| sentence_embeddings = np.vstack(batch_embeddings) | |
| return sentence_embeddings | |
| def sentence_embedding_type( | |
| *, | |
| config: RAGLiteConfig | None = None, | |
| ) -> Literal["late_chunking", "windowing"]: | |
| """Return the type of sentence embeddings.""" | |
| config = config or RAGLiteConfig() | |
| return "late_chunking" if config.embedder.startswith("llama-cpp-python") else "windowing" | |
| def embed_sentences(sentences: list[str], *, config: RAGLiteConfig | None = None) -> FloatMatrix: | |
| """Embed the sentences of a document as a NumPy matrix with one row per sentence.""" | |
| config = config or RAGLiteConfig() | |
| if sentence_embedding_type(config=config) == "late_chunking": | |
| sentence_embeddings = _embed_sentences_with_late_chunking(sentences, config=config) | |
| else: | |
| sentence_embeddings = _embed_sentences_with_windowing(sentences, config=config) | |
| return sentence_embeddings | |