Skip to content

Commit

Permalink
Add flattern weight of lstm (PaddlePaddle#27192)
Browse files Browse the repository at this point in the history
* add flattern weight of lstm
  • Loading branch information
GaoWei8 authored and chen-zhiyu committed Oct 15, 2020
1 parent 3bcf055 commit 9ec8766
Show file tree
Hide file tree
Showing 4 changed files with 290 additions and 62 deletions.
43 changes: 37 additions & 6 deletions paddle/fluid/operators/cudnn_lstm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace operators {
Expand All @@ -25,7 +26,6 @@ class CudnnLSTMOp : public framework::OperatorWithKernel {

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTM");

Expand Down Expand Up @@ -122,7 +122,13 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("W",
"(Tensor) the learnable hidden-hidden weights."
" The shape is (N), where N is total weight size of the LSTM. "
" cudnn concatenate all the weight to one Tensor");
" cudnn concatenate all the weight to one Tensor")
.AsDispensable();
AddInput("WeightList",
"(vector<Tensor>), stores weight and bias data when the weight "
"use the list format. ")
.AsDispensable()
.AsDuplicable();
AddInput("SequenceLength",
"(Tensor) When the input data is padding, "
"set this parameter. This parameter represents "
Expand Down Expand Up @@ -216,7 +222,6 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTMGrad");
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTMGrad");
OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTMGrad");
OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTMGrad");

Expand All @@ -228,7 +233,10 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {
};

SetOutGradDim("Input");
SetOutGradDim("W");
if (ctx->HasInputs("WeightList")) {
ctx->SetOutputsDim(framework::GradVarName("WeightList"),
ctx->GetInputsDim("WeightList"));
}
SetOutGradDim("InitH");
SetOutGradDim("InitC");
}
Expand All @@ -251,7 +259,9 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("Input", this->Input("Input"));
op->SetInput("InitH", this->Input("InitH"));
op->SetInput("InitC", this->Input("InitC"));
op->SetInput("W", this->Input("W"));
if (this->HasInput("WeightList")) {
op->SetInput("WeightList", this->Input("WeightList"));
}
if (this->HasInput("SequenceLength")) {
op->SetInput("SequenceLength", this->Input("SequenceLength"));
}
Expand All @@ -262,8 +272,12 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput(framework::GradVarName("LastC"), this->OutputGrad("LastC"));
op->SetInput(framework::GradVarName("LastH"), this->OutputGrad("LastH"));

if (this->HasInput("WeightList")) {
op->SetOutput(framework::GradVarName("WeightList"),
this->InputGrad("WeightList", false));
}

op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("W"), this->InputGrad("W"));
op->SetOutput(framework::GradVarName("InitH"), this->InputGrad("InitH"));
op->SetOutput(framework::GradVarName("InitC"), this->InputGrad("InitC"));
op->SetAttrMap(this->Attrs());
Expand All @@ -290,3 +304,20 @@ REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp);

REGISTER_OP_CPU_KERNEL(cudnn_lstm, ops::NotImpleKernel<float>);
REGISTER_OP_CPU_KERNEL(cudnn_lstm_grad, ops::NotImpleKernel<float>);

// TODO(Shixiaowei02) Add ModifyInput support
REGISTER_OP_VERSION(cudnn_lstm)
.AddCheckpoint(
R"ROC(
Upgrade cudnn_lstm add a new input [WeightList] and modify input [W] to dispensable.)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewInput(
"WeightList",
"The WeightList stores weight and bias data. WeightList is "
"dispensable.")
.NewInput("SequenceLength",
"When the input data is padding, set this parameter. "
"SequenceLength is dispensable.")
.NewOutput("StateOut", "Store the global drop state when training")
.NewOutput("Reserve",
"A temporary output Tensor to store the reserve_data"));
162 changes: 144 additions & 18 deletions paddle/fluid/operators/cudnn_lstm_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,66 @@ namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;

