Skip to content

Commit

Permalink
fix error for rnn encoders flatten_parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
pyf98 committed May 10, 2022
1 parent 052dd60 commit cd77501
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions espnet/nets/pytorch_backend/rnn/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def forward(self, xs_pad, ilens, prev_state=None):
ilens = torch.tensor(ilens)
xs_pack = pack_padded_sequence(xs_pad, ilens.cpu(), batch_first=True)
rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer))
rnn.flatten_parameters()
if self.training:
rnn.flatten_parameters()
if prev_state is not None and rnn.bidirectional:
prev_state = reset_backward_rnn_state(prev_state)
ys, states = rnn(
Expand Down Expand Up @@ -144,7 +145,8 @@ def forward(self, xs_pad, ilens, prev_state=None):
if not isinstance(ilens, torch.Tensor):
ilens = torch.tensor(ilens)
xs_pack = pack_padded_sequence(xs_pad, ilens.cpu(), batch_first=True)
self.nbrnn.flatten_parameters()
if self.training:
self.nbrnn.flatten_parameters()
if prev_state is not None and self.nbrnn.bidirectional:
# We assume that when previous state is passed,
# it means that we're streaming the input
Expand Down

0 comments on commit cd77501

Please sign in to comment.