Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Jun 29, 2018
1 parent 5d2fa00 commit 1abaa13
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
15 changes: 5 additions & 10 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
__all__ = ['RNN', 'LSTM', 'GRU']

from ... import ndarray, symbol
from .. import HybridBlock
from .. import HybridBlock, tensor_types
from . import rnn_cell

class _RNNLayer(HybridBlock):
Expand Down Expand Up @@ -180,17 +180,13 @@ def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs):
def hybrid_forward(self, F, inputs, states=None, **kwargs):
if F is ndarray:
batch_size = inputs.shape[self._layout.find('N')]
if self._input_size == 0:
for i in range(self._dir):
self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
self.i2h_weight[i]._finish_deferred_init()
skip_states = states is None
if skip_states:
if F is ndarray:
states = self.begin_state(batch_size, ctx=inputs.context)
else:
states = self.begin_state(0, func=symbol.zeros)
if isinstance(states, (ndarray.NDArray, symbol.Symbol)):
if isinstance(states, tensor_types):
states = [states]
if F is ndarray:
for state, info in zip(states, self.state_info(batch_size)):
Expand All @@ -203,13 +199,12 @@ def hybrid_forward(self, F, inputs, states=None, **kwargs):
# out is (output, state)
return out[0] if skip_states else out

def __call__(self, inputs, *states):
if self._input_size == 0 and isinstance(inputs, ndarray.NDArray):
def infer_shape(self, inputs, *states):
if self._input_size == 0:
for i in range(self._dir):
self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
self.i2h_weight[i]._finish_deferred_init()
states = list(filter(lambda x: x is not None, states))
return super(_RNNLayer, self).__call__(inputs, *states)
return super(_RNNLayer, self).infer_shape(inputs, *states)

def _forward_kernel(self, F, inputs, states, **kwargs):
""" forward using CUDNN or CPU kenrel"""
Expand Down
13 changes: 10 additions & 3 deletions tests/python/unittest/test_gluon_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,10 @@ def check_rnn_layer_forward(layer, inputs, states=None, run_only=False):
layer.collect_params().initialize()
inputs.attach_grad()
with mx.autograd.record():
out = layer(inputs, states)
if states is None:
out = layer(inputs)
else:
out = layer(inputs, states)
if states is not None:
assert isinstance(out, (list, tuple)) and len(out) == 2
out = out[0]
Expand All @@ -260,15 +263,19 @@ def check_rnn_layer_forward(layer, inputs, states=None, run_only=False):
layer.hybridize()

with mx.autograd.record():
out = layer(inputs, states)
if states is not None:
out = layer(inputs, states)
assert isinstance(out, (list, tuple)) and len(out) == 2
out = out[0]
else:
out = layer(inputs)
assert isinstance(out, mx.nd.NDArray)
out.backward()

layer(inputs, states) # test is_training = false
if states is not None:
layer(inputs, states) # test is_training = false
else:
layer(inputs)

if not run_only:
mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3, atol=1e-5)
Expand Down

0 comments on commit 1abaa13

Please sign in to comment.