Skip to content

Commit

Permalink
Implement MakeGetterIsInjectiveSource
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahy0825 committed Mar 6, 2024
1 parent ae48ead commit 11ae7cc
Showing 1 changed file with 118 additions and 9 deletions.
127 changes: 118 additions & 9 deletions paddle/cinn/frontend/group_pattern_util.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "paddle/cinn/frontend/group_pattern_util.h"
#include "paddle/cinn/common/topo_walker.h"
#include <optional>

namespace cinn::frontend {
Expand All @@ -11,14 +12,86 @@ using PS = api::PartialShardablePattern<FrontendPattern>;
using InternalPattern = std::variant<IS, R, PS>;


std::function<bool(const pir::Operation*)> MakeGetterIsInThisFusionOp(const pir::FusionOp& fusion_op) {
TODO();
std::function<bool(const pir::Operation*)> MakeGetterIsInThisFusionOp(const cinn::dialect::FusionOp& fusion_op) {
std::set<pir::Operation*> set;
for (const pir::Operation* op : fusion_op.block()->ops()) {
if (!op->isa<pir::YieldOp>()) {
set.insert(op);
}
}
return [set = std::move(set)](const pir::Operation* op) {
return set.count(op) > 0;
};
}

bool IsGeneralInjective(const pir::Operation* op) {
hlir::framework::OpPatternKind op_pattern_kind = GetOpPatternKind(op);
return op_pattern_kind == hlir::framework::kElementWise
|| op_pattern_kind == hlir::framework::kBroadcast
|| op_pattern_kind == hlir::framework::kInjective;
}

std::function<bool(const pir::Operation*)> MakeGetterIsInjectiveSource(
const pir::FusionOp& fusion_op,
const cinn::dialect::FusionOp& fusion_op,
const std::function<bool(const pir::Operation*)>& IsInThisFusionOp) {
TODO();
using NodeVisitor = std::function<void(pir::Operation*)>;
const auto VisitEachInput = [&](const pir::Operation* node, const NodeVisitor& DoEach) {
for (int i = 0; i < op->num_operands(); ++i) {
const auto* input_op = op->operand_source(i).defining_op();
if (IsInThisFusionOp(input_op)) {
DoEach(input_op);
}
}
};
const auto VisitEachOutput = [&](const pir::Operation* node, const NodeVisitor& DoEach) {
for (int i = 0; i < op->num_results(); ++i) {
pir::Value output = op->result(i);
for (auto consumer_it = output.use_begin(); consumer_it != output.use_end(); ++consumer_it) {
const auto* consumer_op = consumer_it->owner();
if (IsInThisFusionOp(consumer_op)) {
DoEach(consumer_op);
}
}
}
};

const auto starts = [&]{
const auto& IsSource = [&](const pir::Operation* op) {
std::size_t num_inputs = 0;
VisitEachInput([&](const pir::Operation*) { ++num_inputs});
return num_inputs == 0;
};
std::list<const pir::Operation*> starts;
for (const auto* op : fusion_op.block().ops()) {
if (!IsInThisFusionOp(op)) continue;
if (IsSource(op)) {
starts.push_back(op);
} else {
// do nothing.
}
}
return starts;
}();

std::unordered_map<pir::Operation*, bool> op_2_is_injective_source;

auto IsInputsAllInjectiveSource = [&](const pir::Operation* op) {
bool is_inputs_all_injective_source = true;
VisitEachInput(op, [&](const pir::Operation* input){
is_inputs_all_injective_source = (is_inputs_all_injective_source && op_2_is_injective_source.at(input));
});
return is_inputs_all_injective_source;
};

common::TopoWalker<const pir::Operation*> walker{VisitEachInput, VisitEachOutput};
walker(starts, [&](const pir::Operation* op){
op_2_is_injective_source[op] = (IsGeneralInjective(op) && IsInputsAllInjectiveSource(op));
});
return [map = std::move(op_2_is_injective_source)](const pir::Operation* op) {
const auto& iter = map.find(op);
CHECK(iter != map.end());
return iter->second;
};
}

void InitInternalFusions(const std::optional<IS> injective_source, std::vector<InternalPattern>* ret) {
Expand All @@ -31,7 +104,7 @@ struct InternalFusionHelper {
const std::function<bool(const pir::Operation*)> IsInThisFusionOp;
const std::function<bool(const pir::Operation*)> IsInjectiveSource;

std::vector<InternalPattern> FuseISAndConvertRemainder(const pir::FusionOp& fusion_op) const {
std::vector<InternalPattern> FuseISAndConvertRemainder(const cinn::dialect::FusionOp& fusion_op) const {
TODO();
}

Expand All @@ -53,7 +126,7 @@ struct InternalFusionHelper {

};

std::variant<std::vector<InternalPattern>, ErrorGroupPattern> InternalFusion(const pir::FusionOp& fusion_op) {
std::variant<std::vector<InternalPattern>, ErrorGroupPattern> InternalFusion(const cinn::dialect::FusionOp& fusion_op) {
const auto& IsInThisFusionOp = MakeGetterIsInThisFusionOp(fusion_op);
const auto& IsInjectiveSource = MakeGetterIsInjectiveSource(fusion_op, IsInThisFusionOp);
InternalFusionHelper helper{IsInThisFusionOp, IsInjectiveSource};
Expand All @@ -65,8 +138,44 @@ std::variant<std::vector<InternalPattern>, ErrorGroupPattern> InternalFusion(con
return internal_patterns;
}

std::variant<GroupPattern, ErrorGroupPattern> LiftToGroupPattern(const std::vector<InternalPattern>& internal_patterns) {
TODO();
std::optional<IS> ConvertToSoleIS(const std::vector<InternalPattern>& internal_patterns) {
std::optional<IS> injective_source;
for (const auto& pattern : internal_patterns) {
if (std::holds_alternative<IS>(pattern)) {
if (injective_source.has_value()) {
LOG(FATAL) << "zero or one InjectiveSource allowed.";
}
injective_source = std::get<IS>(pattern);
}
}
return injective_source;
}

struct ConvertInternalPatternToPSOrR {
std::variant<PS, R> operator()(const IS& pattern) {
LOG(FATAL) << "dead code";
}
std::variant<PS, R> operator()(const PS& pattern) {
return pattern;
}
std::variant<PS, R> operator()(const R& pattern) {
return pattern;
}
}

api::ShardableReductionsPattern<FrontendPattern> LiftToShardableReductionsPattern(
const std::vector<InternalPattern>& internal_patterns) {
api::ShardableReductionsPattern<FrontendPattern> ret;
for (const auto& pattern : internal_patterns) {
ret.emplace_back(std::visit(ConvertInternalPatternToPSOrR{}, pattern));
}
return ret;
}


GroupPattern LiftToGroupPattern(const std::vector<InternalPattern>& internal_patterns) {
if (const auto& opt_injective_src = ConvertToSoleIS(internal_patterns)) return opt_injective_src.value();
return LiftToShardableReductionsPattern(internal_patterns);
}

struct SafeLiftToGroupPattern {
Expand All @@ -81,7 +190,7 @@ struct SafeLiftToGroupPattern {

}

std::variant<GroupPattern, ErrorGroupPattern> GenerateGroupPatternFromFusionOp(const pir::FusionOp& fusion_op) {
std::variant<GroupPattern, ErrorGroupPattern> GenerateGroupPatternFromFusionOp(const cinn::dialect::FusionOp& fusion_op) {
return std::visit(SafeLiftToGroupPattern{}, InternalFusion(fusion_op));
}

Expand Down

0 comments on commit 11ae7cc

Please sign in to comment.