Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import ViTForImageClassification, ViTImageProcessor | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| import plotly.graph_objects as go | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| model_name = "./best_model" | |
| processor = ViTImageProcessor.from_pretrained(model_name) | |
| labels = ['Акне или розацеа', 'Актинический кератоз, базальноклеточная карцинома и другие злокачественные поражения', 'Атопический дерматит', 'Буллезное заболевание', 'Целлюлит, импетиго и другие бактериальные инфекции', 'Контактный дерматит', 'Экзема', 'Экзантемы и лекарственные высыпания', 'Фотографии потери волос, алопеция и другие заболевания волос', 'Герпес, ВПЧ и другие ЗППП', 'Легкие заболевания и нарушения пигментации', 'Волчанка и другие заболевания соединительной ткани', 'Меланома, рак кожи, невусы и родинки', 'Грибок ногтей и другие заболевания ногтей', 'Фотографии псориаза, красный плоский лишай и связанные с ним заболевания', 'Чесотка, болезнь Лайма и другие инвазии и укусы', 'Себорейный кератоз и другие Доброкачественные опухоли', 'Системные заболевания', 'Опоясывающий лишай, кандидоз и другие грибковые инфекции', 'Крапивница', 'Сосудистые опухоли', 'Васкулит', 'Бородавки, моллюск и другие вирусные инфекции'] | |
| class ViTForImageClassificationWithAttention(ViTForImageClassification): | |
| def forward(self, pixel_values): | |
| outputs = super().forward(pixel_values) | |
| attention = self.vit.encoder.layers[0].attention.attention_weights | |
| return outputs, attention | |
| model = ViTForImageClassificationWithAttention.from_pretrained(model_name) | |
| class ViTForImageClassificationWithAttention(ViTForImageClassification): | |
| def forward(self, pixel_values, output_attentions=True): | |
| outputs = super().forward(pixel_values, output_attentions=output_attentions) | |
| attention = outputs.attentions | |
| return outputs, attention | |
| model = ViTForImageClassificationWithAttention.from_pretrained(model_name,attn_implementation="eager") | |
| i_count = 0 | |
| def classify_image(image): | |
| model_name = "best_model.pth" | |
| model.load_state_dict(torch.load(model_name)) | |
| inputs = processor(images=image, return_tensors="pt") | |
| outputs, attention = model(**inputs, output_attentions=True) | |
| logits = outputs.logits | |
| probs = torch.nn.functional.softmax(logits, dim=1) | |
| top_k_probs, top_k_indices = torch.topk(probs, k=5) # show top 5 predicted labels | |
| predicted_class_idx = torch.argmax(logits) | |
| predicted_class_label = labels[predicted_class_idx] | |
| top_k_labels = [labels[idx] for idx in top_k_indices[0]] | |
| top_k_label_probs = [(label, prob.item()) for label, prob in zip(top_k_labels, top_k_probs[0])] | |
| # Create a bar chart | |
| fig_bar = go.Figure( | |
| data=[go.Bar(x=[label for label, prob in top_k_label_probs], y=[prob for label, prob in top_k_label_probs])]) | |
| fig_bar.update_layout(title="Топ 5 диагнозов в порядке убывания вероятности", xaxis_title="Диагноз", | |
| yaxis_title="Вероятность") | |
| # Create a heatmap | |
| if attention is not None: | |
| fig_heatmap = go.Figure( | |
| data=[go.Heatmap(z=attention[0][0, 0, :, :].detach().numpy(), colorscale='Viridis', showscale=False)]) | |
| fig_heatmap.update_layout(title="Карта внимания системы") | |
| else: | |
| fig_heatmap = go.Figure() # Return an empty plot | |
| # Overlay the attention heatmap on the input image | |
| if attention is not None: | |
| img_array = np.array(image) | |
| heatmap = np.array(attention[0][0, 0, :, :].detach().numpy()) | |
| heatmap = np.resize(heatmap, (img_array.shape[0], img_array.shape[1])) | |
| heatmap = heatmap / heatmap.max() * 255 # Normalize heatmap to [0, 255] | |
| heatmap = heatmap.astype(np.uint8) | |
| heatmap_color = np.zeros((img_array.shape[0], img_array.shape[1], 3), dtype=np.uint8) | |
| heatmap_color[:, :, 0] = heatmap # Red channel | |
| heatmap_color[:, :, 1] = heatmap # Green channel | |
| heatmap_color[:, :, 2] = 0 # Blue channel | |
| attention_overlay = (img_array * 0.35 + heatmap_color * 0.75).astype(np.uint8) | |
| attention_overlay = Image.fromarray(attention_overlay) | |
| attention_overlay.save("attention_overlay.png") | |
| attention_overlay = gr.Image("attention_overlay.png") | |
| else: | |
| attention_overlay = gr.Image() # Return an empty image | |
| # Return the predicted label, the bar chart, and the heatmap | |
| return predicted_class_label, fig_bar, fig_heatmap, attention_overlay | |
| def update_model(image, label): | |
| # Convert the label to an integer | |
| label_idx = labels.index(label) | |
| labels_tensor = torch.tensor([label_idx]) | |
| inputs = processor(images=image, return_tensors="pt") | |
| loss_fn = torch.nn.CrossEntropyLoss() | |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | |
| # Zero the gradients | |
| optimizer.zero_grad() | |
| # Forward pass | |
| outputs, attention = model(**inputs) | |
| loss = loss_fn(outputs.logits, labels_tensor) | |
| # Backward pass | |
| loss.backward() | |
| # Update the model parameters | |
| optimizer.step() | |
| # Save the updated model | |
| torch.save(model.state_dict(), "best_model.pth") | |
| return "Модель успешно обновлена" | |
| demo = gr.TabbedInterface( | |
| [ | |
| gr.Interface( | |
| fn=classify_image, | |
| inputs=[ | |
| gr.Image(type="pil", label="Image") | |
| ], | |
| outputs=[ | |
| gr.Label(label="Предсказанный диагноз"), | |
| gr.Plot(label="Топ 5 диагнозов в порядке убывания вероятности") | |
| ], | |
| title="DermaScan Demo", | |
| description="Загрузите изображение, чтобы увидеть прогнозируемую метку класса, 5 лучших прогнозируемых меток с вероятностями и тепловую карту внимания.", | |
| allow_flagging=False | |
| ), | |
| gr.Interface( | |
| fn=update_model, | |
| inputs=[ | |
| gr.Image(type="pil", label="Image"), | |
| gr.Radio( | |
| choices=labels, | |
| type="value", | |
| label="Label", | |
| value=labels[0] | |
| ) | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Обновление модели") | |
| ], | |
| title="Обучить модель", | |
| description="Загрузите изображение и метку для обновления модели.", | |
| allow_flagging=False | |
| ) | |
| ], | |
| title="DermaScan Demo" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |