Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,7 +6,8 @@ import os
|
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
import seaborn as sns
|
| 8 |
import numpy as np
|
| 9 |
-
import
|
|
|
|
| 10 |
|
| 11 |
# Authentification
|
| 12 |
login(token=os.environ["HF_TOKEN"])
|
|
@@ -17,12 +18,10 @@ models_info = {
|
|
| 17 |
"Llama 2": {
|
| 18 |
"7B": {"name": "meta-llama/Llama-2-7b-hf", "languages": ["en"]},
|
| 19 |
"13B": {"name": "meta-llama/Llama-2-13b-hf", "languages": ["en"]},
|
| 20 |
-
"70B": {"name": "meta-llama/Llama-2-70b-hf", "languages": ["en"]},
|
| 21 |
},
|
| 22 |
"Llama 3": {
|
| 23 |
-
"8B": {"name": "meta-llama/
|
| 24 |
"3.2-3B": {"name": "meta-llama/Llama-3.2-3B", "languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"]},
|
| 25 |
-
"3.1-8B": {"name": "meta-llama/Llama-3.1-8B", "languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"]},
|
| 26 |
},
|
| 27 |
},
|
| 28 |
"Mistral AI": {
|
|
@@ -37,8 +36,7 @@ models_info = {
|
|
| 37 |
"Google": {
|
| 38 |
"Gemma": {
|
| 39 |
"2B": {"name": "google/gemma-2-2b", "languages": ["en"]},
|
| 40 |
-
"
|
| 41 |
-
"27B": {"name": "google/gemma-2-27b", "languages": ["en"]},
|
| 42 |
},
|
| 43 |
},
|
| 44 |
"CroissantLLM": {
|
|
@@ -50,31 +48,29 @@ models_info = {
|
|
| 50 |
|
| 51 |
# Paramètres recommandés pour chaque modèle
|
| 52 |
model_parameters = {
|
| 53 |
-
"meta-llama/Llama-2-13b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
|
| 54 |
"meta-llama/Llama-2-7b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
|
| 55 |
-
"meta-llama/Llama-2-
|
| 56 |
-
"meta-llama/
|
| 57 |
"meta-llama/Llama-3.2-3B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
|
| 58 |
-
"meta-llama/Llama-3.1-8B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
|
| 59 |
"mistralai/Mistral-7B-v0.1": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
|
| 60 |
-
"mistralai/Mixtral-8x7B-v0.1": {"temperature": 0.8, "top_p": 0.95, "top_k": 50},
|
| 61 |
"mistralai/Mistral-7B-v0.3": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
|
|
|
|
| 62 |
"google/gemma-2-2b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
|
| 63 |
-
"google/gemma-2-
|
| 64 |
-
"google/gemma-2-27b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
|
| 65 |
"croissantllm/CroissantLLMBase": {"temperature": 0.8, "top_p": 0.92, "top_k": 50}
|
| 66 |
}
|
| 67 |
|
| 68 |
# Variables globales
|
| 69 |
-
|
| 70 |
-
tokenizer = None
|
| 71 |
-
selected_language = None
|
| 72 |
|
|
|
|
| 73 |
def update_model_type(family):
|
| 74 |
return gr.Dropdown(choices=list(models_info[family].keys()), value=None, interactive=True)
|
| 75 |
|
| 76 |
def update_model_variation(family, model_type):
|
| 77 |
-
|
|
|
|
|
|
|
| 78 |
|
| 79 |
def update_selected_model(family, model_type, variation):
|
| 80 |
if family and model_type and variation:
|
|
@@ -82,83 +78,48 @@ def update_selected_model(family, model_type, variation):
|
|
| 82 |
return model_name, gr.Dropdown(choices=models_info[family][model_type][variation]["languages"], value=models_info[family][model_type][variation]["languages"][0], visible=True, interactive=True)
|
| 83 |
return "", gr.Dropdown(visible=False)
|
| 84 |
|
| 85 |
-
def
|
| 86 |
-
global model, tokenizer
|
| 87 |
try:
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
)
|
| 100 |
-
else:
|
| 101 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 102 |
-
model_name,
|
| 103 |
-
torch_dtype=torch.float16,
|
| 104 |
-
device_map="auto"
|
| 105 |
-
)
|
| 106 |
-
|
| 107 |
-
if tokenizer.pad_token is None:
|
| 108 |
-
tokenizer.pad_token = tokenizer.eos_token
|
| 109 |
-
|
| 110 |
-
progress(1.0, desc="Modèle chargé")
|
| 111 |
-
|
| 112 |
-
# Recherche des langues disponibles pour le modèle sélectionné
|
| 113 |
-
available_languages = next(
|
| 114 |
-
(info["languages"] for family in models_info.values()
|
| 115 |
-
for model_type in family.values()
|
| 116 |
-
for variation in model_type.values()
|
| 117 |
-
if variation["name"] == model_name),
|
| 118 |
-
["en"] # Défaut à l'anglais si non trouvé
|
| 119 |
-
)
|
| 120 |
-
|
| 121 |
-
# Mise à jour des sliders avec les valeurs recommandées
|
| 122 |
-
params = model_parameters[model_name]
|
| 123 |
-
return (
|
| 124 |
-
f"Modèle {model_name} chargé avec succès. Langues disponibles : {', '.join(available_languages)}",
|
| 125 |
-
gr.Dropdown(choices=available_languages, value=available_languages[0], visible=True, interactive=True),
|
| 126 |
-
params["temperature"],
|
| 127 |
-
params["top_p"],
|
| 128 |
-
params["top_k"]
|
| 129 |
-
)
|
| 130 |
except Exception as e:
|
| 131 |
-
return f"Erreur lors du chargement du modèle : {str(e)}"
|
| 132 |
|
| 133 |
def set_language(lang):
|
| 134 |
-
global selected_language
|
| 135 |
-
selected_language = lang
|
| 136 |
return f"Langue sélectionnée : {lang}"
|
| 137 |
|
| 138 |
-
def ensure_token_display(token):
|
| 139 |
-
"""Assure que le token est affiché correctement."""
|
| 140 |
if token.isdigit() or (token.startswith('-') and token[1:].isdigit()):
|
| 141 |
return tokenizer.decode([int(token)])
|
| 142 |
return token
|
| 143 |
|
| 144 |
-
def analyze_next_token(input_text, temperature, top_p, top_k):
|
| 145 |
-
|
|
|
|
| 146 |
|
| 147 |
-
|
| 148 |
-
return "Veuillez d'abord charger un modèle.", None, None
|
| 149 |
-
|
| 150 |
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
|
| 151 |
|
| 152 |
try:
|
|
|
|
| 153 |
with torch.no_grad():
|
| 154 |
outputs = model(**inputs)
|
| 155 |
|
| 156 |
last_token_logits = outputs.logits[0, -1, :]
|
| 157 |
probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
|
| 158 |
|
| 159 |
-
top_k = 10
|
| 160 |
top_probs, top_indices = torch.topk(probabilities, top_k)
|
| 161 |
-
top_words = [ensure_token_display(tokenizer.decode([idx.item()])) for idx in top_indices]
|
| 162 |
prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
|
| 163 |
|
| 164 |
prob_text = "Prochains tokens les plus probables :\n\n"
|
|
@@ -166,80 +127,92 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
|
|
| 166 |
prob_text += f"{word}: {prob:.2%}\n"
|
| 167 |
|
| 168 |
prob_plot = plot_probabilities(prob_data)
|
| 169 |
-
attention_plot = plot_attention(inputs["input_ids"][0].cpu(), last_token_logits.cpu())
|
| 170 |
|
|
|
|
| 171 |
return prob_text, attention_plot, prob_plot
|
| 172 |
except Exception as e:
|
| 173 |
return f"Erreur lors de l'analyse : {str(e)}", None, None
|
| 174 |
|
| 175 |
-
def generate_text(input_text, temperature, top_p, top_k):
|
| 176 |
-
|
|
|
|
| 177 |
|
| 178 |
-
|
| 179 |
-
return "Veuillez d'abord charger un modèle."
|
| 180 |
-
|
| 181 |
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
|
| 182 |
|
| 183 |
try:
|
|
|
|
| 184 |
with torch.no_grad():
|
| 185 |
outputs = model.generate(
|
| 186 |
**inputs,
|
| 187 |
-
max_new_tokens=
|
| 188 |
temperature=temperature,
|
| 189 |
top_p=top_p,
|
| 190 |
top_k=top_k
|
| 191 |
)
|
| 192 |
|
| 193 |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
| 194 |
return generated_text
|
| 195 |
except Exception as e:
|
| 196 |
return f"Erreur lors de la génération : {str(e)}"
|
| 197 |
|
| 198 |
def plot_probabilities(prob_data):
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
-
def plot_attention(input_ids, last_token_logits):
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
def reset():
|
| 239 |
-
global
|
| 240 |
-
model
|
| 241 |
-
|
| 242 |
-
|
|
|
|
|
|
|
| 243 |
return (
|
| 244 |
"", 1.0, 1.0, 50, None, None, None, None,
|
| 245 |
gr.Dropdown(choices=list(models_info.keys()), value=None, interactive=True),
|
|
@@ -248,92 +221,139 @@ def reset():
|
|
| 248 |
"", gr.Dropdown(visible=False), ""
|
| 249 |
)
|
| 250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
with gr.Blocks() as demo:
|
| 252 |
gr.Markdown("# LLM&BIAS")
|
| 253 |
|
| 254 |
-
with gr.
|
| 255 |
-
with gr.
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
)
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
outputs=[
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
model_family, model_type, model_variation,
|
| 334 |
-
selected_model, language_dropdown, language_output
|
| 335 |
-
]
|
| 336 |
)
|
| 337 |
|
| 338 |
if __name__ == "__main__":
|
| 339 |
-
demo.launch()
|
|
|
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
import seaborn as sns
|
| 8 |
import numpy as np
|
| 9 |
+
import asyncio
|
| 10 |
+
import gc
|
| 11 |
|
| 12 |
# Authentification
|
| 13 |
login(token=os.environ["HF_TOKEN"])
|
|
|
|
| 18 |
"Llama 2": {
|
| 19 |
"7B": {"name": "meta-llama/Llama-2-7b-hf", "languages": ["en"]},
|
| 20 |
"13B": {"name": "meta-llama/Llama-2-13b-hf", "languages": ["en"]},
|
|
|
|
| 21 |
},
|
| 22 |
"Llama 3": {
|
| 23 |
+
"8B": {"name": "meta-llama/Llama-3-8B", "languages": ["en"]},
|
| 24 |
"3.2-3B": {"name": "meta-llama/Llama-3.2-3B", "languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"]},
|
|
|
|
| 25 |
},
|
| 26 |
},
|
| 27 |
"Mistral AI": {
|
|
|
|
| 36 |
"Google": {
|
| 37 |
"Gemma": {
|
| 38 |
"2B": {"name": "google/gemma-2-2b", "languages": ["en"]},
|
| 39 |
+
"7B": {"name": "google/gemma-2-7b", "languages": ["en"]},
|
|
|
|
| 40 |
},
|
| 41 |
},
|
| 42 |
"CroissantLLM": {
|
|
|
|
| 48 |
|
| 49 |
# Paramètres recommandés pour chaque modèle
|
| 50 |
model_parameters = {
|
|
|
|
| 51 |
"meta-llama/Llama-2-7b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
|
| 52 |
+
"meta-llama/Llama-2-13b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
|
| 53 |
+
"meta-llama/Llama-3-8B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
|
| 54 |
"meta-llama/Llama-3.2-3B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
|
|
|
|
| 55 |
"mistralai/Mistral-7B-v0.1": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
|
|
|
|
| 56 |
"mistralai/Mistral-7B-v0.3": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
|
| 57 |
+
"mistralai/Mixtral-8x7B-v0.1": {"temperature": 0.8, "top_p": 0.95, "top_k": 50},
|
| 58 |
"google/gemma-2-2b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
|
| 59 |
+
"google/gemma-2-7b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
|
|
|
|
| 60 |
"croissantllm/CroissantLLMBase": {"temperature": 0.8, "top_p": 0.92, "top_k": 50}
|
| 61 |
}
|
| 62 |
|
| 63 |
# Variables globales
|
| 64 |
+
model_cache = {}
|
|
|
|
|
|
|
| 65 |
|
| 66 |
+
# Fonctions utilitaires
|
| 67 |
def update_model_type(family):
|
| 68 |
return gr.Dropdown(choices=list(models_info[family].keys()), value=None, interactive=True)
|
| 69 |
|
| 70 |
def update_model_variation(family, model_type):
|
| 71 |
+
if family and model_type:
|
| 72 |
+
return gr.Dropdown(choices=list(models_info[family][model_type].keys()), value=None, interactive=True)
|
| 73 |
+
return gr.Dropdown(choices=[], value=None, interactive=False)
|
| 74 |
|
| 75 |
def update_selected_model(family, model_type, variation):
|
| 76 |
if family and model_type and variation:
|
|
|
|
| 78 |
return model_name, gr.Dropdown(choices=models_info[family][model_type][variation]["languages"], value=models_info[family][model_type][variation]["languages"][0], visible=True, interactive=True)
|
| 79 |
return "", gr.Dropdown(visible=False)
|
| 80 |
|
| 81 |
+
async def load_model_async(model_name, progress=gr.Progress()):
|
|
|
|
| 82 |
try:
|
| 83 |
+
if model_name not in model_cache:
|
| 84 |
+
progress(0.1, f"Chargement du tokenizer pour {model_name}...")
|
| 85 |
+
tokenizer = await asyncio.to_thread(AutoTokenizer.from_pretrained, model_name)
|
| 86 |
+
progress(0.4, f"Chargement du modèle {model_name}...")
|
| 87 |
+
model = await asyncio.to_thread(AutoModelForCausalLM.from_pretrained, model_name,
|
| 88 |
+
torch_dtype=torch.float16, device_map="auto", low_cpu_mem_usage=True)
|
| 89 |
+
if tokenizer.pad_token is None:
|
| 90 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 91 |
+
model_cache[model_name] = (model, tokenizer)
|
| 92 |
+
progress(1.0, f"Modèle {model_name} chargé avec succès")
|
| 93 |
+
return f"Modèle {model_name} chargé avec succès"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
except Exception as e:
|
| 95 |
+
return f"Erreur lors du chargement du modèle {model_name} : {str(e)}"
|
| 96 |
|
| 97 |
def set_language(lang):
|
|
|
|
|
|
|
| 98 |
return f"Langue sélectionnée : {lang}"
|
| 99 |
|
| 100 |
+
def ensure_token_display(token, tokenizer):
|
|
|
|
| 101 |
if token.isdigit() or (token.startswith('-') and token[1:].isdigit()):
|
| 102 |
return tokenizer.decode([int(token)])
|
| 103 |
return token
|
| 104 |
|
| 105 |
+
async def analyze_next_token(model_name, input_text, temperature, top_p, top_k, progress=gr.Progress()):
|
| 106 |
+
if model_name not in model_cache:
|
| 107 |
+
return "Veuillez d'abord charger le modèle", None, None
|
| 108 |
|
| 109 |
+
model, tokenizer = model_cache[model_name]
|
|
|
|
|
|
|
| 110 |
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
|
| 111 |
|
| 112 |
try:
|
| 113 |
+
progress(0.5, "Analyse en cours...")
|
| 114 |
with torch.no_grad():
|
| 115 |
outputs = model(**inputs)
|
| 116 |
|
| 117 |
last_token_logits = outputs.logits[0, -1, :]
|
| 118 |
probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
|
| 119 |
|
| 120 |
+
top_k = min(10, top_k)
|
| 121 |
top_probs, top_indices = torch.topk(probabilities, top_k)
|
| 122 |
+
top_words = [ensure_token_display(tokenizer.decode([idx.item()]), tokenizer) for idx in top_indices]
|
| 123 |
prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
|
| 124 |
|
| 125 |
prob_text = "Prochains tokens les plus probables :\n\n"
|
|
|
|
| 127 |
prob_text += f"{word}: {prob:.2%}\n"
|
| 128 |
|
| 129 |
prob_plot = plot_probabilities(prob_data)
|
| 130 |
+
attention_plot = plot_attention(inputs["input_ids"][0].cpu(), last_token_logits.cpu(), tokenizer)
|
| 131 |
|
| 132 |
+
progress(1.0, "Analyse terminée")
|
| 133 |
return prob_text, attention_plot, prob_plot
|
| 134 |
except Exception as e:
|
| 135 |
return f"Erreur lors de l'analyse : {str(e)}", None, None
|
| 136 |
|
| 137 |
+
async def generate_text(model_name, input_text, temperature, top_p, top_k, progress=gr.Progress()):
|
| 138 |
+
if model_name not in model_cache:
|
| 139 |
+
return "Veuillez d'abord charger le modèle"
|
| 140 |
|
| 141 |
+
model, tokenizer = model_cache[model_name]
|
|
|
|
|
|
|
| 142 |
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
|
| 143 |
|
| 144 |
try:
|
| 145 |
+
progress(0.5, "Génération en cours...")
|
| 146 |
with torch.no_grad():
|
| 147 |
outputs = model.generate(
|
| 148 |
**inputs,
|
| 149 |
+
max_new_tokens=50,
|
| 150 |
temperature=temperature,
|
| 151 |
top_p=top_p,
|
| 152 |
top_k=top_k
|
| 153 |
)
|
| 154 |
|
| 155 |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 156 |
+
progress(1.0, "Génération terminée")
|
| 157 |
return generated_text
|
| 158 |
except Exception as e:
|
| 159 |
return f"Erreur lors de la génération : {str(e)}"
|
| 160 |
|
| 161 |
def plot_probabilities(prob_data):
|
| 162 |
+
try:
|
| 163 |
+
words = list(prob_data.keys())
|
| 164 |
+
probs = list(prob_data.values())
|
| 165 |
+
|
| 166 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 167 |
+
bars = ax.bar(range(len(words)), probs, color='lightgreen')
|
| 168 |
+
ax.set_title("Probabilités des tokens suivants les plus probables")
|
| 169 |
+
ax.set_xlabel("Tokens")
|
| 170 |
+
ax.set_ylabel("Probabilité")
|
| 171 |
+
|
| 172 |
+
ax.set_xticks(range(len(words)))
|
| 173 |
+
ax.set_xticklabels(words, rotation=45, ha='right')
|
| 174 |
+
|
| 175 |
+
for i, (bar, word) in enumerate(zip(bars, words)):
|
| 176 |
+
height = bar.get_height()
|
| 177 |
+
ax.text(i, height, f'{height:.2%}',
|
| 178 |
+
ha='center', va='bottom', rotation=0)
|
| 179 |
+
|
| 180 |
+
plt.tight_layout()
|
| 181 |
+
return fig
|
| 182 |
+
except Exception as e:
|
| 183 |
+
print(f"Erreur lors de la création du graphique : {str(e)}")
|
| 184 |
+
return None
|
| 185 |
|
| 186 |
+
def plot_attention(input_ids, last_token_logits, tokenizer):
|
| 187 |
+
try:
|
| 188 |
+
input_tokens = [ensure_token_display(tokenizer.decode([id]), tokenizer) for id in input_ids]
|
| 189 |
+
attention_scores = torch.nn.functional.softmax(last_token_logits, dim=-1)
|
| 190 |
+
top_k = min(len(input_tokens), 10)
|
| 191 |
+
top_attention_scores, _ = torch.topk(attention_scores, top_k)
|
| 192 |
+
|
| 193 |
+
fig, ax = plt.subplots(figsize=(14, 7))
|
| 194 |
+
sns.heatmap(top_attention_scores.unsqueeze(0).numpy(), annot=True, cmap="YlOrRd", cbar=True, ax=ax, fmt='.2%')
|
| 195 |
+
ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right", fontsize=10)
|
| 196 |
+
ax.set_yticklabels(["Attention"], rotation=0, fontsize=10)
|
| 197 |
+
ax.set_title("Scores d'attention pour les derniers tokens", fontsize=16)
|
| 198 |
+
|
| 199 |
+
cbar = ax.collections[0].colorbar
|
| 200 |
+
cbar.set_label("Score d'attention", fontsize=12)
|
| 201 |
+
cbar.ax.tick_params(labelsize=10)
|
| 202 |
+
|
| 203 |
+
plt.tight_layout()
|
| 204 |
+
return fig
|
| 205 |
+
except Exception as e:
|
| 206 |
+
print(f"Erreur lors de la création du graphique d'attention : {str(e)}")
|
| 207 |
+
return None
|
| 208 |
|
| 209 |
def reset():
|
| 210 |
+
global model_cache
|
| 211 |
+
for model in model_cache.values():
|
| 212 |
+
del model
|
| 213 |
+
model_cache.clear()
|
| 214 |
+
torch.cuda.empty_cache()
|
| 215 |
+
gc.collect()
|
| 216 |
return (
|
| 217 |
"", 1.0, 1.0, 50, None, None, None, None,
|
| 218 |
gr.Dropdown(choices=list(models_info.keys()), value=None, interactive=True),
|
|
|
|
| 221 |
"", gr.Dropdown(visible=False), ""
|
| 222 |
)
|
| 223 |
|
| 224 |
+
def reset_comparison():
|
| 225 |
+
return [gr.Dropdown(choices=[], value=None) for _ in range(4)] + ["", "", gr.Dropdown(choices=[], value=None), 1.0, 1.0, 50, "", "", None, None, None, None]
|
| 226 |
+
|
| 227 |
+
async def compare_models(model1, model2, input_text, temp, top_p, top_k, progress=gr.Progress()):
|
| 228 |
+
if model1 not in model_cache or model2 not in model_cache:
|
| 229 |
+
return "Veuillez d'abord charger les deux modèles", "", None, None, None, None
|
| 230 |
+
|
| 231 |
+
progress(0.1, "Analyse du premier modèle...")
|
| 232 |
+
results1 = await analyze_next_token(model1, input_text, temp, top_p, top_k)
|
| 233 |
+
progress(0.4, "Analyse du second modèle...")
|
| 234 |
+
results2 = await analyze_next_token(model2, input_text, temp, top_p, top_k)
|
| 235 |
+
progress(1.0, "Comparaison terminée")
|
| 236 |
+
return (
|
| 237 |
+
results1[0], results2[0], # Probabilités du prochain token
|
| 238 |
+
results1[2], results2[2], # Graphiques de probabilités
|
| 239 |
+
results1[1], results2[1] # Graphiques d'attention
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
with gr.Blocks() as demo:
|
| 243 |
gr.Markdown("# LLM&BIAS")
|
| 244 |
|
| 245 |
+
with gr.Tabs():
|
| 246 |
+
with gr.Tab("Analyse individuelle"):
|
| 247 |
+
with gr.Accordion("Sélection du modèle", open=True):
|
| 248 |
+
with gr.Row():
|
| 249 |
+
model_family = gr.Dropdown(choices=list(models_info.keys()), label="Famille de modèle", interactive=True)
|
| 250 |
+
model_type = gr.Dropdown(choices=[], label="Type de modèle", interactive=False)
|
| 251 |
+
model_variation = gr.Dropdown(choices=[], label="Variation du modèle", interactive=False)
|
| 252 |
+
|
| 253 |
+
selected_model = gr.Textbox(label="Modèle sélectionné", interactive=False)
|
| 254 |
+
load_button = gr.Button("Charger le modèle")
|
| 255 |
+
load_output = gr.Textbox(label="Statut du chargement")
|
| 256 |
+
language_dropdown = gr.Dropdown(label="Choisissez une langue", visible=False)
|
| 257 |
+
language_output = gr.Textbox(label="Langue sélectionnée")
|
| 258 |
+
|
| 259 |
+
with gr.Row():
|
| 260 |
+
temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température")
|
| 261 |
+
top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
|
| 262 |
+
top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
|
| 263 |
+
|
| 264 |
+
input_text = gr.Textbox(label="Texte d'entrée", lines=3)
|
| 265 |
+
analyze_button = gr.Button("Analyser le prochain token")
|
| 266 |
+
|
| 267 |
+
next_token_probs = gr.Textbox(label="Probabilités du prochain token")
|
| 268 |
+
|
| 269 |
+
with gr.Row():
|
| 270 |
+
attention_plot = gr.Plot(label="Visualisation de l'attention")
|
| 271 |
+
prob_plot = gr.Plot(label="Probabilités des tokens suivants")
|
| 272 |
+
|
| 273 |
+
generate_button = gr.Button("Générer le texte")
|
| 274 |
+
generated_text = gr.Textbox(label="Texte généré")
|
| 275 |
+
|
| 276 |
+
reset_button = gr.Button("Réinitialiser")
|
| 277 |
+
|
| 278 |
+
with gr.Tab("Comparaison de modèles"):
|
| 279 |
+
with gr.Row():
|
| 280 |
+
model1_family = gr.Dropdown(choices=list(models_info.keys()), label="Famille du modèle 1", interactive=True)
|
| 281 |
+
model1_type = gr.Dropdown(choices=[], label="Type du modèle 1", interactive=False)
|
| 282 |
+
model1_variation = gr.Dropdown(choices=[], label="Variation du modèle 1", interactive=False)
|
| 283 |
+
|
| 284 |
+
with gr.Row():
|
| 285 |
+
model2_family = gr.Dropdown(choices=list(models_info.keys()), label="Famille du modèle 2", interactive=True)
|
| 286 |
+
model2_type = gr.Dropdown(choices=[], label="Type du modèle 2", interactive=False)
|
| 287 |
+
model2_variation = gr.Dropdown(choices=[], label="Variation du modèle 2", interactive=False)
|
| 288 |
+
|
| 289 |
+
model1_selected = gr.Textbox(label="Modèle 1 sélectionné", interactive=False)
|
| 290 |
+
model2_selected = gr.Textbox(label="Modèle 2 sélectionné", interactive=False)
|
| 291 |
+
|
| 292 |
+
load_models_button = gr.Button("Charger les modèles")
|
| 293 |
+
load_models_output = gr.Textbox(label="Statut du chargement des modèles")
|
| 294 |
+
|
| 295 |
+
comparison_language = gr.Dropdown(label="Langue pour la comparaison", choices=[], interactive=False)
|
| 296 |
+
|
| 297 |
+
with gr.Row():
|
| 298 |
+
comp_temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température")
|
| 299 |
+
comp_top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
|
| 300 |
+
comp_top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
|
| 301 |
+
|
| 302 |
+
comp_input_text = gr.Textbox(label="Texte d'entrée pour la comparaison", lines=3)
|
| 303 |
+
compare_button = gr.Button("Comparer les modèles")
|
| 304 |
+
|
| 305 |
+
with gr.Row():
|
| 306 |
+
model1_output = gr.Textbox(label="Probabilités du Modèle 1", lines=10)
|
| 307 |
+
model2_output = gr.Textbox(label="Probabilités du Modèle 2", lines=10)
|
| 308 |
+
|
| 309 |
+
with gr.Row():
|
| 310 |
+
model1_prob_plot = gr.Plot(label="Probabilités des tokens (Modèle 1)")
|
| 311 |
+
model2_prob_plot = gr.Plot(label="Probabilités des tokens (Modèle 2)")
|
| 312 |
+
|
| 313 |
+
with gr.Row():
|
| 314 |
+
model1_attention_plot = gr.Plot(label="Attention (Modèle 1)")
|
| 315 |
+
model2_attention_plot = gr.Plot(label="Attention (Modèle 2)")
|
| 316 |
+
|
| 317 |
+
comp_reset_button = gr.Button("Réinitialiser la comparaison")
|
| 318 |
+
|
| 319 |
+
# Événements pour l'onglet d'analyse individuelle
|
| 320 |
+
model_family.change(update_model_type, inputs=[model_family], outputs=[model_type])
|
| 321 |
+
model_type.change(update_model_variation, inputs=[model_family, model_type], outputs=[model_variation])
|
| 322 |
+
model_variation.change(update_selected_model, inputs=[model_family, model_type, model_variation], outputs=[selected_model, language_dropdown])
|
| 323 |
+
load_button.click(load_model_async, inputs=[selected_model], outputs=[load_output])
|
| 324 |
+
language_dropdown.change(set_language, inputs=[language_dropdown], outputs=[language_output])
|
| 325 |
+
analyze_button.click(analyze_next_token, inputs=[selected_model, input_text, temperature, top_p, top_k], outputs=[next_token_probs, attention_plot, prob_plot])
|
| 326 |
+
generate_button.click(generate_text, inputs=[selected_model, input_text, temperature, top_p, top_k], outputs=[generated_text])
|
| 327 |
+
reset_button.click(reset, outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_plot, prob_plot, generated_text, model_family, model_type, model_variation, selected_model, language_dropdown, language_output])
|
| 328 |
+
|
| 329 |
+
# Événements pour l'onglet de comparaison
|
| 330 |
+
model1_family.change(update_model_type, inputs=[model1_family], outputs=[model1_type])
|
| 331 |
+
model1_type.change(update_model_variation, inputs=[model1_family, model1_type], outputs=[model1_variation])
|
| 332 |
+
model1_variation.change(update_selected_model, inputs=[model1_family, model1_type, model1_variation], outputs=[model1_selected, comparison_language])
|
| 333 |
+
|
| 334 |
+
model2_family.change(update_model_type, inputs=[model2_family], outputs=[model2_type])
|
| 335 |
+
model2_type.change(update_model_variation, inputs=[model2_family, model2_type], outputs=[model2_variation])
|
| 336 |
+
model2_variation.change(update_selected_model, inputs=[model2_family, model2_type, model2_variation], outputs=[model2_selected, comparison_language])
|
| 337 |
+
|
| 338 |
+
async def load_both_models(model1, model2):
|
| 339 |
+
result1 = await load_model_async(model1)
|
| 340 |
+
result2 = await load_model_async(model2)
|
| 341 |
+
return f"Modèle 1: {result1}\nModèle 2: {result2}"
|
| 342 |
+
|
| 343 |
+
load_models_button.click(load_both_models, inputs=[model1_selected, model2_selected], outputs=[load_models_output])
|
| 344 |
+
|
| 345 |
+
compare_button.click(
|
| 346 |
+
compare_models,
|
| 347 |
+
inputs=[model1_selected, model2_selected, comp_input_text, comp_temperature, comp_top_p, comp_top_k],
|
| 348 |
+
outputs=[model1_output, model2_output, model1_prob_plot, model2_prob_plot, model1_attention_plot, model2_attention_plot]
|
| 349 |
)
|
| 350 |
+
|
| 351 |
+
comp_reset_button.click(
|
| 352 |
+
reset_comparison,
|
| 353 |
+
outputs=[model1_type, model1_variation, model2_type, model2_variation, model1_selected, model2_selected, comparison_language,
|
| 354 |
+
comp_temperature, comp_top_p, comp_top_k, comp_input_text, model1_output, model2_output,
|
| 355 |
+
model1_prob_plot, model2_prob_plot, model1_attention_plot, model2_attention_plot]
|
|
|
|
|
|
|
|
|
|
| 356 |
)
|
| 357 |
|
| 358 |
if __name__ == "__main__":
|
| 359 |
+
demo.launch()
|