Skip to content

Commit

Permalink
[WIP][MXNET-107] Fused LSTM implementation for CPU (apache#10104)
Browse files Browse the repository at this point in the history
* register RNN fused-API with nnvm, finish single-layer && undirection LSTM forward function

* fix coding style and lint complains

* add single-layer && undirectional LSTM backward function

* make interface universal for other RNN mode

* share intermediate result between forward and backward in a trick way

* add comments for important parameters

* modify testcase

* Fix coding style and error message

* fix openmp collapse error

* fix const

* remove rnn.cu and skip related testcases temporarily for building on GPU

* support multi-layer and bidirectional for lstm inference

* remove some testcaseS in test_gluon_rnn.py to build on GPU

* remove testcase between fp32 and fp64 temporarily

* retrigger ci

* fix some logs

* use a better way to share memory

* fix cudnn registration

* fix invariant calculations and enable some gpu testcases

* add thread local cache for cudnn rnn op

* add thread local cache for rnn op

* fix bugs

* remove some testcases to check segmentfault

* remove cudnn registeration to check segmentfault

* support multi-layer for LSTM Training

* modify lstm testcase

* add bidirectional support for lstm

* fix gluon and coding style

* fix bugs

* remove nnvm registration

* enable gpu testcases

* add detailed descriptions

* add dropout check

* fix workspace size

* dropout is not supported, add unit test for it

* fix review comments
  • Loading branch information
chenchu-zs authored and piiswrong committed May 14, 2018
1 parent 41c6493 commit b8f2625
Show file tree
Hide file tree
Showing 7 changed files with 991 additions and 245 deletions.
4 changes: 1 addition & 3 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from __future__ import print_function
__all__ = ['RNN', 'LSTM', 'GRU']

from ...autograd import is_training
from ... import ndarray
from .. import Block
from . import rnn_cell
Expand Down Expand Up @@ -186,8 +185,7 @@ def forward(self, inputs, states=None):
for i in range(self._dir):
self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
self.i2h_weight[i]._finish_deferred_init()
if inputs.context.device_type == 'gpu' or \
(not is_training() and self._mode == 'lstm'):
if inputs.context.device_type == 'gpu' or self._mode == 'lstm':
out = self._forward_kernel(inputs, states)
else:
out = self._forward(inputs, states)
Expand Down
3 changes: 2 additions & 1 deletion src/operator/cudnn_rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace mxnet {
namespace op {
#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
template<typename DType>
class CuDNNRNNOp : public Operator {
class CuDNNRNNOp : public Operator{
public:
explicit CuDNNRNNOp(RNNParam param) {
this->param_ = param;
Expand Down Expand Up @@ -101,6 +101,7 @@ class CuDNNRNNOp : public Operator {
CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropout_desc_));
Storage::Get()->Free(dropout_states_);
Storage::Get()->Free(reserve_space_);
init_cudnn_ = false;
}
}

Expand Down
Loading

0 comments on commit b8f2625

Please sign in to comment.