File size: 3,468 Bytes
561efb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm

from transformers.activations import ACT2FN
from transformers.utils import logging

import os
local_rank = int(os.environ.get("LOCAL_RANK", -1))

logger = logging.get_logger(__name__)



class VisionMlp(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = ACT2FN[hidden_act]
        self.fc2 = nn.Linear(hidden_dim, dim)

    def forward(self, x) -> torch.Tensor:
        return self.fc2(self.act(self.fc1(x)))


class VisionSdpaAttentionSimple(nn.Module):
    def __init__(self, dim: int, num_heads: int = 16) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: Optional[torch.Tensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> torch.Tensor:
        B, L, C = hidden_states.shape
        q, k, v = self.qkv(hidden_states).reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4).unbind(0)

        attn_output = F.scaled_dot_product_attention(
            q, k, v, dropout_p=0.0
        )
        attn_output = attn_output.transpose(1, 2).reshape(B, L, -1)
        attn_output = self.proj(attn_output)
        return attn_output


class Qwen2VLVisionBlockSimple(nn.Module):
    def __init__(self, embed_dim=1280, num_heads=16, mlp_ratio=4, hidden_act="quick_gelu", attn_implementation="sdpa"):
        super().__init__()
        self.norm1 = LayerNorm(embed_dim, eps=1e-6)
        self.norm2 = LayerNorm(embed_dim, eps=1e-6)
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.attn = VisionSdpaAttentionSimple(embed_dim, num_heads)
        self.mlp = VisionMlp(dim=embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=hidden_act)

    def forward(self, hidden_states, cu_seqlens, rotary_pos_emb=None, position_embeddings=None):
        hidden_states = hidden_states + self.attn(
            self.norm1(hidden_states),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            position_embeddings=position_embeddings,
        )
        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
        return hidden_states


class Qwen2VLVisionConnectorSimple(nn.Module):
    def __init__(self, depth, seq_len, embed_dim=1280, num_heads=16, mlp_ratio=4, hidden_act="quick_gelu", attn_implementation="sdpa"):
        super().__init__()
        self.blocks = nn.ModuleList([Qwen2VLVisionBlockSimple(embed_dim, num_heads, mlp_ratio, hidden_act, attn_implementation) for _ in range(depth)])

        self.pos_embed = nn.Parameter(torch.zeros(1, seq_len, embed_dim))

    def forward(self, hidden_states, cu_seqlens=None, rotary_pos_emb=None):
        B, L, C = hidden_states.shape
        hidden_states = hidden_states + self.pos_embed

        if cu_seqlens is None:
            actual_lengths = torch.full((B,), L, dtype=torch.long)
            cu_seqlens = torch.cat([torch.tensor([0]), actual_lengths.cumsum(dim=0)])
        for block in self.blocks:
            hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb)
        return hidden_states