Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some double/triple grad kernel yaml file #42361

Merged
merged 2 commits into from
Apr 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,12 @@
### Global Variables ###
########################
ops_to_fill_zero_for_empty_grads = set([
"split_grad",
"rnn_grad",
"matmul_double_grad",
"matmul_triple_grad",
"sigmoid_double_grad",
"sigmoid_triple_grad",
"add_double_grad",
"add_triple_grad",
"multiply_double_grad",
"multiply_triple_grad",
"conv2d_grad_grad",
"split_grad", "rnn_grad", "matmul_double_grad", "matmul_triple_grad",
"sigmoid_double_grad", "sigmoid_triple_grad", "add_double_grad",
"add_triple_grad", "multiply_double_grad", "multiply_triple_grad",
"conv2d_grad_grad", "batch_norm_double_grad", "tanh_double_grad",
"tanh_triple_grad", "subtract_double_grad", "divide_double_grad",
"log_double_grad", "elu_double_grad"
])

# For API dispatch used at python-level
Expand Down
12 changes: 10 additions & 2 deletions paddle/phi/api/lib/kernel_dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> {

// TODO(chenweihang): deal with multiple diff input Tensors
// TODO(chenweihang): add global device guard method to set backend
void operator()(const Tensor& x) {
const phi::TensorBase& tensor = *x.impl();
inline void AssignKernelKeySet(const phi::TensorBase& tensor) {
key_set.backend_set =
key_set.backend_set | detail::GetTensorBackendSet(tensor);
// TODO(chenweihang): select multi layout and dtype
Expand All @@ -110,6 +109,8 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> {
}
}

void operator()(const Tensor& x) { AssignKernelKeySet(*x.impl()); }

void operator()(const std::vector<Tensor>& x) {
const phi::TensorBase& tensor = *x.at(0).impl();
key_set.backend_set =
Expand All @@ -119,6 +120,13 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> {
key_set.dtype = tensor.dtype();
}

void operator()(const paddle::optional<const Tensor&> x) {
if (x.get_ptr() != nullptr) {
const phi::TensorBase& tensor = *(x.get_ptr()->impl());
AssignKernelKeySet(tensor);
}
}

// skip other type args, these args don't used in kernel selection
template <typename T>
void operator()(const T& x) {
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/activation_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,18 @@ void ReluDoubleGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void TanhDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& ddx,
const DenseTensor& dout,
const DenseTensor& ddx,
DenseTensor* dout_new,
DenseTensor* ddout);

template <typename T, typename Context>
void TanhTripleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& ddx,
const DenseTensor& dout,
const DenseTensor& d_ddout,
const DenseTensor& ddx,
const DenseTensor& d_dout_new,
const DenseTensor& d_ddout,
DenseTensor* d_out_new,
DenseTensor* d_dout,
DenseTensor* d_ddx);
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/kernels/batch_norm_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,16 @@ void BatchNormGradKernel(const Context& dev_ctx,

template <typename T, typename Context>
void BatchNormDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x_grad_grad,
const DenseTensor& scale_grad_grad,
const DenseTensor& bias_grad_grad,
const DenseTensor& y_grad,
const DenseTensor& x,
const DenseTensor& scale,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
paddle::optional<const DenseTensor&> mean,
paddle::optional<const DenseTensor&> variance,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
const DenseTensor& y_grad,
const DenseTensor& x_grad_grad,
const DenseTensor& scale_grad_grad,
const DenseTensor& bias_grad_grad,
float momentum,
float epsilon,
const std::string& data_layout,
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,16 +341,16 @@ void BatchNormGradKernel(const Context& dev_ctx,

template <typename T, typename Context>
void BatchNormDoubleGradKernel(const Context& ctx,
const DenseTensor& x_grad_grad,
const DenseTensor& scale_grad_grad,
const DenseTensor& bias_grad_grad,
const DenseTensor& y_grad,
const DenseTensor& x,
const DenseTensor& scale,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
paddle::optional<const DenseTensor&> mean,
paddle::optional<const DenseTensor&> variance,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
const DenseTensor& y_grad,
const DenseTensor& x_grad_grad,
const DenseTensor& scale_grad_grad,
const DenseTensor& bias_grad_grad,
float momentum,
float epsilon,
const std::string& data_layout_str,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/elementwise_subtract_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ void SubtractGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void SubtractDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& y,
const DenseTensor& dout,
paddle::optional<const DenseTensor&> ddx,
paddle::optional<const DenseTensor&> ddy,
const DenseTensor& dout,
int axis,
DenseTensor* ddout) {
phi::SubtractDoubleGradImpl<T>(dev_ctx, y, ddx, ddy, dout, axis, ddout);
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/elementwise_subtract_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ void SubtractGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void SubtractDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& y,
const DenseTensor& dout,
paddle::optional<const DenseTensor&> ddx,
paddle::optional<const DenseTensor&> ddy,
const DenseTensor& dout,
int axis,
DenseTensor* ddout);

Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -908,16 +908,16 @@ void BatchNormGradKernel(const Context &dev_ctx,

template <typename T, typename Context>
void BatchNormDoubleGradKernel(const Context &ctx,
const DenseTensor &x_grad_grad,
const DenseTensor &scale_grad_grad,
const DenseTensor &bias_grad_grad,
const DenseTensor &y_grad,
const DenseTensor &x,
const DenseTensor &scale,
const DenseTensor &saved_mean,
const DenseTensor &saved_variance,
paddle::optional<const DenseTensor &> mean,
paddle::optional<const DenseTensor &> variance,
const DenseTensor &saved_mean,
const DenseTensor &saved_variance,
const DenseTensor &y_grad,
const DenseTensor &x_grad_grad,
const DenseTensor &scale_grad_grad,
const DenseTensor &bias_grad_grad,
float momentum,
float epsilon,
const std::string &data_layout_str,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/elementwise_subtract_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ void SubtractGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void SubtractDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& y,
const DenseTensor& dout,
paddle::optional<const DenseTensor&> ddx,
paddle::optional<const DenseTensor&> ddy,
const DenseTensor& dout,
int axis,
DenseTensor* ddout) {
phi::SubtractDoubleGradImpl<T>(dev_ctx, y, ddx, ddy, dout, axis, ddout);
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/impl/activation_grad_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ void LeakyReluDoubleGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void TanhDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& ddx,
const DenseTensor& dout,
const DenseTensor& ddx,
DenseTensor* dout_new,
DenseTensor* ddout) {
if (dout_new) {
Expand All @@ -171,10 +171,10 @@ void TanhDoubleGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void TanhTripleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& ddx,
const DenseTensor& dout,
const DenseTensor& d_ddout,
const DenseTensor& ddx,
const DenseTensor& d_dout_new,
const DenseTensor& d_ddout,
DenseTensor* d_out_new,
DenseTensor* d_dout,
DenseTensor* d_ddx) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/ops/compat/activation_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,13 @@ KernelSignature ReluDoubleGradOpArgumentMapping(
KernelSignature TanhDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"tanh_double_grad", {"Out", "DDX", "DOut"}, {}, {"DOutNew", "DDOut"});
"tanh_double_grad", {"Out", "DOut", "DDX"}, {}, {"DOutNew", "DDOut"});
}

KernelSignature TanhTripleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("tanh_triple_grad",
{"Out", "DDX", "DOut", "D_DDOut", "D_DOut_New"},
{"Out", "DOut", "DDX", "D_DOut_New", "D_DDOut"},
{},
{"D_OutNew", "D_DOut", "D_DDx"});
}
Expand Down
14 changes: 7 additions & 7 deletions paddle/phi/ops/compat/batch_norm_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,16 @@ KernelSignature BatchNormGradOpArgumentMapping(
KernelSignature BatchNormGradGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("batch_norm_grad_grad",
{"DDX",
"DDScale",
"DDBias",
"DY",
"X",
{"X",
"Scale",
"Mean",
"Variance",
"SavedMean",
"SavedVariance",
"Mean",
"Variance"},
"DY",
"DDX",
"DDScale",
"DDBias"},
{"momentum",
"epsilon",
"data_layout",
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/compat/elementwise_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ KernelSignature ElementwiseSubGradOpArgumentMapping(
KernelSignature ElementwiseSubDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"subtract_double_grad", {"Y", "DDX", "DDY", "DOut"}, {"axis"}, {"DDOut"});
"subtract_double_grad", {"Y", "DOut", "DDX", "DDY"}, {"axis"}, {"DDOut"});
}

KernelSignature ElementwiseDivGradOpArgumentMapping(
Expand Down
Loading