template <typename T, typename Type>
bool is_continuous(const Type &weight_list) {
bool continuous = true;
for (size_t i = 0; i < weight_list.size() - 1; ++i) {
auto *in_data = weight_list[i]->template data<T>();
auto *in_after_data = weight_list[i + 1]->template data<T>();
auto in_size = weight_list[i]->numel();
bool temp = in_data + in_size == in_after_data;
continuous = continuous && temp;
}
return continuous;
}

int size_sum(const std::vector<const Tensor *> &weight_list) {
int size = 0;
for (size_t i = 0; i < weight_list.size(); ++i) {
auto in_size = weight_list[i]->numel();
size += in_size;
}
return size;
}

template <typename T>
void weight_to_tensor(const platform::Place &place, cudaStream_t stream,
const std::vector<const Tensor *> &weight_list,
Tensor *weight) {
auto weight_data = weight->data<T>();
int weight_offset = 0;
for (size_t i = 0; i < weight_list.size(); ++i) {
const T *in_data = weight_list[i]->data<T>();
auto in_size = weight_list[i]->numel();

memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, weight->place()),
weight_data + weight_offset,
BOOST_GET_CONST(platform::CUDAPlace, weight_list[i]->place()),
in_data, in_size * sizeof(T), stream);
weight_offset += in_size;
}
}

template <typename T>
void weight_to_tensor_list(const platform::Place &place, cudaStream_t stream,
std::vector<Tensor *> *weight_grad,
const std::vector<const Tensor *> &weight_input,
const Tensor *weight) {
int weight_offset = 0;
auto *weight_data = weight->data<T>();
for (size_t i = 0; i < weight_input.size(); ++i) {
auto in_size = weight_input[i]->numel();
T *weight_grad_data = (*weight_grad)[i]->mutable_data<T>(place);
const T *src = weight_data + weight_offset;

memory::Copy(
BOOST_GET_CONST(platform::CUDAPlace, (*weight_grad)[i]->place()),
weight_grad_data, BOOST_GET_CONST(platform::CUDAPlace, weight->place()),
src, in_size * sizeof(T), stream);
weight_offset += in_size;
}
}

