Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIN+PIR]Fix SplitOpPattern Bug in pd_to_cinn_pass #60669

Merged
merged 4 commits into from
Jan 10, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 87 additions & 102 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ class ScaleOpPattern : public pir::OpRewritePattern<paddle::dialect::ScaleOp> {

bool MatchAndRewrite(paddle::dialect::ScaleOp op,
pir::PatternRewriter &rewriter) const override {
auto scale_factor_gen_op =
op->operand_source(1).dyn_cast<pir::OpResult>().owner();
auto scale_factor_gen_op = op->operand_source(1).defining_op();

if (auto full_op =
scale_factor_gen_op->dyn_cast<paddle::dialect::FullOp>()) {
Expand Down Expand Up @@ -190,8 +189,7 @@ class ReshapeOpPattern

bool MatchAndRewrite(paddle::dialect::ReshapeOp op,
pir::PatternRewriter &rewriter) const override {
auto scale_factor_gen_op =
op->operand_source(1).dyn_cast<pir::OpResult>().owner();
auto scale_factor_gen_op = op->operand_source(1).defining_op();

if (auto full_op =
scale_factor_gen_op->dyn_cast<paddle::dialect::FullIntArrayOp>()) {
Expand Down Expand Up @@ -232,8 +230,7 @@ class Pool2dOpPattern

bool MatchAndRewrite(paddle::dialect::Pool2dOp op,
pir::PatternRewriter &rewriter) const override {
auto kernel_size_gen_op =
op->operand_source(1).dyn_cast<pir::OpResult>().owner();
auto kernel_size_gen_op = op->operand_source(1).defining_op();

if (auto full_op =
kernel_size_gen_op->dyn_cast<paddle::dialect::FullIntArrayOp>()) {
Expand Down Expand Up @@ -279,13 +276,11 @@ class IsCloseOpPattern
bool MatchAndRewrite(paddle::dialect::IscloseOp op,
pir::PatternRewriter &rewriter) const override {
auto rtol_op = op->operand_source(2)
.dyn_cast<pir::OpResult>()
.owner()
.defining_op()
->dyn_cast<paddle::dialect::FullOp>();

auto atol_op = op->operand_source(3)
.dyn_cast<pir::OpResult>()
.owner()
.defining_op()
->dyn_cast<paddle::dialect::FullOp>();

if (rtol_op && atol_op) {
Expand Down Expand Up @@ -318,13 +313,11 @@ class SliceOpPattern : public pir::OpRewritePattern<paddle::dialect::SliceOp> {
bool MatchAndRewrite(paddle::dialect::SliceOp op,
pir::PatternRewriter &rewriter) const override {
auto start_gen_op = op->operand_source(1)
.dyn_cast<pir::OpResult>()
.owner()
.defining_op()
->dyn_cast<paddle::dialect::FullIntArrayOp>();

auto end_gen_op = op->operand_source(2)
.dyn_cast<pir::OpResult>()
.owner()
.defining_op()
->dyn_cast<paddle::dialect::FullIntArrayOp>();

if (start_gen_op && end_gen_op) {
Expand Down Expand Up @@ -360,18 +353,13 @@ class ConcatOpPattern

bool MatchAndRewrite(paddle::dialect::ConcatOp op,
pir::PatternRewriter &rewriter) const override {
auto axis_gen_op = op->operand_source(1).dyn_cast<pir::OpResult>().owner();
auto axis_gen_op = op->operand_source(1).defining_op();
if (auto full_op = axis_gen_op->dyn_cast<paddle::dialect::FullOp>()) {
int axis = phi::Scalar(full_op.attribute("value")
.dyn_cast<::pir::FloatAttribute>()
.data())
.to<int>();
int axis = static_cast<int>(
full_op.attribute("value").dyn_cast<::pir::FloatAttribute>().data());

auto input_ops = op->operand_source(0)
.dyn_cast<pir::OpResult>()
.owner()
->dyn_cast<pir::CombineOp>()
.inputs();
auto input_ops =
op->operand(0).owner()->dyn_cast<pir::CombineOp>().inputs();

auto cinn_concat =
rewriter.Build<cinn::dialect::ConcatOp>(input_ops, axis);
Expand Down Expand Up @@ -413,12 +401,10 @@ class SplitOpPattern : public pir::OpRewritePattern<paddle::dialect::SplitOp> {
bool MatchAndRewrite(paddle::dialect::SplitOp op,
pir::PatternRewriter &rewriter) const override {
auto sections_gen_op = op->operand_source(1)
.dyn_cast<pir::OpResult>()
.owner()
.defining_op()
->dyn_cast<paddle::dialect::FullIntArrayOp>();
auto axis_gen_op = op->operand_source(2)
.dyn_cast<pir::OpResult>()
.owner()
.defining_op()
->dyn_cast<paddle::dialect::FullOp>();
if (sections_gen_op && axis_gen_op) {
auto section_attr = sections_gen_op.attribute("value")
Expand All @@ -432,11 +418,9 @@ class SplitOpPattern : public pir::OpRewritePattern<paddle::dialect::SplitOp> {
section_attr[i].dyn_cast<::pir::Int64Attribute>().data());
}
}

int axis = phi::Scalar(axis_gen_op.attribute("value")
.dyn_cast<::pir::FloatAttribute>()
.data())
.to<int>();
int axis = static_cast<int>(axis_gen_op.attribute("value")
.dyn_cast<::pir::FloatAttribute>()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里似乎有一个假设,这里的full出来的axis值从float到int的转型一定是精度安全的

.data());

auto input_ele = op->operand_source(0)
.type()
Expand All @@ -448,15 +432,76 @@ class SplitOpPattern : public pir::OpRewritePattern<paddle::dialect::SplitOp> {
auto cinn_split = rewriter.Build<cinn::dialect::SplitOp>(
op->operand_source(0), vec_sections, axis);

auto build_split =
op->result(0).first_use().owner()->dyn_cast<::pir::SplitOp>();
auto orig_out = op.result(0);
for (auto it = orig_out.use_begin(); it != orig_out.use_end();) {
auto slice_op = (it++)->owner();
CHECK(slice_op->isa<::pir::SliceOp>());
int index = slice_op->dyn_cast<::pir::SliceOp>()
.attribute("index")
.dyn_cast<::pir::Int32Attribute>()
.data();
rewriter.ReplaceAllUsesWith(slice_op->result(0),
cinn_split.result(index));
rewriter.EraseOp(slice_op);
}
rewriter.EraseOp(op);

return true;
}
return false;
}
};

class SplitWithNumOpPattern
: public pir::OpRewritePattern<paddle::dialect::SplitWithNumOp> {
public:
using pir::OpRewritePattern<
paddle::dialect::SplitWithNumOp>::OpRewritePattern;

bool MatchAndRewrite(paddle::dialect::SplitWithNumOp op,
pir::PatternRewriter &rewriter) const override {
auto axis_gen_op = op->operand_source(1).defining_op();
if (auto full_op = axis_gen_op->dyn_cast<paddle::dialect::FullOp>()) {
int axis = static_cast<int>(
full_op.attribute("value").dyn_cast<::pir::FloatAttribute>().data());

for (size_t i = 0; i < build_split->num_results(); ++i) {
rewriter.ReplaceAllUsesWith(build_split->result(i),
cinn_split.result(i));
auto input_ele = op->operand_source(0)
.type()
.dyn_cast<paddle::dialect::DenseTensorType>();
if (axis < 0) {
axis += input_ele.dims().size();
}
std::vector<int> sections;

auto split_dim = input_ele.dims()[axis];

auto split_num =
op->attribute("num").dyn_cast<::pir::Int32Attribute>().data();
auto part_ele = (split_dim + split_num - 1) / split_num;

rewriter.EraseOp(build_split);
int total_split_num = 0;
for (int i = 0; i < split_num - 1; ++i) {
sections.push_back(part_ele);
total_split_num += part_ele;
}

sections.push_back(split_dim - total_split_num);

auto cinn_split = rewriter.Build<cinn::dialect::SplitOp>(
op->operand_source(0), sections, axis);

auto orig_out = op.result(0);
for (auto it = orig_out.use_begin(); it != orig_out.use_end();) {
auto slice_op = (it++)->owner();
CHECK(slice_op->isa<::pir::SliceOp>());
int index = slice_op->dyn_cast<::pir::SliceOp>()
.attribute("index")
.dyn_cast<::pir::Int32Attribute>()
.data();
rewriter.ReplaceAllUsesWith(slice_op->result(0),
cinn_split.result(index));
rewriter.EraseOp(slice_op);
}

rewriter.EraseOp(op);

Expand All @@ -472,10 +517,8 @@ class AddNOpPattern : public pir::OpRewritePattern<paddle::dialect::AddNOp> {

bool MatchAndRewrite(paddle::dialect::AddNOp op,
pir::PatternRewriter &rewriter) const override {
auto combine_op = op->operand_source(0)
.dyn_cast<pir::OpResult>()
.owner()
->dyn_cast<pir::CombineOp>();
auto combine_op =
op->operand_source(0).defining_op()->dyn_cast<pir::CombineOp>();
auto input_ops = combine_op.inputs();

auto tmp = input_ops[0];
Expand All @@ -501,8 +544,7 @@ class ExpandOpPattern
bool MatchAndRewrite(paddle::dialect::ExpandOp op,
pir::PatternRewriter &rewriter) const override {
auto out_shape_gen_op = op->operand_source(1)
.dyn_cast<pir::OpResult>()
.owner()
.defining_op()
->dyn_cast<paddle::dialect::FullIntArrayOp>();

if (out_shape_gen_op) {
Expand Down Expand Up @@ -541,63 +583,6 @@ class ExpandOpPattern
}
};

class SplitWithNumOpPattern
: public pir::OpRewritePattern<paddle::dialect::SplitWithNumOp> {
public:
using pir::OpRewritePattern<
paddle::dialect::SplitWithNumOp>::OpRewritePattern;

bool MatchAndRewrite(paddle::dialect::SplitWithNumOp op,
pir::PatternRewriter &rewriter) const override {
auto axis_gen_op = op->operand_source(1).dyn_cast<pir::OpResult>().owner();
if (auto full_op = axis_gen_op->dyn_cast<paddle::dialect::FullOp>()) {
int axis = phi::Scalar(full_op.attribute("value")
.dyn_cast<::pir::FloatAttribute>()
.data())
.to<int>();

auto input_ele = op->operand_source(0)
.type()
.dyn_cast<paddle::dialect::DenseTensorType>();
if (axis < 0) {
axis += input_ele.dims().size();
}
std::vector<int> sections;

auto split_dim = input_ele.dims()[axis];

auto split_num =
op->attribute("num").dyn_cast<::pir::Int32Attribute>().data();
auto part_ele = (split_dim + split_num - 1) / split_num;

int total_split_num = 0;
for (int i = 0; i < split_num - 1; ++i) {
sections.push_back(part_ele);
total_split_num += part_ele;
}

sections.push_back(split_dim - total_split_num);

auto cinn_split = rewriter.Build<cinn::dialect::SplitOp>(
op->operand_source(0), sections, axis);

int index = 0;
auto orig_out = op.result(0);
for (auto it = orig_out.use_begin(); it != orig_out.use_end();) {
auto split_op = (it++)->owner();
rewriter.ReplaceAllUsesWith(split_op->result(0),
cinn_split.result(index++));
rewriter.EraseOp(split_op);
}

rewriter.EraseOp(op);

return true;
}
return false;
}
};

class UniformOpPattern : public paddle::drr::DrrPatternBase<UniformOpPattern> {
public:
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
Expand Down