diff --git a/src/transformers/models/hubert/configuration_hubert.py b/src/transformers/models/hubert/configuration_hubert.py index 20977cff87d167..9f488b19888957 100644 --- a/src/transformers/models/hubert/configuration_hubert.py +++ b/src/transformers/models/hubert/configuration_hubert.py @@ -94,6 +94,8 @@ class HubertConfig(PretrainedConfig): embeddings layer. num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16): Number of groups of 1D convolutional positional embeddings layer. + conv_pos_batch_norm (`bool`, *optional*, defaults to `False`): + Whether to use batch norm instead of weight norm in conv_pos do_stable_layer_norm (`bool`, *optional*, defaults to `False`): Whether do apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is @@ -182,6 +184,7 @@ def __init__( conv_bias=False, num_conv_pos_embeddings=128, num_conv_pos_embedding_groups=16, + conv_pos_batch_norm=False, do_stable_layer_norm=False, apply_spec_augment=True, mask_time_prob=0.05, @@ -209,6 +212,7 @@ def __init__( self.conv_bias = conv_bias self.num_conv_pos_embeddings = num_conv_pos_embeddings self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups + self.conv_pos_batch_norm = conv_pos_batch_norm self.num_feat_extract_layers = len(self.conv_dim) self.num_hidden_layers = num_hidden_layers self.intermediate_size = intermediate_size diff --git a/src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py index 6478fdadf13de3..4966340493f35c 100644 --- a/src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py @@ -38,7 +38,8 @@ MAPPING = { "post_extract_proj": "feature_projection.projection", - "encoder.pos_conv.0": "encoder.pos_conv_embed.conv", + "encoder.pos_conv.0": "encoder.pos_conv_embed.batch_norm", + "encoder.pos_conv.1": "encoder.pos_conv_embed.conv", "self_attn.k_proj": "encoder.layers.*.attention.k_proj", "self_attn.v_proj": "encoder.layers.*.attention.v_proj", "self_attn.q_proj": "encoder.layers.*.attention.q_proj", @@ -76,6 +77,12 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type): hf_pointer.weight_v.data = value elif weight_type == "bias": hf_pointer.bias.data = value + elif weight_type == "running_mean": + hf_pointer.running_mean.data = value + elif weight_type == "running_var": + hf_pointer.running_var.data = value + elif weight_type == "num_batches_tracked": + hf_pointer.num_batches_tracked.data = value else: hf_pointer.data = value @@ -116,6 +123,12 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned): weight_type = "weight" elif "bias" in name: weight_type = "bias" + elif "running_mean" in name: + weight_type = "running_mean" + elif "running_var" in name: + weight_type = "running_var" + elif "num_batches_tracked" in name: + weight_type = "num_batches_tracked" else: weight_type = None set_recursively(hf_model, mapped_key, value, name, weight_type) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 57f59cf9aab94f..03904a6abfa08b 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -260,7 +260,6 @@ def forward(self, hidden_states): return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Hubert class HubertPositionalConvEmbedding(nn.Module): def __init__(self, config): super().__init__() @@ -272,32 +271,37 @@ def __init__(self, config): groups=config.num_conv_pos_embedding_groups, ) - weight_norm = nn.utils.weight_norm - if hasattr(nn.utils.parametrizations, "weight_norm"): - weight_norm = nn.utils.parametrizations.weight_norm + self.batch_norm = None + if config.conv_pos_batch_norm: + self.batch_norm = nn.BatchNorm1d(config.hidden_size) + else: + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm - if is_deepspeed_zero3_enabled(): - import deepspeed + if is_deepspeed_zero3_enabled(): + import deepspeed - with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): - self.conv = weight_norm(self.conv, name="weight", dim=2) - if hasattr(self.conv, "parametrizations"): - weight_g = self.conv.parametrizations.weight.original0 - weight_v = self.conv.parametrizations.weight.original1 + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = weight_norm(self.conv, name="weight", dim=2) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) else: - weight_g = self.conv.weight_g - weight_v = self.conv.weight_v - deepspeed.zero.register_external_parameter(self, weight_v) - deepspeed.zero.register_external_parameter(self, weight_g) - else: - self.conv = weight_norm(self.conv, name="weight", dim=2) + self.conv = weight_norm(self.conv, name="weight", dim=2) self.padding = HubertSamePadLayer(config.num_conv_pos_embeddings) self.activation = ACT2FN[config.feat_extract_activation] def forward(self, hidden_states): hidden_states = hidden_states.transpose(1, 2) - + if self.batch_norm is not None: + hidden_states = self.batch_norm(hidden_states) hidden_states = self.conv(hidden_states) hidden_states = self.padding(hidden_states) hidden_states = self.activation(hidden_states) diff --git a/tests/models/hubert/test_modeling_hubert.py b/tests/models/hubert/test_modeling_hubert.py index 86f2b4119324ae..191d2f8c88c380 100644 --- a/tests/models/hubert/test_modeling_hubert.py +++ b/tests/models/hubert/test_modeling_hubert.py @@ -943,3 +943,40 @@ def test_inference_distilhubert(self): self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=5e-3)) self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3)) self.assertTrue(abs(outputs.sum() - expected_output_sum) < 0.1) + + def test_inference_hubert_25hz(self): + model = HubertModel.from_pretrained("slprl/mhubert-base-25hz").to(torch_device) + + sample = self._load_datasamples(1) + input_speech = torch.tensor(sample[0], dtype=torch.float, device=torch_device).unsqueeze(0) + + with torch.no_grad(): + outputs = model(input_speech, output_hidden_states=True).hidden_states[11] + + # expected outputs taken from the original textlesslib implementation by: + # model = SpeechEncoder.by_name(dense_model_name='mhubert-base-25hz', quantizer_model_name='kmeans', + # vocab_size=500, deduplicate=False, need_f0=False) + # model(wav)['dense'] + expected_outputs_first = torch.tensor( + [ + [0.0267, 0.1776, -0.1706, -0.4559], + [-0.2430, -0.2943, -0.1864, -0.1187], + [-0.1812, -0.4239, -0.1916, -0.0858], + [-0.1495, -0.4758, -0.4036, 0.0302], + ], + device=torch_device, + ) + expected_outputs_last = torch.tensor( + [ + [0.3366, -0.2734, -0.1415, -0.3055], + [0.2329, -0.3580, -0.1421, -0.3197], + [0.1631, -0.4301, -0.1965, -0.2956], + [0.3342, -0.2185, -0.2253, -0.2363], + ], + device=torch_device, + ) + expected_output_sum = 1681.7603 + + self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=5e-3)) + self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3)) + self.assertTrue(abs(outputs.sum() - expected_output_sum) < 0.1)