Skip to content

Commit

Permalink
make gluon rnn layers hybrid blocks (apache#11482)
Browse files Browse the repository at this point in the history
* make Gluon RNN layer hybrid block

* separate gluon gpu tests

* remove excess assert_raises_cudnn_disabled usage

* add comments and refactor

* add bidirectional test

* temporarily remove hybridize in test_gluon_rnn.test_layer_fill_shape
  • Loading branch information
szha authored and aaronmarkham committed Aug 6, 2018
1 parent ce3b2ce commit 2d11bad
Show file tree
Hide file tree
Showing 7 changed files with 449 additions and 238 deletions.
132 changes: 62 additions & 70 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@
from __future__ import print_function
__all__ = ['RNN', 'LSTM', 'GRU']

from ... import ndarray
from .. import Block
from ... import ndarray, symbol
from .. import HybridBlock, tensor_types
from . import rnn_cell


class _RNNLayer(Block):
class _RNNLayer(HybridBlock):
"""Implementation of recurrent layers."""
def __init__(self, hidden_size, num_layers, layout,
dropout, bidirectional, input_size,
Expand All @@ -52,33 +51,28 @@ def __init__(self, hidden_size, num_layers, layout,

self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]

self.i2h_weight = []
self.h2h_weight = []
self.i2h_bias = []
self.h2h_bias = []

ng, ni, nh = self._gates, input_size, hidden_size
for i in range(num_layers):
for j in (['l', 'r'] if self._dir == 2 else ['l']):
self.i2h_weight.append(
self.params.get('%s%d_i2h_weight'%(j, i), shape=(ng*nh, ni),
init=i2h_weight_initializer,
allow_deferred_init=True))
self.h2h_weight.append(
self.params.get('%s%d_h2h_weight'%(j, i), shape=(ng*nh, nh),
init=h2h_weight_initializer,
allow_deferred_init=True))
self.i2h_bias.append(
self.params.get('%s%d_i2h_bias'%(j, i), shape=(ng*nh,),
init=i2h_bias_initializer,
allow_deferred_init=True))
self.h2h_bias.append(
self.params.get('%s%d_h2h_bias'%(j, i), shape=(ng*nh,),
init=h2h_bias_initializer,
allow_deferred_init=True))
for j in ['l', 'r'][:self._dir]:
self._register_param('{}{}_i2h_weight'.format(j, i),
shape=(ng*nh, ni),
init=i2h_weight_initializer)
self._register_param('{}{}_h2h_weight'.format(j, i),
shape=(ng*nh, nh),
init=h2h_weight_initializer)
self._register_param('{}{}_i2h_bias'.format(j, i),
shape=(ng*nh,),
init=i2h_bias_initializer)
self._register_param('{}{}_h2h_bias'.format(j, i),
shape=(ng*nh,),
init=h2h_bias_initializer)
ni = nh * self._dir

self._unfused = self._unfuse()
def _register_param(self, name, shape, init):
p = self.params.get(name, shape=shape, init=init,
allow_deferred_init=True)
setattr(self, name, p)
return p

def __repr__(self):
s = '{name}({mapping}, {_layout}'
Expand All @@ -89,12 +83,23 @@ def __repr__(self):
if self._dir == 2:
s += ', bidirectional'
s += ')'
shape = self.i2h_weight[0].shape
shape = self.l0_i2h_weight.shape
mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0] // self._gates)
return s.format(name=self.__class__.__name__,
mapping=mapping,
**self.__dict__)

def _collect_params_with_prefix(self, prefix=''):
if prefix:
prefix += '.'
def convert_key(key): # for compatibility with old parameter format
key = key.split('_')
return '_unfused.{}.{}_cell.{}'.format(key[0][1:], key[0][0], '_'.join(key[1:]))
ret = {prefix + convert_key(key) : val for key, val in self._reg_params.items()}
for name, child in self._children.items():
ret.update(child._collect_params_with_prefix(prefix + name))
return ret

def state_info(self, batch_size=0):
raise NotImplementedError

Expand All @@ -111,7 +116,7 @@ def _unfuse(self):
'gru': lambda **kwargs: rnn_cell.GRUCell(self._hidden_size,
**kwargs)}[self._mode]

stack = rnn_cell.SequentialRNNCell(prefix=self.prefix, params=self.params)
stack = rnn_cell.HybridSequentialRNNCell(prefix=self.prefix, params=self.params)
with stack.name_scope():
ni = self._input_size
for i in range(self._num_layers):
Expand Down Expand Up @@ -169,63 +174,50 @@ def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs):
states.append(func(name='%sh0_%d'%(self.prefix, i), **info))
return states

