Skip to content

Commit

Permalink
【Hackathon 4 No.17】Add cummax / cummin API to Paddle (#53546)
Browse files Browse the repository at this point in the history
  • Loading branch information
Patrick-Star125 authored Jun 13, 2023
1 parent 1bcf437 commit 3a3fb1f
Show file tree
Hide file tree
Showing 15 changed files with 1,579 additions and 0 deletions.
20 changes: 20 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,26 @@
func : cross_grad
data_type : out_grad

- backward_op : cummax_grad
forward : cummax(Tensor x, int axis=-1, int dtype=3) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, int axis, int dtype)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [x]
kernel :
func : cummax_grad

- backward_op : cummin_grad
forward : cummin(Tensor x, int axis=-1, int dtype=3) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, int axis, int dtype)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [x]
kernel :
func : cummin_grad

- backward_op : cumprod_grad
forward : cumprod (Tensor x, int dim) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, int dim)
Expand Down
18 changes: 18 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,24 @@
data_type : input
backward : cross_entropy_with_softmax_grad

- op : cummax
args : (Tensor x, int axis=-1, int dtype=3)
output : Tensor(out), Tensor(indices)
infer_meta :
func : CumWithIndicesInferMeta
kernel :
func : cummax
backward : cummax_grad

- op : cummin
args : (Tensor x, int axis=-1, int dtype=3)
output : Tensor(out), Tensor(indices)
infer_meta :
func : CumWithIndicesInferMeta
kernel :
func : cummin
backward : cummin_grad

- op : cumprod
args : (Tensor x, int dim)
output : Tensor(out)
Expand Down
63 changes: 63 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,69 @@ void CumScalarAxisInferMeta(const MetaTensor& x,
CumInferMeta(x, axis.to<int>(), flatten, exclusive, reverse, out);
}

void CumWithIndicesInferMeta(const MetaTensor& x,
int axis,
int dtype,
MetaTensor* out,
MetaTensor* indices) {
auto x_dims = x.dims();
auto indices_type = phi::TransToPhiDataType(dtype);
PADDLE_ENFORCE_EQ(
(indices_type == DataType::INT32 || indices_type == DataType::INT64),
true,
phi::errors::InvalidArgument("dtype of indices must be int32 or int64"));

if (indices_type == DataType::INT32) {
int _axis;
if (axis < 0) {
_axis = axis + x_dims.size();
} else {
_axis = axis;
}
PADDLE_ENFORCE_LT(
phi::vectorize(x_dims)[_axis],
INT32_MAX,
phi::errors::OutOfRange(
"cummax with axis %ld may be overflow, set dtype int64 to continue",
axis));
}

if (x_dims.size() > 0) {
PADDLE_ENFORCE_GE(
axis,
-x_dims.size(),
phi::errors::OutOfRange(
"axis is out of range (expected to be in range of [%ld, "
"%ld), but got %ld).",
-(x_dims.size()),
x_dims.size(),
axis));
PADDLE_ENFORCE_LT(
axis,
x_dims.size(),
phi::errors::OutOfRange(
"axis is out of range (expected to be in range of [%ld, "
"%ld), but got %ld).",
-(x_dims.size()),
x_dims.size(),
axis));
} else {
PADDLE_ENFORCE_EQ(
(axis == 0 || axis == -1),
true,
errors::InvalidArgument("The axis must be -1 or 0 in 0D Tensor, "
"but the value given is %d.",
axis));
}

out->set_dims(x_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
indices->set_dims(x_dims);
indices->set_dtype(indices_type);
indices->share_lod(x);
}

void CropInferMeta(const MetaTensor& x,
const IntArray& shape,
const IntArray& offsets,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ void CumScalarAxisInferMeta(const MetaTensor& x,
bool reverse,
MetaTensor* out);

void CumWithIndicesInferMeta(const MetaTensor& x,
int axis,
int dtype,
MetaTensor* out,
MetaTensor* indices);

void DecodeJpegInferMeta(const MetaTensor& x,
const std::string& mode,
MetaTensor* out);
Expand Down
91 changes: 91 additions & 0 deletions paddle/phi/kernels/cpu/cum_maxmin_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/cum_maxmin_grad_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

template <typename T, typename Context>
void CummaxGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& indices,
const DenseTensor& out_grad,
int axis,
int dtype,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, x_grad, static_cast<T>(0));
if (axis < 0) {
axis = axis + x.dims().size();
}
auto indices_type = phi::TransToPhiDataType(dtype);
if (indices_type == DataType::INT32) {
phi::funcs::cpu_scatter_add_kernel<T, int32_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
} else if (indices_type == DataType::INT64) {
phi::funcs::cpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
}
}

template <typename T, typename Context>
void CumminGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& indices,
const DenseTensor& out_grad,
int axis,
int dtype,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, x_grad, static_cast<T>(0));
if (axis < 0) {
axis = axis + x.dims().size();
}
auto indices_type = phi::TransToPhiDataType(dtype);
if (indices_type == DataType::INT32) {
phi::funcs::cpu_scatter_add_kernel<T, int32_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
} else if (indices_type == DataType::INT64) {
phi::funcs::cpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
}
}

} // namespace phi

PD_REGISTER_KERNEL(cummax_grad,
CPU,
ALL_LAYOUT,
phi::CummaxGradKernel,
float,
double,
int32_t,
int64_t) {}

PD_REGISTER_KERNEL(cummin_grad,
CPU,
ALL_LAYOUT,
phi::CumminGradKernel,
float,
double,
int32_t,
int64_t) {}
Loading

0 comments on commit 3a3fb1f

Please sign in to comment.