From 6b51057885d632e3ac789e22a3f9fe164fd75b73 Mon Sep 17 00:00:00 2001 From: Anurag Dixit Date: Tue, 11 Jul 2023 16:55:02 -0700 Subject: [PATCH 1/2] feat: Added support for aten::tile converter Signed-off-by: Anurag Dixit --- core/conversion/converters/impl/expand.cpp | 115 +++++++++-------- .../conversion/converters/test_expand.cpp | 120 ++++++++++++++++++ 2 files changed, 181 insertions(+), 54 deletions(-) diff --git a/core/conversion/converters/impl/expand.cpp b/core/conversion/converters/impl/expand.cpp index 6b22fea8d4..692a40b105 100644 --- a/core/conversion/converters/impl/expand.cpp +++ b/core/conversion/converters/impl/expand.cpp @@ -194,6 +194,60 @@ bool add_expand_dynamic( return true; } +bool add_repeat(ConversionCtx* ctx, const torch::jit::Node* n, args& args, const std::string& layer) { + auto in = args[0].ITensorOrFreeze(ctx); + auto input_dims = in->getDimensions(); + auto repeats = args[1].unwrapToIntList().vec(); + int repeats_rank = repeats.size(); + TORCHTRT_CHECK( + repeats_rank >= input_dims.nbDims, + "Number of repeat dimensions cannot be smaller than number of input dimensions"); + + auto num_expand_dims = repeats_rank - input_dims.nbDims; + + if (ctx->input_is_dynamic) { + int input_rank = input_dims.nbDims; + int output_rank = repeats_rank; + auto new_input_shape_tensor = concat(output_rank, input_rank, ctx, in); + + auto shuffle = ctx->net->addShuffle(*in); + shuffle->setInput(1, *new_input_shape_tensor); + in = shuffle->getOutput(0); + } else { + if (num_expand_dims > 0) { + nvinfer1::Dims reshape_dims; + reshape_dims.nbDims = repeats.size(); + for (int i = 0; i < num_expand_dims; i++) { + reshape_dims.d[i] = 1; + } + for (int i = 0; i < input_dims.nbDims; i++) { + reshape_dims.d[num_expand_dims + i] = input_dims.d[i]; + } + // Add a reshape layer to expand dims + auto reshape_layer = ctx->net->addShuffle(*in); + reshape_layer->setReshapeDimensions(reshape_dims); + in = reshape_layer->getOutput(0); + LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims); + } + LOG_DEBUG("Repeats: " << repeats); + } + + // Concat across all repeat axes. + for (int i = repeats.size() - 1; i >= 0; --i) { + std::vector tensors_vec; + for (int j = 0; j < repeats[i]; j++) { + tensors_vec.push_back(in); + } + auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size()); + concat_layer->setAxis(i); + in = concat_layer->getOutput(0); + } + + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in); + LOG_DEBUG(layer << " layer output tensor shape: " << out->getDimensions()); + return true; +} + auto expand_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() .pattern( @@ -230,59 +284,7 @@ auto expand_registrations TORCHTRT_UNUSED = .pattern( {"aten::repeat(Tensor self, int[] repeats) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensorOrFreeze(ctx); - auto input_dims = in->getDimensions(); - auto repeats = args[1].unwrapToIntList().vec(); - int repeats_rank = repeats.size(); - TORCHTRT_CHECK( - repeats_rank >= input_dims.nbDims, - "Number of repeat dimensions cannot be smaller than number of input dimensions"); - auto num_expand_dims = repeats_rank - input_dims.nbDims; - - if (ctx->input_is_dynamic) { - int input_rank = input_dims.nbDims; - int output_rank = repeats_rank; - auto new_input_shape_tensor = concat(output_rank, input_rank, ctx, in); - - // Add a reshape layer to expand dims - auto shuffle = ctx->net->addShuffle(*in); - shuffle->setInput(1, *new_input_shape_tensor); - in = shuffle->getOutput(0); - } else { - if (num_expand_dims > 0) { - nvinfer1::Dims reshape_dims; - reshape_dims.nbDims = repeats.size(); - for (int i = 0; i < num_expand_dims; i++) { - reshape_dims.d[i] = 1; - } - for (int i = 0; i < input_dims.nbDims; i++) { - reshape_dims.d[num_expand_dims + i] = input_dims.d[i]; - } - // Add a reshape layer to expand dims - auto reshape_layer = ctx->net->addShuffle(*in); - reshape_layer->setReshapeDimensions(reshape_dims); - in = reshape_layer->getOutput(0); - LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims); - } - LOG_DEBUG("Repeats: " << repeats); - } - - // Concat across all repeat axes. - // TODO: Implementation might not be performant. Explore other strategies to improve performance. - for (int i = repeats.size() - 1; i >= 0; --i) { - std::vector tensors_vec; - for (int j = 0; j < repeats[i]; j++) { - tensors_vec.push_back(in); - } - auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size()); - concat_layer->setAxis(i); - in = concat_layer->getOutput(0); - } - - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in); - - LOG_DEBUG("Repeat layer output tensor shape: " << out->getDimensions()); - return true; + return add_repeat(ctx, n, args, "Repeat"); }}) .pattern( {"aten::repeat_interleave.self_int(Tensor self, int repeats, int? dim=None, *, int? output_size=None) -> (Tensor)", @@ -395,6 +397,11 @@ auto expand_registrations TORCHTRT_UNUSED = return true; }}) + .pattern( + {"aten::tile(Tensor self, int[] dims) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + return add_repeat(ctx, n, args, "Tile"); + }}) .pattern( {"aten::meshgrid(Tensor[] tensors) -> (Tensor[])", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { @@ -484,4 +491,4 @@ auto expand_registrations TORCHTRT_UNUSED = } // namespace converters } // namespace conversion } // namespace core -} // namespace torch_tensorrt +} // namespace torch_tensorrt \ No newline at end of file diff --git a/tests/core/conversion/converters/test_expand.cpp b/tests/core/conversion/converters/test_expand.cpp index 77b42fb1d9..e94d14dd0c 100644 --- a/tests/core/conversion/converters/test_expand.cpp +++ b/tests/core/conversion/converters/test_expand.cpp @@ -670,6 +670,126 @@ TEST(Converters, ATenRepeatInterleave3dScalarNoDimConvertsCorrectlyWithDynamicIn ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } +TEST(Converters, ATenTileConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int[] = prim::Constant[value=[4, 1]]() + %3 : Tensor = aten::tile(%x.1, %2) + return (%3))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {1, 3}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(jit_in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenTileRepeatRankConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int[] = prim::Constant[value=[4, 1, 2]]() + %3 : Tensor = aten::tile(%x.1, %2) + return (%3))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {1, 3}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(jit_in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenTileConvertsCorrectlyWithDynamicInput) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int[] = prim::Constant[value=[4, 1]]() + %3 : Tensor = aten::tile(%x.1, %2) + return (%3))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {1, 3}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(jit_in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenTile3dConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int[] = prim::Constant[value=[2, 2, 2]]() + %3 : Tensor = aten::tile(%x.1, %2) + return (%3))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(jit_in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenTile3dConvertsCorrectlyWithDynamicInput) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int[] = prim::Constant[value=[2, 2, 2]]() + %3 : Tensor = aten::tile(%x.1, %2) + return (%3))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(jit_in); + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + TEST(Converters, ATenMeshGridConvertsCorrectly) { const auto graph = R"IR( graph(%x : Tensor, %y : Tensor, %z : Tensor): From b7b2725d4f66cc0fb31efe78ef2a53c11ab5bb4d Mon Sep 17 00:00:00 2001 From: Anurag Dixit Date: Wed, 2 Aug 2023 20:23:33 -0700 Subject: [PATCH 2/2] feat: Moved from converter to lowering pass Signed-off-by: Anurag Dixit --- core/conversion/converters/impl/expand.cpp | 115 ++++++++---------- core/lowering/lowering.cpp | 1 + core/lowering/passes/BUILD | 1 + core/lowering/passes/CMakeLists.txt | 1 + core/lowering/passes/passes.h | 1 + core/lowering/passes/tile_to_repeat.cpp | 25 ++++ docsrc/contributors/lowering.rst | 7 ++ .../conversion/converters/test_expand.cpp | 6 + tests/core/lowering/BUILD | 5 + .../lowering/test_tile_to_repeat_pass.cpp | 26 ++++ 10 files changed, 127 insertions(+), 61 deletions(-) create mode 100644 core/lowering/passes/tile_to_repeat.cpp create mode 100644 tests/core/lowering/test_tile_to_repeat_pass.cpp diff --git a/core/conversion/converters/impl/expand.cpp b/core/conversion/converters/impl/expand.cpp index 692a40b105..6b22fea8d4 100644 --- a/core/conversion/converters/impl/expand.cpp +++ b/core/conversion/converters/impl/expand.cpp @@ -194,60 +194,6 @@ bool add_expand_dynamic( return true; } -bool add_repeat(ConversionCtx* ctx, const torch::jit::Node* n, args& args, const std::string& layer) { - auto in = args[0].ITensorOrFreeze(ctx); - auto input_dims = in->getDimensions(); - auto repeats = args[1].unwrapToIntList().vec(); - int repeats_rank = repeats.size(); - TORCHTRT_CHECK( - repeats_rank >= input_dims.nbDims, - "Number of repeat dimensions cannot be smaller than number of input dimensions"); - - auto num_expand_dims = repeats_rank - input_dims.nbDims; - - if (ctx->input_is_dynamic) { - int input_rank = input_dims.nbDims; - int output_rank = repeats_rank; - auto new_input_shape_tensor = concat(output_rank, input_rank, ctx, in); - - auto shuffle = ctx->net->addShuffle(*in); - shuffle->setInput(1, *new_input_shape_tensor); - in = shuffle->getOutput(0); - } else { - if (num_expand_dims > 0) { - nvinfer1::Dims reshape_dims; - reshape_dims.nbDims = repeats.size(); - for (int i = 0; i < num_expand_dims; i++) { - reshape_dims.d[i] = 1; - } - for (int i = 0; i < input_dims.nbDims; i++) { - reshape_dims.d[num_expand_dims + i] = input_dims.d[i]; - } - // Add a reshape layer to expand dims - auto reshape_layer = ctx->net->addShuffle(*in); - reshape_layer->setReshapeDimensions(reshape_dims); - in = reshape_layer->getOutput(0); - LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims); - } - LOG_DEBUG("Repeats: " << repeats); - } - - // Concat across all repeat axes. - for (int i = repeats.size() - 1; i >= 0; --i) { - std::vector tensors_vec; - for (int j = 0; j < repeats[i]; j++) { - tensors_vec.push_back(in); - } - auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size()); - concat_layer->setAxis(i); - in = concat_layer->getOutput(0); - } - - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in); - LOG_DEBUG(layer << " layer output tensor shape: " << out->getDimensions()); - return true; -} - auto expand_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() .pattern( @@ -284,7 +230,59 @@ auto expand_registrations TORCHTRT_UNUSED = .pattern( {"aten::repeat(Tensor self, int[] repeats) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - return add_repeat(ctx, n, args, "Repeat"); + auto in = args[0].ITensorOrFreeze(ctx); + auto input_dims = in->getDimensions(); + auto repeats = args[1].unwrapToIntList().vec(); + int repeats_rank = repeats.size(); + TORCHTRT_CHECK( + repeats_rank >= input_dims.nbDims, + "Number of repeat dimensions cannot be smaller than number of input dimensions"); + auto num_expand_dims = repeats_rank - input_dims.nbDims; + + if (ctx->input_is_dynamic) { + int input_rank = input_dims.nbDims; + int output_rank = repeats_rank; + auto new_input_shape_tensor = concat(output_rank, input_rank, ctx, in); + + // Add a reshape layer to expand dims + auto shuffle = ctx->net->addShuffle(*in); + shuffle->setInput(1, *new_input_shape_tensor); + in = shuffle->getOutput(0); + } else { + if (num_expand_dims > 0) { + nvinfer1::Dims reshape_dims; + reshape_dims.nbDims = repeats.size(); + for (int i = 0; i < num_expand_dims; i++) { + reshape_dims.d[i] = 1; + } + for (int i = 0; i < input_dims.nbDims; i++) { + reshape_dims.d[num_expand_dims + i] = input_dims.d[i]; + } + // Add a reshape layer to expand dims + auto reshape_layer = ctx->net->addShuffle(*in); + reshape_layer->setReshapeDimensions(reshape_dims); + in = reshape_layer->getOutput(0); + LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims); + } + LOG_DEBUG("Repeats: " << repeats); + } + + // Concat across all repeat axes. + // TODO: Implementation might not be performant. Explore other strategies to improve performance. + for (int i = repeats.size() - 1; i >= 0; --i) { + std::vector tensors_vec; + for (int j = 0; j < repeats[i]; j++) { + tensors_vec.push_back(in); + } + auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size()); + concat_layer->setAxis(i); + in = concat_layer->getOutput(0); + } + + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in); + + LOG_DEBUG("Repeat layer output tensor shape: " << out->getDimensions()); + return true; }}) .pattern( {"aten::repeat_interleave.self_int(Tensor self, int repeats, int? dim=None, *, int? output_size=None) -> (Tensor)", @@ -397,11 +395,6 @@ auto expand_registrations TORCHTRT_UNUSED = return true; }}) - .pattern( - {"aten::tile(Tensor self, int[] dims) -> (Tensor)", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - return add_repeat(ctx, n, args, "Tile"); - }}) .pattern( {"aten::meshgrid(Tensor[] tensors) -> (Tensor[])", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { @@ -491,4 +484,4 @@ auto expand_registrations TORCHTRT_UNUSED = } // namespace converters } // namespace conversion } // namespace core -} // namespace torch_tensorrt \ No newline at end of file +} // namespace torch_tensorrt diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index cb1fd97327..ce38a1f292 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -153,6 +153,7 @@ void LowerGraph(std::shared_ptr& g, std::vector& graph, std::st void UnpackAndCastFull(std::shared_ptr& graph, std::string target_device_name); void ReplaceScalarImplicit(std::shared_ptr& graph); void ReplaceAtenPad(std::shared_ptr& graph); +void ReplaceTileWithRepeat(std::shared_ptr& graph); // utility functions exposed for testing std::string unmangle_cls_name(const std::string& name); diff --git a/core/lowering/passes/tile_to_repeat.cpp b/core/lowering/passes/tile_to_repeat.cpp new file mode 100644 index 0000000000..7ecb2bc13d --- /dev/null +++ b/core/lowering/passes/tile_to_repeat.cpp @@ -0,0 +1,25 @@ +#include "core/util/prelude.h" +#include "torch/csrc/jit/passes/subgraph_rewrite.h" + +namespace torch_tensorrt { +namespace core { +namespace lowering { +namespace passes { +void ReplaceTileWithRepeat(std::shared_ptr& graph) { + std::string tile_pattern = R"IR( + graph(%input, %1): + %2 = aten::tile(%input, %1) + return (%2))IR"; + std::string repeat_pattern = R"IR( + graph(%input, %1): + %2 = aten::repeat(%input, %1) + return (%2))IR"; + torch::jit::SubgraphRewriter tile_to_repeat; + tile_to_repeat.RegisterRewritePattern(tile_pattern, repeat_pattern); + tile_to_repeat.runOnGraph(graph); + LOG_GRAPH("Mapping tile -> repeat: " << *graph); +} +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace torch_tensorrt diff --git a/docsrc/contributors/lowering.rst b/docsrc/contributors/lowering.rst index 956c2004e1..a82f497ed2 100644 --- a/docsrc/contributors/lowering.rst +++ b/docsrc/contributors/lowering.rst @@ -205,3 +205,10 @@ Unroll Loops `torch/csrc/jit/passes/loop_unrolling.h `_ Unrolls the operations of compatable loops (e.g. sufficently short) so that you only have to go through the loop once. + +Replace Tile with Repeat +*************************************** + + `Torch-TensorRT/core/lowering/passes/tile_to_repeat.cpp `_ + +Removes dropout operators since we are doing inference. diff --git a/tests/core/conversion/converters/test_expand.cpp b/tests/core/conversion/converters/test_expand.cpp index e94d14dd0c..341fe29aa4 100644 --- a/tests/core/conversion/converters/test_expand.cpp +++ b/tests/core/conversion/converters/test_expand.cpp @@ -1,6 +1,7 @@ #include #include #include "core/compiler.h" +#include "core/lowering/passes/passes.h" #include "gtest/gtest.h" #include "tests/util/util.h" #include "torch/csrc/jit/ir/irparser.h" @@ -680,6 +681,7 @@ TEST(Converters, ATenTileConvertsCorrectly) { auto g = std::make_shared(); torch::jit::parseIR(graph, g.get()); + torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g); auto in = at::randint(1, 10, {1, 3}, {at::kCUDA}); @@ -704,6 +706,7 @@ TEST(Converters, ATenTileRepeatRankConvertsCorrectly) { auto g = std::make_shared(); torch::jit::parseIR(graph, g.get()); + torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g); auto in = at::randint(1, 10, {1, 3}, {at::kCUDA}); @@ -728,6 +731,7 @@ TEST(Converters, ATenTileConvertsCorrectlyWithDynamicInput) { auto g = std::make_shared(); torch::jit::parseIR(graph, g.get()); + torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g); auto in = at::randint(1, 10, {1, 3}, {at::kCUDA}); @@ -752,6 +756,7 @@ TEST(Converters, ATenTile3dConvertsCorrectly) { auto g = std::make_shared(); torch::jit::parseIR(graph, g.get()); + torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g); auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA}); @@ -776,6 +781,7 @@ TEST(Converters, ATenTile3dConvertsCorrectlyWithDynamicInput) { auto g = std::make_shared(); torch::jit::parseIR(graph, g.get()); + torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g); auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA}); diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD index 081443ecb3..30f1fd8e5a 100644 --- a/tests/core/lowering/BUILD +++ b/tests/core/lowering/BUILD @@ -103,6 +103,10 @@ lowering_test( name = "test_replace_aten_pad_pass", ) +lowering_test( + name = "test_tile_to_repeat_pass", +) + test_suite( name = "lowering_tests", tests = [ @@ -122,6 +126,7 @@ test_suite( ":test_remove_unnecessary_casts", ":test_replace_aten_pad_pass", ":test_rewrite_inputs_with_params", + ":test_tile_to_repeat_pass", ":test_unpack_hardsigmoid", ":test_unpack_hardswish", ":test_unpack_reduce_ops", diff --git a/tests/core/lowering/test_tile_to_repeat_pass.cpp b/tests/core/lowering/test_tile_to_repeat_pass.cpp new file mode 100644 index 0000000000..8357007091 --- /dev/null +++ b/tests/core/lowering/test_tile_to_repeat_pass.cpp @@ -0,0 +1,26 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "torch/csrc/jit/ir/subgraph_matcher.h" + +TEST(LoweringPasses, TileToRepeatCorrectly) { + std::string source_graph = R"IR( + graph(%input, %dim): + %o : Tensor = aten::tile(%input, %dim) + return (%o))IR"; + std::string target_graph = R"IR( + graph(%input, %dim): + %o : Tensor = aten::repeat(%input, %dim) + return (%o))IR"; + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, sg.get()); + torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +}