diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index 82824e5fa8..712bef8051 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -46,6 +46,8 @@ class TransformerDecoder(torch.nn.Module): src_attention: if false, encoder-decoder cross attention is not applied, such as CIF model key_bias: whether use bias in attention.linear_k, False for whisper models. + gradient_checkpointing: rerunning a forward-pass segment for each + checkpointed segment during backward. """ def __init__( @@ -65,6 +67,7 @@ def __init__( src_attention: bool = True, key_bias: bool = True, activation_type: str = "relu", + gradient_checkpointing: bool = False, ): super().__init__() attention_dim = encoder_output_size @@ -102,6 +105,8 @@ def __init__( ) for _ in range(self.num_blocks) ]) + self.gradient_checkpointing = gradient_checkpointing + def forward( self, memory: torch.Tensor, @@ -140,8 +145,12 @@ def forward( tgt_mask = tgt_mask & m x, _ = self.embed(tgt) for layer in self.decoders: - x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory, - memory_mask) + if self.gradient_checkpointing and self.training: + x, tgt_mask, memory, memory_mask = torch.utils.checkpoint.checkpoint( + x, tgt_mask, memory, memory_mask) + else: + x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory, + memory_mask) if self.normalize_before: x = self.after_norm(x) if self.use_output_layer: @@ -252,6 +261,7 @@ def __init__( use_output_layer: bool = True, normalize_before: bool = True, key_bias: bool = True, + gradient_checkpointing: bool = False, ): super().__init__() @@ -260,14 +270,14 @@ def __init__( num_blocks, dropout_rate, positional_dropout_rate, self_attention_dropout_rate, src_attention_dropout_rate, input_layer, use_output_layer, normalize_before, - key_bias=key_bias) + key_bias=key_bias, gradient_checkpointing=gradient_checkpointing) self.right_decoder = TransformerDecoder( vocab_size, encoder_output_size, attention_heads, linear_units, r_num_blocks, dropout_rate, positional_dropout_rate, self_attention_dropout_rate, src_attention_dropout_rate, input_layer, use_output_layer, normalize_before, - key_bias=key_bias) + key_bias=key_bias, gradient_checkpointing=gradient_checkpointing) def forward( self, diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index 8179d47374..dac6f3846d 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -49,6 +49,7 @@ def __init__( use_dynamic_chunk: bool = False, global_cmvn: torch.nn.Module = None, use_dynamic_left_chunk: bool = False, + gradient_checkpointing: bool = False, ): """ Args: @@ -78,6 +79,8 @@ def __init__( use_dynamic_left_chunk (bool): whether use dynamic left chunk in dynamic chunk training key_bias: whether use bias in attention.linear_k, False for whisper models. + gradient_checkpointing: rerunning a forward-pass segment for each + checkpointed segment during backward. """ super().__init__() self._output_size = output_size @@ -95,6 +98,7 @@ def __init__( self.static_chunk_size = static_chunk_size self.use_dynamic_chunk = use_dynamic_chunk self.use_dynamic_left_chunk = use_dynamic_left_chunk + self.gradient_checkpointing = gradient_checkpointing def output_size(self) -> int: return self._output_size @@ -138,7 +142,11 @@ def forward( self.static_chunk_size, num_decoding_left_chunks) for layer in self.encoders: - xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + if self.gradient_checkpointing and self.training: + xs, chunk_masks, _, _ = torch.utils.checkpoint.checkpoint( + layer, xs, chunk_masks, pos_emb, mask_pad) + else: + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) if self.normalize_before: xs = self.after_norm(xs) # Here we assume the mask is not changed in encoder layers, so just @@ -315,6 +323,7 @@ def __init__( use_dynamic_left_chunk: bool = False, key_bias: bool = True, activation_type: str = "relu", + gradient_checkpointing: bool = False, ): """ Construct TransformerEncoder @@ -325,7 +334,8 @@ def __init__( positional_dropout_rate, attention_dropout_rate, input_layer, pos_enc_layer_type, normalize_before, static_chunk_size, use_dynamic_chunk, - global_cmvn, use_dynamic_left_chunk) + global_cmvn, use_dynamic_left_chunk, + gradient_checkpointing) activation = WENET_ACTIVATION_CLASSES[activation_type]() self.encoders = torch.nn.ModuleList([ TransformerEncoderLayer( @@ -367,6 +377,7 @@ def __init__( causal: bool = False, cnn_module_norm: str = "batch_norm", key_bias: bool = True, + gradient_checkpointing: bool = False, ): """Construct ConformerEncoder @@ -390,7 +401,8 @@ def __init__( positional_dropout_rate, attention_dropout_rate, input_layer, pos_enc_layer_type, normalize_before, static_chunk_size, use_dynamic_chunk, - global_cmvn, use_dynamic_left_chunk) + global_cmvn, use_dynamic_left_chunk, + gradient_checkpointing) activation = WENET_ACTIVATION_CLASSES[activation_type]() # self-attention module definition diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index 37e8b85fff..955a922e31 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -128,6 +128,10 @@ def add_ddp_args(parser): action='store_true', default=False, help='Use fp16 gradient sync for ddp') + parser.add_argument('--use_grad_checkpoint', + action='store_true', + default=False, + help='Enable gradient_checkpointing during training') return parser