template <typename T>
void LSTMInferece(const bool &has_seq_length, const cudnnHandle_t &handle,
const int &seq_length, ScopedRNNBase *rnn, const T *x_data,
Expand Down Expand Up @@ -75,8 +135,6 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
const Tensor *init_h = ctx.Input<Tensor>("InitH");
const Tensor *init_c = ctx.Input<Tensor>("InitC");

auto w = ctx.Input<Tensor>("W");

Tensor *out = ctx.Output<Tensor>("Out");
Tensor *last_h = ctx.Output<Tensor>("LastH");
Tensor *last_c = ctx.Output<Tensor>("LastC");
Expand All @@ -87,8 +145,6 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
const T *init_h_data = init_h->data<T>();
const T *init_c_data = init_c->data<T>();

const T *w_data = w->data<T>();

T *out_data = out->mutable_data<T>(ctx.GetPlace());
T *last_h_data = last_h->mutable_data<T>(ctx.GetPlace());
T *last_c_data = last_c->mutable_data<T>(ctx.GetPlace());
Expand All @@ -113,11 +169,45 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
int seq_length = x->dims()[0];
int batch_size = x->dims()[1];
int input_size = x->dims()[2];
int weight_numel = w->numel();
bool state_initialized = state_out->IsInitialized() ? true : false;

size_t workspace_size;
size_t reserve_size;
Tensor weight_whole;
T *w_data = nullptr;
int weight_numel;
bool w_initialized = false;
auto place = ctx.GetPlace();
auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
ctx.device_context())
.stream();
if (is_test && ctx.HasInput("W")) {
auto *W = ctx.Input<Tensor>("W");
w_initialized = W->IsInitialized() ? true : false;
weight_numel = W->numel();
}
if (!w_initialized) {
auto weight_list = ctx.MultiInput<framework::Tensor>("WeightList");
bool continuous =
is_continuous<T, std::vector<const Tensor *>>(weight_list);
weight_numel = size_sum(weight_list);

if (!continuous) {
LOG_FIRST_N(WARNING, 2)
<< "If the memory space of the Input WeightList is not "
"continuous, less efficient calculation will be "
"called. Please call coalesce_tensor op to make the "
"input memory continuous.";
weight_whole.mutable_data<T>({weight_numel}, place);
weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
w_data = weight_whole.data<T>();
} else {
w_data = const_cast<T *>(weight_list[0]->data<T>());
}
} else {
auto *W = ctx.Input<Tensor>("W");
w_data = const_cast<T *>(W->data<T>());
}

ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size,
num_layers, dropout_prob, seed, weight_numel,
Expand All @@ -136,6 +226,12 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
LSTMInferece<T>(has_seq_length, handle, seq_length, &rnn, x_data,
init_h_data, init_c_data, w_data, out_data, last_h_data,
last_c_data, &workspace_data_, workspace_size);
if (!w_initialized && ctx.HasInput("W") && ctx.HasInput("WeightList")) {
auto *W = const_cast<Tensor *>(ctx.Input<Tensor>("W"));
auto weight_list = ctx.MultiInput<framework::Tensor>("WeightList");
W->mutable_data<T>({weight_numel}, place);
weight_to_tensor<T>(place, stream, weight_list, W);
}
} else {
if (!has_seq_length) {
// for train
Expand Down Expand Up @@ -176,21 +272,22 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *input = ctx.Input<Tensor>("Input");
auto *weight = ctx.Input<Tensor>("W");
auto *init_h = ctx.Input<Tensor>("InitH");
auto *init_c = ctx.Input<Tensor>("InitC");
auto *reserve = ctx.Input<Tensor>("Reserve");
auto *state_out = ctx.Input<Tensor>("StateOut");
auto weight_list = ctx.MultiInput<Tensor>("WeightList");

auto *out = ctx.Input<Tensor>("Out");
auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *last_h_grad = ctx.Input<Tensor>(framework::GradVarName("LastH"));
auto *last_c_grad = ctx.Input<Tensor>(framework::GradVarName("LastC"));

auto *in_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto *weight_grad = ctx.Output<Tensor>(framework::GradVarName("W"));
auto *init_h_grad = ctx.Output<Tensor>(framework::GradVarName("InitH"));
auto *init_c_grad = ctx.Output<Tensor>(framework::GradVarName("InitC"));
auto weight_grad_list = ctx.MultiOutput<framework::Tensor>(
framework::GradVarName("WeightList"));

auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
Expand All @@ -199,26 +296,57 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
auto init_h_dims = init_h->dims();
auto init_c_dims = init_c->dims();

auto *weight_data = weight->data<T>();
auto *init_h_data = init_h->data<T>();
auto *init_c_data = init_c->data<T>();
auto *out_data = out->data<T>();
auto *out_grad_data = out_grad->data<T>();
auto *last_h_grad_data = last_h_grad->data<T>();
auto *last_c_grad_data = last_c_grad->data<T>();

auto place = ctx.GetPlace();
int weight_numel = size_sum(weight_list);
bool continuous =
is_continuous<T, std::vector<const Tensor *>>(weight_list);

auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
ctx.device_context())
.stream();
Tensor weight_whole;
T *weight_data = nullptr;

if (!continuous) {
weight_whole.mutable_data<T>({weight_numel}, place);
weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
weight_data = weight_whole.data<T>();
} else {
weight_data = const_cast<T *>(weight_list[0]->data<T>());
}

Tensor weight_grad;
math::SetConstant<paddle::platform::CUDADeviceContext, T> zero;
weight_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, weight_grad, static_cast<T>(0.0));
weight_grad.mutable_data<T>({weight_numel}, ctx.GetPlace());
zero(dev_ctx, &weight_grad, static_cast<T>(0.0));
T *weight_grad_data = weight_grad.data<T>();

