Spaces:
Paused
Paused
| """Add support for llama-cpp-python models to LiteLLM.""" | |
| import asyncio | |
| import logging | |
| import warnings | |
| from collections.abc import AsyncIterator, Callable, Iterator | |
| from functools import cache | |
| from typing import Any, ClassVar, cast | |
| import httpx | |
| import litellm | |
| from litellm import ( # type: ignore[attr-defined] | |
| CustomLLM, | |
| GenericStreamingChunk, | |
| ModelResponse, | |
| convert_to_model_response_object, | |
| ) | |
| from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler | |
| from llama_cpp import ( # type: ignore[attr-defined] | |
| ChatCompletionRequestMessage, | |
| CreateChatCompletionResponse, | |
| CreateChatCompletionStreamResponse, | |
| Llama, | |
| LlamaRAMCache, | |
| ) | |
| # Reduce the logging level for LiteLLM and flashrank. | |
| logging.getLogger("litellm").setLevel(logging.WARNING) | |
| logging.getLogger("flashrank").setLevel(logging.WARNING) | |
| class LlamaCppPythonLLM(CustomLLM): | |
| """A llama-cpp-python provider for LiteLLM. | |
| This provider enables using llama-cpp-python models with LiteLLM. The LiteLLM model | |
| specification is "llama-cpp-python/<hugging_face_repo_id>/<filename>@<n_ctx>", where n_ctx is | |
| an optional parameter that specifies the context size of the model. If n_ctx is not provided or | |
| if it's set to 0, the model's default context size is used. | |
| Example usage: | |
| ```python | |
| from litellm import completion | |
| response = completion( | |
| model="llama-cpp-python/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/*Q4_K_M.gguf@4092", | |
| messages=[{"role": "user", "content": "Hello world!"}], | |
| # stream=True | |
| ) | |
| ``` | |
| """ | |
| # Create a lock to prevent concurrent access to llama-cpp-python models. | |
| streaming_lock: ClassVar[asyncio.Lock] = asyncio.Lock() | |
| # The set of supported OpenAI parameters is the intersection of [1] and [2]. Not included: | |
| # max_completion_tokens, stream_options, n, user, logprobs, top_logprobs, extra_headers. | |
| # [1] https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion | |
| # [2] https://docs.litellm.ai/docs/completion/input | |
| supported_openai_params: ClassVar[list[str]] = [ | |
| "functions", # Deprecated | |
| "function_call", # Deprecated | |
| "tools", | |
| "tool_choice", | |
| "temperature", | |
| "top_p", | |
| "top_k", | |
| "min_p", | |
| "typical_p", | |
| "stop", | |
| "seed", | |
| "response_format", | |
| "max_tokens", | |
| "presence_penalty", | |
| "frequency_penalty", | |
| "repeat_penalty", | |
| "tfs_z", | |
| "mirostat_mode", | |
| "mirostat_tau", | |
| "mirostat_eta", | |
| "logit_bias", | |
| ] | |
| def llm(model: str, **kwargs: Any) -> Llama: | |
| # Drop the llama-cpp-python prefix from the model. | |
| repo_id_filename = model.replace("llama-cpp-python/", "") | |
| # Convert the LiteLLM model string to repo_id, filename, and n_ctx. | |
| repo_id, filename = repo_id_filename.rsplit("/", maxsplit=1) | |
| n_ctx = 0 | |
| if len(filename_n_ctx := filename.rsplit("@", maxsplit=1)) == 2: # noqa: PLR2004 | |
| filename, n_ctx_str = filename_n_ctx | |
| n_ctx = int(n_ctx_str) | |
| # Load the LLM. | |
| with warnings.catch_warnings(): # Filter huggingface_hub warning about HF_TOKEN. | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| llm = Llama.from_pretrained( | |
| repo_id=repo_id, | |
| filename=filename, | |
| n_ctx=n_ctx, | |
| n_gpu_layers=-1, | |
| verbose=False, | |
| **kwargs, | |
| ) | |
| # Enable caching. | |
| llm.set_cache(LlamaRAMCache()) | |
| # Register the model info with LiteLLM. | |
| litellm.register_model( # type: ignore[attr-defined] | |
| { | |
| model: { | |
| "max_tokens": llm.n_ctx(), | |
| "max_input_tokens": llm.n_ctx(), | |
| "max_output_tokens": None, | |
| "input_cost_per_token": 0.0, | |
| "output_cost_per_token": 0.0, | |
| "output_vector_size": llm.n_embd() if kwargs.get("embedding") else None, | |
| "litellm_provider": "llama-cpp-python", | |
| "mode": "embedding" if kwargs.get("embedding") else "completion", | |
| "supported_openai_params": LlamaCppPythonLLM.supported_openai_params, | |
| "supports_function_calling": True, | |
| "supports_parallel_function_calling": True, | |
| "supports_vision": False, | |
| } | |
| } | |
| ) | |
| return llm | |
| def completion( # noqa: PLR0913 | |
| self, | |
| model: str, | |
| messages: list[ChatCompletionRequestMessage], | |
| api_base: str, | |
| custom_prompt_dict: dict[str, Any], | |
| model_response: ModelResponse, | |
| print_verbose: Callable, # type: ignore[type-arg] | |
| encoding: str, | |
| api_key: str, | |
| logging_obj: Any, | |
| optional_params: dict[str, Any], | |
| acompletion: Callable | None = None, # type: ignore[type-arg] | |
| litellm_params: dict[str, Any] | None = None, | |
| logger_fn: Callable | None = None, # type: ignore[type-arg] | |
| headers: dict[str, Any] | None = None, | |
| timeout: float | httpx.Timeout | None = None, | |
| client: HTTPHandler | None = None, | |
| ) -> ModelResponse: | |
| llm = self.llm(model) | |
| llama_cpp_python_params = { | |
| k: v for k, v in optional_params.items() if k in self.supported_openai_params | |
| } | |
| response = cast( | |
| CreateChatCompletionResponse, | |
| llm.create_chat_completion(messages=messages, **llama_cpp_python_params), | |
| ) | |
| litellm_model_response: ModelResponse = convert_to_model_response_object( | |
| response_object=response, | |
| model_response_object=model_response, | |
| response_type="completion", | |
| stream=False, | |
| ) | |
| return litellm_model_response | |
| def streaming( # noqa: PLR0913 | |
| self, | |
| model: str, | |
| messages: list[ChatCompletionRequestMessage], | |
| api_base: str, | |
| custom_prompt_dict: dict[str, Any], | |
| model_response: ModelResponse, | |
| print_verbose: Callable, # type: ignore[type-arg] | |
| encoding: str, | |
| api_key: str, | |
| logging_obj: Any, | |
| optional_params: dict[str, Any], | |
| acompletion: Callable | None = None, # type: ignore[type-arg] | |
| litellm_params: dict[str, Any] | None = None, | |
| logger_fn: Callable | None = None, # type: ignore[type-arg] | |
| headers: dict[str, Any] | None = None, | |
| timeout: float | httpx.Timeout | None = None, | |
| client: HTTPHandler | None = None, | |
| ) -> Iterator[GenericStreamingChunk]: | |
| llm = self.llm(model) | |
| llama_cpp_python_params = { | |
| k: v for k, v in optional_params.items() if k in self.supported_openai_params | |
| } | |
| stream = cast( | |
| Iterator[CreateChatCompletionStreamResponse], | |
| llm.create_chat_completion(messages=messages, **llama_cpp_python_params, stream=True), | |
| ) | |
| for chunk in stream: | |
| choices = chunk.get("choices", []) | |
| for choice in choices: | |
| text = choice.get("delta", {}).get("content", None) | |
| finish_reason = choice.get("finish_reason") | |
| litellm_generic_streaming_chunk = GenericStreamingChunk( | |
| text=text, # type: ignore[typeddict-item] | |
| is_finished=bool(finish_reason), | |
| finish_reason=finish_reason, # type: ignore[typeddict-item] | |
| usage=None, | |
| index=choice.get("index"), # type: ignore[typeddict-item] | |
| provider_specific_fields={ | |
| "id": chunk.get("id"), | |
| "model": chunk.get("model"), | |
| "created": chunk.get("created"), | |
| "object": chunk.get("object"), | |
| }, | |
| ) | |
| yield litellm_generic_streaming_chunk | |
| async def astreaming( # type: ignore[misc,override] # noqa: PLR0913 | |
| self, | |
| model: str, | |
| messages: list[ChatCompletionRequestMessage], | |
| api_base: str, | |
| custom_prompt_dict: dict[str, Any], | |
| model_response: ModelResponse, | |
| print_verbose: Callable, # type: ignore[type-arg] | |
| encoding: str, | |
| api_key: str, | |
| logging_obj: Any, | |
| optional_params: dict[str, Any], | |
| acompletion: Callable | None = None, # type: ignore[type-arg] | |
| litellm_params: dict[str, Any] | None = None, | |
| logger_fn: Callable | None = None, # type: ignore[type-arg] | |
| headers: dict[str, Any] | None = None, | |
| timeout: float | httpx.Timeout | None = None, # noqa: ASYNC109 | |
| client: AsyncHTTPHandler | None = None, | |
| ) -> AsyncIterator[GenericStreamingChunk]: | |
| # Start a synchronous stream. | |
| stream = self.streaming( | |
| model, | |
| messages, | |
| api_base, | |
| custom_prompt_dict, | |
| model_response, | |
| print_verbose, | |
| encoding, | |
| api_key, | |
| logging_obj, | |
| optional_params, | |
| acompletion, | |
| litellm_params, | |
| logger_fn, | |
| headers, | |
| timeout, | |
| ) | |
| await asyncio.sleep(0) # Yield control to the event loop after initialising the context. | |
| # Wrap the synchronous stream in an asynchronous stream. | |
| async with LlamaCppPythonLLM.streaming_lock: | |
| for litellm_generic_streaming_chunk in stream: | |
| yield litellm_generic_streaming_chunk | |
| await asyncio.sleep(0) # Yield control to the event loop after each token. | |
| # Register the LlamaCppPythonLLM provider. | |
| if not any(provider["provider"] == "llama-cpp-python" for provider in litellm.custom_provider_map): | |
| litellm.custom_provider_map.append( | |
| {"provider": "llama-cpp-python", "custom_handler": LlamaCppPythonLLM()} | |
| ) | |
| litellm.suppress_debug_info = True | |