Spaces:
Build error
Build error
| import os, json, re, random, time, shutil, threading | |
| import gradio as gr | |
| from datasets import load_dataset, Dataset, concatenate_datasets | |
| from huggingface_hub import HfApi, create_repo, upload_folder, whoami | |
| from transformers import ( | |
| AutoTokenizer, AutoModelForCausalLM, | |
| TrainingArguments | |
| ) | |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel | |
| from trl import SFTTrainer | |
| # ---------- Defaults & env ---------- | |
| BASE = os.getenv("BASE", "meta-llama/Llama-3.2-3B-Instruct") | |
| OUT_REPO = os.getenv("OUT_REPO", "your-username/llama32-3b-thinking") | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) | |
| random.seed(17) | |
| # ---------- Helpers ---------- | |
| def _ok(s): return gr.update(value=s, visible=True) | |
| def try_load(options, **kw): | |
| for dsid in options: | |
| try: | |
| return load_dataset(dsid, **kw) | |
| except Exception: | |
| continue | |
| raise RuntimeError(f"Failed loading any of: {options}") | |
| def trim_text(txt, max_words=220): | |
| w = (txt or "").split() | |
| return " ".join(w[:max_words]) | |
| def pack_record(instruction, rationale, final, inp=""): | |
| rationale = trim_text(rationale, 220) | |
| if len(rationale.split()) < 3: # drop trivial | |
| return None | |
| return { | |
| "instruction": instruction.strip(), | |
| "input": (inp or "").strip(), | |
| "rationale": rationale.strip(), | |
| "final": (final or "").strip() | |
| } | |
| def build_hotpot_rationale(supporting_facts, context, answer): | |
| m = {title: sents for title, sents in context} | |
| bits = [] | |
| for title, idx in supporting_facts[:3]: | |
| try: | |
| s = m[title][idx] | |
| bits.append(f"[{title}] {s}") | |
| except Exception: | |
| pass | |
| if not bits: return None | |
| return " ".join(bits) + f" ⇒ {answer}" | |
| # ---------- Dataset loaders (blend) ---------- | |
| def load_cose(): | |
| ds = try_load(["Salesforce/cos_e", "cos_e"], name="v1.11")["train"] | |
| rows=[] | |
| for ex in ds: | |
| choices = ex.get("choices") or ex.get("options") or [] | |
| rec = pack_record( | |
| instruction=f"Q: {ex['question']}\nOptions: {', '.join(choices)}", | |
| rationale=ex.get("abstractive_explanation") or ex.get("rationale",""), | |
| final=ex["answer"] | |
| ) | |
| if rec: rows.append(rec) | |
| return Dataset.from_list(rows) | |
| def load_esnli(limit=60000): | |
| ds = try_load(["esnli","esnli/esnli"])["train"].select(range(limit)) | |
| rows=[] | |
| for ex in ds: | |
| rat = ex.get("explanation_1") or ex.get("explanation_2") or ex.get("explanation_3") or "" | |
| rec = pack_record( | |
| instruction=f"Premise: {ex['premise']}\nHypothesis: {ex['hypothesis']}\n" | |
| f"Label (entailment/contradiction/neutral) and justify briefly.", | |
| rationale=rat, final=ex["label"] | |
| ) | |
| if rec: rows.append(rec) | |
| return Dataset.from_list(rows) | |
| def load_ecqa(): | |
| ds = try_load(["yangdong/ecqa","ecqa","google-research-datasets/ecqa"])["train"] | |
| rows=[] | |
| for ex in ds: | |
| opts = [ex.get(k) for k in ["opa","opb","opc","opd","ope"] if ex.get(k)] | |
| ans = ex.get("correct_ans","") or ex.get("label","") | |
| exp = ex.get("explanation","") or ex.get("rationale","") | |
| rec = pack_record( | |
| instruction=f"Q: {ex.get('question','')}\nOptions: {', '.join(opts)}", | |
| rationale=exp, final=str(ans) | |
| ) | |
| if rec: rows.append(rec) | |
| return Dataset.from_list(rows) | |
| def load_strategyqa(limit=6000): | |
| ds = try_load(["voidful/StrategyQA","allenai/strategyqa","strategy_qa"])["train"] | |
| rows=[]; i=0 | |
| for ex in ds: | |
| if limit and i>=limit: break | |
| i+=1 | |
| q = ex.get("question") or ex.get("q","") | |
| ans = str(ex.get("answer","")).lower() | |
| rat = ex.get("decomposition","") or " ".join(ex.get("facts",[])) or ex.get("evidence","") | |
| if not rat: rat = "Reason step by step to reach yes/no." | |
| rec = pack_record(instruction=q, rationale=rat, | |
| final="yes" if ans in ["1","true","yes"] else "no") | |
| if rec: rows.append(rec) | |
| return Dataset.from_list(rows) | |
| def load_hotpot(sample=15000): | |
| ds = try_load(["hotpotqa/hotpot_qa","hotpot_qa"], name="distractor")["train"] | |
| idx = list(range(len(ds))); random.shuffle(idx); idx = idx[:sample] | |
| rows=[] | |
| for i in idx: | |
| ex = ds[i] | |
| rat = build_hotpot_rationale(ex["supporting_facts"], ex["context"], ex["answer"]) | |
| if not rat: continue | |
| rec = pack_record(instruction=ex["question"], rationale=rat, final=ex["answer"]) | |
| if rec: rows.append(rec) | |
| return Dataset.from_list(rows) | |
| def load_gsm8k_train(): | |
| ds = try_load(["openai/gsm8k","gsm8k"], name="main")["train"] | |
| rows=[] | |
| for ex in ds: | |
| sol = ex.get("solution","") | |
| m = re.findall(r"(-?\d+(?:\.\d+)?)", sol) | |
| final = m[-1] if m else ex.get("answer","") | |
| rec = pack_record(instruction=ex["question"], rationale=sol, final=str(final)) | |
| if rec: rows.append(rec) | |
| return Dataset.from_list(rows) | |
| def load_openthoughts(limit=100000): | |
| try: | |
| ds = try_load(["open-thoughts/OpenThoughts-114k","OpenThoughts-114k"])["train"] | |
| if limit: ds = ds.select(range(min(limit, len(ds)))) | |
| rows=[] | |
| for ex in ds: | |
| q = ex.get("question") or ex.get("instruction") or "" | |
| rat = ex.get("cot") or ex.get("rationale") or "" | |
| ans = ex.get("answer") or ex.get("final") or "" | |
| rec = pack_record(instruction=q, rationale=rat, final=ans) | |
| if rec: rows.append(rec) | |
| return Dataset.from_list(rows) | |
| except Exception: | |
| return Dataset.from_list([]) | |
| def load_bespoke(): | |
| try: | |
| ds = try_load(["HuggingFaceH4/Bespoke-Stratos-17k","Bespoke-Stratos-17k"])["train"] | |
| rows=[] | |
| for ex in ds: | |
| q = ex.get("prompt") or ex.get("question") or "" | |
| rat = ex.get("reasoning") or ex.get("rationale") or "" | |
| ans = ex.get("output") or ex.get("final") or "" | |
| rec = pack_record(instruction=q, rationale=rat, final=ans) | |
| if rec: rows.append(rec) | |
| return Dataset.from_list(rows) | |
| except Exception: | |
| return Dataset.from_list([]) | |
| # ---------- Build blend ---------- | |
| def build_blend(): | |
| parts = [ | |
| load_openthoughts(limit=100000), | |
| load_bespoke(), | |
| load_gsm8k_train(), | |
| load_cose(), | |
| load_esnli(limit=60000), | |
| load_ecqa(), | |
| load_strategyqa(limit=6000), | |
| load_hotpot(sample=15000), | |
| ] | |
| parts = [p for p in parts if len(p)>0] | |
| mix = concatenate_datasets(parts).shuffle(seed=17) | |
| n_total = len(mix) | |
| # split tiny eval | |
| eval_size = min(3000, max(1000, int(0.01*n_total))) | |
| eval_ds = mix.select(range(eval_size)) | |
| mix.to_json("blend_train.jsonl", orient="records", lines=True) | |
| eval_ds.to_json("blend_eval.jsonl", orient="records", lines=True) | |
| return f"Blend built. Train: {n_total} rows. Eval: {len(eval_ds)} rows. Files: blend_train.jsonl, blend_eval.jsonl" | |
| # ---------- Formatter for SFT ---------- | |
| def to_chat_formatter(tokenizer): | |
| def _fmt(ex): | |
| msgs = [ | |
| {"role":"system","content":"Think privately in <THINK>...</THINK>. Answer ONLY in <FINAL>...</FINAL>."}, | |
| {"role":"user","content": ex["instruction"] + (("\n\n"+ex["input"]) if ex.get("input") else "")}, | |
| {"role":"assistant","content": f"<THINK>{ex['rationale']}</THINK>\n<FINAL>{ex['final']}</FINAL>"} | |
| ] | |
| return {"text": tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)} | |
| return _fmt | |
| # ---------- Train LoRA ---------- | |
| def train_lora(base=BASE, out_dir="thinking3b-lora", epochs=2, lr=2e-4, r=32, alpha=16, dropout=0.05, max_len=3072): | |
| assert HF_TOKEN, "HF_TOKEN not found (Space Secret)." | |
| tok = AutoTokenizer.from_pretrained(base, use_fast=True, token=HF_TOKEN) | |
| tok.pad_token = tok.eos_token | |
| train = load_dataset("json", data_files="blend_train.jsonl")["train"].map(to_chat_formatter(tok), remove_columns=["instruction","input","rationale","final"]) | |
| evald = load_dataset("json", data_files="blend_eval.jsonl")["train"].map(to_chat_formatter(tok), remove_columns=["instruction","input","rationale","final"]) | |
| model = AutoModelForCausalLM.from_pretrained(base, load_in_4bit=True, torch_dtype="auto", device_map="auto", token=HF_TOKEN) | |
| model = prepare_model_for_kbit_training(model) | |
| lora = LoraConfig(r=r, lora_alpha=alpha, lora_dropout=dropout, | |
| target_modules=["q_proj","k_proj","v_proj","o_proj"]) | |
| model = get_peft_model(model, lora) | |
| args = TrainingArguments( | |
| output_dir=out_dir, | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=8, | |
| learning_rate=lr, | |
| num_train_epochs=epochs, | |
| logging_steps=25, | |
| save_strategy="epoch", | |
| evaluation_strategy="epoch", | |
| bf16=True, | |
| lr_scheduler_type="cosine", | |
| warmup_ratio=0.03, | |
| weight_decay=0.0, | |
| max_grad_norm=1.0 | |
| ) | |
| trainer = SFTTrainer( | |
| model=model, tokenizer=tok, | |
| train_dataset=train, eval_dataset=evald, | |
| dataset_text_field="text", | |
| packing=True, max_seq_length=max_len, | |
| args=args | |
| ) | |
| trainer.train() | |
| model.save_pretrained(out_dir) | |
| tok.save_pretrained(out_dir) | |
| return f"LoRA saved to {out_dir}" | |
| # ---------- Merge LoRA ---------- | |
| def merge_lora(base=BASE, adapter_dir="thinking3b-lora", out_dir="thinking3b-merged"): | |
| tok = AutoTokenizer.from_pretrained(base, use_fast=True, token=HF_TOKEN) | |
| base_m = AutoModelForCausalLM.from_pretrained(base, torch_dtype="bfloat16", device_map="auto", token=HF_TOKEN) | |
| merged = PeftModel.from_pretrained(base_m, adapter_dir).merge_and_unload() | |
| merged.save_pretrained(out_dir, safe_serialization=True) | |
| tok.save_pretrained(out_dir) | |
| return f"Merged weights saved to {out_dir}" | |
| # ---------- Push to Hub ---------- | |
| def push_to_hub(repo_id=OUT_REPO, folder="thinking3b-merged"): | |
| assert HF_TOKEN, "HF_TOKEN not found." | |
| api = HfApi(token=HF_TOKEN) | |
| # create repo if needed | |
| try: | |
| create_repo(repo_id, repo_type="model", token=HF_TOKEN, exist_ok=True) | |
| except Exception: | |
| pass | |
| # add a sane generation config | |
| with open(os.path.join(folder, "generation_config.json"), "w", encoding="utf-8") as f: | |
| json.dump({"temperature":0.2, "top_p":0.9, "max_new_tokens":512}, f) | |
| upload_folder(repo_id=repo_id, folder_path=folder, repo_type="model", token=HF_TOKEN) | |
| return f"Pushed {folder} to https://huggingface.co/{repo_id}" | |
| # ---------- Small smoke test ---------- | |
| def smoke_run(local_model_dir="thinking3b-merged", prompt="Give 3 crisp bullets explaining CRDTs."): | |
| tok = AutoTokenizer.from_pretrained(local_model_dir, use_fast=True) | |
| m = AutoModelForCausalLM.from_pretrained(local_model_dir, torch_dtype="bfloat16", device_map="auto") | |
| msgs = [ | |
| {"role":"system","content":"Think privately in <THINK>...</THINK>. Respond to the user ONLY in <FINAL>...</FINAL>."}, | |
| {"role":"user","content":prompt} | |
| ] | |
| text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) | |
| ids = tok(text, return_tensors="pt").to(m.device) | |
| out = m.generate(**ids, do_sample=True, temperature=0.2, top_p=0.9, max_new_tokens=256) | |
| return tok.decode(out[0], skip_special_tokens=False) | |
| # ---------- Long-context helpers ---------- | |
| def token_chunks(text: str, max_tokens=1600, overlap=200): | |
| ids = tok.encode(text) | |
| n = len(ids) | |
| chunks = [] | |
| i = 0 | |
| k = 0 | |
| while i < n: | |
| j = min(i + max_tokens, n) | |
| piece = tok.decode(ids[i:j]) | |
| chunks.append((k, piece)) | |
| if j == n: break | |
| i = j - overlap | |
| k += 1 | |
| return chunks | |
| # Prompts specialized for long-context reading | |
| LC_SYS = ( | |
| "You are a careful researcher. Never reveal private thinking. " | |
| "Use <THINK>..</THINK> for private notes and finish with <FINAL>..</FINAL>." | |
| ) | |
| LC_PLAN = ( | |
| "We have a long document. In <THINK>, make a *very brief* reading plan: " | |
| "key sections to scan and 3–6 questions to answer. Keep under 120 tokens.\n<THINK>\n" | |
| ) | |
| LC_EXTRACT = """You are reading chunk #[{cid}] of a long document. | |
| <CHUNK> | |
| {chunk} | |
| </CHUNK> | |
| In <THINK> (≤150 tokens), extract only high-signal facts, numbers, names, dates, definitions | |
| that help answer: "{query}". Prefix each item with [#{cid}] for citation. | |
| Avoid repetition and opinions. Then stop. | |
| <THINK> | |
| """ | |
| LC_MERGE = """You have private notes collected from multiple chunks: | |
| <NOTES> | |
| {notes} | |
| </NOTES> | |
| In <THINK> (≤{memo_budget} tokens), merge, deduplicate, and compress into a GLOBAL MEMO. | |
| Keep only essential facts helpful to answer "{query}". Preserve [#chunk] citations on each fact. | |
| Return ONLY the memo inside <THINK>..</THINK>. | |
| <THINK> | |
| """ | |
| LC_FINAL = """Using the GLOBAL MEMO below, produce a final answer to: "{query}". | |
| Keep it concise, and include bracketed citations like [#3,#5] on claims. | |
| <GLOBAL_MEMO> | |
| {memo} | |
| </GLOBAL_MEMO> | |
| Return ONLY inside <FINAL>..</FINAL>. | |
| <FINAL> | |
| """ | |
| def _gen_llm(prompt, temperature=0.2, top_p=0.9, max_tokens=256, stop=None): | |
| sp = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_tokens, | |
| stop=stop or ["</THINK>", "</FINAL>"]) | |
| return llm.generate([prompt], sp)[0].outputs[0].text.strip() | |
| def lc_apply_chat(system, user): | |
| return tok.apply_chat_template( | |
| [{"role":"system","content":system},{"role":"user","content":user}], | |
| tokenize=False, add_generation_prompt=True | |
| ) | |
| def longcontext_answer(query: str, doc_text: str, | |
| chunk_tokens=1600, overlap=200, | |
| n_plan_samples=2, | |
| extract_temp=0.2, merge_temp=0.2, final_temp=0.2, | |
| memo_budget=400): | |
| # 0) Plan (optionally pick best of a few) | |
| plan_samples = [] | |
| for _ in range(n_plan_samples): | |
| plan_prompt = lc_apply_chat(LC_SYS, LC_PLAN) | |
| plan_samples.append(_gen_llm(plan_prompt, temperature=0.7, top_p=0.95, max_tokens=160, stop=["</THINK>"])) | |
| plan = max(plan_samples, key=len) | |
| # 1) Chunk the document | |
| chunks = token_chunks(doc_text, max_tokens=chunk_tokens, overlap=overlap) | |
| # 2) Per-chunk extraction (low temperature, short think) | |
| notes = [] | |
| for cid, chunk in chunks: | |
| user = LC_EXTRACT.format(cid=cid, chunk=chunk, query=query) | |
| prompt = lc_apply_chat(LC_SYS, user) | |
| note = _gen_llm(prompt, temperature=extract_temp, top_p=0.9, max_tokens=180, stop=["</THINK>"]) | |
| if note: | |
| notes.append(note) | |
| # 3) Merge into a GLOBAL MEMO (bounded) | |
| merged_prompt = lc_apply_chat(LC_SYS, LC_MERGE.format(notes="\n".join(notes), | |
| query=query, memo_budget=memo_budget)) | |
| memo = _gen_llm(merged_prompt, temperature=merge_temp, top_p=0.9, max_tokens=memo_budget, stop=["</THINK>"]) | |
| # 4) Finalize with citations | |
| final_prompt = lc_apply_chat(LC_SYS, LC_FINAL.format(memo=memo, query=query)) | |
| final_answer = _gen_llm(final_prompt, temperature=final_temp, top_p=0.9, max_tokens=512, stop=["</FINAL>"]) | |
| # Debug payload (optional) | |
| debug = { | |
| "plan": plan, | |
| "n_chunks": len(chunks), | |
| "first_3_notes": notes[:3], | |
| "memo_tokens": len(tok.encode(memo)), | |
| } | |
| return final_answer, debug | |
| # ---------- Gradio UI ---------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 3B Thinking — Train • Merge • Push (Space)") | |
| with gr.Row(): | |
| base_inp = gr.Textbox(label="BASE", value=BASE) | |
| out_repo_inp = gr.Textbox(label="OUT_REPO (your-username/repo)", value=OUT_REPO) | |
| log = gr.Markdown(visible=True, value="Ready.") | |
| with gr.Tab("1) Build Dataset"): | |
| build_btn = gr.Button("Build blend (train/eval)") | |
| build_btn.click(lambda: _ok(build_blend()), outputs=log) | |
| with gr.Tab("2) Train LoRA (QLoRA)"): | |
| epochs = gr.Slider(1, 3, step=1, value=2, label="epochs") | |
| lr = gr.Slider(1e-5, 5e-4, step=1e-5, value=2e-4, label="learning_rate") | |
| lora_r = gr.Slider(8, 64, step=8, value=32, label="LoRA r") | |
| lora_alpha = gr.Slider(8, 64, step=2, value=16, label="LoRA alpha") | |
| lora_dropout = gr.Slider(0.0, 0.2, step=0.01, value=0.05, label="LoRA dropout") | |
| max_len = gr.Slider(1024, 4096, step=128, value=3072, label="max_seq_length") | |
| train_btn = gr.Button("Train LoRA") | |
| train_btn.click( | |
| lambda b,e,l,rr,aa,dd,ml: _ok(train_lora(b, "thinking3b-lora", e, l, rr, aa, dd, ml)), | |
| inputs=[base_inp, epochs, lr, lora_r, lora_alpha, lora_dropout, max_len], | |
| outputs=log | |
| ) | |
| with gr.Tab("3) Merge Weights"): | |
| merge_btn = gr.Button("Merge LoRA → full") | |
| merge_btn.click(lambda b: _ok(merge_lora(b, "thinking3b-lora", "thinking3b-merged")), | |
| inputs=[base_inp], outputs=log) | |
| with gr.Tab("4) Push to Hub"): | |
| push_btn = gr.Button("Push merged to OUT_REPO") | |
| push_btn.click(lambda r: _ok(push_to_hub(r, "thinking3b-merged")), | |
| inputs=[out_repo_inp], outputs=log) | |
| with gr.Tab("Smoke Test"): | |
| prompt = gr.Textbox(value="Give 3 crisp bullets explaining CRDTs.", label="Prompt") | |
| test_btn = gr.Button("Run on merged model") | |
| out_text = gr.Textbox(label="Raw decode") | |
| test_btn.click(lambda p: smoke_run("thinking3b-merged", p), inputs=[prompt], outputs=[out_text]) | |
| with gr.Tab("Long-Context QA"): | |
| q_lc = gr.Textbox(label="Question / Task", lines=3, placeholder="Your question…") | |
| doc = gr.Textbox(label="Long document / context", lines=18, placeholder="Paste long text here…") | |
| with gr.Row(): | |
| max_tok = gr.Slider(800, 2400, value=1600, step=100, label="chunk_tokens") | |
| overlap = gr.Slider(100, 400, value=200, step=50, label="overlap") | |
| memo = gr.Slider(200, 800, value=400, step=50, label="memo_budget") | |
| with gr.Row(): | |
| nplan = gr.Slider(1, 3, value=2, step=1, label="plan samples") | |
| t_ext = gr.Slider(0.1, 0.6, value=0.2, step=0.05, label="extract temp") | |
| t_fin = gr.Slider(0.1, 0.5, value=0.2, step=0.05, label="final temp") | |
| run_lc = gr.Button("Run Long-Context") | |
| out_lc = gr.Textbox(label="Answer (with citations)", lines=10) | |
| dbg_lc = gr.JSON(label="Debug (plan, memo size, #chunks)") | |
| def _lc_run(query, text, ct, ov, mb, np, te, tf): | |
| ans, info = longcontext_answer( | |
| query, text, chunk_tokens=int(ct), overlap=int(ov), | |
| n_plan_samples=int(np), extract_temp=float(te), final_temp=float(tf), | |
| memo_budget=int(mb) | |
| ) | |
| return ans, info | |
| run_lc.click(_lc_run, | |
| inputs=[q_lc, doc, max_tok, overlap, memo, nplan, t_ext, t_fin], | |
| outputs=[out_lc, dbg_lc]) | |
| if __name__ == "__main__": | |
| demo.launch() | |