Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -138,7 +138,7 @@ demo = gr.Interface(
|
|
| 138 |
|
| 139 |
demo.launch()
|
| 140 |
'''
|
| 141 |
-
import gradio as gr
|
| 142 |
from transformers import TFBertForSequenceClassification, BertTokenizer
|
| 143 |
import tensorflow as tf
|
| 144 |
import praw
|
|
@@ -213,5 +213,91 @@ demo = gr.Interface(
|
|
| 213 |
)
|
| 214 |
|
| 215 |
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
|
|
|
|
| 138 |
|
| 139 |
demo.launch()
|
| 140 |
'''
|
| 141 |
+
'''import gradio as gr
|
| 142 |
from transformers import TFBertForSequenceClassification, BertTokenizer
|
| 143 |
import tensorflow as tf
|
| 144 |
import praw
|
|
|
|
| 213 |
)
|
| 214 |
|
| 215 |
demo.launch()
|
| 216 |
+
'''
|
| 217 |
+
import gradio as gr
|
| 218 |
+
from transformers import TFBertForSequenceClassification, BertTokenizer, pipeline
|
| 219 |
+
import tensorflow as tf
|
| 220 |
+
import praw
|
| 221 |
+
import os
|
| 222 |
+
|
| 223 |
+
# Load main BERT model and tokenizer
|
| 224 |
+
model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
|
| 225 |
+
tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
|
| 226 |
+
|
| 227 |
+
# Load fallback sentiment pipeline model
|
| 228 |
+
fallback_classifier = pipeline("text-classification", model="VinMir/GordonAI-sentiment_analysis")
|
| 229 |
+
|
| 230 |
+
# Label mapping for main model
|
| 231 |
+
LABELS = {
|
| 232 |
+
0: "Neutral",
|
| 233 |
+
1: "Positive",
|
| 234 |
+
2: "Negative"
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
# Reddit API setup (secure credentials from Hugging Face secrets)
|
| 238 |
+
reddit = praw.Reddit(
|
| 239 |
+
client_id=os.getenv("REDDIT_CLIENT_ID"),
|
| 240 |
+
client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
|
| 241 |
+
user_agent=os.getenv("REDDIT_USER_AGENT", "sentiment-classifier-script")
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# Fetch content from Reddit URL
|
| 245 |
+
def fetch_reddit_text(reddit_url):
|
| 246 |
+
try:
|
| 247 |
+
submission = reddit.submission(url=reddit_url)
|
| 248 |
+
return f"{submission.title}\n\n{submission.selftext}"
|
| 249 |
+
except Exception as e:
|
| 250 |
+
return f"Error fetching Reddit post: {str(e)}"
|
| 251 |
+
|
| 252 |
+
# Sentiment classification function
|
| 253 |
+
def classify_sentiment(text_input, reddit_url):
|
| 254 |
+
if reddit_url.strip():
|
| 255 |
+
text = fetch_reddit_text(reddit_url)
|
| 256 |
+
elif text_input.strip():
|
| 257 |
+
text = text_input
|
| 258 |
+
else:
|
| 259 |
+
return "[!] Please enter some text or a Reddit post URL."
|
| 260 |
+
|
| 261 |
+
if text.lower().startswith("error") or "Unable to extract" in text:
|
| 262 |
+
return f"[!] {text}"
|
| 263 |
+
|
| 264 |
+
try:
|
| 265 |
+
# Main BERT model prediction
|
| 266 |
+
inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
|
| 267 |
+
outputs = model(inputs)
|
| 268 |
+
probs = tf.nn.softmax(outputs.logits, axis=1)
|
| 269 |
+
confidence = float(tf.reduce_max(probs).numpy())
|
| 270 |
+
pred_label = tf.argmax(probs, axis=1).numpy()[0]
|
| 271 |
+
|
| 272 |
+
if confidence < 0.5:
|
| 273 |
+
# Use fallback model silently
|
| 274 |
+
fallback = fallback_classifier(text)[0]['label']
|
| 275 |
+
return f"Prediction: {fallback}"
|
| 276 |
+
|
| 277 |
+
return f"Prediction: {LABELS[pred_label]}"
|
| 278 |
+
except Exception as e:
|
| 279 |
+
return f"[!] Prediction error: {str(e)}"
|
| 280 |
+
|
| 281 |
+
# Gradio interface
|
| 282 |
+
demo = gr.Interface(
|
| 283 |
+
fn=classify_sentiment,
|
| 284 |
+
inputs=[
|
| 285 |
+
gr.Textbox(
|
| 286 |
+
label="Text Input (can be tweet or any content)",
|
| 287 |
+
placeholder="Paste tweet or type any content here...",
|
| 288 |
+
lines=4
|
| 289 |
+
),
|
| 290 |
+
gr.Textbox(
|
| 291 |
+
label="Reddit Post URL",
|
| 292 |
+
placeholder="Paste a Reddit post URL (optional)",
|
| 293 |
+
lines=1
|
| 294 |
+
),
|
| 295 |
+
],
|
| 296 |
+
outputs="text",
|
| 297 |
+
title="Sentiment Analyzer",
|
| 298 |
+
description="🔍 Paste any text (including tweet content) OR a Reddit post URL to analyze sentiment.\n\n💡 Tweet URLs are not supported directly due to platform restrictions. Please paste tweet content manually."
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
demo.launch()
|
| 302 |
|
| 303 |
|