diff --git a/example/rnn/bucketing/cudnn_lstm_bucketing.py b/example/rnn/bucketing/cudnn_rnn_bucketing.py similarity index 87% rename from example/rnn/bucketing/cudnn_lstm_bucketing.py rename to example/rnn/bucketing/cudnn_rnn_bucketing.py index 84cfc9d43805..29a66a8f4843 100644 --- a/example/rnn/bucketing/cudnn_lstm_bucketing.py +++ b/example/rnn/bucketing/cudnn_rnn_bucketing.py @@ -65,6 +65,8 @@ help='stack fused RNN cells to reduce communication overhead') parser.add_argument('--dropout', type=float, default='0.0', help='dropout probability (1.0 - keep probability)') +parser.add_argument('--rnntype', type=str, default='lstm', + help='rnn type: gru and lstm are supported') #buckets = [32] buckets = [10, 20, 30, 40, 50, 60] @@ -97,13 +99,13 @@ def train(args): cell = mx.rnn.SequentialRNNCell() for i in range(args.num_layers): cell.add(mx.rnn.FusedRNNCell(args.num_hidden, num_layers=1, - mode='lstm', prefix='lstm_l%d'%i, + mode=args.rnntype, prefix='%s_l%d'%(args.rnntype,i), bidirectional=args.bidirectional)) - if args.dropout > 0 and i < args.num_layers - 1: - cell.add(mx.rnn.DropoutCell(args.dropout, prefix='lstm_d%d'%i)) + if args.dropout > 0 and i < args.num_layers - 1 and args.rnntype == 'lstm': + cell.add(mx.rnn.DropoutCell(args.dropout, prefix='%s_d%d'%(args.rnntype,i))) else: cell = mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers, dropout=args.dropout, - mode='lstm', bidirectional=args.bidirectional) + mode=args.rnntype, bidirectional=args.bidirectional) def sym_gen(seq_len): data = mx.sym.Variable('data') @@ -168,16 +170,25 @@ def test(args): if not args.stack_rnn: stack = mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers, - mode='lstm', bidirectional=args.bidirectional).unfuse() + mode=args.rnntype, bidirectional=args.bidirectional).unfuse() else: stack = mx.rnn.SequentialRNNCell() for i in range(args.num_layers): - cell = mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_%dl0_'%i) - if args.bidirectional: - cell = mx.rnn.BidirectionalCell( - cell, - mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_%dr0_'%i), - output_prefix='bi_lstm_%d'%i) + if args.rnntype == 'lstm': + cell = mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='%s_%dl0_'%(args.rnntype,i)) + if args.bidirectional: + cell = mx.rnn.BidirectionalCell( + cell, + mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='%s_%dr0_'%(args.rnntype,i)), + output_prefix='bi_%s_%d'%(args.rnntype,i)) + elif args.rnntype == 'gru': + cell = mx.rnn.GRUCell(num_hidden=args.num_hidden, prefix='%s_%dl0_'%(args.rnntype,i)) + if args.bidirectional: + cell = mx.rnn.BidirectionalCell( + cell, + mx.rnn.GRUCell(num_hidden=args.num_hidden, prefix='%s_%dr0_'%(args.rnntype,i)), + output_prefix='bi_%s_%d'%(args.rnntype,i)) + stack.add(cell) def sym_gen(seq_len): diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 056c1d517c0e..d9dc98ece486 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -190,7 +190,7 @@ def forward(self, inputs, states=None): self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2]) self.i2h_weight[i]._finish_deferred_init() if inputs.context.device_type == 'gpu' or \ - self._mode == 'lstm' and not self._dropout: + self._mode in ['lstm', 'gru'] and not self._dropout: out = self._forward_kernel(inputs, states) else: out = self._forward(inputs, states) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index eded6aeed8a9..99531739afa6 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -101,12 +101,14 @@ inline size_t GetRNNWorkspaceSize(int seq_length, switch (mode) { case rnn_enum::kRnnRelu: case rnn_enum::kRnnTanh: - case rnn_enum::kGru: - LOG(FATAL) << "Only LSTM is supported at the moment"; + LOG(FATAL) << "Only LSTM and GRU are supported at the moment"; break; case rnn_enum::kLstm: size = (seq_length + 1) * batch_size * hidden_size * 4 + batch_size * hidden_size * 2 - + seq_length * batch_size * hidden_size * direction; + + seq_length * batch_size * hidden_size * direction + hidden_size * seq_length * 8; + break; + case rnn_enum::kGru: + size = seq_length * batch_size * hidden_size * direction * 4 + batch_size * hidden_size * 8; break; default: LOG(FATAL) << "unknown RNN mode " << mode; @@ -125,12 +127,16 @@ inline size_t GetRNNReserveSpaceSize(int num_layer, switch (mode) { case rnn_enum::kRnnRelu: case rnn_enum::kRnnTanh: - case rnn_enum::kGru: - LOG(FATAL) << "Only LSTM is supported at the moment"; + LOG(FATAL) << "Only LSTM and GRU are supported at the moment"; break; case rnn_enum::kLstm: size = num_layer * direction * seq_length * batch_size * hidden_size * 6; break; + case rnn_enum::kGru: + size = seq_length * batch_size * hidden_size * direction * num_layer * 8 + + batch_size * hidden_size * direction * 9 + hidden_size * seq_length * 6 + + seq_length * batch_size * 7 * hidden_size * direction; + break; default: LOG(FATAL) << "unknown RNN mode " << mode; break; @@ -221,14 +227,18 @@ void RNNForwardTraining(DType* ws, switch (mode) { case rnn_enum::kRnnTanh: case rnn_enum::kRnnRelu: - case rnn_enum::kGru: - LOG(FATAL) << "Only LSTM is supported at the moment"; + LOG(FATAL) << "Only LSTM and GRU are supported at the moment"; break; case rnn_enum::kLstm: LstmForwardTraining(ws, rs, state_outputs, num_layers, direction, seq_length, batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr); break; + case rnn_enum::kGru: + GruForwardTraining(ws, rs, state_outputs, num_layers, direction, seq_length, + batch_size, input_size, state_size, x_ptr, hx_ptr, + w_ptr, y_ptr, hy_ptr); + break; default: LOG(FATAL) << "unknown RNN mode " << mode; break; @@ -256,14 +266,18 @@ void RNNForwardInference(DType* ws, switch (mode) { case rnn_enum::kRnnRelu: case rnn_enum::kRnnTanh: - case rnn_enum::kGru: - LOG(FATAL) << "Only LSTM is supported at the moment"; + LOG(FATAL) << "Only LSTM and GRU are supported at the moment"; break; case rnn_enum::kLstm: LstmForwardInference(ws, state_outputs, num_layers, direction, seq_length, batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr); break; + case rnn_enum::kGru: + GruForwardInference(ws, state_outputs, num_layers, direction, seq_length, + batch_size, input_size, state_size, x_ptr, hx_ptr, + w_ptr, y_ptr, hy_ptr); + break; default: LOG(FATAL) << "unknown RNN mode" << mode; break; @@ -292,16 +306,26 @@ void RNNBackward(DType* ws, DType* dcx_ptr, DType* dw_ptr, DType* db_ptr, + int req_data, + int req_params, + int req_state, + int req_statecell, int mode) { switch (mode) { case rnn_enum::kRnnRelu: case rnn_enum::kRnnTanh: - case rnn_enum::kGru: break; case rnn_enum::kLstm: LstmBackward(ws, rs, num_layers, direction, seq_length, batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr, w_ptr, y_ptr, - dy_ptr, dhy_ptr, dcy_ptr, dx_ptr, dhx_ptr, dcx_ptr, dw_ptr, db_ptr); + dy_ptr, dhy_ptr, dcy_ptr, dx_ptr, dhx_ptr, dcx_ptr, dw_ptr, db_ptr, + req_data, req_params, req_state, req_statecell); + break; + case rnn_enum::kGru: + GruBackward(ws, rs, num_layers, direction, seq_length, batch_size, + input_size, state_size, x_ptr, hx_ptr, w_ptr, + dy_ptr, dhy_ptr, dx_ptr, dhx_ptr, dw_ptr, + req_data, req_params, req_state); break; default: LOG(FATAL) << "unknown RNN mode" << mode; @@ -330,7 +354,8 @@ class RNNOp : public Operator{ const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; - CHECK_EQ(param_.mode, rnn_enum::kLstm) << "Only lstm mode is supported at the moment."; + CHECK(param_.mode == rnn_enum::kLstm || param_.mode == rnn_enum::kGru) + << "Only lstm and gru mode are supported at the moment."; CHECK_EQ(param_.p, 0) << "Dropout is not supported at the moment."; size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3; @@ -442,8 +467,10 @@ class RNNOp : public Operator{ const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; - CHECK_EQ(param_.mode, rnn_enum::kLstm) << "Only lstm mode is supported at the moment."; + CHECK(param_.mode == rnn_enum::kLstm || param_.mode == rnn_enum::kGru) + << "Only lstm and gru mode are supported at the moment."; CHECK_EQ(param_.p, 0) << "Dropout is not supported at the moment."; + size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3; size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2; if (!param_.state_outputs) { @@ -535,6 +562,10 @@ class RNNOp : public Operator{ dcx_ptr, dw.dptr_, db_ptr, + req[rnn_enum::kData], + req[rnn_enum::kParams], + req[rnn_enum::kState], + req[rnn_enum::kStateCell], param_.mode); } diff --git a/src/operator/rnn_impl.h b/src/operator/rnn_impl.h index 2ee374bbf569..e92a18218f91 100644 --- a/src/operator/rnn_impl.h +++ b/src/operator/rnn_impl.h @@ -40,6 +40,10 @@ #include "./mshadow_op.h" #include "./linalg.h" + +namespace mxnet { +namespace op { + template inline DType sigmoid(DType x) { return 1.0f / (1.0f + exp(-x)); @@ -297,6 +301,7 @@ void LstmForwardInference(DType* ws, template void LstmBackwardSingleLayer(DType* ws, DType* rs, + DType* tmp_buf, bool bid, const int T, const int N, @@ -314,7 +319,11 @@ void LstmBackwardSingleLayer(DType* ws, DType* dcy_ptr, DType* w_ptr, DType* dw_ptr, - DType* db_ptr) { + DType* db_ptr, + int req_data, + int req_params, + int req_state, + int req_statecell) { using namespace mshadow; const Tensor wx(w_ptr, Shape2(H * 4, I)); const Tensor wh(w_ptr + I * H * 4, Shape2(H * 4, H)); @@ -336,6 +345,7 @@ void LstmBackwardSingleLayer(DType* ws, const DType alpha = 1.0; const DType beta0 = 0.0; const DType beta1 = 1.0; + const DType beta2 = 2.0; const int cell_size = N * H; if (dhy_ptr != NULL) { memcpy(dh.dptr_, dhy_ptr, cell_size * sizeof(DType)); @@ -367,24 +377,67 @@ void LstmBackwardSingleLayer(DType* ws, difgo[t][j][1][k] = dc[j][k] * cnext[j][k] * ft * (1 - ft); difgo[t][j][2][k] = dc[j][k] * it * (1 - gt * gt); difgo[t][j][3][k] = dh[j][k] * tc * ot * (1 - ot); - dcnext[j][k] = dc[j][k] * ft; + if (req_statecell != kNullOp || i > 0) { + dcnext[j][k] = dc[j][k] * ft; + } if (i) { htmp[j][k] = y[tnext][j][k + offset]; } } Tensor dyh(difgo[t].dptr_, Shape2(N, H * 4)); - linalg_gemm(dyh, wh, dhnext, alpha, beta0, false, false); - linalg_gemm(dyh, hnext, dwh, alpha, beta1, true, false); + if (req_state != kNullOp || i > 0) { + linalg_gemm(dyh, wh, dhnext, alpha, beta0, false, false); + } + if (req_params != kNullOp) { + if (req_params != kAddTo) { + linalg_gemm(dyh, hnext, dwh, alpha, beta1, true, false); + } else { + linalg_gemm(dyh, hnext, dwh, alpha, beta2, true, false); + + // generate dwx every time step for AddTo + Tensor x_t(x.dptr_ + i * N * I, Shape2(N, I)); + Tensor dyx_t(difgo.dptr_ + i * N * H * 4, Shape2(N, H * 4)); + linalg_gemm(dyx_t, x_t, dwx, alpha, beta2, true, false); + } + } } Tensor dyx(difgo.dptr_, Shape2(T * N, H * 4)); - linalg_gemm(dyx, wx, dx, alpha, bid ? beta1 : beta0, false, false); - linalg_gemm(dyx, x, dwx, alpha, beta0, true, false); + if (req_data != kNullOp) { + linalg_gemm(dyx, wx, dx, alpha, bid ? beta1 : beta0, false, false); + } + if (req_params != kNullOp && req_params != kAddTo) { + linalg_gemm(dyx, x, dwx, alpha, beta0, true, false); + } const int row = T * N; const int col = H * 4; - for (int i = 0; i < row; ++i) { - for (int j = 0; j < col; ++j) { - dbx[j] += dyx[i][j]; - dbh[j] = dbx[j]; + if (req_params != kNullOp) { + if (req_params != kAddTo) { + for (int i = 0; i < row; ++i) { + #pragma omp parallel for num_threads(omp_threads) + for (int j = 0; j < col; ++j) { + dbx[j] += dyx[i][j]; + dbh[j] = dbx[j]; + } + } + } else { + const Tensor tmp_dbx(tmp_buf, Shape2(col, T)); + const Tensor tmp_dbh(tmp_buf + col * T, Shape2(col, T)); + memset(tmp_dbx.dptr_, 0, col * T * sizeof(DType)); + memset(tmp_dbh.dptr_, 0, col * T * sizeof(DType)); + for (int t = T - 1; t >= 0; --t) { + #pragma omp parallel for num_threads(omp_threads) + for (int j = 0; j < col; ++j) { + for (int i = 0; i < N; ++i) { + tmp_dbx[j][t] += dyx[t * N + i][j]; + tmp_dbh[j][t] = tmp_dbx[j][t]; + } + } + #pragma omp parallel for num_threads(omp_threads) + for (int j = 0; j < col; ++j) { + dbx[j] += tmp_dbx[j][t] + dbx[j]; + dbh[j] += tmp_dbh[j][t] + dbh[j]; + } + } } } } @@ -410,7 +463,13 @@ void LstmBackward(DType* ws, DType* dhx_ptr, DType* dcx_ptr, DType* dw_ptr, - DType* db_ptr) { + DType* db_ptr, + int req_data, + int req_params, + int req_state, + int req_statecell) { + DType* tmp_buf = ws; + DType* ws2 = tmp_buf + 8 * T * H; const int total_layers = D * L; Tensor hx(hx_ptr, Shape3(total_layers, N, H)); Tensor cx(cx_ptr, Shape3(total_layers, N, H)); @@ -422,7 +481,7 @@ void LstmBackward(DType* ws, const int w_size1 = (I + H) * H * 4; // first layer const int w_size2 = (D * H + H) * H * 4; // other layers const int cell_size = N * H; - DType* dy_tmp_ptr = ws + T * cell_size * 4 + cell_size * 3; + DType* dy_tmp_ptr = ws2 + T * cell_size * 4 + cell_size * 3; for (int i = L - 1; i >= 0; --i) { const int input_size = i ? H * D : I; const int w_size = i ? w_size2 : w_size1; @@ -437,9 +496,10 @@ void LstmBackward(DType* ws, Tensor dy(dy_ptr, Shape3(T, N, H * D)); Tensor x(i ? y.dptr_ - r_size : x_ptr, Shape2(T * N, input_size)); Tensor dx(i ? dy_tmp_ptr : dx_ptr, Shape2(T * N, input_size)); - LstmBackwardSingleLayer(ws, rs_cur_ptr, false, T, N, input_size, H, + LstmBackwardSingleLayer(ws2, rs_cur_ptr, tmp_buf, false, T, N, input_size, H, x, hx[idx], cx[idx], y, dy, dx, dhx[idx], dcx[idx], - dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, dw_cur_ptr, db_cur_ptr); + dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, dw_cur_ptr, db_cur_ptr, + req_data, req_params, req_state, req_statecell); if (D == 2) { w_cur_ptr += w_size; dw_cur_ptr += w_size; @@ -447,11 +507,874 @@ void LstmBackward(DType* ws, ++idx; dhy_cur_ptr = dhy_ptr ? dhy_cur_ptr + cell_size : NULL; dcy_cur_ptr = dcy_ptr ? dcy_cur_ptr + cell_size : NULL; - LstmBackwardSingleLayer(ws, rs_cur_ptr, true, T, N, input_size, H, + LstmBackwardSingleLayer(ws2, rs_cur_ptr, tmp_buf, true, T, N, input_size, H, x, hx[idx], cx[idx], y, dy, dx, dhx[idx], dcx[idx], - dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, dw_cur_ptr, db_cur_ptr); + dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, dw_cur_ptr, db_cur_ptr, + req_data, req_params, req_state, req_statecell); } dy_ptr = dx.dptr_; } } + +template +void GruForwardInferenceSingleLayer(DType* ws, + DType* tmp_buf, + bool state_outputs, + const int D, + const int T, + const int N, + const int I, + const int H, + const Tensor &x, + const Tensor &hx, + DType* wx_ptr, + DType* wh_ptr, + DType* bx_ptr, + DType* bh_ptr, + DType* y_ptr, + DType* hy_ptr) { + DType* ht = y_ptr; + DType* ht_1 = y_ptr; + DType* back_ht_1 = y_ptr + (T-1) * N * H * D + H; + DType* back_ht = back_ht_1; + DType* gemmC1 = ws; // [D, T, N, 3 * H] + DType* gemmC2 = gemmC1 + D * T * N * 3 * H; // N * 3 * H + DType* rt = gemmC2 + N * 3 * H; + DType* zt = rt + N * H; + DType* nt = zt + N * H; + DType* back_wx_ptr = wx_ptr + I * 3 * H + H * 3 * H; + DType* back_wh_ptr = wh_ptr + I * 3 * H + H * 3 * H; + DType* back_bx_ptr = (bx_ptr != NULL)? bx_ptr + 3 * H * 2 : NULL; + DType* back_bh_ptr = (bh_ptr != NULL)? bh_ptr + 3 * H * 2: NULL; + DType* back_gemmC1 = gemmC1 + T * N * 3 * H; + DType* gemmC1_t = gemmC1; + + const Tensor wx(wx_ptr, Shape2(H * 3, I)); + const Tensor wh(wh_ptr, Shape2(H * 3, H)); + const Tensor bx(bx_ptr, Shape2(3, H)); + const Tensor bh(bh_ptr, Shape2(3, H)); + const Tensor back_wx(back_wx_ptr, Shape2(H * 3, I)); + const Tensor back_wh(back_wh_ptr, Shape2(H * 3, H)); + const Tensor back_bx(back_bx_ptr, Shape2(3, H)); + const Tensor back_bh(back_bh_ptr, Shape2(3, H)); + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + if (D == 1) { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; i++) + for (int j = 0; j < H; j++) { + y_ptr[i * H + j] = hx[i][j]; + } + } else { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; i++) + for (int j = 0; j < H; j++) { + y_ptr[i * D * H + j] = hx[i][j]; + back_ht_1[i * D * H + j] = hx[N + i][j]; + } + } + Tensor dgemmC1(ws, Shape2(T * N, 3 * H)); + Tensor dgemmC2(gemmC2, Shape2(N, 3 * H)); + Tensor dback_gemmC1(back_gemmC1, Shape2(T * N, 3 * H)); + + // x * wx.T : [T * N, I] * [I, 3 * H] + DType alpha = 1.0; + DType beta = 0.0; + linalg_gemm(x, wx, dgemmC1, alpha, beta, false, true); + if (D == 2) { + linalg_gemm(x, back_wx, dback_gemmC1, alpha, beta, false, true); + } + + for (int t = 0; t < T; t++) { + // perform the first direction, X * wx and H * wh for each step + // ht-1 * wh, ht-1:[N, H] wh:[3 * H, H] + Tensor dht_1(ht_1, Shape2(N, D * H)); + if (D == 1) { + linalg_gemm(dht_1, wh, dgemmC2, alpha, beta, false, true); + } else { + Tensor dht_1_tmp = Tensor(reinterpret_cast(tmp_buf), + Shape3(D, H, N)); + dht_1_tmp = reshape(dht_1.T(), Shape3(D, H, N)); + linalg_gemm(dht_1_tmp[0], wh, dgemmC2, alpha, beta, true, true); + } + gemmC1_t = gemmC1 + t * N * 3 * H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; ++i) { + for (int j = 0; j < H; ++j) { + int rtb = i * 3 * H; + int ztb = i * 3 * H + H; + int ntb = i * 3 * H + 2 * H; + rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] + gemmC2[rtb + j] + + bx[0][j] + bh[0][j]); + zt[i * H + j] = sigmoid(gemmC1_t[ztb + j] + gemmC2[ztb + j] + + bx[1][j] + bh[1][j]); + nt[i * H + j] = tanh(gemmC1_t[ntb + j] + bx[2][j] + + rt[i * H + j] * (gemmC2[ntb + j] + bh[2][j])); + ht[i * D * H + j] = (1-zt[i * H + j]) * nt[i * H + j] + + zt[i * H + j] * ht_1[i * D * H + j]; + } + } + ht_1 = ht; + ht = ht + D * H * N; + // perform the second direction + if (D == 2) { + gemmC1_t = back_gemmC1 + (T - 1 - t) * N * 3 * H; + Tensor dback_ht_1(back_ht_1, Shape2(N, D * H)); + Tensor dback_ht_1_tmp = Tensor + (reinterpret_cast(tmp_buf), Shape3(D, H, N)); + dback_ht_1_tmp = reshape(dback_ht_1.T(), Shape3(D, H, N)); + linalg_gemm(dback_ht_1_tmp[0], back_wh, dgemmC2, alpha, beta, true, true); + + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; ++i) { + for (int j = 0; j < H; ++j) { + int rtb = i * 3 * H; + int ztb = i * 3 * H + H; + int ntb = i * 3 * H + 2 * H; + rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] + + gemmC2[rtb + j] + back_bx[0][j] + back_bh[0][j]); + zt[i * H + j] = sigmoid(gemmC1_t[ztb + j] + + gemmC2[ztb + j] + back_bx[1][j]+ back_bh[1][j]); + nt[i * H + j] = tanh(gemmC1_t[ntb + j] + back_bx[2][j] + + rt[i * H + j] * (gemmC2[ntb + j] + back_bh[2][j])); + back_ht[i * D * H + j] = (1 - zt[i * H + j]) * nt[i * H + j] + + zt[i * H + j] * back_ht_1[i * D * H + j]; + } + } + back_ht_1 = back_ht; + back_ht = back_ht - D * H * N; + } + } + // copy last state to hy, from(N, H * D) to (D, N, H) + if (state_outputs) { + if (D == 1) { + DType* y_start = y_ptr + (T - 1) * N * H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; i++) + for (int j = 0; j < H; j++) { + hy_ptr[i * H + j] = y_start[i * H + j]; + } + } else { + DType* y_start = y_ptr + (T - 1) * N * H * D; + DType* y_back_start = y_ptr + H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; i++) + for (int j = 0; j < H; j++) { + hy_ptr[i * H + j] = y_start[i * D * H + j]; + hy_ptr[N * H + i * H + j] = y_back_start[i * D * H + j]; + } + } + } +} + +template +void GruForwardInference(DType* ws, + bool state_outputs, + const int L, + const int D, + const int T, + const int N, + int I, + const int H, + DType* x_ptr, + DType* hx_ptr, + DType* w_ptr, + DType* y_ptr, + DType* hy_ptr) { + DType* wx = w_ptr; + DType* wh = wx + I * H * 3; + DType* bx = wh + H * H * 3 + (D - 1) * (H * H * 3 + I * H * 3) + + (L - 1) * ((D + 1) * H) * H * 3 * D; + DType* bh = bx + H * 3; + + DType* y_tmp = ws; + DType* y_l = x_ptr; + DType* tmp_buf = y_tmp + D * T * N * H; + DType* ws2 = y_tmp + D * T * N * H + D * H * N; + + DType* wx_l = wx; + DType* wh_l = wh; + DType* bx_l = bx; + DType* bh_l = bh; + Tensor hx(hx_ptr, Shape3(D * L, N, H)); + DType* hy_l = hy_ptr; + for (int l = 0; l < L; l++) { + Tensor x_l(y_l, Shape2(T * N, I)); + if ((L + l) % 2) { + y_l = y_ptr; + } else { + y_l = y_tmp; + } + Tensor hx_l = hx[D * l]; + GruForwardInferenceSingleLayer(ws2, tmp_buf, state_outputs, D, T, N, I, H, + x_l, hx_l, wx_l, wh_l, bx_l, bh_l, y_l, hy_l); + hy_l = hy_l + D * N * H; + bx_l = bx_l + 3 * H * D * 2; + bh_l = bh_l + 3 * H * D * 2; + wx_l = wx_l + I * H * 3 * D + H * H * 3 * D; + if (l == 0) { + I = D * H; + } + wh_l = wx_l + I * 3 * H; + } +} + + +template +void GruForwardTrainingSingleLayer(DType* ws, + DType* tmp_buf, + bool state_outputs, + const int D, + const int T, + const int N, + const int I, + const int H, + const Tensor &x, + const Tensor &hx, + DType* wx_ptr, + DType* wh_ptr, + DType* bx_ptr, + DType* bh_ptr, + DType* gateR, + DType* gateZ, + DType* gateN, + DType* Mnh, + DType* y_ptr, + DType* hy_ptr) { + DType* ht = y_ptr; + DType* ht_1 = y_ptr; + DType* back_ht_1 = y_ptr + (T - 1)* N * H * D + H; + DType* back_ht = back_ht_1; + + DType* gemmC1 = ws; // [D, T, N, 3 * H] + DType* gemmC2 = gemmC1 + D * T * N * 3 * H; // N * 3 * H + DType* rt = gateR; + DType* zt = gateZ; + DType* nt = gateN; + DType* back_wx_ptr = wx_ptr + I * 3 * H + H * 3 * H; + DType* back_wh_ptr = wh_ptr + I * 3 * H + H * 3 * H; + DType* back_bx_ptr = (bx_ptr != NULL)? bx_ptr + 3 * H * 2 : NULL; + DType* back_bh_ptr = (bh_ptr != NULL)? bh_ptr + 3 * H * 2 : NULL; + DType* back_gateR = gateR + T * N * H; + DType* back_gateZ = gateZ + T * N * H; + DType* back_gateN = gateN + T * N * H; + DType* back_Mnh = Mnh + T * N * H; + DType* back_gemmC1 = gemmC1 + T * N * 3 * H; + DType* gemmC1_t = gemmC1; + + const Tensor wx(wx_ptr, Shape2(H * 3, I)); + const Tensor wh(wh_ptr, Shape2(H * 3, H)); + const Tensor bx(bx_ptr, Shape2(3, H)); + const Tensor bh(bh_ptr, Shape2(3, H)); + const Tensor back_wx(back_wx_ptr, Shape2(H * 3, I)); + const Tensor back_wh(back_wh_ptr, Shape2(H * 3, H)); + const Tensor back_bx(back_bx_ptr, Shape2(3, H)); + const Tensor back_bh(back_bh_ptr, Shape2(3, H)); + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + if (D == 1) { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; i++) + for (int j = 0; j < H; j++) { + y_ptr[i * H + j] = hx[i][j]; + } + } else { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; i++) + for (int j = 0; j < H; j++) { + y_ptr[i * D * H + j] = hx[i][j]; + back_ht_1[i * D * H + j] = hx[N + i][j]; + } + } + + Tensor dgemmC1(ws, Shape2(T * N, 3 * H)); + Tensor dgemmC2(gemmC2, Shape2(N, 3 * H)); + Tensor dback_gemmC1(back_gemmC1, Shape2(T * N, 3 * H)); + + // x * wx.T : [T * N, I] * [I, 3 * H] + DType alpha = 1.0; + DType beta = 0.0; + linalg_gemm(x, wx, dgemmC1, alpha, beta, false, true); + if (D == 2) { + linalg_gemm(x, back_wx, dback_gemmC1, alpha, beta, false, true); + } + + for (int t = 0; t < T; t++) { + // perform the first direction, X * wx and H * wh for each step + // ht-1 * wh, ht-1:[N, H] wh:[3 * H, H] + Tensor dht_1(ht_1, Shape2(N, D * H)); + if (D == 1) { + linalg_gemm(dht_1, wh, dgemmC2, alpha, beta, false, true); + } else { + Tensor dht_1_tmp = Tensor(reinterpret_cast(tmp_buf), + Shape3(D, H, N)); + dht_1_tmp = reshape(dht_1.T(), Shape3(D, H, N)); + linalg_gemm(dht_1_tmp[0], wh, dgemmC2, alpha, beta, true, true); + } + rt = gateR + t * N * H; + zt = gateZ + t * N * H; + nt = gateN + t * N * H; + gemmC1_t = gemmC1 + t * N * 3 * H; + DType* Mnht = Mnh + t * N * H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; ++i) { + for (int j = 0; j < H; ++j) { + int rtb = i * 3 * H; + int ztb = i * 3 * H + H; + int ntb = i * 3 * H + 2 * H; + Mnht[i * H + j] = gemmC2[ntb + j] + bh[2][j]; + rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] + gemmC2[rtb + j] + + bx[0][j] + bh[0][j]); + zt[i * H + j] = sigmoid(gemmC1_t[ztb + j] + gemmC2[ztb + j] + + bx[1][j] + bh[1][j]); + nt[i * H + j] = tanh(gemmC1_t[ntb + j] + bx[2][j] + + rt[i * H + j] * (gemmC2[ntb + j] + bh[2][j])); + ht[i * D * H + j] = (1-zt[i * H + j]) * nt[i * H + j] + + zt[i * H + j] * ht_1[i * D * H + j]; + } + } + ht_1 = ht; + ht = ht + D * H * N; + // perform the second direction + if (D == 2) { + rt = back_gateR + (T - 1 - t) * N * H; + zt = back_gateZ + (T - 1 - t) * N * H; + nt = back_gateN + (T - 1 - t) * N * H; + gemmC1_t = back_gemmC1 + (T - 1 - t) * N * 3 * H; + Tensor dback_ht_1(back_ht_1, Shape2(N, D * H)); + Tensor dback_ht_1_tmp = Tensor + (reinterpret_cast(tmp_buf), Shape3(D, H, N)); + dback_ht_1_tmp = reshape(dback_ht_1.T(), Shape3(D, H, N)); + linalg_gemm(dback_ht_1_tmp[0], back_wh, dgemmC2, alpha, beta, true, true); + + DType* back_Mnht = back_Mnh + (T - 1 - t) * N * H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; ++i) { + for (int j = 0; j < H; ++j) { + int rtb = i * 3 * H; + int ztb = i * 3 * H + H; + int ntb = i * 3 * H + 2 * H; + back_Mnht[i * H + j] = gemmC2[ntb + j] + back_bh[2][j]; + rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] + + gemmC2[rtb + j] + back_bx[0][j] + back_bh[0][j]); + zt[i * H + j] = sigmoid(gemmC1_t[ztb + j] + + gemmC2[ztb + j] + back_bx[1][j] + back_bh[1][j]); + nt[i * H + j] = tanh(gemmC1_t[ntb + j] + back_bx[2][j] + + rt[i * H + j] * (gemmC2[ntb + j] + back_bh[2][j])); + back_ht[i * D * H + j] = (1 - zt[i * H + j]) * nt[i * H + j] + + zt[i * H + j] * back_ht_1[i * D * H + j]; + } + } + back_ht_1 = back_ht; + back_ht = back_ht - D * H * N; + } + } + + // copy last state to hy, from(N, H * D) to (D, N, H) + if (state_outputs) { + if (D == 1) { + DType* y_start = y_ptr + (T - 1) * N * H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; i++) + for (int j = 0; j < H; j++) { + hy_ptr[i * H + j] = y_start[i * H + j]; + } + } else { + DType* y_start = y_ptr + (T - 1) * N * H * D; + DType* y_back_start = y_ptr + H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; i++) + for (int j = 0; j < H; j++) { + hy_ptr[i * H + j] = y_start[i * D * H + j]; + hy_ptr[N * H + i * H + j] = y_back_start[i * D * H + j]; + } + } + } +} + +template +void GruForwardTraining(DType* ws, + DType* rs, + bool state_outputs, + const int L, + const int D, + const int T, + const int N, + int I, + const int H, + DType* x_ptr, + DType* hx_ptr, + DType* w_ptr, + DType* y_ptr, + DType* hy_ptr) { + DType* wx = w_ptr; + DType* wh = wx + I * H * 3; + DType* bx = wh + H * H * 3 + (D - 1) * (H * H * 3 + I * H * 3) + + (L - 1) * ((D + 1) * H) * H * 3 * D; + DType* bh = bx + H * 3; + Tensor hx(hx_ptr, Shape3(D * L, N, H)); + DType* hy_l = hy_ptr; + DType* gateR_l = rs; + DType* gateZ_l = gateR_l + L * T * D * N * H; + DType* gateN_l = gateZ_l + L * T * D * N * H; + DType* y_l = gateN_l + L * T * D * N * H; + DType* Mnh_l = y_l + L * T * N * H * D; + DType* tmp_buf = Mnh_l + L * D * T * N * H; + DType* ws2 = Mnh_l + L * D * T * N * H + D * H * N; + DType* wx_l = wx; + DType* wh_l = wh; + DType* bx_l = bx; + DType* bh_l = bh; + DType* y_tmp = x_ptr; + + for (int l = 0; l < L; l++) { + if (l != 0) { + y_tmp = y_l; + y_l = y_l + T * N * H * D; + } + Tensor x_l(y_tmp, Shape2(T * N, I)); + Tensor hx_l = hx[D * l]; + GruForwardTrainingSingleLayer(ws2, tmp_buf, state_outputs, D, T, N, I, H, + x_l, hx_l, wx_l, wh_l, bx_l, bh_l, + gateR_l, gateZ_l, gateN_l, Mnh_l, y_l, hy_l); + gateR_l = gateR_l + T * D * N * H; + gateZ_l = gateZ_l + T * D * N * H; + gateN_l = gateN_l + T * D * N * H; + Mnh_l = Mnh_l + T * D * N * H; + hy_l = hy_l + D * N * H; + bx_l = bx_l + 3 * H * D * 2; + bh_l = bh_l + 3 * H * D * 2; + + wx_l = wx_l + I * H * 3 * D + H * H * 3 * D; + if (l == 0) { + I = D * H; + } + wh_l = wx_l + I * 3 * H; + } + memcpy(y_ptr, y_l, T * N * H * D * sizeof(DType)); +} + +template +void GruBackwardSingleLayer(DType* ws, + DType* tmp_buf, + const int D, + const int T, + const int N, + const int I, + const int H, + const Tensor &x, + const Tensor &hx, + DType* wx_ptr, + DType* wh_ptr, + DType* y_ptr, + DType* dy_ptr, + DType* dhy_ptr, + DType* gateR, + DType* gateZ, + DType* gateN, + DType* Mnh, + DType* dx, + DType* dhx, + DType* dwx, + DType* dwh, + DType* dbx, + DType* dbh, + int req_data, + int req_params, + int req_state) { + DType* dyt; + DType* ht1; // [N, D, H] + DType* rt; + DType* zt; + DType* nt; + DType* dat; + DType* dart; + DType* dar = ws; // [T, N, 3 * H] + DType* da = dar + T * N * 3 * H; // [T, N, 3 * H] + DType* dht1 = da + T * N * 3 * H; // [D, N, H] + DType* hx_ = dht1 + D * N * H; // [N, D, H] + DType* Mnht = Mnh; + DType* back_ht1; + DType* back_dht1 = dht1 + N * H; // [N, H] + DType* back_Mnht = Mnh + T * N * H; + DType* back_gateR = gateR + T * N * H; + DType* back_gateZ = gateZ + T * N * H; + DType* back_gateN = gateN + T * N * H; + DType* back_wx_ptr = wx_ptr + I * 3 * H + H * 3 * H; + DType* back_wh_ptr = wh_ptr + I * 3 * H + H * 3 * H; + DType* back_dwx = dwx + I * 3 * H + H * 3 * H; + DType* back_dwh = dwh + I * 3 * H + H * 3 * H; + DType* back_dbx = dbx + 3 * H * 2; + DType* back_dbh = dbh + 3 * H * 2; + + DType alpha = 1.0; + DType beta = 0.0; + const Tensor wx(wx_ptr, Shape2(H * 3, I)); + const Tensor wh(wh_ptr, Shape2(H * 3, H)); + const Tensor back_wx(back_wx_ptr, Shape2(H * 3, I)); + const Tensor back_wh(back_wh_ptr, Shape2(H * 3, H)); + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N * H; ++i) { + if (dhy_ptr) { + dht1[i] = dhy_ptr[i]; + } else { + dht1[i] = 0; + } + } + + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; ++i) { + for (int j = 0; j < H; ++j) { + hx_[i * D * H + j] = hx[i][j]; + } + } + + if (D == 2) { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N * H; ++i) { + if (dhy_ptr) { + back_dht1[i] = dhy_ptr[N * H + i]; + } else { + back_dht1[i] = 0; + } + } + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; ++i) { + for (int j = 0; j < H; ++j) { + hx_[i * D * H + H + j] = hx[N + i][j]; + } + } + } + for (int t = T - 1; t >= 0; --t) { + if (t) { + ht1 = y_ptr + (t - 1) * N * D * H; + } else { + ht1 = hx_; + } + // add dy[T, N, D, H] to dhy[D, N, H] + dyt = dy_ptr + t * N * D * H; + + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; ++i) { + for (int j = 0; j < H; ++j) { + dht1[i * H + j] += dyt[i * D * H + j]; + } + } + + rt = gateR + t * N * H; + zt = gateZ + t * N * H; + nt = gateN + t * N * H; + Mnht = Mnh + t * N * H; + dat = da + t * N * 3 * H; + dart = dar + t * N * 3 * H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; ++i) { + for (int j = 0; j < H; ++j) { + int nid = i * 3 * H + 2 * H + j; + int zid = i * 3 * H + H + j; + int rid = i * 3 * H + j; + int id = i * H + j; + dat[nid] = dht1[id] * (1 - zt[id]) * (1 - nt[id] * nt[id]); + dart[zid] = dat[zid] = dht1[id] * (ht1[i * D * H + j] - nt[id]) * + zt[id] * (1 - zt[id]); + dart[rid] = dat[rid] = dat[nid] * Mnht[id] * rt[id] * + (1 - rt[id]); + dart[nid] = dat[nid] * rt[id]; + dht1[id] = dht1[id] * zt[id]; + } + } + if (req_params != kNullOp) { + alpha = 1.0; + beta = 1.0; + // dht1 = dart * wh [N, H] = [N, 3 * H] * [3 * H, H] + Tensor d_dht1(dht1, Shape2(N, H)); + Tensor d_dart(dart, Shape2(N, 3 * H)); + linalg_gemm(d_dart, wh, d_dht1, alpha, beta, false, false); + + if (req_params == kAddTo) { + beta = 2.0; + // dwx = da.T * x [3 * H, I] = [3 * H, N] * [N, I] for AddTo + Tensor d_xt(x.dptr_ + t * N * I, Shape2(N, I)); + Tensor d_dat(dat, Shape2(N, 3 * H)); + Tensor d_dwx(dwx, Shape2(3 * H, I)); + linalg_gemm(d_dat, d_xt, d_dwx, alpha, beta, true, false); + } + // dwh = dart.T * ht1 [3 * H, H] = [3 * H, N] * [N, H] + Tensor d_ht1(ht1, Shape2(N, D * H)); + Tensor d_dwh(dwh, Shape2(3 * H, H)); + Tensor d_ht1_tmp = Tensor + (reinterpret_cast(tmp_buf), Shape3(D, H, N)); + d_ht1_tmp = reshape(d_ht1.T(), Shape3(D, H, N)); + linalg_gemm(d_dart, d_ht1_tmp[0], d_dwh, alpha, beta, true, true); + } + } + + if (req_params != kNullOp) { + // dbx = e * da [1, 3 * H] = [1, N] * [N, 3 * H] + if (req_params != kAddTo) { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < 3 * H; ++i) { + for (int j = 0; j < N * T; ++j) { + dbx[i] += da[j * 3 * H + i]; + dbh[i] += dar[j * 3 * H + i]; + } + } + } else { + const Tensor tmp_dbx(tmp_buf + T * N * D * H, Shape2(H * 3, T)); + const Tensor tmp_dbh(tmp_buf + T * N * D * H + 3 * H * T, Shape2(H * 3, T)); + memset(tmp_dbx.dptr_, 0, H * T * 3 * sizeof(DType)); + memset(tmp_dbh.dptr_, 0, H * T * 3 * sizeof(DType)); + + for (int t = T - 1; t >= 0; --t) { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < 3 * H; ++i) { + for (int j = 0; j < N; ++j) { + tmp_dbx[i][t] += da[t * N * 3 * H + j * 3 * H + i]; + tmp_dbh[i][t] += dar[t * N * 3 * H + j * 3 * H + i]; + } + } + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < 3 * H; ++i) { + dbx[i] += tmp_dbx[i][t] + dbx[i]; + dbh[i] += tmp_dbh[i][t] + dbh[i]; + } + } + } + } + alpha = 1.0; + beta = 0.0; + + // dx = da * wx [T * N, I] = [T * N, 3 * H] * [3 * H, I] + Tensor d_da(da, Shape2(T * N, 3 * H)); + if (req_data != kNullOp) { + Tensor d_dx(dx, Shape2(T * N, I)); + linalg_gemm(d_da, wx, d_dx, alpha, beta, false, false); + } + + // dwx = da.T * x [3 * H, I] = [3 * H, T * N] * [T * N, I] + if (req_params != kNullOp && req_params != kAddTo) { + Tensor d_dwx(dwx, Shape2(3 * H, I)); + linalg_gemm(d_da, x, d_dwx, alpha, beta, true, false); + } + + if (D == 2) { + for (int t = 0; t < T; ++t) { + if (t == T-1) { + back_ht1 = hx_; + } else { + back_ht1 = y_ptr + (t + 1) * N * D * H; + } + + // add dy[T, N, D, H] to dhy[D, N, H] + dyt = dy_ptr + t * N * D * H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; ++i) { + for (int j = 0; j < H; ++j) { + back_dht1[i * H + j] += dyt[i * D * H + H + j]; + } + } + + rt = back_gateR + t * N * H; + zt = back_gateZ + t * N * H; + nt = back_gateN + t * N * H; + back_Mnht = Mnh + (T + t) * N * H; + dat = da + t * N * 3 * H; + dart = dar + t * N * 3 * H; + + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; ++i) { + for (int j = 0; j < H; ++j) { + int nid = i * 3 * H + 2 * H + j; + int zid = i * 3 * H + H + j; + int rid = i * 3 * H + j; + int id = i * H + j; + dat[nid] = back_dht1[id] * (1 - zt[id]) * (1 - nt[id] * nt[id]); + dart[zid] = dat[zid] = back_dht1[id] * (back_ht1[i * D * H + H + j] - + nt[id]) * zt[id] * (1 - zt[id]); + dart[rid] = dat[rid] = dat[nid] * back_Mnht[id] * rt[id] * + (1 - rt[id]); + dart[nid] = dat[nid] * rt[id]; + back_dht1[id] = back_dht1[id] * zt[id]; + } + } + + if (req_params != kNullOp) { + alpha = 1.0; + beta = 1.0; + // dht1 = da * wh [N, H] = [N, 3 * H] * [3 * H, H] + Tensor d_dart(dart, Shape2(N, 3 * H)); + Tensor d_back_dht1(back_dht1, Shape2(N, H)); + linalg_gemm(d_dart, back_wh, d_back_dht1, alpha, beta, false, false); + + // dwh = da.T * ht1 [3 * H, H] = [3 * H, N] * [N, H] + Tensor d_back_dwh(back_dwh, Shape2(3 * H, H)); + Tensor d_back_ht1(back_ht1 + H, Shape2(N, D * H)); + Tensor d_back_ht1_tmp = Tensor + (reinterpret_cast(tmp_buf), Shape3(D, H, N)); + d_back_ht1_tmp = reshape(d_back_ht1.T(), Shape3(D, H, N)); + if (req_params == kAddTo) { + beta = 2.0; + // dwx = da.T * x [3 * H, I] = [3 * H, N] * [N, I] for AddTo + Tensor d_xt(x.dptr_ + t * N * I, Shape2(N, I)); + Tensor d_dat(dat, Shape2(N, 3 * H)); + Tensor d_back_dwx(back_dwx, Shape2(3 * H, I)); + linalg_gemm(d_dat, d_xt, d_back_dwx, alpha, beta, true, false); + } + linalg_gemm(d_dart, d_back_ht1_tmp[0], d_back_dwh, alpha, beta, true, true); + } + } + + if (req_params != kNullOp) { + // dbx = e * da [1, 3 * H] = [1, N] * [N, 3 * H] + if (req_params != kAddTo) { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < 3 * H; ++i) { + for (int j = 0; j < N * T; ++j) { + back_dbx[i] += da[j * 3 * H + i]; + back_dbh[i] += dar[j * 3 * H + i]; + } + } + } else { + const Tensor tmp_dbx(tmp_buf + T * N * D * H, Shape2(H * 3, T)); + const Tensor tmp_dbh(tmp_buf + T * N * D * H + 3 * H * T, Shape2(H * 3, T)); + memset(tmp_dbx.dptr_, 0, H * T * 3 * sizeof(DType)); + memset(tmp_dbh.dptr_, 0, H * T * 3 * sizeof(DType)); + + for (int t = T - 1; t >= 0; --t) { + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < 3 * H; ++i) { + for (int j = 0; j < N; ++j) { + tmp_dbx[i][t] += da[t * N * 3 * H + j * 3 * H + i]; + tmp_dbh[i][t] += dar[t * N * 3 * H + j * 3 * H + i]; + } + } + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < 3 * H; ++i) { + back_dbx[i] += tmp_dbx[i][t] + back_dbx[i]; + back_dbh[i] += tmp_dbh[i][t] + back_dbh[i]; + } + } + } + } + alpha = 1.0; + beta = 1.0; + // dxt = da * wx [T * N, I] = [T * N, 3 * H] * [3 * H, I] + Tensor d_da2(da, Shape2(T * N, 3 * H)); + if (req_data != kNullOp) { + Tensor d_dx(dx, Shape2(T * N, I)); + linalg_gemm(d_da2, back_wx, d_dx, alpha, beta, false, false); + } + alpha = 1.0; + beta = 0.0; + // dwx = da.T * x [3 * H, I] = [3 * H, T * N] * [T * N, I] + if (req_params != kNullOp && req_params != kAddTo) { + Tensor d_back_dwx(back_dwx, Shape2(3 * H, I)); + linalg_gemm(d_da2, x, d_back_dwx, alpha, beta, true, false); + } + } + if (req_state != kNullOp) { + memcpy(dhx, dht1, N * H * D * sizeof(DType)); + } +} + +template +void GruBackward(DType* ws, + DType* rs, + const int L, + const int D, + const int T, + const int N, + int I, + const int H, + DType* x_ptr, + DType* hx_ptr, + DType* w_ptr, + DType* dy_ptr, + DType* dhy_ptr, + DType* dx_ptr, + DType* dhx_ptr, + DType* dw_ptr, + int req_data, + int req_params, + int req_state) { + DType* wx = w_ptr; + DType* dwx = dw_ptr; + DType* dwh = dwx + I * H * 3; + DType* dbx = dwh + H * H * 3 + (D - 1) * (H * H * 3 + I * H * 3) + + (L - 1) * ((D + 1) * H) * H * 3 * D; + DType* gateR_l = rs + (L - 1) * T * D * N * H; + DType* gateZ_l = gateR_l + L * T * D * N * H; + DType* gateN_l = gateZ_l + L * T * D * N * H; + DType* y_l = gateN_l + L * T * D * N * H; + DType* Mnh_l = y_l + L * T * N * H * D; + DType* tmp_buf = Mnh_l + L * D * T * N * H; + DType* dx_l = tmp_buf + T * N * D * H + 3 * H * T * 2; + DType* ws2 = dx_l + T * N * D * H; + DType* wx_l = (L == 1)? wx : wx + (L - 2) * D * (D + 1) * H * 3 * H + + D * I * 3 * H + D * H * 3 * H; + DType* wh_l = wx_l; + if (L == 1) { + wh_l = wh_l + I * H * 3; + } else { + wh_l = wh_l + (D * H) * H * 3; + } + DType* dhy_l = NULL; + if (dhy_ptr) + dhy_l = dhy_ptr + (L - 1) * D * N * H; + DType* dwx_l = (L == 1)? dwx : dwx + (L - 2) * D * (D + 1) * H * 3 * H + + D * I * 3 * H + D * H * 3 * H; + DType* dwh_l = NULL; + if (L == 1) { + dwh_l = dwx_l + I * H * 3; + } else { + dwh_l = dwx_l + (D * H) * H * 3; + } + DType* dbx_l = dbx + (L - 1) * D * 3 * H * 2; + DType* dbh_l = dbx_l + 3 * H; + DType* dhx_l = dhx_ptr + (L - 1) * D * N * H; + DType* dy_l = dy_ptr; + Tensor hx(hx_ptr, Shape3(L, D * N, H)); + int inputsize = I; + DType* y_tmp = y_l - T * N * H * D; + for (int l = L - 1; l >= 0; --l) { + if (l == 0) { + I = inputsize; + y_tmp = x_ptr; + dx_l = dx_ptr; + } else { + I = D * H; + } + Tensor hx_l = hx[l]; + Tensor x_l(y_tmp, Shape2(T * N, I)); + GruBackwardSingleLayer(ws2, tmp_buf, D, T, N, I, H, x_l, hx_l, wx_l, wh_l, y_l, dy_l, + dhy_l, gateR_l, gateZ_l, gateN_l, Mnh_l, dx_l, dhx_l, + dwx_l, dwh_l, dbx_l, dbh_l, req_data, req_params, req_state); + if (l > 0) { + memcpy(dy_l, dx_l, T * N * H * D * sizeof(DType)); + gateR_l = gateR_l - T * D * N * H; + gateZ_l = gateZ_l - T * D * N * H; + gateN_l = gateN_l - T * D * N * H; + Mnh_l = Mnh_l - T * D * N * H; + dhx_l = dhx_l - D * N * H; + if (dhy_l) + dhy_l = dhy_l - D * N * H; + y_l = y_l - T * N * H * D; + y_tmp = y_l; + if (l == 1) { + wx_l = wx_l - (inputsize + H) * H * 3 * D; + wh_l = wx_l + inputsize * 3 * H; + dwx_l = dwx_l - (inputsize + H) * H * 3 * D; + dwh_l = dwx_l + inputsize * 3 * H; + } else { + wx_l = wx_l - (I + H) * H * 3 * D; + wh_l = wx_l + I * 3 * H; + dwx_l = dwx_l - (I + H) * H * 3 * D; + dwh_l = dwx_l + I * 3 * H; + } + dbx_l = dbx_l - D * 3 * H * 2; + dbh_l = dbx_l + 3 * H; + } + } +} +} // namespace op +} // namespace mxnet #endif // MXNET_OPERATOR_RNN_IMPL_H_ diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 1eb23cc92281..ab03973e8e86 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -28,17 +28,17 @@ from common import setup_module, with_seed import unittest -def check_rnn_consistency(cell1, cell2, T, N, I, H): +def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req): dshape = (N, T, I) data = mx.sym.Variable('data') Y1, _ = cell1.unroll(T, data, layout='NTC', merge_outputs=True) mod1 = mx.mod.Module(Y1, label_names=None, context=default_context()) - mod1.bind(data_shapes=[('data', dshape)], label_shapes=None, inputs_need_grad=True) + mod1.bind(data_shapes=[('data', dshape)], label_shapes=None, inputs_need_grad=True, grad_req=grad_req) Y2, _ = cell2.unroll(T, data, layout='NTC', merge_outputs=True) mod2 = mx.mod.Module(Y2, label_names=None, context=default_context()) - mod2.bind(data_shapes=[('data', dshape)], label_shapes=None, inputs_need_grad=True) + mod2.bind(data_shapes=[('data', dshape)], label_shapes=None, inputs_need_grad=True, grad_req=grad_req) mod1.init_params() args, auxs = mod1.get_params() @@ -60,8 +60,14 @@ def check_rnn_consistency(cell1, cell2, T, N, I, H): dy = mx.random.uniform(shape=mod1.get_outputs()[0].shape) mod1.backward(out_grads=[dy]) - mod2.backward(out_grads=[dy]) - assert_allclose(mod1.get_input_grads()[0].asnumpy(), mod2.get_input_grads()[0].asnumpy(), rtol=1e-2, atol=1e-4) + mod2.backward(out_grads=[dy]) + if grad_req != 'null': + assert_allclose(mod1.get_input_grads()[0].asnumpy(), mod2.get_input_grads()[0].asnumpy(), rtol=1e-2, atol=1e-4) + else: + assert(mod1.get_input_grads()[0] == None) + assert(mod2.get_input_grads()[0] == None) + + @with_seed() def test_lstm_sym(): @@ -71,8 +77,10 @@ def test_lstm_sym(): stack.add(mx.rnn.LSTMCell(H, prefix='l0_')) stack.add(mx.rnn.LSTMCell(H, prefix='l1_')) stack.add(mx.rnn.LSTMCell(H, prefix='l2_')) - check_rnn_consistency(fused, stack, T, N, I, H) - check_rnn_consistency(stack, fused, T, N, I, H) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') @with_seed() def test_lstm_bidirectional(): @@ -90,8 +98,45 @@ def test_lstm_bidirectional(): mx.rnn.LSTMCell(H, prefix='r1_'), output_prefix='bi_lstm_1_')) - check_rnn_consistency(stack, fused, T, N, I, H) - check_rnn_consistency(fused, stack, T, N, I, H) + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') + +@with_seed() +def test_gru_sym(): + T, N, I, H = 5, 32, 800, 800 + fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='gru', get_next_state=True, prefix='') + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.GRUCell(H, prefix='l0_')) + stack.add(mx.rnn.GRUCell(H, prefix='l1_')) + stack.add(mx.rnn.GRUCell(H, prefix='l2_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') + +@with_seed() +def test_gru_bidirectional(): + T, N, I, H = 5, 20, 800, 800 + + fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='gru', + bidirectional=True, get_next_state=True, prefix='') + + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.GRUCell(H, prefix='l0_'), + mx.rnn.GRUCell(H, prefix='r0_'), + output_prefix='bi_gru_0_')) + + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.GRUCell(H, prefix='l1_'), + mx.rnn.GRUCell(H, prefix='r1_'), + output_prefix='bi_gru_1_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') + # Currently, fused LSTM operator doesn't support dropout. # Will change this test after dropout is supported