diff --git a/paddle2onnx/mapper/tensor/argmax.cc b/paddle2onnx/mapper/tensor/argmax.cc index f5a997098..3327eea3d 100644 --- a/paddle2onnx/mapper/tensor/argmax.cc +++ b/paddle2onnx/mapper/tensor/argmax.cc @@ -17,6 +17,17 @@ namespace paddle2onnx { REGISTER_MAPPER(arg_max, ArgMaxMapper) +int32_t ArgMaxMapper::GetMinOpset(bool verbose) { + if (IsAttrVar("axis") && !IsConstant(GetAttrVar("axis")[0])) { + Error() << "While Attribute(axis)'s type is Tensor, it's not " + "supported " + "unless it's a constant tensor." + << std::endl; + return -1; + } + return 7; +} + void ArgMaxMapper::Opset7() { auto input_info = parser_->GetOpInput(block_idx_, op_idx_, "X"); auto output_info = parser_->GetOpOutput(block_idx_, op_idx_, "Out"); @@ -32,7 +43,20 @@ void ArgMaxMapper::Opset7() { need_unsqueeze = true; } } - + if (IsAttrVar("axis")) { + auto axis_info = GetAttrVar("axis"); + std::vector temp; + TryGetValue(axis_info[0], &temp); + axis_ = temp[0]; + } else { + GetAttr("axis", &axis_); + } + if (input_info[0].dtype == P2ODataType::FP64) { + input = helper_->AutoCast(input, P2ODataType::FP64, P2ODataType::FP32); + } + if (input_info[0].dtype == P2ODataType::INT64) { + input = helper_->AutoCast(input, P2ODataType::INT64, P2ODataType::INT32); + } auto arg_node = helper_->MakeNode("ArgMax", {input}); AddAttribute(arg_node, "axis", axis_); AddAttribute(arg_node, "keepdims", static_cast(keepdims_)); diff --git a/paddle2onnx/mapper/tensor/argmax.h b/paddle2onnx/mapper/tensor/argmax.h old mode 100644 new mode 100755 index df3146bc8..f7375e94e --- a/paddle2onnx/mapper/tensor/argmax.h +++ b/paddle2onnx/mapper/tensor/argmax.h @@ -27,9 +27,9 @@ class ArgMaxMapper : public Mapper { : Mapper(p, helper, block_id, op_id) { GetAttr("flatten", &flatten_); GetAttr("keepdims", &keepdims_); - GetAttr("axis", &axis_); GetAttr("dtype", &dtype_); } + int32_t GetMinOpset(bool verbose = false); void Opset7(); private: diff --git a/paddle2onnx/mapper/tensor/argmin.cc b/paddle2onnx/mapper/tensor/argmin.cc index 0c2402165..f0be6eda1 100644 --- a/paddle2onnx/mapper/tensor/argmin.cc +++ b/paddle2onnx/mapper/tensor/argmin.cc @@ -17,6 +17,17 @@ namespace paddle2onnx { REGISTER_MAPPER(arg_min, ArgMinMapper) +int32_t ArgMinMapper::GetMinOpset(bool verbose) { + if (IsAttrVar("axis") && !IsConstant(GetAttrVar("axis")[0])) { + Error() << "While Attribute(axis)'s type is Tensor, it's not " + "supported " + "unless it's a constant tensor." + << std::endl; + return -1; + } + return 7; +} + void ArgMinMapper::Opset7() { auto input_info = GetInput("X"); auto output_info = GetOutput("Out"); @@ -32,7 +43,20 @@ void ArgMinMapper::Opset7() { need_unsqueeze = true; } } - + if (IsAttrVar("axis")) { + auto axis_info = GetAttrVar("axis"); + std::vector temp; + TryGetValue(axis_info[0], &temp); + axis_ = temp[0]; + } else { + GetAttr("axis", &axis_); + } + if (input_info[0].dtype == P2ODataType::FP64) { + input = helper_->AutoCast(input, P2ODataType::FP64, P2ODataType::FP32); + } + if (input_info[0].dtype == P2ODataType::INT64) { + input = helper_->AutoCast(input, P2ODataType::INT64, P2ODataType::INT32); + } auto arg_node = helper_->MakeNode("ArgMin", {input}); AddAttribute(arg_node, "axis", axis_); AddAttribute(arg_node, "keepdims", static_cast(keepdims_)); diff --git a/paddle2onnx/mapper/tensor/argmin.h b/paddle2onnx/mapper/tensor/argmin.h old mode 100644 new mode 100755 index 5186cd007..ca29d1a59 --- a/paddle2onnx/mapper/tensor/argmin.h +++ b/paddle2onnx/mapper/tensor/argmin.h @@ -27,9 +27,9 @@ class ArgMinMapper : public Mapper { : Mapper(p, helper, block_id, op_id) { GetAttr("flatten", &flatten_); GetAttr("keepdims", &keepdims_); - GetAttr("axis", &axis_); GetAttr("dtype", &dtype_); } + int32_t GetMinOpset(bool verbose = false); void Opset7(); private: diff --git a/tests/test_auto_scan_argminmax.py b/tests/test_auto_scan_argminmax.py new file mode 100755 index 000000000..fe8eaea63 --- /dev/null +++ b/tests/test_auto_scan_argminmax.py @@ -0,0 +1,112 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import OPConvertAutoScanTest, BaseNet +from hypothesis import reproduce_failure +import hypothesis.strategies as st +import numpy as np +import unittest +import paddle +import random + +op_api_map = { + "arg_min": paddle.argmin, + "arg_max": paddle.argmax, +} + +opset_version_map = { + "arg_min": [7, 9, 15], + "arg_max": [7, 9, 15], +} + + +class Net(BaseNet): + """ + simple Net + """ + + def forward(self, inputs): + """ + forward + """ + if self.config["tensor_attr"]: + axis = paddle.assign(self.config["axis"]) + else: + axis = self.config["axis"] + x = op_api_map[self.config["op_names"]](inputs, + axis=axis, + keepdim=self.config["keep_dim"], + dtype=self.config["out_dtype"]) + return x + + +class TestArgMinMaxConvert(OPConvertAutoScanTest): + """ + api: paddle.argmin/argmax + OPset version: 7, 9, 15 + """ + + def sample_convert_config(self, draw): + input_shape = draw( + st.lists( + st.integers( + min_value=2, max_value=10), min_size=2, max_size=4)) + + input_spec = [-1] * len(input_shape) + + dtype = draw(st.sampled_from(["float32", "float64", "int32", "int64"])) + + axis = draw( + st.integers( + min_value=-len(input_shape), max_value=len(input_shape) - 1)) + + keep_dim = draw(st.booleans()) + + out_dtype = draw(st.sampled_from(["int32", "int64"])) + + tensor_attr = draw(st.booleans()) + + config = { + "op_names": ["reduce_max"], + "test_data_shapes": [input_shape], + "test_data_types": [[dtype]], + "opset_version": [7, 9, 15], + "axis": axis, + "out_dtype": out_dtype, + "keep_dim": keep_dim, + "tensor_attr": tensor_attr, + "input_spec_shape": [], + "delta": 1e-4, + "rtol": 1e-4 + } + + models = list() + op_names = list() + opset_versions = list() + for op_name, i in op_api_map.items(): + config["op_names"] = op_name + models.append(Net(config)) + op_names.append(op_name) + opset_versions.append(opset_version_map[op_name]) + config["op_names"] = op_names + config["opset_version"] = opset_versions + + return (config, models) + + def test(self): + self.run_and_statis(max_examples=30, max_duration=-1) + + +if __name__ == "__main__": + unittest.main()