Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add vRNN and dropout (#11399)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hao Li authored and piiswrong committed Jun 26, 2018
1 parent e494cee commit 0538ad9
Show file tree
Hide file tree
Showing 4 changed files with 1,099 additions and 54 deletions.
16 changes: 15 additions & 1 deletion example/rnn/bucketing/cudnn_rnn_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
parser.add_argument('--dropout', type=float, default='0.0',
help='dropout probability (1.0 - keep probability)')
parser.add_argument('--rnntype', type=str, default='lstm',
help='rnn type: gru and lstm are supported')
help='rnn type: gru, lstm, rnn_tanh and rnn_relu are supported')

#buckets = [32]
buckets = [10, 20, 30, 40, 50, 60]
Expand Down Expand Up @@ -188,6 +188,20 @@ def test(args):
cell,
mx.rnn.GRUCell(num_hidden=args.num_hidden, prefix='%s_%dr0_'%(args.rnntype,i)),
output_prefix='bi_%s_%d'%(args.rnntype,i))
elif args.rnntype == 'rnn_tanh':
cell = mx.rnn.RNNCell(num_hidden=args.num_hidden, activation='tanh', prefix='%s_%dl0_'%(args.rnntype,i))
if args.bidirectional:
cell = mx.rnn.BidirectionalCell(
cell,
mx.rnn.RNNCell(num_hidden=args.num_hidden, activation='tanh', prefix='%s_%dr0_'%(args.rnntype,i)),
output_prefix='bi_%s_%d'%(args.rnntype,i))
elif args.rnntype == 'rnn_relu':
cell = mx.rnn.RNNCell(num_hidden=args.num_hidden, activation='relu', prefix='%s_%dl0_'%(args.rnntype,i))
if args.bidirectional:
cell = mx.rnn.BidirectionalCell(
cell,
mx.rnn.RNNCell(num_hidden=args.num_hidden, activation='relu', prefix='%s_%dr0_'%(args.rnntype,i)),
output_prefix='bi_%s_%d'%(args.rnntype,i))

stack.add(cell)

