Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API improvement for paddle.argsort and paddle.sort 易用性提升 #63513

Merged
merged 9 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@
func : angle_grad

- backward_op : argsort_grad
forward : argsort (Tensor x, int axis, bool descending) -> Tensor(out), Tensor(indices)
args : (Tensor indices, Tensor x, Tensor out_grad, int axis, bool descending)
forward : argsort (Tensor x, int axis, bool descending, bool stable) -> Tensor(out), Tensor(indices)
args : (Tensor indices, Tensor x, Tensor out_grad, int axis, bool descending, bool stable)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/op_version.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@
- delete_attr : atol
comment : The attribute 'atol' is deleted. The reason why it is deleted is that
attributes do not support a float64 value and it is changed to a tensor.

- op : argsort
version :
- checkpoint : Upgrade agsort, add a new attribute [stable]
action :
- add_attr : stable
comment : If true, it will use stable sorting algorithm which preserves the order
of equivalent elements. Otherwise, the order of equivalent elements will
not be guaranteed to be preserved.
default : "false"

- op : assign_value
version :
- checkpoint : Upgrade assign_value, remove plain attributes in favor of generic attribute.
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : argsort
args : (Tensor x, int axis=-1, bool descending=false)
args : (Tensor x, int axis=-1, bool descending=false, bool stable=false)
output : Tensor(out), Tensor(indices)
infer_meta :
func : ArgsortInferMeta
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
void ArgsortInferMeta(const MetaTensor& input,
int axis,
bool descending,
bool stable,
MetaTensor* output,
MetaTensor* indices) {
auto in_dims = input.dims();
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
void ArgsortInferMeta(const MetaTensor& input,
int axis,
bool descending,
bool stable,
MetaTensor* output,
MetaTensor* indices);

Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/argsort_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ void ArgsortGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
int axis,
bool descending,
bool stable,
DenseTensor* in_grad);

} // namespace phi
4 changes: 4 additions & 0 deletions paddle/phi/kernels/argsort_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ namespace phi {
* algorithm how to sort the input data.
* If descending is true, will sort by descending order,
* else if false, sort by ascending order
* @param stable Indicate whether to use stable sorting algorithm, which
* guarantees that the order of equivalent elements is
* preserved.
Comment on lines +36 to +38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释缩进有点问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revised here:
#63796

* @param out The sorted tensor of Argsort op, with the same shape as
* x
* @param indices The indices of a tensor giving the sorted order, with
Expand All @@ -43,6 +46,7 @@ void ArgsortKernel(const Context& dev_ctx,
const DenseTensor& input,
int axis,
bool descending,
bool stable,
DenseTensor* output,
DenseTensor* indices);

Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/cpu/argsort_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ void ArgsortGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
int axis,
bool descending UNUSED,
bool stable UNUSED,
DenseTensor* in_grad) {
auto in_dims = indices.dims();
auto rank = input.dims().size();
Expand Down
50 changes: 35 additions & 15 deletions paddle/phi/kernels/cpu/argsort_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ static void FullSort(Type input_height,
const DenseTensor* input,
T* t_out,
Type* t_indices,
bool descending) {
bool descending,
bool stable) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
Expand All @@ -48,18 +49,34 @@ static void FullSort(Type input_height,
col_vec.push_back(std::pair<T, Type>(e_input(i, j), j));
}
}
std::sort(col_vec.begin(),
col_vec.end(),
[&](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
if (descending)
return (std::isnan(static_cast<double>(l.first)) &&
!std::isnan(static_cast<double>(r.first))) ||
(l.first > r.first);
else
return (!std::isnan(static_cast<double>(l.first)) &&
std::isnan(static_cast<double>(r.first))) ||
(l.first < r.first);
});
if (stable) {
std::stable_sort(
col_vec.begin(),
col_vec.end(),
[&](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
if (descending)
return (std::isnan(static_cast<double>(l.first)) &&
!std::isnan(static_cast<double>(r.first))) ||
(l.first > r.first);
else
return (!std::isnan(static_cast<double>(l.first)) &&
std::isnan(static_cast<double>(r.first))) ||
(l.first < r.first);
});
} else {
std::sort(col_vec.begin(),
col_vec.end(),
[&](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
if (descending)
return (std::isnan(static_cast<double>(l.first)) &&
!std::isnan(static_cast<double>(r.first))) ||
(l.first > r.first);
else
return (!std::isnan(static_cast<double>(l.first)) &&
std::isnan(static_cast<double>(r.first))) ||
(l.first < r.first);
});
}

for (Type j = 0; j < input_width; ++j) {
t_out[i * input_width + j] = col_vec[j].first;
Expand All @@ -73,6 +90,7 @@ void ArgsortKernel(const Context& dev_ctx,
const DenseTensor& input,
int axis,
bool descending,
bool stable,
DenseTensor* output,
DenseTensor* indices) {
auto in_dims = input.dims();
Expand Down Expand Up @@ -100,7 +118,8 @@ void ArgsortKernel(const Context& dev_ctx,
&input,
out_data,
ids_data,
descending);
descending,
stable);
} else {
// If not full sort do transpose
std::vector<int> trans;
Expand Down Expand Up @@ -141,7 +160,8 @@ void ArgsortKernel(const Context& dev_ctx,
&trans_inp,
t_out,
t_ind,
descending);
descending,
stable);

