Skip to content

Commit

Permalink
add flattern weight of lstm
Browse files Browse the repository at this point in the history
  • Loading branch information
GaoWei8 committed Sep 17, 2020
1 parent 54b81fa commit 5e98cca
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 57 deletions.
18 changes: 13 additions & 5 deletions paddle/fluid/operators/cudnn_lstm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class CudnnLSTMOp : public framework::OperatorWithKernel {

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTM");

Expand Down Expand Up @@ -122,7 +121,13 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("W",
"(Tensor) the learnable hidden-hidden weights."
" The shape is (N), where N is total weight size of the LSTM. "
" cudnn concatenate all the weight to one Tensor");
" cudnn concatenate all the weight to one Tensor")
.AsDispensable();
AddInput("WeightList",
"(vector<Tensor>), stores weight and bias data when the weight "
"use the list format. ")
.AsDispensable()
.AsDuplicable();
AddInput("SequenceLength",
"(Tensor) When the input data is padding, "
"set this parameter. This parameter represents "
Expand Down Expand Up @@ -216,7 +221,6 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTMGrad");
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTMGrad");
OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTMGrad");
OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTMGrad");

Expand All @@ -228,7 +232,8 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {
};

SetOutGradDim("Input");
SetOutGradDim("W");
ctx->SetOutputsDim(framework::GradVarName("WeightList"),
ctx->GetInputsDim("WeightList"));
SetOutGradDim("InitH");
SetOutGradDim("InitC");
}
Expand All @@ -251,7 +256,7 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("Input", this->Input("Input"));
op->SetInput("InitH", this->Input("InitH"));
op->SetInput("InitC", this->Input("InitC"));
op->SetInput("W", this->Input("W"));
op->SetInput("WeightList", this->Input("WeightList"));
if (this->HasInput("SequenceLength")) {
op->SetInput("SequenceLength", this->Input("SequenceLength"));
}
Expand All @@ -262,6 +267,9 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput(framework::GradVarName("LastC"), this->OutputGrad("LastC"));
op->SetInput(framework::GradVarName("LastH"), this->OutputGrad("LastH"));

op->SetOutput(framework::GradVarName("WeightList"),
this->InputGrad("WeightList", false));

op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("W"), this->InputGrad("W"));
op->SetOutput(framework::GradVarName("InitH"), this->InputGrad("InitH"));
Expand Down
122 changes: 109 additions & 13 deletions paddle/fluid/operators/cudnn_lstm_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,66 @@ namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;

template <typename T>
bool is_continuous(const std::vector<const Tensor *> &weight_list) {
bool continuous = true;
for (size_t i = 0; i < weight_list.size() - 1; ++i) {
const T *in_data = weight_list[i]->data<T>();
const T *in_after_data = weight_list[i + 1]->data<T>();
auto in_size = weight_list[i]->numel();
bool temp = in_data + in_size == in_after_data;
continuous = continuous && temp;
}
return continuous;
}

int size_sum(const std::vector<const Tensor *> &weight_list) {
int size = 0;
for (size_t i = 0; i < weight_list.size(); ++i) {
auto in_size = weight_list[i]->numel();
size += in_size;
}
return size;
}

template <typename T>
void weight_to_tensor(const platform::Place &place, cudaStream_t stream,
const std::vector<const Tensor *> &weight_list,
Tensor *weight) {
auto weight_data = weight->data<T>();
int weight_offset = 0;
for (size_t i = 0; i < weight_list.size(); ++i) {
const T *in_data = weight_list[i]->data<T>();
auto in_size = weight_list[i]->numel();

memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, weight->place()),
weight_data + weight_offset,
BOOST_GET_CONST(platform::CUDAPlace, weight_list[i]->place()),
in_data, in_size * sizeof(T), stream);
weight_offset += in_size;
}
}

template <typename T>
void weight_to_tensor_list(const platform::Place &place, cudaStream_t stream,
std::vector<Tensor *> *weight_grad,
const std::vector<const Tensor *> &weight_input,
const Tensor *weight) {
int weight_offset = 0;
auto *weight_data = weight->data<T>();
for (size_t i = 0; i < weight_input.size(); ++i) {
auto in_size = weight_input[i]->numel();
T *weight_grad_data = (*weight_grad)[i]->mutable_data<T>(place);
const T *src = weight_data + weight_offset;

memory::Copy(
BOOST_GET_CONST(platform::CUDAPlace, (*weight_grad)[i]->place()),
weight_grad_data, BOOST_GET_CONST(platform::CUDAPlace, weight->place()),
src, in_size * sizeof(T), stream);
weight_offset += in_size;
}
}

