Skip to content

Commit

Permalink
[OpAttr]Adapt tensor axis for argmin/max (#45453)
Browse files Browse the repository at this point in the history
* Adapt tensor axis for argmin/max

* Add UT

* Polish UT
  • Loading branch information
0x45f authored Aug 30, 2022
1 parent 5f1a8e4 commit 6fc1598
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 41 deletions.
10 changes: 9 additions & 1 deletion paddle/fluid/operators/arg_min_max_op_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ namespace operators {
class ArgMinMaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};

class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
Expand All @@ -42,7 +49,8 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X", "Input tensor.");
AddOutput("Out", "Output tensor.");
AddAttr<int64_t>("axis", "The axis in which to compute the arg indics.");
AddAttr<int64_t>("axis", "The axis in which to compute the arg indics.")
.SupportTensor();
AddAttr<bool>("keepdims", "Keep the dim that to reduce.").SetDefault(false);
AddAttr<bool>("flatten",
"Flatten the input value, and search the min or max indices")
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/legacy_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,15 @@
support_trans_dtype : start, end, step

- api : argmax
args : (Tensor x, int64_t axis, bool keepdims, bool flatten, int dtype)
args : (Tensor x, Scalar axis, bool keepdims, bool flatten, int dtype)
output : Tensor(out)
infer_meta :
func : ArgMinMaxInferMeta
kernel :
func : arg_max

- api : argmin
args : (Tensor x, int64_t axis, bool keepdims, bool flatten, int dtype)
args : (Tensor x, Scalar axis, bool keepdims, bool flatten, int dtype)
output : Tensor(out)
infer_meta :
func : ArgMinMaxInferMeta
Expand Down
63 changes: 42 additions & 21 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,28 +121,12 @@ void AffineGridInferMeta(const MetaTensor& input,
}

void ArgMinMaxInferMeta(const MetaTensor& x,
int64_t axis,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
MetaTensor* out,
MetaConfig config) {
const auto& x_dims = x.dims();

PADDLE_ENFORCE_GE(
axis,
-x_dims.size(),
phi::errors::InvalidArgument("'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d).",
axis,
-x_dims.size()));
PADDLE_ENFORCE_LT(axis,
x_dims.size(),
phi::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).",
axis,
x_dims.size()));

PADDLE_ENFORCE_EQ(
(dtype < 0 || dtype == 2 || dtype == 3),
true,
Expand All @@ -156,16 +140,53 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
paddle::framework::DataTypeToString(
static_cast<paddle::framework::proto::VarType::Type>(dtype))));

if (!config.is_runtime && axis.FromTensor()) {
std::vector<int64_t> vec;
if (flatten) {
vec = {1};
} else {
if (keepdims) {
vec = std::vector<int64_t>(x.dims().size(), -1);
} else {
vec = std::vector<int64_t>(x.dims().size() - 1, -1);
}
}
out->set_dims(phi::make_ddim(vec));
if (dtype == 2) {
out->set_dtype(DataType::INT32);
} else if (dtype == 3) {
out->set_dtype(DataType::INT64);
}
return;
}

auto int_axis = axis.to<int64_t>();
const auto& x_dims = x.dims();

PADDLE_ENFORCE_GE(
int_axis,
-x_dims.size(),
phi::errors::InvalidArgument("'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d).",
int_axis,
-x_dims.size()));
PADDLE_ENFORCE_LT(int_axis,
x_dims.size(),
phi::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).",
int_axis,
x_dims.size()));

auto x_rank = x_dims.size();
if (axis < 0) axis += x_rank;
if (int_axis < 0) int_axis += x_rank;
if (config.is_runtime) {
if (dtype == paddle::framework::proto::VarType::INT32) {
int64_t all_element_num = 0;
if (flatten) {
all_element_num = phi::product(x_dims);

} else {
all_element_num = x_dims[axis];
all_element_num = x_dims[int_axis];
}
PADDLE_ENFORCE_LE(
all_element_num,
Expand All @@ -182,11 +203,11 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
if (flatten) {
vec.emplace_back(static_cast<int64_t>(1));
} else {
for (int64_t i = 0; i < axis; i++) vec.emplace_back(x_dims[i]);
for (int64_t i = 0; i < int_axis; i++) vec.emplace_back(x_dims[i]);
if (keepdims) {
vec.emplace_back(static_cast<int64_t>(1));
}
for (int64_t i = axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]);
for (int64_t i = int_axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]);
}
out->set_dims(phi::make_ddim(vec));
if (dtype == 2) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void AffineGridInferMeta(const MetaTensor& input,
MetaTensor* output);

void ArgMinMaxInferMeta(const MetaTensor& x,
int64_t axis,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
Expand Down
5 changes: 3 additions & 2 deletions paddle/phi/kernels/arg_min_max_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@ limitations under the License. */

#pragma once

#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void ArgMinKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
Expand All @@ -30,7 +31,7 @@ void ArgMinKernel(const Context& dev_ctx,
template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/kernels/cpu/arg_min_max_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ struct VisitDataArgMinMaxFunctor {
template <typename Context, typename T, ArgMinMaxType EnumArgMinMaxValue>
void ArgMinMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
Expand All @@ -145,19 +145,19 @@ void ArgMinMaxKernel(const Context& dev_ctx,
static_cast<paddle::framework::proto::VarType::Type>(
paddle::framework::proto::VarType::INT64),
VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>(
dev_ctx, x, axis, keepdims, flatten, out));
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
return;
}
paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(dtype),
VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>(
dev_ctx, x, axis, keepdims, flatten, out));
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
}

template <typename T, typename Context>
void ArgMinKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
Expand All @@ -169,7 +169,7 @@ void ArgMinKernel(const Context& dev_ctx,
template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/kernels/gpu/arg_min_max_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ struct VisitDataCudaArgMinMaxFunctor {
template <typename Context, typename T, class Reducer>
void ArgMinMaxOpCUDAKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
Expand All @@ -213,19 +213,19 @@ void ArgMinMaxOpCUDAKernel(const Context& dev_ctx,
static_cast<paddle::framework::proto::VarType::Type>(
paddle::framework::proto::VarType::INT64),
VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>(
dev_ctx, x, axis, keepdims, flatten, out));
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
return;
}
paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(dtype),
VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>(
dev_ctx, x, axis, keepdims, flatten, out));
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
}

