Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request | |
| from fastapi.responses import HTMLResponse | |
| from optimum.intel import OVModelForCausalLM | |
| from transformers import AutoTokenizer | |
| import huggingface_hub | |
| import multiprocessing | |
| import os | |
| import re | |
| os.environ["OMP_NUM_THREADS"] = str(multiprocessing.cpu_count()) | |
| os.environ["OV_CPU_THREADS_NUM"] = str(multiprocessing.cpu_count()) | |
| app = FastAPI() | |
| model_name = "google/functiongemma-270m-it" | |
| OV_MODEL_DIR = "functiongemma_ov" | |
| # Authenticate with HuggingFace if token is provided | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if hf_token: | |
| huggingface_hub.login(token=hf_token) | |
| # Export model to OpenVINO format on first run if not already done. | |
| # Using the Python API (export=True) instead of the CLI produces a model with | |
| # dynamic shapes, avoiding static sequence-length constants baked in by tracing. | |
| if not os.path.isdir(OV_MODEL_DIR): | |
| print(f"OpenVINO model not found at '{OV_MODEL_DIR}', exporting now...") | |
| _export_model = OVModelForCausalLM.from_pretrained(model_name, export=True, compile=False) | |
| _export_model.save_pretrained(OV_MODEL_DIR) | |
| del _export_model | |
| print("Export complete.") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Load without compiling, reshape to dynamic sequence length, then compile. | |
| # This ensures long prompts (e.g. many tools) don't hit static-shape errors. | |
| model = OVModelForCausalLM.from_pretrained(OV_MODEL_DIR, compile=False) | |
| model.reshape(1, -1) # batch=1, sequence_length=dynamic | |
| model.compile() | |
| ESCAPE = "<escape>" | |
| SYSTEM_PROMPT = "You are a model that can do function calling with the following functions" | |
| STRIP_TOKENS = ["<eos>", "<end_of_turn>", "<bos>", "<pad>"] | |
| # --- Prompt builder --- | |
| def serialize_value(v, key=None): | |
| if isinstance(v, str): | |
| val = v.upper() if key == "type" else v | |
| return f"{ESCAPE}{val}{ESCAPE}" | |
| elif isinstance(v, bool): | |
| return str(v).lower() | |
| elif isinstance(v, (int, float)): | |
| return str(v) | |
| elif isinstance(v, list): | |
| return "[" + ",".join(serialize_value(i) for i in v) + "]" | |
| elif isinstance(v, dict): | |
| pairs = ",".join(f"{k}:{serialize_value(val, key=k)}" for k, val in v.items()) | |
| return "{" + pairs + "}" | |
| return f"{ESCAPE}{v}{ESCAPE}" | |
| def build_declaration(tool): | |
| name = tool["name"] | |
| desc = tool.get("description", "") | |
| params = serialize_value(tool.get("parameters", {})) | |
| decl = f"declaration:{name}{{description:{ESCAPE}{desc}{ESCAPE},parameters:{params}}}" | |
| return f"<start_function_declaration>{decl}<end_function_declaration>" | |
| def build_prompt(messages, tools): | |
| parts = [] | |
| # Use developer message from the request if provided, else fall back to default. | |
| developer_msg = next((m for m in messages if m["role"] == "developer"), None) | |
| system_content = developer_msg["content"] if developer_msg else SYSTEM_PROMPT | |
| if tools: | |
| declarations = "".join(build_declaration(t) for t in tools) | |
| parts.append( | |
| f"<start_of_turn>developer\n" | |
| f"{system_content}{declarations}" | |
| f"<end_of_turn>\n" | |
| ) | |
| elif developer_msg: | |
| parts.append(f"<start_of_turn>developer\n{system_content}<end_of_turn>\n") | |
| for msg in messages: | |
| role = msg["role"] | |
| if role == "developer": | |
| continue # already emitted above | |
| content = msg["content"] | |
| if role == "tool": | |
| # content should already be formatted as <start_function_response>...<end_function_response> | |
| parts.append(f"{content}\n") | |
| else: | |
| parts.append(f"<start_of_turn>{role}\n{content}<end_of_turn>\n") | |
| parts.append("<start_of_turn>model\n") | |
| return "".join(parts) | |
| # --- Output parser --- | |
| def parse_tool_call_args(body): | |
| """Parse FunctionGemma arg body: key:<escape>val<escape>,key2:123""" | |
| result = {} | |
| i = 0 | |
| while i < len(body): | |
| colon = body.find(":", i) | |
| if colon == -1: | |
| break | |
| key = body[i:colon] | |
| i = colon + 1 | |
| if body[i:i + len(ESCAPE)] == ESCAPE: | |
| end = body.find(ESCAPE, i + len(ESCAPE)) | |
| result[key] = body[i + len(ESCAPE):end] | |
| i = end + len(ESCAPE) | |
| elif body[i] in ("{", "["): | |
| open_c, close_c = body[i], "}" if body[i] == "{" else "]" | |
| depth, j = 0, i | |
| while j < len(body): | |
| if body[j] == open_c: | |
| depth += 1 | |
| elif body[j] == close_c: | |
| depth -= 1 | |
| if depth == 0: | |
| break | |
| j += 1 | |
| result[key] = body[i:j + 1] | |
| i = j + 1 | |
| else: | |
| comma = body.find(",", i) | |
| end = comma if comma != -1 else len(body) | |
| raw = body[i:end].strip() | |
| try: | |
| result[key] = int(raw) | |
| except ValueError: | |
| try: | |
| result[key] = float(raw) | |
| except ValueError: | |
| result[key] = raw | |
| i = end | |
| if i < len(body) and body[i] == ",": | |
| i += 1 | |
| return result | |
| def parse_tool_calls(text): | |
| tool_calls = [] | |
| for m in re.finditer(r"<start_function_call>call:(\w+)\{(.*?)\}<end_function_call>", text, re.DOTALL): | |
| tool_calls.append({ | |
| "name": m.group(1), | |
| "arguments": parse_tool_call_args(m.group(2)), | |
| }) | |
| return tool_calls | |
| # --- Routes --- | |
| def root(): | |
| with open("index.html") as f: | |
| return f.read() | |
| async def generate(request: Request): | |
| body = await request.json() | |
| messages = body.get("messages", []) | |
| tools = body.get("tools", None) | |
| max_new_tokens = body.get("max_new_tokens", 150) | |
| prompt = build_prompt(messages, tools) | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| # Truncate from the left if prompt exceeds model's context window (8192 tokens). | |
| MAX_INPUT_TOKENS = 8192 - max_new_tokens | |
| if inputs["input_ids"].shape[-1] > MAX_INPUT_TOKENS: | |
| inputs = {k: v[:, -MAX_INPUT_TOKENS:] for k, v in inputs.items()} | |
| outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, use_cache=True) | |
| prompt_tokens = inputs["input_ids"].shape[-1] | |
| completion_tokens = outputs.shape[-1] - prompt_tokens | |
| new_tokens = outputs[0][prompt_tokens:] | |
| response_text = tokenizer.decode(new_tokens, skip_special_tokens=False) | |
| for t in STRIP_TOKENS: | |
| response_text = response_text.replace(t, "") | |
| response_text = response_text.strip() | |
| return { | |
| "content": response_text, | |
| "tool_calls": parse_tool_calls(response_text), | |
| "usage": { | |
| "prompt_tokens": prompt_tokens, | |
| "completion_tokens": completion_tokens, | |
| "total_tokens": prompt_tokens + completion_tokens, | |
| }, | |
| } | |