Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Wav2Vec2 - MMS] Correct directly loading adapters weights #24335

Merged
merged 7 commits into from
Jun 20, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Correct direct lang loading
patrickvonplaten committed Jun 17, 2023
commit 7c00aed3939e24e07b4463470e791109d9624397
9 changes: 9 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
@@ -1275,6 +1275,12 @@ def tie_weights(self):
if hasattr(module, "_tie_weights"):
module._tie_weights()

def load_adaptive_weights(self):
"""
Load adaptive weights after state dict has been loaded. If required this method should be overridden by derived class.
"""
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd raise an exception here if it's called, otherwise it fails silently

Suggested change
pass
raise NotImplementedError

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah not 100% sure about the design here, but it seems much more in line with:

Think something like if hasattr(self, "load_adaptive_weights") is also not great.

Also it is a bit questionable whether load_adaptive_weights is general enough to warrant to be in modeling_utils.py, but there is no other way really for the functionality from_pretrained(..., target_lang="...")

@LysandreJik @sgugger wdyt?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I would override the _tie_encoder_decoder_weights method in Wav2Vec2 only and not add those changes in modeling_utils. If things change and we get lots of models with adapaters, we can revisit the decision and do something like this, but I'd wait for it to be necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's pretty hacky, but yeah ok for me. Will add a big comment that we do this to not introduce a new API


@staticmethod
def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str):
uninitialized_encoder_weights: List[str] = []
@@ -2897,6 +2903,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# make sure token embedding weights are still tied if needed
model.tie_weights()

# make sure adaptive weights can be loaded dynamically
model.load_adaptive_weights()

# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()

11 changes: 7 additions & 4 deletions src/transformers/models/wav2vec2/modeling_wav2vec2.py
Original file line number Diff line number Diff line change
@@ -1854,11 +1854,12 @@ def forward(
WAV_2_VEC_2_START_DOCSTRING,
)
class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
def __init__(self, config, target_lang=None):
def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)

self.wav2vec2 = Wav2Vec2Model(config)
self.dropout = nn.Dropout(config.final_dropout)
self.target_lang = target_lang

if config.vocab_size is None:
raise ValueError(
@@ -1872,16 +1873,18 @@ def __init__(self, config, target_lang=None):
)
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)

# Initialize weights and apply final processing
self.post_init()

def load_adaptive_weights(self):
target_lang = self.target_lang
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
logger.info("By default `target_lang` is set to 'eng'.")
elif target_lang is not None:
self.load_adapter(target_lang)

# Initialize weights and apply final processing
self.post_init()

def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature encoder so that its parameter will