Spaces:
Sleeping
Sleeping
| import operator | |
| from helpers import image_to_base64 | |
| import torch | |
| from langgraph.graph import END, StateGraph | |
| from langgraph.types import Send | |
| from typing import Annotated, TypedDict, Any | |
| from transformers import ( | |
| AutoProcessor, | |
| BitsAndBytesConfig, | |
| Gemma3ForConditionalGeneration, | |
| ) | |
| def get_quantization_config(): | |
| return BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| # Define the state schema | |
| class State(TypedDict): | |
| image: Any | |
| voices: list | |
| caption: str | |
| descriptions: Annotated[list, operator.add] | |
| # Build the workflow graph | |
| def build_graph(): | |
| workflow = StateGraph(State) | |
| workflow.add_node("caption_image", caption_image) | |
| workflow.add_node("describe_with_voice", describe_with_voice) | |
| # Add edges | |
| workflow.set_entry_point("caption_image") | |
| workflow.add_conditional_edges("caption_image", map_describe, ["describe_with_voice"]) | |
| workflow.add_edge("describe_with_voice", END) | |
| # Compile the graph | |
| return workflow.compile() | |
| model_id = "google/gemma-3-4b-it" | |
| # Initialize processor and model | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| model = Gemma3ForConditionalGeneration.from_pretrained( | |
| model_id, | |
| # quantization_config=get_quantization_config(), | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| ).eval() | |
| def describe_with_voice(state: State): | |
| caption = state["caption"] | |
| # select one by default shakespeare | |
| voice = state.get("voice", state.get("voices", ["shakespearian"])[0]) | |
| # Voice prompt templates | |
| voice_prompts = { | |
| "scurvy-ridden pirate": "You are a scurvy-ridden pirate, angry and drunk.", | |
| "forgetful wizard": "You are a forgetful and easily distracted wizard.", | |
| "sarcastic teenager": "You are a sarcastic and disinterested teenager.", | |
| "private investigator": "You are a Victorian-age detective. Suave and intellectual.", | |
| "shakespearian": "Talk like one of Shakespeare's characters. ", | |
| } | |
| system_prompt = voice_prompts.get(voice, "You are a pirate.") + " Output 5-10 sentences. Utilize markdown for dramatic text formatting." | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": [{"type": "text", "text": system_prompt}], | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": f"Describe the following:\n\n{caption}"} | |
| ], | |
| }, | |
| ] | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ).to(model.device, dtype=torch.bfloat16) | |
| input_len = inputs["input_ids"].shape[-1] | |
| with torch.inference_mode(): | |
| generation = model.generate(**inputs, max_new_tokens=1000, do_sample=True, temperature=0.9) | |
| generation = generation[0][input_len:] | |
| description = processor.decode(generation, skip_special_tokens=True) | |
| formatted_description = f"## {voice.title()}\n\n{description}" | |
| print(formatted_description) | |
| # note that the return value is a list | |
| return {"descriptions": [formatted_description]} | |
| def map_describe(state: State) -> list: | |
| # Create a Send object for each selected voice | |
| selected_voices = state["voices"] | |
| # Generate description tasks for each selected voice | |
| send_objects = [] | |
| for voice in selected_voices: | |
| send_objects.append( | |
| Send("describe_with_voice", {"caption": state["caption"], "voice": voice}) | |
| ) | |
| return send_objects | |
| def caption_image(state: State): | |
| # image is PIL | |
| image = state["image"] | |
| image = image_to_base64(image) | |
| # Load models (in practice, do this once and cache) | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "You are a helpful assistant that will describe images in 3-5 sentences.", | |
| } | |
| ], | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": "Describe this image."}, | |
| ], | |
| }, | |
| ] | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ).to(model.device, dtype=torch.bfloat16) | |
| input_len = inputs["input_ids"].shape[-1] | |
| with torch.inference_mode(): | |
| generation = model.generate(**inputs, max_new_tokens=1000, do_sample=False) | |
| generation = generation[0][input_len:] | |
| caption = processor.decode(generation, skip_special_tokens=True) | |
| print(caption) | |
| return {"caption" : caption} | |