import os, json, re, random, time, shutil, threading import gradio as gr from datasets import load_dataset, Dataset, concatenate_datasets from huggingface_hub import HfApi, create_repo, upload_folder, whoami from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments ) from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel from trl import SFTTrainer # ---------- Defaults & env ---------- BASE = os.getenv("BASE", "meta-llama/Llama-3.2-3B-Instruct") OUT_REPO = os.getenv("OUT_REPO", "your-username/llama32-3b-thinking") HF_TOKEN = os.getenv("HF_TOKEN", None) random.seed(17) # ---------- Helpers ---------- def _ok(s): return gr.update(value=s, visible=True) def try_load(options, **kw): for dsid in options: try: return load_dataset(dsid, **kw) except Exception: continue raise RuntimeError(f"Failed loading any of: {options}") def trim_text(txt, max_words=220): w = (txt or "").split() return " ".join(w[:max_words]) def pack_record(instruction, rationale, final, inp=""): rationale = trim_text(rationale, 220) if len(rationale.split()) < 3: # drop trivial return None return { "instruction": instruction.strip(), "input": (inp or "").strip(), "rationale": rationale.strip(), "final": (final or "").strip() } def build_hotpot_rationale(supporting_facts, context, answer): m = {title: sents for title, sents in context} bits = [] for title, idx in supporting_facts[:3]: try: s = m[title][idx] bits.append(f"[{title}] {s}") except Exception: pass if not bits: return None return " ".join(bits) + f" ⇒ {answer}" # ---------- Dataset loaders (blend) ---------- def load_cose(): ds = try_load(["Salesforce/cos_e", "cos_e"], name="v1.11")["train"] rows=[] for ex in ds: choices = ex.get("choices") or ex.get("options") or [] rec = pack_record( instruction=f"Q: {ex['question']}\nOptions: {', '.join(choices)}", rationale=ex.get("abstractive_explanation") or ex.get("rationale",""), final=ex["answer"] ) if rec: rows.append(rec) return Dataset.from_list(rows) def load_esnli(limit=60000): ds = try_load(["esnli","esnli/esnli"])["train"].select(range(limit)) rows=[] for ex in ds: rat = ex.get("explanation_1") or ex.get("explanation_2") or ex.get("explanation_3") or "" rec = pack_record( instruction=f"Premise: {ex['premise']}\nHypothesis: {ex['hypothesis']}\n" f"Label (entailment/contradiction/neutral) and justify briefly.", rationale=rat, final=ex["label"] ) if rec: rows.append(rec) return Dataset.from_list(rows) def load_ecqa(): ds = try_load(["yangdong/ecqa","ecqa","google-research-datasets/ecqa"])["train"] rows=[] for ex in ds: opts = [ex.get(k) for k in ["opa","opb","opc","opd","ope"] if ex.get(k)] ans = ex.get("correct_ans","") or ex.get("label","") exp = ex.get("explanation","") or ex.get("rationale","") rec = pack_record( instruction=f"Q: {ex.get('question','')}\nOptions: {', '.join(opts)}", rationale=exp, final=str(ans) ) if rec: rows.append(rec) return Dataset.from_list(rows) def load_strategyqa(limit=6000): ds = try_load(["voidful/StrategyQA","allenai/strategyqa","strategy_qa"])["train"] rows=[]; i=0 for ex in ds: if limit and i>=limit: break i+=1 q = ex.get("question") or ex.get("q","") ans = str(ex.get("answer","")).lower() rat = ex.get("decomposition","") or " ".join(ex.get("facts",[])) or ex.get("evidence","") if not rat: rat = "Reason step by step to reach yes/no." rec = pack_record(instruction=q, rationale=rat, final="yes" if ans in ["1","true","yes"] else "no") if rec: rows.append(rec) return Dataset.from_list(rows) def load_hotpot(sample=15000): ds = try_load(["hotpotqa/hotpot_qa","hotpot_qa"], name="distractor")["train"] idx = list(range(len(ds))); random.shuffle(idx); idx = idx[:sample] rows=[] for i in idx: ex = ds[i] rat = build_hotpot_rationale(ex["supporting_facts"], ex["context"], ex["answer"]) if not rat: continue rec = pack_record(instruction=ex["question"], rationale=rat, final=ex["answer"]) if rec: rows.append(rec) return Dataset.from_list(rows) def load_gsm8k_train(): ds = try_load(["openai/gsm8k","gsm8k"], name="main")["train"] rows=[] for ex in ds: sol = ex.get("solution","") m = re.findall(r"(-?\d+(?:\.\d+)?)", sol) final = m[-1] if m else ex.get("answer","") rec = pack_record(instruction=ex["question"], rationale=sol, final=str(final)) if rec: rows.append(rec) return Dataset.from_list(rows) def load_openthoughts(limit=100000): try: ds = try_load(["open-thoughts/OpenThoughts-114k","OpenThoughts-114k"])["train"] if limit: ds = ds.select(range(min(limit, len(ds)))) rows=[] for ex in ds: q = ex.get("question") or ex.get("instruction") or "" rat = ex.get("cot") or ex.get("rationale") or "" ans = ex.get("answer") or ex.get("final") or "" rec = pack_record(instruction=q, rationale=rat, final=ans) if rec: rows.append(rec) return Dataset.from_list(rows) except Exception: return Dataset.from_list([]) def load_bespoke(): try: ds = try_load(["HuggingFaceH4/Bespoke-Stratos-17k","Bespoke-Stratos-17k"])["train"] rows=[] for ex in ds: q = ex.get("prompt") or ex.get("question") or "" rat = ex.get("reasoning") or ex.get("rationale") or "" ans = ex.get("output") or ex.get("final") or "" rec = pack_record(instruction=q, rationale=rat, final=ans) if rec: rows.append(rec) return Dataset.from_list(rows) except Exception: return Dataset.from_list([]) # ---------- Build blend ---------- def build_blend(): parts = [ load_openthoughts(limit=100000), load_bespoke(), load_gsm8k_train(), load_cose(), load_esnli(limit=60000), load_ecqa(), load_strategyqa(limit=6000), load_hotpot(sample=15000), ] parts = [p for p in parts if len(p)>0] mix = concatenate_datasets(parts).shuffle(seed=17) n_total = len(mix) # split tiny eval eval_size = min(3000, max(1000, int(0.01*n_total))) eval_ds = mix.select(range(eval_size)) mix.to_json("blend_train.jsonl", orient="records", lines=True) eval_ds.to_json("blend_eval.jsonl", orient="records", lines=True) return f"Blend built. Train: {n_total} rows. Eval: {len(eval_ds)} rows. Files: blend_train.jsonl, blend_eval.jsonl" # ---------- Formatter for SFT ---------- def to_chat_formatter(tokenizer): def _fmt(ex): msgs = [ {"role":"system","content":"Think privately in .... Answer ONLY in ...."}, {"role":"user","content": ex["instruction"] + (("\n\n"+ex["input"]) if ex.get("input") else "")}, {"role":"assistant","content": f"{ex['rationale']}\n{ex['final']}"} ] return {"text": tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)} return _fmt # ---------- Train LoRA ---------- def train_lora(base=BASE, out_dir="thinking3b-lora", epochs=2, lr=2e-4, r=32, alpha=16, dropout=0.05, max_len=3072): assert HF_TOKEN, "HF_TOKEN not found (Space Secret)." tok = AutoTokenizer.from_pretrained(base, use_fast=True, token=HF_TOKEN) tok.pad_token = tok.eos_token train = load_dataset("json", data_files="blend_train.jsonl")["train"].map(to_chat_formatter(tok), remove_columns=["instruction","input","rationale","final"]) evald = load_dataset("json", data_files="blend_eval.jsonl")["train"].map(to_chat_formatter(tok), remove_columns=["instruction","input","rationale","final"]) model = AutoModelForCausalLM.from_pretrained(base, load_in_4bit=True, torch_dtype="auto", device_map="auto", token=HF_TOKEN) model = prepare_model_for_kbit_training(model) lora = LoraConfig(r=r, lora_alpha=alpha, lora_dropout=dropout, target_modules=["q_proj","k_proj","v_proj","o_proj"]) model = get_peft_model(model, lora) args = TrainingArguments( output_dir=out_dir, per_device_train_batch_size=1, gradient_accumulation_steps=8, learning_rate=lr, num_train_epochs=epochs, logging_steps=25, save_strategy="epoch", evaluation_strategy="epoch", bf16=True, lr_scheduler_type="cosine", warmup_ratio=0.03, weight_decay=0.0, max_grad_norm=1.0 ) trainer = SFTTrainer( model=model, tokenizer=tok, train_dataset=train, eval_dataset=evald, dataset_text_field="text", packing=True, max_seq_length=max_len, args=args ) trainer.train() model.save_pretrained(out_dir) tok.save_pretrained(out_dir) return f"LoRA saved to {out_dir}" # ---------- Merge LoRA ---------- def merge_lora(base=BASE, adapter_dir="thinking3b-lora", out_dir="thinking3b-merged"): tok = AutoTokenizer.from_pretrained(base, use_fast=True, token=HF_TOKEN) base_m = AutoModelForCausalLM.from_pretrained(base, torch_dtype="bfloat16", device_map="auto", token=HF_TOKEN) merged = PeftModel.from_pretrained(base_m, adapter_dir).merge_and_unload() merged.save_pretrained(out_dir, safe_serialization=True) tok.save_pretrained(out_dir) return f"Merged weights saved to {out_dir}" # ---------- Push to Hub ---------- def push_to_hub(repo_id=OUT_REPO, folder="thinking3b-merged"): assert HF_TOKEN, "HF_TOKEN not found." api = HfApi(token=HF_TOKEN) # create repo if needed try: create_repo(repo_id, repo_type="model", token=HF_TOKEN, exist_ok=True) except Exception: pass # add a sane generation config with open(os.path.join(folder, "generation_config.json"), "w", encoding="utf-8") as f: json.dump({"temperature":0.2, "top_p":0.9, "max_new_tokens":512}, f) upload_folder(repo_id=repo_id, folder_path=folder, repo_type="model", token=HF_TOKEN) return f"Pushed {folder} to https://huggingface.co/{repo_id}" # ---------- Small smoke test ---------- def smoke_run(local_model_dir="thinking3b-merged", prompt="Give 3 crisp bullets explaining CRDTs."): tok = AutoTokenizer.from_pretrained(local_model_dir, use_fast=True) m = AutoModelForCausalLM.from_pretrained(local_model_dir, torch_dtype="bfloat16", device_map="auto") msgs = [ {"role":"system","content":"Think privately in .... Respond to the user ONLY in ...."}, {"role":"user","content":prompt} ] text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) ids = tok(text, return_tensors="pt").to(m.device) out = m.generate(**ids, do_sample=True, temperature=0.2, top_p=0.9, max_new_tokens=256) return tok.decode(out[0], skip_special_tokens=False) # ---------- Long-context helpers ---------- def token_chunks(text: str, max_tokens=1600, overlap=200): ids = tok.encode(text) n = len(ids) chunks = [] i = 0 k = 0 while i < n: j = min(i + max_tokens, n) piece = tok.decode(ids[i:j]) chunks.append((k, piece)) if j == n: break i = j - overlap k += 1 return chunks # Prompts specialized for long-context reading LC_SYS = ( "You are a careful researcher. Never reveal private thinking. " "Use .. for private notes and finish with ..." ) LC_PLAN = ( "We have a long document. In , make a *very brief* reading plan: " "key sections to scan and 3–6 questions to answer. Keep under 120 tokens.\n\n" ) LC_EXTRACT = """You are reading chunk #[{cid}] of a long document. {chunk} In (≤150 tokens), extract only high-signal facts, numbers, names, dates, definitions that help answer: "{query}". Prefix each item with [#{cid}] for citation. Avoid repetition and opinions. Then stop. """ LC_MERGE = """You have private notes collected from multiple chunks: {notes} In (≤{memo_budget} tokens), merge, deduplicate, and compress into a GLOBAL MEMO. Keep only essential facts helpful to answer "{query}". Preserve [#chunk] citations on each fact. Return ONLY the memo inside ... """ LC_FINAL = """Using the GLOBAL MEMO below, produce a final answer to: "{query}". Keep it concise, and include bracketed citations like [#3,#5] on claims. {memo} Return ONLY inside ... """ def _gen_llm(prompt, temperature=0.2, top_p=0.9, max_tokens=256, stop=None): sp = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_tokens, stop=stop or ["", ""]) return llm.generate([prompt], sp)[0].outputs[0].text.strip() def lc_apply_chat(system, user): return tok.apply_chat_template( [{"role":"system","content":system},{"role":"user","content":user}], tokenize=False, add_generation_prompt=True ) def longcontext_answer(query: str, doc_text: str, chunk_tokens=1600, overlap=200, n_plan_samples=2, extract_temp=0.2, merge_temp=0.2, final_temp=0.2, memo_budget=400): # 0) Plan (optionally pick best of a few) plan_samples = [] for _ in range(n_plan_samples): plan_prompt = lc_apply_chat(LC_SYS, LC_PLAN) plan_samples.append(_gen_llm(plan_prompt, temperature=0.7, top_p=0.95, max_tokens=160, stop=[""])) plan = max(plan_samples, key=len) # 1) Chunk the document chunks = token_chunks(doc_text, max_tokens=chunk_tokens, overlap=overlap) # 2) Per-chunk extraction (low temperature, short think) notes = [] for cid, chunk in chunks: user = LC_EXTRACT.format(cid=cid, chunk=chunk, query=query) prompt = lc_apply_chat(LC_SYS, user) note = _gen_llm(prompt, temperature=extract_temp, top_p=0.9, max_tokens=180, stop=[""]) if note: notes.append(note) # 3) Merge into a GLOBAL MEMO (bounded) merged_prompt = lc_apply_chat(LC_SYS, LC_MERGE.format(notes="\n".join(notes), query=query, memo_budget=memo_budget)) memo = _gen_llm(merged_prompt, temperature=merge_temp, top_p=0.9, max_tokens=memo_budget, stop=[""]) # 4) Finalize with citations final_prompt = lc_apply_chat(LC_SYS, LC_FINAL.format(memo=memo, query=query)) final_answer = _gen_llm(final_prompt, temperature=final_temp, top_p=0.9, max_tokens=512, stop=[""]) # Debug payload (optional) debug = { "plan": plan, "n_chunks": len(chunks), "first_3_notes": notes[:3], "memo_tokens": len(tok.encode(memo)), } return final_answer, debug # ---------- Gradio UI ---------- with gr.Blocks() as demo: gr.Markdown("## 3B Thinking — Train • Merge • Push (Space)") with gr.Row(): base_inp = gr.Textbox(label="BASE", value=BASE) out_repo_inp = gr.Textbox(label="OUT_REPO (your-username/repo)", value=OUT_REPO) log = gr.Markdown(visible=True, value="Ready.") with gr.Tab("1) Build Dataset"): build_btn = gr.Button("Build blend (train/eval)") build_btn.click(lambda: _ok(build_blend()), outputs=log) with gr.Tab("2) Train LoRA (QLoRA)"): epochs = gr.Slider(1, 3, step=1, value=2, label="epochs") lr = gr.Slider(1e-5, 5e-4, step=1e-5, value=2e-4, label="learning_rate") lora_r = gr.Slider(8, 64, step=8, value=32, label="LoRA r") lora_alpha = gr.Slider(8, 64, step=2, value=16, label="LoRA alpha") lora_dropout = gr.Slider(0.0, 0.2, step=0.01, value=0.05, label="LoRA dropout") max_len = gr.Slider(1024, 4096, step=128, value=3072, label="max_seq_length") train_btn = gr.Button("Train LoRA") train_btn.click( lambda b,e,l,rr,aa,dd,ml: _ok(train_lora(b, "thinking3b-lora", e, l, rr, aa, dd, ml)), inputs=[base_inp, epochs, lr, lora_r, lora_alpha, lora_dropout, max_len], outputs=log ) with gr.Tab("3) Merge Weights"): merge_btn = gr.Button("Merge LoRA → full") merge_btn.click(lambda b: _ok(merge_lora(b, "thinking3b-lora", "thinking3b-merged")), inputs=[base_inp], outputs=log) with gr.Tab("4) Push to Hub"): push_btn = gr.Button("Push merged to OUT_REPO") push_btn.click(lambda r: _ok(push_to_hub(r, "thinking3b-merged")), inputs=[out_repo_inp], outputs=log) with gr.Tab("Smoke Test"): prompt = gr.Textbox(value="Give 3 crisp bullets explaining CRDTs.", label="Prompt") test_btn = gr.Button("Run on merged model") out_text = gr.Textbox(label="Raw decode") test_btn.click(lambda p: smoke_run("thinking3b-merged", p), inputs=[prompt], outputs=[out_text]) with gr.Tab("Long-Context QA"): q_lc = gr.Textbox(label="Question / Task", lines=3, placeholder="Your question…") doc = gr.Textbox(label="Long document / context", lines=18, placeholder="Paste long text here…") with gr.Row(): max_tok = gr.Slider(800, 2400, value=1600, step=100, label="chunk_tokens") overlap = gr.Slider(100, 400, value=200, step=50, label="overlap") memo = gr.Slider(200, 800, value=400, step=50, label="memo_budget") with gr.Row(): nplan = gr.Slider(1, 3, value=2, step=1, label="plan samples") t_ext = gr.Slider(0.1, 0.6, value=0.2, step=0.05, label="extract temp") t_fin = gr.Slider(0.1, 0.5, value=0.2, step=0.05, label="final temp") run_lc = gr.Button("Run Long-Context") out_lc = gr.Textbox(label="Answer (with citations)", lines=10) dbg_lc = gr.JSON(label="Debug (plan, memo size, #chunks)") def _lc_run(query, text, ct, ov, mb, np, te, tf): ans, info = longcontext_answer( query, text, chunk_tokens=int(ct), overlap=int(ov), n_plan_samples=int(np), extract_temp=float(te), final_temp=float(tf), memo_budget=int(mb) ) return ans, info run_lc.click(_lc_run, inputs=[q_lc, doc, max_tok, overlap, memo, nplan, t_ext, t_fin], outputs=[out_lc, dbg_lc]) if __name__ == "__main__": demo.launch()