Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xk cinn trivalop fuse #69

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion paddle/cinn/frontend/group_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ struct ShardableAxesUtil {
return ret;
}

static ShardableAxes GetFullyShardableAxes(size_t rank) {
static ShardableAxes MakeFullyShardableAxes(const size_t rank) {
ShardableAxes ret;
for (int i = 0; i < rank; ++i) {
ret.emplace_back(ShardableAxis{
Expand All @@ -107,6 +107,27 @@ struct ShardableAxesUtil {
}
return ret;
}

static ShardableAxes MakeReduceOpInputShardableAxes(
const size_t input_rank, const std::vector<int64_t>& reduce_axes) {
if (reduce_axes.empty()) return ShardableAxes{};
for (int64_t reduce_axis : reduce_axes) {
CHECK_GE(reduce_axis, 0);
CHECK_LT(reduce_axis, input_rank);
}
const auto IsReduceAxis = [&](int64_t i) {
return std::find(reduce_axes.begin(), reduce_axes.end(), i) != reduce_axes.end();
};
ShardableAxes ret;
for (int64_t i = 0; i < input_rank; ++i) {
if (IsReduceAxis(i)) continue;
ret.emplace_back(ShardableAxis{
.axis=i,
.axis_name=std::string("D") + std::to_string(ShardableAxis::UnqiueSeqNo()),
});
}
return ret;
}
};

struct SoleOutputShardableAxes {
Expand Down
122 changes: 92 additions & 30 deletions paddle/cinn/frontend/group_pattern_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,88 @@ size_t GetRank(pir::Value value) {
return value.type().dyn_cast<pir::DenseTensorType>().dims().size();
}

std::vector<int64_t> GetReduceAxes(const pir::Operation* reduce_op) {
const size_t input_rank = GetRank(reduce_op->operand_source(0));
const auto& attr_val = reduce_op->attributes().at("dim");
CHECK(attr_val.isa<::pir::ArrayAttribute>());
const auto& axis_attr = attr_val.dyn_cast<::pir::ArrayAttribute>();
std::vector<int64_t> reduce_axes;
for (int i = 0; i < axis_attr.size(); ++i) {
int64_t axis = axis_attr.at(i).dyn_cast<::pir::Int64Attribute>().data();
if (axis < 0) {
axis += input_rank;
}
CHECK_GE(axis, 0);
CHECK_LT(axis, input_rank);
reduce_axes.push_back(axis);
}
return reduce_axes;
}

bool GetReduceOpKeepDims(const pir::Operation* reduce_op) {
const auto& attr_val = reduce_op->attributes().at("keep_dim");
CHECK(attr_val.isa<::pir::BoolAttribute>());
return attr_val.dyn_cast<::pir::BoolAttribute>();
}

ShardableAxes SequeezeShardableAxes(const ShardableAxes& sa) {
ShardableAxes ret_sa(sa);
for (int i = 0; i < ret_sa.size(); ++i) {
for (int j = i + 1; j < ret_sa.size(); ++j) {
CHECK_LT(ret_sa.at(i).axis, ret_sa.at(j).axis);
}
ret_sa.at(i).axis = i;
}
return ret_sa;
}

ShardableAxesSignature MakeEmptyShardableAxesSignature(const pir::Operation* op) {
const int result_idx = GetOutputShardableAxesResultIdx(op);
pir::Value output = op->result(result_idx);
ShardableAxes output_sa = ShardableAxesUtil::MakeFullyShardableAxes(GetRank(output));
using InputSignature = std::unordered_map<OpAndOperandIndex, ShardableAxes>;
InputSignature empty_input_sig;
for (int i = 0; i < op->num_operands(); ++i) {
empty_input_sig[OpAndOperandIndex{op, i}] = ShardableAxes{};
}
return ShardableAxesSignature{
.sole_output_sa = SoleOutputShardableAxes{
.shardable_axes=output_sa,
},
.input_shardable_axes = empty_input_sig,
};
}

ShardableAxesSignature MakeShardableAxesSignature4ReduceOp(
const pir::Operation* reduce_op) {
const size_t input_rank = GetRank(reduce_op->operand_source(0));
const auto& reduce_axes = GetReduceAxes(reduce_op);
const ShardableAxes input_sa =
ShardableAxesUtil::MakeReduceOpInputShardableAxes(input_rank, reduce_axes);
using InputSignature = std::unordered_map<OpAndOperandIndex, ShardableAxes>;
const ShardableAxes output_sa =
(GetReduceOpKeepDims(reduce_op) ? input_sa : SequeezeShardableAxes(input_sa));
return ShardableAxesSignature{
.sole_output_sa = SoleOutputShardableAxes{
.shardable_axes=output_sa,
},
.input_shardable_axes = InputSignature{
{OpAndOperandIndex{reduce_op, 0}, input_sa},
},
};
}

bool IsDisabledElementwiseOp(const pir::Operation* op) {
if (op->isa<cinn::dialect::ReshapeOp>()) return true;
return false;
}

ShardableAxesSignature MakeShardableAxesSignature4ElementWiseOp(
const pir::Operation* op) {
CHECK(!op->isa<cinn::dialect::ReshapeOp>())
<< "reshape not supported. TODO(wuzhanfei).";
if (IsDisabledElementwiseOp(op)) {
LOG(ERROR) << "[ShardableAxesSignature] no shardable axes signature found. op_name : " << op->name();
return MakeEmptyShardableAxesSignature(op);
}
const size_t rank = [&] {
std::optional<size_t> rank;
for (int i = 0; i < op->num_operands(); ++i) {
Expand All @@ -229,7 +307,7 @@ ShardableAxesSignature MakeShardableAxesSignature4ElementWiseOp(
return rank.value();
}();
const ShardableAxes output_shardable_axes =
ShardableAxesUtil::GetFullyShardableAxes(rank);
ShardableAxesUtil::MakeFullyShardableAxes(rank);
std::unordered_map<OpAndOperandIndex, ShardableAxes> input_shardable_axes;
for (int i = 0; i < op->num_operands(); ++i) {
input_shardable_axes[OpAndOperandIndex{op, i}] = output_shardable_axes;
Expand All @@ -244,21 +322,24 @@ ShardableAxesSignature MakeShardableAxesSignature4ElementWiseOp(

ShardableAxesSignature MakeShardableAxesSignature4BroadcastOp(
const pir::Operation* op) {
LOG(FATAL) << "TODO(wuzhanfei).";
LOG(ERROR) << "[ShardableAxesSignature] no shardable axes signature found. op_name : " << op->name();
return MakeEmptyShardableAxesSignature(op);
}

ShardableAxesSignature MakeShardableAxesSignature4Op(const pir::Operation* op) {
const hlir::framework::OpPatternKind kind = GetOpPatternKind(op);
if (kind == hlir::framework::kElementWise) {
if (kind == hlir::framework::kReduction) {
return MakeShardableAxesSignature4ReduceOp(op);
} else if (kind == hlir::framework::kElementWise) {
return MakeShardableAxesSignature4ElementWiseOp(op);
} else if (kind == hlir::framework::kBroadcast) {
return MakeShardableAxesSignature4BroadcastOp(op);
} else {
LOG(FATAL)
<< "only kReduction, kElementWise, kBroadcast supported. op_name:"
LOG(ERROR)
<< "[ShardableAxesSignature] no shardable axes signature found. op_name:"
<< op->name();
}
LOG(FATAL) << "Dead code";
return MakeEmptyShardableAxesSignature(op);
}

template<typename InputIt>
Expand Down Expand Up @@ -498,7 +579,7 @@ std::unordered_map<pir::Value, ShardableAxes> InferShardableAxesFromSink(
CHECK_GT(op_topo.ops->count(sink), 0);
const int result_idx = GetOutputShardableAxesResultIdx(sink);
size_t rank = GetRank(sink->result(result_idx));
const auto& init_sa = ShardableAxesUtil::GetFullyShardableAxes(rank);
const auto& init_sa = ShardableAxesUtil::MakeFullyShardableAxes(rank);
return ReversedInferShardableAxes(reversed_walker, sink, init_sa);
}

Expand Down Expand Up @@ -1718,22 +1799,7 @@ class LoopAlignableClusteringPolicy final : public ClusteringPolicy {
const pir::Operation* reduce_op,
const ShardableAxes& shardable_axes) {
const size_t input_rank = GetRank(reduce_op->operand_source(0));
const auto& reduce_axes = [&]{
const auto& attr_val = reduce_op->attributes().at("dim");
CHECK(attr_val.isa<::pir::ArrayAttribute>());
const auto& axis_attr = attr_val.dyn_cast<::pir::ArrayAttribute>();
std::vector<int64_t> reduce_axes;
for (int i = 0; i < axis_attr.size(); ++i) {
int64_t axis = axis_attr.at(i).dyn_cast<::pir::Int64Attribute>().data();
if (axis < 0) {
axis += input_rank;
}
CHECK_GE(axis, 0);
CHECK_LT(axis, input_rank);
reduce_axes.push_back(axis);
}
return reduce_axes;
}();
const auto& reduce_axes = GetReduceAxes(reduce_op);

// no shardability if input reduced into one element.
if (reduce_axes.empty()) return false;
Expand All @@ -1747,11 +1813,7 @@ class LoopAlignableClusteringPolicy final : public ClusteringPolicy {
};
return std::find_if(shardable_axes.begin(), shardable_axes.end(), Condition) != shardable_axes.end();
};
const bool keepdims = [&]{
const auto& attr_val = reduce_op->attributes().at("keep_dim");
CHECK(attr_val.isa<::pir::BoolAttribute>());
return attr_val.dyn_cast<::pir::BoolAttribute>();
}();
const bool keepdims = GetReduceOpKeepDims(reduce_op);
if (keepdims) {
const size_t output_rank = input_rank;
CHECK(!reduce_axes.empty());
Expand Down