Skip to content

Commit

Permalink
Merge pull request #47 from tc20042008/xk-cinn-trivalop-fuse
Browse files Browse the repository at this point in the history
implement group_pattern_util.InferShardableAxes
  • Loading branch information
tc20042008 authored Mar 10, 2024
2 parents c947ada + de23d96 commit 604afab
Showing 1 changed file with 137 additions and 109 deletions.
246 changes: 137 additions & 109 deletions paddle/cinn/frontend/group_pattern_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,124 @@ std::function<bool(const pir::Operation*)> MakePredicatorIsInjectiveSource(
};
}

size_t GetRank(pir::Value value) {
return value.type().dyn_cast<pir::DenseTensorType>().dims().size();
}

ShardableAxesSignature MakeShardableAxesSignature4ElementWiseOp(const pir::Operation* op) {
CHECK(!op->isa<cinn::dialect::ReshapeOp>()) << "reshape not supported. TODO(wuzhanfei).";
const size_t rank = [&]{
std::optional<size_t> rank;
for (int i = 0; i < op->num_operands(); ++i) {
if (rank.has_value()) {
CHECK_EQ(rank.value(), GetRank(op->operand_source(i)));
} else {
rank = GetRank(op->operand_source(i));
}
}
CHECK_EQ(op->num_results(), 1);
if (rank.has_value()) {
CHECK_EQ(rank.value(), GetRank(op->result(0)));
} else {
rank = GetRank(op->result(0));
}
CHECK(rank.has_value());
return rank.value();
}();
const ShardableAxes output_shardable_axes = ShardableAxesUtil::GetFullyShardableAxes(rank);
std::unordered_map<OpAndOperandIndex, ShardableAxes> input_shardable_axes;
for (int i = 0; i < op->num_operands(); ++i) {
input_shardable_axes[OpAndOperandIndex{op, i}] = output_shardable_axes;
}
return ShardableAxesSignature{
.output_shardable_axes=output_shardable_axes,
.input_shardable_axes=input_shardable_axes,
};
}

ShardableAxesSignature MakeShardableAxesSignature4BroadcastOp(const pir::Operation* op) {
LOG(FATAL) << "TODO(wuzhanfei).";
}

ShardableAxesSignature MakeShardableAxesSignature4Op(const pir::Operation* op) {
const hlir::framework::OpPatternKind kind = GetOpPatternKind(op);
if (kind == hlir::framework::kElementWise) {
return MakeShardableAxesSignature4ElementWiseOp(op);
} else if (kind == hlir::framework::kBroadcast) {
return MakeShardableAxesSignature4BroadcastOp(op);
} else {
LOG(FATAL) << "only kReduction, kElementWise, kBroadcast supported. op_name:" << op->name();
}
LOG(FATAL) << "Dead code";
}

std::unordered_map<pir::Value, ShardableAxes> ReversedInferShardableAxes(
common::TopoWalker<const pir::Operation*>& reversed_walker,
const pir::Operation* sink,
const ShardableAxes& init_sa) {
std::unordered_map<pir::Value, ShardableAxes> value2shardable_axes{
{sink->result(0), init_sa}
};
const auto& UpdateValue2ShardableAxes = [&](pir::Value value, const ShardableAxes& sa) {
auto iter = value2shardable_axes.find(value);
if (iter != value2shardable_axes.end()) {
iter->second = ShardableAxesUtil::GetCommonShardableAxes(iter->second, sa);
} else {
iter->second = sa;
}
};
reversed_walker(sink, [&](const auto* op){
auto shardable_axes_sig = MakeShardableAxesSignature4Op(op);
const auto& old2new = ShardableAxesUtil::GetOldName2NewName(shardable_axes_sig.output_shardable_axes,
value2shardable_axes.at(op->result(0)));
for (auto& pair : shardable_axes_sig.input_shardable_axes) {
const auto& [my_op, input_idx] = pair.first;
CHECK_EQ(my_op, op);
auto* input_shardable_axes = &pair.second;
ShardableAxesUtil::UpdateShardableAxes(old2new, input_shardable_axes);
pir::Value input_value = op->operand_source(input_idx);
UpdateValue2ShardableAxes(input_value, *input_shardable_axes);
}
});
return value2shardable_axes;
}

