functiongemm / app.py
jerinaj's picture
updates
76afdbd
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse
from optimum.intel import OVModelForCausalLM
from transformers import AutoTokenizer
import huggingface_hub
import multiprocessing
import os
import re
os.environ["OMP_NUM_THREADS"] = str(multiprocessing.cpu_count())
os.environ["OV_CPU_THREADS_NUM"] = str(multiprocessing.cpu_count())
app = FastAPI()
model_name = "google/functiongemma-270m-it"
OV_MODEL_DIR = "functiongemma_ov"
# Authenticate with HuggingFace if token is provided
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
huggingface_hub.login(token=hf_token)
# Export model to OpenVINO format on first run if not already done.
# Using the Python API (export=True) instead of the CLI produces a model with
# dynamic shapes, avoiding static sequence-length constants baked in by tracing.
if not os.path.isdir(OV_MODEL_DIR):
print(f"OpenVINO model not found at '{OV_MODEL_DIR}', exporting now...")
_export_model = OVModelForCausalLM.from_pretrained(model_name, export=True, compile=False)
_export_model.save_pretrained(OV_MODEL_DIR)
del _export_model
print("Export complete.")
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load without compiling, reshape to dynamic sequence length, then compile.
# This ensures long prompts (e.g. many tools) don't hit static-shape errors.
model = OVModelForCausalLM.from_pretrained(OV_MODEL_DIR, compile=False)
model.reshape(1, -1) # batch=1, sequence_length=dynamic
model.compile()
ESCAPE = "<escape>"
SYSTEM_PROMPT = "You are a model that can do function calling with the following functions"
STRIP_TOKENS = ["<eos>", "<end_of_turn>", "<bos>", "<pad>"]
# --- Prompt builder ---
def serialize_value(v, key=None):
if isinstance(v, str):
val = v.upper() if key == "type" else v
return f"{ESCAPE}{val}{ESCAPE}"
elif isinstance(v, bool):
return str(v).lower()
elif isinstance(v, (int, float)):
return str(v)
elif isinstance(v, list):
return "[" + ",".join(serialize_value(i) for i in v) + "]"
elif isinstance(v, dict):
pairs = ",".join(f"{k}:{serialize_value(val, key=k)}" for k, val in v.items())
return "{" + pairs + "}"
return f"{ESCAPE}{v}{ESCAPE}"
def build_declaration(tool):
name = tool["name"]
desc = tool.get("description", "")
params = serialize_value(tool.get("parameters", {}))
decl = f"declaration:{name}{{description:{ESCAPE}{desc}{ESCAPE},parameters:{params}}}"
return f"<start_function_declaration>{decl}<end_function_declaration>"
def build_prompt(messages, tools):
parts = []
# Use developer message from the request if provided, else fall back to default.
developer_msg = next((m for m in messages if m["role"] == "developer"), None)
system_content = developer_msg["content"] if developer_msg else SYSTEM_PROMPT
if tools:
declarations = "".join(build_declaration(t) for t in tools)
parts.append(
f"<start_of_turn>developer\n"
f"{system_content}{declarations}"
f"<end_of_turn>\n"
)
elif developer_msg:
parts.append(f"<start_of_turn>developer\n{system_content}<end_of_turn>\n")
for msg in messages:
role = msg["role"]
if role == "developer":
continue # already emitted above
content = msg["content"]
if role == "tool":
# content should already be formatted as <start_function_response>...<end_function_response>
parts.append(f"{content}\n")
else:
parts.append(f"<start_of_turn>{role}\n{content}<end_of_turn>\n")
parts.append("<start_of_turn>model\n")
return "".join(parts)
# --- Output parser ---
def parse_tool_call_args(body):
"""Parse FunctionGemma arg body: key:<escape>val<escape>,key2:123"""
result = {}
i = 0
while i < len(body):
colon = body.find(":", i)
if colon == -1:
break
key = body[i:colon]
i = colon + 1
if body[i:i + len(ESCAPE)] == ESCAPE:
end = body.find(ESCAPE, i + len(ESCAPE))
result[key] = body[i + len(ESCAPE):end]
i = end + len(ESCAPE)
elif body[i] in ("{", "["):
open_c, close_c = body[i], "}" if body[i] == "{" else "]"
depth, j = 0, i
while j < len(body):
if body[j] == open_c:
depth += 1
elif body[j] == close_c:
depth -= 1
if depth == 0:
break
j += 1
result[key] = body[i:j + 1]
i = j + 1
else:
comma = body.find(",", i)
end = comma if comma != -1 else len(body)
raw = body[i:end].strip()
try:
result[key] = int(raw)
except ValueError:
try:
result[key] = float(raw)
except ValueError:
result[key] = raw
i = end
if i < len(body) and body[i] == ",":
i += 1
return result
def parse_tool_calls(text):
tool_calls = []
for m in re.finditer(r"<start_function_call>call:(\w+)\{(.*?)\}<end_function_call>", text, re.DOTALL):
tool_calls.append({
"name": m.group(1),
"arguments": parse_tool_call_args(m.group(2)),
})
return tool_calls
# --- Routes ---
@app.get("/", response_class=HTMLResponse)
def root():
with open("index.html") as f:
return f.read()
@app.post("/generate")
async def generate(request: Request):
body = await request.json()
messages = body.get("messages", [])
tools = body.get("tools", None)
max_new_tokens = body.get("max_new_tokens", 150)
prompt = build_prompt(messages, tools)
inputs = tokenizer(prompt, return_tensors="pt")
# Truncate from the left if prompt exceeds model's context window (8192 tokens).
MAX_INPUT_TOKENS = 8192 - max_new_tokens
if inputs["input_ids"].shape[-1] > MAX_INPUT_TOKENS:
inputs = {k: v[:, -MAX_INPUT_TOKENS:] for k, v in inputs.items()}
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, use_cache=True)
prompt_tokens = inputs["input_ids"].shape[-1]
completion_tokens = outputs.shape[-1] - prompt_tokens
new_tokens = outputs[0][prompt_tokens:]
response_text = tokenizer.decode(new_tokens, skip_special_tokens=False)
for t in STRIP_TOKENS:
response_text = response_text.replace(t, "")
response_text = response_text.strip()
return {
"content": response_text,
"tool_calls": parse_tool_calls(response_text),
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}