File size: 4,500 Bytes
b05cd55
 
 
 
519e6e8
 
 
 
 
 
 
 
 
 
a0f500d
 
 
 
 
 
1f23750
 
27200b9
1f23750
bc1efbe
22e528d
a0f500d
1f23750
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c24966f
1f23750
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27200b9
1f23750
 
 
 
 
 
 
 
 
 
 
 
 
7467021
1f23750
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0f500d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
---
license: mit
language:
- ru
tags:
- Prompt
- Prompt_Classification
- Classification
- LinearModel
- MLP
- class
- Prompt Classes
- Classificator
- Prompt Classification
- AI
- ML
- Class
- Classify
- Text
- Context
---

# 🔁 SimplePromptClassifier — классификатор промптов (русский)

![Model banner](./AI_scheme.png)


**Кратко:** модель классифицирует входные промпты/вопросы на три действия:
- **0 — Поиск в локальной базе знаний (RAG)**: сначала ищем релевантные документы в локальном индексе и формируем контекст для генерации.  
- **1 — Поиск в сети**: триггер запуска обхода внешних поисковых систем/скрейпинга.  
- **2 — Прямой запрос**: сразу посылаем промпт в генеративную модель (например, LLM) для синтеза ответа.

---

## Где используется
Подходит для систем, где нужно автоматически решать стратегию обработки пользовательского промпта:
- чат-боты со связкой Retrieval-Augmented Generation (RAG),
- голосовые ассистенты,
- интерфейсы поддержки, где часть запросов решается поиском, часть — генерацией.

---

## Файлы в репозитории
- `pytorch_model.bin` — веса модели (state_dict).  
- `config.json` — конфигурация (input_dim, num_classes, p_dropout, classes).  
- `modeling_simple_classifier.py` — определение архитектуры.  
- `vectorizer.pkl` — sklearn-векторизатор (TF-IDF/Count).  
- `svd.pkl` — TruncatedSVD (опционально).  
- `label_encoder.pkl` — sklearn.LabelEncoder (для декодирования метки).  
- `README.md` — эта карточка.

---

## Пример загрузки и инференса (без AutoModel)

```python
# Пример: загрузка напрямую из репозитория HF (не требует локальной копии)
from huggingface_hub import hf_hub_download
import json, pickle, torch
import numpy as np
from types import SimpleNamespace

REPO = "Neweret/SimplePromptClassifier-85k"

config_path = hf_hub_download(REPO, "config.json")
weights_path = hf_hub_download(REPO, "pytorch_model.bin")
vec_path = hf_hub_download(REPO, "vectorizer.pkl")
svd_path = None
try:
    svd_path = hf_hub_download(REPO, "svd.pkl")
except Exception:
    svd_path = None
le_path = hf_hub_download(REPO, "label_encoder.pkl")

cfg = SimpleNamespace(**json.load(open(config_path, "r", encoding="utf-8")))

# --- Динамическая модель ---
class SimpleClassifier(torch.nn.Module):
    def __init__(self, input_dim, num_classes, p_dropout=0.3):
        super().__init__()
        self.linear1 = torch.nn.Linear(input_dim, 256)
        self.ln1 = torch.nn.LayerNorm(256)
        self.dropout = torch.nn.Dropout(p_dropout)
        self.linear2 = torch.nn.Linear(256, 128)
        self.ln2 = torch.nn.LayerNorm(128)
        self.linear_out = torch.nn.Linear(128, num_classes)
    def forward(self, x):
        x = torch.nn.functional.gelu(self.ln1(self.linear1(x)))
        x = self.dropout(x)
        x = torch.nn.functional.gelu(self.ln2(self.linear2(x)))
        x = self.dropout(x)
        return self.linear_out(x)

model = SimpleClassifier(cfg.input_dim, cfg.num_classes, cfg.p_dropout)
state = torch.load(weights_path, map_location="cpu")
model.load_state_dict(state)
model.eval()

# препроцессинг
vectorizer = pickle.load(open(vec_path, "rb"))
svd = pickle.load(open(svd_path, "rb")) if svd_path else None
le = pickle.load(open(le_path, "rb"))

def preprocess(text):
    X = vectorizer.transform([text])
    if svd is not None:
        X = svd.transform(X)
    return X.astype(np.float32)

def predict(text):
    x = preprocess(text)
    xb = torch.from_numpy(x).float()
    with torch.inference_mode():
        logits = model(xb)
        pred = int(torch.argmax(logits, dim=1).cpu().numpy()[0])
    return pred, le.inverse_transform([pred])[0]

# пример
print(predict("Как мне найти документацию по нашей компании?"))