From 5e98ccaf870b09b092e8c69182a58a3c0bcc9b98 Mon Sep 17 00:00:00 2001 From: GaoWei8 Date: Wed, 16 Sep 2020 12:14:52 +0000 Subject: [PATCH] add flattern weight of lstm --- paddle/fluid/operators/cudnn_lstm_op.cc | 18 ++- paddle/fluid/operators/cudnn_lstm_op.cu.cc | 122 ++++++++++++++++-- python/paddle/fluid/layers/rnn.py | 58 +++++---- .../tests/unittests/test_lstm_cudnn_op.py | 74 ++++++++--- 4 files changed, 215 insertions(+), 57 deletions(-) diff --git a/paddle/fluid/operators/cudnn_lstm_op.cc b/paddle/fluid/operators/cudnn_lstm_op.cc index 82954bc109a74..e7d76c7abe83c 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cc @@ -25,7 +25,6 @@ class CudnnLSTMOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTM"); - OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTM"); OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTM"); OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTM"); @@ -122,7 +121,13 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("W", "(Tensor) the learnable hidden-hidden weights." " The shape is (N), where N is total weight size of the LSTM. " - " cudnn concatenate all the weight to one Tensor"); + " cudnn concatenate all the weight to one Tensor") + .AsDispensable(); + AddInput("WeightList", + "(vector), stores weight and bias data when the weight " + "use the list format. ") + .AsDispensable() + .AsDuplicable(); AddInput("SequenceLength", "(Tensor) When the input data is padding, " "set this parameter. This parameter represents " @@ -216,7 +221,6 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTMGrad"); - OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTMGrad"); OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTMGrad"); OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTMGrad"); @@ -228,7 +232,8 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel { }; SetOutGradDim("Input"); - SetOutGradDim("W"); + ctx->SetOutputsDim(framework::GradVarName("WeightList"), + ctx->GetInputsDim("WeightList")); SetOutGradDim("InitH"); SetOutGradDim("InitC"); } @@ -251,7 +256,7 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker { op->SetInput("Input", this->Input("Input")); op->SetInput("InitH", this->Input("InitH")); op->SetInput("InitC", this->Input("InitC")); - op->SetInput("W", this->Input("W")); + op->SetInput("WeightList", this->Input("WeightList")); if (this->HasInput("SequenceLength")) { op->SetInput("SequenceLength", this->Input("SequenceLength")); } @@ -262,6 +267,9 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker { op->SetInput(framework::GradVarName("LastC"), this->OutputGrad("LastC")); op->SetInput(framework::GradVarName("LastH"), this->OutputGrad("LastH")); + op->SetOutput(framework::GradVarName("WeightList"), + this->InputGrad("WeightList", false)); + op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input")); op->SetOutput(framework::GradVarName("W"), this->InputGrad("W")); op->SetOutput(framework::GradVarName("InitH"), this->InputGrad("InitH")); diff --git a/paddle/fluid/operators/cudnn_lstm_op.cu.cc b/paddle/fluid/operators/cudnn_lstm_op.cu.cc index 6457d9295dcbf..d57c064709790 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cu.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cu.cc @@ -25,6 +25,66 @@ namespace operators { using LoDTensor = framework::LoDTensor; using Tensor = framework::Tensor; +template +bool is_continuous(const std::vector &weight_list) { + bool continuous = true; + for (size_t i = 0; i < weight_list.size() - 1; ++i) { + const T *in_data = weight_list[i]->data(); + const T *in_after_data = weight_list[i + 1]->data(); + auto in_size = weight_list[i]->numel(); + bool temp = in_data + in_size == in_after_data; + continuous = continuous && temp; + } + return continuous; +} + +int size_sum(const std::vector &weight_list) { + int size = 0; + for (size_t i = 0; i < weight_list.size(); ++i) { + auto in_size = weight_list[i]->numel(); + size += in_size; + } + return size; +} + +template +void weight_to_tensor(const platform::Place &place, cudaStream_t stream, + const std::vector &weight_list, + Tensor *weight) { + auto weight_data = weight->data(); + int weight_offset = 0; + for (size_t i = 0; i < weight_list.size(); ++i) { + const T *in_data = weight_list[i]->data(); + auto in_size = weight_list[i]->numel(); + + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, weight->place()), + weight_data + weight_offset, + BOOST_GET_CONST(platform::CUDAPlace, weight_list[i]->place()), + in_data, in_size * sizeof(T), stream); + weight_offset += in_size; + } +} + +template +void weight_to_tensor_list(const platform::Place &place, cudaStream_t stream, + std::vector *weight_grad, + const std::vector &weight_input, + const Tensor *weight) { + int weight_offset = 0; + auto *weight_data = weight->data(); + for (size_t i = 0; i < weight_input.size(); ++i) { + auto in_size = weight_input[i]->numel(); + T *weight_grad_data = (*weight_grad)[i]->mutable_data(place); + const T *src = weight_data + weight_offset; + + memory::Copy( + BOOST_GET_CONST(platform::CUDAPlace, (*weight_grad)[i]->place()), + weight_grad_data, BOOST_GET_CONST(platform::CUDAPlace, weight->place()), + src, in_size * sizeof(T), stream); + weight_offset += in_size; + } +} + template void LSTMInferece(const bool &has_seq_length, const cudnnHandle_t &handle, const int &seq_length, ScopedRNNBase *rnn, const T *x_data, @@ -70,8 +130,6 @@ class CudnnLSTMGPUKernel : public framework::OpKernel { const Tensor *init_h = ctx.Input("InitH"); const Tensor *init_c = ctx.Input("InitC"); - auto w = ctx.Input("W"); - Tensor *out = ctx.Output("Out"); Tensor *last_h = ctx.Output("LastH"); Tensor *last_c = ctx.Output("LastC"); @@ -82,12 +140,30 @@ class CudnnLSTMGPUKernel : public framework::OpKernel { const T *init_h_data = init_h->data(); const T *init_c_data = init_c->data(); - const T *w_data = w->data(); - T *out_data = out->mutable_data(ctx.GetPlace()); T *last_h_data = last_h->mutable_data(ctx.GetPlace()); T *last_c_data = last_c->mutable_data(ctx.GetPlace()); + auto weight_list = ctx.MultiInput("WeightList"); + bool continuous = is_continuous(weight_list); + int weight_numel = size_sum(weight_list); + + auto place = ctx.GetPlace(); + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + Tensor weight_whole; + T *w_data = nullptr; + + if (!continuous) { + // LOG(WARNING) << "Input WeightList , please use op"; + weight_whole.mutable_data({weight_numel}, place); + weight_to_tensor(place, stream, weight_list, &weight_whole); + w_data = weight_whole.data(); + } else { + w_data = const_cast(weight_list[0]->data()); + } + float dropout_prob = ctx.Attr("dropout_prob"); bool is_bidirec = ctx.Attr("is_bidirec"); int hidden_size = ctx.Attr("hidden_size"); @@ -108,7 +184,6 @@ class CudnnLSTMGPUKernel : public framework::OpKernel { int seq_length = x->dims()[0]; int batch_size = x->dims()[1]; int input_size = x->dims()[2]; - int weight_numel = w->numel(); bool state_initialized = state_out->IsInitialized() ? true : false; size_t workspace_size; @@ -171,11 +246,11 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { auto *input = ctx.Input("Input"); - auto *weight = ctx.Input("W"); auto *init_h = ctx.Input("InitH"); auto *init_c = ctx.Input("InitC"); auto *reserve = ctx.Input("Reserve"); auto *state_out = ctx.Input("StateOut"); + auto weight_list = ctx.MultiInput("WeightList"); auto *out = ctx.Input("Out"); auto *out_grad = ctx.Input(framework::GradVarName("Out")); @@ -183,9 +258,10 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { auto *last_c_grad = ctx.Input(framework::GradVarName("LastC")); auto *in_grad = ctx.Output(framework::GradVarName("Input")); - auto *weight_grad = ctx.Output(framework::GradVarName("W")); auto *init_h_grad = ctx.Output(framework::GradVarName("InitH")); auto *init_c_grad = ctx.Output(framework::GradVarName("InitC")); + auto weight_grad_list = ctx.MultiOutput( + framework::GradVarName("WeightList")); auto &dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); @@ -194,7 +270,6 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { auto init_h_dims = init_h->dims(); auto init_c_dims = init_c->dims(); - auto *weight_data = weight->data(); auto *init_h_data = init_h->data(); auto *init_c_data = init_c->data(); auto *out_data = out->data(); @@ -202,9 +277,29 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { auto *last_h_grad_data = last_h_grad->data(); auto *last_c_grad_data = last_c_grad->data(); + auto place = ctx.GetPlace(); + int weight_numel = size_sum(weight_list); + bool continuous = is_continuous(weight_list); + + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + Tensor weight_whole; + weight_whole.mutable_data({weight_numel}, place); + T *weight_data = nullptr; + + if (!continuous) { + weight_to_tensor(place, stream, weight_list, &weight_whole); + weight_data = weight_whole.data(); + } else { + weight_data = const_cast(weight_list[0]->data()); + } + + Tensor weight_grad; + math::SetConstant zero; - weight_grad->mutable_data(ctx.GetPlace()); - zero(dev_ctx, weight_grad, static_cast(0.0)); + weight_grad.mutable_data({weight_numel}, ctx.GetPlace()); + zero(dev_ctx, &weight_grad, static_cast(0.0)); in_grad->mutable_data(input_dims, ctx.GetPlace()); auto *in_grad_data = in_grad->data(); @@ -231,7 +326,6 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { int seq_length = input_dims[0]; int batch_size = input->dims()[1]; int input_size = input->dims()[2]; - int weight_numel = weight->numel(); size_t workspace_size; size_t reserve_size; @@ -263,7 +357,7 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data(), rnn.init_h_desc(), init_h->data(), rnn.y_descs(), out->data(), workspace_data_.data(), workspace_size, rnn.weight_desc(), - weight_grad->data(), const_cast(reserve_data), + weight_grad.data(), const_cast(reserve_data), reserve_size)); } else { #if CUDNN_VERSION >= 7201 @@ -283,7 +377,7 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { handle, rnn.rnn_desc(), rnn.x_seq_desc(), input->data(), rnn.init_h_desc(), init_h->data(), rnn.y_seq_desc(), out->data(), workspace_data_.data(), workspace_size, - rnn.weight_desc(), weight_grad->data(), + rnn.weight_desc(), weight_grad.data(), const_cast(reserve_data), reserve_size)); #else PADDLE_THROW(platform::errors::Unavailable( @@ -292,6 +386,8 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { "of cudnn is larger than 7.2.1")); #endif } + weight_to_tensor_list(place, stream, &weight_grad_list, weight_list, + &weight_grad); } }; diff --git a/python/paddle/fluid/layers/rnn.py b/python/paddle/fluid/layers/rnn.py index fe8ed83923e88..5d5855af562ff 100644 --- a/python/paddle/fluid/layers/rnn.py +++ b/python/paddle/fluid/layers/rnn.py @@ -2439,29 +2439,43 @@ def lstm(input, input_shape = list(input.shape) input_size = input_shape[-1] weight_size = 0 + num_dirrection = 2 if is_bidirec == True else 1 + + weight_list = [] for i in range(num_layers): - if i == 0: - input_weight_size = (input_size * hidden_size) * 4 - else: - if is_bidirec: - input_weight_size = (hidden_size * 2 * hidden_size) * 4 + for j in range(num_dirrection): + if i == 0: + weight_ih = helper.create_parameter( + attr=helper.param_attr, + shape=[4 * hidden_size, input_size], + dtype=dtype, + default_initializer=default_initializer) else: - input_weight_size = (hidden_size * hidden_size) * 4 - - hidden_weight_size = (hidden_size * hidden_size) * 4 - - if is_bidirec: - weight_size += (input_weight_size + hidden_weight_size) * 2 - weight_size += hidden_size * 8 * 2 - else: - weight_size += input_weight_size + hidden_weight_size - weight_size += hidden_size * 8 - - weight = helper.create_parameter( - attr=helper.param_attr, - shape=[weight_size], - dtype=dtype, - default_initializer=default_initializer) + weight_ih = helper.create_parameter( + attr=helper.param_attr, + shape=[4 * hidden_size, hidden_size], + dtype=dtype, + default_initializer=default_initializer) + weight_list.append(weight_ih) + + weight_hh = helper.create_parameter( + attr=helper.param_attr, + shape=[4 * hidden_size, hidden_size], + dtype=dtype, + default_initializer=default_initializer) + weight_list.append(weight_hh) + bias_ih = helper.create_parameter( + attr=helper.param_attr, + shape=[4 * hidden_size], + dtype=dtype, + default_initializer=default_initializer) + weight_list.append(bias_ih) + bias_hh = helper.create_parameter( + attr=helper.param_attr, + shape=[4 * hidden_size], + dtype=dtype, + default_initializer=default_initializer) + weight_list.append(bias_hh) out = helper.create_variable_for_type_inference(dtype) last_h = helper.create_variable_for_type_inference(dtype) @@ -2478,7 +2492,7 @@ def lstm(input, 'Input': input, 'InitH': init_h, 'InitC': init_c, - 'W': weight, + 'WeightList': weight_list, }, outputs={ 'Out': out, diff --git a/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py b/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py index 29a0fa55f7729..1af102389933b 100644 --- a/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py @@ -22,12 +22,38 @@ from op_test import OpTest import paddle.fluid as fluid import paddle.fluid.layers as layers +import random +random.seed(1) SIGMOID_THRESHOLD_MIN = -40.0 SIGMOID_THRESHOLD_MAX = 13.0 EXP_MAX_INPUT = 40.0 +class RandomWeight: + def __init__(self): + pass + + def updata_weight(self, hidden_size, input_size, dtype): + self.hidden_size = hidden_size + self.input_size = input_size + self.dtype = dtype + + self.weight_ih = np.random.uniform( + low=-0.1, high=0.1, size=(4 * self.hidden_size, + self.input_size)).astype(dtype) + self.weight_hh = np.random.uniform( + low=-0.1, high=0.1, size=(4 * self.hidden_size, + self.hidden_size)).astype(dtype) + self.bias_ih = np.random.uniform( + low=-0.1, high=0.1, size=(4 * self.hidden_size)).astype(dtype) + self.bias_hh = np.random.uniform( + low=-0.1, high=0.1, size=(4 * self.hidden_size)).astype(dtype) + + +weight = RandomWeight() + + class LayerMixin(object): def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) @@ -52,15 +78,13 @@ def __init__(self, input_size, hidden_size, bias=True): self.dtype = np.float64 self.parameters = dict() std = 1.0 / math.sqrt(hidden_size) - self.weight_ih = np.ones( - (4 * hidden_size, input_size), dtype=self.dtype) - self.weight_hh = np.ones((4 * hidden_size, - hidden_size)).astype(self.dtype) + self.weight_ih = weight.weight_ih + self.weight_hh = weight.weight_hh self.parameters['weight_ih'] = self.weight_ih self.parameters['weight_hh'] = self.weight_hh if bias: - self.bias_ih = np.ones((4 * hidden_size)).astype(self.dtype) - self.bias_hh = np.ones((4 * hidden_size)).astype(self.dtype) + self.bias_ih = weight.bias_ih + self.bias_hh = weight.bias_hh self.parameters['bias_ih'] = self.bias_ih self.parameters['bias_hh'] = self.bias_hh else: @@ -354,6 +378,12 @@ def __init__(self, "core is not compiled with CUDA") class TestCUDNNLstmOp(OpTest): #TODO(GaoWei8): Need to satisfy the result through the new interface + def get_weight_names(self): + weight_names = [] + for i in range(4 * self.num_layers): + weight_names.append('weight{}'.format(i)) + return weight_names + def setUp(self): self.op_type = "cudnn_lstm" self.dtype = np.float64 @@ -365,12 +395,6 @@ def setUp(self): input_size = 21 hidden_size = 21 - input_weight_size = (hidden_size * hidden_size) * 4 - hidden_weight_size = (hidden_size * hidden_size) * 4 - weight_size = input_weight_size + hidden_weight_size - weight_size += hidden_size * 8 - weight_size *= self.num_layers - input = np.random.uniform( low=-0.1, high=0.1, size=(seq_length, batch_size, input_size)).astype(self.dtype) @@ -379,6 +403,7 @@ def setUp(self): input[9][3:][:] = 0 input[8][4:][:] = 0 + weight.updata_weight(hidden_size, input_size, self.dtype) rnn1 = LSTM( input_size, hidden_size, @@ -389,7 +414,19 @@ def setUp(self): output, (last_hidden, last_cell) = rnn1( input, sequence_length=self.sequence_length) - flat_w = np.ones((weight_size)).astype(self.dtype) + flat_w = [] + for i in range(self.num_layers): + if i == 0: + weight_ih = weight.weight_ih + else: + weight_ih = weight.weight_hh + weight_hh = weight.weight_hh + bias_ih = weight.bias_ih + bias_hh = weight.bias_hh + flat_w.append(("weight" + str(4 * i), weight_ih)) + flat_w.append(("weight" + str(4 * i + 1), weight_hh)) + flat_w.append(("weight" + str(4 * i + 2), bias_ih)) + flat_w.append(("weight" + str(4 * i + 3), bias_hh)) init_h = np.zeros((self.num_layers, batch_size, hidden_size)).astype(self.dtype) init_c = np.zeros((self.num_layers, batch_size, @@ -398,7 +435,7 @@ def setUp(self): self.inputs = { 'Input': input, - 'W': flat_w, + 'WeightList': flat_w, 'InitH': init_h, 'InitC': init_c, 'SequenceLength': self.sequence_length @@ -428,9 +465,12 @@ def test_output_with_place(self): def test_grad_with_place(self): place = core.CUDAPlace(0) - self.check_grad_with_place(place, - set(['Input', 'W', 'InitH', 'InitC']), - ['Out', 'LastH', 'LastC']) + var_name_list = self.get_weight_names() + for var_name in var_name_list: + self.check_grad_with_place( + place, + set(['Input', var_name, 'InitH', 'InitC']), + ['Out', 'LastH', 'LastC']) @unittest.skipIf(not core.is_compiled_with_cuda(),