# coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Optional, Tuple, Union import numpy as np import math import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.masking_utils import create_causal_mask from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) from transformers.processing_utils import Unpack from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, load_state_dict from transformers.generation import GenerationMixin from transformers.utils import logging, TransformersKwargs from .moondream3_moe_fused.moe_fused_linear import MoeFusedLinear from .moondream3_moe_fused.kernels.indexing import get_expert_counts_and_idx from .configuration_moondream3 import Moondream3Config, Moondream3TextConfig, Moondream3VisionConfig, Moondream3RegionConfig from . import modeling_moondream3 logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "Moondream3Config" class Moondream3FusedSparseMoeBlock(nn.Module): def __init__(self, config: Moondream3TextConfig) -> None: super().__init__() self.num_experts = config.num_experts self.num_selected = config.num_experts_per_tok self.hidden_size = config.hidden_size self.moe_intermediate_size = config.moe_intermediate_size self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) self.gate_proj = MoeFusedLinear(self.hidden_size, self.moe_intermediate_size, config.num_experts) self.up_proj = MoeFusedLinear(self.hidden_size, self.moe_intermediate_size, config.num_experts) self.down_proj = MoeFusedLinear(self.moe_intermediate_size, self.hidden_size, config.num_experts) def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape M = batch_size * sequence_length hidden_states = hidden_states.view(M, hidden_dim) # router_logits: (M, num_experts) router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float32) # routing_weights, selected_experts: (M, num_selected) routing_weights, selected_experts = torch.topk(routing_weights, self.num_selected, dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) hidden_states = hidden_states.unsqueeze(1).expand(M, self.num_selected, hidden_dim) # hidden_states must be contiguous hidden_states = hidden_states.reshape(M * self.num_selected, hidden_dim) selected_experts = selected_experts.view(M * self.num_selected) # Sort selected_experts and hidden_states for better memory coalescence of weight # It's possible to fuse a sort and a MoeFusedLinear layer, but for now we separate them for clarity m_sizes, sort_idx, inv_sort_idx = get_expert_counts_and_idx(selected_experts, self.num_experts) hidden_states = hidden_states[sort_idx] # It's possible to fuse gate_h and up_h, but this affects the shape of LoRA gate_h = self.gate_proj(hidden_states, m_sizes) up_h = self.up_proj(hidden_states, m_sizes) hidden_states = F.gelu(up_h) * (gate_h + 1) del gate_h, up_h hidden_states = self.down_proj(hidden_states, m_sizes) hidden_states = hidden_states[inv_sort_idx] hidden_states = hidden_states.view(M, self.num_selected, hidden_dim) hidden_states = torch.einsum("beo,be->bo", hidden_states, routing_weights) hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim) return hidden_states, router_logits modeling_moondream3.Moondream3SparseMoeBlock = Moondream3FusedSparseMoeBlock from .modeling_moondream3 import Moondream3Config, Moondream3TextConfig, Moondream3VisionConfig, Moondream3RegionConfig, Moondream3PreTrainedModel, Moondream3Model, Moondream3TextModel, Moondream3VisionModel, Moondream3ForConditionalGeneration class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: Moondream3Config): super().__init__(config) self.model = Moondream3Model(config) self.vocab_size = config.text_config.vocab_size self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=True) self.post_init() def get_input_embeddings(self): return self.model.text_model.embed_tokens def set_input_embeddings(self, value): self.model.text_model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model.text_model = decoder def get_decoder(self): return self.model.text_model def forward( self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, tiling: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: int = 0, **kwargs: Unpack[TransformersKwargs], ) -> Union[Tuple, CausalLMOutputWithPast]: # Get hidden states from the base model (it already builds the multimodal prefix) model_outputs = self.model( input_ids=input_ids, pixel_values=pixel_values, tiling=tiling, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=None, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, logits_to_keep=logits_to_keep, ) hidden_states = model_outputs.last_hidden_state # [B, T, D] # Compute logits; only keep the tail if requested if isinstance(logits_to_keep, int) and logits_to_keep > 0: hs = hidden_states[:, -logits_to_keep:, :] elif isinstance(logits_to_keep, slice): hs = hidden_states[:, logits_to_keep, :] else: hs = hidden_states logits = self.lm_head(hs) # [B, T', V] loss = None if labels is not None: # Shift if your training uses standard LM convention; here we assume labels aligned with hs loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=getattr(model_outputs, "past_key_values", None), hidden_states=getattr(model_outputs, "hidden_states", None), attentions=getattr(model_outputs, "attentions", None), ) @classmethod def _load_pretrained_model( cls, model: "PreTrainedModel", state_dict: Optional[dict], checkpoint_files: Optional[list[str]], pretrained_model_name_or_path, weights_only: bool = True, **kwargs, ): if checkpoint_files is not None: state_dict = {} for file in checkpoint_files: sd = load_state_dict(file, map_location="cpu", weights_only=weights_only) for key, value in sd.items(): state_dict[key] = value from collections import defaultdict moe_layer_experts = defaultdict(set) for key in state_dict.keys(): if key.startswith("model.text_model.layers."): parts = key.split(".") # Expected: model.text_model.layers.{layer}.mlp.experts.{expert_id}.down_proj.weight if len(parts) > 6 and parts[5] == "experts" and parts[3].isdigit() and parts[6].isdigit(): layer_idx = int(parts[3]) expert_idx = int(parts[6]) moe_layer_experts[layer_idx].add(expert_idx) moe_layers = {layer: len(experts) for layer, experts in moe_layer_experts.items()} for layer_idx, num_experts in moe_layers.items(): state_dict[f"model.text_model.layers.{layer_idx}.mlp.down_proj.weight"] = torch.stack( [ state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.down_proj.weight"] for i in range(num_experts) ] ) for i in range(num_experts): del state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.down_proj.weight"] state_dict[f"model.text_model.layers.{layer_idx}.mlp.up_proj.weight"] = torch.stack( [ state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.up_proj.weight"] for i in range(num_experts) ] ) for i in range(num_experts): del state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.up_proj.weight"] state_dict[f"model.text_model.layers.{layer_idx}.mlp.gate_proj.weight"] = torch.stack( [ state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.gate_proj.weight"] for i in range(num_experts) ] ) for i in range(num_experts): del state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.gate_proj.weight"] checkpoint_files = None model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs = super()._load_pretrained_model( model, state_dict, checkpoint_files, pretrained_model_name_or_path, **kwargs, ) return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs def _fix_state_dict_keys_on_save(self, state_dict: dict): for layer_idx in range(self.config.text_config.moe_start_layer, self.config.text_config.num_hidden_layers): layer_key = f"model.text_model.layers.{layer_idx}" tensor = state_dict.pop(f"{layer_key}.mlp.down_proj.weight").cpu() for i, t in enumerate(torch.unbind(tensor)): base_key = f"{layer_key}.mlp.experts.{i}" state_dict[f"{base_key}.down_proj.weight"] = t.contiguous() tensor = state_dict.pop(f"{layer_key}.mlp.up_proj.weight").cpu() for i, t in enumerate(torch.unbind(tensor)): base_key = f"{layer_key}.mlp.experts.{i}" state_dict[f"{base_key}.up_proj.weight"] = t.contiguous() tensor = state_dict.pop(f"{layer_key}.mlp.gate_proj.weight").cpu() for i, t in enumerate(torch.unbind(tensor)): base_key = f"{layer_key}.mlp.experts.{i}" state_dict[f"{base_key}.gate_proj.weight"] = t.contiguous() return state_dict @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), ) return reordered_past __all__ = [ "Moondream3Config", "Moondream3TextConfig", "Moondream3VisionConfig", "Moondream3RegionConfig", "Moondream3PreTrainedModel", "Moondream3Model", "Moondream3TextModel", "Moondream3VisionModel", "Moondream3ForConditionalGeneration", ]