NyxKrage commited on
Commit
c466d58
·
verified ·
1 Parent(s): ca700c7

Upload folder using huggingface_hub

Browse files
chat_template.jinja CHANGED
@@ -57,15 +57,18 @@
57
  {{ raise_exception("caption length must be one of: short, normal, long.") }}
58
  {%- endif -%}
59
  <|md_reserved_0|>describe<|md_reserved_1|>{{ length }}<|md_reserved_2|>
 
 
 
60
  {%- elif lower.startswith('query:') -%}
61
  {% set q = text[6:] | trim -%}
62
  <|md_reserved_0|>query<|md_reserved_1|>{{ q }}<|md_reserved_2|>
63
  {%- elif lower.startswith('detect:') -%}
64
  {% set q = text[7:] | trim -%}
65
- <|md_reserved_0|>det<|md_reserved_1|>{{ q }}<|md_reserved_2|>
66
  {%- elif lower.startswith('point:') -%}
67
  {% set q = text[6:] | trim -%}
68
- <|md_reserved_0|>point<|md_reserved_1|>{{ q }}<|md_reserved_2|>
69
  {%- else -%}
70
  {% set q = text -%}
71
  <|md_reserved_0|>query<|md_reserved_1|>{{ q }}<|md_reserved_2|>
 
57
  {{ raise_exception("caption length must be one of: short, normal, long.") }}
58
  {%- endif -%}
59
  <|md_reserved_0|>describe<|md_reserved_1|>{{ length }}<|md_reserved_2|>
60
+ {%- elif lower.startswith('reason:') -%}
61
+ {% set q = text[7:] | trim -%}
62
+ <|md_reserved_0|>query<|md_reserved_1|>{{ q }}<|md_reserved_2|><|md_reserved_3|>
63
  {%- elif lower.startswith('query:') -%}
64
  {% set q = text[6:] | trim -%}
65
  <|md_reserved_0|>query<|md_reserved_1|>{{ q }}<|md_reserved_2|>
66
  {%- elif lower.startswith('detect:') -%}
67
  {% set q = text[7:] | trim -%}
68
+ <|md_reserved_0|>det<|md_reserved_1|> {{ q }}<|md_reserved_2|>
69
  {%- elif lower.startswith('point:') -%}
70
  {% set q = text[6:] | trim -%}
71
+ <|md_reserved_0|>point<|md_reserved_1|> {{ q }}<|md_reserved_2|>
72
  {%- else -%}
73
  {% set q = text -%}
74
  <|md_reserved_0|>query<|md_reserved_1|>{{ q }}<|md_reserved_2|>
config.json CHANGED
@@ -37,7 +37,9 @@
37
  "output_router_logits": false,
38
  "prefix_attn": 730,
39
  "bos_token_id": 0,
40
- "rms_norm_eps": 1e-06,
 
 
41
  "rope_parameters": {
42
  "rope_theta": 1500000.0,
43
  "rope_type": "default"
 
37
  "output_router_logits": false,
38
  "prefix_attn": 730,
39
  "bos_token_id": 0,
40
+ "eos_token_id": 0,
41
+ "coord_token_id": 5,
42
+ "rms_norm_eps": 1e-05,
43
  "rope_parameters": {
44
  "rope_theta": 1500000.0,
45
  "rope_type": "default"
configuration_moondream3.py CHANGED
@@ -54,7 +54,7 @@ class Moondream3TextConfig(PretrainedConfig):
54
  The non-linear activation function.
55
  initializer_range (`float`, *optional*, defaults to 0.02):
56
  The standard deviation of the truncated_normal_initializer.
57
- rms_norm_eps (`float`, *optional*, defaults to 1e-6):
58
  The epsilon used by the rms normalization layers.
59
  use_cache (`bool`, *optional*, defaults to `True`):
60
  Whether or not the model should return the last key/values attentions.
 
54
  The non-linear activation function.
55
  initializer_range (`float`, *optional*, defaults to 0.02):
56
  The standard deviation of the truncated_normal_initializer.
57
+ rms_norm_eps (`float`, *optional*, defaults to 1e-5):
58
  The epsilon used by the rms normalization layers.
59
  use_cache (`bool`, *optional*, defaults to `True`):
60
  Whether or not the model should return the last key/values attentions.
image_processing_moondream3.py CHANGED
@@ -204,7 +204,13 @@ def prepare_crops(image, max_crops=12, overlap_margin=4):
204
  )
205
  all_crops = overlap_crops["crops"]
206
  all_crops = np.transpose(all_crops, (0, 3, 1, 2))
207
- all_crops = (((all_crops / 255.0) - 0.5) / 0.5)
 
 
 
 
 
 
208
  return all_crops.tolist(), overlap_crops["tiling"]
209
 
210
  class Moondream3ImageProcessor(BaseImageProcessor):
 
204
  )
205
  all_crops = overlap_crops["crops"]
206
  all_crops = np.transpose(all_crops, (0, 3, 1, 2))
207
+ all_crops = all_crops = (
208
+ torch.from_numpy(all_crops)
209
+ .to(device="cpu", dtype=torch.bfloat16)
210
+ .div_(255.0)
211
+ .sub_(0.5)
212
+ .div_(0.5)
213
+ )
214
  return all_crops.tolist(), overlap_crops["tiling"]
215
 
216
  class Moondream3ImageProcessor(BaseImageProcessor):
modeling_moondream3.py CHANGED
@@ -26,6 +26,7 @@ from PIL import Image
26
  from transformers.activations import ACT2FN
27
  from transformers.cache_utils import Cache, DynamicCache
28
  from transformers.masking_utils import create_causal_mask
 
