Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement group_pattern_util.InferShardableAxes #47

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 137 additions & 109 deletions paddle/cinn/frontend/group_pattern_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,124 @@ std::function<bool(const pir::Operation*)> MakePredicatorIsInjectiveSource(
};
}

size_t GetRank(pir::Value value) {
return value.type().dyn_cast<pir::DenseTensorType>().dims().size();
}

ShardableAxesSignature MakeShardableAxesSignature4ElementWiseOp(const pir::Operation* op) {
CHECK(!op->isa<cinn::dialect::ReshapeOp>()) << "reshape not supported. TODO(wuzhanfei).";
const size_t rank = [&]{
std::optional<size_t> rank;
for (int i = 0; i < op->num_operands(); ++i) {
if (rank.has_value()) {
CHECK_EQ(rank.value(), GetRank(op->operand_source(i)));
} else {
rank = GetRank(op->operand_source(i));
}
}
CHECK_EQ(op->num_results(), 1);
if (rank.has_value()) {
CHECK_EQ(rank.value(), GetRank(op->result(0)));
} else {
rank = GetRank(op->result(0));
}
CHECK(rank.has_value());
return rank.value();
}();
const ShardableAxes output_shardable_axes = ShardableAxesUtil::GetFullyShardableAxes(rank);
std::unordered_map<OpAndOperandIndex, ShardableAxes> input_shardable_axes;
for (int i = 0; i < op->num_operands(); ++i) {
input_shardable_axes[OpAndOperandIndex{op, i}] = output_shardable_axes;
}
return ShardableAxesSignature{
.output_shardable_axes=output_shardable_axes,
.input_shardable_axes=input_shardable_axes,
};
}

ShardableAxesSignature MakeShardableAxesSignature4BroadcastOp(const pir::Operation* op) {
LOG(FATAL) << "TODO(wuzhanfei).";
}

ShardableAxesSignature MakeShardableAxesSignature4Op(const pir::Operation* op) {
const hlir::framework::OpPatternKind kind = GetOpPatternKind(op);
if (kind == hlir::framework::kElementWise) {
return MakeShardableAxesSignature4ElementWiseOp(op);
} else if (kind == hlir::framework::kBroadcast) {
return MakeShardableAxesSignature4BroadcastOp(op);
} else {
LOG(FATAL) << "only kReduction, kElementWise, kBroadcast supported. op_name:" << op->name();
}
LOG(FATAL) << "Dead code";
}

std::unordered_map<pir::Value, ShardableAxes> ReversedInferShardableAxes(
common::TopoWalker<const pir::Operation*>& reversed_walker,
const pir::Operation* sink,
const ShardableAxes& init_sa) {
std::unordered_map<pir::Value, ShardableAxes> value2shardable_axes{
{sink->result(0), init_sa}
};
const auto& UpdateValue2ShardableAxes = [&](pir::Value value, const ShardableAxes& sa) {
auto iter = value2shardable_axes.find(value);
if (iter != value2shardable_axes.end()) {
iter->second = ShardableAxesUtil::GetCommonShardableAxes(iter->second, sa);
} else {
iter->second = sa;
}
};
reversed_walker(sink, [&](const auto* op){
auto shardable_axes_sig = MakeShardableAxesSignature4Op(op);
const auto& old2new = ShardableAxesUtil::GetOldName2NewName(shardable_axes_sig.output_shardable_axes,
value2shardable_axes.at(op->result(0)));
for (auto& pair : shardable_axes_sig.input_shardable_axes) {
const auto& [my_op, input_idx] = pair.first;
CHECK_EQ(my_op, op);
auto* input_shardable_axes = &pair.second;
ShardableAxesUtil::UpdateShardableAxes(old2new, input_shardable_axes);
pir::Value input_value = op->operand_source(input_idx);
UpdateValue2ShardableAxes(input_value, *input_shardable_axes);
}
});
return value2shardable_axes;
}

