ALSARA / llm_providers.py
axegameon's picture
Upload ALSARA app files (#1)
3e435ad verified
#!/usr/bin/env python3
"""
Multi-LLM provider support with fallback logic.
Includes SambaNova free tier as primary fallback option.
"""
import os
import json
import logging
import httpx
import asyncio
from typing import AsyncGenerator, List, Dict, Any, Optional, Tuple
from anthropic import AsyncAnthropic
from dotenv import load_dotenv
load_dotenv()
logger = logging.getLogger(__name__)
class SambaNovaProvider:
"""
SambaNova Cloud provider - requires API key for access.
Get your API key at https://cloud.sambanova.ai/
Includes $5-30 free credits for new accounts.
"""
BASE_URL = "https://api.sambanova.ai/v1"
# Available models
MODELS = {
"llama-3.3-70b": "Meta-Llama-3.3-70B-Instruct", # Latest and best!
"llama-3.1-405b": "Meta-Llama-3.1-405B-Instruct",
"llama-3.1-70b": "Meta-Llama-3.1-70B-Instruct",
"llama-3.1-8b": "Meta-Llama-3.1-8B-Instruct",
"llama-3.2-11b": "Llama-3.2-11B-Vision-Instruct",
"llama-3.2-3b": "Llama-3.2-3B-Instruct",
"llama-3.2-1b": "Llama-3.2-1B-Instruct"
}
def __init__(self, api_key: Optional[str] = None):
"""
Initialize SambaNova provider.
API key is REQUIRED - get yours at https://cloud.sambanova.ai/
"""
self.api_key = api_key or os.getenv("SAMBANOVA_API_KEY")
if not self.api_key:
raise ValueError(
"SAMBANOVA_API_KEY is required for SambaNova API access.\n"
"Get your API key at: https://cloud.sambanova.ai/\n"
"Then set it in your .env file: SAMBANOVA_API_KEY=your_key_here"
)
self.client = httpx.AsyncClient(timeout=60.0)
async def stream(
self,
messages: List[Dict],
system: str = None,
tools: List[Dict] = None,
model: str = "llama-3.1-70b",
max_tokens: int = 4096,
temperature: float = 0.7
) -> AsyncGenerator[Tuple[str, List[Dict]], None]:
"""
Stream responses from SambaNova API.
Compatible interface with Anthropic streaming.
"""
# Select the full model name
full_model = self.MODELS.get(model, self.MODELS["llama-3.1-70b"])
# Convert messages to OpenAI format (SambaNova uses OpenAI-compatible API)
formatted_messages = []
# Add system message if provided
if system:
formatted_messages.append({
"role": "system",
"content": system
})
# Convert Anthropic message format to OpenAI format
for msg in messages:
if msg["role"] == "user":
formatted_messages.append({
"role": "user",
"content": msg.get("content", "")
})
elif msg["role"] == "assistant":
# Handle assistant messages with potential tool calls
content = msg.get("content", "")
if isinstance(content, list):
# Extract text from content blocks
text_parts = []
for block in content:
if block.get("type") == "text":
text_parts.append(block.get("text", ""))
content = "\n".join(text_parts)
formatted_messages.append({
"role": "assistant",
"content": content
})
# Prepare request payload
payload = {
"model": full_model,
"messages": formatted_messages,
"max_tokens": max_tokens,
"temperature": temperature,
"stream": True,
"stream_options": {"include_usage": True}
}
# Add tools if provided (for models that support it)
if tools and model in ["llama-3.3-70b", "llama-3.1-405b", "llama-3.1-70b"]:
# Convert Anthropic tool format to OpenAI format
openai_tools = []
for tool in tools:
openai_tools.append({
"type": "function",
"function": {
"name": tool["name"],
"description": tool.get("description", ""),
"parameters": tool.get("input_schema", {})
}
})
payload["tools"] = openai_tools
payload["tool_choice"] = "auto"
# Headers - API key is always required now
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
try:
# Make streaming request
accumulated_text = ""
tool_calls = []
async with self.client.stream(
"POST",
f"{self.BASE_URL}/chat/completions",
json=payload,
headers=headers
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
data = line[6:] # Remove "data: " prefix
if data == "[DONE]":
break
try:
chunk = json.loads(data)
# Handle usage-only chunks (sent at end of stream)
if "usage" in chunk and ("choices" not in chunk or len(chunk.get("choices", [])) == 0):
# This is a usage statistics chunk, skip it
logger.debug(f"Received usage chunk: {chunk.get('usage', {})}")
continue
# Extract content from chunk
if "choices" in chunk and len(chunk["choices"]) > 0:
choice = chunk["choices"][0]
delta = choice.get("delta", {})
# Handle text content
if "content" in delta and delta["content"]:
accumulated_text += delta["content"]
yield (accumulated_text, tool_calls)
# Handle tool calls (if supported)
if "tool_calls" in delta:
for tc in delta["tool_calls"]:
# Convert OpenAI tool call format to Anthropic format
tool_calls.append({
"id": tc.get("id", f"tool_{len(tool_calls)}"),
"name": tc.get("function", {}).get("name", ""),
"input": json.loads(tc.get("function", {}).get("arguments", "{}"))
})
except json.JSONDecodeError:
logger.warning(f"Failed to parse SSE data: {data}")
continue
# Final yield with complete results
yield (accumulated_text, tool_calls)
except httpx.HTTPStatusError as e:
if e.response.status_code == 410:
logger.error("SambaNova API endpoint has been discontinued (410 GONE)")
raise RuntimeError(
"SambaNova API endpoint no longer exists. "
"Make sure you have a valid API key set in SAMBANOVA_API_KEY."
)
elif e.response.status_code == 401:
logger.error("SambaNova API authentication failed")
raise RuntimeError(
"SambaNova authentication failed. Please check your API key."
)
else:
logger.error(f"SambaNova API error: {e}")
raise
except httpx.HTTPError as e:
logger.error(f"SambaNova API error: {e}")
raise
async def close(self):
"""Close the HTTP client"""
await self.client.aclose()
class LLMRouter:
"""
Routes LLM requests to appropriate providers with fallback logic.
"""
def __init__(self):
self.providers = {}
self._setup_providers()
def _setup_providers(self):
"""Initialize available providers"""
# Primary: Anthropic (if API key available)
anthropic_key = os.getenv("ANTHROPIC_API_KEY")
if anthropic_key:
self.providers["anthropic"] = AsyncAnthropic(api_key=anthropic_key)
logger.info("Anthropic provider initialized")
# Fallback: SambaNova (always available, free!)
self.providers["sambanova"] = SambaNovaProvider()
logger.info("SambaNova provider initialized (free tier)")
async def stream_with_fallback(
self,
messages: List[Dict],
tools: List[Dict],
system_prompt: str,
model: str = None,
max_tokens: int = 4096,
provider_preference: str = "auto"
) -> AsyncGenerator[Tuple[str, List[Dict], str], None]:
"""
Stream from LLM with automatic fallback.
Returns (text, tool_calls, provider_used) tuples.
"""
# Determine provider order based on preference
if provider_preference == "cost_optimize":
# Prefer free SambaNova first
provider_order = ["sambanova", "anthropic"]
elif provider_preference == "quality_first":
# Prefer Anthropic first
provider_order = ["anthropic", "sambanova"]
else: # auto
# Use Anthropic if available, fall back to SambaNova
provider_order = ["anthropic", "sambanova"] if "anthropic" in self.providers else ["sambanova"]
last_error = None
for provider_name in provider_order:
if provider_name not in self.providers:
continue
try:
logger.info(f"Attempting to use {provider_name} provider...")
if provider_name == "anthropic":
# Use existing Anthropic streaming
provider = self.providers["anthropic"]
# Stream from Anthropic
accumulated_text = ""
tool_calls = []
async with provider.messages.stream(
model=model or "claude-sonnet-4-5-20250929",
max_tokens=max_tokens,
messages=messages,
system=system_prompt,
tools=tools
) as stream:
async for event in stream:
if event.type == "content_block_start":
if event.content_block.type == "tool_use":
tool_calls.append({
"id": event.content_block.id,
"name": event.content_block.name,
"input": {}
})
elif event.type == "content_block_delta":
if event.delta.type == "text_delta":
accumulated_text += event.delta.text
yield (accumulated_text, tool_calls, "Anthropic Claude")
# Get final message
final_message = await stream.get_final_message()
# Rebuild tool calls from final message
tool_calls.clear()
for block in final_message.content:
if block.type == "tool_use":
tool_calls.append({
"id": block.id,
"name": block.name,
"input": block.input
})
yield (accumulated_text, tool_calls, "Anthropic Claude")
return # Success!
elif provider_name == "sambanova":
# Use SambaNova streaming
provider = self.providers["sambanova"]
# Determine which Llama model to use
if max_tokens > 8192:
samba_model = "llama-3.1-405b" # Largest model for complex tasks
else:
# Default to Llama 3.3 70B - newest and best for most tasks
samba_model = "llama-3.3-70b"
async for text, tool_calls in provider.stream(
messages=messages,
system=system_prompt,
tools=tools,
model=samba_model,
max_tokens=max_tokens
):
yield (text, tool_calls, f"SambaNova {samba_model}")
return # Success!
except Exception as e:
logger.warning(f"Provider {provider_name} failed: {e}")
last_error = e
continue
# All providers failed
error_msg = f"All LLM providers failed. Last error: {last_error}"
logger.error(error_msg)
raise Exception(error_msg)
async def cleanup(self):
"""Clean up provider resources"""
if "sambanova" in self.providers:
await self.providers["sambanova"].close()
# Global router instance
llm_router = LLMRouter()