Skip to content

Commit

Permalink
update the code for backward code
Browse files Browse the repository at this point in the history
  • Loading branch information
wawltor committed Oct 22, 2020
1 parent adb387c commit f230c33
Showing 1 changed file with 373 additions and 0 deletions.
373 changes: 373 additions & 0 deletions paddle/fluid/operators/cudnn_lstm_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,36 @@ void parameter_split(const Tensor* weight, const int& gate_num,
}
}

void reset_parameter_vector(const std::vector<Tensor*>& raw_params_vec,
const int& num_layers, const int& gate_num,
const bool& is_bidirec,
std::vector<TensorList>* params_vec) {
// the parameter raw seuquence is [FWhi, FWhh, BWhi, BWhh] * num_layers
// + [FBhi, FBhh, BBhi, BBhh] * num_layers, we will reset the parameter to
// ([FWhi, FWhh, FBhi, FBhh] + [BWhi, BWhh, BBhi, BBhh]) * num_layers
const int& direction_num = is_bidirec ? 2 : 1;
const int& layer_weight_size = 4 * direction_num;
const int& all_weight_size = num_layers * layer_weight_size;
const int& bias_start_idx = all_weight_size / 2;
for (int i = 0; i < num_layers; i++) {
TensorList tensor_list;
tensor_list.reserve(layer_weight_size);
for (int j = 0; j < layer_weight_size; j++) {
Tensor tensor_holder;
tensor_list.emplace_back(tensor_holder);
}
for (int j = 0; j < layer_weight_size; j++) {
int k = j % 4;
const int& section = j / 4;
int tensor_idx = i * 2 * direction_num + section * 2 + k % 2;
if (k >= 2) {
tensor_idx += bias_start_idx;
}
tensor_list[j].ShareDataWith(*raw_params_vec[tensor_idx]);
}
params_vec->emplace_back(tensor_list);
}
}
void reset_parameter_vector(const std::vector<const Tensor*>& raw_params_vec,
const int& num_layers, const int& gate_num,
const bool& is_bidirec,
Expand Down Expand Up @@ -821,5 +851,348 @@ class CudnnLSTMCPUKernel : public framework::OpKernel<T> {
}
};

template <typename T>
struct GradLayer {
virtual ~GradLayer() {}
virtual void operator()(const framework::ExecutionContext& context,
const Tensor* input, const Tensor* output,
const TensorList& last_h_grad_unbind,
const TensorList& last_c_grad_unbind,
const TensorList& gate_tensor_unbind,
const TensorList& state_tensor_unbind,
const Tensor* output_grad,
const std::vector<TensorList>& parameter_lists,
const Tensor* sequence_length, Tensor* input_grad,
Tensor* init_h_grad, Tensor* init_c_grad,
const std::vector<TensorList>& weight_list_grad,
const int& layer_idx) {}

void preprocess(const framework::ExecutionContext& context,
const Tensor* grad_output, Tensor* grad_last_h) {
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto eigen_grad_output = framework::EigenMatrix<T>::Reshape(
*grad_output, grad_output->dims().size() - 1);
auto eigen_grad_last_h = framework::EigenMatrix<T>::Reshape(
*grad_last_h, grad_last_h->dims().size() - 1);
// the output gradient contribute the gradient to last_h
eigen_grad_last_h.device(place) = eigen_grad_last_h + eigen_grad_output;
}
void mask_preprocess(const framework::ExecutionContext& context,
const Tensor* grad_output, Tensor* grad_last_h,
Tensor* grad_last_c, Tensor* grad_pre_h,
Tensor* grad_pre_c, const Tensor& mask_tensor) {
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto eigen_mask = framework::EigenMatrix<T>::From(
mask_tensor, framework::make_ddim({mask_tensor.dims()[1], 1}));
auto eigen_mask_broadcast =
eigen_mask.broadcast(Eigen::DSizes<int, 2>(1, grad_output->dims()[2]));

auto eigen_grad_last_h = framework::EigenMatrix<T>::Reshape(
*grad_last_h, grad_last_h->dims().size() - 1);
auto eigen_grad_last_c = framework::EigenMatrix<T>::Reshape(
*grad_last_c, grad_last_c->dims().size() - 1);
auto eigen_grad_pre_h = framework::EigenMatrix<T>::Reshape(
*grad_pre_h, grad_pre_h->dims().size() - 1);
auto eigen_grad_pre_c = framework::EigenMatrix<T>::Reshape(
*grad_pre_c, grad_pre_c->dims().size() - 1);
auto eigen_grad_output = framework::EigenMatrix<T>::Reshape(
*grad_output, grad_output->dims().size() - 1);

eigen_grad_pre_h.device(place) =
(1 - eigen_mask_broadcast) * eigen_grad_last_h;
eigen_grad_pre_c.device(place) =
(1 - eigen_mask_broadcast) * eigen_grad_last_c;
eigen_grad_last_h.device(place) = eigen_mask_broadcast * eigen_grad_last_h;
eigen_grad_last_c.device(place) = eigen_mask_broadcast * eigen_grad_last_c;

// the output gradient contribute the gradient to last_h
eigen_grad_last_h.device(place) =
eigen_grad_last_h + eigen_mask_broadcast * eigen_grad_output;
}

void postprocess(const framework::ExecutionContext& context) {}
};

