|
|
import threading |
|
|
import tempfile |
|
|
from typing import Tuple, Optional |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image, ImageDraw |
|
|
import gradio as gr |
|
|
from transformers import AutoProcessor, logging, TextIteratorStreamer |
|
|
from torchvision.transforms.functional import resize as resize_api |
|
|
import html as _html |
|
|
|
|
|
from qwen_vl_utils import smart_resize |
|
|
from samr1 import SAMR1ForConditionalGeneration_qwen2p5 |
|
|
|
|
|
import os |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
logging.set_verbosity_error() |
|
|
|
|
|
|
|
|
MODEL_PATH = 'OuyBin/LENS' |
|
|
SAM_IMG_SIZE = 1024 |
|
|
MAX_NEW_TOKENS = 256 |
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
QWEN2_SYS = "You are a helpful assistant. " |
|
|
DEFAULT_TEMPLATE = "Output the bounding box of the {question} in the image." |
|
|
COT_TEMPLATE = \ |
|
|
"Locate \"{question}\", report the bbox coordinates in JSON format." \ |
|
|
"Compare the difference between objects and find the most closely matched one." \ |
|
|
"Output the thinking process in <think> </think> and final answer in <answer> </answer> tags." \ |
|
|
"Output the one bbox inside the interested object in JSON format." \ |
|
|
"i.e., <think>thinking process here</think>" \ |
|
|
"<answer>answer here</answer>" |
|
|
|
|
|
|
|
|
model = SAMR1ForConditionalGeneration_qwen2p5.from_pretrained( |
|
|
MODEL_PATH, |
|
|
torch_dtype=torch.bfloat16, |
|
|
attn_implementation="eager", |
|
|
ignore_mismatched_sizes=True, |
|
|
).to(DEVICE) |
|
|
processor = AutoProcessor.from_pretrained(MODEL_PATH) |
|
|
|
|
|
|
|
|
import re |
|
|
|
|
|
def parse_float_sequence_within(input_str): |
|
|
pattern = r"\[\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?)\s*\]" |
|
|
match = re.search(pattern, input_str) |
|
|
if match: |
|
|
return [float(match.group(i)) for i in range(1, 5)] |
|
|
pattern = r"\(\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?)\s*\)" |
|
|
match = re.search(pattern, input_str) |
|
|
if match: |
|
|
return [float(match.group(i)) for i in range(1, 5)] |
|
|
pattern = r"\(\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?)\),\s*\(\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?)\)" |
|
|
match = re.search(pattern, input_str) |
|
|
if match: |
|
|
return [float(match.group(i)) for i in range(1, 5)] |
|
|
return [0, 0, 0, 0] |
|
|
|
|
|
|
|
|
def preprocess(image_path: str, instruction: str): |
|
|
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) |
|
|
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) |
|
|
|
|
|
question_template = DEFAULT_TEMPLATE |
|
|
system_template = QWEN2_SYS |
|
|
|
|
|
image = Image.open(image_path).convert(mode="RGB") |
|
|
width, height = image.size |
|
|
|
|
|
resized_height, resized_width = smart_resize(height, width, 28, max_pixels=1_000_000) |
|
|
llm_image = image.resize((resized_width, resized_height)) |
|
|
|
|
|
sam_image = resize_api(image, (SAM_IMG_SIZE, SAM_IMG_SIZE)) |
|
|
sam_image = torch.from_numpy(np.array(sam_image)).permute(2, 0, 1).float() |
|
|
sam_image = (sam_image - pixel_mean) / pixel_std |
|
|
|
|
|
message = [ |
|
|
{"role": "system", "content": system_template}, |
|
|
{"role": "user", "content": [ |
|
|
{"type": "image"}, |
|
|
{"type": "text", "text": question_template.format(question=instruction)} |
|
|
]}, |
|
|
] |
|
|
|
|
|
return { |
|
|
"image": llm_image, |
|
|
"message": message, |
|
|
"sam_image": sam_image, |
|
|
"ori_hw": (height, width), |
|
|
"hw": (resized_height, resized_width), |
|
|
} |
|
|
|
|
|
|
|
|
def rescale_box(pred_box, from_hw: Tuple[int, int], to_hw: Tuple[int, int]): |
|
|
|
|
|
from_h, from_w = from_hw |
|
|
to_h, to_w = to_hw |
|
|
scale_w = to_w / from_w |
|
|
scale_h = to_h / from_h |
|
|
x1, y1, x2, y2 = pred_box |
|
|
return [x1 * scale_w, y1 * scale_h, x2 * scale_w, y2 * scale_h] |
|
|
|
|
|
|
|
|
def is_valid_box(box, image_hw: Optional[Tuple[int, int]] = None) -> bool: |
|
|
if not isinstance(box, (list, tuple)) or len(box) != 4: |
|
|
return False |
|
|
x1, y1, x2, y2 = box |
|
|
if not (np.isfinite(x1) and np.isfinite(y1) and np.isfinite(x2) and np.isfinite(y2)): |
|
|
return False |
|
|
if x1 >= x2 or y1 >= y2 or min(x1, y1, x2, y2) < 0: |
|
|
return False |
|
|
if image_hw: |
|
|
h, w = image_hw |
|
|
if x2 > w or y2 > h: |
|
|
return False |
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_single( |
|
|
input_data: dict, |
|
|
max_new_tokens: int = MAX_NEW_TOKENS, |
|
|
): |
|
|
message = input_data["message"] |
|
|
texts = [processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)] |
|
|
image_inputs = [input_data["image"]] |
|
|
|
|
|
inputs = processor( |
|
|
text=texts, |
|
|
images=image_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
inputs = {k: v.to(device=DEVICE) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} |
|
|
|
|
|
with torch.inference_mode(): |
|
|
llm_out = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
use_cache=True, |
|
|
do_sample=False, |
|
|
) |
|
|
|
|
|
|
|
|
new_attention_mask = torch.ones_like(llm_out, dtype=torch.int64, device=llm_out.device) |
|
|
pos = torch.where(llm_out == processor.tokenizer.pad_token_id) |
|
|
if pos[0].numel() > 0: |
|
|
new_attention_mask[pos] = 0 |
|
|
|
|
|
|
|
|
inputs.update({"input_ids": llm_out, "attention_mask": new_attention_mask}) |
|
|
sam_imgs = input_data["sam_image"].unsqueeze(0).to(device=DEVICE, dtype=torch.float32) |
|
|
inputs.update({"sam_images": sam_imgs}) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
_, low_res_mask = model(output_hidden_states=True, use_learnable_query=True, **inputs) |
|
|
|
|
|
pred_mask = model.postprocess_masks(low_res_mask[0], orig_hw=input_data["ori_hw"]) |
|
|
pred_mask = (pred_mask[:, 0] > 0).int() |
|
|
|
|
|
|
|
|
final_text = processor.tokenizer.decode(llm_out[0], skip_special_tokens=True) |
|
|
final_text = final_text.strip() |
|
|
|
|
|
return pred_mask.cpu(), final_text |
|
|
|
|
|
|
|
|
def evaluate_single_stream(input_data: dict, max_new_tokens: int = MAX_NEW_TOKENS): |
|
|
""" |
|
|
Generator: 按 token 增量 yield 文本(仅生成的文本),最终 return (pred_mask, final_text) |
|
|
""" |
|
|
message = input_data["message"] |
|
|
texts = [processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)] |
|
|
image_inputs = [input_data["image"]] |
|
|
|
|
|
inputs = processor( |
|
|
text=texts, |
|
|
images=image_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
inputs = {k: v.to(device=DEVICE) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} |
|
|
|
|
|
streamer = TextIteratorStreamer( |
|
|
processor.tokenizer, |
|
|
skip_prompt=True, |
|
|
skip_special_tokens=True, |
|
|
) |
|
|
|
|
|
generation_output = {"sequences": None, "exception": None} |
|
|
|
|
|
def gen_thread_fn(): |
|
|
try: |
|
|
with torch.inference_mode(): |
|
|
generation_output["sequences"] = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
use_cache=True, |
|
|
do_sample=False, |
|
|
streamer=streamer, |
|
|
) |
|
|
except Exception as e: |
|
|
generation_output["exception"] = e |
|
|
|
|
|
gen_thread = threading.Thread(target=gen_thread_fn, daemon=True) |
|
|
gen_thread.start() |
|
|
|
|
|
|
|
|
current_text = "" |
|
|
for text_piece in streamer: |
|
|
current_text += text_piece |
|
|
yield current_text |
|
|
|
|
|
gen_thread.join() |
|
|
if generation_output["exception"] is not None: |
|
|
raise generation_output["exception"] |
|
|
|
|
|
llm_out = generation_output["sequences"] |
|
|
new_attention_mask = torch.ones_like(llm_out, dtype=torch.int64, device=llm_out.device) |
|
|
pos = torch.where(llm_out == processor.tokenizer.pad_token_id) |
|
|
if pos[0].numel() > 0: |
|
|
new_attention_mask[pos] = 0 |
|
|
|
|
|
inputs.update({"input_ids": llm_out, "attention_mask": new_attention_mask}) |
|
|
sam_imgs = input_data["sam_image"].unsqueeze(0).to(device=DEVICE, dtype=torch.float32) |
|
|
inputs.update({"sam_images": sam_imgs}) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
_, low_res_mask = model(output_hidden_states=True, use_learnable_query=True, **inputs) |
|
|
|
|
|
pred_mask = model.postprocess_masks(low_res_mask[0], orig_hw=input_data["ori_hw"]) |
|
|
pred_mask = (pred_mask[:, 0] > 0).int().cpu() |
|
|
|
|
|
final_text = current_text.strip() |
|
|
return pred_mask, final_text |
|
|
|
|
|
|
|
|
def text_to_token_html(text: str) -> str: |
|
|
"""把纯文本转换成每个 token 带背景和边框的 HTML,保留换行。""" |
|
|
if text is None: |
|
|
return "" |
|
|
escaped = _html.escape(text) |
|
|
|
|
|
lines = escaped.split('\n') |
|
|
html_lines = [] |
|
|
for line in lines: |
|
|
if line == "": |
|
|
html_lines.append('<div class="token-line"> </div>') |
|
|
continue |
|
|
|
|
|
parts = line.split(' ') |
|
|
span_parts = [] |
|
|
for part in parts: |
|
|
if part == "": |
|
|
|
|
|
span_parts.append('<span class="token-space"> </span>') |
|
|
else: |
|
|
span_parts.append(f'<span class="reason-token">{part}</span>') |
|
|
html_lines.append('<span class="token-line">' + ' '.join(span_parts) + '</span>') |
|
|
return '<div class="reasoning-output">' + '<br>'.join(html_lines) + '</div>' |
|
|
|
|
|
|
|
|
|
|
|
def run_segmentation(image: Image.Image, instruction: str): |
|
|
image = image.convert("RGB") |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as f: |
|
|
tmp_path = f.name |
|
|
image.save(tmp_path) |
|
|
|
|
|
input_data = preprocess(tmp_path, instruction) |
|
|
|
|
|
|
|
|
token_streamer = evaluate_single_stream(input_data) |
|
|
iterator = iter(token_streamer) |
|
|
|
|
|
final_mask = None |
|
|
final_text = "" |
|
|
|
|
|
|
|
|
while True: |
|
|
try: |
|
|
partial_text = next(iterator) |
|
|
|
|
|
partial_html = text_to_token_html(partial_text) |
|
|
|
|
|
yield None, partial_html |
|
|
except StopIteration as e: |
|
|
ret = e.value |
|
|
if ret is not None: |
|
|
final_mask, final_text = ret |
|
|
break |
|
|
|
|
|
|
|
|
if final_mask is None or final_text == "": |
|
|
final_mask, final_text = evaluate_single(input_data) |
|
|
|
|
|
|
|
|
mask = final_mask.squeeze(0).numpy().astype('uint8') |
|
|
overlay = Image.new('RGBA', (mask.shape[1], mask.shape[0]), (0, 0, 0, 0)) |
|
|
overlay_arr = np.array(overlay) |
|
|
overlay_arr[..., 0] = 255 |
|
|
overlay_arr[..., 3] = (mask > 0) * 128 |
|
|
overlay = Image.fromarray(overlay_arr, mode='RGBA') |
|
|
|
|
|
base = image.convert('RGBA') |
|
|
composite = Image.alpha_composite(base, overlay) |
|
|
|
|
|
|
|
|
pred_box = parse_float_sequence_within(final_text) |
|
|
|
|
|
if not pred_box or len(pred_box) != 4: |
|
|
pred_box = None |
|
|
|
|
|
if pred_box is not None: |
|
|
resized_hw = input_data["hw"] |
|
|
ori_hw = input_data["ori_hw"] |
|
|
box = rescale_box(pred_box, resized_hw, ori_hw) |
|
|
if is_valid_box(box, ori_hw): |
|
|
draw = ImageDraw.Draw(composite) |
|
|
draw.rectangle(box, outline="green", width=4) |
|
|
|
|
|
|
|
|
final_html = text_to_token_html(final_text) |
|
|
yield composite, final_html |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CSS = """ |
|
|
body { background-color: #f5f7fa; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } |
|
|
h1 {text-align: center; } |
|
|
textarea { font-size: 18px !important; line-height: 1.4 !important; } |
|
|
input[type="text"] { font-size: 18px !important; } |
|
|
.reason-token { display: inline-block; border: 1px solid rgba(0,0,0,0.12); padding: 2px 6px; margin: 2px; border-radius: 6px; background: rgba(0,0,0,0.04); } |
|
|
.reasoning-output { white-space: pre-wrap; font-size: 16px; } |
|
|
.token-space { display: inline-block; width: 0.5em; } |
|
|
""" |
|
|
demo = gr.Interface( |
|
|
fn=run_segmentation, |
|
|
inputs=[ |
|
|
gr.Image(type="pil", label="upload image", height=512), |
|
|
gr.Textbox(lines=1, label="Instruction(i.e. left person / man in green)"), |
|
|
], |
|
|
outputs=[ |
|
|
gr.Image(type="pil", label="Segmentation Output", height=512), |
|
|
gr.HTML(label="Reasoning Output"), |
|
|
], |
|
|
title="<h1>LENS: Learning to Segment Anything with Unified Reinforced Reasoning</h1>", |
|
|
css=CSS, |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|