From 0f75a78e0c8d459bf8255768931c0abb6f54a781 Mon Sep 17 00:00:00 2001 From: Liang Shuhao Date: Wed, 3 Apr 2024 11:20:00 +0800 Subject: [PATCH 1/3] [CINN] remove 0D to 1D pass --- .../operator/transforms/add_cinn_pass.cc | 3 - .../group_merge/convert_0d_to_1d_pass.cc | 272 ------------------ .../group_merge/convert_0d_to_1d_pass.h | 28 -- .../policy/relative_judge_policy.h | 5 +- .../policy/shardable_axes_base.cc | 5 +- paddle/cinn/operator_fusion/utils.h | 4 + .../pir/cinn/symbolic/test_cinn_0d_tensor.py | 198 +++++++++++++ 7 files changed, 208 insertions(+), 307 deletions(-) delete mode 100644 paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_0d_to_1d_pass.cc delete mode 100644 paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_0d_to_1d_pass.h create mode 100644 test/ir/pir/cinn/symbolic/test_cinn_0d_tensor.py diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc index e69b0e7d96bd1e..25090caacedea3 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc @@ -31,7 +31,6 @@ #include "paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/fold_manipulation_ops_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.h" -#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_0d_to_1d_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_dynamic_to_static_dim_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_static_dim_to_dynamic_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/divide_group_op_to_fusion_op_pass.h" @@ -94,9 +93,7 @@ void ApplyCinnPreprocessPass( bool has_dynamic_shape = HasDynamicShape(*program); if (has_dynamic_shape) { - pass_manager->AddPass(cinn::dialect::ir::CreateConvert0DTo1DPass()); pass_manager->AddPass(pir::CreateShapeOptimizationPass()); - pass_manager->AddPass(cinn::dialect::ir::CreateConvert0DTo1DPass()); pass_manager->AddPass( cinn::dialect::ir::CreateFuseShapeOpsIntoGenerateShapeOpPass()); pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_0d_to_1d_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_0d_to_1d_pass.cc deleted file mode 100644 index 588312cc80114c..00000000000000 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_0d_to_1d_pass.cc +++ /dev/null @@ -1,272 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_0d_to_1d_pass.h" - -#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" -#include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" -#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" -#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/pir/include/core/builtin_type.h" -#include "paddle/pir/include/pattern_rewrite/pattern_rewrite_driver.h" - -namespace cinn { -namespace dialect { -namespace ir { - -namespace { - -class FullOpPattern : public pir::OpRewritePattern { - public: - using pir::OpRewritePattern::OpRewritePattern; - - bool Match(paddle::dialect::FullOp op) const override { - return op.attribute("shape") - .dyn_cast() - .data() - .size() == 0 && - op.out().type().dyn_cast().dims().size() == 0; - } - - void Rewrite(paddle::dialect::FullOp op, - pir::PatternRewriter& rewriter) const override { - float factor = - op->attribute("value").dyn_cast<::pir::FloatAttribute>().data(); - phi::DataType dtype = op->attribute("dtype") - .dyn_cast() - .data(); - phi::Place place = op->attribute("place") - .dyn_cast() - .data(); - - auto full_op = rewriter.Build( - std::vector({1}), factor, dtype, place); - rewriter.ReplaceAllUsesWith(op.result(0), full_op.result(0)); - rewriter.EraseOp(op); - } -}; - -class SliceOpPattern : public pir::OpRewritePattern { - public: - using pir::OpRewritePattern::OpRewritePattern; - - bool Match(paddle::dialect::SliceOp op) const override { - const auto& tensor_type = - op.result(0).type().dyn_cast(); - - return tensor_type.dims().size() == 0; - } - - void Rewrite(paddle::dialect::SliceOp op, - pir::PatternRewriter& rewriter) const override { - std::vector vec_dims; - pir::Attribute attr_dims = - pir::ArrayAttribute::get(pir::IrContext::Instance(), vec_dims); - - op->set_attribute("decrease_axis", attr_dims); - } -}; - -class SumOpPattern : public pir::OpRewritePattern { - public: - using pir::OpRewritePattern::OpRewritePattern; - - bool Match(paddle::dialect::SumOp op) const override { - const auto& tensor_type = - op.result(0).type().dyn_cast(); - return tensor_type.dims().size() == 0; - } - - void Rewrite(paddle::dialect::SumOp op, - pir::PatternRewriter& rewriter) const override { - std::vector axis{}; - const auto& dtype = op->attribute("dtype") - .dyn_cast() - .data(); - auto new_reduce_op = rewriter.Build( - op.operand_source(0), axis, dtype, /*keepdim=*/true); - auto reshape_op = rewriter.Build( - new_reduce_op.result(0), /*shape=*/std::vector({1})); - rewriter.ReplaceAllUsesWith(op.result(0), reshape_op.result(0)); - rewriter.EraseOp(op); - } -}; - -pir::DenseTensorType Make1DTensorType(const pir::DenseTensorType& tensor_type) { - return pir::DenseTensorType::get(pir::IrContext::Instance(), - tensor_type.dtype(), - {1}, - tensor_type.data_layout(), - tensor_type.lod(), - tensor_type.offset()); -} - -void ConvertValue0DTo1D(pir::Value operand) { - auto ConvertVectorType0DTo1D = - [](const pir::VectorType& vector_tensor_type) -> std::vector { - std::vector types; - for (std::size_t i = 0; i < vector_tensor_type.size(); ++i) { - CHECK(vector_tensor_type[i].isa()); - const auto& dense_type = - vector_tensor_type[i].dyn_cast(); - types.push_back(dense_type.dims().size() == 0 - ? Make1DTensorType(dense_type) - : vector_tensor_type[i]); - } - return types; - }; - - if (const auto& tensor_type = - operand.type().dyn_cast()) { - if (tensor_type.dims().size() == 0) { - operand.set_type(Make1DTensorType(tensor_type)); - } - } else if (const auto& vector_tensor_type = - operand.type().dyn_cast()) { - pir::Builder builder(pir::IrContext::Instance()); - std::vector inputs_type = - ConvertVectorType0DTo1D(vector_tensor_type); - operand.set_type(builder.vec_type(inputs_type)); - } else { - VLOG(4) << "Unsupported operand type: " << operand.type(); - } -} - -class WhileOpPattern : public pir::OpRewritePattern { - public: - using pir::OpRewritePattern::OpRewritePattern; - - bool Match(paddle::dialect::WhileOp op) const override { - for (const auto& value : op.block_args()) { - if (const auto& tensor_type = - value.type().template dyn_cast()) { - if (tensor_type.dims().size() == 0) { - return true; - } - } - } - return false; - } - - void Rewrite(paddle::dialect::WhileOp op, - pir::PatternRewriter& rewriter) const override { - for (pir::Value value : op.block_args()) { - ConvertValue0DTo1D(value); - } - } -}; - -class CombineOpPattern : public pir::OpRewritePattern { - public: - using pir::OpRewritePattern::OpRewritePattern; - - bool Match(pir::CombineOp op) const override { - for (std::size_t i = 1; i < op->operands().size(); ++i) { - if (op.operand_source(i).type() != op.operand_source(0).type()) { - return true; - } - } - return false; - } - - void Rewrite(pir::CombineOp op, - pir::PatternRewriter& rewriter) const override { - pir::Builder builder(rewriter.ir_context()); - - const std::vector inputs_type = [&]() { - std::vector types; - for (auto value : op->operands_source()) { - types.push_back(value.type()); - } - return types; - }(); - op.result(0).set_type(builder.vec_type(inputs_type)); - } -}; - -class Convert0DTo1DPass : public pir::Pass { - public: - Convert0DTo1DPass() : pir::Pass("convert_0D_to_1D", 1) {} - - bool Initialize(pir::IrContext* context) override { - pir::RewritePatternSet ps(context); - ps.Add(context); - ps.Add(context); - ps.Add(context); - ps.Add(context); - ps.Add(context); - patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); - return true; - } - - void Run(pir::Operation* op) override { - for (uint32_t i = 0; i < op->num_regions(); ++i) { - ApplyPatternOnOperation(op->region(i)); - for (const auto& block : op->region(i)) { - ConvertBlock0DTo1D(block); - } - } - } - - void ApplyPatternOnOperation(pir::Region& region) { // NOLINT - pir::GreedyRewriteConfig cfg; - cfg.use_top_down_traversal = true; - cfg.max_iterations = 10; - const auto& [_, num_rewrites] = - pir::ApplyPatternsGreedily(region, patterns_, cfg); - AddStatistics(num_rewrites); - } - - bool CanApplyOn(pir::Operation* op) const override { - return op->isa() && op->num_regions() > 0; - } - - void ConvertOperation0DTo1D(const pir::Operation& op) { // NOLINT - for (std::size_t i = 0; i < op.num_operands(); ++i) { - ConvertValue0DTo1D(op.operand_source(i)); - } - for (std::size_t i = 0; i < op.num_results(); ++i) { - ConvertValue0DTo1D(op.result(i)); - } - } - - void ConvertBlock0DTo1D(const pir::Block& block) { - for (auto& op : block) { - ConvertOperation0DTo1D(op); - for (std::size_t i = 0; i < op.num_regions(); ++i) { - ApplyPatternOnOperation(op.region(i)); - for (auto& inner_block : op.region(i)) { - ConvertBlock0DTo1D(inner_block); - } - } - } - } - - private: - pir::FrozenRewritePatternSet patterns_; -}; - -} // namespace - -std::unique_ptr<::pir::Pass> CreateConvert0DTo1DPass() { - return std::make_unique(); -} - -} // namespace ir -} // namespace dialect -} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_0d_to_1d_pass.h b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_0d_to_1d_pass.h deleted file mode 100644 index b3cabacd6b261b..00000000000000 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_0d_to_1d_pass.h +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include "paddle/pir/include/pass/pass.h" - -namespace cinn { -namespace dialect { -namespace ir { - -// This is a helper pass for converting zero-dim tensor to one-dim tensor -std::unique_ptr<::pir::Pass> CreateConvert0DTo1DPass(); -} // namespace ir -} // namespace dialect -} // namespace cinn diff --git a/paddle/cinn/operator_fusion/policy/relative_judge_policy.h b/paddle/cinn/operator_fusion/policy/relative_judge_policy.h index ac7d9037d24f56..ca611d5895266c 100644 --- a/paddle/cinn/operator_fusion/policy/relative_judge_policy.h +++ b/paddle/cinn/operator_fusion/policy/relative_judge_policy.h @@ -155,8 +155,9 @@ static ValueDimRelation CreateOpRelativenessForReduce(pir::Operation* op) { int out_idx = 0; bool keep_dim = GetReduceOpKeepDims(op); for (int i = 0; i < input_rank; i++) { - if (std::find(reduce_axis_idx.begin(), reduce_axis_idx.end(), i) != - reduce_axis_idx.end()) { + if (!reduce_axis_idx.empty() && + std::find(reduce_axis_idx.begin(), reduce_axis_idx.end(), i) == + reduce_axis_idx.end()) { res[ValueDim(op->operand_source(0), i)] [ValueDim(op->result(0), out_idx)] = true; out_idx += 1; diff --git a/paddle/cinn/operator_fusion/policy/shardable_axes_base.cc b/paddle/cinn/operator_fusion/policy/shardable_axes_base.cc index a9876ea0b82710..e86a2be77b06eb 100644 --- a/paddle/cinn/operator_fusion/policy/shardable_axes_base.cc +++ b/paddle/cinn/operator_fusion/policy/shardable_axes_base.cc @@ -103,8 +103,9 @@ ShardableAxesSignature CreateSignatureForReduce(pir::Operation* reduce_op) { auto output_axes = std::vector(); for (int i = 0; i < input_rank; i++) { - if (std::find(reduce_axis_idx.begin(), reduce_axis_idx.end(), i) != - reduce_axis_idx.end()) { + if (reduce_axis_idx.empty() || + std::find(reduce_axis_idx.begin(), reduce_axis_idx.end(), i) != + reduce_axis_idx.end()) { if (keep_dim) { output_axes.emplace_back(ShardableAxesInfoManager::GetUniqueName()); } // else do nothing diff --git a/paddle/cinn/operator_fusion/utils.h b/paddle/cinn/operator_fusion/utils.h index 696836fe2a7804..1ef81010795724 100644 --- a/paddle/cinn/operator_fusion/utils.h +++ b/paddle/cinn/operator_fusion/utils.h @@ -50,6 +50,10 @@ static std::vector GetReduceAxisIdx(pir::Operation* reduce_op) { CHECK(attr_val.isa<::pir::ArrayAttribute>()); const auto& axis_attr = attr_val.dyn_cast<::pir::ArrayAttribute>(); std::vector reduce_axis_idx; + if (input_rank == 0) { + VLOG(4) << "GetReduceAxisIdx: "; + return reduce_axis_idx; + } for (int i = 0; i < axis_attr.size(); ++i) { int64_t axis = axis_attr.at(i).dyn_cast<::pir::Int64Attribute>().data(); if (axis < 0) { diff --git a/test/ir/pir/cinn/symbolic/test_cinn_0d_tensor.py b/test/ir/pir/cinn/symbolic/test_cinn_0d_tensor.py new file mode 100644 index 00000000000000..8320abbb3c75eb --- /dev/null +++ b/test/ir/pir/cinn/symbolic/test_cinn_0d_tensor.py @@ -0,0 +1,198 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from os.path import dirname + +import numpy as np + +sys.path.append(dirname(dirname(__file__))) + +import unittest + +import utils + +import paddle +import paddle.nn.functional as F +from paddle.static import InputSpec + + +class TestFunc(unittest.TestCase): + """ + Test Pir API + @to_static + CINN. + """ + + def setUp(self): + paddle.seed(2024) + self.prepare_data() + self.prepare_func() + + def prepare_data(self): + pass + + def prepare_func(self): + pass + + def check_jit_kernel_info(self, static_fn): + utils.check_jit_kernel_number(static_fn, 1) + utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1}) + + def check_output_shape(self, out): + pass + + def eval_symbolic(self, use_cinn): + paddle.seed(2024) + func = utils.apply_to_static(self.func, use_cinn, self.input_spec) + func.eval() + out = func(*self.input) + if use_cinn: + self.check_jit_kernel_info(func) + self.check_output_shape(out) + return out + + def test_eval_symbolic(self): + if type(self) is TestFunc: + return + cinn_out = self.eval_symbolic(use_cinn=True) + dy_out = self.eval_symbolic(use_cinn=False) + np.testing.assert_allclose( + cinn_out.numpy(), dy_out.numpy(), rtol=1e-6, atol=1e-3 + ) + + +class TestReduce3Dto0D(TestFunc): + def prepare_data(self): + self.input_spec = [InputSpec(shape=[8, None, 64], dtype='float32')] + self.input = [paddle.randn([8, 128, 64])] + + def prepare_func(self): + def func(x): + return paddle.sum(x) + + self.func = func + + def check_output_shape(self, out): + np.testing.assert_equal(out.shape, ()) + + +class TestReduce1Dto0D(TestReduce3Dto0D): + def prepare_data(self): + self.input_spec = [InputSpec(shape=[None], dtype='float32')] + self.input = [paddle.randn([2048])] + + +class TestReduce0Dto0D(TestReduce3Dto0D): + def prepare_data(self): + self.input_spec = [InputSpec(shape=[], dtype='float32')] + self.input = [paddle.randn([])] + + +class TestReduce3Dto0DThenRelu(TestReduce3Dto0D): + def prepare_func(self): + def func(x): + return F.relu(paddle.sum(x)) + + self.func = func + + +class TestReduce3Dto0DThenAdd0D(TestReduce3Dto0D): + def prepare_data(self): + self.input_spec = [ + InputSpec(shape=[8, None, 64], dtype='float32'), + InputSpec(shape=[], dtype='float32'), + ] + self.input = [paddle.randn([8, 128, 64]), paddle.randn([])] + + def prepare_func(self): + def func(x, y): + return paddle.sum(x) + y + + self.func = func + + +class TestAdd0Dto3D(TestFunc): + def prepare_data(self): + self.input_spec = [ + InputSpec(shape=[], dtype='float32'), + InputSpec(shape=[8, 128, 64], dtype='float32'), + ] + self.input = [paddle.randn([]), paddle.randn([8, 128, 64])] + + def prepare_func(self): + def func(x, y): + return x + y + + self.func = func + + +class TestAdd0Dto0D(TestAdd0Dto3D): + def prepare_data(self): + self.input_spec = [ + InputSpec(shape=[], dtype='float32'), + InputSpec(shape=[], dtype='float32'), + ] + self.input = [paddle.randn([]), paddle.randn([])] + + def check_output_shape(self, out): + np.testing.assert_equal(out.shape, ()) + + +class TestSoftmax0D(TestReduce0Dto0D): + def prepare_func(self): + def func(x): + x = paddle.exp(x) + d = paddle.sum(x, axis=-1, keepdim=True) + x = x / d + return x + + self.func = func + + +class TestReshape0Dto3D(TestAdd0Dto3D): + def prepare_func(self): + def func(x, y): + return paddle.reshape(x, [1, 1, 1]) + y + + self.func = func + + +class TestReshape0Dto0D(TestAdd0Dto0D): + def prepare_func(self): + def func(x, y): + return paddle.reshape(x, []) + y + + self.func = func + + +class TestExpand0Dto3D(TestFunc): + def prepare_data(self): + self.input_spec = [InputSpec(shape=[], dtype='float32')] + self.input = [paddle.randn([])] + + def prepare_func(self): + def func(x): + return paddle.expand(x, [8, 128, 64]) + + self.func = func + + +class TestExpand0Dto0D(TestAdd0Dto0D): + def prepare_func(self): + def func(x, y): + return paddle.expand(x, []) + y + + self.func = func + + +if __name__ == '__main__': + unittest.main() From 7d5a7459d4bb22a16949348a62ac1065206c5e7b Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Thu, 18 Apr 2024 11:03:30 +0800 Subject: [PATCH 2/3] Update paddle/cinn/operator_fusion/utils.h --- paddle/cinn/operator_fusion/utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/cinn/operator_fusion/utils.h b/paddle/cinn/operator_fusion/utils.h index 1ef81010795724..e9eb0806d60299 100644 --- a/paddle/cinn/operator_fusion/utils.h +++ b/paddle/cinn/operator_fusion/utils.h @@ -51,7 +51,7 @@ static std::vector GetReduceAxisIdx(pir::Operation* reduce_op) { const auto& axis_attr = attr_val.dyn_cast<::pir::ArrayAttribute>(); std::vector reduce_axis_idx; if (input_rank == 0) { - VLOG(4) << "GetReduceAxisIdx: "; + VLOG(4) << "Reduce op has 0D Tensor input, return empty reduce_axis"; return reduce_axis_idx; } for (int i = 0; i < axis_attr.size(); ++i) { From a2bc195b5d917c1700a5669beec6ecc3ec491ed4 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Thu, 18 Apr 2024 11:03:48 +0800 Subject: [PATCH 3/3] Update test/ir/pir/cinn/symbolic/test_cinn_0d_tensor.py --- test/ir/pir/cinn/symbolic/test_cinn_0d_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ir/pir/cinn/symbolic/test_cinn_0d_tensor.py b/test/ir/pir/cinn/symbolic/test_cinn_0d_tensor.py index 8320abbb3c75eb..d022b9f660d0a9 100644 --- a/test/ir/pir/cinn/symbolic/test_cinn_0d_tensor.py +++ b/test/ir/pir/cinn/symbolic/test_cinn_0d_tensor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.