diff --git a/espnet/nets/pytorch_backend/rnn/encoders.py b/espnet/nets/pytorch_backend/rnn/encoders.py index f01acd5a6a4..7ab90e6c3ae 100644 --- a/espnet/nets/pytorch_backend/rnn/encoders.py +++ b/espnet/nets/pytorch_backend/rnn/encoders.py @@ -69,7 +69,8 @@ def forward(self, xs_pad, ilens, prev_state=None): ilens = torch.tensor(ilens) xs_pack = pack_padded_sequence(xs_pad, ilens.cpu(), batch_first=True) rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer)) - rnn.flatten_parameters() + if self.training: + rnn.flatten_parameters() if prev_state is not None and rnn.bidirectional: prev_state = reset_backward_rnn_state(prev_state) ys, states = rnn( @@ -144,7 +145,8 @@ def forward(self, xs_pad, ilens, prev_state=None): if not isinstance(ilens, torch.Tensor): ilens = torch.tensor(ilens) xs_pack = pack_padded_sequence(xs_pad, ilens.cpu(), batch_first=True) - self.nbrnn.flatten_parameters() + if self.training: + self.nbrnn.flatten_parameters() if prev_state is not None and self.nbrnn.bidirectional: # We assume that when previous state is passed, # it means that we're streaming the input