Expand Down
74 changes: 43 additions & 31 deletions src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,17 @@ inline size_t GetRNNWorkspaceSize(int seq_length,
int mode) {
size_t size = 0;
switch (mode) {
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
break;
case rnn_enum::kLstm:
size = (seq_length + 1) * batch_size * hidden_size * 4 + batch_size * hidden_size * 2
+ seq_length * batch_size * hidden_size * direction + hidden_size * seq_length * 8;
break;
case rnn_enum::kGru:
size = seq_length * batch_size * hidden_size * direction * 4 + batch_size * hidden_size * 8;
break;
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
size = seq_length * batch_size * hidden_size * direction * 2 + batch_size * hidden_size * 4;
break;
default:
LOG(FATAL) << "unknown RNN mode " << mode;
break;
Expand All @@ -125,18 +125,20 @@ inline size_t GetRNNReserveSpaceSize(int num_layer,
int mode) {
size_t size = 0;
switch (mode) {
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
break;
case rnn_enum::kLstm:
size = num_layer * direction * seq_length * batch_size * hidden_size * 6;
size = direction * seq_length * batch_size * hidden_size * (num_layer * 7 - 1);
break;
case rnn_enum::kGru:
size = seq_length * batch_size * hidden_size * direction * num_layer * 8 +
size = seq_length * batch_size * hidden_size * direction * (num_layer * 9 - 1) +
batch_size * hidden_size * direction * 9 + hidden_size * seq_length * 6 +
seq_length * batch_size * 7 * hidden_size * direction;
break;
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
size = seq_length * batch_size * hidden_size * direction * (num_layer * 6 - 1) +
batch_size * hidden_size * direction * 3 + hidden_size * seq_length * 2 +
seq_length * batch_size * 2 * hidden_size * direction;
break;
default:
LOG(FATAL) << "unknown RNN mode " << mode;
break;
Expand Down Expand Up @@ -223,21 +225,24 @@ void RNNForwardTraining(DType* ws,
DType* y_ptr,
DType* hy_ptr,
DType* cy_ptr,
const float dropout,
int mode) {
switch (mode) {
case rnn_enum::kRnnTanh:
case rnn_enum::kRnnRelu:
LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
break;
case rnn_enum::kLstm:
LstmForwardTraining<DType>(ws, rs, state_outputs, num_layers, direction, seq_length,
batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr,
w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr);
w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, dropout);
break;
case rnn_enum::kGru:
GruForwardTraining<DType>(ws, rs, state_outputs, num_layers, direction, seq_length,
batch_size, input_size, state_size, x_ptr, hx_ptr,
w_ptr, y_ptr, hy_ptr);
w_ptr, y_ptr, hy_ptr, dropout);
break;
case rnn_enum::kRnnTanh:
case rnn_enum::kRnnRelu:
VanillaRNNForwardTraining<DType>(ws, rs, state_outputs, num_layers, direction, seq_length,
batch_size, input_size, state_size, x_ptr, hx_ptr,
w_ptr, y_ptr, hy_ptr, dropout, mode);
break;
default:
LOG(FATAL) << "unknown RNN mode " << mode;
Expand All @@ -264,10 +269,6 @@ void RNNForwardInference(DType* ws,
DType* cy_ptr,
int mode) {
switch (mode) {
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
break;
case rnn_enum::kLstm:
LstmForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr,
Expand All @@ -278,6 +279,12 @@ void RNNForwardInference(DType* ws,
batch_size, input_size, state_size, x_ptr, hx_ptr,
w_ptr, y_ptr, hy_ptr);
break;
case rnn_enum::kRnnTanh:
case rnn_enum::kRnnRelu:
VanillaRNNForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
batch_size, input_size, state_size, x_ptr, hx_ptr,
w_ptr, y_ptr, hy_ptr, mode);
break;
default:
LOG(FATAL) << "unknown RNN mode" << mode;
break;
Expand Down Expand Up @@ -310,22 +317,27 @@ void RNNBackward(DType* ws,
int req_params,
int req_state,
int req_statecell,
const float dropout,
int mode) {
switch (mode) {
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
break;
case rnn_enum::kLstm:
LstmBackward<DType>(ws, rs, num_layers, direction, seq_length, batch_size,
input_size, state_size, x_ptr, hx_ptr, cx_ptr, w_ptr, y_ptr,
dy_ptr, dhy_ptr, dcy_ptr, dx_ptr, dhx_ptr, dcx_ptr, dw_ptr, db_ptr,
req_data, req_params, req_state, req_statecell);
req_data, req_params, req_state, req_statecell, dropout);
break;
case rnn_enum::kGru:
GruBackward<DType>(ws, rs, num_layers, direction, seq_length, batch_size,
input_size, state_size, x_ptr, hx_ptr, w_ptr,
dy_ptr, dhy_ptr, dx_ptr, dhx_ptr, dw_ptr,
req_data, req_params, req_state);
req_data, req_params, req_state, dropout);
break;
case rnn_enum::kRnnTanh:
case rnn_enum::kRnnRelu:
VanillaRNNBackward<DType>(ws, rs, num_layers, direction, seq_length, batch_size,
input_size, state_size, x_ptr, hx_ptr, w_ptr,
dy_ptr, dhy_ptr, dx_ptr, dhx_ptr, dw_ptr,
req_data, req_params, req_state, dropout, mode);
break;
default:
LOG(FATAL) << "unknown RNN mode" << mode;
Expand Down Expand Up @@ -354,9 +366,8 @@ class RNNOp : public Operator{
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK(param_.mode == rnn_enum::kLstm || param_.mode == rnn_enum::kGru)
<< "Only lstm and gru mode are supported at the moment.";
CHECK_EQ(param_.p, 0) << "Dropout is not supported at the moment.";
CHECK(param_.p >= 0.0f && param_.p < 1.0f)
<< "unsupported dropout value, should be 0 <= dropout < 1";

size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
Expand Down Expand Up @@ -436,6 +447,7 @@ class RNNOp : public Operator{
y.dptr_,
hy_ptr,
cy_ptr,
param_.p,
param_.mode);
} else {
RNNForwardInference<DType>(workspace.dptr_,
Expand Down Expand Up @@ -467,9 +479,8 @@ class RNNOp : public Operator{
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK(param_.mode == rnn_enum::kLstm || param_.mode == rnn_enum::kGru)
<< "Only lstm and gru mode are supported at the moment.";
CHECK_EQ(param_.p, 0) << "Dropout is not supported at the moment.";
CHECK(param_.p >= 0.0f && param_.p < 1.0f)
<< "unsupported dropout value, should be 0 <= dropout < 1";

size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
Expand Down Expand Up @@ -566,6 +577,7 @@ class RNNOp : public Operator{
req[rnn_enum::kParams],
req[rnn_enum::kState],
req[rnn_enum::kStateCell],
param_.p,
param_.mode);
}

Expand Down
Loading

0 comments on commit 0538ad9

Please sign in to comment.