Skip to content

Commit

Permalink
propose better readiblity and code organisation
Browse files Browse the repository at this point in the history
  • Loading branch information
ylacombe committed Sep 27, 2024
1 parent c2b1e60 commit 2499419
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 99 deletions.
20 changes: 10 additions & 10 deletions src/transformers/models/moshi/convert_moshi_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
PreTrainedTokenizerFast,
logging,
)
from transformers.convert_slow_tokenizer import Converter, SpmConverter, _get_prepend_scheme, import_protobuf
from transformers.convert_slow_tokenizer import Converter, SpmConverter, import_protobuf
from transformers.utils import requires_backends


Expand Down Expand Up @@ -85,7 +85,7 @@ def decoder(self, replacement, add_prefix_space):
return decoders.Sequence(sequence)

def pre_tokenizer(self, replacement, add_prefix_space):
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
prepend_scheme = "first"
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)


Expand Down Expand Up @@ -120,7 +120,7 @@ def _grab_best_device(use_gpu=True):
# TRANSFORMERS PART
("gating.linear_in", "mlp.fc1"),
("gating.linear_out", "mlp.fc2"),
("self_attn.out_proj", "self_attn.o_proj"),
("self_attn.out_proj", "self_attn.o_proj.linear"),
("norm1", "input_layernorm"),
("norm2", "post_attention_layernorm"),
("layer_scale_1", "self_attn_layer_scale"),
Expand Down Expand Up @@ -202,23 +202,23 @@ def permute(w, n_heads, dim1=hidden_size, dim2=hidden_size):
query_layer = mixed_qkv[:, :qkv_dim]
key_layer = mixed_qkv[:, qkv_dim : qkv_dim * 2]
value_layer = mixed_qkv[:, qkv_dim * 2 :]
state_dict[new_k.replace("in_proj_weight", "q_proj.weight")] = query_layer
state_dict[new_k.replace("in_proj_weight", "k_proj.weight")] = key_layer
state_dict[new_k.replace("in_proj_weight", "q_proj.linear.weight")] = query_layer
state_dict[new_k.replace("in_proj_weight", "k_proj.linear.weight")] = key_layer

else:
qkv_dim = mixed_qkv.size(0) // 3

query_layer = mixed_qkv[:qkv_dim]
key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
value_layer = mixed_qkv[qkv_dim * 2 :]
state_dict[new_k.replace("in_proj_weight", "q_proj.weight")] = permute(
state_dict[new_k.replace("in_proj_weight", "q_proj.linear.weight")] = permute(
query_layer, num_heads, hidden_size, hidden_size
)
state_dict[new_k.replace("in_proj_weight", "k_proj.weight")] = permute(
state_dict[new_k.replace("in_proj_weight", "k_proj.linear.weight")] = permute(
key_layer, num_key_value_heads, key_value_head_dim, hidden_size
)

state_dict[new_k.replace("in_proj_weight", "v_proj.weight")] = value_layer
state_dict[new_k.replace("in_proj_weight", "v_proj.linear.weight")] = value_layer
elif "o_proj" in new_k and "depth_decoder" in new_k:
output_layer = state_dict.pop(k)
state_dict[new_k] = output_layer.view(config.num_codebooks, -1, output_layer.shape[-1])
Expand Down Expand Up @@ -331,8 +331,6 @@ def convert_checkpoint(
if args.tokenizer_vocab_path:
tokenizer = MoshiConverter(args.tokenizer_vocab_path).tokenizer

tokenizer.save_pretrained(args.pytorch_dump_folder_path)

if args.test_tokenizer:
import sentencepiece
import tqdm
Expand All @@ -355,6 +353,8 @@ def convert_checkpoint(
f"elements in original: {set(decoded2)-set(decoded1)} \n\n{string}"
)

tokenizer.save_pretrained(args.pytorch_dump_folder_path)

if args.push_to_hub:
print("Pushing the tokenizer to the hub...")
tokenizer.push_to_hub(args.push_to_hub)
Expand Down
149 changes: 60 additions & 89 deletions src/transformers/models/moshi/modeling_moshi.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,25 @@ def forward(self, x, layer_idx=None):
# Multiple layers case:
# use einsum to batch the operations (batch_size, num_layers, input_size) -> (batch_size, num_layers, output_size)
return torch.einsum("bnh,noh->bno", x, self.weight)



class MoshiLinear(nn.Module):
def __init__(self, input_dim, output_dim, num_codebooks, use_flexible_linear=False):
super().__init__()

self.use_flexible_linear = use_flexible_linear

if not use_flexible_linear:
self.linear = nn.Linear(input_dim, output_dim, bias=False)
else:
self.linear = MoshiFlexibleLinear(input_dim, output_dim, num_layers=num_codebooks)


def forward(self, x, layer_idx=None):
if self.use_flexible_linear:
return self.linear(x, layer_idx)
else:
return self.linear(x)

# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Moshi
class MoshiRotaryEmbedding(nn.Module):
Expand Down Expand Up @@ -374,12 +392,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):


class MoshiGatingMLP(nn.Module):
def __init__(self, config, num_layers=1, is_depth_mlp=False):
def __init__(self, config, use_flexible_linear=False):
super().__init__()

self.activation_fn = ACT2FN[config.hidden_act]
ffn_dim = config.ffn_dim if not is_depth_mlp else config.depth_ffn_dim
hidden_size = config.hidden_size if not is_depth_mlp else config.depth_hidden_size
ffn_dim = config.ffn_dim
hidden_size = config.hidden_size
num_layers = config.num_codebooks if use_flexible_linear else 1
if num_layers == 1:
self.fc1 = nn.Linear(hidden_size, ffn_dim, bias=False)
self.fc2 = nn.Linear(ffn_dim // 2, hidden_size, bias=False)
Expand Down Expand Up @@ -413,11 +432,10 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
class MoshiAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None, is_depth_attention=False):
def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None, use_flexible_linear=False, use_rope=True):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.is_depth_attention = is_depth_attention
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
Expand All @@ -426,16 +444,12 @@ def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None, is_dept
)

