nico-martin's picture
nico-martin HF Staff
added pythonic tool calls
ec2237a
import {
AutoModelForCausalLM,
AutoTokenizer,
InterruptableStoppingCriteria,
PreTrainedModel,
PreTrainedTokenizer,
Tensor,
TextStreamer,
} from "@huggingface/transformers";
import { calculateDownloadProgress } from "../../utils/calculateDownloadProgress.ts";
import { MODELS } from "../../utils/models.ts";
import {
type Request,
RequestType,
type Response,
ResponseType,
} from "../types.ts";
interface Pipeline {
tokenizer: PreTrainedTokenizer;
model: PreTrainedModel;
}
let pipeline: Pipeline | null = null;
let initializedModelKey: keyof typeof MODELS | null = null;
let cache: { pastKeyValues: any | null; key: string } = {
pastKeyValues: null,
key: "",
};
let stoppingCriteria: any | null = null;
const getTextGenerationPipeline = async (
modelKey: keyof typeof MODELS,
onDownloadProgress: (percentage: number) => void = () => {}
): Promise<Pipeline> => {
if (pipeline && modelKey === initializedModelKey) return pipeline;
if (pipeline) {
await pipeline.model.dispose();
}
const MODEL = MODELS[modelKey];
const MODEL_FILES = new Map();
for (const [key, value] of Object.entries(MODEL.files)) {
MODEL_FILES.set(key, { loaded: 0, total: value });
}
try {
const tokenizer = await AutoTokenizer.from_pretrained(MODEL.modelId);
const model = await AutoModelForCausalLM.from_pretrained(MODEL.modelId, {
dtype: MODEL.dtype,
device: MODEL.device,
progress_callback: calculateDownloadProgress(
({ percentage }) => onDownloadProgress(percentage),
MODEL_FILES
),
});
pipeline = { tokenizer, model };
initializedModelKey = modelKey;
return pipeline;
} catch (error) {
console.error("Failed to initialize feature extraction pipeline:", error);
throw error;
}
};
const postMessage = (message: Response) => self.postMessage(message);
self.onmessage = async ({ data }: MessageEvent<Request>) => {
if (data.type === RequestType.INITIALIZE_MODEL) {
let lastPercentage = 0;
await getTextGenerationPipeline(data.modelKey, (percentage) => {
if (lastPercentage === percentage) return;
lastPercentage = percentage;
postMessage({
type: ResponseType.INITIALIZE_MODEL,
progress: percentage,
done: false,
requestId: data.requestId,
});
});
postMessage({
type: ResponseType.INITIALIZE_MODEL,
progress: 100,
done: true,
requestId: data.requestId,
});
}
if (data.type === RequestType.GENERATE_MESSAGE_ABORT) {
stoppingCriteria.interrupt();
postMessage({
type: ResponseType.GENERATE_TEXT_ABORTED,
requestId: data.requestId,
});
}
if (data.type === RequestType.GENERATE_MESSAGE) {
const MODEL = MODELS[data.modelKey];
stoppingCriteria = new InterruptableStoppingCriteria();
const { messages, tools, requestId } = data;
const { tokenizer, model } = await getTextGenerationPipeline(data.modelKey);
if (!stoppingCriteria) {
stoppingCriteria = new InterruptableStoppingCriteria();
}
const input = tokenizer.apply_chat_template(messages, {
tools,
add_generation_prompt: true,
return_dict: true,
// @ts-expect-error
enable_thinking: data.enableThinking,
}) as {
input_ids: Tensor;
attention_mask: number[] | number[][] | Tensor;
};
const started = performance.now();
let firstTokenTime: DOMHighResTimeStamp | null = null;
let numTokens = 0;
let tps: number = 0;
const removeEosToken = (content: string): string =>
content.replace(tokenizer.eos_token, "");
const tokenCallbackFunction = () => {
firstTokenTime ??= performance.now();
if (numTokens++ > 0) {
tps = (numTokens / (performance.now() - firstTokenTime)) * 1000;
}
};
const callbackFunction = (chunk: string) => {
postMessage({
type: ResponseType.GENERATE_TEXT_CHUNK,
chunk: removeEosToken(chunk),
requestId,
});
};
const streamer = new TextStreamer(tokenizer, {
skip_prompt: true,
skip_special_tokens: false,
token_callback_function: tokenCallbackFunction,
callback_function: callbackFunction,
});
const cacheKey = MODEL.modelId + JSON.stringify(messages.slice(0, -1));
const useCache = cacheKey === cache.key;
const { sequences, past_key_values } = (await model.generate({
...input,
max_new_tokens: 1024,
past_key_values: useCache ? cache.pastKeyValues : null,
return_dict_in_generate: true,
temperature: data.temperature,
stopping_criteria: stoppingCriteria,
streamer,
})) as { sequences: Tensor; past_key_values: any };
const ended = performance.now();
const lengthOfInput = input.input_ids.dims[1];
const response = removeEosToken(
tokenizer.batch_decode(
/**
* First argument (null): Don't slice dimension 0 (the batch dimension) - keep all batches
* Second argument ([lengthOfInput, Number.MAX_SAFE_INTEGER]): For dimension 1 (the sequence/token dimension), slice from index lengthOfInput to the end
*/
sequences.slice(null, [lengthOfInput, Number.MAX_SAFE_INTEGER]),
{
skip_special_tokens: false,
}
)[0]
);
const template = tokenizer.batch_decode(sequences, {
skip_special_tokens: false,
})[0];
cache = {
pastKeyValues: past_key_values,
key:
MODEL.modelId +
JSON.stringify([
...messages,
{
role: "assistant",
content: response,
},
]),
};
postMessage({
type: ResponseType.GENERATE_TEXT_DONE,
response,
metadata: {
inputDurationMs: firstTokenTime - started,
outputTokens: numTokens,
outputDurationMs: ended - firstTokenTime,
outputTps: tps,
doneMs: ended - started,
modelKey: MODEL.modelId,
model: MODEL.title,
template,
useKvCache: useCache,
temperature: data.temperature,
},
interrupted: stoppingCriteria.interrupted,
requestId,
});
}
};