From 9ec87665599723f64aa939420ece47d78828f0e1 Mon Sep 17 00:00:00 2001 From: GaoWei8 <53294385+GaoWei8@users.noreply.github.com> Date: Mon, 12 Oct 2020 15:55:24 +0800 Subject: [PATCH] Add flattern weight of lstm (#27192) * add flattern weight of lstm --- paddle/fluid/operators/cudnn_lstm_op.cc | 43 ++++- paddle/fluid/operators/cudnn_lstm_op.cu.cc | 162 ++++++++++++++++-- python/paddle/fluid/layers/rnn.py | 20 +-- .../tests/unittests/test_lstm_cudnn_op.py | 127 +++++++++++--- 4 files changed, 290 insertions(+), 62 deletions(-) diff --git a/paddle/fluid/operators/cudnn_lstm_op.cc b/paddle/fluid/operators/cudnn_lstm_op.cc index 50486ad041aa4..31f0c26a3f3a1 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace operators { @@ -25,7 +26,6 @@ class CudnnLSTMOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTM"); - OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTM"); OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTM"); OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTM"); @@ -122,7 +122,13 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("W", "(Tensor) the learnable hidden-hidden weights." " The shape is (N), where N is total weight size of the LSTM. " - " cudnn concatenate all the weight to one Tensor"); + " cudnn concatenate all the weight to one Tensor") + .AsDispensable(); + AddInput("WeightList", + "(vector), stores weight and bias data when the weight " + "use the list format. ") + .AsDispensable() + .AsDuplicable(); AddInput("SequenceLength", "(Tensor) When the input data is padding, " "set this parameter. This parameter represents " @@ -216,7 +222,6 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTMGrad"); - OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTMGrad"); OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTMGrad"); OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTMGrad"); @@ -228,7 +233,10 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel { }; SetOutGradDim("Input"); - SetOutGradDim("W"); + if (ctx->HasInputs("WeightList")) { + ctx->SetOutputsDim(framework::GradVarName("WeightList"), + ctx->GetInputsDim("WeightList")); + } SetOutGradDim("InitH"); SetOutGradDim("InitC"); } @@ -251,7 +259,9 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker { op->SetInput("Input", this->Input("Input")); op->SetInput("InitH", this->Input("InitH")); op->SetInput("InitC", this->Input("InitC")); - op->SetInput("W", this->Input("W")); + if (this->HasInput("WeightList")) { + op->SetInput("WeightList", this->Input("WeightList")); + } if (this->HasInput("SequenceLength")) { op->SetInput("SequenceLength", this->Input("SequenceLength")); } @@ -262,8 +272,12 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker { op->SetInput(framework::GradVarName("LastC"), this->OutputGrad("LastC")); op->SetInput(framework::GradVarName("LastH"), this->OutputGrad("LastH")); + if (this->HasInput("WeightList")) { + op->SetOutput(framework::GradVarName("WeightList"), + this->InputGrad("WeightList", false)); + } + op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input")); - op->SetOutput(framework::GradVarName("W"), this->InputGrad("W")); op->SetOutput(framework::GradVarName("InitH"), this->InputGrad("InitH")); op->SetOutput(framework::GradVarName("InitC"), this->InputGrad("InitC")); op->SetAttrMap(this->Attrs()); @@ -290,3 +304,20 @@ REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp); REGISTER_OP_CPU_KERNEL(cudnn_lstm, ops::NotImpleKernel); REGISTER_OP_CPU_KERNEL(cudnn_lstm_grad, ops::NotImpleKernel); + +// TODO(Shixiaowei02) Add ModifyInput support +REGISTER_OP_VERSION(cudnn_lstm) + .AddCheckpoint( + R"ROC( + Upgrade cudnn_lstm add a new input [WeightList] and modify input [W] to dispensable.)ROC", + paddle::framework::compatible::OpVersionDesc() + .NewInput( + "WeightList", + "The WeightList stores weight and bias data. WeightList is " + "dispensable.") + .NewInput("SequenceLength", + "When the input data is padding, set this parameter. " + "SequenceLength is dispensable.") + .NewOutput("StateOut", "Store the global drop state when training") + .NewOutput("Reserve", + "A temporary output Tensor to store the reserve_data")); diff --git a/paddle/fluid/operators/cudnn_lstm_op.cu.cc b/paddle/fluid/operators/cudnn_lstm_op.cu.cc index 6ac75b78d7058..bea7d9c02ca7d 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cu.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cu.cc @@ -30,6 +30,66 @@ namespace operators { using LoDTensor = framework::LoDTensor; using Tensor = framework::Tensor; +template +bool is_continuous(const Type &weight_list) { + bool continuous = true; + for (size_t i = 0; i < weight_list.size() - 1; ++i) { + auto *in_data = weight_list[i]->template data(); + auto *in_after_data = weight_list[i + 1]->template data(); + auto in_size = weight_list[i]->numel(); + bool temp = in_data + in_size == in_after_data; + continuous = continuous && temp; + } + return continuous; +} + +int size_sum(const std::vector &weight_list) { + int size = 0; + for (size_t i = 0; i < weight_list.size(); ++i) { + auto in_size = weight_list[i]->numel(); + size += in_size; + } + return size; +} + +template +void weight_to_tensor(const platform::Place &place, cudaStream_t stream, + const std::vector &weight_list, + Tensor *weight) { + auto weight_data = weight->data(); + int weight_offset = 0; + for (size_t i = 0; i < weight_list.size(); ++i) { + const T *in_data = weight_list[i]->data(); + auto in_size = weight_list[i]->numel(); + + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, weight->place()), + weight_data + weight_offset, + BOOST_GET_CONST(platform::CUDAPlace, weight_list[i]->place()), + in_data, in_size * sizeof(T), stream); + weight_offset += in_size; + } +} + +template +void weight_to_tensor_list(const platform::Place &place, cudaStream_t stream, + std::vector *weight_grad, + const std::vector &weight_input, + const Tensor *weight) { + int weight_offset = 0; + auto *weight_data = weight->data(); + for (size_t i = 0; i < weight_input.size(); ++i) { + auto in_size = weight_input[i]->numel(); + T *weight_grad_data = (*weight_grad)[i]->mutable_data(place); + const T *src = weight_data + weight_offset; + + memory::Copy( + BOOST_GET_CONST(platform::CUDAPlace, (*weight_grad)[i]->place()), + weight_grad_data, BOOST_GET_CONST(platform::CUDAPlace, weight->place()), + src, in_size * sizeof(T), stream); + weight_offset += in_size; + } +} + template void LSTMInferece(const bool &has_seq_length, const cudnnHandle_t &handle, const int &seq_length, ScopedRNNBase *rnn, const T *x_data, @@ -75,8 +135,6 @@ class CudnnLSTMGPUKernel : public framework::OpKernel { const Tensor *init_h = ctx.Input("InitH"); const Tensor *init_c = ctx.Input("InitC"); - auto w = ctx.Input("W"); - Tensor *out = ctx.Output("Out"); Tensor *last_h = ctx.Output("LastH"); Tensor *last_c = ctx.Output("LastC"); @@ -87,8 +145,6 @@ class CudnnLSTMGPUKernel : public framework::OpKernel { const T *init_h_data = init_h->data(); const T *init_c_data = init_c->data(); - const T *w_data = w->data(); - T *out_data = out->mutable_data(ctx.GetPlace()); T *last_h_data = last_h->mutable_data(ctx.GetPlace()); T *last_c_data = last_c->mutable_data(ctx.GetPlace()); @@ -113,11 +169,45 @@ class CudnnLSTMGPUKernel : public framework::OpKernel { int seq_length = x->dims()[0]; int batch_size = x->dims()[1]; int input_size = x->dims()[2]; - int weight_numel = w->numel(); bool state_initialized = state_out->IsInitialized() ? true : false; size_t workspace_size; size_t reserve_size; + Tensor weight_whole; + T *w_data = nullptr; + int weight_numel; + bool w_initialized = false; + auto place = ctx.GetPlace(); + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + if (is_test && ctx.HasInput("W")) { + auto *W = ctx.Input("W"); + w_initialized = W->IsInitialized() ? true : false; + weight_numel = W->numel(); + } + if (!w_initialized) { + auto weight_list = ctx.MultiInput("WeightList"); + bool continuous = + is_continuous>(weight_list); + weight_numel = size_sum(weight_list); + + if (!continuous) { + LOG_FIRST_N(WARNING, 2) + << "If the memory space of the Input WeightList is not " + "continuous, less efficient calculation will be " + "called. Please call coalesce_tensor op to make the " + "input memory continuous."; + weight_whole.mutable_data({weight_numel}, place); + weight_to_tensor(place, stream, weight_list, &weight_whole); + w_data = weight_whole.data(); + } else { + w_data = const_cast(weight_list[0]->data()); + } + } else { + auto *W = ctx.Input("W"); + w_data = const_cast(W->data()); + } ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size, num_layers, dropout_prob, seed, weight_numel, @@ -136,6 +226,12 @@ class CudnnLSTMGPUKernel : public framework::OpKernel { LSTMInferece(has_seq_length, handle, seq_length, &rnn, x_data, init_h_data, init_c_data, w_data, out_data, last_h_data, last_c_data, &workspace_data_, workspace_size); + if (!w_initialized && ctx.HasInput("W") && ctx.HasInput("WeightList")) { + auto *W = const_cast(ctx.Input("W")); + auto weight_list = ctx.MultiInput("WeightList"); + W->mutable_data({weight_numel}, place); + weight_to_tensor(place, stream, weight_list, W); + } } else { if (!has_seq_length) { // for train @@ -176,11 +272,11 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { auto *input = ctx.Input("Input"); - auto *weight = ctx.Input("W"); auto *init_h = ctx.Input("InitH"); auto *init_c = ctx.Input("InitC"); auto *reserve = ctx.Input("Reserve"); auto *state_out = ctx.Input("StateOut"); + auto weight_list = ctx.MultiInput("WeightList"); auto *out = ctx.Input("Out"); auto *out_grad = ctx.Input(framework::GradVarName("Out")); @@ -188,9 +284,10 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { auto *last_c_grad = ctx.Input(framework::GradVarName("LastC")); auto *in_grad = ctx.Output(framework::GradVarName("Input")); - auto *weight_grad = ctx.Output(framework::GradVarName("W")); auto *init_h_grad = ctx.Output(framework::GradVarName("InitH")); auto *init_c_grad = ctx.Output(framework::GradVarName("InitC")); + auto weight_grad_list = ctx.MultiOutput( + framework::GradVarName("WeightList")); auto &dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); @@ -199,7 +296,6 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { auto init_h_dims = init_h->dims(); auto init_c_dims = init_c->dims(); - auto *weight_data = weight->data(); auto *init_h_data = init_h->data(); auto *init_c_data = init_c->data(); auto *out_data = out->data(); @@ -207,18 +303,50 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { auto *last_h_grad_data = last_h_grad->data(); auto *last_c_grad_data = last_c_grad->data(); + auto place = ctx.GetPlace(); + int weight_numel = size_sum(weight_list); + bool continuous = + is_continuous>(weight_list); + + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + Tensor weight_whole; + T *weight_data = nullptr; + + if (!continuous) { + weight_whole.mutable_data({weight_numel}, place); + weight_to_tensor(place, stream, weight_list, &weight_whole); + weight_data = weight_whole.data(); + } else { + weight_data = const_cast(weight_list[0]->data()); + } + + Tensor weight_grad; math::SetConstant zero; - weight_grad->mutable_data(ctx.GetPlace()); - zero(dev_ctx, weight_grad, static_cast(0.0)); + weight_grad.mutable_data({weight_numel}, ctx.GetPlace()); + zero(dev_ctx, &weight_grad, static_cast(0.0)); + T *weight_grad_data = weight_grad.data(); + + int offset = 0; + for (size_t i = 0; i < weight_grad_list.size(); ++i) { + size_t len = weight_grad_list[i]->numel(); + auto dim = weight_grad_list[i]->dims(); + weight_grad_list[i] + ->ShareDataWith(weight_grad.Slice(static_cast(offset), + static_cast(offset + len))) + .Resize(dim); + offset += len; + } in_grad->mutable_data(input_dims, ctx.GetPlace()); auto *in_grad_data = in_grad->data(); - init_h_grad->mutable_data(init_h_dims, ctx.GetPlace()); - auto *init_h_grad_data = init_h_grad->data(); + if (init_h_grad) init_h_grad->mutable_data(init_h_dims, ctx.GetPlace()); + auto *init_h_grad_data = init_h_grad ? init_h_grad->data() : nullptr; - init_c_grad->mutable_data(init_c_dims, ctx.GetPlace()); - auto *init_c_grad_data = init_c_grad->data(); + if (init_c_grad) init_c_grad->mutable_data(init_c_dims, ctx.GetPlace()); + auto *init_c_grad_data = init_c_grad ? init_c_grad->data() : nullptr; float dropout_prob = ctx.Attr("dropout_prob"); bool is_bidirec = ctx.Attr("is_bidirec"); @@ -236,7 +364,6 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { int seq_length = input_dims[0]; int batch_size = input->dims()[1]; int input_size = input->dims()[2]; - int weight_numel = weight->numel(); size_t workspace_size; size_t reserve_size; @@ -268,8 +395,7 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data(), rnn.init_h_desc(), init_h->data(), rnn.y_descs(), out->data(), workspace_data_.data(), workspace_size, rnn.weight_desc(), - weight_grad->data(), const_cast(reserve_data), - reserve_size)); + weight_grad_data, const_cast(reserve_data), reserve_size)); } else { #if CUDNN_VERSION >= 7201 // for train @@ -288,7 +414,7 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { handle, rnn.rnn_desc(), rnn.x_seq_desc(), input->data(), rnn.init_h_desc(), init_h->data(), rnn.y_seq_desc(), out->data(), workspace_data_.data(), workspace_size, - rnn.weight_desc(), weight_grad->data(), + rnn.weight_desc(), weight_grad_data, const_cast(reserve_data), reserve_size)); #else PADDLE_THROW(platform::errors::Unavailable( diff --git a/python/paddle/fluid/layers/rnn.py b/python/paddle/fluid/layers/rnn.py index 8ac46ad2648fd..57c2489194337 100644 --- a/python/paddle/fluid/layers/rnn.py +++ b/python/paddle/fluid/layers/rnn.py @@ -2443,23 +2443,17 @@ def lstm(input, input_shape = list(input.shape) input_size = input_shape[-1] weight_size = 0 + num_dirrection = 2 if is_bidirec == True else 1 + for i in range(num_layers): if i == 0: - input_weight_size = (input_size * hidden_size) * 4 + input_weight_size = (input_size * hidden_size) * 4 * num_dirrection else: - if is_bidirec: - input_weight_size = (hidden_size * 2 * hidden_size) * 4 - else: - input_weight_size = (hidden_size * hidden_size) * 4 + input_weight_size = (hidden_size * hidden_size) * 4 * num_dirrection + hidden_weight_size = (hidden_size * hidden_size) * 4 * num_dirrection - hidden_weight_size = (hidden_size * hidden_size) * 4 - - if is_bidirec: - weight_size += (input_weight_size + hidden_weight_size) * 2 - weight_size += hidden_size * 8 * 2 - else: - weight_size += input_weight_size + hidden_weight_size - weight_size += hidden_size * 8 + weight_size += input_weight_size + hidden_weight_size + weight_size += hidden_size * 8 * num_dirrection weight = helper.create_parameter( attr=helper.param_attr, diff --git a/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py b/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py index 29a0fa55f7729..82443f8c5493b 100644 --- a/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py @@ -20,14 +20,44 @@ import paddle.fluid.core as core from op_test import OpTest +import paddle import paddle.fluid as fluid import paddle.fluid.layers as layers +import random +random.seed(2) +np.set_printoptions(threshold=np.inf) +paddle.enable_static() SIGMOID_THRESHOLD_MIN = -40.0 SIGMOID_THRESHOLD_MAX = 13.0 EXP_MAX_INPUT = 40.0 +class RandomWeight: + def __init__(self): + pass + + def updata_weight(self, hidden_size, input_size, dtype): + std = 1.0 / math.sqrt(hidden_size) + self.hidden_size = hidden_size + self.input_size = input_size + self.dtype = dtype + + self.weight_ih = np.random.uniform( + low=-std, high=std, size=(4 * self.hidden_size, + self.input_size)).astype(dtype) + self.weight_hh = np.random.uniform( + low=-std, high=std, size=(4 * self.hidden_size, + self.hidden_size)).astype(dtype) + self.bias_ih = np.random.uniform( + low=-std, high=std, size=(4 * self.hidden_size)).astype(dtype) + self.bias_hh = np.random.uniform( + low=-std, high=std, size=(4 * self.hidden_size)).astype(dtype) + + +weight = RandomWeight() + + class LayerMixin(object): def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) @@ -51,16 +81,13 @@ def __init__(self, input_size, hidden_size, bias=True): self.bias = bias self.dtype = np.float64 self.parameters = dict() - std = 1.0 / math.sqrt(hidden_size) - self.weight_ih = np.ones( - (4 * hidden_size, input_size), dtype=self.dtype) - self.weight_hh = np.ones((4 * hidden_size, - hidden_size)).astype(self.dtype) + self.weight_ih = weight.weight_ih + self.weight_hh = weight.weight_hh self.parameters['weight_ih'] = self.weight_ih self.parameters['weight_hh'] = self.weight_hh if bias: - self.bias_ih = np.ones((4 * hidden_size)).astype(self.dtype) - self.bias_hh = np.ones((4 * hidden_size)).astype(self.dtype) + self.bias_ih = weight.bias_ih + self.bias_hh = weight.bias_hh self.parameters['bias_ih'] = self.bias_ih self.parameters['bias_hh'] = self.bias_hh else: @@ -353,24 +380,26 @@ def __init__(self, @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") class TestCUDNNLstmOp(OpTest): - #TODO(GaoWei8): Need to satisfy the result through the new interface + def get_weight_names(self): + weight_names = [] + for i in range(2 * self.num_layers): + weight_names.append('weight{}'.format(i)) + for i in range(2 * self.num_layers): + weight_names.append('bias{}'.format(i)) + return weight_names + def setUp(self): self.op_type = "cudnn_lstm" self.dtype = np.float64 self.sequence_length = np.array([12, 11, 10, 9, 8], dtype=np.int32) self.num_layers = 1 + self.set_attrs() seq_length = 12 batch_size = 5 input_size = 21 hidden_size = 21 - input_weight_size = (hidden_size * hidden_size) * 4 - hidden_weight_size = (hidden_size * hidden_size) * 4 - weight_size = input_weight_size + hidden_weight_size - weight_size += hidden_size * 8 - weight_size *= self.num_layers - input = np.random.uniform( low=-0.1, high=0.1, size=(seq_length, batch_size, input_size)).astype(self.dtype) @@ -379,17 +408,39 @@ def setUp(self): input[9][3:][:] = 0 input[8][4:][:] = 0 + weight.updata_weight(hidden_size, input_size, self.dtype) rnn1 = LSTM( input_size, hidden_size, - self.num_layers, + num_layers=self.num_layers, time_major=True, direction="forward") output, (last_hidden, last_cell) = rnn1( input, sequence_length=self.sequence_length) - flat_w = np.ones((weight_size)).astype(self.dtype) + flat_w = [] + num = 0 + for i in range(self.num_layers): + if i == 0: + weight_ih = weight.weight_ih + else: + weight_ih = weight.weight_hh + flat_w.append(("weight" + str(num), weight_ih)) + num += 1 + for i in range(self.num_layers): + weight_hh = weight.weight_hh + flat_w.append(("weight" + str(num), weight_hh)) + num += 1 + num = 0 + for i in range(self.num_layers): + bias_ih = weight.bias_ih + flat_w.append(("bias" + str(num), bias_ih)) + num += 1 + for i in range(self.num_layers): + bias_hh = weight.bias_hh + flat_w.append(("bias" + str(num), bias_hh)) + num += 1 init_h = np.zeros((self.num_layers, batch_size, hidden_size)).astype(self.dtype) init_c = np.zeros((self.num_layers, batch_size, @@ -398,7 +449,7 @@ def setUp(self): self.inputs = { 'Input': input, - 'W': flat_w, + 'WeightList': flat_w, 'InitH': init_h, 'InitC': init_c, 'SequenceLength': self.sequence_length @@ -408,7 +459,7 @@ def setUp(self): 'is_bidirec': False, 'input_size': input_size, 'hidden_size': hidden_size, - 'num_layers': 1, + 'num_layers': self.num_layers, } self.outputs = { 'Out': output, @@ -428,16 +479,42 @@ def test_output_with_place(self): def test_grad_with_place(self): place = core.CUDAPlace(0) - self.check_grad_with_place(place, - set(['Input', 'W', 'InitH', 'InitC']), - ['Out', 'LastH', 'LastC']) + var_name_list = self.get_weight_names() + for var_name in var_name_list: + self.check_grad_with_place( + place, + set(['Input', var_name, 'InitH', 'InitC']), + ['Out', 'LastH', 'LastC']) @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") -class TestCUDNNLstmOp2(TestCUDNNLstmOp): - def set_attrs(self): - self.num_layers = 2 +class TestCUDNNlstmAPI(unittest.TestCase): + def test_lstm(self): + seq_len = 20 + batch_size = 5 + hidden_size = 20 + dropout_prob = 0.0 + num_layers = 1 + input = fluid.data( + name='input', + shape=[seq_len, batch_size, hidden_size], + dtype='float64') + init_h = layers.fill_constant([num_layers, batch_size, hidden_size], + 'float64', 0.0) + init_c = layers.fill_constant([num_layers, batch_size, hidden_size], + 'float64', 0.0) + rnn_out, last_h, last_c = layers.lstm(input, init_h, init_c, seq_len, + hidden_size, num_layers, + dropout_prob, False) + exe = fluid.Executor(fluid.CUDAPlace(0)) + exe.run(fluid.default_startup_program()) + input_i = np.random.uniform( + low=-0.1, high=0.1, size=(seq_len, batch_size, + hidden_size)).astype("float64") + out = exe.run(fluid.default_main_program(), + feed={'input': input_i}, + fetch_list=[rnn_out, last_h, last_c, 'cudnn_lstm_0.w_0']) @unittest.skipIf(not core.is_compiled_with_cuda(), @@ -448,7 +525,7 @@ def test_lstm(self): batch_size = 5 hidden_size = 20 dropout_prob = 0.0 - num_layers = 1 + num_layers = 2 input = fluid.data( name='input', shape=[seq_len, batch_size, hidden_size],