template <typename T>
void LSTMInferece(const bool &has_seq_length, const cudnnHandle_t &handle,
const int &seq_length, ScopedRNNBase *rnn, const T *x_data,
Expand Down Expand Up @@ -70,8 +130,6 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
const Tensor *init_h = ctx.Input<Tensor>("InitH");
const Tensor *init_c = ctx.Input<Tensor>("InitC");

auto w = ctx.Input<Tensor>("W");

Tensor *out = ctx.Output<Tensor>("Out");
Tensor *last_h = ctx.Output<Tensor>("LastH");
Tensor *last_c = ctx.Output<Tensor>("LastC");
Expand All @@ -82,12 +140,30 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
const T *init_h_data = init_h->data<T>();
const T *init_c_data = init_c->data<T>();

const T *w_data = w->data<T>();

T *out_data = out->mutable_data<T>(ctx.GetPlace());
T *last_h_data = last_h->mutable_data<T>(ctx.GetPlace());
T *last_c_data = last_c->mutable_data<T>(ctx.GetPlace());

auto weight_list = ctx.MultiInput<framework::Tensor>("WeightList");
bool continuous = is_continuous<T>(weight_list);
int weight_numel = size_sum(weight_list);

auto place = ctx.GetPlace();
auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
ctx.device_context())
.stream();
Tensor weight_whole;
T *w_data = nullptr;

if (!continuous) {
// LOG(WARNING) << "Input WeightList , please use op";
weight_whole.mutable_data<T>({weight_numel}, place);
weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
w_data = weight_whole.data<T>();
} else {
w_data = const_cast<T *>(weight_list[0]->data<T>());
}

float dropout_prob = ctx.Attr<float>("dropout_prob");
bool is_bidirec = ctx.Attr<bool>("is_bidirec");
int hidden_size = ctx.Attr<int>("hidden_size");
Expand All @@ -108,7 +184,6 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
int seq_length = x->dims()[0];
int batch_size = x->dims()[1];
int input_size = x->dims()[2];
int weight_numel = w->numel();
bool state_initialized = state_out->IsInitialized() ? true : false;

size_t workspace_size;
Expand Down Expand Up @@ -171,21 +246,22 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *input = ctx.Input<Tensor>("Input");
auto *weight = ctx.Input<Tensor>("W");
auto *init_h = ctx.Input<Tensor>("InitH");
auto *init_c = ctx.Input<Tensor>("InitC");
auto *reserve = ctx.Input<Tensor>("Reserve");
auto *state_out = ctx.Input<Tensor>("StateOut");
auto weight_list = ctx.MultiInput<Tensor>("WeightList");

auto *out = ctx.Input<Tensor>("Out");
auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *last_h_grad = ctx.Input<Tensor>(framework::GradVarName("LastH"));
auto *last_c_grad = ctx.Input<Tensor>(framework::GradVarName("LastC"));

auto *in_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto *weight_grad = ctx.Output<Tensor>(framework::GradVarName("W"));
auto *init_h_grad = ctx.Output<Tensor>(framework::GradVarName("InitH"));
auto *init_c_grad = ctx.Output<Tensor>(framework::GradVarName("InitC"));
auto weight_grad_list = ctx.MultiOutput<framework::Tensor>(
framework::GradVarName("WeightList"));

auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
Expand All @@ -194,17 +270,36 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
auto init_h_dims = init_h->dims();
auto init_c_dims = init_c->dims();

auto *weight_data = weight->data<T>();
auto *init_h_data = init_h->data<T>();
auto *init_c_data = init_c->data<T>();
auto *out_data = out->data<T>();
auto *out_grad_data = out_grad->data<T>();
auto *last_h_grad_data = last_h_grad->data<T>();
auto *last_c_grad_data = last_c_grad->data<T>();

auto place = ctx.GetPlace();
int weight_numel = size_sum(weight_list);
bool continuous = is_continuous<T>(weight_list);

auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
ctx.device_context())
.stream();
Tensor weight_whole;
weight_whole.mutable_data<T>({weight_numel}, place);
T *weight_data = nullptr;

if (!continuous) {
weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
weight_data = weight_whole.data<T>();
} else {
weight_data = const_cast<T *>(weight_list[0]->data<T>());
}

Tensor weight_grad;

