diff --git a/python/mxnet/rnn/rnn_cell.py b/python/mxnet/rnn/rnn_cell.py index c00f8a39d8c3..31015faf6fcf 100644 --- a/python/mxnet/rnn/rnn_cell.py +++ b/python/mxnet/rnn/rnn_cell.py @@ -675,8 +675,11 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N if not self._get_next_state: outputs, states = rnn, [] elif self._mode == 'lstm': + rnn[1]._set_attr(__layout__='LNC') + rnn[2]._set_attr(__layout__='LNC') outputs, states = rnn[0], [rnn[1], rnn[2]] else: + rnn[1]._set_attr(__layout__='LNC') outputs, states = rnn[0], [rnn[1]] if axis == 1: