import threading import tempfile from typing import Tuple, Optional import torch import numpy as np from PIL import Image, ImageDraw import gradio as gr from transformers import AutoProcessor, logging, TextIteratorStreamer from torchvision.transforms.functional import resize as resize_api import html as _html from qwen_vl_utils import smart_resize from samr1 import SAMR1ForConditionalGeneration_qwen2p5 import os os.environ["TOKENIZERS_PARALLELISM"] = "false" logging.set_verbosity_error() # ---------------- Config ---------------- MODEL_PATH = 'OuyBin/LENS' SAM_IMG_SIZE = 1024 MAX_NEW_TOKENS = 256 DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' QWEN2_SYS = "You are a helpful assistant. " DEFAULT_TEMPLATE = "Output the bounding box of the {question} in the image." COT_TEMPLATE = \ "Locate \"{question}\", report the bbox coordinates in JSON format." \ "Compare the difference between objects and find the most closely matched one." \ "Output the thinking process in and final answer in tags." \ "Output the one bbox inside the interested object in JSON format." \ "i.e., thinking process here" \ "answer here" # --------------- Load model & processor --------------- model = SAMR1ForConditionalGeneration_qwen2p5.from_pretrained( MODEL_PATH, torch_dtype=torch.bfloat16, attn_implementation="eager", ignore_mismatched_sizes=True, ).to(DEVICE) processor = AutoProcessor.from_pretrained(MODEL_PATH) # --------------- Helpers --------------- import re def parse_float_sequence_within(input_str): pattern = r"\[\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?)\s*\]" match = re.search(pattern, input_str) if match: return [float(match.group(i)) for i in range(1, 5)] pattern = r"\(\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?)\s*\)" match = re.search(pattern, input_str) if match: return [float(match.group(i)) for i in range(1, 5)] pattern = r"\(\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?)\),\s*\(\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?)\)" match = re.search(pattern, input_str) if match: return [float(match.group(i)) for i in range(1, 5)] return [0, 0, 0, 0] def preprocess(image_path: str, instruction: str): pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) question_template = DEFAULT_TEMPLATE system_template = QWEN2_SYS image = Image.open(image_path).convert(mode="RGB") width, height = image.size resized_height, resized_width = smart_resize(height, width, 28, max_pixels=1_000_000) llm_image = image.resize((resized_width, resized_height)) sam_image = resize_api(image, (SAM_IMG_SIZE, SAM_IMG_SIZE)) sam_image = torch.from_numpy(np.array(sam_image)).permute(2, 0, 1).float() sam_image = (sam_image - pixel_mean) / pixel_std message = [ {"role": "system", "content": system_template}, {"role": "user", "content": [ {"type": "image"}, {"type": "text", "text": question_template.format(question=instruction)} ]}, ] return { "image": llm_image, "message": message, "sam_image": sam_image, # float tensor on CPU; will be moved later "ori_hw": (height, width), "hw": (resized_height, resized_width), } def rescale_box(pred_box, from_hw: Tuple[int, int], to_hw: Tuple[int, int]): # pred_box: [x1, y1, x2, y2] in from_hw coordinate from_h, from_w = from_hw to_h, to_w = to_hw scale_w = to_w / from_w scale_h = to_h / from_h x1, y1, x2, y2 = pred_box return [x1 * scale_w, y1 * scale_h, x2 * scale_w, y2 * scale_h] def is_valid_box(box, image_hw: Optional[Tuple[int, int]] = None) -> bool: if not isinstance(box, (list, tuple)) or len(box) != 4: return False x1, y1, x2, y2 = box if not (np.isfinite(x1) and np.isfinite(y1) and np.isfinite(x2) and np.isfinite(y2)): return False if x1 >= x2 or y1 >= y2 or min(x1, y1, x2, y2) < 0: return False if image_hw: h, w = image_hw if x2 > w or y2 > h: return False return True # --------------- Generation functions --------------- def evaluate_single( input_data: dict, max_new_tokens: int = MAX_NEW_TOKENS, ): message = input_data["message"] texts = [processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)] image_inputs = [input_data["image"]] inputs = processor( text=texts, images=image_inputs, padding=True, return_tensors="pt", ) # 把 tensors 迁移到 device 并统一 dtype inputs = {k: v.to(device=DEVICE) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} with torch.inference_mode(): llm_out = model.generate( **inputs, max_new_tokens=max_new_tokens, use_cache=True, do_sample=False, ) # 处理 attention mask new_attention_mask = torch.ones_like(llm_out, dtype=torch.int64, device=llm_out.device) pos = torch.where(llm_out == processor.tokenizer.pad_token_id) if pos[0].numel() > 0: new_attention_mask[pos] = 0 # 拼接 SAM 输入 inputs.update({"input_ids": llm_out, "attention_mask": new_attention_mask}) sam_imgs = input_data["sam_image"].unsqueeze(0).to(device=DEVICE, dtype=torch.float32) inputs.update({"sam_images": sam_imgs}) with torch.inference_mode(): _, low_res_mask = model(output_hidden_states=True, use_learnable_query=True, **inputs) pred_mask = model.postprocess_masks(low_res_mask[0], orig_hw=input_data["ori_hw"]) pred_mask = (pred_mask[:, 0] > 0).int() # final text final_text = processor.tokenizer.decode(llm_out[0], skip_special_tokens=True) final_text = final_text.strip() return pred_mask.cpu(), final_text def evaluate_single_stream(input_data: dict, max_new_tokens: int = MAX_NEW_TOKENS): """ Generator: 按 token 增量 yield 文本(仅生成的文本),最终 return (pred_mask, final_text) """ message = input_data["message"] texts = [processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)] image_inputs = [input_data["image"]] inputs = processor( text=texts, images=image_inputs, padding=True, return_tensors="pt", ) inputs = {k: v.to(device=DEVICE) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} streamer = TextIteratorStreamer( processor.tokenizer, skip_prompt=True, skip_special_tokens=True, ) generation_output = {"sequences": None, "exception": None} def gen_thread_fn(): try: with torch.inference_mode(): generation_output["sequences"] = model.generate( **inputs, max_new_tokens=max_new_tokens, use_cache=True, do_sample=False, streamer=streamer, ) except Exception as e: generation_output["exception"] = e gen_thread = threading.Thread(target=gen_thread_fn, daemon=True) gen_thread.start() # streaming tokens current_text = "" for text_piece in streamer: current_text += text_piece yield current_text gen_thread.join() if generation_output["exception"] is not None: raise generation_output["exception"] llm_out = generation_output["sequences"] new_attention_mask = torch.ones_like(llm_out, dtype=torch.int64, device=llm_out.device) pos = torch.where(llm_out == processor.tokenizer.pad_token_id) if pos[0].numel() > 0: new_attention_mask[pos] = 0 inputs.update({"input_ids": llm_out, "attention_mask": new_attention_mask}) sam_imgs = input_data["sam_image"].unsqueeze(0).to(device=DEVICE, dtype=torch.float32) inputs.update({"sam_images": sam_imgs}) with torch.inference_mode(): _, low_res_mask = model(output_hidden_states=True, use_learnable_query=True, **inputs) pred_mask = model.postprocess_masks(low_res_mask[0], orig_hw=input_data["ori_hw"]) pred_mask = (pred_mask[:, 0] > 0).int().cpu() final_text = current_text.strip() return pred_mask, final_text def text_to_token_html(text: str) -> str: """把纯文本转换成每个 token 带背景和边框的 HTML,保留换行。""" if text is None: return "" escaped = _html.escape(text) # 按行处理,保留空行 lines = escaped.split('\n') html_lines = [] for line in lines: if line == "": html_lines.append('
 
