Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#33 from tc20042008/xk-cinn-trivalop-fuse
Browse files Browse the repository at this point in the history
Xk cinn trivalop fuse
  • Loading branch information
tc20042008 authored Mar 7, 2024
2 parents fedac6c + 1ea7ff5 commit 9b379f5
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 49 deletions.
13 changes: 11 additions & 2 deletions paddle/cinn/frontend/group_pattern.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#pragma once

#include <unordered_map>
#include <atomic>
#include <vector>
#include "paddle/cinn/api/op_topo_pattern.h"
#include "paddle/pir/include/core/operation.h"

Expand Down Expand Up @@ -28,15 +30,22 @@ struct SingleReductionOpPattern<frontend::FrontendPattern> {
const pir::Operation* reduce_op;
};

struct ShardableAxes {
struct ShardableAxis {
int axis;
std::string axis_name;

static int64_t UnqiueSeqNo() {
static std::atomic<int64_t> cnt(0);
return ++cnt;
}
};

using ShardableAxes = std::vector<ShardableAxis>;

struct ShardableAxesSignature {
using OpOperand = std::pair<const pir::Operation*, /*operand index*/int>;

std::vector<ShardableAxes> output_shardable_axes;
ShardableAxes output_shardable_axes;
std::unordered_map<OpOperand, ShardableAxes> input_shardable_axes;
};

Expand Down
193 changes: 146 additions & 47 deletions paddle/cinn/frontend/group_pattern_util.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "paddle/cinn/frontend/group_pattern_util.h"
#include "paddle/cinn/common/topo_walker.h"
#include "paddle/cinn/common/bfs_walker.h"
#include "paddle/cinn/hlir/framework/op.h"
#include <optional>

Expand All @@ -16,7 +17,20 @@ hlir::framework::OpPatternKind GetOpPatternKind(const ::pir::Operation* node) {
return hlir::framework::pir::CompatibleInfo::OpKind(*node);
}

std::function<bool(const pir::Operation*)> MakeGetterIsInThisFusionOp(const cinn::dialect::FusionOp& fusion_op) {
std::function<size_t(const pir::Operation*)> MakeGetterOrderValue4Op(const cinn::dialect::FusionOp& fusion_op) {
std::unordered_map<pir::Operation*, size_t> op2order_in_block;
size_t order = 0;
for (const pir::Operation* op : fusion_op.block()->ops()) {
op2order_in_block[op] = ++order;
}
return [map=std::move(op2order_in_block)](const pir::Operation* op) {
const auto& iter = map.find(op);
CHECK(iter != map.end());
return iter->second;
};
}

std::function<bool(const pir::Operation*)> MakePredicatorIsInThisFusionOp(const cinn::dialect::FusionOp& fusion_op) {
std::set<pir::Operation*> set;
for (const pir::Operation* op : fusion_op.block()->ops()) {
if (!op->isa<pir::YieldOp>()) {
Expand All @@ -35,19 +49,19 @@ bool IsGeneralInjective(const pir::Operation* op) {
|| op_pattern_kind == hlir::framework::kInjective;
}

std::function<bool(const pir::Operation*)> MakeGetterIsInjectiveSource(
std::function<bool(const pir::Operation*)> MakePredicatorIsInjectiveSource(
const cinn::dialect::FusionOp& fusion_op,
const std::function<bool(const pir::Operation*)>& IsInThisFusionOp) {
using NodeVisitor = std::function<void(pir::Operation*)>;
const auto VisitEachInput = [&](const pir::Operation* node, const NodeVisitor& DoEach) {
const auto VisitEachInput = [&](const pir::Operation* op, const NodeVisitor& DoEach) {
for (int i = 0; i < op->num_operands(); ++i) {
const auto* input_op = op->operand_source(i).defining_op();
if (IsInThisFusionOp(input_op)) {
DoEach(input_op);
}
}
};
const auto VisitEachOutput = [&](const pir::Operation* node, const NodeVisitor& DoEach) {
const auto VisitEachOutput = [&](const pir::Operation* op, const NodeVisitor& DoEach) {
for (int i = 0; i < op->num_results(); ++i) {
pir::Value output = op->result(i);
for (auto consumer_it = output.use_begin(); consumer_it != output.use_end(); ++consumer_it) {
Expand Down Expand Up @@ -98,52 +112,153 @@ std::function<bool(const pir::Operation*)> MakeGetterIsInjectiveSource(
};
}

struct StmtFusionHelper {
const std::function<bool(const pir::Operation*)> IsInThisFusionOp;
const std::function<bool(const pir::Operation*)> IsInjectiveSource;
class StmtFusionHelper {
public:
explicit StmtFusionHelper(const cinn::dialect::FusionOp& fusion_op)
: fusion_op_(fusion_op) {
this->IsInThisFusionOp = MakePredicatorIsInThisFusionOp(fusion_op_);
this->IsInjectiveSource = MakePredicatorIsInjectiveSource(fusion_op_, this->IsInThisFusionOp);
}

std::vector<StmtPattern> FuseISAndConvertRemainder(const cinn::dialect::FusionOp& fusion_op) const {
const auto& [injective_source_ops, remainder_ops] = SplitInjectiveSourceOps(fusion_op);
std::vector<StmtPattern> FuseISAndConvertRemainder() const {
std::vector<StmtPattern> ret;
FuseInjectiveSourceThenAppend(injective_source_ops, &ret);
for (const auto& op : remainder_ops) {
FuseInjectiveSourceThenAppend(fusion_op_, &ret);
for (const auto* op : fusion_op_.block()->ops()) {
if (!IsInThisFusionOp(op)) continue;
if (IsInjectiveSource(op)) continue;
ret.emplace_back(ConvertNonInjectiveSourceToStmtPattern(op));
}
return ret;
}

void FuseInjectiveSourceThenAppend(
const std::list<const pir::Operation*>& injective_source_ops,
std::vector<StmtPattern>* ret) {
using IterType = std::list<const pir::Operation*>::iterator;
TODO();
std::vector<StmtPattern>* ret) const {
auto GetOrder = MakeGetterOrderValue4Op(fusion_op_);
auto Cmp = [&](const auto* lhs, const auto& rhs) {
return GetOrder(lhs) < GetOrder(rhs);
};
VisitConnectedInjectiveSource([&](std::vector<const pir::Operation*>&& ops){
std::sort(ops.begin(), ops.end(), Cmp);
ret->emplace_back(IS{ops});
});
}

StmtPattern ConvertNonInjectiveSourceToStmtPattern(const pir::Operation* op) {
template <typename DoEachT>
void VisitConnectedInjectiveSource(
const DoEachT& DoEach) const {
const auto VisitNext = [&](const pir::Operation* node, const OpVisitor& DoEach) {
VisitInputInjectiveSource(node, DoEach);
VisitOutputInjectiveSource(node, DoEach);
};
common::BfsWalker<const pir::Operation*> bfs_walker(VisitNext);
std::unordered_set<const pir::Operation*> visisted_ops;
for (const auto* start : fusion_op_.block()->ops()) {
if (!IsInThisFusionOp(start)) continue;
if (!IsInjectiveSource(start)) continue;
if (visisted_ops.count(start) > 0) continue;
std::vector<const pir::Operation*> current_visited_ops;
bfs_walker(start, [&](const pir::Operation* op){
CHECK(visisted_ops.emplace(op).second);
current_visited_ops.push_back(op);
});
DoEach(std::move(current_visited_ops));
}
}

using OpVisitor = std::function<void(const pir::Operation*)>;

void VisitInputInjectiveSource(const pir::Operation* op, const OpVisitor& DoEach) const {
for (int i = 0; i < op->num_operands(); ++i) {
const auto* input_op = op->operand_source(i).defining_op();
if (IsInThisFusionOp(input_op) && IsInjectiveSource(input_op)) {
DoEach(input_op);
}
}
}

void VisitOutputInjectiveSource(const pir::Operation* op, const OpVisitor& DoEach) const {
for (int i = 0; i < op->num_results(); ++i) {
pir::Value output = op->result(i);
for (auto consumer_it = output.use_begin(); consumer_it != output.use_end(); ++consumer_it) {
const auto* consumer_op = consumer_it->owner();
if (IsInThisFusionOp(consumer_op) && IsInjectiveSource(input_op)) {
DoEach(consumer_op);
}
}
}
}

StmtPattern ConvertNonInjectiveSourceToStmtPattern(const pir::Operation* op) const {
const hlir::framework::OpPatternKind kind = GetOpPatternKind(op);
if (kind == hlir::framework::kReduction) {
return ConvertReductionOpToStmtPattern(op);
return ConvertReductionOpToReductionPattern(op);
} else if (kind == hlir::framework::kElementWise) {
return ConvertElementwiseOpToStmtPattern(op);
return ConvertElementwiseOpToPS(op);
} else if (kind == hlir::framework::kBroadcast) {
return ConvertBroadcastOpToStmtPattern(op);
return ConvertBroadcastOpToPS(op);
} else {
LOG(FATAL) << "only kReduction, kElementWise, kBroadcast supported. op_name:" << op->op_name();
}
LOG(FATAL) << "Dead code";
}

StmtPattern ConvertReductionOpToStmtPattern(const pir::Operation* op) {
R ConvertReductionOpToReductionPattern(const pir::Operation* op) const {
return R{{}, {op}};
}

StmtPattern ConvertElementwiseOpToStmtPattern(const pir::Operation* op) {
CHECK(!op->isa<cinn::dialect::ReshapeOp>()) << "reshape not supported.";
TODO();
PS ConvertElementwiseOpToPS(const pir::Operation* op) const {
CHECK(!op->isa<cinn::dialect::ReshapeOp>()) << "reshape not supported. TODO(wuzhanfei).";
const auto& GetRank = [](pir::Value value) -> size_t {
return value.type().dyn_cast<pir::DenseTensorType>().dims().size();
};
const size_t rank = [&]{
std::optional<size_t> rank;
for (int i = 0; i < op->num_operands(); ++i) {
if (rank.has_value()) {
CHECK_EQ(rank.value(), GetRank(op->operand_source(i)));
} else {
rank = GetRank(op->operand_source(i));
}
}
CHECK_EQ(op->num_results(), 1);
if (rank.has_value()) {
CHECK_EQ(rank.value(), GetRank(op->result(0)));
} else {
rank = GetRank(op->result(0));
}
CHECK(rank.has_value());
return rank.value();
}();
const auto& shardable_axes_signature = [&]{
const ShardableAxes shardable_axes = GetElementwiseOpShardableAxes(rank);
std::unordered_map<OpOperand, ShardableAxes> input_shardable_axes;
for (int i = 0; i < op->num_operands(); ++i) {
input_shardable_axes[std::pair(op, i)] = shardable_axes;
}
return ShardableAxesSignature{
.output_shardable_axes,
.input_shardable_axes=input_shardable_axes,
};
}();
return PS{
.ops={op},
.shardable_axes_signature=shardable_axes_signature,
};
}

ShardableAxes GetElementwiseOpShardableAxes(size_t rank) const {
ShardableAxes ret;
for (int i = 0; i < rank; ++i) {
ret.emplace_back(ShardableAxis{
.axis=i,
.axis_name=std::string("D") + std::to_string(ShardableAxis::UnqiueSeqNo())
});
}
return ret;
}

StmtPattern ConvertBroadcastOpToStmtPattern(const pir::Operation* op) {
LOG(FATAL) << "TODO(wuzhanfei)";
PS ConvertBroadcastOpToPS(const pir::Operation* op) const {
LOG(FATAL) << "TODO(wuzhanfei).";
}

std::variant<IternalPattern, ErrorGroupPattern> MergePattern(
Expand Down Expand Up @@ -187,24 +302,6 @@ struct StmtFusionHelper {
return new_pattern;
}

SplitedOps SplitInjectiveSourceOps(const cinn::dialect::FusionOp& fusion_op) {
SplitedOps ret;
for (const auto& op : fusion_op.block().ops()) {
if (!IsInThisFusionOp(op)) continue;
if (IsInjectiveSource(op)) {
ret.injective_source_ops.push_back(op);
} else {
ret.remainder_ops.push_back(op);
}
}
return ret;
}

struct SplitedOps {
std::list<const pir::Operation*> injective_source_ops;
std::list<const pir::Operation*> remainder_ops;
}

std::optional<std::pair<StmtPattern, StmtPattern>> FindConnetedPattenPairWithCondition(
std::vector<StmtPattern>* stmt_patterns,
std::function<bool(const IternalPattern& upstream, const IternalPattern& downstream)>& FuseTargetCondition) const {
Expand Down Expand Up @@ -286,13 +383,15 @@ struct StmtFusionHelper {
);
}

private:
cinn::dialect::FusionOp fusion_op_;
std::function<bool(const pir::Operation*)> IsInThisFusionOp;
std::function<bool(const pir::Operation*)> IsInjectiveSource;
};

GroupPattern FuseToGroupPattern(const cinn::dialect::FusionOp& fusion_op) {
const auto& IsInThisFusionOp = MakeGetterIsInThisFusionOp(fusion_op);
const auto& IsInjectiveSource = MakeGetterIsInjectiveSource(fusion_op, IsInThisFusionOp);
StmtFusionHelper helper{IsInThisFusionOp, IsInjectiveSource};
std::vector<StmtPattern> stmt_patterns = helper.FuseISAndConvertRemainder(fusion_op);
StmtFusionHelper helper(fusion_op);
std::vector<StmtPattern> stmt_patterns = helper.FuseISAndConvertRemainder();
if (const auto& error = helper.Fuse_PS_x_PS_2_PS(&stmt_patterns)) return error.value();
if (const auto& error = helper.Fuse_IS_x_PS_2_PS(&stmt_patterns)) return error.value();
if (const auto& error = helper.Fuse_IS_x_R_2_R(&stmt_patterns)) return error.value();
Expand Down

0 comments on commit 9b379f5

Please sign in to comment.