int offset = 0;
for (size_t i = 0; i < weight_grad_list.size(); ++i) {
size_t len = weight_grad_list[i]->numel();
auto dim = weight_grad_list[i]->dims();
weight_grad_list[i]
->ShareDataWith(weight_grad.Slice(static_cast<int64_t>(offset),
static_cast<int64_t>(offset + len)))
.Resize(dim);
offset += len;
}

in_grad->mutable_data<T>(input_dims, ctx.GetPlace());
auto *in_grad_data = in_grad->data<T>();

init_h_grad->mutable_data<T>(init_h_dims, ctx.GetPlace());
auto *init_h_grad_data = init_h_grad->data<T>();
if (init_h_grad) init_h_grad->mutable_data<T>(init_h_dims, ctx.GetPlace());
auto *init_h_grad_data = init_h_grad ? init_h_grad->data<T>() : nullptr;

init_c_grad->mutable_data<T>(init_c_dims, ctx.GetPlace());
auto *init_c_grad_data = init_c_grad->data<T>();
if (init_c_grad) init_c_grad->mutable_data<T>(init_c_dims, ctx.GetPlace());
auto *init_c_grad_data = init_c_grad ? init_c_grad->data<T>() : nullptr;

float dropout_prob = ctx.Attr<float>("dropout_prob");
bool is_bidirec = ctx.Attr<bool>("is_bidirec");
Expand All @@ -236,7 +364,6 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
int seq_length = input_dims[0];
int batch_size = input->dims()[1];
int input_size = input->dims()[2];
int weight_numel = weight->numel();

size_t workspace_size;
size_t reserve_size;
Expand Down Expand Up @@ -268,8 +395,7 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data<T>(),
rnn.init_h_desc(), init_h->data<T>(), rnn.y_descs(), out->data<T>(),
workspace_data_.data<uint8_t>(), workspace_size, rnn.weight_desc(),
weight_grad->data<T>(), const_cast<uint8_t *>(reserve_data),
reserve_size));
weight_grad_data, const_cast<uint8_t *>(reserve_data), reserve_size));
} else {
#if CUDNN_VERSION >= 7201
// for train
Expand All @@ -288,7 +414,7 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
handle, rnn.rnn_desc(), rnn.x_seq_desc(), input->data<T>(),
rnn.init_h_desc(), init_h->data<T>(), rnn.y_seq_desc(),
out->data<T>(), workspace_data_.data<uint8_t>(), workspace_size,
rnn.weight_desc(), weight_grad->data<T>(),
rnn.weight_desc(), weight_grad_data,
const_cast<uint8_t *>(reserve_data), reserve_size));
#else
PADDLE_THROW(platform::errors::Unavailable(
Expand Down
20 changes: 7 additions & 13 deletions python/paddle/fluid/layers/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2443,23 +2443,17 @@ def lstm(input,
input_shape = list(input.shape)
input_size = input_shape[-1]
weight_size = 0
num_dirrection = 2 if is_bidirec == True else 1

for i in range(num_layers):
if i == 0:
input_weight_size = (input_size * hidden_size) * 4
input_weight_size = (input_size * hidden_size) * 4 * num_dirrection
else:
if is_bidirec:
input_weight_size = (hidden_size * 2 * hidden_size) * 4
else:
input_weight_size = (hidden_size * hidden_size) * 4
input_weight_size = (hidden_size * hidden_size) * 4 * num_dirrection
hidden_weight_size = (hidden_size * hidden_size) * 4 * num_dirrection

hidden_weight_size = (hidden_size * hidden_size) * 4

if is_bidirec:
weight_size += (input_weight_size + hidden_weight_size) * 2
weight_size += hidden_size * 8 * 2
else:
weight_size += input_weight_size + hidden_weight_size
weight_size += hidden_size * 8
weight_size += input_weight_size + hidden_weight_size
weight_size += hidden_size * 8 * num_dirrection

weight = helper.create_parameter(
attr=helper.param_attr,
Expand Down
Loading

0 comments on commit 9ec8766

Please sign in to comment.