From b052007d88e12630dcb37cbdc526540f4ec36414 Mon Sep 17 00:00:00 2001 From: jiahy0825 Date: Thu, 14 Mar 2024 11:58:47 +0000 Subject: [PATCH 1/2] implement group_pattern_util.MakeShardableAxesSignature4ReduceOp --- paddle/cinn/frontend/group_pattern.h | 23 +++++- paddle/cinn/frontend/group_pattern_util.cc | 82 ++++++++++++++++------ 2 files changed, 82 insertions(+), 23 deletions(-) diff --git a/paddle/cinn/frontend/group_pattern.h b/paddle/cinn/frontend/group_pattern.h index 255eab33894d6..3e8925d783c22 100644 --- a/paddle/cinn/frontend/group_pattern.h +++ b/paddle/cinn/frontend/group_pattern.h @@ -97,7 +97,7 @@ struct ShardableAxesUtil { return ret; } - static ShardableAxes GetFullyShardableAxes(size_t rank) { + static ShardableAxes GetFullyShardableAxes(const size_t rank) { ShardableAxes ret; for (int i = 0; i < rank; ++i) { ret.emplace_back(ShardableAxis{ @@ -107,6 +107,27 @@ struct ShardableAxesUtil { } return ret; } + + static ShardableAxes GetReduceOpInputShardableAxes( + const size_t input_rank, const std::vector& reduce_axes) { + if (reduce_axes.empty()) return ShardableAxes{}; + for (int64_t reduce_axis : reduce_axes) { + CHECK_GE(reduce_axis, 0); + CHECK_LT(reduce_axis, input_rank); + } + const auto IsReduceAxis = [&](int64_t i) { + return std::find(reduce_axes.begin(), reduce_axes.end(), i) != reduce_axes.end(); + }; + ShardableAxes ret; + for (int64_t i = 0; i < input_rank; ++i) { + if (IsReduceAxis(i)) continue; + ret.emplace_back(ShardableAxis{ + .axis=i, + .axis_name=std::string("D") + std::to_string(ShardableAxis::UnqiueSeqNo()), + }); + } + return ret; + } }; struct SoleOutputShardableAxes { diff --git a/paddle/cinn/frontend/group_pattern_util.cc b/paddle/cinn/frontend/group_pattern_util.cc index 61638d01df64a..2cc88404eeb1a 100644 --- a/paddle/cinn/frontend/group_pattern_util.cc +++ b/paddle/cinn/frontend/group_pattern_util.cc @@ -206,6 +206,61 @@ size_t GetRank(pir::Value value) { return value.type().dyn_cast().dims().size(); } +std::vector GetReduceAxes(const pir::Operation* reduce_op) { + const size_t input_rank = GetRank(reduce_op->operand_source(0)); + const auto& attr_val = reduce_op->attributes().at("dim"); + CHECK(attr_val.isa<::pir::ArrayAttribute>()); + const auto& axis_attr = attr_val.dyn_cast<::pir::ArrayAttribute>(); + std::vector reduce_axes; + for (int i = 0; i < axis_attr.size(); ++i) { + int64_t axis = axis_attr.at(i).dyn_cast<::pir::Int64Attribute>().data(); + if (axis < 0) { + axis += input_rank; + } + CHECK_GE(axis, 0); + CHECK_LT(axis, input_rank); + reduce_axes.push_back(axis); + } + return reduce_axes; +} + +bool GetReduceOpKeepDims(const pir::Operation* reduce_op) { + const auto& attr_val = reduce_op->attributes().at("keep_dim"); + CHECK(attr_val.isa<::pir::BoolAttribute>()); + return attr_val.dyn_cast<::pir::BoolAttribute>(); +} + +ShardableAxes SequeezeShardableAxes(const ShardableAxes& sa) { + ShardableAxes ret_sa(sa); + for (int i = 0; i < ret_sa.size(); ++i) { + for (int j = i + 1; j < ret_sa.size(); ++j) { + CHECK_LT(ret_sa.at(i).axis, ret_sa.at(j).axis); + } + ret_sa.at(i).axis = i; + } + return ret_sa; +} + +ShardableAxesSignature MakeShardableAxesSignature4ReduceOp( + const pir::Operation* reduce_op) { + const size_t input_rank = GetRank(reduce_op->operand_source(0)); + const auto& reduce_axes = GetReduceAxes(reduce_op); + const ShardableAxes input_sa = + ShardableAxesUtil::GetReduceOpInputShardableAxes(input_rank, reduce_axes); + using InputSignature = std::unordered_map; + ; + const ShardableAxes output_sa = + (GetReduceOpKeepDims(reduce_op) ? input_sa : SequeezeShardableAxes(input_sa)); + return ShardableAxesSignature{ + .sole_output_sa = SoleOutputShardableAxes{ + .shardable_axes=output_sa, + }, + .input_shardable_axes = InputSignature{ + {OpAndOperandIndex{reduce_op, 0}, input_sa}, + }, + }; +} + ShardableAxesSignature MakeShardableAxesSignature4ElementWiseOp( const pir::Operation* op) { CHECK(!op->isa()) @@ -249,7 +304,9 @@ ShardableAxesSignature MakeShardableAxesSignature4BroadcastOp( ShardableAxesSignature MakeShardableAxesSignature4Op(const pir::Operation* op) { const hlir::framework::OpPatternKind kind = GetOpPatternKind(op); - if (kind == hlir::framework::kElementWise) { + if (kind == hlir::framework::kReduction) { + return MakeShardableAxesSignature4ReduceOp(op); + } else if (kind == hlir::framework::kElementWise) { return MakeShardableAxesSignature4ElementWiseOp(op); } else if (kind == hlir::framework::kBroadcast) { return MakeShardableAxesSignature4BroadcastOp(op); @@ -1718,22 +1775,7 @@ class LoopAlignableClusteringPolicy final : public ClusteringPolicy { const pir::Operation* reduce_op, const ShardableAxes& shardable_axes) { const size_t input_rank = GetRank(reduce_op->operand_source(0)); - const auto& reduce_axes = [&]{ - const auto& attr_val = reduce_op->attributes().at("dim"); - CHECK(attr_val.isa<::pir::ArrayAttribute>()); - const auto& axis_attr = attr_val.dyn_cast<::pir::ArrayAttribute>(); - std::vector reduce_axes; - for (int i = 0; i < axis_attr.size(); ++i) { - int64_t axis = axis_attr.at(i).dyn_cast<::pir::Int64Attribute>().data(); - if (axis < 0) { - axis += input_rank; - } - CHECK_GE(axis, 0); - CHECK_LT(axis, input_rank); - reduce_axes.push_back(axis); - } - return reduce_axes; - }(); + const auto& reduce_axes = GetReduceAxes(reduce_op); // no shardability if input reduced into one element. if (reduce_axes.empty()) return false; @@ -1747,11 +1789,7 @@ class LoopAlignableClusteringPolicy final : public ClusteringPolicy { }; return std::find_if(shardable_axes.begin(), shardable_axes.end(), Condition) != shardable_axes.end(); }; - const bool keepdims = [&]{ - const auto& attr_val = reduce_op->attributes().at("keep_dim"); - CHECK(attr_val.isa<::pir::BoolAttribute>()); - return attr_val.dyn_cast<::pir::BoolAttribute>(); - }(); + const bool keepdims = GetReduceOpKeepDims(reduce_op); if (keepdims) { const size_t output_rank = input_rank; CHECK(!reduce_axes.empty()); From 27a647c3f195203f2ecd49bc7773471543c1d8f5 Mon Sep 17 00:00:00 2001 From: jiahy0825 Date: Thu, 14 Mar 2024 12:23:51 +0000 Subject: [PATCH 2/2] implement group_pattern_util.MakeEmptyShardableAxesSignature --- paddle/cinn/frontend/group_pattern.h | 4 +- paddle/cinn/frontend/group_pattern_util.cc | 44 +++++++++++++++++----- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/paddle/cinn/frontend/group_pattern.h b/paddle/cinn/frontend/group_pattern.h index 3e8925d783c22..3e63f3626a2f1 100644 --- a/paddle/cinn/frontend/group_pattern.h +++ b/paddle/cinn/frontend/group_pattern.h @@ -97,7 +97,7 @@ struct ShardableAxesUtil { return ret; } - static ShardableAxes GetFullyShardableAxes(const size_t rank) { + static ShardableAxes MakeFullyShardableAxes(const size_t rank) { ShardableAxes ret; for (int i = 0; i < rank; ++i) { ret.emplace_back(ShardableAxis{ @@ -108,7 +108,7 @@ struct ShardableAxesUtil { return ret; } - static ShardableAxes GetReduceOpInputShardableAxes( + static ShardableAxes MakeReduceOpInputShardableAxes( const size_t input_rank, const std::vector& reduce_axes) { if (reduce_axes.empty()) return ShardableAxes{}; for (int64_t reduce_axis : reduce_axes) { diff --git a/paddle/cinn/frontend/group_pattern_util.cc b/paddle/cinn/frontend/group_pattern_util.cc index 2cc88404eeb1a..836ddb850e683 100644 --- a/paddle/cinn/frontend/group_pattern_util.cc +++ b/paddle/cinn/frontend/group_pattern_util.cc @@ -241,14 +241,30 @@ ShardableAxes SequeezeShardableAxes(const ShardableAxes& sa) { return ret_sa; } +ShardableAxesSignature MakeEmptyShardableAxesSignature(const pir::Operation* op) { + const int result_idx = GetOutputShardableAxesResultIdx(op); + pir::Value output = op->result(result_idx); + ShardableAxes output_sa = ShardableAxesUtil::MakeFullyShardableAxes(GetRank(output)); + using InputSignature = std::unordered_map; + InputSignature empty_input_sig; + for (int i = 0; i < op->num_operands(); ++i) { + empty_input_sig[OpAndOperandIndex{op, i}] = ShardableAxes{}; + } + return ShardableAxesSignature{ + .sole_output_sa = SoleOutputShardableAxes{ + .shardable_axes=output_sa, + }, + .input_shardable_axes = empty_input_sig, + }; +} + ShardableAxesSignature MakeShardableAxesSignature4ReduceOp( const pir::Operation* reduce_op) { const size_t input_rank = GetRank(reduce_op->operand_source(0)); const auto& reduce_axes = GetReduceAxes(reduce_op); const ShardableAxes input_sa = - ShardableAxesUtil::GetReduceOpInputShardableAxes(input_rank, reduce_axes); + ShardableAxesUtil::MakeReduceOpInputShardableAxes(input_rank, reduce_axes); using InputSignature = std::unordered_map; - ; const ShardableAxes output_sa = (GetReduceOpKeepDims(reduce_op) ? input_sa : SequeezeShardableAxes(input_sa)); return ShardableAxesSignature{ @@ -261,10 +277,17 @@ ShardableAxesSignature MakeShardableAxesSignature4ReduceOp( }; } +bool IsDisabledElementwiseOp(const pir::Operation* op) { + if (op->isa()) return true; + return false; +} + ShardableAxesSignature MakeShardableAxesSignature4ElementWiseOp( const pir::Operation* op) { - CHECK(!op->isa()) - << "reshape not supported. TODO(wuzhanfei)."; + if (IsDisabledElementwiseOp(op)) { + LOG(ERROR) << "[ShardableAxesSignature] no shardable axes signature found. op_name : " << op->name(); + return MakeEmptyShardableAxesSignature(op); + } const size_t rank = [&] { std::optional rank; for (int i = 0; i < op->num_operands(); ++i) { @@ -284,7 +307,7 @@ ShardableAxesSignature MakeShardableAxesSignature4ElementWiseOp( return rank.value(); }(); const ShardableAxes output_shardable_axes = - ShardableAxesUtil::GetFullyShardableAxes(rank); + ShardableAxesUtil::MakeFullyShardableAxes(rank); std::unordered_map input_shardable_axes; for (int i = 0; i < op->num_operands(); ++i) { input_shardable_axes[OpAndOperandIndex{op, i}] = output_shardable_axes; @@ -299,7 +322,8 @@ ShardableAxesSignature MakeShardableAxesSignature4ElementWiseOp( ShardableAxesSignature MakeShardableAxesSignature4BroadcastOp( const pir::Operation* op) { - LOG(FATAL) << "TODO(wuzhanfei)."; + LOG(ERROR) << "[ShardableAxesSignature] no shardable axes signature found. op_name : " << op->name(); + return MakeEmptyShardableAxesSignature(op); } ShardableAxesSignature MakeShardableAxesSignature4Op(const pir::Operation* op) { @@ -311,11 +335,11 @@ ShardableAxesSignature MakeShardableAxesSignature4Op(const pir::Operation* op) { } else if (kind == hlir::framework::kBroadcast) { return MakeShardableAxesSignature4BroadcastOp(op); } else { - LOG(FATAL) - << "only kReduction, kElementWise, kBroadcast supported. op_name:" + LOG(ERROR) + << "[ShardableAxesSignature] no shardable axes signature found. op_name:" << op->name(); } - LOG(FATAL) << "Dead code"; + return MakeEmptyShardableAxesSignature(op); } template @@ -555,7 +579,7 @@ std::unordered_map InferShardableAxesFromSink( CHECK_GT(op_topo.ops->count(sink), 0); const int result_idx = GetOutputShardableAxesResultIdx(sink); size_t rank = GetRank(sink->result(result_idx)); - const auto& init_sa = ShardableAxesUtil::GetFullyShardableAxes(rank); + const auto& init_sa = ShardableAxesUtil::MakeFullyShardableAxes(rank); return ReversedInferShardableAxes(reversed_walker, sink, init_sa); }