Jugal-sheth's picture
Update app.py
99bfc21
import gradio as gr
import torch
from model import model, classes
from torchvision import transforms
import torch.nn.functional as F
checkpoint = torch.load('model.pth', map_location=torch.device('cpu'))
# Load the state dictionary into model
model.load_state_dict(checkpoint['model_state_dict'])
# Set model to evaluation mode
model.eval()
# Transforms
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
import torch
import torch.nn.functional as F
def classify(img):
# Transform
x = transform(img)
# Add batch dim
x = x.unsqueeze(0)
# Get predictions
preds = model(x)
# Get prediction percentages
perc = F.softmax(preds, dim=1)[0] * 100
# Get index of top prediction
idx = torch.argmax(preds)
# Get percentage of top prediction
top_perc = perc[idx].item()
# Get class name
class_name = classes[idx]
# Return prediction with percentage
return f"{class_name} ({top_perc:.2f}%)"
iface = gr.Interface(classify,
"image",
"text",
theme="huggingface",
title="Digit Recognition",
description="Upload Image of any Airplane Automobile Bird Cat Deer Dog Frog Horse Ship Truck and the algorithm will detect it in real time! This is CNN trained on CIFAR10 Dataset",
article="<p style='text-align: center'>CIFAR10 Classification | Demo Model by Jugal</p>",live=True)
iface.launch(debug=True)