LENS / qwen2_vl.py
OuyBin's picture
app.py
561efb0
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm
from transformers.activations import ACT2FN
from transformers.utils import logging
import os
local_rank = int(os.environ.get("LOCAL_RANK", -1))
logger = logging.get_logger(__name__)
class VisionMlp(nn.Module):
def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
super().__init__()
self.fc1 = nn.Linear(dim, hidden_dim)
self.act = ACT2FN[hidden_act]
self.fc2 = nn.Linear(hidden_dim, dim)
def forward(self, x) -> torch.Tensor:
return self.fc2(self.act(self.fc1(x)))
class VisionSdpaAttentionSimple(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
B, L, C = hidden_states.shape
q, k, v = self.qkv(hidden_states).reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4).unbind(0)
attn_output = F.scaled_dot_product_attention(
q, k, v, dropout_p=0.0
)
attn_output = attn_output.transpose(1, 2).reshape(B, L, -1)
attn_output = self.proj(attn_output)
return attn_output
class Qwen2VLVisionBlockSimple(nn.Module):
def __init__(self, embed_dim=1280, num_heads=16, mlp_ratio=4, hidden_act="quick_gelu", attn_implementation="sdpa"):
super().__init__()
self.norm1 = LayerNorm(embed_dim, eps=1e-6)
self.norm2 = LayerNorm(embed_dim, eps=1e-6)
mlp_hidden_dim = int(embed_dim * mlp_ratio)
self.attn = VisionSdpaAttentionSimple(embed_dim, num_heads)
self.mlp = VisionMlp(dim=embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=hidden_act)
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb=None, position_embeddings=None):
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
class Qwen2VLVisionConnectorSimple(nn.Module):
def __init__(self, depth, seq_len, embed_dim=1280, num_heads=16, mlp_ratio=4, hidden_act="quick_gelu", attn_implementation="sdpa"):
super().__init__()
self.blocks = nn.ModuleList([Qwen2VLVisionBlockSimple(embed_dim, num_heads, mlp_ratio, hidden_act, attn_implementation) for _ in range(depth)])
self.pos_embed = nn.Parameter(torch.zeros(1, seq_len, embed_dim))
def forward(self, hidden_states, cu_seqlens=None, rotary_pos_emb=None):
B, L, C = hidden_states.shape
hidden_states = hidden_states + self.pos_embed
if cu_seqlens is None:
actual_lengths = torch.full((B,), L, dtype=torch.long)
cu_seqlens = torch.cat([torch.tensor([0]), actual_lengths.cumsum(dim=0)])
for block in self.blocks:
hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb)
return hidden_states