| import { type Message } from "@huggingface/transformers"; | |
| import { | |
| executeToolCall, | |
| splitResponse, | |
| webMCPToolToChatTemplateTool, | |
| } from "@utils/webMcp"; | |
| import type { WebMCPTool } from "@utils/webMcp/types.ts"; | |
| import { | |
| type ChatMessage, | |
| type ChatMessageAssistant, | |
| type ChatMessageAssistantResponse, | |
| type ChatMessageAssistantTool, | |
| type GenerationMetadata, | |
| type Request, | |
| RequestType, | |
| type Response, | |
| ResponseType, | |
| } from "./types.ts"; | |
| export default class TextGeneration { | |
| private worker: Worker; | |
| private requestId: number = 0; | |
| private modelKey: string | null = null; | |
| private tools: Array<WebMCPTool> | null = null; | |
| private temperature: number | null = null; | |
| private enableThinking: boolean | null = null; | |
| private messages: Array<Message> = []; | |
| private _chatMessages: Array<ChatMessage> = []; | |
| private chatMessagesListener: Array< | |
| (chatMessages: Array<ChatMessage>) => void | |
| > = []; | |
| constructor() { | |
| this.worker = new Worker( | |
| new URL("./worker/textGenerationWorker.ts", import.meta.url), | |
| { | |
| type: "module", | |
| } | |
| ); | |
| } | |
| get chatMessages() { | |
| return this._chatMessages; | |
| } | |
| set chatMessages(chatMessages: Array<ChatMessage>) { | |
| this._chatMessages = chatMessages; | |
| this.chatMessagesListener.forEach((listener) => listener(chatMessages)); | |
| } | |
| public onChatMessageUpdate = ( | |
| callback: (messages: Array<ChatMessage>) => void | |
| ) => { | |
| this.chatMessagesListener.push(callback); | |
| return () => { | |
| this.chatMessagesListener = this.chatMessagesListener.filter( | |
| (listener) => listener !== callback | |
| ); | |
| }; | |
| }; | |
| private postWorkerMessage = (request: Request) => | |
| this.worker.postMessage(request); | |
| private addWorkerEventListener = ( | |
| listener: (ev: MessageEvent<Response>) => void | |
| ) => this.worker.addEventListener("message", listener); | |
| private removeWorkerEventListener = ( | |
| listener: (ev: MessageEvent<Response>) => void | |
| ) => this.worker.removeEventListener("message", listener); | |
| public async initializeModel( | |
| modelKey: string, | |
| onDownload: (percentage: number) => void | |
| ) { | |
| return new Promise<number>((resolve, reject) => { | |
| const requestId = this.requestId++; | |
| const listener = ({ data }: MessageEvent<Response>) => { | |
| if (data.requestId !== requestId) return; | |
| if (data.type === ResponseType.ERROR) { | |
| this.removeWorkerEventListener(listener); | |
| reject(data.message); | |
| } | |
| if (data.type !== ResponseType.INITIALIZE_MODEL) return; | |
| if (data.done) { | |
| this.removeWorkerEventListener(listener); | |
| this.modelKey = modelKey; | |
| resolve(data.progress); | |
| } | |
| onDownload(data.progress); | |
| }; | |
| this.addWorkerEventListener(listener); | |
| this.postWorkerMessage({ | |
| type: RequestType.INITIALIZE_MODEL, | |
| modelKey, | |
| requestId, | |
| }); | |
| }); | |
| } | |
| public initializeConversation( | |
| tools: Array<WebMCPTool> = [], | |
| temperature: number, | |
| enableThinking: boolean, | |
| systemPrompt: string | |
| ) { | |
| this.tools = tools; | |
| this.temperature = temperature; | |
| this.enableThinking = enableThinking; | |
| this.messages = [{ role: "system", content: systemPrompt }]; | |
| this.chatMessages = [{ role: "system", content: systemPrompt }]; | |
| } | |
| public async abort() { | |
| return new Promise<void>((resolve, reject) => { | |
| const requestId = this.requestId++; | |
| const listener = ({ data }: MessageEvent<Response>) => { | |
| if (data.requestId !== requestId) return; | |
| if (data.type === ResponseType.ERROR) { | |
| this.removeWorkerEventListener(listener); | |
| reject(data.message); | |
| } | |
| if (data.type === ResponseType.GENERATE_TEXT_ABORTED) { | |
| this.removeWorkerEventListener(listener); | |
| resolve(); | |
| } | |
| }; | |
| this.addWorkerEventListener(listener); | |
| this.postWorkerMessage({ | |
| type: RequestType.GENERATE_MESSAGE_ABORT, | |
| requestId, | |
| }); | |
| }); | |
| } | |
| private generateText = ( | |
| prompt: string, | |
| role: "user" | "tool", | |
| onResponseUpdate: (response: string) => void = () => {} | |
| ) => { | |
| return new Promise<{ | |
| response: string; | |
| metadata: GenerationMetadata; | |
| interrupted: boolean; | |
| }>((resolve, reject) => { | |
| if (this.modelKey === null) { | |
| reject("Model not initialized"); | |
| return; | |
| } | |
| if ( | |
| this.tools === null || | |
| this.temperature === null || | |
| this.enableThinking === null | |
| ) { | |
| reject("Conversation not initialized"); | |
| return; | |
| } | |
| const requestId = this.requestId++; | |
| this.messages = [...this.messages, { role, content: prompt }]; | |
| let response = ""; | |
| const listener = ({ data }: MessageEvent<Response>) => { | |
| if (data.requestId !== requestId) return; | |
| if (data.type === ResponseType.ERROR) { | |
| this.removeWorkerEventListener(listener); | |
| reject(data.message); | |
| } | |
| if (data.type === ResponseType.GENERATE_TEXT_DONE) { | |
| this.removeWorkerEventListener(listener); | |
| this.messages.push({ role: "assistant", content: data.response }); | |
| resolve({ | |
| response: data.response, | |
| metadata: data.metadata, | |
| interrupted: data.interrupted, | |
| }); | |
| } | |
| if (data.type === ResponseType.GENERATE_TEXT_CHUNK) { | |
| response = response + data.chunk; | |
| onResponseUpdate(response); | |
| } | |
| }; | |
| this.addWorkerEventListener(listener); | |
| this.postWorkerMessage({ | |
| type: RequestType.GENERATE_MESSAGE, | |
| modelKey: this.modelKey, | |
| messages: this.messages, | |
| tools: this.tools.map(webMCPToolToChatTemplateTool), | |
| requestId, | |
| temperature: this.temperature, | |
| enableThinking: this.enableThinking, | |
| }); | |
| }); | |
| }; | |
| public async runAgent(prompt: string): Promise<void> { | |
| let isUser = true; | |
| this.chatMessages = [ | |
| ...this.chatMessages, | |
| { role: "user", content: prompt }, | |
| ]; | |
| while (prompt) { | |
| const prevChatMessages = this.chatMessages; | |
| const assistantMessage: ChatMessageAssistant = { | |
| role: "assistant", | |
| content: [], | |
| interrupted: false, | |
| }; | |
| this.chatMessages = [...prevChatMessages, assistantMessage]; | |
| const { interrupted, metadata } = await this.generateText( | |
| prompt, | |
| isUser ? "user" : "tool", | |
| (partialResponse) => { | |
| const parts = splitResponse(partialResponse); | |
| assistantMessage.content = parts.map((part) => | |
| typeof part === "string" | |
| ? ({ | |
| type: "response", | |
| content: part, | |
| } as ChatMessageAssistantResponse) | |
| : ({ | |
| type: "tool", | |
| result: null, | |
| time: null, | |
| functionSignature: `${part.name}(${JSON.stringify( | |
| part.arguments | |
| )})`, | |
| ...part, | |
| } as ChatMessageAssistantTool) | |
| ); | |
| this.chatMessages = [...prevChatMessages, assistantMessage]; | |
| } | |
| ); | |
| isUser = false; | |
| assistantMessage.metadata = metadata; | |
| assistantMessage.interrupted = interrupted; | |
| this.chatMessages = [...prevChatMessages, assistantMessage]; | |
| const toolCalls = assistantMessage.content.filter( | |
| (c) => c.type === "tool" | |
| ); | |
| if (toolCalls.length === 0) { | |
| prompt = ""; | |
| continue; | |
| } | |
| const toolResponses = await Promise.all( | |
| toolCalls.map((tool) => | |
| executeToolCall( | |
| { | |
| name: tool.name, | |
| arguments: tool.arguments, | |
| id: tool.id, | |
| }, | |
| this.tools || [] | |
| ) | |
| ) | |
| ); | |
| assistantMessage.metadata = metadata; | |
| assistantMessage.content = assistantMessage.content.map((message) => { | |
| if (message.type === "tool") { | |
| const toolResponse = toolResponses.find( | |
| (response) => response.id === message.id | |
| ); | |
| if (toolResponse) { | |
| return { | |
| ...message, | |
| result: toolResponse.result, | |
| time: toolResponse.time, | |
| }; | |
| } | |
| return message; | |
| } else { | |
| return message; | |
| } | |
| }); | |
| this.chatMessages = [...prevChatMessages, assistantMessage]; | |
| prompt = toolResponses.map(({ result }) => result).join("\n"); | |
| } | |
| } | |
| } | |