diff --git a/python/mxnet/rnn/rnn_cell.py b/python/mxnet/rnn/rnn_cell.py index c00f8a39d8c3..d0505f87ac40 100644 --- a/python/mxnet/rnn/rnn_cell.py +++ b/python/mxnet/rnn/rnn_cell.py @@ -672,11 +672,15 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N mode=self._mode, name=self._prefix+'rnn', **states) + attr = {'__layout__' : 'LNC'} if not self._get_next_state: outputs, states = rnn, [] elif self._mode == 'lstm': + rnn[1]._set_attr(**attr) + rnn[2]._set_attr(**attr) outputs, states = rnn[0], [rnn[1], rnn[2]] else: + rnn[1]._set_attr(**attr) outputs, states = rnn[0], [rnn[1]] if axis == 1: