nico-martin's picture
nico-martin HF Staff
added pythonic tool calls
ec2237a
raw
history blame
8.56 kB
import { type Message } from "@huggingface/transformers";
import {
executeToolCall,
splitResponse,
webMCPToolToChatTemplateTool,
} from "@utils/webMcp";
import type { WebMCPTool } from "@utils/webMcp/types.ts";
import {
type ChatMessage,
type ChatMessageAssistant,
type ChatMessageAssistantResponse,
type ChatMessageAssistantTool,
type GenerationMetadata,
type Request,
RequestType,
type Response,
ResponseType,
} from "./types.ts";
export default class TextGeneration {
private worker: Worker;
private requestId: number = 0;
private modelKey: string | null = null;
private tools: Array<WebMCPTool> | null = null;
private temperature: number | null = null;
private enableThinking: boolean | null = null;
private messages: Array<Message> = [];
private _chatMessages: Array<ChatMessage> = [];
private chatMessagesListener: Array<
(chatMessages: Array<ChatMessage>) => void
> = [];
constructor() {
this.worker = new Worker(
new URL("./worker/textGenerationWorker.ts", import.meta.url),
{
type: "module",
}
);
}
get chatMessages() {
return this._chatMessages;
}
set chatMessages(chatMessages: Array<ChatMessage>) {
this._chatMessages = chatMessages;
this.chatMessagesListener.forEach((listener) => listener(chatMessages));
}
public onChatMessageUpdate = (
callback: (messages: Array<ChatMessage>) => void
) => {
this.chatMessagesListener.push(callback);
return () => {
this.chatMessagesListener = this.chatMessagesListener.filter(
(listener) => listener !== callback
);
};
};
private postWorkerMessage = (request: Request) =>
this.worker.postMessage(request);
private addWorkerEventListener = (
listener: (ev: MessageEvent<Response>) => void
) => this.worker.addEventListener("message", listener);
private removeWorkerEventListener = (
listener: (ev: MessageEvent<Response>) => void
) => this.worker.removeEventListener("message", listener);
public async initializeModel(
modelKey: string,
onDownload: (percentage: number) => void
) {
return new Promise<number>((resolve, reject) => {
const requestId = this.requestId++;
const listener = ({ data }: MessageEvent<Response>) => {
if (data.requestId !== requestId) return;
if (data.type === ResponseType.ERROR) {
this.removeWorkerEventListener(listener);
reject(data.message);
}
if (data.type !== ResponseType.INITIALIZE_MODEL) return;
if (data.done) {
this.removeWorkerEventListener(listener);
this.modelKey = modelKey;
resolve(data.progress);
}
onDownload(data.progress);
};
this.addWorkerEventListener(listener);
this.postWorkerMessage({
type: RequestType.INITIALIZE_MODEL,
modelKey,
requestId,
});
});
}
public initializeConversation(
tools: Array<WebMCPTool> = [],
temperature: number,
enableThinking: boolean,
systemPrompt: string
) {
this.tools = tools;
this.temperature = temperature;
this.enableThinking = enableThinking;
this.messages = [{ role: "system", content: systemPrompt }];
this.chatMessages = [{ role: "system", content: systemPrompt }];
}
public async abort() {
return new Promise<void>((resolve, reject) => {
const requestId = this.requestId++;
const listener = ({ data }: MessageEvent<Response>) => {
if (data.requestId !== requestId) return;
if (data.type === ResponseType.ERROR) {
this.removeWorkerEventListener(listener);
reject(data.message);
}
if (data.type === ResponseType.GENERATE_TEXT_ABORTED) {
this.removeWorkerEventListener(listener);
resolve();
}
};
this.addWorkerEventListener(listener);
this.postWorkerMessage({
type: RequestType.GENERATE_MESSAGE_ABORT,
requestId,
});
});
}
private generateText = (
prompt: string,
role: "user" | "tool",
onResponseUpdate: (response: string) => void = () => {}
) => {
return new Promise<{
response: string;
metadata: GenerationMetadata;
interrupted: boolean;
}>((resolve, reject) => {
if (this.modelKey === null) {
reject("Model not initialized");
return;
}
if (
this.tools === null ||
this.temperature === null ||
this.enableThinking === null
) {
reject("Conversation not initialized");
return;
}
const requestId = this.requestId++;
this.messages = [...this.messages, { role, content: prompt }];
let response = "";
const listener = ({ data }: MessageEvent<Response>) => {
if (data.requestId !== requestId) return;
if (data.type === ResponseType.ERROR) {
this.removeWorkerEventListener(listener);
reject(data.message);
}
if (data.type === ResponseType.GENERATE_TEXT_DONE) {
this.removeWorkerEventListener(listener);
this.messages.push({ role: "assistant", content: data.response });
resolve({
response: data.response,
metadata: data.metadata,
interrupted: data.interrupted,
});
}
if (data.type === ResponseType.GENERATE_TEXT_CHUNK) {
response = response + data.chunk;
onResponseUpdate(response);
}
};
this.addWorkerEventListener(listener);
this.postWorkerMessage({
type: RequestType.GENERATE_MESSAGE,
modelKey: this.modelKey,
messages: this.messages,
tools: this.tools.map(webMCPToolToChatTemplateTool),
requestId,
temperature: this.temperature,
enableThinking: this.enableThinking,
});
});
};
public async runAgent(prompt: string): Promise<void> {
let isUser = true;
this.chatMessages = [
...this.chatMessages,
{ role: "user", content: prompt },
];
while (prompt) {
const prevChatMessages = this.chatMessages;
const assistantMessage: ChatMessageAssistant = {
role: "assistant",
content: [],
interrupted: false,
};
this.chatMessages = [...prevChatMessages, assistantMessage];
const { interrupted, metadata } = await this.generateText(
prompt,
isUser ? "user" : "tool",
(partialResponse) => {
const parts = splitResponse(partialResponse);
assistantMessage.content = parts.map((part) =>
typeof part === "string"
? ({
type: "response",
content: part,
} as ChatMessageAssistantResponse)
: ({
type: "tool",
result: null,
time: null,
functionSignature: `${part.name}(${JSON.stringify(
part.arguments
)})`,
...part,
} as ChatMessageAssistantTool)
);
this.chatMessages = [...prevChatMessages, assistantMessage];
}
);
isUser = false;
assistantMessage.metadata = metadata;
assistantMessage.interrupted = interrupted;
this.chatMessages = [...prevChatMessages, assistantMessage];
const toolCalls = assistantMessage.content.filter(
(c) => c.type === "tool"
);
if (toolCalls.length === 0) {
prompt = "";
continue;
}
const toolResponses = await Promise.all(
toolCalls.map((tool) =>
executeToolCall(
{
name: tool.name,
arguments: tool.arguments,
id: tool.id,
},
this.tools || []
)
)
);
assistantMessage.metadata = metadata;
assistantMessage.content = assistantMessage.content.map((message) => {
if (message.type === "tool") {
const toolResponse = toolResponses.find(
(response) => response.id === message.id
);
if (toolResponse) {
return {
...message,
result: toolResponse.result,
time: toolResponse.time,
};
}
return message;
} else {
return message;
}
});
this.chatMessages = [...prevChatMessages, assistantMessage];
prompt = toolResponses.map(({ result }) => result).join("\n");
}
}
}