Skip to content

Commit

Permalink
feat(train): add note
Browse files Browse the repository at this point in the history
  • Loading branch information
xingchensong committed Nov 29, 2023
1 parent 25705d0 commit 4d35b5a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
4 changes: 4 additions & 0 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def forward(
vocab_size) if use_output_layer is True,
torch.tensor(0.0), in order to unify api with bidirectional decoder
olens: (batch, )
NOTE(xcsong):
We pass the `__call__` method of the modules instead of `forward` to the
checkpointing API because `__call__` attaches all the hooks of the module.
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
"""
tgt = ys_in_pad
maxlen = tgt.size(1)
Expand Down
4 changes: 4 additions & 0 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ def forward(
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
masks: torch.Tensor batch padding mask after subsample
(B, 1, T' ~= T/subsample_rate)
NOTE(xcsong):
We pass the `__call__` method of the modules instead of `forward` to the
checkpointing API because `__call__` attaches all the hooks of the module.
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
"""
T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
Expand Down

0 comments on commit 4d35b5a

Please sign in to comment.