diff --git a/src/transformers/models/moshi/convert_moshi_transformers.py b/src/transformers/models/moshi/convert_moshi_transformers.py index be7a230d23d86a..57603d669e5a16 100644 --- a/src/transformers/models/moshi/convert_moshi_transformers.py +++ b/src/transformers/models/moshi/convert_moshi_transformers.py @@ -29,7 +29,7 @@ PreTrainedTokenizerFast, logging, ) -from transformers.convert_slow_tokenizer import Converter, SpmConverter, _get_prepend_scheme, import_protobuf +from transformers.convert_slow_tokenizer import Converter, SpmConverter, import_protobuf from transformers.utils import requires_backends @@ -85,7 +85,7 @@ def decoder(self, replacement, add_prefix_space): return decoders.Sequence(sequence) def pre_tokenizer(self, replacement, add_prefix_space): - prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer) + prepend_scheme = "first" return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False) @@ -120,7 +120,7 @@ def _grab_best_device(use_gpu=True): # TRANSFORMERS PART ("gating.linear_in", "mlp.fc1"), ("gating.linear_out", "mlp.fc2"), - ("self_attn.out_proj", "self_attn.o_proj"), + ("self_attn.out_proj", "self_attn.o_proj.linear"), ("norm1", "input_layernorm"), ("norm2", "post_attention_layernorm"), ("layer_scale_1", "self_attn_layer_scale"), @@ -202,8 +202,8 @@ def permute(w, n_heads, dim1=hidden_size, dim2=hidden_size): query_layer = mixed_qkv[:, :qkv_dim] key_layer = mixed_qkv[:, qkv_dim : qkv_dim * 2] value_layer = mixed_qkv[:, qkv_dim * 2 :] - state_dict[new_k.replace("in_proj_weight", "q_proj.weight")] = query_layer - state_dict[new_k.replace("in_proj_weight", "k_proj.weight")] = key_layer + state_dict[new_k.replace("in_proj_weight", "q_proj.linear.weight")] = query_layer + state_dict[new_k.replace("in_proj_weight", "k_proj.linear.weight")] = key_layer else: qkv_dim = mixed_qkv.size(0) // 3 @@ -211,14 +211,14 @@ def permute(w, n_heads, dim1=hidden_size, dim2=hidden_size): query_layer = mixed_qkv[:qkv_dim] key_layer = mixed_qkv[qkv_dim : qkv_dim * 2] value_layer = mixed_qkv[qkv_dim * 2 :] - state_dict[new_k.replace("in_proj_weight", "q_proj.weight")] = permute( + state_dict[new_k.replace("in_proj_weight", "q_proj.linear.weight")] = permute( query_layer, num_heads, hidden_size, hidden_size ) - state_dict[new_k.replace("in_proj_weight", "k_proj.weight")] = permute( + state_dict[new_k.replace("in_proj_weight", "k_proj.linear.weight")] = permute( key_layer, num_key_value_heads, key_value_head_dim, hidden_size ) - state_dict[new_k.replace("in_proj_weight", "v_proj.weight")] = value_layer + state_dict[new_k.replace("in_proj_weight", "v_proj.linear.weight")] = value_layer elif "o_proj" in new_k and "depth_decoder" in new_k: output_layer = state_dict.pop(k) state_dict[new_k] = output_layer.view(config.num_codebooks, -1, output_layer.shape[-1]) @@ -331,8 +331,6 @@ def convert_checkpoint( if args.tokenizer_vocab_path: tokenizer = MoshiConverter(args.tokenizer_vocab_path).tokenizer - tokenizer.save_pretrained(args.pytorch_dump_folder_path) - if args.test_tokenizer: import sentencepiece import tqdm @@ -355,6 +353,8 @@ def convert_checkpoint( f"elements in original: {set(decoded2)-set(decoded1)} \n\n{string}" ) + tokenizer.save_pretrained(args.pytorch_dump_folder_path) + if args.push_to_hub: print("Pushing the tokenizer to the hub...") tokenizer.push_to_hub(args.push_to_hub) diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 4be67a89a09c28..cddd0d4f558186 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -305,7 +305,25 @@ def forward(self, x, layer_idx=None): # Multiple layers case: # use einsum to batch the operations (batch_size, num_layers, input_size) -> (batch_size, num_layers, output_size) return torch.einsum("bnh,noh->bno", x, self.weight) - + + +class MoshiLinear(nn.Module): + def __init__(self, input_dim, output_dim, num_codebooks, use_flexible_linear=False): + super().__init__() + + self.use_flexible_linear = use_flexible_linear + + if not use_flexible_linear: + self.linear = nn.Linear(input_dim, output_dim, bias=False) + else: + self.linear = MoshiFlexibleLinear(input_dim, output_dim, num_layers=num_codebooks) + + + def forward(self, x, layer_idx=None): + if self.use_flexible_linear: + return self.linear(x, layer_idx) + else: + return self.linear(x) # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Moshi class MoshiRotaryEmbedding(nn.Module): @@ -374,12 +392,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class MoshiGatingMLP(nn.Module): - def __init__(self, config, num_layers=1, is_depth_mlp=False): + def __init__(self, config, use_flexible_linear=False): super().__init__() self.activation_fn = ACT2FN[config.hidden_act] - ffn_dim = config.ffn_dim if not is_depth_mlp else config.depth_ffn_dim - hidden_size = config.hidden_size if not is_depth_mlp else config.depth_hidden_size + ffn_dim = config.ffn_dim + hidden_size = config.hidden_size + num_layers = config.num_codebooks if use_flexible_linear else 1 if num_layers == 1: self.fc1 = nn.Linear(hidden_size, ffn_dim, bias=False) self.fc2 = nn.Linear(ffn_dim // 2, hidden_size, bias=False) @@ -413,11 +432,10 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class MoshiAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None, is_depth_attention=False): + def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None, use_flexible_linear=False, use_rope=True): super().__init__() self.config = config self.layer_idx = layer_idx - self.is_depth_attention = is_depth_attention if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " @@ -426,16 +444,12 @@ def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None, is_dept ) self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size if not is_depth_attention else config.depth_hidden_size - self.num_heads = config.num_attention_heads if not is_depth_attention else config.depth_num_attention_heads - self.head_dim = config.head_dim if not is_depth_attention else config.depth_head_dim - self.num_key_value_heads = ( - config.num_key_value_heads if not is_depth_attention else config.depth_num_key_value_heads - ) + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = ( - config.max_position_embeddings if not is_depth_attention else config.depth_max_position_embeddings - ) + self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True self.scaling = 1 / math.sqrt(self.head_dim) @@ -446,37 +460,22 @@ def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None, is_dept f" and `num_heads`: {self.num_heads})." ) - if not is_depth_attention: - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - else: - self.q_proj = MoshiFlexibleLinear( - self.hidden_size, self.num_heads * self.head_dim, num_layers=config.num_codebooks - ) - self.k_proj = MoshiFlexibleLinear( - self.hidden_size, self.num_key_value_heads * self.head_dim, num_layers=config.num_codebooks - ) - self.v_proj = MoshiFlexibleLinear( - self.hidden_size, self.num_key_value_heads * self.head_dim, num_layers=config.num_codebooks - ) - self.o_proj = MoshiFlexibleLinear( - self.num_heads * self.head_dim, self.hidden_size, num_layers=config.num_codebooks - ) + self.q_proj = MoshiLinear(self.hidden_size, self.num_heads * self.head_dim, config.num_codebooks, use_flexible_linear) + self.k_proj = MoshiLinear(self.hidden_size, self.num_key_value_heads * self.head_dim, config.num_codebooks, use_flexible_linear) + self.v_proj = MoshiLinear(self.hidden_size, self.num_key_value_heads * self.head_dim, config.num_codebooks, use_flexible_linear) + self.o_proj = MoshiLinear(self.num_heads * self.head_dim, self.hidden_size, config.num_codebooks, use_flexible_linear) # rotary embeddings are not used in the depth decoder - self.rotary_emb = ( - MoshiRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - if not is_depth_attention - else None - ) + self.rotary_emb = None + if use_rope: + self.rotary_emb = MoshiRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + - # Copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward with Gemma->Moshi + # Copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward def forward( self, hidden_states: torch.Tensor, @@ -489,15 +488,9 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - query_states = ( - self.q_proj(hidden_states) if not self.is_depth_attention else self.q_proj(hidden_states, cache_position) - ) # Ignore copy - key_states = ( - self.k_proj(hidden_states) if not self.is_depth_attention else self.k_proj(hidden_states, cache_position) - ) # Ignore copy - value_states = ( - self.v_proj(hidden_states) if not self.is_depth_attention else self.v_proj(hidden_states, cache_position) - ) # Ignore copy + query_states = self.q_proj(hidden_states, cache_position) # Ignore copy + key_states = self.k_proj(hidden_states, cache_position) # Ignore copy + value_states = self.v_proj(hidden_states, cache_position) # Ignore copy query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -539,9 +532,7 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) - attn_output = ( - self.o_proj(attn_output) if not self.is_depth_attention else self.o_proj(attn_output, cache_position) - ) # Ignore copy + attn_output = self.o_proj(attn_output, cache_position) # Ignore copy if not output_attentions: attn_weights = None @@ -585,15 +576,9 @@ def forward( bsz, q_len, _ = hidden_states.size() - query_states = ( - self.q_proj(hidden_states) if not self.is_depth_attention else self.q_proj(hidden_states, cache_position) - ) # Ignore copy - key_states = ( - self.k_proj(hidden_states) if not self.is_depth_attention else self.k_proj(hidden_states, cache_position) - ) # Ignore copy - value_states = ( - self.v_proj(hidden_states) if not self.is_depth_attention else self.v_proj(hidden_states, cache_position) - ) # Ignore copy + query_states = self.q_proj(hidden_states, cache_position) # Ignore copy + key_states = self.k_proj(hidden_states, cache_position) # Ignore copy + value_states = self.v_proj(hidden_states, cache_position) # Ignore copy # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim @@ -663,9 +648,7 @@ def forward( ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = ( - self.o_proj(attn_output) if not self.is_depth_attention else self.o_proj(attn_output, cache_position) - ) # Ignore copy + attn_output = self.o_proj(attn_output, cache_position) # Ignore copy if not output_attentions: attn_weights = None @@ -711,15 +694,9 @@ def forward( bsz, q_len, _ = hidden_states.size() - query_states = ( - self.q_proj(hidden_states) if not self.is_depth_attention else self.q_proj(hidden_states, cache_position) - ) # Ignore copy - key_states = ( - self.k_proj(hidden_states) if not self.is_depth_attention else self.k_proj(hidden_states, cache_position) - ) # Ignore copy - value_states = ( - self.v_proj(hidden_states) if not self.is_depth_attention else self.v_proj(hidden_states, cache_position) - ) # Ignore copy + query_states = self.q_proj(hidden_states, cache_position) # Ignore copy + key_states = self.k_proj(hidden_states, cache_position) # Ignore copy + value_states = self.v_proj(hidden_states, cache_position) # Ignore copy query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -768,9 +745,8 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) - attn_output = ( - self.o_proj(attn_output) if not self.is_depth_attention else self.o_proj(attn_output, cache_position) - ) # Ignore copy + attn_output = self.o_proj(attn_output, cache_position) # Ignore copy + return attn_output, None, past_key_value @@ -783,21 +759,16 @@ def forward( class MoshiDecoderLayer(nn.Module): - def __init__(self, config: MoshiConfig, layer_idx: int, use_flexible_linear: bool = False, is_depth_layer=False): + def __init__(self, config: MoshiConfig, layer_idx: int, use_flexible_linear: bool = False, use_rope=True): super().__init__() - self.is_depth_layer = is_depth_layer - self.hidden_size = config.hidden_size if not is_depth_layer else config.depth_hidden_size + self.hidden_size = config.hidden_size self.use_flexible_linear = use_flexible_linear self.self_attn = MOSHI_ATTENTION_CLASSES[config._attn_implementation]( - config=config, layer_idx=layer_idx, is_depth_attention=is_depth_layer + config=config, layer_idx=layer_idx, use_flexible_linear=use_flexible_linear, use_rope=use_rope ) - self.mlp = ( - MoshiGatingMLP(config) - if not use_flexible_linear - else MoshiGatingMLP(config, config.num_codebooks, is_depth_layer) - ) + self.mlp = MoshiGatingMLP(config, use_flexible_linear) self.input_layernorm = MoshiRMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MoshiRMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.sliding_window = config.sliding_window @@ -1081,7 +1052,7 @@ def __init__(self, config: MoshiConfig): self.layers = nn.ModuleList( [ - MoshiDecoderLayer(config, layer_idx, use_flexible_linear=True, is_depth_layer=True) + MoshiDecoderLayer(config, layer_idx, use_flexible_linear=True, use_rope=False) for layer_idx in range(config.depth_num_hidden_layers) ] )