common::TopoWalker<const pir::Operation*> GetOpsTopoWalker(const std::vector<const pir::Operation*>& ops) {
using Cache = std::unordered_set<const pir::Operation*>;
auto ops_set = std::make_shared<Cache>(ops.begin(), ops.end());
const auto VisitUpStreamInOps = [ops_set](const pir::Operation* op, const OpVisitor& DoEach) {
VisitInputOp(op, [&](const auto* input){
if (ops_set->count(input) == 0) return;
DoEach(input);
});
};
const auto VisitDownStreamInOps = [ops_set](const pir::Operation* op, const OpVisitor& DoEach) {
VisitOutputOp(op, [&](const auto* output){
if (ops_set->count(output) == 0) return;
DoEach(output);
});
};
common::TopoWalker<const pir::Operation*> reversed_walker(VisitDownStreamInOps, VisitUpStreamInOps);
return reversed_walker;
}

std::list<const pir::Operation*> GetStarts(
const common::TopoWalker<const pir::Operation*>& topo_walker,
const std::vector<const pir::Operation*>& ops) {
const auto IsStart = [&](const pir::Operation* op) {
size_t num_prevs = 0;
topo_walker.VisitPrevNodes(op, [&](const auto*){ ++num_prevs; });
return num_prevs == 0;
};
std::list<const pir::Operation*> starts;
for (const auto* op : ops) {
if (IsStart(op)) {
starts.push_back(op);
}
}
return starts;
}