29
  from transformers.modeling_outputs import (
30
  BaseModelOutputWithPast,
31
  CausalLMOutputWithPast,
@@ -34,6 +35,7 @@ from transformers.processing_utils import Unpack
34
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
35
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
  from transformers.generation import GenerationMixin
 
37
  from transformers.utils import logging, TransformersKwargs
38
  from .configuration_moondream3 import Moondream3Config, Moondream3TextConfig, Moondream3VisionConfig, Moondream3RegionConfig
39
 
@@ -41,48 +43,49 @@ logger = logging.get_logger(__name__)
41
 
42
  _CONFIG_FOR_DOC = "Moondream3Config"
43
 
44
- def rotate_half(x):
45
- """Rotates half the hidden dims of the input."""
46
- x1 = x[..., : x.shape[-1] // 2]
47
- x2 = x[..., x.shape[-1] // 2 :]
48
- return torch.cat((-x2, x1), dim=-1)
49
-
50
 
51
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
52
- """Applies Rotary Position Embedding to the query and key tensors.
53
 
 
 
 
 
 
 
 
 
 
 
54
  Args:
55
- q (`torch.Tensor`): The query tensor.
56
- k (`torch.Tensor`): The key tensor.
57
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
58
- sin (`torch.Tensor`): The sine part of the rotary embedding.
59
- position_ids (`torch.Tensor`, *optional*):
60
- Deprecated and unused.
61
- unsqueeze_dim (`int`, *optional*, defaults to 1):
62
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
63
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
64
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
65
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
66
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
67
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
68
  Returns:
69
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
70
  """
71
- rot_dim = cos.shape[-1]
72
 
73
- def rotate_prefix(x):
74
- x_rot = x[..., :rot_dim]
75
- x_pass = x[..., rot_dim:]
76
- x_rot = (x_rot * cos) + (rotate_half(x_rot) * sin)
77
- return torch.cat([x_rot, x_pass], dim=-1)
 
 
 
 
 
78
 
79
- cos = cos.unsqueeze(unsqueeze_dim)
80
- sin = sin.unsqueeze(unsqueeze_dim)
81
 
82
- return rotate_prefix(q), rotate_prefix(k)
 
83
 
84
  class Moondream3RotaryEmbedding(nn.Module):
85
- inv_freq: torch.Tensor # fix linting for `register_buffer`
86
 
87
  def __init__(self, config: Moondream3Config, device=None):
88
  super().__init__()
@@ -108,43 +111,34 @@ class Moondream3RotaryEmbedding(nn.Module):
108
  ) -> tuple["torch.Tensor", float]:
109
  """
110
  Computes the inverse frequencies according to the original RoPE implementation
111
- Args:
112
- config ([`~transformers.PreTrainedConfig`]):
113
- The model configuration.
114
- device (`torch.device`):
115
- The device to use for initialization of the inverse frequencies.
116
- seq_len (`int`, *optional*):
117
- The current sequence length. Unused for this type of RoPE.
118
- Returns:
119
- Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
120
- post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
121
  """
122
- base = config.rope_parameters["rope_theta"]
123
  dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
124
- dim = dim // 2
125
-
126
- attention_factor = 1.0 # Unused in this type of RoPE
127
 
128
- # Compute the inverse frequencies
129
  inv_freq = 1.0 / (
130
- base ** (torch.arange(0, dim, 2, dtype=torch.float).to(device=device)[: (dim // 2)] / dim)
131
  )
 
 
132
  return inv_freq, attention_factor
133
 
134
  @torch.no_grad()
135
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
136
  def forward(self, x, position_ids):
137
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
138
- position_ids_expanded = position_ids[:, None, :].float()
 
 
 
139
 
140
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
141
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
142
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
143
- emb = torch.cat((freqs, freqs), dim=-1)
144
- cos = emb.cos() * self.attention_scaling
145
- sin = emb.sin() * self.attention_scaling
146
 
147
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
148
 
149
 
150
  class Moondream3Attention(nn.Module):
@@ -202,6 +196,8 @@ class Moondream3Attention(nn.Module):
202
  **kwargs,
203
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
204
  input_shape = hidden_states.shape[:-1]
 
 
205
  bsz, q_len, _ = hidden_states.size()
206
 
207
  # Get qkv combined for tau (before splitting)
@@ -223,7 +219,9 @@ class Moondream3Attention(nn.Module):
223
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
224
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
225
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
226
-
 
 
227
 
228
  if self.use_tau:
229
  query_states = query_states * tau_q
@@ -234,20 +232,38 @@ class Moondream3Attention(nn.Module):
234
  tau_v_repeated = tau_v
235
  value_states = value_states * tau_v_repeated
236
 
 
 
 
 
237
  cos, sin = None, None
238
  if position_embeddings is not None:
239
  cos, sin = position_embeddings
 
 
 
240
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
241
 
242
- if past_key_values is not None:
 
 
243
 
 
244
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
245
  key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
246
 
 
 
 
 
 
 
 
 
247
  attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS["sdpa"](
248
  self,
249
  query_states,
250
- key_states,
251
  value_states,
252
  attention_mask,
253
  dropout=0.0 if not self.training else self.attention_dropout,
@@ -256,6 +272,9 @@ class Moondream3Attention(nn.Module):
256
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
257
  attn_output = self.o_proj(attn_output)
258
 
 
 
 
259
  return attn_output, attn_weights
260
 
261
  class Moondream3MLP(nn.Module):
@@ -266,7 +285,6 @@ class Moondream3MLP(nn.Module):
266
  self.out_size = self.hidden_size if out_size is None else out_size
267
  self.hidden_act = hidden_act
268
  self.gated = gated
269
- # Ungated MLP: up_proj and down_proj following HF conventions
270
  self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
271
  self.down_proj = nn.Linear(self.intermediate_size, self.out_size, bias=bias)
272
  self.gate_proj = None
@@ -276,17 +294,20 @@ class Moondream3MLP(nn.Module):
276
 
277
  def forward(self, x) -> torch.Tensor:
278
  if self.gated:
279
- h = self.act_fn(self.up_proj(x))
280
- g = self.gate_proj(x)
281
- x = h * (g + 1)
 
 
282
  else:
283
  x = self.act_fn(self.up_proj(x))
284
  return self.down_proj(x)
285
 
286
 
287
  class Moondream3SparseMoeBlock(nn.Module):
288
- def __init__(self, config: Moondream3TextConfig):
289
  super().__init__()
 
290
  self.hidden_size = config.hidden_size
291
  self.moe_intermediate_size = config.moe_intermediate_size
292
  self.num_experts = config.num_experts
@@ -295,32 +316,29 @@ class Moondream3SparseMoeBlock(nn.Module):
295
  self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=True)
296
  self.experts = nn.ModuleList([Moondream3MLP(hidden_size=self.hidden_size, intermediate_size=self.moe_intermediate_size, hidden_act="gelu", gated=True, bias=False) for _ in range(self.num_experts)])
297
 
298
- def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
299
  batch_size, sequence_length, hidden_dim = hidden_states.shape
300
  hidden_states = hidden_states.view(-1, hidden_dim)
301
  router_logits: torch.Tensor = self.gate(hidden_states)
302
-
303
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)
304
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
305
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
306
  routing_weights = routing_weights.to(hidden_states.dtype)
307
 
308
  final_hidden_states = torch.zeros(
309
  (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
310
  )
311
 
312
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
313
-
314
  for expert_idx in range(self.num_experts):
315
  expert_layer = self.experts[expert_idx]
316
- idx, top_x = torch.where(expert_mask[expert_idx])
317
 
318
  if top_x.shape[0] == 0:
319
  continue
320
 
321
  current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
 
322
  current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
323
-
324
  final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
325
 
326
  final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
@@ -330,6 +348,7 @@ class Moondream3SparseMoeBlock(nn.Module):
330
  class Moondream3DecoderLayer(nn.Module):
331
  def __init__(self, config: Moondream3TextConfig, layer_idx: int):
332
  super().__init__()
 
333
  self.hidden_size = config.hidden_size
334
  self.intermediate_size = config.intermediate_size
335
  self.self_attn = Moondream3Attention(config, layer_idx, use_tau=True)
@@ -337,7 +356,7 @@ class Moondream3DecoderLayer(nn.Module):
337
 
338
  self.is_moe_layer = layer_idx >= config.moe_start_layer
339
  if self.is_moe_layer:
340
- self.mlp = Moondream3SparseMoeBlock(config)
341
  else:
342
  self.mlp = Moondream3MLP(self.hidden_size, self.intermediate_size)
343
 
@@ -358,6 +377,8 @@ class Moondream3DecoderLayer(nn.Module):
358
 
359
  # Apply layer norm like original
360
  l_in = self.input_layernorm(hidden_states)
 
 
361
 
362
  # Attention
363
  hidden_states_attn, self_attn_weights = self.self_attn(
@@ -374,10 +395,12 @@ class Moondream3DecoderLayer(nn.Module):
374
 
375
  # MLP
376
  if self.is_moe_layer:
377
- hidden_states_mlp, router_logits = self.mlp(l_in)
378
  else:
379
  hidden_states_mlp = self.mlp(l_in)
380
  router_logits = None
 
 
381
 
382
  # Add both attention and MLP to residual like original
383
  hidden_states = residual + hidden_states_attn + hidden_states_mlp
@@ -538,8 +561,6 @@ class Moondream3TextModel(Moondream3PreTrainedModel):
538
  if output_router_logits and layer_outputs[-1] is not None:
539
  all_router_logits += (layer_outputs[-1],)
540
 
541
- hidden_states = self.norm(hidden_states)
542
-
543
  if output_hidden_states:
544
  all_hidden_states += (hidden_states,)
545
 
@@ -794,8 +815,8 @@ class Moondream3RegionEncoder(nn.Module):
794
  self.register_buffer("size_freq", size_freq.T)
795
 
796
  def fourier_features(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
797
- x_proj = torch.matmul(x, w) * 2 * torch.pi
798
- return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
799
 
800
  def encode_coordinate(self, coord: torch.Tensor) -> torch.Tensor:
801
  fourier_features = self.fourier_features(coord, self.coord_freq)
@@ -815,7 +836,7 @@ class Moondream3RegionDecoder(nn.Module):
815
  return self.coord_decoder(hidden_state)
816
 
817
  def decode_size(self, hidden_state: torch.Tensor) -> torch.Tensor:
818
- return self.size_decoder(hidden_state)
819
 
820
  class Moondream3Model(Moondream3PreTrainedModel):
821
  def __init__(self, config: Moondream3Config):
@@ -885,38 +906,59 @@ class Moondream3Model(Moondream3PreTrainedModel):
885
  if cache_position is None:
886
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
887
  cache_position: torch.Tensor = torch.arange(
888
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
889
  )
890
 
891
  if position_ids is None:
892
  position_ids = cache_position.unsqueeze(0)
893
 
894
- causal_mask = create_causal_mask(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
895
  config=self.config,
896
  input_embeds=inputs_embeds,
897
- attention_mask=attention_mask,
898
  cache_position=cache_position,
899
  past_key_values=past_key_values,
900
  position_ids=position_ids,
 
901
  )
902
 
903
- if pixel_values is not None and input_ids.shape[-1] > 1:
904
- # Vision embeds
905
- pixel_values = pixel_values.to(dtype=self.vision_model.embeddings.projection.weight.dtype)
906
- image_embeds = self.vision_model(pixel_values, tiling=tiling)["last_hidden_state"] # [B,P,D]
907
- prefix = inputs_embeds[:, :1, :] # keep the first token
908
- suffix = inputs_embeds[:, 1 + image_embeds.shape[1] :, :] # keep the rest after the image span
909
- inputs_embeds = torch.cat([prefix, image_embeds, suffix], dim=1)
910
-
911
- # N/A when doing BSZ 1 since create_causal_mask returns None in the case since theres no padding tokens
912
- if causal_mask is not None:
913
- img_len = image_embeds.shape[1]
914
- causal_mask[:, :, :1 + img_len, :1 + img_len] = True
915
- causal_mask[:, :, :1 + img_len, 1 + img_len:] = False
916
-
917
  outputs = self.text_model(
918
  input_ids=None,
919
- attention_mask=causal_mask,
920
  position_ids=position_ids,
921
  past_key_values=past_key_values,
922
  inputs_embeds=inputs_embeds,
@@ -942,11 +984,17 @@ class Moondream3Model(Moondream3PreTrainedModel):
942
  attentions=getattr(outputs, "attentions", None),
943
  )
944
 
 
 
 
 
 
945
  class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMixin):
946
  _tied_weights_keys = ["lm_head.weight"]
947
 
948
  def __init__(self, config: Moondream3Config):
949
  super().__init__(config)
 
950
  self.model = Moondream3Model(config)
951
  self.vocab_size = config.text_config.vocab_size
952
  self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=True)
@@ -970,6 +1018,15 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
970
  def get_decoder(self):
971
  return self.model.text_model
972
 
 
 
 
 
 
 
 
 
 
973
  def forward(
974
  self,
975
  input_ids: torch.LongTensor = None,
@@ -988,6 +1045,9 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
988
  logits_to_keep: int = 0,
989
  **kwargs: Unpack[TransformersKwargs],
990
  ) -> Union[Tuple, CausalLMOutputWithPast]:
 
 
 
991
  # Get hidden states from the base model (it already builds the multimodal prefix)
992
  model_outputs = self.model(
993
  input_ids=input_ids,
@@ -1005,7 +1065,6 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
1005
  cache_position=cache_position,
1006
  logits_to_keep=logits_to_keep,
1007
  )
1008
-
1009
  hidden_states = model_outputs.last_hidden_state # [B, T, D]
1010
 
1011
  # Compute logits; only keep the tail if requested
@@ -1016,8 +1075,127 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
1016
  else:
1017
  hs = hidden_states
1018
 
 
1019
  logits = self.lm_head(hs) # [B, T', V]
1020
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1021
  loss = None
1022
  if labels is not None:
1023
  # Shift if your training uses standard LM convention; here we assume labels aligned with hs
@@ -1031,6 +1209,47 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
1031
  attentions=getattr(model_outputs, "attentions", None),
1032
  )
1033
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1034
  @staticmethod
1035
  def _reorder_cache(past_key_values, beam_idx):
1036
  reordered_past = ()
 
26
  from transformers.activations import ACT2FN
27
  from transformers.cache_utils import Cache, DynamicCache
28
  from transformers.masking_utils import create_causal_mask
29
+ from dataclasses import dataclass
30
  from transformers.modeling_outputs import (
31
  BaseModelOutputWithPast,
32
  CausalLMOutputWithPast,
 
35
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
36
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
  from transformers.generation import GenerationMixin
38
+ from transformers.generation.utils import GenerateDecoderOnlyOutput
39
  from transformers.utils import logging, TransformersKwargs
40
  from .configuration_moondream3 import Moondream3Config, Moondream3TextConfig, Moondream3VisionConfig, Moondream3RegionConfig
41
 
 
43
 
44
  _CONFIG_FOR_DOC = "Moondream3Config"
45
 
46
+ import torch
 
 
 
 
 
47
 
48
+ DEBUG=True
 
49
 
50
+ def apply_rotary_pos_emb(
51
+ q: torch.Tensor, # [B, H, L, D]
52
+ k: torch.Tensor, # [B, H, L, D]
53
+ cos: torch.Tensor, # [B, L, rot_dim]
54
+ sin: torch.Tensor, # [B, L, rot_dim]
55
+ rot_dim: int = 32,
56
+ ):
57
+ """
58
+ Apply rotary position embeddings to query and key tensors.
59
+
60
  Args:
61
+ q: Query tensor [batch, num_heads, seq_len, head_dim]
62
+ k: Key tensor [batch, num_heads, seq_len, head_dim]
63
+ cos: Cosine frequencies [batch, seq_len, rot_dim]
64
+ sin: Sine frequencies [batch, seq_len, rot_dim]
65
+ rot_dim: Number of dimensions to apply rotation to (default: 32)
66
+
 
 
 
 
 
 
 
67
  Returns:
68
+ Tuple of (rotated_q, rotated_k)
69
  """
 
70
 
71
+ def apply_rope(x):
72
+ dtype = x.dtype
73
+ x = x.to(torch.float64)
74
+ x_rot, x_pass = x[..., :rot_dim], x[..., rot_dim:]
75
+
76
+ d_q = x_rot.shape[-1] // 2
77
+ xq_r, xq_i = x_rot[..., :d_q], x_rot[..., d_q:]
78
+
79
+ xq_out_r = xq_r * cos - xq_i * sin
80
+ xq_out_i = xq_r * sin + xq_i * cos
81
 
82
+ xq_out = torch.stack((xq_out_r, xq_out_i), dim=-1).flatten(-2)
 
83
 
84
+ return torch.cat([xq_out, x_pass], dim=-1)
85
+ return apply_rope(q), apply_rope(k)
86
 
87
  class Moondream3RotaryEmbedding(nn.Module):
88
+ inv_freq: torch.Tensor
89
 
90
  def __init__(self, config: Moondream3Config, device=None):
91
  super().__init__()
 
111
  ) -> tuple["torch.Tensor", float]:
112
  """
113
  Computes the inverse frequencies according to the original RoPE implementation
 
 
 
 
 
 
 
 
 
 
114
  """
115
+ base = config.rope_parameters["rope_theta"] # Should be 1500000.0 to match original
116
  dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
117
+ dim //= 2
118
+
119
+ attention_factor = 1.0
120
 
121
+ # Compute the inverse frequencies - matches your original formula
122
  inv_freq = 1.0 / (
123
+ base ** (torch.arange(0, dim, 2, dtype=torch.float64)[: (dim // 2)] / dim)
124
  )
125
+ if device is not None:
126
+ inv_freq = inv_freq.to(device=device)
127
  return inv_freq, attention_factor
128
 
129
  @torch.no_grad()
130
+ @dynamic_rope_update
131
  def forward(self, x, position_ids):
132
+ # inv_freq shape: [dim//2]
133
+ # position_ids shape: [batch_size, seq_len]
134
+
135
+ inv_freq_expanded = self.inv_freq[None, :, None].to(torch.float64).expand(position_ids.shape[0], -1, 1).to(x.device)
136
+ position_ids_expanded = position_ids[:, None, :].to(torch.float64)
137
 
138
+ freqs = (inv_freq_expanded.to(torch.float64) @ position_ids_expanded.to(torch.float64)).transpose(1, 2)
139
+ cfreqs = torch.exp(1j * freqs).unsqueeze(1).expand(-1, self.config.num_attention_heads, -1, -1)
 
 
 
 
140
 
141
+ return cfreqs.real, cfreqs.imag
142
 
143
 
144
  class Moondream3Attention(nn.Module):
 
196
  **kwargs,
197
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
198
  input_shape = hidden_states.shape[:-1]
199
+ if isinstance(self.config, Moondream3TextConfig) and DEBUG:
200
+ torch.save(hidden_states, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_input_states")
201
  bsz, q_len, _ = hidden_states.size()
202
 
203
  # Get qkv combined for tau (before splitting)
 
219
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
220
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
221
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
222
+
223
+ if isinstance(self.config, Moondream3TextConfig) and DEBUG:
224
+ torch.save(value_states, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_pre_tau_value")
225
 
226
  if self.use_tau:
227
  query_states = query_states * tau_q
 
232
  tau_v_repeated = tau_v
233
  value_states = value_states * tau_v_repeated
234
 
235
+ if isinstance(self.config, Moondream3TextConfig) and DEBUG:
236
+ torch.save(value_states, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_post_tau_value")
237
+ torch.save(key_states, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_pre_rope_key")
238
+
239
  cos, sin = None, None
240
  if position_embeddings is not None:
241
  cos, sin = position_embeddings
242
+ if isinstance(self.config, Moondream3TextConfig) and DEBUG:
243
+ torch.save(cos, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_cos")
244
+ torch.save(sin, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_sin")
245
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
246
 
247
+ if isinstance(self.config, Moondream3TextConfig) and DEBUG:
248
+ torch.save(key_states, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_post_rope_key")
249
+ query_states, key_states = query_states.to(value_states.dtype), key_states.to(value_states.dtype)
250
 
251
+ if past_key_values is not None:
252
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
253
  key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
254
 
255
+ if isinstance(self.config, Moondream3TextConfig) and DEBUG:
256
+ torch.save(key_states, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_post_cache_key")
257
+ torch.save(attention_mask, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_attn_mask")
258
+
259
+ query_states = query_states.contiguous()
260
+ key_states = key_states.contiguous()
261
+ value_states = value_states.contiguous()
262
+
263
  attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS["sdpa"](
264
  self,
265
  query_states,
266
+ key_states,
267
  value_states,
268
  attention_mask,
269
  dropout=0.0 if not self.training else self.attention_dropout,
 
272
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
273
  attn_output = self.o_proj(attn_output)
274
 
275
+ if isinstance(self.config, Moondream3TextConfig) and DEBUG:
276
+ torch.save(attn_output, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_attn_out")
277
+
278
  return attn_output, attn_weights
279
 
280
  class Moondream3MLP(nn.Module):
 
285
  self.out_size = self.hidden_size if out_size is None else out_size
286
  self.hidden_act = hidden_act
287
  self.gated = gated
 
288
  self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
289
  self.down_proj = nn.Linear(self.intermediate_size, self.out_size, bias=bias)
290
  self.gate_proj = None
 
294
 
295
  def forward(self, x) -> torch.Tensor:
296
  if self.gated:
297
+ # separate up and gate causes precision issues
298
+ combined_weight = torch.cat([self.up_proj.weight, self.gate_proj.weight], dim=0)
299
+ h_full = F.linear(x, combined_weight)
300
+ h, g = h_full.chunk(2, dim=-1)
301
+ x = self.act_fn(h) * (g + 1)
302
  else:
303
  x = self.act_fn(self.up_proj(x))
304
  return self.down_proj(x)
305
 
306
 
307
  class Moondream3SparseMoeBlock(nn.Module):
308
+ def __init__(self, config: Moondream3TextConfig, layer_idx = None):
309
  super().__init__()
310
+ self.layer_idx = layer_idx
311
  self.hidden_size = config.hidden_size
312
  self.moe_intermediate_size = config.moe_intermediate_size
313
  self.num_experts = config.num_experts
 
316
  self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=True)
317
  self.experts = nn.ModuleList([Moondream3MLP(hidden_size=self.hidden_size, intermediate_size=self.moe_intermediate_size, hidden_act="gelu", gated=True, bias=False) for _ in range(self.num_experts)])
318
 
319
+ def forward(self, hidden_states: torch.Tensor, cache_position=None) -> Tuple[torch.Tensor, torch.Tensor]:
320
  batch_size, sequence_length, hidden_dim = hidden_states.shape
321
  hidden_states = hidden_states.view(-1, hidden_dim)
322
  router_logits: torch.Tensor = self.gate(hidden_states)
323
+ routing_weights, selected_experts = torch.topk(router_logits, self.top_k, dim=-1)
324
+ routing_weights = F.softmax(routing_weights, dim=-1, dtype=torch.float32)
 
 
325
  routing_weights = routing_weights.to(hidden_states.dtype)
326
 
327
  final_hidden_states = torch.zeros(
328
  (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
329
  )
330
 
 
 
331
  for expert_idx in range(self.num_experts):
332
  expert_layer = self.experts[expert_idx]
333
+ top_x, idx = (selected_experts == expert_idx).nonzero(as_tuple=True)
334
 
335
  if top_x.shape[0] == 0:
336
  continue
337
 
338
  current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
339
+ # torch.save(current_state, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_e{expert_idx}")
340
  current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
341
+ # torch.save(current_hidden_states, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_e{expert_idx}")
342
  final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
343
 
344
  final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
 
348
  class Moondream3DecoderLayer(nn.Module):
349
  def __init__(self, config: Moondream3TextConfig, layer_idx: int):
350
  super().__init__()
351
+ self.layer_idx = layer_idx
352
  self.hidden_size = config.hidden_size
353
  self.intermediate_size = config.intermediate_size
354
  self.self_attn = Moondream3Attention(config, layer_idx, use_tau=True)
 
356
 
357
  self.is_moe_layer = layer_idx >= config.moe_start_layer
358
  if self.is_moe_layer:
359
+ self.mlp = Moondream3SparseMoeBlock(config, layer_idx=layer_idx)
360
  else:
361
  self.mlp = Moondream3MLP(self.hidden_size, self.intermediate_size)
362
 
 
377
 
378
  # Apply layer norm like original
379
  l_in = self.input_layernorm(hidden_states)
380
+ if DEBUG:
381
+ torch.save(l_in, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_ln_out")
382
 
383
  # Attention
384
  hidden_states_attn, self_attn_weights = self.self_attn(
 
395
 
396
  # MLP
397
  if self.is_moe_layer:
398
+ hidden_states_mlp, router_logits = self.mlp(l_in, cache_position=cache_position)
399
  else:
400
  hidden_states_mlp = self.mlp(l_in)
401
  router_logits = None
402
+ if DEBUG:
403
+ torch.save(hidden_states_mlp, f"dbg/hf_l{self.layer_idx}_c{cache_position[-1].item()}_mlp_out")
404
 
405
  # Add both attention and MLP to residual like original
406
  hidden_states = residual + hidden_states_attn + hidden_states_mlp
 
561
  if output_router_logits and layer_outputs[-1] is not None:
562
  all_router_logits += (layer_outputs[-1],)
563
 
 
 
564
  if output_hidden_states:
565
  all_hidden_states += (hidden_states,)
566
 
 
815
  self.register_buffer("size_freq", size_freq.T)
816
 
817
  def fourier_features(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
818
+ x_proj = 2 * torch.pi * x @ w
819
+ return torch.cat([x_proj.cos(), x_proj.sin()], dim=-1)
820
 
821
  def encode_coordinate(self, coord: torch.Tensor) -> torch.Tensor:
822
  fourier_features = self.fourier_features(coord, self.coord_freq)
 
836
  return self.coord_decoder(hidden_state)
837
 
838
  def decode_size(self, hidden_state: torch.Tensor) -> torch.Tensor:
839
+ return self.size_decoder(hidden_state).view(hidden_state.shape[0],2,-1)
840
 
841
  class Moondream3Model(Moondream3PreTrainedModel):
842
  def __init__(self, config: Moondream3Config):
 
906
  if cache_position is None:
907
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
908
  cache_position: torch.Tensor = torch.arange(
909
+ past_seen_tokens, past_seen_tokens, device=inputs_embeds.device
910
  )
911
 
912
  if position_ids is None:
913
  position_ids = cache_position.unsqueeze(0)
914
 
915
+ def image_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int):
916
+ # set all up to `self.config.vision_config.prefix_len` to true
917
+ return kv_idx <= q_idx
918
+
919
+ if pixel_values is not None:
920
+ # Vision embeds
921
+ pixel_values = pixel_values.to(dtype=self.vision_model.embeddings.projection.weight.dtype)
922
+ image_embeds = self.vision_model(pixel_values, tiling=tiling)["last_hidden_state"] # [B,P,D]
923
+ prefix = self.text_model.embed_tokens(
924
+ torch.full((input_ids.shape[0], 1), self.config.text_config.bos_token_id, dtype=input_ids.dtype, device=input_ids.device)
925
+ )
926
+ embeds = torch.cat([prefix, image_embeds], dim=1)
927
+ cache_pos = torch.arange(embeds.shape[-2], device=embeds.device)
928
+ pos = cache_pos.unsqueeze(0).expand(embeds.shape[0],-1)
929
+ attn_mask = torch.full(
930
+ (embeds.shape[0], 1, embeds.shape[-2], pos.shape[-1]),
931
+ True,
932
+ dtype=torch.bool,
933
+ device=embeds.device,
934
+ )
935
+
936
+ outputs = self.text_model(
937
+ input_ids=None,
938
+ attention_mask=attn_mask,
939
+ position_ids=pos,
940
+ past_key_values=past_key_values,
941
+ inputs_embeds=embeds,
942
+ use_cache=use_cache,
943
+ output_attentions=output_attentions,
944
+ output_hidden_states=output_hidden_states,
945
+ return_dict=True,
946
+ cache_position=cache_pos,
947
+ )
948
+
949
+ attn_mask = create_causal_mask(
950
  config=self.config,
951
  input_embeds=inputs_embeds,
952
+ attention_mask=torch.cat([torch.ones(attention_mask.shape[0], cache_position[-1] + 1 - attention_mask.shape[-1], device=attention_mask.device, dtype=attention_mask.dtype), attention_mask], dim=-1),
953
  cache_position=cache_position,
954
  past_key_values=past_key_values,
955
  position_ids=position_ids,
956
+ and_mask_function=image_mask_function
957
  )
958
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
959
  outputs = self.text_model(
960
  input_ids=None,
961
+ attention_mask=attn_mask,
962
  position_ids=position_ids,
963
  past_key_values=past_key_values,
964
  inputs_embeds=inputs_embeds,
 
984
  attentions=getattr(outputs, "attentions", None),
985
  )
986
 
987
+ @dataclass
988
+ class Moondream3GenerateOutput(GenerateDecoderOnlyOutput):
989
+ objects: Optional[list[dict[str,float]]] = None
990
+
991
+
992
  class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMixin):
993
  _tied_weights_keys = ["lm_head.weight"]
994
 
995
  def __init__(self, config: Moondream3Config):
996
  super().__init__(config)
997
+ self.objects = None
998
  self.model = Moondream3Model(config)
999
  self.vocab_size = config.text_config.vocab_size
1000
  self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=True)
 
1018
  def get_decoder(self):
1019
  return self.model.text_model
1020
 
1021
+ def _prepare_generated_length(
1022
+ self,
1023
+ generation_config,
1024
+ **kwargs,
1025
+ ):
1026
+ generation_config = super()._prepare_generated_length(generation_config, **kwargs)
1027
+ generation_config.max_length += 730
1028
+ return generation_config
1029
+
1030
  def forward(
1031
  self,
1032
  input_ids: torch.LongTensor = None,
 
1045
  logits_to_keep: int = 0,
1046
  **kwargs: Unpack[TransformersKwargs],
1047
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1048
+ if pixel_values is not None and inputs_embeds is None:
1049
+ position_ids += self.config.vision_config.prefix_len
1050
+ cache_position += self.config.vision_config.prefix_len
1051
  # Get hidden states from the base model (it already builds the multimodal prefix)
1052
  model_outputs = self.model(
1053
  input_ids=input_ids,
 
1065
  cache_position=cache_position,
1066
  logits_to_keep=logits_to_keep,
1067
  )
 
1068
  hidden_states = model_outputs.last_hidden_state # [B, T, D]
1069
 
1070
  # Compute logits; only keep the tail if requested
 
1075
  else:
1076
  hs = hidden_states
1077
 
1078
+ hs = self.model.text_model.norm(hs)
1079
  logits = self.lm_head(hs) # [B, T', V]
1080
 
1081
+ pred = torch.argmax(logits, dim=-1)
1082
+
1083
+ pos_ids = position_ids[:,-1:] + 1
1084
+ cache_pos = cache_position[-1:] + 1
1085
+ mask = torch.ones(
1086
+ hidden_states.shape[0], 1, device=self.device, dtype=torch.long
1087
+ )
1088
+ while torch.any(pred == 5):
1089
+ batch_mask = (pred[:, -1] == 5)
1090
+ hidden_states = hidden_states[:, -1:, :]
1091
+ x_logits = self.model.region_decoder.decode_coordinate(hidden_states)
1092
+ x_center = torch.argmax(x_logits, dim=-1) / x_logits.size(-1)
1093
+ next_embeds = self.model.region_encoder.encode_coordinate(x_center.to(x_logits.dtype)).unsqueeze(1)
1094
+ model_outputs = self.model(
1095
+ input_ids=None,
1096
+ pixel_values=None,
1097
+ tiling=None,
1098
+ attention_mask=mask,
1099
+ position_ids=pos_ids,
1100
+ past_key_values=past_key_values,
1101
+ inputs_embeds=next_embeds,
1102
+ labels=None,
1103
+ use_cache=use_cache,
1104
+ output_attentions=output_attentions,
1105
+ output_hidden_states=output_hidden_states,
1106
+ return_dict=True,
1107
+ cache_position=cache_pos,
1108
+ logits_to_keep=logits_to_keep,
1109
+ )
1110
+ hidden_states = model_outputs.last_hidden_state # [B, T, D]
1111
+ y_logits = self.model.region_decoder.decode_coordinate(hidden_states)
1112
+ y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
1113
+ next_embeds = self.model.region_encoder.encode_coordinate(y_center.to(y_logits.dtype)).unsqueeze(1)
1114
+ coords = torch.cat([x_center, y_center], dim=1)
1115
+ coords = coords * (batch_mask).unsqueeze(1)
1116
+ pos_ids += 1
1117
+ cache_pos = cache_pos + 1
1118
+ bbox = None
1119
+ if input_ids[0,1] == 7235:
1120
+ model_outputs = self.model(
1121
+ input_ids=None,
1122
+ pixel_values=None,
1123
+ tiling=None,
1124
+ attention_mask=mask,
1125
+ position_ids=pos_ids,
1126
+ past_key_values=past_key_values,
1127
+ inputs_embeds=next_embeds,
1128
+ labels=None,
1129
+ use_cache=use_cache,
1130
+ output_attentions=output_attentions,
1131
+ output_hidden_states=output_hidden_states,
1132
+ return_dict=True,
1133
+ cache_position=cache_pos,
1134
+ logits_to_keep=logits_to_keep,
1135
+ )
1136
+ hidden_states = model_outputs.last_hidden_state # [B, T, D]
1137
+ size_logits = self.model.region_decoder.decode_size(hidden_states)
1138
+ bins = torch.argmax(size_logits, dim=-1)
1139
+ w_bin = bins[:,0]
1140
+ h_bin = bins[:,1]
1141
+
1142
+ w = torch.pow(2.0, (w_bin.float() / 1023.0) * 10.0 - 10.0)
1143
+ h = torch.pow(2.0, (h_bin.float() / 1023.0) * 10.0 - 10.0)
1144
+
1145
+ next_embeds = (
1146
+ self.model.region_encoder.encode_size(
1147
+ torch.stack([w, h],dim=-1).to(size_logits.dtype)
1148
+ )
1149
+ ).unsqueeze(1)
1150
+ bbox = [
1151
+ x_center.item() - w.item() / 2,
1152
+ y_center.item() - h.item() / 2,
1153
+ x_center.item() + w.item() / 2,
1154
+ y_center.item() + h.item() / 2,
1155
+ ]
1156
+ bbox = bbox * (batch_mask).unsqueeze(1)
1157
+ pos_ids += 1
1158
+ cache_pos = cache_pos + 1
1159
+
1160
+ new = coords.unsqueeze(1) if bbox is None else bbox.unsqueeze(1)
1161
+ if self.objects is None:
1162
+ self.objects = new
1163
+ else:
1164
+ self.objects = torch.cat([self.objects, new], dim=1)
1165
+ model_outputs = self.model(
1166
+ input_ids=None,
1167
+ pixel_values=None,
1168
+ tiling=None,
1169
+ attention_mask=mask,
1170
+ position_ids=pos_ids,
1171
+ past_key_values=past_key_values,
1172
+ inputs_embeds=next_embeds,
1173
+ labels=None,
1174
+ use_cache=use_cache,
1175
+ output_attentions=output_attentions,
1176
+ output_hidden_states=output_hidden_states,
1177
+ return_dict=True,
1178
+ cache_position=cache_pos,
1179
+ logits_to_keep=logits_to_keep,
1180
+ )
1181
+ pos_ids += 1
1182
+ cache_pos = cache_pos + 1
1183
+ hidden_states = model_outputs.last_hidden_state # [B, T, D]
1184
+
1185
+ indices = torch.tensor(
1186
+ [self.config.text_config.coord_token_id, self.config.text_config.eos_token_id],
1187
+ device=self.device,
1188
+ )
1189
+
1190
+ hidden_states = self.model.text_model.norm(hidden_states)
1191
+ logits = hidden_states @ self.lm_head.weight[indices].T + self.lm_head.bias[indices]
1192
+
1193
+ logits_full = torch.full((logits.shape[0], logits.shape[1], self.config.text_config.vocab_size), float('-inf'), device=logits.device, dtype=logits.dtype)
1194
+ logits_full[:, :, torch.tensor([5,0])] = logits
1195
+ logits = logits_full
1196
+ pred[batch_mask] = torch.argmax(logits, dim=-1)[batch_mask]
1197
+
1198
+
1199
  loss = None
1200
  if labels is not None:
1201
  # Shift if your training uses standard LM convention; here we assume labels aligned with hs
 
1209
  attentions=getattr(model_outputs, "attentions", None),
1210
  )
1211
 
1212
+ def generate(self, **kwargs) -> Union[Moondream3GenerateOutput, torch.LongTensor]:
1213
+ outputs = super().generate(**kwargs)
1214
+ if len(self.objects) > 0:
1215
+ if isinstance(outputs, torch.Tensor):
1216
+ outputs = self.objects
1217
+ self.objects = []
1218
+ else:
1219
+ outputs = Moondream3GenerateOutput(
1220
+ **outputs,
1221
+ objects=self.objects
1222
+ )
1223
+ self.objects = []
1224
+ return outputs
1225
+
1226
+ def prepare_inputs_for_generation(
1227
+ self,
1228
+ input_ids,
1229
+ **model_kwargs
1230
+ ):
1231
+ model_inputs = super().prepare_inputs_for_generation(input_ids, **model_kwargs)
1232
+ model_inputs["position_ids"] += model_inputs["cache_position"].unsqueeze(0) - model_inputs["position_ids"]
1233
+ return model_inputs
1234
+
1235
+ def _update_model_kwargs_for_generation(
1236
+ self,
1237
+ outputs,
1238
+ model_kwargs,
1239
+ is_encoder_decoder,
1240
+ num_new_tokens: int = 1,
1241
+ ):
1242
+ model_kwargs = super()._update_model_kwargs_for_generation(
1243
+ outputs,
1244
+ model_kwargs,
1245
+ is_encoder_decoder=is_encoder_decoder,
1246
+ num_new_tokens=num_new_tokens,
1247
+ )
1248
+ model_kwargs["pixel_values"] = None
1249
+ model_kwargs["tiling"] = None
1250
+ return model_kwargs
1251
+
1252
+
1253
  @staticmethod
1254
  def _reorder_cache(past_key_values, beam_idx):
1255
  reordered_past = ()
processing_moondream3.py CHANGED
@@ -211,40 +211,40 @@ class Moondream3Processor(ProcessorMixin):
211
 
212
  return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
213
  text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
214
- if "input_ids" in text_inputs:
215
- # prepend 1 bos_token_id and 729 image_token_id to the text_inputs
216
- for i in range(len(text_inputs["input_ids"])):
217
- prepended_tokens = [self.tokenizer.bos_token_id] + [self.image_token_id] * 729
218
- text_inputs["input_ids"][i] = prepended_tokens + text_inputs["input_ids"][i]
219
- if "attention_mask" in text_inputs:
220
- # attend to the 730 prepended tokens
221
- for i in range(len(text_inputs["attention_mask"])):
222
- prepended_mask = [1] * 730
223
- text_inputs["attention_mask"][i] = prepended_mask + text_inputs["attention_mask"][i]
224
 
225
  return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
226
 
227
- def apply_chat_template(
228
- self,
229
- conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
230
- chat_template: Optional[str] = None,
231
- **kwargs,
232
- ) -> str:
233
- # Call the original behavior first
234
- out = super().apply_chat_template(
235
- conversation=conversation,
236
- chat_template=chat_template,
237
- **kwargs,
238
- )
239
-
240
- # Only post-process when:
241
- # - user requested assistant mask
242
- # - output is a dict (tokenized + return_dict=True path)
243
- if isinstance(out, BatchFeature) and kwargs.get("return_assistant_tokens_mask", False):
244
- if "assistant_masks" in out and out["assistant_masks"] is not None:
245
- out["assistant_masks"] = _rotate_right_array(out["assistant_masks"], 730)
246
-
247
- return out
248
 
249
  @property
250
  def model_input_names(self):
 
211
 
212
  return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
213
  text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
214
+ # if "input_ids" in text_inputs:
215
+ # # prepend 1 bos_token_id and 729 image_token_id to the text_inputs
216
+ # for i in range(len(text_inputs["input_ids"])):
217
+ # prepended_tokens = [self.tokenizer.bos_token_id] + [self.image_token_id] * 729
218
+ # text_inputs["input_ids"][i] = prepended_tokens + text_inputs["input_ids"][i]
219
+ # if "attention_mask" in text_inputs:
220
+ # # attend to the 730 prepended tokens
221
+ # for i in range(len(text_inputs["attention_mask"])):
222
+ # prepended_mask = [1] * 730
223
+ # text_inputs["attention_mask"][i] = prepended_mask + text_inputs["attention_mask"][i]
224
 
225
  return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
226
 
227
+ # def apply_chat_template(
228
+ # self,
229
+ # conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
230
+ # chat_template: Optional[str] = None,
231
+ # **kwargs,
232
+ # ) -> str:
233
+ # # Call the original behavior first
234
+ # out = super().apply_chat_template(
235
+ # conversation=conversation,
236
+ # chat_template=chat_template,
237
+ # **kwargs,
238
+ # )
239
+
240
+ # # Only post-process when:
241
+ # # - user requested assistant mask
242
+ # # - output is a dict (tokenized + return_dict=True path)
243
+ # if isinstance(out, BatchFeature) and kwargs.get("return_assistant_tokens_mask", False):
244
+ # if "assistant_masks" in out and out["assistant_masks"] is not None:
245
+ # out["assistant_masks"] = _rotate_right_array(out["assistant_masks"], 730)
246
+
247
+ # return out
248
 
249
  @property
250
  def model_input_names(self):