LENS / app.py
OuyBin's picture
Update app.py
2780bb2 verified
import threading
import tempfile
from typing import Tuple, Optional
import torch
import numpy as np
from PIL import Image, ImageDraw
import gradio as gr
from transformers import AutoProcessor, logging, TextIteratorStreamer
from torchvision.transforms.functional import resize as resize_api
import html as _html
from qwen_vl_utils import smart_resize
from samr1 import SAMR1ForConditionalGeneration_qwen2p5
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.set_verbosity_error()
# ---------------- Config ----------------
MODEL_PATH = 'OuyBin/LENS'
SAM_IMG_SIZE = 1024
MAX_NEW_TOKENS = 256
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
QWEN2_SYS = "You are a helpful assistant. "
DEFAULT_TEMPLATE = "Output the bounding box of the {question} in the image."
COT_TEMPLATE = \
"Locate \"{question}\", report the bbox coordinates in JSON format." \
"Compare the difference between objects and find the most closely matched one." \
"Output the thinking process in <think> </think> and final answer in <answer> </answer> tags." \
"Output the one bbox inside the interested object in JSON format." \
"i.e., <think>thinking process here</think>" \
"<answer>answer here</answer>"
# --------------- Load model & processor ---------------
model = SAMR1ForConditionalGeneration_qwen2p5.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
attn_implementation="eager",
ignore_mismatched_sizes=True,
).to(DEVICE)
processor = AutoProcessor.from_pretrained(MODEL_PATH)
# --------------- Helpers ---------------
import re
def parse_float_sequence_within(input_str):
pattern = r"\[\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?)\s*\]"
match = re.search(pattern, input_str)
if match:
return [float(match.group(i)) for i in range(1, 5)]
pattern = r"\(\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?)\s*\)"
match = re.search(pattern, input_str)
if match:
return [float(match.group(i)) for i in range(1, 5)]
pattern = r"\(\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?)\),\s*\(\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?)\)"
match = re.search(pattern, input_str)
if match:
return [float(match.group(i)) for i in range(1, 5)]
return [0, 0, 0, 0]
def preprocess(image_path: str, instruction: str):
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
question_template = DEFAULT_TEMPLATE
system_template = QWEN2_SYS
image = Image.open(image_path).convert(mode="RGB")
width, height = image.size
resized_height, resized_width = smart_resize(height, width, 28, max_pixels=1_000_000)
llm_image = image.resize((resized_width, resized_height))
sam_image = resize_api(image, (SAM_IMG_SIZE, SAM_IMG_SIZE))
sam_image = torch.from_numpy(np.array(sam_image)).permute(2, 0, 1).float()
sam_image = (sam_image - pixel_mean) / pixel_std
message = [
{"role": "system", "content": system_template},
{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": question_template.format(question=instruction)}
]},
]
return {
"image": llm_image,
"message": message,
"sam_image": sam_image, # float tensor on CPU; will be moved later
"ori_hw": (height, width),
"hw": (resized_height, resized_width),
}
def rescale_box(pred_box, from_hw: Tuple[int, int], to_hw: Tuple[int, int]):
# pred_box: [x1, y1, x2, y2] in from_hw coordinate
from_h, from_w = from_hw
to_h, to_w = to_hw
scale_w = to_w / from_w
scale_h = to_h / from_h
x1, y1, x2, y2 = pred_box
return [x1 * scale_w, y1 * scale_h, x2 * scale_w, y2 * scale_h]
def is_valid_box(box, image_hw: Optional[Tuple[int, int]] = None) -> bool:
if not isinstance(box, (list, tuple)) or len(box) != 4:
return False
x1, y1, x2, y2 = box
if not (np.isfinite(x1) and np.isfinite(y1) and np.isfinite(x2) and np.isfinite(y2)):
return False
if x1 >= x2 or y1 >= y2 or min(x1, y1, x2, y2) < 0:
return False
if image_hw:
h, w = image_hw
if x2 > w or y2 > h:
return False
return True
# --------------- Generation functions ---------------
def evaluate_single(
input_data: dict,
max_new_tokens: int = MAX_NEW_TOKENS,
):
message = input_data["message"]
texts = [processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)]
image_inputs = [input_data["image"]]
inputs = processor(
text=texts,
images=image_inputs,
padding=True,
return_tensors="pt",
)
# 把 tensors 迁移到 device 并统一 dtype
inputs = {k: v.to(device=DEVICE) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
with torch.inference_mode():
llm_out = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
use_cache=True,
do_sample=False,
)
# 处理 attention mask
new_attention_mask = torch.ones_like(llm_out, dtype=torch.int64, device=llm_out.device)
pos = torch.where(llm_out == processor.tokenizer.pad_token_id)
if pos[0].numel() > 0:
new_attention_mask[pos] = 0
# 拼接 SAM 输入
inputs.update({"input_ids": llm_out, "attention_mask": new_attention_mask})
sam_imgs = input_data["sam_image"].unsqueeze(0).to(device=DEVICE, dtype=torch.float32)
inputs.update({"sam_images": sam_imgs})
with torch.inference_mode():
_, low_res_mask = model(output_hidden_states=True, use_learnable_query=True, **inputs)
pred_mask = model.postprocess_masks(low_res_mask[0], orig_hw=input_data["ori_hw"])
pred_mask = (pred_mask[:, 0] > 0).int()
# final text
final_text = processor.tokenizer.decode(llm_out[0], skip_special_tokens=True)
final_text = final_text.strip()
return pred_mask.cpu(), final_text
def evaluate_single_stream(input_data: dict, max_new_tokens: int = MAX_NEW_TOKENS):
"""
Generator: 按 token 增量 yield 文本(仅生成的文本),最终 return (pred_mask, final_text)
"""
message = input_data["message"]
texts = [processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)]
image_inputs = [input_data["image"]]
inputs = processor(
text=texts,
images=image_inputs,
padding=True,
return_tensors="pt",
)
inputs = {k: v.to(device=DEVICE) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
streamer = TextIteratorStreamer(
processor.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
generation_output = {"sequences": None, "exception": None}
def gen_thread_fn():
try:
with torch.inference_mode():
generation_output["sequences"] = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
use_cache=True,
do_sample=False,
streamer=streamer,
)
except Exception as e:
generation_output["exception"] = e
gen_thread = threading.Thread(target=gen_thread_fn, daemon=True)
gen_thread.start()
# streaming tokens
current_text = ""
for text_piece in streamer:
current_text += text_piece
yield current_text
gen_thread.join()
if generation_output["exception"] is not None:
raise generation_output["exception"]
llm_out = generation_output["sequences"]
new_attention_mask = torch.ones_like(llm_out, dtype=torch.int64, device=llm_out.device)
pos = torch.where(llm_out == processor.tokenizer.pad_token_id)
if pos[0].numel() > 0:
new_attention_mask[pos] = 0
inputs.update({"input_ids": llm_out, "attention_mask": new_attention_mask})
sam_imgs = input_data["sam_image"].unsqueeze(0).to(device=DEVICE, dtype=torch.float32)
inputs.update({"sam_images": sam_imgs})
with torch.inference_mode():
_, low_res_mask = model(output_hidden_states=True, use_learnable_query=True, **inputs)
pred_mask = model.postprocess_masks(low_res_mask[0], orig_hw=input_data["ori_hw"])
pred_mask = (pred_mask[:, 0] > 0).int().cpu()
final_text = current_text.strip()
return pred_mask, final_text
def text_to_token_html(text: str) -> str:
"""把纯文本转换成每个 token 带背景和边框的 HTML,保留换行。"""
if text is None:
return ""
escaped = _html.escape(text)
# 按行处理,保留空行
lines = escaped.split('\n')
html_lines = []
for line in lines:
if line == "":
html_lines.append('<div class="token-line">&nbsp;</div>')
continue
# 使用空格分割 token(保留多个连续空格)
parts = line.split(' ')
span_parts = []
for part in parts:
if part == "":
# 连续空格 -> 用不可见空白占位
span_parts.append('<span class="token-space">&nbsp;</span>')
else:
span_parts.append(f'<span class="reason-token">{part}</span>')
html_lines.append('<span class="token-line">' + ' '.join(span_parts) + '</span>')
return '<div class="reasoning-output">' + '<br>'.join(html_lines) + '</div>'
# --------------- Gradio generator function ---------------
def run_segmentation(image: Image.Image, instruction: str):
image = image.convert("RGB")
# 使用临时文件避免 race
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as f:
tmp_path = f.name
image.save(tmp_path)
input_data = preprocess(tmp_path, instruction)
# 1) 使用流式 generator
token_streamer = evaluate_single_stream(input_data)
iterator = iter(token_streamer)
final_mask = None
final_text = ""
# 手动迭代以捕获 generator 的 return 值 (StopIteration.value)
while True:
try:
partial_text = next(iterator)
# partial_text 是纯文本,把它格式化为 HTML
partial_html = text_to_token_html(partial_text)
# 更新文本,image 保持 None
yield None, partial_html
except StopIteration as e:
ret = e.value
if ret is not None:
final_mask, final_text = ret
break
# 若流式没有返回有效结果,使用同步版本作为备份(代价是额外一次推理)
if final_mask is None or final_text == "":
final_mask, final_text = evaluate_single(input_data)
# 可视化 overlay
mask = final_mask.squeeze(0).numpy().astype('uint8')
overlay = Image.new('RGBA', (mask.shape[1], mask.shape[0]), (0, 0, 0, 0))
overlay_arr = np.array(overlay)
overlay_arr[..., 0] = 255
overlay_arr[..., 3] = (mask > 0) * 128
overlay = Image.fromarray(overlay_arr, mode='RGBA')
base = image.convert('RGBA')
composite = Image.alpha_composite(base, overlay)
# parse box
pred_box = parse_float_sequence_within(final_text)
# parse_float_sequence_within 可能返回 [] 或者 None
if not pred_box or len(pred_box) != 4:
pred_box = None
if pred_box is not None:
resized_hw = input_data["hw"]
ori_hw = input_data["ori_hw"]
box = rescale_box(pred_box, resized_hw, ori_hw)
if is_valid_box(box, ori_hw):
draw = ImageDraw.Draw(composite)
draw.rectangle(box, outline="green", width=4)
# 最后一次 yield:返回图片与完整文本(HTML 格式)
final_html = text_to_token_html(final_text)
yield composite, final_html
# --------------- Launch UI ---------------
# 全局 CSS:放大所有 textarea 的字体,并定义 token 的样式
CSS = """
body { background-color: #f5f7fa; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; }
h1 {text-align: center; }
textarea { font-size: 18px !important; line-height: 1.4 !important; }
input[type="text"] { font-size: 18px !important; }
.reason-token { display: inline-block; border: 1px solid rgba(0,0,0,0.12); padding: 2px 6px; margin: 2px; border-radius: 6px; background: rgba(0,0,0,0.04); }
.reasoning-output { white-space: pre-wrap; font-size: 16px; }
.token-space { display: inline-block; width: 0.5em; }
"""
demo = gr.Interface(
fn=run_segmentation,
inputs=[
gr.Image(type="pil", label="upload image", height=512),
gr.Textbox(lines=1, label="Instruction(i.e. left person / man in green)"),
],
outputs=[
gr.Image(type="pil", label="Segmentation Output", height=512),
gr.HTML(label="Reasoning Output"),
],
title="<h1>LENS: Learning to Segment Anything with Unified Reinforced Reasoning</h1>",
css=CSS,
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)