Skip to content

Commit

Permalink
Hide aux
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Sep 21, 2021
1 parent b5df833 commit a5cba17
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _test_import_finetune(self, original, imported, config):
# Readout
x = torch.randn(3, 10, config["hidden_size"])
ref = original.lm_head(x)
hyp = imported.aux(x)
hyp = imported._aux(x)
self.assertEqual(ref, hyp)
# The whole model without mask
x = torch.randn(3, 1024)
Expand Down Expand Up @@ -195,8 +195,8 @@ def _test_recreate(self, imported, reloaded, config):
self.assertEqual(ref, hyp)
# Readout
x = torch.randn(3, 10, config["hidden_size"])
ref = imported.aux(x)
hyp = reloaded.aux(x)
ref = imported._aux(x)
hyp = reloaded._aux(x)
self.assertEqual(ref, hyp)
# The whole model
x = torch.randn(3, 1024)
Expand All @@ -208,7 +208,7 @@ def _test_recreate(self, imported, reloaded, config):
def test_recreate_pretrain(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval()
reloaded = factory_func(num_out=imported.aux.out_features)
reloaded = factory_func(num_out=imported._aux.out_features)
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
self._test_recreate(imported, reloaded, config)
Expand All @@ -217,7 +217,7 @@ def test_recreate_pretrain(self, config, factory_func):
def test_recreate_finetune(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval()
reloaded = factory_func(num_out=imported.aux.out_features)
reloaded = factory_func(num_out=imported._aux.out_features)
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
self._test_recreate(imported, reloaded, config)
14 changes: 6 additions & 8 deletions torchaudio/models/wav2vec2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,16 @@ class Wav2Vec2Model(Module):
encoder (torch.nn.Module):
Encoder that converts the audio features into the sequence of probability
distribution (in negative log-likelihood) over labels.
aux (Optional[torch.nn.Module]):
Auxiliary module. If provided, the output from encoder is passed to this module.
"""
def __init__(
self,
feature_extractor: Module,
encoder: Module,
aux: Optional[Module] = None,
):
super().__init__()
self.feature_extractor = feature_extractor
self.encoder = encoder
self.aux = aux
self._aux = None

@torch.jit.export
def extract_features(
Expand Down Expand Up @@ -95,8 +91,8 @@ def forward(
"""
x, lengths = self.feature_extractor(waveforms, lengths)
x = self.encoder(x, lengths)
if self.aux is not None:
x = self.aux(x)
if self._aux is not None:
x = self._aux(x)
return x, lengths


Expand Down Expand Up @@ -142,7 +138,9 @@ def _get_model(
in_features=encoder_embed_dim,
out_features=aux_num_out,
)
return Wav2Vec2Model(feature_extractor, encoder, aux)
model = Wav2Vec2Model(feature_extractor, encoder)
model._aux = aux
return model


def wav2vec2_base(num_out: int) -> Wav2Vec2Model:
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/models/wav2vec2/utils/import_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _build(config, original):
imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict())
imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict())
if original.__class__.__name__ == 'Wav2Vec2ForCTC':
imported.aux.load_state_dict(original.lm_head.state_dict())
imported._aux.load_state_dict(original.lm_head.state_dict())
return imported


Expand Down

0 comments on commit a5cba17

Please sign in to comment.