Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbaizhou committed Mar 20, 2024
1 parent 277b202 commit c02814b
Showing 1 changed file with 77 additions and 0 deletions.
77 changes: 77 additions & 0 deletions paddle/cinn/frontend/cluster_ops/fusion_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,83 @@ std::optional<ErrorGroupPattern> StmtFusionHelper::Fuse_PS_x_R_2_R(
return FuseFilteredStmtPatterns<FusePolicy_PS_x_R_2_R>(stmt_patterns);
}

bool StmtFusionHelper::FusePolicy_IS_x_PS_2_PS::FuseCondition(const StmtPattern& upstream,
const StmtPattern& downstream) {
return IsISPattern(upstream) && IsPSPattern(downstream);
}
std::variant<StmtPattern, ErrorGroupPattern> StmtFusionHelper::FusePolicy_IS_x_PS_2_PS::MergePattern(
const StmtPattern& upstream, const StmtPattern& downstream) {
return MergePatternImpl(std::get<IS>(upstream), std::get<PS>(downstream));
}
std::variant<StmtPattern, ErrorGroupPattern> StmtFusionHelper::FusePolicy_IS_x_PS_2_PS::MergePatternImpl(
const IS& upstream, const PS& downstream) {
const auto& ops = [&] {
std::vector<const pir::Operation*> ops(upstream.ops.begin(),
upstream.ops.end());
for (const auto* downstream_op : downstream.ops) {
if (std::find(ops.begin(), ops.end(), downstream_op) == ops.end()) {
ops.push_back(downstream_op);
}
}
return ops;
}();
const auto& shardable_axes_signature =
MergeShardableAxesSignature(upstream, downstream);
return StmtPattern(PS{
.ops = ops,
.sole_sink = downstream.sole_sink,
.shardable_axes_signature = shardable_axes_signature,
});
}

ShardableAxesSignature StmtFusionHelper::FusePolicy_IS_x_PS_2_PS::MergeShardableAxesSignature(
const IS& upstream, const PS& downstream) {
LOG(FATAL) << "TODO(tianchao)";
}


bool StmtFusionHelper::FusePolicy_IS_x_R_2_R::FuseCondition(const StmtPattern& upstream,
const StmtPattern& downstream) {
return IsISPattern(upstream) && IsRPattern(downstream);
}
std::variant<StmtPattern, ErrorGroupPattern> StmtFusionHelper::FusePolicy_IS_x_R_2_R::MergePattern(
const StmtPattern& upstream, const StmtPattern& downstream) {
return MergePatternImpl(std::get<IS>(upstream), std::get<R>(downstream));
}
std::variant<StmtPattern, ErrorGroupPattern> StmtFusionHelper::FusePolicy_IS_x_R_2_R::MergePatternImpl(
const IS& upstream, const R& downstream) {
if (downstream.HasFusedInput()) {
return ErrorGroupPattern{
.ops = {downstream.reduce_op_pattern.reduce_op},
.error_string = "The input of reduce has been fused.",
};
}
R new_pattern = R(downstream);
new_pattern.input = upstream;
return StmtPattern(std::move(new_pattern));
}

bool StmtFusionHelper::FusePolicy_PS_x_R_2_R::FuseCondition(const StmtPattern& upstream,
const StmtPattern& downstream) {
return IsISPattern(upstream) && IsRPattern(downstream);
}
std::variant<StmtPattern, ErrorGroupPattern> StmtFusionHelper::FusePolicy_PS_x_R_2_R::MergePattern(
const StmtPattern& upstream, const StmtPattern& downstream) {
return MergePatternImpl(std::get<PS>(upstream), std::get<R>(downstream));
}
std::variant<StmtPattern, ErrorGroupPattern> StmtFusionHelper::FusePolicy_PS_x_R_2_R::MergePatternImpl(
const PS& upstream, const R& downstream) {
if (downstream.HasFusedInput()) {
return ErrorGroupPattern{
.ops = {downstream.reduce_op_pattern.reduce_op},
.error_string = "The input of reduce has been fused.",
};
}
R new_pattern = R(downstream);
new_pattern.input = upstream;
return StmtPattern(new_pattern);
}

StmtPattern StmtFusionHelper::ConvertToStmtPattern(const pir::Operation* op) {
const hlir::framework::OpPatternKind kind = GetOpPatternKind(op);
if (IsInjectiveSource(op)) {
Expand Down

0 comments on commit c02814b

Please sign in to comment.