Skip to content

Commit

Permalink
[PIR+CINN]Fix Pool2d Variant Attibute for kernel_size (#60623)
Browse files Browse the repository at this point in the history
* [PIR+CINN]Fix Pool2d Variant Attibute for kernel_size

* fix padding_size

* fix pooling_type
  • Loading branch information
Aurelius84 authored Jan 9, 2024
1 parent 9982819 commit 640c759
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
10 changes: 10 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@
func : isclose
data_type : x

- op : pool2d
args : (Tensor x, int[] kernel_size, int[] stride_size, int[] padding_size, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm)
output : Tensor(out)
infer_meta :
func : Pool2DInferMeta
param : [x, kernel_size, stride_size, padding_size, ceil_mode, exclusive, data_format, pooling_type, global_pooling, adaptive, padding_algorithm]
kernel :
func : pool2d
param : [x, kernel_size, stride_size, padding_size, ceil_mode, exclusive, data_format, pooling_type, global_pooling, adaptive, padding_algorithm]

- op : reduce_max
args : (Tensor x, int64_t[] dim, bool keep_dim)
output : Tensor(out)
Expand Down
47 changes: 47 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,52 @@ class ReshapeOpPattern
}
};

class Pool2dOpPattern
: public pir::OpRewritePattern<paddle::dialect::Pool2dOp> {
public:
using pir::OpRewritePattern<paddle::dialect::Pool2dOp>::OpRewritePattern;

bool MatchAndRewrite(paddle::dialect::Pool2dOp op,
pir::PatternRewriter &rewriter) const override {
auto kernel_size_gen_op =
op->operand_source(1).dyn_cast<pir::OpResult>().owner();

if (auto full_op =
kernel_size_gen_op->dyn_cast<paddle::dialect::FullIntArrayOp>()) {
auto kernel_size_attr =
full_op.attribute("value").dyn_cast<pir::ArrayAttribute>().AsVector();

// kernel_size is generator by full op
// get attribute value from full op
std::vector<pir::Attribute> kernel_size;
for (size_t i = 0; i < static_cast<size_t>(kernel_size_attr.size());
i++) {
pir::Attribute attr = pir::Int32Attribute::get(
pir::IrContext::Instance(),
kernel_size_attr[i].dyn_cast<::pir::Int64Attribute>().data());
kernel_size.push_back(attr);
}
auto attrs = op->attributes();
attrs["kernel_size"] =
pir::ArrayAttribute::get(pir::IrContext::Instance(), kernel_size);
attrs["stride_size"] = attrs.at("strides");
attrs["padding_size"] = attrs.at("paddings");
attrs["pool_type"] = attrs.at("pooling_type");
attrs.erase("strides");
attrs.erase("paddings");
attrs.erase("pooling_type");

auto cinn_reshape = rewriter.Build<cinn::dialect::Pool2dOp>(
op->operand_source(0).dyn_cast<pir::OpResult>(), attrs);
rewriter.ReplaceAllUsesWith(op.result(0), cinn_reshape.result(0));
rewriter.EraseOp(op);

return true;
}
return false;
}
};

class IsCloseOpPattern
: public pir::OpRewritePattern<paddle::dialect::IscloseOp> {
public:
Expand Down Expand Up @@ -613,6 +659,7 @@ pir::RewritePatternSet PdOpToCinnOpPass::InitializePatterns(
ps.Add(MinOpPattern().Build(context));
ps.Add(ProdOpPattern().Build(context));
ps.Add<ReshapeOpPattern>(context);
ps.Add<Pool2dOpPattern>(context);
ps.Add<ConcatOpPattern>(context);
ps.Add<SliceOpPattern>(context);
ps.Add<PowOpPattern>(context);
Expand Down

0 comments on commit 640c759

Please sign in to comment.