Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Multi-LLM provider support with fallback logic. | |
| Includes SambaNova free tier as primary fallback option. | |
| """ | |
| import os | |
| import json | |
| import logging | |
| import httpx | |
| import asyncio | |
| from typing import AsyncGenerator, List, Dict, Any, Optional, Tuple | |
| from anthropic import AsyncAnthropic | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| logger = logging.getLogger(__name__) | |
| class SambaNovaProvider: | |
| """ | |
| SambaNova Cloud provider - requires API key for access. | |
| Get your API key at https://cloud.sambanova.ai/ | |
| Includes $5-30 free credits for new accounts. | |
| """ | |
| BASE_URL = "https://api.sambanova.ai/v1" | |
| # Available models | |
| MODELS = { | |
| "llama-3.3-70b": "Meta-Llama-3.3-70B-Instruct", # Latest and best! | |
| "llama-3.1-405b": "Meta-Llama-3.1-405B-Instruct", | |
| "llama-3.1-70b": "Meta-Llama-3.1-70B-Instruct", | |
| "llama-3.1-8b": "Meta-Llama-3.1-8B-Instruct", | |
| "llama-3.2-11b": "Llama-3.2-11B-Vision-Instruct", | |
| "llama-3.2-3b": "Llama-3.2-3B-Instruct", | |
| "llama-3.2-1b": "Llama-3.2-1B-Instruct" | |
| } | |
| def __init__(self, api_key: Optional[str] = None): | |
| """ | |
| Initialize SambaNova provider. | |
| API key is REQUIRED - get yours at https://cloud.sambanova.ai/ | |
| """ | |
| self.api_key = api_key or os.getenv("SAMBANOVA_API_KEY") | |
| if not self.api_key: | |
| raise ValueError( | |
| "SAMBANOVA_API_KEY is required for SambaNova API access.\n" | |
| "Get your API key at: https://cloud.sambanova.ai/\n" | |
| "Then set it in your .env file: SAMBANOVA_API_KEY=your_key_here" | |
| ) | |
| self.client = httpx.AsyncClient(timeout=60.0) | |
| async def stream( | |
| self, | |
| messages: List[Dict], | |
| system: str = None, | |
| tools: List[Dict] = None, | |
| model: str = "llama-3.1-70b", | |
| max_tokens: int = 4096, | |
| temperature: float = 0.7 | |
| ) -> AsyncGenerator[Tuple[str, List[Dict]], None]: | |
| """ | |
| Stream responses from SambaNova API. | |
| Compatible interface with Anthropic streaming. | |
| """ | |
| # Select the full model name | |
| full_model = self.MODELS.get(model, self.MODELS["llama-3.1-70b"]) | |
| # Convert messages to OpenAI format (SambaNova uses OpenAI-compatible API) | |
| formatted_messages = [] | |
| # Add system message if provided | |
| if system: | |
| formatted_messages.append({ | |
| "role": "system", | |
| "content": system | |
| }) | |
| # Convert Anthropic message format to OpenAI format | |
| for msg in messages: | |
| if msg["role"] == "user": | |
| formatted_messages.append({ | |
| "role": "user", | |
| "content": msg.get("content", "") | |
| }) | |
| elif msg["role"] == "assistant": | |
| # Handle assistant messages with potential tool calls | |
| content = msg.get("content", "") | |
| if isinstance(content, list): | |
| # Extract text from content blocks | |
| text_parts = [] | |
| for block in content: | |
| if block.get("type") == "text": | |
| text_parts.append(block.get("text", "")) | |
| content = "\n".join(text_parts) | |
| formatted_messages.append({ | |
| "role": "assistant", | |
| "content": content | |
| }) | |
| # Prepare request payload | |
| payload = { | |
| "model": full_model, | |
| "messages": formatted_messages, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| "stream": True, | |
| "stream_options": {"include_usage": True} | |
| } | |
| # Add tools if provided (for models that support it) | |
| if tools and model in ["llama-3.3-70b", "llama-3.1-405b", "llama-3.1-70b"]: | |
| # Convert Anthropic tool format to OpenAI format | |
| openai_tools = [] | |
| for tool in tools: | |
| openai_tools.append({ | |
| "type": "function", | |
| "function": { | |
| "name": tool["name"], | |
| "description": tool.get("description", ""), | |
| "parameters": tool.get("input_schema", {}) | |
| } | |
| }) | |
| payload["tools"] = openai_tools | |
| payload["tool_choice"] = "auto" | |
| # Headers - API key is always required now | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {self.api_key}" | |
| } | |
| try: | |
| # Make streaming request | |
| accumulated_text = "" | |
| tool_calls = [] | |
| async with self.client.stream( | |
| "POST", | |
| f"{self.BASE_URL}/chat/completions", | |
| json=payload, | |
| headers=headers | |
| ) as response: | |
| response.raise_for_status() | |
| async for line in response.aiter_lines(): | |
| if line.startswith("data: "): | |
| data = line[6:] # Remove "data: " prefix | |
| if data == "[DONE]": | |
| break | |
| try: | |
| chunk = json.loads(data) | |
| # Handle usage-only chunks (sent at end of stream) | |
| if "usage" in chunk and ("choices" not in chunk or len(chunk.get("choices", [])) == 0): | |
| # This is a usage statistics chunk, skip it | |
| logger.debug(f"Received usage chunk: {chunk.get('usage', {})}") | |
| continue | |
| # Extract content from chunk | |
| if "choices" in chunk and len(chunk["choices"]) > 0: | |
| choice = chunk["choices"][0] | |
| delta = choice.get("delta", {}) | |
| # Handle text content | |
| if "content" in delta and delta["content"]: | |
| accumulated_text += delta["content"] | |
| yield (accumulated_text, tool_calls) | |
| # Handle tool calls (if supported) | |
| if "tool_calls" in delta: | |
| for tc in delta["tool_calls"]: | |
| # Convert OpenAI tool call format to Anthropic format | |
| tool_calls.append({ | |
| "id": tc.get("id", f"tool_{len(tool_calls)}"), | |
| "name": tc.get("function", {}).get("name", ""), | |
| "input": json.loads(tc.get("function", {}).get("arguments", "{}")) | |
| }) | |
| except json.JSONDecodeError: | |
| logger.warning(f"Failed to parse SSE data: {data}") | |
| continue | |
| # Final yield with complete results | |
| yield (accumulated_text, tool_calls) | |
| except httpx.HTTPStatusError as e: | |
| if e.response.status_code == 410: | |
| logger.error("SambaNova API endpoint has been discontinued (410 GONE)") | |
| raise RuntimeError( | |
| "SambaNova API endpoint no longer exists. " | |
| "Make sure you have a valid API key set in SAMBANOVA_API_KEY." | |
| ) | |
| elif e.response.status_code == 401: | |
| logger.error("SambaNova API authentication failed") | |
| raise RuntimeError( | |
| "SambaNova authentication failed. Please check your API key." | |
| ) | |
| else: | |
| logger.error(f"SambaNova API error: {e}") | |
| raise | |
| except httpx.HTTPError as e: | |
| logger.error(f"SambaNova API error: {e}") | |
| raise | |
| async def close(self): | |
| """Close the HTTP client""" | |
| await self.client.aclose() | |
| class LLMRouter: | |
| """ | |
| Routes LLM requests to appropriate providers with fallback logic. | |
| """ | |
| def __init__(self): | |
| self.providers = {} | |
| self._setup_providers() | |
| def _setup_providers(self): | |
| """Initialize available providers""" | |
| # Primary: Anthropic (if API key available) | |
| anthropic_key = os.getenv("ANTHROPIC_API_KEY") | |
| if anthropic_key: | |
| self.providers["anthropic"] = AsyncAnthropic(api_key=anthropic_key) | |
| logger.info("Anthropic provider initialized") | |
| # Fallback: SambaNova (always available, free!) | |
| self.providers["sambanova"] = SambaNovaProvider() | |
| logger.info("SambaNova provider initialized (free tier)") | |
| async def stream_with_fallback( | |
| self, | |
| messages: List[Dict], | |
| tools: List[Dict], | |
| system_prompt: str, | |
| model: str = None, | |
| max_tokens: int = 4096, | |
| provider_preference: str = "auto" | |
| ) -> AsyncGenerator[Tuple[str, List[Dict], str], None]: | |
| """ | |
| Stream from LLM with automatic fallback. | |
| Returns (text, tool_calls, provider_used) tuples. | |
| """ | |
| # Determine provider order based on preference | |
| if provider_preference == "cost_optimize": | |
| # Prefer free SambaNova first | |
| provider_order = ["sambanova", "anthropic"] | |
| elif provider_preference == "quality_first": | |
| # Prefer Anthropic first | |
| provider_order = ["anthropic", "sambanova"] | |
| else: # auto | |
| # Use Anthropic if available, fall back to SambaNova | |
| provider_order = ["anthropic", "sambanova"] if "anthropic" in self.providers else ["sambanova"] | |
| last_error = None | |
| for provider_name in provider_order: | |
| if provider_name not in self.providers: | |
| continue | |
| try: | |
| logger.info(f"Attempting to use {provider_name} provider...") | |
| if provider_name == "anthropic": | |
| # Use existing Anthropic streaming | |
| provider = self.providers["anthropic"] | |
| # Stream from Anthropic | |
| accumulated_text = "" | |
| tool_calls = [] | |
| async with provider.messages.stream( | |
| model=model or "claude-sonnet-4-5-20250929", | |
| max_tokens=max_tokens, | |
| messages=messages, | |
| system=system_prompt, | |
| tools=tools | |
| ) as stream: | |
| async for event in stream: | |
| if event.type == "content_block_start": | |
| if event.content_block.type == "tool_use": | |
| tool_calls.append({ | |
| "id": event.content_block.id, | |
| "name": event.content_block.name, | |
| "input": {} | |
| }) | |
| elif event.type == "content_block_delta": | |
| if event.delta.type == "text_delta": | |
| accumulated_text += event.delta.text | |
| yield (accumulated_text, tool_calls, "Anthropic Claude") | |
| # Get final message | |
| final_message = await stream.get_final_message() | |
| # Rebuild tool calls from final message | |
| tool_calls.clear() | |
| for block in final_message.content: | |
| if block.type == "tool_use": | |
| tool_calls.append({ | |
| "id": block.id, | |
| "name": block.name, | |
| "input": block.input | |
| }) | |
| yield (accumulated_text, tool_calls, "Anthropic Claude") | |
| return # Success! | |
| elif provider_name == "sambanova": | |
| # Use SambaNova streaming | |
| provider = self.providers["sambanova"] | |
| # Determine which Llama model to use | |
| if max_tokens > 8192: | |
| samba_model = "llama-3.1-405b" # Largest model for complex tasks | |
| else: | |
| # Default to Llama 3.3 70B - newest and best for most tasks | |
| samba_model = "llama-3.3-70b" | |
| async for text, tool_calls in provider.stream( | |
| messages=messages, | |
| system=system_prompt, | |
| tools=tools, | |
| model=samba_model, | |
| max_tokens=max_tokens | |
| ): | |
| yield (text, tool_calls, f"SambaNova {samba_model}") | |
| return # Success! | |
| except Exception as e: | |
| logger.warning(f"Provider {provider_name} failed: {e}") | |
| last_error = e | |
| continue | |
| # All providers failed | |
| error_msg = f"All LLM providers failed. Last error: {last_error}" | |
| logger.error(error_msg) | |
| raise Exception(error_msg) | |
| async def cleanup(self): | |
| """Clean up provider resources""" | |
| if "sambanova" in self.providers: | |
| await self.providers["sambanova"].close() | |
| # Global router instance | |
| llm_router = LLMRouter() |