From 50f9eb80ac4b0433c7fe46b4572fef2b34e9ebd0 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Fri, 20 Sep 2024 15:49:25 +0200 Subject: [PATCH] working generation --- .../models/moshi/configuration_moshi.py | 26 +- .../models/moshi/modeling_moshi.py | 971 ++++++++---------- 2 files changed, 430 insertions(+), 567 deletions(-) diff --git a/src/transformers/models/moshi/configuration_moshi.py b/src/transformers/models/moshi/configuration_moshi.py index 9eda765c834477..eff73613a82b80 100644 --- a/src/transformers/models/moshi/configuration_moshi.py +++ b/src/transformers/models/moshi/configuration_moshi.py @@ -97,7 +97,7 @@ class MoshiConfig(PretrainedConfig): Example: - ```python + ```python # TODO(YL): update >>> from transformers import ( ... MoshiConfig, ... EncodecConfig, @@ -189,21 +189,24 @@ def __init__(self, self.depth_head_dim = depth_head_dim or depth_hidden_size // depth_num_attention_heads self.depth_num_key_value_heads = depth_num_key_value_heads if depth_num_key_value_heads is not None else depth_num_attention_heads - super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) - - if "audio_encoder" not in kwargs: + audio_encoder_config = kwargs.pop("audio_encoder", None) + if audio_encoder_config is None: raise ValueError("Config has to be initialized with audio_encoder config") - - audio_encoder_config = kwargs.pop("audio_encoder") + audio_encoder_model_type = audio_encoder_config.pop("model_type") self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config) + if self.num_codebooks > self.audio_encoder.num_codebooks: raise ValueError(f"`num_codebooks={num_codebooks}` is greater than the maximum number of codebooks that the audio encoder can deal with ({self.audio_encoder.num_codebooks}). Please lower it.") self.audio_vocab_size = self.audio_encoder.codebook_size if audio_vocab_size is None else audio_vocab_size + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) - + @property + def sampling_rate(self): + return self.audio_encoder.sampling_rate @classmethod def from_audio_encoder_config( @@ -213,17 +216,12 @@ def from_audio_encoder_config( ): r""" Instantiate a [`MoshiConfig`] (or a derived class) from an audio encoder configuration. - + Returns: [`MoshiConfig`]: An instance of a configuration object """ - + return cls( audio_encoder=audio_encoder_config.to_dict(), **kwargs, ) - - @property - # This is a property because you might want to change the codec model on the fly - def sampling_rate(self): - return self.audio_encoder.sampling_rate diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 8afeecc4112a1d..e7efa12aab1729 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -91,7 +91,7 @@ @dataclass class MoshiCausalLMOutputWithPast(ModelOutput): """ - Base class for causal language model (or autoregressive) outputs. + `MoshiForCausalLM` outputs. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): @@ -127,6 +127,60 @@ class MoshiCausalLMOutputWithPast(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor, ...]] = None +@dataclass +class MoshiConditionalGenerationOutputWithPast(ModelOutput): + """ + `MoshiForConditionalGeneration` outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `text_labels` is provided): + Text language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the text language modeling head (scores for each vocabulary token before SoftMax). + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + depth_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `audio_labels` is provided): + Audio language modeling loss (for next-token prediction). + audio_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the audio language modeling heads. + depth_past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Past key-values of the depth decoder. + depth_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Hidden states of the depth decoder + depth_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Depth decoder's Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + depth_loss: Optional[torch.FloatTensor] = None + audio_logits: torch.FloatTensor = None + depth_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + depth_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + depth_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ Shift input ids one token to the right. @@ -248,10 +302,7 @@ def forward(self, x, layer_idx=None): If it's a tensor of shape `(seq_length,)`, will matmul each element of the sequence to the corresponding weights. But if `layer_idx` is a tensor of shape `(seq_length,)`, it will matmul each i-th element of the input sequence to the corresponding layer `weight[i]`. """ - if layer_idx is not None and layer_idx.dim()==0: - # Single layer case: select a specific layer (batch_size, 1 , input_size) -> (batch_size, 1, output_size) - return torch.matmul(x, self.weight[layer_idx].T) - elif layer_idx is not None: + if layer_idx is not None: # Use torch.gather to select the corresponding weights for each sample selected_weights = torch.index_select(self.weight, 0, layer_idx) return torch.einsum('bnh,noh->bno', x, selected_weights) @@ -427,9 +478,9 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) if not self.is_depth_attention else self.q_proj(hidden_states, position_ids) # Ignore copy - key_states = self.k_proj(hidden_states) if not self.is_depth_attention else self.k_proj(hidden_states, position_ids) # Ignore copy - value_states = self.v_proj(hidden_states) if not self.is_depth_attention else self.v_proj(hidden_states, position_ids) # Ignore copy + query_states = self.q_proj(hidden_states) if not self.is_depth_attention else self.q_proj(hidden_states, cache_position) # Ignore copy + key_states = self.k_proj(hidden_states) if not self.is_depth_attention else self.k_proj(hidden_states, cache_position) # Ignore copy + value_states = self.v_proj(hidden_states) if not self.is_depth_attention else self.v_proj(hidden_states, cache_position) # Ignore copy query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -441,7 +492,7 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} if self.rotary_emb is not None else {"cache_position": cache_position} # Ignore copy key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -467,7 +518,7 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) if not self.is_depth_attention else self.o_proj(attn_output, position_ids) # Ignore copy + attn_output = self.o_proj(attn_output) if not self.is_depth_attention else self.o_proj(attn_output, cache_position) # Ignore copy if not output_attentions: attn_weights = None @@ -511,9 +562,9 @@ def forward( bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) if not self.is_depth_attention else self.q_proj(hidden_states, cache_position) # Ignore copy + key_states = self.k_proj(hidden_states) if not self.is_depth_attention else self.k_proj(hidden_states, cache_position) # Ignore copy + value_states = self.v_proj(hidden_states) if not self.is_depth_attention else self.v_proj(hidden_states, cache_position) # Ignore copy # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim @@ -529,7 +580,7 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} if self.rotary_emb is not None else {"cache_position": cache_position} # Ignore copy key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache @@ -580,7 +631,7 @@ def forward( ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output) if not self.is_depth_attention else self.o_proj(attn_output, cache_position) # Ignore copy if not output_attentions: attn_weights = None @@ -626,9 +677,9 @@ def forward( bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) if not self.is_depth_attention else self.q_proj(hidden_states, cache_position) # Ignore copy + key_states = self.k_proj(hidden_states) if not self.is_depth_attention else self.k_proj(hidden_states, cache_position) # Ignore copy + value_states = self.v_proj(hidden_states) if not self.is_depth_attention else self.v_proj(hidden_states, cache_position) # Ignore copy query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -641,7 +692,7 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} if self.rotary_emb is not None else {"cache_position": cache_position} # Ignore copy key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -674,7 +725,7 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output) if not self.is_depth_attention else self.o_proj(attn_output, cache_position) # Ignore copy return attn_output, None, past_key_value @@ -926,7 +977,6 @@ def _init_weights(self, module): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ -# TODO: DO it as a depth decoder class MoshiDepthDecoder(MoshiPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoshiTransformerLayer`] @@ -947,21 +997,11 @@ def __init__(self, config: MoshiConfig): self.input_projections = MoshiFlexibleLinear(config.hidden_size, config.depth_hidden_size, config.num_codebooks) - # TODO: remove if relevant - # nn.ModuleList( - # [nn.Linear(config.hidden_size, config.depth_hidden_size, bias=False) for _ in range(config.num_codebooks)] - # ) - self.layers = nn.ModuleList( [MoshiDecoderLayer(config, layer_idx, use_flexible_linear=True, is_depth_layer=True) for layer_idx in range(config.depth_num_hidden_layers)] ) self.lm_heads = MoshiFlexibleLinear(config.depth_hidden_size, config.audio_vocab_size, config.num_codebooks) - - # TODO: remove if relevant - # nn.ModuleList( - # [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_codebooks)] - # ) self._attn_implementation = config._attn_implementation @@ -971,10 +1011,8 @@ def __init__(self, config: MoshiConfig): def forward( # TODO: update docstrings entirely self, input_ids: Optional[torch.LongTensor] = None, # sequence of oracle input ids, i.e it must be the input ids that are predicted by the decoder # (B, S) - audio_codes: Optional[torch.Tensor] = None, # Same, shoud be oracle audio codebooks, but also with one channel less: # (B, C, S) or C-1 - last_hidden_states: torch.LongTensor = None, # use 8 times (B, S, H_in) | (B*S, H_in) + last_hidden_state: torch.LongTensor = None, # shape: (B*S, 1, hidden_dim) # use 8 times (B, S, H_in) | (B*S, H_in) attention_mask: Optional[torch.BoolTensor] = None, - padding_mask: Optional[torch.BoolTensor] = None, past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -983,7 +1021,7 @@ def forward( # TODO: update docstrings entirely return_dict: Optional[bool] = None, position_ids: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, - + cache_position: Optional[torch.LongTensor] = None, # TODO: add to docstrings ) -> Union[Tuple, BaseModelOutputWithPast]: """ Args: @@ -1049,19 +1087,23 @@ def forward( # TODO: update docstrings entirely """ # here, we actually predict a sequence length of C # independtly from the batch size and sequence length - # 1/ input ids is passed through text_embed_tokens -> (B, S, H) H=1024 - # 2/ each codebooks is passed through the embedding layer ase well -> (B, C-1, S, H) - # 3/ concat the two precedent results and get (B, C, S ,H) + # 1/ input ids is passed through text_embed_tokens -> (B * S, H) H=1024 + # 2/ each codebooks is passed through the embedding layer ase well -> (B*S, C-1, H) + # 3/ concat the two precedent results and get (B*S, C, ,H) # 4/ then we also pass the last hidden states through the input projection layers, one for each codebooks - # we get (B, C, S, H) - # 5/ sum one and another (B, C, S, H) + # we get (B*S, C, H) + # 5/ sum one and another (B*S, C, H) # 6/ pass (B*S, C, H) through the model and get (B*S, C, H_out) # 7/ for each codebook, pass it through its own lm heads: (B*S, C, H) # 8/ predict the codebook C1, C2 ... -> (B, S, C, H) + # generation: + # we start with last hidden states and text tokens + # depending on position ids chose which embedding layer + # TODO: can we suppose B*S each time instead of B,S # in the generation mode, it's different: - # S=1 + # text_token (B*S, ) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1079,44 +1121,40 @@ def forward( # TODO: update docstrings entirely if use_cache and past_key_values is None and not self.training: past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length() if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device + past_seen_tokens, past_seen_tokens + input_ids.shape[1], device=input_ids.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) - # If inputs_embeds is provided, it has the priority over input_ids and audio_codes, which won't be used + # If inputs_embeds is provided, it has the priority over input_ids, which won't be used if inputs_embeds is None: - if input_ids is None and audio_codes is None: - raise ValueError("You must provide at least one of `input_ids`, `inputs_embeds`, `input_values` and `audio_codes`.") - - if input_ids is not None: - inputs_embeds = self.text_embed_tokens(input_ids) - - # TODO: this should actually use embed_tokens depending on which position ids is asked for - # We should actually use one codebook embedding per element of the sequence - if audio_codes is not None: # TODO(YL): make sure it's C-1 - audio_inputs_embeds = sum([self.embed_tokens[codebook](audio_codes[:, codebook]) for codebook in range(audio_codes.shape[1])]) - inputs_embeds = audio_inputs_embeds if inputs_embeds is None else audio_inputs_embeds + inputs_embeds - - # TODO: check position ids shape - inputs_embeds += self.input_projections(last_hidden_states, position_ids) + inputs_embeds = [] + for position_idx in cache_position: + if position_idx == 0: + inputs_embeds.append(self.text_embed_tokens(input_ids[:, [position_idx]])) + else: + inputs_embeds.append(self.embed_tokens[(position_idx-1)](input_ids[:, [position_idx - past_seen_tokens]])) + + inputs_embeds = torch.cat(inputs_embeds, dim=1) + + inputs_embeds += self.input_projections(last_hidden_state, cache_position) causal_mask = None if attention_mask is not None: causal_mask = self._update_causal_mask( - attention_mask, hidden_states, cache_position, past_key_values, output_attentions + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None - + hidden_states = inputs_embeds for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -1157,26 +1195,13 @@ def forward( # TODO: update docstrings entirely next_cache = next_decoder_cache if use_cache else None - - # TODO: check position ids shape # TODO: remove the float() operation in v4.46 - logits = self.lm_heads(hidden_states, position_ids).float() + logits = self.lm_heads(hidden_states, cache_position).float() loss = None if labels is not None: - # TODO: it's probably not the right way to do it - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + loss = 0 + # TODO(YL) if not return_dict: return tuple(v for v in [loss, logits, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -1189,6 +1214,77 @@ def forward( # TODO: update docstrings entirely attentions=all_self_attns, ) + # Copied from transformers.models.gemma.modeling_gemma.GemmaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): + # TODO: use this to make sure max_tokens = num_codebooks + super()._validate_generated_length(generation_config, input_ids_length, has_default_max_length) + + # Copied from transformers.models.gemma.modeling_gemma.GemmaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, @@ -1201,11 +1297,6 @@ def prepare_inputs_for_generation( num_logits_to_keep=None, **kwargs, ): - # TODO(YL): on the first step, `input_ids` is used - # then every new input_ids are passed as `audio_codes` instead! - # Make sure cache_positions is correct - # do we use num_logits_to_keep? - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here @@ -1263,6 +1354,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, + "last_hidden_state": kwargs.get("last_hidden_state") # Ignore copy } ) return model_inputs @@ -1316,7 +1408,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, # TODO(YL): add to docstrings ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1718,82 +1810,25 @@ def prepare_inputs_for_generation( "for speech-to-speech.", MOSHI_START_DOCSTRING, ) -class MoshiForConditionalGeneration(MoshiPreTrainedModel): # TODO(YL): don't think I can initialize like this for a composite model +class MoshiForConditionalGeneration(MoshiPreTrainedModel): config_class = MoshiConfig main_input_name = "input_ids" supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - def __init__( - self, - config: Optional[MoshiConfig] = None, - audio_encoder: Optional[PreTrainedModel] = None, - decoder: Optional[MoshiForCausalLM] = None, - depth_decoder: Optional[MoshiDepthDecoder] = None, - ): - if config is None and (audio_encoder is None or decoder is None or depth_decoder is None): - raise ValueError( - "Either a configuration has to be provided, or all three of Moshi depth decoder, audio encoder and Moshi decoder." - ) - if config is None: - config = MoshiConfig.from_audio_encoder_config(audio_encoder.config) - else: - if not isinstance(config, self.config_class): - raise ValueError(f"Config: {config} has to be of type {self.config_class}") - - # TODO: verify decoder and depth decoder not incompatible - # TODO: does the decoder and depth decoder makes sense? - - # initialize with config + def __init__(self, config: MoshiConfig): super().__init__(config) - - if audio_encoder is None: - from ..auto.modeling_auto import AutoModel - - audio_encoder = AutoModel.from_config(config.audio_encoder) - - if decoder is None: - decoder = MoshiForCausalLM(config) - - if depth_decoder is None: - depth_decoder = MoshiDepthDecoder(config) - - self.depth_decoder = depth_decoder - self.decoder = decoder - self.audio_encoder = audio_encoder - # We have 2 * num_codebooks audio embedding layers because we have the user input channel and the model output channel. self.embed_tokens = nn.ModuleList( [nn.Embedding(config.audio_vocab_size + 1, config.hidden_size) for _ in range(2 * config.num_codebooks)] ) - - if self.decoder.config.to_dict() != self.config.to_dict(): - logger.warning( - f"Config of the decoder: {self.decoder.__class__} is overwritten by shared config:" - f" {self.config}" - ) - if self.depth_decoder.config.to_dict() != self.config.to_dict(): - logger.warning( - f"Config of the depth decoder: {depth_decoder.decoder.__class__} is overwritten by shared config:" - f" {self.config}" - ) - if self.audio_encoder.config.to_dict() != self.config.audio_encoder.to_dict(): - logger.warning( - f"Config of the audio_encoder: {self.audio_encoder.__class__} is overwritten by shared audio_encoder config:" - f" {self.config.audio_encoder}" - ) - - # make sure that the individual model's config refers to the shared config - # so that the updates to the config will be synced - self.audio_encoder.config = self.config.audio_encoder - self.decoder.config = self.config - self.depth_decoder.config = self.config + self.audio_encoder = AutoModel.from_config(config.audio_encoder, attn_implementation=config._attn_implementation) + self.decoder = MoshiForCausalLM(config) + self.depth_decoder = MoshiDepthDecoder(config) self.num_codebooks = config.num_codebooks - - # tie text encoder, decoder weights if config set accordingly - self.tie_weights() + self.post_init() def get_audio_encoder(self): return self.audio_encoder @@ -1813,40 +1848,20 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): return self.decoder.set_output_embeddings(new_embeddings) - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Example: - - ```python - >>> from transformers import MoshiForConditionalGeneration - - >>> model = MoshiForConditionalGeneration.from_pretrained("kyutai/moshiko") - ```""" - - # At the moment fast initialization is not supported for composite models - if kwargs.get("_fast_init", False): - logger.warning( - "Fast initialization is currently not supported for MoshiForConditionalGeneration. " - "Falling back to slow initialization..." - ) - kwargs["_fast_init"] = False - - return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - @add_start_docstrings_to_model_forward(MOSHI_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.BoolTensor] = None, - input_values: Optional[torch.FloatTensor] = None, # audio_codes has priority over input_values - precise it - padding_mask: Optional[torch.BoolTensor] = None, - audio_codes: Optional[torch.Tensor] = None, # TODO add to docstrings - and decide if it's 16 codebooks or (8 and another audio_values) + user_input_values: Optional[torch.FloatTensor] = None, # audio_codes has priority over input_values - precise it + user_audio_codes: Optional[torch.Tensor] = None, # TODO add to docstrings + moshi_input_values: Optional[torch.FloatTensor] = None, # audio_codes has priority over input_values - precise it + moshi_audio_codes: Optional[torch.Tensor] = None, # TODO add to docstrings past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, text_labels: Optional[torch.LongTensor] = None, #TODO: update do docstrings - audio_labels: Optional[torch.LongTensor] = None, #TODO: update do docstrings + audio_labels: Optional[torch.LongTensor] = None, #TODO: update do docstrings - must be 16 channels (first user than moshi?) use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -1892,13 +1907,13 @@ def forward( argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } - # TODO: encode input_values - # TODO: how to deal with both streams, we actually two input_values stream - if input_values is not None and audio_codes is None: - # TODO: should be 16 codebooks - audio_codes = self.audio_encoder.encode(input_values, padding_mask, num_quantizers=self.num_codebooks, **kwargs_audio_encoder)[0] - + kwargs_depth_decoder = { + argument[len("depth_decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("depth_decoder_") + } + # TODO: we need to have same number of timestamps, and same number of batch + + if (text_labels is not None) and (input_ids is None and inputs_embeds is None): input_ids = shift_tokens_right( text_labels, self.config.decoder.pad_token_id, self.config.decoder.decoder_start_token_id @@ -1907,6 +1922,15 @@ def forward( # If inputs_embeds is provided, it has the priority over input_ids and audio_codes, which won't be used if inputs_embeds is None: + if user_input_values is not None and user_audio_codes is None: + user_audio_codes = self.audio_encoder.encode(user_input_values, num_quantizers=self.num_codebooks, **kwargs_audio_encoder)[0] + + if moshi_input_values is not None and moshi_audio_codes is None: + moshi_audio_codes = self.audio_encoder.encode(moshi_input_values, num_quantizers=self.num_codebooks, **kwargs_audio_encoder)[0] + + # TODO: make sure it's the right order (user than moshi) + make sure it's done over the right dim + audio_codes = torch.cat([user_audio_codes, moshi_audio_codes], dim=1) + if input_ids is None and audio_codes is None: raise ValueError("You must provide at least one of `input_ids`, `inputs_embeds`, `input_values` and `audio_codes`.") @@ -1930,60 +1954,80 @@ def forward( **kwargs_decoder, ) - # TODO: how to deal with loss here ? maybe we can do one loss for text - # and one loss for audio_labels? - decoder_last_hidden_states = decoder_outputs.last_hidden_state - # TODO: we want to pass the audio_codes and audio_labels from the model inputs + decoder_last_hidden_state = decoder_outputs.last_hidden_state depth_decoder_outputs = None if text_labels is not None and audio_labels is not None: - # TODO: how to deal with padding mask and attention mask ? - # To use depth decoder forward here, we actually need oracle input ids since we're supposed to pass the true input ids + + # (batch_size, sequence_length) -> (batch_size * sequence_length, 1) + text_labels = text_labels.view(-1, 1) + # (batch_size, num_codebooks, sequence_length) -> (batch_size * sequence_length, num_codebooks) + audio_labels = audio_labels.transpose(1,2).reshape(-1, audio_labels.shape[1]) + + depth_input_ids = torch.cat([text_labels, audio_labels], dim=1) + # keep the last codebook out of input_ids + depth_input_ids = depth_input_ids[:, :-1] + + # (batch_size, sequence_length, dim) -> (batch_size * sequence_length, 1, dim) + decoder_last_hidden_state = decoder_last_hidden_state.view(-1, 1, decoder_last_hidden_state.shape[-1]) + depth_decoder_outputs = self.depth_decoder( - hidden_states=decoder_last_hidden_states, - input_ids=text_labels, # probably need to reshape to (B*S) - audio_codes=audio_labels, # probably need to reshape to (B*S) + last_hidden_state=decoder_last_hidden_state, + input_ids=depth_input_ids, attention_mask=attention_mask, - padding_mask=padding_mask, ) - if not return_dict: outputs = decoder_outputs.to_tuple() if depth_decoder_outputs is not None: outputs += depth_decoder_outputs.to_tuple() - return outputs# TODO + encoder_outputs + return outputs - # TODO: change - return Seq2SeqLMOutput( - loss=decoder_outputs.loss, # TODO: it's the text loss + return MoshiConditionalGenerationOutputWithPast( + loss=decoder_outputs.loss, logits=decoder_outputs.logits, + last_hidden_state=decoder_last_hidden_state, past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + depth_loss=None if depth_decoder_outputs is None else depth_decoder_outputs.loss, + audio_logits=None if depth_decoder_outputs is None else depth_decoder_outputs.logits, + depth_past_key_values=None if decoder_outputs is None else decoder_outputs.past_key_values, + depth_hidden_states=None if decoder_outputs is None else decoder_outputs.hidden_states, + depth_attentions=None if decoder_outputs is None else decoder_outputs.attentions, ) def _prepare_inputs_embeds_for_generation( self, input_ids: Optional[torch.LongTensor] = None, - input_values: Optional[torch.FloatTensor] = None, - audio_codes: Optional[torch.Tensor] = None, + user_input_values: Optional[torch.FloatTensor] = None, + user_audio_codes: Optional[torch.Tensor] = None, + moshi_input_values: Optional[torch.FloatTensor] = None, # audio_codes has priority over input_values - precise it + moshi_audio_codes: Optional[torch.Tensor] = None, # TODO add to docstrings inputs_embeds: Optional[torch.FloatTensor] = None, ): # If inputs_embeds is provided, it has the priority over input_ids and audio_codes, which won't be used if inputs_embeds is None: - # TODO: here we have to decide how to deal with audio codes from the user - # also have to decide how to deal with number of channels - - if input_values is not None and audio_codes is None: - # TODO: should be 16 codebooks - audio_codes = self.audio_encoder.encode(input_values, num_quantizers=self.num_codebooks,)[0] + if input_ids is None and user_input_values is None and user_audio_codes is None and moshi_input_values is None and moshi_audio_codes is None: + raise ValueError("You must provide at least one of `input_ids`, `user_input_values`, `moshi_input_values`, `user_audio_codes` or `moshi_audio_codes`.") + # TODO: make sure batch size and sequence length is concording - if input_ids is None and audio_codes is None: - raise ValueError("You must provide at least one of `input_ids`, `inputs_embeds`, `input_values` and `audio_codes`.") + if user_input_values is not None and user_audio_codes is None: + user_audio_codes = self.audio_encoder.encode(user_input_values, num_quantizers=self.num_codebooks)[0] + + if moshi_input_values is not None and moshi_audio_codes is None: + moshi_audio_codes = self.audio_encoder.encode(moshi_input_values, num_quantizers=self.num_codebooks)[0] + + audio_codes = None + if user_audio_codes is not None and moshi_audio_codes is not None: + # TODO: make sure it's the right order (user than moshi) + make sure it's done over the right dim + audio_codes = torch.cat([user_audio_codes, moshi_audio_codes], dim=1) + elif user_audio_codes is not None: + audio_codes = user_audio_codes + elif moshi_audio_codes is not None: + audio_codes = moshi_audio_codes if input_ids is not None: inputs_embeds = self.decoder.model.embed_tokens(input_ids) @@ -1992,103 +2036,194 @@ def _prepare_inputs_embeds_for_generation( audio_inputs_embeds = sum([self.embed_tokens[codebook](audio_codes[:, codebook]) for codebook in range(audio_codes.shape[1])]) inputs_embeds = audio_inputs_embeds if inputs_embeds is None else audio_inputs_embeds + inputs_embeds - return inputs_embeds - + return inputs_embeds, moshi_audio_codes - def prepare_inputs_for_generation( + @torch.no_grad() + def generate( self, - input_ids, - past_key_values=None, - attention_mask=None, - use_cache=None, - decoder_delay_pattern_mask=None, - guidance_scale=None, + input_ids: Optional[torch.LongTensor] = None, + user_input_values: Optional[torch.FloatTensor] = None, # audio_codes has priority over input_values - precise it + user_audio_codes: Optional[torch.Tensor] = None, # TODO add to docstrings + moshi_input_values: Optional[torch.FloatTensor] = None, # audio_codes has priority over input_values - precise it + moshi_audio_codes: Optional[torch.Tensor] = None, # TODO add to docstrings + inputs_embeds: Optional[torch.FloatTensor] = None, **kwargs, - ): - if decoder_delay_pattern_mask is None: - input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( - input_ids, - self.generation_config.pad_token_id, - max_length=self.generation_config.max_length, - ) + ) -> torch.LongTensor: + """ + # TODO: modify + Generates sequences of token ids for models with a language modeling head. - # apply the delay pattern mask - input_ids = self.decoder.apply_delay_pattern_mask(input_ids, decoder_delay_pattern_mask) + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. + kwargs (`Dict[str, Any]`, *optional*): + Remaining dictionary of keyword arguments that are passed to the `generate` method. Refers to the + original [`generate` docstrings](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate) + for more information on how to use them. + Note that keywords with a *depth_* prefix will be input for the `generate` method of the + depth decoder. Otherwise, the latter will use its default generation config. + + Return: # TODO + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. - if guidance_scale is not None and guidance_scale > 1: - # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these - # before sampling) - input_ids = input_ids.repeat((2, 1)) - if attention_mask is not None: - attention_mask = attention_mask.repeat((2, 1)) + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + - [`~generation.GenerateDecoderOnlyOutput`], + - [`~generation.GenerateBeamDecoderOnlyOutput`] - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: - input_ids = input_ids[:, remove_prefix_length:] + - [`~generation.GenerateEncoderDecoderOutput`], + - [`~generation.GenerateBeamEncoderDecoderOutput`] + """ + inputs_embeds, moshi_audio_codes = self._prepare_inputs_embeds_for_generation( + input_ids=input_ids, + user_input_values=user_input_values, + user_audio_codes=user_audio_codes, + moshi_input_values=moshi_input_values, + moshi_audio_codes=moshi_audio_codes, + inputs_embeds=inputs_embeds, + ) + + self.generated_audio_codes = moshi_audio_codes + + outputs = super().generate(inputs_embeds=inputs_embeds, **kwargs) - return { - "input_ids": None, # TODO encoder_outputs is defined. input_ids not needed - "past_key_values": past_key_values, - "input_ids": input_ids, - "attention_mask": attention_mask, - "use_cache": use_cache, - } + # check if outputs is a dict or a Tensor (depending on unaccessed `generation_config.return_dict_in_generate`) + if isinstance(outputs, torch.Tensor): + output_text_ids = outputs + else: + output_text_ids = outputs.sequences + + output_audio_codes = self.generated_audio_codes + + + output_values = self.audio_encoder.decode( + output_audio_codes, + ).audio_values + + + return output_text_ids, output_values - def _prepare_audio_encoder_kwargs_for_generation( - self, input_values, model_kwargs, model_input_name: Optional[str] = None + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + **kwargs, ): - # 1. get audio encoder - encoder = self.get_audio_encoder() - # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device - # as the inputs. - if hasattr(encoder, "_hf_hook"): - encoder._hf_hook.io_same_device = True - - # 2. Prepare encoder args and encoder kwargs from model kwargs. - irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] - encoder_kwargs = { - argument: value - for argument, value in model_kwargs.items() - if not any(argument.startswith(p) for p in irrelevant_prefix) - } - encoder_signature = set(inspect.signature(encoder.forward).parameters) - encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature - if not encoder_accepts_wildcard: - encoder_kwargs = { - argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature - } + # 1. Do usual operations done on LLMs like Gemma - because we pre-processed inputs, the first pass always has inputs_embeds - # 3. make sure that encoder returns `ModelOutput` - model_input_name = model_input_name if model_input_name is not None else self.audio_encoder.main_input_name - encoder_kwargs["return_dict"] = True + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) - encoder_kwargs[model_input_name] = input_values - audio_encoder_outputs = encoder.encode(**encoder_kwargs) - audio_codes = audio_encoder_outputs.audio_codes - audio_scales = audio_encoder_outputs.audio_scales + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - frames, bsz, codebooks, seq_len = audio_codes.shape + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + dtype = self.lm_head.weight.dtype + min_dtype = torch.finfo(dtype).min - if frames != 1: - raise ValueError( - f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is " - "disabled by setting `chunk_length=None` in the audio encoder." + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, ) - input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len) - - model_kwargs["input_ids"] = input_ids - model_kwargs["audio_scales"] = audio_scales + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + + # 2. Now that everything is prepared, generate audio_codes using the depth decoder + + # we want to do it after a first token has been generated + if model_inputs["input_ids"] is not None: + last_hidden_state = kwargs.get("last_hidden_state") + # (batch_size, sequence_length, dim) -> (batch_size * sequence_length, 1, dim) + last_hidden_state = last_hidden_state.view(-1, 1, last_hidden_state.shape[-1]) + + input_ids = model_inputs.pop("input_ids") + + # TODO: allow passing generation kwargs + generated_audio_codes = self.depth_decoder.generate( + last_hidden_state=last_hidden_state, + input_ids=input_ids.view(-1, 1), + min_length=self.num_codebooks + 1,# TODO: change + max_length=self.num_codebooks + 1,# TODO: change + ) + + # the first tokens are text tokens + generated_audio_codes = generated_audio_codes[:, 1:].unsqueeze(2) + + self.generated_audio_codes = torch.cat([self.generated_audio_codes, generated_audio_codes], dim=2) + + # TODO: for now, we don't use blank user input ids !! + inputs_embeds, _ = self._prepare_inputs_embeds_for_generation(input_ids, moshi_audio_codes=generated_audio_codes) + + model_inputs["input_ids"] = None + model_inputs["inputs_embeds"] = inputs_embeds + + return model_inputs + + def _update_model_kwargs_for_generation(self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + ) -> Dict[str, Any]: + model_kwargs = super()._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder, num_new_tokens) + + # update last_hidden_state that'll be used in the depth decoder + model_kwargs["last_hidden_state"] = outputs.get("last_hidden_state") return model_kwargs def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): @@ -2115,274 +2250,4 @@ def freeze_depth_decoder(self): """ for param in self.depth_decoder.parameters(): param.requires_grad = False - self.depth_decoder._requires_grad = False - - def _maybe_initialize_input_ids_for_generation( - self, - inputs: Optional[torch.Tensor] = None, - bos_token_id: Optional[int] = None, - model_kwargs: Optional[Dict[str, torch.Tensor]] = None, - ) -> torch.LongTensor: - """Initializes input ids for generation, if necessary.""" - if inputs is not None: - return inputs - - if bos_token_id is None: - raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") - - # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with - # soft-prompting or in multimodal implementations built on top of decoder-only language models. - batch_size = 1 - for value in model_kwargs.values(): - if isinstance(value, torch.Tensor): - batch_size = value.shape[0] - break - return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id - - def _get_decoder_start_token_id( - self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None - ) -> int: - decoder_start_token_id = ( - decoder_start_token_id - if decoder_start_token_id is not None - else self.generation_config.decoder_start_token_id - ) - bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id - - if decoder_start_token_id is not None: - return decoder_start_token_id - elif bos_token_id is not None: - return bos_token_id - raise ValueError( - "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." - ) - - @torch.no_grad() - def generate( - self, - inputs: Optional[torch.Tensor] = None, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - synced_gpus: Optional[bool] = None, - streamer: Optional["BaseStreamer"] = None, - **kwargs, - ): - """ - - Generates sequences of token ids for models with a language modeling head. - - - - Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the - model's default generation configuration. You can override any `generation_config` by passing the corresponding - parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. - - For an overview of generation strategies and code examples, check out the [following - guide](./generation_strategies). - - - - Parameters: - inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): - The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the - method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` - should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of - `input_ids`, `input_values`, `input_features`, or `pixel_values`. - generation_config (`~generation.GenerationConfig`, *optional*): - The generation configuration to be used as base parametrization for the generation call. `**kwargs` - passed to generate matching the attributes of `generation_config` will override them. If - `generation_config` is not provided, the default will be used, which had the following loading - priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model - configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s - default values, whose documentation should be checked to parameterize generation. - logits_processor (`LogitsProcessorList`, *optional*): - Custom logits processors that complement the default logits processors built from arguments and - generation config. If a logit processor is passed that is already created with the arguments or a - generation config an error is thrown. This feature is intended for advanced users. - stopping_criteria (`StoppingCriteriaList`, *optional*): - Custom stopping criteria that complement the default stopping criteria built from arguments and a - generation config. If a stopping criteria is passed that is already created with the arguments or a - generation config an error is thrown. This feature is intended for advanced users. - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - streamer (`BaseStreamer`, *optional*): - Streamer object that will be used to stream the generated sequences. Generated tokens are passed - through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - kwargs (`Dict[str, Any]`, *optional*): - Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be - forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder - specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. - - Return: - [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` - or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. - - If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible - [`~utils.ModelOutput`] types are: - - - [`~generation.GenerateDecoderOnlyOutput`], - - [`~generation.GenerateBeamDecoderOnlyOutput`] - - If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible - [`~utils.ModelOutput`] types are: - - - [`~generation.GenerateEncoderDecoderOutput`], - - [`~generation.GenerateBeamEncoderDecoderOutput`] - """ - # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects - if generation_config is None: - generation_config = self.generation_config - - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs - generation_config.validate() - self._validate_model_kwargs(model_kwargs.copy()) - - - # 2. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - - requires_attention_mask = False # TODO "encoder_outputs" not in model_kwargs - kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None - - # 3. Define model inputs - inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( - inputs, generation_config.bos_token_id, model_kwargs - ) - batch_size = inputs_tensor.shape[0] - self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device) - - # 4. Define other model kwargs - model_kwargs["use_cache"] = generation_config.use_cache - model_kwargs["guidance_scale"] = generation_config.guidance_scale - - if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: - model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor - ) - - - if "input_ids" not in model_kwargs and "input_values" in model_kwargs: - model_kwargs = self._prepare_audio_encoder_kwargs_for_generation( - model_kwargs["input_values"], - model_kwargs, - ) - - # 5. Prepare `input_ids` which will be used for auto-regressive generation - input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( - batch_size=batch_size, - model_input_name=model_input_name, - model_kwargs=model_kwargs, - decoder_start_token_id=generation_config._decoder_start_token_tensor, - bos_token_id=generation_config._bos_token_tensor, - device=inputs_tensor.device, - ) - - # 6. Prepare `max_length` depending on other stopping criteria. - input_ids_length = input_ids.shape[-1] - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None - generation_config = self._prepare_generated_length( - generation_config=generation_config, - has_default_max_length=has_default_max_length, - has_default_min_length=has_default_min_length, - model_input_name=model_input_name, - inputs_tensor=inputs_tensor, - input_ids_length=input_ids_length, - ) - - # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Moshi) - input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( - input_ids, - pad_token_id=generation_config._decoder_start_token_tensor, - max_length=generation_config.max_length, - ) - # stash the delay mask so that we don't have to recompute in each forward pass - model_kwargs["decoder_delay_pattern_mask"] = decoder_delay_pattern_mask - - # input_ids are ready to be placed on the streamer (if used) - if streamer is not None: - streamer.put(input_ids.cpu()) - - # 7. determine generation mode - generation_mode = generation_config.get_generation_mode() - - # 8. prepare batched CFG externally (to enable coexistance with the unbatched CFG) - if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1: - logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale)) - generation_config.guidance_scale = None - - # 9. prepare distribution pre_processing samplers - logits_processor = self._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_length, - encoder_input_ids=inputs_tensor, - prefix_allowed_tokens_fn=None, - logits_processor=logits_processor, - device=input_ids.device, - ) - - # 10. prepare stopping criteria - stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria - ) - - if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): - # expand input_ids with `num_return_sequences` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=generation_config.num_return_sequences, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, - ) - - # 11. run sample - outputs = self._sample( - input_ids, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, - generation_config=generation_config, - synced_gpus=synced_gpus, - streamer=streamer, - **model_kwargs, - ) - - else: - raise ValueError( - "Got incompatible mode for generation, should be one of greedy or sampling. " - "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`." - ) - - if generation_config.return_dict_in_generate: - output_ids = outputs.sequences - else: - output_ids = outputs - - # apply the pattern mask to the final ids - output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"]) - - # revert the pattern delay mask by filtering the pad token id - output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape( - batch_size, self.decoder.num_codebooks, -1 - ) - - # append the frame dimension back to the audio codes - output_ids = output_ids[None, ...] - - audio_scales = model_kwargs.get("audio_scales") - if audio_scales is None: - audio_scales = [None] * batch_size - - output_values = self.audio_encoder.decode( - output_ids, - audio_scales=audio_scales, - ).audio_values - - - if generation_config.return_dict_in_generate: - outputs.sequences = output_values - return outputs - else: - return output_values \ No newline at end of file + self.depth_decoder._requires_grad = False \ No newline at end of file