diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index 56ee0002eb..a8b041054c 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -133,6 +133,10 @@ def forward( vocab_size) if use_output_layer is True, torch.tensor(0.0), in order to unify api with bidirectional decoder olens: (batch, ) + NOTE(xcsong): + We pass the `__call__` method of the modules instead of `forward` to the + checkpointing API because `__call__` attaches all the hooks of the module. + https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 """ tgt = ys_in_pad maxlen = tgt.size(1) diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index 22e239a737..844d9680d6 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -129,6 +129,10 @@ def forward( xs: padded output tensor (B, T' ~= T/subsample_rate, D) masks: torch.Tensor batch padding mask after subsample (B, 1, T' ~= T/subsample_rate) + NOTE(xcsong): + We pass the `__call__` method of the modules instead of `forward` to the + checkpointing API because `__call__` attaches all the hooks of the module. + https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 """ T = xs.size(1) masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)