|
|
--- |
|
|
license: mit |
|
|
language: |
|
|
- ru |
|
|
tags: |
|
|
- Prompt |
|
|
- Prompt_Classification |
|
|
- Classification |
|
|
- LinearModel |
|
|
- MLP |
|
|
- class |
|
|
- Prompt Classes |
|
|
- Classificator |
|
|
- Prompt Classification |
|
|
- AI |
|
|
- ML |
|
|
- Class |
|
|
- Classify |
|
|
- Text |
|
|
- Context |
|
|
--- |
|
|
|
|
|
# 🔁 SimplePromptClassifier — классификатор промптов (русский) |
|
|
|
|
|
 |
|
|
|
|
|
|
|
|
**Кратко:** модель классифицирует входные промпты/вопросы на три действия: |
|
|
- **0 — Поиск в локальной базе знаний (RAG)**: сначала ищем релевантные документы в локальном индексе и формируем контекст для генерации. |
|
|
- **1 — Поиск в сети**: триггер запуска обхода внешних поисковых систем/скрейпинга. |
|
|
- **2 — Прямой запрос**: сразу посылаем промпт в генеративную модель (например, LLM) для синтеза ответа. |
|
|
|
|
|
--- |
|
|
|
|
|
## Где используется |
|
|
Подходит для систем, где нужно автоматически решать стратегию обработки пользовательского промпта: |
|
|
- чат-боты со связкой Retrieval-Augmented Generation (RAG), |
|
|
- голосовые ассистенты, |
|
|
- интерфейсы поддержки, где часть запросов решается поиском, часть — генерацией. |
|
|
|
|
|
--- |
|
|
|
|
|
## Файлы в репозитории |
|
|
- `pytorch_model.bin` — веса модели (state_dict). |
|
|
- `config.json` — конфигурация (input_dim, num_classes, p_dropout, classes). |
|
|
- `modeling_simple_classifier.py` — определение архитектуры. |
|
|
- `vectorizer.pkl` — sklearn-векторизатор (TF-IDF/Count). |
|
|
- `svd.pkl` — TruncatedSVD (опционально). |
|
|
- `label_encoder.pkl` — sklearn.LabelEncoder (для декодирования метки). |
|
|
- `README.md` — эта карточка. |
|
|
|
|
|
--- |
|
|
|
|
|
## Пример загрузки и инференса (без AutoModel) |
|
|
|
|
|
```python |
|
|
# Пример: загрузка напрямую из репозитория HF (не требует локальной копии) |
|
|
from huggingface_hub import hf_hub_download |
|
|
import json, pickle, torch |
|
|
import numpy as np |
|
|
from types import SimpleNamespace |
|
|
|
|
|
REPO = "Neweret/SimplePromptClassifier-85k" |
|
|
|
|
|
config_path = hf_hub_download(REPO, "config.json") |
|
|
weights_path = hf_hub_download(REPO, "pytorch_model.bin") |
|
|
vec_path = hf_hub_download(REPO, "vectorizer.pkl") |
|
|
svd_path = None |
|
|
try: |
|
|
svd_path = hf_hub_download(REPO, "svd.pkl") |
|
|
except Exception: |
|
|
svd_path = None |
|
|
le_path = hf_hub_download(REPO, "label_encoder.pkl") |
|
|
|
|
|
cfg = SimpleNamespace(**json.load(open(config_path, "r", encoding="utf-8"))) |
|
|
|
|
|
# --- Динамическая модель --- |
|
|
class SimpleClassifier(torch.nn.Module): |
|
|
def __init__(self, input_dim, num_classes, p_dropout=0.3): |
|
|
super().__init__() |
|
|
self.linear1 = torch.nn.Linear(input_dim, 256) |
|
|
self.ln1 = torch.nn.LayerNorm(256) |
|
|
self.dropout = torch.nn.Dropout(p_dropout) |
|
|
self.linear2 = torch.nn.Linear(256, 128) |
|
|
self.ln2 = torch.nn.LayerNorm(128) |
|
|
self.linear_out = torch.nn.Linear(128, num_classes) |
|
|
def forward(self, x): |
|
|
x = torch.nn.functional.gelu(self.ln1(self.linear1(x))) |
|
|
x = self.dropout(x) |
|
|
x = torch.nn.functional.gelu(self.ln2(self.linear2(x))) |
|
|
x = self.dropout(x) |
|
|
return self.linear_out(x) |
|
|
|
|
|
model = SimpleClassifier(cfg.input_dim, cfg.num_classes, cfg.p_dropout) |
|
|
state = torch.load(weights_path, map_location="cpu") |
|
|
model.load_state_dict(state) |
|
|
model.eval() |
|
|
|
|
|
# препроцессинг |
|
|
vectorizer = pickle.load(open(vec_path, "rb")) |
|
|
svd = pickle.load(open(svd_path, "rb")) if svd_path else None |
|
|
le = pickle.load(open(le_path, "rb")) |
|
|
|
|
|
def preprocess(text): |
|
|
X = vectorizer.transform([text]) |
|
|
if svd is not None: |
|
|
X = svd.transform(X) |
|
|
return X.astype(np.float32) |
|
|
|
|
|
def predict(text): |
|
|
x = preprocess(text) |
|
|
xb = torch.from_numpy(x).float() |
|
|
with torch.inference_mode(): |
|
|
logits = model(xb) |
|
|
pred = int(torch.argmax(logits, dim=1).cpu().numpy()[0]) |
|
|
return pred, le.inverse_transform([pred])[0] |
|
|
|
|
|
# пример |
|
|
print(predict("Как мне найти документацию по нашей компании?")) |