From cff8bb6b9db3720a79dfc1fa5fa69a2559dda662 Mon Sep 17 00:00:00 2001 From: jiahy0825 Date: Sun, 10 Mar 2024 07:16:25 +0000 Subject: [PATCH 1/2] declare group_pattern.InferShardableAxes --- paddle/cinn/frontend/group_pattern.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/cinn/frontend/group_pattern.h b/paddle/cinn/frontend/group_pattern.h index ea69cc1db06ca0..4b23ef86313616 100644 --- a/paddle/cinn/frontend/group_pattern.h +++ b/paddle/cinn/frontend/group_pattern.h @@ -143,4 +143,6 @@ namespace cinn::frontend { using ErrorGroupPattern = api::ErrorPattern; using GroupPattern = api::OpTopoPattern; +std::unordered_map InferShardableAxes(const cinn::pir::FusionOp& fusion_op); + } \ No newline at end of file From 8e74d2e38b760d06688f8c098f4461c75c05db15 Mon Sep 17 00:00:00 2001 From: jiahy0825 Date: Sun, 10 Mar 2024 07:20:38 +0000 Subject: [PATCH 2/2] refine signature of group_pattern.InferShardableAxes --- paddle/cinn/frontend/group_pattern.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/cinn/frontend/group_pattern.h b/paddle/cinn/frontend/group_pattern.h index 4b23ef86313616..9c9d7d4c638d84 100644 --- a/paddle/cinn/frontend/group_pattern.h +++ b/paddle/cinn/frontend/group_pattern.h @@ -143,6 +143,6 @@ namespace cinn::frontend { using ErrorGroupPattern = api::ErrorPattern; using GroupPattern = api::OpTopoPattern; -std::unordered_map InferShardableAxes(const cinn::pir::FusionOp& fusion_op); +std::unordered_map InferShardableAxes(const std::vector& ops); } \ No newline at end of file