template <typename T, typename GradCellType>
struct SingleGradLayer : GradLayer<T> {
explicit SingleGradLayer(GradCellType& cell) : cell_(cell) {}
virtual ~SingleGradLayer() {}
void operator()(
const framework::ExecutionContext& context, const Tensor* input,
const Tensor* output, const TensorList& last_h_grad_unbind,
const TensorList& last_c_grad_unbind,
const TensorList& gate_tensor_unbind,
const TensorList& state_tensor_unbind, const Tensor* output_grad,
const std::vector<TensorList>& parameter_lists,
const Tensor* sequence_length, Tensor* input_grad,
TensorList* init_h_grad_unbind, TensorList* init_c_grad_unbind,
const std::vector<TensorList>& weight_list_grad, const int& layer_idx) {
auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>();
const int& time_step = input->dims()[0];
const int& batch_size = input->dims()[1];
const int& input_size = input->dims()[2];

const bool& has_sequence_length = sequence_length == nullptr ? false : true;
Tensor mask_matrix;
TensorList mask_tensor_list;
if (has_sequence_length) {
mask_matrix.Resize(framework::make_ddim({time_step, input->dims()[1]}));
create_mask_matrix<T>(context, sequence_length, &mask_matrix, false);
mask_tensor_list = Unbind(mask_matrix);
}

// copy the last_h, last_c for swaping pointer
Tensor dynamic_last_h;
Tensor dynamic_last_c;
dynamic_last_h.Resize(last_h_grad_unbind[layer_idx].dims());
dynamic_last_h.mutable_data<T>(context.GetPlace());
dynamic_last_c.Resize(last_c_grad_unbind[layer_idx].dims());
dynamic_last_c.mutable_data<T>(context.GetPlace());
framework::TensorCopy(last_h_grad_unbind[layer_idx], context.GetPlace(),
&dynamic_last_h);
framework::TensorCopy(last_c_grad_unbind[layer_idx], context.GetPlace(),
&dynamic_last_c);

// if the init_c init_h grad is nullptr, we will create the tensor
Tensor dynamic_pre_h;
Tensor dynamic_pre_c;
if (init_h_grad_unbind->size() > 0) {
dynamic_pre_h.ShareDataWith((*init_h_grad_unbind)[layer_idx]);
} else {
dynamic_pre_h.Resize(dynamic_last_h.dims());
dynamic_pre_h.mutable_data<T>(context.GetPlace());
}
if (init_c_grad_unbind->size() > 0) {
dynamic_pre_c.ShareDataWith((*init_h_grad_unbind)[layer_idx]);
} else {
dynamic_pre_c.Resize(dynamic_last_c.dims());
dynamic_pre_c.mutable_data<T>(context.GetPlace());
}

// ubind the output, the output from [time_step, batch_size, hidden_size]
auto output_tensor_unbind = Unbind(*output);
for (int i = time_step - 1; i >= 0; --i) {
if (has_sequence_length) {
this->mask_preprocess(context, output_tensor_unbind[i], &dynamic_last_h,
&dynamic_last_c, &dynamic_pre_h, &dynamic_pre_c,
mask_tensor_list[i]);
} else {
this->preprocess(context, output_tensor_unbind[i], &dynamic_last_h);
}
// TODO(wawltor) add the rnn cell

SwapPoniter(&&dynamic_last_h, &&dynamic_pre_h);
SwapPoniter(&&dynamic_last_c, &&dynamic_pre_c);
}

// copy the gradient to init_c init_h
if ((*init_h_grad_unbind).size() > 0 && time_step % 2 == 0) {
framework::TensorCopy(dynamic_last_h, context.GetPlace(),
&((*init_h_grad_unbind)[layer_idx]));
}
if ((*init_c_grad_unbind).size() > 0 && time_step % 2 == 0) {
framework::TensorCopy(dynamic_last_c, context.GetPlace(),
&((*init_c_grad_unbind)[layer_idx]));
}
}
GradCellType cell_;
};

