Skip to content

Commit

Permalink
feat(train): Support gradient checkpointing for Conformer & Transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
xingchensong committed Nov 28, 2023
1 parent 0df2759 commit fbb0557
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 7 deletions.
18 changes: 14 additions & 4 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class TransformerDecoder(torch.nn.Module):
src_attention: if false, encoder-decoder cross attention is not
applied, such as CIF model
key_bias: whether use bias in attention.linear_k, False for whisper models.
gradient_checkpointing: rerunning a forward-pass segment for each
checkpointed segment during backward.
"""

def __init__(
Expand All @@ -65,6 +67,7 @@ def __init__(
src_attention: bool = True,
key_bias: bool = True,
activation_type: str = "relu",
gradient_checkpointing: bool = False,
):
super().__init__()
attention_dim = encoder_output_size
Expand Down Expand Up @@ -102,6 +105,8 @@ def __init__(
) for _ in range(self.num_blocks)
])

self.gradient_checkpointing = gradient_checkpointing

def forward(
self,
memory: torch.Tensor,
Expand Down Expand Up @@ -140,8 +145,12 @@ def forward(
tgt_mask = tgt_mask & m
x, _ = self.embed(tgt)
for layer in self.decoders:
x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
memory_mask)
if self.gradient_checkpointing and self.training:
x, tgt_mask, memory, memory_mask = torch.utils.checkpoint.checkpoint(
x, tgt_mask, memory, memory_mask)
else:
x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
memory_mask)
if self.normalize_before:
x = self.after_norm(x)
if self.use_output_layer:
Expand Down Expand Up @@ -252,6 +261,7 @@ def __init__(
use_output_layer: bool = True,
normalize_before: bool = True,
key_bias: bool = True,
gradient_checkpointing: bool = False,
):

super().__init__()
Expand All @@ -260,14 +270,14 @@ def __init__(
num_blocks, dropout_rate, positional_dropout_rate,
self_attention_dropout_rate, src_attention_dropout_rate,
input_layer, use_output_layer, normalize_before,
key_bias=key_bias)
key_bias=key_bias, gradient_checkpointing=gradient_checkpointing)

self.right_decoder = TransformerDecoder(
vocab_size, encoder_output_size, attention_heads, linear_units,
r_num_blocks, dropout_rate, positional_dropout_rate,
self_attention_dropout_rate, src_attention_dropout_rate,
input_layer, use_output_layer, normalize_before,
key_bias=key_bias)
key_bias=key_bias, gradient_checkpointing=gradient_checkpointing)

def forward(
self,
Expand Down
18 changes: 15 additions & 3 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
gradient_checkpointing: bool = False,
):
"""
Args:
Expand Down Expand Up @@ -78,6 +79,8 @@ def __init__(
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
dynamic chunk training
key_bias: whether use bias in attention.linear_k, False for whisper models.
gradient_checkpointing: rerunning a forward-pass segment for each
checkpointed segment during backward.
"""
super().__init__()
self._output_size = output_size
Expand All @@ -95,6 +98,7 @@ def __init__(
self.static_chunk_size = static_chunk_size
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk
self.gradient_checkpointing = gradient_checkpointing

def output_size(self) -> int:
return self._output_size
Expand Down Expand Up @@ -138,7 +142,11 @@ def forward(
self.static_chunk_size,
num_decoding_left_chunks)
for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
if self.gradient_checkpointing and self.training:
xs, chunk_masks, _, _ = torch.utils.checkpoint.checkpoint(
layer, xs, chunk_masks, pos_emb, mask_pad)
else:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
Expand Down Expand Up @@ -315,6 +323,7 @@ def __init__(
use_dynamic_left_chunk: bool = False,
key_bias: bool = True,
activation_type: str = "relu",
gradient_checkpointing: bool = False,
):
""" Construct TransformerEncoder
Expand All @@ -325,7 +334,8 @@ def __init__(
positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk,
global_cmvn, use_dynamic_left_chunk)
global_cmvn, use_dynamic_left_chunk,
gradient_checkpointing)
activation = WENET_ACTIVATION_CLASSES[activation_type]()
self.encoders = torch.nn.ModuleList([
TransformerEncoderLayer(
Expand Down Expand Up @@ -367,6 +377,7 @@ def __init__(
causal: bool = False,
cnn_module_norm: str = "batch_norm",
key_bias: bool = True,
gradient_checkpointing: bool = False,
):
"""Construct ConformerEncoder
Expand All @@ -390,7 +401,8 @@ def __init__(
positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk,
global_cmvn, use_dynamic_left_chunk)
global_cmvn, use_dynamic_left_chunk,
gradient_checkpointing)
activation = WENET_ACTIVATION_CLASSES[activation_type]()

# self-attention module definition
Expand Down
4 changes: 4 additions & 0 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def add_ddp_args(parser):
action='store_true',
default=False,
help='Use fp16 gradient sync for ddp')
parser.add_argument('--use_grad_checkpoint',
action='store_true',
default=False,
help='Enable gradient_checkpointing during training')
return parser


Expand Down

0 comments on commit fbb0557

Please sign in to comment.