Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
make gluon rnn layers hybrid blocks WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Jul 19, 2018
1 parent 9b30af2 commit de75d33
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 57 deletions.
106 changes: 59 additions & 47 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@
from __future__ import print_function
__all__ = ['RNN', 'LSTM', 'GRU']

from ... import ndarray
from .. import Block
from ... import ndarray, symbol
from .. import HybridBlock, tensor_types
from . import rnn_cell


class _RNNLayer(Block):
class _RNNLayer(HybridBlock):
"""Implementation of recurrent layers."""
def __init__(self, hidden_size, num_layers, layout,
dropout, bidirectional, input_size,
Expand Down Expand Up @@ -78,7 +77,13 @@ def __init__(self, hidden_size, num_layers, layout,
allow_deferred_init=True))
ni = nh * self._dir

<<<<<<< HEAD
self._unfused = self._unfuse()
=======
for param_list in [self.i2h_weight, self.h2h_weight, self.i2h_bias, self.h2h_bias]:
for p in param_list:
self._reg_params[p.name] = p
>>>>>>> make gluon rnn layers hybrid blocks WIP

def __repr__(self):
s = '{name}({mapping}, {_layout}'
Expand All @@ -98,8 +103,15 @@ def __repr__(self):
def state_info(self, batch_size=0):
raise NotImplementedError

def _unfuse(self):
"""Unfuses the fused RNN in to a stack of rnn cells."""
def unfuse(self):
"""Unfuses the fused RNN in to a stack of rnn cells.
Returns
-------
cell : SequentialRNNCell
A sequential RNN cell that replicates the structure of the RNN layer, with shared
weights.
"""
get_cell = {'rnn_relu': lambda **kwargs: rnn_cell.RNNCell(self._hidden_size,
activation='relu',
**kwargs),
Expand Down Expand Up @@ -169,63 +181,63 @@ def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs):
states.append(func(name='%sh0_%d'%(self.prefix, i), **info))
return states

def forward(self, inputs, states=None):
batch_size = inputs.shape[self._layout.find('N')]
def hybrid_forward(self, F, inputs, states=None, **kwargs):
if F is ndarray:
batch_size = inputs.shape[self._layout.find('N')]
if self._input_size == 0:
for i in range(self._dir):
self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
self.i2h_weight[i]._finish_deferred_init()
skip_states = states is None
if skip_states:
states = self.begin_state(batch_size, ctx=inputs.context)
if isinstance(states, ndarray.NDArray):
if F is ndarray:
states = self.begin_state(batch_size, ctx=inputs.context)
else:
states = self.begin_state(0, func=symbol.zeros)
if isinstance(states, tensor_types):
states = [states]
for state, info in zip(states, self.state_info(batch_size)):
if state.shape != info['shape']:
raise ValueError(
"Invalid recurrent state shape. Expecting %s, got %s."%(
str(info['shape']), str(state.shape)))
if self._input_size == 0:
for i in range(self._dir):
self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
self.i2h_weight[i]._finish_deferred_init()
out = self._forward_kernel(inputs, states)
if F is ndarray:
for state, info in zip(states, self.state_info(batch_size)):
if state.shape != info['shape']:
raise ValueError(
"Invalid recurrent state shape. Expecting %s, got %s."%(
str(info['shape']), str(state.shape)))
out = self._forward_kernel(F, inputs, states, **kwargs)

# out is (output, state)
return out[0] if skip_states else out

def _forward(self, inputs, states):
"""forward using gluon cell"""
ns = len(states)
axis = self._layout.find('T')
states = sum(zip(*((j for j in i) for i in states)), ())
outputs, states = self._unfused.unroll(
inputs.shape[axis], inputs, states,
layout=self._layout, merge_outputs=True)
new_states = []
for i in range(ns):
state = ndarray.concat(*(j.reshape((1,)+j.shape) for j in states[i::ns]), dim=0)
new_states.append(state)

return outputs, new_states

def _forward_kernel(self, inputs, states):
def __call__(self, inputs, *states):
if self._input_size == 0 and isinstance(inputs, ndarray.NDArray):
for i in range(self._dir):
self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
self.i2h_weight[i]._finish_deferred_init()
states = list(filter(lambda x: x is not None, states))
return super(_RNNLayer, self).__call__(inputs, *states)

def _forward_kernel(self, F, inputs, states, **kwargs):
""" forward using CUDNN or CPU kenrel"""
if self._layout == 'NTC':
inputs = ndarray.swapaxes(inputs, dim1=0, dim2=1)
ctx = inputs.context
params = sum(zip(self.i2h_weight, self.h2h_weight), ())
params += sum(zip(self.i2h_bias, self.h2h_bias), ())
params = (i.data(ctx).reshape((-1,)) for i in params)
params = ndarray.concat(*params, dim=0)

