import math from typing import List, Optional, Tuple, Union import torch from torch import nn from torch.nn import CrossEntropyLoss # transformers == 4.45.1 from .configuration_unirec import UniRecConfig from transformers import M2M100PreTrainedModel from transformers.models.m2m_100.modeling_m2m_100 import M2M100ScaledWordEmbedding, M2M100Decoder from transformers.modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput, BaseModelOutput from transformers.generation import GenerationMixin from openrec.modeling.encoders.focalsvtr import FocalSVTR from transformers.utils import is_flash_attn_2_available if is_flash_attn_2_available(): from transformers.modeling_flash_attention_utils import _flash_attention_forward def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ Shift input ids one token to the right. """ shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() shifted_input_ids[:, 0] = decoder_start_token_id if pad_token_id is None: raise ValueError('self.model.config.pad_token_id has to be defined.') # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) return shifted_input_ids class UniRecEncoder(M2M100PreTrainedModel): """ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a [`M2M100EncoderLayer`]. Args: config: UniRecConfig """ def __init__(self, config: UniRecConfig): super().__init__(config) self.config = config self.vision_encoder = FocalSVTR(img_size=[1408, 960], depths=[2, 2, 9, 2], embed_dim=96, sub_k=[[2, 2], [2, 2], [2, 2], [-1, -1]], focal_levels=[3, 3, 3, 3], max_khs=[7, 3, 3, 3], focal_windows=[3, 3, 3, 3], last_stage=False, feat2d=False) self.vision_fc = nn.Linear(config.dims[-1], config.d_model) def forward( self, pixel_values: torch.Tensor = None, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # print('visionencoder pixel_values', pixel_values) # retrieve input_ids and inputs_embeds encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None hidden_states = self.vision_encoder(pixel_values) hidden_states = self.vision_fc(hidden_states) # hidden_states = self.layer_norm(hidden_states) if output_hidden_states: encoder_states = (hidden_states, ) if not return_dict: return tuple( v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions) class UniRecModel(M2M100PreTrainedModel): _tied_weights_keys = [ 'encoder.embed_tokens.weight', 'decoder.embed_tokens.weight' ] def __init__(self, config: UniRecConfig): super().__init__(config) padding_idx, vocab_size = config.pad_token_id, config.vocab_size embed_scale = math.sqrt( config.d_model) if config.scale_embedding else 1.0 self.shared = M2M100ScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) self.encoder = UniRecEncoder(config) self.decoder = M2M100Decoder(config, self.shared) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.shared def set_input_embeddings(self, value): self.shared = value self.decoder.embed_tokens = self.shared def _tie_weights(self): if self.config.tie_word_embeddings: self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) def get_encoder(self): return self.encoder def get_decoder(self): return self.decoder def forward( self, pixel_values: torch.Tensor = None, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if encoder_outputs is None: encoder_outputs = self.encoder(pixel_values) # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) attention_mask = torch.ones(encoder_outputs[0].shape[:2], dtype=torch.long, device=encoder_outputs[0].device) # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if not return_dict: return decoder_outputs + encoder_outputs return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) class UniRecForConditionalGenerationNew(M2M100PreTrainedModel, GenerationMixin): base_model_prefix = 'model' _tied_weights_keys = [ 'encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight' ] def __init__(self, config: UniRecConfig): super().__init__(config) self.model = UniRecModel(config) self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) self.loss_fct = CrossEntropyLoss( ignore_index=config.pad_token_id, label_smoothing=config.label_smoothing) # Initialize weights and apply final processing self.post_init() def get_encoder(self): return self.model.get_encoder() def get_decoder(self): return self.model.get_decoder() def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def forward( self, pixel_values: torch.Tensor, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, length: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: if decoder_input_ids is None: # decoder_input_ids = shift_tokens_right( # labels, self.config.pad_token_id, self.config.decoder_start_token_id # ) if length is not None: max_len = length.max() decoder_input_ids = labels[:, :1 + max_len] labels = labels[:, 1:2 + max_len] else: decoder_input_ids = labels[:, :-1] labels = labels[:, 1:] masked_lm_loss = None if self.training and labels is not None: outputs = self.model( pixel_values=pixel_values, input_ids=None, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) lm_logits = self.lm_head(outputs[0]) masked_lm_loss = self.loss_fct( lm_logits.reshape(-1, self.config.vocab_size), labels.reshape(-1)) else: # print('pixel_values', pixel_values.shape) # print('decoder_input_ids', decoder_input_ids) outputs = self.model( pixel_values=pixel_values, input_ids=None, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) lm_logits = self.lm_head(outputs[0]) if not return_dict: output = (lm_logits, ) + outputs[1:] return ((masked_lm_loss, ) + output) if masked_lm_loss is not None else output return Seq2SeqLMOutput( loss=masked_lm_loss, logits=lm_logits, past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, ) def prepare_inputs_for_generation( self, decoder_input_ids, past_key_values=None, attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, pixel_values=None, **kwargs, ): # cut decoder_input_ids if past is used if past_key_values is not None: past_length = past_key_values[0][0].shape[2] # Some generation methods already pass only the last input ID if decoder_input_ids.shape[1] > past_length: remove_prefix_length = past_length else: # Default to old behavior: keep only final ID remove_prefix_length = decoder_input_ids.shape[1] - 1 decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { 'input_ids': None, # encoder_outputs is defined. input_ids not needed 'encoder_outputs': encoder_outputs, 'past_key_values': past_key_values, 'decoder_input_ids': decoder_input_ids, 'attention_mask': attention_mask, 'head_mask': head_mask, 'decoder_head_mask': decoder_head_mask, 'cross_attn_head_mask': cross_attn_head_mask, 'use_cache': use_cache, # change this to avoid caching (presumably for debugging) 'pixel_values': pixel_values, } @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += (tuple( past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), ) return reordered_past