Really-amin's picture
Upload 577 files
b190b45 verified
"""
HuggingFace Sentiment Provider - AI-powered text analysis
Provides:
- Sentiment analysis using transformer models
- Text summarization
- Named entity recognition
- Zero-shot classification
Uses HuggingFace Inference API for model inference.
API Documentation: https://huggingface.co/docs/api-inference/
"""
from __future__ import annotations
import os
from typing import Any, Dict, List, Optional
from .base import BaseProvider, create_success_response, create_error_response
class HFSentimentProvider(BaseProvider):
"""HuggingFace Inference API provider for AI-powered analysis"""
# API Key from environment variable
API_KEY = os.getenv("HF_API_TOKEN") or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") or ""
# Default models for each task (using stable, available models)
MODELS = {
"sentiment": "distilbert-base-uncased-finetuned-sst-2-english",
"sentiment_financial": "mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis",
"summarization": "sshleifer/distilbart-cnn-12-6",
"ner": "dslim/bert-base-NER",
"classification": "facebook/bart-large-mnli",
"text_generation": "gpt2"
}
def __init__(self, api_key: Optional[str] = None):
super().__init__(
name="huggingface",
base_url="https://router.huggingface.co/hf-inference/models",
api_key=api_key or self.API_KEY,
timeout=15.0, # HF inference can be slower
cache_ttl=60.0 # Cache AI results for 60 seconds
)
def _get_default_headers(self) -> Dict[str, str]:
"""Get headers with HuggingFace authorization"""
return {
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
async def analyze_sentiment(
self,
text: str,
model: Optional[str] = None,
use_financial_model: bool = False
) -> Dict[str, Any]:
"""
Analyze sentiment of text using HuggingFace models.
Args:
text: Text to analyze
model: Custom model to use (optional)
use_financial_model: Use FinBERT for financial text
Returns:
Standardized response with sentiment analysis
"""
if not text or len(text.strip()) < 3:
return create_error_response(
self.name,
"Invalid text",
"Text must be at least 3 characters"
)
# Truncate text if too long (HF has limits)
text = text[:1000]
# Select model
if model:
model_id = model
elif use_financial_model:
model_id = self.MODELS["sentiment_financial"]
else:
model_id = self.MODELS["sentiment"]
# Build endpoint
endpoint = f"{model_id}"
response = await self.post(endpoint, json_data={"inputs": text})
if not response.get("success"):
return response
data = response.get("data", [])
# Handle model loading state
if isinstance(data, dict) and data.get("error"):
error_msg = data.get("error", "Model error")
if "loading" in error_msg.lower():
return create_error_response(
self.name,
"Model is loading",
"Please retry in a few seconds"
)
return create_error_response(self.name, error_msg)
# Parse sentiment results
results = self._parse_sentiment_results(data, model_id)
return create_success_response(
self.name,
{
"text": text[:100] + "..." if len(text) > 100 else text,
"model": model_id,
"sentiment": results
}
)
def _parse_sentiment_results(self, data: Any, model_id: str) -> Dict[str, Any]:
"""Parse sentiment results from different model formats"""
if not data:
return {"label": "unknown", "score": 0.0}
# Handle nested list format [[{label, score}, ...]]
if isinstance(data, list) and len(data) > 0:
if isinstance(data[0], list):
data = data[0]
# Find highest scoring label
best = max(data, key=lambda x: x.get("score", 0))
# Normalize label
label = best.get("label", "unknown").lower()
score = best.get("score", 0.0)
# Map common labels
label_map = {
"label_0": "negative",
"label_1": "neutral",
"label_2": "positive",
"negative": "negative",
"neutral": "neutral",
"positive": "positive",
"pos": "positive",
"neg": "negative",
"neu": "neutral"
}
normalized_label = label_map.get(label, label)
return {
"label": normalized_label,
"score": round(score, 4),
"allScores": [
{"label": item.get("label"), "score": round(item.get("score", 0), 4)}
for item in data
]
}
return {"label": "unknown", "score": 0.0}
async def summarize_text(
self,
text: str,
max_length: int = 150,
min_length: int = 30,
model: Optional[str] = None
) -> Dict[str, Any]:
"""
Summarize text using HuggingFace summarization model.
Args:
text: Text to summarize
max_length: Maximum summary length
min_length: Minimum summary length
model: Custom model to use
"""
if not text or len(text.strip()) < 50:
return create_error_response(
self.name,
"Text too short",
"Text must be at least 50 characters for summarization"
)
# Truncate very long text
text = text[:3000]
model_id = model or self.MODELS["summarization"]
payload = {
"inputs": text,
"parameters": {
"max_length": max_length,
"min_length": min_length,
"do_sample": False
}
}
response = await self.post(model_id, json_data=payload)
if not response.get("success"):
return response
data = response.get("data", [])
# Handle model loading
if isinstance(data, dict) and data.get("error"):
error_msg = data.get("error", "Model error")
if "loading" in error_msg.lower():
return create_error_response(
self.name,
"Model is loading",
"Please retry in a few seconds"
)
return create_error_response(self.name, error_msg)
# Parse summary
summary = ""
if isinstance(data, list) and len(data) > 0:
summary = data[0].get("summary_text", "")
elif isinstance(data, dict):
summary = data.get("summary_text", "")
return create_success_response(
self.name,
{
"originalLength": len(text),
"summaryLength": len(summary),
"model": model_id,
"summary": summary
}
)
async def extract_entities(
self,
text: str,
model: Optional[str] = None
) -> Dict[str, Any]:
"""
Extract named entities from text.
Args:
text: Text to analyze
model: Custom NER model to use
"""
if not text or len(text.strip()) < 3:
return create_error_response(
self.name,
"Invalid text",
"Text must be at least 3 characters"
)
text = text[:1000]
model_id = model or self.MODELS["ner"]
response = await self.post(model_id, json_data={"inputs": text})
if not response.get("success"):
return response
data = response.get("data", [])
if isinstance(data, dict) and data.get("error"):
error_msg = data.get("error", "Model error")
if "loading" in error_msg.lower():
return create_error_response(
self.name,
"Model is loading",
"Please retry in a few seconds"
)
return create_error_response(self.name, error_msg)
# Parse entities
entities = []
if isinstance(data, list):
for entity in data:
entities.append({
"word": entity.get("word"),
"entity": entity.get("entity_group") or entity.get("entity"),
"score": round(entity.get("score", 0), 4),
"start": entity.get("start"),
"end": entity.get("end")
})
return create_success_response(
self.name,
{
"text": text[:100] + "..." if len(text) > 100 else text,
"model": model_id,
"entities": entities,
"count": len(entities)
}
)
async def classify_text(
self,
text: str,
candidate_labels: List[str],
model: Optional[str] = None
) -> Dict[str, Any]:
"""
Zero-shot text classification.
Args:
text: Text to classify
candidate_labels: List of possible labels
model: Custom classification model
"""
if not text or len(text.strip()) < 3:
return create_error_response(
self.name,
"Invalid text",
"Text must be at least 3 characters"
)
if not candidate_labels or len(candidate_labels) < 2:
return create_error_response(
self.name,
"Invalid labels",
"At least 2 candidate labels required"
)
text = text[:500]
model_id = model or self.MODELS["classification"]
payload = {
"inputs": text,
"parameters": {
"candidate_labels": candidate_labels[:10] # Limit labels
}
}
response = await self.post(model_id, json_data=payload)
if not response.get("success"):
return response
data = response.get("data", {})
if isinstance(data, dict) and data.get("error"):
error_msg = data.get("error", "Model error")
if "loading" in error_msg.lower():
return create_error_response(
self.name,
"Model is loading",
"Please retry in a few seconds"
)
return create_error_response(self.name, error_msg)
# Parse classification results
labels = data.get("labels", [])
scores = data.get("scores", [])
classifications = []
for label, score in zip(labels, scores):
classifications.append({
"label": label,
"score": round(score, 4)
})
return create_success_response(
self.name,
{
"text": text[:100] + "..." if len(text) > 100 else text,
"model": model_id,
"classifications": classifications,
"bestLabel": labels[0] if labels else None,
"bestScore": round(scores[0], 4) if scores else 0.0
}
)
async def get_available_models(self) -> Dict[str, Any]:
"""Get list of available models for each task"""
return create_success_response(
self.name,
{
"models": self.MODELS,
"tasks": list(self.MODELS.keys())
}
)