template <typename T, typename Context>
void ArgMinKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
Expand All @@ -237,7 +237,7 @@ void ArgMinKernel(const Context& dev_ctx,
template <typename T, typename Context>
void ArgMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int64_t axis,
const Scalar& axis,
bool keepdims,
bool flatten,
int dtype,
Expand Down
88 changes: 88 additions & 0 deletions python/paddle/fluid/tests/unittests/test_arg_min_max_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

from __future__ import print_function

import os
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
from test_attribute_var import UnittestBase


class BaseTestCase(OpTest):
Expand Down Expand Up @@ -235,6 +237,92 @@ def setUp(self):
}


class TestArgMaxTensorAxis(UnittestBase):

def init_info(self):
self.shapes = [[2, 3, 4]]
self.x = [np.random.randn(*shape) for shape in self.shapes]
self.save_path = os.path.join(self.temp_dir.name, self.path_prefix())

def test_static(self):
main_prog = Program()
starup_prog = Program()
with program_guard(main_prog, starup_prog):
fc = paddle.nn.Linear(4, 10)
x = paddle.randn([2, 3, 4])
x.stop_gradient = False
feat = fc(x)

out = self.call_func(feat)

sgd = paddle.optimizer.SGD()
sgd.minimize(paddle.mean(paddle.cast(out, 'float32')))
self.assertTrue(self.var_prefix() in str(main_prog))

exe = paddle.static.Executor()
exe.run(starup_prog)
res = exe.run(fetch_list=[feat, out])
paddle.static.save_inference_model(self.save_path, [x], [feat, out],
exe)
gt = np.argmax(res[0], 0)
np.testing.assert_allclose(res[1], gt)

# Test for Inference Predictor
infer_outs = self.infer_prog()
gt = np.argmax(infer_outs[0], 0)
np.testing.assert_allclose(infer_outs[1], gt)

def path_prefix(self):
return 'argmax_tensor_axis'

def var_prefix(self):
return "Var["

def call_func(self, x):
axis = paddle.assign(0)
out = paddle.argmax(x, axis)
return out


class TestArgMinTensorAxis(TestArgMaxTensorAxis):

def test_static(self):
main_prog = Program()
starup_prog = Program()
with program_guard(main_prog, starup_prog):
fc = paddle.nn.Linear(4, 10)
x = paddle.randn([2, 3, 4])
x.stop_gradient = False
feat = fc(x)
feat = paddle.cast(feat, 'int32')
out = self.call_func(feat)

sgd = paddle.optimizer.SGD()
sgd.minimize(paddle.mean(paddle.cast(out, 'float32')))
self.assertTrue(self.var_prefix() in str(main_prog))

exe = paddle.static.Executor()
exe.run(starup_prog)
res = exe.run(fetch_list=[feat, out])
paddle.static.save_inference_model(self.save_path, [x], [feat, out],
exe)
gt = np.argmin(res[0], 1)
np.testing.assert_allclose(np.squeeze(res[1]), gt)

# Test for Inference Predictor
infer_outs = self.infer_prog()
gt = np.argmin(infer_outs[0], 1)
np.testing.assert_allclose(np.squeeze(infer_outs[1]), gt)

def path_prefix(self):
return 'argmin_tensor_axis'

def call_func(self, x):
axis = paddle.assign(1)
out = paddle.argmin(x, axis, keepdim=True)
return out


if __name__ == '__main__':
paddle.enable_static()
unittest.main()
8 changes: 4 additions & 4 deletions python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
print(out4)
# [[2, 2, 0, 1]]
"""
if axis is not None and not isinstance(axis, int):
if axis is not None and not isinstance(axis, (int, Variable)):
raise TypeError(
"The type of 'axis' must be int or None in argmax, but received %s."
"The type of 'axis' must be int or Tensor or None in argmax, but received %s."
% (type(axis)))

if dtype is None:
Expand Down Expand Up @@ -244,9 +244,9 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
print(out4)
# [[1, 1, 1, 2]]
"""
if axis is not None and not isinstance(axis, int):
if axis is not None and not isinstance(axis, (int, Variable)):
raise TypeError(
"The type of 'axis' must be int or None in argmin, but received %s."
"The type of 'axis' must be int or Tensor or None in argmin, but received %s."
% (type(axis)))

if dtype is None:
Expand Down

0 comments on commit 6fc1598

Please sign in to comment.