Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -717,11 +717,14 @@ tokenizer = AutoTokenizer.from_pretrained(main_model_name)
|
|
| 717 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 718 |
model.to(device)
|
| 719 |
|
| 720 |
-
# Load fallback multilingual model
|
| 721 |
-
multi_model_name = "
|
| 722 |
multi_tokenizer = AutoTokenizer.from_pretrained(multi_model_name)
|
| 723 |
multi_model = AutoModelForSequenceClassification.from_pretrained(multi_model_name).to(device)
|
| 724 |
|
|
|
|
|
|
|
|
|
|
| 725 |
# Reddit API setup
|
| 726 |
reddit = praw.Reddit(
|
| 727 |
client_id=os.getenv("REDDIT_CLIENT_ID"),
|
|
@@ -741,14 +744,7 @@ def multilingual_classifier(text):
|
|
| 741 |
with torch.no_grad():
|
| 742 |
output = multi_model(**encoded_input)
|
| 743 |
scores = softmax(output.logits.cpu().numpy()[0])
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
if stars in [1, 2]:
|
| 747 |
-
return "Prediction: Negative"
|
| 748 |
-
elif stars == 3:
|
| 749 |
-
return "Prediction: Neutral"
|
| 750 |
-
else:
|
| 751 |
-
return "Prediction: Positive"
|
| 752 |
|
| 753 |
def clean_ocr_text(text):
|
| 754 |
text = text.strip()
|
|
@@ -867,6 +863,13 @@ demo.launch()
|
|
| 867 |
|
| 868 |
|
| 869 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 870 |
|
| 871 |
|
| 872 |
|
|
|
|
| 717 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 718 |
model.to(device)
|
| 719 |
|
| 720 |
+
# Load fallback multilingual model (direct sentiment labels)
|
| 721 |
+
multi_model_name = "cardiffnlp/twitter-xlm-roberta-base-sentiment"
|
| 722 |
multi_tokenizer = AutoTokenizer.from_pretrained(multi_model_name)
|
| 723 |
multi_model = AutoModelForSequenceClassification.from_pretrained(multi_model_name).to(device)
|
| 724 |
|
| 725 |
+
# Labels for multilingual model
|
| 726 |
+
multi_labels = ['Negative', 'Neutral', 'Positive']
|
| 727 |
+
|
| 728 |
# Reddit API setup
|
| 729 |
reddit = praw.Reddit(
|
| 730 |
client_id=os.getenv("REDDIT_CLIENT_ID"),
|
|
|
|
| 744 |
with torch.no_grad():
|
| 745 |
output = multi_model(**encoded_input)
|
| 746 |
scores = softmax(output.logits.cpu().numpy()[0])
|
| 747 |
+
return f"Prediction: {multi_labels[np.argmax(scores)]}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 748 |
|
| 749 |
def clean_ocr_text(text):
|
| 750 |
text = text.strip()
|
|
|
|
| 863 |
|
| 864 |
|
| 865 |
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
|
| 871 |
+
|
| 872 |
+
|
| 873 |
|
| 874 |
|
| 875 |
|