diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index dcebb9ab2a6c..1c78f65f942a 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -31,7 +31,7 @@ ) from nemo.collections.asr.metrics import BLEU, WER from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel -from nemo.collections.asr.parts.mixins import ASRBPEMixin, ASRTranscriptionMixin +from nemo.collections.asr.parts.mixins import ASRBPEMixin, ASRModuleMixin, ASRTranscriptionMixin from nemo.collections.asr.parts.mixins.transcription import ( GenericTranscriptionType, InternalTranscribeConfig, @@ -115,7 +115,7 @@ def __post_init__(self): self.prompt = parse_multitask_prompt(self.prompt) -class EncDecMultiTaskModel(ASRModel, ExportableEncDecModel, ASRBPEMixin, ASRTranscriptionMixin): +class EncDecMultiTaskModel(ASRModel, ExportableEncDecModel, ASRBPEMixin, ASRModuleMixin, ASRTranscriptionMixin): """Base class for AED multi-task models""" def __init__(self, cfg: DictConfig, trainer: Trainer = None): @@ -225,6 +225,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.decoding, tokenize=self.cfg.get('bleu_tokenizer', "13a"), log_prediction=False ) # Wer is handling logging + # Setup encoder adapters (from ASRAdapterModelMixin) + self.setup_adapters() + def change_decoding_strategy(self, decoding_cfg: DictConfig): """ Changes decoding strategy used during Multi Task decoding process. @@ -1057,6 +1060,10 @@ def predict_step(self, batch, batch_idx=0, dataloader_idx=0, has_processed_signa text = [self.decoding.strip_special_tokens(t) for t in text] return text + @property + def adapter_module_names(self) -> List[str]: + return ['', 'encoder', 'transf_encoder', 'transf_decoder'] + def parse_multitask_prompt(prompt: dict | None) -> list[dict]: if prompt is None or not prompt: diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index 093419c3ca0c..7540532d371b 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -879,6 +879,10 @@ def list_available_models(cls) -> List[PretrainedModelInfo]: return results + @property + def adapter_module_names(self) -> List[str]: + return ['', 'encoder', 'decoder'] + @property def wer(self): return self._wer diff --git a/nemo/collections/asr/modules/transformer/transformer.py b/nemo/collections/asr/modules/transformer/transformer.py index 718448aa1c7c..0ea376340d18 100644 --- a/nemo/collections/asr/modules/transformer/transformer.py +++ b/nemo/collections/asr/modules/transformer/transformer.py @@ -13,18 +13,21 @@ # limitations under the License. from dataclasses import dataclass -from typing import Dict, Optional +from typing import Dict, List, Optional import torch -from omegaconf.omegaconf import MISSING +from omegaconf.omegaconf import MISSING, DictConfig from nemo.collections.asr.modules.transformer.decoder_module import DecoderModule from nemo.collections.asr.modules.transformer.encoder_module import EncoderModule -from nemo.collections.asr.modules.transformer.transformer_decoders import TransformerDecoder +from nemo.collections.asr.modules.transformer.transformer_decoders import TransformerDecoder, TransformerDecoderAdapter from nemo.collections.asr.modules.transformer.transformer_encoders import TransformerEncoder from nemo.collections.asr.modules.transformer.transformer_modules import TransformerEmbedding +from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin +from nemo.collections.asr.parts.utils import adapter_utils from nemo.core.classes.common import typecheck from nemo.core.classes.exportable import Exportable +from nemo.core.classes.mixins import adapter_mixins from nemo.core.neural_types import ChannelType, NeuralType @@ -155,6 +158,8 @@ def input_example(self, max_batch=1, max_dim=256): class TransformerDecoderNM(DecoderModule, Exportable): + DECODER_TYPE: type = TransformerDecoder + def __init__( self, vocab_size: int, @@ -192,7 +197,7 @@ def __init__( learn_positional_encodings=learn_positional_encodings, ) - self._decoder = TransformerDecoder( + self._decoder = self.DECODER_TYPE( hidden_size=self.hidden_size, num_layers=num_layers, inner_size=inner_size, @@ -207,7 +212,12 @@ def __init__( @typecheck() def forward( - self, input_ids, decoder_mask, encoder_embeddings, encoder_mask, decoder_mems=None, + self, + input_ids, + decoder_mask, + encoder_embeddings, + encoder_mask, + decoder_mems=None, ): start_pos = 0 if decoder_mems is not None: @@ -274,3 +284,36 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: return {"last_hidden_states": NeuralType(('B', 'D', 'T', 'D'), ChannelType())} else: return {"last_hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())} + + +class TransformerDecoderNMAdapter(TransformerDecoderNM, adapter_mixins.AdapterModuleMixin): + DECODER_TYPE: type = TransformerDecoderAdapter + + # Higher level forwarding + def add_adapter(self, name: str, cfg: dict): + cfg = self._update_adapter_cfg_input_dim(cfg) + self._decoder.add_adapter(name, cfg) # type: adapter_mixins.AdapterModuleMixin + + def is_adapter_available(self) -> bool: + return self._decoder.is_adapter_available() # type: adapter_mixins.AdapterModuleMixin + + def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True): + self._decoder.set_enabled_adapters(name=name, enabled=enabled) # # type: adapter_mixins.AdapterModuleMixin + + def get_enabled_adapters(self) -> List[str]: + names = set([]) + names.update(self._decoder.get_enabled_adapters()) # type: adapter_mixins.AdapterModuleMixin + + names = sorted(list(names)) + return names + + def _update_adapter_cfg_input_dim(self, cfg: DictConfig): + cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self._hidden_size) + return cfg + + +""" +Register any additional information +""" +if adapter_mixins.get_registered_adapter(TransformerDecoderNM) is None: + adapter_mixins.register_adapter(base_class=TransformerDecoderNM, adapter_class=TransformerDecoderNMAdapter) diff --git a/nemo/collections/asr/modules/transformer/transformer_decoders.py b/nemo/collections/asr/modules/transformer/transformer_decoders.py index a5b2c299393c..30c6179b85a6 100644 --- a/nemo/collections/asr/modules/transformer/transformer_decoders.py +++ b/nemo/collections/asr/modules/transformer/transformer_decoders.py @@ -13,17 +13,22 @@ # limitations under the License. import copy +from typing import List, Optional, Set import torch import torch.nn as nn +from omegaconf import DictConfig from nemo.collections.asr.modules.transformer.transformer_modules import MultiHeadAttention, PositionWiseFF +from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin +from nemo.collections.asr.parts.utils import adapter_utils from nemo.collections.common.parts import form_attention_mask +from nemo.core.classes.mixins import adapter_mixins __all__ = ["TransformerDecoder"] -class TransformerDecoderBlock(nn.Module): +class TransformerDecoderBlock(nn.Module, AttentionAdapterModuleMixin): """ Building block of Transformer decoder. @@ -63,6 +68,9 @@ def __init__( self.layer_norm_3 = nn.LayerNorm(hidden_size, eps=1e-5) self.third_sub_layer = PositionWiseFF(hidden_size, inner_size, ffn_dropout, hidden_act) + # Information for the adapter module mixin + self.self_attention_model = "transf_abs" + def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask): """ Pre-LayerNorm block @@ -74,6 +82,17 @@ def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_state self_attn_output = self.first_sub_layer(decoder_query, decoder_keys, decoder_keys, decoder_mask) self_attn_output += residual + if self.is_adapter_available(): + # Call the MHA adapters + pack_input = { + 'x': self_attn_output, + 'loc': 'mha', + 'att_mask': decoder_mask, + 'pos_emb': None, + } + pack_input = self.forward_enabled_adapters(pack_input) + self_attn_output = pack_input['x'] + residual = self_attn_output self_attn_output = self.layer_norm_2(self_attn_output) enc_dec_attn_output = self.second_sub_layer(self_attn_output, encoder_states, encoder_states, encoder_mask) @@ -84,6 +103,15 @@ def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_state output_states = self.third_sub_layer(enc_dec_attn_output) output_states += residual + if self.is_adapter_available(): + # Call the Linear adapters + pack_input = { + 'x': output_states, + 'loc': 'post', + } + pack_input = self.forward_enabled_adapters(pack_input) + output_states = pack_input['x'] + return output_states def forward_postln(self, decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask): @@ -93,6 +121,18 @@ def forward_postln(self, decoder_query, decoder_mask, decoder_keys, encoder_stat """ self_attn_output = self.first_sub_layer(decoder_query, decoder_keys, decoder_keys, decoder_mask) self_attn_output += decoder_query + + if self.is_adapter_available(): + # Call the MHA adapters + pack_ip = { + 'x': self_attn_output, + 'loc': 'mha', + 'att_mask': decoder_mask, + 'pos_emb': None, + } + pack_ip = self.forward_enabled_adapters(pack_ip) + self_attn_output = pack_ip['x'] + self_attn_output = self.layer_norm_1(self_attn_output) enc_dec_attn_output = self.second_sub_layer(self_attn_output, encoder_states, encoder_states, encoder_mask) @@ -101,6 +141,16 @@ def forward_postln(self, decoder_query, decoder_mask, decoder_keys, encoder_stat output_states = self.third_sub_layer(enc_dec_attn_output) output_states += enc_dec_attn_output + + if self.is_adapter_available(): + # Call the linear adapters + pack_ip = { + 'x': output_states, + 'loc': 'post', + } + pack_ip = self.forward_enabled_adapters(pack_ip) + output_states = pack_ip['x'] + return self.layer_norm_3(output_states) def forward(self, decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask): @@ -109,6 +159,19 @@ def forward(self, decoder_query, decoder_mask, decoder_keys, encoder_states, enc else: return self.forward_postln(decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask) + def get_accepted_adapter_types(self) -> Set[type]: + types = super().get_accepted_adapter_types() + + if len(types) == 0: + self.set_accepted_adapter_types( + [ + adapter_utils.LINEAR_ADAPTER_CLASSPATH, + adapter_utils.TRANSFORMER_MHA_ADAPTER_CLASSPATH, + ] + ) + types = self.get_accepted_adapter_types() + return types + class TransformerDecoder(nn.Module): def __init__( @@ -131,6 +194,8 @@ def __init__( else: self.final_layer_norm = None + self.d_model = hidden_size + layer = TransformerDecoderBlock( hidden_size, inner_size, @@ -219,3 +284,38 @@ def input_example(self, max_batch=1, max_dim=256): input_ids = torch.randint(low=0, high=2048, size=(max_batch, max_dim, 1024), device=sample.device) encoder_mask = torch.randint(low=0, high=1, size=(max_batch, max_dim), device=sample.device) return tuple([input_ids, encoder_mask, input_ids, encoder_mask]) + + +class TransformerDecoderAdapter(TransformerDecoder, adapter_mixins.AdapterModuleMixin): + + # Higher level forwarding + def add_adapter(self, name: str, cfg: dict): + cfg = self._update_adapter_cfg_input_dim(cfg) + for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin + transformer_layer.add_adapter(name, cfg) + + def is_adapter_available(self) -> bool: + return any([transformer_layer.is_adapter_available() for transformer_layer in self.layers]) + + def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True): + for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin + transformer_layer.set_enabled_adapters(name=name, enabled=enabled) + + def get_enabled_adapters(self) -> List[str]: + names = set([]) + for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin + names.update(transformer_layer.get_enabled_adapters()) + + names = sorted(list(names)) + return names + + def _update_adapter_cfg_input_dim(self, cfg: DictConfig): + cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self.d_model) + return cfg + + +""" +Register any additional information +""" +if adapter_mixins.get_registered_adapter(TransformerDecoder) is None: + adapter_mixins.register_adapter(base_class=TransformerDecoder, adapter_class=TransformerDecoderAdapter) diff --git a/nemo/collections/asr/modules/transformer/transformer_encoders.py b/nemo/collections/asr/modules/transformer/transformer_encoders.py index 544d561267cf..d3116db82482 100644 --- a/nemo/collections/asr/modules/transformer/transformer_encoders.py +++ b/nemo/collections/asr/modules/transformer/transformer_encoders.py @@ -13,17 +13,22 @@ # limitations under the License. import copy +from typing import List, Optional, Set import torch import torch.nn as nn +from omegaconf import DictConfig from nemo.collections.asr.modules.transformer.transformer_modules import MultiHeadAttention, PositionWiseFF +from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin +from nemo.collections.asr.parts.utils import adapter_utils from nemo.collections.common.parts import form_attention_mask +from nemo.core.classes.mixins import adapter_mixins __all__ = ["TransformerEncoder"] -class TransformerEncoderBlock(nn.Module): +class TransformerEncoderBlock(nn.Module, AttentionAdapterModuleMixin): """ Building block of Transformer encoder. @@ -59,6 +64,9 @@ def __init__( self.layer_norm_2 = nn.LayerNorm(hidden_size, eps=1e-5) self.second_sub_layer = PositionWiseFF(hidden_size, inner_size, ffn_dropout, hidden_act) + # Information for the adapter module mixin + self.self_attention_model = "transf_abs" + def forward_preln(self, encoder_query, encoder_mask, encoder_keys): """ Pre-LayerNorm block @@ -70,11 +78,31 @@ def forward_preln(self, encoder_query, encoder_mask, encoder_keys): self_attn_output = self.first_sub_layer(encoder_query, encoder_keys, encoder_keys, encoder_mask) self_attn_output += residual + if self.is_adapter_available(): + # Call the MHA adapters + pack_input = { + 'x': self_attn_output, + 'loc': 'mha', + 'att_mask': encoder_mask, + 'pos_emb': None, + } + pack_input = self.forward_enabled_adapters(pack_input) + self_attn_output = pack_input['x'] + residual = self_attn_output self_attn_output = self.layer_norm_2(self_attn_output) output_states = self.second_sub_layer(self_attn_output) output_states += residual + if self.is_adapter_available(): + # Call the Linear adapters + pack_input = { + 'x': output_states, + 'loc': 'post', + } + pack_input = self.forward_enabled_adapters(pack_input) + output_states = pack_input['x'] + return output_states def forward_postln(self, encoder_query, encoder_mask, encoder_keys): @@ -84,10 +112,32 @@ def forward_postln(self, encoder_query, encoder_mask, encoder_keys): """ self_attn_output = self.first_sub_layer(encoder_query, encoder_keys, encoder_keys, encoder_mask) self_attn_output += encoder_query + + if self.is_adapter_available(): + # Call the MHA adapters + pack_ip = { + 'x': self_attn_output, + 'loc': 'mha', + 'att_mask': encoder_mask, + 'pos_emb': None, + } + pack_ip = self.forward_enabled_adapters(pack_ip) + self_attn_output = pack_ip['x'] + self_attn_output = self.layer_norm_1(self_attn_output) output_states = self.second_sub_layer(self_attn_output) output_states += self_attn_output + + if self.is_adapter_available(): + # Call the linear adapters + pack_ip = { + 'x': output_states, + 'loc': 'post', + } + pack_ip = self.forward_enabled_adapters(pack_ip) + output_states = pack_ip['x'] + output_states = self.layer_norm_2(output_states) return output_states @@ -98,6 +148,19 @@ def forward(self, encoder_query, encoder_mask, encoder_keys): else: return self.forward_postln(encoder_query, encoder_mask, encoder_keys) + def get_accepted_adapter_types(self) -> Set[type]: + types = super().get_accepted_adapter_types() + + if len(types) == 0: + self.set_accepted_adapter_types( + [ + adapter_utils.LINEAR_ADAPTER_CLASSPATH, + adapter_utils.TRANSFORMER_MHA_ADAPTER_CLASSPATH, + ] + ) + types = self.get_accepted_adapter_types() + return types + class TransformerEncoder(nn.Module): def __init__( @@ -121,6 +184,8 @@ def __init__( else: self.final_layer_norm = None + self.d_model = hidden_size + layer = TransformerEncoderBlock( hidden_size, inner_size, @@ -172,3 +237,38 @@ def forward(self, encoder_states, encoder_mask, encoder_mems_list=None, return_m return cached_mems_list else: return cached_mems_list[-1] + + +class TransformerEncoderAdapter(TransformerEncoder, adapter_mixins.AdapterModuleMixin): + + # Higher level forwarding + def add_adapter(self, name: str, cfg: dict): + cfg = self._update_adapter_cfg_input_dim(cfg) + for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin + transformer_layer.add_adapter(name, cfg) + + def is_adapter_available(self) -> bool: + return any([transformer_layer.is_adapter_available() for transformer_layer in self.layers]) + + def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True): + for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin + transformer_layer.set_enabled_adapters(name=name, enabled=enabled) + + def get_enabled_adapters(self) -> List[str]: + names = set([]) + for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin + names.update(transformer_layer.get_enabled_adapters()) + + names = sorted(list(names)) + return names + + def _update_adapter_cfg_input_dim(self, cfg: DictConfig): + cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self.d_model) + return cfg + + +""" +Register any additional information +""" +if adapter_mixins.get_registered_adapter(TransformerEncoder) is None: + adapter_mixins.register_adapter(base_class=TransformerEncoder, adapter_class=TransformerEncoderAdapter) diff --git a/nemo/collections/asr/modules/transformer/transformer_generators.py b/nemo/collections/asr/modules/transformer/transformer_generators.py index 4061f54a907a..1a38e7fa4b6c 100644 --- a/nemo/collections/asr/modules/transformer/transformer_generators.py +++ b/nemo/collections/asr/modules/transformer/transformer_generators.py @@ -173,7 +173,7 @@ def _forward( def __call__( self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None, return_beam_scores=False ): - with self.as_frozen(): + with torch.inference_mode(): results = self._forward( decoder_input_ids, encoder_hidden_states, encoder_input_mask, return_beam_scores=return_beam_scores ) @@ -188,8 +188,7 @@ def __call__( return prefixes, scores, tgt def freeze(self) -> None: - """Freeze weights of embedding, decoder, and classification layers to prevent memory leak. - """ + """Freeze weights of embedding, decoder, and classification layers to prevent memory leak.""" for param in self.embedding.parameters(): param.requires_grad = False self.embedding.eval() @@ -201,8 +200,7 @@ def freeze(self) -> None: self.log_softmax.eval() def unfreeze(self) -> None: - """Unfreeze weights of embedding, decoder, and classification layers. - """ + """Unfreeze weights of embedding, decoder, and classification layers.""" for param in self.embedding.parameters(): param.requires_grad = True self.embedding.train() @@ -357,13 +355,13 @@ def _forward( # choose top-k hypotheses with length penalty applied len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen) scores = scores / len_penalties - scores, indices_i = torch.topk(scores.view(-1, self.beam_size ** 2), self.beam_size, dim=1) + scores, indices_i = torch.topk(scores.view(-1, self.beam_size**2), self.beam_size, dim=1) scores = scores.view(-1, 1) * len_penalties # select prefixes which correspond to the chosen hypotheses prefixes = prefixes.unsqueeze(1).repeat(1, self.beam_size, 1) prefixes = torch.cat((prefixes, prefixes_i.unsqueeze(2)), dim=2) - prefixes = prefixes.view(batch_size, self.beam_size ** 2, -1) + prefixes = prefixes.view(batch_size, self.beam_size**2, -1) p_len = prefixes.size(2) prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len) prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len) @@ -463,7 +461,10 @@ def _one_step_forward_lm(self, decoder_input_ids=None, lm_mems_list=None, pos=0) input_mask = mask_padded_tokens(decoder_input_ids, self.pad).float() lm_hidden_states = self.language_model.encoder.embedding.forward(decoder_input_ids, start_pos=pos) lm_mems_list = self.language_model.encoder.encoder.forward( - lm_hidden_states, input_mask, lm_mems_list, return_mems=True, + lm_hidden_states, + input_mask, + lm_mems_list, + return_mems=True, ) lm_log_probs = self.language_model.log_softmax.forward(hidden_states=lm_mems_list[-1][:, -1:]) return lm_log_probs, lm_mems_list @@ -639,13 +640,13 @@ def _forward(self, src_ids, encoder_input_mask, decoder_input_ids=None, return_b # choose top-k hypotheses with length penalty applied len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen) scores = scores / len_penalties - scores, indices_i = torch.topk(scores.view(-1, self.beam_size ** 2), self.beam_size, dim=1) + scores, indices_i = torch.topk(scores.view(-1, self.beam_size**2), self.beam_size, dim=1) scores = scores.view(-1, 1) * len_penalties # select prefixes which correspond to the chosen hypotheses prefixes = prefixes.unsqueeze(1).repeat(1, self.beam_size, 1) prefixes = torch.cat((prefixes, prefixes_i.unsqueeze(2)), dim=2) - prefixes = prefixes.view(batch_size, self.beam_size ** 2, -1) + prefixes = prefixes.view(batch_size, self.beam_size**2, -1) p_len = prefixes.size(2) prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len) prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len) @@ -697,12 +698,11 @@ def _forward(self, src_ids, encoder_input_mask, decoder_input_ids=None, return_b return tgt def __call__(self, src_ids, encoder_input_mask, decoder_input_ids=None, return_beam_scores=False): - with self.as_frozen(): + with torch.inference_mode(): return self._forward(src_ids, encoder_input_mask, decoder_input_ids, return_beam_scores) def freeze(self) -> None: - """Freeze weights of embedding, decoder, and classification layers to prevent memory leak. - """ + """Freeze weights of embedding, decoder, and classification layers to prevent memory leak.""" for model_num in range(self.num_models): for param in self.embeddings[model_num].parameters(): param.requires_grad = False @@ -718,8 +718,7 @@ def freeze(self) -> None: self.encoders[model_num].eval() def unfreeze(self) -> None: - """Unfreeze weights of embedding, decoder, and classification layers. - """ + """Unfreeze weights of embedding, decoder, and classification layers.""" for model_num in range(self.num_models): for param in self.embeddings[model_num].parameters(): param.requires_grad = True @@ -781,13 +780,20 @@ def _one_step_forward( ): nmt_log_probs, decoder_mems_list = super()._one_step_forward( - decoder_input_ids, encoder_hidden_states, encoder_input_mask, decoder_mems_list, pos, + decoder_input_ids, + encoder_hidden_states, + encoder_input_mask, + decoder_mems_list, + pos, ) input_mask = mask_padded_tokens(decoder_input_ids, self.pad).float() lm_hidden_states = self.language_model.encoder.embedding.forward(decoder_input_ids, start_pos=pos) lm_mems_list = self.language_model.encoder.encoder.forward( - lm_hidden_states, input_mask, lm_mems_list, return_mems=True, + lm_hidden_states, + input_mask, + lm_mems_list, + return_mems=True, ) lm_log_probs = self.language_model.log_softmax.forward(hidden_states=lm_mems_list[-1][:, -1:]) @@ -863,13 +869,13 @@ def _forward( # choose top-k hypotheses with length penalty applied len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen) scores = scores / len_penalties - scores, indices_i = torch.topk(scores.view(-1, self.beam_size ** 2), self.beam_size, dim=1) + scores, indices_i = torch.topk(scores.view(-1, self.beam_size**2), self.beam_size, dim=1) scores = scores.view(-1, 1) * len_penalties # select prefixes which correspond to the chosen hypotheses prefixes = prefixes.unsqueeze(1).repeat(1, self.beam_size, 1) prefixes = torch.cat((prefixes, prefixes_i.unsqueeze(2)), dim=2) - prefixes = prefixes.view(batch_size, self.beam_size ** 2, -1) + prefixes = prefixes.view(batch_size, self.beam_size**2, -1) p_len = prefixes.size(2) prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len) prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len) diff --git a/nemo/collections/asr/modules/transformer/transformer_modules.py b/nemo/collections/asr/modules/transformer/transformer_modules.py index 25fb781f0cd4..d090604287cb 100644 --- a/nemo/collections/asr/modules/transformer/transformer_modules.py +++ b/nemo/collections/asr/modules/transformer/transformer_modules.py @@ -65,7 +65,9 @@ def forward(self, position_ids): f'Max position id {max_pos_id} is greater than max sequence length {self._max_sequence_length}. Expanding position embeddings just for this batch. This is not expected to work very well. Consider chunking your input into smaller sequences.' ) self._build_pos_enc( - hidden_size=self._hidden_size, max_sequence_length=max_pos_id + 1, device=position_ids.device, + hidden_size=self._hidden_size, + max_sequence_length=max_pos_id + 1, + device=position_ids.device, ) embeddings = torch.embedding(self.pos_enc, position_ids) @@ -203,8 +205,9 @@ def forward(self, queries, keys, values, attention_mask): attention_probs = self.attn_dropout(attention_probs) context = torch.matmul(attention_probs, value) + context_hidden_size = context.size()[-1] * self.num_attention_heads context = context.permute(0, 2, 1, 3).contiguous() - new_context_shape = context.size()[:-2] + (self.hidden_size,) + new_context_shape = context.size()[:-2] + (context_hidden_size,) context = context.view(*new_context_shape) # output projection diff --git a/nemo/collections/asr/modules/transformer/transformer_utils.py b/nemo/collections/asr/modules/transformer/transformer_utils.py index da9ffb8fbd00..5de1652ee1b0 100644 --- a/nemo/collections/asr/modules/transformer/transformer_utils.py +++ b/nemo/collections/asr/modules/transformer/transformer_utils.py @@ -113,6 +113,7 @@ def get_nemo_transformer( else: raise ValueError(f"Unknown arch = {arch}") else: + model = TransformerDecoderNM( vocab_size=cfg.get('vocab_size'), hidden_size=cfg.get('hidden_size'), diff --git a/nemo/collections/asr/parts/mixins/asr_adapter_mixins.py b/nemo/collections/asr/parts/mixins/asr_adapter_mixins.py index f452acd19847..bd0607f2c4f3 100644 --- a/nemo/collections/asr/parts/mixins/asr_adapter_mixins.py +++ b/nemo/collections/asr/parts/mixins/asr_adapter_mixins.py @@ -21,7 +21,7 @@ class ASRAdapterModelMixin(AdapterModelPTMixin): - """ ASR Adapter Mixin that can augment any Encoder module with Adapter module support. + """ASR Adapter Mixin that can augment any Encoder module with Adapter module support. This mixin class should be used only with a top level ModelPT subclass, that includes an `encoder` submodule. This mixin class adds several utility methods which are propagated to the `encoder`. @@ -54,14 +54,10 @@ def setup_adapters(self): supports_adapters = False # At least the encoder must extend AdapterModuleMixin - if hasattr(self, 'encoder') and isinstance(self.encoder, AdapterModuleMixin): - supports_adapters |= True - - if hasattr(self, 'decoder') and isinstance(self.decoder, AdapterModuleMixin): - supports_adapters |= True - - if hasattr(self, 'joint') and isinstance(self.joint, AdapterModuleMixin): - supports_adapters |= True + valid_adapter_names = [x for x in self.adapter_module_names if x != ''] + for module_name in valid_adapter_names: + if hasattr(self, module_name) and isinstance(getattr(self, module_name), AdapterModuleMixin): + supports_adapters |= True # If adapters are supported, setup the adapter config + any modules (pre-existing adapter modules) if supports_adapters: @@ -87,24 +83,30 @@ def add_adapter(self, name: str, cfg: DictConfig): else: module_names = [module_name] + valid_module_names = [x for x in self.adapter_module_names if x != ''] + default_module_name = self.default_adapter_module_name + + # Check if default module name is None or not + if default_module_name is None: + raise ValueError( + f"Default module name is None. Class {self.__class__.__name__} must implement " + f"`default_adapter_module_name`" + ) + # Update the model.cfg with information about the new adapter from cfg with open_dict(self.cfg): for module_name in module_names: # Check if encoder adapters should be added - if module_name in ('', 'encoder'): - # Dispatch the call to the encoder. - self.encoder.add_adapter(name=name, cfg=cfg) - - # Check if decoder adapters should be added - if module_name == 'decoder': - # Dispatch call to the decoder. - self.decoder.add_adapter(name=name, cfg=cfg) + if module_name == '': + if hasattr(self, default_module_name): + # Dispatch the call to the default model. + getattr(self, default_module_name).add_adapter(name=name, cfg=cfg) - # Check if joint adapters should be added; - # Note: We need additional check if joint even exists in model (for CTC models) - if hasattr(self, 'joint') and module_name == 'joint': - # Dispatch call to the joint. - self.joint.add_adapter(name=name, cfg=cfg) + elif module_name in valid_module_names: + # Check if module exists + if hasattr(self, module_name): + # Dispatch the call to the module. + getattr(self, module_name).add_adapter(name=name, cfg=cfg) def is_adapter_available(self) -> bool: """ @@ -116,15 +118,12 @@ def is_adapter_available(self) -> bool: """ config_contains_adapter = super().is_adapter_available() - # Forward the method call to the individual modules - if hasattr(self, 'encoder') and isinstance(self.encoder, AdapterModuleMixin): - config_contains_adapter |= self.encoder.is_adapter_available() - - if hasattr(self, 'decoder') and isinstance(self.decoder, AdapterModuleMixin): - config_contains_adapter |= self.decoder.is_adapter_available() + valid_module_names = [x for x in self.adapter_module_names if x != ''] - if hasattr(self, 'joint') and isinstance(self.joint, AdapterModuleMixin): - config_contains_adapter |= self.joint.is_adapter_available() + # Forward the method call to the individual modules + for module_name in valid_module_names: + if hasattr(self, module_name) and isinstance(getattr(self, module_name), AdapterModuleMixin): + config_contains_adapter |= getattr(self, module_name).is_adapter_available() return config_contains_adapter @@ -160,23 +159,29 @@ def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True) else: module_names = [module_name] + valid_module_names = [x for x in self.adapter_module_names if x != ''] + default_module_name = self.default_adapter_module_name + + # Check if default module name is None or not + if default_module_name is None: + raise ValueError( + f"Default module name is None. Class {self.__class__.__name__} must implement " + f"`default_adapter_module_name`" + ) + + # Forward the method call to the individual modules if they exist for module_name in module_names: # Check if encoder adapters should be used - # Dispatch the call to the encoder. - if name is None or module_name in ('', 'encoder'): - if self.encoder.is_adapter_available(): - self.encoder.set_enabled_adapters(name=name, enabled=enabled) - - # Dispatch the call to the decoder. - if name is None or module_name == 'decoder': - if self.decoder.is_adapter_available(): - self.decoder.set_enabled_adapters(name=name, enabled=enabled) - - # Dispatch the call to the joint. - # Note: We need additional check for joint, since it may not exist (CTC models). - if name is None or module_name == 'joint': - if hasattr(self, 'joint') and self.joint.is_adapter_available(): - self.joint.set_enabled_adapters(name=name, enabled=enabled) + + if module_name == '': + if hasattr(self, default_module_name): + # Dispatch the call to the default model. + getattr(self, default_module_name).set_enabled_adapters(name=name, enabled=enabled) + + elif module_name in valid_module_names: + if hasattr(self, module_name): + # Dispatch the call to the module. + getattr(self, module_name).set_enabled_adapters(name=name, enabled=enabled) def get_enabled_adapters(self) -> List[str]: """ @@ -187,15 +192,12 @@ def get_enabled_adapters(self) -> List[str]: """ enabled_adapters = super().get_enabled_adapters() - # Check if encoder adapters should be used or are enabled - if hasattr(self, 'encoder') and isinstance(self.encoder, AdapterModuleMixin): - enabled_adapters.extend(self.encoder.get_enabled_adapters()) + valid_module_names = [x for x in self.adapter_module_names if x != ''] - if hasattr(self, 'decoder') and isinstance(self.decoder, AdapterModuleMixin): - enabled_adapters.extend(self.decoder.get_enabled_adapters()) - - if hasattr(self, 'joint') and isinstance(self.joint, AdapterModuleMixin): - enabled_adapters.extend(self.joint.get_enabled_adapters()) + # Check if encoder adapters should be used or are enabled + for module_name in valid_module_names: + if hasattr(self, module_name) and isinstance(getattr(self, module_name), AdapterModuleMixin): + enabled_adapters.extend(getattr(self, module_name).get_enabled_adapters()) enabled_adapters = list(sorted(list(set(enabled_adapters)))) @@ -208,44 +210,19 @@ def check_valid_model_with_adapter_support_(self): # Obtain the global adapter config if possible, otherwise use sensible defaults. global_cfg = self._get_global_cfg() - # Test whether the encoder supports adapters - use_encoder_adapter = global_cfg.get('check_encoder_adapter', True) - if use_encoder_adapter: - if not hasattr(self, 'encoder'): - logging.warning( - "Cannot add adapter to this object as it does not have an `encoder` sub-module!", - mode=logging_mode.ONCE, - ) - - if hasattr(self, 'encoder') and not isinstance(self.encoder, AdapterModuleMixin): - logging.warning( - f'{self.encoder.__class__.__name__} does not implement `AdapterModuleMixin`', - mode=logging_mode.ONCE, - ) - - # Test whether the decoder supports adapters - use_decoder_adapter = global_cfg.get('check_decoder_adapter', True) - if use_decoder_adapter: - if not hasattr(self, 'decoder'): - logging.warning( - "Cannot add adapter to this object as it does not have an `decoder` sub-module!", - mode=logging_mode.ONCE, - ) - - if hasattr(self, 'decoder') and not isinstance(self.decoder, AdapterModuleMixin): - logging.warning( - f'{self.decoder.__class__.__name__} does not implement `AdapterModuleMixin`', - mode=logging_mode.ONCE, - ) - - # Test whether the joint supports adapters - use_joint_adapter = global_cfg.get('check_joint_adapter', True) - if use_joint_adapter: - # Joint is only for RNNT models, skip assertion that it must always exist. - if hasattr(self, 'joint') and not isinstance(self.joint, AdapterModuleMixin): - logging.warning( - f'{self.joint.__class__.__name__} does not implement `AdapterModuleMixin`', mode=logging_mode.ONCE - ) + valid_module_names = [x for x in self.adapter_module_names if x != ''] + + for module_name in valid_module_names: + check_adapter_support = global_cfg.get(f'check_{module_name}_adapter', True) + + if check_adapter_support: + # Test whether the module supports adapters + if hasattr(self, module_name) and not isinstance(getattr(self, module_name), AdapterModuleMixin): + logging.warning( + f'Module `{module_name}` exists, but {getattr(self, module_name).__class__.__name__} ' + f'does not implement `AdapterModuleMixin`', + mode=logging_mode.ONCE, + ) def resolve_adapter_module_name_(self, name: str) -> Tuple[str, str]: """ @@ -293,3 +270,7 @@ def _get_global_cfg(self): def adapter_module_names(self) -> List[str]: valid_module_names = ['', 'encoder', 'decoder', 'joint'] return valid_module_names + + @property + def default_adapter_module_name(self) -> str: + return 'encoder' diff --git a/nemo/collections/asr/parts/submodules/adapters/__init__.py b/nemo/collections/asr/parts/submodules/adapters/__init__.py index 6aa05d07dea1..c51d935bddd4 100644 --- a/nemo/collections/asr/parts/submodules/adapters/__init__.py +++ b/nemo/collections/asr/parts/submodules/adapters/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# fmt: off +from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin from nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module import ( MHAResidualAddAdapterStrategy, MHAResidualAddAdapterStrategyConfig, @@ -24,3 +26,9 @@ RelPositionMultiHeadAttentionAdapter, RelPositionMultiHeadAttentionAdapterConfig, ) +from nemo.collections.asr.parts.submodules.adapters.transformer_multi_head_attention_adapter_module import ( + TransformerMultiHeadAttentionAdapter, + TransformerMultiHeadAttentionAdapterConfig, +) + +# fmt: on diff --git a/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py b/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py new file mode 100644 index 000000000000..0c1852773072 --- /dev/null +++ b/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py @@ -0,0 +1,119 @@ +import torch + +from nemo.core.classes.mixins import adapter_mixins +from nemo.utils import logging, logging_mode + + +class AttentionAdapterModuleMixin(adapter_mixins.AdapterModuleMixin): + """ + Utility class that implements a custom forward method for Modules that are attention based. + Attention based adapters can support either linear adapters, and Multi-Head Attention adapters. + + However, Multi Head Attention adapters require additional arguments, such as `att_mask` and `pos_emb`. + This utility class unifies the adapter forward pass for both types of adapters. + + .. Usage: + + To use this class, inherit from this class, and when calling self.foward_enabled_adapters() pass the following: + + .. code-block:: python + + if self.is_adapter_available(): + # Call the MHA adapters + pack_ip = { + 'x': residual, + 'loc': 'mha', + 'att_mask': att_mask, + 'pos_emb': pos_emb, + } + pack_ip = self.forward_enabled_adapters(pack_ip) + residual = pack_ip['x'] + + if self.is_adapter_available(): + # Call the Linear adapters + pack_ip = { + 'x': x, + 'loc': 'post', + } + pack_ip = self.forward_enabled_adapters(pack_ip) + x = pack_ip['x'] + """ + + def forward_single_enabled_adapter_( + self, + input: dict, + adapter_module: torch.nn.Module, + *, + adapter_name: str, + adapter_strategy: 'nemo.core.classes.mixins.adapter_mixin_strategies.AbstractAdapterStrategy', + ): + """ + Perform the forward step of a single adapter module on some input data. + + **Note**: Subclasses can override this method to accommodate more complicate adapter forward steps. + + Args: + input: Dictionary of packed tensors. The dict should contain at least + `x`: output tensor + `loc`: Semantic location in module where this adapter was called. Can be 'mha' or 'post'. + `att_mask`: Optional, Attention mask + `pos_emb`: Optional, Positional Embedding for Relative Positional Encoding. + The output tensor of the calling module is the input to the first adapter, whose output + is then chained to the next adapter until all adapters are consumed. + adapter_module: The adapter module that is currently required to perform the forward pass. + adapter_name: The resolved name of the adapter that is undergoing the current forward pass. + adapter_strategy: A subclass of `AbstractAdapterStrategy`, that determines how the + output of the adapter should be merged with the input, or if it should be merged at all. + + Returns: + The result tensor, after the current active adapter has finished its forward pass. + """ + if not hasattr(self, 'self_attention_model'): + raise RuntimeError( + "self_attention_model attribute not found in the module! Please set in the module " + "a string attribute 'self_attention_model' with value 'abs_pos', 'rel_pos' or " + "other supported self-attention model types." + ) + + # Collect imports to prevent circular imports + from nemo.collections.asr.modules.transformer import transformer_modules as transformer_mha + from nemo.collections.asr.parts.submodules import multi_head_attention as conformer_mha + + # (input: torch.Tensor, adapter: torch.nn.Module, *, module: 'AdapterModuleMixin') + x = input['x'] + loc = input['loc'] + att_mask = input.get('att_mask', None) + pos_emb = input.get('pos_emb', None) + + from nemo.collections.common.parts import adapter_modules + + if isinstance(adapter_module, adapter_modules.LinearAdapter) and loc == 'post': + output = adapter_strategy(x, adapter_module, module=self) + + elif isinstance(adapter_module, conformer_mha.MultiHeadAttention) and loc == 'mha': + if self.self_attention_model == 'rel_pos': + x = dict(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb) + output = adapter_strategy(x, adapter_module, module=self) + + elif self.self_attention_model == 'abs_pos': + x = dict(query=x, key=x, value=x, mask=att_mask) + output = adapter_strategy(x, adapter_module, module=self) + + else: + raise ValueError(f"Unsupported value of self_attention_model , provided {self.self_attention_model}!") + + elif isinstance(adapter_module, transformer_mha.MultiHeadAttention) and loc == 'mha': + x = dict(queries=x, keys=x, values=x, attention_mask=att_mask) + output = adapter_strategy(x, adapter_module, module=self) + + else: + # No adapter compatible, skip + logging.warning( + "No adapter compatible with the current module. Skipping adapter forward pass.", mode=logging_mode.ONCE + ) + + output = x + + input['x'] = output + + return input diff --git a/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py b/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py index 3df51092ac4b..2617ed6f575b 100644 --- a/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py +++ b/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py @@ -29,7 +29,7 @@ class MHAResidualAddAdapterStrategy(adapter_mixin_strategies.ResidualAddAdapterS An implementation of residual addition of an adapter module with its input for the MHA Adapters. """ - def forward(self, input: torch.Tensor, adapter: torch.nn.Module, *, module: 'AdapterModuleMixin'): + def forward(self, input: dict, adapter: torch.nn.Module, *, module: 'AdapterModuleMixin'): """ A basic strategy, comprising of a residual connection over the input, after forward pass by the underlying adapter. Additional work is done to pack and unpack the dictionary of inputs and outputs. @@ -55,18 +55,29 @@ def forward(self, input: torch.Tensor, adapter: torch.nn.Module, *, module: 'Ada """ out = self.compute_output(input, adapter, module=module) + value_name = None + if 'value' in input: + value_name = 'value' + elif 'values' in input: + value_name = 'values' + else: + raise ValueError( + "Input dictionary must contain 'value' or 'values' key for residual connection. Input " + f"dictionary keys: {input.keys()}" + ) + # If not in training mode, or probability of stochastic depth is 0, skip step. p = self.stochastic_depth if not module.training or p == 0.0: pass else: - out = self.apply_stochastic_depth(out, input['value'], adapter, module=module) + out = self.apply_stochastic_depth(out, input[value_name], adapter, module=module) # Return the residual connection output = input + adapter(input) - result = input['value'] + out + result = input[value_name] + out # If l2_lambda is activated, register the loss value - self.compute_auxiliary_losses(result, input['value'], adapter, module=module) + self.compute_auxiliary_losses(result, input[value_name], adapter, module=module) return result @@ -105,16 +116,16 @@ class MHAResidualAddAdapterStrategyConfig(adapter_mixin_strategies.ResidualAddAd class MultiHeadAttentionAdapter(mha.MultiHeadAttention, adapter_modules.AdapterModuleUtil): """Multi-Head Attention layer of Transformer. - Args: - n_head (int): number of heads - n_feat (int): size of the features - dropout_rate (float): dropout rate - proj_dim (int, optional): Optional integer value for projection before computing attention. - If None, then there is no projection (equivalent to proj_dim = n_feat). - If > 0, then will project the n_feat to proj_dim before calculating attention. - If <0, then will equal n_head, so that each head has a projected dimension of 1. - adapter_strategy: By default, MHAResidualAddAdapterStrategyConfig. An adapter composition function object. - """ + Args: + n_head (int): number of heads + n_feat (int): size of the features + dropout_rate (float): dropout rate + proj_dim (int, optional): Optional integer value for projection before computing attention. + If None, then there is no projection (equivalent to proj_dim = n_feat). + If > 0, then will project the n_feat to proj_dim before calculating attention. + If <0, then will equal n_head, so that each head has a projected dimension of 1. + adapter_strategy: By default, MHAResidualAddAdapterStrategyConfig. An adapter composition function object. + """ def __init__( self, @@ -300,7 +311,6 @@ class RelPositionMultiHeadAttentionAdapterConfig: class PositionalEncodingAdapter(mha.PositionalEncoding, adapter_modules.AdapterModuleUtil): - """ Absolute positional embedding adapter. @@ -327,7 +337,11 @@ def __init__( ): super().__init__( - d_model=d_model, dropout_rate=0.0, max_len=max_len, xscale=xscale, dropout_rate_emb=0.0, + d_model=d_model, + dropout_rate=0.0, + max_len=max_len, + xscale=xscale, + dropout_rate_emb=0.0, ) # Setup adapter strategy diff --git a/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py b/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py new file mode 100644 index 000000000000..4319a6962f4f --- /dev/null +++ b/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py @@ -0,0 +1,128 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass, field +from typing import Any, Optional + +import torch +from torch import nn as nn + +from nemo.collections.asr.modules.transformer import transformer_modules +from nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module import ( + MHAResidualAddAdapterStrategy, + MHAResidualAddAdapterStrategyConfig, +) +from nemo.collections.common.parts import adapter_modules +from nemo.core.classes.mixins import adapter_mixin_strategies, adapter_mixins + + +class TransformerMultiHeadAttentionAdapter(transformer_modules.MultiHeadAttention, adapter_modules.AdapterModuleUtil): + """Multi-Head Attention layer of Transformer Encoder. + + Args: + hidden_size (int): number of heads + num_attention_heads (int): size of the features + attn_score_dropout (float): dropout rate for the attention scores + attn_layer_dropout (float): dropout rate for the layer + proj_dim (int, optional): Optional integer value for projection before computing attention. + If None, then there is no projection (equivalent to proj_dim = n_feat). + If > 0, then will project the n_feat to proj_dim before calculating attention. + If <0, then will equal n_head, so that each head has a projected dimension of 1. + adapter_strategy: By default, MHAResidualAddAdapterStrategyConfig. An adapter composition function object. + """ + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + attn_score_dropout: float = 0.0, + attn_layer_dropout: float = 0.0, + proj_dim: Optional[int] = None, + adapter_strategy: MHAResidualAddAdapterStrategy = None, + ): + super().__init__( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attn_score_dropout=attn_score_dropout, + attn_layer_dropout=attn_layer_dropout, + ) + + self.pre_norm = nn.LayerNorm(hidden_size) + + # Set the projection dim to number of heads automatically + if proj_dim is not None and proj_dim < 1: + proj_dim = num_attention_heads + + self.proj_dim = proj_dim + + # Recompute weights for projection dim + if self.proj_dim is not None: + if self.proj_dim % num_attention_heads != 0: + raise ValueError(f"proj_dim ({proj_dim}) is not divisible by n_head ({num_attention_heads})") + + self.attn_head_size = self.proj_dim // num_attention_heads + self.attn_scale = math.sqrt(math.sqrt(self.attn_head_size)) + self.query_net = nn.Linear(hidden_size, self.proj_dim) + self.key_net = nn.Linear(hidden_size, self.proj_dim) + self.value_net = nn.Linear(hidden_size, self.proj_dim) + self.out_projection = nn.Linear(self.proj_dim, hidden_size) + + # Setup adapter strategy + self.setup_adapter_strategy(adapter_strategy) + + # reset parameters for Q to be identity operation + self.reset_parameters() + + def forward(self, queries, keys, values, attention_mask): + """Compute 'Scaled Dot Product Attention'. + Args: + query (torch.Tensor): (batch, time1, size) + key (torch.Tensor): (batch, time2, size) + value(torch.Tensor): (batch, time2, size) + mask (torch.Tensor): (batch, time1, time2) + cache (torch.Tensor) : (batch, time_cache, size) + + returns: + output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention + cache (torch.Tensor) : (batch, time_cache_next, size) + """ + # Need to perform duplicate computations as at this point the tensors have been + # separated by the adapter forward + query = self.pre_norm(queries) + key = self.pre_norm(keys) + value = self.pre_norm(values) + + return super().forward(query, key, value, attention_mask) + + def reset_parameters(self): + with torch.no_grad(): + nn.init.zeros_(self.out_projection.weight) + nn.init.zeros_(self.out_projection.bias) + + def get_default_strategy_config(self) -> 'dataclass': + return MHAResidualAddAdapterStrategyConfig() + + +@dataclass +class TransformerMultiHeadAttentionAdapterConfig: + hidden_size: int + num_attention_heads: int + attn_score_dropout: float = 0.0 + attn_layer_dropout: float = 0.0 + proj_dim: Optional[int] = None + adapter_strategy: Optional[Any] = field(default_factory=lambda: MHAResidualAddAdapterStrategyConfig()) + _target_: str = "{0}.{1}".format( + TransformerMultiHeadAttentionAdapter.__module__, TransformerMultiHeadAttentionAdapter.__name__ + ) diff --git a/nemo/collections/asr/parts/submodules/conformer_modules.py b/nemo/collections/asr/parts/submodules/conformer_modules.py index 093cde63c439..c2d897d63225 100644 --- a/nemo/collections/asr/parts/submodules/conformer_modules.py +++ b/nemo/collections/asr/parts/submodules/conformer_modules.py @@ -17,6 +17,7 @@ from torch import nn as nn from torch.nn import LayerNorm +from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin from nemo.collections.asr.parts.submodules.batchnorm import FusedBatchNorm1d from nemo.collections.asr.parts.submodules.causal_convs import CausalConv1D from nemo.collections.asr.parts.submodules.multi_head_attention import ( @@ -25,15 +26,13 @@ RelPositionMultiHeadAttentionLongformer, ) from nemo.collections.asr.parts.utils.activations import Swish -from nemo.collections.common.parts import adapter_modules from nemo.collections.common.parts.utils import activation_registry from nemo.core.classes.mixins import AccessMixin -from nemo.core.classes.mixins.adapter_mixins import AdapterModuleMixin __all__ = ['ConformerConvolution', 'ConformerFeedForward', 'ConformerLayer'] -class ConformerLayer(torch.nn.Module, AdapterModuleMixin, AccessMixin): +class ConformerLayer(torch.nn.Module, AttentionAdapterModuleMixin, AccessMixin): """A single block of the Conformer encoder. Args: @@ -184,14 +183,14 @@ def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_chan if self.is_adapter_available(): # Call the MHA adapters - pack_ip = { + pack_input = { 'x': residual, 'loc': 'mha', 'att_mask': att_mask, 'pos_emb': pos_emb, } - pack_ip = self.forward_enabled_adapters(pack_ip) - residual = pack_ip['x'] + pack_input = self.forward_enabled_adapters(pack_input) + residual = pack_input['x'] x = self.norm_conv(residual) x = self.conv(x, pad_mask=pad_mask, cache=cache_last_time) @@ -207,12 +206,12 @@ def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_chan if self.is_adapter_available(): # Call the adapters - pack_ip = { + pack_input = { 'x': x, 'loc': 'post', } - pack_ip = self.forward_enabled_adapters(pack_ip) - x = pack_ip['x'] + pack_input = self.forward_enabled_adapters(pack_input) + x = pack_input['x'] if self.is_access_enabled(getattr(self, "model_guid", None)) and self.access_cfg.get( 'save_encoder_tensors', False @@ -223,64 +222,6 @@ def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_chan else: return x, cache_last_channel, cache_last_time - def forward_single_enabled_adapter_( - self, - input: dict, - adapter_module: torch.nn.Module, - *, - adapter_name: str, - adapter_strategy: 'nemo.core.classes.mixins.adapter_mixin_strategies.AbstractAdapterStrategy', - ): - """ - Perform the forward step of a single adapter module on some input data. - - **Note**: Subclasses can override this method to accommodate more complicate adapter forward steps. - - Args: - input: Dictionary of packed tensors. The dict should contain at least - `x`: output tensor - `loc`: Semantic location in module where this adapter was called - `att_mask`: Optional, Attention mask - `pos_emb`: Optional, Positional Embedding for Relative Positional Encoding. - The output tensor of the calling module is the input to the first adapter, whose output - is then chained to the next adapter until all adapters are consumed. - adapter_module: The adapter module that is currently required to perform the forward pass. - adapter_name: The resolved name of the adapter that is undergoing the current forward pass. - adapter_strategy: A subclass of `AbstractAdapterStrategy`, that determines how the - output of the adapter should be merged with the input, or if it should be merged at all. - - Returns: - The result tensor, after the current active adapter has finished its forward pass. - """ - # (input: torch.Tensor, adapter: torch.nn.Module, *, module: 'AdapterModuleMixin') - x = input['x'] - loc = input['loc'] - att_mask = input.get('att_mask', None) - pos_emb = input.get('pos_emb', None) - - if isinstance(adapter_module, adapter_modules.LinearAdapter) and loc == 'post': - output = adapter_strategy(x, adapter_module, module=self) - - elif isinstance(adapter_module, MultiHeadAttention) and loc == 'mha': - if self.self_attention_model == 'rel_pos': - x = dict(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb) - output = adapter_strategy(x, adapter_module, module=self) - - elif self.self_attention_model == 'abs_pos': - x = dict(query=x, key=x, value=x, mask=att_mask) - output = adapter_strategy(x, adapter_module, module=self) - - else: - raise ValueError(f"Unsupported value of self_attention_model , provided {self.self_attention_model}!") - - else: - # No adapter compatible, skip - output = x - - input['x'] = output - - return input - class ConformerConvolution(nn.Module): """The convolution module for the Conformer model. diff --git a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py index ef3a0cddb286..25becda6fa75 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py @@ -201,8 +201,7 @@ class BeamRNNTInfer(Typing): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return { "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), "encoded_lengths": NeuralType(tuple('B'), LengthsType()), @@ -211,8 +210,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return {"predictions": [NeuralType(elements_type=HypothesisType())]} def __init__( @@ -369,7 +367,7 @@ def __call__( return_hat_ilm_default = self.joint.return_hat_ilm self.joint.return_hat_ilm = self.hat_subtract_ilm - with torch.no_grad(): + with torch.inference_mode(): # Apply optional preprocessing encoder_output = encoder_output.transpose(1, 2) # (B, T, D) @@ -384,38 +382,34 @@ def __call__( unit='sample', ) as idx_gen: - # Freeze the decoder and joint to prevent recording of gradients - # during the beam loop. - with self.decoder.as_frozen(), self.joint.as_frozen(): - - _p = next(self.joint.parameters()) - dtype = _p.dtype + _p = next(self.joint.parameters()) + dtype = _p.dtype - # Decode every sample in the batch independently. - for batch_idx in idx_gen: - inseq = encoder_output[batch_idx : batch_idx + 1, : encoded_lengths[batch_idx], :] # [1, T, D] - logitlen = encoded_lengths[batch_idx] + # Decode every sample in the batch independently. + for batch_idx in idx_gen: + inseq = encoder_output[batch_idx : batch_idx + 1, : encoded_lengths[batch_idx], :] # [1, T, D] + logitlen = encoded_lengths[batch_idx] - if inseq.dtype != dtype: - inseq = inseq.to(dtype=dtype) + if inseq.dtype != dtype: + inseq = inseq.to(dtype=dtype) - # Extract partial hypothesis if exists - partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None + # Extract partial hypothesis if exists + partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None - # Execute the specific search strategy - nbest_hyps = self.search_algorithm( - inseq, logitlen, partial_hypotheses=partial_hypothesis - ) # sorted list of hypothesis + # Execute the specific search strategy + nbest_hyps = self.search_algorithm( + inseq, logitlen, partial_hypotheses=partial_hypothesis + ) # sorted list of hypothesis - # Prepare the list of hypotheses - nbest_hyps = pack_hypotheses(nbest_hyps) + # Prepare the list of hypotheses + nbest_hyps = pack_hypotheses(nbest_hyps) - # Pack the result - if self.return_best_hypothesis: - best_hypothesis = nbest_hyps[0] # type: Hypothesis - else: - best_hypothesis = NBestHypotheses(nbest_hyps) # type: NBestHypotheses - hypotheses.append(best_hypothesis) + # Pack the result + if self.return_best_hypothesis: + best_hypothesis = nbest_hyps[0] # type: Hypothesis + else: + best_hypothesis = NBestHypotheses(nbest_hyps) # type: NBestHypotheses + hypotheses.append(best_hypothesis) self.decoder.train(decoder_training_state) self.joint.train(joint_training_state) @@ -639,7 +633,10 @@ def default_beam_search( # keep those hypothesis that have scores greater than next search generation hyps_max = float(max(hyps, key=lambda x: x.score).score) - kept_most_prob = sorted([hyp for hyp in kept_hyps if hyp.score > hyps_max], key=lambda x: x.score,) + kept_most_prob = sorted( + [hyp for hyp in kept_hyps if hyp.score > hyps_max], + key=lambda x: x.score, + ) # If enough hypothesis have scores greater than next search generation, # stop beam search. diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index 420e49c96142..70ab74e7b014 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -383,14 +383,13 @@ def forward( hypotheses = [] # Process each sequence independently - with self.decoder.as_frozen(), self.joint.as_frozen(): - for batch_idx in range(encoder_output.size(0)): - inseq = encoder_output[batch_idx, :, :].unsqueeze(1) # [T, 1, D] - logitlen = encoded_lengths[batch_idx] + for batch_idx in range(encoder_output.size(0)): + inseq = encoder_output[batch_idx, :, :].unsqueeze(1) # [T, 1, D] + logitlen = encoded_lengths[batch_idx] - partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None - hypothesis = self._greedy_decode(inseq, logitlen, partial_hypotheses=partial_hypothesis) - hypotheses.append(hypothesis) + partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None + hypothesis = self._greedy_decode(inseq, logitlen, partial_hypotheses=partial_hypothesis) + hypotheses.append(hypothesis) # Pack results into Hypotheses packed_result = pack_hypotheses(hypotheses, encoded_lengths) @@ -720,12 +719,11 @@ def forward( self.decoder.eval() self.joint.eval() - with self.decoder.as_frozen(), self.joint.as_frozen(): - inseq = encoder_output # [B, T, D] + inseq = encoder_output # [B, T, D] - hypotheses = self._greedy_decode( - inseq, logitlen, device=inseq.device, partial_hypotheses=partial_hypotheses - ) + hypotheses = self._greedy_decode( + inseq, logitlen, device=inseq.device, partial_hypotheses=partial_hypotheses + ) # Pack the hypotheses results packed_result = pack_hypotheses(hypotheses, logitlen) @@ -2487,14 +2485,13 @@ def forward( hypotheses = [] # Process each sequence independently - with self.decoder.as_frozen(), self.joint.as_frozen(): - for batch_idx in range(encoder_output.size(0)): - inseq = encoder_output[batch_idx, :, :].unsqueeze(1) # [T, 1, D] - logitlen = encoded_lengths[batch_idx] + for batch_idx in range(encoder_output.size(0)): + inseq = encoder_output[batch_idx, :, :].unsqueeze(1) # [T, 1, D] + logitlen = encoded_lengths[batch_idx] - partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None - hypothesis = self._greedy_decode(inseq, logitlen, partial_hypotheses=partial_hypothesis) - hypotheses.append(hypothesis) + partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None + hypothesis = self._greedy_decode(inseq, logitlen, partial_hypotheses=partial_hypothesis) + hypotheses.append(hypothesis) # Pack results into Hypotheses packed_result = pack_hypotheses(hypotheses, encoded_lengths) @@ -2775,11 +2772,10 @@ def forward( self.decoder.eval() self.joint.eval() - with self.decoder.as_frozen(), self.joint.as_frozen(): - inseq = encoder_output # [B, T, D] - hypotheses = self._greedy_decode( - inseq, logitlen, device=inseq.device, partial_hypotheses=partial_hypotheses - ) + inseq = encoder_output # [B, T, D] + hypotheses = self._greedy_decode( + inseq, logitlen, device=inseq.device, partial_hypotheses=partial_hypotheses + ) # Pack the hypotheses results packed_result = pack_hypotheses(hypotheses, logitlen) diff --git a/nemo/collections/asr/parts/submodules/squeezeformer_modules.py b/nemo/collections/asr/parts/submodules/squeezeformer_modules.py index ff2cf7c5b3cc..212320e1f76f 100644 --- a/nemo/collections/asr/parts/submodules/squeezeformer_modules.py +++ b/nemo/collections/asr/parts/submodules/squeezeformer_modules.py @@ -16,14 +16,13 @@ from torch import nn as nn from torch.nn import LayerNorm +from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin from nemo.collections.asr.parts.submodules.conformer_modules import ConformerConvolution, ConformerFeedForward from nemo.collections.asr.parts.submodules.multi_head_attention import ( MultiHeadAttention, RelPositionMultiHeadAttention, ) -from nemo.collections.common.parts import adapter_modules from nemo.core.classes.mixins import AccessMixin -from nemo.core.classes.mixins.adapter_mixins import AdapterModuleMixin __all__ = ['SqueezeformerLayer', 'ConformerFeedForward', 'SqueezeformerLayer'] @@ -57,7 +56,7 @@ def forward(self, x): return x * scale + bias -class SqueezeformerLayer(torch.nn.Module, AdapterModuleMixin, AccessMixin): +class SqueezeformerLayer(torch.nn.Module, AttentionAdapterModuleMixin, AccessMixin): """A single block of the Squeezeformer encoder. Args: @@ -197,64 +196,6 @@ def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None): return x - def forward_single_enabled_adapter_( - self, - input: dict, - adapter_module: torch.nn.Module, - *, - adapter_name: str, - adapter_strategy: 'nemo.core.classes.mixins.adapter_mixin_strategies.AbstractAdapterStrategy', - ): - """ - Perform the forward step of a single adapter module on some input data. - - **Note**: Subclasses can override this method to accommodate more complicate adapter forward steps. - - Args: - input: Dictionary of packed tensors. The dict should contain at least - `x`: output tensor - `loc`: Semantic location in module where this adapter was called - `att_mask`: Optional, Attention mask - `pos_emb`: Optional, Positional Embedding for Relative Positional Encoding. - The output tensor of the calling module is the input to the first adapter, whose output - is then chained to the next adapter until all adapters are consumed. - adapter_module: The adapter module that is currently required to perform the forward pass. - adapter_name: The resolved name of the adapter that is undergoing the current forward pass. - adapter_strategy: A subclass of `AbstractAdapterStrategy`, that determines how the - output of the adapter should be merged with the input, or if it should be merged at all. - - Returns: - The result tensor, after the current active adapter has finished its forward pass. - """ - # (input: torch.Tensor, adapter: torch.nn.Module, *, module: 'AdapterModuleMixin') - x = input['x'] - loc = input['loc'] - att_mask = input.get('att_mask', None) - pos_emb = input.get('pos_emb', None) - - if isinstance(adapter_module, adapter_modules.LinearAdapter) and loc == 'post': - output = adapter_strategy(x, adapter_module, module=self) - - elif isinstance(adapter_module, MultiHeadAttention) and loc == 'mha': - if self.self_attention_model == 'rel_pos': - x = dict(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb) - output = adapter_strategy(x, adapter_module, module=self) - - elif self.self_attention_model == 'abs_pos': - x = dict(query=x, key=x, value=x, mask=att_mask) - output = adapter_strategy(x, adapter_module, module=self) - - else: - raise ValueError(f"Unsupported value of self_attention_model , provided {self.self_attention_model}!") - - else: - # No adapter compatible, skip - output = x - - input['x'] = output - - return input - def reset_parameters(self): # Used for Squeezeformer initialization only self.feed_forward1.reset_parameters_ff() diff --git a/nemo/collections/asr/parts/utils/adapter_utils.py b/nemo/collections/asr/parts/utils/adapter_utils.py index 5b74a296419a..b85bdee7051a 100644 --- a/nemo/collections/asr/parts/utils/adapter_utils.py +++ b/nemo/collections/asr/parts/utils/adapter_utils.py @@ -21,6 +21,8 @@ # Constants LINEAR_ADAPTER_CLASSPATH = "nemo.collections.common.parts.adapter_modules.LinearAdapter" + +# Conformer Adapters MHA_ADAPTER_CLASSPATH = ( "nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module.MultiHeadAttentionAdapter" ) @@ -32,6 +34,9 @@ "nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module.RelPositionalEncodingAdapter" ) +# Transformer Adapters +TRANSFORMER_MHA_ADAPTER_CLASSPATH = "nemo.collections.asr.parts.submodules.adapters.transformer_multi_head_attention_adapter_module.TransformerMultiHeadAttentionAdapter" + def convert_adapter_cfg_to_dict_config(cfg: DictConfig): # Convert to DictConfig from dict or Dataclass @@ -58,7 +63,7 @@ def update_adapter_cfg_input_dim(module: torch.nn.Module, cfg: DictConfig, *, mo """ cfg = convert_adapter_cfg_to_dict_config(cfg) - input_dim_valid_keys = ['in_features', 'n_feat'] + input_dim_valid_keys = ['in_features', 'n_feat', 'hidden_size'] input_key = None for key in input_dim_valid_keys: diff --git a/nemo/collections/nlp/modules/common/transformer/transformer_generators.py b/nemo/collections/nlp/modules/common/transformer/transformer_generators.py index 6e17151dcd1b..9bac89f61135 100644 --- a/nemo/collections/nlp/modules/common/transformer/transformer_generators.py +++ b/nemo/collections/nlp/modules/common/transformer/transformer_generators.py @@ -179,8 +179,7 @@ def __call__( ) def freeze(self) -> None: - """Freeze weights of embedding, decoder, and classification layers to prevent memory leak. - """ + """Freeze weights of embedding, decoder, and classification layers to prevent memory leak.""" for param in self.embedding.parameters(): param.requires_grad = False self.embedding.eval() @@ -192,8 +191,7 @@ def freeze(self) -> None: self.log_softmax.eval() def unfreeze(self) -> None: - """Unfreeze weights of embedding, decoder, and classification layers. - """ + """Unfreeze weights of embedding, decoder, and classification layers.""" for param in self.embedding.parameters(): param.requires_grad = True self.embedding.train() @@ -347,13 +345,13 @@ def _forward( # choose top-k hypotheses with length penalty applied len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen) scores = scores / len_penalties - scores, indices_i = torch.topk(scores.view(-1, self.beam_size ** 2), self.beam_size, dim=1) + scores, indices_i = torch.topk(scores.view(-1, self.beam_size**2), self.beam_size, dim=1) scores = scores.view(-1, 1) * len_penalties # select prefixes which correspond to the chosen hypotheses prefixes = prefixes.unsqueeze(1).repeat(1, self.beam_size, 1) prefixes = torch.cat((prefixes, prefixes_i.unsqueeze(2)), dim=2) - prefixes = prefixes.view(batch_size, self.beam_size ** 2, -1) + prefixes = prefixes.view(batch_size, self.beam_size**2, -1) p_len = prefixes.size(2) prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len) prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len) @@ -453,7 +451,10 @@ def _one_step_forward_lm(self, decoder_input_ids=None, lm_mems_list=None, pos=0) input_mask = mask_padded_tokens(decoder_input_ids, self.pad).float() lm_hidden_states = self.language_model.encoder.embedding.forward(decoder_input_ids, start_pos=pos) lm_mems_list = self.language_model.encoder.encoder.forward( - lm_hidden_states, input_mask, lm_mems_list, return_mems=True, + lm_hidden_states, + input_mask, + lm_mems_list, + return_mems=True, ) lm_log_probs = self.language_model.log_softmax.forward(hidden_states=lm_mems_list[-1][:, -1:]) return lm_log_probs, lm_mems_list @@ -629,13 +630,13 @@ def _forward(self, src_ids, encoder_input_mask, decoder_input_ids=None, return_b # choose top-k hypotheses with length penalty applied len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen) scores = scores / len_penalties - scores, indices_i = torch.topk(scores.view(-1, self.beam_size ** 2), self.beam_size, dim=1) + scores, indices_i = torch.topk(scores.view(-1, self.beam_size**2), self.beam_size, dim=1) scores = scores.view(-1, 1) * len_penalties # select prefixes which correspond to the chosen hypotheses prefixes = prefixes.unsqueeze(1).repeat(1, self.beam_size, 1) prefixes = torch.cat((prefixes, prefixes_i.unsqueeze(2)), dim=2) - prefixes = prefixes.view(batch_size, self.beam_size ** 2, -1) + prefixes = prefixes.view(batch_size, self.beam_size**2, -1) p_len = prefixes.size(2) prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len) prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len) @@ -691,8 +692,7 @@ def __call__(self, src_ids, encoder_input_mask, decoder_input_ids=None, return_b return self._forward(src_ids, encoder_input_mask, decoder_input_ids, return_beam_scores) def freeze(self) -> None: - """Freeze weights of embedding, decoder, and classification layers to prevent memory leak. - """ + """Freeze weights of embedding, decoder, and classification layers to prevent memory leak.""" for model_num in range(self.num_models): for param in self.embeddings[model_num].parameters(): param.requires_grad = False @@ -708,8 +708,7 @@ def freeze(self) -> None: self.encoders[model_num].eval() def unfreeze(self) -> None: - """Unfreeze weights of embedding, decoder, and classification layers. - """ + """Unfreeze weights of embedding, decoder, and classification layers.""" for model_num in range(self.num_models): for param in self.embeddings[model_num].parameters(): param.requires_grad = True @@ -730,6 +729,40 @@ def as_frozen(self): Context manager which temporarily freezes embedding, decoder, and log_softmax modules, yields control and finally unfreezes the modules. """ + grad_module_list = {'embeddings': {}, 'decoders': {}, 'log_softmaxes': {}, 'encoders': {}} + training_mode_module_list = {'embeddings': {}, 'decoders': {}, 'log_softmaxes': {}, 'encoders': {}} + + def gather_grad_values(module_name): + map_values = [{} for _ in range(self.num_models)] + for model_num in range(self.num_models): + for name, param in getattr(self, module_name)[model_num].named_parameters(): + map_values[model_num][name].append(param.requires_grad) + return map_values + + def reset_grad_values(module_name, map_values, require_grad_default: bool): + for model_num in range(self.num_models): + for name, param in getattr(self, module_name)[model_num].named_parameters(): + if name in map_values[model_num]: + param.requires_grad = map_values[model_num].pop() + else: + param.requires_grad = require_grad_default + + def gather_reset_training_mode_values(module_name, map_values: dict = None): + map_values = [{} for _ in range(self.num_models)] if not map_values else map_values + get_values = len(map_values) == 0 + + for model_num in range(self.num_models): + if get_values: + map_values[model_num] = getattr(self, module_name)[model_num].training + else: + getattr(self, module_name)[model_num].train(map_values[model_num]) + return map_values + + # Cache the param.require_grad state of each module + for module_name in grad_module_list.keys(): + grad_module_list[module_name] = gather_grad_values(module_name) + training_mode_module_list[module_name] = gather_reset_training_mode_values(module_name) + self.freeze() try: @@ -737,6 +770,11 @@ def as_frozen(self): finally: self.unfreeze() + # Reset the param.require_grad state of each module + for module_name in grad_module_list.keys(): + reset_grad_values(module_name, grad_module_list[module_name], require_grad_default=True) + gather_reset_training_mode_values(module_name, map_values=training_mode_module_list[module_name]) + class BeamSearchSequenceGeneratorWithLanguageModel(GreedySequenceGenerator): def __init__( @@ -771,13 +809,20 @@ def _one_step_forward( ): nmt_log_probs, decoder_mems_list = super()._one_step_forward( - decoder_input_ids, encoder_hidden_states, encoder_input_mask, decoder_mems_list, pos, + decoder_input_ids, + encoder_hidden_states, + encoder_input_mask, + decoder_mems_list, + pos, ) input_mask = mask_padded_tokens(decoder_input_ids, self.pad).float() lm_hidden_states = self.language_model.encoder.embedding.forward(decoder_input_ids, start_pos=pos) lm_mems_list = self.language_model.encoder.encoder.forward( - lm_hidden_states, input_mask, lm_mems_list, return_mems=True, + lm_hidden_states, + input_mask, + lm_mems_list, + return_mems=True, ) lm_log_probs = self.language_model.log_softmax.forward(hidden_states=lm_mems_list[-1][:, -1:]) @@ -853,13 +898,13 @@ def _forward( # choose top-k hypotheses with length penalty applied len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen) scores = scores / len_penalties - scores, indices_i = torch.topk(scores.view(-1, self.beam_size ** 2), self.beam_size, dim=1) + scores, indices_i = torch.topk(scores.view(-1, self.beam_size**2), self.beam_size, dim=1) scores = scores.view(-1, 1) * len_penalties # select prefixes which correspond to the chosen hypotheses prefixes = prefixes.unsqueeze(1).repeat(1, self.beam_size, 1) prefixes = torch.cat((prefixes, prefixes_i.unsqueeze(2)), dim=2) - prefixes = prefixes.view(batch_size, self.beam_size ** 2, -1) + prefixes = prefixes.view(batch_size, self.beam_size**2, -1) p_len = prefixes.size(2) prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len) prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len) diff --git a/nemo/core/classes/mixins/adapter_mixins.py b/nemo/core/classes/mixins/adapter_mixins.py index 2a05f374d464..05ac9b429d85 100644 --- a/nemo/core/classes/mixins/adapter_mixins.py +++ b/nemo/core/classes/mixins/adapter_mixins.py @@ -15,7 +15,7 @@ import inspect from abc import ABC from dataclasses import dataclass, is_dataclass -from typing import List, Optional, Set, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union import torch import torch.nn as nn @@ -123,8 +123,72 @@ def _prepare_default_adapter_config(*, global_key: str, meta_key: str, cfg: Dict return cfg +def update_module_class_with_adapter_class( + module: nn.Module, cfg: DictConfig, update_config: bool = True, verbose: bool = True +): + """ + Recursively walks through the module and its children, checking if the class is registered in the adapter registry. + If it is, the module's class is swapped with the registered adapter class. + Also updates the config with the adapter classpath, if required. + + Args: + module: torch.nn.Module to recurse through. + cfg: DictConfig object or dict that contains the config of the module. + update_config: Bool, whether to update the config with the adapter classpath. + verbose: Bool, whether to log the changes made to the module and config. + """ + + def inplace_recursive_walk_dict(d: Union[dict, DictConfig], base_class_path: str, adapter_class_path: str): + """ + Utility function to recursively walk through a dictionary and update the classpath if required. + Update is done inplace + + Args: + d: Dict to recurse through. + base_class_path: The str classpath of the base class. + adapter_class_path: The str classpath of the adapter class. + """ + for k, v in d.items(): # Loop through all k, v pairs + if isinstance(v, (dict, DictConfig)): # If value is a dict, recurse through it + inplace_recursive_walk_dict(v, base_class_path, adapter_class_path) + + # If key is target and value is base class, update the value to adapter class + elif k in ('target', '_target_') and isinstance(v, str) and v == base_class_path: + if verbose: + logging.info( + f"Updating config from {v} (base class) to {adapter_class_path} (adapter compatible " f"class)" + ) + + # Update the value inplace + d[k] = adapter_class_path + + if not isinstance(module, AdapterModuleMixin): + info = get_registered_adapter(module.__class__) + if info is not None: + if verbose: + logging.info( + f"Swapping class {info.base_class_path} with adapter compatible class: " + f"{info.adapter_class_path}" + ) + + # Swap the registered class with its registered adapter class. + # Due to direct inheritance of the Adapter subclass from the original class, + # the module's class container will be replaced with the adapter class. + + adapter_cls = info.adapter_class + module.__class__ = adapter_cls + + if update_config: + # Update the adapter config with the registered adapter config + # Find the location where the original module was registered in config + # and replace it with the adapter classpath. + original_classpath = info.base_class_path + adapter_classpath = info.adapter_class_path + inplace_recursive_walk_dict(cfg, original_classpath, adapter_classpath) + + class AdapterModuleMixin(ABC): - """ Generic Adapter Mixin that can augment any torch.nn.Module with Adapter module support. + """Generic Adapter Mixin that can augment any torch.nn.Module with Adapter module support. This mixin class adds a hierarchical way to add any type of Adapter modules to a pre-existing module. Since Models are inherently also nn.Module, this mixin can be attached to any Model or Module. @@ -171,21 +235,7 @@ def add_adapter(self, name: str, cfg: Union[DictConfig, AdapterConfig], **kwargs cfg = DictConfig(cfg) adapter_types = self.get_accepted_adapter_types() - _pass_types = False - if len(adapter_types) > 0: - test = model_utils.import_class_by_path(cfg._target_) - for _type in adapter_types: - # TODO: (@adithyare) should revisit if subclass is the best check... - if issubclass(test, _type): - _pass_types = True - break - if not _pass_types: - raise ValueError( - f"Config: \n{OmegaConf.to_yaml(cfg)}\n" - f"It creates adapter class {test} \n" - f"that is not in the list of accepted adapter types.\n" - f"Accepted adapters: {[t for t in adapter_types]}" - ) + self.check_supported_adapter_type_(cfg, adapter_types) # Convert to DictConfig from dict or Dataclass if is_dataclass(cfg): @@ -363,7 +413,9 @@ def set_accepted_adapter_types(self, adapter_types: List[Union[type, str]]) -> N self._accepted_adapter_types = set(types) - def get_accepted_adapter_types(self,) -> Set[type]: + def get_accepted_adapter_types( + self, + ) -> Set[type]: """ Utility function to get the set of all classes that are accepted by the module. @@ -543,9 +595,38 @@ def forward_single_enabled_adapter_( output = adapter_strategy(input, adapter_module, module=self) return output + def check_supported_adapter_type_( + self, adapter_cfg: DictConfig, supported_adapter_types: Optional[Iterable[type]] = None + ): + """ + Utility method to check if the adapter module is a supported type by the module. + + This method should be called by the subclass to ensure that the adapter module is a supported type. + """ + _pass_types = False + + if supported_adapter_types is None: + supported_adapter_types = self.get_accepted_adapter_types() + + if len(supported_adapter_types) > 0: + test = model_utils.import_class_by_path(adapter_cfg['_target_']) + for _type in supported_adapter_types: + # TODO: (@adithyare) should revisit if subclass is the best check... + if issubclass(test, _type): + _pass_types = True + break + + if not _pass_types: + raise ValueError( + f"Config: \n{OmegaConf.to_yaml(adapter_cfg)}\n" + f"It creates adapter class {test} \n" + f"that is not in the list of accepted adapter types.\n" + f"Accepted adapters: {[t for t in supported_adapter_types]}" + ) + class AdapterModelPTMixin(AdapterModuleMixin): - """ Adapter Mixin that can augment a ModelPT subclass with Adapter support. + """Adapter Mixin that can augment a ModelPT subclass with Adapter support. This mixin class should be used only with a top level ModelPT subclass. This mixin class adds several utility methods which should be subclassed and overriden to @@ -641,7 +722,9 @@ def add_adapter(self, name: str, cfg: Union[DictConfig, AdapterConfig]): self.cfg.adapters = OmegaConf.create({}) self.cfg.adapters = _prepare_default_adapter_config( - global_key=self.adapter_global_cfg_key, meta_key=self.adapter_metadata_cfg_key, cfg=self.cfg.adapters, + global_key=self.adapter_global_cfg_key, + meta_key=self.adapter_metadata_cfg_key, + cfg=self.cfg.adapters, ) # If the adapter is not being restored, force unique name to be provided for all adapters. @@ -970,6 +1053,19 @@ def update_adapter_cfg(self, cfg: DictConfig): if isinstance(module, AdapterModuleMixin): module.adapter_cfg = cfg + def replace_adapter_compatible_modules(self, update_config: bool = True, verbose: bool = True): + """ + Utility method to replace all child modules with Adapter variants, if they exist. + Does NOT recurse through children of children modules (only immediate children). + + Args: + update_config: A flag that determines if the config should be updated or not. + verbose: A flag that determines if the method should log the changes made or not. + """ + # Update the given module itself, and then all its children modules + for name, mod in self.named_modules(): + update_module_class_with_adapter_class(mod, cfg=self.cfg, update_config=update_config, verbose=verbose) + @property def adapter_module_names(self) -> List[str]: """ @@ -982,6 +1078,22 @@ def adapter_module_names(self) -> List[str]: Returns: A list of str, one for each of the adapter modules that are supported. By default, the subclass - should support the "global adapter" (''). + should support the "default adapter" (''). """ return [''] + + @property + def default_adapter_module_name(self) -> Optional[str]: + """ + Name of the adapter module that is used as "default" if a name of '' is provided. + + .. note:: + + Subclasses should override this property and return a str name of the module + that they wish to denote as the default. + + Returns: + A str name of a module, which is denoted as 'default' adapter or None. If None, then no default + adapter is supported. + """ + return None diff --git a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py index c520bd4c1292..cac1eb2fcdf3 100644 --- a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py +++ b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py @@ -12,12 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import pytest import torch from omegaconf import DictConfig, ListConfig, OmegaConf -from nemo.collections.asr.models import ASRModel, EncDecCTCModel, EncDecRNNTModel -from nemo.collections.asr.parts.submodules.adapters import multi_head_attention_adapter_module +from nemo.collections.asr.models import ASRModel, EncDecCTCModel, EncDecMultiTaskModel, EncDecRNNTModel +from nemo.collections.asr.parts.submodules.adapters import ( + multi_head_attention_adapter_module, + transformer_multi_head_attention_adapter_module, +) from nemo.collections.asr.parts.utils import adapter_utils from nemo.collections.common.parts import adapter_modules from nemo.core.classes.mixins.access_mixins import AccessMixin @@ -286,8 +291,130 @@ def rnnt_model(): return model_instance +@pytest.fixture() +def multitask_model(test_data_dir): + preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})} + + # fmt: off + tokenizer = { + 'dir': None, + 'type': 'agg', + 'langs': { + 'spl_tokens': { + 'dir': os.path.join(test_data_dir, 'asr', 'tokenizers', 'canary'), + 'type': 'bpe', + }, + 'en': { + 'dir': os.path.join(test_data_dir, 'asr', 'tokenizers', 'an4_spe_128'), + 'type': 'bpe', + } + }, + 'custom_tokenizer': { + '_target_': 'nemo.collections.common.tokenizers.canary_tokenizer.CanaryTokenizer', + 'tokenizers': None, + } + } + # fmt: on + + model_defaults = {"asr_enc_hidden": 128, "lm_enc_hidden": 128, "lm_dec_hidden": 128} + + # Test case where Encoder (default) is not adapter compatible + encoder = { + '_target_': 'nemo.collections.asr.modules.ConformerEncoder', + 'feat_in': 64, + 'feat_out': -1, + 'n_layers': 2, + 'd_model': 128, + 'subsampling': 'striding', + 'subsampling_factor': 4, + 'self_attention_model': 'rel_pos', + 'n_heads': 4, + 'conv_kernel_size': 31, + } + + transf_encoder = { + "_target_": "nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoder", + "num_layers": 1, + "hidden_size": "${model_defaults.lm_enc_hidden}", + "inner_size": int(4 * model_defaults['lm_enc_hidden']), + "num_attention_heads": 8, + "ffn_dropout": 0.1, + "attn_score_dropout": 0.1, + "attn_layer_dropout": 0.1, + "mask_future": False, + "pre_ln": True, + "pre_ln_final_layer_norm": True, + } + + transf_decoder = { + "_target_": "nemo.collections.asr.modules.transformer.get_nemo_transformer", + "model_name": None, + "pretrained": False, + "encoder": None, + "pre_ln_final_layer_norm": True, + "config_dict": { + "max_sequence_length": 512, + "num_token_types": 0, + "embedding_dropout": 0.1, + "learn_positional_encodings": False, + "hidden_size": "${model_defaults.lm_dec_hidden}", + "inner_size": "${multiply:${model_defaults.lm_dec_hidden}, 4}", + "num_layers": 2, + "num_attention_heads": 8, + "ffn_dropout": 0.1, + "attn_score_dropout": 0.1, + "attn_layer_dropout": 0.1, + "hidden_act": "relu", + "pre_ln": True, + "vocab_size": None, # Will be set by the model at runtime + "adapter": True, # Add support for adapter class + }, + } + + head = { + "_target_": "nemo.collections.asr.parts.submodules.token_classifier.TokenClassifier", + "num_layers": 1, + "activation": "relu", + "log_softmax": True, + "hidden_size": "${transf_decoder.config_dict.hidden_size}", + "num_classes": None, # Will be set by the model at runtime + "dropout": 0.0, + "use_transformer_init": True, + } + + decoding = {'strategy': 'beam', 'beam': {'beam_size': 1, 'len_pen': 0.0, 'max_generation_delta': 50}} + + loss = { + "_target_": "nemo.collections.common.losses.smoothed_cross_entropy.SmoothedCrossEntropyLoss", + "label_smoothing": 0.0, + "pad_id": None, + } + + modelConfig = DictConfig( + { + 'sample_rate': 16000, + 'prompt_format': 'canary', + 'preprocessor': DictConfig(preprocessor), + 'model_defaults': DictConfig(model_defaults), + 'tokenizer': DictConfig(tokenizer), + 'encoder': DictConfig(encoder), + 'transf_encoder': DictConfig(transf_encoder), + 'transf_decoder': DictConfig(transf_decoder), + 'head': DictConfig(head), + 'decoding': DictConfig(decoding), + 'loss': DictConfig(loss), + } + ) + + model_instance = EncDecMultiTaskModel(cfg=modelConfig) + + # Execute the model class swap logic + model_instance.replace_adapter_compatible_modules() + return model_instance + + def get_adapter_cfg(in_features=50, dim=100, norm_pos='pre', atype='linear', **kwargs): - valid_types = ['linear', 'mha', 'relmha'] + valid_types = ['linear', 'mha', 'relmha', 'transf_mha'] if atype not in valid_types: raise ValueError(f"Invalid type. Valid types = {atype}") @@ -295,7 +422,15 @@ def get_adapter_cfg(in_features=50, dim=100, norm_pos='pre', atype='linear', **k cfg = adapter_modules.LinearAdapterConfig(in_features=in_features, dim=dim, norm_position=norm_pos) elif atype == 'mha': cfg = multi_head_attention_adapter_module.MultiHeadAttentionAdapterConfig( - n_head=kwargs.get('n_head', 1), n_feat=in_features + n_head=kwargs.get('n_head', 1), + n_feat=in_features, + proj_dim=kwargs.get('proj_dim', None), + ) + elif atype == 'transf_mha': + cfg = transformer_multi_head_attention_adapter_module.TransformerMultiHeadAttentionAdapterConfig( + num_attention_heads=kwargs.get('n_head', 1), + hidden_size=in_features, + proj_dim=kwargs.get('proj_dim', None), ) elif atype == 'relmha': cfg = multi_head_attention_adapter_module.RelPositionMultiHeadAttentionAdapterConfig( @@ -375,12 +510,14 @@ def test_asr_model_constructor_joint_module_ctc_skip(self, model): original_num_params = model.num_weights # this step should exit without adding adapters and without errors - model.add_adapter(name='joint:adapter_0', cfg=get_adapter_cfg()) + with pytest.raises(ValueError): + model.add_adapter(name='joint:adapter_0', cfg=get_adapter_cfg()) new_num_params = model.num_weights assert new_num_params == original_num_params @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_asr_model_constructor_joint_module_rnnt(self, rnnt_model): @@ -467,6 +604,74 @@ def test_squeezeformer_forward_mha(self, squeezeformer_ctc_adapter, name): assert torch.mean(torch.abs(origial_output - new_output)) < 1e-5 + @pytest.mark.unit + @pytest.mark.parametrize('adapter_type', ['linear', 'attn']) + @pytest.mark.parametrize( + 'name', ['adapter_0', 'encoder:adapter_0', 'transf_encoder:adapter_0', 'transf_decoder:adapter_0'] + ) + def test_canary_forward_mha(self, multitask_model, name, adapter_type): + multitask_model.eval() + torch.random.manual_seed(0) + input_signal = torch.randn(2, 512) + input_signal_length = torch.tensor([512, 512], dtype=torch.int32) + transcript = torch.randint(0, multitask_model.tokenizer.vocab_size, size=(2, 10)) + transcript_len = torch.tensor([10, 9], dtype=torch.int32) + + origial_output = multitask_model( + input_signal=input_signal, + input_signal_length=input_signal_length, + transcript=transcript, + transcript_length=transcript_len, + ) + og_logprob = origial_output[0] + og_enc_out = origial_output[2] + + if adapter_type == 'attn': + adapter_type = 'transf_mha' if 'transf' in name else 'mha' + + multitask_model.add_adapter(name=name, cfg=get_adapter_cfg(in_features=128, atype=adapter_type, proj_dim=4)) + + new_output = multitask_model( + input_signal=input_signal, + input_signal_length=input_signal_length, + transcript=transcript, + transcript_length=transcript_len, + ) + + new_logprob = new_output[0] + new_enc_out = new_output[2] + + assert torch.mean(torch.abs(og_logprob - new_logprob)) < 1e-5 + assert torch.mean(torch.abs(og_enc_out - new_enc_out)) < 1e-5 + + if 'linear' in adapter_type: + mod_name = name.split(":")[-1] + for mod in multitask_model.modules(): + if isinstance(mod, AdapterModuleMixin): + amodule = mod.get_adapter_module(mod_name) + if amodule is not None: + assert isinstance(amodule, adapter_modules.LinearAdapter) + + # Try to use incorrect adapter + with pytest.raises(ValueError): + multitask_model.add_adapter( + name="transf_encoder:adapter_1", cfg=get_adapter_cfg(in_features=128, atype='mha') + ) + + @pytest.mark.unit + @pytest.mark.parametrize('name', ['transf_decoder:adapter_0']) + def test_canary_forward_mha_decoder_fails_without_support(self, multitask_model, name): + multitask_model.eval() + torch.random.manual_seed(0) + + # Change internal class of transf_decoder module + adapter_class = multitask_model.transf_decoder.__class__ + multitask_model.transf_decoder.__class__ = get_registered_adapter(adapter_class).base_class + + with pytest.raises(AttributeError): + adapter_type = 'transf_mha' if 'transf' in name else 'mha' + multitask_model.add_adapter(name=name, cfg=get_adapter_cfg(in_features=128, atype=adapter_type)) + @pytest.mark.unit @pytest.mark.parametrize('name1', ['adapter_0', 'encoder:adapter_0', 'decoder:adapter_0']) @pytest.mark.parametrize('name2', ['adapter_1', 'encoder:adapter_1', 'decoder:adapter_1']) @@ -488,7 +693,8 @@ def test_asr_multi_adapter_forward(self, model, name1, name2): assert torch.mean(torch.abs(origial_output - new_output)) < 1e-5 @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.parametrize('name1', ['decoder:adapter_0', 'joint:adapter_0']) @pytest.mark.parametrize('name2', ['decoder:adapter_1', 'joint:adapter_1']) @@ -582,7 +788,8 @@ def test_constructor_pretrained(self): assert model.num_weights < 1e5 @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.with_downloads() @pytest.mark.unit diff --git a/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py b/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py index c4ee4b97a2a6..ffaf1e640f3e 100644 --- a/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py +++ b/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py @@ -111,6 +111,22 @@ def test_rel_pos_encoding_adapter_config(self): assert cls_subset is None assert dataclass_subset is None + @pytest.mark.unit + def test_transformer_mha_adapter_config(self): + IGNORED_ARGS = ['_target_'] + + result = config_utils.assert_dataclass_signature_match( + adapter_modules.TransformerMultiHeadAttentionAdapter, + adapter_modules.TransformerMultiHeadAttentionAdapterConfig, + ignore_args=IGNORED_ARGS, + ) + + signatures_match, cls_subset, dataclass_subset = result + + assert signatures_match + assert cls_subset is None + assert dataclass_subset is None + @pytest.mark.unit @pytest.mark.parametrize('n_head', [1, 2, 10]) @pytest.mark.parametrize('proj_dim', [None, -1]) @@ -194,6 +210,31 @@ def test_relpos_encoding_init(self): assert (out - x).sum().abs() <= 1e-8 assert out.shape == x.shape + @pytest.mark.unit + @pytest.mark.parametrize('n_head', [1, 2, 10]) + @pytest.mark.parametrize('proj_dim', [None, -1]) + def test_transformer_mha_adapter_init(self, n_head, proj_dim): + torch.random.manual_seed(0) + x = torch.randn(2, 32, 50) + lengths = torch.randint(1, x.size(1), size=(x.size(0),)) + lengths[torch.randint(0, x.size(0), size=(1,))[0]] = x.size(1) + + adapter = adapter_modules.TransformerMultiHeadAttentionAdapter( + num_attention_heads=n_head, hidden_size=50, attn_layer_dropout=0.0, proj_dim=proj_dim + ) + + pad_mask, att_mask = get_mask(lengths) + att_mask = att_mask.unsqueeze(1) + + with torch.no_grad(): + assert adapter.out_projection.weight.sum() == 0 + if hasattr(adapter.out_projection, 'bias') and adapter.out_projection.bias is not None: + assert adapter.out_projection.bias.sum() == 0 + + out = adapter(x, x, x, att_mask) + assert out.sum().abs() <= 1e-8 + assert out.shape == x.shape + @pytest.mark.unit def test_mha_adapter_strategy(self): adapter = adapter_modules.MultiHeadAttentionAdapter(n_head=1, n_feat=50, dropout_rate=0.0) @@ -225,3 +266,13 @@ def test_relpos_encoding_adapter_strategy(self): assert adapter.adapter_strategy is not None # assert default strategy is set assert isinstance(adapter.adapter_strategy, adapter_mixin_strategies.ReturnResultAdapterStrategy) + + @pytest.mark.unit + def test_transformer_mha_adapter_strategy(self): + adapter = adapter_modules.TransformerMultiHeadAttentionAdapter( + num_attention_heads=1, hidden_size=50, attn_layer_dropout=0.0 + ) + assert hasattr(adapter, 'adapter_strategy') + assert adapter.adapter_strategy is not None + # assert default strategy is set + assert isinstance(adapter.adapter_strategy, adapter_modules.MHAResidualAddAdapterStrategy) diff --git a/tests/core/mixins/adapters/test_adapter_model_mixin.py b/tests/core/mixins/adapters/test_adapter_model_mixin.py index 87c6b4e4cfb3..20ced653ceb6 100644 --- a/tests/core/mixins/adapters/test_adapter_model_mixin.py +++ b/tests/core/mixins/adapters/test_adapter_model_mixin.py @@ -14,12 +14,12 @@ import os import shutil import tempfile -from typing import Tuple +from typing import List, Optional, Tuple import pytest import torch from hydra.utils import instantiate -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig, OmegaConf, open_dict from nemo.core import ModelPT, NeuralModule from nemo.core.classes.mixins import adapter_mixin_strategies, adapter_mixins @@ -28,7 +28,7 @@ class DefaultModule(NeuralModule): - """ Define a default neural module (without adapter support)""" + """Define a default neural module (without adapter support)""" def __init__(self): super().__init__() @@ -51,7 +51,7 @@ def num_params(self): class DefaultModuleAdapter(DefaultModule, AdapterModuleMixin): - """ Subclass the DefaultModule, adding adapter module support""" + """Subclass the DefaultModule, adding adapter module support""" def forward(self, x): x = super(DefaultModuleAdapter, self).forward(x) @@ -66,7 +66,7 @@ def forward(self, x): class DefaultModelAdapterMixin(AdapterModelPTMixin): - """ Mixin class that implements this model's specific overrides to AdapterModelPTMixin + """Mixin class that implements this model's specific overrides to AdapterModelPTMixin It will container two modules, an encoder and a decoder, and both can have adapters. By default, encoder adapters are enabled, and decoder adapters are diabled. Decoder adapters can be enabled via the global_cfg in model.cfg.adapters. @@ -79,13 +79,13 @@ class DefaultModelAdapterMixin(AdapterModelPTMixin): def setup_adapters(self): supports_adapters = False - # Check the inheriting class' modules supports adapters or not - if hasattr(self, 'encoder') and isinstance(self.encoder, AdapterModuleMixin): - supports_adapters |= True - - if hasattr(self, 'decoder') and isinstance(self.decoder, AdapterModuleMixin): - supports_adapters |= True + # At least the encoder must extend AdapterModuleMixin + valid_adapter_names = [x for x in self.adapter_module_names if x != ''] + for module_name in valid_adapter_names: + if hasattr(self, module_name) and isinstance(getattr(self, module_name), AdapterModuleMixin): + supports_adapters |= True + # If adapters are supported, setup the adapter config + any modules (pre-existing adapter modules) if supports_adapters: super().setup_adapters() @@ -96,66 +96,98 @@ def add_adapter(self, name: str, cfg: DictConfig): # Resolve module name and adapter name module_name, adapter_name = self.resolve_adapter_module_name_(name) - # Try to retrieve global adapter config - global_config = self._get_global_cfg() - - # forward the method call to the individual modules - # If module name is empty, it is a global adapter, otherwise it is a local adapter - if (module_name == '' and global_config.get('encoder_adapter', True)) or (module_name == 'encoder'): - if hasattr(self, 'encoder'): - self.encoder.add_adapter(name, cfg) - - if (module_name == '' and global_config.get('decoder_adapter', False)) or (module_name == 'decoder'): - if hasattr(self, 'decoder'): - self.decoder.add_adapter(name, cfg) + # Use + as a splitter, in order to share one name across multiple modules + if '+' in module_name: + module_names = module_name.split('+') + else: + module_names = [module_name] + + valid_module_names = [x for x in self.adapter_module_names if x != ''] + default_module_name = self.default_adapter_module_name + + # Update the model.cfg with information about the new adapter from cfg + for module_name in module_names: + # Check if encoder adapters should be added + if module_name == '': + for default in default_module_name: # This model has multiple default modules + if hasattr(self, default): + # Dispatch the call to the default model. + getattr(self, default).add_adapter(name=name, cfg=cfg) + + elif module_name in valid_module_names: + # Check if module exists + if hasattr(self, module_name): + # Dispatch the call to the module. + getattr(self, module_name).add_adapter(name=name, cfg=cfg) def set_enabled_adapters(self, name=None, enabled: bool = True): # check if valid model with some adapter support super().set_enabled_adapters(name, enabled) - # Resolve module name and adapter name + # Resolve the module name and adapter name if name is not None: module_name, _ = self.resolve_adapter_module_name_(name) else: module_name = None - # Try to retrieve global adapter config - global_config = self._get_global_cfg() - - # Forward the method call to the individual modules - if name is None or global_config.get('encoder_adapter', True) or module_name in ('', 'encoder'): - if hasattr(self, 'encoder') and self.encoder.is_adapter_available(): - self.encoder.set_enabled_adapters(name, enabled) - - if name is None or global_config.get('decoder_adapter', False) or module_name == 'decoder': - if hasattr(self, 'decoder') and self.decoder.is_adapter_available(): - self.decoder.set_enabled_adapters(name, enabled) + # Use + as a splitter, in order to share one name across multiple modules + if module_name is not None and '+' in module_name: + module_names = module_name.split('+') + else: + module_names = [module_name] + + valid_module_names = [x for x in self.adapter_module_names if x != ''] + default_module_name = self.default_adapter_module_name + + # Check if default module name is None or not + if default_module_name is None: + raise ValueError( + f"Default module name is None. Class {self.__class__.__name__} must implement " + f"`default_adapter_module_name`" + ) + + # Forward the method call to the individual modules if they exist + for module_name in module_names: + # Check if encoder adapters should be used + + if module_name == '': + for default in default_module_name: + if hasattr(self, default) and isinstance(getattr(self, default), AdapterModuleMixin): + if getattr(self, default).is_adapter_available(): + # Dispatch the call to the default model. + getattr(self, default).set_enabled_adapters(name=name, enabled=enabled) + + elif module_name in valid_module_names: + if hasattr(self, module_name) and isinstance(getattr(self, module_name), AdapterModuleMixin): + if getattr(self, module_name).is_adapter_available(): + # Dispatch the call to the module. + getattr(self, module_name).set_enabled_adapters(name=name, enabled=enabled) def get_enabled_adapters(self) -> list: enabled_adapters = super().get_enabled_adapters() - # Forward the method call to the individual modules - if hasattr(self, 'encoder') and isinstance(self.encoder, AdapterModuleMixin): - encoder_adapters = self.encoder.get_enabled_adapters() - enabled_adapters.extend(encoder_adapters) + valid_module_names = [x for x in self.adapter_module_names if x != ''] - if hasattr(self, 'decoder') and isinstance(self.decoder, AdapterModuleMixin): - decoder_adapters = self.decoder.get_enabled_adapters() - enabled_adapters.extend(decoder_adapters) + # Check if encoder adapters should be used or are enabled + for module_name in valid_module_names: + if hasattr(self, module_name) and isinstance(getattr(self, module_name), AdapterModuleMixin): + enabled_adapters.extend(getattr(self, module_name).get_enabled_adapters()) + + enabled_adapters = list(sorted(list(set(enabled_adapters)))) return enabled_adapters def is_adapter_available(self) -> bool: adapters_available = super().is_adapter_available() - # Try to retrieve global adapter config - # Forward the method call to the individual modules - if hasattr(self, 'encoder') and isinstance(self.encoder, AdapterModuleMixin): - print("Encoder is adapter available", self.encoder.is_adapter_available()) - adapters_available |= self.encoder.is_adapter_available() + valid_module_names = [x for x in self.adapter_module_names if x != ''] - if hasattr(self, 'decoder') and isinstance(self.decoder, AdapterModuleMixin): - adapters_available |= self.decoder.is_adapter_available() + # Forward the method call to the individual modules + for module_name in valid_module_names: + print("Module name", module_name) + if hasattr(self, module_name) and isinstance(getattr(self, module_name), AdapterModuleMixin): + adapters_available |= getattr(self, module_name).is_adapter_available() + print("Adapter available for module", module_name, getattr(self, module_name).is_adapter_available()) return adapters_available @@ -198,6 +230,19 @@ def adapter_module_names(self) -> list: valid_adapter_modules = ['', 'encoder', 'decoder'] return valid_adapter_modules + @property + def default_adapter_module_name(self) -> Optional[List[str]]: + global_config = self._get_global_cfg() + default_modules = [] + encoder_adapter = global_config.get('encoder_adapter', True) + decoder_adapter = global_config.get('decoder_adapter', False) + + if encoder_adapter: + default_modules.append('encoder') + if decoder_adapter: + default_modules.append('decoder') + return default_modules + class DefaultAdapterModel(ModelPT, DefaultModelAdapterMixin): def __init__(self, cfg, trainer=None): @@ -302,6 +347,23 @@ def test_base_model_no_support_for_adapters(self, caplog): logging._logger.propagate = False logging.set_verbosity(original_verbosity) + @pytest.mark.unit + def test_base_model_replace_adapter_compatible_modules(self, caplog): + cfg = get_model_config(in_features=50, update_adapter_cfg=False) + model = DefaultAdapterModel(cfg) + + with pytest.raises(AttributeError): + model.add_adapter(name='adapter_0', cfg=get_adapter_cfg()) + + # Replace the modules of the model dynamically to support adapters + model.replace_adapter_compatible_modules() + + assert isinstance(model.encoder, AdapterModuleMixin) + assert model.encoder.is_adapter_available() is False + + model.add_adapter(name='encoder:adapter_0', cfg=get_adapter_cfg()) + assert model.encoder.is_adapter_available() is True + @pytest.mark.unit def test_single_adapter(self): cfg = get_model_config(in_features=50) @@ -934,8 +996,18 @@ def test_multiple_decoder_save_load_adapter_only_exact_name(self): assert (original_state_dict[ogkey] - restored_state_dict[newkey]).abs().mean() < 1e-6 @pytest.mark.unit - @pytest.mark.parametrize("decoder", ["adapter_0",]) # "decoder:adapter_0" - @pytest.mark.parametrize("encoder", ["adapter_1",]) # "encoder:adapter_1" + @pytest.mark.parametrize( + "decoder", + [ + "adapter_0", + ], + ) # "decoder:adapter_0" + @pytest.mark.parametrize( + "encoder", + [ + "adapter_1", + ], + ) # "encoder:adapter_1" def test_multiple_save_load_adapter_with_multiple_load(self, decoder, encoder): # create a model config, but do not add global_cfg to it # we want to test just module level adapter