Spaces:
Build error
Build error
| import cv2 | |
| import json | |
| import gradio as gr | |
| import numpy as np | |
| import tensorflow as tf | |
| from backbone import create_name_vit | |
| from backbone import ClassificationModel | |
| vit_l16_512 = { | |
| "backbone_name": "vit-l/16", | |
| "backbone_params": { | |
| "image_size": 512, | |
| "representation_size": 0, | |
| "attention_dropout_rate": 0., | |
| "dropout_rate": 0., | |
| "channels": 3 | |
| }, | |
| "dropout_rate": 0., | |
| "pretrained": "./weights/vit_l16_512/model-weights" | |
| } | |
| # Init backbone | |
| backbone = create_name_vit(vit_l16_512["backbone_name"], **vit_l16_512["backbone_params"]) | |
| # Init classification model | |
| model = ClassificationModel( | |
| backbone=backbone, | |
| dropout_rate=vit_l16_512["dropout_rate"], | |
| num_classes=1000 | |
| ) | |
| # Load weights | |
| model.load_weights(vit_l16_512["pretrained"]) | |
| model.trainable = False | |
| # Load ImageNet idx to label mapping | |
| with open("assets/imagenet_1000_idx2labels.json") as f: | |
| idx_to_label = json.load(f) | |
| def resize_with_normalization(image, size=[512, 512]): | |
| image = tf.cast(image, tf.float32) | |
| image = tf.image.resize(image, size) | |
| image -= tf.constant(127.5, shape=(1, 1, 3), dtype=tf.float32) | |
| image /= tf.constant(127.5, shape=(1, 1, 3), dtype=tf.float32) | |
| image = tf.expand_dims(image, axis=0) | |
| return image | |
| def softmax_stable(x): | |
| return(np.exp(x - np.max(x)) / np.exp(x - np.max(x)).sum()) | |
| def classify_image(img, top_k): | |
| img = tf.convert_to_tensor(img) | |
| img = resize_with_normalization(img) | |
| pred_logits = model.predict(img, batch_size=1, workers=8)[0] | |
| pred_probs = softmax_stable(pred_logits) | |
| top_k_labels = pred_probs.argsort()[-top_k:][::-1] | |
| return {idx_to_label[str(idx)] : round(float(pred_probs[idx]), 4) for idx in top_k_labels} | |
| demo = gr.Interface( | |
| classify_image, | |
| inputs=[gr.Image(), gr.Slider(0, 1000, value=5)], | |
| outputs=gr.outputs.Label(), | |
| title="Image Classification with Kakao Brain ViT", | |
| examples=[ | |
| ["assets/halloween-gaf8ad7ebc_1920.jpeg", 5], | |
| ["assets/IMG_4484.jpeg", 5], | |
| ["assets/IMG_4737.jpeg", 5], | |
| ["assets/IMG_4740.jpeg", 5], | |
| ], | |
| ) | |
| demo.launch() |