Suchith-nj commited on
Commit
c1f097f
Β·
1 Parent(s): b3a9019

using pretrained model

Browse files
Files changed (1) hide show
  1. pages/1_Image_Classifier.py +43 -34
pages/1_Image_Classifier.py CHANGED
@@ -1,54 +1,63 @@
1
  import streamlit as st
2
  from PIL import Image
3
- import requests
 
4
 
5
- st.set_page_config(
6
- page_title="Food Classifier",
7
- page_icon="πŸ•",
8
- layout="wide"
9
- )
10
 
11
  st.title("Food Classification")
12
- st.markdown("AI-powered food recognition")
13
-
14
- API_URL = "https://api-inference.huggingface.co/models/nateraw/vit-base-beans"
15
-
16
- def classify_image(image_bytes):
17
- try:
18
- response = requests.post(API_URL, data=image_bytes, timeout=30)
19
- return response.json()
20
- except:
21
- return None
22
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  col1, col2 = st.columns(2)
24
 
25
  with col1:
26
  st.subheader("Upload Image")
27
- uploaded_file = st.file_uploader("Choose a food image", type=['jpg', 'jpeg', 'png'])
28
-
29
  if uploaded_file:
30
  image = Image.open(uploaded_file)
31
- st.image(image, use_column_width=True)
32
 
33
  with col2:
34
  st.subheader("Results")
35
-
36
  if uploaded_file:
37
  with st.spinner("Analyzing..."):
38
- results = classify_image(uploaded_file.getvalue())
39
-
40
- if results and isinstance(results, list):
41
- for i, result in enumerate(results[:5], 1):
42
- label = result.get('label', 'Unknown')
43
- score = result.get('score', 0)
44
-
45
- st.write(f"**{i}. {label}**")
46
- st.progress(score)
47
- st.caption(f"{score*100:.1f}%")
48
  else:
49
- st.info("Model loading. Wait 20 seconds and retry.")
50
  else:
51
- st.info("Upload an image to classify")
52
 
53
  st.markdown("---")
54
- st.caption("Week 1 Project - Image Classification Pipeline")
 
1
  import streamlit as st
2
  from PIL import Image
3
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
4
+ import torch
5
 
6
+ # Streamlit page config
7
+ st.set_page_config(page_title="Food Classifier", layout="wide")
 
 
 
8
 
9
  st.title("Food Classification")
10
+ st.markdown("AI-powered food recognition using a pretrained Food101 model")
11
+
12
+ # Load processor and model
13
+ processor = AutoImageProcessor.from_pretrained("nateraw/food")
14
+ model = AutoModelForImageClassification.from_pretrained("nateraw/food")
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ model.to(device)
17
+ model.eval() # set model to evaluation mode
18
+
19
+ # Function to preprocess and classify image
20
+ def classify_image(image: Image.Image):
21
+ inputs = processor(images=image.convert("RGB"), return_tensors="pt")
22
+ # Move to device
23
+ for k, v in inputs.items():
24
+ inputs[k] = v.to(device)
25
+ # Forward pass
26
+ with torch.no_grad():
27
+ outputs = model(**inputs)
28
+ # Softmax probabilities
29
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
30
+ # Get top 5 predictions
31
+ top_prob, top_class = torch.topk(probs, k=5)
32
+ id2label = model.config.id2label
33
+ results = []
34
+ for prob, idx in zip(top_prob[0], top_class[0]):
35
+ results.append({"label": id2label[idx.item()], "score": prob.item()})
36
+ return results
37
+
38
+ # Streamlit layout
39
  col1, col2 = st.columns(2)
40
 
41
  with col1:
42
  st.subheader("Upload Image")
43
+ uploaded_file = st.file_uploader("Choose a food image", type=["jpg", "jpeg", "png"])
 
44
  if uploaded_file:
45
  image = Image.open(uploaded_file)
46
+ st.image(image, use_container_width=True)
47
 
48
  with col2:
49
  st.subheader("Results")
 
50
  if uploaded_file:
51
  with st.spinner("Analyzing..."):
52
+ results = classify_image(image)
53
+ if results:
54
+ for i, r in enumerate(results, 1):
55
+ st.write(f"{i}. {r['label']} β€” {r['score']*100:.2f}%")
56
+ st.progress(r['score'])
 
 
 
 
 
57
  else:
58
+ st.info("Could not classify the image. Try again.")
59
  else:
60
+ st.info("Upload an image to classify.")
61
 
62
  st.markdown("---")
63
+ st.caption("Week 1 Project - Image Classification Pipeline")