class StmtFusionHelper {
public:
explicit StmtFusionHelper(const cinn::dialect::FusionOp& fusion_op)
Expand Down Expand Up @@ -409,57 +527,6 @@ class StmtFusionHelper {
return std::nullopt;
}

size_t GetRank(pir::Value value) const {
return value.type().dyn_cast<pir::DenseTensorType>().dims().size();
};

ShardableAxesSignature MakeShardableAxesSignature4Op(const pir::Operation* op) const {
const hlir::framework::OpPatternKind kind = GetOpPatternKind(op);
if (kind == hlir::framework::kElementWise) {
return MakeShardableAxesSignature4ElementWiseOp(op);
} else if (kind == hlir::framework::kBroadcast) {
return MakeShardableAxesSignature4BroadcastOp(op);
} else {
LOG(FATAL) << "only kReduction, kElementWise, kBroadcast supported. op_name:" << op->name();
}
LOG(FATAL) << "Dead code";
}

ShardableAxesSignature MakeShardableAxesSignature4ElementWiseOp(const pir::Operation* op) const {
CHECK(!op->isa<cinn::dialect::ReshapeOp>()) << "reshape not supported. TODO(wuzhanfei).";
const size_t rank = [&]{
std::optional<size_t> rank;
for (int i = 0; i < op->num_operands(); ++i) {
if (rank.has_value()) {
CHECK_EQ(rank.value(), GetRank(op->operand_source(i)));
} else {
rank = GetRank(op->operand_source(i));
}
}
CHECK_EQ(op->num_results(), 1);
if (rank.has_value()) {
CHECK_EQ(rank.value(), GetRank(op->result(0)));
} else {
rank = GetRank(op->result(0));
}
CHECK(rank.has_value());
return rank.value();
}();
const ShardableAxes output_shardable_axes = ShardableAxesUtil::GetFullyShardableAxes(rank);
std::unordered_map<OpAndOperandIndex, ShardableAxes> input_shardable_axes;
for (int i = 0; i < op->num_operands(); ++i) {
input_shardable_axes[OpAndOperandIndex{op, i}] = output_shardable_axes;
}
return ShardableAxesSignature{
.output_shardable_axes=output_shardable_axes,
.input_shardable_axes=input_shardable_axes,
};
}

ShardableAxesSignature MakeShardableAxesSignature4BroadcastOp(const pir::Operation* op) const {
LOG(FATAL) << "TODO(wuzhanfei).";
}

struct StmtIterPair {
std::list<StmtPtr>::iterator upstream_iter;
std::list<StmtPtr>::iterator downstream_iter;
Expand Down Expand Up @@ -550,36 +617,13 @@ class StmtFusionHelper {

ShardableAxesSignature GetShardableAxesSignature(const std::vector<const pir::Operation*>& ops) const {
std::unordered_set<const pir::Operation*> ops_set(ops.begin(), ops.end());
const auto VisitUpStreamInOps = [&](const pir::Operation* op, const OpVisitor& DoEach) {
VisitInputOp(op, [&](const auto* input){
if (ops_set.count(input) == 0) return;
DoEach(input);
});
};
const auto VisitDownStreamInOps = [&](const pir::Operation* op, const OpVisitor& DoEach) {
VisitOutputOp(op, [&](const auto* output){
if (ops_set.count(output) == 0) return;
DoEach(output);
});
};
const auto IsSinkOp = [&](const pir::Operation* op) {
size_t num_donwstreams = 0;
VisitDownStreamInOps(op, [&](const auto*){ ++num_donwstreams; });
return num_donwstreams == 0;
};
auto reversed_walker = GetOpsTopoWalker(ops);
const pir::Operation* sink = [&]{
std::optional<const pir::Operation*> sink;
for (const auto* op : ops) {
if (IsSinkOp(op)) {
CHECK(!sink.has_value()) << "only one sink node.";
}
sink = op;
}
CHECK(sink.has_value());
return sink.value();
const auto& sinks = GetStarts(reversed_walker, ops);
CHECK_EQ(sinks.size(), 1) << "ops must have only one sink node.";
return *sinks.begin();
}();
const auto& value2shardable_axes = [&]{
common::TopoWalker<const pir::Operation*> reversed_walker(VisitDownStreamInOps, VisitUpStreamInOps);
size_t rank = GetRank(sink->result(0));
const auto& init_sa = ShardableAxesUtil::GetFullyShardableAxes(rank);
return ReversedInferShardableAxes(reversed_walker, sink, init_sa);
Expand Down Expand Up @@ -611,37 +655,6 @@ class StmtFusionHelper {
return shardable_axes_sig;
}

std::unordered_map<pir::Value, ShardableAxes> ReversedInferShardableAxes(
common::TopoWalker<const pir::Operation*>& reversed_walker,
const pir::Operation* sink,
const ShardableAxes& init_sa) const {
std::unordered_map<pir::Value, ShardableAxes> value2shardable_axes{
{sink->result(0), init_sa}
};
const auto& UpdateValue2ShardableAxes = [&](pir::Value value, const ShardableAxes& sa) {
auto iter = value2shardable_axes.find(value);
if (iter != value2shardable_axes.end()) {
iter->second = ShardableAxesUtil::GetCommonShardableAxes(iter->second, sa);
} else {
iter->second = sa;
}
};
reversed_walker(sink, [&](const auto* op){
auto shardable_axes_sig = MakeShardableAxesSignature4Op(op);
const auto& old2new = ShardableAxesUtil::GetOldName2NewName(shardable_axes_sig.output_shardable_axes,
value2shardable_axes.at(op->result(0)));
for (auto& pair : shardable_axes_sig.input_shardable_axes) {
const auto& [my_op, input_idx] = pair.first;
CHECK_EQ(my_op, op);
auto* input_shardable_axes = &pair.second;
ShardableAxesUtil::UpdateShardableAxes(old2new, input_shardable_axes);
pir::Value input_value = op->operand_source(input_idx);
UpdateValue2ShardableAxes(input_value, *input_shardable_axes);
}
});
return value2shardable_axes;
}

private:
cinn::dialect::FusionOp fusion_op_;
std::function<bool(const pir::Operation*)> IsInThisFusionOp;
Expand All @@ -665,4 +678,19 @@ GroupPattern GenerateGroupPatternFromFusionOp(const cinn::dialect::FusionOp& fus
return FuseToGroupPattern(fusion_op);
}

std::unordered_map<pir::Value, ShardableAxes> InferShardableAxes(const std::vector<const pir::Operation*>& ops) {
auto reversed_walker = GetOpsTopoWalker(ops);
const pir::Operation* sink = [&]{
const auto& sinks = GetStarts(reversed_walker, ops);
CHECK_EQ(sinks.size(), 1) << "ops must have only one sink node.";
return *sinks.begin();
}();
const auto& value2shardable_axes = [&]{
size_t rank = GetRank(sink->result(0));
const auto& init_sa = ShardableAxesUtil::GetFullyShardableAxes(rank);
return ReversedInferShardableAxes(reversed_walker, sink, init_sa);
}();
return value2shardable_axes;
}

}