From a5f699b2c32483e5b6c7fce50fd5e2c9074f906c Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 8 Jan 2024 20:30:40 +0800 Subject: [PATCH 1/3] [PIR+CINN]Fix Pool2d Variant Attibute for kernel_size --- paddle/cinn/hlir/dialect/operator/ir/ops.yaml | 10 +++++ .../operator/transforms/pd_to_cinn_pass.cc | 43 +++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/paddle/cinn/hlir/dialect/operator/ir/ops.yaml b/paddle/cinn/hlir/dialect/operator/ir/ops.yaml index 2e42323782839..63bbc86dbeabd 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/ops.yaml +++ b/paddle/cinn/hlir/dialect/operator/ir/ops.yaml @@ -18,6 +18,16 @@ func : isclose data_type : x +- op : pool2d + args : (Tensor x, int[] kernel_size, int[] stride_size, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) + output : Tensor(out) + infer_meta : + func : Pool2DInferMeta + param : [x, kernel_size, stride_size, paddings, ceil_mode, exclusive, data_format, pooling_type, global_pooling, adaptive, padding_algorithm] + kernel : + func : pool2d + param : [x, kernel_size, stride_size, paddings, ceil_mode, exclusive, data_format, pooling_type, global_pooling, adaptive, padding_algorithm] + - op : reduce_max args : (Tensor x, int64_t[] dim, bool keep_dim) output : Tensor(out) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index 352fd9fdde322..9534d6045fb7d 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -225,6 +225,48 @@ class ReshapeOpPattern } }; +class Pool2dOpPattern + : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + + bool MatchAndRewrite(paddle::dialect::Pool2dOp op, + pir::PatternRewriter &rewriter) const override { + auto kernel_size_gen_op = + op->operand_source(1).dyn_cast().owner(); + + if (auto full_op = + kernel_size_gen_op->dyn_cast()) { + auto kernel_size_attr = + full_op.attribute("value").dyn_cast().AsVector(); + + // kernel_size is generator by full op + // get attribute value from full op + std::vector kernel_size; + for (size_t i = 0; i < static_cast(kernel_size_attr.size()); + i++) { + pir::Attribute attr = pir::Int32Attribute::get( + pir::IrContext::Instance(), + kernel_size_attr[i].dyn_cast<::pir::Int64Attribute>().data()); + kernel_size.push_back(attr); + } + auto attrs = op->attributes(); + attrs["kernel_size"] = + pir::ArrayAttribute::get(pir::IrContext::Instance(), kernel_size); + attrs["strides"] = attrs.at("stride_size"); + attrs.erase("stride_size"); + + auto cinn_reshape = rewriter.Build( + op->operand_source(0).dyn_cast(), attrs); + rewriter.ReplaceAllUsesWith(op.result(0), cinn_reshape.result(0)); + rewriter.EraseOp(op); + + return true; + } + return false; + } +}; + class IsCloseOpPattern : public pir::OpRewritePattern { public: @@ -613,6 +655,7 @@ pir::RewritePatternSet PdOpToCinnOpPass::InitializePatterns( ps.Add(MinOpPattern().Build(context)); ps.Add(ProdOpPattern().Build(context)); ps.Add(context); + ps.Add(context); ps.Add(context); ps.Add(context); ps.Add(context); From a43a3d97cc5c9f6c2116484c84321ff5e4cf3541 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 8 Jan 2024 20:44:02 +0800 Subject: [PATCH 2/3] fix padding_size --- paddle/cinn/hlir/dialect/operator/ir/ops.yaml | 6 +++--- .../hlir/dialect/operator/transforms/pd_to_cinn_pass.cc | 5 ++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/ir/ops.yaml b/paddle/cinn/hlir/dialect/operator/ir/ops.yaml index 63bbc86dbeabd..d8ce6c02fed9f 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/ops.yaml +++ b/paddle/cinn/hlir/dialect/operator/ir/ops.yaml @@ -19,14 +19,14 @@ data_type : x - op : pool2d - args : (Tensor x, int[] kernel_size, int[] stride_size, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) + args : (Tensor x, int[] kernel_size, int[] stride_size, int[] padding_size, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) output : Tensor(out) infer_meta : func : Pool2DInferMeta - param : [x, kernel_size, stride_size, paddings, ceil_mode, exclusive, data_format, pooling_type, global_pooling, adaptive, padding_algorithm] + param : [x, kernel_size, stride_size, padding_size, ceil_mode, exclusive, data_format, pooling_type, global_pooling, adaptive, padding_algorithm] kernel : func : pool2d - param : [x, kernel_size, stride_size, paddings, ceil_mode, exclusive, data_format, pooling_type, global_pooling, adaptive, padding_algorithm] + param : [x, kernel_size, stride_size, padding_size, ceil_mode, exclusive, data_format, pooling_type, global_pooling, adaptive, padding_algorithm] - op : reduce_max args : (Tensor x, int64_t[] dim, bool keep_dim) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index 9534d6045fb7d..2c97ed0c4c7cd 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -254,7 +254,10 @@ class Pool2dOpPattern attrs["kernel_size"] = pir::ArrayAttribute::get(pir::IrContext::Instance(), kernel_size); attrs["strides"] = attrs.at("stride_size"); - attrs.erase("stride_size"); + attrs["stride_size"] = attrs.at("strides"); + attrs["padding_size"] = attrs.at("paddings"); + attrs.erase("strides"); + attrs.erase("paddings"); auto cinn_reshape = rewriter.Build( op->operand_source(0).dyn_cast(), attrs); From b1ebbdf0ee0fb8ba5e4ab3fb850e41a89ff0aca3 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 8 Jan 2024 20:57:11 +0800 Subject: [PATCH 3/3] fix pooling_type --- .../cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index 2c97ed0c4c7cd..66c6a7ddf8c59 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -253,11 +253,12 @@ class Pool2dOpPattern auto attrs = op->attributes(); attrs["kernel_size"] = pir::ArrayAttribute::get(pir::IrContext::Instance(), kernel_size); - attrs["strides"] = attrs.at("stride_size"); attrs["stride_size"] = attrs.at("strides"); attrs["padding_size"] = attrs.at("paddings"); + attrs["pool_type"] = attrs.at("pooling_type"); attrs.erase("strides"); attrs.erase("paddings"); + attrs.erase("pooling_type"); auto cinn_reshape = rewriter.Build( op->operand_source(0).dyn_cast(), attrs);