') continue # 使用空格分割 token(保留多个连续空格) parts = line.split(' ') span_parts = [] for part in parts: if part == "": # 连续空格 -> 用不可见空白占位 span_parts.append(' ') else: span_parts.append(f'{part}') html_lines.append('' + ' '.join(span_parts) + '') return '
' + '
'.join(html_lines) + '
' # --------------- Gradio generator function --------------- def run_segmentation(image: Image.Image, instruction: str): image = image.convert("RGB") # 使用临时文件避免 race with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as f: tmp_path = f.name image.save(tmp_path) input_data = preprocess(tmp_path, instruction) # 1) 使用流式 generator token_streamer = evaluate_single_stream(input_data) iterator = iter(token_streamer) final_mask = None final_text = "" # 手动迭代以捕获 generator 的 return 值 (StopIteration.value) while True: try: partial_text = next(iterator) # partial_text 是纯文本,把它格式化为 HTML partial_html = text_to_token_html(partial_text) # 更新文本,image 保持 None yield None, partial_html except StopIteration as e: ret = e.value if ret is not None: final_mask, final_text = ret break # 若流式没有返回有效结果,使用同步版本作为备份(代价是额外一次推理) if final_mask is None or final_text == "": final_mask, final_text = evaluate_single(input_data) # 可视化 overlay mask = final_mask.squeeze(0).numpy().astype('uint8') overlay = Image.new('RGBA', (mask.shape[1], mask.shape[0]), (0, 0, 0, 0)) overlay_arr = np.array(overlay) overlay_arr[..., 0] = 255 overlay_arr[..., 3] = (mask > 0) * 128 overlay = Image.fromarray(overlay_arr, mode='RGBA') base = image.convert('RGBA') composite = Image.alpha_composite(base, overlay) # parse box pred_box = parse_float_sequence_within(final_text) # parse_float_sequence_within 可能返回 [] 或者 None if not pred_box or len(pred_box) != 4: pred_box = None if pred_box is not None: resized_hw = input_data["hw"] ori_hw = input_data["ori_hw"] box = rescale_box(pred_box, resized_hw, ori_hw) if is_valid_box(box, ori_hw): draw = ImageDraw.Draw(composite) draw.rectangle(box, outline="green", width=4) # 最后一次 yield:返回图片与完整文本(HTML 格式) final_html = text_to_token_html(final_text) yield composite, final_html # --------------- Launch UI --------------- # 全局 CSS:放大所有 textarea 的字体,并定义 token 的样式 CSS = """ body { background-color: #f5f7fa; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } h1 {text-align: center; } textarea { font-size: 18px !important; line-height: 1.4 !important; } input[type="text"] { font-size: 18px !important; } .reason-token { display: inline-block; border: 1px solid rgba(0,0,0,0.12); padding: 2px 6px; margin: 2px; border-radius: 6px; background: rgba(0,0,0,0.04); } .reasoning-output { white-space: pre-wrap; font-size: 16px; } .token-space { display: inline-block; width: 0.5em; } """ demo = gr.Interface( fn=run_segmentation, inputs=[ gr.Image(type="pil", label="upload image", height=512), gr.Textbox(lines=1, label="Instruction(i.e. left person / man in green)"), ], outputs=[ gr.Image(type="pil", label="Segmentation Output", height=512), gr.HTML(label="Reasoning Output"), ], title="

LENS: Learning to Segment Anything with Unified Reinforced Reasoning

", css=CSS, ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)