common::TopoWalker<const pir::Operation*> GetOpsTopoWalker(const std::vector<const pir::Operation*>& ops) {
using Cache = std::unordered_set<const pir::Operation*>;
auto ops_set = std::make_shared<Cache>(ops.begin(), ops.end());
const auto VisitUpStreamInOps = [ops_set](const pir::Operation* op, const OpVisitor& DoEach) {
VisitInputOp(op, [&](const auto* input){
if (ops_set->count(input) == 0) return;
DoEach(input);
});
};
const auto VisitDownStreamInOps = [ops_set](const pir::Operation* op, const OpVisitor& DoEach) {
VisitOutputOp(op, [&](const auto* output){
if (ops_set->count(output) == 0) return;
DoEach(output);
});
};
common::TopoWalker<const pir::Operation*> reversed_walker(VisitDownStreamInOps, VisitUpStreamInOps);
return reversed_walker;
}

std::list<const pir::Operation*> GetStarts(
const common::TopoWalker<const pir::Operation*>& topo_walker,
const std::vector<const pir::Operation*>& ops) {
const auto IsStart = [&](const pir::Operation* op) {
size_t num_prevs = 0;
topo_walker.VisitPrevNodes(op, [&](const auto*){ ++num_prevs; });
return num_prevs == 0;
};
std::list<const pir::Operation*> starts;
for (const auto* op : ops) {
if (IsStart(op)) {
starts.push_back(op);
}
}
return starts;
}

