Upload folder using huggingface_hub
Browse files- chat_template.jinja +5 -2
- config.json +3 -1
- configuration_moondream3.py +1 -1
- image_processing_moondream3.py +7 -1
- modeling_moondream3.py +320 -101
- processing_moondream3.py +31 -31
chat_template.jinja
CHANGED
|
@@ -57,15 +57,18 @@
|
|
| 57 |
{{ raise_exception("caption length must be one of: short, normal, long.") }}
|
| 58 |
{%- endif -%}
|
| 59 |
<|md_reserved_0|>describe<|md_reserved_1|>{{ length }}<|md_reserved_2|>
|
|
|
|
|
|
|
|
|
|
| 60 |
{%- elif lower.startswith('query:') -%}
|
| 61 |
{% set q = text[6:] | trim -%}
|
| 62 |
<|md_reserved_0|>query<|md_reserved_1|>{{ q }}<|md_reserved_2|>
|
| 63 |
{%- elif lower.startswith('detect:') -%}
|
| 64 |
{% set q = text[7:] | trim -%}
|
| 65 |
-
<|md_reserved_0|>det<|md_reserved_1|>{{ q }}<|md_reserved_2|>
|
| 66 |
{%- elif lower.startswith('point:') -%}
|
| 67 |
{% set q = text[6:] | trim -%}
|
| 68 |
-
<|md_reserved_0|>point<|md_reserved_1|>{{ q }}<|md_reserved_2|>
|
| 69 |
{%- else -%}
|
| 70 |
{% set q = text -%}
|
| 71 |
<|md_reserved_0|>query<|md_reserved_1|>{{ q }}<|md_reserved_2|>
|
|
|
|
| 57 |
{{ raise_exception("caption length must be one of: short, normal, long.") }}
|
| 58 |
{%- endif -%}
|
| 59 |
<|md_reserved_0|>describe<|md_reserved_1|>{{ length }}<|md_reserved_2|>
|
| 60 |
+
{%- elif lower.startswith('reason:') -%}
|
| 61 |
+
{% set q = text[7:] | trim -%}
|
| 62 |
+
<|md_reserved_0|>query<|md_reserved_1|>{{ q }}<|md_reserved_2|><|md_reserved_3|>
|
| 63 |
{%- elif lower.startswith('query:') -%}
|
| 64 |
{% set q = text[6:] | trim -%}
|
| 65 |
<|md_reserved_0|>query<|md_reserved_1|>{{ q }}<|md_reserved_2|>
|
| 66 |
{%- elif lower.startswith('detect:') -%}
|
| 67 |
{% set q = text[7:] | trim -%}
|
| 68 |
+
<|md_reserved_0|>det<|md_reserved_1|> {{ q }}<|md_reserved_2|>
|
| 69 |
{%- elif lower.startswith('point:') -%}
|
| 70 |
{% set q = text[6:] | trim -%}
|
| 71 |
+
<|md_reserved_0|>point<|md_reserved_1|> {{ q }}<|md_reserved_2|>
|
| 72 |
{%- else -%}
|
| 73 |
{% set q = text -%}
|
| 74 |
<|md_reserved_0|>query<|md_reserved_1|>{{ q }}<|md_reserved_2|>
|
config.json
CHANGED
|
@@ -37,7 +37,9 @@
|
|
| 37 |
"output_router_logits": false,
|
| 38 |
"prefix_attn": 730,
|
| 39 |
"bos_token_id": 0,
|
| 40 |
-
"
|
|
|
|
|
|
|
| 41 |
"rope_parameters": {
|
| 42 |
"rope_theta": 1500000.0,
|
| 43 |
"rope_type": "default"
|
|
|
|
| 37 |
"output_router_logits": false,
|
| 38 |
"prefix_attn": 730,
|
| 39 |
"bos_token_id": 0,
|
| 40 |
+
"eos_token_id": 0,
|
| 41 |
+
"coord_token_id": 5,
|
| 42 |
+
"rms_norm_eps": 1e-05,
|
| 43 |
"rope_parameters": {
|
| 44 |
"rope_theta": 1500000.0,
|
| 45 |
"rope_type": "default"
|
configuration_moondream3.py
CHANGED
|
@@ -54,7 +54,7 @@ class Moondream3TextConfig(PretrainedConfig):
|
|
| 54 |
The non-linear activation function.
|
| 55 |
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 56 |
The standard deviation of the truncated_normal_initializer.
|
| 57 |
-
rms_norm_eps (`float`, *optional*, defaults to 1e-
|
| 58 |
The epsilon used by the rms normalization layers.
|
| 59 |
use_cache (`bool`, *optional*, defaults to `True`):
|
| 60 |
Whether or not the model should return the last key/values attentions.
|
|
|
|
| 54 |
The non-linear activation function.
|
| 55 |
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 56 |
The standard deviation of the truncated_normal_initializer.
|
| 57 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-5):
|
| 58 |
The epsilon used by the rms normalization layers.
|
| 59 |
use_cache (`bool`, *optional*, defaults to `True`):
|
| 60 |
Whether or not the model should return the last key/values attentions.
|
image_processing_moondream3.py
CHANGED
|
@@ -204,7 +204,13 @@ def prepare_crops(image, max_crops=12, overlap_margin=4):
|
|
| 204 |
)
|
| 205 |
all_crops = overlap_crops["crops"]
|
| 206 |
all_crops = np.transpose(all_crops, (0, 3, 1, 2))
|
| 207 |
-
all_crops =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
return all_crops.tolist(), overlap_crops["tiling"]
|
| 209 |
|
| 210 |
class Moondream3ImageProcessor(BaseImageProcessor):
|
|
|
|
| 204 |
)
|
| 205 |
all_crops = overlap_crops["crops"]
|
| 206 |
all_crops = np.transpose(all_crops, (0, 3, 1, 2))
|
| 207 |
+
all_crops = all_crops = (
|
| 208 |
+
torch.from_numpy(all_crops)
|
| 209 |
+
.to(device="cpu", dtype=torch.bfloat16)
|
| 210 |
+
.div_(255.0)
|
| 211 |
+
.sub_(0.5)
|
| 212 |
+
.div_(0.5)
|
| 213 |
+
)
|
| 214 |
return all_crops.tolist(), overlap_crops["tiling"]
|
| 215 |
|
| 216 |
class Moondream3ImageProcessor(BaseImageProcessor):
|
modeling_moondream3.py
CHANGED
|
@@ -26,6 +26,7 @@ from PIL import Image
|
|
| 26 |
from transformers.activations import ACT2FN
|
| 27 |
from transformers.cache_utils import Cache, DynamicCache
|
| 28 |
from transformers.masking_utils import create_causal_mask
|
|
|
|
| 29 |
from transformers.modeling_outputs import (
|
| 30 |
BaseModelOutputWithPast,
|
| 31 |
CausalLMOutputWithPast,
|
|
@@ -34,6 +35,7 @@ from transformers.processing_utils import Unpack
|
|
| 34 |
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 35 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 36 |
from transformers.generation import GenerationMixin
|
|
|
|
| 37 |
from transformers.utils import logging, TransformersKwargs
|
| 38 |
from .configuration_moondream3 import Moondream3Config, Moondream3TextConfig, Moondream3VisionConfig, Moondream3RegionConfig
|
| 39 |
|
|
@@ -41,48 +43,49 @@ logger = logging.get_logger(__name__)
|
|
| 41 |
|
| 42 |
_CONFIG_FOR_DOC = "Moondream3Config"
|
| 43 |
|
| 44 |
-
|
| 45 |
-
"""Rotates half the hidden dims of the input."""
|
| 46 |
-
x1 = x[..., : x.shape[-1] // 2]
|
| 47 |
-
x2 = x[..., x.shape[-1] // 2 :]
|
| 48 |
-
return torch.cat((-x2, x1), dim=-1)
|
| 49 |
-
|
| 50 |
|
| 51 |
-
|
| 52 |
-
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
Args:
|
| 55 |
-
q
|
| 56 |
-
k
|
| 57 |
-
cos
|
| 58 |
-
sin
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 62 |
-
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 63 |
-
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 64 |
-
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 65 |
-
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 66 |
-
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 67 |
-
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 68 |
Returns:
|
| 69 |
-
|
| 70 |
"""
|
| 71 |
-
rot_dim = cos.shape[-1]
|
| 72 |
|
| 73 |
-
def
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
x_rot =
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
-
|
| 80 |
-
sin = sin.unsqueeze(unsqueeze_dim)
|
| 81 |
|
| 82 |
-
|
|
|
|
| 83 |
|
| 84 |
class Moondream3RotaryEmbedding(nn.Module):
|
| 85 |
-
inv_freq: torch.Tensor
|
| 86 |
|
| 87 |
def __init__(self, config: Moondream3Config, device=None):
|
| 88 |
super().__init__()
|
|
@@ -108,43 +111,34 @@ class Moondream3RotaryEmbedding(nn.Module):
|
|
| 108 |
) -> tuple["torch.Tensor", float]:
|
| 109 |
"""
|
| 110 |
Computes the inverse frequencies according to the original RoPE implementation
|
| 111 |
-
Args:
|
| 112 |
-
config ([`~transformers.PreTrainedConfig`]):
|
| 113 |
-
The model configuration.
|
| 114 |
-
device (`torch.device`):
|
| 115 |
-
The device to use for initialization of the inverse frequencies.
|
| 116 |
-
seq_len (`int`, *optional*):
|
| 117 |
-
The current sequence length. Unused for this type of RoPE.
|
| 118 |
-
Returns:
|
| 119 |
-
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 120 |
-
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
| 121 |
"""
|
| 122 |
-
base = config.rope_parameters["rope_theta"]
|
| 123 |
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
| 124 |
-
dim
|
| 125 |
-
|
| 126 |
-
attention_factor = 1.0
|
| 127 |
|
| 128 |
-
# Compute the inverse frequencies
|
| 129 |
inv_freq = 1.0 / (
|
| 130 |
-
base ** (torch.arange(0, dim, 2, dtype=torch.
|
| 131 |
)
|
|
|
|
|
|
|
| 132 |
return inv_freq, attention_factor
|
| 133 |
|
| 134 |
@torch.no_grad()
|
| 135 |
-
@dynamic_rope_update
|
| 136 |
def forward(self, x, position_ids):
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 143 |
-
emb = torch.cat((freqs, freqs), dim=-1)
|
| 144 |
-
cos = emb.cos() * self.attention_scaling
|
| 145 |
-
sin = emb.sin() * self.attention_scaling
|
| 146 |
|
| 147 |
-
return
|
| 148 |
|
| 149 |
|
| 150 |
class Moondream3Attention(nn.Module):
|
|
@@ -202,6 +196,8 @@ class Moondream3Attention(nn.Module):
|
|
| 202 |
**kwargs,
|
| 203 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
| 204 |
input_shape = hidden_states.shape[:-1]
|
|
|
|
|
|
|
| 205 |
bsz, q_len, _ = hidden_states.size()
|
| 206 |
|
| 207 |
# Get qkv combined for tau (before splitting)
|
|
@@ -223,7 +219,9 @@ class Moondream3Attention(nn.Module):
|
|
| 223 |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 224 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 225 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 226 |
-
|
|
|
|
|
|
|
| 227 |
|
| 228 |
if self.use_tau:
|
| 229 |
query_states = query_states * tau_q
|
|
@@ -234,20 +232,38 @@ class Moondream3Attention(nn.Module):
|
|
| 234 |
tau_v_repeated = tau_v
|
| 235 |
value_states = value_states * tau_v_repeated
|
| 236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
cos, sin = None, None
|
| 238 |
if position_embeddings is not None:
|
| 239 |
cos, sin = position_embeddings
|
|
|
|
|
|
|
|
|
|
| 240 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 241 |
|
| 242 |
-
if
|
|
|
|
|
|
|
| 243 |
|
|
|
|
| 244 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 245 |
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 246 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS["sdpa"](
|
| 248 |
self,
|
| 249 |
query_states,
|
| 250 |
-
key_states,
|
| 251 |
value_states,
|
| 252 |
attention_mask,
|
| 253 |
dropout=0.0 if not self.training else self.attention_dropout,
|
|
@@ -256,6 +272,9 @@ class Moondream3Attention(nn.Module):
|
|
| 256 |
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 257 |
attn_output = self.o_proj(attn_output)
|
| 258 |
|
|
|
|
|
|
|
|
|
|
| 259 |
return attn_output, attn_weights
|
| 260 |
|
| 261 |
class Moondream3MLP(nn.Module):
|
|
@@ -266,7 +285,6 @@ class Moondream3MLP(nn.Module):
|
|
| 266 |
self.out_size = self.hidden_size if out_size is None else out_size
|
| 267 |
self.hidden_act = hidden_act
|
| 268 |
self.gated = gated
|
| 269 |
-
# Ungated MLP: up_proj and down_proj following HF conventions
|
| 270 |
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
|
| 271 |
self.down_proj = nn.Linear(self.intermediate_size, self.out_size, bias=bias)
|
| 272 |
self.gate_proj = None
|
|
@@ -276,17 +294,20 @@ class Moondream3MLP(nn.Module):
|
|
| 276 |
|
| 277 |
def forward(self, x) -> torch.Tensor:
|
| 278 |
if self.gated:
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
|
|
|
|
|
|
| 282 |
else:
|
| 283 |
x = self.act_fn(self.up_proj(x))
|
| 284 |
return self.down_proj(x)
|
| 285 |
|
| 286 |
|
| 287 |
class Moondream3SparseMoeBlock(nn.Module):
|
| 288 |
-
def __init__(self, config: Moondream3TextConfig):
|
| 289 |
super().__init__()
|
|
|
|
| 290 |
self.hidden_size = config.hidden_size
|
| 291 |
self.moe_intermediate_size = config.moe_intermediate_size
|
| 292 |
self.num_experts = config.num_experts
|
|
@@ -295,32 +316,29 @@ class Moondream3SparseMoeBlock(nn.Module):
|
|
| 295 |
self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=True)
|
| 296 |
self.experts = nn.ModuleList([Moondream3MLP(hidden_size=self.hidden_size, intermediate_size=self.moe_intermediate_size, hidden_act="gelu", gated=True, bias=False) for _ in range(self.num_experts)])
|
| 297 |
|
| 298 |
-
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 299 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 300 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 301 |
router_logits: torch.Tensor = self.gate(hidden_states)
|
| 302 |
-
|
| 303 |
-
routing_weights = F.softmax(
|
| 304 |
-
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
| 305 |
-
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
| 306 |
routing_weights = routing_weights.to(hidden_states.dtype)
|
| 307 |
|
| 308 |
final_hidden_states = torch.zeros(
|
| 309 |
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
| 310 |
)
|
| 311 |
|
| 312 |
-
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
| 313 |
-
|
| 314 |
for expert_idx in range(self.num_experts):
|
| 315 |
expert_layer = self.experts[expert_idx]
|
| 316 |
-
|
| 317 |
|
| 318 |
if top_x.shape[0] == 0:
|
| 319 |
continue
|
| 320 |
|
| 321 |
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
|
|
|
|
| 322 |
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
| 323 |
-
|
| 324 |
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
| 325 |
|
| 326 |
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
|
@@ -330,6 +348,7 @@ class Moondream3SparseMoeBlock(nn.Module):
|
|
| 330 |
class Moondream3DecoderLayer(nn.Module):
|
| 331 |
def __init__(self, config: Moondream3TextConfig, layer_idx: int):
|
| 332 |
super().__init__()
|
|
|
|
| 333 |
self.hidden_size = config.hidden_size
|
| 334 |
self.intermediate_size = config.intermediate_size
|
| 335 |
self.self_attn = Moondream3Attention(config, layer_idx, use_tau=True)
|
|
@@ -337,7 +356,7 @@ class Moondream3DecoderLayer(nn.Module):
|
|
| 337 |
|
| 338 |
self.is_moe_layer = layer_idx >= config.moe_start_layer
|
| 339 |
if self.is_moe_layer:
|
| 340 |
-
self.mlp = Moondream3SparseMoeBlock(config)
|
| 341 |
else:
|
| 342 |
self.mlp = Moondream3MLP(self.hidden_size, self.intermediate_size)
|
| 343 |
|
|
@@ -358,6 +377,8 @@ class Moondream3DecoderLayer(nn.Module):
|
|
| 358 |
|
| 359 |
# Apply layer norm like original
|
| 360 |
l_in = self.input_layernorm(hidden_states)
|
|
|
|
|
|
|
| 361 |
|
| 362 |
# Attention
|
| 363 |
hidden_states_attn, self_attn_weights = self.self_attn(
|
|
@@ -374,10 +395,12 @@ class Moondream3DecoderLayer(nn.Module):
|
|
| 374 |
|
| 375 |
# MLP
|
| 376 |
if self.is_moe_layer:
|
| 377 |
-
hidden_states_mlp, router_logits = self.mlp(l_in)
|
| 378 |
else:
|
| 379 |
hidden_states_mlp = self.mlp(l_in)
|
| 380 |
router_logits = None
|
|
|
|
|
|
|
| 381 |
|
| 382 |
# Add both attention and MLP to residual like original
|
| 383 |
hidden_states = residual + hidden_states_attn + hidden_states_mlp
|
|
@@ -538,8 +561,6 @@ class Moondream3TextModel(Moondream3PreTrainedModel):
|
|
| 538 |
if output_router_logits and layer_outputs[-1] is not None:
|
| 539 |
all_router_logits += (layer_outputs[-1],)
|
| 540 |
|
| 541 |
-
hidden_states = self.norm(hidden_states)
|
| 542 |
-
|
| 543 |
if output_hidden_states:
|
| 544 |
all_hidden_states += (hidden_states,)
|
| 545 |
|
|
@@ -794,8 +815,8 @@ class Moondream3RegionEncoder(nn.Module):
|
|
| 794 |
self.register_buffer("size_freq", size_freq.T)
|
| 795 |
|
| 796 |
def fourier_features(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
| 797 |
-
x_proj = torch.
|
| 798 |
-
return torch.cat([
|
| 799 |
|
| 800 |
def encode_coordinate(self, coord: torch.Tensor) -> torch.Tensor:
|
| 801 |
fourier_features = self.fourier_features(coord, self.coord_freq)
|
|
@@ -815,7 +836,7 @@ class Moondream3RegionDecoder(nn.Module):
|
|
| 815 |
return self.coord_decoder(hidden_state)
|
| 816 |
|
| 817 |
def decode_size(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
| 818 |
-
return self.size_decoder(hidden_state)
|
| 819 |
|
| 820 |
class Moondream3Model(Moondream3PreTrainedModel):
|
| 821 |
def __init__(self, config: Moondream3Config):
|
|
@@ -885,38 +906,59 @@ class Moondream3Model(Moondream3PreTrainedModel):
|
|
| 885 |
if cache_position is None:
|
| 886 |
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 887 |
cache_position: torch.Tensor = torch.arange(
|
| 888 |
-
past_seen_tokens, past_seen_tokens
|
| 889 |
)
|
| 890 |
|
| 891 |
if position_ids is None:
|
| 892 |
position_ids = cache_position.unsqueeze(0)
|
| 893 |
|
| 894 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 895 |
config=self.config,
|
| 896 |
input_embeds=inputs_embeds,
|
| 897 |
-
attention_mask=attention_mask,
|
| 898 |
cache_position=cache_position,
|
| 899 |
past_key_values=past_key_values,
|
| 900 |
position_ids=position_ids,
|
|
|
|
| 901 |
)
|
| 902 |
|
| 903 |
-
if pixel_values is not None and input_ids.shape[-1] > 1:
|
| 904 |
-
# Vision embeds
|
| 905 |
-
pixel_values = pixel_values.to(dtype=self.vision_model.embeddings.projection.weight.dtype)
|
| 906 |
-
image_embeds = self.vision_model(pixel_values, tiling=tiling)["last_hidden_state"] # [B,P,D]
|
| 907 |
-
prefix = inputs_embeds[:, :1, :] # keep the first token
|
| 908 |
-
suffix = inputs_embeds[:, 1 + image_embeds.shape[1] :, :] # keep the rest after the image span
|
| 909 |
-
inputs_embeds = torch.cat([prefix, image_embeds, suffix], dim=1)
|
| 910 |
-
|
| 911 |
-
# N/A when doing BSZ 1 since create_causal_mask returns None in the case since theres no padding tokens
|
| 912 |
-
if causal_mask is not None:
|
| 913 |
-
img_len = image_embeds.shape[1]
|
| 914 |
-
causal_mask[:, :, :1 + img_len, :1 + img_len] = True
|
| 915 |
-
causal_mask[:, :, :1 + img_len, 1 + img_len:] = False
|
| 916 |
-
|
| 917 |
outputs = self.text_model(
|
| 918 |
input_ids=None,
|
| 919 |
-
attention_mask=
|
| 920 |
position_ids=position_ids,
|
| 921 |
past_key_values=past_key_values,
|
| 922 |
inputs_embeds=inputs_embeds,
|
|
@@ -942,11 +984,17 @@ class Moondream3Model(Moondream3PreTrainedModel):
|
|
| 942 |
attentions=getattr(outputs, "attentions", None),
|
| 943 |
)
|
| 944 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 945 |
class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMixin):
|
| 946 |
_tied_weights_keys = ["lm_head.weight"]
|
| 947 |
|
| 948 |
def __init__(self, config: Moondream3Config):
|
| 949 |
super().__init__(config)
|
|
|
|
| 950 |
self.model = Moondream3Model(config)
|
| 951 |
self.vocab_size = config.text_config.vocab_size
|
| 952 |
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=True)
|
|
@@ -970,6 +1018,15 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
|
|
| 970 |
def get_decoder(self):
|
| 971 |
return self.model.text_model
|
| 972 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 973 |
def forward(
|
| 974 |
self,
|
| 975 |
input_ids: torch.LongTensor = None,
|
|
@@ -988,6 +1045,9 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
|
|
| 988 |
logits_to_keep: int = 0,
|
| 989 |
**kwargs: Unpack[TransformersKwargs],
|
| 990 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
|
|
|
|
|
|
|
|
| 991 |
# Get hidden states from the base model (it already builds the multimodal prefix)
|
| 992 |
model_outputs = self.model(
|
| 993 |
input_ids=input_ids,
|
|
@@ -1005,7 +1065,6 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
|
|
| 1005 |
cache_position=cache_position,
|
| 1006 |
logits_to_keep=logits_to_keep,
|
| 1007 |
)
|
| 1008 |
-
|
| 1009 |
hidden_states = model_outputs.last_hidden_state # [B, T, D]
|
| 1010 |
|
| 1011 |
# Compute logits; only keep the tail if requested
|
|
@@ -1016,8 +1075,127 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
|
|
| 1016 |
else:
|
| 1017 |
hs = hidden_states
|
| 1018 |
|
|
|
|
| 1019 |
logits = self.lm_head(hs) # [B, T', V]
|
| 1020 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1021 |
loss = None
|
| 1022 |
if labels is not None:
|
| 1023 |
# Shift if your training uses standard LM convention; here we assume labels aligned with hs
|
|
@@ -1031,6 +1209,47 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
|
|
| 1031 |
attentions=getattr(model_outputs, "attentions", None),
|
| 1032 |
)
|
| 1033 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1034 |
@staticmethod
|
| 1035 |
def _reorder_cache(past_key_values, beam_idx):
|
| 1036 |
reordered_past = ()
|
|
|
|
| 26 |
from transformers.activations import ACT2FN
|
| 27 |
from transformers.cache_utils import Cache, DynamicCache
|
| 28 |
from transformers.masking_utils import create_causal_mask
|
| 29 |
+
from dataclasses import dataclass
|
| 30 |
from transformers.modeling_outputs import (
|
| 31 |
BaseModelOutputWithPast,
|
| 32 |
CausalLMOutputWithPast,
|
|
|
|
| 35 |
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 36 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 37 |
from transformers.generation import GenerationMixin
|
| 38 |
+
from transformers.generation.utils import GenerateDecoderOnlyOutput
|
| 39 |
from transformers.utils import logging, TransformersKwargs
|
| 40 |
from .configuration_moondream3 import Moondream3Config, Moondream3TextConfig, Moondream3VisionConfig, Moondream3RegionConfig
|
| 41 |
|
|
|
|
| 43 |
|
| 44 |
_CONFIG_FOR_DOC = "Moondream3Config"
|
| 45 |
|
| 46 |
+
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
DEBUG=True
|
|
|
|
| 49 |
|
| 50 |
+
def apply_rotary_pos_emb(
|
| 51 |
+
q: torch.Tensor, # [B, H, L, D]
|
| 52 |
+
k: torch.Tensor, # [B, H, L, D]
|
| 53 |
+
cos: torch.Tensor, # [B, L, rot_dim]
|
| 54 |
+
sin: torch.Tensor, # [B, L, rot_dim]
|
| 55 |
+
rot_dim: int = 32,
|
| 56 |
+
):
|
| 57 |
+
"""
|
| 58 |
+
Apply rotary position embeddings to query and key tensors.
|
| 59 |
+
|
| 60 |
Args:
|
| 61 |
+
q: Query tensor [batch, num_heads, seq_len, head_dim]
|
| 62 |
+
k: Key tensor [batch, num_heads, seq_len, head_dim]
|
| 63 |
+
cos: Cosine frequencies [batch, seq_len, rot_dim]
|
| 64 |
+
sin: Sine frequencies [batch, seq_len, rot_dim]
|
| 65 |
+
rot_dim: Number of dimensions to apply rotation to (default: 32)
|
| 66 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
Returns:
|
| 68 |
+
Tuple of (rotated_q, rotated_k)
|
| 69 |
"""
|
|
|
|
| 70 |
|
| 71 |
+
def apply_rope(x):
|
| 72 |
+
dtype = x.dtype
|
| 73 |
+
x = x.to(torch.float64)
|
| 74 |
+
x_rot, x_pass = x[..., :rot_dim], x[..., rot_dim:]
|
| 75 |
+
|
| 76 |
+
d_q = x_rot.shape[-1] // 2
|
| 77 |
+
xq_r, xq_i = x_rot[..., :d_q], x_rot[..., d_q:]
|
| 78 |
+
|
| 79 |
+
xq_out_r = xq_r * cos - xq_i * sin
|
| 80 |
+
xq_out_i = xq_r * sin + xq_i * cos
|
| 81 |
|
| 82 |
+
xq_out = torch.stack((xq_out_r, xq_out_i), dim=-1).flatten(-2)
|
|
|
|
| 83 |
|
| 84 |
+
return torch.cat([xq_out, x_pass], dim=-1)
|
| 85 |
+
return apply_rope(q), apply_rope(k)
|
| 86 |
|
| 87 |
class Moondream3RotaryEmbedding(nn.Module):
|
| 88 |
+
inv_freq: torch.Tensor
|
| 89 |
|
| 90 |
def __init__(self, config: Moondream3Config, device=None):
|
| 91 |
super().__init__()
|
|
|
|
| 111 |
) -> tuple["torch.Tensor", float]:
|
| 112 |
"""
|
| 113 |
Computes the inverse frequencies according to the original RoPE implementation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
"""
|
| 115 |
+
base = config.rope_parameters["rope_theta"] # Should be 1500000.0 to match original
|
| 116 |
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
| 117 |
+
dim //= 2
|
| 118 |
+
|
| 119 |
+
attention_factor = 1.0
|
| 120 |
|
| 121 |
+
# Compute the inverse frequencies - matches your original formula
|
| 122 |
inv_freq = 1.0 / (
|
| 123 |
+
base ** (torch.arange(0, dim, 2, dtype=torch.float64)[: (dim // 2)] / dim)
|
| 124 |
)
|
| 125 |
+
if device is not None:
|
| 126 |
+
inv_freq = inv_freq.to(device=device)
|
| 127 |
return inv_freq, attention_factor
|
| 128 |
|
| 129 |
@torch.no_grad()
|
| 130 |
+
@dynamic_rope_update
|
| 131 |
def forward(self, x, position_ids):
|
| 132 |
+
# inv_freq shape: [dim//2]
|
| 133 |
+
# position_ids shape: [batch_size, seq_len]
|
| 134 |
+
|
| 135 |
+
inv_freq_expanded = self.inv_freq[None, :, None].to(torch.float64).expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 136 |
+
position_ids_expanded = position_ids[:, None, :].to(torch.float64)
|
| 137 |
|
| 138 |
+
freqs = (inv_freq_expanded.to(torch.float64) @ position_ids_expanded.to(torch.float64)).transpose(1, 2)
|
| 139 |
+
cfreqs = torch.exp(1j * freqs).unsqueeze(1).expand(-1, self.config.num_attention_heads, -1, -1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
+
return cfreqs.real, cfreqs.imag
|
| 142 |
|
| 143 |
|
| 144 |
class Moondream3Attention(nn.Module):
|
|
|
|
| 196 |
**kwargs,
|
| 197 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
| 198 |
input_shape = hidden_states.shape[:-1]
|
| 199 |
+
if isinstance(self.config, Moondream3TextConfig) and DEBUG:
|
| 200 |
+
torch.save(hidden_states, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_input_states")
|
| 201 |
bsz, q_len, _ = hidden_states.size()
|
| 202 |
|
| 203 |
# Get qkv combined for tau (before splitting)
|
|
|
|
| 219 |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 220 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 221 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 222 |
+
|
| 223 |
+
if isinstance(self.config, Moondream3TextConfig) and DEBUG:
|
| 224 |
+
torch.save(value_states, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_pre_tau_value")
|
| 225 |
|
| 226 |
if self.use_tau:
|
| 227 |
query_states = query_states * tau_q
|
|
|
|
| 232 |
tau_v_repeated = tau_v
|
| 233 |
value_states = value_states * tau_v_repeated
|
| 234 |
|
| 235 |
+
if isinstance(self.config, Moondream3TextConfig) and DEBUG:
|
| 236 |
+
torch.save(value_states, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_post_tau_value")
|
| 237 |
+
torch.save(key_states, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_pre_rope_key")
|
| 238 |
+
|
| 239 |
cos, sin = None, None
|
| 240 |
if position_embeddings is not None:
|
| 241 |
cos, sin = position_embeddings
|
| 242 |
+
if isinstance(self.config, Moondream3TextConfig) and DEBUG:
|
| 243 |
+
torch.save(cos, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_cos")
|
| 244 |
+
torch.save(sin, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_sin")
|
| 245 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 246 |
|
| 247 |
+
if isinstance(self.config, Moondream3TextConfig) and DEBUG:
|
| 248 |
+
torch.save(key_states, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_post_rope_key")
|
| 249 |
+
query_states, key_states = query_states.to(value_states.dtype), key_states.to(value_states.dtype)
|
| 250 |
|
| 251 |
+
if past_key_values is not None:
|
| 252 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 253 |
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 254 |
|
| 255 |
+
if isinstance(self.config, Moondream3TextConfig) and DEBUG:
|
| 256 |
+
torch.save(key_states, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_post_cache_key")
|
| 257 |
+
torch.save(attention_mask, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_attn_mask")
|
| 258 |
+
|
| 259 |
+
query_states = query_states.contiguous()
|
| 260 |
+
key_states = key_states.contiguous()
|
| 261 |
+
value_states = value_states.contiguous()
|
| 262 |
+
|
| 263 |
attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS["sdpa"](
|
| 264 |
self,
|
| 265 |
query_states,
|
| 266 |
+
key_states,
|
| 267 |
value_states,
|
| 268 |
attention_mask,
|
| 269 |
dropout=0.0 if not self.training else self.attention_dropout,
|
|
|
|
| 272 |
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 273 |
attn_output = self.o_proj(attn_output)
|
| 274 |
|
| 275 |
+
if isinstance(self.config, Moondream3TextConfig) and DEBUG:
|
| 276 |
+
torch.save(attn_output, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_attn_out")
|
| 277 |
+
|
| 278 |
return attn_output, attn_weights
|
| 279 |
|
| 280 |
class Moondream3MLP(nn.Module):
|
|
|
|
| 285 |
self.out_size = self.hidden_size if out_size is None else out_size
|
| 286 |
self.hidden_act = hidden_act
|
| 287 |
self.gated = gated
|
|
|
|
| 288 |
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
|
| 289 |
self.down_proj = nn.Linear(self.intermediate_size, self.out_size, bias=bias)
|
| 290 |
self.gate_proj = None
|
|
|
|
| 294 |
|
| 295 |
def forward(self, x) -> torch.Tensor:
|
| 296 |
if self.gated:
|
| 297 |
+
# separate up and gate causes precision issues
|
| 298 |
+
combined_weight = torch.cat([self.up_proj.weight, self.gate_proj.weight], dim=0)
|
| 299 |
+
h_full = F.linear(x, combined_weight)
|
| 300 |
+
h, g = h_full.chunk(2, dim=-1)
|
| 301 |
+
x = self.act_fn(h) * (g + 1)
|
| 302 |
else:
|
| 303 |
x = self.act_fn(self.up_proj(x))
|
| 304 |
return self.down_proj(x)
|
| 305 |
|
| 306 |
|
| 307 |
class Moondream3SparseMoeBlock(nn.Module):
|
| 308 |
+
def __init__(self, config: Moondream3TextConfig, layer_idx = None):
|
| 309 |
super().__init__()
|
| 310 |
+
self.layer_idx = layer_idx
|
| 311 |
self.hidden_size = config.hidden_size
|
| 312 |
self.moe_intermediate_size = config.moe_intermediate_size
|
| 313 |
self.num_experts = config.num_experts
|
|
|
|
| 316 |
self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=True)
|
| 317 |
self.experts = nn.ModuleList([Moondream3MLP(hidden_size=self.hidden_size, intermediate_size=self.moe_intermediate_size, hidden_act="gelu", gated=True, bias=False) for _ in range(self.num_experts)])
|
| 318 |
|
| 319 |
+
def forward(self, hidden_states: torch.Tensor, cache_position=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 320 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 321 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 322 |
router_logits: torch.Tensor = self.gate(hidden_states)
|
| 323 |
+
routing_weights, selected_experts = torch.topk(router_logits, self.top_k, dim=-1)
|
| 324 |
+
routing_weights = F.softmax(routing_weights, dim=-1, dtype=torch.float32)
|
|
|
|
|
|
|
| 325 |
routing_weights = routing_weights.to(hidden_states.dtype)
|
| 326 |
|
| 327 |
final_hidden_states = torch.zeros(
|
| 328 |
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
| 329 |
)
|
| 330 |
|
|
|
|
|
|
|
| 331 |
for expert_idx in range(self.num_experts):
|
| 332 |
expert_layer = self.experts[expert_idx]
|
| 333 |
+
top_x, idx = (selected_experts == expert_idx).nonzero(as_tuple=True)
|
| 334 |
|
| 335 |
if top_x.shape[0] == 0:
|
| 336 |
continue
|
| 337 |
|
| 338 |
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
|
| 339 |
+
# torch.save(current_state, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_e{expert_idx}")
|
| 340 |
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
| 341 |
+
# torch.save(current_hidden_states, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_e{expert_idx}")
|
| 342 |
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
| 343 |
|
| 344 |
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
|
|
|
| 348 |
class Moondream3DecoderLayer(nn.Module):
|
| 349 |
def __init__(self, config: Moondream3TextConfig, layer_idx: int):
|
| 350 |
super().__init__()
|
| 351 |
+
self.layer_idx = layer_idx
|
| 352 |
self.hidden_size = config.hidden_size
|
| 353 |
self.intermediate_size = config.intermediate_size
|
| 354 |
self.self_attn = Moondream3Attention(config, layer_idx, use_tau=True)
|
|
|
|
| 356 |
|
| 357 |
self.is_moe_layer = layer_idx >= config.moe_start_layer
|
| 358 |
if self.is_moe_layer:
|
| 359 |
+
self.mlp = Moondream3SparseMoeBlock(config, layer_idx=layer_idx)
|
| 360 |
else:
|
| 361 |
self.mlp = Moondream3MLP(self.hidden_size, self.intermediate_size)
|
| 362 |
|
|
|
|
| 377 |
|
| 378 |
# Apply layer norm like original
|
| 379 |
l_in = self.input_layernorm(hidden_states)
|
| 380 |
+
if DEBUG:
|
| 381 |
+
torch.save(l_in, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_ln_out")
|
| 382 |
|
| 383 |
# Attention
|
| 384 |
hidden_states_attn, self_attn_weights = self.self_attn(
|
|
|
|
| 395 |
|
| 396 |
# MLP
|
| 397 |
if self.is_moe_layer:
|
| 398 |
+
hidden_states_mlp, router_logits = self.mlp(l_in, cache_position=cache_position)
|
| 399 |
else:
|
| 400 |
hidden_states_mlp = self.mlp(l_in)
|
| 401 |
router_logits = None
|
| 402 |
+
if DEBUG:
|
| 403 |
+
torch.save(hidden_states_mlp, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_mlp_out")
|
| 404 |
|
| 405 |
# Add both attention and MLP to residual like original
|
| 406 |
hidden_states = residual + hidden_states_attn + hidden_states_mlp
|
|
|
|
| 561 |
if output_router_logits and layer_outputs[-1] is not None:
|
| 562 |
all_router_logits += (layer_outputs[-1],)
|
| 563 |
|
|
|
|
|
|
|
| 564 |
if output_hidden_states:
|
| 565 |
all_hidden_states += (hidden_states,)
|
| 566 |
|
|
|
|
| 815 |
self.register_buffer("size_freq", size_freq.T)
|
| 816 |
|
| 817 |
def fourier_features(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
| 818 |
+
x_proj = 2 * torch.pi * x @ w
|
| 819 |
+
return torch.cat([x_proj.cos(), x_proj.sin()], dim=-1)
|
| 820 |
|
| 821 |
def encode_coordinate(self, coord: torch.Tensor) -> torch.Tensor:
|
| 822 |
fourier_features = self.fourier_features(coord, self.coord_freq)
|
|
|
|
| 836 |
return self.coord_decoder(hidden_state)
|
| 837 |
|
| 838 |
def decode_size(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
| 839 |
+
return self.size_decoder(hidden_state).view(hidden_state.shape[0],2,-1)
|
| 840 |
|
| 841 |
class Moondream3Model(Moondream3PreTrainedModel):
|
| 842 |
def __init__(self, config: Moondream3Config):
|
|
|
|
| 906 |
if cache_position is None:
|
| 907 |
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 908 |
cache_position: torch.Tensor = torch.arange(
|
| 909 |
+
past_seen_tokens, past_seen_tokens, device=inputs_embeds.device
|
| 910 |
)
|
| 911 |
|
| 912 |
if position_ids is None:
|
| 913 |
position_ids = cache_position.unsqueeze(0)
|
| 914 |
|
| 915 |
+
def image_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int):
|
| 916 |
+
# set all up to `self.config.vision_config.prefix_len` to true
|
| 917 |
+
return kv_idx <= q_idx
|
| 918 |
+
|
| 919 |
+
if pixel_values is not None:
|
| 920 |
+
# Vision embeds
|
| 921 |
+
pixel_values = pixel_values.to(dtype=self.vision_model.embeddings.projection.weight.dtype)
|
| 922 |
+
image_embeds = self.vision_model(pixel_values, tiling=tiling)["last_hidden_state"] # [B,P,D]
|
| 923 |
+
prefix = self.text_model.embed_tokens(
|
| 924 |
+
torch.full((input_ids.shape[0], 1), self.config.text_config.bos_token_id, dtype=input_ids.dtype, device=input_ids.device)
|
| 925 |
+
)
|
| 926 |
+
embeds = torch.cat([prefix, image_embeds], dim=1)
|
| 927 |
+
cache_pos = torch.arange(embeds.shape[-2], device=embeds.device)
|
| 928 |
+
pos = cache_pos.unsqueeze(0).expand(embeds.shape[0],-1)
|
| 929 |
+
attn_mask = torch.full(
|
| 930 |
+
(embeds.shape[0], 1, embeds.shape[-2], pos.shape[-1]),
|
| 931 |
+
True,
|
| 932 |
+
dtype=torch.bool,
|
| 933 |
+
device=embeds.device,
|
| 934 |
+
)
|
| 935 |
+
|
| 936 |
+
outputs = self.text_model(
|
| 937 |
+
input_ids=None,
|
| 938 |
+
attention_mask=attn_mask,
|
| 939 |
+
position_ids=pos,
|
| 940 |
+
past_key_values=past_key_values,
|
| 941 |
+
inputs_embeds=embeds,
|
| 942 |
+
use_cache=use_cache,
|
| 943 |
+
output_attentions=output_attentions,
|
| 944 |
+
output_hidden_states=output_hidden_states,
|
| 945 |
+
return_dict=True,
|
| 946 |
+
cache_position=cache_pos,
|
| 947 |
+
)
|
| 948 |
+
|
| 949 |
+
attn_mask = create_causal_mask(
|
| 950 |
config=self.config,
|
| 951 |
input_embeds=inputs_embeds,
|
| 952 |
+
attention_mask=torch.cat([torch.ones(attention_mask.shape[0], cache_position[-1] + 1 - attention_mask.shape[-1], device=attention_mask.device, dtype=attention_mask.dtype), attention_mask], dim=-1),
|
| 953 |
cache_position=cache_position,
|
| 954 |
past_key_values=past_key_values,
|
| 955 |
position_ids=position_ids,
|
| 956 |
+
and_mask_function=image_mask_function
|
| 957 |
)
|
| 958 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 959 |
outputs = self.text_model(
|
| 960 |
input_ids=None,
|
| 961 |
+
attention_mask=attn_mask,
|
| 962 |
position_ids=position_ids,
|
| 963 |
past_key_values=past_key_values,
|
| 964 |
inputs_embeds=inputs_embeds,
|
|
|
|
| 984 |
attentions=getattr(outputs, "attentions", None),
|
| 985 |
)
|
| 986 |
|
| 987 |
+
@dataclass
|
| 988 |
+
class Moondream3GenerateOutput(GenerateDecoderOnlyOutput):
|
| 989 |
+
objects: Optional[list[dict[str,float]]] = None
|
| 990 |
+
|
| 991 |
+
|
| 992 |
class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMixin):
|
| 993 |
_tied_weights_keys = ["lm_head.weight"]
|
| 994 |
|
| 995 |
def __init__(self, config: Moondream3Config):
|
| 996 |
super().__init__(config)
|
| 997 |
+
self.objects = None
|
| 998 |
self.model = Moondream3Model(config)
|
| 999 |
self.vocab_size = config.text_config.vocab_size
|
| 1000 |
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=True)
|
|
|
|
| 1018 |
def get_decoder(self):
|
| 1019 |
return self.model.text_model
|
| 1020 |
|
| 1021 |
+
def _prepare_generated_length(
|
| 1022 |
+
self,
|
| 1023 |
+
generation_config,
|
| 1024 |
+
**kwargs,
|
| 1025 |
+
):
|
| 1026 |
+
generation_config = super()._prepare_generated_length(generation_config, **kwargs)
|
| 1027 |
+
generation_config.max_length += 730
|
| 1028 |
+
return generation_config
|
| 1029 |
+
|
| 1030 |
def forward(
|
| 1031 |
self,
|
| 1032 |
input_ids: torch.LongTensor = None,
|
|
|
|
| 1045 |
logits_to_keep: int = 0,
|
| 1046 |
**kwargs: Unpack[TransformersKwargs],
|
| 1047 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 1048 |
+
if pixel_values is not None and inputs_embeds is None:
|
| 1049 |
+
position_ids += self.config.vision_config.prefix_len
|
| 1050 |
+
cache_position += self.config.vision_config.prefix_len
|
| 1051 |
# Get hidden states from the base model (it already builds the multimodal prefix)
|
| 1052 |
model_outputs = self.model(
|
| 1053 |
input_ids=input_ids,
|
|
|
|
| 1065 |
cache_position=cache_position,
|
| 1066 |
logits_to_keep=logits_to_keep,
|
| 1067 |
)
|
|
|
|
| 1068 |
hidden_states = model_outputs.last_hidden_state # [B, T, D]
|
| 1069 |
|
| 1070 |
# Compute logits; only keep the tail if requested
|
|
|
|
| 1075 |
else:
|
| 1076 |
hs = hidden_states
|
| 1077 |
|
| 1078 |
+
hs = self.model.text_model.norm(hs)
|
| 1079 |
logits = self.lm_head(hs) # [B, T', V]
|
| 1080 |
|
| 1081 |
+
pred = torch.argmax(logits, dim=-1)
|
| 1082 |
+
|
| 1083 |
+
pos_ids = position_ids[:,-1:] + 1
|
| 1084 |
+
cache_pos = cache_position[-1:] + 1
|
| 1085 |
+
mask = torch.ones(
|
| 1086 |
+
hidden_states.shape[0], 1, device=self.device, dtype=torch.long
|
| 1087 |
+
)
|
| 1088 |
+
while torch.any(pred == 5):
|
| 1089 |
+
batch_mask = (pred[:, -1] == 5)
|
| 1090 |
+
hidden_states = hidden_states[:, -1:, :]
|
| 1091 |
+
x_logits = self.model.region_decoder.decode_coordinate(hidden_states)
|
| 1092 |
+
x_center = torch.argmax(x_logits, dim=-1) / x_logits.size(-1)
|
| 1093 |
+
next_embeds = self.model.region_encoder.encode_coordinate(x_center.to(x_logits.dtype)).unsqueeze(1)
|
| 1094 |
+
model_outputs = self.model(
|
| 1095 |
+
input_ids=None,
|
| 1096 |
+
pixel_values=None,
|
| 1097 |
+
tiling=None,
|
| 1098 |
+
attention_mask=mask,
|
| 1099 |
+
position_ids=pos_ids,
|
| 1100 |
+
past_key_values=past_key_values,
|
| 1101 |
+
inputs_embeds=next_embeds,
|
| 1102 |
+
labels=None,
|
| 1103 |
+
use_cache=use_cache,
|
| 1104 |
+
output_attentions=output_attentions,
|
| 1105 |
+
output_hidden_states=output_hidden_states,
|
| 1106 |
+
return_dict=True,
|
| 1107 |
+
cache_position=cache_pos,
|
| 1108 |
+
logits_to_keep=logits_to_keep,
|
| 1109 |
+
)
|
| 1110 |
+
hidden_states = model_outputs.last_hidden_state # [B, T, D]
|
| 1111 |
+
y_logits = self.model.region_decoder.decode_coordinate(hidden_states)
|
| 1112 |
+
y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
|
| 1113 |
+
next_embeds = self.model.region_encoder.encode_coordinate(y_center.to(y_logits.dtype)).unsqueeze(1)
|
| 1114 |
+
coords = torch.cat([x_center, y_center], dim=1)
|
| 1115 |
+
coords = coords * (batch_mask).unsqueeze(1)
|
| 1116 |
+
pos_ids += 1
|
| 1117 |
+
cache_pos = cache_pos + 1
|
| 1118 |
+
bbox = None
|
| 1119 |
+
if input_ids[0,1] == 7235:
|
| 1120 |
+
model_outputs = self.model(
|
| 1121 |
+
input_ids=None,
|
| 1122 |
+
pixel_values=None,
|
| 1123 |
+
tiling=None,
|
| 1124 |
+
attention_mask=mask,
|
| 1125 |
+
position_ids=pos_ids,
|
| 1126 |
+
past_key_values=past_key_values,
|
| 1127 |
+
inputs_embeds=next_embeds,
|
| 1128 |
+
labels=None,
|
| 1129 |
+
use_cache=use_cache,
|
| 1130 |
+
output_attentions=output_attentions,
|
| 1131 |
+
output_hidden_states=output_hidden_states,
|
| 1132 |
+
return_dict=True,
|
| 1133 |
+
cache_position=cache_pos,
|
| 1134 |
+
logits_to_keep=logits_to_keep,
|
| 1135 |
+
)
|
| 1136 |
+
hidden_states = model_outputs.last_hidden_state # [B, T, D]
|
| 1137 |
+
size_logits = self.model.region_decoder.decode_size(hidden_states)
|
| 1138 |
+
bins = torch.argmax(size_logits, dim=-1)
|
| 1139 |
+
w_bin = bins[:,0]
|
| 1140 |
+
h_bin = bins[:,1]
|
| 1141 |
+
|
| 1142 |
+
w = torch.pow(2.0, (w_bin.float() / 1023.0) * 10.0 - 10.0)
|
| 1143 |
+
h = torch.pow(2.0, (h_bin.float() / 1023.0) * 10.0 - 10.0)
|
| 1144 |
+
|
| 1145 |
+
next_embeds = (
|
| 1146 |
+
self.model.region_encoder.encode_size(
|
| 1147 |
+
torch.stack([w, h],dim=-1).to(size_logits.dtype)
|
| 1148 |
+
)
|
| 1149 |
+
).unsqueeze(1)
|
| 1150 |
+
bbox = [
|
| 1151 |
+
x_center.item() - w.item() / 2,
|
| 1152 |
+
y_center.item() - h.item() / 2,
|
| 1153 |
+
x_center.item() + w.item() / 2,
|
| 1154 |
+
y_center.item() + h.item() / 2,
|
| 1155 |
+
]
|
| 1156 |
+
bbox = bbox * (batch_mask).unsqueeze(1)
|
| 1157 |
+
pos_ids += 1
|
| 1158 |
+
cache_pos = cache_pos + 1
|
| 1159 |
+
|
| 1160 |
+
new = coords.unsqueeze(1) if bbox is None else bbox.unsqueeze(1)
|
| 1161 |
+
if self.objects is None:
|
| 1162 |
+
self.objects = new
|
| 1163 |
+
else:
|
| 1164 |
+
self.objects = torch.cat([self.objects, new], dim=1)
|
| 1165 |
+
model_outputs = self.model(
|
| 1166 |
+
input_ids=None,
|
| 1167 |
+
pixel_values=None,
|
| 1168 |
+
tiling=None,
|
| 1169 |
+
attention_mask=mask,
|
| 1170 |
+
position_ids=pos_ids,
|
| 1171 |
+
past_key_values=past_key_values,
|
| 1172 |
+
inputs_embeds=next_embeds,
|
| 1173 |
+
labels=None,
|
| 1174 |
+
use_cache=use_cache,
|
| 1175 |
+
output_attentions=output_attentions,
|
| 1176 |
+
output_hidden_states=output_hidden_states,
|
| 1177 |
+
return_dict=True,
|
| 1178 |
+
cache_position=cache_pos,
|
| 1179 |
+
logits_to_keep=logits_to_keep,
|
| 1180 |
+
)
|
| 1181 |
+
pos_ids += 1
|
| 1182 |
+
cache_pos = cache_pos + 1
|
| 1183 |
+
hidden_states = model_outputs.last_hidden_state # [B, T, D]
|
| 1184 |
+
|
| 1185 |
+
indices = torch.tensor(
|
| 1186 |
+
[self.config.text_config.coord_token_id, self.config.text_config.eos_token_id],
|
| 1187 |
+
device=self.device,
|
| 1188 |
+
)
|
| 1189 |
+
|
| 1190 |
+
hidden_states = self.model.text_model.norm(hidden_states)
|
| 1191 |
+
logits = hidden_states @ self.lm_head.weight[indices].T + self.lm_head.bias[indices]
|
| 1192 |
+
|
| 1193 |
+
logits_full = torch.full((logits.shape[0], logits.shape[1], self.config.text_config.vocab_size), float('-inf'), device=logits.device, dtype=logits.dtype)
|
| 1194 |
+
logits_full[:, :, torch.tensor([5,0])] = logits
|
| 1195 |
+
logits = logits_full
|
| 1196 |
+
pred[batch_mask] = torch.argmax(logits, dim=-1)[batch_mask]
|
| 1197 |
+
|
| 1198 |
+
|
| 1199 |
loss = None
|
| 1200 |
if labels is not None:
|
| 1201 |
# Shift if your training uses standard LM convention; here we assume labels aligned with hs
|
|
|
|
| 1209 |
attentions=getattr(model_outputs, "attentions", None),
|
| 1210 |
)
|
| 1211 |
|
| 1212 |
+
def generate(self, **kwargs) -> Union[Moondream3GenerateOutput, torch.LongTensor]:
|
| 1213 |
+
outputs = super().generate(**kwargs)
|
| 1214 |
+
if len(self.objects) > 0:
|
| 1215 |
+
if isinstance(outputs, torch.Tensor):
|
| 1216 |
+
outputs = self.objects
|
| 1217 |
+
self.objects = []
|
| 1218 |
+
else:
|
| 1219 |
+
outputs = Moondream3GenerateOutput(
|
| 1220 |
+
**outputs,
|
| 1221 |
+
objects=self.objects
|
| 1222 |
+
)
|
| 1223 |
+
self.objects = []
|
| 1224 |
+
return outputs
|
| 1225 |
+
|
| 1226 |
+
def prepare_inputs_for_generation(
|
| 1227 |
+
self,
|
| 1228 |
+
input_ids,
|
| 1229 |
+
**model_kwargs
|
| 1230 |
+
):
|
| 1231 |
+
model_inputs = super().prepare_inputs_for_generation(input_ids, **model_kwargs)
|
| 1232 |
+
model_inputs["position_ids"] += model_inputs["cache_position"].unsqueeze(0) - model_inputs["position_ids"]
|
| 1233 |
+
return model_inputs
|
| 1234 |
+
|
| 1235 |
+
def _update_model_kwargs_for_generation(
|
| 1236 |
+
self,
|
| 1237 |
+
outputs,
|
| 1238 |
+
model_kwargs,
|
| 1239 |
+
is_encoder_decoder,
|
| 1240 |
+
num_new_tokens: int = 1,
|
| 1241 |
+
):
|
| 1242 |
+
model_kwargs = super()._update_model_kwargs_for_generation(
|
| 1243 |
+
outputs,
|
| 1244 |
+
model_kwargs,
|
| 1245 |
+
is_encoder_decoder=is_encoder_decoder,
|
| 1246 |
+
num_new_tokens=num_new_tokens,
|
| 1247 |
+
)
|
| 1248 |
+
model_kwargs["pixel_values"] = None
|
| 1249 |
+
model_kwargs["tiling"] = None
|
| 1250 |
+
return model_kwargs
|
| 1251 |
+
|
| 1252 |
+
|
| 1253 |
@staticmethod
|
| 1254 |
def _reorder_cache(past_key_values, beam_idx):
|
| 1255 |
reordered_past = ()
|
processing_moondream3.py
CHANGED
|
@@ -211,40 +211,40 @@ class Moondream3Processor(ProcessorMixin):
|
|
| 211 |
|
| 212 |
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
| 213 |
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
|
| 214 |
-
if "input_ids" in text_inputs:
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
if "attention_mask" in text_inputs:
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
|
| 225 |
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
| 226 |
|
| 227 |
-
def apply_chat_template(
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
) -> str:
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
|
| 249 |
@property
|
| 250 |
def model_input_names(self):
|
|
|
|
| 211 |
|
| 212 |
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
| 213 |
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
|
| 214 |
+
# if "input_ids" in text_inputs:
|
| 215 |
+
# # prepend 1 bos_token_id and 729 image_token_id to the text_inputs
|
| 216 |
+
# for i in range(len(text_inputs["input_ids"])):
|
| 217 |
+
# prepended_tokens = [self.tokenizer.bos_token_id] + [self.image_token_id] * 729
|
| 218 |
+
# text_inputs["input_ids"][i] = prepended_tokens + text_inputs["input_ids"][i]
|
| 219 |
+
# if "attention_mask" in text_inputs:
|
| 220 |
+
# # attend to the 730 prepended tokens
|
| 221 |
+
# for i in range(len(text_inputs["attention_mask"])):
|
| 222 |
+
# prepended_mask = [1] * 730
|
| 223 |
+
# text_inputs["attention_mask"][i] = prepended_mask + text_inputs["attention_mask"][i]
|
| 224 |
|
| 225 |
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
| 226 |
|
| 227 |
+
# def apply_chat_template(
|
| 228 |
+
# self,
|
| 229 |
+
# conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
|
| 230 |
+
# chat_template: Optional[str] = None,
|
| 231 |
+
# **kwargs,
|
| 232 |
+
# ) -> str:
|
| 233 |
+
# # Call the original behavior first
|
| 234 |
+
# out = super().apply_chat_template(
|
| 235 |
+
# conversation=conversation,
|
| 236 |
+
# chat_template=chat_template,
|
| 237 |
+
# **kwargs,
|
| 238 |
+
# )
|
| 239 |
+
|
| 240 |
+
# # Only post-process when:
|
| 241 |
+
# # - user requested assistant mask
|
| 242 |
+
# # - output is a dict (tokenized + return_dict=True path)
|
| 243 |
+
# if isinstance(out, BatchFeature) and kwargs.get("return_assistant_tokens_mask", False):
|
| 244 |
+
# if "assistant_masks" in out and out["assistant_masks"] is not None:
|
| 245 |
+
# out["assistant_masks"] = _rotate_right_array(out["assistant_masks"], 730)
|
| 246 |
+
|
| 247 |
+
# return out
|
| 248 |
|
| 249 |
@property
|
| 250 |
def model_input_names(self):
|