Spaces:
Running
on
Zero
Running
on
Zero
| from collections.abc import Callable | |
| import torch | |
| import torchvision.transforms.functional as F | |
| import io | |
| import os | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as patches | |
| from PIL import Image, ImageDraw, ImageColor, ImageFont | |
| import random | |
| import numpy as np | |
| import re | |
| from pathlib import Path | |
| #workaround for unnecessary flash_attn requirement | |
| from unittest.mock import patch | |
| from transformers.dynamic_module_utils import get_imports | |
| import transformers | |
| from safetensors.torch import save_file | |
| def fixed_get_imports(filename: str | os.PathLike) -> list[str]: | |
| try: | |
| if not str(filename).endswith("modeling_florence2.py"): | |
| return get_imports(filename) | |
| imports = get_imports(filename) | |
| imports.remove("flash_attn") | |
| except: | |
| print(f"No flash_attn import to remove") | |
| pass | |
| return imports | |
| def create_path_dict(paths: list[str], predicate: Callable[[Path], bool] = lambda _: True) -> dict[str, str]: | |
| """ | |
| Creates a flat dictionary of the contents of all given paths: ``{name: absolute_path}``. | |
| Non-recursive. Optionally takes a predicate to filter items. Duplicate names overwrite (the last one wins). | |
| Args: | |
| paths (list[str]): | |
| The paths to search for items. | |
| predicate (Callable[[Path], bool]): | |
| (Optional) If provided, each path is tested against this filter. | |
| Returns ``True`` to include a path. | |
| Default: Include everything | |
| """ | |
| flattened_paths = [item for path in paths if Path(path).exists() for item in Path(path).iterdir() if predicate(item)] | |
| return {item.name: str(item.absolute()) for item in flattened_paths} | |
| import comfy.model_management as mm | |
| from comfy.utils import ProgressBar | |
| import folder_paths | |
| script_directory = os.path.dirname(os.path.abspath(__file__)) | |
| model_directory = os.path.join(folder_paths.models_dir, "LLM") | |
| os.makedirs(model_directory, exist_ok=True) | |
| # Ensure ComfyUI knows about the LLM model path | |
| folder_paths.add_model_folder_path("LLM", model_directory) | |
| from transformers import AutoModelForCausalLM, AutoProcessor, set_seed | |
| class DownloadAndLoadFlorence2Model: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "model": ( | |
| [ | |
| 'microsoft/Florence-2-base', | |
| 'microsoft/Florence-2-base-ft', | |
| 'microsoft/Florence-2-large', | |
| 'microsoft/Florence-2-large-ft', | |
| 'HuggingFaceM4/Florence-2-DocVQA', | |
| 'thwri/CogFlorence-2.1-Large', | |
| 'thwri/CogFlorence-2.2-Large', | |
| 'gokaygokay/Florence-2-SD3-Captioner', | |
| 'gokaygokay/Florence-2-Flux-Large', | |
| 'MiaoshouAI/Florence-2-base-PromptGen-v1.5', | |
| 'MiaoshouAI/Florence-2-large-PromptGen-v1.5', | |
| 'MiaoshouAI/Florence-2-base-PromptGen-v2.0', | |
| 'MiaoshouAI/Florence-2-large-PromptGen-v2.0' | |
| ], | |
| { | |
| "default": 'microsoft/Florence-2-base' | |
| }), | |
| "precision": ([ 'fp16','bf16','fp32'], | |
| { | |
| "default": 'fp16' | |
| }), | |
| "attention": ( | |
| [ 'flash_attention_2', 'sdpa', 'eager'], | |
| { | |
| "default": 'sdpa' | |
| }), | |
| }, | |
| "optional": { | |
| "lora": ("PEFTLORA",), | |
| "convert_to_safetensors": ("BOOLEAN", {"default": False, "tooltip": "Some of the older model weights are not saved in .safetensors format, which seem to cause longer loading times, this option converts the .bin weights to .safetensors"}), | |
| } | |
| } | |
| RETURN_TYPES = ("FL2MODEL",) | |
| RETURN_NAMES = ("florence2_model",) | |
| FUNCTION = "loadmodel" | |
| CATEGORY = "Florence2" | |
| def loadmodel(self, model, precision, attention, lora=None, convert_to_safetensors=False): | |
| device = mm.get_torch_device() | |
| offload_device = mm.unet_offload_device() | |
| dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] | |
| model_name = model.rsplit('/', 1)[-1] | |
| model_path = os.path.join(model_directory, model_name) | |
| if not os.path.exists(model_path): | |
| print(f"Downloading Florence2 model to: {model_path}") | |
| from huggingface_hub import snapshot_download | |
| snapshot_download(repo_id=model, | |
| local_dir=model_path, | |
| local_dir_use_symlinks=False) | |
| print(f"Florence2 using {attention} for attention") | |
| from transformers import AutoConfig | |
| # Manually load the state dict to CPU to avoid issues with ZeroGPU patching | |
| print("Manually loading weights to CPU...") | |
| weights_path = os.path.join(model_path, "pytorch_model.bin") | |
| state_dict = torch.load(weights_path, map_location="cpu") | |
| config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) | |
| if transformers.__version__ < '4.51.0': | |
| with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement | |
| model = AutoModelForCausalLM.from_pretrained( | |
| None, config=config, state_dict=state_dict, attn_implementation=attention, | |
| torch_dtype=dtype, trust_remote_code=True | |
| ) | |
| else: | |
| from .modeling_florence2 import Florence2ForConditionalGeneration | |
| model = Florence2ForConditionalGeneration.from_pretrained( | |
| None, config=config, state_dict=state_dict, attn_implementation=attention, torch_dtype=dtype | |
| ) | |
| # We don't need to call .to(offload_device) here as it's already on CPU | |
| # and the run node will handle moving it to the GPU. | |
| processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) | |
| if lora is not None: | |
| from peft import PeftModel | |
| adapter_name = lora | |
| model = PeftModel.from_pretrained(model, adapter_name, trust_remote_code=True) | |
| florence2_model = { | |
| 'model': model, | |
| 'processor': processor, | |
| 'dtype': dtype | |
| } | |
| return (florence2_model,) | |
| class DownloadAndLoadFlorence2Lora: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "model": ( | |
| [ | |
| 'NikshepShetty/Florence-2-pixelprose', | |
| ], | |
| ), | |
| }, | |
| } | |
| RETURN_TYPES = ("PEFTLORA",) | |
| RETURN_NAMES = ("lora",) | |
| FUNCTION = "loadmodel" | |
| CATEGORY = "Florence2" | |
| def loadmodel(self, model): | |
| model_name = model.rsplit('/', 1)[-1] | |
| model_path = os.path.join(model_directory, model_name) | |
| if not os.path.exists(model_path): | |
| print(f"Downloading Florence2 lora model to: {model_path}") | |
| from huggingface_hub import snapshot_download | |
| snapshot_download(repo_id=model, | |
| local_dir=model_path, | |
| local_dir_use_symlinks=False) | |
| return (model_path,) | |
| class Florence2ModelLoader: | |
| def INPUT_TYPES(s): | |
| all_llm_paths = folder_paths.get_folder_paths("LLM") | |
| s.model_paths = create_path_dict(all_llm_paths, lambda x: x.is_dir()) | |
| return {"required": { | |
| "model": ([*s.model_paths], {"tooltip": "models are expected to be in Comfyui/models/LLM folder"}), | |
| "precision": (['fp16','bf16','fp32'],), | |
| "attention": ( | |
| [ 'flash_attention_2', 'sdpa', 'eager'], | |
| { | |
| "default": 'sdpa' | |
| }), | |
| }, | |
| "optional": { | |
| "lora": ("PEFTLORA",), | |
| "convert_to_safetensors": ("BOOLEAN", {"default": False, "tooltip": "Some of the older model weights are not saved in .safetensors format, which seem to cause longer loading times, this option converts the .bin weights to .safetensors"}), | |
| } | |
| } | |
| RETURN_TYPES = ("FL2MODEL",) | |
| RETURN_NAMES = ("florence2_model",) | |
| FUNCTION = "loadmodel" | |
| CATEGORY = "Florence2" | |
| def loadmodel(self, model, precision, attention, lora=None, convert_to_safetensors=False): | |
| device = mm.get_torch_device() | |
| offload_device = mm.unet_offload_device() | |
| dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] | |
| model_path = Florence2ModelLoader.model_paths.get(model) | |
| print(f"Loading model from {model_path}") | |
| print(f"Florence2 using {attention} for attention") | |
| from transformers import AutoConfig | |
| # Manually load the state dict to CPU to avoid issues with ZeroGPU patching | |
| print("Manually loading weights to CPU...") | |
| # Prefer safetensors if they exist (potentially after conversion) | |
| weights_path = os.path.join(model_path, "model.safetensors") | |
| if not os.path.exists(weights_path): | |
| weights_path = os.path.join(model_path, "pytorch_model.bin") | |
| state_dict = torch.load(weights_path, map_location="cpu") | |
| config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) | |
| if transformers.__version__ < '4.51.0': | |
| with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement | |
| model = AutoModelForCausalLM.from_pretrained( | |
| None, config=config, state_dict=state_dict, attn_implementation=attention, | |
| torch_dtype=dtype, trust_remote_code=True | |
| ) | |
| else: | |
| from .modeling_florence2 import Florence2ForConditionalGeneration | |
| model = Florence2ForConditionalGeneration.from_pretrained( | |
| None, config=config, state_dict=state_dict, attn_implementation=attention, torch_dtype=dtype | |
| ) | |
| processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) | |
| if lora is not None: | |
| from peft import PeftModel | |
| adapter_name = lora | |
| model = PeftModel.from_pretrained(model, adapter_name, trust_remote_code=True) | |
| florence2_model = { | |
| 'model': model, | |
| 'processor': processor, | |
| 'dtype': dtype | |
| } | |
| return (florence2_model,) | |
| class Florence2Run: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "image": ("IMAGE", ), | |
| "florence2_model": ("FL2MODEL", ), | |
| "text_input": ("STRING", {"default": "", "multiline": True}), | |
| "task": ( | |
| [ | |
| 'region_caption', | |
| 'dense_region_caption', | |
| 'region_proposal', | |
| 'caption', | |
| 'detailed_caption', | |
| 'more_detailed_caption', | |
| 'caption_to_phrase_grounding', | |
| 'referring_expression_segmentation', | |
| 'ocr', | |
| 'ocr_with_region', | |
| 'docvqa', | |
| 'prompt_gen_tags', | |
| 'prompt_gen_mixed_caption', | |
| 'prompt_gen_analyze', | |
| 'prompt_gen_mixed_caption_plus', | |
| ], | |
| ), | |
| "fill_mask": ("BOOLEAN", {"default": True}), | |
| }, | |
| "optional": { | |
| "keep_model_loaded": ("BOOLEAN", {"default": False}), | |
| "max_new_tokens": ("INT", {"default": 1024, "min": 1, "max": 4096}), | |
| "num_beams": ("INT", {"default": 3, "min": 1, "max": 64}), | |
| "do_sample": ("BOOLEAN", {"default": True}), | |
| "output_mask_select": ("STRING", {"default": ""}), | |
| "seed": ("INT", {"default": 1, "min": 1, "max": 0xffffffffffffffff}), | |
| } | |
| } | |
| RETURN_TYPES = ("IMAGE", "MASK", "STRING", "JSON") | |
| RETURN_NAMES =("image", "mask", "caption", "data") | |
| FUNCTION = "encode" | |
| CATEGORY = "Florence2" | |
| def hash_seed(self, seed): | |
| import hashlib | |
| # Convert the seed to a string and then to bytes | |
| seed_bytes = str(seed).encode('utf-8') | |
| # Create a SHA-256 hash of the seed bytes | |
| hash_object = hashlib.sha256(seed_bytes) | |
| # Convert the hash to an integer | |
| hashed_seed = int(hash_object.hexdigest(), 16) | |
| # Ensure the hashed seed is within the acceptable range for set_seed | |
| return hashed_seed % (2**32) | |
| def encode(self, image, text_input, florence2_model, task, fill_mask, keep_model_loaded=False, | |
| num_beams=3, max_new_tokens=1024, do_sample=True, output_mask_select="", seed=None): | |
| device = mm.get_torch_device() | |
| _, height, width, _ = image.shape | |
| offload_device = mm.unet_offload_device() | |
| annotated_image_tensor = None | |
| mask_tensor = None | |
| processor = florence2_model['processor'] | |
| model = florence2_model['model'] | |
| dtype = florence2_model['dtype'] | |
| model.to(device) | |
| if seed: | |
| set_seed(self.hash_seed(seed)) | |
| colormap = ['blue','orange','green','purple','brown','pink','olive','cyan','red', | |
| 'lime','indigo','violet','aqua','magenta','gold','tan','skyblue'] | |
| prompts = { | |
| 'region_caption': '<OD>', | |
| 'dense_region_caption': '<DENSE_REGION_CAPTION>', | |
| 'region_proposal': '<REGION_PROPOSAL>', | |
| 'caption': '<CAPTION>', | |
| 'detailed_caption': '<DETAILED_CAPTION>', | |
| 'more_detailed_caption': '<MORE_DETAILED_CAPTION>', | |
| 'caption_to_phrase_grounding': '<CAPTION_TO_PHRASE_GROUNDING>', | |
| 'referring_expression_segmentation': '<REFERRING_EXPRESSION_SEGMENTATION>', | |
| 'ocr': '<OCR>', | |
| 'ocr_with_region': '<OCR_WITH_REGION>', | |
| 'docvqa': '<DocVQA>', | |
| 'prompt_gen_tags': '<GENERATE_TAGS>', | |
| 'prompt_gen_mixed_caption': '<MIXED_CAPTION>', | |
| 'prompt_gen_analyze': '<ANALYZE>', | |
| 'prompt_gen_mixed_caption_plus': '<MIXED_CAPTION_PLUS>', | |
| } | |
| task_prompt = prompts.get(task, '<OD>') | |
| if (task not in ['referring_expression_segmentation', 'caption_to_phrase_grounding', 'docvqa']) and text_input: | |
| raise ValueError("Text input (prompt) is only supported for 'referring_expression_segmentation', 'caption_to_phrase_grounding', and 'docvqa'") | |
| if text_input != "": | |
| prompt = task_prompt + " " + text_input | |
| else: | |
| prompt = task_prompt | |
| image = image.permute(0, 3, 1, 2) | |
| out = [] | |
| out_masks = [] | |
| out_results = [] | |
| out_data = [] | |
| pbar = ProgressBar(len(image)) | |
| for img in image: | |
| image_pil = F.to_pil_image(img) | |
| inputs = processor(text=prompt, images=image_pil, return_tensors="pt", do_rescale=False).to(dtype).to(device) | |
| generated_ids = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| num_beams=num_beams, | |
| ) | |
| results = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| print(results) | |
| # cleanup the special tokens from the final list | |
| if task == 'ocr_with_region': | |
| clean_results = str(results) | |
| cleaned_string = re.sub(r'</?s>|<[^>]*>', '\n', clean_results) | |
| clean_results = re.sub(r'\n+', '\n', cleaned_string) | |
| else: | |
| clean_results = str(results) | |
| clean_results = clean_results.replace('</s>', '') | |
| clean_results = clean_results.replace('<s>', '') | |
| #return single string if only one image for compatibility with nodes that can't handle string lists | |
| if len(image) == 1: | |
| out_results = clean_results | |
| else: | |
| out_results.append(clean_results) | |
| W, H = image_pil.size | |
| parsed_answer = processor.post_process_generation(results, task=task_prompt, image_size=(W, H)) | |
| if task == 'region_caption' or task == 'dense_region_caption' or task == 'caption_to_phrase_grounding' or task == 'region_proposal': | |
| fig, ax = plt.subplots(figsize=(W / 100, H / 100), dpi=100) | |
| fig.subplots_adjust(left=0, right=1, top=1, bottom=0) | |
| ax.imshow(image_pil) | |
| bboxes = parsed_answer[task_prompt]['bboxes'] | |
| labels = parsed_answer[task_prompt]['labels'] | |
| mask_indexes = [] | |
| # Determine mask indexes outside the loop | |
| if output_mask_select != "": | |
| mask_indexes = [n for n in output_mask_select.split(",")] | |
| print(mask_indexes) | |
| else: | |
| mask_indexes = [str(i) for i in range(len(bboxes))] | |
| # Initialize mask_layer only if needed | |
| if fill_mask: | |
| mask_layer = Image.new('RGB', image_pil.size, (0, 0, 0)) | |
| mask_draw = ImageDraw.Draw(mask_layer) | |
| for index, (bbox, label) in enumerate(zip(bboxes, labels)): | |
| # Modify the label to include the index | |
| indexed_label = f"{index}.{label}" | |
| if fill_mask: | |
| if str(index) in mask_indexes: | |
| print("match index:", str(index), "in mask_indexes:", mask_indexes) | |
| mask_draw.rectangle([bbox[0], bbox[1], bbox[2], bbox[3]], fill=(255, 255, 255)) | |
| if label in mask_indexes: | |
| print("match label") | |
| mask_draw.rectangle([bbox[0], bbox[1], bbox[2], bbox[3]], fill=(255, 255, 255)) | |
| # Create a Rectangle patch | |
| rect = patches.Rectangle( | |
| (bbox[0], bbox[1]), # (x,y) - lower left corner | |
| bbox[2] - bbox[0], # Width | |
| bbox[3] - bbox[1], # Height | |
| linewidth=1, | |
| edgecolor='r', | |
| facecolor='none', | |
| label=indexed_label | |
| ) | |
| # Calculate text width with a rough estimation | |
| text_width = len(label) * 6 # Adjust multiplier based on your font size | |
| text_height = 12 # Adjust based on your font size | |
| # Initial text position | |
| text_x = bbox[0] | |
| text_y = bbox[1] - text_height # Position text above the top-left of the bbox | |
| # Adjust text_x if text is going off the left or right edge | |
| if text_x < 0: | |
| text_x = 0 | |
| elif text_x + text_width > W: | |
| text_x = W - text_width | |
| # Adjust text_y if text is going off the top edge | |
| if text_y < 0: | |
| text_y = bbox[3] # Move text below the bottom-left of the bbox if it doesn't overlap with bbox | |
| # Add the rectangle to the plot | |
| ax.add_patch(rect) | |
| facecolor = random.choice(colormap) if len(image) == 1 else 'red' | |
| # Add the label | |
| plt.text( | |
| text_x, | |
| text_y, | |
| indexed_label, | |
| color='white', | |
| fontsize=12, | |
| bbox=dict(facecolor=facecolor, alpha=0.5) | |
| ) | |
| if fill_mask: | |
| mask_tensor = F.to_tensor(mask_layer) | |
| mask_tensor = mask_tensor.unsqueeze(0).permute(0, 2, 3, 1).cpu().float() | |
| mask_tensor = mask_tensor.mean(dim=0, keepdim=True) | |
| mask_tensor = mask_tensor.repeat(1, 1, 1, 3) | |
| mask_tensor = mask_tensor[:, :, :, 0] | |
| out_masks.append(mask_tensor) | |
| # Remove axis and padding around the image | |
| ax.axis('off') | |
| ax.margins(0,0) | |
| ax.get_xaxis().set_major_locator(plt.NullLocator()) | |
| ax.get_yaxis().set_major_locator(plt.NullLocator()) | |
| fig.canvas.draw() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', pad_inches=0) | |
| buf.seek(0) | |
| annotated_image_pil = Image.open(buf) | |
| annotated_image_tensor = F.to_tensor(annotated_image_pil) | |
| out_tensor = annotated_image_tensor[:3, :, :].unsqueeze(0).permute(0, 2, 3, 1).cpu().float() | |
| out.append(out_tensor) | |
| if task == 'caption_to_phrase_grounding': | |
| out_data.append(parsed_answer[task_prompt]) | |
| else: | |
| out_data.append(bboxes) | |
| pbar.update(1) | |
| plt.close(fig) | |
| elif task == 'referring_expression_segmentation': | |
| # Create a new black image | |
| mask_image = Image.new('RGB', (W, H), 'black') | |
| mask_draw = ImageDraw.Draw(mask_image) | |
| predictions = parsed_answer[task_prompt] | |
| # Iterate over polygons and labels | |
| for polygons, label in zip(predictions['polygons'], predictions['labels']): | |
| color = random.choice(colormap) | |
| for _polygon in polygons: | |
| _polygon = np.array(_polygon).reshape(-1, 2) | |
| # Clamp polygon points to image boundaries | |
| _polygon = np.clip(_polygon, [0, 0], [W - 1, H - 1]) | |
| if len(_polygon) < 3: | |
| print('Invalid polygon:', _polygon) | |
| continue | |
| _polygon = _polygon.reshape(-1).tolist() | |
| # Draw the polygon | |
| if fill_mask: | |
| overlay = Image.new('RGBA', image_pil.size, (255, 255, 255, 0)) | |
| image_pil = image_pil.convert('RGBA') | |
| draw = ImageDraw.Draw(overlay) | |
| color_with_opacity = ImageColor.getrgb(color) + (180,) | |
| draw.polygon(_polygon, outline=color, fill=color_with_opacity, width=3) | |
| image_pil = Image.alpha_composite(image_pil, overlay) | |
| else: | |
| draw = ImageDraw.Draw(image_pil) | |
| draw.polygon(_polygon, outline=color, width=3) | |
| #draw mask | |
| mask_draw.polygon(_polygon, outline="white", fill="white") | |
| image_tensor = F.to_tensor(image_pil) | |
| image_tensor = image_tensor[:3, :, :].unsqueeze(0).permute(0, 2, 3, 1).cpu().float() | |
| out.append(image_tensor) | |
| mask_tensor = F.to_tensor(mask_image) | |
| mask_tensor = mask_tensor.unsqueeze(0).permute(0, 2, 3, 1).cpu().float() | |
| mask_tensor = mask_tensor.mean(dim=0, keepdim=True) | |
| mask_tensor = mask_tensor.repeat(1, 1, 1, 3) | |
| mask_tensor = mask_tensor[:, :, :, 0] | |
| out_masks.append(mask_tensor) | |
| pbar.update(1) | |
| elif task == 'ocr_with_region': | |
| try: | |
| font = ImageFont.load_default().font_variant(size=24) | |
| except: | |
| font = ImageFont.load_default() | |
| predictions = parsed_answer[task_prompt] | |
| scale = 1 | |
| image_pil = image_pil.convert('RGBA') | |
| overlay = Image.new('RGBA', image_pil.size, (255, 255, 255, 0)) | |
| draw = ImageDraw.Draw(overlay) | |
| bboxes, labels = predictions['quad_boxes'], predictions['labels'] | |
| # Create a new black image for the mask | |
| mask_image = Image.new('RGB', (W, H), 'black') | |
| mask_draw = ImageDraw.Draw(mask_image) | |
| for box, label in zip(bboxes, labels): | |
| scaled_box = [v / (width if idx % 2 == 0 else height) for idx, v in enumerate(box)] | |
| out_data.append({"label": label, "box": scaled_box}) | |
| color = random.choice(colormap) | |
| new_box = (np.array(box) * scale).tolist() | |
| if fill_mask: | |
| color_with_opacity = ImageColor.getrgb(color) + (180,) | |
| draw.polygon(new_box, outline=color, fill=color_with_opacity, width=3) | |
| else: | |
| draw.polygon(new_box, outline=color, width=3) | |
| draw.text((new_box[0]+8, new_box[1]+2), | |
| "{}".format(label), | |
| align="right", | |
| font=font, | |
| fill=color) | |
| # Draw the mask | |
| mask_draw.polygon(new_box, outline="white", fill="white") | |
| image_pil = Image.alpha_composite(image_pil, overlay) | |
| image_pil = image_pil.convert('RGB') | |
| image_tensor = F.to_tensor(image_pil) | |
| image_tensor = image_tensor[:3, :, :].unsqueeze(0).permute(0, 2, 3, 1).cpu().float() | |
| out.append(image_tensor) | |
| # Process the mask | |
| mask_tensor = F.to_tensor(mask_image) | |
| mask_tensor = mask_tensor.unsqueeze(0).permute(0, 2, 3, 1).cpu().float() | |
| mask_tensor = mask_tensor.mean(dim=0, keepdim=True) | |
| mask_tensor = mask_tensor.repeat(1, 1, 1, 3) | |
| mask_tensor = mask_tensor[:, :, :, 0] | |
| out_masks.append(mask_tensor) | |
| pbar.update(1) | |
| elif task == 'docvqa': | |
| if text_input == "": | |
| raise ValueError("Text input (prompt) is required for 'docvqa'") | |
| prompt = "<DocVQA> " + text_input | |
| inputs = processor(text=prompt, images=image_pil, return_tensors="pt", do_rescale=False).to(dtype).to(device) | |
| generated_ids = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| num_beams=num_beams, | |
| ) | |
| results = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| clean_results = results.replace('</s>', '').replace('<s>', '') | |
| if len(image) == 1: | |
| out_results = clean_results | |
| else: | |
| out_results.append(clean_results) | |
| out.append(F.to_tensor(image_pil).unsqueeze(0).permute(0, 2, 3, 1).cpu().float()) | |
| pbar.update(1) | |
| if len(out) > 0: | |
| out_tensor = torch.cat(out, dim=0) | |
| else: | |
| out_tensor = torch.zeros((1, 64,64, 3), dtype=torch.float32, device="cpu") | |
| if len(out_masks) > 0: | |
| out_mask_tensor = torch.cat(out_masks, dim=0) | |
| else: | |
| out_mask_tensor = torch.zeros((1,64,64), dtype=torch.float32, device="cpu") | |
| if not keep_model_loaded: | |
| print("Offloading model...") | |
| model.to(offload_device) | |
| mm.soft_empty_cache() | |
| return (out_tensor, out_mask_tensor, out_results, out_data) | |
| NODE_CLASS_MAPPINGS = { | |
| "DownloadAndLoadFlorence2Model": DownloadAndLoadFlorence2Model, | |
| "DownloadAndLoadFlorence2Lora": DownloadAndLoadFlorence2Lora, | |
| "Florence2ModelLoader": Florence2ModelLoader, | |
| "Florence2Run": Florence2Run, | |
| } | |
| NODE_DISPLAY_NAME_MAPPINGS = { | |
| "DownloadAndLoadFlorence2Model": "DownloadAndLoadFlorence2Model", | |
| "DownloadAndLoadFlorence2Lora": "DownloadAndLoadFlorence2Lora", | |
| "Florence2ModelLoader": "Florence2ModelLoader", | |
| "Florence2Run": "Florence2Run", | |
| } | |