|
|
"""
|
|
|
HuggingFace Sentiment Provider - AI-powered text analysis
|
|
|
|
|
|
Provides:
|
|
|
- Sentiment analysis using transformer models
|
|
|
- Text summarization
|
|
|
- Named entity recognition
|
|
|
- Zero-shot classification
|
|
|
|
|
|
Uses HuggingFace Inference API for model inference.
|
|
|
API Documentation: https://huggingface.co/docs/api-inference/
|
|
|
"""
|
|
|
|
|
|
from __future__ import annotations
|
|
|
import os
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
|
|
from .base import BaseProvider, create_success_response, create_error_response
|
|
|
|
|
|
|
|
|
class HFSentimentProvider(BaseProvider):
|
|
|
"""HuggingFace Inference API provider for AI-powered analysis"""
|
|
|
|
|
|
|
|
|
API_KEY = os.getenv("HF_API_TOKEN") or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") or ""
|
|
|
|
|
|
|
|
|
MODELS = {
|
|
|
"sentiment": "distilbert-base-uncased-finetuned-sst-2-english",
|
|
|
"sentiment_financial": "mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis",
|
|
|
"summarization": "sshleifer/distilbart-cnn-12-6",
|
|
|
"ner": "dslim/bert-base-NER",
|
|
|
"classification": "facebook/bart-large-mnli",
|
|
|
"text_generation": "gpt2"
|
|
|
}
|
|
|
|
|
|
def __init__(self, api_key: Optional[str] = None):
|
|
|
super().__init__(
|
|
|
name="huggingface",
|
|
|
base_url="https://router.huggingface.co/hf-inference/models",
|
|
|
api_key=api_key or self.API_KEY,
|
|
|
timeout=15.0,
|
|
|
cache_ttl=60.0
|
|
|
)
|
|
|
|
|
|
def _get_default_headers(self) -> Dict[str, str]:
|
|
|
"""Get headers with HuggingFace authorization"""
|
|
|
return {
|
|
|
"Accept": "application/json",
|
|
|
"Content-Type": "application/json",
|
|
|
"Authorization": f"Bearer {self.api_key}"
|
|
|
}
|
|
|
|
|
|
async def analyze_sentiment(
|
|
|
self,
|
|
|
text: str,
|
|
|
model: Optional[str] = None,
|
|
|
use_financial_model: bool = False
|
|
|
) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Analyze sentiment of text using HuggingFace models.
|
|
|
|
|
|
Args:
|
|
|
text: Text to analyze
|
|
|
model: Custom model to use (optional)
|
|
|
use_financial_model: Use FinBERT for financial text
|
|
|
|
|
|
Returns:
|
|
|
Standardized response with sentiment analysis
|
|
|
"""
|
|
|
if not text or len(text.strip()) < 3:
|
|
|
return create_error_response(
|
|
|
self.name,
|
|
|
"Invalid text",
|
|
|
"Text must be at least 3 characters"
|
|
|
)
|
|
|
|
|
|
|
|
|
text = text[:1000]
|
|
|
|
|
|
|
|
|
if model:
|
|
|
model_id = model
|
|
|
elif use_financial_model:
|
|
|
model_id = self.MODELS["sentiment_financial"]
|
|
|
else:
|
|
|
model_id = self.MODELS["sentiment"]
|
|
|
|
|
|
|
|
|
endpoint = f"{model_id}"
|
|
|
|
|
|
response = await self.post(endpoint, json_data={"inputs": text})
|
|
|
|
|
|
if not response.get("success"):
|
|
|
return response
|
|
|
|
|
|
data = response.get("data", [])
|
|
|
|
|
|
|
|
|
if isinstance(data, dict) and data.get("error"):
|
|
|
error_msg = data.get("error", "Model error")
|
|
|
if "loading" in error_msg.lower():
|
|
|
return create_error_response(
|
|
|
self.name,
|
|
|
"Model is loading",
|
|
|
"Please retry in a few seconds"
|
|
|
)
|
|
|
return create_error_response(self.name, error_msg)
|
|
|
|
|
|
|
|
|
results = self._parse_sentiment_results(data, model_id)
|
|
|
|
|
|
return create_success_response(
|
|
|
self.name,
|
|
|
{
|
|
|
"text": text[:100] + "..." if len(text) > 100 else text,
|
|
|
"model": model_id,
|
|
|
"sentiment": results
|
|
|
}
|
|
|
)
|
|
|
|
|
|
def _parse_sentiment_results(self, data: Any, model_id: str) -> Dict[str, Any]:
|
|
|
"""Parse sentiment results from different model formats"""
|
|
|
if not data:
|
|
|
return {"label": "unknown", "score": 0.0}
|
|
|
|
|
|
|
|
|
if isinstance(data, list) and len(data) > 0:
|
|
|
if isinstance(data[0], list):
|
|
|
data = data[0]
|
|
|
|
|
|
|
|
|
best = max(data, key=lambda x: x.get("score", 0))
|
|
|
|
|
|
|
|
|
label = best.get("label", "unknown").lower()
|
|
|
score = best.get("score", 0.0)
|
|
|
|
|
|
|
|
|
label_map = {
|
|
|
"label_0": "negative",
|
|
|
"label_1": "neutral",
|
|
|
"label_2": "positive",
|
|
|
"negative": "negative",
|
|
|
"neutral": "neutral",
|
|
|
"positive": "positive",
|
|
|
"pos": "positive",
|
|
|
"neg": "negative",
|
|
|
"neu": "neutral"
|
|
|
}
|
|
|
|
|
|
normalized_label = label_map.get(label, label)
|
|
|
|
|
|
return {
|
|
|
"label": normalized_label,
|
|
|
"score": round(score, 4),
|
|
|
"allScores": [
|
|
|
{"label": item.get("label"), "score": round(item.get("score", 0), 4)}
|
|
|
for item in data
|
|
|
]
|
|
|
}
|
|
|
|
|
|
return {"label": "unknown", "score": 0.0}
|
|
|
|
|
|
async def summarize_text(
|
|
|
self,
|
|
|
text: str,
|
|
|
max_length: int = 150,
|
|
|
min_length: int = 30,
|
|
|
model: Optional[str] = None
|
|
|
) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Summarize text using HuggingFace summarization model.
|
|
|
|
|
|
Args:
|
|
|
text: Text to summarize
|
|
|
max_length: Maximum summary length
|
|
|
min_length: Minimum summary length
|
|
|
model: Custom model to use
|
|
|
"""
|
|
|
if not text or len(text.strip()) < 50:
|
|
|
return create_error_response(
|
|
|
self.name,
|
|
|
"Text too short",
|
|
|
"Text must be at least 50 characters for summarization"
|
|
|
)
|
|
|
|
|
|
|
|
|
text = text[:3000]
|
|
|
|
|
|
model_id = model or self.MODELS["summarization"]
|
|
|
|
|
|
payload = {
|
|
|
"inputs": text,
|
|
|
"parameters": {
|
|
|
"max_length": max_length,
|
|
|
"min_length": min_length,
|
|
|
"do_sample": False
|
|
|
}
|
|
|
}
|
|
|
|
|
|
response = await self.post(model_id, json_data=payload)
|
|
|
|
|
|
if not response.get("success"):
|
|
|
return response
|
|
|
|
|
|
data = response.get("data", [])
|
|
|
|
|
|
|
|
|
if isinstance(data, dict) and data.get("error"):
|
|
|
error_msg = data.get("error", "Model error")
|
|
|
if "loading" in error_msg.lower():
|
|
|
return create_error_response(
|
|
|
self.name,
|
|
|
"Model is loading",
|
|
|
"Please retry in a few seconds"
|
|
|
)
|
|
|
return create_error_response(self.name, error_msg)
|
|
|
|
|
|
|
|
|
summary = ""
|
|
|
if isinstance(data, list) and len(data) > 0:
|
|
|
summary = data[0].get("summary_text", "")
|
|
|
elif isinstance(data, dict):
|
|
|
summary = data.get("summary_text", "")
|
|
|
|
|
|
return create_success_response(
|
|
|
self.name,
|
|
|
{
|
|
|
"originalLength": len(text),
|
|
|
"summaryLength": len(summary),
|
|
|
"model": model_id,
|
|
|
"summary": summary
|
|
|
}
|
|
|
)
|
|
|
|
|
|
async def extract_entities(
|
|
|
self,
|
|
|
text: str,
|
|
|
model: Optional[str] = None
|
|
|
) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Extract named entities from text.
|
|
|
|
|
|
Args:
|
|
|
text: Text to analyze
|
|
|
model: Custom NER model to use
|
|
|
"""
|
|
|
if not text or len(text.strip()) < 3:
|
|
|
return create_error_response(
|
|
|
self.name,
|
|
|
"Invalid text",
|
|
|
"Text must be at least 3 characters"
|
|
|
)
|
|
|
|
|
|
text = text[:1000]
|
|
|
model_id = model or self.MODELS["ner"]
|
|
|
|
|
|
response = await self.post(model_id, json_data={"inputs": text})
|
|
|
|
|
|
if not response.get("success"):
|
|
|
return response
|
|
|
|
|
|
data = response.get("data", [])
|
|
|
|
|
|
if isinstance(data, dict) and data.get("error"):
|
|
|
error_msg = data.get("error", "Model error")
|
|
|
if "loading" in error_msg.lower():
|
|
|
return create_error_response(
|
|
|
self.name,
|
|
|
"Model is loading",
|
|
|
"Please retry in a few seconds"
|
|
|
)
|
|
|
return create_error_response(self.name, error_msg)
|
|
|
|
|
|
|
|
|
entities = []
|
|
|
if isinstance(data, list):
|
|
|
for entity in data:
|
|
|
entities.append({
|
|
|
"word": entity.get("word"),
|
|
|
"entity": entity.get("entity_group") or entity.get("entity"),
|
|
|
"score": round(entity.get("score", 0), 4),
|
|
|
"start": entity.get("start"),
|
|
|
"end": entity.get("end")
|
|
|
})
|
|
|
|
|
|
return create_success_response(
|
|
|
self.name,
|
|
|
{
|
|
|
"text": text[:100] + "..." if len(text) > 100 else text,
|
|
|
"model": model_id,
|
|
|
"entities": entities,
|
|
|
"count": len(entities)
|
|
|
}
|
|
|
)
|
|
|
|
|
|
async def classify_text(
|
|
|
self,
|
|
|
text: str,
|
|
|
candidate_labels: List[str],
|
|
|
model: Optional[str] = None
|
|
|
) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Zero-shot text classification.
|
|
|
|
|
|
Args:
|
|
|
text: Text to classify
|
|
|
candidate_labels: List of possible labels
|
|
|
model: Custom classification model
|
|
|
"""
|
|
|
if not text or len(text.strip()) < 3:
|
|
|
return create_error_response(
|
|
|
self.name,
|
|
|
"Invalid text",
|
|
|
"Text must be at least 3 characters"
|
|
|
)
|
|
|
|
|
|
if not candidate_labels or len(candidate_labels) < 2:
|
|
|
return create_error_response(
|
|
|
self.name,
|
|
|
"Invalid labels",
|
|
|
"At least 2 candidate labels required"
|
|
|
)
|
|
|
|
|
|
text = text[:500]
|
|
|
model_id = model or self.MODELS["classification"]
|
|
|
|
|
|
payload = {
|
|
|
"inputs": text,
|
|
|
"parameters": {
|
|
|
"candidate_labels": candidate_labels[:10]
|
|
|
}
|
|
|
}
|
|
|
|
|
|
response = await self.post(model_id, json_data=payload)
|
|
|
|
|
|
if not response.get("success"):
|
|
|
return response
|
|
|
|
|
|
data = response.get("data", {})
|
|
|
|
|
|
if isinstance(data, dict) and data.get("error"):
|
|
|
error_msg = data.get("error", "Model error")
|
|
|
if "loading" in error_msg.lower():
|
|
|
return create_error_response(
|
|
|
self.name,
|
|
|
"Model is loading",
|
|
|
"Please retry in a few seconds"
|
|
|
)
|
|
|
return create_error_response(self.name, error_msg)
|
|
|
|
|
|
|
|
|
labels = data.get("labels", [])
|
|
|
scores = data.get("scores", [])
|
|
|
|
|
|
classifications = []
|
|
|
for label, score in zip(labels, scores):
|
|
|
classifications.append({
|
|
|
"label": label,
|
|
|
"score": round(score, 4)
|
|
|
})
|
|
|
|
|
|
return create_success_response(
|
|
|
self.name,
|
|
|
{
|
|
|
"text": text[:100] + "..." if len(text) > 100 else text,
|
|
|
"model": model_id,
|
|
|
"classifications": classifications,
|
|
|
"bestLabel": labels[0] if labels else None,
|
|
|
"bestScore": round(scores[0], 4) if scores else 0.0
|
|
|
}
|
|
|
)
|
|
|
|
|
|
async def get_available_models(self) -> Dict[str, Any]:
|
|
|
"""Get list of available models for each task"""
|
|
|
return create_success_response(
|
|
|
self.name,
|
|
|
{
|
|
|
"models": self.MODELS,
|
|
|
"tasks": list(self.MODELS.keys())
|
|
|
}
|
|
|
)
|
|
|
|