Skip to content

Commit

Permalink
[Wav2Vec2] Fix torch srcipt (huggingface#24062)
Browse files Browse the repository at this point in the history
* [Wav2Vec2] Fix torch srcipt

* fix more
  • Loading branch information
patrickvonplaten authored and novice03 committed Jun 23, 2023
1 parent df895a9 commit d1b7753
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
5 changes: 2 additions & 3 deletions src/transformers/models/wav2vec2/modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,8 +1178,7 @@ def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureEncoder)):
module.gradient_checkpointing = value

@property
def _adapters(self):
def _get_adapters(self):
if self.config.adapter_attn_dim is None:
raise ValueError(f"{self.__class__} has no adapter layers. Make sure to define `config.adapter_attn_dim`.")

Expand Down Expand Up @@ -1339,7 +1338,7 @@ def load_adapter(self, target_lang: str, **kwargs):
f" directory containing a file named {filepath}."
)

adapter_weights = self._adapters
adapter_weights = self._get_adapters()
unexpected_keys = set(state_dict.keys()) - set(adapter_weights.keys())
missing_keys = set(adapter_weights.keys()) - set(state_dict.keys())

Expand Down
6 changes: 3 additions & 3 deletions tests/models/wav2vec2/test_modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def create_and_check_model_with_attn_adapter(self, config, input_values, attenti
config.adapter_attn_dim = 16
model = Wav2Vec2ForCTC(config=config)

self.parent.assertIsNotNone(model._adapters)
self.parent.assertIsNotNone(model._get_adapters())

model.to(torch_device)
model.eval()
Expand Down Expand Up @@ -1146,7 +1146,7 @@ def get_logits(model, input_features):
model = Wav2Vec2ForCTC.from_pretrained(tempdir)

logits = get_logits(model, input_features)
adapter_weights = model._adapters
adapter_weights = model._get_adapters()

# save safe weights
safe_filepath = os.path.join(tempdir, WAV2VEC2_ADAPTER_SAFE_FILE.format("eng"))
Expand All @@ -1168,7 +1168,7 @@ def get_logits(model, input_features):
model = Wav2Vec2ForCTC.from_pretrained(tempdir)

logits = get_logits(model, input_features)
adapter_weights = model._adapters
adapter_weights = model._get_adapters()

# save pt weights
pt_filepath = os.path.join(tempdir, WAV2VEC2_ADAPTER_PT_FILE.format("eng"))
Expand Down

0 comments on commit d1b7753

Please sign in to comment.