Suchith-nj commited on
Commit
c03853d
·
1 Parent(s): e1ccb1b

Load model directly from Hub (no Inference API)

Browse files
Files changed (1) hide show
  1. pages/1_Image_Classifier.py +49 -112
pages/1_Image_Classifier.py CHANGED
@@ -1,8 +1,7 @@
1
  import streamlit as st
2
  from PIL import Image
3
- import requests
4
- import json
5
- import time
6
 
7
  st.set_page_config(
8
  page_title="Food Classifier",
@@ -11,49 +10,31 @@ st.set_page_config(
11
  )
12
 
13
  st.title("Food-101 Image Classifier")
14
- st.markdown("### AI-powered food recognition using ResNet-50")
15
 
16
- # Model API
17
- API_URL = "https://api-inference.huggingface.co/models/nateraw/food"
 
 
 
 
 
18
 
19
- # Try to get token, ignore if not found
20
  try:
21
- headers = {"Authorization": f"Bearer {st.secrets['HF_TOKEN']}"}
22
- except:
23
- headers = {}
 
 
 
24
 
25
- def query(image_bytes, max_retries=3):
26
- """Query HuggingFace Inference API with retry logic"""
27
- for attempt in range(max_retries):
28
- try:
29
- response = requests.post(API_URL, headers=headers, data=image_bytes, timeout=30)
30
-
31
- if response.status_code == 503:
32
- return {"error": "loading", "message": "Model is loading. Please wait."}
33
-
34
- if response.status_code == 200:
35
- return response.json()
36
-
37
- return {"error": "api_error", "status": response.status_code, "message": response.text}
38
-
39
- except requests.exceptions.Timeout:
40
- if attempt < max_retries - 1:
41
- time.sleep(5)
42
- continue
43
- return {"error": "timeout", "message": "Request timed out. Please try again."}
44
- except Exception as e:
45
- return {"error": "exception", "message": str(e)}
46
-
47
- return {"error": "failed", "message": "Failed after multiple attempts"}
48
-
49
- # Model info
50
  with st.expander("ℹ️ Model Information"):
51
  st.markdown("""
52
- **Model**: ResNet-50 (Fine-tuned)
53
- **Dataset**: Food-101 (101 food categories)
54
- **Training**: 5 epochs, 75K images
55
- **Accuracy**: 41% (vs 1% random baseline)
56
- **HuggingFace**: [View Model](https://huggingface.co/suchithnj12/food101-resnet50)
57
  """)
58
 
59
  # Main interface
@@ -64,8 +45,7 @@ with col1:
64
 
65
  uploaded_file = st.file_uploader(
66
  "Choose a food image",
67
- type=['jpg', 'jpeg', 'png'],
68
- help="Upload an image of food"
69
  )
70
 
71
  if uploaded_file:
@@ -76,86 +56,43 @@ with col2:
76
  st.markdown("### Prediction Results")
77
 
78
  if uploaded_file:
79
- with st.spinner("Analyzing image..."):
80
- # Get image bytes
81
- image_bytes = uploaded_file.getvalue()
82
-
83
- # Query API
84
- results = query(image_bytes)
85
-
86
- # Handle errors
87
- if isinstance(results, dict) and "error" in results:
88
- if results["error"] == "loading":
89
- st.warning("⏳ Model is loading on HuggingFace servers...")
90
- st.info("First request takes 20-30 seconds. Please wait and try again.")
91
- if st.button("Retry"):
92
- st.rerun()
93
- else:
94
- st.error(f"Error: {results.get('message', 'Unknown error')}")
95
- st.info("Please try again in a few seconds.")
96
-
97
- # Handle successful results
98
- elif isinstance(results, list) and len(results) > 0:
99
- # Display top prediction
100
- top_result = results[0]
101
- top_class = top_result['label'].replace('_', ' ').title()
102
- top_confidence = top_result['score']
103
 
104
- st.markdown(f"### 🎯 {top_class}")
105
- st.progress(top_confidence)
106
- st.metric("Confidence", f"{top_confidence*100:.1f}%")
 
 
107
 
108
- st.markdown("---")
 
109
 
110
- # Display top 5
111
- st.markdown("#### Top 5 Predictions")
112
- for i, result in enumerate(results[:5]):
113
- class_name = result['label'].replace('_', ' ').title()
114
- confidence = result['score']
 
115
 
116
- st.markdown(f"**{i+1}. {class_name}**")
117
- st.progress(confidence)
118
- st.caption(f"{confidence*100:.1f}%")
119
- else:
120
- st.error("Unexpected response from model.")
121
- st.info("The model might be initializing. Please wait 30 seconds and try again.")
122
  else:
123
  st.info("👈 Upload an image to get started")
124
 
125
- # Sample images section - REMOVED broken image URLs
126
- st.markdown("---")
127
- st.markdown("### 💡 Tips for Best Results")
128
- st.info("""
129
- - Use clear, well-lit food images
130
- - Ensure food is the main subject
131
- - Works best with common dishes
132
- - First prediction may take 20-30 seconds (model loading)
133
- """)
134
-
135
- # Technical details
136
  st.markdown("---")
137
  with st.expander("🔧 Technical Details"):
138
  st.markdown("""
139
- **Architecture**: ResNet-50 with modified classification head
140
-
141
- **Inference**: HuggingFace Inference API (serverless)
142
-
143
- **Training Details**:
144
- - Optimizer: AdamW
145
- - Learning Rate: 2e-5
146
- - Batch Size: 64
147
- - Mixed Precision: FP16
148
-
149
- **Performance**:
150
- - Test Accuracy: 40.8%
151
- - F1 Score: 38.0%
152
- - Inference Time: 1-3 seconds (after initial load)
153
-
154
- **Sample Categories**:
155
- pizza, sushi, hamburger, pasta, steak, salad, ice cream, cake, and 93 more...
156
-
157
- **Note**: First request takes 20-30 seconds as the model loads on HuggingFace servers.
158
  """)
159
 
160
  st.markdown("---")
161
- st.caption("Week 1 Complete - Using HuggingFace Inference API")
 
1
  import streamlit as st
2
  from PIL import Image
3
+ import torch
4
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
 
5
 
6
  st.set_page_config(
7
  page_title="Food Classifier",
 
10
  )
11
 
12
  st.title("Food-101 Image Classifier")
13
+ st.markdown("### ResNet-50 trained on 75K food images")
14
 
15
+ # Load model (cached)
16
+ @st.cache_resource
17
+ def load_model():
18
+ processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
19
+ model = AutoModelForImageClassification.from_pretrained("suchithnj12/food101-resnet50")
20
+ model.eval()
21
+ return processor, model
22
 
 
23
  try:
24
+ with st.spinner("Loading model (first time takes 30 seconds)..."):
25
+ processor, model = load_model()
26
+ st.success("✅ Model loaded!")
27
+ except Exception as e:
28
+ st.error(f"Error loading model: {str(e)}")
29
+ st.stop()
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  with st.expander("ℹ️ Model Information"):
32
  st.markdown("""
33
+ **Your Model**: suchithnj12/food101-resnet50
34
+ **Base**: ResNet-50
35
+ **Training**: 5 epochs on Food-101
36
+ **Accuracy**: 40.8%
37
+ **Categories**: 101 food types
38
  """)
39
 
40
  # Main interface
 
45
 
46
  uploaded_file = st.file_uploader(
47
  "Choose a food image",
48
+ type=['jpg', 'jpeg', 'png']
 
49
  )
50
 
51
  if uploaded_file:
 
56
  st.markdown("### Prediction Results")
57
 
58
  if uploaded_file:
59
+ with st.spinner("Analyzing..."):
60
+ try:
61
+ # Preprocess
62
+ inputs = processor(images=image, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ # Predict
65
+ with torch.no_grad():
66
+ outputs = model(**inputs)
67
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)
68
+ top5_probs, top5_indices = torch.topk(probs, 5)
69
 
70
+ # Display results
71
+ st.success("✅ Analysis Complete")
72
 
73
+ for i in range(5):
74
+ label = model.config.id2label[top5_indices[0][i].item()]
75
+ score = top5_probs[0][i].item()
76
+
77
+ # Format label
78
+ label = label.replace('_', ' ').title()
79
 
80
+ st.markdown(f"**{i+1}. {label}**")
81
+ st.progress(score)
82
+ st.caption(f"{score*100:.1f}%")
83
+
84
+ except Exception as e:
85
+ st.error(f"Prediction failed: {str(e)}")
86
  else:
87
  st.info("👈 Upload an image to get started")
88
 
 
 
 
 
 
 
 
 
 
 
 
89
  st.markdown("---")
90
  with st.expander("🔧 Technical Details"):
91
  st.markdown("""
92
+ **Model Loading**: Direct from HuggingFace Hub
93
+ **Inference**: On HuggingFace Spaces hardware
94
+ **Test Accuracy**: 40.8%
95
+ **Categories**: apple pie, sushi, pizza, pasta, and 97 more
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  """)
97
 
98
  st.markdown("---")