Commit
·
4a5d667
1
Parent(s):
291f55b
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
from doctest import OutputChecker
|
| 3 |
import sys
|
| 4 |
import argparse
|
| 5 |
-
|
| 6 |
import re
|
| 7 |
import os
|
| 8 |
import gradio as gr
|
|
@@ -19,7 +19,7 @@ import requests
|
|
| 19 |
#from sklearn.metrics.pairwise import cosine_similarity
|
| 20 |
|
| 21 |
#device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 22 |
-
|
| 23 |
|
| 24 |
#SentenceTransformer('stsb-distilbert-base', device=device)
|
| 25 |
|
|
@@ -108,8 +108,8 @@ def Visual_re_ranker(caption, visual_context_label, visual_context_prob):
|
|
| 108 |
caption = caption
|
| 109 |
visual_context_label= visual_context_label
|
| 110 |
visual_context_prob = visual_context_prob
|
| 111 |
-
caption_emb =
|
| 112 |
-
visual_context_label_emb =
|
| 113 |
|
| 114 |
|
| 115 |
sim = cosine_scores = util.pytorch_cos_sim(caption_emb, visual_context_label_emb)
|
|
|
|
| 2 |
from doctest import OutputChecker
|
| 3 |
import sys
|
| 4 |
import argparse
|
| 5 |
+
import torch
|
| 6 |
import re
|
| 7 |
import os
|
| 8 |
import gradio as gr
|
|
|
|
| 19 |
#from sklearn.metrics.pairwise import cosine_similarity
|
| 20 |
|
| 21 |
#device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 22 |
+
model_sts = gr.Interface.load('huggingface/sentence-transformers/stsb-distilbert-base')
|
| 23 |
|
| 24 |
#SentenceTransformer('stsb-distilbert-base', device=device)
|
| 25 |
|
|
|
|
| 108 |
caption = caption
|
| 109 |
visual_context_label= visual_context_label
|
| 110 |
visual_context_prob = visual_context_prob
|
| 111 |
+
caption_emb = model_sts.encode(caption, convert_to_tensor=True)
|
| 112 |
+
visual_context_label_emb = model_sts.encode(visual_context_label, convert_to_tensor=True)
|
| 113 |
|
| 114 |
|
| 115 |
sim = cosine_scores = util.pytorch_cos_sim(caption_emb, visual_context_label_emb)
|