Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from model import model, classes | |
| from torchvision import transforms | |
| import torch.nn.functional as F | |
| checkpoint = torch.load('model.pth', map_location=torch.device('cpu')) | |
| # Load the state dictionary into model | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| # Set model to evaluation mode | |
| model.eval() | |
| # Transforms | |
| transform = transforms.Compose([ | |
| transforms.ToPILImage(), | |
| transforms.Resize((32, 32)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | |
| import torch | |
| import torch.nn.functional as F | |
| def classify(img): | |
| # Transform | |
| x = transform(img) | |
| # Add batch dim | |
| x = x.unsqueeze(0) | |
| # Get predictions | |
| preds = model(x) | |
| # Get prediction percentages | |
| perc = F.softmax(preds, dim=1)[0] * 100 | |
| # Get index of top prediction | |
| idx = torch.argmax(preds) | |
| # Get percentage of top prediction | |
| top_perc = perc[idx].item() | |
| # Get class name | |
| class_name = classes[idx] | |
| # Return prediction with percentage | |
| return f"{class_name} ({top_perc:.2f}%)" | |
| iface = gr.Interface(classify, | |
| "image", | |
| "text", | |
| theme="huggingface", | |
| title="Digit Recognition", | |
| description="Upload Image of any Airplane Automobile Bird Cat Deer Dog Frog Horse Ship Truck and the algorithm will detect it in real time! This is CNN trained on CIFAR10 Dataset", | |
| article="<p style='text-align: center'>CIFAR10 Classification | Demo Model by Jugal</p>",live=True) | |
| iface.launch(debug=True) | |