from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import LayerNorm from transformers.activations import ACT2FN from transformers.utils import logging import os local_rank = int(os.environ.get("LOCAL_RANK", -1)) logger = logging.get_logger(__name__) class VisionMlp(nn.Module): def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None: super().__init__() self.fc1 = nn.Linear(dim, hidden_dim) self.act = ACT2FN[hidden_act] self.fc2 = nn.Linear(hidden_dim, dim) def forward(self, x) -> torch.Tensor: return self.fc2(self.act(self.fc1(x))) class VisionSdpaAttentionSimple(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: B, L, C = hidden_states.shape q, k, v = self.qkv(hidden_states).reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4).unbind(0) attn_output = F.scaled_dot_product_attention( q, k, v, dropout_p=0.0 ) attn_output = attn_output.transpose(1, 2).reshape(B, L, -1) attn_output = self.proj(attn_output) return attn_output class Qwen2VLVisionBlockSimple(nn.Module): def __init__(self, embed_dim=1280, num_heads=16, mlp_ratio=4, hidden_act="quick_gelu", attn_implementation="sdpa"): super().__init__() self.norm1 = LayerNorm(embed_dim, eps=1e-6) self.norm2 = LayerNorm(embed_dim, eps=1e-6) mlp_hidden_dim = int(embed_dim * mlp_ratio) self.attn = VisionSdpaAttentionSimple(embed_dim, num_heads) self.mlp = VisionMlp(dim=embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=hidden_act) def forward(self, hidden_states, cu_seqlens, rotary_pos_emb=None, position_embeddings=None): hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states class Qwen2VLVisionConnectorSimple(nn.Module): def __init__(self, depth, seq_len, embed_dim=1280, num_heads=16, mlp_ratio=4, hidden_act="quick_gelu", attn_implementation="sdpa"): super().__init__() self.blocks = nn.ModuleList([Qwen2VLVisionBlockSimple(embed_dim, num_heads, mlp_ratio, hidden_act, attn_implementation) for _ in range(depth)]) self.pos_embed = nn.Parameter(torch.zeros(1, seq_len, embed_dim)) def forward(self, hidden_states, cu_seqlens=None, rotary_pos_emb=None): B, L, C = hidden_states.shape hidden_states = hidden_states + self.pos_embed if cu_seqlens is None: actual_lengths = torch.full((B,), L, dtype=torch.long) cu_seqlens = torch.cat([torch.tensor([0]), actual_lengths.cumsum(dim=0)]) for block in self.blocks: hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb) return hidden_states