rnn = ndarray.RNN(inputs, params, *states, state_size=self._hidden_size,
num_layers=self._num_layers, bidirectional=self._dir == 2,
p=self._dropout, state_outputs=True, mode=self._mode)
inputs = F.swapaxes(inputs, dim1=0, dim2=1)
prefix = self._prefix[:-1] if self._prefix[-1] == '_' else self._prefix
params = (kwargs['{}_{}{}_{}_{}'.format(prefix, j, i, c, p)].reshape((-1,))
for p in ['weight', 'bias']
for c in ['i2h', 'h2h']
for i in range(self._num_layers)
for j in (['l', 'r'] if self._dir == 2 else ['l']))
params = F.concat(*params, dim=0)

rnn = F.RNN(inputs, params, *states, state_size=self._hidden_size,
num_layers=self._num_layers, bidirectional=self._dir == 2,
p=self._dropout, state_outputs=True, mode=self._mode)

if self._mode == 'lstm':
outputs, states = rnn[0], [rnn[1], rnn[2]]
else:
outputs, states = rnn[0], [rnn[1]]

if self._layout == 'NTC':
outputs = ndarray.swapaxes(outputs, dim1=0, dim2=1)
outputs = F.swapaxes(outputs, dim1=0, dim2=1)

return outputs, states

Expand Down
6 changes: 3 additions & 3 deletions src/operator/rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ Operator *RNNProp::CreateOperatorEx(Context ctx,
DMLC_REGISTER_PARAMETER(RNNParam);

MXNET_REGISTER_OP_PROPERTY(RNN, RNNProp)
.describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are
.describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are
implemented, with both multi-layer and bidirectional support.
**Vanilla RNN**
Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported:
Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported:
ReLU and Tanh.
With ReLU activation function:
Expand All @@ -63,7 +63,7 @@ With Tanh activtion function:
.. math::
h_t = \tanh(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh})
Reference paper: Finding structure in time - Elman, 1988.
Reference paper: Finding structure in time - Elman, 1988.
https://crl.ucsd.edu/~elman/Papers/fsit.pdf
**LSTM**
Expand Down
32 changes: 25 additions & 7 deletions tests/python/unittest/test_gluon_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,12 @@ def check_rnn_layer_forward(layer, inputs, states=None, run_only=False):
layer.collect_params().initialize()
inputs.attach_grad()
with mx.autograd.record():
out = layer(inputs, states)
if states is None:
out = layer(inputs)
else:
out = layer(inputs, states)
if states is not None:
assert isinstance(out, tuple) and len(out) == 2
assert isinstance(out, (list, tuple)) and len(out) == 2
out = out[0]
else:
assert isinstance(out, mx.nd.NDArray)
Expand All @@ -410,15 +413,19 @@ def check_rnn_layer_forward(layer, inputs, states=None, run_only=False):
layer.hybridize()

with mx.autograd.record():
out = layer(inputs, states)
if states is not None:
assert isinstance(out, tuple) and len(out) == 2
out = layer(inputs, states)
assert isinstance(out, (list, tuple)) and len(out) == 2
out = out[0]
else:
out = layer(inputs)
assert isinstance(out, mx.nd.NDArray)
out.backward()

layer(inputs, states) # test is_training = false
if states is not None:
layer(inputs, states) # test is_training = false
else:
layer(inputs)

if not run_only:
mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3, atol=1e-5)
Expand Down Expand Up @@ -448,15 +455,26 @@ def test_rnn_layers():
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, dropout=0.5),
mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), run_only=True)

net = gluon.nn.Sequential()
net.add(gluon.rnn.LSTM(10, 2, bidirectional=True))
net = gluon.nn.HybridSequential()
net.add(gluon.rnn.LSTM(10, bidirectional=True))
net.add(gluon.nn.BatchNorm(axis=2))
net.add(gluon.nn.Flatten())
net.add(gluon.nn.Dense(3, activation='relu'))
net.hybridize()
net.collect_params().initialize()
with mx.autograd.record():
net(mx.nd.ones((2, 3, 10))).backward()

net2 = gluon.nn.HybridSequential()
net2.add(gluon.rnn.LSTM(10, bidirectional=True))
net2.add(gluon.nn.BatchNorm(axis=2))
net2.add(gluon.nn.Flatten())
net2.add(gluon.nn.Dense(3, activation='relu'))
net2.hybridize()
net2.collect_params().initialize()
with mx.autograd.record():
net2(mx.nd.ones((2, 3, 10))).backward()


def test_rnn_unroll_variant_length():
# Test for imperative usage
Expand Down

0 comments on commit de75d33

Please sign in to comment.