Skip to content

Commit

Permalink
Merge branch 'develop' of github.com:gouzil/Paddle into dy2st_pir_api…
Browse files Browse the repository at this point in the history
…_push_9

# Conflicts:
#	python/paddle/pir/math_op_patch.py
  • Loading branch information
gouzil committed Nov 24, 2023
2 parents 607b7ba + 7a1795c commit b834e89
Show file tree
Hide file tree
Showing 176 changed files with 5,488 additions and 1,318 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class GroupOpPattern : public pir::OpRewritePattern<cinn::dialect::GroupOp> {
cinn::dialect::ir::GeneralFusionMergePassInternal(op_fusion);

for (auto group : group_list) {
auto ir_compiler = std::make_shared<cinn::hlir::framework::PirCompiler>(
auto ir_compiler = cinn::hlir::framework::PirCompilerManager::Create(
*program, target, scope);
if (FLAGS_cinn_enable_map_expr) {
cinn::adt::TryGenerateMapExprFromGroup(group);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1395,7 +1395,8 @@ class GeneralFusionMergePassHelper {
}
// master node
for (auto& node : consumer->master_ops) {
if (GetOpKind(node->name()) == OpPatternKind::kReduction) {
if (hlir::framework::pir::CompatibleInfo::OpKind(*node) ==
OpPatternKind::kReduction) {
fused_group->master_ops.insert(node);
}
}
Expand Down Expand Up @@ -1474,7 +1475,8 @@ class GeneralFusionMergePassHelper {
++consumer) {
::pir::Operation* master_node = nullptr;
for (auto& node : (*consumer)->master_ops) {
if (GetOpKind(node->name()) != OpPatternKind::kReduction) {
if (hlir::framework::pir::CompatibleInfo::OpKind(*node) !=
OpPatternKind::kReduction) {
master_node = node;
break;
}
Expand Down Expand Up @@ -1609,7 +1611,8 @@ class GeneralFusionMergePassHelper {
}
// master nodes
for (auto& node : producer->master_ops) {
if (GetOpKind(node->name()) == OpPatternKind::kReduction) {
if (hlir::framework::pir::CompatibleInfo::OpKind(*node) ==
OpPatternKind::kReduction) {
fused_group->master_ops.insert(node);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ static bool ReduceFuseReduce1(const OpGroupPtr& first,
// }
std::unique_ptr<cinn::dialect::ir::OpNode> reducer_0 = nullptr;
for (auto op : first.GetGroup()->CollectOps()) {
if (GetOpKind(op->name()) == OpPatternKind::kReduction) {
if (hlir::framework::pir::CompatibleInfo::OpKind(*op) ==
OpPatternKind::kReduction) {
reducer_0.reset(new cinn::dialect::ir::OpNode(op));
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ inline bool elementwise_fuse_reduce(const std::shared_ptr<ir::Group>& first,
// if reduce using block_reduce, can't fuse producer.
::pir::Operation* reducer = nullptr;
for (auto& node : second->master_ops) {
if (GetOpKind(node->name()) == OpPatternKind::kReduction) {
if (hlir::framework::pir::CompatibleInfo::OpKind(*node) ==
OpPatternKind::kReduction) {
reducer = node;
break;
}
Expand Down Expand Up @@ -291,7 +292,8 @@ inline bool broadcast_fuse_reduce(const std::shared_ptr<ir::Group>& first,
}
::pir::Operation* reducer = nullptr;
for (auto& node : second->master_ops) {
if (GetOpKind(node->name()) == OpPatternKind::kReduction) {
if (hlir::framework::pir::CompatibleInfo::OpKind(*node) ==
OpPatternKind::kReduction) {
reducer = node;
break;
}
Expand Down Expand Up @@ -339,7 +341,7 @@ inline bool horizontal_relation(const std::shared_ptr<ir::Group>& first,
OpPatternKind kind) {
std::unordered_set<::pir::Operation*> selected;
for (auto node : nodes) {
if (GetOpKind(node->name()) == kind) {
if (hlir::framework::pir::CompatibleInfo::OpKind(*node) == kind) {
selected.insert(node);
}
}
Expand Down Expand Up @@ -425,7 +427,8 @@ inline bool reduce_fuse_broadcast(const std::shared_ptr<ir::Group>& first,
// required that each consumer of type Broadcast meet the same shape after
// broadcast as before reduce.
for (auto& node_in_master : first->master_ops) {
if (GetOpKind(node_in_master->name()) != OpPatternKind::kReduction) {
if (hlir::framework::pir::CompatibleInfo::OpKind(*node_in_master) !=
OpPatternKind::kReduction) {
continue;
}
::pir::Operation* reducer = node_in_master;
Expand Down Expand Up @@ -488,7 +491,8 @@ inline bool reduce_fuse_broadcast(const std::shared_ptr<ir::Group>& first,
visited_set.insert(consumer);
candidates.push(consumer);
}
if (GetOpKind(consumer->name()) == OpPatternKind::kBroadcast &&
if (hlir::framework::pir::CompatibleInfo::OpKind(*consumer) ==
OpPatternKind::kBroadcast &&
second->OpSet().find(consumer) != second->OpSet().end()) {
broadcasters.insert(consumer);
}
Expand Down Expand Up @@ -552,7 +556,8 @@ inline bool reduce_fuse_reduce(const std::shared_ptr<ir::Group>& first,
}
::pir::Operation* reducer_0 = nullptr;
for (auto& reducer : first->master_ops) {
if (GetOpKind(reducer->name()) == OpPatternKind::kReduction) {
if (hlir::framework::pir::CompatibleInfo::OpKind(*reducer) ==
OpPatternKind::kReduction) {
reducer_0 = reducer;
break;
}
Expand All @@ -561,7 +566,8 @@ inline bool reduce_fuse_reduce(const std::shared_ptr<ir::Group>& first,

::pir::Operation* reducer_1 = nullptr;
for (auto& reducer : second->master_ops) {
if (GetOpKind(reducer->name()) == OpPatternKind::kReduction) {
if (hlir::framework::pir::CompatibleInfo::OpKind(*reducer) ==
OpPatternKind::kReduction) {
reducer_1 = reducer;
break;
}
Expand Down Expand Up @@ -598,7 +604,8 @@ inline bool reduce_fuse_reduce(const std::shared_ptr<ir::Group>& first,
auto shared_size = 0;
for (auto& fusion_group : {first, second}) {
for (auto* master : fusion_group->master_ops) {
if (GetOpKind(master->name()) == OpPatternKind::kReduction) {
if (hlir::framework::pir::CompatibleInfo::OpKind(*master) ==
OpPatternKind::kReduction) {
shared_size += GetSharedSize(master);
}
}
Expand All @@ -619,7 +626,8 @@ inline bool reduce_fuse_reduce(const std::shared_ptr<ir::Group>& first,
auto shared_size = 0;
for (auto& fusion_group : {first, second}) {
for (auto* master : fusion_group->master_ops) {
if (GetOpKind(master->name()) == OpPatternKind::kReduction) {
if (hlir::framework::pir::CompatibleInfo::OpKind(*master) ==
OpPatternKind::kReduction) {
shared_size += GetSharedSize(master);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class OpNode {
: node_(node), input_tensors_(node), output_tensors_(node) {}

OpPatternKind kind() const {
auto kind = GetOpKind(node_->name());
auto kind = hlir::framework::pir::CompatibleInfo::OpKind(*node_);
if (kind == OpPatternKind::kBroadcast) {
// As binary op was defined as broadcast, actually it should be
// element-wise.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,45 +33,6 @@ namespace cinn {
namespace dialect {
namespace ir {

std::unordered_map<std::string, OpPatternKind> OpKindMap = {
{"pd_op.add", OpPatternKind::kElementWise},
{"pd_op.subtract", OpPatternKind::kElementWise},
{"pd_op.multiply", OpPatternKind::kElementWise},
{"pd_op.divide", OpPatternKind::kElementWise},
{"pd_op.sqrt", OpPatternKind::kElementWise},
{"pd_op.rsqrt", OpPatternKind::kElementWise},
{"pd_op.full", OpPatternKind::kElementWise},
{"pd_op.relu", OpPatternKind::kElementWise},
{"pd_op.exp", OpPatternKind::kElementWise},
{"pd_op.sin", OpPatternKind::kElementWise},
{"pd_op.cos", OpPatternKind::kElementWise},
{"pd_op.pow", OpPatternKind::kElementWise},
{"pd_op.elementwise_pow", OpPatternKind::kElementWise},
{"pd_op.sum", OpPatternKind::kReduction},
{"cinn_op.reshape", OpPatternKind::kElementWise},
{"pd_op.cast", OpPatternKind::kElementWise},
{"pd_op.greater_than", OpPatternKind::kElementWise},
{"pd_op.greater_equal", OpPatternKind::kElementWise},
{"pd_op.transpose", OpPatternKind::kInjective},
{"pd_op.gather_nd", OpPatternKind::kInjective},
{"cinn_op.scale", OpPatternKind::kElementWise},
{"cinn_op.concat", OpPatternKind::kInjective},
{"cinn_op.slice", OpPatternKind::kInjective},
{"cinn_op.reduce_sum", OpPatternKind::kReduction},
{"cinn_op.reduce_max", OpPatternKind::kReduction},
{"cinn_op.broadcast", OpPatternKind::kBroadcast},
{"cinn_op.uniform_random", OpPatternKind::kElementWise}};

OpPatternKind GetOpKind(const std::string& op_name) {
auto found_it = OpKindMap.find(op_name);
if (found_it == OpKindMap.end()) {
PADDLE_THROW(phi::errors::Unavailable(
"not support [%s] op yet in op kind map", op_name));
}

return found_it->second;
}

std::vector<pir::Operation*> GetProducerOpsReverseSort(
pir::Operation* op,
const std::unordered_map<pir::Operation*, size_t>& op2id) {
Expand Down Expand Up @@ -323,7 +284,8 @@ class OpFusionPassHelper {
}

// group type
group->op_pattern_kind = GetOpKind(op->name());
group->op_pattern_kind =
hlir::framework::pir::CompatibleInfo::OpKind(*op);
// use current op as master op for schedule
group->master_ops.insert(op);

Expand Down Expand Up @@ -389,7 +351,8 @@ class OpFusionPassHelper {
private:
void DoOpFusion() {
for (auto consumer : ops_) {
auto consumer_kind = GetOpKind(consumer->name());
auto consumer_kind =
hlir::framework::pir::CompatibleInfo::OpKind(*consumer);
// kNonFusible op can't fuse any other op.
if (consumer_kind == OpPatternKind::kNonFusible) {
continue;
Expand Down Expand Up @@ -418,7 +381,8 @@ class OpFusionPassHelper {
continue;
}
// kNonFusible op can't fuse any other op.
auto producer_kind = GetOpKind(producer->name());
auto producer_kind =
hlir::framework::pir::CompatibleInfo::OpKind(*producer);
if (producer_kind == OpPatternKind::kNonFusible) {
continue;
}
Expand Down Expand Up @@ -625,13 +589,17 @@ class OpFusionPassHelper {
}

bool CanFuse(::pir::Operation* producer, const ::pir::Operation* consumer) {
auto& relation = fusion_relation_map_[GetOpKind(producer->name())];
auto& relation =
fusion_relation_map_[hlir::framework::pir::CompatibleInfo::OpKind(
*producer)];
// first step: check producer can be fused into consumer
if (relation.op_kind.count(GetOpKind(consumer->name()))) {
if (relation.op_kind.count(
hlir::framework::pir::CompatibleInfo::OpKind(*consumer))) {
auto& consumer_group = fusion_groups_[consumer];
// second step: check producer can be fused into consumer group
VLOG(3) << "Call ConditionFunction, Producer Op Pattern : "
<< GetOpKind(producer->name()) << " , Consumer Group Pattern : "
<< hlir::framework::pir::CompatibleInfo::OpKind(*producer)
<< " , Consumer Group Pattern : "
<< consumer_group->op_pattern_kind;

return relation.fusion_op_kind[consumer_group->op_pattern_kind](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <vector>

#include "paddle/cinn/hlir/framework/pir/group.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/pir/core/operation.h"
Expand Down Expand Up @@ -61,8 +62,6 @@ std::vector<T> GetVectorAttr(const ::pir::Operation* op,
return vec_res;
}

OpPatternKind GetOpKind(const std::string& op_name);

phi::DDim GetFirstInputShape(const ::pir::Operation* op);

phi::DDim GetValueShape(const ::pir::Value& value);
Expand Down Expand Up @@ -114,7 +113,8 @@ inline bool reduce_fuse_reduce(::pir::Operation* producer,
const std::shared_ptr<Group>& consumer) {
::pir::Operation* reducer = NULL;
for (auto* master : consumer->master_ops) {
if (GetOpKind(master->name()) == OpPatternKind::kReduction) {
if (hlir::framework::pir::CompatibleInfo::OpKind(*master) ==
OpPatternKind::kReduction) {
reducer = master;
break;
}
Expand Down Expand Up @@ -157,7 +157,8 @@ inline bool reduce_fuse_reduce(::pir::Operation* producer,
if (input_shape_same || without_last_dim) {
auto shared_size = GetSharedSize(producer);
for (auto* master : consumer->master_ops) {
if (GetOpKind(master->name()) == OpPatternKind::kReduction) {
if (hlir::framework::pir::CompatibleInfo::OpKind(*master) ==
OpPatternKind::kReduction) {
shared_size += GetSharedSize(master);
}
}
Expand Down Expand Up @@ -207,7 +208,8 @@ inline bool is_horizontal_relation(::pir::Operation* producer,
};

for (auto op : consumer->ops_set) {
if (GetOpKind(op->name()) != consumer->op_pattern_kind) {
if (hlir::framework::pir::CompatibleInfo::OpKind(*op) !=
consumer->op_pattern_kind) {
continue;
}
if (check_depency(op)) {
Expand All @@ -228,7 +230,8 @@ inline bool horizontal_or_vertical_reduce_relation(
// reducer op in fusion op.
::pir::Operation* reducer = NULL;
for (auto* master : consumer->master_ops) {
if (GetOpKind(master->name()) == OpPatternKind::kReduction) {
if (hlir::framework::pir::CompatibleInfo::OpKind(*master) ==
OpPatternKind::kReduction) {
reducer = master;
break;
}
Expand Down Expand Up @@ -385,7 +388,8 @@ inline bool reduce_fuse_broadcast(::pir::Operation* producer,
};

for (auto op : consumer->ops_set) {
if (GetOpKind(op->name()) != OpPatternKind::kBroadcast) {
if (hlir::framework::pir::CompatibleInfo::OpKind(*op) !=
OpPatternKind::kBroadcast) {
continue;
}

Expand Down
4 changes: 0 additions & 4 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,6 @@ std::vector<ir::LoweredFunc> OpLowererImpl::Lower(const GroupPtr& group,
bool apply_pass) {
VLOG(3) << "Lowering Group : " << group->group_id
<< " , Op Pattern : " << group->op_pattern_kind;
// TODO(Aurelius84): The logic shoule be moved into op_fusion module.
if (group->ops.size() >= 1U & group->output_ops.size() == 0) {
group->output_ops.insert(group->ops[group->ops.size() - 1]);
}
group->input_names.clear();
group->output_names.clear();
switch (group->op_pattern_kind) {
Expand Down
Loading

0 comments on commit b834e89

Please sign in to comment.