self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size if not is_depth_attention else config.depth_hidden_size
self.num_heads = config.num_attention_heads if not is_depth_attention else config.depth_num_attention_heads
self.head_dim = config.head_dim if not is_depth_attention else config.depth_head_dim
self.num_key_value_heads = (
config.num_key_value_heads if not is_depth_attention else config.depth_num_key_value_heads
)
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = (
config.max_position_embeddings if not is_depth_attention else config.depth_max_position_embeddings
)
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.scaling = 1 / math.sqrt(self.head_dim)
Expand All @@ -446,37 +460,22 @@ def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None, is_dept
f" and `num_heads`: {self.num_heads})."
)

if not is_depth_attention:
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
else:
self.q_proj = MoshiFlexibleLinear(
self.hidden_size, self.num_heads * self.head_dim, num_layers=config.num_codebooks
)
self.k_proj = MoshiFlexibleLinear(
self.hidden_size, self.num_key_value_heads * self.head_dim, num_layers=config.num_codebooks
)
self.v_proj = MoshiFlexibleLinear(
self.hidden_size, self.num_key_value_heads * self.head_dim, num_layers=config.num_codebooks
)
self.o_proj = MoshiFlexibleLinear(
self.num_heads * self.head_dim, self.hidden_size, num_layers=config.num_codebooks
)
self.q_proj = MoshiLinear(self.hidden_size, self.num_heads * self.head_dim, config.num_codebooks, use_flexible_linear)
self.k_proj = MoshiLinear(self.hidden_size, self.num_key_value_heads * self.head_dim, config.num_codebooks, use_flexible_linear)
self.v_proj = MoshiLinear(self.hidden_size, self.num_key_value_heads * self.head_dim, config.num_codebooks, use_flexible_linear)
self.o_proj = MoshiLinear(self.num_heads * self.head_dim, self.hidden_size, config.num_codebooks, use_flexible_linear)

# rotary embeddings are not used in the depth decoder
self.rotary_emb = (
MoshiRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
if not is_depth_attention
else None
)
self.rotary_emb = None
if use_rope:
self.rotary_emb = MoshiRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)


# Copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward with Gemma->Moshi
# Copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -489,15 +488,9 @@ def forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

