|
|
import { |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
InterruptableStoppingCriteria, |
|
|
PreTrainedModel, |
|
|
PreTrainedTokenizer, |
|
|
Tensor, |
|
|
TextStreamer, |
|
|
} from "@huggingface/transformers"; |
|
|
|
|
|
import { calculateDownloadProgress } from "../../utils/calculateDownloadProgress.ts"; |
|
|
import { MODELS } from "../../utils/models.ts"; |
|
|
import { |
|
|
type Request, |
|
|
RequestType, |
|
|
type Response, |
|
|
ResponseType, |
|
|
} from "../types.ts"; |
|
|
|
|
|
interface Pipeline { |
|
|
tokenizer: PreTrainedTokenizer; |
|
|
model: PreTrainedModel; |
|
|
} |
|
|
|
|
|
let pipeline: Pipeline | null = null; |
|
|
let initializedModelKey: keyof typeof MODELS | null = null; |
|
|
let cache: { pastKeyValues: any | null; key: string } = { |
|
|
pastKeyValues: null, |
|
|
key: "", |
|
|
}; |
|
|
let stoppingCriteria: any | null = null; |
|
|
|
|
|
const getTextGenerationPipeline = async ( |
|
|
modelKey: keyof typeof MODELS, |
|
|
onDownloadProgress: (percentage: number) => void = () => {} |
|
|
): Promise<Pipeline> => { |
|
|
if (pipeline && modelKey === initializedModelKey) return pipeline; |
|
|
if (pipeline) { |
|
|
await pipeline.model.dispose(); |
|
|
} |
|
|
|
|
|
const MODEL = MODELS[modelKey]; |
|
|
const MODEL_FILES = new Map(); |
|
|
for (const [key, value] of Object.entries(MODEL.files)) { |
|
|
MODEL_FILES.set(key, { loaded: 0, total: value }); |
|
|
} |
|
|
|
|
|
try { |
|
|
const tokenizer = await AutoTokenizer.from_pretrained(MODEL.modelId); |
|
|
const model = await AutoModelForCausalLM.from_pretrained(MODEL.modelId, { |
|
|
dtype: MODEL.dtype, |
|
|
device: MODEL.device, |
|
|
progress_callback: calculateDownloadProgress( |
|
|
({ percentage }) => onDownloadProgress(percentage), |
|
|
MODEL_FILES |
|
|
), |
|
|
}); |
|
|
pipeline = { tokenizer, model }; |
|
|
initializedModelKey = modelKey; |
|
|
return pipeline; |
|
|
} catch (error) { |
|
|
console.error("Failed to initialize feature extraction pipeline:", error); |
|
|
throw error; |
|
|
} |
|
|
}; |
|
|
|
|
|
const postMessage = (message: Response) => self.postMessage(message); |
|
|
|
|
|
self.onmessage = async ({ data }: MessageEvent<Request>) => { |
|
|
if (data.type === RequestType.INITIALIZE_MODEL) { |
|
|
let lastPercentage = 0; |
|
|
await getTextGenerationPipeline(data.modelKey, (percentage) => { |
|
|
if (lastPercentage === percentage) return; |
|
|
lastPercentage = percentage; |
|
|
postMessage({ |
|
|
type: ResponseType.INITIALIZE_MODEL, |
|
|
progress: percentage, |
|
|
done: false, |
|
|
requestId: data.requestId, |
|
|
}); |
|
|
}); |
|
|
postMessage({ |
|
|
type: ResponseType.INITIALIZE_MODEL, |
|
|
progress: 100, |
|
|
done: true, |
|
|
requestId: data.requestId, |
|
|
}); |
|
|
} |
|
|
|
|
|
if (data.type === RequestType.GENERATE_MESSAGE_ABORT) { |
|
|
stoppingCriteria.interrupt(); |
|
|
postMessage({ |
|
|
type: ResponseType.GENERATE_TEXT_ABORTED, |
|
|
requestId: data.requestId, |
|
|
}); |
|
|
} |
|
|
|
|
|
if (data.type === RequestType.GENERATE_MESSAGE) { |
|
|
const MODEL = MODELS[data.modelKey]; |
|
|
stoppingCriteria = new InterruptableStoppingCriteria(); |
|
|
|
|
|
const { messages, tools, requestId } = data; |
|
|
const { tokenizer, model } = await getTextGenerationPipeline(data.modelKey); |
|
|
if (!stoppingCriteria) { |
|
|
stoppingCriteria = new InterruptableStoppingCriteria(); |
|
|
} |
|
|
|
|
|
const input = tokenizer.apply_chat_template(messages, { |
|
|
tools, |
|
|
add_generation_prompt: true, |
|
|
return_dict: true, |
|
|
|
|
|
enable_thinking: data.enableThinking, |
|
|
}) as { |
|
|
input_ids: Tensor; |
|
|
attention_mask: number[] | number[][] | Tensor; |
|
|
}; |
|
|
|
|
|
const started = performance.now(); |
|
|
let firstTokenTime: DOMHighResTimeStamp | null = null; |
|
|
let numTokens = 0; |
|
|
let tps: number = 0; |
|
|
|
|
|
const removeEosToken = (content: string): string => |
|
|
content.replace(tokenizer.eos_token, ""); |
|
|
|
|
|
const tokenCallbackFunction = () => { |
|
|
firstTokenTime ??= performance.now(); |
|
|
if (numTokens++ > 0) { |
|
|
tps = (numTokens / (performance.now() - firstTokenTime)) * 1000; |
|
|
} |
|
|
}; |
|
|
|
|
|
const callbackFunction = (chunk: string) => { |
|
|
postMessage({ |
|
|
type: ResponseType.GENERATE_TEXT_CHUNK, |
|
|
chunk: removeEosToken(chunk), |
|
|
requestId, |
|
|
}); |
|
|
}; |
|
|
|
|
|
const streamer = new TextStreamer(tokenizer, { |
|
|
skip_prompt: true, |
|
|
skip_special_tokens: false, |
|
|
token_callback_function: tokenCallbackFunction, |
|
|
callback_function: callbackFunction, |
|
|
}); |
|
|
|
|
|
const cacheKey = MODEL.modelId + JSON.stringify(messages.slice(0, -1)); |
|
|
const useCache = cacheKey === cache.key; |
|
|
|
|
|
const { sequences, past_key_values } = (await model.generate({ |
|
|
...input, |
|
|
max_new_tokens: 1024, |
|
|
past_key_values: useCache ? cache.pastKeyValues : null, |
|
|
return_dict_in_generate: true, |
|
|
temperature: data.temperature, |
|
|
stopping_criteria: stoppingCriteria, |
|
|
streamer, |
|
|
})) as { sequences: Tensor; past_key_values: any }; |
|
|
const ended = performance.now(); |
|
|
|
|
|
const lengthOfInput = input.input_ids.dims[1]; |
|
|
const response = removeEosToken( |
|
|
tokenizer.batch_decode( |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sequences.slice(null, [lengthOfInput, Number.MAX_SAFE_INTEGER]), |
|
|
{ |
|
|
skip_special_tokens: false, |
|
|
} |
|
|
)[0] |
|
|
); |
|
|
|
|
|
const template = tokenizer.batch_decode(sequences, { |
|
|
skip_special_tokens: false, |
|
|
})[0]; |
|
|
|
|
|
cache = { |
|
|
pastKeyValues: past_key_values, |
|
|
key: |
|
|
MODEL.modelId + |
|
|
JSON.stringify([ |
|
|
...messages, |
|
|
{ |
|
|
role: "assistant", |
|
|
content: response, |
|
|
}, |
|
|
]), |
|
|
}; |
|
|
|
|
|
postMessage({ |
|
|
type: ResponseType.GENERATE_TEXT_DONE, |
|
|
response, |
|
|
metadata: { |
|
|
inputDurationMs: firstTokenTime - started, |
|
|
outputTokens: numTokens, |
|
|
outputDurationMs: ended - firstTokenTime, |
|
|
outputTps: tps, |
|
|
doneMs: ended - started, |
|
|
modelKey: MODEL.modelId, |
|
|
model: MODEL.title, |
|
|
template, |
|
|
useKvCache: useCache, |
|
|
temperature: data.temperature, |
|
|
}, |
|
|
interrupted: stoppingCriteria.interrupted, |
|
|
requestId, |
|
|
}); |
|
|
} |
|
|
}; |
|
|
|