Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#31 from feifei-111/cinn-trivalop-fuse
Browse files Browse the repository at this point in the history
MergePattern
  • Loading branch information
feifei-111 authored Mar 6, 2024
2 parents 0a00878 + c1bb050 commit ab42ae4
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 14 deletions.
6 changes: 3 additions & 3 deletions paddle/cinn/api/op_topo_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct PartialShardablePattern {};
template <typename T>
struct ReductionPattern {
using Nothing = std::monostate;
std::variant<Nothing, InjectiveSourcePattern<T>, PartialShardablePattern> opt_is_or_ps_input;
std::variant<Nothing, InjectiveSourcePattern<T>, PartialShardablePattern<T>> opt_inputs;
SingleReductionOpPattern<T> reduction_op_pattern;
};

Expand All @@ -30,8 +30,8 @@ template <typename T>
using ShardableReductionsPattern = std::vector<std::variant<ReductionPattern<T>, PartialShardablePattern<T>>>;

// fuse rules:
// 1. IS * PS -> PS
// 2. PS * PS -> PS
// 1. PS * PS -> PS
// 2. IS * PS -> PS
// 3. IS * R -> R
// 4. PS * R -> R

Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/frontend/group_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ struct ShardableAxes {
struct ShardableAxesSignature {
using OpOperand = std::pair<const pir::Operation*, /*operand index*/int>;

ShardableAxes output_shardable_axes;
std::vector<ShardableAxes> output_shardable_axes;
std::unordered_map<OpOperand, ShardableAxes> input_shardable_axes;
};

Expand Down
59 changes: 49 additions & 10 deletions paddle/cinn/frontend/group_pattern_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,47 @@ struct InternalFusionHelper {
LOG(FATAL) << "TODO(wuzhanfei)";
}

std::variant<IternalPattern, ErrorGroupPattern> MergePattern(
const IS& upstream,
const PS& downstream){
PS new_pattern = CopyPattern(downstream);
new_pattern.ops.insert(new_pattern.end(), upstream.begin(), upstream.end());
return new_pattern;
}

std::variant<IternalPattern, ErrorGroupPattern> MergePattern(
const PS& upstream,
const PS& downstream){
PS new_pattern = CopyPattern(downstream);
new_pattern.ops.insert(new_pattern.end(), upstream.begin(), upstream.end());
new_pattern.shardable_axes_signature.output_shardable_axes.insert(
new_pattern.shardable_axes_signature.output_shardable_axes.end(),
upstream.shardable_axes_signature.output_shardable_axes.begin(),
upstream.shardable_axes_signature.output_shardable_axes.end()
);
new_pattern.shardable_axes_signature.input_shardable_axes.insert(
upstream.shardable_axes_signature.input_shardable_axes.begin(),
upstream.shardable_axes_signature.input_shardable_axes.end()
);
return new_pattern
}

std::variant<IternalPattern, ErrorGroupPattern> MergePattern(
const IS& upstream,
const R& downstream){
R new_pattern = CopyPattern(downstream);
new_pattern.opt_inputs = CopyPattern(upstream);
return new_pattern;
}

std::variant<IternalPattern, ErrorGroupPattern> MergePattern(
const PS& upstream,
const R& downstream){
R new_pattern = CopyPattern(downstream);
new_pattern.opt_inputs = CopyPattern(upstream);
return new_pattern;
}

SplitedOps SplitInjectiveSourceOps(const cinn::dialect::FusionOp& fusion_op) {
SplitedOps ret;
for (const auto& op : fusion_op.block().ops()) {
Expand All @@ -170,31 +211,27 @@ struct InternalFusionHelper {
std::list<const pir::Operation*> injective_source_ops;
std::list<const pir::Operation*> remainder_ops;
}


std::optional<std::pair<InternalPattern, InternalPattern>> FindConnetedPattenPairWithCondition(
std::vector<InternalPattern>* internal_patterns,
std::function<bool(const IternalPattern&, const IternalPattern&)>& FuseTargetCondition /* first input is upstream, second is downstream */) const {
std::function<bool(const IternalPattern& upstream, const IternalPattern& downstream)>& FuseTargetCondition) const {
for (int i=0; i<internal_patterns.size(); i++){
for (int j=i+1; j<internal_patterns.size(); j++){
bool i_used_j = FirstIsUpstreamOfSecond(internal_patterns[j], internal_patterns[i]);
bool j_used_i = FirstIsUpstreamOfSecond(internal_patterns[i], internal_patterns[j]);

if((!i_used_j && !j_used_i) || LeadToLoop()){
continue;
}

if (i_used_j && FuseTargetCondition(internal_patterns[j], internal_patterns[i])){
return std::make_pair(internal_patterns[j], internal_patterns[i]);
}else if(j_used_i && FuseTargetCondition(internal_patterns[i], internal_patterns[j])){
return std::make_pair(internal_patterns[i], internal_patterns[j]);
}else{
continue;
}
}
}
return {};
return std::nullopt;
}


std::optional<ErrorGroupPattern> FuseIternalPattenPrototype(
std::vector<InternalPattern>* internal_patterns,
std::function<bool(const IternalPattern&, const IternalPattern&)>& FuseTargetCondition) const{
Expand All @@ -206,7 +243,9 @@ struct InternalFusionHelper {
if (!pattern_pair.value()){
break;
}
const InternalPattern& new_pattern = MergePattern(pattern_pair.first, pattern_pair.second);
const std::variant<IternalPattern, ErrorGroupPattern>& new_pattern =
MergePattern(pattern_pair.first, pattern_pair.second);

if (IsErrorGroupPattern(new_pattern)){
return new_pattern;
}
Expand Down Expand Up @@ -261,8 +300,8 @@ std::variant<std::vector<InternalPattern>, ErrorGroupPattern> InternalFusion(con
const auto& IsInjectiveSource = MakeGetterIsInjectiveSource(fusion_op, IsInThisFusionOp);
InternalFusionHelper helper{IsInThisFusionOp, IsInjectiveSource};
std::vector<InternalPattern> internal_patterns = helper.FuseISAndConvertRemainder(fusion_op);
if (const auto& opt_error = helper.Fuse_IS_x_PS_2_PS(&internal_patterns)) return opt_error.value();
if (const auto& opt_error = helper.Fuse_PS_x_PS_2_PS(&internal_patterns)) return opt_error.value();
if (const auto& opt_error = helper.Fuse_IS_x_PS_2_PS(&internal_patterns)) return opt_error.value();
if (const auto& opt_error = helper.Fuse_IS_x_R_2_R(&internal_patterns)) return opt_error.value();
if (const auto& opt_error = helper.Fuse_PS_x_R_2_R(&internal_patterns)) return opt_error.value();
return internal_patterns;
Expand Down

0 comments on commit ab42ae4

Please sign in to comment.