from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse from optimum.intel import OVModelForCausalLM from transformers import AutoTokenizer import huggingface_hub import multiprocessing import os import re os.environ["OMP_NUM_THREADS"] = str(multiprocessing.cpu_count()) os.environ["OV_CPU_THREADS_NUM"] = str(multiprocessing.cpu_count()) app = FastAPI() model_name = "google/functiongemma-270m-it" OV_MODEL_DIR = "functiongemma_ov" # Authenticate with HuggingFace if token is provided hf_token = os.environ.get("HF_TOKEN") if hf_token: huggingface_hub.login(token=hf_token) # Export model to OpenVINO format on first run if not already done. # Using the Python API (export=True) instead of the CLI produces a model with # dynamic shapes, avoiding static sequence-length constants baked in by tracing. if not os.path.isdir(OV_MODEL_DIR): print(f"OpenVINO model not found at '{OV_MODEL_DIR}', exporting now...") _export_model = OVModelForCausalLM.from_pretrained(model_name, export=True, compile=False) _export_model.save_pretrained(OV_MODEL_DIR) del _export_model print("Export complete.") tokenizer = AutoTokenizer.from_pretrained(model_name) # Load without compiling, reshape to dynamic sequence length, then compile. # This ensures long prompts (e.g. many tools) don't hit static-shape errors. model = OVModelForCausalLM.from_pretrained(OV_MODEL_DIR, compile=False) model.reshape(1, -1) # batch=1, sequence_length=dynamic model.compile() ESCAPE = "" SYSTEM_PROMPT = "You are a model that can do function calling with the following functions" STRIP_TOKENS = ["", "", "", ""] # --- Prompt builder --- def serialize_value(v, key=None): if isinstance(v, str): val = v.upper() if key == "type" else v return f"{ESCAPE}{val}{ESCAPE}" elif isinstance(v, bool): return str(v).lower() elif isinstance(v, (int, float)): return str(v) elif isinstance(v, list): return "[" + ",".join(serialize_value(i) for i in v) + "]" elif isinstance(v, dict): pairs = ",".join(f"{k}:{serialize_value(val, key=k)}" for k, val in v.items()) return "{" + pairs + "}" return f"{ESCAPE}{v}{ESCAPE}" def build_declaration(tool): name = tool["name"] desc = tool.get("description", "") params = serialize_value(tool.get("parameters", {})) decl = f"declaration:{name}{{description:{ESCAPE}{desc}{ESCAPE},parameters:{params}}}" return f"{decl}" def build_prompt(messages, tools): parts = [] # Use developer message from the request if provided, else fall back to default. developer_msg = next((m for m in messages if m["role"] == "developer"), None) system_content = developer_msg["content"] if developer_msg else SYSTEM_PROMPT if tools: declarations = "".join(build_declaration(t) for t in tools) parts.append( f"developer\n" f"{system_content}{declarations}" f"\n" ) elif developer_msg: parts.append(f"developer\n{system_content}\n") for msg in messages: role = msg["role"] if role == "developer": continue # already emitted above content = msg["content"] if role == "tool": # content should already be formatted as ... parts.append(f"{content}\n") else: parts.append(f"{role}\n{content}\n") parts.append("model\n") return "".join(parts) # --- Output parser --- def parse_tool_call_args(body): """Parse FunctionGemma arg body: key:val,key2:123""" result = {} i = 0 while i < len(body): colon = body.find(":", i) if colon == -1: break key = body[i:colon] i = colon + 1 if body[i:i + len(ESCAPE)] == ESCAPE: end = body.find(ESCAPE, i + len(ESCAPE)) result[key] = body[i + len(ESCAPE):end] i = end + len(ESCAPE) elif body[i] in ("{", "["): open_c, close_c = body[i], "}" if body[i] == "{" else "]" depth, j = 0, i while j < len(body): if body[j] == open_c: depth += 1 elif body[j] == close_c: depth -= 1 if depth == 0: break j += 1 result[key] = body[i:j + 1] i = j + 1 else: comma = body.find(",", i) end = comma if comma != -1 else len(body) raw = body[i:end].strip() try: result[key] = int(raw) except ValueError: try: result[key] = float(raw) except ValueError: result[key] = raw i = end if i < len(body) and body[i] == ",": i += 1 return result def parse_tool_calls(text): tool_calls = [] for m in re.finditer(r"call:(\w+)\{(.*?)\}", text, re.DOTALL): tool_calls.append({ "name": m.group(1), "arguments": parse_tool_call_args(m.group(2)), }) return tool_calls # --- Routes --- @app.get("/", response_class=HTMLResponse) def root(): with open("index.html") as f: return f.read() @app.post("/generate") async def generate(request: Request): body = await request.json() messages = body.get("messages", []) tools = body.get("tools", None) max_new_tokens = body.get("max_new_tokens", 150) prompt = build_prompt(messages, tools) inputs = tokenizer(prompt, return_tensors="pt") # Truncate from the left if prompt exceeds model's context window (8192 tokens). MAX_INPUT_TOKENS = 8192 - max_new_tokens if inputs["input_ids"].shape[-1] > MAX_INPUT_TOKENS: inputs = {k: v[:, -MAX_INPUT_TOKENS:] for k, v in inputs.items()} outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, use_cache=True) prompt_tokens = inputs["input_ids"].shape[-1] completion_tokens = outputs.shape[-1] - prompt_tokens new_tokens = outputs[0][prompt_tokens:] response_text = tokenizer.decode(new_tokens, skip_special_tokens=False) for t in STRIP_TOKENS: response_text = response_text.replace(t, "") response_text = response_text.strip() return { "content": response_text, "tool_calls": parse_tool_calls(response_text), "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, }, }