| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| import einops | |
| def predict(img): | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| model = models.resnet50() | |
| model.fc = nn.Linear(2048, 720) | |
| resume_path = 'full+++++.pth' | |
| model.load_state_dict(torch.load(resume_path, map_location=torch.device(device))) | |
| model.to(device) | |
| with torch.no_grad(): | |
| model.eval() | |
| img = cv2.resize(img, (224, 224))/255. | |
| img = np.stack([einops.rearrange(img, 'h w c -> c h w')], 0) | |
| img = torch.Tensor(img).float().to(device) | |
| pred = model(img) | |
| max_pred = torch.argsort(pred, dim=1, descending=True) | |
| max_h = (max_pred[0][0] // 60).item() | |
| max_m = (max_pred[0][0] % 60).item() | |
| return '{}:{}'.format(str(max_h), str(max_m).zfill(2)) | |
| inputs = gr.inputs.Image() | |
| io = gr.Interface( | |
| fn=predict, | |
| description='Note that this model ingests clocks that are already cropped, i.e. we do not run object detection.', | |
| title='It\'s About Time: Analog Clock Reading in the Wild', | |
| inputs=inputs, | |
| examples=['d1.png', 'd2.png'], | |
| outputs="text", | |
| ) | |
| io.launch(share=True) | |