def forward(self, inputs, states=None):
batch_size = inputs.shape[self._layout.find('N')]
def hybrid_forward(self, F, inputs, states=None, **kwargs):
if F is ndarray:
batch_size = inputs.shape[self._layout.find('N')]
skip_states = states is None
if skip_states:
states = self.begin_state(batch_size, ctx=inputs.context)
if isinstance(states, ndarray.NDArray):
if F is ndarray:
states = self.begin_state(batch_size, ctx=inputs.context)
else:
states = self.begin_state(0, func=symbol.zeros)
if isinstance(states, tensor_types):
states = [states]
for state, info in zip(states, self.state_info(batch_size)):
if state.shape != info['shape']:
raise ValueError(
"Invalid recurrent state shape. Expecting %s, got %s."%(
str(info['shape']), str(state.shape)))
if self._input_size == 0:
for i in range(self._dir):
self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
self.i2h_weight[i]._finish_deferred_init()
out = self._forward_kernel(inputs, states)
if F is ndarray:
for state, info in zip(states, self.state_info(batch_size)):
if state.shape != info['shape']:
raise ValueError(
"Invalid recurrent state shape. Expecting %s, got %s."%(
str(info['shape']), str(state.shape)))
out = self._forward_kernel(F, inputs, states, **kwargs)

# out is (output, state)
return out[0] if skip_states else out

def _forward(self, inputs, states):
"""forward using gluon cell"""
ns = len(states)
axis = self._layout.find('T')
states = sum(zip(*((j for j in i) for i in states)), ())
outputs, states = self._unfused.unroll(
inputs.shape[axis], inputs, states,
layout=self._layout, merge_outputs=True)
new_states = []
for i in range(ns):
state = ndarray.concat(*(j.reshape((1,)+j.shape) for j in states[i::ns]), dim=0)
new_states.append(state)

return outputs, new_states

def _forward_kernel(self, inputs, states):
def _forward_kernel(self, F, inputs, states, **kwargs):
""" forward using CUDNN or CPU kenrel"""
if self._layout == 'NTC':
inputs = ndarray.swapaxes(inputs, dim1=0, dim2=1)
ctx = inputs.context
params = sum(zip(self.i2h_weight, self.h2h_weight), ())
params += sum(zip(self.i2h_bias, self.h2h_bias), ())
params = (i.data(ctx).reshape((-1,)) for i in params)
params = ndarray.concat(*params, dim=0)

rnn = ndarray.RNN(inputs, params, *states, state_size=self._hidden_size,
num_layers=self._num_layers, bidirectional=self._dir == 2,
p=self._dropout, state_outputs=True, mode=self._mode)
inputs = F.swapaxes(inputs, dim1=0, dim2=1)
params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1)
for t in ['weight', 'bias']
for l in range(self._num_layers)
for d in ['l', 'r'][:self._dir]
for g in ['i2h', 'h2h'])
params = F._internal._rnn_param_concat(*params, dim=0)

rnn = F.RNN(inputs, params, *states, state_size=self._hidden_size,
num_layers=self._num_layers, bidirectional=self._dir == 2,
p=self._dropout, state_outputs=True, mode=self._mode)

if self._mode == 'lstm':
outputs, states = rnn[0], [rnn[1], rnn[2]]
else:
outputs, states = rnn[0], [rnn[1]]

if self._layout == 'NTC':
outputs = ndarray.swapaxes(outputs, dim1=0, dim2=1)
outputs = F.swapaxes(outputs, dim1=0, dim2=1)

return outputs, states

Expand Down
127 changes: 102 additions & 25 deletions src/operator/nn/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,65 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs,
return dshape.Size() != 0;
}

// Concat for RNN param deals with the reverse shape inference from output
// for the special case of concatenating RNN parameters.
// The first (and sometimes the second) input may be unknown on the target axis.
// If the two inputs are unknown, they always have the same shape.
static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
using namespace mshadow;
const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
TShape dshape;
index_t size = 0;
int num_zero = 0;
int axis = -1;
for (int i = 0; i < param_.num_args; ++i) {
TShape tmp = (*in_shape)[i];
if (tmp.ndim()) {
axis = CheckAxis(param_.dim, tmp.ndim());
num_zero += tmp[axis] == 0;
size += tmp[axis];
tmp[axis] = 0;
shape_assign(&dshape, tmp);
}
}

TShape tmp = (*out_shape)[0];
if (tmp.ndim()) {
axis = CheckAxis(param_.dim, tmp.ndim());
tmp[axis] = 0;
shape_assign(&dshape, tmp);
}

if (dshape.ndim() == 0) return false;

for (int i = 0; i < param_.num_args; ++i) {
CHECK(shape_assign(&(*in_shape)[i], dshape))
<< "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
}

if (!num_zero) dshape[axis] = size;
CHECK(shape_assign(&(*out_shape)[0], dshape))
<< "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
if ((*out_shape)[0][axis] != 0 && num_zero) {
int residual = (*out_shape)[0][axis] - size;
CHECK_GE(residual, 0)
<< "Input size already exceeds output size. Residual: " << residual;
CHECK(num_zero <= 2 && num_zero >= 0)
<< "Expecting 1 or 2 inputs that need shape inference. Got: " << num_zero;
bool need_infer = !(*out_shape)[0].Size();
for (int i = 0; i < num_zero; i++) {
(*in_shape)[i*2][axis] = residual / num_zero;
need_infer = need_infer || !(*in_shape)[i].Size();
}
return !need_infer;
}

