Spaces:
Sleeping
Sleeping
Update app.py
Browse filesRussian Translation Added
app.py
CHANGED
|
@@ -9,7 +9,7 @@ import numpy as np
|
|
| 9 |
from PIL import Image
|
| 10 |
model_name = "./best_model"
|
| 11 |
processor = ViTImageProcessor.from_pretrained(model_name)
|
| 12 |
-
labels = ['
|
| 13 |
|
| 14 |
class ViTForImageClassificationWithAttention(ViTForImageClassification):
|
| 15 |
def forward(self, pixel_values):
|
|
@@ -44,14 +44,14 @@ def classify_image(image):
|
|
| 44 |
# Create a bar chart
|
| 45 |
fig_bar = go.Figure(
|
| 46 |
data=[go.Bar(x=[label for label, prob in top_k_label_probs], y=[prob for label, prob in top_k_label_probs])])
|
| 47 |
-
fig_bar.update_layout(title="
|
| 48 |
-
yaxis_title="
|
| 49 |
|
| 50 |
# Create a heatmap
|
| 51 |
if attention is not None:
|
| 52 |
fig_heatmap = go.Figure(
|
| 53 |
data=[go.Heatmap(z=attention[0][0, 0, :, :].detach().numpy(), colorscale='Viridis', showscale=False)])
|
| 54 |
-
fig_heatmap.update_layout(title="
|
| 55 |
else:
|
| 56 |
fig_heatmap = go.Figure() # Return an empty plot
|
| 57 |
|
|
@@ -66,7 +66,7 @@ def classify_image(image):
|
|
| 66 |
heatmap_color[:, :, 0] = heatmap # Red channel
|
| 67 |
heatmap_color[:, :, 1] = heatmap # Green channel
|
| 68 |
heatmap_color[:, :, 2] = 0 # Blue channel
|
| 69 |
-
attention_overlay = (img_array * 0.
|
| 70 |
attention_overlay = Image.fromarray(attention_overlay)
|
| 71 |
attention_overlay.save("attention_overlay.png")
|
| 72 |
attention_overlay = gr.Image("attention_overlay.png")
|
|
@@ -102,7 +102,7 @@ def update_model(image, label):
|
|
| 102 |
# Save the updated model
|
| 103 |
torch.save(model.state_dict(), "best_model.pth")
|
| 104 |
|
| 105 |
-
return "
|
| 106 |
|
| 107 |
|
| 108 |
demo = gr.TabbedInterface(
|
|
@@ -113,11 +113,11 @@ demo = gr.TabbedInterface(
|
|
| 113 |
gr.Image(type="pil", label="Image")
|
| 114 |
],
|
| 115 |
outputs=[
|
| 116 |
-
gr.Label(label="
|
| 117 |
-
gr.Plot(label="
|
| 118 |
],
|
| 119 |
-
title="
|
| 120 |
-
description="
|
| 121 |
allow_flagging=False
|
| 122 |
),
|
| 123 |
gr.Interface(
|
|
@@ -132,14 +132,14 @@ demo = gr.TabbedInterface(
|
|
| 132 |
)
|
| 133 |
],
|
| 134 |
outputs=[
|
| 135 |
-
gr.Textbox(label="
|
| 136 |
],
|
| 137 |
-
title="
|
| 138 |
-
description="
|
| 139 |
allow_flagging=False
|
| 140 |
)
|
| 141 |
],
|
| 142 |
-
title="
|
| 143 |
)
|
| 144 |
|
| 145 |
if __name__ == "__main__":
|
|
|
|
| 9 |
from PIL import Image
|
| 10 |
model_name = "./best_model"
|
| 11 |
processor = ViTImageProcessor.from_pretrained(model_name)
|
| 12 |
+
labels = ['Акне или розацеа', 'Актинический кератоз, базальноклеточная карцинома и другие злокачественные поражения', 'Атопический дерматит', 'Буллезное заболевание', 'Целлюлит, импетиго и другие бактериальные инфекции', 'Контактный дерматит', 'Экзема', 'Экзантемы и лекарственные высыпания', 'Фотографии потери волос, алопеция и другие заболевания волос', 'Герпес, ВПЧ и другие ЗППП', 'Легкие заболевания и нарушения пигментации', 'Волчанка и другие заболевания соединительной ткани', 'Меланома, рак кожи, невусы и родинки', 'Грибок ногтей и другие заболевания ногтей', 'Фотографии псориаза, красный плоский лишай и связанные с ним заболевания', 'Чесотка, болезнь Лайма и другие инвазии и укусы', 'Себорейный кератоз и другие Доброкачественные опухоли', 'Системные заболевания', 'Опоясывающий лишай, кандидоз и другие грибковые инфекции', 'Крапивница', 'Сосудистые опухоли', 'Васкулит', 'Бородавки, моллюск и другие вирусные инфекции']
|
| 13 |
|
| 14 |
class ViTForImageClassificationWithAttention(ViTForImageClassification):
|
| 15 |
def forward(self, pixel_values):
|
|
|
|
| 44 |
# Create a bar chart
|
| 45 |
fig_bar = go.Figure(
|
| 46 |
data=[go.Bar(x=[label for label, prob in top_k_label_probs], y=[prob for label, prob in top_k_label_probs])])
|
| 47 |
+
fig_bar.update_layout(title="Топ 5 диагнозов в порядке убывания вероятности", xaxis_title="Диагноз",
|
| 48 |
+
yaxis_title="Вероятность")
|
| 49 |
|
| 50 |
# Create a heatmap
|
| 51 |
if attention is not None:
|
| 52 |
fig_heatmap = go.Figure(
|
| 53 |
data=[go.Heatmap(z=attention[0][0, 0, :, :].detach().numpy(), colorscale='Viridis', showscale=False)])
|
| 54 |
+
fig_heatmap.update_layout(title="Карта внимания системы")
|
| 55 |
else:
|
| 56 |
fig_heatmap = go.Figure() # Return an empty plot
|
| 57 |
|
|
|
|
| 66 |
heatmap_color[:, :, 0] = heatmap # Red channel
|
| 67 |
heatmap_color[:, :, 1] = heatmap # Green channel
|
| 68 |
heatmap_color[:, :, 2] = 0 # Blue channel
|
| 69 |
+
attention_overlay = (img_array * 0.35 + heatmap_color * 0.75).astype(np.uint8)
|
| 70 |
attention_overlay = Image.fromarray(attention_overlay)
|
| 71 |
attention_overlay.save("attention_overlay.png")
|
| 72 |
attention_overlay = gr.Image("attention_overlay.png")
|
|
|
|
| 102 |
# Save the updated model
|
| 103 |
torch.save(model.state_dict(), "best_model.pth")
|
| 104 |
|
| 105 |
+
return "Модель успешно обновлена"
|
| 106 |
|
| 107 |
|
| 108 |
demo = gr.TabbedInterface(
|
|
|
|
| 113 |
gr.Image(type="pil", label="Image")
|
| 114 |
],
|
| 115 |
outputs=[
|
| 116 |
+
gr.Label(label="Предсказанный диагноз"),
|
| 117 |
+
gr.Plot(label="Топ 5 диагнозов в порядке убывания вероятности")
|
| 118 |
],
|
| 119 |
+
title="DermaScan Demo",
|
| 120 |
+
description="Загрузите изображение, чтобы увидеть прогнозируемую метку класса, 5 лучших прогнозируемых меток с вероятностями и тепловую карту внимания.",
|
| 121 |
allow_flagging=False
|
| 122 |
),
|
| 123 |
gr.Interface(
|
|
|
|
| 132 |
)
|
| 133 |
],
|
| 134 |
outputs=[
|
| 135 |
+
gr.Textbox(label="Обновление модели")
|
| 136 |
],
|
| 137 |
+
title="Обучить модель",
|
| 138 |
+
description="Загрузите изображение и метку для обновления модели.",
|
| 139 |
allow_flagging=False
|
| 140 |
)
|
| 141 |
],
|
| 142 |
+
title="DermaScan Demo"
|
| 143 |
)
|
| 144 |
|
| 145 |
if __name__ == "__main__":
|