NyxKrage commited on
Commit
4cbb231
·
verified ·
1 Parent(s): 1101625

Update modeling_moondream3.py

Browse files
Files changed (1) hide show
  1. modeling_moondream3.py +386 -223
modeling_moondream3.py CHANGED
@@ -15,14 +15,10 @@
15
 
16
  from typing import Callable, Optional, Tuple, Union
17
 
18
- import numpy as np
19
-
20
  import torch
21
  import torch.nn as nn
22
  import torch.nn.functional as F
23
 
24
- from PIL import Image
25
-
26
  from transformers.activations import ACT2FN
27
  from transformers.cache_utils import Cache, DynamicCache
28
  from transformers.masking_utils import create_causal_mask
@@ -37,33 +33,35 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
  from transformers.generation import GenerationMixin
38
  from transformers.generation.utils import GenerateDecoderOnlyOutput
39
  from transformers.utils import logging, TransformersKwargs
40
- from .configuration_moondream3 import Moondream3Config, Moondream3TextConfig, Moondream3VisionConfig, Moondream3RegionConfig
 
 
 
 
 
41
 
42
  logger = logging.get_logger(__name__)
43
 
44
  _CONFIG_FOR_DOC = "Moondream3Config"
45
 
46
- import torch
47
-
48
- DEBUG=True
49
 
50
  def apply_rotary_pos_emb(
51
- q: torch.Tensor, # [B, H, L, D]
52
- k: torch.Tensor, # [B, H, L, D]
53
- cos: torch.Tensor, # [B, L, rot_dim]
54
- sin: torch.Tensor, # [B, L, rot_dim]
55
  rot_dim: int = 32,
56
  ):
57
  """
58
  Apply rotary position embeddings to query and key tensors.
59
-
60
  Args:
61
  q: Query tensor [batch, num_heads, seq_len, head_dim]
62
  k: Key tensor [batch, num_heads, seq_len, head_dim]
63
  cos: Cosine frequencies [batch, seq_len, rot_dim]
64
  sin: Sine frequencies [batch, seq_len, rot_dim]
65
  rot_dim: Number of dimensions to apply rotation to (default: 32)
66
-
67
  Returns:
68
  Tuple of (rotated_q, rotated_k)
69
  """
@@ -77,13 +75,15 @@ def apply_rotary_pos_emb(
77
  xq_r, xq_i = x_rot[..., :d_q], x_rot[..., d_q:]
78
 
79
  xq_out_r = xq_r * cos - xq_i * sin
80
- xq_out_i = xq_r * sin + xq_i * cos
81
 
82
  xq_out = torch.stack((xq_out_r, xq_out_i), dim=-1).flatten(-2)
83
 
84
  return torch.cat([xq_out, x_pass], dim=-1)
 
85
  return apply_rope(q), apply_rope(k)
86
 
 
87
  class Moondream3RotaryEmbedding(nn.Module):
88
  inv_freq: torch.Tensor
89
 
@@ -112,15 +112,17 @@ class Moondream3RotaryEmbedding(nn.Module):
112
  """
113
  Computes the inverse frequencies according to the original RoPE implementation
114
  """
115
- base = config.rope_parameters["rope_theta"] # Should be 1500000.0 to match original
116
- dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
 
 
 
117
  dim //= 2
118
-
119
  attention_factor = 1.0
120
 
121
- # Compute the inverse frequencies - matches your original formula
122
  inv_freq = 1.0 / (
123
- base ** (torch.arange(0, dim, 2, dtype=torch.float64)[: (dim // 2)] / dim)
124
  )
125
  if device is not None:
126
  inv_freq = inv_freq.to(device=device)
@@ -129,38 +131,53 @@ class Moondream3RotaryEmbedding(nn.Module):
129
  @torch.no_grad()
130
  @dynamic_rope_update
131
  def forward(self, x, position_ids):
132
- # inv_freq shape: [dim//2]
133
- # position_ids shape: [batch_size, seq_len]
134
-
135
- inv_freq_expanded = self.inv_freq[None, :, None].to(torch.float64).expand(position_ids.shape[0], -1, 1).to(x.device)
136
- position_ids_expanded = position_ids[:, None, :].to(torch.float64)
137
-
138
- freqs = (inv_freq_expanded.to(torch.float64) @ position_ids_expanded.to(torch.float64)).transpose(1, 2)
139
- cfreqs = torch.exp(1j * freqs).unsqueeze(1).expand(-1, self.config.num_attention_heads, -1, -1)
 
 
 
 
 
 
 
 
 
140
 
141
  return cfreqs.real, cfreqs.imag
142
 
143
 
144
  class Moondream3Attention(nn.Module):
145
- def __init__(self, config: Moondream3TextConfig | Moondream3VisionConfig, layer_idx: Optional[int] = None, use_tau: bool = True):
 
 
 
 
 
146
  super().__init__()
147
  self.config = config
148
  self.layer_idx = layer_idx
149
  self.hidden_size = config.hidden_size
150
  self.num_heads = config.num_attention_heads
151
  self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
152
- self.num_key_value_heads = getattr(config, "num_key_value_heads", self.num_heads)
 
 
153
  attention_bias = config.attention_bias
154
  self.attention_dropout = config.attention_dropout
155
 
156
- # Initialize parameters based on config type
157
  if isinstance(config, Moondream3TextConfig):
158
  self.is_causal = True
159
- elif isinstance(config, Moondream3VisionConfig): # Moondream3VisionConfig
160
  self.is_causal = False
161
  else:
162
  raise TypeError(f"Unsupported config type: {type(config)}")
163
-
164
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
165
  self.use_tau = use_tau
166
 
@@ -170,15 +187,29 @@ class Moondream3Attention(nn.Module):
170
  f" and `num_heads`: {self.num_heads})."
171
  )
172
 
173
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=attention_bias)
174
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=attention_bias)
175
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=attention_bias)
176
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=attention_bias)
177
-
178
- # Tau parameters for token-level attention (from original Moondream) - only for text model
 
 
 
 
 
 
 
 
 
 
 
179
  if self.use_tau:
180
  # In original, tau weights are (n_heads, qkv_dim) where qkv_dim is the combined QKV dimension
181
- qkv_dim = self.num_heads * self.head_dim + 2 * self.num_key_value_heads * self.head_dim
 
 
 
182
  self.tau_wq = nn.Linear(qkv_dim, self.num_heads, bias=False)
183
  self.tau_wv = nn.Linear(qkv_dim, self.num_heads, bias=False)
184
  self.tau_alpha = nn.Parameter(torch.empty(self.num_heads))
@@ -196,65 +227,65 @@ class Moondream3Attention(nn.Module):
196
  **kwargs,
197
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
198
  input_shape = hidden_states.shape[:-1]
