Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
make Gluon RNN layer hybrid block
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Jul 27, 2018
1 parent 9b30af2 commit 03ea58d
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 120 deletions.
165 changes: 60 additions & 105 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@
from __future__ import print_function
__all__ = ['RNN', 'LSTM', 'GRU']

from ... import ndarray
from .. import Block
from . import rnn_cell
from ... import ndarray, symbol
from .. import HybridBlock, tensor_types


class _RNNLayer(Block):
class _RNNLayer(HybridBlock):
"""Implementation of recurrent layers."""
def __init__(self, hidden_size, num_layers, layout,
dropout, bidirectional, input_size,
Expand All @@ -52,33 +50,28 @@ def __init__(self, hidden_size, num_layers, layout,

self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]

self.i2h_weight = []
self.h2h_weight = []
self.i2h_bias = []
self.h2h_bias = []

ng, ni, nh = self._gates, input_size, hidden_size
for i in range(num_layers):
for j in (['l', 'r'] if self._dir == 2 else ['l']):
self.i2h_weight.append(
self.params.get('%s%d_i2h_weight'%(j, i), shape=(ng*nh, ni),
init=i2h_weight_initializer,
allow_deferred_init=True))
self.h2h_weight.append(
self.params.get('%s%d_h2h_weight'%(j, i), shape=(ng*nh, nh),
init=h2h_weight_initializer,
allow_deferred_init=True))
self.i2h_bias.append(
self.params.get('%s%d_i2h_bias'%(j, i), shape=(ng*nh,),
init=i2h_bias_initializer,
allow_deferred_init=True))
self.h2h_bias.append(
self.params.get('%s%d_h2h_bias'%(j, i), shape=(ng*nh,),
init=h2h_bias_initializer,
allow_deferred_init=True))
self._register_param('{}{}_i2h_weight'.format(j, i),
shape=(ng*nh, ni),
init=i2h_weight_initializer)
self._register_param('{}{}_h2h_weight'.format(j, i),
shape=(ng*nh, nh),
init=h2h_weight_initializer)
self._register_param('{}{}_i2h_bias'.format(j, i),
shape=(ng*nh,),
init=i2h_bias_initializer)
self._register_param('{}{}_h2h_bias'.format(j, i),
shape=(ng*nh,),
init=h2h_bias_initializer)
ni = nh * self._dir

self._unfused = self._unfuse()
def _register_param(self, name, shape, init):
p = self.params.get(name, shape=shape, init=init,
allow_deferred_init=True)
setattr(self, name, p)
return p

def __repr__(self):
s = '{name}({mapping}, {_layout}'
Expand All @@ -89,51 +82,26 @@ def __repr__(self):
if self._dir == 2:
s += ', bidirectional'
s += ')'
shape = self.i2h_weight[0].shape
shape = self.l0_i2h_weight.shape
mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0] // self._gates)
return s.format(name=self.__class__.__name__,
mapping=mapping,
**self.__dict__)

def _collect_params_with_prefix(self, prefix=''):
if prefix:
prefix += '.'
def convert_key(key): # for compatibility with old parameter format
key = key.split('_')
return '_unfused.{}.{}_cell.{}'.format(key[0][1:], key[0][0], '_'.join(key[1:]))
ret = {prefix + convert_key(key) : val for key, val in self._reg_params.items()}
for name, child in self._children.items():
ret.update(child._collect_params_with_prefix(prefix + name))
return ret

def state_info(self, batch_size=0):
raise NotImplementedError

def _unfuse(self):
"""Unfuses the fused RNN in to a stack of rnn cells."""
get_cell = {'rnn_relu': lambda **kwargs: rnn_cell.RNNCell(self._hidden_size,
activation='relu',
**kwargs),
'rnn_tanh': lambda **kwargs: rnn_cell.RNNCell(self._hidden_size,
activation='tanh',
**kwargs),
'lstm': lambda **kwargs: rnn_cell.LSTMCell(self._hidden_size,
**kwargs),
'gru': lambda **kwargs: rnn_cell.GRUCell(self._hidden_size,
**kwargs)}[self._mode]

