LENS / samr1.py
OuyBin's picture
app.py
561efb0
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLCausalLMOutputWithPast
from transformers.utils import logging
from qwen2_vl import Qwen2VLVisionConnectorSimple
from segment_anything_2.sam2.build_sam import build_sam2
logger = logging.get_logger(__name__)
local_rank = int(os.getenv("LOCAL_RANK", -1))
class SAMR1Config(Qwen2VLConfig):
def __init__(self, num_of_query=None, if_use_qwen_connector=None, if_include_sam=None, **kwargs):
super().__init__(**kwargs)
self.num_of_query = num_of_query
self.if_use_qwen_connector = if_use_qwen_connector
class SAMR1ForConditionalGeneration_qwen2p5(Qwen2_5_VLForConditionalGeneration):
"""
SAM-R1 model for conditional generation based on Qwen2VL.
Integrates a learnable query parameter and projection to SAM for joint vision-language tasks.
"""
config_class = SAMR1Config
def __init__(self, config, num_of_query=64, if_use_qwen_connector=True, **kwargs):
super().__init__(config)
model_num_of_query = config.num_of_query or num_of_query
model_if_use_qwen_connector = config.if_use_qwen_connector or if_use_qwen_connector
self.if_detach_res_loss = False
# Learnable context queries
self.learnable_query = nn.Parameter(torch.randn(1, model_num_of_query, config.hidden_size), requires_grad=True)
self.learnable_query.ds_full_param = True # Keep full param in DeepSpeed ZeRO
self.learnable_query.ds_persist = True
self.model_num_of_query = model_num_of_query
self.model_if_use_qwen_connector = model_if_use_qwen_connector
self.conv_1d = nn.Conv1d(config.hidden_size, config.hidden_size, kernel_size=model_num_of_query)
if model_if_use_qwen_connector:
self.connector = Qwen2VLVisionConnectorSimple(depth=4, seq_len=model_num_of_query, embed_dim=config.hidden_size)
# Projection to SAM feature space
self.proj_to_sam = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size),
nn.GELU(),
nn.Linear(config.hidden_size, 256)
)
# Build SAM backbone
self.sam = build_sam2("sam2_hiera_l.yaml", device=self.model.device)
del self.sam.maskmem_tpos_enc
del self.sam.memory_attention
del self.sam.memory_encoder
input_size = 1024
self._bb_feat_sizes = [
(input_size // 4, input_size // 4),
(input_size // 8, input_size // 8),
(input_size // 16, input_size // 16),
]
self._init_custom_params()
self.post_init()
def _init_custom_params(self):
"""Initialize custom parameters."""
nn.init.normal_(self.learnable_query, mean=0.0, std=0.02)
nn.init.normal_(self.conv_1d.weight, mean=0.0, std=0.02)
nn.init.zeros_(self.conv_1d.bias)
def set_if_detach_res_loss(self, if_detach_res_loss):
self.if_detach_res_loss = if_detach_res_loss
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_learnable_query: bool = False,
**kwargs,
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
"""
Extended forward method to support learnable query injection.
"""
if use_learnable_query:
attention_mask, inputs_embeds = self.process_llm_input(input_ids, pixel_values, image_grid_thw, attention_mask)
input_ids = None
sam_images = kwargs.pop("sam_images", None)
mllm_pred_bboxes = kwargs.pop("pred_bboxes", None)
mllm_pred_points = kwargs.pop("pred_points", None)
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
if sam_images is not None:
assert output_hidden_states is True
box_end_embedding = self.get_sam_embedding(outputs.hidden_states[-1], if_detach_res_loss=self.if_detach_res_loss)
sam_images = sam_images.to(box_end_embedding)
backbone_out = self.sam.forward_image(sam_images)
_, image_embeddings, _, _ = self.sam._prepare_backbone_features(backbone_out)
image_embeddings = [_.to(sam_images.dtype) for _ in image_embeddings]
batch_size = sam_images.shape[0]
if self.sam.directly_add_no_mem_embed:
image_embeddings[-1] = image_embeddings[-1] + self.sam.no_mem_embed
feats = [
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
for feat, feat_size in zip(image_embeddings[::-1], self._bb_feat_sizes[::-1])
][::-1]
_features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
pred_masks = []
for i in range(len(box_end_embedding)):
if mllm_pred_bboxes is not None and mllm_pred_points is not None:
pred_box = mllm_pred_bboxes[i]
pred_point = mllm_pred_points[i]
boxes = pred_box.unsqueeze(0).to(box_end_embedding)
coords = pred_point.unsqueeze(0).unsqueeze(0).to(box_end_embedding)
labels = torch.ones((1, 1), device=box_end_embedding.device, dtype=torch.long)
labels[(coords[..., 0] == 0) & (coords[..., 1] == 0)] = -1
points = (coords, labels)
else:
boxes = None
points = None
sparse_embeddings, dense_embeddings = self.sam.sam_prompt_encoder(
points=points,
boxes=boxes,
masks=None,
text_embeds=box_end_embedding[i].unsqueeze(0),
)
sparse_embeddings = sparse_embeddings.to(box_end_embedding[i].dtype)
high_res_features = [feat_level[i].unsqueeze(0) for feat_level in _features["high_res_feats"]]
low_res_masks, _, _, _ = self.sam.sam_mask_decoder(
image_embeddings=_features["image_embed"][i].unsqueeze(0),
image_pe=self.sam.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
repeat_image=True,
high_res_features=high_res_features,
)
pred_masks.append(low_res_masks)
return outputs, pred_masks
return outputs
def process_llm_input(self, input_ids, pixel_values, image_grid_thw, attention_mask):
"""
Convert input_ids to embeddings and append learnable queries at the end.
"""
if not isinstance(input_ids, torch.LongTensor):
input_ids = input_ids.to(torch.long)
inputs_embeds = self.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
inputs_embeds = torch.cat(
[inputs_embeds, self.learnable_query.repeat(inputs_embeds.size(0), 1, 1)], dim=1
)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
attention_mask = torch.cat(
[attention_mask, torch.ones(attention_mask.size(0), self.model_num_of_query).to(attention_mask)], dim=1
)
else:
attention_mask = torch.ones(inputs_embeds.size(0), inputs_embeds.size(1)).to(inputs_embeds.device)
return attention_mask, inputs_embeds
def get_sam_embedding(self, hidden_states, if_detach_res_loss=False):
"""
Extract and project SAM embedding from the last learnable queries in hidden states.
"""
query_hidden_state = hidden_states[:, -self.model_num_of_query:]
if if_detach_res_loss:
query_hidden_state = query_hidden_state.detach()
if self.model_if_use_qwen_connector:
query_hidden_state = self.connector(query_hidden_state)
query_hidden_state = self.conv_1d(query_hidden_state.transpose(1, 2)).transpose(1, 2).contiguous()
sam_embedding = self.proj_to_sam(query_hidden_state)
return sam_embedding
def postprocess_masks(self, masks, orig_hw):
masks = masks.float()
masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
return masks
__all__ = ["SAMR1ForConditionalGeneration_qwen2p5"]