199
- if isinstance(self.config, Moondream3TextConfig) and DEBUG:
200
- torch.save(hidden_states, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_input_states")
201
  bsz, q_len, _ = hidden_states.size()
202
 
203
- # Get qkv combined for tau (before splitting)
204
  query_states = self.q_proj(hidden_states)
205
  key_states = self.k_proj(hidden_states)
206
  value_states = self.v_proj(hidden_states)
207
  if self.use_tau:
208
  qkv_out = torch.cat([query_states, key_states, value_states], dim=-1)
209
  tok_feat = F.gelu(qkv_out)
210
- tok_q = torch.tanh(self.tau_wq(tok_feat)).permute(0, 2, 1) # (bsz, n_heads, seq_len)
211
- tok_v = torch.tanh(self.tau_wv(tok_feat)).permute(0, 2, 1) # (bsz, n_heads, seq_len)
212
 
213
  pos = position_ids.to(tok_q.dtype) + 1
214
  alpha = self.tau_alpha.to(tok_q.dtype)
215
- tau_pos = 1 + (torch.sigmoid(alpha[None, :, None] * pos[:, None, :].log()) - 0.5) # (n_heads, seq_len)
216
- tau_q = (tok_q + tau_pos).unsqueeze(-1) # (bsz, n_heads, seq_len, 1)
217
- tau_v = (tok_v + tau_pos).unsqueeze(-1) # (bsz, n_heads, seq_len, 1)
218
-
219
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
220
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
221
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
222
-
223
- if isinstance(self.config, Moondream3TextConfig) and DEBUG:
224
- torch.save(value_states, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_pre_tau_value")
 
 
 
 
 
225
 
226
  if self.use_tau:
227
  query_states = query_states * tau_q
228
 
229
  if self.num_key_value_groups > 1:
230
- tau_v_repeated = tau_v.repeat(1, self.num_key_value_groups, 1, 1)[:, :self.num_key_value_heads, :, :]
 
 
231
  else:
232
  tau_v_repeated = tau_v
233
  value_states = value_states * tau_v_repeated
234
 
235
- if isinstance(self.config, Moondream3TextConfig) and DEBUG:
236
- torch.save(value_states, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_post_tau_value")
237
- torch.save(key_states, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_pre_rope_key")
238
-
239
  cos, sin = None, None
240
  if position_embeddings is not None:
241
  cos, sin = position_embeddings
242
- if isinstance(self.config, Moondream3TextConfig) and DEBUG:
243
- torch.save(cos, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_cos")
244
- torch.save(sin, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_sin")
245
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
246
 
247
- if isinstance(self.config, Moondream3TextConfig) and DEBUG:
248
- torch.save(key_states, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_post_rope_key")
249
- query_states, key_states = query_states.to(value_states.dtype), key_states.to(value_states.dtype)
 
 
 
 
 
250
 
251
  if past_key_values is not None:
252
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
253
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
254
-
255
- if isinstance(self.config, Moondream3TextConfig) and DEBUG:
256
- torch.save(key_states, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_post_cache_key")
257
- torch.save(attention_mask, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_attn_mask")
258
 
259
  query_states = query_states.contiguous()
260
  key_states = key_states.contiguous()
@@ -272,13 +303,19 @@ class Moondream3Attention(nn.Module):
272
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
273
  attn_output = self.o_proj(attn_output)
274
 
275
- if isinstance(self.config, Moondream3TextConfig) and DEBUG:
276
- torch.save(attn_output, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_attn_out")
277
-
278
  return attn_output, attn_weights
279
 
 
280
  class Moondream3MLP(nn.Module):
281
- def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str = "gelu_pytorch_tanh", out_size: int | None = None, gated: bool = False, bias: bool = True):
 
 
 
 
 
 
 
 
282
  super().__init__()
283
  self.hidden_size = hidden_size
284
  self.intermediate_size = intermediate_size
@@ -289,15 +326,15 @@ class Moondream3MLP(nn.Module):
289
  self.down_proj = nn.Linear(self.intermediate_size, self.out_size, bias=bias)
290
  self.gate_proj = None
291
  if self.gated:
292
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
 
 
293
  self.act_fn = ACT2FN[self.hidden_act]
294
 
295
  def forward(self, x) -> torch.Tensor:
296
  if self.gated:
297
- # separate up and gate causes precision issues
298
- combined_weight = torch.cat([self.up_proj.weight, self.gate_proj.weight], dim=0)
299
- h_full = F.linear(x, combined_weight)
300
- h, g = h_full.chunk(2, dim=-1)
301
  x = self.act_fn(h) * (g + 1)
302
  else:
303
  x = self.act_fn(self.up_proj(x))
@@ -305,7 +342,7 @@ class Moondream3MLP(nn.Module):
305
 
306
 
307
  class Moondream3SparseMoeBlock(nn.Module):
308
- def __init__(self, config: Moondream3TextConfig, layer_idx = None):
309
  super().__init__()
310
  self.layer_idx = layer_idx
311
  self.hidden_size = config.hidden_size
@@ -314,18 +351,35 @@ class Moondream3SparseMoeBlock(nn.Module):
314
  self.top_k = config.num_experts_per_tok
315
 
316
  self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=True)
317
- self.experts = nn.ModuleList([Moondream3MLP(hidden_size=self.hidden_size, intermediate_size=self.moe_intermediate_size, hidden_act="gelu", gated=True, bias=False) for _ in range(self.num_experts)])
 
 
 
 
 
 
 
 
 
 
 
318
 
319
- def forward(self, hidden_states: torch.Tensor, cache_position=None) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
320
  batch_size, sequence_length, hidden_dim = hidden_states.shape
321
  hidden_states = hidden_states.view(-1, hidden_dim)
322
  router_logits: torch.Tensor = self.gate(hidden_states)
323
- routing_weights, selected_experts = torch.topk(router_logits, self.top_k, dim=-1)
 
 
324
  routing_weights = F.softmax(routing_weights, dim=-1, dtype=torch.float32)
325
  routing_weights = routing_weights.to(hidden_states.dtype)
326
 
327
  final_hidden_states = torch.zeros(
328
- (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
 
 
329
  )
330
 
331
  for expert_idx in range(self.num_experts):
@@ -336,12 +390,16 @@ class Moondream3SparseMoeBlock(nn.Module):
336
  continue
337
 
338
  current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
339
- # torch.save(current_state, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_e{expert_idx}")
340
- current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
341
- # torch.save(current_hidden_states, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_e{expert_idx}")
342
- final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
 
 
343
 
344
- final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
 
 
345
  return final_hidden_states, router_logits
346
 
347
 
@@ -358,7 +416,10 @@ class Moondream3DecoderLayer(nn.Module):
358
  if self.is_moe_layer:
359
  self.mlp = Moondream3SparseMoeBlock(config, layer_idx=layer_idx)
360
  else:
361
- self.mlp = Moondream3MLP(self.hidden_size, self.intermediate_size)
 
 
 
362
 
363
  def forward(
364
  self,
@@ -373,16 +434,10 @@ class Moondream3DecoderLayer(nn.Module):
373
  position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
374
  **kwargs,
375
  ) -> Tuple:
376
- residual = hidden_states
377
 
378
- # Apply layer norm like original
379
- l_in = self.input_layernorm(hidden_states)
380
- if DEBUG:
381
- torch.save(l_in, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_ln_out")
382
-
383
- # Attention
384
  hidden_states_attn, self_attn_weights = self.self_attn(
385
- hidden_states=l_in,
386
  attention_mask=attention_mask,
387
  position_ids=position_ids,
388
  past_key_values=past_key_values,
@@ -390,20 +445,19 @@ class Moondream3DecoderLayer(nn.Module):
390
  use_cache=use_cache,
391
  cache_position=cache_position,
392
  position_embeddings=position_embeddings,
393
- **kwargs
394
  )
395
 
396
- # MLP
397
  if self.is_moe_layer:
398
- hidden_states_mlp, router_logits = self.mlp(l_in, cache_position=cache_position)
 
 
399
  else:
400
- hidden_states_mlp = self.mlp(l_in)
401
  router_logits = None
402
- if DEBUG:
403
- torch.save(hidden_states_mlp, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_mlp_out")
404
 
405
  # Add both attention and MLP to residual like original
406
- hidden_states = residual + hidden_states_attn + hidden_states_mlp
407
 
408
  outputs = (hidden_states,)
409
 
@@ -427,14 +481,15 @@ class Moondream3PreTrainedModel(PreTrainedModel):
427
  _supports_cache_class = True
428
 
429
  def _init_weights(self, module):
430
- # Use text_config initializer_range if available, otherwise use default
431
- if hasattr(self.config, 'text_config') and hasattr(self.config.text_config, 'initializer_range'):
 
432
  std = self.config.text_config.initializer_range
433
- elif hasattr(self.config, 'initializer_range'):
434
  std = self.config.initializer_range
435
  else:
436
- std = 0.02 # Default initialization range
437
-
438
  if isinstance(module, nn.Linear):
439
  module.weight.data.normal_(mean=0.0, std=std)
440
  if module.bias is not None:
@@ -453,15 +508,19 @@ class Moondream3TextModel(Moondream3PreTrainedModel):
453
  self.padding_idx = config.pad_token_id if hasattr(config, "pad_token_id") else 0
454
  self.vocab_size = config.vocab_size
455
 
456
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
 
457
  self.layers = nn.ModuleList(
458
- [Moondream3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
 
 
 
459
  )
460
  self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
461
  self.rotary_emb = Moondream3RotaryEmbedding(config=config)
462
  self.gradient_checkpointing = False
463
 
464
-
465
  self.post_init()
466
 
467
  def forward(
@@ -478,19 +537,31 @@ class Moondream3TextModel(Moondream3PreTrainedModel):
478
  return_dict: Optional[bool] = None,
479
  cache_position: Optional[torch.LongTensor] = None,
480
  ) -> Union[Tuple, BaseModelOutputWithPast]:
481
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
482
  output_router_logits = (
483
- output_router_logits if output_router_logits is not None else self.config.output_router_logits
 
 
484
  )
485
  output_hidden_states = (
486
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
487
  )
488
  use_cache = use_cache if use_cache is not None else self.config.use_cache
489
 
490
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
491
 
492
  if (input_ids is None) ^ (inputs_embeds is not None):
493
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one")
 
 
494
 
495
  if inputs_embeds is None:
496
  inputs_embeds = self.embed_tokens(input_ids)
@@ -509,9 +580,13 @@ class Moondream3TextModel(Moondream3PreTrainedModel):
509
  past_key_values = DynamicCache()
510
 
511
  if cache_position is None:
512
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
 
 
513
  cache_position = torch.arange(
514
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
 
 
515
  )
516
 
517
  if position_ids is None:
@@ -538,7 +613,7 @@ class Moondream3TextModel(Moondream3PreTrainedModel):
538
  output_router_logits,
539
  use_cache,
540
  cache_position,
541
- position_embeddings
542
  )
543
  else:
544
  layer_outputs = decoder_layer(
@@ -550,7 +625,7 @@ class Moondream3TextModel(Moondream3PreTrainedModel):
550
  output_router_logits=output_router_logits,
551
  use_cache=use_cache,
552
  cache_position=cache_position,
553
- position_embeddings=position_embeddings
554
  )
555
 
556
  hidden_states = layer_outputs[0]
@@ -571,7 +646,13 @@ class Moondream3TextModel(Moondream3PreTrainedModel):
571
  if not return_dict:
572
  return tuple(
573
  v
574
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
 
 
 
 
 
 
575
  if v is not None
576
  )
577
 
@@ -594,8 +675,14 @@ class Moondream3VisionPatchEmbeddings(nn.Module):
594
  self.grid_size = self.crop_size // self.patch_size
595
  self.num_patches = self.grid_size * self.grid_size
596
 
597
- self.projection = nn.Linear(self.patch_size * self.patch_size * self.num_channels, self.hidden_size, bias=True)
598
- self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_patches, config.hidden_size))
 
 
 
 
 
 
599
 
600
  def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
601
  B, C, H, W = pixel_values.shape
@@ -610,17 +697,23 @@ class Moondream3VisionPatchEmbeddings(nn.Module):
610
  x = self.projection(x)
611
  return x + self.position_embeddings
612
 
 
613
  class Moondream3VisionEncoderLayer(nn.Module):
614
  def __init__(self, config: Moondream3VisionConfig, layer_idx: int):
615
  super().__init__()
616
  self.hidden_size = config.hidden_size
617
  self.intermediate_size = config.intermediate_size
618
  self.layer_idx = layer_idx
619
-
620
- self.self_attn = Moondream3Attention(config, layer_idx=self.layer_idx, use_tau=False)
 
 
621
  self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=1e-5)
622
  self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=1e-5)
623
- self.mlp = Moondream3MLP(hidden_size=self.hidden_size, intermediate_size=self.intermediate_size)
 
 
 
624
 
625
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
626
  residual = hidden_states
@@ -635,6 +728,7 @@ class Moondream3VisionEncoderLayer(nn.Module):
635
 
636
  return hidden_states
637
 
 
638
  class Moondream3VisionModel(Moondream3PreTrainedModel):
639
  config_class = Moondream3VisionConfig
640
  main_input_name = "pixel_values"
@@ -649,9 +743,18 @@ class Moondream3VisionModel(Moondream3PreTrainedModel):
649
  self.proj_out_dim = self.config.proj_out_dim
650
 
651
  self.embeddings = Moondream3VisionPatchEmbeddings(config)
652
- self.layers = nn.ModuleList([Moondream3VisionEncoderLayer(config,layer_idx) for layer_idx in range(self.num_hidden_layers)])
 
 
 
 
 
653
  self.post_layernorm = nn.LayerNorm(self.hidden_size, eps=1e-5)
654
- self.vision_projection = Moondream3MLP(hidden_size=self.hidden_size * 2, intermediate_size=self.proj_inner_dim, out_size=self.proj_out_dim)
 
 
 
 
655
  self.gradient_checkpointing = False
656
  self.post_init()
657
 
@@ -688,7 +791,6 @@ class Moondream3VisionModel(Moondream3PreTrainedModel):
688
  crop_height, crop_width = crops[0].shape[:2]
689
  margin_pixels = overlap_margin * patch_size
690
 
691
- # Calculate output size (only adding margins once)
692
  output_h = (crop_height - 2 * margin_pixels) * tiling_h + 2 * margin_pixels
693
  output_w = (crop_width - 2 * margin_pixels) * tiling_w + 2 * margin_pixels
694
  reconstructed = torch.zeros(
@@ -701,21 +803,16 @@ class Moondream3VisionModel(Moondream3PreTrainedModel):
701
  tile_y = i // tiling_w
702
  tile_x = i % tiling_w
703
 
704
- # For each tile, determine which part to keep
705
- # Keep left margin only for first column
706
  x_start = 0 if tile_x == 0 else margin_pixels
707
- # Keep right margin only for last column
708
  x_end = crop_width if tile_x == tiling_w - 1 else crop_width - margin_pixels
709
- # Keep top margin only for first row
710
  y_start = 0 if tile_y == 0 else margin_pixels
711
- # Keep bottom margin only for last row
712
- y_end = crop_height if tile_y == tiling_h - 1 else crop_height - margin_pixels
 
713
 
714
- # Calculate where this piece belongs in the output
715
  out_x = tile_x * (crop_width - 2 * margin_pixels)
716
  out_y = tile_y * (crop_height - 2 * margin_pixels)
717
 
718
- # Place the piece
719
  reconstructed[
720
  out_y + y_start : out_y + y_end, out_x + x_start : out_x + x_end
721
  ] = crop[y_start:y_end, x_start:x_end]
@@ -725,16 +822,24 @@ class Moondream3VisionModel(Moondream3PreTrainedModel):
725
  def forward(
726
  self,
727
  pixel_values: torch.FloatTensor,
728
- tiling: Tuple[int,int],
729
  output_attentions: Optional[bool] = None,
730
  output_hidden_states: Optional[bool] = None,
731
  return_dict: Optional[bool] = None,
732
  ) -> Union[Tuple, BaseModelOutputWithPast]:
733
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
734
  output_hidden_states = (
735
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
736
  )
737
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
738
 
739
  batch_size, num_crops = pixel_values.shape[:2]
740
  # flatten batch_size and num_crops into same dim
@@ -749,17 +854,19 @@ class Moondream3VisionModel(Moondream3PreTrainedModel):
749
  all_hidden_states += (hidden_states,)
750
 
751
  if self.gradient_checkpointing and self.training:
752
- layer_outputs = self._gradient_checkpointing_func(encoder_layer.__call__, hidden_states)
 
 
753
  else:
754
  layer_outputs = encoder_layer(hidden_states)
755
 
756
  hidden_states = layer_outputs
757
 
758
  hidden_states = self.post_layernorm(hidden_states)
759
- # B, _, _
760
 
761
- # back out into batch_size, num_crops
762
- hidden_states = hidden_states.view(batch_size, num_crops, *hidden_states.shape[1:])
 
763
  outputs = []
764
  for b in range(batch_size):
765
  hs = hidden_states[b]
@@ -782,9 +889,12 @@ class Moondream3VisionModel(Moondream3PreTrainedModel):
782
 
783
  reconstructed = reconstructed.permute(2, 0, 1)
784
  reconstructed = F.adaptive_avg_pool2d(
785
- reconstructed, output_size=(self.num_hidden_layers, self.num_hidden_layers)
 
 
 
 
786
  )
787
- reconstructed = reconstructed.permute(1, 2, 0).view(729, self.hidden_size)
788
  final_features = torch.cat([global_features, reconstructed], dim=-1)
789
  outputs.append(final_features)
790
  output = torch.stack(outputs, 0)
@@ -795,7 +905,11 @@ class Moondream3VisionModel(Moondream3PreTrainedModel):
795
  all_hidden_states += (hidden_states,)
796
 
797
  if not return_dict:
798
- return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
 
 
 
 
799
 
800
  return BaseModelOutputWithPast(
801
  last_hidden_state=hidden_states,
@@ -803,12 +917,13 @@ class Moondream3VisionModel(Moondream3PreTrainedModel):
803
  attentions=all_attentions,
804
  )
805
 
 
806
  class Moondream3RegionEncoder(nn.Module):
807
  def __init__(self, config: Moondream3RegionConfig):
808
  super().__init__()
809
  self.coord_encoder = nn.Linear(config.coord_feat_dim, config.hidden_size)
810
  self.size_encoder = nn.Linear(config.size_feat_dim, config.hidden_size)
811
-
812
  coord_freq = torch.randn(config.coord_feat_dim // 2, 1) * 10.0
813
  size_freq = torch.randn(config.size_feat_dim // 2, 2) * 10.0
814
  self.register_buffer("coord_freq", coord_freq.T)
@@ -826,6 +941,7 @@ class Moondream3RegionEncoder(nn.Module):
826
  fourier_features = self.fourier_features(size, self.size_freq)
827
  return self.size_encoder(fourier_features)
828
 
 
829
  class Moondream3RegionDecoder(nn.Module):
830
  def __init__(self, config: Moondream3RegionConfig):
831
  super().__init__()
@@ -836,7 +952,8 @@ class Moondream3RegionDecoder(nn.Module):
836
  return self.coord_decoder(hidden_state)
837
 
838
  def decode_size(self, hidden_state: torch.Tensor) -> torch.Tensor:
839
- return self.size_decoder(hidden_state).view(hidden_state.shape[0],2,-1)
 
840
 
841
  class Moondream3Model(Moondream3PreTrainedModel):
842
  def __init__(self, config: Moondream3Config):
@@ -844,7 +961,7 @@ class Moondream3Model(Moondream3PreTrainedModel):
844
  self.text_model = Moondream3TextModel(config.text_config)
845
  self.vision_model = Moondream3VisionModel(config.vision_config)
846
  self.vocab_size = config.text_config.vocab_size
847
-
848
  self.region_encoder = Moondream3RegionEncoder(config.region_config)
849
  self.region_decoder = Moondream3RegionDecoder(config.region_config)
850
  self.post_init()
@@ -865,7 +982,7 @@ class Moondream3Model(Moondream3PreTrainedModel):
865
  self,
866
  input_ids: torch.LongTensor = None,
867
  pixel_values: torch.FloatTensor = None,
868
- tiling: Tuple[int,int] = None,
869
  attention_mask: Optional[torch.Tensor] = None,
870
  position_ids: Optional[torch.LongTensor] = None,
871
  past_key_values: Optional[Cache] = None,
@@ -878,11 +995,19 @@ class Moondream3Model(Moondream3PreTrainedModel):
878
  cache_position: Optional[torch.LongTensor] = None,
879
  logits_to_keep: int = 0,
880
  ) -> Union[Tuple, BaseModelOutputWithPast]:
881
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
882
  output_hidden_states = (
883
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
884
  )
885
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
886
 
887
  if (input_ids is not None) == (inputs_embeds is not None):
888
  raise ValueError("Provide exactly one of input_ids or inputs_embeds.")
@@ -890,8 +1015,9 @@ class Moondream3Model(Moondream3PreTrainedModel):
890
  if not ((pixel_values is not None) ^ (tiling is None)):
891
  raise ValueError("You must specify both pixel_values and tiling")
892
 
893
- # Case A: inputs_embeds provided -> assume it already contains BOS+image+text in correct order.
894
- if inputs_embeds is not None and (pixel_values is not None or tiling is not None):
 
895
  raise ValueError(
896
  "When inputs_embeds is provided, do not pass pixel_values/tiling; "
897
  "inputs_embeds must already include BOS+image(+text)."
@@ -904,7 +1030,9 @@ class Moondream3Model(Moondream3PreTrainedModel):
904
  past_key_values = DynamicCache(config=self.config)
905
 
906
  if cache_position is None:
907
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
 
 
908
  cache_position: torch.Tensor = torch.arange(
909
  past_seen_tokens, past_seen_tokens, device=inputs_embeds.device
910
  )
@@ -912,20 +1040,24 @@ class Moondream3Model(Moondream3PreTrainedModel):
912
  if position_ids is None:
913
  position_ids = cache_position.unsqueeze(0)
914
 
915
- def image_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int):
916
- # set all up to `self.config.vision_config.prefix_len` to true
917
- return kv_idx <= q_idx
918
-
919
  if pixel_values is not None:
920
- # Vision embeds
921
- pixel_values = pixel_values.to(dtype=self.vision_model.embeddings.projection.weight.dtype)
922
- image_embeds = self.vision_model(pixel_values, tiling=tiling)["last_hidden_state"] # [B,P,D]
 
 
 
923
  prefix = self.text_model.embed_tokens(
924
- torch.full((input_ids.shape[0], 1), self.config.text_config.bos_token_id, dtype=input_ids.dtype, device=input_ids.device)
 
 
 
 
 
925
  )
926
  embeds = torch.cat([prefix, image_embeds], dim=1)
927
  cache_pos = torch.arange(embeds.shape[-2], device=embeds.device)
928
- pos = cache_pos.unsqueeze(0).expand(embeds.shape[0],-1)
929
  attn_mask = torch.full(
930
  (embeds.shape[0], 1, embeds.shape[-2], pos.shape[-1]),
931
  True,
@@ -949,11 +1081,21 @@ class Moondream3Model(Moondream3PreTrainedModel):
949
  attn_mask = create_causal_mask(
950
  config=self.config,
951
  input_embeds=inputs_embeds,
952
- attention_mask=torch.cat([torch.ones(attention_mask.shape[0], cache_position[-1] + 1 - attention_mask.shape[-1], device=attention_mask.device, dtype=attention_mask.dtype), attention_mask], dim=-1),
 
 
 
 
 
 
 
 
 
 
 
953
  cache_position=cache_position,
954
  past_key_values=past_key_values,
955
  position_ids=position_ids,
956
- and_mask_function=image_mask_function
957
  )
958
 
959
  outputs = self.text_model(
@@ -970,12 +1112,16 @@ class Moondream3Model(Moondream3PreTrainedModel):
970
  )
971
 
972
  if not return_dict:
973
- return tuple(v for v in [
974
- outputs.last_hidden_state,
975
- getattr(outputs, "past_key_values", None),
976
- getattr(outputs, "hidden_states", None),
977
- getattr(outputs, "attentions", None),
978
- ] if v is not None)
 
 
 
 
979
 
980
  return BaseModelOutputWithPast(
981
  last_hidden_state=outputs.last_hidden_state,
@@ -984,9 +1130,10 @@ class Moondream3Model(Moondream3PreTrainedModel):
984
  attentions=getattr(outputs, "attentions", None),
985
  )
986
 
 
987
  @dataclass
988
  class Moondream3GenerateOutput(GenerateDecoderOnlyOutput):
989
- objects: Optional[list[dict[str,float]]] = None
990
 
991
 
992
  class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMixin):
@@ -997,7 +1144,9 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
997
  self.objects = None
998
  self.model = Moondream3Model(config)
999
  self.vocab_size = config.text_config.vocab_size
1000
- self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=True)
 
 
1001
  self.post_init()
1002
 
1003
  def get_input_embeddings(self):
@@ -1023,8 +1172,10 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
1023
  generation_config,
1024
  **kwargs,
1025
  ):
1026
- generation_config = super()._prepare_generated_length(generation_config, **kwargs)
1027
- generation_config.max_length += 730
 
 
1028
  return generation_config
1029
 
1030
  def forward(
@@ -1048,7 +1199,7 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
1048
  if pixel_values is not None and inputs_embeds is None:
1049
  position_ids += self.config.vision_config.prefix_len
1050
  cache_position += self.config.vision_config.prefix_len
1051
- # Get hidden states from the base model (it already builds the multimodal prefix)
1052
  model_outputs = self.model(
1053
  input_ids=input_ids,
1054
  pixel_values=pixel_values,
@@ -1065,9 +1216,7 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
1065
  cache_position=cache_position,
1066
  logits_to_keep=logits_to_keep,
1067
  )
1068
- hidden_states = model_outputs.last_hidden_state # [B, T, D]
1069
-
1070
- # Compute logits; only keep the tail if requested
1071
  if isinstance(logits_to_keep, int) and logits_to_keep > 0:
1072
  hs = hidden_states[:, -logits_to_keep:, :]
1073
  elif isinstance(logits_to_keep, slice):
@@ -1076,21 +1225,24 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
1076
  hs = hidden_states
1077
 
1078
  hs = self.model.text_model.norm(hs)
1079
- logits = self.lm_head(hs) # [B, T', V]
1080
 
1081
  pred = torch.argmax(logits, dim=-1)
1082
 
1083
- pos_ids = position_ids[:,-1:] + 1
1084
  cache_pos = cache_position[-1:] + 1
1085
  mask = torch.ones(
1086
  hidden_states.shape[0], 1, device=self.device, dtype=torch.long
1087
  )
1088
- while torch.any(pred == 5):
1089
- batch_mask = (pred[:, -1] == 5)
 
1090
  hidden_states = hidden_states[:, -1:, :]
1091
  x_logits = self.model.region_decoder.decode_coordinate(hidden_states)
1092
  x_center = torch.argmax(x_logits, dim=-1) / x_logits.size(-1)
1093
- next_embeds = self.model.region_encoder.encode_coordinate(x_center.to(x_logits.dtype)).unsqueeze(1)
 
 
1094
  model_outputs = self.model(
1095
  input_ids=None,
1096
  pixel_values=None,
@@ -1107,16 +1259,18 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
1107
  cache_position=cache_pos,
1108
  logits_to_keep=logits_to_keep,
1109
  )
1110
- hidden_states = model_outputs.last_hidden_state # [B, T, D]
1111
  y_logits = self.model.region_decoder.decode_coordinate(hidden_states)
1112
  y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
1113
- next_embeds = self.model.region_encoder.encode_coordinate(y_center.to(y_logits.dtype)).unsqueeze(1)
 
 
1114
  coords = torch.cat([x_center, y_center], dim=1)
1115
  coords = coords * (batch_mask).unsqueeze(1)
1116
  pos_ids += 1
1117
  cache_pos = cache_pos + 1
1118
  bbox = None
1119
- if input_ids[0,1] == 7235:
1120
  model_outputs = self.model(
1121
  input_ids=None,
1122
  pixel_values=None,
@@ -1133,18 +1287,18 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
1133
  cache_position=cache_pos,
1134
  logits_to_keep=logits_to_keep,
1135
  )
1136
- hidden_states = model_outputs.last_hidden_state # [B, T, D]
1137
  size_logits = self.model.region_decoder.decode_size(hidden_states)
1138
  bins = torch.argmax(size_logits, dim=-1)
1139
- w_bin = bins[:,0]
1140
- h_bin = bins[:,1]
1141
 
1142
  w = torch.pow(2.0, (w_bin.float() / 1023.0) * 10.0 - 10.0)
1143
  h = torch.pow(2.0, (h_bin.float() / 1023.0) * 10.0 - 10.0)
1144
 
1145
  next_embeds = (
1146
  self.model.region_encoder.encode_size(
1147
- torch.stack([w, h],dim=-1).to(size_logits.dtype)
1148
  )
1149
  ).unsqueeze(1)
1150
  x_center = x_center.squeeze(1)
@@ -1155,7 +1309,7 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
1155
  x_center + w / 2,
1156
  y_center + h / 2,
1157
  ]
1158
- bbox = torch.stack(bbox, dim=1)
1159
  bbox = bbox * (batch_mask).unsqueeze(1)
1160
  pos_ids += 1
1161
  cache_pos = cache_pos + 1
@@ -1183,26 +1337,38 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
1183
  )
1184
  pos_ids += 1
1185
  cache_pos = cache_pos + 1
1186
- hidden_states = model_outputs.last_hidden_state # [B, T, D]
1187
 
1188
  indices = torch.tensor(
1189
- [self.config.text_config.coord_token_id, self.config.text_config.eos_token_id],
 
 
 
1190
  device=self.device,
1191
  )
1192
 
1193
  hidden_states = self.model.text_model.norm(hidden_states)
1194
- logits = hidden_states @ self.lm_head.weight[indices].T + self.lm_head.bias[indices]
 
 
 
1195
 
1196
- logits_full = torch.full((logits.shape[0], logits.shape[1], self.config.text_config.vocab_size), float('-inf'), device=logits.device, dtype=logits.dtype)
1197
- logits_full[:, :, torch.tensor([5,0])] = logits
 
 
 
 
 
1198
  logits = logits_full
1199
  pred[batch_mask] = torch.argmax(logits, dim=-1)[batch_mask]
1200
-
1201
 
1202
  loss = None
1203
  if labels is not None:
1204
- # Shift if your training uses standard LM convention; here we assume labels aligned with hs
1205
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size)
 
1206
 
1207
  return CausalLMOutputWithPast(
1208
  loss=loss,
@@ -1214,25 +1380,20 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
1214
 
1215
  def generate(self, **kwargs) -> Union[Moondream3GenerateOutput, torch.LongTensor]:
1216
  outputs = super().generate(**kwargs)
1217
- if len(self.objects) > 0:
1218
  if isinstance(outputs, torch.Tensor):
1219
  outputs = self.objects
1220
- self.objects = []
1221
  else:
1222
- outputs = Moondream3GenerateOutput(
1223
- **outputs,
1224
- objects=self.objects
1225
- )
1226
- self.objects = []
1227
  return outputs
1228
 
1229
- def prepare_inputs_for_generation(
1230
- self,
1231
- input_ids,
1232
- **model_kwargs
1233
- ):
1234
  model_inputs = super().prepare_inputs_for_generation(input_ids, **model_kwargs)
1235
- model_inputs["position_ids"] += model_inputs["cache_position"].unsqueeze(0) - model_inputs["position_ids"]
 
 
1236
  return model_inputs
1237
 
1238
  def _update_model_kwargs_for_generation(
@@ -1252,13 +1413,15 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
1252
  model_kwargs["tiling"] = None
1253
  return model_kwargs
1254
 
1255
-
1256
  @staticmethod
1257
  def _reorder_cache(past_key_values, beam_idx):
1258
  reordered_past = ()
1259
  for layer_past in past_key_values:
1260
  reordered_past += (
1261
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
 
 
 
1262
  )
1263
  return reordered_past
1264
 
 
15
 
16
  from typing import Callable, Optional, Tuple, Union
17
 
 
 
18
  import torch
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
 
 
 
22
  from transformers.activations import ACT2FN
23
  from transformers.cache_utils import Cache, DynamicCache
24
  from transformers.masking_utils import create_causal_mask
 
33
  from transformers.generation import GenerationMixin
34
  from transformers.generation.utils import GenerateDecoderOnlyOutput
35
  from transformers.utils import logging, TransformersKwargs
36
+ from .configuration_moondream3 import (
37
+ Moondream3Config,
38
+ Moondream3TextConfig,
39
+ Moondream3VisionConfig,
40
+ Moondream3RegionConfig,
41
+ )
42
 
43
  logger = logging.get_logger(__name__)
44
 
45
  _CONFIG_FOR_DOC = "Moondream3Config"
46
 
 
 
 
47
 
48
  def apply_rotary_pos_emb(
49
+ q: torch.Tensor,
50
+ k: torch.Tensor,
51
+ cos: torch.Tensor,
52
+ sin: torch.Tensor,
53
  rot_dim: int = 32,
54
  ):
55
  """
56
  Apply rotary position embeddings to query and key tensors.
57
+
58
  Args:
59
  q: Query tensor [batch, num_heads, seq_len, head_dim]
60
  k: Key tensor [batch, num_heads, seq_len, head_dim]
61
  cos: Cosine frequencies [batch, seq_len, rot_dim]
62
  sin: Sine frequencies [batch, seq_len, rot_dim]
63
  rot_dim: Number of dimensions to apply rotation to (default: 32)
64
+
65
  Returns:
66
  Tuple of (rotated_q, rotated_k)
67
  """
 
75
  xq_r, xq_i = x_rot[..., :d_q], x_rot[..., d_q:]
76
 
77
  xq_out_r = xq_r * cos - xq_i * sin
78
+ xq_out_i = xq_r * sin + xq_i * cos
79
 
80
  xq_out = torch.stack((xq_out_r, xq_out_i), dim=-1).flatten(-2)
81
 
82
  return torch.cat([xq_out, x_pass], dim=-1)
83
+
84
  return apply_rope(q), apply_rope(k)
85
 
86
+
87
  class Moondream3RotaryEmbedding(nn.Module):
88
  inv_freq: torch.Tensor
89
 
 
112
  """
113
  Computes the inverse frequencies according to the original RoPE implementation
114
  """
115
+ base = config.rope_parameters["rope_theta"]
116
+ dim = (
117
+ getattr(config, "head_dim", None)
118
+ or config.hidden_size // config.num_attention_heads
119
+ )
120
  dim //= 2
121
+
122
  attention_factor = 1.0
123
 
 
124
  inv_freq = 1.0 / (
125
+ base ** (torch.arange(0, dim, 2, dtype=torch.float32)[: (dim // 2)] / dim)
126
  )
127
  if device is not None:
128
  inv_freq = inv_freq.to(device=device)
 
131
  @torch.no_grad()
132
  @dynamic_rope_update
133
  def forward(self, x, position_ids):
134
+ inv_freq_expanded = (
135
+ self.inv_freq[None, :, None]
136
+ .to(torch.float32)
137
+ .expand(position_ids.shape[0], -1, 1)
138
+ .to(x.device)
139
+ )
140
+ position_ids_expanded = position_ids[:, None, :].to(torch.float32)
141
+
142
+ freqs = (
143
+ inv_freq_expanded.to(torch.float32)
144
+ @ position_ids_expanded.to(torch.float32)
145
+ ).transpose(1, 2)
146
+ cfreqs = (
147
+ torch.exp(1j * freqs)
148
+ .unsqueeze(1)
149
+ .expand(-1, self.config.num_attention_heads, -1, -1)
150
+ )
151
 
152
  return cfreqs.real, cfreqs.imag
153
 
154
 
155
  class Moondream3Attention(nn.Module):
156
+ def __init__(
157
+ self,
158
+ config: Moondream3TextConfig | Moondream3VisionConfig,
159
+ layer_idx: Optional[int] = None,
160
+ use_tau: bool = True,
161
+ ):
162
  super().__init__()
163
  self.config = config
164
  self.layer_idx = layer_idx
165
  self.hidden_size = config.hidden_size
166
  self.num_heads = config.num_attention_heads
167
  self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
168
+ self.num_key_value_heads = getattr(
169
+ config, "num_key_value_heads", self.num_heads
170
+ )
171
  attention_bias = config.attention_bias
172
  self.attention_dropout = config.attention_dropout
173
 
 
174
  if isinstance(config, Moondream3TextConfig):
175
  self.is_causal = True
176
+ elif isinstance(config, Moondream3VisionConfig):
177
  self.is_causal = False
178
  else:
179
  raise TypeError(f"Unsupported config type: {type(config)}")
180
+
181
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
182
  self.use_tau = use_tau
183
 
 
187
  f" and `num_heads`: {self.num_heads})."
188
  )
189
 
190
+ self.q_proj = nn.Linear(
191
+ self.hidden_size, self.num_heads * self.head_dim, bias=attention_bias
192
+ )
193
+ self.k_proj = nn.Linear(
194
+ self.hidden_size,
195
+ self.num_key_value_heads * self.head_dim,
196
+ bias=attention_bias,
197
+ )
198
+ self.v_proj = nn.Linear(
199
+ self.hidden_size,
200
+ self.num_key_value_heads * self.head_dim,
201
+ bias=attention_bias,
202
+ )
203
+ self.o_proj = nn.Linear(
204
+ self.num_heads * self.head_dim, self.hidden_size, bias=attention_bias
205
+ )
206
+
207
  if self.use_tau:
208
  # In original, tau weights are (n_heads, qkv_dim) where qkv_dim is the combined QKV dimension
209
+ qkv_dim = (
210
+ self.num_heads * self.head_dim
211
+ + 2 * self.num_key_value_heads * self.head_dim
212
+ )
213
  self.tau_wq = nn.Linear(qkv_dim, self.num_heads, bias=False)
214
  self.tau_wv = nn.Linear(qkv_dim, self.num_heads, bias=False)
215
  self.tau_alpha = nn.Parameter(torch.empty(self.num_heads))
 
227
  **kwargs,
228
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
229
  input_shape = hidden_states.shape[:-1]
230
+
 
231
  bsz, q_len, _ = hidden_states.size()
232
 
 
233
  query_states = self.q_proj(hidden_states)
234
  key_states = self.k_proj(hidden_states)
235
  value_states = self.v_proj(hidden_states)
236
  if self.use_tau:
237
  qkv_out = torch.cat([query_states, key_states, value_states], dim=-1)
238
  tok_feat = F.gelu(qkv_out)
239
+ tok_q = torch.tanh(self.tau_wq(tok_feat)).permute(0, 2, 1)
240
+ tok_v = torch.tanh(self.tau_wv(tok_feat)).permute(0, 2, 1)
241
 
242
  pos = position_ids.to(tok_q.dtype) + 1
243
  alpha = self.tau_alpha.to(tok_q.dtype)
244
+ tau_pos = 1 + (
245
+ torch.sigmoid(alpha[None, :, None] * pos[:, None, :].log()) - 0.5
246
+ )
247
+ tau_q = (tok_q + tau_pos).unsqueeze(-1)
248
+ tau_v = (tok_v + tau_pos).unsqueeze(-1)
249
+
250
+ query_states = query_states.view(
251
+ bsz, q_len, self.num_heads, self.head_dim
252
+ ).transpose(1, 2)
253
+ key_states = key_states.view(
254
+ bsz, q_len, self.num_key_value_heads, self.head_dim
255
+ ).transpose(1, 2)
256
+ value_states = value_states.view(
257
+ bsz, q_len, self.num_key_value_heads, self.head_dim
258
+ ).transpose(1, 2)
259
 
260
  if self.use_tau:
261
  query_states = query_states * tau_q
262
 
263
  if self.num_key_value_groups > 1:
264
+ tau_v_repeated = tau_v.repeat(1, self.num_key_value_groups, 1, 1)[
265
+ :, : self.num_key_value_heads, :, :
266
+ ]
267
  else:
268
  tau_v_repeated = tau_v
269
  value_states = value_states * tau_v_repeated
270
 
 
 
 
 
271
  cos, sin = None, None
272
  if position_embeddings is not None:
273
  cos, sin = position_embeddings
 
 
 
 
274
 
275
+ query_states, key_states = apply_rotary_pos_emb(
276
+ query_states, key_states, cos, sin
277
+ )
278
+
279
+ query_states, key_states = (
280
+ query_states.to(value_states.dtype),
281
+ key_states.to(value_states.dtype),
282
+ )
283
 
284
  if past_key_values is not None:
285
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
286
+ key_states, value_states = past_key_values.update(
287
+ key_states, value_states, self.layer_idx, cache_kwargs
288
+ )
 
 
289
 
290
  query_states = query_states.contiguous()
291
  key_states = key_states.contiguous()
 
303
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
304
  attn_output = self.o_proj(attn_output)
305
 
 
 
 
306
  return attn_output, attn_weights
307
 
308
+
309
  class Moondream3MLP(nn.Module):
310
+ def __init__(
311
+ self,
312
+ hidden_size: int,
313
+ intermediate_size: int,
314
+ hidden_act: str = "gelu_pytorch_tanh",
315
+ out_size: int | None = None,
316
+ gated: bool = False,
317
+ bias: bool = True,
318
+ ):
319
  super().__init__()
320
  self.hidden_size = hidden_size
321
  self.intermediate_size = intermediate_size
 
326
  self.down_proj = nn.Linear(self.intermediate_size, self.out_size, bias=bias)
327
  self.gate_proj = None
328
  if self.gated:
329
+ self.gate_proj = nn.Linear(
330
+ self.hidden_size, self.intermediate_size, bias=bias
331
+ )
332
  self.act_fn = ACT2FN[self.hidden_act]
333
 
334
  def forward(self, x) -> torch.Tensor:
335
  if self.gated:
336
+ h = self.up_proj(x)
337
+ g = self.gate_proj(x)
 
 
338
  x = self.act_fn(h) * (g + 1)
339
  else:
340
  x = self.act_fn(self.up_proj(x))
 
342
 
343
 
344
  class Moondream3SparseMoeBlock(nn.Module):
345
+ def __init__(self, config: Moondream3TextConfig, layer_idx=None):
346
  super().__init__()
347
  self.layer_idx = layer_idx
348
  self.hidden_size = config.hidden_size
 
351
  self.top_k = config.num_experts_per_tok
352
 
353
  self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=True)
354
+ self.experts = nn.ModuleList(
355
+ [
356
+ Moondream3MLP(
357
+ hidden_size=self.hidden_size,
358
+ intermediate_size=self.moe_intermediate_size,
359
+ gated=True,
360
+ bias=False,
361
+ hidden_act="gelu"
362
+ )
363
+ for _ in range(self.num_experts)
364
+ ]
365
+ )
366
 
367
+ def forward(
368
+ self, hidden_states: torch.Tensor, cache_position=None
369
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
370
  batch_size, sequence_length, hidden_dim = hidden_states.shape
371
  hidden_states = hidden_states.view(-1, hidden_dim)
372
  router_logits: torch.Tensor = self.gate(hidden_states)
373
+ routing_weights, selected_experts = torch.topk(
374
+ router_logits, self.top_k, dim=-1
375
+ )
376
  routing_weights = F.softmax(routing_weights, dim=-1, dtype=torch.float32)
377
  routing_weights = routing_weights.to(hidden_states.dtype)
378
 
379
  final_hidden_states = torch.zeros(
380
+ (batch_size * sequence_length, hidden_dim),
381
+ dtype=hidden_states.dtype,
382
+ device=hidden_states.device,
383
  )
384
 
385
  for expert_idx in range(self.num_experts):
 
390
  continue
391
 
392
  current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
393
+ current_hidden_states = (
394
+ expert_layer(current_state) * routing_weights[top_x, idx, None]
395
+ )
396
+ final_hidden_states.index_add_(
397
+ 0, top_x, current_hidden_states.to(hidden_states.dtype)
398
+ )
399
 
400
+ final_hidden_states = final_hidden_states.reshape(
401
+ batch_size, sequence_length, hidden_dim
402
+ )
403
  return final_hidden_states, router_logits
404
 
405
 
 
416
  if self.is_moe_layer:
417
  self.mlp = Moondream3SparseMoeBlock(config, layer_idx=layer_idx)
418
  else:
419
+ self.mlp = Moondream3MLP(
420
+ self.hidden_size,
421
+ self.intermediate_size,
422
+ )
423
 
424
  def forward(
425
  self,
 
434
  position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
435
  **kwargs,
436
  ) -> Tuple:
437
+ hidden_states_ln = self.input_layernorm(hidden_states)
438
 
 
 
 
 
 
 
439
  hidden_states_attn, self_attn_weights = self.self_attn(
440
+ hidden_states=hidden_states_ln,
441
  attention_mask=attention_mask,
442
  position_ids=position_ids,
443
  past_key_values=past_key_values,
 
445
  use_cache=use_cache,
446
  cache_position=cache_position,
447
  position_embeddings=position_embeddings,
448
+ **kwargs,
449
  )
450
 
 
451
  if self.is_moe_layer:
452
+ hidden_states_mlp, router_logits = self.mlp(
453
+ hidden_states_ln, cache_position=cache_position
454
+ )
455
  else:
456
+ hidden_states_mlp = self.mlp(hidden_states_ln)
457
  router_logits = None
 
 
458
 
459
  # Add both attention and MLP to residual like original
460
+ hidden_states = hidden_states + hidden_states_attn + hidden_states_mlp
461
 
462
  outputs = (hidden_states,)
463
 
 
481
  _supports_cache_class = True
482
 
483
  def _init_weights(self, module):
484
+ if hasattr(self.config, "text_config") and hasattr(
485
+ self.config.text_config, "initializer_range"
486
+ ):
487
  std = self.config.text_config.initializer_range
488
+ elif hasattr(self.config, "initializer_range"):
489
  std = self.config.initializer_range
490
  else:
491
+ std = 0.02
492
+
493
  if isinstance(module, nn.Linear):
494
  module.weight.data.normal_(mean=0.0, std=std)
495
  if module.bias is not None:
 
508
  self.padding_idx = config.pad_token_id if hasattr(config, "pad_token_id") else 0
509
  self.vocab_size = config.vocab_size
510
 
511
+ self.embed_tokens = nn.Embedding(
512
+ config.vocab_size, config.hidden_size, self.padding_idx
513
+ )
514
  self.layers = nn.ModuleList(
515
+ [
516
+ Moondream3DecoderLayer(config, layer_idx)
517
+ for layer_idx in range(config.num_hidden_layers)
518
+ ]
519
  )
520
  self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
521
  self.rotary_emb = Moondream3RotaryEmbedding(config=config)
522
  self.gradient_checkpointing = False
523
 
 
524
  self.post_init()
525
 
526
  def forward(
 
537
  return_dict: Optional[bool] = None,
538
  cache_position: Optional[torch.LongTensor] = None,
539
  ) -> Union[Tuple, BaseModelOutputWithPast]:
540
+ output_attentions = (
541
+ output_attentions
542
+ if output_attentions is not None
543
+ else self.config.output_attentions
544
+ )
545
  output_router_logits = (
546
+ output_router_logits
547
+ if output_router_logits is not None
548
+ else self.config.output_router_logits
549
  )
550
  output_hidden_states = (
551
+ output_hidden_states
552
+ if output_hidden_states is not None
553
+ else self.config.output_hidden_states
554
  )
555
  use_cache = use_cache if use_cache is not None else self.config.use_cache
556
 
557
+ return_dict = (
558
+ return_dict if return_dict is not None else self.config.use_return_dict
559
+ )
560
 
561
  if (input_ids is None) ^ (inputs_embeds is not None):
562
+ raise ValueError(
563
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
564
+ )
565
 
566
  if inputs_embeds is None:
567
  inputs_embeds = self.embed_tokens(input_ids)
 
580
  past_key_values = DynamicCache()
581
 
582
  if cache_position is None:
583
+ past_seen_tokens = (
584
+ past_key_values.get_seq_length() if past_key_values is not None else 0
585
+ )
586
  cache_position = torch.arange(
587
+ past_seen_tokens,
588
+ past_seen_tokens + inputs_embeds.shape[1],
589
+ device=inputs_embeds.device,
590
  )
591
 
592
  if position_ids is None:
 
613
  output_router_logits,
614
  use_cache,
615
  cache_position,
616
+ position_embeddings,
617
  )
618
  else:
619
  layer_outputs = decoder_layer(
 
625
  output_router_logits=output_router_logits,
626
  use_cache=use_cache,
627
  cache_position=cache_position,
628
+ position_embeddings=position_embeddings,
629
  )
630
 
631
  hidden_states = layer_outputs[0]
 
646
  if not return_dict:
647
  return tuple(
648
  v
649
+ for v in [
650
+ hidden_states,
651
+ next_cache,
652
+ all_hidden_states,
653
+ all_self_attns,
654
+ all_router_logits,
655
+ ]
656
  if v is not None
657
  )
658
 
 
675
  self.grid_size = self.crop_size // self.patch_size
676
  self.num_patches = self.grid_size * self.grid_size
677
 
678
+ self.projection = nn.Linear(
679
+ self.patch_size * self.patch_size * self.num_channels,
680
+ self.hidden_size,
681
+ bias=True,
682
+ )
683
+ self.position_embeddings = nn.Parameter(
684
+ torch.zeros(1, self.num_patches, config.hidden_size)
685
+ )
686
 
687
  def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
688
  B, C, H, W = pixel_values.shape
 
697
  x = self.projection(x)
698
  return x + self.position_embeddings
699
 
700
+
701
  class Moondream3VisionEncoderLayer(nn.Module):
702
  def __init__(self, config: Moondream3VisionConfig, layer_idx: int):
703
  super().__init__()
704
  self.hidden_size = config.hidden_size
705
  self.intermediate_size = config.intermediate_size
706
  self.layer_idx = layer_idx
707
+
708
+ self.self_attn = Moondream3Attention(
709
+ config, layer_idx=self.layer_idx, use_tau=False
710
+ )
711
  self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=1e-5)
712
  self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=1e-5)
713
+ self.mlp = Moondream3MLP(
714
+ hidden_size=self.hidden_size,
715
+ intermediate_size=self.intermediate_size,
716
+ )
717
 
718
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
719
  residual = hidden_states
 
728
 
729
  return hidden_states
730
 
731
+
732
  class Moondream3VisionModel(Moondream3PreTrainedModel):
733
  config_class = Moondream3VisionConfig
734
  main_input_name = "pixel_values"
 
743
  self.proj_out_dim = self.config.proj_out_dim
744
 
745
  self.embeddings = Moondream3VisionPatchEmbeddings(config)
746
+ self.layers = nn.ModuleList(
747
+ [
748
+ Moondream3VisionEncoderLayer(config, layer_idx)
749
+ for layer_idx in range(self.num_hidden_layers)
750
+ ]
751
+ )
752
  self.post_layernorm = nn.LayerNorm(self.hidden_size, eps=1e-5)
753
+ self.vision_projection = Moondream3MLP(
754
+ hidden_size=self.hidden_size * 2,
755
+ intermediate_size=self.proj_inner_dim,
756
+ out_size=self.proj_out_dim,
757
+ )
758
  self.gradient_checkpointing = False
759
  self.post_init()
760
 
 
791
  crop_height, crop_width = crops[0].shape[:2]
792
  margin_pixels = overlap_margin * patch_size
793
 
 
794
  output_h = (crop_height - 2 * margin_pixels) * tiling_h + 2 * margin_pixels
795
  output_w = (crop_width - 2 * margin_pixels) * tiling_w + 2 * margin_pixels
796
  reconstructed = torch.zeros(
 
803
  tile_y = i // tiling_w
804
  tile_x = i % tiling_w
805
 
 
 
806
  x_start = 0 if tile_x == 0 else margin_pixels
 
807
  x_end = crop_width if tile_x == tiling_w - 1 else crop_width - margin_pixels
 
808
  y_start = 0 if tile_y == 0 else margin_pixels
809
+ y_end = (
810
+ crop_height if tile_y == tiling_h - 1 else crop_height - margin_pixels
811
+ )
812
 
 
813
  out_x = tile_x * (crop_width - 2 * margin_pixels)
814
  out_y = tile_y * (crop_height - 2 * margin_pixels)
815
 
 
816
  reconstructed[
817
  out_y + y_start : out_y + y_end, out_x + x_start : out_x + x_end
818
  ] = crop[y_start:y_end, x_start:x_end]
 
822
  def forward(
823
  self,
824
  pixel_values: torch.FloatTensor,
825
+ tiling: Tuple[int, int],
826
  output_attentions: Optional[bool] = None,
827
  output_hidden_states: Optional[bool] = None,
828
  return_dict: Optional[bool] = None,
829
  ) -> Union[Tuple, BaseModelOutputWithPast]:
830
+ output_attentions = (
831
+ output_attentions
832
+ if output_attentions is not None
833
+ else self.config.output_attentions
834
+ )
835
  output_hidden_states = (
836
+ output_hidden_states
837
+ if output_hidden_states is not None
838
+ else self.config.output_hidden_states
839
+ )
840
+ return_dict = (
841
+ return_dict if return_dict is not None else self.config.use_return_dict
842
  )
 
843
 
844
  batch_size, num_crops = pixel_values.shape[:2]
845
  # flatten batch_size and num_crops into same dim
 
854
  all_hidden_states += (hidden_states,)
855
 
856
  if self.gradient_checkpointing and self.training:
857
+ layer_outputs = self._gradient_checkpointing_func(
858
+ encoder_layer.__call__, hidden_states
859
+ )
860
  else:
861
  layer_outputs = encoder_layer(hidden_states)
862
 
863
  hidden_states = layer_outputs
864
 
865
  hidden_states = self.post_layernorm(hidden_states)
 
866
 
867
+ hidden_states = hidden_states.view(
868
+ batch_size, num_crops, *hidden_states.shape[1:]
869
+ )
870
  outputs = []
871
  for b in range(batch_size):
872
  hs = hidden_states[b]
 
889
 
890
  reconstructed = reconstructed.permute(2, 0, 1)
891
  reconstructed = F.adaptive_avg_pool2d(
892
+ reconstructed,
893
+ output_size=(self.num_hidden_layers, self.num_hidden_layers),
894
+ )
895
+ reconstructed = reconstructed.permute(1, 2, 0).view(
896
+ self.num_hidden_layers * self.num_hidden_layers, self.hidden_size
897
  )
 
898
  final_features = torch.cat([global_features, reconstructed], dim=-1)
899
  outputs.append(final_features)
900
  output = torch.stack(outputs, 0)
 
905
  all_hidden_states += (hidden_states,)
906
 
907
  if not return_dict:
908
+ return tuple(
909
+ v
910
+ for v in [hidden_states, all_hidden_states, all_attentions]
911
+ if v is not None
912
+ )
913
 
914
  return BaseModelOutputWithPast(
915
  last_hidden_state=hidden_states,
 
917
  attentions=all_attentions,
918
  )
919
 
920
+
921
  class Moondream3RegionEncoder(nn.Module):
922
  def __init__(self, config: Moondream3RegionConfig):
923
  super().__init__()
924
  self.coord_encoder = nn.Linear(config.coord_feat_dim, config.hidden_size)
925
  self.size_encoder = nn.Linear(config.size_feat_dim, config.hidden_size)
926
+
927
  coord_freq = torch.randn(config.coord_feat_dim // 2, 1) * 10.0
928
  size_freq = torch.randn(config.size_feat_dim // 2, 2) * 10.0
929
  self.register_buffer("coord_freq", coord_freq.T)
 
941
  fourier_features = self.fourier_features(size, self.size_freq)
942
  return self.size_encoder(fourier_features)
943
 
944
+
945
  class Moondream3RegionDecoder(nn.Module):
946
  def __init__(self, config: Moondream3RegionConfig):
947
  super().__init__()
 
952
  return self.coord_decoder(hidden_state)
953
 
954
  def decode_size(self, hidden_state: torch.Tensor) -> torch.Tensor:
955
+ return self.size_decoder(hidden_state).view(hidden_state.shape[0], 2, -1)
956
+
957
 
958
  class Moondream3Model(Moondream3PreTrainedModel):
959
  def __init__(self, config: Moondream3Config):
 
961
  self.text_model = Moondream3TextModel(config.text_config)
962
  self.vision_model = Moondream3VisionModel(config.vision_config)
963
  self.vocab_size = config.text_config.vocab_size
964
+
965
  self.region_encoder = Moondream3RegionEncoder(config.region_config)
966
  self.region_decoder = Moondream3RegionDecoder(config.region_config)
967
  self.post_init()
 
982
  self,
983
  input_ids: torch.LongTensor = None,
984
  pixel_values: torch.FloatTensor = None,
985
+ tiling: Tuple[int, int] = None,
986
  attention_mask: Optional[torch.Tensor] = None,
987
  position_ids: Optional[torch.LongTensor] = None,
988
  past_key_values: Optional[Cache] = None,
 
995
  cache_position: Optional[torch.LongTensor] = None,
996
  logits_to_keep: int = 0,
997
  ) -> Union[Tuple, BaseModelOutputWithPast]:
998
+ output_attentions = (
999
+ output_attentions
1000
+ if output_attentions is not None
1001
+ else self.config.output_attentions
1002
+ )
1003
  output_hidden_states = (
1004
+ output_hidden_states
1005
+ if output_hidden_states is not None
1006
+ else self.config.output_hidden_states
1007
+ )
1008
+ return_dict = (
1009
+ return_dict if return_dict is not None else self.config.use_return_dict
1010
  )
 
1011
 
1012
  if (input_ids is not None) == (inputs_embeds is not None):
1013
  raise ValueError("Provide exactly one of input_ids or inputs_embeds.")
 
1015
  if not ((pixel_values is not None) ^ (tiling is None)):
1016
  raise ValueError("You must specify both pixel_values and tiling")
1017
 
1018
+ if inputs_embeds is not None and (
1019
+ pixel_values is not None or tiling is not None
1020
+ ):
1021
  raise ValueError(
1022
  "When inputs_embeds is provided, do not pass pixel_values/tiling; "
1023
  "inputs_embeds must already include BOS+image(+text)."
 
1030
  past_key_values = DynamicCache(config=self.config)
1031
 
1032
  if cache_position is None:
1033
+ past_seen_tokens = (
1034
+ past_key_values.get_seq_length() if past_key_values is not None else 0
1035
+ )
1036
  cache_position: torch.Tensor = torch.arange(
1037
  past_seen_tokens, past_seen_tokens, device=inputs_embeds.device
1038
  )
 
1040
  if position_ids is None:
1041
  position_ids = cache_position.unsqueeze(0)
1042
 
 
 
 
 
1043
  if pixel_values is not None:
1044
+ pixel_values = pixel_values.to(
1045
+ dtype=self.vision_model.embeddings.projection.weight.dtype
1046
+ )
1047
+ image_embeds = self.vision_model(pixel_values, tiling=tiling)[
1048
+ "last_hidden_state"
1049
+ ]
1050
  prefix = self.text_model.embed_tokens(
1051
+ torch.full(
1052
+ (input_ids.shape[0], 1),
1053
+ 0,
1054
+ dtype=input_ids.dtype,
1055
+ device=input_ids.device,
1056
+ )
1057
  )
1058
  embeds = torch.cat([prefix, image_embeds], dim=1)
1059
  cache_pos = torch.arange(embeds.shape[-2], device=embeds.device)
1060
+ pos = cache_pos.unsqueeze(0).expand(embeds.shape[0], -1)
1061
  attn_mask = torch.full(
1062
  (embeds.shape[0], 1, embeds.shape[-2], pos.shape[-1]),
1063
  True,
 
1081
  attn_mask = create_causal_mask(
1082
  config=self.config,
1083
  input_embeds=inputs_embeds,
1084
+ attention_mask=torch.cat(
1085
+ [
1086
+ torch.ones(
1087
+ attention_mask.shape[0],
1088
+ cache_position[-1] + 1 - attention_mask.shape[-1],
1089
+ device=attention_mask.device,
1090
+ dtype=attention_mask.dtype,
1091
+ ),
1092
+ attention_mask,
1093
+ ],
1094
+ dim=-1,
1095
+ ),
1096
  cache_position=cache_position,
1097
  past_key_values=past_key_values,
1098
  position_ids=position_ids,
 
1099
  )
1100
 
1101
  outputs = self.text_model(
 
1112
  )
1113
 
1114
  if not return_dict:
1115
+ return tuple(
1116
+ v
1117
+ for v in [
1118
+ outputs.last_hidden_state,
1119
+ getattr(outputs, "past_key_values", None),
1120
+ getattr(outputs, "hidden_states", None),
1121
+ getattr(outputs, "attentions", None),
1122
+ ]
1123
+ if v is not None
1124
+ )
1125
 
1126
  return BaseModelOutputWithPast(
1127
  last_hidden_state=outputs.last_hidden_state,
 
1130
  attentions=getattr(outputs, "attentions", None),
1131
  )
1132
 
1133
+
1134
  @dataclass
1135
  class Moondream3GenerateOutput(GenerateDecoderOnlyOutput):
1136
+ objects: Optional[list[dict[str, float]]] = None
1137
 
1138
 
1139
  class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMixin):
 
1144
  self.objects = None
1145
  self.model = Moondream3Model(config)
1146
  self.vocab_size = config.text_config.vocab_size
1147
+ self.lm_head = nn.Linear(
1148
+ config.text_config.hidden_size, config.text_config.vocab_size, bias=True
1149
+ )
1150
  self.post_init()
1151
 
1152
  def get_input_embeddings(self):
 
1172
  generation_config,
1173
  **kwargs,
1174
  ):
1175
+ generation_config = super()._prepare_generated_length(
1176
+ generation_config, **kwargs
1177
+ )
1178
+ generation_config.max_length += self.config.vision_config.prefix_len
1179
  return generation_config
1180
 
1181
  def forward(
 
1199
  if pixel_values is not None and inputs_embeds is None:
1200
  position_ids += self.config.vision_config.prefix_len
1201
  cache_position += self.config.vision_config.prefix_len
1202
+
1203
  model_outputs = self.model(
1204
  input_ids=input_ids,
1205
  pixel_values=pixel_values,
 
1216
  cache_position=cache_position,
1217
  logits_to_keep=logits_to_keep,
1218
  )
1219
+ hidden_states = model_outputs.last_hidden_state
 
 
1220
  if isinstance(logits_to_keep, int) and logits_to_keep > 0:
1221
  hs = hidden_states[:, -logits_to_keep:, :]
1222
  elif isinstance(logits_to_keep, slice):
 
1225
  hs = hidden_states
1226
 
1227
  hs = self.model.text_model.norm(hs)
1228
+ logits = self.lm_head(hs)
1229
 
1230
  pred = torch.argmax(logits, dim=-1)
1231
 
1232
+ pos_ids = position_ids[:, -1:] + 1
1233
  cache_pos = cache_position[-1:] + 1
1234
  mask = torch.ones(
1235
  hidden_states.shape[0], 1, device=self.device, dtype=torch.long
1236
  )
1237
+ is_processing_point = torch.any(pred == 5)
1238
+ while is_processing_point:
1239
+ batch_mask = pred[:, -1] == 5
1240
  hidden_states = hidden_states[:, -1:, :]
1241
  x_logits = self.model.region_decoder.decode_coordinate(hidden_states)
1242
  x_center = torch.argmax(x_logits, dim=-1) / x_logits.size(-1)
1243
+ next_embeds = self.model.region_encoder.encode_coordinate(
1244
+ x_center.to(x_logits.dtype)
1245
+ ).unsqueeze(1)
1246
  model_outputs = self.model(
1247
  input_ids=None,
1248
  pixel_values=None,
 
1259
  cache_position=cache_pos,
1260
  logits_to_keep=logits_to_keep,
1261
  )
1262
+ hidden_states = model_outputs.last_hidden_state
1263
  y_logits = self.model.region_decoder.decode_coordinate(hidden_states)
1264
  y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
1265
+ next_embeds = self.model.region_encoder.encode_coordinate(
1266
+ y_center.to(y_logits.dtype)
1267
+ ).unsqueeze(1)
1268
  coords = torch.cat([x_center, y_center], dim=1)
1269
  coords = coords * (batch_mask).unsqueeze(1)
1270
  pos_ids += 1
1271
  cache_pos = cache_pos + 1
1272
  bbox = None
1273
+ if input_ids.shape[-1] > 1 and input_ids[0, 1] == 7235:
1274
  model_outputs = self.model(
1275
  input_ids=None,
1276
  pixel_values=None,
 
1287
  cache_position=cache_pos,
1288
  logits_to_keep=logits_to_keep,
1289
  )
1290
+ hidden_states = model_outputs.last_hidden_state
1291
  size_logits = self.model.region_decoder.decode_size(hidden_states)
1292
  bins = torch.argmax(size_logits, dim=-1)
1293
+ w_bin = bins[:, 0]
1294
+ h_bin = bins[:, 1]
1295
 
1296
  w = torch.pow(2.0, (w_bin.float() / 1023.0) * 10.0 - 10.0)
1297
  h = torch.pow(2.0, (h_bin.float() / 1023.0) * 10.0 - 10.0)
1298
 
1299
  next_embeds = (
1300
  self.model.region_encoder.encode_size(
1301
+ torch.stack([w, h], dim=-1).to(size_logits.dtype)
1302
  )
1303
  ).unsqueeze(1)
1304
  x_center = x_center.squeeze(1)
 
1309
  x_center + w / 2,
1310
  y_center + h / 2,
1311
  ]
1312
+ bbox = torch.stack(bbox, dim=1) # shape (B, 4)
1313
  bbox = bbox * (batch_mask).unsqueeze(1)
1314
  pos_ids += 1
1315
  cache_pos = cache_pos + 1
 
1337
  )
1338
  pos_ids += 1
1339
  cache_pos = cache_pos + 1
1340
+ hidden_states = model_outputs.last_hidden_state
1341
 
1342
  indices = torch.tensor(
1343
+ [
1344
+ self.config.text_config.coord_token_id,
1345
+ 0,
1346
+ ],
1347
  device=self.device,
1348
  )
1349
 
1350
  hidden_states = self.model.text_model.norm(hidden_states)
1351
+ logits = (
1352
+ hidden_states @ self.lm_head.weight[indices].T
1353
+ + self.lm_head.bias[indices]
1354
+ )
1355
 
1356
+ logits_full = torch.full(
1357
+ (logits.shape[0], logits.shape[1], self.config.text_config.vocab_size),
1358
+ float("-inf"),
1359
+ device=logits.device,
1360
+ dtype=logits.dtype,
1361
+ )
1362
+ logits_full[:, :, torch.tensor([5, 0])] = logits
1363
  logits = logits_full
1364
  pred[batch_mask] = torch.argmax(logits, dim=-1)[batch_mask]
1365
+ is_processing_point = torch.any(pred == 5)
1366
 
1367
  loss = None
1368
  if labels is not None:
1369
+ loss = self.loss_function(
1370
+ logits=logits, labels=labels, vocab_size=self.vocab_size
1371
+ )
1372
 
1373
  return CausalLMOutputWithPast(
1374
  loss=loss,
 
1380
 
1381
  def generate(self, **kwargs) -> Union[Moondream3GenerateOutput, torch.LongTensor]:
1382
  outputs = super().generate(**kwargs)
1383
+ if len(self.objects if self.objects is not None else []) > 0:
1384
  if isinstance(outputs, torch.Tensor):
1385
  outputs = self.objects
1386
+ self.objects = None
1387
  else:
1388
+ outputs = Moondream3GenerateOutput(**outputs, objects=self.objects)
1389
+ self.objects = None
 
 
 
1390
  return outputs
1391
 
1392
+ def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
 
 
 
 
1393
  model_inputs = super().prepare_inputs_for_generation(input_ids, **model_kwargs)
1394
+ model_inputs["position_ids"] += (
1395
+ model_inputs["cache_position"].unsqueeze(0) - model_inputs["position_ids"]
1396
+ )
1397
  return model_inputs
1398
 
1399
  def _update_model_kwargs_for_generation(
 
1413
  model_kwargs["tiling"] = None
1414
  return model_kwargs
1415
 
 
1416
  @staticmethod
1417
  def _reorder_cache(past_key_values, beam_idx):
1418
  reordered_past = ()
1419
  for layer_past in past_key_values:
1420
  reordered_past += (
1421
+ tuple(
1422
+ past_state.index_select(0, beam_idx.to(past_state.device))
1423
+ for past_state in layer_past
1424
+ ),
1425
  )
1426
  return reordered_past
1427