dev_ctx.template Alloc<int64_t>(indices);
TransposeKernel<int64_t, Context>(dev_ctx, tmp_indices, trans, indices);
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/argsort_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ void ArgsortGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
int axis,
bool descending,
bool stable,
DenseTensor* in_grad) {
dev_ctx.template Alloc<T>(in_grad);
phi::funcs::set_constant(dev_ctx, in_grad, static_cast<T>(0.0));
Expand Down
31 changes: 24 additions & 7 deletions paddle/phi/kernels/gpu/argsort_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ void ArgsortKernel(const Context& dev_ctx,
const DenseTensor& input,
int axis,
bool descending,
bool stable,
DenseTensor* output,
DenseTensor* indices) {
auto in_dims = input.dims();
Expand All @@ -251,14 +252,30 @@ void ArgsortKernel(const Context& dev_ctx,
// Compared to the following 'Special case for full sort', ascending sort is
// 34 times faster and descending sort is 31 times faster.
if (size == in_dims[axis]) {
thrust::sequence(thrust::device, ids_data, ids_data + size);
thrust::copy(thrust::device, in_data, in_data + size, out_data);
thrust::sort_by_key(thrust::device, out_data, out_data + size, ids_data);
if (descending) {
thrust::reverse(thrust::device, out_data, out_data + size);
thrust::reverse(thrust::device, ids_data, ids_data + size);
if (stable) {
thrust::sequence(thrust::device, ids_data, ids_data + size);
thrust::copy(thrust::device, in_data, in_data + size, out_data);
if (descending) {
thrust::stable_sort_by_key(thrust::device,
out_data,
out_data + size,
ids_data,
thrust::greater<T>());
} else {
thrust::stable_sort_by_key(
thrust::device, out_data, out_data + size, ids_data);
}
return;
} else {
thrust::sequence(thrust::device, ids_data, ids_data + size);
thrust::copy(thrust::device, in_data, in_data + size, out_data);
thrust::sort_by_key(thrust::device, out_data, out_data + size, ids_data);
if (descending) {
thrust::reverse(thrust::device, out_data, out_data + size);
thrust::reverse(thrust::device, ids_data, ids_data + size);
}
return;
}
return;
}

// Special case for full sort, speedup ~190x.
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/xpu/argsort_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ void ArgsortGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
int axis,
bool descending,
bool stable,
DenseTensor* in_grad) {
auto in_dims = indices.dims();
auto rank = in_dims.size();
Expand Down
41 changes: 29 additions & 12 deletions paddle/phi/kernels/xpu/argsort_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,16 @@ static inline void xpu_argsort(xpu::Context* ctx,
TID* indices_data,
int m,
int n,
bool descending) {
int ret =
xpu::sort(ctx, input_data, output_data, indices_data, m, n, descending);
bool descending,
bool stable) {
int ret;
if (stable) {
ret = xpu::stable_sort(
ctx, input_data, output_data, indices_data, m, n, descending);
} else {
ret =
xpu::sort(ctx, input_data, output_data, indices_data, m, n, descending);
}
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "sort");
}

Expand Down Expand Up @@ -60,7 +67,8 @@ struct XPUArgsort {
int64_t* indices_data,
const std::vector<int>& data_shape,
const std::vector<int>& permute,
bool descending) {
bool descending,
bool stable) {
xpu::ctx_guard RAII_GUARD(ctx);
int m = data_shape[0] * data_shape[2];
int n = data_shape[1];
Expand All @@ -79,7 +87,8 @@ struct XPUArgsort {
indices_data_trans,
m,
n,
descending);
descending,
stable);
xpu_transpose(
ctx, output_data_trans, output_data, trans_data_shape, permute);
xpu_transpose(
Expand All @@ -95,7 +104,8 @@ struct XPUArgsort<T, false, true> {
int64_t* indices_data,
const std::vector<int>& data_shape,
const std::vector<int>& permute,
bool descending) {
bool descending,
bool stable) {
xpu::ctx_guard RAII_GUARD(ctx);
int m = data_shape[0] * data_shape[2];
int n = data_shape[1];
Expand All @@ -115,7 +125,8 @@ struct XPUArgsort<T, false, true> {
indices_data_trans,
m,
n,
descending);
descending,
stable);
xpu_transpose(
ctx, output_data_trans, output_data, trans_data_shape, permute);
xpu_cast(ctx, indices_data_trans, cast_data_int64, len);
Expand All @@ -132,7 +143,8 @@ struct XPUArgsort<int64_t, true, true> {
int64_t* indices_data,
const std::vector<int>& data_shape,
const std::vector<int>& permute,
bool descending) {
bool descending,
bool stable) {
xpu::ctx_guard RAII_GUARD(ctx);
int m = data_shape[0] * data_shape[2];
int n = data_shape[1];
Expand All @@ -154,7 +166,8 @@ struct XPUArgsort<int64_t, true, true> {
indices_data_trans,
m,
n,
descending);
descending,
stable);

xpu_cast(ctx, output_data_trans, cast_data_int64, len);
xpu_transpose(ctx, cast_data_int64, output_data, trans_data_shape, permute);
Expand All @@ -169,6 +182,7 @@ void ArgsortKernel(const Context& dev_ctx,
const DenseTensor& input,
int axis,
bool descending,
bool stable,
DenseTensor* output,
DenseTensor* indices) {
auto in_dims = input.dims();
Expand Down Expand Up @@ -217,7 +231,8 @@ void ArgsortKernel(const Context& dev_ctx,
indices_data,
data_shape,
permute_vec,
descending);
descending,
stable);
} else if (index_need_cast) {
XPUArgsort<XPUType, false, true>()(
dev_ctx.x_context(),
Expand All @@ -226,7 +241,8 @@ void ArgsortKernel(const Context& dev_ctx,
indices_data,
data_shape,
permute_vec,
descending);
descending,
stable);
} else {
XPUArgsort<XPUType, false, false>()(
dev_ctx.x_context(),
Expand All @@ -235,7 +251,8 @@ void ArgsortKernel(const Context& dev_ctx,
indices_data,
data_shape,
permute_vec,
descending);
descending,
stable);
}
}

Expand Down
Loading