query_states = (
self.q_proj(hidden_states) if not self.is_depth_attention else self.q_proj(hidden_states, cache_position)
) # Ignore copy
key_states = (
self.k_proj(hidden_states) if not self.is_depth_attention else self.k_proj(hidden_states, cache_position)
) # Ignore copy
value_states = (
self.v_proj(hidden_states) if not self.is_depth_attention else self.v_proj(hidden_states, cache_position)
) # Ignore copy
query_states = self.q_proj(hidden_states, cache_position) # Ignore copy
key_states = self.k_proj(hidden_states, cache_position) # Ignore copy
value_states = self.v_proj(hidden_states, cache_position) # Ignore copy

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
Expand Down Expand Up @@ -539,9 +532,7 @@ def forward(
attn_output = attn_output.transpose(1, 2).contiguous()

attn_output = attn_output.view(bsz, q_len, -1)
attn_output = (
self.o_proj(attn_output) if not self.is_depth_attention else self.o_proj(attn_output, cache_position)
) # Ignore copy
attn_output = self.o_proj(attn_output, cache_position) # Ignore copy

if not output_attentions:
attn_weights = None
Expand Down Expand Up @@ -585,15 +576,9 @@ def forward(

bsz, q_len, _ = hidden_states.size()

query_states = (
self.q_proj(hidden_states) if not self.is_depth_attention else self.q_proj(hidden_states, cache_position)
) # Ignore copy
key_states = (
self.k_proj(hidden_states) if not self.is_depth_attention else self.k_proj(hidden_states, cache_position)
) # Ignore copy
value_states = (
self.v_proj(hidden_states) if not self.is_depth_attention else self.v_proj(hidden_states, cache_position)
) # Ignore copy
query_states = self.q_proj(hidden_states, cache_position) # Ignore copy
key_states = self.k_proj(hidden_states, cache_position) # Ignore copy
value_states = self.v_proj(hidden_states, cache_position) # Ignore copy

# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
Expand Down Expand Up @@ -663,9 +648,7 @@ def forward(
)

attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = (
self.o_proj(attn_output) if not self.is_depth_attention else self.o_proj(attn_output, cache_position)
) # Ignore copy
attn_output = self.o_proj(attn_output, cache_position) # Ignore copy

if not output_attentions:
attn_weights = None
Expand Down Expand Up @@ -711,15 +694,9 @@ def forward(

bsz, q_len, _ = hidden_states.size()

query_states = (
self.q_proj(hidden_states) if not self.is_depth_attention else self.q_proj(hidden_states, cache_position)
) # Ignore copy
key_states = (
self.k_proj(hidden_states) if not self.is_depth_attention else self.k_proj(hidden_states, cache_position)
) # Ignore copy
value_states = (
self.v_proj(hidden_states) if not self.is_depth_attention else self.v_proj(hidden_states, cache_position)
) # Ignore copy
query_states = self.q_proj(hidden_states, cache_position) # Ignore copy
key_states = self.k_proj(hidden_states, cache_position) # Ignore copy
value_states = self.v_proj(hidden_states, cache_position) # Ignore copy

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
Expand Down Expand Up @@ -768,9 +745,8 @@ def forward(
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)

attn_output = (
self.o_proj(attn_output) if not self.is_depth_attention else self.o_proj(attn_output, cache_position)
) # Ignore copy
attn_output = self.o_proj(attn_output, cache_position) # Ignore copy


return attn_output, None, past_key_value

Expand All @@ -783,21 +759,16 @@ def forward(


class MoshiDecoderLayer(nn.Module):
def __init__(self, config: MoshiConfig, layer_idx: int, use_flexible_linear: bool = False, is_depth_layer=False):
def __init__(self, config: MoshiConfig, layer_idx: int, use_flexible_linear: bool = False, use_rope=True):
super().__init__()
self.is_depth_layer = is_depth_layer
self.hidden_size = config.hidden_size if not is_depth_layer else config.depth_hidden_size
self.hidden_size = config.hidden_size
self.use_flexible_linear = use_flexible_linear

self.self_attn = MOSHI_ATTENTION_CLASSES[config._attn_implementation](
config=config, layer_idx=layer_idx, is_depth_attention=is_depth_layer
config=config, layer_idx=layer_idx, use_flexible_linear=use_flexible_linear, use_rope=use_rope
)

self.mlp = (
MoshiGatingMLP(config)
if not use_flexible_linear
else MoshiGatingMLP(config, config.num_codebooks, is_depth_layer)
)
self.mlp = MoshiGatingMLP(config, use_flexible_linear)
self.input_layernorm = MoshiRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = MoshiRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window
Expand Down Expand Up @@ -1081,7 +1052,7 @@ def __init__(self, config: MoshiConfig):

self.layers = nn.ModuleList(
[
MoshiDecoderLayer(config, layer_idx, use_flexible_linear=True, is_depth_layer=True)
MoshiDecoderLayer(config, layer_idx, use_flexible_linear=True, use_rope=False)
for layer_idx in range(config.depth_num_hidden_layers)
]
)
Expand Down

0 comments on commit 2499419

Please sign in to comment.