return dshape.Size() != 0;
}

static bool ConcatType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type) {
Expand Down Expand Up @@ -228,6 +287,34 @@ struct ConcatGrad {

DMLC_REGISTER_PARAMETER(ConcatParam);

#define CONCAT_FORWARD_ATTRS \
.set_num_inputs([](const NodeAttrs& attrs) { \
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed); \
return params.num_args; \
}) \
.set_num_outputs(1) \
.set_attr_parser(ParamParser<ConcatParam>) \
.set_attr<nnvm::FListInputNames>("FListInputNames", \
[](const NodeAttrs& attrs) { \
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed); \
std::vector<std::string> ret; \
for (int i = 0; i < params.num_args; ++i) { \
ret.push_back(std::string("arg") + std::to_string(i)); \
} \
return ret; \
}) \
.set_attr<nnvm::FListOutputNames>("FListOutputNames", \
[](const NodeAttrs& attrs) { \
return std::vector<std::string>{"output"}; \
}) \
.set_attr<nnvm::FInferType>("FInferType", ConcatType) \
.set_attr<FInferStorageType>("FInferStorageType", ConcatForwardInferStorageType) \
.set_attr<FCompute>("FCompute<cpu>", ConcatCompute<cpu>) \
.set_attr<FComputeEx>("FComputeEx<cpu>", ConcatComputeExCPU) \
.set_attr<nnvm::FGradient>("FGradient", ConcatGrad{"_backward_Concat"}) \
.set_attr<std::string>("key_var_num_args", "num_args")


NNVM_REGISTER_OP(Concat)
MXNET_ADD_SPARSE_OP_ALIAS(concat)
.add_alias("concat")
Expand Down Expand Up @@ -268,37 +355,13 @@ Example::
[ 5., 5., 8., 8.]]
)code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
return params.num_args;
})
.set_num_outputs(1)
.set_attr_parser(ParamParser<ConcatParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
std::vector<std::string> ret;
for (int i = 0; i < params.num_args; ++i) {
ret.push_back(std::string("arg") + std::to_string(i));
}
return ret;
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
#if MXNET_USE_MKLDNN == 1
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#endif
CONCAT_FORWARD_ATTRS
.set_attr<nnvm::FInferShape>("FInferShape", ConcatShape)
.set_attr<nnvm::FInferType>("FInferType", ConcatType)
.set_attr<FInferStorageType>("FInferStorageType", ConcatForwardInferStorageType)
.set_attr<FCompute>("FCompute<cpu>", ConcatCompute<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", ConcatComputeExCPU)
.set_attr<nnvm::FGradient>("FGradient", ConcatGrad{"_backward_Concat"})
.set_attr<std::string>("key_var_num_args", "num_args")
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
.add_arguments(ConcatParam::__FIELDS__());

Expand All @@ -320,5 +383,19 @@ NNVM_REGISTER_OP(_backward_Concat)
#endif
.set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);

// _rnn_param_concat is a custom concat op with specialized infer_shape,
// which handles the case where the first one or two inputs may have
// unknown shape that can be inferred from output shape.
NNVM_REGISTER_OP(_rnn_param_concat)
#if MXNET_USE_MKLDNN == 1
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#endif
CONCAT_FORWARD_ATTRS
.set_attr<nnvm::FInferShape>("FInferShape", RNNParamConcatShape)
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
.add_arguments(ConcatParam::__FIELDS__());

} // namespace op
} // namespace mxnet
4 changes: 4 additions & 0 deletions src/operator/nn/concat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ NNVM_REGISTER_OP(Concat)
.set_attr<FCompute>("FCompute<gpu>", ConcatCompute<gpu>)
.set_attr<FComputeEx>("FComputeEx<gpu>", ConcatComputeExGPU);

NNVM_REGISTER_OP(_rnn_param_concat)
.set_attr<FCompute>("FCompute<gpu>", ConcatCompute<gpu>)
.set_attr<FComputeEx>("FComputeEx<gpu>", ConcatComputeExGPU);

NNVM_REGISTER_OP(_backward_Concat)
.set_attr<FCompute>("FCompute<gpu>", ConcatGradCompute<gpu>);

Expand Down
6 changes: 3 additions & 3 deletions src/operator/rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ Operator *RNNProp::CreateOperatorEx(Context ctx,
DMLC_REGISTER_PARAMETER(RNNParam);

MXNET_REGISTER_OP_PROPERTY(RNN, RNNProp)
.describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are
.describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are
implemented, with both multi-layer and bidirectional support.
**Vanilla RNN**
Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported:
Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported:
ReLU and Tanh.
With ReLU activation function:
Expand All @@ -63,7 +63,7 @@ With Tanh activtion function:
.. math::
h_t = \tanh(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh})
Reference paper: Finding structure in time - Elman, 1988.
Reference paper: Finding structure in time - Elman, 1988.
https://crl.ucsd.edu/~elman/Papers/fsit.pdf
**LSTM**
Expand Down
Loading

0 comments on commit 2d11bad

Please sign in to comment.