math::SetConstant<paddle::platform::CUDADeviceContext, T> zero;
weight_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, weight_grad, static_cast<T>(0.0));
weight_grad.mutable_data<T>({weight_numel}, ctx.GetPlace());
zero(dev_ctx, &weight_grad, static_cast<T>(0.0));

in_grad->mutable_data<T>(input_dims, ctx.GetPlace());
auto *in_grad_data = in_grad->data<T>();
Expand All @@ -231,7 +326,6 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
int seq_length = input_dims[0];
int batch_size = input->dims()[1];
int input_size = input->dims()[2];
int weight_numel = weight->numel();

size_t workspace_size;
size_t reserve_size;
Expand Down Expand Up @@ -263,7 +357,7 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data<T>(),
rnn.init_h_desc(), init_h->data<T>(), rnn.y_descs(), out->data<T>(),
workspace_data_.data<uint8_t>(), workspace_size, rnn.weight_desc(),
weight_grad->data<T>(), const_cast<uint8_t *>(reserve_data),
weight_grad.data<T>(), const_cast<uint8_t *>(reserve_data),
reserve_size));
} else {
#if CUDNN_VERSION >= 7201
Expand All @@ -283,7 +377,7 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
handle, rnn.rnn_desc(), rnn.x_seq_desc(), input->data<T>(),
rnn.init_h_desc(), init_h->data<T>(), rnn.y_seq_desc(),
out->data<T>(), workspace_data_.data<uint8_t>(), workspace_size,
rnn.weight_desc(), weight_grad->data<T>(),
rnn.weight_desc(), weight_grad.data<T>(),
const_cast<uint8_t *>(reserve_data), reserve_size));
#else
PADDLE_THROW(platform::errors::Unavailable(
Expand All @@ -292,6 +386,8 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
"of cudnn is larger than 7.2.1"));
#endif
}
weight_to_tensor_list<T>(place, stream, &weight_grad_list, weight_list,
&weight_grad);
}
};

Expand Down
58 changes: 36 additions & 22 deletions python/paddle/fluid/layers/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2439,29 +2439,43 @@ def lstm(input,
input_shape = list(input.shape)
input_size = input_shape[-1]
weight_size = 0
num_dirrection = 2 if is_bidirec == True else 1

weight_list = []
for i in range(num_layers):
if i == 0:
input_weight_size = (input_size * hidden_size) * 4
else:
if is_bidirec:
input_weight_size = (hidden_size * 2 * hidden_size) * 4
for j in range(num_dirrection):
if i == 0:
weight_ih = helper.create_parameter(
attr=helper.param_attr,
shape=[4 * hidden_size, input_size],
dtype=dtype,
default_initializer=default_initializer)
else:
input_weight_size = (hidden_size * hidden_size) * 4

hidden_weight_size = (hidden_size * hidden_size) * 4

if is_bidirec:
weight_size += (input_weight_size + hidden_weight_size) * 2
weight_size += hidden_size * 8 * 2
else:
weight_size += input_weight_size + hidden_weight_size
weight_size += hidden_size * 8

weight = helper.create_parameter(
attr=helper.param_attr,
shape=[weight_size],
dtype=dtype,
default_initializer=default_initializer)
weight_ih = helper.create_parameter(
attr=helper.param_attr,
shape=[4 * hidden_size, hidden_size],
dtype=dtype,
default_initializer=default_initializer)
weight_list.append(weight_ih)

weight_hh = helper.create_parameter(
attr=helper.param_attr,
shape=[4 * hidden_size, hidden_size],
dtype=dtype,
default_initializer=default_initializer)
weight_list.append(weight_hh)
bias_ih = helper.create_parameter(
attr=helper.param_attr,
shape=[4 * hidden_size],
dtype=dtype,
default_initializer=default_initializer)
weight_list.append(bias_ih)
bias_hh = helper.create_parameter(
attr=helper.param_attr,
shape=[4 * hidden_size],
dtype=dtype,
default_initializer=default_initializer)
weight_list.append(bias_hh)

out = helper.create_variable_for_type_inference(dtype)
last_h = helper.create_variable_for_type_inference(dtype)
Expand All @@ -2478,7 +2492,7 @@ def lstm(input,
'Input': input,
'InitH': init_h,
'InitC': init_c,
'W': weight,
'WeightList': weight_list,
},
outputs={
'Out': out,
Expand Down
Loading

0 comments on commit 5e98cca

Please sign in to comment.