#!/usr/bin/env python3 """ Multi-LLM provider support with fallback logic. Includes SambaNova free tier as primary fallback option. """ import os import json import logging import httpx import asyncio from typing import AsyncGenerator, List, Dict, Any, Optional, Tuple from anthropic import AsyncAnthropic from dotenv import load_dotenv load_dotenv() logger = logging.getLogger(__name__) class SambaNovaProvider: """ SambaNova Cloud provider - requires API key for access. Get your API key at https://cloud.sambanova.ai/ Includes $5-30 free credits for new accounts. """ BASE_URL = "https://api.sambanova.ai/v1" # Available models MODELS = { "llama-3.3-70b": "Meta-Llama-3.3-70B-Instruct", # Latest and best! "llama-3.1-405b": "Meta-Llama-3.1-405B-Instruct", "llama-3.1-70b": "Meta-Llama-3.1-70B-Instruct", "llama-3.1-8b": "Meta-Llama-3.1-8B-Instruct", "llama-3.2-11b": "Llama-3.2-11B-Vision-Instruct", "llama-3.2-3b": "Llama-3.2-3B-Instruct", "llama-3.2-1b": "Llama-3.2-1B-Instruct" } def __init__(self, api_key: Optional[str] = None): """ Initialize SambaNova provider. API key is REQUIRED - get yours at https://cloud.sambanova.ai/ """ self.api_key = api_key or os.getenv("SAMBANOVA_API_KEY") if not self.api_key: raise ValueError( "SAMBANOVA_API_KEY is required for SambaNova API access.\n" "Get your API key at: https://cloud.sambanova.ai/\n" "Then set it in your .env file: SAMBANOVA_API_KEY=your_key_here" ) self.client = httpx.AsyncClient(timeout=60.0) async def stream( self, messages: List[Dict], system: str = None, tools: List[Dict] = None, model: str = "llama-3.1-70b", max_tokens: int = 4096, temperature: float = 0.7 ) -> AsyncGenerator[Tuple[str, List[Dict]], None]: """ Stream responses from SambaNova API. Compatible interface with Anthropic streaming. """ # Select the full model name full_model = self.MODELS.get(model, self.MODELS["llama-3.1-70b"]) # Convert messages to OpenAI format (SambaNova uses OpenAI-compatible API) formatted_messages = [] # Add system message if provided if system: formatted_messages.append({ "role": "system", "content": system }) # Convert Anthropic message format to OpenAI format for msg in messages: if msg["role"] == "user": formatted_messages.append({ "role": "user", "content": msg.get("content", "") }) elif msg["role"] == "assistant": # Handle assistant messages with potential tool calls content = msg.get("content", "") if isinstance(content, list): # Extract text from content blocks text_parts = [] for block in content: if block.get("type") == "text": text_parts.append(block.get("text", "")) content = "\n".join(text_parts) formatted_messages.append({ "role": "assistant", "content": content }) # Prepare request payload payload = { "model": full_model, "messages": formatted_messages, "max_tokens": max_tokens, "temperature": temperature, "stream": True, "stream_options": {"include_usage": True} } # Add tools if provided (for models that support it) if tools and model in ["llama-3.3-70b", "llama-3.1-405b", "llama-3.1-70b"]: # Convert Anthropic tool format to OpenAI format openai_tools = [] for tool in tools: openai_tools.append({ "type": "function", "function": { "name": tool["name"], "description": tool.get("description", ""), "parameters": tool.get("input_schema", {}) } }) payload["tools"] = openai_tools payload["tool_choice"] = "auto" # Headers - API key is always required now headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}" } try: # Make streaming request accumulated_text = "" tool_calls = [] async with self.client.stream( "POST", f"{self.BASE_URL}/chat/completions", json=payload, headers=headers ) as response: response.raise_for_status() async for line in response.aiter_lines(): if line.startswith("data: "): data = line[6:] # Remove "data: " prefix if data == "[DONE]": break try: chunk = json.loads(data) # Handle usage-only chunks (sent at end of stream) if "usage" in chunk and ("choices" not in chunk or len(chunk.get("choices", [])) == 0): # This is a usage statistics chunk, skip it logger.debug(f"Received usage chunk: {chunk.get('usage', {})}") continue # Extract content from chunk if "choices" in chunk and len(chunk["choices"]) > 0: choice = chunk["choices"][0] delta = choice.get("delta", {}) # Handle text content if "content" in delta and delta["content"]: accumulated_text += delta["content"] yield (accumulated_text, tool_calls) # Handle tool calls (if supported) if "tool_calls" in delta: for tc in delta["tool_calls"]: # Convert OpenAI tool call format to Anthropic format tool_calls.append({ "id": tc.get("id", f"tool_{len(tool_calls)}"), "name": tc.get("function", {}).get("name", ""), "input": json.loads(tc.get("function", {}).get("arguments", "{}")) }) except json.JSONDecodeError: logger.warning(f"Failed to parse SSE data: {data}") continue # Final yield with complete results yield (accumulated_text, tool_calls) except httpx.HTTPStatusError as e: if e.response.status_code == 410: logger.error("SambaNova API endpoint has been discontinued (410 GONE)") raise RuntimeError( "SambaNova API endpoint no longer exists. " "Make sure you have a valid API key set in SAMBANOVA_API_KEY." ) elif e.response.status_code == 401: logger.error("SambaNova API authentication failed") raise RuntimeError( "SambaNova authentication failed. Please check your API key." ) else: logger.error(f"SambaNova API error: {e}") raise except httpx.HTTPError as e: logger.error(f"SambaNova API error: {e}") raise async def close(self): """Close the HTTP client""" await self.client.aclose() class LLMRouter: """ Routes LLM requests to appropriate providers with fallback logic. """ def __init__(self): self.providers = {} self._setup_providers() def _setup_providers(self): """Initialize available providers""" # Primary: Anthropic (if API key available) anthropic_key = os.getenv("ANTHROPIC_API_KEY") if anthropic_key: self.providers["anthropic"] = AsyncAnthropic(api_key=anthropic_key) logger.info("Anthropic provider initialized") # Fallback: SambaNova (always available, free!) self.providers["sambanova"] = SambaNovaProvider() logger.info("SambaNova provider initialized (free tier)") async def stream_with_fallback( self, messages: List[Dict], tools: List[Dict], system_prompt: str, model: str = None, max_tokens: int = 4096, provider_preference: str = "auto" ) -> AsyncGenerator[Tuple[str, List[Dict], str], None]: """ Stream from LLM with automatic fallback. Returns (text, tool_calls, provider_used) tuples. """ # Determine provider order based on preference if provider_preference == "cost_optimize": # Prefer free SambaNova first provider_order = ["sambanova", "anthropic"] elif provider_preference == "quality_first": # Prefer Anthropic first provider_order = ["anthropic", "sambanova"] else: # auto # Use Anthropic if available, fall back to SambaNova provider_order = ["anthropic", "sambanova"] if "anthropic" in self.providers else ["sambanova"] last_error = None for provider_name in provider_order: if provider_name not in self.providers: continue try: logger.info(f"Attempting to use {provider_name} provider...") if provider_name == "anthropic": # Use existing Anthropic streaming provider = self.providers["anthropic"] # Stream from Anthropic accumulated_text = "" tool_calls = [] async with provider.messages.stream( model=model or "claude-sonnet-4-5-20250929", max_tokens=max_tokens, messages=messages, system=system_prompt, tools=tools ) as stream: async for event in stream: if event.type == "content_block_start": if event.content_block.type == "tool_use": tool_calls.append({ "id": event.content_block.id, "name": event.content_block.name, "input": {} }) elif event.type == "content_block_delta": if event.delta.type == "text_delta": accumulated_text += event.delta.text yield (accumulated_text, tool_calls, "Anthropic Claude") # Get final message final_message = await stream.get_final_message() # Rebuild tool calls from final message tool_calls.clear() for block in final_message.content: if block.type == "tool_use": tool_calls.append({ "id": block.id, "name": block.name, "input": block.input }) yield (accumulated_text, tool_calls, "Anthropic Claude") return # Success! elif provider_name == "sambanova": # Use SambaNova streaming provider = self.providers["sambanova"] # Determine which Llama model to use if max_tokens > 8192: samba_model = "llama-3.1-405b" # Largest model for complex tasks else: # Default to Llama 3.3 70B - newest and best for most tasks samba_model = "llama-3.3-70b" async for text, tool_calls in provider.stream( messages=messages, system=system_prompt, tools=tools, model=samba_model, max_tokens=max_tokens ): yield (text, tool_calls, f"SambaNova {samba_model}") return # Success! except Exception as e: logger.warning(f"Provider {provider_name} failed: {e}") last_error = e continue # All providers failed error_msg = f"All LLM providers failed. Last error: {last_error}" logger.error(error_msg) raise Exception(error_msg) async def cleanup(self): """Clean up provider resources""" if "sambanova" in self.providers: await self.providers["sambanova"].close() # Global router instance llm_router = LLMRouter()