Spaces:
Runtime error
Runtime error
Commit
·
06b1c65
1
Parent(s):
7ca342d
Upload app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from transformers import pipeline
|
| 3 |
+
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
| 4 |
+
import openpyxl
|
| 5 |
+
|
| 6 |
+
#Function to predict the food from the image using the pre-trained model "nateraw/food"
|
| 7 |
+
def predict(image):
|
| 8 |
+
extractor = AutoFeatureExtractor.from_pretrained("nateraw/food")
|
| 9 |
+
model = AutoModelForImageClassification.from_pretrained("nateraw/food")
|
| 10 |
+
|
| 11 |
+
input = extractor(images=image, return_tensors='pt')
|
| 12 |
+
output = model(**input)
|
| 13 |
+
logits = output.logits
|
| 14 |
+
|
| 15 |
+
pred_class = logits.argmax(-1).item()
|
| 16 |
+
return(model.config.id2label[pred_class])
|
| 17 |
+
|
| 18 |
+
#Function to retrieve the Nutritional Value from database.xlsx which is downloaded from USDA
|
| 19 |
+
def check_food(food, counter):
|
| 20 |
+
path = './database.xlsx'
|
| 21 |
+
wb_obj = openpyxl.load_workbook(path)
|
| 22 |
+
sheet_obj = wb_obj.active
|
| 23 |
+
|
| 24 |
+
foodPred, cal, carb, prot, fat = None, None, None, None, None
|
| 25 |
+
|
| 26 |
+
#Filter to prioritize the most probable match between the prediction and the entries in the database
|
| 27 |
+
for i in range(3, sheet_obj.max_row+1):
|
| 28 |
+
cell_obj = sheet_obj.cell(row = i, column = 2)
|
| 29 |
+
if counter == 0:
|
| 30 |
+
if len(food) >= 3:
|
| 31 |
+
foodName = food[0].capitalize() + " " + food[1] + " " + food[2] + ","
|
| 32 |
+
elif len(food) == 2:
|
| 33 |
+
foodName = food[0].capitalize() + " " + food[1] + ","
|
| 34 |
+
elif len(food) == 1:
|
| 35 |
+
foodName = food[0].capitalize() + ","
|
| 36 |
+
condition = foodName == cell_obj.value[0:len(foodName):]
|
| 37 |
+
elif counter == 1:
|
| 38 |
+
if len(food) >= 3:
|
| 39 |
+
foodName = food[0].capitalize() + " " + food[1] + " " + food[2]
|
| 40 |
+
elif len(food) == 2:
|
| 41 |
+
foodName = food[0].capitalize() + " " + food[1]
|
| 42 |
+
elif len(food) == 1:
|
| 43 |
+
foodName = food[0].capitalize()
|
| 44 |
+
condition = foodName == cell_obj.value[0:len(foodName):]
|
| 45 |
+
elif counter == 2:
|
| 46 |
+
if len(food) >= 3:
|
| 47 |
+
foodName = food[0] + " " + food[1] + " " + food[2]
|
| 48 |
+
elif len(food) == 2:
|
| 49 |
+
foodName = food[0] + " " + food[1]
|
| 50 |
+
elif len(food) == 1:
|
| 51 |
+
foodName = food[0]
|
| 52 |
+
condition = foodName in cell_obj.value
|
| 53 |
+
elif (counter == 3) & (len(food) > 1):
|
| 54 |
+
condition = food[0] in cell_obj.value
|
| 55 |
+
else:
|
| 56 |
+
break
|
| 57 |
+
|
| 58 |
+
#Update values if conditions are met
|
| 59 |
+
if condition:
|
| 60 |
+
foodPred = cell_obj.value
|
| 61 |
+
cal = sheet_obj.cell(row = i, column = 5).value
|
| 62 |
+
carb = sheet_obj.cell(row = i, column = 7).value
|
| 63 |
+
prot = sheet_obj.cell(row = i, column = 6).value
|
| 64 |
+
fat = sheet_obj.cell(row = i, column = 10).value
|
| 65 |
+
break
|
| 66 |
+
|
| 67 |
+
return foodPred, cal, carb, prot, fat
|
| 68 |
+
|
| 69 |
+
#Function to prepare the output
|
| 70 |
+
def get_cc(food, weight):
|
| 71 |
+
|
| 72 |
+
#Configure the food string to match the entries in the database
|
| 73 |
+
food = food.split("_")
|
| 74 |
+
if food[-1][-1] == "s":
|
| 75 |
+
food[-1] = food[-1][:-1]
|
| 76 |
+
|
| 77 |
+
foodPred, cal, carb, prot, fat = None, None, None, None, None
|
| 78 |
+
counter = 0
|
| 79 |
+
|
| 80 |
+
#Try for the most probable match between the prediction and the entries in the database
|
| 81 |
+
while (not foodPred) & (counter <= 3):
|
| 82 |
+
foodPred, cal, carb, prot, fat = check_food(food,counter)
|
| 83 |
+
counter += 1
|
| 84 |
+
|
| 85 |
+
#Check if there is a match
|
| 86 |
+
if food:
|
| 87 |
+
output = foodPred + "\nCalories: " + str(round(cal * weight)/100) + " kJ\nCarbohydrate: " + str(round(carb * weight)/100) + " g\nProtein: " + str(round(prot * weight)/100) + " g\nTotal Fat: " + str(round(fat * weight)/100) + " g"
|
| 88 |
+
elif not food:
|
| 89 |
+
output = "No data for food"
|
| 90 |
+
|
| 91 |
+
return(output)
|
| 92 |
+
|
| 93 |
+
#Main function
|
| 94 |
+
def CC(image, weight):
|
| 95 |
+
pred = predict(image)
|
| 96 |
+
cc = get_cc(pred, weight)
|
| 97 |
+
return(pred, cc)
|
| 98 |
+
|
| 99 |
+
interface = gr.Interface(
|
| 100 |
+
fn = CC,
|
| 101 |
+
inputs = [gr.inputs.Image(shape=(224,224)), gr.inputs.Number(default = 100, label = "Weight in grams (g):")],
|
| 102 |
+
outputs = [gr.outputs.Textbox(label='Food Prediction:'), gr.outputs.Textbox(label='Nutritional Value:')],
|
| 103 |
+
examples = [["pizza.jpg", 107], ["spaghetti.jpg",205]])
|
| 104 |
+
|
| 105 |
+
interface.launch()
|