template <typename T, typename GradCellType>
struct BidirGradLayer : GradLayer<T> {
explicit BidirGradLayer(GradCellType& cell) : cell_(cell) {}
virtual ~BidirGradLayer() {}
void operator()(const framework::ExecutionContext& context,
const Tensor* input, const Tensor* output,
const TensorList& last_h_grad_unbind,
const TensorList& last_c_grad_unbind,
const TensorList& gate_tensor_unbind,
const TensorList& state_tensor_unbind,
const Tensor* output_grad,
const std::vector<TensorList>& parameter_lists,
const Tensor* sequence_length, Tensor* input_grad,
Tensor* init_h_grad, Tensor* init_c_grad,
const std::vector<TensorList>& weight_list_grad,
const int& layer_idx) {}
GradCellType cell_;
};

template <typename T>
struct GradCell {
virtual ~GradCell() {}
virtual void operator()(const Tensor* input) {}
};

template <typename T>
struct LSTMGradCell : GradCell<T> {
virtual void operator()(const Tensor* input) {}
};

template <typename GradCellType,
template <typename, typename> class SingleGradLayerT,
template <typename, typename> class BidirGradLayerT, typename T>
void RnnGradFunc(const framework::ExecutionContext& ctx, const int& gate_num,
const int& cell_num) {
// get the tensor pointer for the input
auto* input = ctx.Input<Tensor>("Input");
auto weight_list = ctx.MultiInput<Tensor>("WeightList");
auto* init_h = ctx.Input<Tensor>("InitH");
auto* init_c = ctx.Input<Tensor>("InitC");
auto* reserve_state = ctx.Input<Tensor>("Reserve");
auto* state_out = ctx.Input<Tensor>("StateOut");
auto* output = ctx.Input<Tensor>("Out");
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* last_h_grad = ctx.Input<Tensor>(framework::GradVarName("LastH"));
auto* last_c_grad = ctx.Input<Tensor>(framework::GradVarName("LastC"));

bool has_seq_length = ctx.HasInput("SequenceLength");
const Tensor* sequence_length = nullptr;
if (has_seq_length) {
sequence_length = ctx.Input<Tensor>("SequenceLength");
}

// get the tensor pointer for the output
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto weight_grad_list =
ctx.MultiOutput<framework::Tensor>(framework::GradVarName("WeightList"));
auto* init_h_grad = ctx.Output<Tensor>(framework::GradVarName("InitH"));
auto* init_c_grad = ctx.Output<Tensor>(framework::GradVarName("InitC"));

// get the attributes for the calcluate
const int& num_layers = ctx.Attr<int>("num_layers");
const bool& is_bidirec = ctx.Attr<bool>("is_bidirec");
const float& dropout_prob = ctx.Attr<float>("dropout_prob");
const bool& is_test = ctx.Attr<bool>("is_test");

// get the input_size, batch_size, time_step, hidden_size
const int& time_step = input->dims()[0];
const int& batch_size = input->dims()[1];
const int& input_size = input->dims()[2];
const int& hidden_size = ctx.Attr<int>("hidden_size");
const int& direction_num = is_bidirec ? 2 : 1;

// allocate the memory
input_grad->mutable_data<T>(input->dims(), ctx.GetPlace());
if (init_h_grad) init_h_grad->mutable_data<T>(init_h->dims(), ctx.GetPlace());
if (init_c_grad) init_c_grad->mutable_data<T>(init_c->dims(), ctx.GetPlace());

// reset the parameter to sorted
std::vector<TensorList> parameter_lists;
parameter_lists.reserve(num_layers);
reset_parameter_vector(weight_list, num_layers, gate_num, is_bidirec,
&parameter_lists);
std::vector<TensorList> parameter_lists_grad;
parameter_lists_grad.reserve(num_layers);
reset_parameter_vector(weight_grad_list, num_layers, gate_num, is_bidirec,
&parameter_lists_grad);

// resolve the state of reverse_state
const int& block_size = time_step * batch_size * hidden_size * direction_num;
// NOTICE *******
// reserve_state->Resize(framework::make_ddim({reserve_state->numel()/block_size,
// block_size}));
Tensor gate_tensor;
Tensor state_tensor;
Tensor hidden_tensor;
gate_tensor = reserve_state->Slice(0, gate_num * num_layers);
gate_tensor.Resize({num_layers, time_step * direction_num, batch_size,
hidden_size * gate_num});
if (cell_num >= 1) {
state_tensor = state_tensor.Slice(gate_num * num_layers,
(gate_num + cell_num) * num_layers);
state_tensor = state_tensor.Resize({num_layers, time_step * direction_num,
batch_size, hidden_size * gate_num});
}
if (num_layers > 1) {
hidden_tensor =
reserve_state->Slice((gate_num + 1) * num_layers,
(gate_num + 1) * num_layers + num_layers - 1);
hidden_tensor.Resize({num_layers - 1, time_step * direction_num, batch_size,
hidden_size * gate_num});
}
// unbind
auto last_h_grad_unbind = Unbind(*last_h_grad);
auto last_c_grad_unbind = Unbind(*last_c_grad);
auto gate_tensor_unbind = Unbind(gate_tensor);
auto state_tensor_unbind = Unbind(state_tensor);
auto hidden_tensor_unbind = Unbind(hidden_tensor);
std::vector<Tensor*> init_h_grad_unbind;
std::vector<Tensor*> init_c_grad_unbind;
if (init_h != nullptr) {
init_h_grad_unbind = Unbind(*init_h);
}
if (init_c != nullptr) {
init_c_grad_unbind = Unbind(*init_c);
}

GradCellType cell;
Tensor* layer_input;
Tensor* layer_output;
Tensor* layer_input_grad_holder = nullptr;
Tensor* layer_output_grad_holder = output_grad;
Tensor input_grad_temp;
Tensor output_grad_temp;

bool has_allocate_mem = false;
for (int i = num_layers - 1; i >= 0; --i) {
// the layer input output had saved, just use the data
if (i > 0) {
layer_input->ShareDataWith(hidden_tensor_unbind[i - 1]);
} else {
layer_input->ShareDataWith(*input);
}
if (i == num_layers - 1) {
layer_output->ShareDataWith(*output);
} else {
layer_output->ShareDataWith(hidden_tensor_unbind[i]);
}
if (num_layers == 1) {
layer_input_grad_holder = input_grad;
} else {
if (i == num_layers - 1) {
input_grad_temp.Resize(layer_input->dims());
input_grad_temp.mutable_data<T>(ctx.GetPlace());
layer_input_grad_holder = &input_grad_temp;
}
}
if (is_bidirec) {
BidirGradLayerT<T, GradCellType> layer(cell);
} else {
SingleGradLayerT<T, GradCellType> layer(cell);
layer(ctx, layer_input, layer_output, last_h_grad_unbind,
last_c_grad_unbind, gate_tensor_unbind, state_tensor_unbind,
layer_output_grad_holder, parameter_lists, sequence_length,
layer_input_grad_holder, &init_h_grad_unbind, &init_c_grad_unbind,
parameter_lists_grad, i);
}
if (i - 1 == 0) {
layer_output_grad_holder = input_grad;
} else {
if (!has_allocate_mem) {
output_grad_temp.Resize(layer_input_grad_holder->dims());
output_grad_temp.mutable_data<T>(ctx.GetPlace());
layer_output_grad_holder = output_grad_temp;
}
}
SwapPoniter(&layer_input_grad_holder, &layer_output_grad_holder);
}
}

template <typename DeviceContext, typename T>
class CudnnLSTMCPUGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const std::string& cell_type = ctx.Attr<std::string>("cell_type");
int gate_num = 4;
int cell_num = 1;
if (cell_type == "lstm") {
RnnGradFunc<LSTMGradCell<T>, SingleGradLayer, BidirGradLayer, T>(
ctx, gate_num, cell_num);
}
}
};
} // namespace operators
} // namespace paddle

0 comments on commit f230c33

Please sign in to comment.