From ff2b6180edf32788b9f92d4b9f178a00b54263eb Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Sat, 9 Nov 2024 13:45:08 +0800 Subject: [PATCH 1/9] [Paddle TensorRT] No.8,9 Add pd_op.(argmin,argsort) converter --- .../transforms/tensorrt/trt_op_marker_pass.cc | 74 ++++++++++++++++++ python/paddle/tensorrt/impls/search.py | 50 ++++++++++++ test/tensorrt/test_converter_search.py | 76 +++++++++++++++++++ 3 files changed, 200 insertions(+) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index 0da97028510e0f..bb76cf95d14fdd 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -29,6 +29,7 @@ #include "paddle/pir/include/core/builtin_op.h" #include "paddle/pir/include/pass/pass.h" #include "paddle/pir/include/pass/pass_registry.h" +#include "paddle/pir/include/pattern_rewrite/pattern_match.h" namespace { @@ -1235,6 +1236,78 @@ class ArgmaxOpPattern } }; +class ArgminOpPattern + : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + bool MatchAndRewrite(paddle::dialect::ArgminOp op, + pir::PatternRewriter &rewriter) const override { + if (op->HasAttribute(kCanRunTrtAttr) && + op.attribute(kCanRunTrtAttr).data()) { + return false; + } + if (!op.axis().defining_op()->isa()) { + VLOG(3) << "Skip to convert into TRT while found axis is not a constant " + "data in arg_max."; + return false; + } + auto x = op.x(); + auto x_tensor_type = x.type().dyn_cast(); + auto data_type = paddle::dialect::TransToPhiDataType(x_tensor_type.dtype()); + if (!(data_type == phi::DataType::FLOAT32 || + data_type == phi::DataType::FLOAT16 || + data_type == phi::DataType::FLOAT64)) { + return false; + } + int axis = static_cast(op.axis() + .defining_op() + ->attribute("value") + .data()); + + bool flatten = op.attribute("flatten").data(); + phi::DataType dtype = + op.attribute("dtype").data(); + if (axis == 0 || flatten || + (dtype != phi::DataType::INT32 && dtype != phi::DataType::INT64)) + return false; + op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); + return true; + } +}; + +class ArgsortPattern + : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + bool MatchAndRewrite(paddle::dialect::ArgsortOp op, + pir::PatternRewriter &rewriter) const override { + if (op->HasAttribute(kCanRunTrtAttr) && + op.attribute(kCanRunTrtAttr).data()) { + return false; + } + const std::vector required_attrs = {"axis", "descending"}; + for (const auto &attr : required_attrs) { + if (!op->HasAttribute(attr)) { + VLOG(3) << "Argsort " << attr << " attribute does not exist"; + return false; + } + } + auto x = op.x(); + auto x_type = x.type().dyn_cast(); + auto x_shape = x_type.dims(); + int axis = op->attribute("axis").data(); + if (axis < 0) { + axis += x_shape.size(); + } + if (x_shape[axis] > 3840 || x_shape[axis] < 0) { + VLOG(3) << "The axis dim of input should be less than 3840 and greater " + "than 0 in Tensorrt argsort"; + return false; + } + op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); + return true; + } +}; class BilinearInterpV2Pattern : public pir::OpRewritePattern { public: @@ -1653,6 +1726,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass { ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); + ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); diff --git a/python/paddle/tensorrt/impls/search.py b/python/paddle/tensorrt/impls/search.py index 80ba4c7f9a4918..afaab0231368c3 100644 --- a/python/paddle/tensorrt/impls/search.py +++ b/python/paddle/tensorrt/impls/search.py @@ -66,6 +66,56 @@ def argmax_converter(network, paddle_op, inputs): return squeeze_layer.get_output(0) +@converter_registry.register("pd_op.argmin", trt_version="8.x") +def argmin_converter(network, paddle_op, inputs): + x = inputs[0] + input_dims = x.shape + rank = len(input_dims) + axis = int( + paddle_op.operands()[1] + .source() + .get_defining_op() + .attrs() + .get("value", -1) + ) + keepdims = paddle_op.attrs()["keepdims"] + + if axis < 0: + axis += rank + + topk_layer = network.add_topk( + input=x, op=trt.TopKOperation.Min, k=1, axes=(1 << axis) + ) + + if keepdims: + return topk_layer.get_output(1) + else: + squeeze_layer = network.add_shuffle(topk_layer.get_output(1)) + output_dims = [] + for i in range(len(input_dims)): + if i == axis: + continue + output_dims.append(input_dims[i]) + squeeze_layer.reshape_dims = tuple(output_dims) + return squeeze_layer.get_output(0) + + +@converter_registry.register("pd_op.argsort", trt_version="8.x") +def argsort_converter(network, paddle_op, inputs): + input_tensor = inputs[0] + input_shape = input_tensor.shape + # The following two attributes is judged in Marker Pass. + # Default value maybe redundant. + axis = paddle_op.attrs().get("axis", -1) + descending = paddle_op.attrs().get("descending", False) + if axis < 0: + axis += len(input_shape) + topk_op = trt.TopKOperation.MAX if descending else trt.TopKOperation.MIN + k = input_shape[axis] + topk_layer = network.add_topk(input_tensor, topk_op, k, 1 << axis) + return topk_layer.get_output(1) + + @converter_registry.register("pd_op.topk", trt_version="8.x") def topk_converter(network, paddle_op, inputs): input_tensor = inputs[0] diff --git a/test/tensorrt/test_converter_search.py b/test/tensorrt/test_converter_search.py index 5e5f34ab0abe71..4c4fe9cf80404b 100644 --- a/test/tensorrt/test_converter_search.py +++ b/test/tensorrt/test_converter_search.py @@ -35,6 +35,82 @@ def test_trt_result(self): self.check_trt_result() +class TestArgminCase1TRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.argmin + self.api_args = { + "x": np.random.randn(2, 3).astype(np.float32), + "axis": -1, + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [1, 3]} + self.max_shape = {"x": [5, 3]} + + def test_trt_result(self): + self.check_trt_result() + + +class TestArgminCase2TRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.argmin + self.api_args = { + "x": np.random.randn(2, 3).astype(np.int64), + "axis": -1, + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [1, 3]} + self.max_shape = {"x": [5, 3]} + + def test_trt_result(self): + self.check_trt_result() + + +class TestArgsortCase1TRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.argsort + self.api_args = { + "x": np.random.randn(2, 3).astype(np.float32), + "axis": -1, + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [1, 3]} + self.max_shape = {"x": [5, 3]} + + def test_trt_result(self): + self.check_trt_result() + + +class TestArgsortCase2TRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.argsort + self.api_args = { + "x": np.random.randn(2, 3).astype(np.int64), + "axis": -1, + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [1, 3]} + self.max_shape = {"x": [5, 3]} + + def test_trt_result(self): + self.check_trt_result() + + +class TestArgsortCase3TRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.argsort + self.api_args = { + "x": np.random.randn(2, 3).astype(np.int64), + "axis": -1, + "descending": True, + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [1, 3]} + self.max_shape = {"x": [5, 3]} + + def test_trt_result(self): + self.check_trt_result() + + class TestTopkCase1TRTPattern(TensorRTBaseTest): def setUp(self): self.python_api = paddle.topk From f884a7b1c6c446658c5a157097314f7e60a7f3ba Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Sat, 9 Nov 2024 13:52:24 +0800 Subject: [PATCH 2/9] fix --- paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index a079a8457122a8..1e6c28f7c7fb55 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -1755,6 +1755,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass { ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); + ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); From 2b3946a32bb98f5c1e9466ccd358ca266f996035 Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Sat, 9 Nov 2024 14:12:40 +0800 Subject: [PATCH 3/9] Fix typos --- paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index 1e6c28f7c7fb55..9a1977a746efda 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -1276,7 +1276,7 @@ class ArgminOpPattern } }; -class ArgsortPattern +class ArgsortOpPattern : public pir::OpRewritePattern { public: using pir::OpRewritePattern::OpRewritePattern; From 40eaa67f75d4f781de1c011c45f2bf01fa2de95b Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Mon, 11 Nov 2024 16:16:36 +0800 Subject: [PATCH 4/9] Apply review --- .../transforms/tensorrt/trt_op_marker_pass.cc | 52 ++++++++++++------- python/paddle/tensorrt/impls/search.py | 38 +++++++++++--- test/tensorrt/test_converter_search.py | 31 ++++------- 3 files changed, 74 insertions(+), 47 deletions(-) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index 9a1977a746efda..0756e0b36d0023 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -1213,12 +1213,13 @@ class ArgmaxOpPattern "data in arg_max."; return false; } - auto x = op.x(); - auto x_tensor_type = x.type().dyn_cast(); - auto data_type = paddle::dialect::TransToPhiDataType(x_tensor_type.dtype()); - if (!(data_type == phi::DataType::FLOAT32 || - data_type == phi::DataType::FLOAT16 || - data_type == phi::DataType::FLOAT64)) { + pir::Value x = op.x(); + auto data_type = pir::GetDataTypeFromValue(x); + if (!(data_type.isa() || + data_type.isa() || + data_type.isa())) { + VLOG(3) << "At present, pd_op.argmax only support float32 or float16 or " + "float64 into trt."; return false; } int axis = static_cast(op.axis() @@ -1230,8 +1231,12 @@ class ArgmaxOpPattern phi::DataType dtype = op.attribute("dtype").data(); if (axis == 0 || flatten || - (dtype != phi::DataType::INT32 && dtype != phi::DataType::INT64)) + (dtype != phi::DataType::INT32 && dtype != phi::DataType::INT64)) { + VLOG(3) << "Skipping TRT conversion in pd_op.argmax: axis is zero, " + "flatten is True, or " + "dtype is int32/int64"; return false; + } op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); return true; } @@ -1249,15 +1254,16 @@ class ArgminOpPattern } if (!op.axis().defining_op()->isa()) { VLOG(3) << "Skip to convert into TRT while found axis is not a constant " - "data in arg_max."; + "data in arg_mix."; return false; } - auto x = op.x(); - auto x_tensor_type = x.type().dyn_cast(); - auto data_type = paddle::dialect::TransToPhiDataType(x_tensor_type.dtype()); - if (!(data_type == phi::DataType::FLOAT32 || - data_type == phi::DataType::FLOAT16 || - data_type == phi::DataType::FLOAT64)) { + pir::Value x = op.x(); + auto data_type = pir::GetDataTypeFromValue(x); + if (!(data_type.isa() || + data_type.isa() || + data_type.isa())) { + VLOG(3) << "At present, pd_op.argmin only support float32 or float16 or " + "float64 into trt."; return false; } int axis = static_cast(op.axis() @@ -1269,8 +1275,13 @@ class ArgminOpPattern phi::DataType dtype = op.attribute("dtype").data(); if (axis == 0 || flatten || - (dtype != phi::DataType::INT32 && dtype != phi::DataType::INT64)) + (dtype != phi::DataType::INT32 && dtype != phi::DataType::INT64)) { + VLOG(3) << "Skipping TRT conversion in pd_op.argmin: axis is zero, " + "flatten is True, or " + "dtype is int32/int64"; return false; + } + op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); return true; } @@ -1289,20 +1300,21 @@ class ArgsortOpPattern const std::vector required_attrs = {"axis", "descending"}; for (const auto &attr : required_attrs) { if (!op->HasAttribute(attr)) { - VLOG(3) << "Argsort " << attr << " attribute does not exist"; + VLOG(3) << "pd_op.argsort " << attr << " attribute does not exist"; return false; } } - auto x = op.x(); - auto x_type = x.type().dyn_cast(); + pir::Value x = op.x(); + auto x_type = pir::GetDataTypeFromValue(x); auto x_shape = x_type.dims(); int axis = op->attribute("axis").data(); if (axis < 0) { axis += x_shape.size(); } if (x_shape[axis] > 3840 || x_shape[axis] < 0) { - VLOG(3) << "The axis dim of input should be less than 3840 and greater " - "than 0 in Tensorrt argsort"; + VLOG(3) << "In pd_op.argsort,the axis dim of input should be less than " + "3840 and greater " + "than 0 in Tensorrt"; return false; } op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); diff --git a/python/paddle/tensorrt/impls/search.py b/python/paddle/tensorrt/impls/search.py index 8fde645449b3c7..039a1492d7ef0e 100644 --- a/python/paddle/tensorrt/impls/search.py +++ b/python/paddle/tensorrt/impls/search.py @@ -16,8 +16,11 @@ import tensorrt as trt from paddle.tensorrt.converter_utils import ( + get_shape_tensor_element, squeeze_trt, trt_cast, + trt_reshape, + trt_shape, unsqueeze_trt, ) from paddle.tensorrt.register import converter_registry @@ -104,16 +107,37 @@ def argmin_converter(network, paddle_op, inputs): def argsort_converter(network, paddle_op, inputs): input_tensor = inputs[0] input_shape = input_tensor.shape - # The following two attributes is judged in Marker Pass. - # Default value maybe redundant. - axis = paddle_op.attrs().get("axis", -1) - descending = paddle_op.attrs().get("descending", False) + in_type = input_tensor.dtype + in_rank = len(input_shape) + axis = paddle_op.attrs()["axis"] + descending = paddle_op.attrs()["descending"] if axis < 0: axis += len(input_shape) topk_op = trt.TopKOperation.MAX if descending else trt.TopKOperation.MIN - k = input_shape[axis] - topk_layer = network.add_topk(input_tensor, topk_op, k, 1 << axis) - return topk_layer.get_output(1) + need_cast = True if in_type == trt.DataType.FLOAT else False + if in_rank == 1: + unsqueeze_shape = trt.Dims([1, -1]) + input_tensor = trt_reshape( + network, input_tensor, unsqueeze_shape, is_shape_tensor=True + ) + axis = 1 + if need_cast: + input_tensor = trt_cast(network, input_tensor, trt.DataType.FLOAT) + topk_layer = network.add_topk(input_tensor, topk_op, 1, 1 << axis) + shape = trt_shape(network, input_shape) + k_tensor = get_shape_tensor_element(network, shape, axis, True) + topk_layer.set_input(1, k_tensor) + out = topk_layer.get_output(0) + indices = topk_layer.get_ouput(1) + if in_rank == 1: + squeeze_shape = trt.Dims([1, -1]) + out = trt_reshape(network, out, squeeze_shape, is_shape_tensor=True) + indices = trt_reshape( + network, indices, squeeze_shape, is_shape_tensor=True + ) + out_tensor = trt_cast(network, out, in_type) + indices_tensor = trt_cast(network, indices, indices.dtype) + return out_tensor, indices_tensor @converter_registry.register("pd_op.where", trt_version="8.x") diff --git a/test/tensorrt/test_converter_search.py b/test/tensorrt/test_converter_search.py index c17ad34d75c86d..4644f737e7d8a7 100644 --- a/test/tensorrt/test_converter_search.py +++ b/test/tensorrt/test_converter_search.py @@ -35,17 +35,20 @@ def test_trt_result(self): self.check_trt_result() -class TestArgminCase1TRTPattern(TensorRTBaseTest): +class TestArgminTRTPattern(TensorRTBaseTest): def setUp(self): self.python_api = paddle.argmin self.api_args = { - "x": np.random.randn(2, 3).astype(np.float32), + "x": np.random.randn(2, 3).astype("float32"), "axis": -1, } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [1, 3]} self.max_shape = {"x": [5, 3]} + def test_trt_result(self): + self.check_trt_result() + class TestWhereTRTPatternCase1(TensorRTBaseTest): def setUp(self): @@ -63,26 +66,11 @@ def test_trt_result(self): self.check_trt_result() -class TestArgminCase2TRTPattern(TensorRTBaseTest): - def setUp(self): - self.python_api = paddle.argmin - self.api_args = { - "x": np.random.randn(2, 3).astype(np.int64), - "axis": -1, - } - self.program_config = {"feed_list": ["x"]} - self.min_shape = {"x": [1, 3]} - self.max_shape = {"x": [5, 3]} - - def test_trt_result(self): - self.check_trt_result() - - class TestArgsortCase1TRTPattern(TensorRTBaseTest): def setUp(self): self.python_api = paddle.argsort self.api_args = { - "x": np.random.randn(2, 3).astype(np.float32), + "x": np.random.randn(2, 3).astype("float32"), "axis": -1, } self.program_config = {"feed_list": ["x"]} @@ -97,7 +85,7 @@ class TestArgsortCase2TRTPattern(TensorRTBaseTest): def setUp(self): self.python_api = paddle.argsort self.api_args = { - "x": np.random.randn(2, 3).astype(np.int64), + "x": np.random.randn(2, 3).astype("int64"), "axis": -1, } self.program_config = {"feed_list": ["x"]} @@ -112,7 +100,7 @@ class TestArgsortCase3TRTPattern(TensorRTBaseTest): def setUp(self): self.python_api = paddle.argsort self.api_args = { - "x": np.random.randn(2, 3).astype(np.int64), + "x": np.random.randn(2, 3).astype("float32"), "axis": -1, "descending": True, } @@ -120,6 +108,9 @@ def setUp(self): self.min_shape = {"x": [1, 3]} self.max_shape = {"x": [5, 3]} + def test_trt_result(self): + self.check_trt_result() + class TestWhereTRTPatternCase2(TensorRTBaseTest): def setUp(self): From 893d0e334571604d4aa2c4eb2a7634261f5a189c Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Mon, 11 Nov 2024 23:04:57 +0800 Subject: [PATCH 5/9] Fix typos --- .../fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc | 2 +- python/paddle/tensorrt/impls/search.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index 0756e0b36d0023..5facd5938f697c 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -1305,7 +1305,7 @@ class ArgsortOpPattern } } pir::Value x = op.x(); - auto x_type = pir::GetDataTypeFromValue(x); + auto x_type = x.type().dyn_cast(); auto x_shape = x_type.dims(); int axis = op->attribute("axis").data(); if (axis < 0) { diff --git a/python/paddle/tensorrt/impls/search.py b/python/paddle/tensorrt/impls/search.py index 039a1492d7ef0e..625223894d852c 100644 --- a/python/paddle/tensorrt/impls/search.py +++ b/python/paddle/tensorrt/impls/search.py @@ -87,7 +87,7 @@ def argmin_converter(network, paddle_op, inputs): axis += rank topk_layer = network.add_topk( - input=x, op=trt.TopKOperation.Min, k=1, axes=(1 << axis) + input=x, op=trt.TopKOperation.MIN, k=1, axes=(1 << axis) ) if keepdims: @@ -114,7 +114,7 @@ def argsort_converter(network, paddle_op, inputs): if axis < 0: axis += len(input_shape) topk_op = trt.TopKOperation.MAX if descending else trt.TopKOperation.MIN - need_cast = True if in_type == trt.DataType.FLOAT else False + need_cast = True if in_type != trt.DataType.FLOAT else False if in_rank == 1: unsqueeze_shape = trt.Dims([1, -1]) input_tensor = trt_reshape( @@ -124,13 +124,13 @@ def argsort_converter(network, paddle_op, inputs): if need_cast: input_tensor = trt_cast(network, input_tensor, trt.DataType.FLOAT) topk_layer = network.add_topk(input_tensor, topk_op, 1, 1 << axis) - shape = trt_shape(network, input_shape) + shape = trt_shape(network, input_tensor) k_tensor = get_shape_tensor_element(network, shape, axis, True) topk_layer.set_input(1, k_tensor) out = topk_layer.get_output(0) indices = topk_layer.get_ouput(1) if in_rank == 1: - squeeze_shape = trt.Dims([1, -1]) + squeeze_shape = trt.Dims([-1]) out = trt_reshape(network, out, squeeze_shape, is_shape_tensor=True) indices = trt_reshape( network, indices, squeeze_shape, is_shape_tensor=True From 53a5840fecfd9635596b8d2eb86f4f9026b29c82 Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Tue, 12 Nov 2024 09:53:12 +0800 Subject: [PATCH 6/9] Fix Typos --- python/paddle/tensorrt/impls/search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/tensorrt/impls/search.py b/python/paddle/tensorrt/impls/search.py index 625223894d852c..ea1edb37a33c0c 100644 --- a/python/paddle/tensorrt/impls/search.py +++ b/python/paddle/tensorrt/impls/search.py @@ -128,7 +128,7 @@ def argsort_converter(network, paddle_op, inputs): k_tensor = get_shape_tensor_element(network, shape, axis, True) topk_layer.set_input(1, k_tensor) out = topk_layer.get_output(0) - indices = topk_layer.get_ouput(1) + indices = topk_layer.get_output(1) if in_rank == 1: squeeze_shape = trt.Dims([-1]) out = trt_reshape(network, out, squeeze_shape, is_shape_tensor=True) From 83b6e9649e8d52057bcfa0f96c418afb26309667 Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Wed, 13 Nov 2024 16:13:27 +0800 Subject: [PATCH 7/9] Add more tests to pass coverage ci --- .../transforms/tensorrt/trt_op_marker_pass.cc | 4 +- test/tensorrt/test_converter_search.py | 122 ++++++++++++++++-- 2 files changed, 116 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index 5facd5938f697c..68ba215db9ea1f 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -1234,7 +1234,7 @@ class ArgmaxOpPattern (dtype != phi::DataType::INT32 && dtype != phi::DataType::INT64)) { VLOG(3) << "Skipping TRT conversion in pd_op.argmax: axis is zero, " "flatten is True, or " - "dtype is int32/int64"; + "dtype isn't int32/int64"; return false; } op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); @@ -1278,7 +1278,7 @@ class ArgminOpPattern (dtype != phi::DataType::INT32 && dtype != phi::DataType::INT64)) { VLOG(3) << "Skipping TRT conversion in pd_op.argmin: axis is zero, " "flatten is True, or " - "dtype is int32/int64"; + "dtype isn't int32/int64"; return false; } diff --git a/test/tensorrt/test_converter_search.py b/test/tensorrt/test_converter_search.py index 4644f737e7d8a7..be8c742e9f8f1c 100644 --- a/test/tensorrt/test_converter_search.py +++ b/test/tensorrt/test_converter_search.py @@ -20,11 +20,11 @@ import paddle -class TestArgmaxTRTPattern(TensorRTBaseTest): +class TestArgmaxCase1TRTPattern(TensorRTBaseTest): def setUp(self): self.python_api = paddle.argmax self.api_args = { - "x": np.random.randn(2, 3).astype(np.float32), + "x": np.random.randn(2, 3).astype("float32"), "axis": -1, } self.program_config = {"feed_list": ["x"]} @@ -35,7 +35,53 @@ def test_trt_result(self): self.check_trt_result() -class TestArgminTRTPattern(TensorRTBaseTest): +class TestArgmaxCase2TRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.argmax + self.api_args = { + "x": np.random.randn(2, 3).astype("int64"), + "axis": -1, + } + self.program_config = {"feed_list": ["x"]} + self.target_marker_op = "pd_op.argmax" + + def test_trt_result(self): + # test input's dtype + self.check_marker(expected_result=False) + + +class TestArgmaxCase3TRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.argmax + self.api_args = { + "x": np.random.randn(2, 3).astype("float32"), + "axis": 0, + } + self.program_config = {"feed_list": ["x"]} + self.target_marker_op = "pd_op.argmax" + + def test_trt_result(self): + # test axis + self.check_marker(expected_result=False) + + +class TestArgmaxCase4TRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.argmax + self.api_args = { + "x": np.random.randn(2, 3).astype("float32"), + "axis": -1, + "dtype": "float32", + } + self.program_config = {"feed_list": ["x"]} + self.target_marker_op = "pd_op.argmax" + + def test_trt_result(self): + # test dtype attr + self.check_marker(expected_result=False) + + +class TestArgminCase1TRTPattern(TensorRTBaseTest): def setUp(self): self.python_api = paddle.argmin self.api_args = { @@ -50,6 +96,52 @@ def test_trt_result(self): self.check_trt_result() +class TestArgminCase2TRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.argmin + self.api_args = { + "x": np.random.randn(2, 3).astype("int64"), + "axis": -1, + } + self.program_config = {"feed_list": ["x"]} + self.target_marker_op = "pd_op.argmin" + + def test_trt_result(self): + # test input's dtype + self.check_marker(expected_result=False) + + +class TestArgminCase3TRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.argmin + self.api_args = { + "x": np.random.randn(2, 3).astype("float32"), + "axis": 0, + } + self.program_config = {"feed_list": ["x"]} + self.target_marker_op = "pd_op.argmin" + + def test_trt_result(self): + # test axis + self.check_marker(expected_result=False) + + +class TestArgminCase4TRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.argmin + self.api_args = { + "x": np.random.randn(2, 3).astype("float32"), + "axis": -1, + "dtype": "float32", + } + self.program_config = {"feed_list": ["x"]} + self.target_marker_op = "pd_op.argmin" + + def test_trt_result(self): + # test dtype attr + self.check_marker(expected_result=False) + + class TestWhereTRTPatternCase1(TensorRTBaseTest): def setUp(self): self.python_api = paddle.where @@ -85,12 +177,12 @@ class TestArgsortCase2TRTPattern(TensorRTBaseTest): def setUp(self): self.python_api = paddle.argsort self.api_args = { - "x": np.random.randn(2, 3).astype("int64"), + "x": np.random.randn(2).astype("float32"), "axis": -1, } self.program_config = {"feed_list": ["x"]} - self.min_shape = {"x": [1, 3]} - self.max_shape = {"x": [5, 3]} + self.min_shape = {"x": [1]} + self.max_shape = {"x": [5]} def test_trt_result(self): self.check_trt_result() @@ -100,9 +192,8 @@ class TestArgsortCase3TRTPattern(TensorRTBaseTest): def setUp(self): self.python_api = paddle.argsort self.api_args = { - "x": np.random.randn(2, 3).astype("float32"), + "x": np.random.randn(2, 3).astype("int64"), "axis": -1, - "descending": True, } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [1, 3]} @@ -112,6 +203,21 @@ def test_trt_result(self): self.check_trt_result() +class TestArgsortCase4TRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.argsort + self.api_args = { + "x": np.random.randn(2, 3).astype("float32"), + "axis": 3940, + } + self.program_config = {"feed_list": ["x"]} + self.target_marker_op = "pd_op.argsort" + + def test_trt_result(self): + # test axis attr + self.check_marker(expected_result=False) + + class TestWhereTRTPatternCase2(TensorRTBaseTest): def setUp(self): self.python_api = paddle.where From 911c63e6fa664c95d1ed8dfc28ae0cf8621af00e Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Wed, 13 Nov 2024 23:11:12 +0800 Subject: [PATCH 8/9] Fix tests --- .../transforms/tensorrt/trt_op_marker_pass.cc | 5 ++- python/paddle/tensorrt/impls/search.py | 6 ++-- test/tensorrt/test_converter_search.py | 36 ++----------------- 3 files changed, 7 insertions(+), 40 deletions(-) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index 68ba215db9ea1f..e369f903c427ca 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -1311,10 +1311,9 @@ class ArgsortOpPattern if (axis < 0) { axis += x_shape.size(); } - if (x_shape[axis] > 3840 || x_shape[axis] < 0) { + if (x_shape[axis] > 3840) { VLOG(3) << "In pd_op.argsort,the axis dim of input should be less than " - "3840 and greater " - "than 0 in Tensorrt"; + "3840"; return false; } op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); diff --git a/python/paddle/tensorrt/impls/search.py b/python/paddle/tensorrt/impls/search.py index ea1edb37a33c0c..688860abc83983 100644 --- a/python/paddle/tensorrt/impls/search.py +++ b/python/paddle/tensorrt/impls/search.py @@ -118,7 +118,7 @@ def argsort_converter(network, paddle_op, inputs): if in_rank == 1: unsqueeze_shape = trt.Dims([1, -1]) input_tensor = trt_reshape( - network, input_tensor, unsqueeze_shape, is_shape_tensor=True + network, input_tensor, unsqueeze_shape, is_shape_tensor=False ) axis = 1 if need_cast: @@ -131,9 +131,9 @@ def argsort_converter(network, paddle_op, inputs): indices = topk_layer.get_output(1) if in_rank == 1: squeeze_shape = trt.Dims([-1]) - out = trt_reshape(network, out, squeeze_shape, is_shape_tensor=True) + out = trt_reshape(network, out, squeeze_shape, is_shape_tensor=False) indices = trt_reshape( - network, indices, squeeze_shape, is_shape_tensor=True + network, indices, squeeze_shape, is_shape_tensor=False ) out_tensor = trt_cast(network, out, in_type) indices_tensor = trt_cast(network, indices, indices.dtype) diff --git a/test/tensorrt/test_converter_search.py b/test/tensorrt/test_converter_search.py index be8c742e9f8f1c..da2c8a30ab5146 100644 --- a/test/tensorrt/test_converter_search.py +++ b/test/tensorrt/test_converter_search.py @@ -65,22 +65,6 @@ def test_trt_result(self): self.check_marker(expected_result=False) -class TestArgmaxCase4TRTPattern(TensorRTBaseTest): - def setUp(self): - self.python_api = paddle.argmax - self.api_args = { - "x": np.random.randn(2, 3).astype("float32"), - "axis": -1, - "dtype": "float32", - } - self.program_config = {"feed_list": ["x"]} - self.target_marker_op = "pd_op.argmax" - - def test_trt_result(self): - # test dtype attr - self.check_marker(expected_result=False) - - class TestArgminCase1TRTPattern(TensorRTBaseTest): def setUp(self): self.python_api = paddle.argmin @@ -126,22 +110,6 @@ def test_trt_result(self): self.check_marker(expected_result=False) -class TestArgminCase4TRTPattern(TensorRTBaseTest): - def setUp(self): - self.python_api = paddle.argmin - self.api_args = { - "x": np.random.randn(2, 3).astype("float32"), - "axis": -1, - "dtype": "float32", - } - self.program_config = {"feed_list": ["x"]} - self.target_marker_op = "pd_op.argmin" - - def test_trt_result(self): - # test dtype attr - self.check_marker(expected_result=False) - - class TestWhereTRTPatternCase1(TensorRTBaseTest): def setUp(self): self.python_api = paddle.where @@ -207,8 +175,8 @@ class TestArgsortCase4TRTPattern(TensorRTBaseTest): def setUp(self): self.python_api = paddle.argsort self.api_args = { - "x": np.random.randn(2, 3).astype("float32"), - "axis": 3940, + "x": np.random.randn(2, 4000).astype("float32"), + "axis": 1, } self.program_config = {"feed_list": ["x"]} self.target_marker_op = "pd_op.argsort" From 5f09dc7b54e3923916ff062efef6e1ab9ed94b8a Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Thu, 14 Nov 2024 06:58:25 +0800 Subject: [PATCH 9/9] Fix tests --- .../transforms/tensorrt/trt_op_marker_pass.cc | 10 +++---- test/tensorrt/test_converter_search.py | 30 +++++++++++++++++++ 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index e369f903c427ca..447eade9d41e72 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -1233,8 +1233,7 @@ class ArgmaxOpPattern if (axis == 0 || flatten || (dtype != phi::DataType::INT32 && dtype != phi::DataType::INT64)) { VLOG(3) << "Skipping TRT conversion in pd_op.argmax: axis is zero, " - "flatten is True, or " - "dtype isn't int32/int64"; + "flatten is True, or dtype isn't int32/int64"; return false; } op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); @@ -1277,8 +1276,7 @@ class ArgminOpPattern if (axis == 0 || flatten || (dtype != phi::DataType::INT32 && dtype != phi::DataType::INT64)) { VLOG(3) << "Skipping TRT conversion in pd_op.argmin: axis is zero, " - "flatten is True, or " - "dtype isn't int32/int64"; + "flatten is True, or dtype isn't int32/int64"; return false; } @@ -1312,8 +1310,8 @@ class ArgsortOpPattern axis += x_shape.size(); } if (x_shape[axis] > 3840) { - VLOG(3) << "In pd_op.argsort,the axis dim of input should be less than " - "3840"; + VLOG(3) + << "In pd_op.argsort,the axis dim of input should be less than 3840"; return false; } op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); diff --git a/test/tensorrt/test_converter_search.py b/test/tensorrt/test_converter_search.py index da2c8a30ab5146..4fdc9a03008dba 100644 --- a/test/tensorrt/test_converter_search.py +++ b/test/tensorrt/test_converter_search.py @@ -65,6 +65,21 @@ def test_trt_result(self): self.check_marker(expected_result=False) +class TestArgmaxCase4TRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.argmin + self.api_args = { + "x": np.random.randn(2, 3).astype("float32"), + "axis": np.random.randn(1).astype("int64"), + } + self.program_config = {"feed_list": ["x", "axis"]} + self.target_marker_op = "pd_op.argmax" + + def test_trt_result(self): + # test axis Value + self.check_marker(expected_result=False) + + class TestArgminCase1TRTPattern(TensorRTBaseTest): def setUp(self): self.python_api = paddle.argmin @@ -110,6 +125,21 @@ def test_trt_result(self): self.check_marker(expected_result=False) +class TestArgminCase4TRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.argmin + self.api_args = { + "x": np.random.randn(2, 3).astype("float32"), + "axis": np.random.randn(1).astype("int64"), + } + self.program_config = {"feed_list": ["x", "axis"]} + self.target_marker_op = "pd_op.argmin" + + def test_trt_result(self): + # test axis Value + self.check_marker(expected_result=False) + + class TestWhereTRTPatternCase1(TensorRTBaseTest): def setUp(self): self.python_api = paddle.where