From fef2d4532390c90af41d80e432845d3315d628ef Mon Sep 17 00:00:00 2001 From: jiahy0825 Date: Thu, 28 Dec 2023 06:49:11 +0000 Subject: [PATCH 1/7] add SubstituteDimExpr --- paddle/cinn/common/dim_expr_util.cc | 94 +++++++++++++++++++++++++++++ paddle/cinn/common/dim_expr_util.h | 15 +++++ 2 files changed, 109 insertions(+) create mode 100644 paddle/cinn/common/dim_expr_util.cc create mode 100644 paddle/cinn/common/dim_expr_util.h diff --git a/paddle/cinn/common/dim_expr_util.cc b/paddle/cinn/common/dim_expr_util.cc new file mode 100644 index 0000000000000..59e03dcf3b51d --- /dev/null +++ b/paddle/cinn/common/dim_expr_util.cc @@ -0,0 +1,94 @@ +#include "paddle/cinn/common/dim_expr_util.h" + +namespace cinn::common { +using namespace symbol; + +namespace { + +class SubstituteDimExprHelper final { + public: + explicit SubstituteDimExprHelper( + const std::unordered_map& pattern_to_replacement) + : pattern_to_replacement_(pattern_to_replacement) {} + + std::optional Substitute(const DimExpr& dim_expr) { + auto iter = pattern_to_replacement_.find(dim_expr); + if (iter != pattern_to_replacement_.end()) return iter->second; + return std::visit([&](const auto& impl) { + return SubstituteImpl(impl); + }, dim_expr.variant()); + } + + private: + std::optional SubstituteImpl(const std::int64_t& value) { + // `Substitute` has handled the case that `value` is matched. + return std::nullopt; + } + std::optional SubstituteImpl(const std::string& value) { + // `Substitute` has handled the case that `value` is matched. + return std::nullopt; + } + + std::optional SubstituteImpl(const Negative& dim_expr) { + return SubstituteUnary(dim_expr); + } + std::optional SubstituteImpl(const Reciprocal& dim_expr) { + return SubstituteUnary(dim_expr); + } + + template + std::optional SubstituteUnary(const T& dim_expr) { + const auto& [operand] = *dim_expr; + const auto& substituted_operand = Substitute(operand); + if (!substituted_operand.has_value()) return std::nullopt; + return T{substituted_operand.value()}; + } + + std::optional SubstituteImpl(const Add& dim_expr) { + return SubstituteVariadic(dim_expr); + } + + std::optional SubstituteImpl(const Mul& dim_expr) { + return SubstituteVariadic(dim_expr); + } + + std::optional SubstituteImpl(const Max& dim_expr) { + return SubstituteVariadic(dim_expr); + } + + std::optional SubstituteImpl(const Min& dim_expr) { + return SubstituteVariadic(dim_expr); + } + + std::optional SubstituteImpl(const Broadcast& dim_expr) { + return SubstituteVariadic(dim_expr); + } + + template + std::optional SubstituteVariadic(const T& dim_expr) { + const auto& [operands] = *dim_expr; + List substituted_operands{}; + size_t replace_cnt = 0; + for (const auto& operand : operands) { + const auto& substituted_operand = Substitute(operand); + replace_cnt += substituted_operand.has_value(); + substituted_operands.push_back( + substituted_operand.has_value() ? substituted_operand.value() : operand); + } + if (replace_cnt == 0) return std::nullopt; + return T{substituted_operands}; + } + + std::unordered_map pattern_to_replacement_; +}; + +} // namespace + +symbol::DimExpr SubstituteDimExpr( + const symbol::DimExpr& dim_expr, + const std::unordered_map& pattern_to_replacement) { + const auto& opt_replaced = SubstituteDimExprHelper(pattern_to_replacement).Substitute(dim_expr); + return opt_replaced.has_value() ? opt_replaced.value() : dim_expr; +} + +} // namespace cinn::common \ No newline at end of file diff --git a/paddle/cinn/common/dim_expr_util.h b/paddle/cinn/common/dim_expr_util.h new file mode 100644 index 0000000000000..58cb95a1ff02f --- /dev/null +++ b/paddle/cinn/common/dim_expr_util.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include "paddle/pir/core/builder.h" +#include "paddle/pir/dialect/shape/utils/dim_expr.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" + + +namespace cinn::common { + +symbol::DimExpr SubstituteDimExpr( + const symbol::DimExpr& dim_expr, + const std::unordered_map& pattern_to_replacement); + +} \ No newline at end of file From e04a525d51f6ca563b1ec2a7c8325039c12bb9f9 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Tue, 2 Jan 2024 02:18:04 +0000 Subject: [PATCH 2/7] Fix compile error --- paddle/cinn/common/CMakeLists.txt | 4 +- paddle/cinn/common/dim_expr_util.cc | 43 +++++++++++----- paddle/cinn/common/dim_expr_util_test.cc | 62 ++++++++++++++++++++++++ 3 files changed, 95 insertions(+), 14 deletions(-) create mode 100644 paddle/cinn/common/dim_expr_util_test.cc diff --git a/paddle/cinn/common/CMakeLists.txt b/paddle/cinn/common/CMakeLists.txt index b71055169945c..ff024385d3479 100644 --- a/paddle/cinn/common/CMakeLists.txt +++ b/paddle/cinn/common/CMakeLists.txt @@ -23,7 +23,8 @@ gather_srcs( nvgpu_dev_info.cc integer_set.cc dim_expr_simplify.cc - dim_expr_converter.cc) + dim_expr_converter.cc + dim_expr_util.cc) cinn_cc_test(test_equation_graph_topo_walker SRCS equation_graph_topo_walker_test.cc DEPS gtest glog) @@ -48,6 +49,7 @@ if(WITH_CUDA) gtest glog) endif() if(NOT CINN_ONLY) + cinn_cc_test(dim_expr_util_test SRCS dim_expr_util_test.cc DEPS cinncore) cinn_cc_test(dim_expr_simplify_test SRCS dim_expr_simplify_test.cc DEPS cinncore) cinn_cc_test(dim_expr_converter_test SRCS dim_expr_converter_test.cc DEPS diff --git a/paddle/cinn/common/dim_expr_util.cc b/paddle/cinn/common/dim_expr_util.cc index 59e03dcf3b51d..0d0a9090429a0 100644 --- a/paddle/cinn/common/dim_expr_util.cc +++ b/paddle/cinn/common/dim_expr_util.cc @@ -1,22 +1,36 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "paddle/cinn/common/dim_expr_util.h" namespace cinn::common { -using namespace symbol; +using namespace symbol; // NOLINT namespace { class SubstituteDimExprHelper final { public: explicit SubstituteDimExprHelper( - const std::unordered_map& pattern_to_replacement) - : pattern_to_replacement_(pattern_to_replacement) {} + const std::unordered_map& + pattern_to_replacement) + : pattern_to_replacement_(pattern_to_replacement) {} std::optional Substitute(const DimExpr& dim_expr) { auto iter = pattern_to_replacement_.find(dim_expr); if (iter != pattern_to_replacement_.end()) return iter->second; - return std::visit([&](const auto& impl) { - return SubstituteImpl(impl); - }, dim_expr.variant()); + return std::visit([&](const auto& impl) { return SubstituteImpl(impl); }, + dim_expr.variant()); } private: @@ -38,7 +52,7 @@ class SubstituteDimExprHelper final { template std::optional SubstituteUnary(const T& dim_expr) { - const auto& [operand] = *dim_expr; + const auto& operand = dim_expr->data; const auto& substituted_operand = Substitute(operand); if (!substituted_operand.has_value()) return std::nullopt; return T{substituted_operand.value()}; @@ -66,14 +80,15 @@ class SubstituteDimExprHelper final { template std::optional SubstituteVariadic(const T& dim_expr) { - const auto& [operands] = *dim_expr; + const auto& operands = *(dim_expr.operands); List substituted_operands{}; size_t replace_cnt = 0; for (const auto& operand : operands) { const auto& substituted_operand = Substitute(operand); replace_cnt += substituted_operand.has_value(); - substituted_operands.push_back( - substituted_operand.has_value() ? substituted_operand.value() : operand); + substituted_operands->push_back(substituted_operand.has_value() + ? substituted_operand.value() + : operand); } if (replace_cnt == 0) return std::nullopt; return T{substituted_operands}; @@ -86,9 +101,11 @@ class SubstituteDimExprHelper final { symbol::DimExpr SubstituteDimExpr( const symbol::DimExpr& dim_expr, - const std::unordered_map& pattern_to_replacement) { - const auto& opt_replaced = SubstituteDimExprHelper(pattern_to_replacement).Substitute(dim_expr); + const std::unordered_map& + pattern_to_replacement) { + const auto& opt_replaced = + SubstituteDimExprHelper(pattern_to_replacement).Substitute(dim_expr); return opt_replaced.has_value() ? opt_replaced.value() : dim_expr; } -} // namespace cinn::common \ No newline at end of file +} // namespace cinn::common diff --git a/paddle/cinn/common/dim_expr_util_test.cc b/paddle/cinn/common/dim_expr_util_test.cc new file mode 100644 index 0000000000000..ef6422d63420f --- /dev/null +++ b/paddle/cinn/common/dim_expr_util_test.cc @@ -0,0 +1,62 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/common/dim_expr_util.h" + +#include "gtest/gtest.h" + +namespace cinn::dialect { +using namespace symbol; // NOLINT + +namespace { +DimExpr CreateExampleDimExpr() { + DimExprBuilder dim_expr_builder{nullptr}; + DimExpr sym0 = DimExpr("S0"); + DimExpr sym1 = DimExpr("S1"); + DimExpr constant = DimExpr(2); + DimExpr expr1 = (sym0 - sym1) * constant / sym0; + DimExpr expr2 = dim_expr_builder.Max(expr1, sym0); + DimExpr output = dim_expr_builder.Min(expr2, sym1); + return output; +} +} // namespace + +TEST(DimExprUtil, Substitute) { + DimExpr dim_expr = CreateExampleDimExpr(); + const auto& opt_expr = SubstituteDimExpr( + dim_expr, [](const DimExpr& expr) -> std::optional { + if (expr == DimExpr("S0")) { + return DimExpr("symbol0"); + } else if (expr == DimExpr("S1")) { + return DimExpr("symbol1"); + } else { + return std::nullopt; + } + }); + ASSERT_TRUE(opt_expr.has_value()); + const auto& ret_expr = SubstituteDimExpr( + opt_expr.value(), [](const DimExpr& expr) -> std::optional { + if (expr == DimExpr("symbol0")) { + return DimExpr("S0"); + } else if (expr == DimExpr("symbol1")) { + return DimExpr("S1"); + } else { + return std::nullopt; + } + }); + ASSERT_TRUE(ret_expr.has_value()); + ASSERT_EQ(ret_expr.value(), dim_expr); +} + +} // namespace cinn::dialect From 82fd504f5d0599f2a7cd2533c70ffb9dc9a18a81 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Tue, 2 Jan 2024 03:01:00 +0000 Subject: [PATCH 3/7] Code format --- paddle/cinn/common/dim_expr_util.h | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/paddle/cinn/common/dim_expr_util.h b/paddle/cinn/common/dim_expr_util.h index 58cb95a1ff02f..163aeb226ab0d 100644 --- a/paddle/cinn/common/dim_expr_util.h +++ b/paddle/cinn/common/dim_expr_util.h @@ -1,15 +1,29 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once #include +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/pir/core/builder.h" #include "paddle/pir/dialect/shape/utils/dim_expr.h" -#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" - namespace cinn::common { symbol::DimExpr SubstituteDimExpr( const symbol::DimExpr& dim_expr, - const std::unordered_map& pattern_to_replacement); + const std::unordered_map& + pattern_to_replacement); -} \ No newline at end of file +} From f17e788d2a99f8b2f887b12475147f96331202bd Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Tue, 2 Jan 2024 03:25:50 +0000 Subject: [PATCH 4/7] Polish DimExprUtilTest --- paddle/cinn/common/dim_expr_util_test.cc | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/paddle/cinn/common/dim_expr_util_test.cc b/paddle/cinn/common/dim_expr_util_test.cc index ef6422d63420f..deb9d339b0ad1 100644 --- a/paddle/cinn/common/dim_expr_util_test.cc +++ b/paddle/cinn/common/dim_expr_util_test.cc @@ -21,14 +21,10 @@ using namespace symbol; // NOLINT namespace { DimExpr CreateExampleDimExpr() { - DimExprBuilder dim_expr_builder{nullptr}; DimExpr sym0 = DimExpr("S0"); DimExpr sym1 = DimExpr("S1"); DimExpr constant = DimExpr(2); - DimExpr expr1 = (sym0 - sym1) * constant / sym0; - DimExpr expr2 = dim_expr_builder.Max(expr1, sym0); - DimExpr output = dim_expr_builder.Min(expr2, sym1); - return output; + return (sym0 - sym1) * constant / sym0; } } // namespace From fcaf2e07a6e003f9062dab2fcc159576a2d64b16 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Tue, 2 Jan 2024 04:37:15 +0000 Subject: [PATCH 5/7] Change namesapce --- paddle/cinn/common/dim_expr_util_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/cinn/common/dim_expr_util_test.cc b/paddle/cinn/common/dim_expr_util_test.cc index deb9d339b0ad1..340618294c27f 100644 --- a/paddle/cinn/common/dim_expr_util_test.cc +++ b/paddle/cinn/common/dim_expr_util_test.cc @@ -16,7 +16,7 @@ #include "gtest/gtest.h" -namespace cinn::dialect { +namespace cinn::common { using namespace symbol; // NOLINT namespace { @@ -55,4 +55,4 @@ TEST(DimExprUtil, Substitute) { ASSERT_EQ(ret_expr.value(), dim_expr); } -} // namespace cinn::dialect +} // namespace cinn::common From c14a8710c17ef1da00d22b9c52c44cfc30227bdf Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Tue, 2 Jan 2024 05:20:56 +0000 Subject: [PATCH 6/7] Fix unittest --- paddle/cinn/common/dim_expr_util_test.cc | 28 +++++++----------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/paddle/cinn/common/dim_expr_util_test.cc b/paddle/cinn/common/dim_expr_util_test.cc index 340618294c27f..937c38d6ece52 100644 --- a/paddle/cinn/common/dim_expr_util_test.cc +++ b/paddle/cinn/common/dim_expr_util_test.cc @@ -30,27 +30,15 @@ DimExpr CreateExampleDimExpr() { TEST(DimExprUtil, Substitute) { DimExpr dim_expr = CreateExampleDimExpr(); - const auto& opt_expr = SubstituteDimExpr( - dim_expr, [](const DimExpr& expr) -> std::optional { - if (expr == DimExpr("S0")) { - return DimExpr("symbol0"); - } else if (expr == DimExpr("S1")) { - return DimExpr("symbol1"); - } else { - return std::nullopt; - } - }); + std::unordered_map naive_to_full_name{ + {DimExpr("S0"), DimExpr("symbol0")}, {DimExpr("S1"), DimExpr("symbol1")}}; + std::unordered_map full_name_to_naive{ + {DimExpr("symbol0"), DimExpr("S0")}, {DimExpr("symbol1"), DimExpr("S1")}}; + + const auto& opt_expr = SubstituteDimExpr(dim_expr, naive_to_full_name); ASSERT_TRUE(opt_expr.has_value()); - const auto& ret_expr = SubstituteDimExpr( - opt_expr.value(), [](const DimExpr& expr) -> std::optional { - if (expr == DimExpr("symbol0")) { - return DimExpr("S0"); - } else if (expr == DimExpr("symbol1")) { - return DimExpr("S1"); - } else { - return std::nullopt; - } - }); + const auto& ret_expr = + SubstituteDimExpr(opt_expr.value(), full_name_to_naive); ASSERT_TRUE(ret_expr.has_value()); ASSERT_EQ(ret_expr.value(), dim_expr); } From 501b03834208f123bf5eb9be222b78eba269599e Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Tue, 2 Jan 2024 06:39:40 +0000 Subject: [PATCH 7/7] Polish DimExprUtilTest --- paddle/cinn/common/dim_expr_util_test.cc | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/paddle/cinn/common/dim_expr_util_test.cc b/paddle/cinn/common/dim_expr_util_test.cc index 937c38d6ece52..82b300fc5bfe2 100644 --- a/paddle/cinn/common/dim_expr_util_test.cc +++ b/paddle/cinn/common/dim_expr_util_test.cc @@ -35,12 +35,9 @@ TEST(DimExprUtil, Substitute) { std::unordered_map full_name_to_naive{ {DimExpr("symbol0"), DimExpr("S0")}, {DimExpr("symbol1"), DimExpr("S1")}}; - const auto& opt_expr = SubstituteDimExpr(dim_expr, naive_to_full_name); - ASSERT_TRUE(opt_expr.has_value()); - const auto& ret_expr = - SubstituteDimExpr(opt_expr.value(), full_name_to_naive); - ASSERT_TRUE(ret_expr.has_value()); - ASSERT_EQ(ret_expr.value(), dim_expr); + const auto& mid_expr = SubstituteDimExpr(dim_expr, naive_to_full_name); + const auto& ret_expr = SubstituteDimExpr(mid_expr, full_name_to_naive); + ASSERT_EQ(ret_expr, dim_expr); } } // namespace cinn::common