import { AutoModelForCausalLM, AutoTokenizer, InterruptableStoppingCriteria, PreTrainedModel, PreTrainedTokenizer, Tensor, TextStreamer, } from "@huggingface/transformers"; import { calculateDownloadProgress } from "../../utils/calculateDownloadProgress.ts"; import { MODELS } from "../../utils/models.ts"; import { type Request, RequestType, type Response, ResponseType, } from "../types.ts"; interface Pipeline { tokenizer: PreTrainedTokenizer; model: PreTrainedModel; } let pipeline: Pipeline | null = null; let initializedModelKey: keyof typeof MODELS | null = null; let cache: { pastKeyValues: any | null; key: string } = { pastKeyValues: null, key: "", }; let stoppingCriteria: any | null = null; const getTextGenerationPipeline = async ( modelKey: keyof typeof MODELS, onDownloadProgress: (percentage: number) => void = () => {} ): Promise => { if (pipeline && modelKey === initializedModelKey) return pipeline; if (pipeline) { await pipeline.model.dispose(); } const MODEL = MODELS[modelKey]; const MODEL_FILES = new Map(); for (const [key, value] of Object.entries(MODEL.files)) { MODEL_FILES.set(key, { loaded: 0, total: value }); } try { const tokenizer = await AutoTokenizer.from_pretrained(MODEL.modelId); const model = await AutoModelForCausalLM.from_pretrained(MODEL.modelId, { dtype: MODEL.dtype, device: MODEL.device, progress_callback: calculateDownloadProgress( ({ percentage }) => onDownloadProgress(percentage), MODEL_FILES ), }); pipeline = { tokenizer, model }; initializedModelKey = modelKey; return pipeline; } catch (error) { console.error("Failed to initialize feature extraction pipeline:", error); throw error; } }; const postMessage = (message: Response) => self.postMessage(message); self.onmessage = async ({ data }: MessageEvent) => { if (data.type === RequestType.INITIALIZE_MODEL) { let lastPercentage = 0; await getTextGenerationPipeline(data.modelKey, (percentage) => { if (lastPercentage === percentage) return; lastPercentage = percentage; postMessage({ type: ResponseType.INITIALIZE_MODEL, progress: percentage, done: false, requestId: data.requestId, }); }); postMessage({ type: ResponseType.INITIALIZE_MODEL, progress: 100, done: true, requestId: data.requestId, }); } if (data.type === RequestType.GENERATE_MESSAGE_ABORT) { stoppingCriteria.interrupt(); postMessage({ type: ResponseType.GENERATE_TEXT_ABORTED, requestId: data.requestId, }); } if (data.type === RequestType.GENERATE_MESSAGE) { const MODEL = MODELS[data.modelKey]; stoppingCriteria = new InterruptableStoppingCriteria(); const { messages, tools, requestId } = data; const { tokenizer, model } = await getTextGenerationPipeline(data.modelKey); if (!stoppingCriteria) { stoppingCriteria = new InterruptableStoppingCriteria(); } const input = tokenizer.apply_chat_template(messages, { tools, add_generation_prompt: true, return_dict: true, // @ts-expect-error enable_thinking: data.enableThinking, }) as { input_ids: Tensor; attention_mask: number[] | number[][] | Tensor; }; const started = performance.now(); let firstTokenTime: DOMHighResTimeStamp | null = null; let numTokens = 0; let tps: number = 0; const removeEosToken = (content: string): string => content.replace(tokenizer.eos_token, ""); const tokenCallbackFunction = () => { firstTokenTime ??= performance.now(); if (numTokens++ > 0) { tps = (numTokens / (performance.now() - firstTokenTime)) * 1000; } }; const callbackFunction = (chunk: string) => { postMessage({ type: ResponseType.GENERATE_TEXT_CHUNK, chunk: removeEosToken(chunk), requestId, }); }; const streamer = new TextStreamer(tokenizer, { skip_prompt: true, skip_special_tokens: false, token_callback_function: tokenCallbackFunction, callback_function: callbackFunction, }); const cacheKey = MODEL.modelId + JSON.stringify(messages.slice(0, -1)); const useCache = cacheKey === cache.key; const { sequences, past_key_values } = (await model.generate({ ...input, max_new_tokens: 1024, past_key_values: useCache ? cache.pastKeyValues : null, return_dict_in_generate: true, temperature: data.temperature, stopping_criteria: stoppingCriteria, streamer, })) as { sequences: Tensor; past_key_values: any }; const ended = performance.now(); const lengthOfInput = input.input_ids.dims[1]; const response = removeEosToken( tokenizer.batch_decode( /** * First argument (null): Don't slice dimension 0 (the batch dimension) - keep all batches * Second argument ([lengthOfInput, Number.MAX_SAFE_INTEGER]): For dimension 1 (the sequence/token dimension), slice from index lengthOfInput to the end */ sequences.slice(null, [lengthOfInput, Number.MAX_SAFE_INTEGER]), { skip_special_tokens: false, } )[0] ); const template = tokenizer.batch_decode(sequences, { skip_special_tokens: false, })[0]; cache = { pastKeyValues: past_key_values, key: MODEL.modelId + JSON.stringify([ ...messages, { role: "assistant", content: response, }, ]), }; postMessage({ type: ResponseType.GENERATE_TEXT_DONE, response, metadata: { inputDurationMs: firstTokenTime - started, outputTokens: numTokens, outputDurationMs: ended - firstTokenTime, outputTps: tps, doneMs: ended - started, modelKey: MODEL.modelId, model: MODEL.title, template, useKvCache: useCache, temperature: data.temperature, }, interrupted: stoppingCriteria.interrupted, requestId, }); } };