Saumya Saraswat
adding fine tuner
c0dcd9b
import os, json, re, random, time, shutil, threading
import gradio as gr
from datasets import load_dataset, Dataset, concatenate_datasets
from huggingface_hub import HfApi, create_repo, upload_folder, whoami
from transformers import (
AutoTokenizer, AutoModelForCausalLM,
TrainingArguments
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
from trl import SFTTrainer
# ---------- Defaults & env ----------
BASE = os.getenv("BASE", "meta-llama/Llama-3.2-3B-Instruct")
OUT_REPO = os.getenv("OUT_REPO", "your-username/llama32-3b-thinking")
HF_TOKEN = os.getenv("HF_TOKEN", None)
random.seed(17)
# ---------- Helpers ----------
def _ok(s): return gr.update(value=s, visible=True)
def try_load(options, **kw):
for dsid in options:
try:
return load_dataset(dsid, **kw)
except Exception:
continue
raise RuntimeError(f"Failed loading any of: {options}")
def trim_text(txt, max_words=220):
w = (txt or "").split()
return " ".join(w[:max_words])
def pack_record(instruction, rationale, final, inp=""):
rationale = trim_text(rationale, 220)
if len(rationale.split()) < 3: # drop trivial
return None
return {
"instruction": instruction.strip(),
"input": (inp or "").strip(),
"rationale": rationale.strip(),
"final": (final or "").strip()
}
def build_hotpot_rationale(supporting_facts, context, answer):
m = {title: sents for title, sents in context}
bits = []
for title, idx in supporting_facts[:3]:
try:
s = m[title][idx]
bits.append(f"[{title}] {s}")
except Exception:
pass
if not bits: return None
return " ".join(bits) + f" ⇒ {answer}"
# ---------- Dataset loaders (blend) ----------
def load_cose():
ds = try_load(["Salesforce/cos_e", "cos_e"], name="v1.11")["train"]
rows=[]
for ex in ds:
choices = ex.get("choices") or ex.get("options") or []
rec = pack_record(
instruction=f"Q: {ex['question']}\nOptions: {', '.join(choices)}",
rationale=ex.get("abstractive_explanation") or ex.get("rationale",""),
final=ex["answer"]
)
if rec: rows.append(rec)
return Dataset.from_list(rows)
def load_esnli(limit=60000):
ds = try_load(["esnli","esnli/esnli"])["train"].select(range(limit))
rows=[]
for ex in ds:
rat = ex.get("explanation_1") or ex.get("explanation_2") or ex.get("explanation_3") or ""
rec = pack_record(
instruction=f"Premise: {ex['premise']}\nHypothesis: {ex['hypothesis']}\n"
f"Label (entailment/contradiction/neutral) and justify briefly.",
rationale=rat, final=ex["label"]
)
if rec: rows.append(rec)
return Dataset.from_list(rows)
def load_ecqa():
ds = try_load(["yangdong/ecqa","ecqa","google-research-datasets/ecqa"])["train"]
rows=[]
for ex in ds:
opts = [ex.get(k) for k in ["opa","opb","opc","opd","ope"] if ex.get(k)]
ans = ex.get("correct_ans","") or ex.get("label","")
exp = ex.get("explanation","") or ex.get("rationale","")
rec = pack_record(
instruction=f"Q: {ex.get('question','')}\nOptions: {', '.join(opts)}",
rationale=exp, final=str(ans)
)
if rec: rows.append(rec)
return Dataset.from_list(rows)
def load_strategyqa(limit=6000):
ds = try_load(["voidful/StrategyQA","allenai/strategyqa","strategy_qa"])["train"]
rows=[]; i=0
for ex in ds:
if limit and i>=limit: break
i+=1
q = ex.get("question") or ex.get("q","")
ans = str(ex.get("answer","")).lower()
rat = ex.get("decomposition","") or " ".join(ex.get("facts",[])) or ex.get("evidence","")
if not rat: rat = "Reason step by step to reach yes/no."
rec = pack_record(instruction=q, rationale=rat,
final="yes" if ans in ["1","true","yes"] else "no")
if rec: rows.append(rec)
return Dataset.from_list(rows)
def load_hotpot(sample=15000):
ds = try_load(["hotpotqa/hotpot_qa","hotpot_qa"], name="distractor")["train"]
idx = list(range(len(ds))); random.shuffle(idx); idx = idx[:sample]
rows=[]
for i in idx:
ex = ds[i]
rat = build_hotpot_rationale(ex["supporting_facts"], ex["context"], ex["answer"])
if not rat: continue
rec = pack_record(instruction=ex["question"], rationale=rat, final=ex["answer"])
if rec: rows.append(rec)
return Dataset.from_list(rows)
def load_gsm8k_train():
ds = try_load(["openai/gsm8k","gsm8k"], name="main")["train"]
rows=[]
for ex in ds:
sol = ex.get("solution","")
m = re.findall(r"(-?\d+(?:\.\d+)?)", sol)
final = m[-1] if m else ex.get("answer","")
rec = pack_record(instruction=ex["question"], rationale=sol, final=str(final))
if rec: rows.append(rec)
return Dataset.from_list(rows)
def load_openthoughts(limit=100000):
try:
ds = try_load(["open-thoughts/OpenThoughts-114k","OpenThoughts-114k"])["train"]
if limit: ds = ds.select(range(min(limit, len(ds))))
rows=[]
for ex in ds:
q = ex.get("question") or ex.get("instruction") or ""
rat = ex.get("cot") or ex.get("rationale") or ""
ans = ex.get("answer") or ex.get("final") or ""
rec = pack_record(instruction=q, rationale=rat, final=ans)
if rec: rows.append(rec)
return Dataset.from_list(rows)
except Exception:
return Dataset.from_list([])
def load_bespoke():
try:
ds = try_load(["HuggingFaceH4/Bespoke-Stratos-17k","Bespoke-Stratos-17k"])["train"]
rows=[]
for ex in ds:
q = ex.get("prompt") or ex.get("question") or ""
rat = ex.get("reasoning") or ex.get("rationale") or ""
ans = ex.get("output") or ex.get("final") or ""
rec = pack_record(instruction=q, rationale=rat, final=ans)
if rec: rows.append(rec)
return Dataset.from_list(rows)
except Exception:
return Dataset.from_list([])
# ---------- Build blend ----------
def build_blend():
parts = [
load_openthoughts(limit=100000),
load_bespoke(),
load_gsm8k_train(),
load_cose(),
load_esnli(limit=60000),
load_ecqa(),
load_strategyqa(limit=6000),
load_hotpot(sample=15000),
]
parts = [p for p in parts if len(p)>0]
mix = concatenate_datasets(parts).shuffle(seed=17)
n_total = len(mix)
# split tiny eval
eval_size = min(3000, max(1000, int(0.01*n_total)))
eval_ds = mix.select(range(eval_size))
mix.to_json("blend_train.jsonl", orient="records", lines=True)
eval_ds.to_json("blend_eval.jsonl", orient="records", lines=True)
return f"Blend built. Train: {n_total} rows. Eval: {len(eval_ds)} rows. Files: blend_train.jsonl, blend_eval.jsonl"
# ---------- Formatter for SFT ----------
def to_chat_formatter(tokenizer):
def _fmt(ex):
msgs = [
{"role":"system","content":"Think privately in <THINK>...</THINK>. Answer ONLY in <FINAL>...</FINAL>."},
{"role":"user","content": ex["instruction"] + (("\n\n"+ex["input"]) if ex.get("input") else "")},
{"role":"assistant","content": f"<THINK>{ex['rationale']}</THINK>\n<FINAL>{ex['final']}</FINAL>"}
]
return {"text": tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)}
return _fmt
# ---------- Train LoRA ----------
def train_lora(base=BASE, out_dir="thinking3b-lora", epochs=2, lr=2e-4, r=32, alpha=16, dropout=0.05, max_len=3072):
assert HF_TOKEN, "HF_TOKEN not found (Space Secret)."
tok = AutoTokenizer.from_pretrained(base, use_fast=True, token=HF_TOKEN)
tok.pad_token = tok.eos_token
train = load_dataset("json", data_files="blend_train.jsonl")["train"].map(to_chat_formatter(tok), remove_columns=["instruction","input","rationale","final"])
evald = load_dataset("json", data_files="blend_eval.jsonl")["train"].map(to_chat_formatter(tok), remove_columns=["instruction","input","rationale","final"])
model = AutoModelForCausalLM.from_pretrained(base, load_in_4bit=True, torch_dtype="auto", device_map="auto", token=HF_TOKEN)
model = prepare_model_for_kbit_training(model)
lora = LoraConfig(r=r, lora_alpha=alpha, lora_dropout=dropout,
target_modules=["q_proj","k_proj","v_proj","o_proj"])
model = get_peft_model(model, lora)
args = TrainingArguments(
output_dir=out_dir,
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
learning_rate=lr,
num_train_epochs=epochs,
logging_steps=25,
save_strategy="epoch",
evaluation_strategy="epoch",
bf16=True,
lr_scheduler_type="cosine",
warmup_ratio=0.03,
weight_decay=0.0,
max_grad_norm=1.0
)
trainer = SFTTrainer(
model=model, tokenizer=tok,
train_dataset=train, eval_dataset=evald,
dataset_text_field="text",
packing=True, max_seq_length=max_len,
args=args
)
trainer.train()
model.save_pretrained(out_dir)
tok.save_pretrained(out_dir)
return f"LoRA saved to {out_dir}"
# ---------- Merge LoRA ----------
def merge_lora(base=BASE, adapter_dir="thinking3b-lora", out_dir="thinking3b-merged"):
tok = AutoTokenizer.from_pretrained(base, use_fast=True, token=HF_TOKEN)
base_m = AutoModelForCausalLM.from_pretrained(base, torch_dtype="bfloat16", device_map="auto", token=HF_TOKEN)
merged = PeftModel.from_pretrained(base_m, adapter_dir).merge_and_unload()
merged.save_pretrained(out_dir, safe_serialization=True)
tok.save_pretrained(out_dir)
return f"Merged weights saved to {out_dir}"
# ---------- Push to Hub ----------
def push_to_hub(repo_id=OUT_REPO, folder="thinking3b-merged"):
assert HF_TOKEN, "HF_TOKEN not found."
api = HfApi(token=HF_TOKEN)
# create repo if needed
try:
create_repo(repo_id, repo_type="model", token=HF_TOKEN, exist_ok=True)
except Exception:
pass
# add a sane generation config
with open(os.path.join(folder, "generation_config.json"), "w", encoding="utf-8") as f:
json.dump({"temperature":0.2, "top_p":0.9, "max_new_tokens":512}, f)
upload_folder(repo_id=repo_id, folder_path=folder, repo_type="model", token=HF_TOKEN)
return f"Pushed {folder} to https://huggingface.co/{repo_id}"
# ---------- Small smoke test ----------
def smoke_run(local_model_dir="thinking3b-merged", prompt="Give 3 crisp bullets explaining CRDTs."):
tok = AutoTokenizer.from_pretrained(local_model_dir, use_fast=True)
m = AutoModelForCausalLM.from_pretrained(local_model_dir, torch_dtype="bfloat16", device_map="auto")
msgs = [
{"role":"system","content":"Think privately in <THINK>...</THINK>. Respond to the user ONLY in <FINAL>...</FINAL>."},
{"role":"user","content":prompt}
]
text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
ids = tok(text, return_tensors="pt").to(m.device)
out = m.generate(**ids, do_sample=True, temperature=0.2, top_p=0.9, max_new_tokens=256)
return tok.decode(out[0], skip_special_tokens=False)
# ---------- Long-context helpers ----------
def token_chunks(text: str, max_tokens=1600, overlap=200):
ids = tok.encode(text)
n = len(ids)
chunks = []
i = 0
k = 0
while i < n:
j = min(i + max_tokens, n)
piece = tok.decode(ids[i:j])
chunks.append((k, piece))
if j == n: break
i = j - overlap
k += 1
return chunks
# Prompts specialized for long-context reading
LC_SYS = (
"You are a careful researcher. Never reveal private thinking. "
"Use <THINK>..</THINK> for private notes and finish with <FINAL>..</FINAL>."
)
LC_PLAN = (
"We have a long document. In <THINK>, make a *very brief* reading plan: "
"key sections to scan and 3–6 questions to answer. Keep under 120 tokens.\n<THINK>\n"
)
LC_EXTRACT = """You are reading chunk #[{cid}] of a long document.
<CHUNK>
{chunk}
</CHUNK>
In <THINK> (≤150 tokens), extract only high-signal facts, numbers, names, dates, definitions
that help answer: "{query}". Prefix each item with [#{cid}] for citation.
Avoid repetition and opinions. Then stop.
<THINK>
"""
LC_MERGE = """You have private notes collected from multiple chunks:
<NOTES>
{notes}
</NOTES>
In <THINK> (≤{memo_budget} tokens), merge, deduplicate, and compress into a GLOBAL MEMO.
Keep only essential facts helpful to answer "{query}". Preserve [#chunk] citations on each fact.
Return ONLY the memo inside <THINK>..</THINK>.
<THINK>
"""
LC_FINAL = """Using the GLOBAL MEMO below, produce a final answer to: "{query}".
Keep it concise, and include bracketed citations like [#3,#5] on claims.
<GLOBAL_MEMO>
{memo}
</GLOBAL_MEMO>
Return ONLY inside <FINAL>..</FINAL>.
<FINAL>
"""
def _gen_llm(prompt, temperature=0.2, top_p=0.9, max_tokens=256, stop=None):
sp = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_tokens,
stop=stop or ["</THINK>", "</FINAL>"])
return llm.generate([prompt], sp)[0].outputs[0].text.strip()
def lc_apply_chat(system, user):
return tok.apply_chat_template(
[{"role":"system","content":system},{"role":"user","content":user}],
tokenize=False, add_generation_prompt=True
)
def longcontext_answer(query: str, doc_text: str,
chunk_tokens=1600, overlap=200,
n_plan_samples=2,
extract_temp=0.2, merge_temp=0.2, final_temp=0.2,
memo_budget=400):
# 0) Plan (optionally pick best of a few)
plan_samples = []
for _ in range(n_plan_samples):
plan_prompt = lc_apply_chat(LC_SYS, LC_PLAN)
plan_samples.append(_gen_llm(plan_prompt, temperature=0.7, top_p=0.95, max_tokens=160, stop=["</THINK>"]))
plan = max(plan_samples, key=len)
# 1) Chunk the document
chunks = token_chunks(doc_text, max_tokens=chunk_tokens, overlap=overlap)
# 2) Per-chunk extraction (low temperature, short think)
notes = []
for cid, chunk in chunks:
user = LC_EXTRACT.format(cid=cid, chunk=chunk, query=query)
prompt = lc_apply_chat(LC_SYS, user)
note = _gen_llm(prompt, temperature=extract_temp, top_p=0.9, max_tokens=180, stop=["</THINK>"])
if note:
notes.append(note)
# 3) Merge into a GLOBAL MEMO (bounded)
merged_prompt = lc_apply_chat(LC_SYS, LC_MERGE.format(notes="\n".join(notes),
query=query, memo_budget=memo_budget))
memo = _gen_llm(merged_prompt, temperature=merge_temp, top_p=0.9, max_tokens=memo_budget, stop=["</THINK>"])
# 4) Finalize with citations
final_prompt = lc_apply_chat(LC_SYS, LC_FINAL.format(memo=memo, query=query))
final_answer = _gen_llm(final_prompt, temperature=final_temp, top_p=0.9, max_tokens=512, stop=["</FINAL>"])
# Debug payload (optional)
debug = {
"plan": plan,
"n_chunks": len(chunks),
"first_3_notes": notes[:3],
"memo_tokens": len(tok.encode(memo)),
}
return final_answer, debug
# ---------- Gradio UI ----------
with gr.Blocks() as demo:
gr.Markdown("## 3B Thinking — Train • Merge • Push (Space)")
with gr.Row():
base_inp = gr.Textbox(label="BASE", value=BASE)
out_repo_inp = gr.Textbox(label="OUT_REPO (your-username/repo)", value=OUT_REPO)
log = gr.Markdown(visible=True, value="Ready.")
with gr.Tab("1) Build Dataset"):
build_btn = gr.Button("Build blend (train/eval)")
build_btn.click(lambda: _ok(build_blend()), outputs=log)
with gr.Tab("2) Train LoRA (QLoRA)"):
epochs = gr.Slider(1, 3, step=1, value=2, label="epochs")
lr = gr.Slider(1e-5, 5e-4, step=1e-5, value=2e-4, label="learning_rate")
lora_r = gr.Slider(8, 64, step=8, value=32, label="LoRA r")
lora_alpha = gr.Slider(8, 64, step=2, value=16, label="LoRA alpha")
lora_dropout = gr.Slider(0.0, 0.2, step=0.01, value=0.05, label="LoRA dropout")
max_len = gr.Slider(1024, 4096, step=128, value=3072, label="max_seq_length")
train_btn = gr.Button("Train LoRA")
train_btn.click(
lambda b,e,l,rr,aa,dd,ml: _ok(train_lora(b, "thinking3b-lora", e, l, rr, aa, dd, ml)),
inputs=[base_inp, epochs, lr, lora_r, lora_alpha, lora_dropout, max_len],
outputs=log
)
with gr.Tab("3) Merge Weights"):
merge_btn = gr.Button("Merge LoRA → full")
merge_btn.click(lambda b: _ok(merge_lora(b, "thinking3b-lora", "thinking3b-merged")),
inputs=[base_inp], outputs=log)
with gr.Tab("4) Push to Hub"):
push_btn = gr.Button("Push merged to OUT_REPO")
push_btn.click(lambda r: _ok(push_to_hub(r, "thinking3b-merged")),
inputs=[out_repo_inp], outputs=log)
with gr.Tab("Smoke Test"):
prompt = gr.Textbox(value="Give 3 crisp bullets explaining CRDTs.", label="Prompt")
test_btn = gr.Button("Run on merged model")
out_text = gr.Textbox(label="Raw decode")
test_btn.click(lambda p: smoke_run("thinking3b-merged", p), inputs=[prompt], outputs=[out_text])
with gr.Tab("Long-Context QA"):
q_lc = gr.Textbox(label="Question / Task", lines=3, placeholder="Your question…")
doc = gr.Textbox(label="Long document / context", lines=18, placeholder="Paste long text here…")
with gr.Row():
max_tok = gr.Slider(800, 2400, value=1600, step=100, label="chunk_tokens")
overlap = gr.Slider(100, 400, value=200, step=50, label="overlap")
memo = gr.Slider(200, 800, value=400, step=50, label="memo_budget")
with gr.Row():
nplan = gr.Slider(1, 3, value=2, step=1, label="plan samples")
t_ext = gr.Slider(0.1, 0.6, value=0.2, step=0.05, label="extract temp")
t_fin = gr.Slider(0.1, 0.5, value=0.2, step=0.05, label="final temp")
run_lc = gr.Button("Run Long-Context")
out_lc = gr.Textbox(label="Answer (with citations)", lines=10)
dbg_lc = gr.JSON(label="Debug (plan, memo size, #chunks)")
def _lc_run(query, text, ct, ov, mb, np, te, tf):
ans, info = longcontext_answer(
query, text, chunk_tokens=int(ct), overlap=int(ov),
n_plan_samples=int(np), extract_temp=float(te), final_temp=float(tf),
memo_budget=int(mb)
)
return ans, info
run_lc.click(_lc_run,
inputs=[q_lc, doc, max_tok, overlap, memo, nplan, t_ext, t_fin],
outputs=[out_lc, dbg_lc])
if __name__ == "__main__":
demo.launch()