From cfeafa621f921d966ea9458b84dcf5cc9964e41b Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Tue, 19 Dec 2023 09:31:38 +0000 Subject: [PATCH 1/8] DimExpr support print --- paddle/pir/dialect/shape/utils/dim_expr.cc | 60 +++++++++++++++++++ paddle/pir/dialect/shape/utils/dim_expr.h | 5 ++ .../pir/shape_dialect/symbol_dim_expr_test.cc | 14 +++++ 3 files changed, 79 insertions(+) diff --git a/paddle/pir/dialect/shape/utils/dim_expr.cc b/paddle/pir/dialect/shape/utils/dim_expr.cc index 9c46a8841c1e1..5e4295a11af16 100644 --- a/paddle/pir/dialect/shape/utils/dim_expr.cc +++ b/paddle/pir/dialect/shape/utils/dim_expr.cc @@ -124,4 +124,64 @@ bool DimExpr::operator!=(const DimExpr& other) const { return !(*this == other); } +namespace { + +std::string ToTxtStringImpl(std::int64_t dim_expr) { + return std::to_string(dim_expr); +} + +std::string ToTxtStringImpl(const std::string& dim_expr) { return dim_expr; } + +std::string ToTxtStringImpl(const Negative& dim_expr) { + return "-" + ToTxtString(dim_expr->data); +} + +std::string ToTxtStringImpl(const Reciprocal& dim_expr) { + return "1 / (" + ToTxtString(dim_expr->data) + ")"; +} + +std::string ListDimExprToTxtString(const List& dim_exprs, + const std::string& delim = ", ") { + std::string ret; + for (std::size_t i = 0; i < dim_exprs->size(); ++i) { + if (i > 0) { + ret += delim; + } + ret += ToTxtString(dim_exprs->at(i)); + } + return ret; +} + +std::string ToTxtStringImpl(const Add& dim_expr) { + return "Add(" + ListDimExprToTxtString(dim_expr.operands, ", ") + ")"; +} + +std::string ToTxtStringImpl(const Mul& dim_expr) { + return "Mul(" + ListDimExprToTxtString(dim_expr.operands, ", ") + ")"; +} + +std::string ToTxtStringImpl(const Max& dim_expr) { + return "Max(" + ListDimExprToTxtString(dim_expr.operands, ", ") + ")"; +} + +std::string ToTxtStringImpl(const Min& dim_expr) { + return "Min(" + ListDimExprToTxtString(dim_expr.operands, ", ") + ")"; +} + +std::string ToTxtStringImpl(const Broadcast& dim_expr) { + return "Broadcast(" + ListDimExprToTxtString(dim_expr.operands, ", ") + ")"; +} + +} // namespace + +std::string ToTxtString(const DimExpr& dim_expr) { + return std::visit([](const auto& impl) { return ToTxtStringImpl(impl); }, + dim_expr.variant()); +} + +std::ostream& operator<<(std::ostream& stream, const DimExpr& dim_expr) { + stream << ToTxtString(dim_expr); + return stream; +} + } // namespace symbol diff --git a/paddle/pir/dialect/shape/utils/dim_expr.h b/paddle/pir/dialect/shape/utils/dim_expr.h index 7b9b7f81c4ad5..829013c145a42 100644 --- a/paddle/pir/dialect/shape/utils/dim_expr.h +++ b/paddle/pir/dialect/shape/utils/dim_expr.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -239,4 +240,8 @@ class ValueShape { using ValueShapeDimExprs = ValueShape; +std::string ToTxtString(const DimExpr& dim_expr); + +std::ostream& operator<<(std::ostream&, const DimExpr& dim_expr); + } // namespace symbol diff --git a/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc b/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc index bbceda3906de1..1298608316fa3 100644 --- a/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc +++ b/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc @@ -92,4 +92,18 @@ TEST(DimExpr, equal) { builder.Broadcast(DimExpr("S0"), constant1)); } +TEST(DimExpr, print) { + imExprBuilder builder{nullptr}; + DimExpr sym0 = DimExpr("S0"); + DimExpr sym1 = DimExpr("S1"); + DimExpr constant1 = DimExpr(1); + ASSERT_EQ(sym0 + sym1 + constant1, "Add(S0, S1, 1)"); + ASSERT_EQ(sym0 - sym1 + constant1, "Add(S0, -S1, 1)"); + ASSERT_EQ(sym0 * sym1, "Mul(S0, S1)"); + ASSERT_EQ(sym0 / sym1, "Mul(S0, 1 / (S1))"); + ASSERT_EQ(builder.Max(sym0, sym1), "Max(S0, S1)"); + ASSERT_EQ(builder.Min(sym0, sym1), "Min(S0, S1)"); + ASSERT_EQ(builder.Broadcast(sym0, sym1), "Broadcast(S0, S1)"); +} + } // namespace symbol::test From e0f37143903eef901af3fb6e4f4d35ee7ca20e66 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Tue, 19 Dec 2023 13:02:20 +0000 Subject: [PATCH 2/8] ToTxtString --- test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc b/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc index 1298608316fa3..ed6e8716ad2ef 100644 --- a/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc +++ b/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc @@ -97,13 +97,13 @@ TEST(DimExpr, print) { DimExpr sym0 = DimExpr("S0"); DimExpr sym1 = DimExpr("S1"); DimExpr constant1 = DimExpr(1); - ASSERT_EQ(sym0 + sym1 + constant1, "Add(S0, S1, 1)"); - ASSERT_EQ(sym0 - sym1 + constant1, "Add(S0, -S1, 1)"); - ASSERT_EQ(sym0 * sym1, "Mul(S0, S1)"); - ASSERT_EQ(sym0 / sym1, "Mul(S0, 1 / (S1))"); - ASSERT_EQ(builder.Max(sym0, sym1), "Max(S0, S1)"); - ASSERT_EQ(builder.Min(sym0, sym1), "Min(S0, S1)"); - ASSERT_EQ(builder.Broadcast(sym0, sym1), "Broadcast(S0, S1)"); + ASSERT_EQ(ToTxtString(sym0 + sym1 + constant1), "Add(S0, S1, 1)"); + ASSERT_EQ(ToTxtString(sym0 - sym1 + constant1), "Add(S0, -S1, 1)"); + ASSERT_EQ(ToTxtString(sym0 * sym1), "Mul(S0, S1)"); + ASSERT_EQ(ToTxtString(sym0 / sym1), "Mul(S0, 1 / (S1))"); + ASSERT_EQ(ToTxtString(builder.Max(sym0, sym1)), "Max(S0, S1)"); + ASSERT_EQ(ToTxtString(builder.Min(sym0, sym1)), "Min(S0, S1)"); + ASSERT_EQ(ToTxtString(builder.Broadcast(sym0, sym1)), "Broadcast(S0, S1)"); } } // namespace symbol::test From 7593413485fb43a5f18d2eb8d27f848a96a7ca55 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Tue, 19 Dec 2023 14:12:29 +0000 Subject: [PATCH 3/8] Fix ASSERT_EQ bug --- test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc b/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc index ed6e8716ad2ef..8f6f3985d9f31 100644 --- a/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc +++ b/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc @@ -97,13 +97,13 @@ TEST(DimExpr, print) { DimExpr sym0 = DimExpr("S0"); DimExpr sym1 = DimExpr("S1"); DimExpr constant1 = DimExpr(1); - ASSERT_EQ(ToTxtString(sym0 + sym1 + constant1), "Add(S0, S1, 1)"); - ASSERT_EQ(ToTxtString(sym0 - sym1 + constant1), "Add(S0, -S1, 1)"); - ASSERT_EQ(ToTxtString(sym0 * sym1), "Mul(S0, S1)"); - ASSERT_EQ(ToTxtString(sym0 / sym1), "Mul(S0, 1 / (S1))"); - ASSERT_EQ(ToTxtString(builder.Max(sym0, sym1)), "Max(S0, S1)"); - ASSERT_EQ(ToTxtString(builder.Min(sym0, sym1)), "Min(S0, S1)"); - ASSERT_EQ(ToTxtString(builder.Broadcast(sym0, sym1)), "Broadcast(S0, S1)"); + ASSERT_EQ((ToTxtString(sym0 + sym1 + constant1)), "Add(S0, S1, 1)"); + ASSERT_EQ((ToTxtString(sym0 - sym1 + constant1)), "Add(S0, -S1, 1)"); + ASSERT_EQ((ToTxtString(sym0 * sym1)), "Mul(S0, S1)"); + ASSERT_EQ((ToTxtString(sym0 / sym1)), "Mul(S0, 1 / (S1))"); + ASSERT_EQ((ToTxtString(builder.Max(sym0, sym1))), "Max(S0, S1)"); + ASSERT_EQ((ToTxtString(builder.Min(sym0, sym1))), "Min(S0, S1)"); + ASSERT_EQ((ToTxtString(builder.Broadcast(sym0, sym1))), "Broadcast(S0, S1)"); } } // namespace symbol::test From 114ea8db75c66893f5d4dfb29f5fe33bebbc1306 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Wed, 20 Dec 2023 02:39:40 +0000 Subject: [PATCH 4/8] Fix typo --- test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc b/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc index 8f6f3985d9f31..c95bf712c90ec 100644 --- a/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc +++ b/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc @@ -93,7 +93,7 @@ TEST(DimExpr, equal) { } TEST(DimExpr, print) { - imExprBuilder builder{nullptr}; + DimExprBuilder builder{nullptr}; DimExpr sym0 = DimExpr("S0"); DimExpr sym1 = DimExpr("S1"); DimExpr constant1 = DimExpr(1); From d2a04f6d3500ad4575f76375617a03ed95c45b5a Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Wed, 20 Dec 2023 08:58:06 +0000 Subject: [PATCH 5/8] Fix unittest --- test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc b/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc index c95bf712c90ec..8ffc7fd3a5870 100644 --- a/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc +++ b/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc @@ -96,9 +96,8 @@ TEST(DimExpr, print) { DimExprBuilder builder{nullptr}; DimExpr sym0 = DimExpr("S0"); DimExpr sym1 = DimExpr("S1"); - DimExpr constant1 = DimExpr(1); - ASSERT_EQ((ToTxtString(sym0 + sym1 + constant1)), "Add(S0, S1, 1)"); - ASSERT_EQ((ToTxtString(sym0 - sym1 + constant1)), "Add(S0, -S1, 1)"); + ASSERT_EQ((ToTxtString(sym0 + sym1)), "Add(S0, S1)"); + ASSERT_EQ((ToTxtString(sym0 - sym1)), "Add(S0, -S1)"); ASSERT_EQ((ToTxtString(sym0 * sym1)), "Mul(S0, S1)"); ASSERT_EQ((ToTxtString(sym0 / sym1)), "Mul(S0, 1 / (S1))"); ASSERT_EQ((ToTxtString(builder.Max(sym0, sym1))), "Max(S0, S1)"); From 152f7d0aaa21db4f2ae72dd143e64272d899e67e Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Mon, 25 Dec 2023 09:15:06 +0000 Subject: [PATCH 6/8] ToTxtString->ToString --- paddle/pir/dialect/shape/utils/dim_expr.cc | 44 +++++++++---------- paddle/pir/dialect/shape/utils/dim_expr.h | 2 +- .../pir/shape_dialect/symbol_dim_expr_test.cc | 14 +++--- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/paddle/pir/dialect/shape/utils/dim_expr.cc b/paddle/pir/dialect/shape/utils/dim_expr.cc index 5e4295a11af16..0d9b6ece23245 100644 --- a/paddle/pir/dialect/shape/utils/dim_expr.cc +++ b/paddle/pir/dialect/shape/utils/dim_expr.cc @@ -126,61 +126,61 @@ bool DimExpr::operator!=(const DimExpr& other) const { namespace { -std::string ToTxtStringImpl(std::int64_t dim_expr) { +std::string ToStringImpl(std::int64_t dim_expr) { return std::to_string(dim_expr); } -std::string ToTxtStringImpl(const std::string& dim_expr) { return dim_expr; } +std::string ToStringImpl(const std::string& dim_expr) { return dim_expr; } -std::string ToTxtStringImpl(const Negative& dim_expr) { - return "-" + ToTxtString(dim_expr->data); +std::string ToStringImpl(const Negative& dim_expr) { + return "-" + ToString(dim_expr->data); } -std::string ToTxtStringImpl(const Reciprocal& dim_expr) { - return "1 / (" + ToTxtString(dim_expr->data) + ")"; +std::string ToStringImpl(const Reciprocal& dim_expr) { + return "1 / (" + ToString(dim_expr->data) + ")"; } -std::string ListDimExprToTxtString(const List& dim_exprs, - const std::string& delim = ", ") { +std::string ListDimExprToString(const List& dim_exprs, + const std::string& delim = ", ") { std::string ret; for (std::size_t i = 0; i < dim_exprs->size(); ++i) { if (i > 0) { ret += delim; } - ret += ToTxtString(dim_exprs->at(i)); + ret += ToString(dim_exprs->at(i)); } return ret; } -std::string ToTxtStringImpl(const Add& dim_expr) { - return "Add(" + ListDimExprToTxtString(dim_expr.operands, ", ") + ")"; +std::string ToStringImpl(const Add& dim_expr) { + return "Add(" + ListDimExprToString(dim_expr.operands, ", ") + ")"; } -std::string ToTxtStringImpl(const Mul& dim_expr) { - return "Mul(" + ListDimExprToTxtString(dim_expr.operands, ", ") + ")"; +std::string ToStringImpl(const Mul& dim_expr) { + return "Mul(" + ListDimExprToString(dim_expr.operands, ", ") + ")"; } -std::string ToTxtStringImpl(const Max& dim_expr) { - return "Max(" + ListDimExprToTxtString(dim_expr.operands, ", ") + ")"; +std::string ToStringImpl(const Max& dim_expr) { + return "Max(" + ListDimExprToString(dim_expr.operands, ", ") + ")"; } -std::string ToTxtStringImpl(const Min& dim_expr) { - return "Min(" + ListDimExprToTxtString(dim_expr.operands, ", ") + ")"; +std::string ToStringImpl(const Min& dim_expr) { + return "Min(" + ListDimExprToString(dim_expr.operands, ", ") + ")"; } -std::string ToTxtStringImpl(const Broadcast& dim_expr) { - return "Broadcast(" + ListDimExprToTxtString(dim_expr.operands, ", ") + ")"; +std::string ToStringImpl(const Broadcast& dim_expr) { + return "Broadcast(" + ListDimExprToString(dim_expr.operands, ", ") + ")"; } } // namespace -std::string ToTxtString(const DimExpr& dim_expr) { - return std::visit([](const auto& impl) { return ToTxtStringImpl(impl); }, +std::string ToString(const DimExpr& dim_expr) { + return std::visit([](const auto& impl) { return ToStringImpl(impl); }, dim_expr.variant()); } std::ostream& operator<<(std::ostream& stream, const DimExpr& dim_expr) { - stream << ToTxtString(dim_expr); + stream << ToString(dim_expr); return stream; } diff --git a/paddle/pir/dialect/shape/utils/dim_expr.h b/paddle/pir/dialect/shape/utils/dim_expr.h index 829013c145a42..da7293f897a06 100644 --- a/paddle/pir/dialect/shape/utils/dim_expr.h +++ b/paddle/pir/dialect/shape/utils/dim_expr.h @@ -240,7 +240,7 @@ class ValueShape { using ValueShapeDimExprs = ValueShape; -std::string ToTxtString(const DimExpr& dim_expr); +std::string ToString(const DimExpr& dim_expr); std::ostream& operator<<(std::ostream&, const DimExpr& dim_expr); diff --git a/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc b/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc index 8ffc7fd3a5870..550a26d508b76 100644 --- a/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc +++ b/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc @@ -96,13 +96,13 @@ TEST(DimExpr, print) { DimExprBuilder builder{nullptr}; DimExpr sym0 = DimExpr("S0"); DimExpr sym1 = DimExpr("S1"); - ASSERT_EQ((ToTxtString(sym0 + sym1)), "Add(S0, S1)"); - ASSERT_EQ((ToTxtString(sym0 - sym1)), "Add(S0, -S1)"); - ASSERT_EQ((ToTxtString(sym0 * sym1)), "Mul(S0, S1)"); - ASSERT_EQ((ToTxtString(sym0 / sym1)), "Mul(S0, 1 / (S1))"); - ASSERT_EQ((ToTxtString(builder.Max(sym0, sym1))), "Max(S0, S1)"); - ASSERT_EQ((ToTxtString(builder.Min(sym0, sym1))), "Min(S0, S1)"); - ASSERT_EQ((ToTxtString(builder.Broadcast(sym0, sym1))), "Broadcast(S0, S1)"); + ASSERT_EQ((ToString(sym0 + sym1)), "Add(S0, S1)"); + ASSERT_EQ((ToString(sym0 - sym1)), "Add(S0, -S1)"); + ASSERT_EQ((ToString(sym0 * sym1)), "Mul(S0, S1)"); + ASSERT_EQ((ToString(sym0 / sym1)), "Mul(S0, 1 / (S1))"); + ASSERT_EQ((ToString(builder.Max(sym0, sym1))), "Max(S0, S1)"); + ASSERT_EQ((ToString(builder.Min(sym0, sym1))), "Min(S0, S1)"); + ASSERT_EQ((ToString(builder.Broadcast(sym0, sym1))), "Broadcast(S0, S1)"); } } // namespace symbol::test From 515a70bc13dad7d57ff07f9def40efeb63fa8488 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Tue, 26 Dec 2023 02:36:24 +0000 Subject: [PATCH 7/8] Fix windows CI --- paddle/pir/dialect/shape/utils/dim_expr.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/pir/dialect/shape/utils/dim_expr.h b/paddle/pir/dialect/shape/utils/dim_expr.h index 697f19c83218d..53041078476e4 100644 --- a/paddle/pir/dialect/shape/utils/dim_expr.h +++ b/paddle/pir/dialect/shape/utils/dim_expr.h @@ -219,7 +219,7 @@ using DimExprConstraint = std::variant, Broadcastable>; // ShapeOrDataDimExprs = (tShape [DimExpr], tData (opt [DimExpr])) template -class ShapeOrData { +class IR_API ShapeOrData { public: explicit ShapeOrData(const std::vector& shape) : shape_(shape), data_(std::nullopt) {} @@ -249,8 +249,8 @@ class ShapeOrData { using ShapeOrDataDimExprs = ShapeOrData; -std::string ToString(const DimExpr& dim_expr); +IR_API std::string ToString(const DimExpr& dim_expr); -std::ostream& operator<<(std::ostream&, const DimExpr& dim_expr); +IR_API std::ostream& operator<<(std::ostream&, const DimExpr& dim_expr); } // namespace symbol From 69063ac0a3e9498be8240d6cec4eabeb78e739ff Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Tue, 26 Dec 2023 09:15:06 +0000 Subject: [PATCH 8/8] ShapeOrData cannot use IR_API --- paddle/pir/dialect/shape/utils/dim_expr.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/pir/dialect/shape/utils/dim_expr.h b/paddle/pir/dialect/shape/utils/dim_expr.h index 53041078476e4..277a6febe66ed 100644 --- a/paddle/pir/dialect/shape/utils/dim_expr.h +++ b/paddle/pir/dialect/shape/utils/dim_expr.h @@ -219,7 +219,7 @@ using DimExprConstraint = std::variant, Broadcastable>; // ShapeOrDataDimExprs = (tShape [DimExpr], tData (opt [DimExpr])) template -class IR_API ShapeOrData { +class ShapeOrData { public: explicit ShapeOrData(const std::vector& shape) : shape_(shape), data_(std::nullopt) {}