Spaces:
Runtime error
Runtime error
| from dataclasses import asdict, dataclass | |
| from functools import partial | |
| from math import ceil | |
| from typing import Optional | |
| import tiktoken | |
| import yaml | |
| from jinja2 import Environment, StrictUndefined, Template | |
| from PIL import Image | |
| from torch import Tensor, cosine_similarity | |
| from pptagent.llms import LLM, AsyncLLM | |
| from pptagent.utils import get_json_from_response, package_join | |
| ENCODING = tiktoken.encoding_for_model("gpt-4o") | |
| class Turn: | |
| """ | |
| A class to represent a turn in a conversation. | |
| """ | |
| id: int | |
| prompt: str | |
| response: str | |
| message: list | |
| retry: int = -1 | |
| images: list[str] = None | |
| input_tokens: int = 0 | |
| output_tokens: int = 0 | |
| embedding: Tensor = None | |
| def to_dict(self): | |
| return {k: v for k, v in asdict(self).items() if k != "embedding"} | |
| def calc_token(self): | |
| """ | |
| Calculate the number of tokens for the turn. | |
| """ | |
| if self.images is not None: | |
| self.input_tokens += calc_image_tokens(self.images) | |
| self.input_tokens += len(ENCODING.encode(self.prompt)) | |
| self.output_tokens = len(ENCODING.encode(self.response)) | |
| def __eq__(self, other): | |
| return self is other | |
| class Agent: | |
| """ | |
| An agent, defined by its instruction template and model. | |
| """ | |
| def __init__( | |
| self, | |
| name: str, | |
| llm_mapping: dict[str, LLM | AsyncLLM], | |
| text_model: Optional[LLM | AsyncLLM] = None, | |
| record_cost: bool = False, | |
| config: Optional[dict] = None, | |
| env: Optional[Environment] = None, | |
| ): | |
| """ | |
| Initialize the Agent. | |
| Args: | |
| name (str): The name of the role. | |
| env (Environment): The Jinja2 environment. | |
| record_cost (bool): Whether to record the token cost. | |
| llm (LLM): The language model. | |
| config (dict): The configuration. | |
| text_model (LLM): The text embedding model. | |
| """ | |
| self.name = name | |
| self.config = config | |
| if self.config is None: | |
| with open(package_join("roles", f"{name}.yaml"), encoding="utf-8") as f: | |
| self.config = yaml.safe_load(f) | |
| assert isinstance(self.config, dict), "Agent config must be a dict" | |
| self.llm_mapping = llm_mapping | |
| self.llm = self.llm_mapping[self.config["use_model"]] | |
| self.model = self.llm.model | |
| self.record_cost = record_cost | |
| self.text_model = text_model | |
| self.return_json = self.config.get("return_json", False) | |
| self.system_message = self.config["system_prompt"] | |
| self.prompt_args = set(self.config["jinja_args"]) | |
| self.env = env | |
| if self.env is None: | |
| self.env = Environment(undefined=StrictUndefined) | |
| self.template = self.env.from_string(self.config["template"]) | |
| self.retry_template = Template( | |
| """The previous output is invalid, please carefully analyze the traceback and feedback information, correct errors happened before. | |
| feedback: | |
| {{feedback}} | |
| traceback: | |
| {{traceback}} | |
| Give your corrected output in the same format without including the previous output: | |
| """ | |
| ) | |
| self.input_tokens = 0 | |
| self.output_tokens = 0 | |
| self._history: list[Turn] = [] | |
| run_args = self.config.get("run_args", {}) | |
| self.llm.__call__ = partial(self.llm.__call__, **run_args) | |
| self.system_tokens = len(ENCODING.encode(self.system_message)) | |
| def calc_cost(self, turns: list[Turn]): | |
| """ | |
| Calculate the cost of a list of turns. | |
| """ | |
| for turn in turns[:-1]: | |
| self.input_tokens += turn.input_tokens | |
| self.input_tokens += turn.output_tokens | |
| self.input_tokens += turns[-1].input_tokens | |
| self.output_tokens += turns[-1].output_tokens | |
| self.input_tokens += self.system_tokens | |
| def get_history(self, similar: int, recent: int, prompt: str): | |
| """ | |
| Get the conversation history. | |
| """ | |
| history = self._history[-recent:] if recent > 0 else [] | |
| if similar > 0: | |
| assert isinstance(self.text_model, LLM), "text_model must be a LLM" | |
| embedding = self.text_model.get_embedding(prompt) | |
| history.sort(key=lambda x: cosine_similarity(embedding, x.embedding)) | |
| for turn in history: | |
| if len(history) > similar + recent: | |
| break | |
| if turn not in history: | |
| history.append(turn) | |
| history.sort(key=lambda x: x.id) | |
| return history | |
| def retry(self, feedback: str, traceback: str, turn_id: int, error_idx: int): | |
| """ | |
| Retry a failed turn with feedback and traceback. | |
| """ | |
| assert error_idx > 0, "error_idx must be greater than 0" | |
| prompt = self.retry_template.render(feedback=feedback, traceback=traceback) | |
| history = [t for t in self._history if t.id == turn_id] | |
| history_msg = [] | |
| for turn in history: | |
| history_msg.extend(turn.message) | |
| response, message = self.llm( | |
| prompt, | |
| history=history_msg, | |
| return_message=True, | |
| ) | |
| turn = Turn( | |
| id=turn_id, | |
| prompt=prompt, | |
| response=response, | |
| message=message, | |
| retry=error_idx, | |
| ) | |
| return self.__post_process__(response, history, turn) | |
| def to_sync(self): | |
| """ | |
| Convert the agent to a synchronous agent. | |
| """ | |
| return Agent( | |
| self.name, | |
| self.llm_mapping, | |
| self.text_model, | |
| self.record_cost, | |
| self.config, | |
| self.env, | |
| ) | |
| def to_async(self): | |
| """ | |
| Convert the agent to an asynchronous agent. | |
| """ | |
| return AsyncAgent( | |
| self.name, | |
| self.llm_mapping, | |
| self.text_model, | |
| self.record_cost, | |
| self.config, | |
| self.env, | |
| ) | |
| def next_turn_id(self): | |
| if len(self._history) == 0: | |
| return 0 | |
| return max(t.id for t in self._history) + 1 | |
| def history(self): | |
| return sorted(self._history, key=lambda x: (x.id, x.retry)) | |
| def __repr__(self) -> str: | |
| return f"{self.__class__.__name__}(name={self.name}, model={self.model})" | |
| def __call__( | |
| self, | |
| images: list[str] = None, | |
| recent: int = 0, | |
| similar: int = 0, | |
| **jinja_args, | |
| ): | |
| """ | |
| Call the agent with prompt arguments. | |
| Args: | |
| images (list[str]): A list of image file paths. | |
| recent (int): The number of recent turns to include. | |
| similar (int): The number of similar turns to include. | |
| **jinja_args: Additional arguments for the Jinja2 template. | |
| Returns: | |
| The response from the role. | |
| """ | |
| if isinstance(images, str): | |
| images = [images] | |
| assert self.prompt_args == set( | |
| jinja_args.keys() | |
| ), f"Invalid arguments, expected: {self.prompt_args}, got: {jinja_args.keys()}" | |
| prompt = self.template.render(**jinja_args) | |
| history = self.get_history(similar, recent, prompt) | |
| history_msg = [] | |
| for turn in history: | |
| history_msg.extend(turn.message) | |
| response, message = self.llm( | |
| prompt, | |
| system_message=self.system_message, | |
| history=history_msg, | |
| images=images, | |
| return_message=True, | |
| ) | |
| turn = Turn( | |
| id=self.next_turn_id, | |
| prompt=prompt, | |
| response=response, | |
| message=message, | |
| images=images, | |
| ) | |
| return turn.id, self.__post_process__(response, history, turn, similar) | |
| def __post_process__( | |
| self, response: str, history: list[Turn], turn: Turn, similar: int = 0 | |
| ) -> str | dict: | |
| """ | |
| Post-process the response from the agent. | |
| """ | |
| self._history.append(turn) | |
| if similar > 0: | |
| turn.embedding = self.text_model.get_embedding(turn.prompt) | |
| if self.record_cost: | |
| turn.calc_token() | |
| self.calc_cost(history + [turn]) | |
| if self.return_json: | |
| response = get_json_from_response(response) | |
| return response | |
| class AsyncAgent(Agent): | |
| """ | |
| An agent, defined by its instruction template and model. | |
| """ | |
| def __init__( | |
| self, | |
| name: str, | |
| llm_mapping: dict[str, AsyncLLM], | |
| text_model: Optional[AsyncLLM] = None, | |
| record_cost: bool = False, | |
| config: Optional[dict] = None, | |
| env: Optional[Environment] = None, | |
| ): | |
| super().__init__(name, llm_mapping, text_model, record_cost, config, env) | |
| self.llm = self.llm.to_async() | |
| async def retry(self, feedback: str, traceback: str, turn_id: int, error_idx: int): | |
| """ | |
| Retry a failed turn with feedback and traceback. | |
| """ | |
| assert error_idx > 0, "error_idx must be greater than 0" | |
| prompt = self.retry_template.render(feedback=feedback, traceback=traceback) | |
| history = [t for t in self._history if t.id == turn_id] | |
| history_msg = [] | |
| for turn in history: | |
| history_msg.extend(turn.message) | |
| response, message = await self.llm( | |
| prompt, | |
| history=history_msg, | |
| return_message=True, | |
| ) | |
| turn = Turn( | |
| id=turn_id, | |
| prompt=prompt, | |
| response=response, | |
| message=message, | |
| retry=error_idx, | |
| ) | |
| return await self.__post_process__(response, history, turn) | |
| async def __call__( | |
| self, | |
| images: list[str] = None, | |
| recent: int = 0, | |
| similar: int = 0, | |
| **jinja_args, | |
| ): | |
| """ | |
| Call the agent with prompt arguments. | |
| Args: | |
| images (list[str]): A list of image file paths. | |
| recent (int): The number of recent turns to include. | |
| similar (int): The number of similar turns to include. | |
| **jinja_args: Additional arguments for the Jinja2 template. | |
| Returns: | |
| The response from the role. | |
| """ | |
| if isinstance(images, str): | |
| images = [images] | |
| assert self.prompt_args == set( | |
| jinja_args.keys() | |
| ), f"Invalid arguments, expected: {self.prompt_args}, got: {jinja_args.keys()}" | |
| prompt = self.template.render(**jinja_args) | |
| history = await self.get_history(similar, recent, prompt) | |
| history_msg = [] | |
| for turn in history: | |
| history_msg.extend(turn.message) | |
| response, message = await self.llm( | |
| prompt, | |
| system_message=self.system_message, | |
| history=history_msg, | |
| images=images, | |
| return_message=True, | |
| ) | |
| turn = Turn( | |
| id=self.next_turn_id, | |
| prompt=prompt, | |
| response=response, | |
| message=message, | |
| images=images, | |
| ) | |
| return turn.id, await self.__post_process__(response, history, turn, similar) | |
| async def get_history(self, similar: int, recent: int, prompt: str): | |
| """ | |
| Get the conversation history. | |
| """ | |
| history = self._history[-recent:] if recent > 0 else [] | |
| if similar > 0: | |
| embedding = await self.text_model.get_embedding(prompt) | |
| history.sort(key=lambda x: cosine_similarity(embedding, x.embedding)) | |
| for turn in history: | |
| if len(history) > similar + recent: | |
| break | |
| if turn not in history: | |
| history.append(turn) | |
| history.sort(key=lambda x: x.id) | |
| return history | |
| async def __post_process__( | |
| self, response: str, history: list[Turn], turn: Turn, similar: int = 0 | |
| ): | |
| """ | |
| Post-process the response from the agent. | |
| """ | |
| self._history.append(turn) | |
| if similar > 0: | |
| turn.embedding = await self.text_model.get_embedding(turn.prompt) | |
| if self.record_cost: | |
| turn.calc_token() | |
| self.calc_cost(history + [turn]) | |
| if self.return_json: | |
| response = get_json_from_response(response) | |
| return response | |
| def calc_image_tokens(images: list[str]): | |
| """ | |
| Calculate the number of tokens for a list of images. | |
| """ | |
| tokens = 0 | |
| for image in images: | |
| with open(image, "rb") as f: | |
| width, height = Image.open(f).size | |
| if width > 1024 or height > 1024: | |
| if width > height: | |
| height = int(height * 1024 / width) | |
| width = 1024 | |
| else: | |
| width = int(width * 1024 / height) | |
| height = 1024 | |
| h = ceil(height / 512) | |
| w = ceil(width / 512) | |
| tokens += 85 + 170 * h * w | |
| return tokens | |