class StmtFusionHelper {
public:
explicit StmtFusionHelper(const cinn::dialect::FusionOp& fusion_op)
Expand Down Expand Up @@ -409,57 +527,6 @@ class StmtFusionHelper {
return std::nullopt;
}

size_t GetRank(pir::Value value) const {
return value.type().dyn_cast<pir::DenseTensorType>().dims().size();
};

ShardableAxesSignature MakeShardableAxesSignature4Op(const pir::Operation* op) const {
const hlir::framework::OpPatternKind kind = GetOpPatternKind(op);
if (kind == hlir::framework::kElementWise) {
return MakeShardableAxesSignature4ElementWiseOp(op);
} else if (kind == hlir::framework::kBroadcast) {
return MakeShardableAxesSignature4BroadcastOp(op);
} else {
LOG(FATAL) << "only kReduction, kElementWise, kBroadcast supported. op_name:" << op->name();
}
LOG(FATAL) << "Dead code";
}

ShardableAxesSignature MakeShardableAxesSignature4ElementWiseOp(const pir::Operation* op) const {
CHECK(!op->isa<cinn::dialect::ReshapeOp>()) << "reshape not supported. TODO(wuzhanfei).";
const size_t rank = [&]{
std::optional<size_t> rank;
for (int i = 0; i < op->num_operands(); ++i) {
if (rank.has_value()) {
CHECK_EQ(rank.value(), GetRank(op->operand_source(i)));
} else {
rank = GetRank(op->operand_source(i));
}
}
CHECK_EQ(op->num_results(), 1);
if (rank.has_value()) {
CHECK_EQ(rank.value(), GetRank(op->result(0)));
} else {
rank = GetRank(op->result(0));
}
CHECK(rank.has_value());
return rank.value();
}();
const ShardableAxes output_shardable_axes = ShardableAxesUtil::GetFullyShardableAxes(rank);
std::unordered_map<OpAndOperandIndex, ShardableAxes> input_shardable_axes;
for (int i = 0; i < op->num_operands(); ++i) {
input_shardable_axes[OpAndOperandIndex{op, i}] = output_shardable_axes;
}
return ShardableAxesSignature{
.output_shardable_axes=output_shardable_axes,
.input_shardable_axes=input_shardable_axes,
};
}

ShardableAxesSignature MakeShardableAxesSignature4BroadcastOp(const pir::Operation* op) const {
LOG(FATAL) << "TODO(wuzhanfei).";
}

struct StmtIterPair {
std::list<StmtPtr>::iterator upstream_iter;
std::list<StmtPtr>::iterator downstream_iter;
Expand Down Expand Up @@ -550,36 +617,13 @@ class StmtFusionHelper {

ShardableAxesSignature GetShardableAxesSignature(const std::vector<const pir::Operation*>& ops) const {
std::unordered_set<const pir::Operation*> ops_set(ops.begin(), ops.end());
const auto VisitUpStreamInOps = [&](const pir::Operation* op, const OpVisitor& DoEach) {
VisitInputOp(op, [&](const auto* input){
if (ops_set.count(input) == 0) return;
DoEach(input);
});
};
const auto VisitDownStreamInOps = [&](const pir::Operation* op, const OpVisitor& DoEach) {
VisitOutputOp(op, [&](const auto* output){
if (ops_set.count(output) == 0) return;
DoEach(output);
});
};
const auto IsSinkOp = [&](const pir::Operation* op) {
size_t num_donwstreams = 0;
VisitDownStreamInOps(op, [&](const auto*){ ++num_donwstreams; });
return num_donwstreams == 0;
};
auto reversed_walker = GetOpsTopoWalker(ops);
const pir::Operation* sink = [&]{
std::optional<const pir::Operation*> sink;
for (const auto* op : ops) {
if (IsSinkOp(op)) {
CHECK(!sink.has_value()) << "only one sink node.";
}
sink = op;
}
CHECK(sink.has_value());
return sink.value();
const auto& sinks = GetStarts(reversed_walker, ops);
CHECK_EQ(sinks.size(), 1) << "ops must have only one sink node.";
return *sinks.begin();
}();
const auto& value2shardable_axes = [&]{
common::TopoWalker<const pir::Operation*> reversed_walker(VisitDownStreamInOps, VisitUpStreamInOps);
size_t rank = GetRank(sink->result(0));
const auto& init_sa = ShardableAxesUtil::GetFullyShardableAxes(rank);
return ReversedInferShardableAxes(reversed_walker, sink, init_sa);
Expand Down Expand Up @@ -611,37 +655,6 @@ class StmtFusionHelper {
return shardable_axes_sig;
}

std::unordered_map<pir::Value, ShardableAxes> ReversedInferShardableAxes(
common::TopoWalker<const pir::Operation*>& reversed_walker,
const pir::Operation* sink,
const ShardableAxes& init_sa) const {
std::unordered_map<pir::Value, ShardableAxes> value2shardable_axes{
{sink->result(0), init_sa}
};
const auto& UpdateValue2ShardableAxes = [&](pir::Value value, const ShardableAxes& sa) {
auto iter = value2shardable_axes.find(value);
if (iter != value2shardable_axes.end()) {
iter->second = ShardableAxesUtil::GetCommonShardableAxes(iter->second, sa);
} else {
iter->second = sa;
}
};
reversed_walker(sink, [&](const auto* op){
auto shardable_axes_sig = MakeShardableAxesSignature4Op(op);
const auto& old2new = ShardableAxesUtil::GetOldName2NewName(shardable_axes_sig.output_shardable_axes,
value2shardable_axes.at(op->result(0)));
for (auto& pair : shardable_axes_sig.input_shardable_axes) {
const auto& [my_op, input_idx] = pair.first;
CHECK_EQ(my_op, op);
auto* input_shardable_axes = &pair.second;
ShardableAxesUtil::UpdateShardableAxes(old2new, input_shardable_axes);
pir::Value input_value = op->operand_source(input_idx);
UpdateValue2ShardableAxes(input_value, *input_shardable_axes);
}
});
return value2shardable_axes;
}

private:
cinn::dialect::FusionOp fusion_op_;
std::function<bool(const pir::Operation*)> IsInThisFusionOp;
Expand All @@ -665,4 +678,19 @@ GroupPattern GenerateGroupPatternFromFusionOp(const cinn::dialect::FusionOp& fus
return FuseToGroupPattern(fusion_op);
}

std::unordered_map<pir::Value, ShardableAxes> InferShardableAxes(const std::vector<const pir::Operation*>& ops) {
auto reversed_walker = GetOpsTopoWalker(ops);
const pir::Operation* sink = [&]{
const auto& sinks = GetStarts(reversed_walker, ops);
CHECK_EQ(sinks.size(), 1) << "ops must have only one sink node.";
return *sinks.begin();
}();
const auto& value2shardable_axes = [&]{
size_t rank = GetRank(sink->result(0));
const auto& init_sa = ShardableAxesUtil::GetFullyShardableAxes(rank);
return ReversedInferShardableAxes(reversed_walker, sink, init_sa);
}();
return value2shardable_axes;
}

}

0 comments on commit 604afab

Please sign in to comment.