stack = rnn_cell.SequentialRNNCell(prefix=self.prefix, params=self.params)
with stack.name_scope():
ni = self._input_size
for i in range(self._num_layers):
kwargs = {'input_size': ni,
'i2h_weight_initializer': self._i2h_weight_initializer,
'h2h_weight_initializer': self._h2h_weight_initializer,
'i2h_bias_initializer': self._i2h_bias_initializer,
'h2h_bias_initializer': self._h2h_bias_initializer}
if self._dir == 2:
stack.add(rnn_cell.BidirectionalCell(
get_cell(prefix='l%d_'%i, **kwargs),
get_cell(prefix='r%d_'%i, **kwargs)))
else:
stack.add(get_cell(prefix='l%d_'%i, **kwargs))

if self._dropout > 0 and i != self._num_layers - 1:
stack.add(rnn_cell.DropoutCell(self._dropout))

ni = self._hidden_size * self._dir

return stack

def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs):
"""Initial state for this cell.
Expand Down Expand Up @@ -169,63 +137,50 @@ def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs):
states.append(func(name='%sh0_%d'%(self.prefix, i), **info))
return states

def forward(self, inputs, states=None):
batch_size = inputs.shape[self._layout.find('N')]
def hybrid_forward(self, F, inputs, states=None, **kwargs):
if F is ndarray:
batch_size = inputs.shape[self._layout.find('N')]
skip_states = states is None
if skip_states:
states = self.begin_state(batch_size, ctx=inputs.context)
if isinstance(states, ndarray.NDArray):
if F is ndarray:
states = self.begin_state(batch_size, ctx=inputs.context)
else:
states = self.begin_state(0, func=symbol.zeros)
if isinstance(states, tensor_types):
states = [states]
for state, info in zip(states, self.state_info(batch_size)):
if state.shape != info['shape']:
raise ValueError(
"Invalid recurrent state shape. Expecting %s, got %s."%(
str(info['shape']), str(state.shape)))
if self._input_size == 0:
for i in range(self._dir):
self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
self.i2h_weight[i]._finish_deferred_init()
out = self._forward_kernel(inputs, states)
if F is ndarray:
for state, info in zip(states, self.state_info(batch_size)):
if state.shape != info['shape']:
raise ValueError(
"Invalid recurrent state shape. Expecting %s, got %s."%(
str(info['shape']), str(state.shape)))
out = self._forward_kernel(F, inputs, states, **kwargs)

# out is (output, state)
return out[0] if skip_states else out

def _forward(self, inputs, states):
"""forward using gluon cell"""
ns = len(states)
axis = self._layout.find('T')
states = sum(zip(*((j for j in i) for i in states)), ())
outputs, states = self._unfused.unroll(
inputs.shape[axis], inputs, states,
layout=self._layout, merge_outputs=True)
new_states = []
for i in range(ns):
state = ndarray.concat(*(j.reshape((1,)+j.shape) for j in states[i::ns]), dim=0)
new_states.append(state)

return outputs, new_states

def _forward_kernel(self, inputs, states):
def _forward_kernel(self, F, inputs, states, **kwargs):
""" forward using CUDNN or CPU kenrel"""
if self._layout == 'NTC':
inputs = ndarray.swapaxes(inputs, dim1=0, dim2=1)
ctx = inputs.context
params = sum(zip(self.i2h_weight, self.h2h_weight), ())
params += sum(zip(self.i2h_bias, self.h2h_bias), ())
params = (i.data(ctx).reshape((-1,)) for i in params)
params = ndarray.concat(*params, dim=0)

rnn = ndarray.RNN(inputs, params, *states, state_size=self._hidden_size,
num_layers=self._num_layers, bidirectional=self._dir == 2,
p=self._dropout, state_outputs=True, mode=self._mode)
inputs = F.swapaxes(inputs, dim1=0, dim2=1)
params = (kwargs['{}{}_{}_{}'.format(j, i, c, p)].reshape(-1)
for p in ['weight', 'bias']
for c in ['i2h', 'h2h']
for i in range(self._num_layers)
for j in (['l', 'r'] if self._dir == 2 else ['l']))
params = F._internal._rnn_param_concat(*params, dim=0)

rnn = F.RNN(inputs, params, *states, state_size=self._hidden_size,
num_layers=self._num_layers, bidirectional=self._dir == 2,
p=self._dropout, state_outputs=True, mode=self._mode)

if self._mode == 'lstm':
outputs, states = rnn[0], [rnn[1], rnn[2]]
else:
outputs, states = rnn[0], [rnn[1]]

if self._layout == 'NTC':
outputs = ndarray.swapaxes(outputs, dim1=0, dim2=1)
outputs = F.swapaxes(outputs, dim1=0, dim2=1)

return outputs, states

Expand Down
87 changes: 87 additions & 0 deletions src/operator/nn/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,57 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs,
return dshape.Size() != 0;
}

static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
using namespace mshadow;
const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
TShape dshape;
index_t size = 0;
int num_zero = 0;
int axis = -1;
for (int i = 0; i < param_.num_args; ++i) {
TShape tmp = (*in_shape)[i];
if (tmp.ndim()) {
axis = CheckAxis(param_.dim, tmp.ndim());
num_zero += tmp[axis] == 0;
size += tmp[axis];
tmp[axis] = 0;
shape_assign(&dshape, tmp);
}
}

TShape tmp = (*out_shape)[0];
if (tmp.ndim()) {
axis = CheckAxis(param_.dim, tmp.ndim());
tmp[axis] = 0;
shape_assign(&dshape, tmp);
}

if (dshape.ndim() == 0) return false;

for (int i = 0; i < param_.num_args; ++i) {
CHECK(shape_assign(&(*in_shape)[i], dshape))
<< "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
}

if (!num_zero) dshape[axis] = size;
CHECK(shape_assign(&(*out_shape)[0], dshape))
<< "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
index_t residual = (*out_shape)[0][axis] - size;
if ((*out_shape)[0].Size() != 0 && residual > 0 && num_zero) {
bool need_infer = false;
for (int i = 0; i < num_zero; i++) {
(*in_shape)[i][axis] = residual / num_zero;
need_infer = need_infer || (*in_shape)[i].Size() == 0;
}
return !need_infer;
}

return dshape.Size() != 0;
}

static bool ConcatType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type) {
Expand Down Expand Up @@ -320,5 +371,41 @@ NNVM_REGISTER_OP(_backward_Concat)
#endif
.set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);


NNVM_REGISTER_OP(_rnn_param_concat)
.set_num_inputs([](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
return params.num_args;
})
.set_num_outputs(1)
.set_attr_parser(ParamParser<ConcatParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
std::vector<std::string> ret;
for (int i = 0; i < params.num_args; ++i) {
ret.push_back(std::string("arg") + std::to_string(i));
}
return ret;
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
#if MXNET_USE_MKLDNN == 1
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#endif
.set_attr<nnvm::FInferShape>("FInferShape", RNNParamConcatShape)
.set_attr<nnvm::FInferType>("FInferType", ConcatType)
.set_attr<FInferStorageType>("FInferStorageType", ConcatForwardInferStorageType)
.set_attr<FCompute>("FCompute<cpu>", ConcatCompute<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", ConcatComputeExCPU)
.set_attr<nnvm::FGradient>("FGradient", ConcatGrad{"_backward_Concat"})
.set_attr<std::string>("key_var_num_args", "num_args")
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
.add_arguments(ConcatParam::__FIELDS__());

} // namespace op
} // namespace mxnet
4 changes: 4 additions & 0 deletions src/operator/nn/concat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ NNVM_REGISTER_OP(Concat)
.set_attr<FCompute>("FCompute<gpu>", ConcatCompute<gpu>)
.set_attr<FComputeEx>("FComputeEx<gpu>", ConcatComputeExGPU);

NNVM_REGISTER_OP(_rnn_param_concat)
.set_attr<FCompute>("FCompute<gpu>", ConcatCompute<gpu>)
.set_attr<FComputeEx>("FComputeEx<gpu>", ConcatComputeExGPU);

NNVM_REGISTER_OP(_backward_Concat)
.set_attr<FCompute>("FCompute<gpu>", ConcatGradCompute<gpu>);

Expand Down
6 changes: 3 additions & 3 deletions src/operator/rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ Operator *RNNProp::CreateOperatorEx(Context ctx,
DMLC_REGISTER_PARAMETER(RNNParam);

MXNET_REGISTER_OP_PROPERTY(RNN, RNNProp)
.describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are
.describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are
implemented, with both multi-layer and bidirectional support.
**Vanilla RNN**
Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported:
Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported:
ReLU and Tanh.
With ReLU activation function:
Expand All @@ -63,7 +63,7 @@ With Tanh activtion function:
.. math::
h_t = \tanh(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh})
Reference paper: Finding structure in time - Elman, 1988.
Reference paper: Finding structure in time - Elman, 1988.
https://crl.ucsd.edu/~elman/Papers/fsit.pdf
**LSTM**
Expand Down
Loading

0 comments on commit 03ea58d

Please sign in to comment.