Spaces:
Runtime error
Runtime error
| import asyncio | |
| import re | |
| import traceback | |
| from dataclasses import asdict, dataclass | |
| from datetime import datetime | |
| from typing import Any, Optional | |
| from jinja2 import Environment, StrictUndefined | |
| from torch import cosine_similarity | |
| from pptagent.agent import Agent, AsyncAgent | |
| from pptagent.llms import LLM, AsyncLLM | |
| from pptagent.utils import edit_distance, get_logger, package_join, pexists | |
| from .element import Section, SubSection, Table, link_medias | |
| logger = get_logger(__name__) | |
| env = Environment(undefined=StrictUndefined) | |
| MERGE_METADATA_PROMPT = env.from_string( | |
| open(package_join("prompts", "merge_metadata.txt")).read() | |
| ) | |
| HEADING_EXTRACT_PROMPT = env.from_string( | |
| open(package_join("prompts", "heading_extract.txt")).read() | |
| ) | |
| SECTION_SUMMARY_PROMPT = env.from_string( | |
| open(package_join("prompts", "section_summary.txt")).read() | |
| ) | |
| MARKDOWN_IMAGE_REGEX = re.compile(r"!\[.*\]\(.*\)") | |
| MARKDOWN_TABLE_REGEX = re.compile(r"\|.*\|") | |
| def split_markdown_by_headings( | |
| markdown_content: str, | |
| headings: list[str], | |
| adjusted_headings: list[str], | |
| min_chunk_size: int = 64, | |
| ) -> list[str]: | |
| """ | |
| Split markdown content using headings as separators without regex. | |
| Args: | |
| markdown_content (str): The markdown content to split | |
| headings (list[str]): List of heading strings to split by | |
| Returns: | |
| list[str]: List of content sections | |
| """ | |
| adjusted_headings = [ | |
| max(headings, key=lambda x: edit_distance(x, ah)) for ah in adjusted_headings | |
| ] | |
| sections = [] | |
| current_section = [] | |
| for line in markdown_content.splitlines(): | |
| if any(line.strip().startswith(h) for h in adjusted_headings): | |
| if len(current_section) != 0: | |
| sections.append("\n".join(current_section).strip()) | |
| current_section = [line] | |
| else: | |
| current_section.append(line) | |
| if len(current_section) != 0: | |
| sections.append("\n".join(current_section).strip()) | |
| # if an chunk is too small, merge it with the previous chunk | |
| for i in reversed(range(1, len(sections))): | |
| if len(sections[i]) < min_chunk_size: | |
| sections[i - 1] += sections[i] | |
| sections.pop(i) | |
| if len(sections[0]) < min_chunk_size: | |
| sections[0] += sections[1] | |
| sections.pop(1) | |
| return sections | |
| def to_paragraphs(original_text: str, max_chunk_size: int = 256): | |
| paragraphs = [] | |
| medias = [] | |
| for i, para in enumerate(original_text.split("\n\n")): | |
| para = para.strip() | |
| if not para: | |
| continue | |
| paragraph = {"markdown_content": para, "index": i} | |
| if MARKDOWN_TABLE_REGEX.match(para): | |
| paragraph["type"] = "table" | |
| medias.append(paragraph) | |
| elif MARKDOWN_IMAGE_REGEX.match(para): | |
| paragraph["type"] = "image" | |
| medias.append(paragraph) | |
| else: | |
| paragraphs.append(paragraph) | |
| for media in medias: | |
| pre_chunk = "" | |
| after_chunk = "" | |
| for chunk in reversed(paragraphs): | |
| if chunk["index"] < media["index"]: | |
| pre_chunk += chunk["markdown_content"] + "\n\n" | |
| if len(pre_chunk) > max_chunk_size: | |
| break | |
| for chunk in paragraphs: | |
| if chunk["index"] > media["index"]: | |
| after_chunk += chunk["markdown_content"] + "\n\n" | |
| if len(after_chunk) > max_chunk_size: | |
| break | |
| media["near_chunks"] = (pre_chunk, after_chunk) | |
| return medias | |
| class Document: | |
| image_dir: str | |
| sections: list[Section] | |
| metadata: dict[str, str] | |
| def __post_init__(self): | |
| self.metadata["presentation-date"] = datetime.now().strftime("%Y-%m-%d") | |
| def iter_medias(self): | |
| for section in self.sections: | |
| yield from section.iter_medias() | |
| def get_table(self, image_path: str): | |
| for media in self.iter_medias(): | |
| if media.path == image_path and isinstance(media, Table): | |
| return media | |
| raise ValueError(f"table not found: {image_path}") | |
| def from_dict( | |
| cls, data: dict[str, Any], image_dir: str, require_caption: bool = True | |
| ): | |
| assert ( | |
| "sections" in data | |
| ), f"'sections' key is required in data dictionary but was not found. Input keys: {list(data.keys())}" | |
| assert ( | |
| "metadata" in data | |
| ), f"'metadata' key is required in data dictionary but was not found. Input keys: {list(data.keys())}" | |
| assert pexists(image_dir), f"image directory is not found: {image_dir}" | |
| document = cls( | |
| image_dir=image_dir, | |
| sections=[Section.from_dict(section) for section in data["sections"]], | |
| metadata=data["metadata"], | |
| ) | |
| for section in document.sections: | |
| section.validate_medias(image_dir, require_caption) | |
| return document | |
| def _parse_chunk( | |
| cls, | |
| extractor: Agent, | |
| language_model: LLM, | |
| vision_model: LLM, | |
| table_model: LLM, | |
| metadata: Optional[dict[str, Any]], | |
| section: Optional[dict[str, Any]], | |
| image_dir: str, | |
| turn_id: int = None, | |
| retry: int = 0, | |
| medias: Optional[list[dict]] = None, | |
| ): | |
| if retry == 0: | |
| medias = to_paragraphs(section) | |
| turn_id, section = extractor(markdown_document=section) | |
| metadata = section.pop("metadata", {}) | |
| try: | |
| section["subsections"] = link_medias(medias, section["subsections"]) | |
| section = Section.from_dict(section) | |
| for media in section.iter_medias(): | |
| media.parse(table_model, image_dir) | |
| if isinstance(media, Table): | |
| media.get_caption(language_model) | |
| else: | |
| media.get_caption(vision_model) | |
| section.validate_medias(image_dir, False) | |
| except Exception as e: | |
| if retry < 3: | |
| logger.info("Retry section with error: %s", str(e)) | |
| new_section = extractor.retry( | |
| str(e), traceback.format_exc(), turn_id, retry + 1 | |
| ) | |
| return cls._parse_chunk( | |
| extractor, | |
| language_model, | |
| vision_model, | |
| table_model, | |
| metadata, | |
| new_section, | |
| image_dir, | |
| turn_id, | |
| retry + 1, | |
| medias, | |
| ) | |
| else: | |
| logger.error( | |
| "Failed to extract section, tried %d times", | |
| retry, | |
| exc_info=e, | |
| ) | |
| raise e | |
| return metadata, section | |
| async def _parse_chunk_async( | |
| cls, | |
| extractor: AsyncAgent, | |
| language_model: AsyncLLM, | |
| vision_model: AsyncLLM, | |
| table_model: Optional[AsyncLLM], | |
| metadata: Optional[dict[str, Any]], | |
| section: Optional[dict[str, Any]], | |
| image_dir: str, | |
| turn_id: int = None, | |
| retry: int = 0, | |
| medias: Optional[list[dict]] = None, | |
| ): | |
| if retry == 0: | |
| medias = to_paragraphs(section) | |
| turn_id, section = await extractor(markdown_document=section) | |
| metadata = section.pop("metadata", {}) | |
| try: | |
| section["subsections"] = link_medias(medias, section["subsections"]) | |
| section = Section.from_dict(section) | |
| for media in section.iter_medias(): | |
| await media.parse_async(table_model, image_dir) | |
| if isinstance(media, Table): | |
| await media.get_caption_async(language_model) | |
| else: | |
| await media.get_caption_async(vision_model) | |
| section.validate_medias(image_dir, False) | |
| except Exception as e: | |
| if retry < 3: | |
| logger.info("Retry section with error: %s", str(e)) | |
| new_section = await extractor.retry( | |
| str(e), traceback.format_exc(), turn_id, retry + 1 | |
| ) | |
| return await cls._parse_chunk_async( | |
| extractor, | |
| language_model, | |
| vision_model, | |
| table_model, | |
| metadata, | |
| new_section, | |
| image_dir, | |
| turn_id, | |
| retry + 1, | |
| medias, | |
| ) | |
| else: | |
| logger.error( | |
| "Failed to extract section, tried %d times", | |
| retry, | |
| exc_info=e, | |
| ) | |
| raise e | |
| return metadata, section | |
| def from_markdown( | |
| cls, | |
| markdown_content: str, | |
| language_model: LLM, | |
| vision_model: LLM, | |
| image_dir: str, | |
| table_model: Optional[LLM] = None, | |
| ): | |
| """ | |
| Create a Document from markdown content. | |
| Args: | |
| markdown_content (str): The markdown content. | |
| language_model (LLM): The language model. | |
| vision_model (LLM): The vision model. | |
| image_dir (str): The directory containing images. | |
| Returns: | |
| Document: The created document. | |
| """ | |
| doc_extractor = Agent( | |
| "doc_extractor", | |
| llm_mapping={"language": language_model, "vision": vision_model}, | |
| ) | |
| metadata_list = [] | |
| sections = [] | |
| headings = re.findall(r"^#+\s+.*", markdown_content, re.MULTILINE) | |
| adjusted_headings = language_model( | |
| HEADING_EXTRACT_PROMPT.render(headings=headings), return_json=True | |
| ) | |
| for chunk in split_markdown_by_headings( | |
| markdown_content, headings, adjusted_headings | |
| ): | |
| metadata, section = cls._parse_chunk( | |
| doc_extractor, | |
| language_model, | |
| vision_model, | |
| table_model, | |
| None, | |
| chunk, | |
| image_dir, | |
| ) | |
| section.summary = language_model( | |
| SECTION_SUMMARY_PROMPT.render(section_content=chunk), | |
| ) | |
| metadata_list.append(metadata) | |
| sections.append(section) | |
| merged_metadata = language_model( | |
| MERGE_METADATA_PROMPT.render(metadata=metadata_list), return_json=True | |
| ) | |
| return Document( | |
| image_dir=image_dir, metadata=merged_metadata, sections=sections | |
| ) | |
| async def from_markdown_async( | |
| cls, | |
| markdown_content: str, | |
| language_model: AsyncLLM, | |
| vision_model: AsyncLLM, | |
| image_dir: str, | |
| table_model: Optional[AsyncLLM] = None, | |
| ): | |
| doc_extractor = AsyncAgent( | |
| "doc_extractor", | |
| llm_mapping={"language": language_model, "vision": vision_model}, | |
| ) | |
| headings = re.findall(r"^#+\s+.*", markdown_content, re.MULTILINE) | |
| adjusted_headings = await language_model( | |
| HEADING_EXTRACT_PROMPT.render(headings=headings), return_json=True | |
| ) | |
| metadata = [] | |
| sections = [] | |
| tasks = [] | |
| async with asyncio.TaskGroup() as tg: | |
| for chunk in split_markdown_by_headings( | |
| markdown_content, headings, adjusted_headings | |
| ): | |
| task1 = tg.create_task( | |
| cls._parse_chunk_async( | |
| doc_extractor, | |
| language_model, | |
| vision_model, | |
| table_model, | |
| None, | |
| chunk, | |
| image_dir, | |
| ) | |
| ) | |
| task2 = tg.create_task( | |
| language_model( | |
| SECTION_SUMMARY_PROMPT.render(section_content=chunk), | |
| ) | |
| ) | |
| tasks.append((task1, task2)) | |
| # Process results in order | |
| for task1, task2 in tasks: | |
| meta, section = task1.result() | |
| metadata.append(meta) | |
| sections.append(section) | |
| for section in sections: | |
| section.summary = task2.result() | |
| merged_metadata = await language_model( | |
| MERGE_METADATA_PROMPT.render(metadata=metadata), return_json=True | |
| ) | |
| return Document( | |
| image_dir=image_dir, metadata=merged_metadata, sections=sections | |
| ) | |
| def __contains__(self, key: str): | |
| for section in self.sections: | |
| if section.title == key: | |
| return True | |
| return False | |
| def __getitem__(self, key: str): | |
| for section in self.sections: | |
| if section.title == key: | |
| return section | |
| raise KeyError( | |
| f"section not found: {key}, available sections: {[section.title for section in self.sections]}" | |
| ) | |
| def to_dict(self): | |
| return asdict(self) | |
| def retrieve( | |
| self, | |
| indexs: dict[str, list[str]], | |
| ) -> list[SubSection]: | |
| assert isinstance( | |
| indexs, dict | |
| ), "subsection_keys for index must be a dict, follow a two-level structure" | |
| subsecs = [] | |
| for sec_key, subsec_keys in indexs.items(): | |
| section = self[sec_key] | |
| for subsec_key in subsec_keys: | |
| subsecs.append(section[subsec_key]) | |
| return subsecs | |
| def find_caption(self, caption: str): | |
| for media in self.iter_medias(): | |
| if media.caption == caption: | |
| return media.path | |
| raise ValueError(f"Image caption not found: {caption}") | |
| def get_overview(self, include_summary: bool = False): | |
| overview = "" | |
| for section in self.sections: | |
| overview += f"Section: {section.title}\n" | |
| if include_summary: | |
| overview += f"\tSummary: {section.summary}\n" | |
| for subsection in section.subsections: | |
| overview += f"\tSubsection: {subsection.title}\n" | |
| for media in subsection.medias: | |
| overview += f"\t\tMedia: {media.caption}\n" | |
| overview += "\n" | |
| return overview | |
| def metainfo(self): | |
| return "\n".join([f"{k}: {v}" for k, v in self.metadata.items()]) | |
| def subsections(self): | |
| return [subsec for section in self.sections for subsec in section.subsections] | |
| class OutlineItem: | |
| purpose: str | |
| section: str | |
| indexs: dict[str, list[str]] | str | |
| images: list[str] | |
| def from_dict(cls, data: dict[str, Any]): | |
| assert ( | |
| "purpose" in data and "section" in data | |
| ), "purpose and section of outline item are required" | |
| return cls( | |
| purpose=data["purpose"], | |
| section=data["section"], | |
| indexs=data.get("indexs", {}), | |
| images=data.get("images", []), | |
| ) | |
| def retrieve(self, slide_idx: int, document: Document): | |
| subsections = document.retrieve(self.indexs) | |
| header = f"Slide-{slide_idx+1}: {self.purpose}\n" | |
| content = "" | |
| for subsection in subsections: | |
| content += f"Paragraph: {subsection.title}\nContent: {subsection.content}\n" | |
| images = [ | |
| f"Image: {document.find_caption(caption)}\nCaption: {caption}" | |
| for caption in self.images | |
| ] | |
| return header, content, images | |
| def check_retrieve(self, document: Document, sim_bound: float): | |
| for sec_key, subsec_keys in list(self.indexs.items()): | |
| section = max( | |
| document.sections, key=lambda x: edit_distance(x.title, sec_key) | |
| ) | |
| self.indexs[section.title] = self.indexs.pop(sec_key) | |
| if edit_distance(section.title, sec_key) < sim_bound: | |
| logger.warning( | |
| f"section not found: {sec_key}, available sections: {[section.title for section in document.sections]}.", | |
| ) | |
| raise ValueError( | |
| f"section not found: {sec_key}, available sections: {[section.title for section in document.sections]}." | |
| ) | |
| for idx in range(len(subsec_keys)): | |
| subsection = max( | |
| section.subsections, | |
| key=lambda x: edit_distance(x.title, subsec_keys[idx]), | |
| ) | |
| self.indexs[section.title][idx] = subsection.title | |
| if edit_distance(subsection.title, subsec_keys[idx]) < sim_bound: | |
| raise ValueError( | |
| f"subsection {subsec_keys[idx]} not found in section {section.title}, available subsections: {[subsection.title for subsection in section.subsections]}." | |
| ) | |
| def check_images(self, document: Document, text_model: LLM, sim_bound: float): | |
| doc_images = list(document.iter_medias()) | |
| image_embeddings = [] | |
| for idx, image in enumerate(self.images): | |
| if len(doc_images) == 0: | |
| raise ValueError("Document does not contain any images.") | |
| similar = max(doc_images, key=lambda x: edit_distance(x.caption, image)) | |
| if edit_distance(similar.caption, image) > sim_bound: | |
| self.images[idx] = similar.caption | |
| continue | |
| if len(image_embeddings) == 0: | |
| image_embeddings.extend( | |
| [text_model.get_embedding(image) for image in self.images] | |
| ) | |
| embedding = text_model.get_embedding(image) | |
| similar = max( | |
| range(len(image_embeddings)), | |
| key=lambda x: cosine_similarity(embedding, image_embeddings[x]), | |
| ) | |
| if cosine_similarity(embedding, image_embeddings[similar]) > sim_bound: | |
| self.images[idx] = doc_images[similar].caption | |
| else: | |
| logger.warning( | |
| f"image not found: {image}, available images: {[image.caption for image in doc_images]}.", | |
| ) | |
| raise ValueError( | |
| f"image not found: {image}, available images: \n{[image.caption for image in doc_images]}\nPlease ensure the caption is exactly matched." | |
| ) | |
| async def check_images_async( | |
| self, document: Document, text_model: AsyncLLM, sim_bound: float | |
| ): | |
| doc_images = list(document.iter_medias()) | |
| image_embeddings = [] | |
| for idx, image in enumerate(self.images): | |
| if len(doc_images) == 0: | |
| raise ValueError("Document does not contain any images.") | |
| similar = max(doc_images, key=lambda x: edit_distance(x.caption, image)) | |
| if edit_distance(similar.caption, image) > sim_bound: | |
| self.images[idx] = similar.caption | |
| continue | |
| if len(image_embeddings) == 0: | |
| image_embeddings = await asyncio.gather( | |
| *[text_model.get_embedding(image) for image in self.images] | |
| ) | |
| embedding = await text_model.get_embedding(image) | |
| similar = max( | |
| range(len(image_embeddings)), | |
| key=lambda x: cosine_similarity(embedding, image_embeddings[x]), | |
| ) | |
| if cosine_similarity(embedding, image_embeddings[similar]) > sim_bound: | |
| self.images[idx] = doc_images[similar].caption | |