From 8f598d9fba34e39169ba96fd4828ce874dc133ce Mon Sep 17 00:00:00 2001 From: LoneRanger <836253168@qq.com> Date: Wed, 27 Dec 2023 16:27:57 +0800 Subject: [PATCH 001/142] =?UTF-8?q?=E3=80=90PIR=20API=20adaptor=20No.75-76?= =?UTF-8?q?=E3=80=91Migrate=20some=20ops=20into=20pir=20(#59627)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/nn/functional/common.py | 2 +- python/paddle/nn/functional/loss.py | 10 +++++++--- test/legacy_test/test_fold_op.py | 4 ++-- test/legacy_test/test_sigmoid_focal_loss.py | 2 ++ 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 8988e89111c096..1fb678efd0b13f 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -2401,7 +2401,7 @@ def _is_list_or_turple_(data): "of 2 or 4 integers" ) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): out = _C_ops.fold( x, output_sizes, kernel_sizes, strides, paddings, dilations ) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index d1611106b7c52f..809056cf39aafe 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3176,7 +3176,7 @@ def sigmoid_focal_loss( ) ) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): place = _current_expected_place() one = _C_ops.full(logit.shape, 1.0, logit.dtype, place) @@ -3193,7 +3193,10 @@ def sigmoid_focal_loss( ), ) - alpha = base.dygraph.base.to_variable([alpha], dtype=loss.dtype) + if in_dynamic_mode(): + alpha = base.dygraph.base.to_variable([alpha], dtype=loss.dtype) + else: + alpha = paddle.to_tensor(alpha, dtype=loss.dtype) alpha_t = _C_ops.add( _C_ops.multiply(alpha, label), _C_ops.multiply( @@ -3202,7 +3205,8 @@ def sigmoid_focal_loss( ) loss = _C_ops.multiply(alpha_t, loss) - gamma = base.dygraph.base.to_variable([gamma], dtype=loss.dtype) + if in_dynamic_mode(): + gamma = base.dygraph.base.to_variable([gamma], dtype=loss.dtype) gamma_t = _C_ops.pow(_C_ops.subtract(one, p_t), gamma) loss = _C_ops.multiply(gamma_t, loss) diff --git a/test/legacy_test/test_fold_op.py b/test/legacy_test/test_fold_op.py index 8e4ab1971b7ae6..18aa7886bff7bd 100644 --- a/test/legacy_test/test_fold_op.py +++ b/test/legacy_test/test_fold_op.py @@ -133,10 +133,10 @@ def setUp(self): self.set_data() def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Y') + self.check_grad(['X'], 'Y', check_pir=True) class TestFold_Complex64(TestFoldOp): diff --git a/test/legacy_test/test_sigmoid_focal_loss.py b/test/legacy_test/test_sigmoid_focal_loss.py index b151d4c56a21e1..9142375f376945 100644 --- a/test/legacy_test/test_sigmoid_focal_loss.py +++ b/test/legacy_test/test_sigmoid_focal_loss.py @@ -18,6 +18,7 @@ import paddle from paddle import base +from paddle.pir_utils import test_with_pir_api def call_sfl_functional( @@ -119,6 +120,7 @@ def calc_sigmoid_focal_loss( class TestSigmoidFocalLoss(unittest.TestCase): + @test_with_pir_api def test_SigmoidFocalLoss(self): logit_np = np.random.uniform(0.1, 0.8, size=(2, 3, 4, 10)).astype( np.float64 From c79c631b7d728baac87e312dfa5a7a9694790229 Mon Sep 17 00:00:00 2001 From: enzodechine Date: Wed, 27 Dec 2023 17:37:25 +0800 Subject: [PATCH 002/142] bind bf16 strided_slice&grad (#60382) --- paddle/phi/kernels/strided_slice_grad_kernel.cc | 3 ++- paddle/phi/kernels/strided_slice_kernel.cc | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/strided_slice_grad_kernel.cc b/paddle/phi/kernels/strided_slice_grad_kernel.cc index 7582f751bf16a5..8c5c90783133c9 100644 --- a/paddle/phi/kernels/strided_slice_grad_kernel.cc +++ b/paddle/phi/kernels/strided_slice_grad_kernel.cc @@ -78,5 +78,6 @@ PD_REGISTER_KERNEL(strided_slice_grad, int, int16_t, float, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} #endif diff --git a/paddle/phi/kernels/strided_slice_kernel.cc b/paddle/phi/kernels/strided_slice_kernel.cc index 0852cc8830e2c0..2bc9325de1ee7f 100644 --- a/paddle/phi/kernels/strided_slice_kernel.cc +++ b/paddle/phi/kernels/strided_slice_kernel.cc @@ -76,5 +76,6 @@ PD_REGISTER_KERNEL(strided_slice, int, int16_t, float, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} #endif From fd9e67c14b6f45d5bb9c2e754d89ae0b3b80b88e Mon Sep 17 00:00:00 2001 From: enzodechine Date: Wed, 27 Dec 2023 17:37:44 +0800 Subject: [PATCH 003/142] [XPU]support bf16 elementwise_sub and div (#60386) * support bf16 elementwise_sub and div * support bf16 elementwise_sub and div --- paddle/phi/backends/xpu/xpu3_op_list.cc | 10 ++++++++-- .../phi/kernels/xpu/elementwise_divide_grad_kernel.cc | 1 + paddle/phi/kernels/xpu/elementwise_divide_kernel.cc | 1 + .../kernels/xpu/elementwise_subtract_grad_kernel.cc | 1 + paddle/phi/kernels/xpu/elementwise_subtract_kernel.cc | 1 + 5 files changed, 12 insertions(+), 2 deletions(-) diff --git a/paddle/phi/backends/xpu/xpu3_op_list.cc b/paddle/phi/backends/xpu/xpu3_op_list.cc index 623f63444c3084..016e5ef917af57 100644 --- a/paddle/phi/backends/xpu/xpu3_op_list.cc +++ b/paddle/phi/backends/xpu/xpu3_op_list.cc @@ -266,10 +266,13 @@ XPUOpMap& get_kl3_ops() { phi::DataType::INT64, phi::DataType::INT32})}, {"elementwise_div_grad", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"elementwise_div", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT64, phi::DataType::INT32})}, {"elementwise_floordiv", @@ -295,10 +298,13 @@ XPUOpMap& get_kl3_ops() { {"elementwise_pow", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"elementwise_sub_grad", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"elementwise_sub", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::INT32, phi::DataType::INT64})}, {"elementwise_mod", diff --git a/paddle/phi/kernels/xpu/elementwise_divide_grad_kernel.cc b/paddle/phi/kernels/xpu/elementwise_divide_grad_kernel.cc index 3b20874b5f312e..eeba11974c3041 100644 --- a/paddle/phi/kernels/xpu/elementwise_divide_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/elementwise_divide_grad_kernel.cc @@ -59,4 +59,5 @@ PD_REGISTER_KERNEL(divide_grad, ALL_LAYOUT, phi::DivideGradKernel, phi::dtype::float16, + phi::dtype::bfloat16, float) {} diff --git a/paddle/phi/kernels/xpu/elementwise_divide_kernel.cc b/paddle/phi/kernels/xpu/elementwise_divide_kernel.cc index 2f608879cd7e03..41f20b061fae67 100644 --- a/paddle/phi/kernels/xpu/elementwise_divide_kernel.cc +++ b/paddle/phi/kernels/xpu/elementwise_divide_kernel.cc @@ -50,5 +50,6 @@ PD_REGISTER_KERNEL(divide, phi::DivideKernel, float, phi::dtype::float16, + phi::dtype::bfloat16, int, int64_t) {} diff --git a/paddle/phi/kernels/xpu/elementwise_subtract_grad_kernel.cc b/paddle/phi/kernels/xpu/elementwise_subtract_grad_kernel.cc index d22b369619d40d..f61a5f5de94109 100644 --- a/paddle/phi/kernels/xpu/elementwise_subtract_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/elementwise_subtract_grad_kernel.cc @@ -53,4 +53,5 @@ PD_REGISTER_KERNEL(subtract_grad, ALL_LAYOUT, phi::SubtractGradKernel, phi::dtype::float16, + phi::dtype::bfloat16, float) {} diff --git a/paddle/phi/kernels/xpu/elementwise_subtract_kernel.cc b/paddle/phi/kernels/xpu/elementwise_subtract_kernel.cc index a3252b7534dbf4..8ba3c47a456e9f 100644 --- a/paddle/phi/kernels/xpu/elementwise_subtract_kernel.cc +++ b/paddle/phi/kernels/xpu/elementwise_subtract_kernel.cc @@ -44,5 +44,6 @@ PD_REGISTER_KERNEL(subtract, phi::SubtractKernel, float, phi::dtype::float16, + phi::dtype::bfloat16, int, int64_t) {} From 363a11b3043fc4db9fb9b4e25b91c71218b54b61 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Wed, 27 Dec 2023 18:34:08 +0800 Subject: [PATCH 004/142] [PIR] inplace pass support sub block (#60369) * inplace pass support sub block * update * update * fix typo --- paddle/fluid/pir/transforms/inplace_pass.cc | 125 ++++++++++-------- .../transforms/transform_general_functions.cc | 58 +++++++- .../transforms/transform_general_functions.h | 30 ++++- paddle/fluid/pybind/control_flow_api.cc | 50 +------ 4 files changed, 150 insertions(+), 113 deletions(-) diff --git a/paddle/fluid/pir/transforms/inplace_pass.cc b/paddle/fluid/pir/transforms/inplace_pass.cc index eaaaeba7b28b64..b836617321f8cf 100644 --- a/paddle/fluid/pir/transforms/inplace_pass.cc +++ b/paddle/fluid/pir/transforms/inplace_pass.cc @@ -28,6 +28,7 @@ #include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" #include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/fluid/pir/transforms/inplace_pass.h" +#include "paddle/fluid/pir/transforms/transform_general_functions.h" #include "paddle/phi/core/flags.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/operation.h" @@ -36,17 +37,17 @@ PHI_DECLARE_string(ir_inplace_kernel_blacklist); -namespace details { +namespace { using TensorType = paddle::dialect::AllocatedDenseTensorType; -static std::unordered_set ignore_shape_check_ops = { +std::unordered_set IgnoreShapeCheckOps = { paddle::dialect::ReshapeOp::name(), paddle::dialect::SqueezeOp::name(), paddle::dialect::UnsqueezeOp::name(), }; -static std::unordered_set relax_shape_check_ops = { +std::unordered_set RelaxShapeCheckOps = { paddle::dialect::ReshapeGradOp::name(), paddle::dialect::AddGradOp::name(), }; @@ -54,7 +55,7 @@ static std::unordered_set relax_shape_check_ops = { // NOTE(zhangbo): Which kind of value can be deleted? // (1) Value's type needs to be AllocatedDenseTensorType or // AllocatedSelectedRowsType; (2) Value's is not persisable. -static bool CanBeDeleted(pir::Value value) { +bool CanBeDeleted(pir::Value value) { if (!value.type()) { return false; } @@ -66,10 +67,10 @@ static bool CanBeDeleted(pir::Value value) { return !(persist_attr && persist_attr.data()); } -static bool CanDoInplace(const std::unordered_set& eager_dels, - pir::Value input, - pir::Value output, - const std::string& op_name) { +bool CanDoInplace(const std::unordered_set& eager_dels, + pir::Value input, + pir::Value output, + const std::string& op_name) { if (!input.type() || !output.type()) { return false; } @@ -83,7 +84,7 @@ static bool CanDoInplace(const std::unordered_set& eager_dels, return false; } - if (details::ignore_shape_check_ops.count(op_name) > 0 && + if (IgnoreShapeCheckOps.count(op_name) > 0 && eager_dels.count(input) != 0) { VLOG(9) << " -- reshape, squeeze, unsqueeze do not need check shape, " "can do inplace"; @@ -141,7 +142,7 @@ static bool CanDoInplace(const std::unordered_set& eager_dels, return in_numel == out_numel; }; bool equal = false; - bool relax = (details::relax_shape_check_ops.count(op_name) > 0); + bool relax = (RelaxShapeCheckOps.count(op_name) > 0); if (relax) { equal = is_numel_euqal_loose_version(input_alloc_tensor_type, output_alloc_tensor_type); @@ -164,7 +165,7 @@ static bool CanDoInplace(const std::unordered_set& eager_dels, return true; } -static bool IsNoNeedBuffer(pir::Operation* op, pir::Value value) { +bool IsNoNeedBuffer(pir::Operation* op, pir::Value value) { if (op->dialect()->name().compare(paddle::dialect::KernelDialect::name()) != 0) { VLOG(8) << op->name() @@ -194,9 +195,9 @@ static bool IsNoNeedBuffer(pir::Operation* op, pir::Value value) { // NOTE(zhangbo): pd_op.feed's output and pd_op.fetch's input can not be eager // deleted. -static std::unordered_set GetSkipDeletionValues(pir::Block* block) { +std::unordered_set GetSkipDeletionValues(const pir::Block& block) { std::unordered_set skip_dels; - for (auto& op : *block) { + for (auto& op : block) { if (op.dialect()->name().compare(paddle::dialect::KernelDialect::name()) != 0) { continue; @@ -223,11 +224,11 @@ static std::unordered_set GetSkipDeletionValues(pir::Block* block) { // NOTE(zhangbo): For inplace Pass, currently only the kernel_dialect operator // is supported. Therefore, this function only returns the values in the // kernel_dialect operator that can be eager deleted. -static void GetEagerDelValueOfOp( - pir::Block* block, +void GetEagerDelValueOfOp( + const pir::Block& block, const std::unordered_set& skip_dels, std::unordered_map* del_value_2_op) { - for (auto& op : *block) { + for (auto& op : block) { std::string upper_op_name = op.name(); if (op.dialect()->name().compare(paddle::dialect::KernelDialect::name()) == 0) { @@ -259,18 +260,19 @@ static void GetEagerDelValueOfOp( } } - if (op.isa()) { - auto if_op = op.dyn_cast(); - GetEagerDelValueOfOp(&if_op.true_block(), skip_dels, del_value_2_op); - VLOG(8) << "GetEagerDelValueOfOp for IfOp true block"; - GetEagerDelValueOfOp(&if_op.false_block(), skip_dels, del_value_2_op); - VLOG(8) << "GetEagerDelValueOfOp for IfOp false block"; + if (op.num_regions() > 0) { + for (size_t i = 0; i < op.num_regions(); ++i) { + for (const auto& inner_block : op.region(i)) { + GetEagerDelValueOfOp(inner_block, skip_dels, del_value_2_op); + } + } + VLOG(8) << "GetEagerDelValueOfOp for sub block"; } } } -static std::unordered_map> -GetEagerDeletionValues(pir::Block* block) { +std::unordered_map> +GetEagerDeletionValues(const pir::Block& block) { std::unordered_set skip_dels = GetSkipDeletionValues(block); std::unordered_map del_value_2_op; @@ -285,8 +287,8 @@ GetEagerDeletionValues(pir::Block* block) { return eager_dels; } -static std::unordered_map GetInplaceOps( - pir::Block* block) { +std::unordered_map GetInplaceOps( + const pir::Block& block) { const auto eager_dels = GetEagerDeletionValues(block); std::unordered_map inplace_ops; @@ -295,7 +297,7 @@ static std::unordered_map GetInplaceOps( std::unordered_set reused_input_values; std::unordered_set reused_output_values; - for (auto& op : *block) { + for (auto& op : block) { for (size_t i = 0; i < op.num_operands(); ++i) { visited_values.insert(op.operand_source(i)); } @@ -391,6 +393,8 @@ static std::unordered_map GetInplaceOps( std::unordered_map inplace_out_2_in = upper_inplace_op_info_parser.GetInplaceIdMap(); + const auto used_external_values = GetUsedExternalValue(block); + bool can_do_inplace = true; for (auto& kv : inplace_out_2_in) { uint32_t out_slot = kv.first; @@ -403,12 +407,19 @@ static std::unordered_map GetInplaceOps( (visited_values.count(op.result(out_slot)) > 0) || (!CanBeDeleted(op.result(out_slot))) || (reused_input_values.count(op.operand_source(in_slot)) > 0) || - (reused_output_values.count(op.result(out_slot)) > 0)) { + (reused_output_values.count(op.result(out_slot)) > 0) || + (std::find(used_external_values.begin(), + used_external_values.end(), + op.operand_source(in_slot)) != + used_external_values.end()) || + (std::find(used_external_values.begin(), + used_external_values.end(), + op.result(out_slot)) != used_external_values.end())) { can_do_inplace = false; VLOG(6) << upper_op_name << "'s value has been visited or reused by other inplace op, " "so that can't do inplace when setting relax to :" - << (details::relax_shape_check_ops.count(upper_op_name) > 0); + << (RelaxShapeCheckOps.count(upper_op_name) > 0); VLOG_IF( 8, ((in_slot < op.num_operands()) && (out_slot < op.num_results()))) << " -- operand " << in_slot << " and result " << out_slot @@ -450,45 +461,43 @@ static std::unordered_map GetInplaceOps( } return inplace_ops; } -} // namespace details +} // namespace class InplacePass : public pir::Pass { public: InplacePass() : pir::Pass("inplace_pass", 3) {} void Run(pir::Operation* op) override { - auto module_op = op->dyn_cast(); - IR_ENFORCE(module_op, "inplace_pass should run on module op."); - auto& block = module_op.block(); - - auto inplace_ops = details::GetInplaceOps(&block); int64_t num_rewrites_{0}; - for (auto kv : inplace_ops) { - VLOG(6) << "Do inplace for: " - << kv.first->attributes() - .at("op_name") - .dyn_cast() - .AsString(); - pir::Block::Iterator insert_pos = - std::find(block.begin(), block.end(), *kv.first); - IR_ENFORCE(insert_pos != block.end(), - "Operator %s not found in block.", - kv.first->name()); - - kv.first->set_attribute( - "op_name", - pir::StrAttribute::get(pir::IrContext::Instance(), kv.second)); - kv.first->set_attribute( - "is_inplace", - pir::BoolAttribute::get(pir::IrContext::Instance(), true)); - num_rewrites_++; + for (size_t i = 0; i < op->num_regions(); ++i) { + auto& region = op->region(i); + for (auto& block : region) { + auto inplace_ops = GetInplaceOps(block); + + for (const auto& kv : inplace_ops) { + VLOG(6) << "Do inplace for: " + << kv.first->attributes() + .at("op_name") + .dyn_cast() + .AsString(); + pir::Block::Iterator insert_pos = + std::find(block.begin(), block.end(), *kv.first); + IR_ENFORCE(insert_pos != block.end(), + "Operator %s not found in block.", + kv.first->name()); + + kv.first->set_attribute( + "op_name", + pir::StrAttribute::get(pir::IrContext::Instance(), kv.second)); + kv.first->set_attribute( + "is_inplace", + pir::BoolAttribute::get(pir::IrContext::Instance(), true)); + num_rewrites_++; + } + } } PrintStatistics(num_rewrites_); } - - bool CanApplyOn(pir::Operation* op) const override { - return op->isa<::pir::ModuleOp>() && op->num_regions() > 0; - } }; namespace pir { diff --git a/paddle/fluid/pir/transforms/transform_general_functions.cc b/paddle/fluid/pir/transforms/transform_general_functions.cc index d0d44b1a720af9..7f9f74cb6710ac 100644 --- a/paddle/fluid/pir/transforms/transform_general_functions.cc +++ b/paddle/fluid/pir/transforms/transform_general_functions.cc @@ -14,12 +14,47 @@ #include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include + #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/op_operand.h" #include "paddle/pir/core/parameter.h" #include "paddle/pir/core/program.h" +#include "paddle/pir/core/value.h" + +namespace { + +void GetUsedExternalValueImpl( + std::unordered_set& defined_values, // NOLINT + std::vector& used_values, // NOLINT + const pir::Operation& op) { + for (size_t index = 0; index < op.num_operands(); ++index) { + pir::Value value = op.operand_source(index); + if (defined_values.find(value) == defined_values.end()) { + used_values.push_back(value); + defined_values.insert(value); + } + } + for (auto& region : op) { + for (auto& block : region) { + for (auto value : block.args()) { + defined_values.insert(value); + } + } + for (auto& block : region) { + for (auto& inner_op : block) { + GetUsedExternalValueImpl(defined_values, used_values, inner_op); + } + } + } + for (size_t index = 0; index < op.num_results(); ++index) { + defined_values.insert(op.result(index)); + } +} + +} // namespace namespace pir { @@ -58,7 +93,7 @@ pir::Type GetDataTypeFromValue(pir::Value value) { return value.type().dyn_cast().dtype(); } -Operation* GetDefiningOpForInput(Operation* op, uint32_t index) { +Operation* GetDefiningOpForInput(const Operation* op, uint32_t index) { PADDLE_ENFORCE_EQ( index < op->num_operands() && op->operand_source(index), true, @@ -66,8 +101,8 @@ Operation* GetDefiningOpForInput(Operation* op, uint32_t index) { return op->operand_source(index).dyn_cast().owner(); } -std::vector> GetUseOpsForOutput(Operation* op, - uint32_t index) { +std::vector> GetUseOpsForOutput( + const Operation* op, uint32_t index) { PADDLE_ENFORCE_EQ( index < op->num_results(), true, @@ -80,4 +115,21 @@ std::vector> GetUseOpsForOutput(Operation* op, return use_ops; } +std::vector GetUsedExternalValue(const pir::Operation& op) { + std::unordered_set defined_values{nullptr}; + std::vector used_values; + GetUsedExternalValueImpl(defined_values, used_values, op); + return used_values; +} + +std::vector GetUsedExternalValue(const pir::Block& block) { + auto& args = block.args(); + std::unordered_set defined_values(args.begin(), args.end()); + std::vector used_values; + for (auto& op : block) { + GetUsedExternalValueImpl(defined_values, used_values, op); + } + return used_values; +} + } // namespace pir diff --git a/paddle/fluid/pir/transforms/transform_general_functions.h b/paddle/fluid/pir/transforms/transform_general_functions.h index e653f5d4713c1d..3c909accf1b5f7 100644 --- a/paddle/fluid/pir/transforms/transform_general_functions.h +++ b/paddle/fluid/pir/transforms/transform_general_functions.h @@ -57,23 +57,43 @@ pir::Type GetDataTypeFromValue(pir::Value value); /** * @brief Get an operation that defines the specific input of the operation. * - * @param Operation* pointer to an operation + * @param const Operation* const pointer to an operation * @param uint32_t index of operand of the operation * * @return Operation* */ -Operation* GetDefiningOpForInput(Operation* op, uint32_t index); +Operation* GetDefiningOpForInput(const Operation* op, uint32_t index); /** * @brief Get operations and the index of designative op operand (op result) that use the specific output of the operation. * - * @param Operation* pointer to an operation + * @param const Operation* cosnt pointer to an operation * @param uint32_t index of result of the operation * @return std::vector> */ -std::vector> GetUseOpsForOutput(Operation* op, - uint32_t index); +std::vector> GetUseOpsForOutput( + const Operation* op, uint32_t index); + +/** +* @brief Get the value of the input and output of the specified op in the +external block. +* +* @param const Operation& const reference to an operation + +* @return std::vector +*/ +std::vector GetUsedExternalValue(const Operation& op); + +/** + * @brief Get the external value of the input and output of all op which in the + specified block. + * + * @param const Block& const reference to an block + + * @return std::vector + */ +std::vector GetUsedExternalValue(const Block& block); } // namespace pir diff --git a/paddle/fluid/pybind/control_flow_api.cc b/paddle/fluid/pybind/control_flow_api.cc index 2979d944e0bbf4..2cf9bcd424ffe6 100644 --- a/paddle/fluid/pybind/control_flow_api.cc +++ b/paddle/fluid/pybind/control_flow_api.cc @@ -24,6 +24,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" #include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" +#include "paddle/fluid/pir/transforms/transform_general_functions.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/place.h" @@ -111,51 +112,6 @@ void BindAssertOp(py::module* m) { "as_operation", &AssertOp::operation, return_value_policy::reference); } -void GetUsedExternalValueImpl( - std::unordered_set& defined_values, // NOLINT - std::vector& used_values, // NOLINT - const Operation& op) { - for (size_t index = 0; index < op.num_operands(); ++index) { - Value value = op.operand_source(index); - if (defined_values.find(value) == defined_values.end()) { - used_values.push_back(value); - defined_values.insert(value); - } - } - for (auto& region : op) { - for (auto& block : region) { - for (auto value : block.args()) { - defined_values.insert(value); - } - } - for (auto& block : region) { - for (auto& inner_op : block) { - GetUsedExternalValueImpl(defined_values, used_values, inner_op); - } - } - } - for (size_t index = 0; index < op.num_results(); ++index) { - defined_values.insert(op.result(index)); - } -} - -std::vector GetUsedExternalValue(const Operation& op) { - std::unordered_set defined_values{nullptr}; - std::vector used_values; - GetUsedExternalValueImpl(defined_values, used_values, op); - return used_values; -} - -std::vector GetUsedExternalValue(const Block& block) { - auto& args = block.args(); - std::unordered_set defined_values(args.begin(), args.end()); - std::vector used_values; - for (auto& op : block) { - GetUsedExternalValueImpl(defined_values, used_values, op); - } - return used_values; -} - Value BuildHasElementsOp(Operation& fwd_op) { // NOLINT PADDLE_ENFORCE(fwd_op.isa(), phi::errors::PreconditionNotMet( @@ -246,9 +202,9 @@ void PyIfOp::UpdateOutput() { void BindControlFlowApi(py::module* m) { m->def("get_used_external_value", - [](const Operation& op) { return GetUsedExternalValue(op); }); + [](const Operation& op) { return pir::GetUsedExternalValue(op); }); m->def("get_used_external_value", - [](const Block& block) { return GetUsedExternalValue(block); }); + [](const Block& block) { return pir::GetUsedExternalValue(block); }); m->def("build_pipe_for_block", BuildPipeForBlock); m->def("cf_has_elements", BuildHasElementsOp); m->def("cf_yield", [](py::list inputs) { From d1344c9427b9199d195b0a5c13532e2c6b47d552 Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Wed, 27 Dec 2023 18:43:05 +0800 Subject: [PATCH 005/142] =?UTF-8?q?=E3=80=90pir=E3=80=91delete=20wrong=20o?= =?UTF-8?q?ld=20ir=20while=5Floop=20test=20add=20pir=20test=20(#60328)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * optimize backward * modfiy while_loop * delete print * modify append_full_like use copy value * clear * clear --- python/paddle/autograd/ir_backward.py | 17 ++--- test/legacy_test/test_while_loop_op.py | 95 +++++++++++++------------- 2 files changed, 51 insertions(+), 61 deletions(-) diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index a8ac124e6e2b15..eed96992a1d528 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -574,7 +574,6 @@ def make_input_with_input_stopgradient(op): return inputs, input_grad_stopgradients def update_input_grad_map(op, input_grads, all_inputs): - _, fwd_value_to_block_argument_map = argument_to_value(op) i = 0 for input, grad_semantic in zip(all_inputs, get_grad_semantic_info(op)): if not grad_semantic: @@ -631,8 +630,11 @@ def append_yield( if len(state.value_to_valuegrad[value]) > 1: append_add_n(value) else: + new_value = return_map_value( + value, control_flow_value_to_copyvalue_map + ) value_grad = append_full_like( - 0.0, value, value, state, backward_ops + 0.0, new_value, value, state, backward_ops ) input_grad = state.value_to_valuegrad[value][0][0] @@ -762,16 +764,6 @@ def argument_to_value(while_op): for sub_fwd_block, sub_bwd_block in zip( op.blocks(), grad_op.blocks() ): - # update grad_op structure - if grad_op.name() == "pd_op.while": - ( - _, - sub_bwd_block_argument_to_value_map, - ) = argument_to_value(grad_op) - else: - sub_bwd_block_argument_to_value_map = ( - ValueDict() - ) sub_state = state.copy(sub_fwd_block) sub_backward_ops = [] append_backward_ops( @@ -784,7 +776,6 @@ def argument_to_value(while_op): no_grad_set, sub_backward_ops, sub_state, - sub_bwd_block_argument_to_value_map, ) # update input_grad map update_input_grad_map(op, input_grads, origin_inputs) diff --git a/test/legacy_test/test_while_loop_op.py b/test/legacy_test/test_while_loop_op.py index ca874defb6b0d5..42582d092fa6ff 100644 --- a/test/legacy_test/test_while_loop_op.py +++ b/test/legacy_test/test_while_loop_op.py @@ -254,66 +254,63 @@ def internal_body(j, init, sums): class TestApiWhileLoop_Backward(unittest.TestCase): - # TODO(zhangbo): Support while grad exe for pir - # @test_with_pir_api def test_while_loop_backward(self): - def cond(i, x): - return paddle.less_than(i, eleven) + with paddle.pir_utils.IrGuard(): + + def cond(i, x): + return paddle.less_than(i, eleven) + + def body(i, x): + x = paddle.multiply(x=i, y=i) + i = paddle.increment(i) + return [i, x] + + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + i = paddle.static.data(name='i', shape=[1], dtype='float32') + i.stop_gradient = False + i.persistable = True + eleven = paddle.tensor.fill_constant( + shape=[1], dtype='float32', value=11 + ) + one = paddle.tensor.fill_constant( + shape=[1], dtype='float32', value=1 + ) + x = paddle.static.data(name='x', shape=[1], dtype='float32') + x.stop_gradient = False + x.persistable = True - def body(i, x): - x = paddle.multiply(x=i, y=i) - i = paddle.increment(i) - return [i, x] + out = paddle.static.nn.while_loop(cond, body, [i, x]) + mean = paddle.mean(out[1]) + grad_list = append_backward(mean) - main_program = paddle.static.Program() - startup_program = paddle.static.Program() - with paddle.static.program_guard(main_program, startup_program): - i = paddle.static.data(name='i', shape=[1], dtype='float32') - i.stop_gradient = False - i.persistable = True - eleven = paddle.tensor.fill_constant( - shape=[1], dtype='float32', value=11 - ) - one = paddle.tensor.fill_constant( - shape=[1], dtype='float32', value=1 + place = ( + base.CUDAPlace(0) + if core.is_compiled_with_cuda() + else base.CPUPlace() ) - x = paddle.static.data(name='x', shape=[1], dtype='float32') - x.stop_gradient = False - x.persistable = True - - out = paddle.static.nn.while_loop(cond, body, [i, x]) - mean = paddle.mean(out[1]) - grad_list = append_backward(mean) + exe = base.Executor(place) - place = ( - base.CUDAPlace(0) - if core.is_compiled_with_cuda() - else base.CPUPlace() - ) - exe = base.Executor(place) + feed_i = np.ones(1).astype('float32') + feed_x = np.ones(1).astype('float32') + data = np.asarray([100]).astype('float32') + i_grad = np.asarray([0]).astype('float32') + x_grad = np.asarray([0]).astype('float32') - feed_i = np.ones(1).astype('float32') - feed_x = np.ones(1).astype('float32') - data = np.asarray([100]).astype('float32') - i_grad = np.asarray([110]).astype('float32') - - if paddle.framework.in_pir_mode(): for p, g in grad_list: - if p == i: + if p.is_same(i): di = g + elif p.is_same(x): + dx = g res = exe.run( main_program, feed={'i': feed_i, 'x': feed_x}, - fetch_list=[mean, di], + fetch_list=[mean, di, dx], ) - else: - res = exe.run( - main_program, - feed={'i': feed_i, 'x': feed_x}, - fetch_list=[mean.name, i.grad_name, x.grad_name], - ) - np.testing.assert_allclose(np.asarray(res[0]), data, rtol=1e-05) - np.testing.assert_allclose(np.asarray(res[1]), i_grad, rtol=1e-05) + np.testing.assert_allclose(np.asarray(res[0]), data, rtol=1e-05) + np.testing.assert_allclose(np.asarray(res[1]), i_grad, rtol=1e-05) + np.testing.assert_allclose(np.asarray(res[2]), x_grad, rtol=1e-05) @test_with_pir_api def test_while_loop_backward2(self): @@ -356,6 +353,7 @@ def body(i, x): fetch_list = [out[1]] for p, g in grad_list: fetch_list.append(g) + res = exe.run( main_program, feed={'i': feed_i, 'x': feed_x}, @@ -367,6 +365,7 @@ def body(i, x): feed={'i': feed_i, 'x': feed_x}, fetch_list=[out[1].name, i.grad_name, x.grad_name], ) + np.testing.assert_allclose(np.asarray(res[0]), data, rtol=1e-05) np.testing.assert_allclose(np.asarray(res[1]), i_grad, rtol=1e-05) np.testing.assert_allclose(np.asarray(res[2]), x_grad, rtol=1e-05) From cdeb3a632de460323dc2fec5e872898fafaeb7ca Mon Sep 17 00:00:00 2001 From: lijialin03 <124568209+lijialin03@users.noreply.github.com> Date: Wed, 27 Dec 2023 19:07:03 +0800 Subject: [PATCH 006/142] fix bug of lbfgs test=develop (#60219) * fix bug of lbfgs test=develop * update 1 * update 2 * update 3 test file --- python/paddle/optimizer/lbfgs.py | 12 +++++------- test/legacy_test/test_lbfgs_class.py | 10 ++++++++++ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/python/paddle/optimizer/lbfgs.py b/python/paddle/optimizer/lbfgs.py index 215473ff3a7406..936b71b232d4d9 100644 --- a/python/paddle/optimizer/lbfgs.py +++ b/python/paddle/optimizer/lbfgs.py @@ -155,12 +155,7 @@ def _strong_wolfe( gtd_new = paddle.dot(grad_new, d) # bracket an interval containing a point satisfying the Wolfe criteria - t_prev, f_prev, g_prev, gtd_prev = ( - paddle.to_tensor(0, dtype=grad.dtype), - loss, - grad, - gtd, - ) + t_prev, f_prev, g_prev, gtd_prev = (0, loss, grad, gtd) done = False ls_iter = 0 while ls_iter < max_ls: @@ -227,7 +222,10 @@ def _strong_wolfe( low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0) while not done and ls_iter < max_ls: # line-search bracket is so small - if paddle.abs(bracket[1] - bracket[0]) * d_norm < tolerance_change: + bracket_ls = bracket[1] - bracket[0] + if not isinstance(bracket_ls, paddle.Tensor): + bracket_ls = paddle.to_tensor(bracket_ls, dtype=gtd_new.dtype) + if paddle.abs(bracket_ls) * d_norm < tolerance_change: break # compute new trial value diff --git a/test/legacy_test/test_lbfgs_class.py b/test/legacy_test/test_lbfgs_class.py index 47c0d36b9ecddc..631d21962e398b 100644 --- a/test/legacy_test/test_lbfgs_class.py +++ b/test/legacy_test/test_lbfgs_class.py @@ -498,6 +498,16 @@ def func3(x, alpha, d): paddle.to_tensor([1.0]), max_ls=1, ) + lbfgs._strong_wolfe( + func2, + paddle.to_tensor([1.0]), + -0.001, + paddle.to_tensor([1.0]), + paddle.to_tensor([1.0]), + paddle.to_tensor([1.0]), + paddle.to_tensor([1.0]), + max_ls=1, + ) lbfgs._strong_wolfe( func3, From 9faa23f7e835b24d698d014e29e7765f0fd105a5 Mon Sep 17 00:00:00 2001 From: Yichen Zhang <32740647+pkuzyc@users.noreply.github.com> Date: Wed, 27 Dec 2023 21:18:00 +0800 Subject: [PATCH 007/142] fix the randomness in c_softmax_with_cross_entropy (#60370) --- .../collective/c_softmax_with_cross_entropy_op.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu index f8f43d5c9da48c..88bd57f55016c9 100644 --- a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu @@ -295,10 +295,8 @@ struct CSoftmaxWithCrossEntropyFunctor { sum_exp_logits = ctx.AllocateTmpTensor({N, 1}, dev_ctx); sum_exp_logits.mutable_data(place); - auto eigen_sum_exp_logits = - phi::funcs::EigenMatrix::From(sum_exp_logits); - eigen_sum_exp_logits.device(*dev_ctx.eigen_device()) = - eigen_softmax.sum(along_axis); + phi::SumKernel( + dev_ctx, softmax_2d, {-1}, softmax_2d.dtype(), true, &sum_exp_logits); if (comm_ctx) { comm_ctx->AllReduce(&sum_exp_logits, sum_exp_logits, ncclSum, stream); @@ -333,6 +331,8 @@ struct CSoftmaxWithCrossEntropyFunctor { N); } + auto eigen_sum_exp_logits = + phi::funcs::EigenMatrix::From(sum_exp_logits); eigen_softmax.device(*dev_ctx.eigen_device()) = (eigen_softmax * eigen_sum_exp_logits.inverse().broadcast(one_by_class)); From 04bceca9d67057b3495e4cb75cc15f580bcf711f Mon Sep 17 00:00:00 2001 From: winter-wang <78149749+winter-wang@users.noreply.github.com> Date: Wed, 27 Dec 2023 21:34:01 +0800 Subject: [PATCH 008/142] [PIR] support mutable loop_vars in while_loop. (#60330) --- .../dialect/operator/ir/control_flow_op.cc | 122 ++++++++++++++++-- .../pir/dialect/operator/ir/control_flow_op.h | 7 +- .../pir/dialect/operator/ir/op_dialect.cc | 3 +- paddle/fluid/pybind/control_flow_api.cc | 85 ++++++++++-- paddle/fluid/pybind/control_flow_api.h | 16 +++ paddle/fluid/pybind/pir.cc | 10 +- paddle/phi/infermeta/unary.cc | 1 + paddle/pir/core/block.cc | 16 ++- paddle/pir/core/block.h | 2 + paddle/pir/core/interface_support.h | 4 +- paddle/pir/core/interface_value.h | 4 +- paddle/pir/core/region.cc | 10 ++ paddle/pir/core/region.h | 7 +- python/paddle/static/nn/control_flow.py | 14 +- test/ir/pir/test_ir_pybind.py | 3 - test/ir/pir/test_while_api.py | 10 +- test/legacy_test/test_while_loop_op.py | 6 +- 17 files changed, 258 insertions(+), 62 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index a898965f1f7025..040fbb28377115 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -287,20 +287,30 @@ std::vector> IfOp::Vjp( void WhileOp::Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT pir::Value cond, - const std::vector &inputs) { + const std::vector &inputs, + bool construct_body) { argument.AddInput(cond); argument.AddInputs(inputs); - auto &body = argument.AddRegion().emplace_back(); std::vector outs_stop_gradient; - for (auto val : inputs) { - argument.AddOutput(val.type()); - auto arg = body.AddArgument(val.type()); - - auto bool_attr = val.attribute(kStopGradientAttrName); - arg.set_attribute(kStopGradientAttrName, - bool_attr ? bool_attr : builder.bool_attr(false)); - outs_stop_gradient.push_back(bool_attr ? bool_attr - : builder.bool_attr(false)); + if (construct_body) { + auto &body = argument.AddRegion().emplace_back(); + for (auto val : inputs) { + argument.AddOutput(val.type()); + auto arg = body.AddArgument(val.type()); + auto bool_attr = val.attribute(kStopGradientAttrName); + outs_stop_gradient.push_back(bool_attr ? bool_attr + : builder.bool_attr(false)); + arg.set_attribute(kStopGradientAttrName, + bool_attr ? bool_attr : builder.bool_attr(false)); + } + } else { + argument.AddRegion(nullptr); + for (auto val : inputs) { + argument.AddOutput(val.type()); + auto bool_attr = val.attribute(kStopGradientAttrName); + outs_stop_gradient.push_back(bool_attr ? bool_attr + : builder.bool_attr(false)); + } } argument.AddAttribute( @@ -343,6 +353,96 @@ void WhileOp::Print(pir::IrPrinter &printer) { os << "\n }"; } +void WhileOp::VerifySig() { + VLOG(4) << "Start Verifying inputs, outputs and attributes for: WhileOp."; + auto input_size = num_operands(); + PADDLE_ENFORCE_GE( + input_size, + 1u, + phi::errors::PreconditionNotMet( + "The size %d of inputs must be greater or equal to 1.", input_size)); + + if (auto cond_type = operand_type(0).dyn_cast()) { + PADDLE_ENFORCE_EQ( + cond_type.dtype().isa(), + true, + phi::errors::PreconditionNotMet( + "Type validation failed for the 0th input, it should be a " + "bool DenseTensorType.")); + } else if (auto cond_type = + operand_type(0).dyn_cast()) { + PADDLE_ENFORCE_EQ( + cond_type.dtype().isa(), + true, + phi::errors::PreconditionNotMet( + "Type validation failed for the 0th input, it should be a " + "bool DenseTensorType.")); + } else { + PADDLE_THROW(phi::errors::PreconditionNotMet( + "Currently, the while op cond input only support bool dense_tensor " + "and bool allocated_dense_tensor.")); + } + PADDLE_ENFORCE_EQ((*this)->num_regions(), + 1u, + phi::errors::PreconditionNotMet( + "The size %d of regions must be equal to 1.", + (*this)->num_regions())); + auto output_size = num_results(); + PADDLE_ENFORCE_EQ(output_size + 1, + input_size, + phi::errors::PreconditionNotMet( + "The result size (%d) not equal to input size(%d) + 1.", + num_results(), + input_size)); + for (size_t index = 0; index < output_size; ++index) { + PADDLE_ENFORCE_EQ( + operand_type(index + 1), + result_type(index), + phi::errors::PreconditionNotMet( + "The (%d) result and operand type is not equal.", index)); + } +} + +void WhileOp::VerifyRegion() { + VLOG(4) << "Start verifying sub regions for: WhileOp."; + PADDLE_ENFORCE_EQ( + (*this)->region(0).size(), + 1u, + phi::errors::PreconditionNotMet("The size %d of body_region must be 1.", + (*this)->region(0).size())); + auto &body_block = body(); + auto output_size = num_results(); + PADDLE_ENFORCE_EQ( + body_block.args_size(), + output_size, + phi::errors::PreconditionNotMet( + "The result size (%d) not equal to block args size(%d) + 1.", + output_size, + body_block.args_size())); + + PADDLE_ENFORCE_EQ( + body_block.empty(), + false, + phi::errors::PreconditionNotMet("The body block is empty.")); + + auto yield_op = body_block.back().dyn_cast(); + auto input_size = num_operands(); + PADDLE_ENFORCE_EQ( + yield_op && yield_op.num_operands() == input_size, + true, + phi::errors::PreconditionNotMet( + "The body block yield size not equal to operands size.")); + // Todo: fix other bugs and make the following code work. + // for (size_t index = 0; index < input_size; ++index) { + // PADDLE_ENFORCE_EQ( + // operand_type(index), + // yield_op.operand_type(index), + // phi::errors::PreconditionNotMet( + // "The (%d) operand and block yield type is not equal.", index)); + // } + VLOG(4) << "Successful end verifying sub regions for: WhileOp."; +} + std::vector> WhileOp::Vjp( pir::Operation *op, const std::vector> &inputs, diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h index baffcadc127184..3c86d56d116165 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h @@ -77,13 +77,14 @@ class WhileOp : public pir::Op { static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT pir::Value cond, - const std::vector &inputs); + const std::vector &inputs, + bool construct_body = true); TEST_API pir::Block &body(); pir::Value cond(); const pir::Block::ArgListType &block_args() { return body().args(); } void Print(pir::IrPrinter &printer); // NOLINT - void VerifySig() {} - void VerifyRegion() {} + void VerifySig(); + void VerifyRegion(); static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 8cd6375dbe7b64..7b5959a542e7af 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -35,8 +35,7 @@ OperatorDialect::OperatorDialect(pir::IrContext *ctx) ctx->GetOrRegisterDialect<::pir::ControlFlowDialect>(); auto info = ctx->GetRegisteredOpInfo(pir::TuplePushOp::name()); info.AttachInterface(std::move( - pir::InterfaceValue:: - Get())); + pir::InterfaceValue::Get())); } void OperatorDialect::initialize() { diff --git a/paddle/fluid/pybind/control_flow_api.cc b/paddle/fluid/pybind/control_flow_api.cc index 2cf9bcd424ffe6..42beed478d8219 100644 --- a/paddle/fluid/pybind/control_flow_api.cc +++ b/paddle/fluid/pybind/control_flow_api.cc @@ -40,6 +40,8 @@ using paddle::dialect::AssertOp; using paddle::dialect::HasElementsOp; using paddle::dialect::IfOp; using paddle::dialect::WhileOp; +using paddle::pybind::PyIfOp; +using paddle::pybind::PyWhileOp; using pir::Block; using pir::Builder; using pir::Operation; @@ -51,8 +53,6 @@ using pir::Type; using pir::Value; using pir::YieldOp; using pybind11::return_value_policy; - -using paddle::pybind::PyIfOp; namespace { void BindIfOp(py::module* m) { @@ -79,22 +79,24 @@ void BindIfOp(py::module* m) { } void BindWhileOp(py::module* m) { - m->def("build_while_op", [](Value cond, py::list loop_vars) { + m->def("build_while_op", [](Value cond, py::list loop_vars) -> PyWhileOp { std::vector loop_values; for (auto var : loop_vars) { loop_values.push_back(var.cast()); } - return ApiBuilder::Instance().GetBuilder()->Build(cond, - loop_values); + return PyWhileOp( + ApiBuilder::Instance().GetBuilder()->Build(cond, loop_values)); }); - py::class_ while_op(*m, "WhileOp", R"DOC( + py::class_ while_op(*m, "WhileOp", R"DOC( WhileOp in python api. )DOC"); - while_op.def("body", &WhileOp::body, return_value_policy::reference) - .def("as_operation", &WhileOp::operation, return_value_policy::reference) + while_op.def("body", &PyWhileOp::body, return_value_policy::reference) + .def( + "as_operation", &PyWhileOp::operation, return_value_policy::reference) .def("block_arguments", &WhileOp::block_args, - return_value_policy::reference); + return_value_policy::reference) + .def("optimize_update", &PyWhileOp::OptimizeUpdate); } void BindAssertOp(py::module* m) { @@ -183,7 +185,7 @@ PyIfOp::PyIfOp(IfOp if_op) : IfOp(if_op) { void PyIfOp::UpdateOutput() { PADDLE_ENFORCE_NOT_NULL( - *this, + operation_, paddle::platform::errors::InvalidArgument( "The if_op in PyIfOp used to update output can't be nullptr")); auto block = parent(); @@ -197,7 +199,68 @@ void PyIfOp::UpdateOutput() { cond(), true_region().TakeBack(), false_region().TakeBack()); block->Assign(iter, new_if_op); IfOp::operator=(new_if_op); - VerifyRegion(); + operation_->Verify(); +} + +PyWhileOp::PyWhileOp(WhileOp while_op) : WhileOp(while_op) { + PADDLE_ENFORCE_NOT_NULL( + operation_, + paddle::platform::errors::InvalidArgument( + "The while_op used to construct PyWhileOp can't be nullptr")); +} + +std::vector PyWhileOp::OptimizeUpdate() { + PADDLE_ENFORCE_NOT_NULL(operation_, + paddle::platform::errors::InvalidArgument( + "The while_op in PyWhileOp used to remove unused " + "loop vars can't be nullptr")); + auto parent_block = parent(); + PADDLE_ENFORCE_NOT_NULL( + parent_block, + paddle::platform::errors::InvalidArgument( + "The parent block of while_op which used to remove " + "unused loop vars can't be nullptr")); + + operation_->Verify(); + auto& body_block = body(); + auto yield_op = body_block.back().dyn_cast(); + auto operand_num = operation_->num_operands(); + bool no_change = true; + std::vector index_vec; + std::vector res, new_input, new_yield_val{yield_op.operand_source(0)}; + for (uint32_t i = 0; i < num_results(); ++i) { + res.push_back(result(i)); + } + for (size_t operand_index = 1u, arg_index = 0u; operand_index < operand_num; + ++operand_index) { + if (yield_op.operand_source(operand_index) == body_block.arg(arg_index)) { + body_block.arg(arg_index).ReplaceAllUsesWith( + operand_source(operand_index)); + body_block.EraseArgument(arg_index); + no_change = false; + res[operand_index - 1u] = operand_source(operand_index); + } else { + new_input.push_back(operand_source(operand_index)); + index_vec.push_back(operand_index - 1u); + new_yield_val.push_back(yield_op.operand_source(operand_index)); + ++arg_index; + } + } + if (no_change) return res; + Block::Iterator iter = **this; + Builder builder(ir_context(), false); + auto new_while_op = builder.Build(cond(), new_input, false); + new_while_op->region(0).swap(std::move(operation_->region(0))); + parent_block->Assign(iter, new_while_op); + WhileOp::operator=(new_while_op); + body_block.pop_back(); + builder.SetInsertionPointToBlockEnd(&body_block); + builder.Build(new_yield_val); + operation_->Verify(); + for (size_t result_index = 0; result_index < num_results(); ++result_index) { + res[index_vec[result_index]] = result(result_index); + } + return res; } void BindControlFlowApi(py::module* m) { diff --git a/paddle/fluid/pybind/control_flow_api.h b/paddle/fluid/pybind/control_flow_api.h index 18905bdc096787..020904a6d999dc 100644 --- a/paddle/fluid/pybind/control_flow_api.h +++ b/paddle/fluid/pybind/control_flow_api.h @@ -25,6 +25,22 @@ class PyIfOp : public dialect::IfOp { void UpdateOutput(); }; +class PyWhileOp : public dialect::WhileOp { + public: + explicit PyWhileOp(dialect::WhileOp while_op); + + /// + /// \brief Construct a new while_op to replace the original while_op. The + /// input, output, and parameters of the new while_op no longer contain the + /// variables that have not been modified in the loop. The size of the return + /// value is equal to the output size of the original while_op, where the + /// value of the read-only loop variable is the corresponding operand of the + /// original while_op, and the value of the non-read-only loop variable is the + /// corresponding output of the new while_op, + /// + std::vector OptimizeUpdate(); +}; + void BindControlFlowApi(pybind11::module *m); } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 330f5650caf1a9..7e1d46b3364c8d 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -527,14 +527,8 @@ void BindOperation(py::module *m) { }) .def("as_if_op", [](Operation &self) { return PyIfOp(self.dyn_cast()); }) - .def("as_while_op", [](Operation &self) -> WhileOp { - auto while_op = self.dyn_cast(); - if (!while_op) { - PADDLE_THROW(phi::errors::InvalidArgument( - "Can't cast non-while type Operation to WhileOp.")); - } - return while_op; - }); + .def("as_while_op", + [](Operation &self) { return PyWhileOp(self.dyn_cast()); }); py::class_ block_container( *m, "Operation_BlockContainer", R"DOC( The Operation_BlockContainer only use to walk all blocks in the operation. diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index d221c139689105..90987398057fe9 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1859,6 +1859,7 @@ void IncrementInferMeta(const MetaTensor& x, float value, MetaTensor* out) { product(x.dims()))); out->set_dims(x.dims()); out->share_lod(x); + out->set_layout(x.layout()); out->set_dtype(x.dtype()); } diff --git a/paddle/pir/core/block.cc b/paddle/pir/core/block.cc index 73902960c95ab7..49389454545d10 100644 --- a/paddle/pir/core/block.cc +++ b/paddle/pir/core/block.cc @@ -32,6 +32,12 @@ void Block::push_back(Operation *op) { insert(ops_.end(), op); } void Block::push_front(Operation *op) { insert(ops_.begin(), op); } +void Block::pop_back() { + IR_ENFORCE(!ops_.empty(), "can't pop back from empty block."); + ops_.back()->Destroy(); + ops_.pop_back(); +} + Operation *Block::GetParentOp() const { return parent_ ? parent_->GetParent() : nullptr; } @@ -50,8 +56,7 @@ Block::Iterator Block::erase(ConstIterator position) { void Block::clear() { while (!empty()) { - ops_.back()->Destroy(); - ops_.pop_back(); + pop_back(); } } @@ -103,6 +108,13 @@ Value Block::AddArgument(Type type) { return argument; } +void Block::EraseArgument(uint32_t index) { + auto argument = arg(index); + IR_ENFORCE(argument.use_empty(), + "Erase a block argument that is still in use."); + argument.dyn_cast().Destroy(); + arguments_.erase(arguments_.begin() + index); +} bool Block::TopoOrderCheck(const OpListType &op_list) { std::unordered_set visited_values; for (Operation *op : op_list) { diff --git a/paddle/pir/core/block.h b/paddle/pir/core/block.h index a912676f7fb684..373f97e12c51ef 100644 --- a/paddle/pir/core/block.h +++ b/paddle/pir/core/block.h @@ -69,6 +69,7 @@ class IR_API Block { void push_back(Operation *op); void push_front(Operation *op); + void pop_back(); Iterator insert(ConstIterator iterator, Operation *op); Iterator erase(ConstIterator position); void clear(); @@ -111,6 +112,7 @@ class IR_API Block { Type arg_type(uint32_t index) const { return arguments_[index].type(); } void ClearArguments(); Value AddArgument(Type type); + void EraseArgument(uint32_t index); template void AddArguments(TypeIter first, TypeIter last); template diff --git a/paddle/pir/core/interface_support.h b/paddle/pir/core/interface_support.h index f8fc83efa31720..60211a9437d7bb 100644 --- a/paddle/pir/core/interface_support.h +++ b/paddle/pir/core/interface_support.h @@ -39,8 +39,8 @@ class ConstructInterfacesOrTraits { /// Placement new interface. template static void ConstrctInterface(InterfaceSet &interface_set) { // NOLINT - InterfaceValue val = InterfaceValue:: - Get>(); + InterfaceValue val = + InterfaceValue::Get>(); auto suceess = interface_set.insert(std::move(val)).second; IR_ENFORCE(suceess, "Interface: id[%u] is already registered. inset failed", diff --git a/paddle/pir/core/interface_value.h b/paddle/pir/core/interface_value.h index 3115dc47a365e1..4c28e35c72ca22 100644 --- a/paddle/pir/core/interface_value.h +++ b/paddle/pir/core/interface_value.h @@ -22,7 +22,7 @@ namespace pir { class IR_API InterfaceValue { public: - template + template static InterfaceValue Get(); TypeId type_id() const { return type_id_; } void *model() const { return model_; } @@ -52,7 +52,7 @@ class IR_API InterfaceValue { void *model_{nullptr}; }; -template +template InterfaceValue InterfaceValue::Get() { InterfaceValue val; val.type_id_ = TypeId::get(); diff --git a/paddle/pir/core/region.cc b/paddle/pir/core/region.cc index 66e2e9d407f755..21a09198f1d791 100644 --- a/paddle/pir/core/region.cc +++ b/paddle/pir/core/region.cc @@ -70,6 +70,16 @@ void Region::clear() { } } +void Region::swap(Region &&other) { + blocks_.swap(other.blocks_); + for (auto iter = begin(); iter != end(); ++iter) { + iter->SetParent(this, iter); + } + for (auto iter = other.begin(); iter != other.end(); ++iter) { + iter->SetParent(&other, iter); + } +} + template void Region::Walk(FuncT &&callback) { for (auto &block : *this) { diff --git a/paddle/pir/core/region.h b/paddle/pir/core/region.h index 9a4675990c8156..c8d4daadaa74ca 100644 --- a/paddle/pir/core/region.h +++ b/paddle/pir/core/region.h @@ -55,7 +55,6 @@ class IR_API Region { Block &front() { return *blocks_.front(); } Block &back() { return *blocks_.back(); } - const Block &front() const { return *blocks_.front(); } const Block &back() const { return *blocks_.back(); } @@ -65,6 +64,7 @@ class IR_API Region { Iterator insert(ConstIterator position, Block *block); Iterator erase(ConstIterator position); void clear(); + void swap(Region &&other); /// Operation Walkers, walk the operations in this region. The callback method /// is called for each nested region, block or operation, @@ -77,7 +77,6 @@ class IR_API Region { void TakeBody(Region &&other); Operation *GetParent() const { return parent_; } - void set_parent(Operation *parent) { parent_ = parent; } // return the program which contains this region. // if region is not in a program, return nullptr. Program *parent_program() const; @@ -85,7 +84,7 @@ class IR_API Region { IrContext *ir_context() const; private: - Operation *parent_{nullptr}; // not owned - std::list blocks_; // owned + Operation *const parent_{nullptr}; // not owned + std::list blocks_; // owned }; } // namespace pir diff --git a/python/paddle/static/nn/control_flow.py b/python/paddle/static/nn/control_flow.py index 5ba3a14469d8ec..3d2f9858a1feb5 100644 --- a/python/paddle/static/nn/control_flow.py +++ b/python/paddle/static/nn/control_flow.py @@ -687,21 +687,23 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): if in_pir_mode(): while_op = build_while_op(pre_cond, flatten(loop_vars)) with while_op.body() as cur_block: - args = cur_block.args() - next_var = body(*args) + args = pack_sequence_as(loop_vars, cur_block.args()) + next_vars = body(*args) try: assert_same_structure( - flatten(next_var), flatten(loop_vars), check_types=False + flatten(next_vars), flatten(loop_vars), check_types=False ) except ValueError as e: raise ValueError( "body in while_loop should return the same arity " f"(length and structure) as loop_vars: {e}" ) - next_cond = cond(*next_var) + if not isinstance(next_vars, (list, tuple)): + next_vars = [next_vars] + next_cond = cond(*next_vars) next_cond.stop_gradient = True - cf_yield([next_cond, *next_var]) - return while_op.as_operation().results() + cf_yield([next_cond, *flatten(next_vars)]) + return pack_sequence_as(loop_vars, while_op.optimize_update()) if in_dygraph_mode(): now_cond = pre_cond.item() diff --git a/test/ir/pir/test_ir_pybind.py b/test/ir/pir/test_ir_pybind.py index fda8236020b4df..9ae4a3ebbf633e 100644 --- a/test/ir/pir/test_ir_pybind.py +++ b/test/ir/pir/test_ir_pybind.py @@ -42,7 +42,6 @@ def get_ir_program(): class TestPybind(unittest.TestCase): def test_program(self): pir_program = get_ir_program() - print(pir_program) block = pir_program.global_block() program = block.program @@ -152,7 +151,6 @@ def test_type(self): pir_program = get_ir_program() matmul_op = pir_program.global_block().ops[1] add_op = pir_program.global_block().ops[2] - print(matmul_op.result(0).type()) self.assertEqual( matmul_op.result(0).type() == add_op.result(0).type(), True ) @@ -184,7 +182,6 @@ def test_attr(self): ) pir_program = pir.translate_to_pir(main_program.desc) - print(pir_program) conv_attr = pir_program.global_block().ops[3].attrs() full_attr = pir_program.global_block().ops[8].attrs() self.assertEqual(conv_attr["stop_gradient"], [False]) diff --git a/test/ir/pir/test_while_api.py b/test/ir/pir/test_while_api.py index cc07cdbb58ad66..1a5ee3186d692a 100644 --- a/test/ir/pir/test_while_api.py +++ b/test/ir/pir/test_while_api.py @@ -57,7 +57,7 @@ def test_while_base(self): out = last_op.results() self.assertEqual(out[0].stop_gradient, False) self.assertEqual(last_op.name(), "pd_op.while") - self.assertEqual(len(out), 2) + self.assertEqual(len(out), 1) def test_get_used_external_value(self): main_program = paddle.static.Program() @@ -177,20 +177,20 @@ def test_backward(self): ) self.assertEqual( main_program.global_block() - .ops[-1] + .ops[-3] .as_while_op() .body() - .ops[-2] + .ops[-4] .name(), "cf.has_elements", ) self.assertEqual( main_program.global_block() - .ops[-1] + .ops[-3] .as_while_op() .body() - .ops[-3] + .ops[-5] .name(), "pd_op.add_grad", ) diff --git a/test/legacy_test/test_while_loop_op.py b/test/legacy_test/test_while_loop_op.py index 42582d092fa6ff..83fecc6b5ad7f5 100644 --- a/test/legacy_test/test_while_loop_op.py +++ b/test/legacy_test/test_while_loop_op.py @@ -22,7 +22,6 @@ from paddle import base from paddle.base import core from paddle.base.backward import append_backward -from paddle.base.framework import program_guard from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -98,6 +97,7 @@ def body(i, mem): np.testing.assert_allclose(np.asarray(res[1]), data, rtol=1e-05) @compare_legacy_with_pt + @test_with_pir_api def test_var_dict(self): def cond(i, ten, test_dict, test_list, test_list_dict): return paddle.less_than(i, ten) @@ -118,7 +118,7 @@ def body(i, ten, test_dict, test_list, test_list_dict): main_program = paddle.static.Program() startup_program = paddle.static.Program() - with program_guard(main_program, startup_program): + with paddle.static.program_guard(main_program, startup_program): i = paddle.zeros(shape=[1], dtype='int64') ten = paddle.tensor.fill_constant( shape=[1], dtype='int64', value=10 @@ -130,7 +130,7 @@ def body(i, ten, test_dict, test_list, test_list_dict): test_dict = {"test_key": test_data} test_list = [ paddle.tensor.fill_constant( - shape=[1, 2], dtype='int64', value=0 + shape=[2, 1], dtype='int64', value=0 ) ] test_list_dict = [ From de1fe4ba7b2ded5773b0a62aba09bd8b1a297ef2 Mon Sep 17 00:00:00 2001 From: cyber-pioneer <116002591+cyber-pioneer@users.noreply.github.com> Date: Thu, 28 Dec 2023 07:22:06 +0800 Subject: [PATCH 009/142] [Prim][PIR] decomp support Inference (#60141) * inference support decomp * polish code * add decomp base define * add decomp base define2 * change decomp infer * fix symbol overload * fix test case * debug --- .../fluid/inference/api/analysis_predictor.cc | 8 + .../tensor_operants_gen.py | 23 +-- paddle/fluid/primitive/base/decomp_trans.cc | 28 ++-- paddle/fluid/primitive/base/decomp_trans.h | 21 ++- paddle/fluid/pybind/pybind.cc | 3 +- .../test_decomp_inference_predictor_run.py | 155 ++++++++++++++++++ 6 files changed, 210 insertions(+), 28 deletions(-) create mode 100644 test/ir/inference/test_decomp_inference_predictor_run.py diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index c70ef74e94baad..4af55a7c6c9337 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -56,6 +56,8 @@ #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" +#include "paddle/fluid/prim/utils/utils.h" +#include "paddle/fluid/primitive/base/decomp_trans.h" #include "paddle/phi/api/include/context_pool.h" #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/common/backend.h" @@ -786,6 +788,12 @@ bool AnalysisPredictor::PrepareExecutor() { pir_program_ = std::move( paddle::TranslateLegacyProgramToProgram(*inference_program_)); + if (paddle::prim::PrimCommonUtils::IsFwdPrimEnabled()) { + VLOG(4) << "[Prim] Decomp program in predictor begin."; + DecompProgram decomp_object(pir_program_.get()); + decomp_object.decomp_program(); + } + if (config_.use_gpu()) { ::pir::PassManager gpu_pm(::pir::IrContext::Instance(), 2); //----------------------------------------------------------------------------------------------// diff --git a/paddle/fluid/prim/api/auto_code_generated/tensor_operants_gen.py b/paddle/fluid/prim/api/auto_code_generated/tensor_operants_gen.py index 378f57a468cd46..6cf66150752828 100644 --- a/paddle/fluid/prim/api/auto_code_generated/tensor_operants_gen.py +++ b/paddle/fluid/prim/api/auto_code_generated/tensor_operants_gen.py @@ -216,6 +216,7 @@ class TEST_API StaticTensorOperants : public TensorOperantsBase { #include "paddle/fluid/primitive/type/lazy_tensor.h" PHI_DECLARE_bool(enable_pir_api); +PHI_DECLARE_bool(enable_pir_in_executor); """ @@ -228,7 +229,7 @@ class TEST_API StaticTensorOperants : public TensorOperantsBase { using LazyTensor = paddle::primitive::LazyTensor; Tensor StaticTensorOperants::add(const Tensor& x, const Scalar& y) { - if (FLAGS_enable_pir_api) { + if (FLAGS_enable_pir_api || FLAGS_enable_pir_in_executor) { return paddle::primitive::backend::add(x, paddle::primitive::backend::full(x.shape(), y, x.dtype(), x.place())); } else { return paddle::prim::add(x, paddle::prim::full(x.shape(), y, x.dtype(), x.place())); @@ -236,7 +237,7 @@ class TEST_API StaticTensorOperants : public TensorOperantsBase { } Tensor StaticTensorOperants::subtract(const Tensor& x, const Scalar& y) { - if (FLAGS_enable_pir_api) { + if (FLAGS_enable_pir_api || FLAGS_enable_pir_in_executor) { return paddle::primitive::backend::subtract(x, paddle::primitive::backend::full(x.shape(), y, x.dtype(), x.place())); } else { return paddle::prim::subtract(x, paddle::prim::full(x.shape(), y, x.dtype(), x.place())); @@ -244,7 +245,7 @@ class TEST_API StaticTensorOperants : public TensorOperantsBase { } Tensor StaticTensorOperants::multiply(const Tensor& x, const Scalar& y) { - if (FLAGS_enable_pir_api) { + if (FLAGS_enable_pir_api || FLAGS_enable_pir_in_executor) { return paddle::primitive::backend::scale(x, y, 0.0f, true); } else { return paddle::prim::scale(x, y, 0.0f, true); @@ -252,7 +253,7 @@ class TEST_API StaticTensorOperants : public TensorOperantsBase { } Tensor StaticTensorOperants::divide(const Tensor& x, const Scalar& y) { - if (FLAGS_enable_pir_api) { + if (FLAGS_enable_pir_api || FLAGS_enable_pir_in_executor) { return paddle::primitive::backend::divide(x, paddle::primitive::backend::full(x.shape(), y, x.dtype(), x.place())); } else { return paddle::prim::divide(x, paddle::prim::full(x.shape(), y, x.dtype(), x.place())); @@ -260,7 +261,7 @@ class TEST_API StaticTensorOperants : public TensorOperantsBase { } Tensor StaticTensorOperants::add(const Scalar& x, const Tensor& y) { - if (FLAGS_enable_pir_api) { + if (FLAGS_enable_pir_api || FLAGS_enable_pir_in_executor) { return paddle::primitive::backend::add(paddle::primitive::backend::full(y.shape(), x, y.dtype(), y.place()), y); } else { return paddle::prim::add(paddle::prim::full(y.shape(), x, y.dtype(), y.place()), y); @@ -269,7 +270,7 @@ class TEST_API StaticTensorOperants : public TensorOperantsBase { Tensor StaticTensorOperants::subtract(const Scalar& x, const Tensor& y) { - if (FLAGS_enable_pir_api) { + if (FLAGS_enable_pir_api || FLAGS_enable_pir_in_executor) { return paddle::primitive::backend::subtract(paddle::primitive::backend::full(y.shape(), x, y.dtype(), y.place()), y); } else { return paddle::prim::subtract(paddle::prim::full(y.shape(), x, y.dtype(), y.place()), y); @@ -277,7 +278,7 @@ class TEST_API StaticTensorOperants : public TensorOperantsBase { } Tensor StaticTensorOperants::multiply(const Scalar& x, const Tensor& y) { - if (FLAGS_enable_pir_api) { + if (FLAGS_enable_pir_api || FLAGS_enable_pir_in_executor) { return paddle::primitive::backend::scale(y, x, 0.0f, true); } else { return paddle::prim::scale(y, x, 0.0f, true); @@ -285,7 +286,7 @@ class TEST_API StaticTensorOperants : public TensorOperantsBase { } Tensor StaticTensorOperants::divide(const Scalar& x, const Tensor& y) { - if (FLAGS_enable_pir_api) { + if (FLAGS_enable_pir_api || FLAGS_enable_pir_in_executor) { return paddle::primitive::backend::divide(paddle::primitive::backend::full(y.shape(), x, y.dtype(), y.place()), y); } else { return paddle::prim::divide(paddle::prim::full(y.shape(), x, y.dtype(), y.place()), y); @@ -293,7 +294,7 @@ class TEST_API StaticTensorOperants : public TensorOperantsBase { } Tensor StaticTensorOperants::pow(const Tensor& x, const Tensor& y) { - if (FLAGS_enable_pir_api) { + if (FLAGS_enable_pir_api || FLAGS_enable_pir_in_executor) { return paddle::primitive::backend::elementwise_pow(x, y); } else { return paddle::prim::elementwise_pow(x, y); @@ -301,7 +302,7 @@ class TEST_API StaticTensorOperants : public TensorOperantsBase { } Tensor StaticTensorOperants::pow(const Tensor& x, const Scalar& y) { - if (FLAGS_enable_pir_api) { + if (FLAGS_enable_pir_api || FLAGS_enable_pir_in_executor) { return paddle::primitive::backend::elementwise_pow(x, paddle::primitive::backend::full(x.shape(), y, x.dtype(), x.place())); } else { return paddle::prim::elementwise_pow(x, paddle::prim::full(x.shape(), y, x.dtype(), x.place())); @@ -394,7 +395,7 @@ def gene_static_tensor_func_call(self): ) static_func_parameters = self.get_func_args() - static_tensor_func_call = f"""if (FLAGS_enable_pir_api) {{ + static_tensor_func_call = f"""if (FLAGS_enable_pir_api || FLAGS_enable_pir_in_executor) {{ return {backend_static_func_name}({static_func_parameters}); }} else {{ return {prim_static_func_name}({static_func_parameters}); diff --git a/paddle/fluid/primitive/base/decomp_trans.cc b/paddle/fluid/primitive/base/decomp_trans.cc index 6dde6c8b940027..df0111d56f8afc 100644 --- a/paddle/fluid/primitive/base/decomp_trans.cc +++ b/paddle/fluid/primitive/base/decomp_trans.cc @@ -124,8 +124,8 @@ void DecompProgram::check_decomp_outputs( for (size_t i = 0; i < orig_outs.size(); i++) { if (skip_invalid_op_check && paddle::dialect::IsEmptyValue(decomp_outs[i])) { - VLOG(0) << "[Prim] Decomp op skip check of output index " << i - << " of op " << op_name; + VLOG(4) << "[Prim] Decomp op skip check of " << i + << "-index output of op " << op_name; } else { PADDLE_ENFORCE( !paddle::dialect::IsEmptyValue(orig_outs[i]), @@ -238,6 +238,14 @@ std::vector DecompProgram::construct_dst_vars( return tar_vars; } +std::vector DecompProgram::get_dst_vars() { + if (!paddle::prim::PrimCommonUtils::IsFwdPrimEnabled()) { + return src_vars_; + } else { + return dst_vars_; + } +} + bool DecompProgram::enable_decomp_by_filter(const std::string& op_name) { bool flag = true; @@ -266,16 +274,7 @@ std::vector> call_decomp_rule(pir::Operation* op) { return decomp_res; } -DecompProgram::DecompProgram(pir::Program* program, - const std::vector& src_vars, - const std::set& blacklist, - const std::set& whitelist) - : program_(program), - src_vars_(src_vars), - blacklist_(blacklist), - whitelist_(whitelist) {} - -std::vector DecompProgram::decomp_program() { +void DecompProgram::decomp_program() { std::ostringstream orig_prog_stream; std::unordered_map orig_vars_dict; for (size_t i = 0; i < src_vars_.size(); i++) { @@ -285,7 +284,7 @@ std::vector DecompProgram::decomp_program() { VLOG(4) << "[Prim] Origin program bofore decomp :\n" << orig_prog_stream.str(); if (!paddle::prim::PrimCommonUtils::IsFwdPrimEnabled()) { - return src_vars_; + return; } std::vector tar_vars(src_vars_.size()); pir::Block* block = program_->block(); @@ -338,7 +337,8 @@ std::vector DecompProgram::decomp_program() { std::ostringstream decomp_prog_stream; program_->Print(decomp_prog_stream); VLOG(4) << "[Prim] New program after decomp :\n" << decomp_prog_stream.str(); - return tar_vars; + dst_vars_ = tar_vars; + return; } } // namespace paddle diff --git a/paddle/fluid/primitive/base/decomp_trans.h b/paddle/fluid/primitive/base/decomp_trans.h index 550d8beab80314..4f3a83d326b337 100644 --- a/paddle/fluid/primitive/base/decomp_trans.h +++ b/paddle/fluid/primitive/base/decomp_trans.h @@ -26,12 +26,18 @@ namespace paddle { class DecompProgram { public: + explicit DecompProgram(pir::Program* program) : program_(program) {} + DecompProgram(pir::Program* program, const std::vector& src_vars, const std::set& blacklist, - const std::set& whitelist); + const std::set& whitelist) + : program_(program), + src_vars_(src_vars), + blacklist_(blacklist), + whitelist_(whitelist) {} - std::vector decomp_program(); + void decomp_program(); bool check_decomp_dynamic_shape(pir::Operation* op); void check_decomp_outputs(const std::string& op_name, const std::vector& orig_outs, @@ -46,10 +52,21 @@ class DecompProgram { const std::vector& decomp_outs, std::unordered_map orig_vars_dict); bool enable_decomp_by_filter(const std::string& op_name); + void set_src_vars(const std::vector& src_vars) { + src_vars_ = src_vars; + } + void set_blacklist(const std::set& blacklist) { + blacklist_ = blacklist; + } + void set_whitelist(const std::set& whitelist) { + whitelist_ = whitelist; + } + std::vector get_dst_vars(); private: pir::Program* program_; std::vector src_vars_; + std::vector dst_vars_; std::set blacklist_; std::set whitelist_; }; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index efeeb4855205e2..53df4c25034abd 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -785,7 +785,8 @@ void BindDecomp(pybind11::module *m) { VLOG(4) << "[Prim] Bind Decomp sinking_decomp begin."; py::list res; DecompProgram decomp_object(program, src_vars, blacklist, whitelist); - auto tar_vars = decomp_object.decomp_program(); + decomp_object.decomp_program(); + std::vector tar_vars = decomp_object.get_dst_vars(); for (size_t i = 0; i < tar_vars.size(); ++i) { if (!tar_vars[i]) { res.append(nullptr); diff --git a/test/ir/inference/test_decomp_inference_predictor_run.py b/test/ir/inference/test_decomp_inference_predictor_run.py new file mode 100644 index 00000000000000..687f28c1bcf159 --- /dev/null +++ b/test/ir/inference/test_decomp_inference_predictor_run.py @@ -0,0 +1,155 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import numpy as np + +import paddle +from paddle.inference import Config, create_predictor + + +class TestNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.fc1 = paddle.nn.Linear(64, 32) + self.fc2 = paddle.nn.Linear(64, 32) + + def forward(self, x1, x2): + y1 = self.fc1(x1) + y2 = self.fc2(x2) + y3 = y1 + y2 + y4 = paddle.nn.functional.layer_norm(y3, y3.shape[1:]) + z = paddle.nn.functional.softmax(y4) + return z + + +class TestPredictorRunWithTensor(unittest.TestCase): + def setUp(self): + self.use_gpu = paddle.is_compiled_with_cuda() + np.random.seed(2023) + self.shape = [4, 8, 16, 64] + self.x = np.random.random(self.shape).astype(np.float32) + self.y = np.random.random(self.shape).astype(np.float32) + self.temp_dir = tempfile.TemporaryDirectory() + net = TestNet() + model = paddle.jit.to_static( + net, + input_spec=[ + paddle.static.InputSpec( + shape=self.shape, dtype='float32', name='input0' + ), + paddle.static.InputSpec( + shape=self.shape, dtype='float32', name='input1' + ), + ], + ) + paddle.jit.save( + model, + os.path.join( + self.temp_dir.name, 'test_predictor_run_model/inference' + ), + ) + + def tearDown(self): + self.temp_dir.cleanup() + + def enable_pir(self, flag: bool): + paddle.set_flags({'FLAGS_enable_pir_in_executor': flag}) + + def init_predictor(self): + config = Config( + os.path.join( + self.temp_dir.name, + 'test_predictor_run_model/inference.pdmodel', + ), + os.path.join( + self.temp_dir.name, + 'test_predictor_run_model/inference.pdiparams', + ), + ) + if self.use_gpu: + config.enable_use_gpu(256, 0) + config.switch_ir_optim(False) + config.enable_new_executor() + predictor = create_predictor(config) + return predictor + + def get_inputs(self): + input0_tensor = paddle.to_tensor(self.x) + input1_tensor = paddle.to_tensor(self.y) + + return [input0_tensor, input1_tensor] + + def get_disorder_output(self, predictor): + [input0_tensor, input1_tensor] = self.get_inputs() + + input_names = predictor.get_input_names() + input0_tensor.name = input_names[0] + input1_tensor.name = input_names[1] + + # disorder + inputs = [input1_tensor, input0_tensor] + outputs = predictor.run(inputs) + + return outputs[0] + + def get_inorder_output(self, predictor): + [input0_tensor, input1_tensor] = self.get_inputs() + + # inorder + inputs = [input0_tensor, input1_tensor] + outputs = predictor.run(inputs) + + return outputs[0] + + def test_output_prim_inorder(self): + self.enable_pir(False) + predictor = self.init_predictor() + output = self.get_inorder_output(predictor) + self.enable_pir(True) + paddle.core._set_prim_all_enabled(True) + pir_predictor = self.init_predictor() + pir_output = self.get_inorder_output(pir_predictor) + paddle.core._set_prim_all_enabled(False) + + np.testing.assert_allclose( + output.numpy().flatten(), + pir_output.numpy().flatten(), + rtol=1e-6, + atol=1e-6, + ) + + def test_output_prim_disorder(self): + self.enable_pir(False) + predictor = self.init_predictor() + output = self.get_disorder_output(predictor) + self.enable_pir(True) + paddle.core._set_prim_all_enabled(True) + pir_predictor = self.init_predictor() + pir_output = self.get_disorder_output(pir_predictor) + paddle.core._set_prim_all_enabled(False) + + np.testing.assert_allclose( + output.numpy().flatten(), + pir_output.numpy().flatten(), + rtol=1e-6, + atol=1e-6, + ) + + +if __name__ == '__main__': + unittest.main() From f1b736daa9474efb696620b4b639f10a3eedd6a6 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Thu, 28 Dec 2023 08:11:55 +0800 Subject: [PATCH 010/142] [auto parallel] add recompute to pp ut (#60406) --- .../hybrid_strategy/test_semi_auto_parallel_llama_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_llama_model.py b/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_llama_model.py index 36b6c1d5d0e978..3ace2754c7123c 100644 --- a/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_llama_model.py +++ b/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_llama_model.py @@ -74,9 +74,8 @@ def setUp(self): "backend": ["gpu"], "use_sp": ["true", "false"], "use_param_group": ["false", "true"], - # TODO(Yuang Liu): add recompute ut to pp after fixing pp probs - # "recompute": ["true", "false"], - # "recompute_granularity": ["full", "full_attn", "core_attn"], + "recompute": ["true", "false"], + "recompute_granularity": ["full", "full_attn", "core_attn"], } def test_simple_net_hybrid_strategy(self): @@ -103,6 +102,8 @@ def setUp(self): } self._changeable_envs = { "backend": ["gpu"], + "recompute": ["true", "false"], + "recompute_granularity": ["full", "full_attn", "core_attn"], } def test_simple_net_hybrid_strategy_acc(self): From a216f5b067c0a219b72f2972240afdc8bcaab90f Mon Sep 17 00:00:00 2001 From: Liujie0926 <44688141+Liujie0926@users.noreply.github.com> Date: Thu, 28 Dec 2023 10:11:39 +0800 Subject: [PATCH 011/142] fix bug (#60354) --- tools/auto_parallel/ci_auto_parallel.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/tools/auto_parallel/ci_auto_parallel.sh b/tools/auto_parallel/ci_auto_parallel.sh index 09095d7f6122b4..848a5ca1b1bbde 100644 --- a/tools/auto_parallel/ci_auto_parallel.sh +++ b/tools/auto_parallel/ci_auto_parallel.sh @@ -160,8 +160,6 @@ if [[ ${#case_list[*]} -ne 0 ]];then elif [[ ${case} == "gpt-3_dygraph" ]];then bash /workspace/PaddleNLP/scripts/distribute/ci_case_dy.sh llm_gpt_case_list_dygraph $FLAGS_install_deps $FLAGS_download_data print_info $? `ls -lt ${log_path} | grep "llm_gpt" | head -n 1 | awk '{print $9}'` ${case} - export FLAGS_install_deps=1 - export FLAGS_download_data="llm_gpt ""$FLAGS_download_data" let case_num++ elif [[ ${case} == "dygraph_unit_test" ]];then bash /workspace/Paddle/tools/auto_parallel/ci_case_unit.sh dygraph_unit_test From 95b5a6846c73bed0b0746ded0bf617a1f011ada3 Mon Sep 17 00:00:00 2001 From: zhink <33270771+zhink@users.noreply.github.com> Date: Thu, 28 Dec 2023 10:22:16 +0800 Subject: [PATCH 012/142] [paddle inference]support tgt_mask in block_multihead_attention (#60389) [paddle inference]support tgt_mask in block_multihead_attention (#60389) --- paddle/phi/kernels/fusion/gpu/block_attn.h | 38 ++++++++++++++++++- .../test_block_multihead_attention.py | 10 ++++- 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/fusion/gpu/block_attn.h b/paddle/phi/kernels/fusion/gpu/block_attn.h index 73be0901c6f36e..500ffe939870f2 100644 --- a/paddle/phi/kernels/fusion/gpu/block_attn.h +++ b/paddle/phi/kernels/fusion/gpu/block_attn.h @@ -38,6 +38,10 @@ struct Block_AttN_params { // [bsz, 1, 1, time_step(cache_seq_length)+1] const T *attn_mask; + // mask_length is the 3th dimension of attn_mask. + int mask_length; + bool mask_broadcast_num_heads; + // k_cache [max_block_num, num_head, block_size, head_size] // v_cache [max_block_num, num_head, block_size, head_size] T *k_cache; @@ -312,6 +316,14 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void block_attention_kernel( } if (tid == 0) { qk *= params.inv_sqrt_dh; + if (params.attn_mask) { + auto mask_bhi = bhi; + if (params.mask_broadcast_num_heads) { + mask_bhi = bi; + } + T mask = params.attn_mask[mask_bhi * params.mask_length + act_time_step]; + qk += static_cast(mask); + } qk_max = qk; qk_smem[act_time_step] = qk; } @@ -372,7 +384,14 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void block_attention_kernel( } float qk = Qk_dot::dot(q, k, params.inv_sqrt_dh); - + if (params.attn_mask) { + auto mask_bhi = bhi; + if (params.mask_broadcast_num_heads) { + mask_bhi = bi; + } + T mask = params.attn_mask[mask_bhi * params.mask_length + ti]; + qk += static_cast(mask); + } if (ti < act_time_step && tid % THREADS_PER_KEY == 0) { qk_max = fmaxf(qk_max, qk); qk_smem[ti] = qk; @@ -786,8 +805,25 @@ void blha(const phi::GPUContext &dev_ctx, params.max_num_blocks_per_seq = max_num_blocks_per_seq; params.neox_rotary_style = neox_rotary_style; + params.attn_mask = nullptr; + bool mask_broadcast_num_heads = false; if (src_mask_tensor) { + if (src_mask_tensor->dims()[1] == 1) { + // all head share a mask. + mask_broadcast_num_heads = true; + } else if (src_mask_tensor->dims()[1] == num_head) { + mask_broadcast_num_heads = false; + } else { + PADDLE_THROW(errors::InvalidArgument( + "Unknow dimension for attn_mask, the num_head(2nd) " + "dimension is invalid, it should be 1 or num_head(%d), " + "but got %d", + num_head, + src_mask_tensor->dims()[1])); + } params.attn_mask = src_mask_tensor->data(); + params.mask_broadcast_num_heads = mask_broadcast_num_heads; + params.mask_length = src_mask_tensor->dims()[3]; } else { params.attn_mask = nullptr; } diff --git a/test/legacy_test/test_block_multihead_attention.py b/test/legacy_test/test_block_multihead_attention.py index 04919ca3d82402..7f3033044e1c57 100644 --- a/test/legacy_test/test_block_multihead_attention.py +++ b/test/legacy_test/test_block_multihead_attention.py @@ -306,6 +306,12 @@ def setUp(self): ] * self.batch_size, ) + + self.tgt_mask = paddle.randn( + [self.batch_size, self.num_head, 1, self.seq_len + 1], + dtype=self.dtype, + ) + self.scale = 1.0 / np.sqrt(self.shape[-1]) self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) @@ -462,7 +468,7 @@ def test_all(self): naive_cache_v, None, None, - None, + self.tgt_mask, self.scale, ) .transpose([0, 2, 1, 3]) @@ -492,7 +498,7 @@ def test_all(self): None, # out_smooth None, # rotary_embs None, # attn_mask - None, # tgt_mask + self.tgt_mask, # tgt_mask 1, # seq_len, self.blocksize, False, # use_neox_rotary_style From b81deac6d3898d8ed09f5a639030353a7ce5a0b6 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Thu, 28 Dec 2023 10:23:45 +0800 Subject: [PATCH 013/142] [PIR] OneDNN Pir onednn instruction (#60257) * onednn dialect gend --- .gitignore | 2 + .../framework/new_executor/CMakeLists.txt | 10 + .../onednn_legacy_kernel_instruction.cc | 52 +++ .../onednn/onednn_legacy_kernel_instruction.h | 72 ++++ .../onednn_mixed_phi_kernel_instruction.cc | 61 +++ .../onednn_mixed_phi_kernel_instruction.h | 42 ++ .../onednn/onednn_phi_kernel_instruction.cc | 388 ++++++++++++++++++ .../onednn/onednn_phi_kernel_instruction.h | 82 ++++ .../framework/new_executor/pir_interpreter.cc | 19 + .../ir_adaptor/translator/op_translator.cc | 40 +- .../fluid/ir_adaptor/translator/translate.cc | 6 + paddle/fluid/ir_adaptor/translator/utils.cc | 6 + paddle/fluid/pir/dialect/CMakeLists.txt | 56 ++- .../pir/dialect/kernel/ir/kernel_dialect.cc | 103 +++++ .../pir/dialect/kernel/ir/kernel_dialect.h | 22 + .../fluid/pir/dialect/kernel/ir/kernel_op.cc | 127 ++++++ .../fluid/pir/dialect/kernel/ir/kernel_op.h | 43 ++ .../fluid/pir/dialect/op_generator/op_gen.py | 188 ++++++++- .../pir/dialect/op_generator/ops_api_gen.py | 1 + .../op_generator/ops_onednn_extra_parser.py | 86 ++++ .../fluid/pir/dialect/operator/ir/onednn.yaml | 9 + .../dialect/operator/ir/op_onednn_dialect.cc | 168 ++++++++ .../dialect/operator/ir/op_onednn_dialect.h | 44 ++ paddle/fluid/pir/dialect/operator/ir/ops.yaml | 9 + .../dialect/operator/ir/ops_onednn_extra.yaml | 33 ++ .../fluid/pir/dialect/operator/trait/onednn.h | 49 +++ .../fluid/pir/dialect/operator/trait/trait.cc | 10 +- .../operator/utils/op_yaml_info_util.h | 20 +- .../fluid/pir/dialect/operator/utils/utils.cc | 7 + .../fluid/pir/dialect/operator/utils/utils.h | 4 + .../pir/transforms/pd_op_to_kernel_pass.cc | 329 +++++++++++++-- paddle/phi/api/lib/data_transform.h | 5 + paddle/phi/api/yaml/op_compat.yaml | 9 + .../cpu/onednn_to_paddle_layout_kernel.cc | 94 +++++ .../kernels/onednn_to_paddle_layout_kernel.h | 28 ++ test/mkldnn/test_conv2d_mkldnn_op.py | 91 ++++ 36 files changed, 2257 insertions(+), 58 deletions(-) create mode 100644 paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_kernel_instruction.cc create mode 100644 paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_kernel_instruction.h create mode 100644 paddle/fluid/framework/new_executor/instruction/onednn/onednn_mixed_phi_kernel_instruction.cc create mode 100644 paddle/fluid/framework/new_executor/instruction/onednn/onednn_mixed_phi_kernel_instruction.h create mode 100644 paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.cc create mode 100644 paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.h create mode 100644 paddle/fluid/pir/dialect/op_generator/ops_onednn_extra_parser.py create mode 100644 paddle/fluid/pir/dialect/operator/ir/onednn.yaml create mode 100644 paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc create mode 100644 paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h create mode 100644 paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml create mode 100644 paddle/fluid/pir/dialect/operator/trait/onednn.h create mode 100644 paddle/phi/kernels/cpu/onednn_to_paddle_layout_kernel.cc create mode 100644 paddle/phi/kernels/onednn_to_paddle_layout_kernel.h diff --git a/.gitignore b/.gitignore index 232d8fa08b4bd6..c4046a8d6b6e38 100644 --- a/.gitignore +++ b/.gitignore @@ -108,6 +108,8 @@ paddle/fluid/pir/dialect/operator/ir/pd_api.* paddle/fluid/pir/dialect/operator/ir/op_decomp.cc paddle/fluid/pir/dialect/operator/ir/pd_op_vjp.cc paddle/fluid/pir/dialect/operator/ir/pd_op.* +paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.* +paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.* paddle/fluid/pir/dialect/operator/ir/pd_op_bwd.* paddle/fluid/pir/dialect/operator/ir/pd_op_fused.* paddle/fluid/pir/dialect/operator/ir/pd_op_fused_bwd.* diff --git a/paddle/fluid/framework/new_executor/CMakeLists.txt b/paddle/fluid/framework/new_executor/CMakeLists.txt index df01de6d424919..990f82efa8edeb 100644 --- a/paddle/fluid/framework/new_executor/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/CMakeLists.txt @@ -5,6 +5,16 @@ if(NOT (WITH_CINN AND NOT CINN_ONLY)) ${CMAKE_CURRENT_SOURCE_DIR}/instruction/cinn_jit_instruction.cc) endif() +if(NOT WITH_MKLDNN) + list( + REMOVE_ITEM + standalone_executor_srcs + ${CMAKE_CURRENT_SOURCE_DIR}/instruction/onednn/onednn_legacy_kernel_instruction.cc + ${CMAKE_CURRENT_SOURCE_DIR}/instruction/onednn/onednn_phi_kernel_instruction.cc + ${CMAKE_CURRENT_SOURCE_DIR}/instruction/onednn/onednn_mixed_phi_kernel_instruction.cc + ) +endif() + set(standalone_executor_deps pir program_translator diff --git a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_kernel_instruction.cc new file mode 100644 index 00000000000000..6d1944219a2dc9 --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_kernel_instruction.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_kernel_instruction.h" + +#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" +#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" +#include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" +#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" + +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/core/meta_tensor.h" +#include "paddle/phi/core/type_defs.h" + +namespace paddle { +namespace framework { + +OneDNNLegacyKernelInstruction::OneDNNLegacyKernelInstruction( + size_t id, + const platform::Place& place, + pir::Operation* op, + const ValueExecutionInfo* value_exec_info) + : InstructionBase(id, place), value_exec_info_(value_exec_info) { + PADDLE_THROW(platform::errors::Unimplemented( + "OneDNNLegacyKernelInstruction not defined now.")); +} + +OneDNNLegacyKernelInstruction::~OneDNNLegacyKernelInstruction() {} + +void OneDNNLegacyKernelInstruction::Run() { + PADDLE_THROW(platform::errors::Unimplemented( + "OneDNNLegacyKernelInstruction not defined now.")); +} +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_kernel_instruction.h b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_kernel_instruction.h new file mode 100644 index 00000000000000..e5c7b0cd151765 --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_kernel_instruction.h @@ -0,0 +1,72 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h" + +namespace pir { +class Operation; +} // namespace pir + +namespace paddle { +namespace framework { +class Scope; +class ValueExecutionInfo; + +class OneDNNLegacyKernelInstruction : public InstructionBase { + public: + OneDNNLegacyKernelInstruction(size_t id, + const platform::Place& place, + ::pir::Operation* op, + const ValueExecutionInfo* value_exec_info); + + ~OneDNNLegacyKernelInstruction(); + phi::Kernel* PhiKernel() const { return phi_kernel_; } + + const phi::InferMetaContext& InferMetaContext() const { + return infer_meta_context_; + } + + paddle::dialect::InferMetaInterface::Concept* InferMetaInterface() const { + return infer_meta_interface_; + } + + void Run() override; + + const std::string& Name() const override { return legacy_op_name_; } + + ::pir::Operation* Operation() const override { return op_; } + + private: + std::string legacy_op_name_; + + paddle::dialect::InferMetaInterface::Concept* infer_meta_interface_{ + nullptr}; // not owned + + phi::InferMetaContext infer_meta_context_; + + paddle::framework::ExecutionContext* kernel_context_{nullptr}; + std::shared_ptr runtime_context_; + std::shared_ptr operator_base_; + + phi::Kernel* phi_kernel_{nullptr}; // not owned + + ::pir::Operation* op_{nullptr}; // not owned + + const ValueExecutionInfo* value_exec_info_; // not owned +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_mixed_phi_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_mixed_phi_kernel_instruction.cc new file mode 100644 index 00000000000000..572c26eb420789 --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_mixed_phi_kernel_instruction.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/new_executor/instruction/onednn/onednn_mixed_phi_kernel_instruction.h" + +#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" +#include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" +#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/core/meta_tensor.h" +#include "paddle/phi/core/type_defs.h" + +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/value.h" + +#include "dnnl.hpp" // NOLINT +#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" +#include "paddle/fluid/framework/type_defs.h" +#include "paddle/fluid/ir_adaptor/translator/op_compat_info.h" +#include "paddle/phi/backends/onednn/onednn_context.h" +#include "paddle/phi/backends/onednn/onednn_helper.h" +#include "paddle/phi/kernels/funcs/data_layout_transform.h" + +namespace paddle { +namespace framework { + +OneDNNMixedPhiKernelInstruction::OneDNNMixedPhiKernelInstruction( + size_t id, + const platform::Place& place, + pir::Operation* op, + const ValueExecutionInfo* value_exec_info) + : OneDNNPhiKernelInstruction(id, place, op, value_exec_info) {} + +void OneDNNMixedPhiKernelInstruction::Run() { + // Step1. Mixed Dynamic Choose Kernel + // todo if (input_tensor.layout() != phi::DataLayout::ONEDNN) + + OneDNNPhiKernelInstruction::Run(); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_mixed_phi_kernel_instruction.h b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_mixed_phi_kernel_instruction.h new file mode 100644 index 00000000000000..d39e5fa9d1fea0 --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_mixed_phi_kernel_instruction.h @@ -0,0 +1,42 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.h" + +namespace pir { +class Operation; +} // namespace pir + +namespace paddle { +namespace framework { +class Scope; +class ValueExecutionInfo; + +using RuntimeAttribute = phi::Attribute; +using PIRAttribute = pir::Attribute; + +class OneDNNMixedPhiKernelInstruction : public OneDNNPhiKernelInstruction { + public: + OneDNNMixedPhiKernelInstruction(size_t id, + const platform::Place& place, + ::pir::Operation* op, + const ValueExecutionInfo* value_exec_info); + + void Run() override; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.cc new file mode 100644 index 00000000000000..71385619cb958b --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.cc @@ -0,0 +1,388 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.h" + +#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" +#include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" +#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/core/meta_tensor.h" +#include "paddle/phi/core/type_defs.h" + +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/value.h" + +#include "dnnl.hpp" // NOLINT +#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" +#include "paddle/fluid/framework/type_defs.h" +#include "paddle/fluid/ir_adaptor/translator/op_compat_info.h" +#include "paddle/phi/backends/onednn/onednn_context.h" +#include "paddle/phi/backends/onednn/onednn_helper.h" +#include "paddle/phi/kernels/funcs/data_layout_transform.h" + +namespace paddle { +namespace framework { + +static RuntimeAttribute ConvertPirAttribute2RuntimeAttribute( + PIRAttribute attr, + const std::string& attr_name, + const paddle::dialect::OpYamlInfoParser& op_yaml_info) { + auto& attr_type_name = op_yaml_info.AttrTypeName(attr_name); + if (attr_type_name == "pir::Int32Attribute") { + return attr.dyn_cast().data(); + } else if (attr_type_name == "pir::FloatAttribute") { + return attr.dyn_cast().data(); + } else if (attr_type_name == "pir::BoolAttribute") { + return attr.dyn_cast().data(); + } else if (attr_type_name == "pir::StrAttribute") { + return attr.dyn_cast().AsString(); + } else if (attr_type_name == "pir::ArrayAttribute") { + auto array_list = attr.dyn_cast().AsVector(); + std::vector vec_res; + if (array_list.size() > 0) { + PADDLE_ENFORCE_EQ(array_list[0].isa(), + true, + phi::errors::Unimplemented( + "the 0th elementwise MUST be pir::Int32Attribute")); + for (size_t i = 0; i < array_list.size(); ++i) { + vec_res.push_back(array_list[i].dyn_cast().data()); + } + } + return vec_res; + } else if (attr_type_name == "pir::ArrayAttribute") { + auto array_list = attr.dyn_cast().AsVector(); + std::vector vec_res; + if (array_list.size() > 0) { + if (array_list[0].isa()) { + for (size_t i = 0; i < array_list.size(); ++i) { + vec_res.push_back( + array_list[i].dyn_cast().data()); + } + + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "ConvertPirAttribute2RuntimeAttribute not support [%s] ", + attr_type_name)); + } + } + return vec_res; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "ConvertPirAttribute2RuntimeAttribute not support [%s] ", + attr_type_name)); + } +} + +void TensorNameMap(pir::Operation* op, + const ValueExecutionInfo& value_exec_info, + const paddle::dialect::OpYamlInfoParser& op_yaml_info, + std::map>& + inputs_tensor_name_map, // NOLINT + std::map>& + outputs_tensor_name_map) { // NOLINT + const Scope* inner_scope = value_exec_info.GetScope(); + VLOG(6) << "TensorNameMap in scope[" << inner_scope << "]"; + + auto& vec_kernel_fn_tensor_params = op_yaml_info.TensorParams(true); + + auto& name2id = op_yaml_info.InputName2Id(); + + std::string fluid_op_name = op_yaml_info.GetOriginOpName(); + + auto& op_normalizer = paddle::translator::OpNameNormalizer::instance(); + + for (auto& name : vec_kernel_fn_tensor_params) { + PADDLE_ENFORCE_EQ( + name2id.count(name), + true, + phi::errors::NotFound("param [%s] MUST in name2id map", name)); + auto index = name2id.at(name); + pir::Value ptr = op->operand_source(index); + + if (!IsInvalid(ptr)) { + continue; + } + + auto legacy_arg_name = op_normalizer.GetLegacyArgName(fluid_op_name, name); + auto in_var_name = value_exec_info.GetVarName(ptr); + PADDLE_ENFORCE_NOT_NULL(inner_scope->FindVar(in_var_name), + phi::errors::PreconditionNotMet( + "can not find var[%s] in scope", in_var_name)); + + auto type = ptr.type(); + if (type.isa() || + type.isa()) { + inputs_tensor_name_map[legacy_arg_name] = {in_var_name}; + } else if (type.isa()) { + auto var = inner_scope->FindVar(in_var_name); + auto var_ref = var->Get(); + std::vector vec_tmp; + vec_tmp.reserve(var_ref.size()); + for (size_t k = 0; k < var_ref.size(); ++k) { + vec_tmp.push_back(value_exec_info.GetVarName(var_ref[k])); + } + inputs_tensor_name_map[legacy_arg_name] = vec_tmp; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "only support AllocatedDenseTensor, AllocatedSelectedRowsType and " + "pir::vector type")); + } + } + + auto& output_name_list = op_yaml_info.OutputNames(); + for (size_t i = 0; i < output_name_list.size(); ++i) { + auto name = output_name_list[i]; + pir::Value ptr = op->result(i); + auto legacy_arg_name = op_normalizer.GetLegacyArgName(fluid_op_name, name); + + if (!IsInvalid(ptr)) { + continue; + } + + auto out_var_name = value_exec_info.GetVarName(ptr); + + PADDLE_ENFORCE_NOT_NULL(inner_scope->FindVar(out_var_name), + phi::errors::PreconditionNotMet( + "can not find var[%s] in scope", out_var_name)); + + auto type = ptr.type(); + if (type.isa() || + type.isa()) { + outputs_tensor_name_map[legacy_arg_name] = {out_var_name}; + } else if (type.isa()) { + auto var = inner_scope->FindVar(out_var_name); + auto var_ref = var->Get(); + std::vector vec_tmp; + vec_tmp.reserve(var_ref.size()); + for (size_t k = 0; k < var_ref.size(); ++k) { + vec_tmp.push_back(value_exec_info.GetVarName(var_ref[k])); + } + outputs_tensor_name_map[legacy_arg_name] = vec_tmp; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "only support AllocatedDenseTensor, AllocatedSelectedRowsType and " + "pir::vector type")); + } + } +} + +OneDNNPhiKernelInstruction::OneDNNPhiKernelInstruction( + size_t id, + const platform::Place& place, + pir::Operation* op, + const ValueExecutionInfo* value_exec_info) + : InstructionBase(id, place), value_exec_info_(value_exec_info) { + // Step1: build phi kernel instruction as PhiKernelInstruction + auto op_attributes = op->attributes(); + auto op_name = + op_attributes.at("op_name").dyn_cast().AsString(); + pir::OpInfo op_info = + pir::IrContext::Instance()->GetRegisteredOpInfo(op_name); + op_ = op; + phi_op_name_ = op_name; + VLOG(6) << "construct phi kernel instruction for: " << phi_op_name_; + + SetKernelType(AnalyseOpFuncType(op, place)); + VLOG(6) << "finish process analyse kernel type"; + + infer_meta_interface_ = + op_info.GetInterfaceImpl(); + VLOG(6) << "finish process infer_meta_interface_"; + + auto yaml_interface = + op_info.GetInterfaceImpl(); + PADDLE_ENFORCE_NOT_NULL( + yaml_interface, + phi::errors::PreconditionNotMet( + "can not find OpYamlInfoInterface from [%s]", phi_op_name_)); + paddle::dialect::OpYamlInfoParser yaml_info_parser( + yaml_interface->get_op_info_(), + paddle::dialect::IsOneDNNLegacyOp(op_name)); + VLOG(6) << "finish process yaml_info_parser"; + + if (infer_meta_interface_) { + BuildPhiContext< + phi::InferMetaContext, + phi::MetaTensor, + phi::MetaTensor, + paddle::small_vector, + paddle::small_vector, + false>(op, *value_exec_info_, yaml_info_parser, &infer_meta_context_); + } + VLOG(6) << "finish process infer meta context"; + + auto kernel_name = + op_attributes.at("kernel_name").dyn_cast().AsString(); + auto kernel_key = op_attributes.at("kernel_key") + .dyn_cast() + .data(); + + phi_kernel_ = new phi::Kernel( + phi::KernelFactory::Instance().SelectKernel(kernel_name, kernel_key)); + PADDLE_ENFORCE_EQ( + phi_kernel_->IsValid(), true, "not found kernel for [%s]", kernel_name); + VLOG(6) << "finish process select kernel"; + + BuildPhiContext, + paddle::small_vector, + true>( + op, *value_exec_info_, yaml_info_parser, &kernel_context_); + + kernel_context_.SetDeviceContext(phi::DeviceContextPool::Instance().Get( + phi::TransToPhiPlace(kernel_key.backend()))); + VLOG(6) << "finish process kernel context"; + + SetDeviceContext( + ParseDeviceContext(op, + phi::DeviceContextPool::Instance().Get( + phi::TransToPhiPlace(kernel_key.backend())), + place, + GetExecutionStream(), + GetStreamPriority())); + VLOG(6) << "finish process device context"; + + InitInputsOutputsIds(op, *value_exec_info); + VLOG(6) << "finish process inputs outputs index"; + + auto& no_need_buffer_ids = yaml_info_parser.NoNeedBufferIds(); + std::unordered_set no_need_buffer_values; + for (size_t id = 0; id < no_need_buffer_ids.size(); id++) { + no_need_buffer_values.insert(op->operand_source(no_need_buffer_ids[id])); + } + SetNoNeedBuffer(no_need_buffer_values); + VLOG(6) << "finish process no need buffer"; + + // Step2: build layout_transform information + if (op_attributes.count("layout_transform_arg")) { + auto layout_transform_arg = op_attributes.at("layout_transform_arg") + .dyn_cast() + .AsString(); + auto data_layout = op_attributes.at(layout_transform_arg) + .dyn_cast() + .AsString(); + input_layout_ = common::StringToDataLayout(data_layout); + std::vector layout_transform_inputs_attr = + op->attributes() + .at("layout_transform_inputs") + .dyn_cast() + .AsVector(); + std::vector layout_transform_inputs; + for (auto& attr : layout_transform_inputs_attr) { + auto pair = kernel_context_.InputRangeAt(value_exec_info_->GetIdByName( + attr.dyn_cast().AsString())); + for (int i = pair.first; i < pair.second; ++i) { + layout_transform_inputs_.insert(i); + } + } + } + + // Step3: build extra attr information + if (op_attributes.count("extra_args")) { + std::vector extra_args_attr = + op->attributes() + .at("extra_args") + .dyn_cast() + .AsVector(); + std::vector extra_args; + for (auto& attr : extra_args_attr) { + auto attr_name = attr.dyn_cast().AsString(); + extra_attr_[attr_name] = ConvertPirAttribute2RuntimeAttribute( + op_attributes.at(attr_name), attr_name, yaml_info_parser); + } + } + TensorNameMap(op, *value_exec_info_, yaml_info_parser, inputs_, outputs_); +} + +OneDNNPhiKernelInstruction::~OneDNNPhiKernelInstruction() { + if (phi_kernel_ != nullptr) { + delete phi_kernel_; + } +} + +void OneDNNPhiKernelInstruction::Run() { + // Step1. TransLayout + auto inputs = kernel_context_.InputsBetween( + size_t(0), kernel_context_.InputsSize()); + for (size_t i = 0; i < inputs.size(); ++i) { + auto input = inputs[i]; + if (input->layout() != phi::DataLayout::ONEDNN) { + phi::DataLayout from_layout = input->layout(); + + // Handle 'layout_transform' in + // ops_onednn_extra.yaml(GetKernelTypeForVar) + if (layout_transform_inputs_.count(i) && + input_layout_ != phi::DataLayout::kAnyLayout) { + from_layout = input_layout_; + } + + auto transed_tensor = const_cast(input); + + if (from_layout == DataLayout::kNHWC || + from_layout == DataLayout::kNDHWC) { + phi::funcs::MatchShapeToLayout( + transed_tensor, from_layout, phi::DataLayout::ONEDNN); + // We register only NHWC assuming that model is consistent e.g. either + // NHWC or NCHW + phi::OneDNNContext::tls().set_cur_paddle_data_layout(from_layout); + } + + if (from_layout == DataLayout::kAnyLayout) { + from_layout = phi::OneDNNContext::tls().get_cur_paddle_data_layout(); + } + + dnnl::memory::desc out_mem_desc = + phi::funcs::make_memory_desc(*input, from_layout); + transed_tensor->set_mem_desc(out_mem_desc); + } + } + + // Step2. Append extra information into ctx + // SetDnnAttrIntoDeviceContext + // SetInputsName SetOutputsName + auto one_dnn_ctx = const_cast( + &kernel_context_.GetDeviceContext()); + for (auto& attr : extra_attr_) { + one_dnn_ctx->SetDnnAttr(attr.first, attr.second); + } + one_dnn_ctx->SetInputsName(inputs_); + one_dnn_ctx->SetOutputsName(outputs_); + + // Step3. InferMeta + if (infer_meta_interface_) { + infer_meta_interface_->infer_meta_(&(infer_meta_context_)); + } + + // Step4. Run kernel + VLOG(6) << "Run op " << phi_op_name_ << " infer meta."; + (*(phi_kernel_))(&(kernel_context_)); + VLOG(6) << "Run op " << phi_op_name_ << " kernel."; + + // Step5. ClearDnnAttr + one_dnn_ctx->ClearDnnAttr(); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.h b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.h new file mode 100644 index 00000000000000..c15a69728f9c3d --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.h @@ -0,0 +1,82 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h" + +namespace pir { +class Operation; +} // namespace pir + +namespace paddle { +namespace framework { +class Scope; +class ValueExecutionInfo; + +using RuntimeAttribute = phi::Attribute; +using PIRAttribute = pir::Attribute; + +class OneDNNPhiKernelInstruction : public InstructionBase { + public: + OneDNNPhiKernelInstruction(size_t id, + const platform::Place& place, + ::pir::Operation* op, + const ValueExecutionInfo* value_exec_info); + + ~OneDNNPhiKernelInstruction(); + + phi::Kernel* PhiKernel() const { return phi_kernel_; } + + const phi::KernelContext& KernelContext() const { return kernel_context_; } + + const phi::InferMetaContext& InferMetaContext() const { + return infer_meta_context_; + } + + paddle::dialect::InferMetaInterface::Concept* InferMetaInterface() const { + return infer_meta_interface_; + } + + ::pir::Operation* Operation() const override { return op_; } + + void Run() override; + + const std::string& Name() const override { return phi_op_name_; } + + private: + paddle::dialect::InferMetaInterface::Concept* infer_meta_interface_{ + nullptr}; // not owned + + phi::InferMetaContext infer_meta_context_; + + phi::KernelContext kernel_context_; + + phi::Kernel* phi_kernel_{nullptr}; // not owned + + std::string phi_op_name_; + + ::pir::Operation* op_{nullptr}; // not owned + + const ValueExecutionInfo* value_exec_info_; // not owned + + std::set layout_transform_inputs_{}; + phi::DataLayout input_layout_{phi::DataLayout::kAnyLayout}; + std::map extra_attr_{}; + std::map> inputs_{}; + std::map> outputs_{}; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.cc b/paddle/fluid/framework/new_executor/pir_interpreter.cc index 7dbb514513fc24..1cd1117d0ea1d2 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/pir_interpreter.cc @@ -34,6 +34,9 @@ #include "paddle/phi/core/sparse_csr_tensor.h" #ifdef PADDLE_WITH_DNNL +#include "paddle/fluid/framework/new_executor/instruction/onednn/onednn_legacy_kernel_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/onednn/onednn_mixed_phi_kernel_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.h" #include "paddle/fluid/platform/mkldnn_helper.h" #endif @@ -728,6 +731,22 @@ void PirInterpreter::BuildInstruction() { } else { CREATE_INSTR(PhiKernelInstruction); } +#ifdef PADDLE_WITH_DNNL + } else if (op.dialect()->name() == "pd_onednn_kernel") { + auto op_name = op.attributes() + .at("op_name") + .dyn_cast<::pir::StrAttribute>() + .AsString(); + VLOG(6) << "process " << op_name; + + if (op.isa()) { + CREATE_INSTR(OneDNNPhiKernelInstruction); + } else if (op.isa()) { + CREATE_INSTR(OneDNNMixedPhiKernelInstruction); + } else { + CREATE_INSTR(OneDNNLegacyKernelInstruction); + } +#endif #ifdef PADDLE_WITH_CINN } else if (op.dialect()->name() == "cinn_runtime") { CREATE_INSTR(CinnJitInstruction); diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 76a787cda64bf4..626073d143e3e3 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -44,6 +44,9 @@ #include "paddle/pir/core/operation.h" #include "paddle/pir/core/value.h" +#ifdef PADDLE_WITH_DNNL +#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.h" +#endif // NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in // paddle/fluid/pir/dialect/CMakeLists.txt. #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" @@ -77,7 +80,10 @@ using AttributeHandlerFn = std::function; using DenseTensorTypeStorage = paddle::dialect::DenseTensorTypeStorage; constexpr char kTargetDialectPrefix[] = "pd_op."; // NOLINT -constexpr char kEmptyVarName[] = "@EMPTY@"; // NOLINT +#ifdef PADDLE_WITH_DNNL +constexpr char kOneDNNTargetDialectPrefix[] = "pd_onednn_op."; // NOLINT +#endif +constexpr char kEmptyVarName[] = "@EMPTY@"; // NOLINT static const std::unordered_set SpecialNonInplaceOps = {}; @@ -223,12 +229,36 @@ inline pir::Operation* InsertCreateArrayOp(pir::IrContext* ctx, return create_array_op.operation(); } +inline std::string GetPrefix(pir::IrContext* ctx, const OpDesc& op_desc) { +#ifdef PADDLE_WITH_DNNL + if (op_desc.GetAttrIfExists("use_mkldnn")) { + std::string target_op_name = + kOneDNNTargetDialectPrefix + OpNameCompatibleMapping(op_desc.Type()); + if (IsInplace(op_desc) && *target_op_name.rbegin() != '_') { + target_op_name += "_"; + } + auto op_info = ctx->GetRegisteredOpInfo(target_op_name); + if (!op_info) { + VLOG(3) << op_desc.Type() + << "'s use_mkldnn == True, but PIR not support OneDNN for this " + "op right now."; + return kTargetDialectPrefix; + } else { + return kOneDNNTargetDialectPrefix; + } + } else { + return kTargetDialectPrefix; + } +#else + return kTargetDialectPrefix; +#endif +} } // namespace pir::OpInfo OpTranscriber::LoopkUpOpInfo(pir::IrContext* ctx, const OpDesc& op_desc) { std::string target_op_name = - kTargetDialectPrefix + OpNameCompatibleMapping(op_desc.Type()); + GetPrefix(ctx, op_desc) + OpNameCompatibleMapping(op_desc.Type()); if (IsInplace(op_desc) && *target_op_name.rbegin() != '_') { target_op_name += "_"; } @@ -321,7 +351,7 @@ pir::OpInfo OpTranscriber::LoopkUpOpInfo(pir::IrContext* ctx, op_desc.Type(), target_op_name); - target_op_name = kTargetDialectPrefix + target_op_name; + target_op_name = GetPrefix(ctx, op_desc) + target_op_name; if (IsInplace(op_desc) && *target_op_name.rbegin() != '_') { target_op_name += "_"; } @@ -1054,7 +1084,7 @@ struct EmbeddingGradOpTranscriber : public OpTranscriber { pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, const OpDesc& op_desc) override { std::string target_op_name = - kTargetDialectPrefix + OpNameCompatibleMapping(op_desc.Type()); + GetPrefix(ctx, op_desc) + OpNameCompatibleMapping(op_desc.Type()); bool is_sparse = paddle::get(op_desc.GetAttr("is_sparse")); @@ -1307,7 +1337,7 @@ struct AddNOpTranscriber : public OpTranscriber { pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, const OpDesc& op_desc) override { std::string target_op_name = - kTargetDialectPrefix + OpNameCompatibleMapping(op_desc.Type()); + GetPrefix(ctx, op_desc) + OpNameCompatibleMapping(op_desc.Type()); if (IsInplace(op_desc)) { target_op_name += "_"; } else { diff --git a/paddle/fluid/ir_adaptor/translator/translate.cc b/paddle/fluid/ir_adaptor/translator/translate.cc index 7a7081fe1acbf2..04ddf1d13a5a8a 100644 --- a/paddle/fluid/ir_adaptor/translator/translate.cc +++ b/paddle/fluid/ir_adaptor/translator/translate.cc @@ -22,6 +22,9 @@ #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/core/program.h" +#ifdef PADDLE_WITH_DNNL +#include "paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h" +#endif namespace paddle { using LegacyProgramDesc = ::paddle::framework::ProgramDesc; @@ -31,6 +34,9 @@ std::unique_ptr TranslateLegacyProgramToProgram( const LegacyProgramDesc& legacy_program) { pir::IrContext* ctx = pir::IrContext::Instance(); ctx->GetOrRegisterDialect(); +#ifdef PADDLE_WITH_DNNL + ctx->GetOrRegisterDialect(); +#endif auto program = std::make_unique(ctx); translator::ProgramTranslator program_translator(&legacy_program, program.get()); diff --git a/paddle/fluid/ir_adaptor/translator/utils.cc b/paddle/fluid/ir_adaptor/translator/utils.cc index ebba4428220f70..dbd85292974bf0 100644 --- a/paddle/fluid/ir_adaptor/translator/utils.cc +++ b/paddle/fluid/ir_adaptor/translator/utils.cc @@ -23,6 +23,9 @@ #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/utils.h" +#ifdef PADDLE_WITH_DNNL +#include "paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h" +#endif namespace paddle { namespace dialect { @@ -94,6 +97,9 @@ std::vector CheckUnregisteredOperationInBlock( std::vector CheckUnregisteredOperation( pir::IrContext* ctx, const framework::ProgramDesc& legacy_program) { ctx->GetOrRegisterDialect(); +#ifdef PADDLE_WITH_DNNL + ctx->GetOrRegisterDialect(); +#endif std::vector unregistered_ops; for (size_t block_idx = 0; block_idx < legacy_program.Size(); block_idx++) { diff --git a/paddle/fluid/pir/dialect/CMakeLists.txt b/paddle/fluid/pir/dialect/CMakeLists.txt index 2c812ccada69af..337841b2274971 100644 --- a/paddle/fluid/pir/dialect/CMakeLists.txt +++ b/paddle/fluid/pir/dialect/CMakeLists.txt @@ -27,6 +27,7 @@ set(pir_op_fwd_src_yaml ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/ops.yaml) set(pir_op_bwd_src_yaml ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml) + set(pir_update_op_fwd_src_yaml ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/update_ops.yaml) set(parsed_op_dir @@ -108,6 +109,44 @@ set(generated_files_pd_op "${pir_bwd_op_source_file}" "${pir_update_op_source_file}") +if(WITH_MKLDNN) + set(pir_op_onednn_yaml ${parsed_op_dir}/onednn.parsed.yaml) + + set(pd_onednn_op_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/onednn.yaml) + + set(pd_ops_onednn_extra_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml + ) + + set(op_onednn_info_file ${PD_DIALECT_SOURCE_DIR}/pd_onednn_op_info.cc) + set(op_onednn_info_file_tmp ${op_onednn_info_file}.tmp) + + set(onednn_op_namespace paddle,onednn,dialect) + set(onednn_dialect_name pd_onednn_op) + set(onednn_op_header_file ${PD_DIALECT_SOURCE_DIR}/pd_onednn_op.h) + set(onednn_op_source_file ${PD_DIALECT_SOURCE_DIR}/pd_onednn_op.cc) + set(onednn_op_header_file_tmp ${onednn_op_header_file}.tmp) + set(onednn_op_source_file_tmp ${onednn_op_source_file}.tmp) + + execute_process( + COMMAND ${PYTHON_EXECUTABLE} ${op_gen_parsed_yaml_file} --op_yaml_path + ${pd_onednn_op_yaml_file} --output_path ${pir_op_onednn_yaml}) + + execute_process( + COMMAND + ${PYTHON_EXECUTABLE} ${op_gen_file} --op_yaml_files ${op_yaml_files} + --op_compat_yaml_file ${op_compat_yaml_file} --namespaces + ${onednn_op_namespace} --dialect_name ${onednn_dialect_name} + --op_def_h_file ${onednn_op_header_file_tmp} --op_info_file + ${op_onednn_info_file_tmp} --op_def_cc_file ${onednn_op_source_file_tmp} + --onednn_yaml_file ${pir_op_onednn_yaml} --ops_onednn_extra_yaml_file + ${pd_ops_onednn_extra_yaml_file}) + + set(generated_files_onednn_pd_op + "${onednn_op_header_file}" "${onednn_op_source_file}" + "${op_onednn_info_file}") +endif() set(api_gen_yaml_files ${op_fwd_yaml},${op_bwd_yaml},${pir_op_fwd_yaml},${pir_op_bwd_yaml},${pir_update_op_fwd_yaml} ) @@ -159,8 +198,10 @@ execute_process( set(generated_files_ops_api "${ops_api_source_file}") -set(generated_files_pir ${generated_files_pd_op} ${generated_files_pd_api} - ${generated_files_python_c} ${generated_files_ops_api}) +set(generated_files_pir + ${generated_files_pd_op} ${generated_files_onednn_pd_op} + ${generated_files_pd_api} ${generated_files_python_c} + ${generated_files_ops_api}) foreach(generated_file ${generated_files_pir}) if(EXISTS "${generated_file}.tmp" AND EXISTS "${generated_file}") execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different @@ -206,6 +247,10 @@ set(op_dialect_srcs ${pir_update_op_source_file} ${api_source_file}) +if(WITH_MKLDNN) + set(op_dialect_srcs ${op_dialect_srcs} ${onednn_op_source_file}) +endif() + set(op_dialect_deps phi common pir type_info string_helper) cc_library( @@ -222,6 +267,13 @@ set(op_dialect_vjp_srcs ${op_decomp_source_file} ${op_vjp_source_file} ${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/base/decomp_trans.cc) + +if(WITH_MKLDNN) + set(op_dialect_vjp_srcs + ${op_dialect_vjp_srcs} + ${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/op_onednn_dialect.cc) +endif() + set(op_dialect_vjp_deps primitive_vjp_experimental op_dialect) cc_library( diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.cc b/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.cc index 95e77ff6169c68..ecf04d4411397b 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.cc +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.cc @@ -122,7 +122,110 @@ void KernelDialect::PrintOperation(pir::Operation *op, } } +#ifdef PADDLE_WITH_DNNL +OneDNNKernelDialect::OneDNNKernelDialect(pir::IrContext *context) + : pir::Dialect(name(), context, pir::TypeId::get()) { + initialize(); +} + +void OneDNNKernelDialect::initialize() { + RegisterTypes(); + RegisterOps(); + RegisterAttributes(); +} + +void OneDNNKernelDialect::PrintType(pir::Type type, std::ostream &os) const { + if (type.isa()) { + AllocatedDenseTensorType tensor_type = + type.dyn_cast(); + + os << phi::AllocationTypeStr(tensor_type.place().GetType()) << "_"; + os << "tensor<"; + for (auto d : common::vectorize(tensor_type.dims())) { + os << d; + os << "x"; + } + tensor_type.dtype().Print(os); + os << ">"; + } else if (type.isa()) { + AllocatedSelectedRowsType tensor_type = + type.dyn_cast(); + + os << phi::AllocationTypeStr(tensor_type.place().GetType()) << "_"; + os << "tensor<"; + for (auto d : common::vectorize(tensor_type.dims())) { + os << d; + os << "x"; + } + tensor_type.dtype().Print(os); + os << ">"; + } else if (type.isa()) { + AllocatedDenseTensorArrayType tensor_array_type = + type.dyn_cast(); + + os << phi::AllocationTypeStr(tensor_array_type.place().GetType()) << "_"; + os << "tensor_array<"; + tensor_array_type.dtype().Print(os); + os << ">"; + } +} + +void OneDNNKernelDialect::PrintAttribute(pir::Attribute attr, + std::ostream &os) const { + phi::KernelKey kernel = attr.dyn_cast().data(); + + os << ""; +} + +void OneDNNKernelDialect::PrintOperation(pir::Operation *op, + pir::IrPrinter &printer) const { + if (op->dyn_cast() || op->dyn_cast()) { + auto &os = printer.os; + printer.PrintOpResult(op); + os << " ="; + if (auto phi_kernel_op = op->dyn_cast()) { + std::string kernel_name = phi_kernel_op.kernel_name(); + if (op->attributes().count("is_inplace") != 0 && + op->attributes() + .at("is_inplace") + .dyn_cast() + .data()) { + kernel_name = kernel_name + "_"; + } + os << " \"" << kernel_name << "(phi_kernel)\""; + } else { + auto legacy_kernel_op = op->dyn_cast(); + std::string kernel_name = legacy_kernel_op.kernel_name(); + if (op->attributes().count("is_inplace") != 0 && + op->attributes() + .at("is_inplace") + .dyn_cast() + .data()) { + kernel_name = kernel_name + "_"; + } + os << " \"" << kernel_name << "(legacy_kernel)\""; + } + printer.PrintOpOperands(op); + printer.PrintAttributeMap(op); + os << " :"; + printer.PrintOperandsType(op); + os << " -> "; + printer.PrintOpReturnType(op); + } else { + printer.PrintGeneralOperation(op); + } +} +#endif + } // namespace dialect } // namespace paddle IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::KernelDialect) +#ifdef PADDLE_WITH_DNNL +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNKernelDialect) +#endif diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h b/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h index d2fbcadaf8cf2a..fbdb53a40b183d 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h @@ -36,7 +36,29 @@ class KernelDialect : public pir::Dialect { void initialize(); }; +#ifdef PADDLE_WITH_DNNL +class OneDNNKernelDialect : public pir::Dialect { + public: + explicit OneDNNKernelDialect(pir::IrContext* context); + + static const char* name() { return "pd_onednn_kernel"; } + + void PrintType(pir::Type type, std::ostream& os) const override; + + void PrintAttribute(pir::Attribute attr, std::ostream& os) const override; + + void PrintOperation(pir::Operation* op, + pir::IrPrinter& printer) const override; // NOLINT + + private: + void initialize(); +}; +#endif + } // namespace dialect } // namespace paddle IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::KernelDialect) +#ifdef PADDLE_WITH_DNNL +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNKernelDialect) +#endif diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc index 8ad46bc8906adb..45f0a848fc174d 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc @@ -98,8 +98,135 @@ phi::KernelKey LegacyKernelOp::kernel_key() { return attributes().at("kernel_key").dyn_cast().data(); } +#ifdef PADDLE_WITH_DNNL +const char* OneDNNPhiKernelOp::attributes_name[attributes_num] = { // NOLINT + "op_name", + "kernel_name", + "kernel_key"}; + +void OneDNNPhiKernelOp::VerifySig() { + VLOG(4) << "Verifying inputs, outputs and attributes for: OneDNNPhiKernelOp."; + + auto& attributes = this->attributes(); + + PADDLE_ENFORCE(attributes.count("op_name") > 0 && + attributes.at("op_name").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: op_name is not right.")); + + PADDLE_ENFORCE(attributes.count("kernel_name") > 0 && + attributes.at("kernel_name").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: kernel_name is not right.")); + + PADDLE_ENFORCE(attributes.count("kernel_key") > 0 && + attributes.at("kernel_key").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: kernel_key is not right.")); +} + +std::string OneDNNPhiKernelOp::op_name() { + return attributes().at("op_name").dyn_cast().AsString(); +} +std::string OneDNNPhiKernelOp::kernel_name() { + return attributes() + .at("kernel_name") + .dyn_cast() + .AsString(); +} +phi::KernelKey OneDNNPhiKernelOp::kernel_key() { + return attributes().at("kernel_key").dyn_cast().data(); +} + +const char* OneDNNMixedPhiKernelOp::attributes_name[attributes_num] = + { // NOLINT + "op_name", + "kernel_name", + "kernel_key"}; + +void OneDNNMixedPhiKernelOp::VerifySig() { + VLOG(4) << "Verifying inputs, outputs and attributes for: " + "OneDNNMixedPhiKernelOp."; + + auto& attributes = this->attributes(); + + PADDLE_ENFORCE(attributes.count("op_name") > 0 && + attributes.at("op_name").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: op_name is not right.")); + + PADDLE_ENFORCE(attributes.count("kernel_name") > 0 && + attributes.at("kernel_name").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: kernel_name is not right.")); + + PADDLE_ENFORCE(attributes.count("kernel_key") > 0 && + attributes.at("kernel_key").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: kernel_key is not right.")); +} + +std::string OneDNNMixedPhiKernelOp::op_name() { + return attributes().at("op_name").dyn_cast().AsString(); +} +std::string OneDNNMixedPhiKernelOp::kernel_name() { + return attributes() + .at("kernel_name") + .dyn_cast() + .AsString(); +} +phi::KernelKey OneDNNMixedPhiKernelOp::kernel_key() { + return attributes().at("kernel_key").dyn_cast().data(); +} + +const char* OneDNNLegacyKernelOp::attributes_name[attributes_num] = { // NOLINT + "op_name", + "kernel_name", + "kernel_key"}; + +void OneDNNLegacyKernelOp::VerifySig() { + VLOG(4) + << "Verifying inputs, outputs and attributes for: OneDNNLegacyKernelOp."; + + auto& attributes = this->attributes(); + + PADDLE_ENFORCE(attributes.count("op_name") > 0 && + attributes.at("op_name").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: op_name is not right.")); + + PADDLE_ENFORCE(attributes.count("kernel_name") > 0 && + attributes.at("kernel_name").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: kernel_name is not right.")); + + PADDLE_ENFORCE(attributes.count("kernel_key") > 0 && + attributes.at("kernel_key").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: kernel_key is not right.")); +} + +std::string OneDNNLegacyKernelOp::op_name() { + return attributes().at("op_name").dyn_cast().AsString(); +} +std::string OneDNNLegacyKernelOp::kernel_name() { + return attributes() + .at("kernel_name") + .dyn_cast() + .AsString(); +} +phi::KernelKey OneDNNLegacyKernelOp::kernel_key() { + return attributes().at("kernel_key").dyn_cast().data(); +} +#endif + } // namespace dialect } // namespace paddle IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::PhiKernelOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::LegacyKernelOp) +#ifdef PADDLE_WITH_DNNL +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNPhiKernelOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNMixedPhiKernelOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNLegacyKernelOp) +#endif diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h index a96aa5732d5806..df723158702085 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h @@ -44,8 +44,51 @@ class LegacyKernelOp : public pir::Op { void VerifySig(); }; +#ifdef PADDLE_WITH_DNNL +class OneDNNPhiKernelOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_onednn_kernel.phi_kernel"; } + static constexpr uint32_t attributes_num = 3; + static const char *attributes_name[attributes_num]; + std::string op_name(); + std::string kernel_name(); + phi::KernelKey kernel_key(); + void VerifySig(); +}; + +class OneDNNMixedPhiKernelOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_onednn_kernel.phi_mixed_kernel"; } + static constexpr uint32_t attributes_num = 3; + static const char *attributes_name[attributes_num]; + std::string op_name(); + std::string kernel_name(); + phi::KernelKey kernel_key(); + void VerifySig(); +}; + +class OneDNNLegacyKernelOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_onednn_kernel.legacy_kernel"; } + static constexpr uint32_t attributes_num = 3; + static const char *attributes_name[attributes_num]; + std::string op_name(); + std::string kernel_name(); + phi::KernelKey kernel_key(); + void VerifySig(); +}; +#endif + } // namespace dialect } // namespace paddle IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::PhiKernelOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::LegacyKernelOp) +#ifdef PADDLE_WITH_DNNL +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNPhiKernelOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNMixedPhiKernelOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNLegacyKernelOp) +#endif diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 7dd754e868f86d..4cb54ada152b82 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -30,6 +30,7 @@ from op_kerneltype_gen import gen_kernel_type_for_var_str from op_member_func_gen import gen_op_get_inputs_outputs_str from op_verify_gen import gen_verify_func_str +from ops_onednn_extra_parser import parse_extra_args, parse_layout_transform from parse_kernel_key_gen import gen_parse_kernel_key_str from vjp_interface_black_list import vjp_interface_black_list @@ -67,6 +68,7 @@ #include "paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h" #include "paddle/fluid/pir/dialect/operator/interface/decomp.h" #include "paddle/fluid/pir/dialect/operator/trait/inplace.h" +#include "paddle/fluid/pir/dialect/operator/trait/onednn.h" #include "paddle/fluid/pir/dialect/operator/trait/custom_vjp.h" #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/phi/core/infermeta_utils.h" @@ -213,6 +215,17 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}"); }} """ + +OP_INFO_ONEDNN_TEMPLATE = """ +OpInfoTuple {op_name}::GetOpInfo() {{ + std::vector inputs = {{ {inputs} }}; + std::vector attributes = {{ {attributes} }}; + std::vector outputs = {{ {outputs} }}; + paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, "{kernel_func}", {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{kernel_key_backend}}}, {{{inplace}}}, {{{view}}}, {{{extra_args}}}, "{layout_transform_arg}", {{{layout_transform_inputs}}}, {is_onednn_only}, {dynamic_fallback}); + return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}"); +}} +""" + CONSTRUCT_INPUT_INFO_TEMPLATE = """paddle::dialect::OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer}, {is_mutable_attribute}, {with_grad_semantic})""" CONSTRUCT_OUTPUT_INFO_TEMPLATE = """paddle::dialect::OpOutputInfo("{name}", "{typename}", {optional}, {intermediate})""" CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = """paddle::dialect::OpAttributeInfo("{name}", "{typename}", "{data_type}")""" @@ -420,7 +433,7 @@ def __init__(self, op_yaml_item, op_compat_item): self.non_mutable_attribute_data_type_list, self.non_mutable_attribute_build_arg_type_list, self.non_mutable_attribute_default_value_list, - ) = self.parse_non_nutable_attribute() + ) = self.parse_non_mutable_attribute() # parse infermeta && kernel self.infer_meta_map = self.parse_infer_meta_map() @@ -462,6 +475,18 @@ def __init__(self, op_yaml_item, op_compat_item): # parse interfaces list self.interfaces_list = self.parse_op_interfaces() + # OneDNN info + if "extra_args" in self.op_yaml_item: + self.onednn_extra_args = self.op_yaml_item["extra_args"] + self.onednn_layout_transform = self.op_yaml_item["layout_transform"] + self.is_onednn_only = self.op_yaml_item["is_onednn_only"] + self.dynamic_fallback = self.op_yaml_item["dynamic_fallback"] + else: + self.onednn_extra_args = [] + self.onednn_layout_transform = None + self.is_onednn_only = False + self.dynamic_fallback = False + def parse_op_traits(self): if 'traits' in self.op_yaml_item: return self.op_yaml_item['traits'] @@ -633,7 +658,7 @@ def parse_mutable_attribute(self): sorted_mutable_attribute_type_list, ) - def parse_non_nutable_attribute(self): + def parse_non_mutable_attribute(self): op_non_mutable_attribute_name_list = [] op_non_mutable_attribute_type_list = [] op_non_mutable_attribute_data_type_list = [] @@ -1112,17 +1137,21 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): if ( op_info.backward_name and op_info.op_phi_name[0] not in vjp_interface_black_list + and dialect_name != "pd_onednn_op" ): op_interfaces += ["paddle::dialect::VjpInterface"] exclusive_interface_str = gen_exclusive_interface_str( op_info, op_info_items ) - if dialect_name == "pd_op": + if dialect_name == "pd_op" or dialect_name == "pd_onednn_op": op_interfaces += ["paddle::dialect::GetKernelTypeForVarInterface"] # if op has custom vjp rule, then append a CustomVjpTrait to it - if op_info.op_phi_name[0] in custom_vjp_op_name_list: + if ( + op_info.op_phi_name[0] in custom_vjp_op_name_list + and dialect_name != "pd_onednn_op" + ): op_traits += ["paddle::dialect::CustomVjpTrait"] # check op inputs and mutable_attributes grad semantics @@ -1143,6 +1172,15 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): if op_name[-1] == "_": op_traits += ["paddle::dialect::InplaceTrait"] + if dialect_name == "pd_onednn_op": + op_traits += ["paddle::dialect::OneDNNTrait"] + + if op_info.is_onednn_only: + op_traits += ["paddle::dialect::OneDNNOnlyTrait"] + + if op_info.dynamic_fallback: + op_traits += ["paddle::dialect::OneDNNDynamicFallbackTrait"] + op_traits_str = "" if len(op_traits) > 0: op_traits_str = "," + ",".join(op_traits) @@ -1158,6 +1196,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): if ( op_name in decomp_interface_declare_gen_op_list and kernel_func_name in decomp_interface_declare_gen_op_list + and dialect_name != "pd_onednn_op" ): op_interfaces = op_interfaces + [ "paddle::dialect::DecompInterface" @@ -1221,7 +1260,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): build_func_with_muta_attr_is_input = "" get_kernel_type_for_var_declare_str = "" - if dialect_name == "pd_op": + if dialect_name == "pd_op" or dialect_name == "pd_onednn_op": get_kernel_type_for_var_declare_str = ( get_kernel_type_for_var_declare_template ) @@ -1556,6 +1595,53 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): origin_op_name=op_info.op_yaml_item['name'], ) + if dialect_name == "pd_onednn_op": + if len(op_info.onednn_extra_args) > 0: + args_name = [] + for arg in op_info.onednn_extra_args: + args_name.append(arg["name"]) + + extra_args = '"' + '", "'.join(args_name) + '"' + else: + extra_args = "" + if op_info.onednn_layout_transform is None: + layout_transform_arg, layout_transform_inputs = ( + "", + "", + ) + else: + ( + layout_transform_arg, + layout_transform_inputs, + ) = op_info.onednn_layout_transform + layout_transform_inputs = ( + '"' + '", "'.join(layout_transform_inputs) + '"' + ) + + op_info_func_str = OP_INFO_ONEDNN_TEMPLATE.format( + op_name=op_class_name, + inputs=inputs_info_str, + attributes=attribute_info_str, + outputs=outputs_info_str, + infer_meta_func=infer_meta_func_str, + infer_meta_param=infer_meta_param_str, + kernel_func=kernel_func_str, + kernel_param=kernel_param_str, + kernel_key_dtype=kernel_key_dtype, + kernel_key_backend=kernel_key_backend, + inplace=inplace_str, + view=view_str, + origin_op_name=op_info.op_yaml_item['name'], + extra_args=extra_args, + layout_transform_arg=layout_transform_arg, + layout_transform_inputs=layout_transform_inputs, + is_onednn_only="true" + if op_info.is_onednn_only + else "false", + dynamic_fallback="true" + if op_info.dynamic_fallback + else "false", + ) # generate op verify function str op_verify_str = '' if not op_info.custom_verify: @@ -1600,7 +1686,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): # generate op GetKernelKeyForVar function str op_get_kernel_type_for_var_str = '' - if dialect_name == "pd_op": + if dialect_name == "pd_op" or dialect_name == "pd_onednn_op": op_get_kernel_type_for_var_str = ( gen_kernel_type_for_var_str( op_class_name, @@ -1629,6 +1715,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): op_info.backward_name and op_info.op_phi_name[0] not in vjp_interface_black_list + and dialect_name != "pd_onednn_op" ): op_vjp_str = gen_op_vjp_str( op_class_name, @@ -1659,7 +1746,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): ops_defined_list.append(infer_symbolic_shape_define_str) # NOTE(chenxi67)skip if dialect_name==cinn - if dialect_name == "cinn": + if dialect_name == "cinn" or dialect_name == "pd_onednn_op": pass else: ops_vjp_defined_list.append(op_vjp_str) @@ -1741,6 +1828,8 @@ def OpGenerator( op_info_file, op_def_cc_file, op_vjp_cc_file, + onednn_yaml_file, + ops_onednn_extra_yaml_file, ): # (1) Prepare: Delete existing old files: pd_op.h.tmp, pd_op.cc.tmp if os.path.exists(op_def_h_file): @@ -1754,8 +1843,32 @@ def OpGenerator( # (2) parse yaml files op_compat_parser = OpCompatParser(op_compat_yaml_file) + if dialect_name == "pd_onednn_op": + with open(ops_onednn_extra_yaml_file, "r") as f: + ops_onednn_extra = yaml.safe_load(f) + ops_onednn_extra_map = {} + for op in ops_onednn_extra: + op_name = op['op'] + item = {} + item["is_onednn_only"] = False + item["extra_args"] = parse_extra_args(op_name, op['extra_args']) + if 'layout_transform' in op: + item["layout_transform"] = parse_layout_transform( + op_name, op['layout_transform'] + ) + else: + item["layout_transform"] = None + if 'dynamic_fallback' in op: + item["dynamic_fallback"] = op['dynamic_fallback'] + else: + item["dynamic_fallback"] = False + item["attrs"] = parse_extra_args(op_name, op['extra_args']) + ops_onednn_extra_map[op_name] = item + op_yaml_files.insert(0, onednn_yaml_file) + op_infos = [] all_op_info_items = {} + first_file = True for yaml_file in op_yaml_files: op_yaml_items = [] with open(yaml_file, "r") as f: @@ -1765,7 +1878,7 @@ def OpGenerator( op_info_items = {} for op in op_yaml_items: op_compat_item = None - if dialect_name == "pd_op": + if dialect_name == "pd_op" or dialect_name == "pd_onednn_op": op_compat_item = op_compat_parser.get_compat(op['name']) if ( @@ -1791,11 +1904,26 @@ def OpGenerator( ) = op_compat_parser.parse_support_tensor(op) op_compat_item['scalar'] = scalar_item op_compat_item['int_array'] = int_array_item - - op_info_items[op['name']] = OpInfoParser(op, op_compat_item) - all_op_info_items[op['name']] = OpInfoParser(op, op_compat_item) + if dialect_name == "pd_onednn_op": + if first_file: + first_file = False + op["is_onednn_only"] = True + elif op['name'] in ops_onednn_extra_map: + onednn_item = ops_onednn_extra_map[op['name']] + op["is_onednn_only"] = onednn_item["is_onednn_only"] + op["extra_args"] = onednn_item["extra_args"] + op["layout_transform"] = onednn_item["layout_transform"] + op["dynamic_fallback"] = onednn_item["dynamic_fallback"] + op["attrs"] = op["attrs"] + onednn_item["attrs"] + else: + continue + item = OpInfoParser(op, op_compat_item) + op_info_items[op['name']] = item + all_op_info_items[op['name']] = item op_infos.append(op_info_items) + if dialect_name == "pd_onednn_op": + op_infos = [all_op_info_items] # (3) auto code gen op_list_strs = [] @@ -1867,14 +1995,15 @@ def OpGenerator( else: op_to_multi_kernels_map_str = "" - op_info_str = CC_OP_INFO_FILE_TEMPLATE.format( - op_declare=",".join(op_list_strs).replace("\n", ""), - op_to_multi_kernels_map=op_to_multi_kernels_map_str, - h_file=op_def_h_file[:-4], - ) + if op_info_file is not None: + op_info_str = CC_OP_INFO_FILE_TEMPLATE.format( + op_declare=",".join(op_list_strs).replace("\n", ""), + op_to_multi_kernels_map=op_to_multi_kernels_map_str, + h_file=op_def_h_file[:-4], + ) - with open(op_info_file, 'w') as f: - f.write(op_info_str) + with open(op_info_file, 'w') as f: + f.write(op_info_str) # (6) write to files for xx_op.cc.tmp for id in range(len(op_def_cc_file)): @@ -1883,8 +2012,17 @@ def OpGenerator( source_file_str = NAMESPACE_GARD_TEMPLATE.format( namespace=name, input=source_file_str ) # Add namespaces + + if dialect_name == "pd_onednn_op": + op_def_h_file_tmp = ( + "paddle/fluid/pir/dialect/operator/ir/pd_op.h\"\n#include \"" + + op_def_h_file + ) + else: + op_def_h_file_tmp = op_def_h_file + source_file_str = CC_FILE_TEMPLATE.format( - h_file=op_def_h_file[:-4], + h_file=op_def_h_file_tmp[:-4], input=source_file_str, define_type_id=define_type_id_strs[id], ) @@ -1896,7 +2034,11 @@ def OpGenerator( # and vjp is only avaible for pd dialect. vjp_source_file_str = "\n".join(vjp_source_file_strs) vjp_source_file_str = VJP_CC_FILE_TEMPLATE.format(input=vjp_source_file_str) - if dialect_name != 'cinn' and op_vjp_cc_file: + if ( + dialect_name != 'cinn' + and dialect_name != 'pd_onednn_op' + and op_vjp_cc_file + ): with open(op_vjp_cc_file, 'w') as f: f.write(vjp_source_file_str) @@ -1916,6 +2058,8 @@ def ParseArguments(): parser.add_argument('--op_info_file', type=str) parser.add_argument('--op_def_cc_file', type=str) parser.add_argument('--op_vjp_cc_file', type=str) + parser.add_argument('--onednn_yaml_file', type=str) + parser.add_argument('--ops_onednn_extra_yaml_file', type=str) return parser.parse_args() @@ -1935,6 +2079,8 @@ def ParseArguments(): op_info_file = args.op_info_file op_def_cc_files = args.op_def_cc_file.split(",") op_vjp_cc_file = args.op_vjp_cc_file + onednn_yaml_file = args.onednn_yaml_file + ops_onednn_extra_yaml_file = args.ops_onednn_extra_yaml_file # auto code generate OpGenerator( @@ -1946,4 +2092,6 @@ def ParseArguments(): op_info_file, op_def_cc_files, op_vjp_cc_file, + onednn_yaml_file, + ops_onednn_extra_yaml_file, ) diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index 9fd6bd4bfbd98c..0a834bc7b0c2cf 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -102,6 +102,7 @@ 'print', 'number_count', 'assign_value', + 'onednn_to_paddle_layout', ] NO_NEED_GEN_STATIC_ONLY_APIS = [ diff --git a/paddle/fluid/pir/dialect/op_generator/ops_onednn_extra_parser.py b/paddle/fluid/pir/dialect/op_generator/ops_onednn_extra_parser.py new file mode 100644 index 00000000000000..3296fa0d68829d --- /dev/null +++ b/paddle/fluid/pir/dialect/op_generator/ops_onednn_extra_parser.py @@ -0,0 +1,86 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any, Dict, List, Tuple + + +def parse_plain_list(s: str, sep=",") -> List[str]: + if sep == ",": + patten = re.compile(r',(?![^{]*\})') # support "int[] a={1,2}" + items = re.split(patten, s.strip()) + items = [x.strip() for x in items] + return items + else: + return [item.strip() for item in s.strip().split(sep)] + + +def parse_arg(op_name: str, s: str) -> Dict[str, str]: + """parse an argument in following formats: + 1. typename name + 2. typename name = default_value + """ + typename, rest = (item.strip() for item in s.split(" ", 1)) + assert ( + len(typename) > 0 + ), f"The arg typename should not be empty. Please check the args of {op_name} in yaml." + + assert ( + rest.count("=") <= 1 + ), f"There is more than 1 = in an arg in {op_name}" + if rest.count("=") == 1: + name, default_value = (item.strip() for item in rest.split("=", 1)) + assert ( + len(name) > 0 + ), f"The arg name should not be empty. Please check the args of {op_name} in yaml." + assert ( + len(default_value) > 0 + ), f"The default value should not be empty. Please check the args of {op_name} in yaml." + return { + "typename": typename, + "name": name, + "default_value": default_value, + } + else: + name = rest.strip() + assert ( + len(name) > 0 + ), f"The arg name should not be empty. Please check the args of {op_name} in yaml." + return {"typename": typename, "name": name} + + +def parse_extra_args(op_name: str, arguments: str) -> List: + if arguments is None: + return [] + args_str = arguments.strip() + args = parse_plain_list(args_str) + + attrs = [] + + for arg in args: + item = parse_arg(op_name, arg) + typename = item["typename"] + name = item["name"] + attrs.append(item) + return attrs + + +def parse_layout_transform( + op_name: str, layout_transform: Dict[str, Any] +) -> Tuple[str, List]: + if layout_transform is None: + return "", [] + return layout_transform["arg_name"], parse_plain_list( + layout_transform["tensors"] + ) diff --git a/paddle/fluid/pir/dialect/operator/ir/onednn.yaml b/paddle/fluid/pir/dialect/operator/ir/onednn.yaml new file mode 100644 index 00000000000000..d7de4310d5781f --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/ir/onednn.yaml @@ -0,0 +1,9 @@ +- op : quantize + args : (Tensor input, bool is_negative_input=false, float scale=1.0, float shift=0.0, str output_format="NHWC", bool bfloat16=false) + output : Tensor(output) + infer_meta : + func : UnchangedInferMeta + param : [input] + kernel : + func : quantize + data_type : input diff --git a/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc new file mode 100644 index 00000000000000..0d65389cc4922b --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc @@ -0,0 +1,168 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/type_storage.h" +#include "paddle/fluid/pir/dialect/operator/transforms/param_to_variable.h" +#include "paddle/pir/core/builtin_type_interfaces.h" +#include "paddle/pir/core/interface_value.h" +#include "paddle/pir/core/ir_printer.h" +#include "paddle/pir/core/utils.h" +#include "paddle/pir/dialect/control_flow/ir/cf_dialect.h" +#include "paddle/pir/dialect/control_flow/ir/cf_op.h" + +#ifdef PADDLE_WITH_DNNL +#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.h" +#endif + +namespace paddle { +namespace dialect { + +OneDNNOperatorDialect::OneDNNOperatorDialect(pir::IrContext *ctx) + : pir::Dialect(name(), ctx, pir::TypeId::get()) { + initialize(); +} + +void OneDNNOperatorDialect::initialize() { + // NOTE(zhangbo9674): GET_OP_LIST is defined in pd_op.h which is + // generated by op_gen.py, see details in + // paddle/fluid/pir/dialect/CMakeLists.txt. + // NOTE(Ruting)GET_MANUAL_OP_LIST is define in manual_op.h" + // use RegisterOps when list has more than two ops. + RegisterOps< +#define GET_OP_LIST +#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.cc" // NOLINT + >(); +} + +void OneDNNOperatorDialect::PrintType(pir::Type type, std::ostream &os) const { + os << type.dialect().name(); + os << '.'; + if (auto tensor_type = type.dyn_cast()) { + os << "tensor<"; + for (auto d : common::vectorize(tensor_type.dims())) { + os << d; + os << "x"; + } + tensor_type.dtype().Print(os); + os << ">"; + } else if (auto selected_rows_type = type.dyn_cast()) { + os << "selectedrows<"; + for (auto d : common::vectorize(selected_rows_type.dims())) { + os << d; + os << "x"; + } + selected_rows_type.dtype().Print(os); + os << ">"; + } else if (auto tensor_array_type = type.dyn_cast()) { + os << "tensor_array<"; + tensor_array_type.dtype().Print(os); + os << ">"; + } +} + +void OneDNNOperatorDialect::PrintAttribute(pir::Attribute attr, + std::ostream &os) const { + os << "(" << attr.dialect().name(); + os << '.'; + if (auto int_array_attr = attr.dyn_cast()) { + phi::IntArray data = int_array_attr.data(); + os << "IntArray)" + << "["; + const auto &inner_data = data.GetData(); + pir::PrintInterleave( + inner_data.begin(), + inner_data.end(), + [&os](int64_t i) { os << i; }, + [&os]() { os << ","; }); + os << "]"; + } else if (auto data_type_attr = attr.dyn_cast()) { + os << "DataType)" << data_type_attr.data(); + } else if (auto place_type_attr = attr.dyn_cast()) { + os << "Place)" << place_type_attr.data(); + } else if (auto data_layout_attr = attr.dyn_cast()) { + os << "DataLayout)" << data_layout_attr.data(); + } else { + os << "<#AttrNotImplemented>"; + } +} + +pir::Type OneDNNOperatorDialect::ParseType(pir::IrParser &parser) { // NOLINT + parser.ConsumeAToken("pd_op.tensor"); + parser.ConsumeAToken("<"); + std::vector dim{}; + Token dim_token = parser.PeekToken(); + while (dim_token.token_type_ == DIGIT) { + dim_token = parser.ConsumeToken(); + dim.push_back(atoi(dim_token.val_.c_str())); + std::string peek_token_val = parser.PeekToken().val_; + if (peek_token_val[0] != 'x') { + break; + } + parser.ConsumeToken(); + parser.lexer->Unget(static_cast(peek_token_val.size() - 1)); + if (parser.PeekToken().token_type_ != DIGIT) { + break; + } + } + phi::DDim ddim = common::make_ddim(dim); + pir::Type dtype = parser.ParseType(); + std::vector> lod; + std::vector lodv; + lodv.push_back(0); + lod.push_back(lodv); + parser.ConsumeAToken(">"); + return DenseTensorType::get( + parser.ctx, dtype, ddim, phi::DataLayout::UNDEFINED, lod, 0); +} + +pir::Attribute OneDNNOperatorDialect::ParseAttribute( + pir::IrParser &parser) { // NOLINT + std::string type_name = parser.ConsumeToken().val_; + std::string attribute_name = + type_name.substr(type_name.find('.') + 1, std::string::npos); + parser.ConsumeAToken(")"); + if (attribute_name == "IntArray") { + return IntArrayAttribute::Parse(parser); + } else if (attribute_name == "DataType") { + return DataTypeAttribute::Parse(parser); + } else if (attribute_name == "Place") { + return PlaceAttribute::Parse(parser); + } else if (attribute_name == "DataLayout") { + return DataLayoutAttribute::Parse(parser); + } else { + IR_THROW("No function to parse " + attribute_name + " exists!" + + parser.GetErrorLocationInfo()); + } +} + +void OneDNNOperatorDialect::PrintOperation(pir::Operation *op, + pir::IrPrinter &printer) const { + if (auto if_op = op->dyn_cast()) { + if_op.Print(printer); + } else if (auto while_op = op->dyn_cast()) { + while_op.Print(printer); + } else { + printer.PrintGeneralOperation(op); + } +} + +} // namespace dialect +} // namespace paddle + +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNOperatorDialect) diff --git a/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h b/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h new file mode 100644 index 00000000000000..ac6483d4d53ecb --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h @@ -0,0 +1,44 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/core/dialect.h" + +namespace paddle { +namespace dialect { + +class OneDNNOperatorDialect : public pir::Dialect { + public: + explicit OneDNNOperatorDialect(pir::IrContext* context); + + static const char* name() { return "pd_onednn_op"; } + + pir::Type ParseType(pir::IrParser& parser) override; // NOLINT + pir::Attribute ParseAttribute(pir::IrParser& parser) override; // NOLINT + + void PrintType(pir::Type type, std::ostream& os) const override; + void PrintAttribute(pir::Attribute type, std::ostream& os) const override; + + void PrintOperation(pir::Operation* op, + pir::IrPrinter& printer) const override; // NOLINT + + private: + void initialize(); +}; + +} // namespace dialect +} // namespace paddle + +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNOperatorDialect) diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 57d7857a2498ce..0d571f8ef868a7 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -1464,6 +1464,15 @@ func: number_count data_type: numbers +- op: onednn_to_paddle_layout + args: (Tensor x, int dst_layout) + output: Tensor(out) + infer_meta: + func : UnchangedInferMeta + param : [x] + kernel: + func: onednn_to_paddle_layout + - op: sparse_momentum args: (Tensor param, Tensor grad, Tensor velocity, Tensor index, Tensor learning_rate, Tensor master_param,float mu, Scalar axis=0, bool use_nesterov=false,str regularization_method="", float regularization_coeff=0.0f, bool multi_precision=false, float rescale_grad=1.0f) output: Tensor(param_out), Tensor(velocity_out), Tensor(master_param_out) diff --git a/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml b/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml new file mode 100644 index 00000000000000..58897216793dd6 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml @@ -0,0 +1,33 @@ + +- op : conv2d + extra_args : bool is_test=false + layout_transform : + arg_name: data_format + tensors: input + +- op : conv2d_grad + extra_args : bool is_test=false + layout_transform : + arg_name: data_format + tensors: input, out_grad +# - op : matmul +# extra_args : str mkldnn_data_type="float32" +# layout_transform : +# arg_name: cur_paddle_data_layout +# tensors: x, y + +# - op : pad3d +# extra_args : +# layout_transform : +# arg_name: data_format +# tensors: x +# dynamic_fallback : True + +# - op : batch_norm +# extra_args : bool fuse_with_relu=false +# layout_transform : +# arg_name: data_layout +# tensors: x + +# - op : prelu +# extra_args : bool is_test=false, str mkldnn_data_type="float32" diff --git a/paddle/fluid/pir/dialect/operator/trait/onednn.h b/paddle/fluid/pir/dialect/operator/trait/onednn.h new file mode 100644 index 00000000000000..df810c6707df12 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/trait/onednn.h @@ -0,0 +1,49 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef PADDLE_WITH_DNNL + +#include "paddle/pir/core/op_base.h" + +namespace paddle { +namespace dialect { +class OneDNNTrait : public pir::OpTraitBase { + public: + explicit OneDNNTrait(pir::Operation *op) + : pir::OpTraitBase(op) {} +}; + +class OneDNNOnlyTrait : public pir::OpTraitBase { + public: + explicit OneDNNOnlyTrait(pir::Operation *op) + : pir::OpTraitBase(op) {} +}; + +class OneDNNDynamicFallbackTrait + : public pir::OpTraitBase { + public: + explicit OneDNNDynamicFallbackTrait(pir::Operation *op) + : pir::OpTraitBase(op) {} +}; + +} // namespace dialect +} // namespace paddle + +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNTrait) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNOnlyTrait) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNDynamicFallbackTrait) + +#endif diff --git a/paddle/fluid/pir/dialect/operator/trait/trait.cc b/paddle/fluid/pir/dialect/operator/trait/trait.cc index 2a5b7575959b9f..9d828570d389aa 100644 --- a/paddle/fluid/pir/dialect/operator/trait/trait.cc +++ b/paddle/fluid/pir/dialect/operator/trait/trait.cc @@ -14,6 +14,14 @@ #include "paddle/fluid/pir/dialect/operator/trait/custom_vjp.h" #include "paddle/fluid/pir/dialect/operator/trait/inplace.h" - +#ifdef PADDLE_WITH_DNNL +#include "paddle/fluid/pir/dialect/operator/trait/onednn.h" +#endif IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::InplaceTrait) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CustomVjpTrait) + +#ifdef PADDLE_WITH_DNNL +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNTrait) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNOnlyTrait) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNDynamicFallbackTrait) +#endif diff --git a/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h index 637de470675eb1..662616bce773a0 100644 --- a/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h +++ b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h @@ -93,6 +93,12 @@ struct OpRunTimeInfo { std::vector kernel_key_backend; std::vector> inplace; std::vector> view; + std::vector extra_args; + std::string layout_transform_arg; + std::vector layout_transform_inputs; + bool is_onednn_only; + bool dynamic_fallback; + OpRunTimeInfo(const std::string& infer_meta_func, const std::vector& infer_meta_param, const std::string& kernel_func, @@ -100,7 +106,12 @@ struct OpRunTimeInfo { const std::vector& dtype, const std::vector& backend, const std::vector>& inplace, - const std::vector>& view) + const std::vector>& view, + const std::vector& extra_args = {}, + const std::string& layout_transform_arg = "", + const std::vector& layout_transform_inputs = {}, + bool is_onednn_only = false, + bool dynamic_fallback = false) : infer_meta_func(infer_meta_func), infer_meta_param(infer_meta_param), kernel_func(kernel_func), @@ -108,7 +119,12 @@ struct OpRunTimeInfo { kernel_key_dtype(dtype), kernel_key_backend(backend), inplace(inplace), - view(view) {} + view(view), + extra_args(extra_args), + layout_transform_arg(layout_transform_arg), + layout_transform_inputs(layout_transform_inputs), + is_onednn_only(is_onednn_only), + dynamic_fallback(dynamic_fallback) {} }; } // namespace dialect diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index 6782b2f8bfd7c3..722685fc3b5105 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -60,6 +60,7 @@ const std::unordered_set LegacyOpList = { SoftReluOp::name(), SoftReluGradOp::name()}; +const std::unordered_set OneDNNLegacyOpList = {}; enum class AttrType { UNDEFINED = 0, BOOL, @@ -220,6 +221,12 @@ VariantType GetAttributeData(const pir::Attribute& attr) { bool IsLegacyOp(const std::string& name) { return LegacyOpList.count(name); } +#ifdef PADDLE_WITH_DNNL +bool IsOneDNNLegacyOp(const std::string& name) { + return OneDNNLegacyOpList.count(name); +} +#endif + bool IsEmptyValue(const pir::Value& value) { return !value.impl() || !value.type(); } diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.h b/paddle/fluid/pir/dialect/operator/utils/utils.h index 1ebe7d244affdd..0e14077bb8559d 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.h +++ b/paddle/fluid/pir/dialect/operator/utils/utils.h @@ -132,6 +132,10 @@ VariantType GetAttributeData(const pir::Attribute& attr); bool IsLegacyOp(const std::string& name); +#ifdef PADDLE_WITH_DNNL +bool IsOneDNNLegacyOp(const std::string& name); +#endif + bool IsEmptyValue(const pir::Value& value); std::vector GetInt64Vector(const pir::Attribute& attr); diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 91ca8a0d4b3f66..df7b8673d9ea80 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -17,6 +17,7 @@ #include #include "paddle/fluid/framework/op_kernel_type.h" +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_op.h" @@ -44,6 +45,12 @@ #include "paddle/pir/dialect/control_flow/ir/cf_op.h" #include "paddle/utils/flags.h" +#ifdef PADDLE_WITH_DNNL +#include "paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.h" +#include "paddle/fluid/pir/dialect/operator/trait/onednn.h" +#endif + PHI_DECLARE_bool(print_ir); namespace paddle { namespace dialect { @@ -337,6 +344,49 @@ static pir::OpResult AddPlaceTransferOp(pir::Value in, return new_in; } +#ifdef PADDLE_WITH_DNNL +static pir::OpResult AddOneDNN2PaddleLayoutTransferOp( + pir::Value in, const phi::DataLayout& dst_layout, pir::Block* block) { + pir::IrContext* ctx = pir::IrContext::Instance(); + auto in_alloc_type = in.type().dyn_cast(); + + phi::KernelKey kernel_key; + kernel_key.set_backend(phi::Backend::CPU); + kernel_key.set_layout(phi::DataLayout::ANY); + kernel_key.set_dtype(dialect::TransToPhiDataType(in_alloc_type.dtype())); + + std::unordered_map op_attribute; + op_attribute = { + {"op_name", pir::StrAttribute::get(ctx, "pd_op.onednn_to_paddle_layout")}, + {"kernel_name", pir::StrAttribute::get(ctx, "onednn_to_paddle_layout")}, + {"kernel_key", KernelAttribute::get(ctx, kernel_key)}, + {"dst_layout", + pir::Int32Attribute::get(ctx, static_cast(dst_layout))}}; + + auto out_type = AllocatedDenseTensorType::get(ctx, + in_alloc_type.place(), + in_alloc_type.dtype(), + in_alloc_type.dims(), + dst_layout, + in_alloc_type.lod(), + in_alloc_type.offset()); + + pir::OpInfo kernel_op_info = ctx->GetRegisteredOpInfo(PhiKernelOp::name()); + pir::Operation* op = + pir::Operation::Create({in}, op_attribute, {out_type}, kernel_op_info); + + auto in_op = in.dyn_cast().owner(); + if (in_op && in_op->HasAttribute(kAttrIsPersisable)) { + op->set_attribute(kAttrIsPersisable, in_op->attribute(kAttrIsPersisable)); + } + + block->push_back(op); + auto new_in = op->result(0); + + return new_in; +} +#endif + static bool NeedTransformDataType(const phi::DataType& l, const phi::DataType& r) { return l != phi::DataType::ALL_DTYPE && r != phi::DataType::ALL_DTYPE && @@ -424,6 +474,46 @@ static pir::Type BuildOutputType(pir::Type type, } } +#ifdef PADDLE_WITH_DNNL +template +static pir::Type create_type(pir::Type type, + const phi::Place& place, + const phi::DataLayout& layout, + pir::Type out_dtype, + pir::IrContext* ctx) { + auto input_type = type.dyn_cast(); + return IrType2::get(ctx, + place, + out_dtype, + input_type.dims(), + layout, + input_type.lod(), + input_type.offset()); +} + +static pir::Type BuildOutputType(pir::Type type, + const phi::Place& place, + const phi::DataLayout& layout, + pir::IrContext* ctx) { + if (type.isa()) { + auto out_dtype = type.dyn_cast().dtype(); + return create_type( + type, place, layout, out_dtype, ctx); + } else if (type.isa()) { + auto out_dtype = type.dyn_cast().dtype(); + return create_type( + type, place, layout, out_dtype, ctx); + } else if (type.isa()) { + auto array_type = type.dyn_cast(); + return AllocatedDenseTensorArrayType::get( + ctx, place, array_type.dtype(), layout); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "BuildOutputType only support DenseTensorType and SelectedRowsType")); + } +} +#endif + pir::OpResult AddDtypeTransferOp(pir::Value in, pir::Block* block, const phi::KernelKey& kernel_key, @@ -666,6 +756,49 @@ std::string GetKernelName(const OpYamlInfoParser* op_info_parser, return kernel_fn_str; } +#ifdef PADDLE_WITH_DNNL +bool SupportsMKLDNN(const std::string& kernel_name, + const phi::DataType data_type) { + auto phi_kernels = + phi::KernelFactory::Instance().SelectKernelMap(kernel_name); + auto has_phi_kernel = + std::any_of(phi_kernels.begin(), + phi_kernels.end(), + [data_type](phi::KernelKeyMap::const_reference kern_pair) { + return kern_pair.first.backend() == phi::Backend::ONEDNN && + kern_pair.first.dtype() == data_type; + }); + if (has_phi_kernel) { + return true; + } else { + auto op_kernel_iter = + paddle::framework::OperatorWithKernel::AllOpKernels().find( + phi::TransToFluidOpName(kernel_name)); + if (op_kernel_iter == + paddle::framework::OperatorWithKernel::AllOpKernels().end()) { + return false; + } else { + auto& op_kernels = op_kernel_iter->second; + return std::any_of( + op_kernels.begin(), + op_kernels.end(), + [data_type](std::unordered_map< + paddle::framework::OpKernelType, + std::function, + paddle::framework::OpKernelType::Hash>::const_reference + kern_pair) { + return platform::is_cpu_place(kern_pair.first.place_) && + kern_pair.first.library_type_ == + paddle::framework::LibraryType::kMKLDNN && + kern_pair.first.data_type_ == + paddle::framework::TransToProtoVarType(data_type); + }); + } + } +} +#endif + phi::KernelKey GetKernelKey( pir::Operation* op, const phi::Place& place, @@ -899,6 +1032,13 @@ phi::KernelKey GetKernelKey( "to GPU"; } +#ifdef PADDLE_WITH_DNNL + if (op->HasTrait() && res.backend() == phi::Backend::CPU && + SupportsMKLDNN(kernel_fn_str, res.dtype())) { + res.set_backend(phi::Backend::ONEDNN); + res.set_layout(phi::DataLayout::ONEDNN); + } +#endif return res; } @@ -1375,7 +1515,17 @@ std::vector BuildOutputs(pir::Operation* op_item, } else if (result_type.isa() || result_type.isa() || result_type.isa()) { +#ifdef PADDLE_WITH_DNNL + if (kernel_key.backend() == phi::Backend::ONEDNN) { + op_output_types.push_back(BuildOutputType( + result_type, out_place, phi::DataLayout::ONEDNN, ctx)); + } else { + op_output_types.push_back(BuildOutputType(result_type, out_place, ctx)); + } +#else op_output_types.push_back(BuildOutputType(result_type, out_place, ctx)); +#endif + } else if (result_type.isa()) { std::vector vec_inner_types; auto base_types = result_type.dyn_cast().data(); @@ -1383,8 +1533,18 @@ std::vector BuildOutputs(pir::Operation* op_item, if (base_type) { if (base_type.isa() || base_type.isa()) { +#ifdef PADDLE_WITH_DNNL + if (kernel_key.backend() == phi::Backend::ONEDNN) { + vec_inner_types.push_back(BuildOutputType( + base_type, out_place, phi::DataLayout::ONEDNN, ctx)); + } else { + vec_inner_types.push_back( + BuildOutputType(base_type, out_place, ctx)); + } +#else vec_inner_types.push_back( BuildOutputType(base_type, out_place, ctx)); +#endif } else { PADDLE_THROW(phi::errors::Unimplemented( "only support dense tensor and selected rows in vector type " @@ -1395,6 +1555,11 @@ std::vector BuildOutputs(pir::Operation* op_item, pir::Type fp32_dtype = pir::Float32Type::get(ctx); phi::DDim dims = {}; phi::DataLayout data_layout = phi::DataLayout::NCHW; +#ifdef PADDLE_WITH_DNNL + if (kernel_key.backend() == phi::Backend::ONEDNN) { + data_layout = phi::DataLayout::ONEDNN; + } +#endif phi::LoD lod = {{}}; size_t offset = 0; auto dense_tensor_dtype = DenseTensorType::get( @@ -1463,7 +1628,21 @@ std::vector BuildInputs( } } - // 1.backend transfer + // 1. layout transfer(only for onednn) +#ifdef PADDLE_WITH_DNNL + if (kernel_key.backend() != phi::Backend::ONEDNN) { + auto new_in_type = new_in.type(); + if (new_in_type.isa()) { + if (new_in_type.dyn_cast().data_layout() == + phi::DataLayout::ONEDNN) { + new_in = AddOneDNN2PaddleLayoutTransferOp( + new_in, phi::DataLayout::ANY, block); + } + } + } +#endif + + // 2.backend transfer bool check_place_transfer = (op_item->isa<::pir::SetParameterOp>()) || (kernel.IsValid() && (!UnchangeOutputOps.count(op_item->name()))); @@ -1664,7 +1843,7 @@ std::vector BuildInputs( } } - // 2. dtype transfer + // 3. dtype transfer if (op_info_parser != nullptr) { std::string var_name = op_info_parser->InputNames()[i]; auto fake_tensors = PrepareFakeTensors(new_in); @@ -1694,6 +1873,7 @@ std::vector BuildInputs( } } } + vec_inputs.push_back(new_in); } return vec_inputs; @@ -1773,18 +1953,76 @@ pir::Operation* BuildKernelOp( op_attribute.emplace("is_inplace", pir::BoolAttribute::get(ctx, true)); } - pir::OpInfo phi_kernel_op_info = - ctx->GetRegisteredOpInfo(PhiKernelOp::name()); - - pir::OpInfo legacy_kernel_op_info = - ctx->GetRegisteredOpInfo(LegacyKernelOp::name()); pir::Operation* op = nullptr; - if (IsLegacyOp(op_item->name())) { - op = pir::Operation::Create( - vec_inputs, op_attribute, op_output_types, legacy_kernel_op_info); - } else { - op = pir::Operation::Create( - vec_inputs, op_attribute, op_output_types, phi_kernel_op_info); +#ifdef PADDLE_WITH_DNNL + if (op_item->HasTrait()) { + if (IsOneDNNLegacyOp(op_item->name())) { + VLOG(4) << "choose OneDNNLegacyKernelOp"; + pir::OpInfo legacy_kernel_op_info = + ctx->GetRegisteredOpInfo(OneDNNLegacyKernelOp::name()); + op = pir::Operation::Create( + vec_inputs, op_attribute, op_output_types, legacy_kernel_op_info); + } else { + auto op_info_parser = GetOpYamlInfoParser(op_item); + std::vector extra_args; + for (auto& arg : op_info_parser->OpRuntimeInfo().extra_args) { + extra_args.push_back(pir::StrAttribute::get(ctx, arg)); + } + op_attribute.emplace( + "extra_args", + pir::ArrayAttribute::get(pir::IrContext::Instance(), extra_args)); + op_attribute.emplace( + "layout_transform_arg", + pir::StrAttribute::get( + ctx, op_info_parser->OpRuntimeInfo().layout_transform_arg)); + std::vector layout_transform_inputs; + for (auto& input : + op_info_parser->OpRuntimeInfo().layout_transform_inputs) { + layout_transform_inputs.push_back(pir::StrAttribute::get(ctx, input)); + } + op_attribute.emplace("layout_transform_inputs", + pir::ArrayAttribute::get(pir::IrContext::Instance(), + layout_transform_inputs)); + op_attribute.emplace( + "is_onednn_only", + pir::BoolAttribute::get( + ctx, op_info_parser->OpRuntimeInfo().is_onednn_only)); + op_attribute.emplace( + "dynamic_fallback", + pir::BoolAttribute::get( + ctx, op_info_parser->OpRuntimeInfo().dynamic_fallback)); + if (op_item->HasTrait()) { + VLOG(4) << "choose OneDNNMixedPhiKernelOp"; + pir::OpInfo phi_kernel_op_info = + ctx->GetRegisteredOpInfo(OneDNNMixedPhiKernelOp::name()); + + op = pir::Operation::Create( + vec_inputs, op_attribute, op_output_types, phi_kernel_op_info); + } else { + VLOG(4) << "choose OneDNNPhiKernelOp"; + pir::OpInfo phi_kernel_op_info = + ctx->GetRegisteredOpInfo(OneDNNPhiKernelOp::name()); + + op = pir::Operation::Create( + vec_inputs, op_attribute, op_output_types, phi_kernel_op_info); + } + } + } else // NOLINT +#endif + { + if (IsLegacyOp(op_item->name())) { + pir::OpInfo legacy_kernel_op_info = + ctx->GetRegisteredOpInfo(LegacyKernelOp::name()); + + op = pir::Operation::Create( + vec_inputs, op_attribute, op_output_types, legacy_kernel_op_info); + } else { + pir::OpInfo phi_kernel_op_info = + ctx->GetRegisteredOpInfo(PhiKernelOp::name()); + + op = pir::Operation::Create( + vec_inputs, op_attribute, op_output_types, phi_kernel_op_info); + } } (*map_op_pair)[op_item] = op; @@ -1809,10 +2047,11 @@ void ProcessBlock( std::unordered_map* map_value_pair) { auto inputs_by_data_op = GetInputsByDataOp(block); - for (auto& op_item : *block) { - VLOG(6) << "op name " << op_item.name(); - if ((op_item.isa()) && - inputs_by_data_op.count(op_item.attributes() + for (auto iter = block->begin(); iter != block->end(); ++iter) { + pir::Operation* op_item = &(*iter); + VLOG(6) << "op name " << op_item->name(); + if ((op_item->isa()) && + inputs_by_data_op.count(op_item->attributes() .at("name") .dyn_cast() .AsString())) { @@ -1821,24 +2060,55 @@ void ProcessBlock( } // HandleSpecialOp - if (SpecialLowerOps.count(op_item.name())) { - VLOG(6) << "Handle Special Op: [" << op_item.name() + if (SpecialLowerOps.count(op_item->name())) { + VLOG(6) << "Handle Special Op: [" << op_item->name() << "] while lowering to kernel pass"; HandleForSpecialOp( - place, &op_item, new_block, ctx, map_op_pair, map_value_pair); + place, op_item, new_block, ctx, map_op_pair, map_value_pair); continue; } - auto op_info_parser = GetOpYamlInfoParser(&op_item); - auto kernel_name = GetKernelName(op_info_parser.get(), &op_item); + auto op_info_parser = GetOpYamlInfoParser(op_item); + auto kernel_name = GetKernelName(op_info_parser.get(), op_item); auto kernel_key = GetKernelKey( - &op_item, place, kernel_name, *map_value_pair, op_info_parser.get()); + op_item, place, kernel_name, *map_value_pair, op_info_parser.get()); VLOG(6) << "kernel type " << kernel_key; +#ifdef PADDLE_WITH_DNNL + if (op_item->HasTrait() && + kernel_key.backend() != phi::Backend::ONEDNN) { + std::vector op_item_inner_output_types; + if (op_item->num_results() > 0) { + for (size_t i = 0; i < op_item->num_results(); ++i) { + op_item_inner_output_types.push_back(op_item->result_type(i)); + } + } + std::string target_op_name = op_item->name(); + target_op_name.replace(0, 12, "pd_op"); + auto op_info = ctx->GetRegisteredOpInfo(target_op_name); + if (!op_info) { + IR_THROW("Ctx should have corresponding OpInfo %s", target_op_name); + } + pir::Operation* op_item_inner = + pir::Operation::Create(op_item->operands_source(), + op_item->attributes(), + op_item_inner_output_types, + op_info); + op_item->ReplaceAllUsesWith(op_item_inner->results()); + for (auto iter = block->begin(); iter != block->end(); ++iter) { + if (*iter == *op_item) { + block->Assign(iter, op_item_inner); + break; + } + } + op_item = op_item_inner; + op_info_parser = GetOpYamlInfoParser(op_item_inner); + } +#endif // build output type - auto op_output_types = BuildOutputs(&op_item, kernel_name, kernel_key, ctx); + auto op_output_types = BuildOutputs(op_item, kernel_name, kernel_key, ctx); // build input - auto vec_inputs = BuildInputs(&op_item, + auto vec_inputs = BuildInputs(op_item, kernel_name, kernel_key, place, @@ -1853,14 +2123,14 @@ void ProcessBlock( kernel_key, vec_inputs, op_output_types, - &op_item, + op_item, new_block, ctx, map_op_pair, map_value_pair); AddShadowFeedOpForDataOrFeed( - place, &op_item, op, new_block, ctx, map_op_pair, map_value_pair); + place, op_item, op, new_block, ctx, map_op_pair, map_value_pair); } } @@ -1877,7 +2147,10 @@ std::unique_ptr PdOpLowerToKernelPass(pir::Program* prog, pir::IrContext* ctx = pir::IrContext::Instance(); ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); - +#ifdef PADDLE_WITH_DNNL + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); +#endif std::unordered_map map_op_pair; std::unordered_map map_value_pair; diff --git a/paddle/phi/api/lib/data_transform.h b/paddle/phi/api/lib/data_transform.h index dd3166f05c3ef9..e0509fa8582ae2 100644 --- a/paddle/phi/api/lib/data_transform.h +++ b/paddle/phi/api/lib/data_transform.h @@ -178,6 +178,11 @@ inline bool NeedTransformPlace(const phi::Place& src_place, (target != Backend::ALL_BACKEND && phi::TransToPhiBackend(src_place) != (target != Backend::GPUDNN ? target : Backend::GPU)); +#ifdef PADDLE_WITH_DNNL + if (target == Backend::ONEDNN) { + ret = src_place.GetType() != AllocationType::CPU; + } +#endif return ret; } diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 556a713fdac302..d69e290bdbd144 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -2462,6 +2462,15 @@ outputs : {q : Q, r : R} +- op : quantize + backward : quantize_grad + inputs : + input : Input + outputs : + output : Output + attrs : + {scale : Scale, shift : Shift, include_self: Include_self} + - op : quantize_linear extra : attrs : [float moving_rate = 0.9] diff --git a/paddle/phi/kernels/cpu/onednn_to_paddle_layout_kernel.cc b/paddle/phi/kernels/cpu/onednn_to_paddle_layout_kernel.cc new file mode 100644 index 00000000000000..eba8b2b61f4d27 --- /dev/null +++ b/paddle/phi/kernels/cpu/onednn_to_paddle_layout_kernel.cc @@ -0,0 +1,94 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/onednn_to_paddle_layout_kernel.h" + +#include +#include + +#include "glog/logging.h" + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/funcs/data_layout_transform.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/memcpy_kernel.h" +#ifdef PADDLE_WITH_DNNL +#include "paddle/phi/backends/onednn/onednn_helper.h" +#endif +namespace phi { + +template +void OneDNN2PaddleLayout(const Context& dev_ctx, + const DenseTensor& x, + int dst_layout, + DenseTensor* out) { +#ifdef PADDLE_WITH_DNNL + DataLayout src_layout = x.layout(); + VLOG(10) << "TransDataLayout from " << static_cast(src_layout) + << " -> " << static_cast(dst_layout); + + auto print_tensor_meta = [](const DenseTensor& x) { + std::ostringstream oss; + + oss << "["; + oss << "layout:" << x.layout() << " ,"; + oss << "dims:" << x.dims() << " ,"; + if (x.IsInitialized()) oss << "place:" << x.place(); + oss << "]"; + + return oss.str(); + }; + VLOG(10) << " x: " << print_tensor_meta(x); + VLOG(10) << " out: " << print_tensor_meta(*out) << " " << out; + + if (src_layout != DataLayout::ONEDNN) { + out->ShareDataWith(x); + out->ShareInplaceVersionCounterWith(x); + out->set_layout(static_cast(dst_layout)); + return; + } + + DataLayout tmp_layout = static_cast(dst_layout); + if (static_cast(dst_layout) == DataLayout::ANY) { + tmp_layout = phi::OneDNNContext::tls().get_cur_paddle_data_layout(); + } + + if (tmp_layout == DataLayout::ANY) { + tmp_layout = phi::OneDNNContext::tls().get_cur_paddle_data_layout(); + } + + // NOTE(zhiqiu): to handle the special case in ApplyDataTransform() in + // data_transfer.cc + if (!x.IsInitialized() && src_layout == DataLayout::ONEDNN && + tmp_layout == DataLayout::NHWC) { + VLOG(4) << src_layout << "->" << tmp_layout << " " << x.layout(); + out->Resize(x.dims()); + out->set_layout(tmp_layout); + funcs::MatchShapeToLayout(out, src_layout, tmp_layout); + return; + } + + funcs::TransDataLayoutFromOneDNN( + src_layout, tmp_layout, x, out, dev_ctx.GetPlace()); +#endif +} + +} // namespace phi + +PD_REGISTER_KERNEL_FOR_ALL_DTYPE(onednn_to_paddle_layout, + CPU, + ALL_LAYOUT, + phi::OneDNN2PaddleLayout) {} diff --git a/paddle/phi/kernels/onednn_to_paddle_layout_kernel.h b/paddle/phi/kernels/onednn_to_paddle_layout_kernel.h new file mode 100644 index 00000000000000..a6ddc280c4e3c8 --- /dev/null +++ b/paddle/phi/kernels/onednn_to_paddle_layout_kernel.h @@ -0,0 +1,28 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/unary.h" +#include "paddle/phi/kernels/empty_kernel.h" + +namespace phi { + +template +void OneDNN2PaddleLayout(const Context& dev_ctx, + const DenseTensor& x, + int dst_layout, + DenseTensor* out); +} // namespace phi diff --git a/test/mkldnn/test_conv2d_mkldnn_op.py b/test/mkldnn/test_conv2d_mkldnn_op.py index 3c77581acf80db..2d6cafdbc3734b 100644 --- a/test/mkldnn/test_conv2d_mkldnn_op.py +++ b/test/mkldnn/test_conv2d_mkldnn_op.py @@ -17,6 +17,9 @@ import numpy as np from op_test import OpTest, skip_check_grad_ci from test_conv2d_op import TestConv2DOp, TestConv2DOp_v2 +from utils import compare_legacy_with_pt + +from paddle.base import core def conv2d_bias_naive(out, bias): @@ -113,6 +116,94 @@ def setUp(self): self.outputs['Output'] = output +class TestConv2DMKLDNNOp2(TestConv2DOp): + def init_group(self): + self.groups = 1 + + def init_kernel_type(self): + self.data_format = "NCHW" + self.use_mkldnn = True + self._cpu_only = True + self.dtype = np.float32 + + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + def setUp(self): + self.fuse_bias = False + self.bias_size = None + self.fuse_activation = "" + self.fuse_alpha = 0 + self.fuse_beta = 0 + self.fuse_residual_connection = False + self.input_residual_size = None + + TestConv2DOp.setUp(self) + + output = self.outputs['Output'] + + # mkldnn only support either conv-sum-relu, or conv-relu. + if self.fuse_bias and self.bias_size is not None: + bias = np.random.random(self.bias_size).astype(self.dtype) + output = conv2d_bias_naive(output, bias) + output = output.astype(self.dtype) + self.attrs['fuse_bias'] = self.fuse_bias + self.inputs['Bias'] = OpTest.np_dtype_to_base_dtype(bias) + + if ( + self.fuse_residual_connection + and self.input_residual_size is not None + ): + input_residual = np.random.random(self.input_residual_size).astype( + self.dtype + ) + output = conv2d_residual_naive(output, input_residual) + + self.attrs[ + 'fuse_residual_connection' + ] = self.fuse_residual_connection + self.inputs['ResidualData'] = OpTest.np_dtype_to_base_dtype( + input_residual + ) + + if self.fuse_activation == "relu": + output = np.maximum(output, 0).astype(self.dsttype) + + if self.fuse_activation == "relu6": + output = np.minimum(np.maximum(output, 0), self.fuse_beta).astype( + self.dsttype + ) + if ( + self.fuse_activation != "" + or self.fuse_bias + or self.fuse_residual_connection + ): + self.op_type = 'fused_conv2d' + + output = output.astype(self.dtype) + + self.attrs['fuse_bias'] = self.fuse_bias + self.attrs['fuse_activation'] = self.fuse_activation + self.attrs['fuse_alpha'] = self.fuse_alpha + self.attrs['fuse_beta'] = self.fuse_beta + self.attrs['fuse_residual_connection'] = self.fuse_residual_connection + + self.outputs['Output'] = output + + @compare_legacy_with_pt + def test_check_output(self): + place = core.CUDAPlace(0) if self.has_cuda() else core.CPUPlace() + # TODO(wangzhongpu): support mkldnn op in dygraph mode + self.check_output_with_place( + place, atol=1e-5, check_dygraph=(not self.use_mkldnn) + ) + + @skip_check_grad_ci( reason="Fusion is for inference only, check_grad is not required." ) From 1b5c02827f4cdc953110e6bda3d9cc6e52cf33e9 Mon Sep 17 00:00:00 2001 From: risemeup1 <62429225+risemeup1@users.noreply.github.com> Date: Thu, 28 Dec 2023 10:33:33 +0800 Subject: [PATCH 014/142] =?UTF-8?q?[cmake=E6=B2=BB=E7=90=86]Cmake=20optimi?= =?UTF-8?q?zation=20framework/details=20(#59478)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * cmake optimization * cmake optimization * cmake optimization * cmake optimization --- paddle/fluid/eager/CMakeLists.txt | 2 +- paddle/fluid/framework/CMakeLists.txt | 14 +- paddle/fluid/framework/details/CMakeLists.txt | 363 ++++-------------- .../framework/details/build_strategy_test.cc | 308 --------------- paddle/fluid/framework/ir/CMakeLists.txt | 3 +- .../ir/memory_optimize_pass/CMakeLists.txt | 23 +- .../multi_devices_graph_pass/CMakeLists.txt | 29 +- paddle/fluid/imperative/CMakeLists.txt | 2 - paddle/fluid/inference/CMakeLists.txt | 2 +- paddle/fluid/pybind/CMakeLists.txt | 2 +- .../fluid/framework/details/CMakeLists.txt | 9 +- .../details/reduce_op_handle_test.cc | 0 12 files changed, 101 insertions(+), 656 deletions(-) delete mode 100644 paddle/fluid/framework/details/build_strategy_test.cc rename {paddle => test/cpp}/fluid/framework/details/reduce_op_handle_test.cc (100%) diff --git a/paddle/fluid/eager/CMakeLists.txt b/paddle/fluid/eager/CMakeLists.txt index dde3ed71bc8c9e..5667a86876e19e 100755 --- a/paddle/fluid/eager/CMakeLists.txt +++ b/paddle/fluid/eager/CMakeLists.txt @@ -45,7 +45,7 @@ endif() cc_library( eager_nan_inf_utils SRCS nan_inf_utils.cc - DEPS phi common nan_inf_utils enforce) + DEPS phi common enforce) cc_library( grad_node_info SRCS grad_node_info.cc diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index e5aac824b753bf..338130c64d9a06 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -340,7 +340,7 @@ if(WITH_XPU) op_kernel_type op_call_stack unused_var_check - nan_inf_utils + detail_op_handle phi_utils infershape_utils phi @@ -367,7 +367,7 @@ else() op_kernel_type op_call_stack unused_var_check - nan_inf_utils + detail_op_handle phi_utils infershape_utils phi @@ -873,15 +873,7 @@ target_link_libraries( cc_library( parallel_executor SRCS parallel_executor.cc - DEPS threaded_ssa_graph_executor - scope_buffered_ssa_graph_executor - parallel_ssa_graph_executor - async_ssa_graph_executor - graph - build_strategy - bind_threaded_ssa_graph_executor - collective_helper - fast_threaded_ssa_graph_executor + DEPS ssa_graph_executor graph build_strategy collective_helper variable_helper) cc_library( diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 2ee0da89fe980d..d771a12411adbe 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -1,55 +1,3 @@ -cc_library( - var_handle - SRCS var_handle.cc - DEPS place framework_proto node) -cc_library( - op_handle_base - SRCS op_handle_base.cc - DEPS var_handle device_context lod_tensor) - -cc_library( - scale_loss_grad_op_handle - SRCS scale_loss_grad_op_handle.cc - DEPS op_handle_base scope lod_tensor phi common) -cc_library( - fetch_op_handle - SRCS fetch_op_handle.cc - DEPS op_handle_base scope lod_tensor phi common) -cc_library( - fetch_async_op_handle - SRCS fetch_async_op_handle.cc - DEPS op_handle_base scope lod_tensor phi common) - -cc_library( - share_tensor_buffer_functor - SRCS share_tensor_buffer_functor.cc - DEPS framework_proto scope place operator op_registry) -cc_library( - computation_op_handle - SRCS computation_op_handle.cc - DEPS framework_proto scope place operator op_registry) -cc_library( - share_tensor_buffer_op_handle - SRCS share_tensor_buffer_op_handle.cc - DEPS op_handle_base scope computation_op_handle share_tensor_buffer_functor) -cc_library( - rpc_op_handle - SRCS rpc_op_handle.cc - DEPS framework_proto scope place operator op_registry) -cc_library( - fetch_barrier_op_handle - SRCS fetch_barrier_op_handle.cc - DEPS framework_proto scope place operator op_registry) -cc_library( - multi_devices_helper - SRCS multi_devices_helper.cc - DEPS graph graph_helper) - -cc_library( - variable_visitor - SRCS variable_visitor.cc - DEPS lod_tensor selected_rows_utils) - if(WITH_PSCORE) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor" @@ -67,260 +15,101 @@ if(WITH_PSCORE) ${DISTRIBUTE_COMPILE_FLAGS}) endif() -if(WITH_GPU) - nv_library( - nan_inf_utils - SRCS nan_inf_utils_detail.cc - DEPS framework_proto scope place phi common) - nv_library( - all_reduce_op_handle - SRCS all_reduce_op_handle.cc - DEPS variable_visitor - op_handle_base - scope - lod_tensor - phi - common - fluid_memory - dynload_cuda) - nv_library( - fused_all_reduce_op_handle - SRCS fused_all_reduce_op_handle.cc - DEPS all_reduce_op_handle - variable_visitor - op_handle_base - scope - lod_tensor - phi - common - dynload_cuda - place) - nv_library( - grad_merge_all_reduce_op_handle - SRCS grad_merge_all_reduce_op_handle.cc - DEPS fused_all_reduce_op_handle - op_handle_base - scope - lod_tensor - phi - common - dynload_cuda - variable_visitor - place - all_reduce_op_handle) +set(op_handle_srcs + nan_inf_utils_detail.cc + all_reduce_op_handle.cc + fused_all_reduce_op_handle.cc + grad_merge_all_reduce_op_handle.cc + reduce_op_handle.cc + broadcast_op_handle.cc + fused_broadcast_op_handle.cc + var_handle.cc + op_handle_base.cc + scale_loss_grad_op_handle.cc + fetch_op_handle.cc + fetch_async_op_handle.cc + share_tensor_buffer_functor.cc + computation_op_handle.cc + share_tensor_buffer_op_handle.cc + rpc_op_handle.cc + fetch_barrier_op_handle.cc + multi_devices_helper.cc + variable_visitor.cc + gather_op_handle.cc + eager_deletion_op_handle.cc) + +if(WITH_DGC) + set(op_handle_srcs ${op_handle_srcs} sparse_all_reduce_op_handle.cc) +endif() - if(WITH_DGC) - nv_library( - sparse_all_reduce_op_handle - SRCS sparse_all_reduce_op_handle.cc - DEPS op_handle_base - scope - lod_tensor - phi - common - dynload_cuda - variable_visitor - dgc - all_reduce_op_handle) - endif() +set(op_handle_deps + pass + operator + place + framework_proto + node + device_context + op_registry + lod_tensor + selected_rows_utils + reference_count_pass_helper) - if(WITH_DISTRIBUTE) - nv_library( - reduce_op_handle - SRCS reduce_op_handle.cc - DEPS op_handle_base variable_visitor scope phi common dynload_cuda) - else() - nv_library( - reduce_op_handle - SRCS reduce_op_handle.cc - DEPS op_handle_base variable_visitor scope phi common dynload_cuda) - endif() - nv_library( - broadcast_op_handle - SRCS broadcast_op_handle.cc - DEPS op_handle_base scope phi common variable_visitor dynload_cuda) +if(WITH_MKLDNN) + set(op_handle_deps ${op_handle_deps} mkldnn) +endif() + +if(WITH_DGC) + set(op_handle_deps ${op_handle_deps} dgc) +endif() + +if(WITH_GPU) nv_library( - fused_broadcast_op_handle - SRCS fused_broadcast_op_handle.cc - DEPS broadcast_op_handle) + detail_op_handle + SRCS ${op_handle_srcs} + DEPS ${op_handle_deps}) elseif(WITH_ROCM) hip_library( - nan_inf_utils - SRCS nan_inf_utils_detail.cc - DEPS framework_proto scope place phi common) - hip_library( - all_reduce_op_handle - SRCS all_reduce_op_handle.cc - DEPS op_handle_base - scope - lod_tensor - phi - common - dynload_cuda - variable_visitor) - hip_library( - fused_all_reduce_op_handle - SRCS fused_all_reduce_op_handle.cc - DEPS all_reduce_op_handle - op_handle_base - variable_visitor - scope - lod_tensor - phi - common - dynload_cuda - place) - hip_library( - grad_merge_all_reduce_op_handle - SRCS grad_merge_all_reduce_op_handle.cc - DEPS fused_all_reduce_op_handle - op_handle_base - scope - lod_tensor - phi - common - dynload_cuda - variable_visitor - place - all_reduce_op_handle) - - if(WITH_DISTRIBUTE) - hip_library( - reduce_op_handle - SRCS reduce_op_handle.cc - DEPS op_handle_base variable_visitor scope phi common dynload_cuda) - else() - hip_library( - reduce_op_handle - SRCS reduce_op_handle.cc - DEPS op_handle_base variable_visitor scope phi common dynload_cuda) - endif() - hip_library( - broadcast_op_handle - SRCS broadcast_op_handle.cc - DEPS op_handle_base scope phi common variable_visitor dynload_cuda) - hip_library( - fused_broadcast_op_handle - SRCS fused_broadcast_op_handle.cc - DEPS broadcast_op_handle) + detail_op_handle + SRCS ${op_handle_srcs} + DEPS ${op_handle_deps}) else() cc_library( - nan_inf_utils - SRCS nan_inf_utils_detail.cc - DEPS framework_proto scope place phi common) - cc_library( - all_reduce_op_handle - SRCS all_reduce_op_handle.cc - DEPS op_handle_base scope lod_tensor phi common variable_visitor) - cc_library( - fused_all_reduce_op_handle - SRCS fused_all_reduce_op_handle.cc - DEPS all_reduce_op_handle - op_handle_base - scope - lod_tensor - phi - common - variable_visitor - place) - cc_library( - grad_merge_all_reduce_op_handle - SRCS grad_merge_all_reduce_op_handle.cc - DEPS fused_all_reduce_op_handle - op_handle_base - scope - lod_tensor - phi - common - variable_visitor - place - all_reduce_op_handle) - if(WITH_DISTRIBUTE) - cc_library( - reduce_op_handle - SRCS reduce_op_handle.cc - DEPS op_handle_base variable_visitor scope phi common) - else() - cc_library( - reduce_op_handle - SRCS reduce_op_handle.cc - DEPS op_handle_base variable_visitor scope phi common) - endif() - cc_library( - broadcast_op_handle - SRCS broadcast_op_handle.cc - DEPS op_handle_base scope phi common variable_visitor) - cc_library( - fused_broadcast_op_handle - SRCS fused_broadcast_op_handle.cc - DEPS broadcast_op_handle) + detail_op_handle + SRCS ${op_handle_srcs} + DEPS ${op_handle_deps}) endif() -cc_library( - gather_op_handle - SRCS gather_op_handle.cc - DEPS op_handle_base scope phi common variable_visitor) - -cc_library( - eager_deletion_op_handle - SRCS eager_deletion_op_handle.cc - DEPS lod_tensor selected_rows_utils reference_count_pass_helper) - +add_dependencies(detail_op_handle framework_proto auto_parallel_proto xxhash) + +set(ssa_graph_executor_srcs + ssa_graph_executor.cc + threaded_ssa_graph_executor.cc + parallel_ssa_graph_executor.cc + async_ssa_graph_executor.cc + bind_threaded_ssa_graph_executor.cc + fast_threaded_ssa_graph_executor.cc + scope_buffered_ssa_graph_executor.cc + scope_buffered_monitor.cc) set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto - multi_devices_helper + detail_op_handle reference_count_pass eager_deletion_pass buffer_shared_inplace_op_pass buffer_shared_cross_op_memory_reuse_pass inplace_addto_op_pass - set_reader_device_info_utils) -cc_library( - ssa_graph_executor NOT_FOR_INFER - SRCS ssa_graph_executor.cc - DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) - -cc_library( - threaded_ssa_graph_executor - SRCS threaded_ssa_graph_executor.cc - DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool - device_context) + set_reader_device_info_utils + scope + simple_threadpool + device_context + profiler + selected_rows_utils) cc_library( - parallel_ssa_graph_executor - SRCS parallel_ssa_graph_executor.cc - DEPS threaded_ssa_graph_executor) - -set(ASYNC_SSA_GRAPH_EXECUTOR_DEPS threaded_ssa_graph_executor) - -cc_library( - async_ssa_graph_executor - SRCS async_ssa_graph_executor.cc - DEPS ${ASYNC_SSA_GRAPH_EXECUTOR_DEPS}) -cc_library( - scope_buffered_monitor - SRCS scope_buffered_monitor.cc - DEPS scope profiler selected_rows_utils) -cc_library( - scope_buffered_ssa_graph_executor - SRCS scope_buffered_ssa_graph_executor.cc - DEPS ssa_graph_executor scope_buffered_monitor) -cc_library( - bind_threaded_ssa_graph_executor - SRCS bind_threaded_ssa_graph_executor.cc - DEPS fetch_op_handle - phi - common - ssa_graph_executor - scope - simple_threadpool - device_context) -cc_library( - fast_threaded_ssa_graph_executor - SRCS fast_threaded_ssa_graph_executor.cc - DEPS fetch_async_op_handle ssa_graph_executor scope simple_threadpool - device_context) + ssa_graph_executor + SRCS ${ssa_graph_executor_srcs} + DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) set(IR_PASS_DEPS graph_viz_pass diff --git a/paddle/fluid/framework/details/build_strategy_test.cc b/paddle/fluid/framework/details/build_strategy_test.cc deleted file mode 100644 index dc6a7e33e4f2f0..00000000000000 --- a/paddle/fluid/framework/details/build_strategy_test.cc +++ /dev/null @@ -1,308 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/details/build_strategy.h" - -#include -#include -#include -#include -#include -#include - -#include "gtest/gtest-message.h" -#include "gtest/gtest-test-part.h" -#include "gtest/gtest.h" -#include "gtest/gtest_pred_impl.h" -#include "paddle/fluid/framework/op_proto_maker.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/var_type_inference.h" -#include "paddle/fluid/platform/place.h" - -PD_DECLARE_bool(convert_all_blocks); - -namespace paddle { -namespace framework { - -class SumOpMaker : public OpProtoAndCheckerMaker { - public: - void Make() { - AddInput("X", "").AsDuplicable(); - AddOutput("Out", "").AsDuplicable(); - AddComment(""); - } -}; - -class SumOpWithKernel : public OperatorWithKernel { - public: - using OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext *ctx) const override {} - phi::KernelKey GetExpectedKernelType( - const ExecutionContext &ctx) const override { - return phi::KernelKey(proto::VarType::FP32, - ctx.Input("X")->place()); - } -}; - -} // namespace framework -} // namespace paddle - -REGISTER_OP_WITHOUT_GRADIENT(fake_sum, - paddle::framework::SumOpWithKernel, - paddle::framework::SumOpMaker); - -namespace paddle { -namespace framework { -namespace details { - -static std::vector CreatePlaces(size_t num, bool use_cuda) { - std::vector result; - result.reserve(num); - for (size_t i = 0; i < num; ++i) { - if (use_cuda) { - result.emplace_back(platform::CUDAPlace(static_cast(i))); - } else { - result.emplace_back(platform::CPUPlace()); - } - } - return result; -} - -void BuildStrategyApply(BuildStrategy *build_strategy, ir::Graph *graph) { - std::string loss_name = ""; - Scope scope; - std::vector scopes = {&scope}; - - auto places = CreatePlaces(1, false); - auto device = platform::Place2DeviceType(places[0]); - -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - platform::NCCLCommunicator ctxs; -#elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL) - platform::BKCLCommunicator ctxs; -#endif - - build_strategy->Apply(graph, - places, - loss_name, - scopes, - 1, -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - device, - &ctxs); -#elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL) - device, - &ctxs); -#else - device); -#endif -} - -std::unique_ptr CreateGraph() { - ProgramDesc prog; - auto *op = prog.MutableBlock(0)->AppendOp(); - op->SetType("fake_sum"); - op->SetInput("X", {"a1"}); - op->SetOutput("Out", {"b1"}); - op->SetAttr("op_role", 1); - - prog.MutableBlock(0)->Var("a1")->SetType(proto::VarType::LOD_TENSOR); - prog.MutableBlock(0)->Var("b1")->SetType(proto::VarType::LOD_TENSOR); - - std::unique_ptr g(new ir::Graph(prog)); - return g; -} - -std::unique_ptr CreateMultiGraph() { - ProgramDesc prog; - prog.AppendBlock(prog.Block(0)); - prog.AppendBlock(prog.Block(0)); - - // Set contents in block_0. - auto *op = prog.MutableBlock(0)->AppendOp(); - op->SetType("fake_sum"); - op->SetInput("X", {"test_a", "test_b", "test_c"}); - op->SetOutput("Out", {"test_out"}); - op->SetAttr("op_role", 1); - - prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS); - prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS); - prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS); - prog.MutableBlock(0)->Var("test_out"); - op->InferVarType(prog.MutableBlock(0)); - - prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::LOD_TENSOR); - op->InferVarType(prog.MutableBlock(0)); - - // Set contents in block_1. - op = prog.MutableBlock(1)->AppendOp(); - op->SetType("fake_sum"); - op->SetInput("X", {"a1"}); - op->SetOutput("Out", {"b1"}); - op->SetAttr("op_role", 1); - - prog.MutableBlock(1)->Var("a1")->SetType(proto::VarType::LOD_TENSOR); - prog.MutableBlock(1)->Var("b1")->SetType(proto::VarType::LOD_TENSOR); - - // Set contents in block_2. - op = prog.MutableBlock(2)->AppendOp(); - op->SetType("fake_sum"); - op->SetInput("X", {"a2"}); - op->SetOutput("Out", {"b2"}); - op->SetAttr("op_role", 1); - - prog.MutableBlock(2)->Var("a2")->SetType(proto::VarType::LOD_TENSOR); - prog.MutableBlock(2)->Var("b2")->SetType(proto::VarType::LOD_TENSOR); - - std::unique_ptr g(new ir::Graph(prog)); - return g; -} - -inline bool CheckSubGraphSame(ir::Graph *g1, ir::Graph *g2) { - const auto &g1_nodes_set = g1->Nodes(); - const auto &g2_nodes_set = g2->Nodes(); - - if (g1_nodes_set.size() != g2_nodes_set.size()) return false; - - std::vector g1_nodes(g1_nodes_set.begin(), g1_nodes_set.end()); - std::vector g2_nodes(g2_nodes_set.begin(), g2_nodes_set.end()); - - auto comp = [](ir::Node *n1, ir::Node *n2) { - return n1->Name() > n2->Name(); - }; - std::stable_sort(g1_nodes.begin(), g1_nodes.end(), comp); - std::stable_sort(g2_nodes.begin(), g2_nodes.end(), comp); - - for (size_t i = 0; i < g1_nodes.size(); ++i) { - const auto &n1 = g1_nodes[i]; - const auto &n2 = g2_nodes[i]; - - if (n1->NodeType() != n2->NodeType()) return false; - if (n1->Name() != n2->Name()) return false; - - auto n1_inputs = n1->inputs; - auto n2_inputs = n2->inputs; - if (n1_inputs.size() != n2_inputs.size()) return false; - - std::stable_sort(n1_inputs.begin(), n1_inputs.end(), comp); - std::stable_sort(n2_inputs.begin(), n2_inputs.end(), comp); - for (size_t i = 0; i < n1_inputs.size(); ++i) { - if (n1_inputs[i]->Name() != n2_inputs[i]->Name()) return false; - } - - auto n1_outputs = n1->outputs; - auto n2_outputs = n2->outputs; - if (n1_outputs.size() != n2_outputs.size()) return false; - - std::stable_sort(n1_outputs.begin(), n1_outputs.end(), comp); - std::stable_sort(n2_outputs.begin(), n2_outputs.end(), comp); - for (size_t i = 0; i < n1_outputs.size(); ++i) { - if (n1_outputs[i]->Name() != n2_outputs[i]->Name()) return false; - } - - if (n1->IsVar()) { - const auto &var1 = n1->Var(); - const auto &var2 = n2->Var(); - if ((var1 == nullptr) != (var2 == nullptr)) return false; - } - - if (n1->IsOp()) { - const auto &op1 = n1->Op(); - const auto &op2 = n2->Op(); - if ((op1 == nullptr) != (op2 == nullptr)) return false; - - const auto &op1_input = op1->InputNames(); - const auto &op2_input = op2->InputNames(); - if (op1_input.size() != op2_input.size()) return false; - if (op1_input != op2_input) return false; - - for (size_t i = 0; i < op1_input.size(); ++i) { - if (op1->Input(op1_input[i]) != op2->Input(op2_input[i])) return false; - } - - const auto &op1_output = op1->OutputNames(); - const auto &op2_output = op2->OutputNames(); - if (op1_output.size() != op2_output.size()) return false; - if (op1_output != op2_output) return false; - - for (size_t i = 0; i < op1_output.size(); ++i) { - if (op1->Output(op1_output[i]) != op2->Output(op2_output[i])) - return false; - } - } - } - return true; -} - -inline bool CheckGraphSame(ir::Graph *g1, ir::Graph *g2) { - if (g1 == nullptr || g2 == nullptr) return true; - - if (FLAGS_convert_all_blocks) { - if (g1->SubGraphsSize() != g2->SubGraphsSize()) return false; - - for (size_t i = 0; i < g1->SubGraphsSize(); ++i) { - if (!CheckSubGraphSame(g1->GetSubGraph(i), g2->GetSubGraph(i))) - return false; - } - } else { - if (!CheckSubGraphSame(g1, g2)) return false; - } - return true; -} - -TEST(BuildStrategy, Basic) { - BuildStrategy build_strategy; - - ProgramDesc prog; - ir::Graph old_graph(prog), graph(prog); - - BuildStrategyApply(&build_strategy, &graph); - - ASSERT_TRUE(CheckGraphSame(&old_graph, &graph)); -} - -TEST(BuildStrategy, TestSingleGraph) { - BuildStrategy build_strategy; - auto graph = CreateGraph(); - ir::Graph old_graph(graph->OriginProgram()); - - BuildStrategyApply(&build_strategy, graph.get()); - - // graph should not change for no pass here - ASSERT_TRUE(CheckGraphSame(&old_graph, graph.get())); -} - -TEST(BuildStrategy, TestMultiGraph) { - // Set FLAGS_convert_all_blocks to true to make sure this test works. - bool flag_temp = FLAGS_convert_all_blocks; - FLAGS_convert_all_blocks = true; - - BuildStrategy build_strategy; - auto graph = CreateMultiGraph(); - ir::Graph old_graph(graph->OriginProgram()); - - BuildStrategyApply(&build_strategy, graph.get()); - - // graph should not change for no pass here - ASSERT_TRUE(CheckGraphSame(&old_graph, graph.get())); - - // Recover FLAGS_convert_all_blocks. - FLAGS_convert_all_blocks = flag_temp; -} - -} // namespace details -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index fa6c8a25834536..3c7560b69e3323 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -19,8 +19,7 @@ cc_library( cc_library( graph_helper SRCS graph_helper.cc - DEPS graph program_utils scale_loss_grad_op_handle - grad_merge_all_reduce_op_handle collective_helper) + DEPS graph program_utils collective_helper) # cc_library( pass SRCS pass.cc diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt index d0618616619037..85923aafc23a74 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt +++ b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt @@ -1,36 +1,33 @@ cc_library( op_graph_view SRCS op_graph_view.cc - DEPS op_handle_base) + DEPS detail_op_handle) cc_library( conditional_block_op_eager_deletion_pass SRCS conditional_block_op_eager_deletion_pass.cc - DEPS conditional_block_op_helper graph_helper pass computation_op_handle) + DEPS conditional_block_op_helper graph_helper pass) cc_library( pylayer_op_eager_deletion_pass SRCS pylayer_op_eager_deletion_pass.cc - DEPS pylayer_op_helper graph_helper pass computation_op_handle) + DEPS pylayer_op_helper graph_helper pass) cc_library( while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc - DEPS while_op_helper graph_helper pass computation_op_handle) + DEPS while_op_helper graph_helper pass) cc_library( recurrent_op_eager_deletion_pass SRCS recurrent_op_eager_deletion_pass.cc - DEPS recurrent_op_helper graph_helper pass computation_op_handle) + DEPS recurrent_op_helper graph_helper pass) cc_library( reference_count_pass_helper SRCS reference_count_pass_helper.cc - DEPS garbage_collector computation_op_handle var_handle) + DEPS garbage_collector) # cc_library( reference_count_pass SRCS reference_count_pass.cc - DEPS computation_op_handle graph graph_helper pass op_graph_view - reference_count_pass_helper) + DEPS graph graph_helper pass op_graph_view reference_count_pass_helper) set(EAGER_DELETETION_PASS_DEPS - computation_op_handle - eager_deletion_op_handle graph graph_helper pass @@ -43,8 +40,7 @@ if(WITH_CINN) cc_library( share_varinfo_into_cinn_pass SRCS share_varinfo_into_cinn_pass.cc - DEPS pass enforce common graph_helper computation_op_handle - eager_deletion_op_handle) + DEPS pass enforce common graph_helper) cc_test( share_varinfo_into_cinn_pass_test SRCS share_varinfo_into_cinn_pass_test.cc @@ -61,8 +57,7 @@ cc_library( cc_library( memory_reuse_pass SRCS memory_reuse_pass.cc - DEPS computation_op_handle reference_count_pass_helper - share_tensor_buffer_op_handle graph pass multi_devices_helper) + DEPS reference_count_pass_helper graph pass) cc_library( buffer_shared_inplace_op_pass diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/CMakeLists.txt b/paddle/fluid/framework/ir/multi_devices_graph_pass/CMakeLists.txt index e97331bc87a453..2aa76a8eb2214b 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/CMakeLists.txt +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/CMakeLists.txt @@ -1,36 +1,25 @@ cc_library( modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc - DEPS computation_op_handle scale_loss_grad_op_handle op_graph_view - multi_devices_helper) + DEPS detail_op_handle op_graph_view) cc_library( multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc - DEPS multi_devices_helper) + DEPS detail_op_handle) cc_library( multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc - DEPS multi_devices_helper) + DEPS detail_op_handle) -set(ALL_REDUCE_OP_HANDLES all_reduce_op_handle) -set(ALL_REDUCE_OP_HANDLES grad_merge_all_reduce_op_handle) if(WITH_GPU AND WITH_DGC) - list(APPEND ALL_REDUCE_OP_HANDLES sparse_all_reduce_op_handle) + list(APPEND ALL_REDUCE_OP_HANDLES detail_op_handle) endif() cc_library( multi_devices_graph_pass SRCS multi_devices_graph_pass.cc - DEPS multi_devices_helper - computation_op_handle - scale_loss_grad_op_handle - rpc_op_handle - fetch_barrier_op_handle - ${ALL_REDUCE_OP_HANDLES} - reduce_op_handle - broadcast_op_handle - fused_broadcast_op_handle) + DEPS detail_op_handle ${ALL_REDUCE_OP_HANDLES}) cc_library( sequential_execution_pass SRCS sequential_execution_pass.cc @@ -43,12 +32,11 @@ cc_library( cc_library( fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc - DEPS graph graph_helper fused_all_reduce_op_handle - grad_merge_all_reduce_op_handle) + DEPS graph graph_helper) cc_library( all_reduce_deps_pass SRCS all_reduce_deps_pass.cc - DEPS all_reduce_op_handle graph graph_helper pass) + DEPS graph graph_helper pass) cc_library( backward_optimizer_op_deps_pass SRCS backward_optimizer_op_deps_pass.cc @@ -60,5 +48,4 @@ cc_library( cc_library( fix_op_run_order_pass SRCS fix_op_run_order_pass.cc - DEPS graph graph_helper multi_devices_helper pass op_handle_base - eager_deletion_op_handle) + DEPS graph graph_helper pass) diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index b6d846e9a0c12d..7a764f5302021b 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -19,7 +19,6 @@ if(WITH_XPU) var_type_traits op_kernel_type data_transform - nan_inf_utils phi common var_helper @@ -37,7 +36,6 @@ else() var_type_traits op_kernel_type data_transform - nan_inf_utils phi common var_helper diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index 3f4e7a9344a30c..1f353e2ba8409c 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -107,7 +107,7 @@ set(SHARED_INFERENCE_SRCS list(REMOVE_ITEM fluid_modules cinn_op_dialect) # NOTE(Aurelisu84): Remove pir dialect related target DEPS for inference # shared library to prune library size. -list(REMOVE_ITEM fluid_modules ${not_infer_modules}) +# list(REMOVE_ITEM fluid_modules ${not_infer_modules}) set(SHARED_INFERENCE_DEPS phi common ${fluid_modules} analysis_predictor ${utils_modules}) diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 8965bcfbf234ee..52eada6c5482ff 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -44,7 +44,7 @@ set(PYBIND_DEPS new_profiler fluid_jit prim_utils - gather_op_handle + detail_op_handle static_tensor_operants type_info auto_parallel) diff --git a/test/cpp/fluid/framework/details/CMakeLists.txt b/test/cpp/fluid/framework/details/CMakeLists.txt index cb430109e286ff..4a02fd08e0815e 100644 --- a/test/cpp/fluid/framework/details/CMakeLists.txt +++ b/test/cpp/fluid/framework/details/CMakeLists.txt @@ -3,14 +3,7 @@ paddle_test(broadcast_op_test SRCS broadcast_op_handle_test.cc) cc_test( gather_op_test SRCS gather_op_handle_test.cc - DEPS var_handle - op_handle_base - scope - phi - common - fluid_memory - device_context - gather_op_handle) + DEPS detail_op_handle scope phi common fluid_memory device_context) paddle_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc) paddle_test(exception_holder_test SRCS exception_holder_test.cc) diff --git a/paddle/fluid/framework/details/reduce_op_handle_test.cc b/test/cpp/fluid/framework/details/reduce_op_handle_test.cc similarity index 100% rename from paddle/fluid/framework/details/reduce_op_handle_test.cc rename to test/cpp/fluid/framework/details/reduce_op_handle_test.cc From 51d25ec7d3b9f3c9a8534bc0a37c88de3bb1193e Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Thu, 28 Dec 2023 10:40:27 +0800 Subject: [PATCH 015/142] fiox (#60404) --- test/dygraph_to_static/test_cache_program.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/dygraph_to_static/test_cache_program.py b/test/dygraph_to_static/test_cache_program.py index 2ad05f56b41e79..34744a6567cf00 100644 --- a/test/dygraph_to_static/test_cache_program.py +++ b/test/dygraph_to_static/test_cache_program.py @@ -18,9 +18,6 @@ import numpy as np from dygraph_to_static_utils import ( Dy2StTestBase, - IrMode, - ToStaticMode, - disable_test_case, enable_to_static_guard, test_ast_only, test_legacy_and_pt_and_pir, @@ -175,7 +172,6 @@ def sum_under_while(limit): return ret_sum -@disable_test_case((ToStaticMode.AST, IrMode.PT)) class TestToOutputWithCache(Dy2StTestBase): def test_output(self): ret = paddle.jit.to_static(sum_even_until_limit)(80, 10) From 61cf6b12537006f3efcbdde7e556104f3f0364c5 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Thu, 28 Dec 2023 10:54:43 +0800 Subject: [PATCH 016/142] sub block trace run in inference (#60419) * sub block trace run in inference * update --- .../new_executor/instruction/if_instruction.cc | 13 ++++++++----- .../new_executor/instruction/if_instruction.h | 3 ++- .../instruction/while_instruction.cc | 18 +++++++++++------- .../instruction/while_instruction.h | 3 ++- .../interpreter/interpreter_util.cc | 4 ++++ .../framework/new_executor/pir_interpreter.cc | 6 ++---- .../new_executor/program_interpreter.cc | 4 ++++ .../controlflow/conditional_block_op.cc | 3 +++ 8 files changed, 36 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/framework/new_executor/instruction/if_instruction.cc b/paddle/fluid/framework/new_executor/instruction/if_instruction.cc index c43eba69ed1f58..57146acdfb5df5 100644 --- a/paddle/fluid/framework/new_executor/instruction/if_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/if_instruction.cc @@ -43,7 +43,7 @@ IfInstruction::IfInstruction(size_t id, const platform::Place& place, pir::Operation* op, ValueExecutionInfo* value_exec_info, - const std::set& skip_gc_vars) + interpreter::ExecutionConfig execution_config) : InstructionBase(id, place) { PADDLE_ENFORCE( op->isa(), @@ -124,12 +124,15 @@ IfInstruction::IfInstruction(size_t id, VLOG(6) << "finish process inputs outputs index"; Scope* true_scope = &(value_exec_info->GetScope()->NewScope()); + auto skip_gc_vars = execution_config.skip_gc_vars; + execution_config.skip_gc_vars.clear(); + execution_config.create_local_scope = true; true_branch_inter_ = new PirInterpreter(place, {}, &true_branch_block, true_scope, value_exec_info->NewChild(true_scope), - {}); + execution_config); std::set true_skip_gc_names_set; for (auto value : GetYiedOpInputs(&true_branch_block)) { @@ -143,7 +146,7 @@ IfInstruction::IfInstruction(size_t id, true_skip_gc_names_.push_back(true_branch_inter_->GetNameByValue(value)); true_skip_gc_names_set.insert(true_branch_inter_->GetNameByValue(value)); } - for (auto var_name : skip_gc_vars) { + for (const auto& var_name : skip_gc_vars) { true_skip_gc_names_.push_back(var_name); true_skip_gc_names_set.insert(var_name); } @@ -157,7 +160,7 @@ IfInstruction::IfInstruction(size_t id, &if_op.false_block(), false_scope, value_exec_info->NewChild(false_scope), - {}); + execution_config); std::set false_skip_gc_names_set; for (auto value : GetYiedOpInputs(&false_branch_block)) { false_branch_outputs_.push_back(false_branch_inter_->GetNameByValue(value)); @@ -168,7 +171,7 @@ IfInstruction::IfInstruction(size_t id, false_skip_gc_names_.push_back(false_branch_inter_->GetNameByValue(value)); false_skip_gc_names_set.insert(false_branch_inter_->GetNameByValue(value)); } - for (auto var_name : skip_gc_vars) { + for (const auto& var_name : skip_gc_vars) { false_skip_gc_names_.push_back(var_name); false_skip_gc_names_set.insert(var_name); } diff --git a/paddle/fluid/framework/new_executor/instruction/if_instruction.h b/paddle/fluid/framework/new_executor/instruction/if_instruction.h index e6d1fc4723c5d6..b7b3ed6ac8f174 100644 --- a/paddle/fluid/framework/new_executor/instruction/if_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/if_instruction.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/fluid/framework/new_executor/instruction/instruction_base.h" +#include "paddle/fluid/framework/new_executor/interpreter/execution_config.h" namespace ir { class Operation; @@ -33,7 +34,7 @@ class IfInstruction : public InstructionBase { const platform::Place& place, ::pir::Operation* op, ValueExecutionInfo* value_exe_info, - const std::set& skip_gc_vars); + interpreter::ExecutionConfig execution_config); ~IfInstruction(); diff --git a/paddle/fluid/framework/new_executor/instruction/while_instruction.cc b/paddle/fluid/framework/new_executor/instruction/while_instruction.cc index f2a6e92e2f4b22..b281e2b8a6cbe4 100644 --- a/paddle/fluid/framework/new_executor/instruction/while_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/while_instruction.cc @@ -40,11 +40,12 @@ namespace paddle { namespace framework { -WhileInstruction::WhileInstruction(size_t id, - const platform::Place& place, - pir::Operation* op, - ValueExecutionInfo* parent_exe_info, - const std::set& skip_gc_vars) +WhileInstruction::WhileInstruction( + size_t id, + const platform::Place& place, + pir::Operation* op, + ValueExecutionInfo* parent_exe_info, + interpreter::ExecutionConfig execution_config) : InstructionBase(id, place) { op_ = op; VLOG(6) << "finish process dist attributes"; @@ -108,8 +109,11 @@ WhileInstruction::WhileInstruction(size_t id, body_scope->Var(var_name); body_exe_info->Add(body_block_->arg(i), var_name); } + auto skip_gc_vars = execution_config.skip_gc_vars; + execution_config.skip_gc_vars.clear(); + execution_config.create_local_scope = true; body_inter_ = std::unique_ptr(new PirInterpreter( - place, {}, body_block_, body_scope, body_exe_info, {})); + place, {}, body_block_, body_scope, body_exe_info, execution_config)); std::set body_skip_gc_names_set; auto body_block_outputs = GetYiedOpInputs(body_block_); @@ -122,7 +126,7 @@ WhileInstruction::WhileInstruction(size_t id, body_skip_gc_names_.push_back(body_inter_->GetNameByValue(value)); body_skip_gc_names_set.insert(body_inter_->GetNameByValue(value)); } - for (auto var_name : skip_gc_vars) { + for (const auto& var_name : skip_gc_vars) { body_skip_gc_names_.push_back(var_name); body_skip_gc_names_set.insert(var_name); } diff --git a/paddle/fluid/framework/new_executor/instruction/while_instruction.h b/paddle/fluid/framework/new_executor/instruction/while_instruction.h index ae27c89b0051a7..f8a98d3b03d6b1 100644 --- a/paddle/fluid/framework/new_executor/instruction/while_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/while_instruction.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/fluid/framework/new_executor/instruction/instruction_base.h" +#include "paddle/fluid/framework/new_executor/interpreter/execution_config.h" namespace ir { class Operation; @@ -39,7 +40,7 @@ class WhileInstruction : public InstructionBase { const platform::Place& place, ::pir::Operation* op, ValueExecutionInfo* parent_exe_info, - const std::set& skip_gc_vars); + interpreter::ExecutionConfig execution_config); void Run() override; diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 006e9d5fc4603d..614b97c26b7b07 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -641,6 +641,10 @@ void BuildOpFuncList(const platform::Place& place, auto runtime_attrs = op->RuntimeAttrs(); runtime_attrs.insert(std::make_pair("used_for_inference", true)); op->SetRuntimeAttributeMap(runtime_attrs); + } else if (op->Type() == "conditional_block") { + auto runtime_attrs = op->RuntimeAttrs(); + runtime_attrs.insert(std::make_pair("used_for_inference", true)); + op->SetRuntimeAttributeMap(runtime_attrs); } } diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.cc b/paddle/fluid/framework/new_executor/pir_interpreter.cc index 1cd1117d0ea1d2..82bf2973345ad5 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/pir_interpreter.cc @@ -686,9 +686,8 @@ void PirInterpreter::BuildInstruction() { } } else if (op.dialect()->name() == "pd_op") { if (op.isa()) { - auto skip_gc_vars = execution_config_.skip_gc_vars; vec_instruction_base_.emplace_back(std::make_unique( - op_idx++, place_, &op, value_exe_info_.get(), skip_gc_vars)); + op_idx++, place_, &op, value_exe_info_.get(), execution_config_)); sub_blocks_.insert( {&op.dyn_cast().true_block(), dynamic_cast(vec_instruction_base_.back().get()) @@ -698,9 +697,8 @@ void PirInterpreter::BuildInstruction() { dynamic_cast(vec_instruction_base_.back().get()) ->FalseBranchInterpreter()}); } else if (op.isa()) { - auto skip_gc_vars = execution_config_.skip_gc_vars; vec_instruction_base_.emplace_back(std::make_unique( - op_idx++, place_, &op, value_exe_info_.get(), skip_gc_vars)); + op_idx++, place_, &op, value_exe_info_.get(), execution_config_)); sub_blocks_.insert( {&op.dyn_cast().body(), dynamic_cast(vec_instruction_base_.back().get()) diff --git a/paddle/fluid/framework/new_executor/program_interpreter.cc b/paddle/fluid/framework/new_executor/program_interpreter.cc index 442112033608d5..9434e4fd81af60 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.cc +++ b/paddle/fluid/framework/new_executor/program_interpreter.cc @@ -921,6 +921,10 @@ void ProgramInterpreter::RunOperator(const Instruction& instr_node) { auto runtime_attrs = op->RuntimeAttrs(); runtime_attrs.insert(std::make_pair("used_for_inference", true)); op->SetRuntimeAttributeMap(runtime_attrs); + } else if (op->Type() == "conditional_block") { + auto runtime_attrs = op->RuntimeAttrs(); + runtime_attrs.insert(std::make_pair("used_for_inference", true)); + op->SetRuntimeAttributeMap(runtime_attrs); } } diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index 2ce18593461404..58e0114045db4b 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -103,6 +103,9 @@ class ConditionalBlockOp : public ConditionalOp { dev_place); framework::interpreter::ExecutionConfig execution_config; + if (HasAttr("used_for_inference") && Attr("used_for_inference")) { + execution_config.used_for_inference = true; + } execution_config.create_local_scope = false; execution_config.used_for_control_flow_op = true; execution_config.skip_gc_vars = From 7d9f2a2c9efce4dfad797562d9e84c56aa795be2 Mon Sep 17 00:00:00 2001 From: winter-wang <78149749+winter-wang@users.noreply.github.com> Date: Thu, 28 Dec 2023 10:55:20 +0800 Subject: [PATCH 017/142] [PIR] support while api for python api. (#60364) --- paddle/fluid/pybind/pir.cc | 6 +- paddle/pir/core/block.cc | 5 +- paddle/pir/core/block.h | 4 +- paddle/pir/core/region.cc | 12 ++-- python/paddle/static/nn/control_flow.py | 18 ++++-- python/paddle/tensor/math.py | 11 ++-- test/legacy_test/test_while_op.py | 76 +++++++++++++++++++------ 7 files changed, 91 insertions(+), 41 deletions(-) diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 7e1d46b3364c8d..2103e7b7b660e2 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -317,11 +317,15 @@ void BindBlock(py::module *m) { The constructor of Block should not be invoked directly. You can use `Program.block()` to get a block. )DOC"); - block + block.def("empty", &Block::empty) .def( "front", [](Block &self) { return &self.front(); }, return_value_policy::reference) + .def( + "back", + [](Block &self) { return &self.back(); }, + return_value_policy::reference) .def_property_readonly( "parent_op", [](Block &self) { return self.GetParentOp(); }, diff --git a/paddle/pir/core/block.cc b/paddle/pir/core/block.cc index 49389454545d10..e52e09258ab390 100644 --- a/paddle/pir/core/block.cc +++ b/paddle/pir/core/block.cc @@ -73,10 +73,7 @@ Operation *Block::Take(Operation *op) { return op; } -void Block::SetParent(Region *parent, Region::Iterator position) { - parent_ = parent; - position_ = position; -} +void Block::SetParent(Region *parent) { parent_ = parent; } Block::UseIterator Block::use_begin() const { return first_use_; } diff --git a/paddle/pir/core/block.h b/paddle/pir/core/block.h index 373f97e12c51ef..3d7774f0be3753 100644 --- a/paddle/pir/core/block.h +++ b/paddle/pir/core/block.h @@ -73,7 +73,6 @@ class IR_API Block { Iterator insert(ConstIterator iterator, Operation *op); Iterator erase(ConstIterator position); void clear(); - operator Region::Iterator() { return position_; } // Assign the operation underlying in position with parameter op, // meanwhile, destroy the original operation. @@ -145,7 +144,7 @@ class IR_API Block { // Allow access to 'SetParent'. friend class Region; - void SetParent(Region *parent, Region::Iterator position); + void SetParent(Region *parent); // Take out corresponding Operation and its ownershipe. friend class Operation; @@ -154,7 +153,6 @@ class IR_API Block { static bool TopoOrderCheck(const OpListType &op_list); private: - Region::Iterator position_; BlockOperand first_use_; OpListType ops_; // owned ArgListType arguments_; // owned diff --git a/paddle/pir/core/region.cc b/paddle/pir/core/region.cc index 21a09198f1d791..552df868611677 100644 --- a/paddle/pir/core/region.cc +++ b/paddle/pir/core/region.cc @@ -32,7 +32,7 @@ void Region::push_front(Block *block) { insert(blocks_.begin(), block); } Region::Iterator Region::insert(ConstIterator position, Block *block) { Region::Iterator iter = blocks_.insert(position, block); - block->SetParent(this, iter); + block->SetParent(this); return iter; } @@ -54,7 +54,7 @@ void Region::TakeBody(Region &&other) { clear(); blocks_.swap(other.blocks_); for (auto iter = blocks_.begin(); iter != blocks_.end(); ++iter) { - (*iter)->SetParent(this, iter); + (*iter)->SetParent(this); } } @@ -72,11 +72,11 @@ void Region::clear() { void Region::swap(Region &&other) { blocks_.swap(other.blocks_); - for (auto iter = begin(); iter != end(); ++iter) { - iter->SetParent(this, iter); + for (auto &block : *this) { + block.SetParent(this); } - for (auto iter = other.begin(); iter != other.end(); ++iter) { - iter->SetParent(&other, iter); + for (auto &block : other) { + block.SetParent(&other); } } diff --git a/python/paddle/static/nn/control_flow.py b/python/paddle/static/nn/control_flow.py index 3d2f9858a1feb5..a6a2027ac9a3dc 100644 --- a/python/paddle/static/nn/control_flow.py +++ b/python/paddle/static/nn/control_flow.py @@ -153,14 +153,21 @@ class WhileGuard(BlockGuard): def __init__(self, while_op): if not isinstance(while_op, While): raise TypeError("WhileGuard takes a while op") - super().__init__(while_op.helper.main_program) + if not in_pir_mode(): + super().__init__(while_op.helper.main_program) self.while_op = while_op def __enter__(self): + if in_pir_mode(): + self.block = build_while_op(self.while_op.cond_var, []).body() + return self.block.__enter__() self.while_op.status = While.IN_WHILE_BLOCK return super().__enter__() def __exit__(self, exc_type, exc_val, exc_tb): + if in_pir_mode(): + cf_yield([self.while_op.cond_var]) + return self.block.__exit__(exc_type, exc_val, exc_tb) if exc_type is not None: return False self.while_op.status = While.AFTER_WHILE_BLOCK @@ -509,8 +516,7 @@ class While: AFTER_WHILE_BLOCK = 2 def __init__(self, cond, is_test=False, name=None): - self.helper = LayerHelper("while", name=name) - self.status = While.BEFORE_WHILE_BLOCK + self.cond_var = cond check_variable_and_dtype(cond, 'cond', ['bool'], 'static.nn.While') if reduce(lambda a, b: a * b, cond.shape, 1) != 1: raise TypeError( @@ -518,7 +524,10 @@ def __init__(self, cond, is_test=False, name=None): list(cond.shape) ) ) - self.cond_var = cond + if in_pir_mode(): + return + self.status = While.BEFORE_WHILE_BLOCK + self.helper = LayerHelper("while", name=name) self.is_test = is_test def block(self): @@ -1870,5 +1879,4 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.inside_scope = False if exc_type is not None: return False # re-raise exception - return True diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 49fd425726cb5e..acaa0905ce6f40 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -4719,12 +4719,15 @@ def increment(x, value=1.0, name=None): [1.]) """ - if in_dynamic_or_pir_mode(): + if in_dynamic_mode(): + return _C_ops.increment_(x, value) + + check_variable_and_dtype( + x, 'x', ['float32', 'float64', 'int32', 'int64'], 'increment' + ) + if in_pir_mode(): return _C_ops.increment_(x, value) else: - check_variable_and_dtype( - x, 'x', ['float32', 'float64', 'int32', 'int64'], 'increment' - ) helper = LayerHelper("increment", **locals()) helper.append_op( type='increment', diff --git a/test/legacy_test/test_while_op.py b/test/legacy_test/test_while_op.py index 5ff7698b6b2bc1..63affc80d7cf4a 100644 --- a/test/legacy_test/test_while_op.py +++ b/test/legacy_test/test_while_op.py @@ -12,17 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest import numpy from utils import compare_legacy_with_pt import paddle -from paddle import base +from paddle import base, set_flags from paddle.base import core from paddle.base.backward import append_backward from paddle.base.executor import Executor +from paddle.base.framework import in_pir_mode from paddle.incubate.layers.nn import shuffle_batch +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -70,7 +73,7 @@ def simple_net(self): prev2 = paddle.tensor.array_read(array=mem_array, i=j) result2 = paddle.add_n([d2, prev2]) - j = paddle.increment(x=j) + paddle.increment(x=j) paddle.tensor.array_write(result2, i=j, array=mem_array) paddle.assign(paddle.less_than(x=j, y=array_len2), cond2) @@ -79,7 +82,8 @@ def simple_net(self): loss = paddle.mean(sum_result) return loss, sum_result - # TODO(zhangbo): Support pir test(support write_to_array and read_from_array, support while_grad). + # TODO(winter-wang): Support pir test in (FLAGS_enable_pir_in_executor_trace_run = False && FLAGS_new_executor_serial_run == False). + @test_with_pir_api def test_simple_net(self): main_program = base.Program() startup_program = base.Program() @@ -88,6 +92,14 @@ def test_simple_net(self): append_backward(loss) + if in_pir_mode(): + flag_1 = "FLAGS_enable_pir_in_executor_trace_run" + flag_2 = "FLAGS_new_executor_serial_run" + os.environ[flag_1] = 'True' + os.environ[flag_2] = 'True' + set_flags({flag_1: True}) + set_flags({flag_2: True}) + cpu = core.CPUPlace() exe = Executor(cpu) d = [] @@ -99,15 +111,24 @@ def test_simple_net(self): feed={'d0': d[0], 'd1': d[1], 'd2': d[2]}, fetch_list=[sum_result], ) + if in_pir_mode(): + del os.environ[flag_1] + del os.environ[flag_2] + set_flags({flag_1: False}) + set_flags({flag_2: False}) self.assertAlmostEqual(numpy.sum(d), numpy.sum(outs[0]), delta=0.01) - # TODO(zhangbo): Support pir test(support write_to_array and read_from_array) + # TODO(winter-wang): Support pir test in (FLAGS_enable_pir_in_executor_trace_run = False && FLAGS_new_executor_serial_run == False). + @test_with_pir_api def test_simple_net_forward(self): main_program = base.Program() startup_program = base.Program() with base.program_guard(main_program, startup_program): self.simple_net() - binary = base.compiler.CompiledProgram(main_program) + if in_pir_mode(): + binary = main_program + else: + binary = base.compiler.CompiledProgram(main_program) cpu = core.CPUPlace() exe = Executor(cpu) d = [] @@ -115,10 +136,23 @@ def test_simple_net_forward(self): for i in range(3): d.append(numpy.random.random(size=[10]).astype('float32')) + if in_pir_mode(): + flag_1 = "FLAGS_enable_pir_in_executor_trace_run" + flag_2 = "FLAGS_new_executor_serial_run" + os.environ[flag_1] = 'True' + os.environ[flag_2] = 'True' + set_flags({flag_1: True}) + set_flags({flag_2: True}) for _ in range(2): exe.run(binary, feed={'d0': d[0], 'd1': d[1], 'd2': d[2]}) + if in_pir_mode(): + del os.environ[flag_1] + del os.environ[flag_2] + set_flags({flag_1: False}) + set_flags({flag_2: False}) @compare_legacy_with_pt + @test_with_pir_api def test_exceptions(self): i = paddle.zeros(shape=[2], dtype='int64') array_len = paddle.tensor.fill_constant( @@ -134,6 +168,7 @@ def test_exceptions(self): class BadInputTest(unittest.TestCase): @compare_legacy_with_pt + @test_with_pir_api def test_error(self): with base.program_guard(base.Program()): @@ -158,8 +193,9 @@ def body_func(i, ten, batch_info, origin_seq): x = paddle.static.data(name='x', shape=[-1, 1, 4], dtype='float32') y = paddle.static.data(name='y', shape=[-1, 1, 1], dtype='float32') - x.desc.set_need_check_feed(False) - y.desc.set_need_check_feed(False) + if not in_pir_mode(): + x.desc.set_need_check_feed(False) + y.desc.set_need_check_feed(False) temp = paddle.concat([x, y], axis=-1) i = paddle.tensor.fill_constant(shape=[1], value=0, dtype='int32') @@ -190,6 +226,7 @@ def body_func(i, ten, batch_info, origin_seq): class TestOutputsMustExistsInputs(unittest.TestCase): @compare_legacy_with_pt + @test_with_pir_api def test_outputs_exists_inputs(self): """ We guarantee that the output tensor must be in the input tensor, so that the output and input can correspond to each other, but the input can be greater than the number of outputs. It's required in paddle2onnx. @@ -218,17 +255,20 @@ def body(i, s, x): paddle.enable_static() x = paddle.static.data(shape=[-1], name='x', dtype='float32') func(x) - for op in main_program.block(0).ops: - if op.type == "while": - for out_name in op.output("Out"): - if out_name in op.input("Condition"): - continue - self.assertTrue( - out_name in op.input("X"), - "In while op, the variable in output(`Out`) must exists in inputs(`X`), but the variable with name `{}` not meet the precondition.".format( - out_name - ), - ) + + # NOTE(winter-wang): The while_op in pir mode doesn't need following constrait, so hre only check when in non-pir mode. + if not in_pir_mode(): + for op in main_program.block(0).ops: + if op.type == "while": + for out_name in op.output("Out"): + if out_name in op.input("Condition"): + continue + self.assertTrue( + out_name in op.input("X"), + "In while op, the variable in output(`Out`) must exists in inputs(`X`), but the variable with name `{}` not meet the precondition.".format( + out_name + ), + ) if __name__ == '__main__': From b989f8a16edfb0260e7aaaf21519768d022e7829 Mon Sep 17 00:00:00 2001 From: Galaxy1458 <55453380+Galaxy1458@users.noreply.github.com> Date: Thu, 28 Dec 2023 11:04:20 +0800 Subject: [PATCH 018/142] [compilation opt]change_cc_test (#60392) * change * update --- test/cpp/prim/CMakeLists.txt | 5 +---- test/legacy_test/CMakeLists.txt | 10 ++++++---- test/xpu/cpp/CMakeLists.txt | 2 +- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/test/cpp/prim/CMakeLists.txt b/test/cpp/prim/CMakeLists.txt index 5be98e0a3b33dd..cb9e2cdeae8888 100644 --- a/test/cpp/prim/CMakeLists.txt +++ b/test/cpp/prim/CMakeLists.txt @@ -31,10 +31,7 @@ endif() # skip win32 since wget is not installed by default on windows machine. if(NOT WIN32) - cc_test( - test_vjp_pir - SRCS test_vjp.cc - DEPS op_dialect_vjp pir) + paddle_test(test_vjp_pir SRCS test_vjp.cc) endif() if(WITH_ONNXRUNTIME AND WIN32) # Copy onnxruntime for some c++ test in Windows, since the test will diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 45bd253a5aa596..824d50d8a6aaf7 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -857,10 +857,12 @@ if(WITH_HETERPS) endif() if(WIN32) - cc_test( - cc_imp_py_test - SRCS cc_imp_py_test.cc - DEPS python) + paddle_test(cc_imp_py_test SRCS cc_imp_py_test.cc) + if(WITH_ONNXRUNTIME) + # Copy onnxruntime for some c++ test in Windows, since the test will + # be build only in CI, so suppose the generator in Windows is Ninja. + copy_onnx(cc_imp_py_test) + endif() endif() set_tests_properties( diff --git a/test/xpu/cpp/CMakeLists.txt b/test/xpu/cpp/CMakeLists.txt index 7fd9278bfa7b48..8d1576446e9f34 100644 --- a/test/xpu/cpp/CMakeLists.txt +++ b/test/xpu/cpp/CMakeLists.txt @@ -1 +1 @@ -cc_test(enforce_xpu_test SRCS enforce_xpu_test.cc) +paddle_test(enforce_xpu_test SRCS enforce_xpu_test.cc) From 8d4fb21e532b096139057d8655b2b723e7b9568a Mon Sep 17 00:00:00 2001 From: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> Date: Thu, 28 Dec 2023 11:15:22 +0800 Subject: [PATCH 019/142] Use DimExpr and change InferSymbolicShapeInterface (#60371) * Use DimExpr and change InferSymbolicShapeInterface * static infer lib --- paddle/cinn/hlir/dialect/operator/ir/ops.yaml | 1 + paddle/fluid/inference/CMakeLists.txt | 5 +- .../op_generator/infer_symbolic_shape_gen.py | 6 +- .../fluid/pir/dialect/op_generator/op_gen.py | 4 +- .../interface/infer_symbolic_shape.cc | 273 ++++++++------ .../operator/interface/infer_symbolic_shape.h | 85 +++-- .../pir/dialect/operator/ir/op_dialect.cc | 31 ++ paddle/fluid/pir/dialect/operator/ir/ops.yaml | 3 + .../pir/transforms/shape_optimization_pass.cc | 333 ++++-------------- paddle/phi/api/yaml/ops.yaml | 3 + paddle/pir/dialect/shape/utils/shape_utils.h | 11 +- 11 files changed, 335 insertions(+), 420 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/ir/ops.yaml b/paddle/cinn/hlir/dialect/operator/ir/ops.yaml index 22006e1ae4570b..2e423237828399 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/ops.yaml +++ b/paddle/cinn/hlir/dialect/operator/ir/ops.yaml @@ -74,6 +74,7 @@ func : SliceRawInferMeta kernel : func : slice + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : uniform_random args : (int64_t[] shape, float min, float max, int seed, DataType dtype, int diag_num = 0, int diag_step=0, float diag_val=1.0) diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index 1f353e2ba8409c..295e72c43ce8f2 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -64,8 +64,9 @@ set(KERNEL_LIST # shared inference library deps list(REMOVE_DUPLICATES fluid_modules) -#windows GPU static library over the limit, so not create_static_lib, and cc_library is dummy -if(WIN32 AND WITH_GPU) +# windows static library(both CPU and GPU)over the limit, so no longer create_static_lib, +# and cc_library is dummy +if(WIN32) cc_library(paddle_inference DEPS ${fluid_modules} ${STATIC_INFERENCE_API} ${utils_modules}) else() diff --git a/paddle/fluid/pir/dialect/op_generator/infer_symbolic_shape_gen.py b/paddle/fluid/pir/dialect/op_generator/infer_symbolic_shape_gen.py index d85ed967418d54..ff2094a3df0093 100644 --- a/paddle/fluid/pir/dialect/op_generator/infer_symbolic_shape_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/infer_symbolic_shape_gen.py @@ -13,11 +13,9 @@ # limitations under the License. OP_GET_KERNEL_TYPE_FOR_VAR_TEMPLATE = """ -bool {op_name}::InferSymbolicShape(pir::Builder &builder, - const std::vector &operands, - std::vector &reified_return_shapes) {{ +bool {op_name}::InferSymbolicShape(pir::ShapeConstraintIRAnalysis* shape_analysis) {{ VLOG(4) << "Infer symbolic shape for op: {op_name}"; - return {op_name}InferSymbolicShape(builder, operands, reified_return_shapes); + return {op_name}InferSymbolicShape(this->operation(), shape_analysis); }} """ diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 4cb54ada152b82..d29982d22e5f77 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -133,9 +133,7 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ """ infer_symbolic_shape_template = """ - static bool InferSymbolicShape(pir::Builder &builder, - const std::vector &operands, - std::vector &reified_return_shapes); + bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis* shape_analysis); """ # ===================================== diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc index 676e4b9d574b9b..1b9ca43b7d9f10 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc @@ -13,16 +13,15 @@ // limitations under the License. #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/builtin_type.h" #include "paddle/pir/dialect/shape/ir/shape_op.h" namespace paddle::dialect { bool InferSymbolicShapeInterface::InferSymbolicShape( - pir::Builder &builder, - const std::vector &operands, - std::vector &reified_return_shapes) { - return impl_->infer_symbolic_shapes( - operation(), builder, operands, reified_return_shapes); + pir::ShapeConstraintIRAnalysis *shape_analysis) { + return impl_->infer_symbolic_shapes(operation(), shape_analysis); } } // namespace paddle::dialect @@ -30,124 +29,176 @@ namespace paddle::dialect { namespace { -bool DeriveShapeFromOperand(pir::Builder *builder, - pir::Value operand, - std::vector *reified_return_shapes) { - auto shaped_type = operand.type().dyn_cast(); - if (!shaped_type) return false; - reified_return_shapes->assign( - {builder->Build(operand).result(0)}); +bool InferSymbolicShapeAllEqualUnary( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + std::string operand_source_id = pir::GetValueId(&operand_source); + pir::OpResult res = op->result(0); + std::string res_id = pir::GetValueId(&res); + shape_analysis->value_id_to_shapeordata_[res_id] = + shape_analysis->value_id_to_shapeordata_[operand_source_id]; return true; } -// Returns a new scalar integer value having type `type`. -// Here `type` must be an integer or index type. -pir::Value MaybeCastTo(pir::Builder &builder, // NOLINT - pir::Value value, - pir::Type type) { - if (type == value.type()) return value; - // if (!type.IsIndex() && !value.type().IsIndex()) { - // Value casted = - // builder.Build(builder.index_type(), value) - // .result(0); - // return builder.Build(type, casted).result(0); - // } - // return builder.Build(type, value).result(0); +bool InferSymbolicShapeAllEqualBinary( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + std::string operand_source_id = pir::GetValueId(&operand_source); + pir::OpResult res = op->result(0); + std::string res_id = pir::GetValueId(&res); + shape_analysis->value_id_to_shapeordata_[res_id] = + shape_analysis->value_id_to_shapeordata_[operand_source_id]; + return true; } + } // namespace -bool AbsOpInferSymbolicShape( - pir::Builder &builder, // NOLINT - const std::vector &operands, - std::vector &reified_return_shapes) { // NOLINT - return DeriveShapeFromOperand( - &builder, operands.front().source(), &reified_return_shapes); -} - -bool Abs_OpInferSymbolicShape( - pir::Builder &builder, // NOLINT - const std::vector &operands, - std::vector &reified_return_shapes) { // NOLINT - return DeriveShapeFromOperand( - &builder, operands.front().source(), &reified_return_shapes); -} - -bool TransposeOpInferSymbolicShape( - pir::Builder &builder, // NOLINT - const std::vector &operands, - std::vector &reified_return_shapes) { // NOLINT - // auto operand_type = operands[0].type().dyn_cast(); - // // Currently not support unranked type. - // if (!operand_type) return false; - // std::vector permutation = this->permutation(); - // std::vector shape_values(permutation.size()); - // Type shape_scalar_type = builder.index_type(); - // auto to_shape_scalar_type = [&](Value v) { - // return MaybeCastTo(builder, v, shape_scalar_type); - // }; - // auto shaped_type = operand_type.dyn_cast(); - // auto shape_vector = shaped_type.GetDyShape(); - // for (auto [idx, element] = std::tuple{0, shape_vector.begin()}; - // element != shape_vector.end(); - // ++idx, ++element) { - // auto it = std::find(permutation.begin(), permutation.end(), idx); - // // TODO(zhangbopd): Need BuildOrFold - // Value value_dim = to_shape_scalar_type( - // builder.Build(operands[0].source(), - // idx).result(0)); - // shape_values[std::distance(permutation.begin(), it)] = value_dim; - // } - // Value output_shape = - // builder.Build(shape_values).result(0); - // reified_return_shapes.push_back(output_shape); +bool AbsOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + return InferSymbolicShapeAllEqualUnary(op, shape_analysis); +} + +bool Abs_OpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + return InferSymbolicShapeAllEqualUnary(op, shape_analysis); +} + +bool CastOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + return InferSymbolicShapeAllEqualUnary(op, shape_analysis); +} + +bool Cast_OpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + return InferSymbolicShapeAllEqualUnary(op, shape_analysis); +} + +bool ExpOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + return InferSymbolicShapeAllEqualUnary(op, shape_analysis); +} +bool Exp_OpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + return InferSymbolicShapeAllEqualUnary(op, shape_analysis); +} + +bool SubtractOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return InferSymbolicShapeAllEqualBinary(op, shape_analysis); +} + +bool Subtract_OpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return InferSymbolicShapeAllEqualBinary(op, shape_analysis); +} + +bool ShapeOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + std::string operand_source_id = pir::GetValueId(&operand_source); + pir::OpResult res = op->result(0); + std::string res_id = pir::GetValueId(&res); + + std::vector dims = + common::vectorize(res.type().dyn_cast().dims()); + + std::vector shapes; + for (int64_t dim : dims) { + symbol::DimExpr dim_expr; + if (dim == -1) { + symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName()); + dim_expr = res_dim_expr; + } else { + symbol::DimExpr res_dim_expr(dim); + dim_expr = res_dim_expr; + } + shapes.push_back(dim_expr); + } + + symbol::ShapeOrDataDimExprs shape_data{shapes}; + shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; return true; } -bool ConcatOpInferSymbolicShape( - pir::Builder &builder, // NOLINT - const std::vector &operands, - std::vector &reified_return_shapes) { // NOLINT - // std::vector inputs = {x()}; - // auto operand_type = inputs[0].type().dyn_cast(); - // // Currently not support unranked type. - // if (!operand_type) return false; - // Type shapeScalarType = builder.index_type(); - // auto to_shape_scalar_type = [&](Value v) { - // return MaybeCastTo(builder, v, shapeScalarType); - // }; - // std::vector> all_shape_values; - // for (size_t inputId = 0; inputId < inputs.size(); ++inputId) { - // Value operand = inputs[inputId]; - // auto operand_type = operand.type().dyn_cast(); - // if (!operand_type) return false; - // std::vector shape_values; - // auto shaped_type = operand_type.dyn_cast(); - // auto shape_vector = shaped_type.GetDyShape(); - // for (auto [idx, element] = std::tuple{0, shape_vector.begin()}; - // element != shape_vector.end(); - // ++idx, ++element) { - // Value value_dim = to_shape_scalar_type( - // builder.Build(operand, idx).result(0)); - // shape_values.push_back(value_dim); - // } - // all_shape_values.emplace_back(std::move(shape_values)); - // } - // [[maybe_unused]] int axis = this->dimension(); - // auto &shape_values = all_shape_values[0]; - // for (size_t vecId = 1; vecId < all_shape_values.size(); ++vecId) { - // auto &otherShapeValues = all_shape_values[vecId]; - // if (otherShapeValues.size() != shape_values.size()) return false; - // TODO(zhangbopd): AddIOp - // shape_values[axis] = - // builder.Build(shape_values[axis], - // otherShapeValues[axis]); - // } - // Value output_shape = - // builder.Build(shape_values).result(0); - // reified_return_shapes.push_back(output_shape); +bool ShapeSrOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return ShapeOpInferSymbolicShape(op, shape_analysis); +} + +bool StackOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + std::string operand_source_id = pir::GetValueId(&operand_source); + pir::OpResult res = op->result(0); + std::string res_id = pir::GetValueId(&res); + + symbol::ShapeOrDataDimExprs shape_data; + shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_id]; + shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; + return true; +} + +bool ReshapeOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source_1 = op->operand_source(1); + std::string operand_source_1_id = pir::GetValueId(&operand_source_1); + pir::OpResult res = op->result(0); + std::string res_id = pir::GetValueId(&res); + + symbol::ShapeOrDataDimExprs shape_data; + + shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_1_id]; + shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; return true; } +bool Reshape_OpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return ReshapeOpInferSymbolicShape(op, shape_analysis); +} + } // namespace paddle::dialect +namespace cinn::dialect { + +bool SliceOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + std::string operand_source_id = pir::GetValueId(&operand_source); + pir::OpResult res = op->result(0); + std::string res_id = pir::GetValueId(&res); + + std::vector dims = + common::vectorize(res.type().dyn_cast().dims()); + + std::vector shapes; + for (int64_t dim : dims) { + symbol::DimExpr dim_expr; + if (dim == -1) { + symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName()); + dim_expr = res_dim_expr; + } else { + symbol::DimExpr res_dim_expr(dim); + dim_expr = res_dim_expr; + } + shapes.push_back(dim_expr); + } + + // pir::AttributeMap attributes = op->attributes(); + + // auto attr_starts = + // attributes["starts"].dyn_cast().AsVector(); + // auto start = attr_starts[0].dyn_cast().data(); + + // auto attr_ends = + // attributes["ends"].dyn_cast().AsVector(); + // auto end = attr_ends[0].dyn_cast().data(); + + symbol::ShapeOrDataDimExprs shape_data{shapes}; + shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; + return true; +} + +} // namespace cinn::dialect + IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::InferSymbolicShapeInterface) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h index 46ccf56183b2ac..b1c72e3111df23 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/pir/core/op_base.h" +#include "paddle/pir/dialect/shape/utils/shape_utils.h" // Type inference is currently modelled executionally for operation creation // using the `InferMetaInterface`. While `InferSymbolicShapeInterface` is used @@ -31,54 +32,82 @@ class InferSymbolicShapeInterface /// Defined these methods with the interface. struct Concept { explicit Concept(bool (*infer_symbolic_shapes)( - pir::Operation* op, - pir::Builder& builder, // NOLINT - const std::vector& operands, - std::vector& reified_return_shapes)) // NOLINT + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis)) : infer_symbolic_shapes(infer_symbolic_shapes) {} bool (*infer_symbolic_shapes)( - pir::Operation* op, - pir::Builder& builder, - const std::vector& operands, - std::vector& reified_return_shapes); // NOLINT + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); }; template struct Model : public Concept { static inline bool InferSymbolicShape( - pir::Operation* op, - pir::Builder& builder, // NOLINT - const std::vector& operands, - std::vector& reified_return_shapes) { // NOLINT - return op->dyn_cast().InferSymbolicShape( - builder, operands, reified_return_shapes); + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return op->dyn_cast().InferSymbolicShape(shape_analysis); } Model() : Concept(InferSymbolicShape) {} }; /// Constructor - InferSymbolicShapeInterface(pir::Operation* op, Concept* impl) + InferSymbolicShapeInterface(pir::Operation *op, Concept *impl) : pir::OpInterfaceBase(op), impl_(impl) {} - bool InferSymbolicShape( - pir::Builder& builder, // NOLINT - const std::vector& operands, - std::vector& reified_return_shapes); // NOLINT + bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); private: - Concept* impl_; + Concept *impl_; }; -bool AbsOpInferSymbolicShape( - pir::Builder& builder, // NOLINT - const std::vector& operands, - std::vector& reified_return_shapes); // NOLINT -bool Abs_OpInferSymbolicShape( - pir::Builder& builder, // NOLINT - const std::vector& operands, - std::vector& reified_return_shapes); // NOLINT +} // namespace paddle::dialect + +namespace paddle::dialect { + +bool AbsOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool Abs_OpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool CastOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool Cast_OpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool ExpOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool Exp_OpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool SubtractOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool Subtract_OpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool ShapeOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool ShapeSrOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool StackOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool ReshapeOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); + +bool Reshape_OpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); } // namespace paddle::dialect +namespace cinn::dialect { + +bool SliceOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + +} + IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::InferSymbolicShapeInterface) diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 7b5959a542e7af..6e2e105d9c18a0 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -29,6 +29,32 @@ namespace paddle { namespace dialect { +struct CombineOpInferSymbolicShapeInterfaceModel + : public InferSymbolicShapeInterface::Concept { + static inline bool InferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + symbol::ShapeOrDataDimExprs value_shape; + + // for (auto operand_source : op->operands_source()) { + // std::string operand_source_id = pir::GetValueId(&operand_source); + // auto source_shape_vec = + // shape_analysis->value_id_to_shapeordata_[operand_source_id]; + // for (int i = 0; i < source_shape_vec.size(); i++) { + // value_shape.second.emplace_back(source_shape_vec[i]); + // } + // } + + auto res = op->result(0); + auto res_id = pir::GetValueId(&res); + + shape_analysis->value_id_to_shapeordata_[res_id] = value_shape; + return true; + } + + CombineOpInferSymbolicShapeInterfaceModel() + : InferSymbolicShapeInterface::Concept(InferSymbolicShape) {} +}; + OperatorDialect::OperatorDialect(pir::IrContext *ctx) : pir::Dialect(name(), ctx, pir::TypeId::get()) { initialize(); @@ -36,6 +62,11 @@ OperatorDialect::OperatorDialect(pir::IrContext *ctx) auto info = ctx->GetRegisteredOpInfo(pir::TuplePushOp::name()); info.AttachInterface(std::move( pir::InterfaceValue::Get())); + + info = ctx->GetRegisteredOpInfo(pir::CombineOp::name()); + info.AttachInterface(std::move( + pir::InterfaceValue::Get())); } void OperatorDialect::initialize() { diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 0d571f8ef868a7..ec68a17c9cb13b 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -265,6 +265,7 @@ data_type : x inplace: (x -> out) backward : cast_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : channel_shuffle args : (Tensor x, int groups, str data_format="NCHW") @@ -1044,6 +1045,7 @@ view: (x -> out) intermediate : xshape backward: reshape_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : rnn args: (Tensor x, Tensor[] pre_state, Tensor[] weight_list, Tensor sequence_length, Tensor dropout_state_in, float dropout_prob=0.0, bool is_bidirec=false, int input_size=10, int hidden_size=100, int num_layers=1, str mode="RNN_TANH", int seed=0, bool is_test=false) @@ -1214,6 +1216,7 @@ func : subtract inplace : (x -> out) backward : subtract_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : sum args : (Tensor x, IntArray axis={}, DataType dtype=DataType::UNDEFINED, bool keepdim=false) diff --git a/paddle/fluid/pir/transforms/shape_optimization_pass.cc b/paddle/fluid/pir/transforms/shape_optimization_pass.cc index a7d32c6577906b..5c6481110034e2 100644 --- a/paddle/fluid/pir/transforms/shape_optimization_pass.cc +++ b/paddle/fluid/pir/transforms/shape_optimization_pass.cc @@ -111,8 +111,8 @@ class InferSymbolicShapePass : public pir::Pass { if (it != infer_sym_shape_map.end()) { it->second(op, shape_analysis_); } else { - VLOG(3) << "[" << op.name() - << "] is not supported for infer_symbolic_shape pass."; + LOG(WARNING) << "[" << op.name() + << "] is not supported for infer_symbolic_shape pass."; } } @@ -206,7 +206,7 @@ struct ExpandShapeOfOpPattern : public OpRewritePattern { bool MatchAndRewrite(shape::ShapeOfOp op, PatternRewriter& rewriter) const override { - VLOG(5) << "Apply ExpandShapeOfOpPattern..."; + VLOG(3) << "Apply ExpandShapeOfOpPattern..."; auto type = op.out().type().dyn_cast(); @@ -233,44 +233,6 @@ struct DimOfShapedTypeOpInterfacePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; bool MatchAndRewrite(OpTy dim_op, PatternRewriter& rewriter) const override { - OpResult dim_value = dim_op.source().template dyn_cast(); - if (!dim_value) return false; - - auto shaped_type_op = - dim_value.owner() - ->dyn_cast(); - if (!shaped_type_op) return false; - - std::optional dim_index = dim_op.GetConstantIndex(); - if (!dim_index) return false; - - std::vector reified_result_shapes; - if (!shaped_type_op.InferSymbolicShape( - rewriter, shaped_type_op->operands(), reified_result_shapes)) - return false; - - if (reified_result_shapes.size() != shaped_type_op->num_results()) - return false; - - Value result_shape = reified_result_shapes[dim_value.index()]; - auto result_shape_type = result_shape.type().dyn_cast(); - auto shaped_type = result_shape_type.dyn_cast(); - if (!result_shape_type || !shaped_type.GetElementType().IsIntOrIndex()) - return false; - - // TODO(zhangbopd): BuildOrFold required. - std::vector indices; - indices.push_back(rewriter.Build(*dim_index).out()); - - Value new_value = - rewriter.Build(result_shape, indices).out(); - - if (!new_value.type().isa()) - new_value = - rewriter.Build(rewriter.index_type(), new_value) - .out(); - - rewriter.ReplaceOp(dim_op, {new_value}); return true; } }; @@ -349,19 +311,6 @@ bool ShapeComputationIRAnalysis::Run() { // Make sure only run once. if (initialized_) return false; initialized_ = true; - // auto build_shape_func = - // std::bind(&ShapeComputationIRAnalysis::BuildShapeOnOperation, - // this, - // std::placeholders::_1); - // if (!RunOnRegion(&(m_->region(0)), build_shape_func)) return false; - // auto apply_op_constraint_func = - // std::bind(&ShapeComputationIRAnalysis::ApplyOpConstraint, - // this, - // std::placeholders::_1); - // // TODO(zhangbopd): Delete the following 1 line and fix UT - // // `shape_optimization_test` - // return true; - // if (!RunOnRegion(&(m_->region(0)), apply_op_constraint_func)) return false; return true; } @@ -508,220 +457,81 @@ bool OptimizeShapeComputation(pir::ModuleOp m, PassPipelineRunner runner) { return true; } -void print_program(pir::ModuleOp m, std::string mgs) { +void PrintProgram(pir::ModuleOp m, std::string mgs) { std::ostringstream print_stream; print_stream << "\n\n"; m.program()->Print(print_stream); print_stream << "\n\n"; - VLOG(5) << "===================== " << mgs << "\n" << print_stream.str(); -} - -bool IsShapeSpecialOp(const pir::Operation& op) { - auto name = op.name(); - if (name == "pd_op.shape" || name == "cinn_op.slice") { - return true; - } - - return false; -} - -bool IsAllEqualUnaryOp(const pir::Operation& op) { - auto name = op.name(); - if (name == "pd_op.exp" || name == "pd_op.cast") { - return true; - } - - return false; -} - -void InferSymbolicShapeAllEqualUnary( - pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { - auto operand_source = op->operand_source(0); - auto operand_source_id = pir::GetValueId(&operand_source); - auto rst = op->result(0); - auto rst_id = pir::GetValueId(&rst); - shape_analysis->value_to_valueshape_expr_[rst_id] = - shape_analysis->value_to_valueshape_expr_[operand_source_id]; -} - -bool IsAllEqualBinaryOp(const pir::Operation& op) { - auto name = op.name(); - if (name == "pd_op.subtract") { - return true; - } - - return false; -} - -void InferSymbolicShapeAllEqualBinary( - pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { - auto operand_source = op->operand_source(0); - auto operand_source_id = pir::GetValueId(&operand_source); - auto rst = op->result(0); - auto rst_id = pir::GetValueId(&rst); - shape_analysis->value_to_valueshape_expr_[rst_id] = - shape_analysis->value_to_valueshape_expr_[operand_source_id]; -} - -void InferSymbolicShapePdShape(pir::Operation* op, - pir::ShapeConstraintIRAnalysis* shape_analysis) { - auto operand_source = op->operand_source(0); - auto operand_source_id = pir::GetValueId(&operand_source); - auto rst = op->result(0); - auto rst_id = pir::GetValueId(&rst); - std::pair, std::vector> value_shape; - - auto type = rst.type(); - auto tensor_type = type.dyn_cast(); - auto ddim_vec = common::vectorize(tensor_type.dims()); - for (auto dim : ddim_vec) { - std::string sym_name = ""; - if (dim == -1) { - sym_name = shape_analysis->GetNextSymName(); - } else { - sym_name = std::to_string(dim); - } - value_shape.first.emplace_back(sym_name); - } - - value_shape.second = - shape_analysis->value_to_valueshape_expr_[operand_source_id].first; - shape_analysis->value_to_valueshape_expr_[rst_id] = value_shape; -} - -void InferSymbolicShapeCinnSlice( - pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { - auto operand_source = op->operand_source(0); - auto operand_source_id = pir::GetValueId(&operand_source); - auto rst = op->result(0); - auto rst_id = pir::GetValueId(&rst); - std::pair, std::vector> value_shape; - - auto type = rst.type(); - auto tensor_type = type.dyn_cast(); - auto ddim_vec = common::vectorize(tensor_type.dims()); - for (auto dim : ddim_vec) { - std::string sym_name = ""; - if (dim == -1) { - sym_name = shape_analysis->GetNextSymName(); - } else { - sym_name = std::to_string(dim); - } - value_shape.first.emplace_back(sym_name); - } - - auto attributes = op->attributes(); - - auto attr_starts = attributes["starts"].dyn_cast().AsVector(); - auto start = attr_starts[0].dyn_cast().data(); - - auto attr_ends = attributes["ends"].dyn_cast().AsVector(); - auto end = attr_ends[0].dyn_cast().data(); - - auto source_shape_vec = - shape_analysis->value_to_valueshape_expr_[operand_source_id].second; - for (int i = start; i < end; i++) { - value_shape.second.emplace_back(source_shape_vec[i]); - } - - shape_analysis->value_to_valueshape_expr_[rst_id] = value_shape; + VLOG(3) << "===================== " << mgs << " =====================\n" + << print_stream.str(); } -void InferSymbolicShapeBuiltinCombine( - pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { - std::pair, std::vector> value_shape; - for (auto operand_source : op->operands_source()) { - auto operand_source_id = pir::GetValueId(&operand_source); - auto source_shape_vec = - shape_analysis->value_to_valueshape_expr_[operand_source_id].second; - for (int i = 0; i < source_shape_vec.size(); i++) { - value_shape.second.emplace_back(source_shape_vec[i]); - } - } - - auto rst = op->result(0); - auto rst_id = pir::GetValueId(&rst); - - shape_analysis->value_to_valueshape_expr_[rst_id] = value_shape; -} - -void InferSymbolicShapeStack(pir::Operation* op, - pir::ShapeConstraintIRAnalysis* shape_analysis) { - auto operand_source = op->operand_source(0); - auto operand_source_id = pir::GetValueId(&operand_source); - auto rst = op->result(0); - auto rst_id = pir::GetValueId(&rst); - std::pair, std::vector> value_shape; - - value_shape.second = - shape_analysis->value_to_valueshape_expr_[operand_source_id].second; - shape_analysis->value_to_valueshape_expr_[rst_id] = value_shape; -} - -void InferSymbolicShapeReshape(pir::Operation* op, - pir::ShapeConstraintIRAnalysis* shape_analysis) { - auto operand_source_1 = op->operand_source(1); - auto operand_source_1_id = pir::GetValueId(&operand_source_1); - auto rst = op->result(0); - auto rst_id = pir::GetValueId(&rst); - - std::pair, std::vector> value_shape; - - value_shape.first = - shape_analysis->value_to_valueshape_expr_[operand_source_1_id].second; - shape_analysis->value_to_valueshape_expr_[rst_id] = value_shape; -} - -void debug_print_op_info( +void DebugPrintOpInfo( pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis = nullptr) { - VLOG(5) << op->name() << ", num_operands: " << op->num_operands(); - for (auto& rst : op->results()) { - auto type = rst.type(); - auto value_id = pir::GetValueId(&rst); + VLOG(0) << op->name() << ", num_operands: " << op->num_operands(); + for (auto& res : op->results()) { + auto value_id = pir::GetValueId(&res); std::ostringstream print_stream; - print_stream << ">>>> result(" << rst.index() << ") 's ID: " << value_id; - if (shape_analysis != nullptr) { - auto value_shape = shape_analysis->value_to_valueshape_expr_[value_id]; - print_stream << ", value_shape.first: ["; - for (auto str : value_shape.first) { - print_stream << str << ", "; + print_stream << ">>>> result(" << res.index() << ") 's ID: " << value_id; + if (shape_analysis != nullptr) { + auto shape_data = shape_analysis->value_id_to_shapeordata_[value_id]; + print_stream << ", ShapeOrData.shape: ["; + + for (auto str : shape_data.shape()) { + int64_t* i = std::get_if(&str); + std::string* s = std::get_if(&str); + if (i) { + print_stream << *i << ", "; + } else if (s) { + print_stream << *s << ", "; + } } - print_stream << "], second: ["; - for (auto str : value_shape.second) { - print_stream << str << ", "; + + print_stream << "], ShapeOrData.data: ["; + if (shape_data.data().has_value()) { + for (auto str : shape_data.data().value()) { + int64_t* i = std::get_if(&str); + std::string* s = std::get_if(&str); + if (i) { + print_stream << *i << ", "; + } else if (s) { + print_stream << *s << ", "; + } + } } print_stream << "]\n"; } - VLOG(5) << print_stream.str(); + VLOG(0) << print_stream.str(); } } -void InferSymExprForAllValues(pir::ModuleOp module_op) { - auto shape_analysis_mgr = pir::ShapeAnalysisManager::Instance(); - pir::ShapeConstraintIRAnalysis& shape_analysis = +void InferSymExprForAllValues(ModuleOp module_op) { + auto shape_analysis_mgr = ShapeAnalysisManager::Instance(); + ShapeConstraintIRAnalysis& shape_analysis = shape_analysis_mgr.Get(module_op.program()); for (int i = 0; i < module_op->num_regions(); i++) { for (auto& block : module_op->region(i)) { for (auto& op : block) { if (op.num_operands() == 0) { - // Need new syms for -1s - for (auto& rst : op.results()) { - auto value_id = pir::GetValueId(&rst); - std::pair, std::vector> - value_shape; - auto type = rst.type(); - auto tensor_type = type.dyn_cast(); - auto ddim_vec = common::vectorize(tensor_type.dims()); - for (auto dim : ddim_vec) { - std::string sym_name = ""; + for (auto& res : op.results()) { + auto value_id = pir::GetValueId(&res); + + std::vector dims = common::vectorize( + res.type().dyn_cast().dims()); + + std::vector shapes; + for (int64_t dim : dims) { + symbol::DimExpr dim_expr; if (dim == -1) { - sym_name = shape_analysis.GetNextSymName(); + symbol::DimExpr res_dim_expr(shape_analysis.GetNextSymName()); + dim_expr = res_dim_expr; } else { - sym_name = std::to_string(dim); + symbol::DimExpr res_dim_expr(dim); + dim_expr = res_dim_expr; } - value_shape.first.emplace_back(sym_name); + shapes.push_back(dim_expr); } if (op.name() == "pd_op.full_int_array") { @@ -730,28 +540,23 @@ void InferSymExprForAllValues(pir::ModuleOp module_op) { auto arr = attr.dyn_cast(); const auto& vec = arr.AsVector(); for (auto item : vec) { - auto i = item.dyn_cast(); - value_shape.second.emplace_back(std::to_string(i.data())); + int64_t i = item.dyn_cast().data(); + shapes.push_back(symbol::DimExpr(i)); } } - shape_analysis.value_to_valueshape_expr_[value_id] = value_shape; + symbol::ShapeOrDataDimExprs shape_data{shapes}; + shape_analysis.value_id_to_shapeordata_[value_id] = shape_data; + } + } else { + auto infer_symbolic_shape_interface = + op.dyn_cast(); + if (infer_symbolic_shape_interface) { + PADDLE_ENFORCE(infer_symbolic_shape_interface.InferSymbolicShape( + &shape_analysis)); } - } else if (IsAllEqualUnaryOp(op)) { - InferSymbolicShapeAllEqualUnary(&op, &shape_analysis); - } else if (IsAllEqualBinaryOp(op)) { - InferSymbolicShapeAllEqualBinary(&op, &shape_analysis); - } else if (op.name() == "pd_op.shape") { - InferSymbolicShapePdShape(&op, &shape_analysis); - } else if (op.name() == "cinn_op.slice") { - InferSymbolicShapeCinnSlice(&op, &shape_analysis); - } else if (op.name() == "builtin.combine") { - InferSymbolicShapeBuiltinCombine(&op, &shape_analysis); - } else if (op.name() == "pd_op.stack") { - InferSymbolicShapeStack(&op, &shape_analysis); - } else if (op.name() == "pd_op.reshape") { - InferSymbolicShapeReshape(&op, &shape_analysis); } - debug_print_op_info(&op, &shape_analysis); + + DebugPrintOpInfo(&op, &shape_analysis); } } } @@ -762,11 +567,11 @@ class ShapeOptimizationPass : public pir::Pass { ShapeOptimizationPass() : pir::Pass("shape_optimization_pass", 0) {} void Run(pir::Operation* op) override { - VLOG(5) << "===================== ShapeOptimizationPass Run start... " + VLOG(3) << "===================== ShapeOptimizationPass Run start... " "============================="; auto module_op = op->dyn_cast(); IR_ENFORCE(module_op, "ShapeOptimizationPass should run on module op."); - print_program(module_op, "Origin Program:"); + PrintProgram(module_op, "Origin Program"); InferSymExprForAllValues(module_op); MaterializeShapeComputation(module_op); @@ -777,7 +582,7 @@ class ShapeOptimizationPass : public pir::Pass { // if (!OptimizeShapeComputation(module_op, runner)) { // return; // } - VLOG(5) << "===================== ShapeOptimizationPass Run End. " + VLOG(3) << "===================== ShapeOptimizationPass Run End. " "============================="; } diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index de7c49250ea16e..de4d700cdf80ee 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -841,6 +841,7 @@ func : exp inplace : (x -> out) backward : exp_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : expand args : (Tensor x, IntArray shape = {}) @@ -2355,6 +2356,7 @@ shape_sr {selected_rows -> dense} data_transform: skip_transform : input + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : shard_index args : (Tensor input, int index_num, int nshards, int shard_id, int ignore_value=-1) @@ -2538,6 +2540,7 @@ kernel : func : stack backward : stack_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : stanh args : (Tensor x, float scale_a=0.67f, float scale_b=1.7159f) diff --git a/paddle/pir/dialect/shape/utils/shape_utils.h b/paddle/pir/dialect/shape/utils/shape_utils.h index 717b05eb8fede6..ac72c0bae88c78 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_utils.h @@ -76,11 +76,6 @@ class IR_API ShapeConstraintIRAnalysis : public ShapeAnalysis { Value rhs, std::vector rhs_dim_idxs) override; - std::unordered_map< - std::string, - std::pair, std::vector>> - value_to_valueshape_expr_; - inline const std::string GetNextSymName() { return "S" + std::to_string(next_sym_idx_++); } @@ -89,6 +84,9 @@ class IR_API ShapeConstraintIRAnalysis : public ShapeAnalysis { symbol::DimExprBuilder CreateDimExprBuilder() override; + std::unordered_map + value_id_to_shapeordata_; + private: // The operation this analysis runs on. ModuleOp m_; @@ -99,9 +97,6 @@ class IR_API ShapeConstraintIRAnalysis : public ShapeAnalysis { std::unordered_map> value_to_sym_dims_; - std::unordered_map - value_id_to_shapeordata; - int64_t next_sym_idx_ = 0; std::vector constraints_; From 4551f5f8de03a2886eb8d9b3d68df4d855f7bf6e Mon Sep 17 00:00:00 2001 From: xingmingyyj <135400902+xingmingyyj@users.noreply.github.com> Date: Thu, 28 Dec 2023 11:34:32 +0800 Subject: [PATCH 020/142] =?UTF-8?q?=E3=80=90PIR=20OpTest=20Fix=20No.39?= =?UTF-8?q?=E3=80=91=20fix=20test=5Fc=5Freduce=5Fmin=5Ftranslate=20(#60236?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add test_c_reduce_min_translate * fix * fix * fix * fix * fix * fix * fix --- .../pir/dialect/op_generator/ops_api_gen.py | 2 + paddle/fluid/pir/dialect/operator/ir/ops.yaml | 10 ++++ .../fluid/pir/dialect/operator/utils/utils.cc | 3 +- paddle/phi/api/yaml/op_compat.yaml | 6 +++ test/ir/pir/CMakeLists.txt | 2 + test/ir/pir/translator/CMakeLists.txt | 15 ++++++ .../translator/test_c_reduce_min_translate.py | 42 ++++++++++++++++ test/ir/pir/translator/test_op_transcriber.py | 48 +++++++++++++++++++ 8 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 test/ir/pir/translator/CMakeLists.txt create mode 100644 test/ir/pir/translator/test_c_reduce_min_translate.py create mode 100644 test/ir/pir/translator/test_op_transcriber.py diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index 0a834bc7b0c2cf..d541f34a890dc2 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -136,6 +136,8 @@ 'sparse_momentum', 'soft_relu', 'uniform_random_batch_size_like', + 'c_reduce_min', + 'c_reduce_min_', ] diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index ec68a17c9cb13b..5bdcadc3cca03f 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -213,6 +213,16 @@ func : c_identity inplace : (x -> out) +- op : c_reduce_min + args : (Tensor x, int ring_id, int root_id, bool use_calc_stream) + output : Tensor(out) + infer_meta : + func : DistReduceInferMeta + param : [x] + kernel : + func : c_reduce_min + inplace : (x -> out) + - op : c_reduce_sum args : (Tensor x, int ring_id, int root_id, bool use_calc_stream) output : Tensor(out) diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index 722685fc3b5105..ebc1615a16d51a 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -58,7 +58,8 @@ const std::unordered_set LegacyOpList = { RowConvOp::name(), RowConvGradOp::name(), SoftReluOp::name(), - SoftReluGradOp::name()}; + SoftReluGradOp::name(), + CReduceMinOp::name()}; const std::unordered_set OneDNNLegacyOpList = {}; enum class AttrType { diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index d69e290bdbd144..e605dab1543371 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -3376,6 +3376,12 @@ outputs : out: Out +- op: c_reduce_min + inputs : + x : X + outputs : + out: Out + - op: c_reduce_sum inputs : x : X diff --git a/test/ir/pir/CMakeLists.txt b/test/ir/pir/CMakeLists.txt index 61d69ee4816f32..0b8d91aed17618 100644 --- a/test/ir/pir/CMakeLists.txt +++ b/test/ir/pir/CMakeLists.txt @@ -39,3 +39,5 @@ py_test_modules( FLAGS_pir_subgraph_saving_dir=${CMAKE_CURRENT_SOURCE_DIR}) add_subdirectory(fused_pass) + +add_subdirectory(translator) diff --git a/test/ir/pir/translator/CMakeLists.txt b/test/ir/pir/translator/CMakeLists.txt new file mode 100644 index 00000000000000..108615b0c204e5 --- /dev/null +++ b/test/ir/pir/translator/CMakeLists.txt @@ -0,0 +1,15 @@ +file( + GLOB TEST_INTERP_CASES + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "test_*.py") +string(REPLACE ".py" "" TEST_INTERP_CASES "${TEST_INTERP_CASES}") + +set(DISTRIBUTED_OP_TRANSLATION_TEST test_c_reduce_min_translate) + +if(NOT WITH_DISTRIBUTE) + list(REMOVE_ITEM TEST_INTERP_CASES ${DISTRIBUTED_OP_TRANSLATION_TEST}) +endif() + +foreach(target ${TEST_INTERP_CASES}) + py_test_modules(${target} MODULES ${target}) +endforeach() diff --git a/test/ir/pir/translator/test_c_reduce_min_translate.py b/test/ir/pir/translator/test_c_reduce_min_translate.py new file mode 100644 index 00000000000000..63c4e8271c2e15 --- /dev/null +++ b/test/ir/pir/translator/test_c_reduce_min_translate.py @@ -0,0 +1,42 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import test_op_transcriber + +import paddle +from paddle.base.layer_helper import LayerHelper + + +class TestCReduceMinOpTranscriber(test_op_transcriber.TestOpTranscriber): + def append_op(self): + self.op_type = "c_reduce_min" + x = paddle.ones(shape=(100, 2, 3), dtype='float32') + y = paddle.ones(shape=(100, 2, 3), dtype='float32') + attrs = {'ring_id': 0, 'root_id': 0, 'use_calc_stream': False} + helper = LayerHelper(self.op_type) + helper.append_op( + type=self.op_type, + inputs={"X": x}, + outputs={"Out": y}, + attrs=attrs, + ) + + def test_translator(self): + self.check() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/ir/pir/translator/test_op_transcriber.py b/test/ir/pir/translator/test_op_transcriber.py new file mode 100644 index 00000000000000..dfb8fa63a18705 --- /dev/null +++ b/test/ir/pir/translator/test_op_transcriber.py @@ -0,0 +1,48 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +from paddle import pir +from paddle.base import core + +paddle.enable_static() + + +class TestOpTranscriber(unittest.TestCase): + def setUp(self): + self.place = core.Place() + self.place.set_place(paddle.CPUPlace()) + self.new_scope = paddle.static.Scope() + self.main_program = paddle.static.Program() + + def append_op(self): + raise Exception("Define the op to be tested here!") + + def build_model(self): + with paddle.static.scope_guard(self.new_scope): + with paddle.static.program_guard(self.main_program): + self.append_op() + + def check(self): + self.build_model() + l = pir.translate_to_pir(self.main_program.desc) + assert hasattr(self, "op_type"), "Op_type should be specified!" + assert self.op_type in str(l), ( + self.op_type + + " should be translated to pd_op." + + self.op_type + + '!' + ) From 4c975499456ca37cbeafda232a94fbfb97daf854 Mon Sep 17 00:00:00 2001 From: RuohengMa <120699764+RuohengMa@users.noreply.github.com> Date: Thu, 28 Dec 2023 11:55:13 +0800 Subject: [PATCH 021/142] [PHI] add new supported datatype for tile and sigmoid_grad (#60119) --- paddle/phi/backends/xpu/xpu2_op_list.cc | 3 ++- .../phi/kernels/xpu/activation_grad_kernel.cc | 8 +++++++- paddle/phi/kernels/xpu/tile_kernel.cc | 19 ------------------- 3 files changed, 9 insertions(+), 21 deletions(-) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 253f0a8c1b87f3..31d16aaf5c0a38 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -835,7 +835,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::BFLOAT16})}, {"sigmoid", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, - {"sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"sigmoid_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"sign", XPUKernelSet({phi::DataType::FLOAT32})}, {"slice_grad", XPUKernelSet({phi::DataType::FLOAT32, diff --git a/paddle/phi/kernels/xpu/activation_grad_kernel.cc b/paddle/phi/kernels/xpu/activation_grad_kernel.cc index 7cada9005c33eb..48ff73d2477203 100644 --- a/paddle/phi/kernels/xpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_grad_kernel.cc @@ -734,6 +734,13 @@ PD_REGISTER_KERNEL(swish_grad, phi::dtype::float16, phi::dtype::bfloat16) {} +PD_REGISTER_KERNEL(sigmoid_grad, + XPU, + ALL_LAYOUT, + phi::SigmoidGradKernel, + float, + phi::dtype::float16) {} + PD_REGISTER_ACTIVATION_GRAD_KERNEL(exp_grad, ExpGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(log_grad, LogGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel) @@ -741,7 +748,6 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardsigmoid_grad, HardSigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardswish_grad, HardSwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) diff --git a/paddle/phi/kernels/xpu/tile_kernel.cc b/paddle/phi/kernels/xpu/tile_kernel.cc index cce230c970bf97..d90232b6767e79 100644 --- a/paddle/phi/kernels/xpu/tile_kernel.cc +++ b/paddle/phi/kernels/xpu/tile_kernel.cc @@ -29,7 +29,6 @@ void TileKernel(const Context& dev_ctx, const DenseTensor& x, const IntArray& repeat_times_arr, DenseTensor* out) { - using XPUType = typename XPUTypeTrait::Type; auto rank = x.dims().size(); std::vector repeat_times = repeat_times_arr.GetData(); int repeat_times_size = repeat_times.size(); @@ -123,24 +122,6 @@ void TileKernel(const Context& dev_ctx, vec_in_dims, vec_out_dims); - } else if (std::is_same::value) { - float* x_t = RAII_GUARD.alloc_l3_or_gm(x.numel()); - float* y_t = RAII_GUARD.alloc_l3_or_gm(out->numel()); - int r = - xpu::cast(dev_ctx.x_context(), - reinterpret_cast(x.data()), - x_t, - x.numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); - ret = xpu::broadcast( - dev_ctx.x_context(), x_t, y_t, vec_in_dims, vec_out_dims); - PADDLE_ENFORCE_XDNN_SUCCESS(ret, "broadcast"); - r = xpu::cast(dev_ctx.x_context(), - y_t, - reinterpret_cast(out->data()), - out->numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); - } else { ret = xpu::broadcast(dev_ctx.x_context(), x.data(), From cfa74f5a316117821a0e98f4845316e2d6083496 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Thu, 28 Dec 2023 12:57:07 +0800 Subject: [PATCH 022/142] Fix build bug for V100 (#60418) --- paddle/phi/kernels/funcs/weight_only_gemv.cu | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/funcs/weight_only_gemv.cu b/paddle/phi/kernels/funcs/weight_only_gemv.cu index ff9285693b55fd..2a14c4c9fb9d85 100644 --- a/paddle/phi/kernels/funcs/weight_only_gemv.cu +++ b/paddle/phi/kernels/funcs/weight_only_gemv.cu @@ -649,6 +649,7 @@ struct WeightOnlyConverter { } }; +#ifdef PADDLE_CUDA_BF16 template <> struct WeightOnlyConverter<__nv_bfloat16, WeightOnlyQuantType::Int8b> { static __device__ inline void convert(__nv_bfloat16 halves[4], @@ -689,6 +690,7 @@ struct WeightOnlyConverter<__nv_bfloat16, WeightOnlyQuantType::Int8b> { #endif } }; +#endif template <> struct WeightOnlyConverter { @@ -766,6 +768,7 @@ struct WeightOnlyConverter { } }; +#ifdef PADDLE_CUDA_BF16 template <> struct WeightOnlyConverter<__nv_bfloat16, WeightOnlyQuantType::Int4b> { static __device__ inline void convert(__nv_bfloat16 halves[8], @@ -817,6 +820,7 @@ struct WeightOnlyConverter<__nv_bfloat16, WeightOnlyQuantType::Int4b> { #endif } }; +#endif template __device__ __forceinline__ void load(T0* dst, T1* src, size_t offset = 0) { @@ -1401,7 +1405,7 @@ template void WeightOnlyGemvWrapper(const phi::GPUContext& ctx, const std::string& weight_only_type, const std::string& act_method, phi::dtype::float16* output); - +#ifdef PADDLE_CUDA_BF16 template void WeightOnlyGemvWrapper(const phi::GPUContext& ctx, const phi::dtype::bfloat16* input, const int8_t* weight, @@ -1415,4 +1419,6 @@ template void WeightOnlyGemvWrapper(const phi::GPUContext& ctx, const std::string& weight_only_type, const std::string& act_method, phi::dtype::bfloat16* output); +#endif + } // namespace phi From db27fe4e38ef24af461516ea627bf66e33f6730a Mon Sep 17 00:00:00 2001 From: zhangyikun02 <48021248+zhangyk0314@users.noreply.github.com> Date: Thu, 28 Dec 2023 13:58:25 +0800 Subject: [PATCH 023/142] elementwise_pow, square, sin and cos support bfloat16 for xpu (#60402) --- paddle/phi/backends/xpu/xpu3_op_list.cc | 18 +++-- .../kernels/legacy/xpu/elementwise_kernel.cc | 3 +- paddle/phi/kernels/xpu/activation_kernel.cc | 27 ++++++-- paddle/phi/kernels/xpu/elementwise_kernel.cc | 3 +- test/xpu/test_activation_op_xpu.py | 63 +++++++++--------- test/xpu/test_elementwise_pow_op_xpu.py | 65 ++++++++----------- 6 files changed, 96 insertions(+), 83 deletions(-) diff --git a/paddle/phi/backends/xpu/xpu3_op_list.cc b/paddle/phi/backends/xpu/xpu3_op_list.cc index 016e5ef917af57..20c649ee4ba978 100644 --- a/paddle/phi/backends/xpu/xpu3_op_list.cc +++ b/paddle/phi/backends/xpu/xpu3_op_list.cc @@ -296,7 +296,9 @@ XPUOpMap& get_kl3_ops() { phi::DataType::INT32, phi::DataType::INT64})}, {"elementwise_pow", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"elementwise_sub_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, @@ -891,7 +893,9 @@ XPUOpMap& get_kl3_ops() { {"square_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"square", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"squared_l2_norm", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, @@ -1142,9 +1146,15 @@ XPUOpMap& get_kl3_ops() { phi::DataType::FLOAT32, phi::DataType::FLOAT16, phi::DataType::BFLOAT16})}, - {"sin", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"sin", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"sin_grad", XPUKernelSet({phi::DataType::FLOAT32})}, - {"cos", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"cos", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"cos_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"linspace", XPUKernelSet({phi::DataType::FLOAT32, diff --git a/paddle/phi/kernels/legacy/xpu/elementwise_kernel.cc b/paddle/phi/kernels/legacy/xpu/elementwise_kernel.cc index 2e4bf779d26cdd..96ad9bb1f56848 100644 --- a/paddle/phi/kernels/legacy/xpu/elementwise_kernel.cc +++ b/paddle/phi/kernels/legacy/xpu/elementwise_kernel.cc @@ -153,4 +153,5 @@ PD_REGISTER_KERNEL(elementwise_pow_raw, ALL_LAYOUT, phi::ElementwisePowRawKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/xpu/activation_kernel.cc b/paddle/phi/kernels/xpu/activation_kernel.cc index e76fded263f7c6..449be30474193a 100644 --- a/paddle/phi/kernels/xpu/activation_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_kernel.cc @@ -624,8 +624,13 @@ PD_REGISTER_KERNEL(sqrt, PD_REGISTER_KERNEL( tanh, XPU, ALL_LAYOUT, phi::TanhKernel, float, phi::dtype::float16) {} -PD_REGISTER_KERNEL( - square, XPU, ALL_LAYOUT, phi::SquareKernel, float, phi::dtype::float16) {} +PD_REGISTER_KERNEL(square, + XPU, + ALL_LAYOUT, + phi::SquareKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL( log, XPU, ALL_LAYOUT, phi::LogKernel, float, phi::dtype::float16) {} @@ -633,10 +638,20 @@ PD_REGISTER_KERNEL( PD_REGISTER_KERNEL( relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float, phi::dtype::float16) {} -PD_REGISTER_KERNEL( - sin, XPU, ALL_LAYOUT, phi::SinKernel, float, phi::dtype::float16) {} -PD_REGISTER_KERNEL( - cos, XPU, ALL_LAYOUT, phi::CosKernel, float, phi::dtype::float16) {} +PD_REGISTER_KERNEL(sin, + XPU, + ALL_LAYOUT, + phi::SinKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) {} +PD_REGISTER_KERNEL(cos, + XPU, + ALL_LAYOUT, + phi::CosKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) {} #define PD_REGISTER_ACTIVATION_KERNEL(name, func) \ PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {} diff --git a/paddle/phi/kernels/xpu/elementwise_kernel.cc b/paddle/phi/kernels/xpu/elementwise_kernel.cc index 83dce5437c9ecb..a4b1385393d69c 100644 --- a/paddle/phi/kernels/xpu/elementwise_kernel.cc +++ b/paddle/phi/kernels/xpu/elementwise_kernel.cc @@ -114,4 +114,5 @@ PD_REGISTER_KERNEL(elementwise_pow, ALL_LAYOUT, phi::ElementwisePowKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/test/xpu/test_activation_op_xpu.py b/test/xpu/test_activation_op_xpu.py index 9ea61f229822e2..3952217a301f21 100644 --- a/test/xpu/test_activation_op_xpu.py +++ b/test/xpu/test_activation_op_xpu.py @@ -521,6 +521,11 @@ def set_case(self): self.op_type = "square" self.dtype = self.in_type self.init_config() + if self.dtype == np.uint16: + # bfloat16 actually + self.x = convert_float_to_uint16(self.tmp_x) + else: + self.x = self.tmp_x.astype(self.dtype) out = np.square(self.x) self.attrs = {'use_xpu': True} @@ -528,27 +533,27 @@ def set_case(self): self.outputs = {'Out': out} def init_config(self): - self.x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) + self.tmp_x = np.random.uniform(-1, 1, [11, 17]) class XPUTestSquare_ZeroDim(XPUTestSquare): def init_config(self): - self.x = np.random.uniform(-2, 2, []).astype(self.dtype) + self.tmp_x = np.random.uniform(-2, 2, []) class XPUTestSquare2(XPUTestSquare): def init_config(self): - self.x = np.random.uniform(-2, 2, [100]).astype(self.dtype) + self.tmp_x = np.random.uniform(-2, 2, [100]) class XPUTestSquare3(XPUTestSquare): def init_config(self): - self.x = np.random.uniform(-2, 2, [1, 15, 19]).astype(self.dtype) + self.tmp_x = np.random.uniform(-2, 2, [1, 15, 19]) class XPUTestSquare4(XPUTestSquare): def init_config(self): - self.x = np.random.uniform(-2, 2, [100, 10]).astype(self.dtype) + self.tmp_x = np.random.uniform(-2, 2, [100, 10]) class XPUTestSquare5(XPUTestSquare): def init_config(self): - self.x = np.random.uniform(-2, 2, [1, 2, 5, 17]).astype(self.dtype) + self.tmp_x = np.random.uniform(-2, 2, [1, 2, 5, 17]) support_types = get_xpu_op_support_types('square') @@ -1297,6 +1302,11 @@ def set_case(self): self.dtype = self.in_type self.init_config() + if self.dtype == np.uint16: + # bfloat16 actually + self.x = convert_float_to_uint16(self.tmp_x) + else: + self.x = self.tmp_x.astype(self.dtype) out = np.sin(self.x) self.inputs = {'X': self.x} @@ -1304,31 +1314,23 @@ def set_case(self): self.attrs = {'use_xpu': True} def init_config(self): - self.x = np.random.uniform(-np.pi, np.pi, [11, 17]).astype( - self.dtype - ) + self.tmp_x = np.random.uniform(-np.pi, np.pi, [11, 17]) class XPUTestSin_ZeroDim(XPUTestSinBase): def init_config(self): - self.x = np.random.uniform(-np.pi, np.pi, []).astype(self.dtype) + self.tmp_x = np.random.uniform(-np.pi, np.pi, []) class XPUTestSin2(XPUTestSinBase): def init_config(self): - self.x = np.random.uniform(-np.pi, np.pi, [1024, 8]).astype( - self.dtype - ) + self.tmp_x = np.random.uniform(-np.pi, np.pi, [1024, 8]) class XPUTestSin3(XPUTestSinBase): def init_config(self): - self.x = np.random.uniform(-np.pi, np.pi, [4, 512, 15, 15]).astype( - self.dtype - ) + self.tmp_x = np.random.uniform(-np.pi, np.pi, [4, 512, 15, 15]) class XPUTestSin4(XPUTestSinBase): def init_config(self): - self.x = np.random.uniform(-np.pi, np.pi, [4, 256, 22, 22]).astype( - self.dtype - ) + self.tmp_x = np.random.uniform(-np.pi, np.pi, [4, 256, 22, 22]) support_types = get_xpu_op_support_types('sin') @@ -1347,6 +1349,11 @@ def set_case(self): self.dtype = self.in_type self.init_config() + if self.dtype == np.uint16: + # bfloat16 actually + self.x = convert_float_to_uint16(self.tmp_x) + else: + self.x = self.tmp_x.astype(self.dtype) out = np.cos(self.x) self.inputs = {'X': self.x} @@ -1354,31 +1361,23 @@ def set_case(self): self.attrs = {'use_xpu': True} def init_config(self): - self.x = np.random.uniform(-np.pi, np.pi, [11, 17]).astype( - self.dtype - ) + self.tmp_x = np.random.uniform(-np.pi, np.pi, [11, 17]) class XPUTestCos_ZeroDim(XPUTestCosBase): def init_config(self): - self.x = np.random.uniform(-np.pi, np.pi, []).astype(self.dtype) + self.tmp_x = np.random.uniform(-np.pi, np.pi, []) class XPUTestCos2(XPUTestCosBase): def init_config(self): - self.x = np.random.uniform(-np.pi, np.pi, [1024, 8]).astype( - self.dtype - ) + self.tmp_x = np.random.uniform(-np.pi, np.pi, [1024, 8]) class XPUTestCos3(XPUTestCosBase): def init_config(self): - self.x = np.random.uniform(-np.pi, np.pi, [4, 512, 15, 15]).astype( - self.dtype - ) + self.tmp_x = np.random.uniform(-np.pi, np.pi, [4, 512, 15, 15]) class XPUTestCos4(XPUTestCosBase): def init_config(self): - self.x = np.random.uniform(-np.pi, np.pi, [4, 256, 22, 22]).astype( - self.dtype - ) + self.tmp_x = np.random.uniform(-np.pi, np.pi, [4, 256, 22, 22]) support_types = get_xpu_op_support_types('cos') diff --git a/test/xpu/test_elementwise_pow_op_xpu.py b/test/xpu/test_elementwise_pow_op_xpu.py index ddcf64fb9d4051..a63e403ca50d5c 100644 --- a/test/xpu/test_elementwise_pow_op_xpu.py +++ b/test/xpu/test_elementwise_pow_op_xpu.py @@ -20,7 +20,7 @@ create_test_class, get_xpu_op_support_types, ) -from op_test import OpTest, skip_check_grad_ci +from op_test import OpTest, convert_float_to_uint16, skip_check_grad_ci from op_test_xpu import XPUOpTest import paddle @@ -40,14 +40,23 @@ def setUp(self): self.dtype = self.in_type self.__class__.no_need_check_grad = True self.compute_input_output() - - def compute_input_output(self): + if self.dtype == np.uint16: + # bfloat16 actually + self.x = convert_float_to_uint16(self.tmp_x) + self.y = convert_float_to_uint16(self.tmp_y) + else: + self.x = self.tmp_x.astype(self.dtype) + self.y = self.tmp_y.astype(self.dtype) self.inputs = { - 'X': np.random.uniform(1, 2, [20, 5]).astype(self.dtype), - 'Y': np.random.uniform(1, 2, [20, 5]).astype(self.dtype), + 'X': self.x, + 'Y': self.y, } self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])} + def compute_input_output(self): + self.tmp_x = np.random.uniform(1, 2, [20, 5]) + self.tmp_y = np.random.uniform(1, 2, [20, 5]) + def test_check_output(self): if paddle.is_compiled_with_xpu(): place = paddle.XPUPlace(0) @@ -55,58 +64,36 @@ def test_check_output(self): class TestElementwisePowOp_big_shape_1(TestElementwisePowOp): def compute_input_output(self): - self.inputs = { - 'X': np.random.uniform(1, 2, [10, 10]).astype(self.dtype), - 'Y': np.random.uniform(0.1, 1, [10, 10]).astype(self.dtype), - } - self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])} + self.tmp_x = np.random.uniform(1, 2, [10, 10]) + self.tmp_y = np.random.uniform(0.1, 1, [10, 10]) class TestElementwisePowOp_big_shape_2(TestElementwisePowOp): def compute_input_output(self): - self.inputs = { - 'X': np.random.uniform(1, 2, [10, 10]).astype(self.dtype), - 'Y': np.random.uniform(0.2, 2, [10, 10]).astype(self.dtype), - } - self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])} + self.tmp_x = np.random.uniform(1, 2, [10, 10]) + self.tmp_y = np.random.uniform(0.2, 2, [10, 10]) @skip_check_grad_ci( reason="[skip shape check] Use y_shape(1) to test broadcast." ) class TestElementwisePowOp_scalar(TestElementwisePowOp): def compute_input_output(self): - self.inputs = { - 'X': np.random.uniform(0.1, 1, [3, 3, 4]).astype(self.dtype), - 'Y': np.random.uniform(0.1, 1, [1]).astype(self.dtype), - } - self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])} + self.tmp_x = np.random.uniform(0.1, 1, [3, 3, 4]) + self.tmp_y = np.random.uniform(0.1, 1, [1]) class TestElementwisePowOp_tensor(TestElementwisePowOp): def compute_input_output(self): - self.inputs = { - 'X': np.random.uniform(0.1, 1, [100]).astype(self.dtype), - 'Y': np.random.uniform(1, 3, [100]).astype(self.dtype), - } - self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])} + self.tmp_x = np.random.uniform(0.1, 1, [100]) + self.tmp_y = np.random.uniform(1, 3, [100]) class TestElementwisePowOp_broadcast_0(TestElementwisePowOp): def compute_input_output(self): - self.inputs = { - 'X': np.random.uniform(0.1, 1, [2, 1, 100]).astype(self.dtype), - 'Y': np.random.uniform(0.1, 1, [100]).astype(self.dtype), - } - self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])} + self.tmp_x = np.random.uniform(0.1, 1, [2, 1, 100]) + self.tmp_y = np.random.uniform(0.1, 1, [100]) class TestElementwisePowOp_broadcast_4(TestElementwisePowOp): def compute_input_output(self): - self.inputs = { - 'X': np.random.uniform(0.1, 1, [2, 10, 3, 5]).astype( - self.dtype - ), - 'Y': np.random.uniform(0.1, 1, [2, 10, 1, 5]).astype( - self.dtype - ), - } - self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])} + self.tmp_x = np.random.uniform(0.1, 1, [2, 10, 3, 5]) + self.tmp_y = np.random.uniform(0.1, 1, [2, 10, 1, 5]) class TestElementwisePowOpInt(OpTest): def setUp(self): From bae368752ab884300dfe5f55524b8df26ff26d3f Mon Sep 17 00:00:00 2001 From: Nyakku Shigure Date: Thu, 28 Dec 2023 14:10:30 +0800 Subject: [PATCH 024/142] [Dy2St] Replace all astor usage with `ast_to_source_code` (#60302) --- .../jit/dy2static/transformers/basic_api_transformer.py | 7 +++---- python/paddle/jit/dy2static/utils.py | 7 +++---- python/paddle/jit/dy2static/utils_helper.py | 5 ++--- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/python/paddle/jit/dy2static/transformers/basic_api_transformer.py b/python/paddle/jit/dy2static/transformers/basic_api_transformer.py index 1d9c865bf75b28..0902a3558b2b0d 100644 --- a/python/paddle/jit/dy2static/transformers/basic_api_transformer.py +++ b/python/paddle/jit/dy2static/transformers/basic_api_transformer.py @@ -13,11 +13,10 @@ # limitations under the License. -import astor - from paddle.utils import gast from .. import utils +from ..ast_utils import ast_to_source_code from .base import BaseTransformer __all__ = [] @@ -63,7 +62,7 @@ def visit_Expr(self, node): def _visit_Call(self, node): assert isinstance(node, gast.Call) - func_name = astor.to_source(gast.gast_to_ast(node.func)) + func_name = ast_to_source_code(node.func) if self._is_dygraph_forward(func_name): class_node = self._get_class_node(func_name) @@ -91,7 +90,7 @@ def _update_class_node_dict(self, node): return False utils.update_args_of_func(node_value, node_value, "__init__") - target_str = astor.to_source(gast.gast_to_ast(node.targets[0])) + target_str = ast_to_source_code(node.targets[0]) self.class_node_dict[target_str] = node_value return True # TODO: node.value is not dygraph class diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index 8079a9a5271689..fc18ee5883e9ce 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -27,7 +27,6 @@ import warnings from importlib.machinery import SourceFileLoader -import astor import numpy as np import paddle @@ -320,7 +319,7 @@ def in_white_list(module, func_name): def _delete_keywords_from(node): assert isinstance(node, gast.Call) - func_src = astor.to_source(gast.gast_to_ast(node.func)) + func_src = ast_to_source_code(node.func) full_args = eval(f"inspect.getfullargspec({func_src})") full_args_name = full_args[0] @@ -398,7 +397,7 @@ def update_args_of_func(node, dygraph_node, method_name): "The method name of class to update args should be '__init__' or 'forward'" ) - class_src = astor.to_source(gast.gast_to_ast(dygraph_node.func)) + class_src = ast_to_source_code(dygraph_node.func) if method_name == "__init__" or eval( f"issubclass({class_src}, paddle.nn.Layer)" @@ -454,7 +453,7 @@ def get_attribute_full_name(node): assert isinstance( node, gast.Attribute ), "Input non-Attribute node to get attribute full name" - return astor.to_source(gast.gast_to_ast(node)).strip() + return ast_to_source_code(node).strip() def generate_name_node(name_ids, ctx=gast.Load(), gen_tuple_if_single=False): diff --git a/python/paddle/jit/dy2static/utils_helper.py b/python/paddle/jit/dy2static/utils_helper.py index 5f9c8c506aca7c..4f1ae014507394 100644 --- a/python/paddle/jit/dy2static/utils_helper.py +++ b/python/paddle/jit/dy2static/utils_helper.py @@ -15,7 +15,6 @@ import inspect -import astor import numpy as np # noqa: F401 import paddle @@ -62,7 +61,7 @@ def is_api_in_module(node, module_prefix): while isinstance(func_node, gast.Call): func_node = func_node.func - func_str = astor.to_source(gast.gast_to_ast(func_node)).strip() + func_str = ast_to_source_code(func_node).strip() try: import paddle.jit.dy2static as _jst # noqa: F401 from paddle import to_tensor # noqa: F401 @@ -80,7 +79,7 @@ def _is_api_in_module_helper(obj, module_prefix): # Is numpy_api cannot reuse is_api_in_module because of numpy module problem def is_numpy_api(node): assert isinstance(node, gast.Call), "Input non-Call node for is_numpy_api" - func_str = astor.to_source(gast.gast_to_ast(node.func)) + func_str = ast_to_source_code(node.func) try: module_result = eval( "_is_api_in_module_helper({}, '{}')".format(func_str, "numpy") From 54ee802ee4757a7681a54b943960d50332bf741a Mon Sep 17 00:00:00 2001 From: tianshuo78520a <707759223@qq.com> Date: Thu, 28 Dec 2023 14:18:23 +0800 Subject: [PATCH 025/142] test test_dist_fuse_resunit_pass (#60393) --- tools/gpups_test.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/gpups_test.sh b/tools/gpups_test.sh index 883604ef6685ed..91cc6627dd7e29 100644 --- a/tools/gpups_test.sh +++ b/tools/gpups_test.sh @@ -27,10 +27,11 @@ function collect_failed_tests() { done } -# disable test: test_dist_fuse_resunit_pass +# disable test: serial_list="^test_conv2d_op$|\ ^test_conv2d_transpose_op$|\ +^test_dist_fuse_resunit_pass$|\ ^test_dygraph_dataparallel_bf16$|\ ^test_dygraph_sharding_stage1_fp16$|\ ^test_dygraph_sharding_stage1_bf16$|\ From 8710cb794f4e149ff93a045f1c3fcbc04ead03ed Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Thu, 28 Dec 2023 14:31:56 +0800 Subject: [PATCH 026/142] =?UTF-8?q?=E3=80=90PIR=20API=20adaptor=20No.271?= =?UTF-8?q?=E3=80=91Migrate=20LogNormal=20to=20pir=20(#60318)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/distribution/distribution.py | 7 +++-- python/paddle/distribution/normal.py | 20 +++++++++++-- .../test_distribution_lognormal_static.py | 30 +++++++++++++++++-- 3 files changed, 50 insertions(+), 7 deletions(-) diff --git a/python/paddle/distribution/distribution.py b/python/paddle/distribution/distribution.py index 0bc1a70a4c8547..130a5c300a64d8 100644 --- a/python/paddle/distribution/distribution.py +++ b/python/paddle/distribution/distribution.py @@ -150,7 +150,7 @@ def _validate_args(self, *args): is_variable = False is_number = False for arg in args: - if isinstance(arg, Variable): + if isinstance(arg, (Variable, paddle.pir.Value)): is_variable = True else: is_number = True @@ -176,7 +176,10 @@ def _to_tensor(self, *args): tmp = 0.0 for arg in args: - if not isinstance(arg, (float, list, tuple, np.ndarray, Variable)): + if not isinstance( + arg, + (float, list, tuple, np.ndarray, Variable, paddle.pir.Value), + ): raise TypeError( "Type of input args must be float, list, tuple, numpy.ndarray or Tensor, but received type {}".format( type(arg) diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index 53155c49287e65..aacf8ffa635a26 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -99,13 +99,29 @@ def __init__(self, loc, scale, name=None): check_type( loc, 'loc', - (int, float, np.ndarray, Variable, list, tuple), + ( + int, + float, + np.ndarray, + Variable, + paddle.pir.Value, + list, + tuple, + ), 'Normal', ) check_type( scale, 'scale', - (int, float, np.ndarray, Variable, list, tuple), + ( + int, + float, + np.ndarray, + Variable, + paddle.pir.Value, + list, + tuple, + ), 'Normal', ) diff --git a/test/distribution/test_distribution_lognormal_static.py b/test/distribution/test_distribution_lognormal_static.py index b2d61e6ddc68c4..ac4b4d428cfc90 100644 --- a/test/distribution/test_distribution_lognormal_static.py +++ b/test/distribution/test_distribution_lognormal_static.py @@ -33,9 +33,10 @@ ('one-dim', xrand((2,)), xrand((2,)), xrand((2,))), ('multi-dim', xrand((3, 3)), xrand((3, 3)), xrand((3, 3))), ], + test_pir=True, ) class TestLogNormal(unittest.TestCase): - def setUp(self): + def run_program(self): paddle.enable_static() startup_program = paddle.static.Program() main_program = paddle.static.Program() @@ -67,6 +68,13 @@ def setUp(self): self.log_prob, ] = executor.run(main_program, feed=self.feeds, fetch_list=fetch_list) + def setUp(self): + if self.test_pir: + with paddle.pir_utils.IrGuard(): + self.run_program() + else: + self.run_program() + def test_mean(self): np_mean = self.np_lognormal.mean self.assertEqual(str(self.mean.dtype).split('.')[-1], self.scale.dtype) @@ -122,9 +130,10 @@ def test_log_prob(self): @parameterize_cls( (TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand((4,)), xrand((4,), min=0, max=1))], + test_pir=True, ) class TestLogNormalSample(unittest.TestCase): - def setUp(self): + def run_program(self): paddle.enable_static() startup_program = paddle.static.Program() main_program = paddle.static.Program() @@ -150,6 +159,13 @@ def setUp(self): main_program, feed=self.feeds, fetch_list=fetch_list ) + def setUp(self): + if self.test_pir: + with paddle.pir_utils.IrGuard(): + self.run_program() + else: + self.run_program() + def test_sample(self): samples_mean = self.samples.mean(axis=0) samples_var = self.samples.var(axis=0) @@ -196,9 +212,10 @@ def _kstest(self, loc, scale, samples): xrand((2, 2)), ), ], + test_pir=True, ) class TestLogNormalKL(unittest.TestCase): - def setUp(self): + def run_program(self): paddle.enable_static() startup_program = paddle.static.Program() main_program = paddle.static.Program() @@ -236,6 +253,13 @@ def setUp(self): main_program, feed=self.feeds, fetch_list=fetch_list ) + def setUp(self): + if self.test_pir: + with paddle.pir_utils.IrGuard(): + self.run_program() + else: + self.run_program() + def test_kl_divergence(self): np.testing.assert_allclose( self.kl0, From a773f32ddaf10565dc33d5e481b25115a7b9b1ee Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Thu, 28 Dec 2023 14:34:09 +0800 Subject: [PATCH 027/142] [auto parallel] Lazy init with random control. (#60316) --- paddle/fluid/pybind/eager.cc | 15 ++++- .../paddle/distributed/auto_parallel/api.py | 38 ++++++++++-- .../distributed/auto_parallel/random.py | 40 +++++++++++- python/paddle/nn/initializer/Bilinear.py | 3 + python/paddle/nn/initializer/assign.py | 3 + python/paddle/nn/initializer/dirac.py | 3 + python/paddle/nn/initializer/initializer.py | 12 +++- python/paddle/nn/initializer/kaiming.py | 3 + python/paddle/nn/initializer/normal.py | 3 + python/paddle/nn/initializer/orthogonal.py | 3 + python/paddle/nn/initializer/uniform.py | 3 + python/paddle/nn/initializer/xavier.py | 17 +++-- test/auto_parallel/CMakeLists.txt | 4 ++ .../semi_auto_parallel_lazy_init.py | 62 +++++++++++++++++++ .../test_semi_auto_parallel_lazy_init.py | 44 +++++++++++++ 15 files changed, 236 insertions(+), 17 deletions(-) create mode 100644 test/auto_parallel/semi_auto_parallel_lazy_init.py create mode 100644 test/auto_parallel/test_semi_auto_parallel_lazy_init.py diff --git a/paddle/fluid/pybind/eager.cc b/paddle/fluid/pybind/eager.cc index 99ceed6b2b3092..3cb3ccf964ec81 100644 --- a/paddle/fluid/pybind/eager.cc +++ b/paddle/fluid/pybind/eager.cc @@ -244,9 +244,18 @@ void InitDistTensorWithTensor(TensorObject* self, std::make_shared(tensor, process_mesh, placements)); VLOG(4) << "Same place, do ShareDataWith for DistTensor."; } else { - std::shared_ptr tensor = - std::static_pointer_cast( - src.copy_to(place, true).impl()); + std::shared_ptr tensor; + if (src.initialized()) { + tensor = std::static_pointer_cast( + src.copy_to(place, true).impl()); + } else { + // lazy init branch. The src tensor is on undefined place. + PADDLE_ENFORCE( + src.place().GetType() == phi::AllocationType::UNDEFINED, + phi::errors::InvalidArgument("Only undefined place is support for " + "uninitialized input tensor.")); + tensor = std::static_pointer_cast(src.impl()); + } self->tensor.set_impl( std::make_shared(tensor, process_mesh, placements)); VLOG(4) << "Different place, do TensorCopy for DistTensor."; diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index f8eb3f71f89b9f..d3f19baded5e6b 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -45,6 +45,7 @@ from paddle.framework import core from .placement_type import check_placements_equal, get_shard_spec +from .random import determinate_rng, rng_state # There are the auto parallel API of the unified version of dynamic and static mode. # Some APIs have the same name with the previous APIs implementation, which are @@ -171,19 +172,48 @@ def shard_tensor( # `paddle.to_tensor` supports both dynamic and static mode if stop_gradient is None: stop_gradient = getattr(data, "stop_gradient", True) - tensor = paddle.to_tensor( - data, dtype=dtype, place=place, stop_gradient=stop_gradient - ) + if isinstance(data, EagerParamBase) and not data._is_initialized(): + assert ( + data._init_func is not None + ), "Get an uninitialized param with an unregistered init_func." + tensor = data + else: + tensor = paddle.to_tensor( + data, dtype=dtype, place=place, stop_gradient=stop_gradient + ) if paddle.in_dynamic_mode(): # here the dist tensor is deep copy constructed if isinstance(data, EagerParamBase): - return EagerParamBase.from_tensor( + + def lazy_init_hook(param, origin_hook): + # lazy init hook with randomness controlling + def _init_func(var, block): + # get the unique rng name + rng_name = determinate_rng( + dist.get_rank(), + process_mesh=param.process_mesh, + placements=param.placements, + ) + # real call the init function + with rng_state(rng_name): + origin_hook(var, block) + + return _init_func + + dist_param = EagerParamBase.from_tensor( tensor, process_mesh=mesh, placements=placements, **tensor.__dict__, ) + if tensor._init_func is not None: + origin_init_func = tensor._init_func + dist_param.set_init_func( + lazy_init_hook(dist_param, origin_init_func) + ) + + return dist_param else: return paddle.Tensor( tensor, process_mesh=mesh, placements=placements, place=place diff --git a/python/paddle/distributed/auto_parallel/random.py b/python/paddle/distributed/auto_parallel/random.py index d79f94e166524f..4f27d3f7cc5edb 100644 --- a/python/paddle/distributed/auto_parallel/random.py +++ b/python/paddle/distributed/auto_parallel/random.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import logging import paddle @@ -22,6 +23,7 @@ _logger = get_logger(logging.INFO) _rng_name_to_seed = {} +_rng_name_to_states = {} _inited_rng_name_to_seed = {} _enable_random_control = False _basic_seed = 42 @@ -71,7 +73,16 @@ def parallel_manual_seed(seed, name=""): _basic_name = name -def determinate_rng(rank, dims_mapping, process_mesh): +def determinate_rng( + rank, dims_mapping=None, process_mesh=None, placements=None +): + assert process_mesh is not None, "Must provide process mesh" + assert ( + dims_mapping is not None or placements is not None + ), "Must provide one of dims mapping or placements." + assert not ( + dims_mapping is not None and placements is not None + ), "Cannot provide dims mapping and placements at same time." # TODO(JZ-LIANG) Support Mesh with any high rank # use a string to unique integer hashing algorithm for seed computation. # instead of using offsets to coodinate seed across devices. @@ -98,7 +109,9 @@ def determinate_rng(rank, dims_mapping, process_mesh): seed_ += _mesh_offset * (unique_id + 1) for i in range(len(process_mesh.shape)): - if i not in dims_mapping: + if (dims_mapping is not None and i not in dims_mapping) or ( + placements is not None and not placements[i].is_shard() + ): relative_idx = -1 else: relative_idx = _get_idx_in_axis( @@ -112,6 +125,7 @@ def determinate_rng(rank, dims_mapping, process_mesh): seed_ += _dim_offsets[i] * (relative_idx + 1) global _rng_name_to_seed + global _rng_name_to_states if sharding_expr in _rng_name_to_seed: assert _rng_name_to_seed[sharding_expr] == seed_ else: @@ -121,10 +135,30 @@ def determinate_rng(rank, dims_mapping, process_mesh): seed_, sharding_expr, _rng_name_to_seed ) _rng_name_to_seed[sharding_expr] = seed_ - + if paddle.in_dynamic_mode(): + # for dygraph, just init the seed when meeting a new seed + orig_rng_state = paddle.get_rng_state() + paddle.seed(seed_) + _rng_name_to_states[sharding_expr] = paddle.get_rng_state() + paddle.set_rng_state(orig_rng_state) return sharding_expr +@contextlib.contextmanager +def rng_state(name): + global _rng_name_to_states + assert ( + name in _rng_name_to_states + ), f"The rng state name {name} haven't been init. " + orig_rng_state = paddle.get_rng_state() + paddle.set_rng_state(_rng_name_to_states[name]) + try: + yield + finally: + _rng_name_to_states[name] = paddle.get_rng_state() + paddle.set_rng_state(orig_rng_state) + + def init_auto_parallel_rng(): if not is_enable_auto_rand_ctrl(): return diff --git a/python/paddle/nn/initializer/Bilinear.py b/python/paddle/nn/initializer/Bilinear.py index a12393e2e28722..cfb18dac02c2a8 100644 --- a/python/paddle/nn/initializer/Bilinear.py +++ b/python/paddle/nn/initializer/Bilinear.py @@ -89,6 +89,9 @@ def forward(self, var, block=None): Returns: The initialization op """ + assert not ( + isinstance(var, framework.EagerParamBase) and var.is_dist() + ), "Currently, Bilinear initializer not support lazy init for dist param." block = self._check_block(block) if not isinstance(var, (framework.Variable, pir.core.ParameterMeta)): diff --git a/python/paddle/nn/initializer/assign.py b/python/paddle/nn/initializer/assign.py index 62cbcf6179f9aa..9274ff5275df09 100644 --- a/python/paddle/nn/initializer/assign.py +++ b/python/paddle/nn/initializer/assign.py @@ -56,6 +56,9 @@ def forward(self, var, block=None): Returns: The initialization op """ + assert not ( + isinstance(var, framework.EagerParamBase) and var.is_dist() + ), "Currently, assign initializer not support lazy init for dist param." block = self._check_block(block) assert isinstance( diff --git a/python/paddle/nn/initializer/dirac.py b/python/paddle/nn/initializer/dirac.py index 8ec63f64bbc028..7da5cd15b54f7e 100644 --- a/python/paddle/nn/initializer/dirac.py +++ b/python/paddle/nn/initializer/dirac.py @@ -106,6 +106,9 @@ def __call__(self, var, block=None): Returns: The most critical OP(scatter) in this initializer, which contains 7~8 ops in total. """ + assert not ( + isinstance(var, framework.EagerParamBase) and var.is_dist() + ), "Currently, dirac initializer not support lazy init for dist param." block = self._check_block(block) assert isinstance(var, (framework.Variable, pir.core.ParameterMeta)) assert isinstance(block, (framework.Block, pir.Block)) diff --git a/python/paddle/nn/initializer/initializer.py b/python/paddle/nn/initializer/initializer.py index 6f37e95a79816e..7b3901613f9e3a 100644 --- a/python/paddle/nn/initializer/initializer.py +++ b/python/paddle/nn/initializer/initializer.py @@ -17,7 +17,11 @@ import numpy as np -from ...base.framework import default_main_program, in_dygraph_mode +from ...base.framework import ( + EagerParamBase, + default_main_program, + in_dygraph_mode, +) from .lazy_init import lazy_init_helper __all__ = [] @@ -86,7 +90,11 @@ def _compute_fans(self, var): Returns: tuple of two integers (fan_in, fan_out). """ - shape = var.shape + shape = ( + var._local_shape + if (isinstance(var, EagerParamBase) and var.is_dist()) + else var.shape + ) if not shape or len(shape) == 0: fan_in = fan_out = 1 elif len(shape) == 1: diff --git a/python/paddle/nn/initializer/kaiming.py b/python/paddle/nn/initializer/kaiming.py index 14e3d726c87368..39329acaf7da13 100644 --- a/python/paddle/nn/initializer/kaiming.py +++ b/python/paddle/nn/initializer/kaiming.py @@ -91,6 +91,9 @@ def forward(self, var, block=None): Returns: The initialization op. """ + assert not ( + isinstance(var, framework.EagerParamBase) and var.is_dist() + ), "Currently, kaiming initializer not support lazy init for dist param." block = self._check_block(block) assert isinstance( var, (framework.Variable, paddle.pir.core.ParameterMeta) diff --git a/python/paddle/nn/initializer/normal.py b/python/paddle/nn/initializer/normal.py index 3983f270e60a69..4ca0a0902246c6 100644 --- a/python/paddle/nn/initializer/normal.py +++ b/python/paddle/nn/initializer/normal.py @@ -56,6 +56,9 @@ def forward(self, var, block=None): Returns: The initialization op. """ + assert not ( + isinstance(var, framework.EagerParamBase) and var.is_dist() + ), "Currently, normal initializer not support lazy init for dist param." block = self._check_block(block) assert isinstance(block, (framework.Block, pir.Block)) diff --git a/python/paddle/nn/initializer/orthogonal.py b/python/paddle/nn/initializer/orthogonal.py index 0dc2bd2aede474..486a68bcd5d0fc 100644 --- a/python/paddle/nn/initializer/orthogonal.py +++ b/python/paddle/nn/initializer/orthogonal.py @@ -81,6 +81,9 @@ def __call__(self, var, block=None): Returns: The last initialization op, it contain 8 ops in orthogonal initializer. """ + assert not ( + isinstance(var, framework.EagerParamBase) and var.is_dist() + ), "Currently, orthogonal initializer not support lazy init for dist param." block = self._check_block(block) assert isinstance(var, (framework.Variable, pir.core.ParameterMeta)) assert isinstance(block, (framework.Block, pir.Block)) diff --git a/python/paddle/nn/initializer/uniform.py b/python/paddle/nn/initializer/uniform.py index 86ef5aedbf1af2..f30ef1b38402d6 100644 --- a/python/paddle/nn/initializer/uniform.py +++ b/python/paddle/nn/initializer/uniform.py @@ -73,6 +73,9 @@ def forward(self, var, block=None): Returns: The initialization op """ + assert not ( + isinstance(var, framework.EagerParamBase) and var.is_dist() + ), "Currently, uniform initializer not support lazy init for dist param." block = self._check_block(block) assert isinstance(block, (framework.Block, pir.Block)) diff --git a/python/paddle/nn/initializer/xavier.py b/python/paddle/nn/initializer/xavier.py index 13a2c8cdce28fa..58d73d21dfe865 100644 --- a/python/paddle/nn/initializer/xavier.py +++ b/python/paddle/nn/initializer/xavier.py @@ -114,7 +114,9 @@ def forward(self, var, block=None): name=unique_name.generate( ".".join(['xavier_init', var.name, 'tmp']) ), - shape=var.shape, + shape=var._local_shape + if (isinstance(var, framework.EagerParamBase) and var.is_dist()) + else var.shape, dtype=out_dtype, type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, @@ -151,10 +153,15 @@ def forward(self, var, block=None): if var.dtype == core.VarDesc.VarType.FP16 or ( var.dtype == core.VarDesc.VarType.BF16 and not self._uniform ): - var_tmp = _C_ops.cast(out_var, var.dtype) - var_tmp._share_underline_tensor_to(var) - else: - out_var._share_underline_tensor_to(var) + out_var = _C_ops.cast(out_var, var.dtype) + if isinstance(var, framework.EagerParamBase) and var.is_dist(): + # lazy init for dist tensor + out_var = ( + paddle.distributed.auto_parallel.api.dtensor_from_local( + out_var, var.process_mesh, var.placements + ) + ) + out_var._share_underline_tensor_to(var) return None elif in_pir_mode(): if self._uniform: diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index 04d6219c5946e3..774dc3d2023b93 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -162,6 +162,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU) test_semi_auto_parallel_single_strategy) set_tests_properties(test_semi_auto_parallel_single_strategy PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 400) + py_test_modules(test_semi_auto_parallel_lazy_init MODULES + test_semi_auto_parallel_lazy_init) + set_tests_properties(test_semi_auto_parallel_lazy_init + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) py_test_modules(test_semi_auto_parallel_in_framework MODULES test_semi_auto_parallel_in_framework) set_tests_properties(test_semi_auto_parallel_in_framework diff --git a/test/auto_parallel/semi_auto_parallel_lazy_init.py b/test/auto_parallel/semi_auto_parallel_lazy_init.py new file mode 100644 index 00000000000000..52016c358ea357 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_lazy_init.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import paddle +import paddle.distributed as dist +from paddle import LazyGuard + + +class TestSemiAutoParallelLazyInit: + def __init__(self): + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + def test_replicate(self): + paddle.distributed.auto_parallel.parallel_manual_seed(self._seed) + with LazyGuard(): + linear = paddle.nn.Linear(10, 10) + linear.weight = dist.shard_tensor( + linear.weight, self._mesh, [dist.Replicate()] + ) + linear.bias = dist.shard_tensor( + linear.bias, self._mesh, [dist.Replicate()] + ) + for param in linear.parameters(): + assert not param._is_initialized() + param.initialize() + assert param._is_initialized() + + local_weight_md5 = linear.weight._local_value()._md5sum() + mesh0 = dist.ProcessMesh([0], dim_names=["x"]) + mesh1 = dist.ProcessMesh([1], dim_names=["x"]) + tmp = paddle.distributed.auto_parallel.api.dtensor_from_local( + linear.weight._local_value(), + mesh0 if dist.get_rank() == 0 else mesh1, + [dist.Replicate()], + ) + tmp = dist.reshard( + tmp, mesh1 if dist.get_rank() == 0 else mesh0, [dist.Replicate()] + ) + tmp_md5 = tmp._local_value()._md5sum() + assert local_weight_md5 == tmp_md5 + + def run_test_case(self): + self.test_replicate() + + +if __name__ == '__main__': + TestSemiAutoParallelLazyInit().run_test_case() diff --git a/test/auto_parallel/test_semi_auto_parallel_lazy_init.py b/test/auto_parallel/test_semi_auto_parallel_lazy_init.py new file mode 100644 index 00000000000000..d0c09749af53d5 --- /dev/null +++ b/test/auto_parallel/test_semi_auto_parallel_lazy_init.py @@ -0,0 +1,44 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import collective.test_communication_api_base as test_base + + +class TestSemiAutoParallelLazyInit(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp( + num_of_devices=2, + timeout=120, + ) + self._default_envs = { + "dtype": "float32", + "seed": "2023", + } + self._changeable_envs = {"backend": ["cpu", "gpu"]} + + def test_lazy_init(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_lazy_init.py", + user_defined_envs=envs, + ) + + +if __name__ == "__main__": + unittest.main() From 180ded554f73baa1e8a401a7979bf9e4b9038492 Mon Sep 17 00:00:00 2001 From: Nyakku Shigure Date: Thu, 28 Dec 2023 14:35:16 +0800 Subject: [PATCH 028/142] [Dy2St] Unify PT flags in dy2st and run PT in AST (#60410) --- .../eager/to_static/run_program_op_node.h | 120 ++++++++++++------ paddle/fluid/framework/executor_cache.cc | 4 +- paddle/fluid/framework/executor_cache.h | 18 ++- .../paddle/jit/dy2static/partial_program.py | 64 ++++------ test/custom_runtime/CMakeLists.txt | 8 +- .../test_custom_cpu_to_static.py | 4 +- test/dygraph_to_static/CMakeLists.txt | 2 +- 7 files changed, 130 insertions(+), 90 deletions(-) diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index 257b249e51600d..b409c0f7067e56 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -488,8 +488,11 @@ inline void PirRunProgramAPI( paddle::framework::InterpreterCoreInfoCache::Instance(); std::shared_ptr interpreter_core = nullptr; - if (!interpretercore_info_cache.Has( - program_id, global_inner_scope, place_hash_key, /*is_grad=*/false)) { + if (!interpretercore_info_cache.Has(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/false, + /*in_pir_mode=*/true)) { paddle::platform::RecordEvent record_event( "create_new_interpretercore", paddle::platform::TracerEventType::UserDefined, @@ -555,8 +558,12 @@ inline void PirRunProgramAPI( 1); VLOG(2) << "Get interpretercore cache by program:" << program_id; // Step 1. get cache interpretercore - auto &cached_value = interpretercore_info_cache.GetMutable( - program_id, global_inner_scope, place_hash_key, /*is_grad=*/false); + auto &cached_value = + interpretercore_info_cache.GetMutable(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/false, + /*in_pir_mode=*/true); interpreter_core = cached_value.core_; // Step 2. update scope for cache interpretercore details::ShareTensorsIntoScopeByValue( @@ -631,6 +638,12 @@ inline void RunProgramAPI( int64_t program_id = PADDLE_GET_CONST(int64_t, attrs.at("program_id")); auto place = egr::Controller::Instance().GetExpectedPlace(); + bool in_pir_pt_mode = FLAGS_enable_pir_with_pt_in_dy2st; + if (attrs.count("in_pir_pt_mode")) { + in_pir_pt_mode = PADDLE_GET_CONST(bool, attrs.at("in_pir_pt_mode")); + } + in_pir_pt_mode = in_pir_pt_mode || FLAGS_enable_pir_in_executor; + // NOTE(chenweihang): In order not to add new variable type, use vector // here. Originally, here can use scope directly. auto *out_scope_vec = &step_scope; @@ -688,8 +701,11 @@ inline void RunProgramAPI( paddle::framework::InterpreterCoreInfoCache::Instance(); std::shared_ptr interpreter_core = nullptr; - if (!interpretercore_info_cache.Has( - program_id, global_inner_scope, place_hash_key, /*is_grad=*/false)) { + if (!interpretercore_info_cache.Has(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/false, + /*in_pir_mode=*/in_pir_pt_mode)) { paddle::platform::RecordEvent record_event( "create_new_interpretercore", paddle::platform::TracerEventType::UserDefined, @@ -702,12 +718,7 @@ inline void RunProgramAPI( details::ShareTensorsIntoScope(params, global_inner_scope); // Step 2. create new interpretercore - bool in_pir_pt_mode = FLAGS_enable_pir_with_pt_in_dy2st; - if (attrs.count("in_pir_pt_mode")) { - in_pir_pt_mode = PADDLE_GET_CONST(bool, attrs.at("in_pir_pt_mode")); - } - - if (FLAGS_enable_pir_in_executor || in_pir_pt_mode) { + if (in_pir_pt_mode) { // build new ir program auto ir_program = paddle::framework::ConstructFowardIrProgram(forward_global_block, @@ -765,6 +776,7 @@ inline void RunProgramAPI( global_inner_scope, place_hash_key, false, + in_pir_pt_mode, skip_eager_delete_vars); VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size(); } else { @@ -774,8 +786,12 @@ inline void RunProgramAPI( 1); VLOG(2) << "Get interpretercore cahce by program:" << program_id; // Step 1. get cache interpretercore - auto &cached_value = interpretercore_info_cache.GetMutable( - program_id, global_inner_scope, place_hash_key, /*is_grad=*/false); + auto &cached_value = + interpretercore_info_cache.GetMutable(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/false, + /*in_pir_mode=*/in_pir_pt_mode); interpreter_core = cached_value.core_; // Step 2. update scope for cache interpretercore details::ShareTensorsIntoScopeWithName(x, input_names, global_inner_scope); @@ -840,6 +856,12 @@ inline void RunProgramGradAPI( int64_t program_id = PADDLE_GET_CONST(int64_t, attrs.at("program_id")); + bool in_pir_pt_mode = FLAGS_enable_pir_with_pt_in_dy2st; + if (attrs.count("in_pir_pt_mode")) { + in_pir_pt_mode = PADDLE_GET_CONST(bool, attrs.at("in_pir_pt_mode")); + } + in_pir_pt_mode = in_pir_pt_mode || FLAGS_enable_pir_in_executor; + auto place = egr::Controller::Instance().GetExpectedPlace(); VLOG(2) << "RunProgramGradOp use interpretercore to execute program."; @@ -858,8 +880,11 @@ inline void RunProgramGradAPI( paddle::framework::InterpreterCoreInfoCache::Instance(); std::shared_ptr interpreter_core = nullptr; - if (!interpretercore_info_cache.Has( - program_id, global_inner_scope, place_hash_key, /*is_grad=*/true)) { + if (!interpretercore_info_cache.Has(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/true, + /*in_pir_mode=*/in_pir_pt_mode)) { paddle::platform::RecordEvent record_event( "create_new_interpretercore", paddle::platform::TracerEventType::UserDefined, @@ -869,12 +894,7 @@ inline void RunProgramGradAPI( << program_id; details::ShareTensorsIntoScope(out_grad, global_inner_scope); - bool in_pir_pt_mode = FLAGS_enable_pir_with_pt_in_dy2st; - if (attrs.count("in_pir_pt_mode")) { - in_pir_pt_mode = PADDLE_GET_CONST(bool, attrs.at("in_pir_pt_mode")); - } - - if (FLAGS_enable_pir_in_executor || in_pir_pt_mode) { + if (in_pir_pt_mode) { auto res = paddle::framework::ConstructBackwardIrProgram(backward_global_block, out_grad, @@ -904,14 +924,19 @@ inline void RunProgramGradAPI( // share threadpool // NOTE(zhiqiu): this only works interpreter_core is executed strictly // after the related fwd_interpreter_core. - if (interpretercore_info_cache.Has( - program_id, global_inner_scope, place_hash_key, false)) { - auto fwd_interpreter_core = interpretercore_info_cache - .GetMutable(program_id, - global_inner_scope, - place_hash_key, - /*is_grad=*/false) - .core_; + if (interpretercore_info_cache.Has(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/false, + /*in_pir_mode=*/in_pir_pt_mode)) { + auto fwd_interpreter_core = + interpretercore_info_cache + .GetMutable(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/false, + /*in_pir_mode=*/in_pir_pt_mode) + .core_; interpreter_core->ShareWorkQueueFrom(fwd_interpreter_core); VLOG(4) << "Share workqueue from " << fwd_interpreter_core.get() << " to " << interpreter_core.get(); @@ -938,6 +963,7 @@ inline void RunProgramGradAPI( global_inner_scope, place_hash_key, /*is_grad=*/true, + in_pir_pt_mode, skip_eager_delete_vars); VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size(); } else { @@ -946,8 +972,12 @@ inline void RunProgramGradAPI( paddle::platform::TracerEventType::UserDefined, 1); VLOG(2) << "Get interpretercore cahce by program:" << program_id; - auto &cached_value = interpretercore_info_cache.GetMutable( - program_id, global_inner_scope, place_hash_key, /*is_grad=*/true); + auto &cached_value = + interpretercore_info_cache.GetMutable(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/true, + /*in_pir_mode=*/in_pir_pt_mode); interpreter_core = cached_value.core_; // update scope @@ -1054,8 +1084,11 @@ inline void PirRunProgramGradAPI( paddle::framework::InterpreterCoreInfoCache::Instance(); std::shared_ptr interpreter_core = nullptr; - if (!interpretercore_info_cache.Has( - program_id, global_inner_scope, place_hash_key, /*is_grad=*/true)) { + if (!interpretercore_info_cache.Has(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/true, + /*in_pir_mode=*/true)) { paddle::platform::RecordEvent record_event( "create_new_interpretercore", paddle::platform::TracerEventType::UserDefined, @@ -1080,13 +1113,17 @@ inline void PirRunProgramGradAPI( // share threadpool // NOTE(zhiqiu): this only works interpreter_core is executed strictly // after the related fwd_interpreter_core. - if (interpretercore_info_cache.Has( - program_id, global_inner_scope, place_hash_key, false)) { + if (interpretercore_info_cache.Has(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/false, + /*in_pir_mode=*/true)) { auto fwd_interpreter_core = interpretercore_info_cache .GetMutable(program_id, global_inner_scope, place_hash_key, - /*is_grad=*/false) + /*is_grad=*/false, + /*in_pir_mode=*/true) .core_; interpreter_core->ShareWorkQueueFrom(fwd_interpreter_core); VLOG(4) << "Share workqueue from " << fwd_interpreter_core.get() << " to " @@ -1107,6 +1144,7 @@ inline void PirRunProgramGradAPI( global_inner_scope, place_hash_key, /*is_grad=*/true, + /*in_pir_mode=*/true, skip_eager_delete_vars); VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size(); details::print_collection(skip_eager_delete_vars); @@ -1116,8 +1154,12 @@ inline void PirRunProgramGradAPI( paddle::platform::TracerEventType::UserDefined, 1); VLOG(2) << "Get interpretercore cahce by program:" << program_id; - auto &cached_value = interpretercore_info_cache.GetMutable( - program_id, global_inner_scope, place_hash_key, /*is_grad=*/true); + auto &cached_value = + interpretercore_info_cache.GetMutable(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/true, + /*in_pir_mode=*/true); interpreter_core = cached_value.core_; if (interpreter_core->GetVariableScope()->GetMutableScope() != diff --git a/paddle/fluid/framework/executor_cache.cc b/paddle/fluid/framework/executor_cache.cc index 6af74433583613..97e4d386ea9aa8 100644 --- a/paddle/fluid/framework/executor_cache.cc +++ b/paddle/fluid/framework/executor_cache.cc @@ -326,7 +326,7 @@ std::shared_ptr CreateProgramInterpreterCoreInfoToCache( place, program_desc.Block(0), scope, execution_config)); auto &cached_value = interpretercore_info_cache.GetMutable( - program_id, scope, place_hash_key, is_grad); + program_id, scope, place_hash_key, is_grad, /*in_pir_mode=*/false); cached_value.core_ = core; return core; } @@ -355,7 +355,7 @@ std::shared_ptr CreatePirInterpreterCoreInfoToCache( place, {}, ir_program->block(), scope, execution_config)); auto &cached_value = interpretercore_info_cache.GetMutable( - program_id, scope, place_hash_key, is_grad); + program_id, scope, place_hash_key, is_grad, /*in_pir_mode=*/true); cached_value.core_ = core; cached_value.ir_prog_ = std::move(ir_program); return core; diff --git a/paddle/fluid/framework/executor_cache.h b/paddle/fluid/framework/executor_cache.h index 57d9b06d92b0ee..bd8b82180cbac0 100644 --- a/paddle/fluid/framework/executor_cache.h +++ b/paddle/fluid/framework/executor_cache.h @@ -196,8 +196,9 @@ class InterpreterCoreInfoCache { bool Has(int64_t program_id, const framework::Scope* scope, const int64_t& place_hash_key, - bool is_grad) { - if (FLAGS_enable_pir_in_executor || FLAGS_enable_pir_with_pt_in_dy2st) { + bool is_grad, + bool in_pir_mode) { + if (in_pir_mode) { int64_t scope_i = reinterpret_cast(scope); program_id = hash_with_seed(program_id, scope_i); program_id = hash_with_seed(program_id, place_hash_key); @@ -209,8 +210,9 @@ class InterpreterCoreInfoCache { InterpreterCoreInfo::CacheValue& GetMutable(int64_t program_id, const framework::Scope* scope, const int64_t& place_hash_key, - bool is_grad) { - if (FLAGS_enable_pir_in_executor || FLAGS_enable_pir_with_pt_in_dy2st) { + bool is_grad, + bool in_pir_mode) { + if (in_pir_mode) { int64_t scope_i = reinterpret_cast(scope); program_id = hash_with_seed(program_id, scope_i); program_id = hash_with_seed(program_id, place_hash_key); @@ -222,16 +224,20 @@ class InterpreterCoreInfoCache { const framework::Scope* scope, const int64_t& place_hash_key, bool is_grad, + bool in_pir_mode, const std::set& skip_vars) { - auto& cached_value = GetMutable(program_id, scope, place_hash_key, is_grad); + auto& cached_value = + GetMutable(program_id, scope, place_hash_key, is_grad, in_pir_mode); cached_value.skip_eager_delete_vars_ = std::move(skip_vars); } std::set& GetSkipEagerDeleteVars(int64_t program_id, const framework::Scope* scope, const int64_t& place_hash_key, + bool in_pir_mode, bool is_grad) { - auto& cached_value = GetMutable(program_id, scope, place_hash_key, is_grad); + auto& cached_value = + GetMutable(program_id, scope, place_hash_key, is_grad, in_pir_mode); return cached_value.skip_eager_delete_vars_; } diff --git a/python/paddle/jit/dy2static/partial_program.py b/python/paddle/jit/dy2static/partial_program.py index ef567d193b85c7..84719c3eee7928 100644 --- a/python/paddle/jit/dy2static/partial_program.py +++ b/python/paddle/jit/dy2static/partial_program.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os from copy import deepcopy import numpy as np @@ -229,10 +228,7 @@ def __call__(self, inputs): in_vars, in_var_names = self._prepare_inputs(inputs) out_vars = self._prepare_outputs() self._cast_fp16_if_pure_fp16(in_vars) - # TODO(dev): Currently AST + PT has some issues in control flow, so we only - # enable SOT + PT in 2.6, we will fix it later. - is_dy2st_test = os.environ.get("DY2ST_TEST", None) == "True" - attrs = self._prepare_attributes(force_not_use_pt=(not is_dy2st_test)) + attrs = self._prepare_attributes() attrs.extend(["x_names", in_var_names]) self._sync_lr_value_with_scheduler() @@ -259,7 +255,7 @@ def sot_call(self, inputs): """ out_vars = self._prepare_outputs() self._cast_fp16_if_pure_fp16(inputs) - attrs = self._prepare_attributes(force_not_use_pt=False) + attrs = self._prepare_attributes() attrs.extend(["x_names", self._in_var_names]) self._sync_lr_value_with_scheduler() @@ -296,14 +292,7 @@ def set_hooker(self, hooker): self._hooker = hooker def _get_scope(self, program_id=None, use_scope_cache=False): - if ( - get_flags('FLAGS_enable_pir_in_executor')[ - 'FLAGS_enable_pir_in_executor' - ] - or get_flags('FLAGS_enable_pir_with_pt_in_dy2st')[ - 'FLAGS_enable_pir_with_pt_in_dy2st' - ] - ): + if self._in_pir_pt_mode or self._enable_pir_in_executor: _scope_cache = self._pir_scope_cache else: _scope_cache = self._legacy_scope_cache @@ -768,7 +757,28 @@ def _cast_fp16_if_pure_fp16(self, in_vars): in_vars[i] = var.astype('float16') in_vars[i].name = name - def _prepare_attributes(self, force_not_use_pt=False): + @property + def _in_pir_pt_mode(self): + pir_dy2st_flag = 'FLAGS_enable_pir_with_pt_in_dy2st' + in_pir_pt_mode = get_flags(pir_dy2st_flag)[pir_dy2st_flag] + is_prim_enabled = ( + core._is_fwd_prim_enabled() or core._is_bwd_prim_enabled() + ) + in_cinn_backend = self._backend == "CINN" + is_cinn_enabled = self._build_strategy.build_cinn_pass + if is_prim_enabled or in_cinn_backend or is_cinn_enabled: + in_pir_pt_mode = False + return in_pir_pt_mode + + @property + def _enable_pir_in_executor(self): + enable_pir_in_executor_flag = 'FLAGS_enable_pir_in_executor' + enable_pir_in_executor = get_flags(enable_pir_in_executor_flag)[ + enable_pir_in_executor_flag + ] + return enable_pir_in_executor + + def _prepare_attributes(self): attrs = [ 'forward_global_block', self.forward_program.desc.block(0), @@ -804,17 +814,7 @@ def _prepare_attributes(self, force_not_use_pt=False): ) ) - pir_dy2st_flag = 'FLAGS_enable_pir_with_pt_in_dy2st' - in_pir_pt_mode = get_flags(pir_dy2st_flag)[pir_dy2st_flag] - is_prim_enabled = ( - core._is_fwd_prim_enabled() or core._is_bwd_prim_enabled() - ) - in_cinn_backend = self._backend == "CINN" - is_cinn_enabled = self._build_strategy.build_cinn_pass - if is_prim_enabled or in_cinn_backend or is_cinn_enabled: - in_pir_pt_mode = False - if force_not_use_pt: - in_pir_pt_mode = False + in_pir_pt_mode = self._in_pir_pt_mode attrs.extend(['in_pir_pt_mode', in_pir_pt_mode]) return attrs @@ -901,21 +901,13 @@ def _apply_inplace_pass(self, forward_program, backward_program): forward_program, backward_program ) backward_mem_opt_skip_vars = self._parse_skip_gc_vars(forward_program) - in_pir_pt_mode = ( - get_flags('FLAGS_enable_pir_in_executor')[ - 'FLAGS_enable_pir_in_executor' - ] - or get_flags('FLAGS_enable_pir_with_pt_in_dy2st')[ - 'FLAGS_enable_pir_with_pt_in_dy2st' - ] - ) if forward_program: attrs = { "use_cuda": use_cuda, "mem_opt_skip_vars": forward_mem_opt_skip_vars, "for_partial_block": True, } - if not in_pir_pt_mode: + if not (self._in_pir_pt_mode or self._enable_pir_in_executor): _apply_pass( forward_program, empty_startup_program, @@ -929,7 +921,7 @@ def _apply_inplace_pass(self, forward_program, backward_program): "mem_opt_skip_vars": backward_mem_opt_skip_vars, "for_partial_block": True, } - if not in_pir_pt_mode: + if not (self._in_pir_pt_mode or self._enable_pir_in_executor): _apply_pass( backward_program, empty_startup_program, diff --git a/test/custom_runtime/CMakeLists.txt b/test/custom_runtime/CMakeLists.txt index e8b14445278be8..b0b162c19d6ed1 100644 --- a/test/custom_runtime/CMakeLists.txt +++ b/test/custom_runtime/CMakeLists.txt @@ -9,9 +9,11 @@ if(WITH_CUSTOM_DEVICE AND NOT WITH_GPU) string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") foreach(TEST_OP ${TEST_OPS}) - py_test(${TEST_OP} - SRCS ${TEST_OP}.py ENVS FLAGS_allocator_strategy=naive_best_fit - PLUGIN_URL=${PLUGIN_URL} PLUGIN_TAG=${PLUGIN_TAG}) + py_test( + ${TEST_OP} + SRCS ${TEST_OP}.py ENVS FLAGS_allocator_strategy=naive_best_fit + PLUGIN_URL=${PLUGIN_URL} PLUGIN_TAG=${PLUGIN_TAG} + FLAGS_enable_pir_with_pt_in_dy2st=False) endforeach() bash_test_modules( diff --git a/test/custom_runtime/test_custom_cpu_to_static.py b/test/custom_runtime/test_custom_cpu_to_static.py index 78978e9175310e..b365f8ab39811e 100644 --- a/test/custom_runtime/test_custom_cpu_to_static.py +++ b/test/custom_runtime/test_custom_cpu_to_static.py @@ -164,9 +164,7 @@ def forward(self, x): # convert to static model build_strategy = paddle.static.BuildStrategy() - mnist = paddle.jit.to_static( - model, build_strategy=build_strategy, full_graph=True - ) + mnist = paddle.jit.to_static(model, build_strategy=build_strategy) # data loader transform = paddle.vision.transforms.Compose( diff --git a/test/dygraph_to_static/CMakeLists.txt b/test/dygraph_to_static/CMakeLists.txt index e9ae745681017c..f54bd5f714b9ea 100644 --- a/test/dygraph_to_static/CMakeLists.txt +++ b/test/dygraph_to_static/CMakeLists.txt @@ -4,7 +4,7 @@ file( "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") set(SOT_ENVS SOT_LOG_LEVEL=0 COST_MODEL=False MIN_GRAPH_SIZE=0 - STRICT_MODE=False DY2ST_TEST=True) + STRICT_MODE=False) set(GC_ENVS FLAGS_eager_delete_tensor_gb=0.0) list(REMOVE_ITEM TEST_OPS test_lac) From 875fbfb4b733856e3a4a452358e6d7d6047dbbc0 Mon Sep 17 00:00:00 2001 From: Nyakku Shigure Date: Thu, 28 Dec 2023 15:15:55 +0800 Subject: [PATCH 029/142] [Dy2St] Use `ShadowOutputOp` to get dy2st output (#60363) --- .../pir_adaptor/pir_adaptor_util.cc | 4 ++ paddle/fluid/pybind/pir.cc | 44 +++++++++---------- .../jit/dy2static/pir_partial_program.py | 35 ++++++++------- .../jit/pir_dy2static/parameter_recorder.py | 2 +- .../test_tensor_memcpy_on_cpu.py | 3 +- 5 files changed, 48 insertions(+), 40 deletions(-) diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc index 7f110b49b218fd..a06abb197de5fe 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc @@ -542,6 +542,10 @@ void HandleForSpecialOp(pir::Operation* op, // change opreand name to param_name auto orig_name = value_exe_info->GetValue2VarName().at(value); + if (var_name == orig_name) { + return; + } + if (value_exe_info->GetScope()->FindVar(var_name) != nullptr) { const_cast(value_exe_info->GetScope())->EraseVars({var_name}); VLOG(1) << "var " << var_name << " has been removed from scope"; diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 2103e7b7b660e2..9e87a3f39459df 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -1057,14 +1057,14 @@ std::pair, OpResultMap> CloneProgram( std::make_pair(associated_array_key, associated_array_value)); } -void AppendSetParameter(Program *forward_program, +void AppendShadowOutput(Program *forward_program, const pir::OpResult &result, const std::string &name, size_t start_point) { pir::IrContext *ctx = pir::IrContext::Instance(); - auto op_info = ctx->GetRegisteredOpInfo(pir::SetParameterOp::name()); + auto op_info = ctx->GetRegisteredOpInfo(pir::ShadowOutputOp::name()); pir::AttributeMap attribute_map = { - {"parameter_name", pir::StrAttribute::get(ctx, name)}, + {"output_name", pir::StrAttribute::get(ctx, name)}, }; pir::Operation *operation = pir::Operation::Create({result}, attribute_map, {}, op_info); @@ -1077,7 +1077,7 @@ void AppendSetParameter(Program *forward_program, } } -int AppendSetParameters(Program *forward_program, +int AppendShadowOutputs(Program *forward_program, const std::vector &outputs_op_result, int start_point, std::string name_prefix) { @@ -1086,9 +1086,9 @@ int AppendSetParameters(Program *forward_program, for (const auto &result : outputs_op_result) { if (!added_op_result.count(result) || IsFakeOpResult(result)) { - std::string parameter_name = name_prefix + std::to_string(counter); - AppendSetParameter( - forward_program, result, parameter_name, start_point + counter); + std::string shadow_output_name = name_prefix + std::to_string(counter); + AppendShadowOutput( + forward_program, result, shadow_output_name, start_point + counter); counter += 1; added_op_result.insert(result); } @@ -1204,20 +1204,20 @@ SplitedResult SplitForwardBackward( if (v.impl() == nullptr) { return; } - // NOTE(Aurelius84): we should skip insert SetParameterOp repeatly by + // NOTE(Aurelius84): we should skip insert ShadowOutputOp repeatly by // calling SplitForwardBackward multi-times. - std::string parameter_name = + std::string shadow_output_name = std::string("output_") + std::to_string(counter); std::unordered_set inserted_value; for (auto it = forward_program->block()->rbegin(); it != forward_program->block()->rend(); ++it) { - if (it->isa()) { + if (it->isa()) { auto out_name = - it->attribute("parameter_name").AsString(); - if (out_name == parameter_name) { + it->attribute("output_name").AsString(); + if (out_name == shadow_output_name) { VLOG(4) << out_name - << " has been inserted SetParameterOp, skip it now."; + << " has been inserted ShadowOutputOp, skip it now."; return; } @@ -1228,9 +1228,9 @@ SplitedResult SplitForwardBackward( if (inserted_value.count(forward_value_map[v])) { return; } - auto op_info = ctx->GetRegisteredOpInfo(pir::SetParameterOp::name()); + auto op_info = ctx->GetRegisteredOpInfo(pir::ShadowOutputOp::name()); pir::AttributeMap attribute_map = { - {"parameter_name", pir::StrAttribute::get(ctx, parameter_name)}, + {"output_name", pir::StrAttribute::get(ctx, shadow_output_name)}, }; pir::Operation *operation = pir::Operation::Create( {forward_value_map[v]}, attribute_map, {}, op_info); @@ -1245,9 +1245,9 @@ SplitedResult SplitForwardBackward( if (v.impl() == nullptr) { return; } - auto op_info = ctx->GetRegisteredOpInfo(pir::SetParameterOp::name()); + auto op_info = ctx->GetRegisteredOpInfo(pir::ShadowOutputOp::name()); pir::AttributeMap attribute_map = { - {"parameter_name", + {"output_name", pir::StrAttribute::get( ctx, std::string("output_") + std::to_string(counter))}, }; @@ -1372,10 +1372,10 @@ pir::Type CreateSelectedRowsTypeByDenseTensor(pir::Type dense_tensor_type) { } } -void ResetParameterName(pir::Operation *op, const std::string &name) { +void ResetShadowOutputName(pir::Operation *op, const std::string &name) { pir::IrContext *ctx = pir::IrContext::Instance(); - if (op->isa()) { - op->set_attribute("parameter_name", pir::StrAttribute::get(ctx, name)); + if (op->isa()) { + op->set_attribute("output_name", pir::StrAttribute::get(ctx, name)); } } @@ -1410,9 +1410,9 @@ std::map GetOpInplaceInfo(const pir::Operation *op) { void BindUtils(pybind11::module *m) { m->def("clone_program", CloneProgram); m->def("get_op_inplace_info", GetOpInplaceInfo); - m->def("reset_parameter_name", ResetParameterName); + m->def("reset_shadow_output_name", ResetShadowOutputName); m->def("split_program", SplitForwardBackward); - m->def("append_set_parameters", AppendSetParameters); + m->def("append_shadow_outputs", AppendShadowOutputs); m->def("fake_op_result", FakeOpResult); m->def("is_fake_op_result", IsFakeOpResult); m->def("get_current_insertion_point", []() -> PyInsertionPoint { diff --git a/python/paddle/jit/dy2static/pir_partial_program.py b/python/paddle/jit/dy2static/pir_partial_program.py index 2b1f6c6b478746..a5858df1886e8f 100644 --- a/python/paddle/jit/dy2static/pir_partial_program.py +++ b/python/paddle/jit/dy2static/pir_partial_program.py @@ -103,7 +103,7 @@ def union(self, x, y): self.father[father_x] = father_y def find_root(self, x): - if not self.father.__contains__(x): + if x not in self.father: self.father[x] = x if self.father[x].is_same(x): return x @@ -135,24 +135,29 @@ def _get_value_name_map_from_program(cls, program): ret = ValueDict() ret[fake_op_result()] = "FakeVar" for op in program.global_block().ops: - if op.name() == "pd_op.data": - ret[op.result(0)] = op.attrs()["name"] if op.name() == "builtin.set_parameter": ret[op.operand(0).source()] = op.attrs()["parameter_name"] - if op.name() == "builtin.parameter": + elif op.name() == "builtin.parameter": ret[op.result(0)] = op.attrs()["parameter_name"] + elif op.name() == "builtin.shadow_output": + ret[op.operand(0).source()] = op.attrs()["output_name"] + elif op.name() == "pd_op.data": + ret[op.result(0)] = op.attrs()["name"] return ret @classmethod def _get_name_defining_op(cls, program, value): for op in program.global_block().ops: - if op.name() == "pd_op.data": + if op.name() == "builtin.set_parameter": + if value.is_same(op.operand(0).source()): + return op + elif op.name() == "builtin.parameter": if value.is_same(op.result(0)): return op - if op.name() == "builtin.set_parameter": + elif op.name() == "builtin.shadow_output": if value.is_same(op.operand(0).source()): return op - if op.name() == "builtin.parameter": + elif op.name() == "pd_op.data": if value.is_same(op.result(0)): return op return None @@ -291,7 +296,7 @@ def _forward_backward_program(self): def program_attr(self): assert ( self.finish_pass is False - ), "program_attr() is called by PartialProgramLayer, don't call it matually, use program_name_attr instead." + ), "program_attr() is called by PartialProgramLayer, don't call it manually, use program_name_attr instead." # can't apply pass after call this function. self.finish_pass = True fwd_map = { @@ -346,7 +351,7 @@ def has_name(value): if has_name(ufset.find_root(value)): name_defining_op = self._get_name_defining_op(program, value) if name_defining_op: - paddle.core.pir.reset_parameter_name( + paddle.core.pir.reset_shadow_output_name( name_defining_op, value2name[ufset.find_root(value)] ) @@ -384,8 +389,8 @@ class PirPassContext: """ INPUT_OP_NAME = "pd_op.data" - PARM_OP_NAME = "builtin.parameter" - OUTPUT_OP_NAME = "builtin.set_parameter" + PARAM_OP_NAME = "builtin.parameter" + OUTPUT_OP_NAME = "builtin.shadow_output" @classmethod def apply(cls, runable_program, build_strategy): @@ -419,7 +424,7 @@ def _prepare_attr(cls, program): op_name = op.name() if op_name == cls.INPUT_OP_NAME: inputs.append(op.result(0)) - elif op_name == cls.PARM_OP_NAME: + elif op_name == cls.PARAM_OP_NAME: params.append(op.result(0)) elif op_name == cls.OUTPUT_OP_NAME: outputs.append(op.operand(0).source()) @@ -546,7 +551,7 @@ def origin_runable_program(self): inputs = list(self._inputs.var_list) outputs = list(self._outputs.var_list) params = self._param_values - paddle.base.libpaddle.pir.append_set_parameters( + paddle.base.libpaddle.pir.append_shadow_outputs( self._origin_main_program, outputs, len(self._origin_main_program.global_block().ops), @@ -796,7 +801,7 @@ def _append_backward_desc(self, train_runnable_program: RunableProgram): dtype=out_op_result.dtype, ) forward_outputs_grads.append(value) - paddle.base.libpaddle.pir.append_set_parameters( + paddle.base.libpaddle.pir.append_shadow_outputs( program, forward_outputs_grads, len(program.global_block().ops), @@ -861,7 +866,7 @@ def _append_backward_desc(self, train_runnable_program: RunableProgram): ) ) backward_end_op_index = len(program.global_block().ops) - paddle.base.libpaddle.pir.append_set_parameters( + paddle.base.libpaddle.pir.append_shadow_outputs( program, output_grads_to_append, backward_end_op_index, diff --git a/python/paddle/jit/pir_dy2static/parameter_recorder.py b/python/paddle/jit/pir_dy2static/parameter_recorder.py index 565dad78f394d1..538ec04f265a96 100644 --- a/python/paddle/jit/pir_dy2static/parameter_recorder.py +++ b/python/paddle/jit/pir_dy2static/parameter_recorder.py @@ -81,7 +81,7 @@ def get(self, program, value): return None root_var = inplace_dict[value] saved = [] - while inplace_dict.__contains__(root_var): + while root_var in inplace_dict: saved.append(root_var) root_var = inplace_dict[root_var] for var in saved: diff --git a/test/dygraph_to_static/test_tensor_memcpy_on_cpu.py b/test/dygraph_to_static/test_tensor_memcpy_on_cpu.py index 0b92fae0556bbb..ccf0b35ee4d296 100644 --- a/test/dygraph_to_static/test_tensor_memcpy_on_cpu.py +++ b/test/dygraph_to_static/test_tensor_memcpy_on_cpu.py @@ -18,7 +18,6 @@ from dygraph_to_static_utils import ( Dy2StTestBase, enable_to_static_guard, - test_legacy_and_pt, test_legacy_and_pt_and_pir, ) @@ -69,7 +68,7 @@ def _run(self): x2 = paddle.jit.to_static(tensor_copy_to_cuda)(x1) return x1.place, x2.place, x2.numpy() - @test_legacy_and_pt + @test_legacy_and_pt_and_pir def test_tensor_cuda_on_default_cpu(self): if not paddle.is_compiled_with_cuda(): return From beba862cd2aa4dd2b14cdd0c6c4c08be33df62f2 Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Thu, 28 Dec 2023 15:39:18 +0800 Subject: [PATCH 030/142] =?UTF-8?q?=E3=80=90Hackathon=205th=20No.25?= =?UTF-8?q?=E3=80=91add=20`gammaln`=20api=20(#59311)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/api/yaml/backward.yaml | 10 ++ paddle/phi/api/yaml/ops.yaml | 10 ++ paddle/phi/kernels/cpu/gammaln_grad_kernel.cc | 22 +++ paddle/phi/kernels/cpu/gammaln_kernel.cc | 22 +++ paddle/phi/kernels/gammaln_grad_kernel.h | 27 +++ paddle/phi/kernels/gammaln_kernel.h | 26 +++ paddle/phi/kernels/gpu/gammaln_grad_kernel.cu | 30 ++++ paddle/phi/kernels/gpu/gammaln_kernel.cu | 29 ++++ .../kernels/impl/gammaln_grad_kernel_impl.h | 92 ++++++++++ paddle/phi/kernels/impl/gammaln_kernel_impl.h | 49 ++++++ python/paddle/__init__.py | 4 + python/paddle/tensor/__init__.py | 4 + python/paddle/tensor/math.py | 45 +++++ test/legacy_test/test_gammaln_op.py | 160 ++++++++++++++++++ test/legacy_test/test_inplace.py | 8 + 15 files changed, 538 insertions(+) create mode 100644 paddle/phi/kernels/cpu/gammaln_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/gammaln_kernel.cc create mode 100644 paddle/phi/kernels/gammaln_grad_kernel.h create mode 100644 paddle/phi/kernels/gammaln_kernel.h create mode 100644 paddle/phi/kernels/gpu/gammaln_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/gammaln_kernel.cu create mode 100644 paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/gammaln_kernel_impl.h create mode 100644 test/legacy_test/test_gammaln_op.py diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 938ea9d5000460..d5748145ffe49d 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -922,6 +922,16 @@ kernel : func : frame_grad +- backward_op : gammaln_grad + forward : gammaln(Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : gammaln_grad + - backward_op : gather_grad forward : gather(Tensor x, Tensor index, Scalar axis=0) -> Tensor(out) args : (Tensor x, Tensor index, Tensor out_grad, Scalar axis=0) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index de4d700cdf80ee..dc545b7a2da546 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1042,6 +1042,16 @@ data_type : dtype backend : place +- op : gammaln + args : (Tensor x) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + kernel : + func : gammaln + inplace: (x -> out) + backward : gammaln_grad + - op : gather args : (Tensor x, Tensor index, Scalar axis=0) output : Tensor(out) diff --git a/paddle/phi/kernels/cpu/gammaln_grad_kernel.cc b/paddle/phi/kernels/cpu/gammaln_grad_kernel.cc new file mode 100644 index 00000000000000..c52ee8b3848e9a --- /dev/null +++ b/paddle/phi/kernels/cpu/gammaln_grad_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gammaln_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h" + +PD_REGISTER_KERNEL( + gammaln_grad, CPU, ALL_LAYOUT, phi::GammalnGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/gammaln_kernel.cc b/paddle/phi/kernels/cpu/gammaln_kernel.cc new file mode 100644 index 00000000000000..ff62f86d2522fd --- /dev/null +++ b/paddle/phi/kernels/cpu/gammaln_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gammaln_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/gammaln_kernel_impl.h" + +PD_REGISTER_KERNEL( + gammaln, CPU, ALL_LAYOUT, phi::GammalnKernel, float, double) {} diff --git a/paddle/phi/kernels/gammaln_grad_kernel.h b/paddle/phi/kernels/gammaln_grad_kernel.h new file mode 100644 index 00000000000000..440dca72a9d469 --- /dev/null +++ b/paddle/phi/kernels/gammaln_grad_kernel.h @@ -0,0 +1,27 @@ + +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GammalnGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& d_out, + DenseTensor* d_x); +} // namespace phi diff --git a/paddle/phi/kernels/gammaln_kernel.h b/paddle/phi/kernels/gammaln_kernel.h new file mode 100644 index 00000000000000..db3015c4a747db --- /dev/null +++ b/paddle/phi/kernels/gammaln_kernel.h @@ -0,0 +1,26 @@ + +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GammalnKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out); +} // namespace phi diff --git a/paddle/phi/kernels/gpu/gammaln_grad_kernel.cu b/paddle/phi/kernels/gpu/gammaln_grad_kernel.cu new file mode 100644 index 00000000000000..b2513d9e3f25ca --- /dev/null +++ b/paddle/phi/kernels/gpu/gammaln_grad_kernel.cu @@ -0,0 +1,30 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gammaln_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(gammaln_grad, + GPU, + ALL_LAYOUT, + phi::GammalnGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/gammaln_kernel.cu b/paddle/phi/kernels/gpu/gammaln_kernel.cu new file mode 100644 index 00000000000000..3d57be7b277335 --- /dev/null +++ b/paddle/phi/kernels/gpu/gammaln_kernel.cu @@ -0,0 +1,29 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gammaln_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/impl/gammaln_kernel_impl.h" + +PD_REGISTER_KERNEL(gammaln, + GPU, + ALL_LAYOUT, + phi::GammalnKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h new file mode 100644 index 00000000000000..50c73cff27ce4a --- /dev/null +++ b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h @@ -0,0 +1,92 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/kernels/funcs/for_range.h" + +namespace phi { +template +HOSTDEVICE T digamma(T x) { + static T c = T{8.5}; + static T euler_mascheroni = T{0.57721566490153286060}; + T r; + T value; + T x2; + + if (x <= T{0.0}) { + value = T{0.0}; + return value; + } + + if (x <= T{0.000001}) { + value = -euler_mascheroni - T{1.0} / x + T{1.6449340668482264365} * x; + return value; + } + + value = T{0.0}; + x2 = x; + while (x2 < c) { + value = value - T{1.0} / x2; + x2 = x2 + T{1.0}; + } + + r = T{1.0} / x2; + value = value + std::log(x2) - T{0.5} * r; + + r = r * r; + + value = value - + r * (T{1.0} / T{12.0} - + r * (T{1.0} / T{120.0} - + r * (T{1.0} / T{252.0} - + r * (T{1.0} / T{240.0} - r * (T{1.0} / T{132.0}))))); + + return value; +} + +template +struct GammalnGradFunctor { + GammalnGradFunctor(const T* dout, const T* x, T* output, int64_t numel) + : dout_(dout), x_(x), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + using MT = typename phi::dtype::MPTypeTrait::Type; + const MT mp_dout = static_cast(dout_[idx]); + const MT mp_x = static_cast(x_[idx]); + output_[idx] = static_cast(mp_dout * digamma(mp_x)); + } + + private: + const T* dout_; + const T* x_; + T* output_; + int64_t numel_; +}; +template +void GammalnGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& d_out, + DenseTensor* d_x) { + auto numel = d_out.numel(); + auto* dout_data = d_out.data(); + auto* x_data = x.data(); + auto* dx_data = + dev_ctx.template Alloc(d_x, static_cast(numel * sizeof(T))); + phi::funcs::ForRange for_range(dev_ctx, numel); + GammalnGradFunctor functor(dout_data, x_data, dx_data, numel); + for_range(functor); +} +} // namespace phi diff --git a/paddle/phi/kernels/impl/gammaln_kernel_impl.h b/paddle/phi/kernels/impl/gammaln_kernel_impl.h new file mode 100644 index 00000000000000..38385610de0de6 --- /dev/null +++ b/paddle/phi/kernels/impl/gammaln_kernel_impl.h @@ -0,0 +1,49 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/kernels/funcs/for_range.h" + +namespace phi { +template +struct GammalnFunctor { + GammalnFunctor(const T* x, T* output, int64_t numel) + : x_(x), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + using MT = typename phi::dtype::MPTypeTrait::Type; + const MT mp_x = static_cast(x_[idx]); + output_[idx] = static_cast(std::lgamma(mp_x)); + } + + private: + const T* x_; + T* output_; + int64_t numel_; +}; + +template +void GammalnKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + auto numel = x.numel(); + auto* x_data = x.data(); + auto* out_data = dev_ctx.template Alloc(out); + phi::funcs::ForRange for_range(dev_ctx, numel); + GammalnFunctor functor(x_data, out_data, numel); + for_range(functor); +} +} // namespace phi diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index fc7b2a3533f892..1f0017562ebade 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -398,6 +398,8 @@ frac, frac_, frexp, + gammaln, + gammaln_, gcd, gcd_, heaviside, @@ -773,6 +775,8 @@ 'square_', 'divide', 'divide_', + 'gammaln', + 'gammaln_', 'ceil', 'atan', 'atan_', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index b26798892a2b2f..b718910348d8ff 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -278,6 +278,8 @@ frac, frac_, frexp, + gammaln, + gammaln_, gcd, gcd_, heaviside, @@ -668,6 +670,8 @@ 'real', 'imag', 'is_floating_point', + 'gammaln', + 'gammaln_', 'digamma', 'digamma_', 'diagonal', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index acaa0905ce6f40..6d75d41b4949ca 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -5003,6 +5003,51 @@ def conj(x, name=None): return out +def gammaln(x, name=None): + r""" + Calculates the logarithm of the absolute value of the gamma function elementwisely. + + Args: + x (Tensor): Input Tensor. Must be one of the following types: float16, float32, float64, bfloat16. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, The values of the logarithm of the absolute value of the gamma at the given tensor x. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> x = paddle.arange(1.5, 4.5, 0.5) + >>> out = paddle.gammaln(x) + >>> print(out) + Tensor(shape=[6], dtype=float32, place=Place(cpu), stop_gradient=True, + [-0.12078224, 0. , 0.28468287, 0.69314718, 1.20097363, + 1.79175949]) + """ + if in_dynamic_or_pir_mode(): + return _C_ops.gammaln(x) + else: + check_variable_and_dtype( + x, 'x', ['float16', 'float32', 'float64', 'bfloat16'], 'gammaln' + ) + helper = LayerHelper('gammaln', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op(type='gammaln', inputs={'x': x}, outputs={'out': out}) + return out + + +@inplace_apis_in_dygraph_only +def gammaln_(x, name=None): + r""" + Inplace version of ``gammaln`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_gammaln`. + """ + if in_dynamic_mode(): + return _C_ops.gammaln_(x) + + def digamma(x, name=None): r""" Calculates the digamma of the given input tensor, element-wise. diff --git a/test/legacy_test/test_gammaln_op.py b/test/legacy_test/test_gammaln_op.py new file mode 100644 index 00000000000000..50331af5c7a34c --- /dev/null +++ b/test/legacy_test/test_gammaln_op.py @@ -0,0 +1,160 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import OpTest, convert_float_to_uint16 +from scipy import special + +import paddle +from paddle.base import core + + +def ref_gammaln(x): + return special.gammaln(x) + + +def ref_gammaln_grad(x, dout): + return dout * special.polygamma(0, x) + + +class TestGammalnOp(OpTest): + def setUp(self): + self.op_type = 'gammaln' + self.python_api = paddle.gammaln + self.init_dtype_type() + self.shape = (3, 40) + self.x = np.random.random(self.shape).astype(self.dtype) + 1 + self.inputs = {'x': self.x} + out = ref_gammaln(self.x) + self.outputs = {'out': out} + + def init_dtype_type(self): + self.dtype = np.float64 + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad(self): + self.check_grad(['x'], 'out', check_pir=True) + + +class TestGammalnOpFp32(TestGammalnOp): + def init_dtype_type(self): + self.dtype = np.float32 + + +class TestGammalnFP16Op(TestGammalnOp): + def init_dtype_type(self): + self.dtype = np.float16 + + +class TestGammalnBigNumberOp(TestGammalnOp): + def setUp(self): + self.op_type = 'gammaln' + self.python_api = paddle.gammaln + self.init_dtype_type() + self.shape = (100, 1) + self.x = np.random.random(self.shape).astype(self.dtype) + 1 + self.x[:5, 0] = np.array([1e5, 1e10, 1e20, 1e40, 1e80]) + self.inputs = {'x': self.x} + out = ref_gammaln(self.x) + self.outputs = {'out': out} + + def init_dtype_type(self): + self.dtype = np.float64 + + def test_check_grad(self): + d_out = self.outputs['out'] + d_x = ref_gammaln_grad(self.x, d_out) + self.check_grad( + ['x'], + 'out', + user_defined_grads=[ + d_x, + ], + user_defined_grad_outputs=[ + d_out, + ], + check_pir=True, + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support bfloat16", +) +class TestGammalnBF16Op(OpTest): + def setUp(self): + self.op_type = 'gammaln' + self.python_api = paddle.gammaln + self.dtype = np.uint16 + self.shape = (5, 30) + x = np.random.random(self.shape).astype("float32") + 1 + self.inputs = {'x': convert_float_to_uint16(x)} + out = ref_gammaln(x) + self.outputs = {'out': convert_float_to_uint16(out)} + + def test_check_output(self): + self.check_output_with_place(core.CUDAPlace(0), check_pir=True) + + def test_check_grad(self): + self.check_grad_with_place( + core.CUDAPlace(0), ['x'], 'out', check_pir=True + ) + + +class TestGammalnOpApi(unittest.TestCase): + def setUp(self): + self.shape = [2, 3, 4, 5] + self.init_dtype_type() + self.x_np = np.random.random(self.shape).astype(self.dtype) + 1 + self.place = ( + paddle.CUDAPlace(0) + if core.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def init_dtype_type(self): + self.dtype = "float64" + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('x', self.x_np.shape, self.x_np.dtype) + out = paddle.gammaln(x) + exe = paddle.static.Executor(self.place) + (res,) = exe.run(feed={'x': self.x_np}, fetch_list=[out]) + out_ref = ref_gammaln(self.x_np) + np.testing.assert_allclose(out_ref, res, rtol=1e-5, atol=1e-5) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out = paddle.gammaln(x) + out_ref = ref_gammaln(self.x_np) + np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-5, atol=1e-5) + paddle.enable_static() + + +class TestGammalnOpApiFp32(TestGammalnOpApi): + def init_dtype_type(self): + self.dtype = "float32" + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/test/legacy_test/test_inplace.py b/test/legacy_test/test_inplace.py index 42f9a46cfb9100..38fbac0357d6df 100644 --- a/test/legacy_test/test_inplace.py +++ b/test/legacy_test/test_inplace.py @@ -869,6 +869,14 @@ def test_leaf_inplace_var_error(self): pass +class TestDygraphInplaceGammaln(TestDygraphInplaceWithContinuous): + def inplace_api_processing(self, var): + return paddle.gammaln_(var) + + def non_inplace_api_processing(self, var): + return paddle.gammaln(var) + + class TestDygraphInplaceNeg(TestDygraphInplaceWithContinuous): def inplace_api_processing(self, var): return paddle.neg_(var) From b03482a24c8b5e6f1e44329c0f6a397d370f6061 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Thu, 28 Dec 2023 15:47:30 +0800 Subject: [PATCH 031/142] complete dsl test case for dynamic schedule primitive (#60428) --- .../ir/schedule/impl/loop_transformation.cc | 3 +- .../ir/test_llir_schedule_cache_read_write.py | 61 +++++++++++- test/cinn/ir/test_llir_schedule_compute_at.py | 93 +++++++++++++++++++ .../ir/test_llir_schedule_compute_inline.py | 77 +++++++++++++++ test/cinn/ir/test_llir_schedule_fuse_split.py | 84 +++++++++++++++++ test/cinn/ir/test_llir_schedule_reorder.py | 66 +++++++++++++ 6 files changed, 379 insertions(+), 5 deletions(-) diff --git a/paddle/cinn/ir/schedule/impl/loop_transformation.cc b/paddle/cinn/ir/schedule/impl/loop_transformation.cc index 4577db7770a738..c3a3ad448f5362 100644 --- a/paddle/cinn/ir/schedule/impl/loop_transformation.cc +++ b/paddle/cinn/ir/schedule/impl/loop_transformation.cc @@ -114,7 +114,8 @@ std::vector DyScheduleImpl::Split(const Expr& loop, for (auto factor : factors) prod_size = prod_size * Expr(factor); std::for_each(factors.begin(), factors.end(), [&](int factor) { if (factor == -1) { - process_factors.push_back(tot_extent / prod_size + Expr(1)); + process_factors.push_back( + cinn::common::AutoSimplify(tot_extent / prod_size + Expr(1))); } else { process_factors.push_back(Expr(factor)); } diff --git a/test/cinn/ir/test_llir_schedule_cache_read_write.py b/test/cinn/ir/test_llir_schedule_cache_read_write.py index 41f1fc8d342ab7..7dd8cb488e918f 100644 --- a/test/cinn/ir/test_llir_schedule_cache_read_write.py +++ b/test/cinn/ir/test_llir_schedule_cache_read_write.py @@ -28,6 +28,7 @@ def elementwise_add_cache_read( Y: DataArray((128, 128)), A: DataArray((128, 128)), A_local_temp_buffer: DataArray((128, 128)), + N: ir.Var(), ): for i in range(128): for j in range(128): @@ -49,6 +50,7 @@ def elementwise_add_cache_read( Y: DataArray((128, 128)), A: DataArray((128, 128)), A_local_temp_buffer: DataArray((128, 128)), + N: ir.Var(), ): for i in range(128): for j in range(128): @@ -68,10 +70,6 @@ def elementwise_add_cache_read( i1, j1 = ir.AxisMap("SS", [i3, j3]) Y[i1, j1] = -A_local_temp_buffer[i1, j1] + 3.0 - assert str(origin.elementwise_add_cache_read) == str( - expected.elementwise_add_cache_read - ) - def test_cache_write_elementwise(): @to_cinn_llir @@ -98,6 +96,61 @@ def elementwise_add_cache_write( # assert_llir_equal(elementwise_add_cache_write, elementwise_add_cache_write) +def test_cache_read_elementwise_dynamic(): + class origin: + @to_cinn_llir + def elementwise_add_cache_read( + X: DataArray((-1, 128)), + Y: DataArray((-1, 128)), + A: DataArray((-1, 128)), + A_local_temp_buffer: DataArray((-1, 128)), + N: ir.Var(), + ): + for i in range(N): + for j in range(128): + with ir.ScheduleBlockContext("A") as A_block: + i1, j1 = ir.AxisMap("SS", [i, j]) + A[i1, j1] = X[i1, j1] * 2.0 + for i3 in range(N): + for j3 in range(128): + with ir.ScheduleBlockContext("B") as B_block: + i1, j1 = ir.AxisMap("SS", [i3, j3]) + Y[i1, j1] = -A[i1, j1] + 3.0 + + cached_b = sch.cache_read(B_block.block, 0, "local") + + class expected: + @to_cinn_llir + def elementwise_add_cache_read( + X: DataArray((-1, 128)), + Y: DataArray((-1, 128)), + A: DataArray((-1, 128)), + A_local_temp_buffer: DataArray((-1, 128)), + N: ir.Var(), + ): + for i in range(N): + for j in range(128): + with ir.ScheduleBlockContext("A") as A_block: + i1, j1 = ir.AxisMap("SS", [i, j]) + A[i1, j1] = X[i1, j1] * 2.0 + for cache_ax0 in range(N): + for cache_ax1 in range(128): + with ir.ScheduleBlockContext( + "A_local_temp_buffer" + ) as A_local_temp_buffer_block: + v0, v1 = ir.AxisMap("SS", [cache_ax0, cache_ax1]) + A_local_temp_buffer[v0, v1] = A[v0, v1] + for i3 in range(N): + for j3 in range(128): + with ir.ScheduleBlockContext("B") as B_block: + i1, j1 = ir.AxisMap("SS", [i3, j3]) + Y[i1, j1] = -A_local_temp_buffer[i1, j1] + 3.0 + + assert str(origin.elementwise_add_cache_read) == str( + expected.elementwise_add_cache_read + ) + + if __name__ == "__main__": test_cache_read_elementwise() test_cache_write_elementwise() diff --git a/test/cinn/ir/test_llir_schedule_compute_at.py b/test/cinn/ir/test_llir_schedule_compute_at.py index 0f82786935b411..4c96ff23436ae4 100644 --- a/test/cinn/ir/test_llir_schedule_compute_at.py +++ b/test/cinn/ir/test_llir_schedule_compute_at.py @@ -106,6 +106,99 @@ def reverse_compute_at_tiled_gt( assert_llir_equal(reverse_compute_at_tiled, reverse_compute_at_tiled_gt) +def test_compute_at_elementwise_dynamic(): + @to_cinn_llir + def elementwise_add( + X: DataArray((-1, 128)), + Y: DataArray((-1, 128)), + A: DataArray((-1, 128)), + N: ir.Var(), + ): + for i in range(N): + for j in range(128): + with ir.ScheduleBlockContext("A") as A_block: + i1, j1 = ir.AxisMap("SS", [i, j]) + A[i1, j1] = X[i1, j1] * 2.0 + for i in range(N): + for j in range(128): + with ir.ScheduleBlockContext("Y"): + i1, j1 = ir.AxisMap("SS", [i, j]) + sch.compute_at(A_block.block, i, False) + Y[i1, j1] = A[i1, j1] + 2.0 + + @to_cinn_llir + def elementwise_add_gt( + X: DataArray((-1, 128)), + Y: DataArray((-1, 128)), + A: DataArray((-1, 128)), + N: ir.Var(), + ): + for i in range(N): + for j in range(128): + with ir.ScheduleBlockContext("A"): + i1, j1 = ir.AxisMap("SS", [i, 0 + j]) + A[i1, j1] = X[i1, j1] * 2.0 + for k in range(128): + with ir.ScheduleBlockContext("Y"): + i2, k1 = ir.AxisMap("SS", [i, k]) + Y[i2, k1] = A[i2, k1] + 2.0 + + assert_llir_equal(elementwise_add, elementwise_add_gt) + + +def test_reverse_compute_at_dynamic(): + @to_cinn_llir + def reverse_compute_at_tiled( + A: DataArray((-1, 128)), + B: DataArray((-1, 128)), + C: DataArray((-1, 128)), + N: ir.Var(), + ): + for i0 in range(N / 16): + for j0 in range(8): + for i1 in range(16): + for j1 in range(16): + with ir.ScheduleBlockContext("B") as B_block: + vi, vj = ir.AxisMap( + "SS", [i0 * 16 + i1, j0 * 16 + j1] + ) + B[vi, vj] = A[vi, vj] * 2.0 + for i in range(N): + for j in range(128): + with ir.ScheduleBlockContext("C") as C_block: + vi, vj = ir.AxisMap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + sch.reverse_compute_at(C_block.block, B_block.i1) + + @to_cinn_llir + def reverse_compute_at_tiled_gt( + A: DataArray((-1, 128)), + B: DataArray((-1, 128)), + C: DataArray((-1, 128)), + N: ir.Var(), + ): + for i0 in range(N / 16): + for j0 in range(8): + for i1 in range(16): + for j1 in range(16): + with ir.ScheduleBlockContext("B") as B_block: + vi, vj = ir.AxisMap( + "SS", [i0 * 16 + i1, j0 * 16 + j1] + ) + B[vi, vj] = A[vi, vj] * 2.0 + for j2 in range(16): + with ir.ScheduleBlockContext("C") as C_block: + vi, vj = ir.AxisMap( + "SS", [16 * i0 + i1, 16 * j0 + j2] + ) + C[vi, vj] = B[vi, vj] + 1.0 + + assert_llir_equal(reverse_compute_at_tiled, reverse_compute_at_tiled_gt) + + if __name__ == '__main__': test_compute_at_elementwise() test_reverse_compute_at() + test_compute_at_elementwise_dynamic() + test_reverse_compute_at_dynamic() diff --git a/test/cinn/ir/test_llir_schedule_compute_inline.py b/test/cinn/ir/test_llir_schedule_compute_inline.py index a95d1dd8174495..113c0b7dfe6216 100644 --- a/test/cinn/ir/test_llir_schedule_compute_inline.py +++ b/test/cinn/ir/test_llir_schedule_compute_inline.py @@ -90,6 +90,83 @@ def elementwise_add_inline_gt( assert_llir_equal(elementwise_add_inline, elementwise_add_inline_gt) +def test_compute_inline_elementwise_dynamic(): + @to_cinn_llir + def elementwise_add_inline( + X: DataArray((-1, 128)), + Y: DataArray((-1, 128)), + A: DataArray((-1, 128)), + N: ir.Var(), + ): + for i in range(N): + for j in range(128): + with ir.ScheduleBlockContext("A") as A_block: + i1, j1 = ir.AxisMap("SS", [i, j]) + A[i1, j1] = X[i1, j1] * 2.0 + for i3 in range(N): + for j3 in range(128): + with ir.ScheduleBlockContext("Y"): + i1, j1 = ir.AxisMap("SS", [i3, j3]) + Y[i1, j1] = -A[i1, j1] + 3.0 + + block_a = sch.get_block("A") + sch.compute_inline(block_a) + + @to_cinn_llir + def elementwise_add_inline_gt( + X: DataArray((-1, 128)), + Y: DataArray((-1, 128)), + A: DataArray((-1, 128)), + N: ir.Var(), + ): + for i in range(N): + for j in range(128): + with ir.ScheduleBlockContext("Y"): + i1, j1 = ir.AxisMap("SS", [i, j]) + Y[i1, j1] = -(X[i1, j1] * 2.0) + 3.0 + + assert_llir_equal(elementwise_add_inline, elementwise_add_inline_gt) + + +def test_reverse_compute_inline_elementwise_dynamic(): + @to_cinn_llir + def elementwise_add_inline( + X: DataArray((-1, 128)), + Y: DataArray((-1, 128)), + A: DataArray((-1, 128)), + N: ir.Var(), + ): + for i in range(N): + for j in range(128): + with ir.ScheduleBlockContext("A") as A_block: + i1, j1 = ir.AxisMap("SS", [i, j]) + A[i1, j1] = X[i1, j1] * 2.0 + for i3 in range(-1): + for j3 in range(128): + with ir.ScheduleBlockContext("Y") as Y_block: + i1, j1 = ir.AxisMap("SS", [i3, j3]) + Y[i1, j1] = -A[i1, j1] + 3.0 + + sch.reverse_compute_inline(Y_block.block) + + @to_cinn_llir + def elementwise_add_inline_gt( + X: DataArray((-1, 128)), + Y: DataArray((-1, 128)), + A: DataArray((-1, 128)), + N: ir.Var(), + ): + for i in range(N): + for j in range(128): + with ir.ScheduleBlockContext("A"): + i1, j1 = ir.AxisMap("SS", [i, j]) + Y[i1, j1] = -(X[i1, j1] * 2.0) + 3.0 + + assert_llir_equal(elementwise_add_inline, elementwise_add_inline_gt) + + if __name__ == "__main__": test_compute_inline_elementwise() test_reverse_compute_inline_elementwise() + test_compute_inline_elementwise_dynamic() + test_reverse_compute_inline_elementwise_dynamic() diff --git a/test/cinn/ir/test_llir_schedule_fuse_split.py b/test/cinn/ir/test_llir_schedule_fuse_split.py index 07712590b9ac16..362cb81f87b964 100644 --- a/test/cinn/ir/test_llir_schedule_fuse_split.py +++ b/test/cinn/ir/test_llir_schedule_fuse_split.py @@ -125,7 +125,91 @@ def elementwise_split_predicate_gt( ) +def test_fuse_dynamic(): + class origin: + @to_cinn_llir + def elementwise_fuse_assign_loop( + X: DataArray((-1, 128, 128)), + Y: DataArray((-1, 128, 128)), + N: ir.Var(), + ): + for i in range(N): + for j in range(128): + for k in range(128): + with ir.ScheduleBlockContext("Y") as block_y: + sch.fuse([i, j, k]) + i1, j1, k1 = ir.AxisMap("SSS", [i, j, k]) + Y[i1, j1, k1] = X[i1, j1, k1] * 2.0 + + class expected: + @to_cinn_llir + def elementwise_fuse_assign_loop( + X: DataArray((-1, 128, 128)), + Y: DataArray((-1, 128, 128)), + N: ir.Var(), + ): + for i_j_k_fused in range(((1 * N) * 128) * 128): + with ir.ScheduleBlockContext("Y") as block_y: + i1, j1, k1 = ir.AxisMap( + "SSS", + [ + (i_j_k_fused / 128) / 128, + (i_j_k_fused / 128) % 128, + i_j_k_fused % 128, + ], + ) + Y[i1, j1, k1] = 2.0 * X[i1, j1, k1] + + assert str(origin.elementwise_fuse_assign_loop) == str( + expected.elementwise_fuse_assign_loop + ) + + +def test_split_dynamic(): + class origin: + @to_cinn_llir + def elementwise_split( + X: DataArray((128, 128, -1)), + Y: DataArray((128, 128, -1)), + N: ir.Var(), + ): + for i in range(128): + for j in range(128): + for k in range(N): + with ir.ScheduleBlockContext("Y") as Y_block: + i1, j1, k1 = ir.AxisMap("SSS", [i, j, k]) + sch.split(Y_block.k, factors=[16, -1]) + Y[i1, j1, k1] = X[i1, j1, k1] * 2.0 + + class expected: + @to_cinn_llir + def elementwise_split( + X: DataArray((128, 128, -1)), + Y: DataArray((128, 128, -1)), + N: ir.Var(), + ): + for i in range(128): + for j in range(128): + for k_7 in range(16): + for k_8 in range((N / 16) + 1): + if (((N / 16) * k_7) + (k_7 + k_8)) < N: + with ir.ScheduleBlockContext("Y") as Y_block: + i1, j1, k1 = ir.AxisMap( + "SSS", + [ + i, + j, + (((N / 16) * k_7) + (k_7 + k_8)), + ], + ) + Y[i1, j1, k1] = X[i1, j1, k1] * 2.0 + + assert_llir_equal(origin.elementwise_split, expected.elementwise_split) + + if __name__ == "__main__": test_fuse() test_split() test_split_predicate() + test_fuse_dynamic() + test_split_dynamic() diff --git a/test/cinn/ir/test_llir_schedule_reorder.py b/test/cinn/ir/test_llir_schedule_reorder.py index 00ca99388ba941..254197beb222a1 100644 --- a/test/cinn/ir/test_llir_schedule_reorder.py +++ b/test/cinn/ir/test_llir_schedule_reorder.py @@ -75,6 +75,72 @@ def reorder_overlapped_gt(X: DataArray((28, 8)), Y: DataArray((28, 8))): assert_llir_equal(reorder_overlapped, reorder_overlapped_gt) +def test_reorder_elementwise_dynamic(): + @to_cinn_llir + def reorder_elementwise( + X: DataArray((-1, 64, 64, 64)), + Y: DataArray((-1, 64, 64, 64)), + N: ir.Var(), + ): + for i in range(N): + for j in range(64): + for k in range(64): + for l in range(8): + with ir.ScheduleBlockContext("Y") as Y_block: + vi, vj, vk, vl = ir.AxisMap( + "SSSS", [i, j, k, 8 * l] + ) + Y[vi, vj, vk, vl] = X[vi, vj, vk, vl] * 2.0 + sch.reorder([Y_block.k, Y_block.l, Y_block.i]) + + @to_cinn_llir + def reorder_elementwise_gt( + X: DataArray((-1, 64, 64, 64)), + Y: DataArray((-1, 64, 64, 64)), + N: ir.Var(), + ): + for k in range(64): + for j in range(64): + for l in range(8): + for i in range(N): + with ir.ScheduleBlockContext("Y"): + vi, vj, vk, vl = ir.AxisMap( + "SSSS", [i, j, k, 8 * l] + ) + Y[vi, vj, vk, vl] = X[vi, vj, vk, vl] * 2.0 + + assert_llir_equal(reorder_elementwise, reorder_elementwise_gt) + + +def test_reorder_overlapped_dynamic(): + @to_cinn_llir + def reorder_overlapped( + X: DataArray((-1, 8)), Y: DataArray((-1, 8)), N: ir.Var() + ): + for i in range(N / 4): + for j in range(4): + for k in range(4): + with ir.ScheduleBlockContext("Y"): + vi, vj = ir.AxisMap("SS", [i, j]) + sch.reorder([i, k, j]) + Y[vi, vj] = X[vi, vj] + 1.0 + + @to_cinn_llir + def reorder_overlapped_gt( + X: DataArray((-1, 8)), Y: DataArray((-1, 8)), N: ir.Var() + ): + for i in range(N / 4): + for k in range(4): + for j in range(4): + with ir.ScheduleBlockContext("Y"): + vi, vj = ir.AxisMap("SS", [i, j]) + Y[vi, vj] = X[vi, vj] + 1.0 + + assert_llir_equal(reorder_overlapped, reorder_overlapped_gt) + + if __name__ == '__main__': test_reorder_elementwise() test_reorder_overlapped() + test_reorder_elementwise_dynamic() + test_reorder_overlapped_dynamic() From 65e2d934caf47a4ecb81730ff6172d57982fbe6a Mon Sep 17 00:00:00 2001 From: Ryan <44900829+DrRyanHuang@users.noreply.github.com> Date: Thu, 28 Dec 2023 16:28:54 +0800 Subject: [PATCH 032/142] [Dy2St] Remove `NodeVarType` (#60381) --- python/paddle/jit/dy2static/__init__.py | 2 +- .../paddle/jit/dy2static/static_analysis.py | 62 ++++---- .../transformers/loop_transformer.py | 14 +- python/paddle/jit/dy2static/utils.py | 1 - python/paddle/jit/dy2static/utils_helper.py | 132 +++++++----------- .../dygraph_to_static/test_static_analysis.py | 72 +++++----- 6 files changed, 115 insertions(+), 168 deletions(-) diff --git a/python/paddle/jit/dy2static/__init__.py b/python/paddle/jit/dy2static/__init__.py index 83535ac17aee67..d2c90a2c852dbf 100644 --- a/python/paddle/jit/dy2static/__init__.py +++ b/python/paddle/jit/dy2static/__init__.py @@ -30,7 +30,7 @@ unpack_by_structure as Unpack, ) from .program_translator import convert_to_static # noqa: F401 -from .static_analysis import NodeVarType, StaticAnalysisVisitor # noqa: F401 +from .static_analysis import StaticAnalysisVisitor # noqa: F401 from .transformers import DygraphToStaticAst # noqa: F401 from .utils import UndefinedVar, ast_to_source_code, saw # noqa: F401 from .variable_trans_func import ( # noqa: F401 diff --git a/python/paddle/jit/dy2static/static_analysis.py b/python/paddle/jit/dy2static/static_analysis.py index 81bfa589b018f5..c239e8aaacf489 100644 --- a/python/paddle/jit/dy2static/static_analysis.py +++ b/python/paddle/jit/dy2static/static_analysis.py @@ -15,11 +15,12 @@ from paddle.utils import gast from .utils_helper import ( - NodeVarType, + binary_op_output_type, index_in_list, is_dygraph_api, is_numpy_api, is_paddle_api, + type_from_annotation, ) __all__ = [] @@ -37,7 +38,7 @@ def __init__(self, node): self.node = node self.parent = None self.children = [] - self.node_var_type = {NodeVarType.UNKNOWN} + self.node_var_type = {"UNKNOWN"} class StaticAnalysisVisitor: @@ -87,7 +88,7 @@ def get_node_to_wrapper_map(self): return self.node_to_wrapper_map def is_tensor_node(self, node): - tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES} + tensor_types = {"TENSOR", "PADDLE_RETURN_TYPES"} node_wrapper = self.node_to_wrapper_map.get(node, None) if node_wrapper is None: return False @@ -101,17 +102,17 @@ def _get_constant_node_type(self, node): ) # singleton: None, True or False if node.value is None: - return {NodeVarType.NONE} + return {"NONE"} if isinstance(node.value, bool): - return {NodeVarType.BOOLEAN} + return {"BOOLEAN"} if isinstance(node.value, int): - return {NodeVarType.INT} + return {"INT"} if isinstance(node.value, float): - return {NodeVarType.FLOAT} + return {"FLOAT"} if isinstance(node.value, str): - return {NodeVarType.STRING} + return {"STRING"} - return {NodeVarType.UNKNOWN} + return {"UNKNOWN"} def _get_node_var_type(self, cur_wrapper): node = cur_wrapper.node @@ -119,14 +120,14 @@ def _get_node_var_type(self, cur_wrapper): return self._get_constant_node_type(node) if isinstance(node, gast.BoolOp): - return {NodeVarType.BOOLEAN} + return {"BOOLEAN"} if isinstance(node, gast.Compare): - return {NodeVarType.BOOLEAN} + return {"BOOLEAN"} if isinstance(node, gast.Dict): - return {NodeVarType.DICT} + return {"DICT"} if isinstance(node, gast.Set): - return {NodeVarType.SET} + return {"SET"} if isinstance(node, gast.UnaryOp): return self.node_to_wrapper_map[node.operand].node_var_type @@ -137,7 +138,7 @@ def _get_node_var_type(self, cur_wrapper): result_type = set() for l in left_type: for r in right_type: - result_type.add(NodeVarType.binary_op_output_type(l, r)) + result_type.add(binary_op_output_type(l, r)) return result_type if isinstance(node, gast.Assign): @@ -157,16 +158,13 @@ def _get_node_var_type(self, cur_wrapper): if isinstance(node, gast.AnnAssign): # TODO(0x45f): To determine whether need to support assignment statements # like `self.x: float = 2.1`. - ret_type = {NodeVarType.type_from_annotation(node.annotation)} + ret_type = {type_from_annotation(node.annotation)} # if annotation and value(Constant) are diffent type, we use value type if node.value: node_value_type = self.node_to_wrapper_map[ node.value ].node_var_type - if not ( - node_value_type - & {NodeVarType.UNKNOWN, NodeVarType.STATEMENT} - ): + if not (node_value_type & {"UNKNOWN", "STATEMENT"}): ret_type = node_value_type if isinstance(node.target, gast.Name): self.node_to_wrapper_map[node.target].node_var_type = ret_type @@ -174,9 +172,9 @@ def _get_node_var_type(self, cur_wrapper): if isinstance(node, gast.Name): if node.id == "None": - return {NodeVarType.NONE} + return {"NONE"} if node.id in {"True", "False"}: - return {NodeVarType.BOOLEAN} + return {"BOOLEAN"} # If node is child of functionDef.arguments parent_node_wrapper = cur_wrapper.parent if parent_node_wrapper and isinstance( @@ -184,33 +182,33 @@ def _get_node_var_type(self, cur_wrapper): ): return self._get_func_argument_type(parent_node_wrapper, node) - return {NodeVarType.UNKNOWN} + return {"UNKNOWN"} if isinstance(node, gast.Return): # If return nothing: if node.value is None: - return {NodeVarType.NONE} + return {"NONE"} - return {NodeVarType.UNKNOWN} + return {"UNKNOWN"} if isinstance(node, gast.Call): if is_dygraph_api(node): if isinstance(node.func, gast.Attribute): if node.func.attr == "to_variable": - return {NodeVarType.TENSOR} + return {"TENSOR"} if is_paddle_api(node): - return {NodeVarType.PADDLE_RETURN_TYPES} + return {"PADDLE_RETURN_TYPES"} if is_numpy_api(node): # In this simple version we assume numpy api returns nd-array - return {NodeVarType.NUMPY_NDARRAY} + return {"NUMPY_NDARRAY"} if isinstance(node.func, gast.Name): - return {NodeVarType.UNKNOWN} + return {"UNKNOWN"} if isinstance(node, gast.Subscript): if self.is_tensor_node(node.value): - return {NodeVarType.TENSOR} + return {"TENSOR"} - return {NodeVarType.STATEMENT} + return {"STATEMENT"} def _get_func_argument_type(self, parent_node_wrapper, node): """ @@ -232,9 +230,9 @@ def _get_func_argument_type(self, parent_node_wrapper, node): assert isinstance(node, gast.Name) parent_node = parent_node_wrapper.node - var_type = {NodeVarType.UNKNOWN} + var_type = {"UNKNOWN"} if node.annotation is not None: - var_type = {NodeVarType.type_from_annotation(node.annotation)} + var_type = {type_from_annotation(node.annotation)} # if annotation and value(Constant) are diffent type, we use value type if parent_node.defaults: diff --git a/python/paddle/jit/dy2static/transformers/loop_transformer.py b/python/paddle/jit/dy2static/transformers/loop_transformer.py index 42c2a40a5ca988..2d2cfee1f97b0a 100644 --- a/python/paddle/jit/dy2static/transformers/loop_transformer.py +++ b/python/paddle/jit/dy2static/transformers/loop_transformer.py @@ -18,7 +18,7 @@ from paddle.base import unique_name from paddle.utils import gast -from ..static_analysis import NodeVarType, StaticAnalysisVisitor +from ..static_analysis import StaticAnalysisVisitor from ..utils import ( FOR_BODY_PREFIX, FOR_CONDITION_PREFIX, @@ -344,18 +344,6 @@ def _var_node_to_name(self, node): elif isinstance(node, gast.Attribute): return get_attribute_full_name(node) - def _node_var_type_is_basic(self, node_var_type): - basic_types = { - NodeVarType.BOOLEAN, - NodeVarType.INT, - NodeVarType.FLOAT, - NodeVarType.STRING, - } - for t in node_var_type: - if t in basic_types: - return True - return False - def _is_call_func_name_node(self, node): parent_node = self._get_parent_node(node) if isinstance(parent_node, gast.Call) and parent_node.func == node: diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index fc18ee5883e9ce..3061e9f47b7e80 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -42,7 +42,6 @@ DYGRAPH_MODULE_PREFIX, DYGRAPH_TO_STATIC_MODULE_PREFIX, PADDLE_MODULE_PREFIX, - NodeVarType, _is_api_in_module_helper, index_in_list, is_api_in_module, diff --git a/python/paddle/jit/dy2static/utils_helper.py b/python/paddle/jit/dy2static/utils_helper.py index 4f1ae014507394..9a55f23cf46db4 100644 --- a/python/paddle/jit/dy2static/utils_helper.py +++ b/python/paddle/jit/dy2static/utils_helper.py @@ -97,91 +97,53 @@ def is_paddle_api(node): return is_api_in_module(node, PADDLE_MODULE_PREFIX) -class NodeVarType: - """ - Enum class of python variable types. We have to know some variable types - during compile time to transfer AST. For example, a string variable and a - tensor variable in if clause may lead to different conversion from dygraph - to static graph. - """ - - ERROR = -1 # Returns when static analysis gets error - UNKNOWN = 0 # Reserve for AST nodes have not known the type - STATEMENT = 1 # For nodes representing statement (non-variable type) - CALLABLE = 2 - - # python data types - NONE = 100 - BOOLEAN = 101 - INT = 102 - FLOAT = 103 - STRING = 104 - TENSOR = 105 - NUMPY_NDARRAY = 106 - - # python collections - LIST = 200 - SET = 201 - DICT = 202 - - PADDLE_DYGRAPH_API = 300 - PADDLE_CONTROL_IF = 301 - PADDLE_CONTROL_WHILE = 302 - PADDLE_CONTROL_FOR = 303 - # Paddle API may not be visible to get source code. - # We use this enum value to denote the type return by a Paddle API - PADDLE_RETURN_TYPES = 304 - - # If node.node_var_type in TENSOR_TYPES, it can be considered as tensor-dependent. - TENSOR_TYPES = {TENSOR, PADDLE_RETURN_TYPES} - - Annotation_map = { - "Tensor": TENSOR, - "paddle.Tensor": TENSOR, - "int": INT, - "float": FLOAT, - "bool": BOOLEAN, - "str": STRING, - } - - @staticmethod - def binary_op_output_type(in_type1, in_type2): - if in_type1 == in_type2: - return in_type1 - - if in_type1 == NodeVarType.UNKNOWN: - return in_type2 - if in_type2 == NodeVarType.UNKNOWN: - return in_type1 - - supported_types = [ - NodeVarType.BOOLEAN, - NodeVarType.INT, - NodeVarType.FLOAT, - NodeVarType.NUMPY_NDARRAY, - NodeVarType.TENSOR, - NodeVarType.PADDLE_RETURN_TYPES, - ] - - if in_type1 not in supported_types: - return NodeVarType.UNKNOWN - if in_type2 not in supported_types: - return NodeVarType.UNKNOWN - - forbidden_types = [NodeVarType.NUMPY_NDARRAY, NodeVarType.TENSOR] - if in_type1 in forbidden_types and in_type2 in forbidden_types: - return NodeVarType.UNKNOWN - return max(in_type1, in_type2) - - @staticmethod - def type_from_annotation(annotation): - annotation_str = ast_to_source_code(annotation).strip() - if annotation_str in NodeVarType.Annotation_map: - return NodeVarType.Annotation_map[annotation_str] - - # raise warning if not found - warn("Currently we don't support annotation: %s" % annotation_str) - return NodeVarType.UNKNOWN +def binary_op_output_type(in_type1, in_type2): + if in_type1 == in_type2: + return in_type1 + + if in_type1 == "UNKNOWN": + return in_type2 + if in_type2 == "UNKNOWN": + return in_type1 + + supported_types = [ + "BOOLEAN", + "INT", + "FLOAT", + "NUMPY_NDARRAY", + "TENSOR", + "PADDLE_RETURN_TYPES", + ] + + if in_type1 not in supported_types: + return "UNKNOWN" + if in_type2 not in supported_types: + return "UNKNOWN" + + forbidden_types = ["NUMPY_NDARRAY", "TENSOR"] + if in_type1 in forbidden_types and in_type2 in forbidden_types: + return "UNKNOWN" + return max(in_type1, in_type2) + + +Annotation_map = { + "Tensor": "TENSOR", + "paddle.Tensor": "TENSOR", + "int": "INT", + "float": "FLOAT", + "bool": "BOOLEAN", + "str": "STRING", +} + + +def type_from_annotation(annotation): + annotation_str = ast_to_source_code(annotation).strip() + if annotation_str in Annotation_map: + return Annotation_map[annotation_str] + + # raise warning if not found + warn("Currently we don't support annotation: %s" % annotation_str) + return "UNKNOWN" def set_dynamic_shape(variable, shape_list): diff --git a/test/dygraph_to_static/test_static_analysis.py b/test/dygraph_to_static/test_static_analysis.py index ea44992a048449..889bf183d079c0 100644 --- a/test/dygraph_to_static/test_static_analysis.py +++ b/test/dygraph_to_static/test_static_analysis.py @@ -19,7 +19,7 @@ import paddle from paddle import base -from paddle.jit.dy2static import NodeVarType, StaticAnalysisVisitor +from paddle.jit.dy2static import StaticAnalysisVisitor from paddle.utils import gast @@ -42,7 +42,7 @@ def func_to_test2(x): return x -result_var_type2 = {'m': {NodeVarType.INT}} +result_var_type2 = {'m': {"INT"}} def func_to_test3(): @@ -61,18 +61,18 @@ def func_to_test3(): result_var_type3 = { - 'a': {NodeVarType.INT}, - 'b': {NodeVarType.FLOAT}, - 'c': {NodeVarType.FLOAT}, - 'd': {NodeVarType.FLOAT}, - 'e': {NodeVarType.BOOLEAN}, - 'f': {NodeVarType.INT}, - 'g': {NodeVarType.STRING}, - 'h': {NodeVarType.NONE}, - 'i': {NodeVarType.BOOLEAN}, - 'j': {NodeVarType.UNKNOWN}, - 'k': {NodeVarType.FLOAT}, - 'l': {NodeVarType.PADDLE_RETURN_TYPES}, + 'a': {"INT"}, + 'b': {"FLOAT"}, + 'c': {"FLOAT"}, + 'd': {"FLOAT"}, + 'e': {"BOOLEAN"}, + 'f': {"INT"}, + 'g': {"STRING"}, + 'h': {"NONE"}, + 'i': {"BOOLEAN"}, + 'j': {"UNKNOWN"}, + 'k': {"FLOAT"}, + 'l': {"PADDLE_RETURN_TYPES"}, } @@ -85,10 +85,10 @@ def func_to_test4(): result_var_type4 = { - 'a': {NodeVarType.NUMPY_NDARRAY}, - 'b': {NodeVarType.NUMPY_NDARRAY}, - 'c': {NodeVarType.TENSOR}, - 'd': {NodeVarType.TENSOR}, + 'a': {"NUMPY_NDARRAY"}, + 'b': {"NUMPY_NDARRAY"}, + 'c': {"TENSOR"}, + 'd': {"TENSOR"}, } @@ -112,13 +112,13 @@ def inner_unknown_func(x): result_var_type5 = { - 'a': {NodeVarType.INT}, - 'b': {NodeVarType.FLOAT, NodeVarType.BOOLEAN}, - 'c': {NodeVarType.UNKNOWN}, - 'd': {NodeVarType.PADDLE_RETURN_TYPES}, - 'inner_int_func': {NodeVarType.INT}, - 'inner_bool_float_func': {NodeVarType.FLOAT, NodeVarType.BOOLEAN}, - 'inner_unknown_func': {NodeVarType.UNKNOWN}, + 'a': {"INT"}, + 'b': {"FLOAT", "BOOLEAN"}, + 'c': {"UNKNOWN"}, + 'd': {"PADDLE_RETURN_TYPES"}, + 'inner_int_func': {"INT"}, + 'inner_bool_float_func': {"FLOAT", "BOOLEAN"}, + 'inner_unknown_func': {"UNKNOWN"}, } @@ -136,10 +136,10 @@ def add(x, y): result_var_type6 = { - 'i': {NodeVarType.INT}, - 'x': {NodeVarType.INT}, - 'y': {NodeVarType.INT}, - 'add': {NodeVarType.INT}, + 'i': {"INT"}, + 'x': {"INT"}, + 'y': {"INT"}, + 'add': {"INT"}, } @@ -150,13 +150,13 @@ def func_to_test7(a: int, b: float, c: paddle.Tensor, d: float = 'diff'): result_var_type7 = { - 'a': {NodeVarType.BOOLEAN}, - 'b': {NodeVarType.FLOAT}, - 'c': {NodeVarType.TENSOR}, - 'd': {NodeVarType.STRING}, - 'e': {NodeVarType.PADDLE_RETURN_TYPES}, - 'f': {NodeVarType.PADDLE_RETURN_TYPES}, - 'g': {NodeVarType.TENSOR}, + 'a': {"BOOLEAN"}, + 'b': {"FLOAT"}, + 'c': {"TENSOR"}, + 'd': {"STRING"}, + 'e': {"PADDLE_RETURN_TYPES"}, + 'f': {"PADDLE_RETURN_TYPES"}, + 'g': {"TENSOR"}, } test_funcs = [ From 76ce9bb2de84ee6ea052cfde6ce269be1e4d8baf Mon Sep 17 00:00:00 2001 From: pangengzheng <117730991+pangengzheng@users.noreply.github.com> Date: Thu, 28 Dec 2023 16:52:06 +0800 Subject: [PATCH 033/142] support save load optimizer master_weights (#60027) * exclude xpu * dedup tensor in state_dict * polish * support flatten and unflatten state_dict * test flatten * rename test * fix dedup tensor test * fix test * fix load state dict * rename * fix test * support save load optimizer master weights * add comment --- .../distributed/checkpoint/load_state_dict.py | 36 +++--- .../paddle/distributed/checkpoint/metadata.py | 1 + .../distributed/checkpoint/save_state_dict.py | 57 ++++++++-- python/paddle/distributed/checkpoint/utils.py | 44 +++++++- python/paddle/optimizer/optimizer.py | 30 +---- test/auto_parallel/CMakeLists.txt | 3 + ...i_auto_parallel_checkpoint_dedup_tensor.py | 68 ++++++++++++ ...uto_parallel_checkpoint_flatten_mapping.py | 74 ++++++++++++ .../semi_auto_parallel_shard_optimizer_api.py | 55 +++++++++ .../test_dist_checkpoint_utils.py | 105 ++++++++++++++++++ 10 files changed, 416 insertions(+), 57 deletions(-) create mode 100644 test/auto_parallel/semi_auto_parallel_checkpoint_dedup_tensor.py create mode 100644 test/auto_parallel/semi_auto_parallel_checkpoint_flatten_mapping.py create mode 100644 test/auto_parallel/test_dist_checkpoint_utils.py diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py index fda6b6f9174b5f..4ae82398713aee 100644 --- a/python/paddle/distributed/checkpoint/load_state_dict.py +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -405,9 +405,9 @@ def load_state_dict( assert isinstance( state_dict, dict ), "The state_dict should be a dictionary." - state_dict = flatten_state_dict(state_dict) - if len(state_dict) > 0: - for val in state_dict.values(): + flat_state_dict, mapping = flatten_state_dict(state_dict) + if len(flat_state_dict) > 0: + for val in flat_state_dict.values(): assert isinstance( val, paddle.Tensor ), f"Only support dygraph Tensor now, but is {val}" @@ -423,7 +423,7 @@ def load_state_dict( paddle.distributed.barrier(process_group) rank_to_files = get_rank_to_files( - path, state_dict, process_group, use_dist + path, flat_state_dict, process_group, use_dist ) if len(rank_to_files) <= 0: return @@ -434,16 +434,18 @@ def load_state_dict( ) # read_items: [ReadItem(local_tensor_index, rank, cur_offsets, storage_offsets, lengths)], # slice the storage local tensor in (storage_offsets, lengths) to assign the current tensor in (cur_offsets, lengths) in rank. - read_items = get_read_items(path, state_dict, process_group, use_dist) + read_items = get_read_items( + path, flat_state_dict, process_group, use_dist + ) storage_file_to_state_dict = {} logger.debug( - f"before load, state_dict:{state_dict},\n load_infos:{load_infos},\n read_items:{read_items}" + f"before load, state_dict:{flat_state_dict},\n load_infos:{load_infos},\n read_items:{read_items}" ) state_dict_in_cpu = [] - for k, v in state_dict.items(): + for k, v in flat_state_dict.items(): if v.place.is_cpu_place(): state_dict_in_cpu.append(k) - state_dict[k] = v.cuda() + flat_state_dict[k] = v.cuda() for item in read_items: assert ( item.local_tensor_index in load_infos @@ -484,15 +486,17 @@ def load_state_dict( # The read item rank need to be assigned if item.rank == paddle.distributed.get_rank(): assert ( - item.local_tensor_index.tensor_key in state_dict - ), f"item:{item}, state_dict:{state_dict}" + item.local_tensor_index.tensor_key in flat_state_dict + ), f"item:{item}, state_dict:{flat_state_dict}" cur_local_tensor = ( - state_dict[ + flat_state_dict[ item.local_tensor_index.tensor_key ]._local_value() if use_dist - and state_dict[item.local_tensor_index.tensor_key].is_dist() - else state_dict[item.local_tensor_index.tensor_key] + and flat_state_dict[ + item.local_tensor_index.tensor_key + ].is_dist() + else flat_state_dict[item.local_tensor_index.tensor_key] ) cur_offsets = item.cur_offset cur_lengths = item.lengths @@ -513,7 +517,9 @@ def load_state_dict( else: cur_chunk_tensor = paddle.zeros( item.lengths, - dtype=state_dict[item.local_tensor_index.tensor_key].dtype, + dtype=flat_state_dict[ + item.local_tensor_index.tensor_key + ].dtype, ) if src_rank == item.rank: @@ -530,6 +536,6 @@ def load_state_dict( cur_chunk_tensor, src=src_rank, group=process_group ) - for k, v in state_dict.items(): + for k, v in flat_state_dict.items(): if k in state_dict_in_cpu: state_dict[k] = v.cpu() diff --git a/python/paddle/distributed/checkpoint/metadata.py b/python/paddle/distributed/checkpoint/metadata.py index 4eb5d559a9c0c4..d1f3a3fdb66c07 100644 --- a/python/paddle/distributed/checkpoint/metadata.py +++ b/python/paddle/distributed/checkpoint/metadata.py @@ -40,3 +40,4 @@ class LocalTensorIndex: class Metadata: state_dict_metadata: Dict[str, List[LocalTensorMetadata]] = None storage_metadata: Dict[LocalTensorIndex, str] = None + flat_mapping: Dict[str, Tuple[str]] = None diff --git a/python/paddle/distributed/checkpoint/save_state_dict.py b/python/paddle/distributed/checkpoint/save_state_dict.py index b2c380c66ba2f4..86047e637e3609 100644 --- a/python/paddle/distributed/checkpoint/save_state_dict.py +++ b/python/paddle/distributed/checkpoint/save_state_dict.py @@ -13,7 +13,6 @@ # limitations under the License. import os -from typing import List import paddle from paddle.distributed.communication.group import is_initialized @@ -50,7 +49,7 @@ def check_file_name(file_name, process_group): def merge_state_dict_metadata(global_state_dict_metadata): assert isinstance( - global_state_dict_metadata, List + global_state_dict_metadata, list ), "The global_state_dict should be a list." out = {} for state_dict in global_state_dict_metadata: @@ -64,7 +63,7 @@ def merge_state_dict_metadata(global_state_dict_metadata): return out -def dedup_storage_metadata(global_storage_metadata): +def dedup_key_in_dict(global_storage_metadata): out = {} for storage_metadata in global_storage_metadata: for key, val in storage_metadata.items(): @@ -74,6 +73,34 @@ def dedup_storage_metadata(global_storage_metadata): return out +def dedup_tensor( + local_state_dict, local_storage_metadata, global_storage_metadata +): + """ + Dedup the replicated tensor in local state_dict. + + Args: + local_state_dict(Dict[str, paddle.Tensor]): The state_dict of current rank. + local_storage_metadata(Dict[LocalTensorIndex, str]): The storage metadata of current rank. + global_storage_metadata(Dict[LocalTensorIndex, str]): The final storage metadata of all ranks. + + Examples: + In rank0, local_state_dict:{"w1": t1_0, "w2": t2}, local_storage_metadata:{LocalTensorIndex("w1", (0,0)): "0_0.distcp", LocalTensorIndex("w2", (0,0)): "0_0.distcp"}, + in rank1, local_state_dict:{"w1": t1_1, "w2": t2}, local_storage_metadata:{LocalTensorIndex("w1", (1,0)): "1_0.distcp", LocalTensorIndex("w2", (0,0)): "1_0.distcp"}, + global_storage_metadata:{LocalTensorIndex("w1", (0,0)): "0_0.distcp", LocalTensorIndex("w1", (1,0)): "1_0.distcp", LocalTensorIndex("w2", (0, 0)): "0_0.distcp"}. + w2 is replicated in rank0 and rank1. We save it in rank0 as default thus need to remove it in other ranks. + Finally, the local_state_dict:{"w1": t1_1, "w2": t2} in rank1 update to {"w1": t1_1}. + """ + + for tensor_index, file_name in global_storage_metadata.items(): + rank = int(file_name.split(".")[0].split("_")[0]) + if ( + tensor_index in local_storage_metadata + and rank != paddle.distributed.get_rank() + ): + local_state_dict.pop(tensor_index.tensor_key) + + def save_state_dict( state_dict, path, @@ -107,9 +134,9 @@ def save_state_dict( assert isinstance( state_dict, dict ), "The state_dict should be a dictionary." - state_dict = flatten_state_dict(state_dict) - if len(state_dict) > 0: - for val in state_dict.values(): + flat_state_dict, mapping = flatten_state_dict(state_dict) + if len(flat_state_dict) > 0: + for val in flat_state_dict.values(): assert isinstance( val, paddle.Tensor ), "Only support dygraph Tensor now, support static DistributedTensor later" @@ -134,12 +161,12 @@ def save_state_dict( if use_dist: check_file_name(file_name, process_group) # the parameter_name and order in state_dict should be the same - check_state_dict(state_dict, process_group) + check_state_dict(flat_state_dict, process_group) metadata = Metadata() local_state_dict = {} local_state_dict_metadata = {} local_storage_metadata = {} - for key, val in state_dict.items(): + for key, val in flat_state_dict.items(): if isinstance(val, paddle.Tensor): # Case1: not initialized means this tensor is placed in another mesh which do not contain this rank if not val._is_initialized(): @@ -178,6 +205,7 @@ def save_state_dict( ] = file_name global_state_dict_metadata = [] global_storage_metadata = [] + global_flatten_mapping = [] if use_dist: paddle.distributed.all_gather_object( global_state_dict_metadata, @@ -187,19 +215,24 @@ def save_state_dict( paddle.distributed.all_gather_object( global_storage_metadata, local_storage_metadata, process_group ) + paddle.distributed.all_gather_object( + global_flatten_mapping, mapping, process_group + ) else: global_state_dict_metadata.append(local_state_dict_metadata) global_storage_metadata.append(local_storage_metadata) + global_flatten_mapping.append(mapping) metadata.state_dict_metadata = merge_state_dict_metadata( global_state_dict_metadata ) - metadata.storage_metadata = dedup_storage_metadata( - global_storage_metadata - ) + metadata.storage_metadata = dedup_key_in_dict(global_storage_metadata) + metadata.flat_mapping = dedup_key_in_dict(global_flatten_mapping) if coordinator_rank == paddle.distributed.get_rank(): logger.debug(f"metadata:{metadata}") paddle.save(metadata, os.path.join(path, f"{unique_id}.metadata")) logger.debug(f"local_state_dict:{local_state_dict}") - # TODO(pangengzheng): del the replicated tensor in local_state_dict, now different might save the replicated tensor + dedup_tensor( + local_state_dict, local_storage_metadata, metadata.storage_metadata + ) paddle.save(local_state_dict, os.path.join(path, file_name)) diff --git a/python/paddle/distributed/checkpoint/utils.py b/python/paddle/distributed/checkpoint/utils.py index cb0f069984c3a2..d592d6ebcb97b7 100644 --- a/python/paddle/distributed/checkpoint/utils.py +++ b/python/paddle/distributed/checkpoint/utils.py @@ -63,5 +63,47 @@ def compute_local_shape_and_global_offset( def flatten_state_dict(state_dict): - # TODO, {"model": {"w0": xxx}} -> {model.w0: xxx} + """ + Flatten the nested dict to a flat dict. + {"model": {"w0": xxx}} -> {model.w0: xxx} + """ + flatten_state_dict = {} + mapping = {} + + def _flatten(key, value): + if isinstance(value, dict): + for k, v in value.items(): + assert isinstance(k, str), f"The key should be str, but is {k}" + _flatten(key + (k,), v) + elif isinstance(value, paddle.Tensor): + flatten_key_str = ".".join(key) + flatten_state_dict[flatten_key_str] = value + mapping[flatten_key_str] = key + else: + raise ValueError( + f"The value should be dict or paddle.Tensor, but is {value}" + ) + + _flatten((), state_dict) + + return flatten_state_dict, mapping + + +def unflatten_state_dict(flat_state_dict, mapping): + """ + Unflatten the flat dict to a nested dict. + {model.w0: xxx} -> {"model": {"w0": xxx}} + """ + state_dict = {} + for key, value in flat_state_dict.items(): + key_tuple = mapping[key] + assert isinstance( + key_tuple, tuple + ), f"The key should be tuple, but is {key_tuple}" + tmp = state_dict + for i in range(len(key_tuple) - 1): + key = key_tuple[i] + tmp = tmp.setdefault(key, {}) + tmp[key_tuple[-1]] = value + return state_dict diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 3a64f2095f30a7..134b164409a95f 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -406,35 +406,7 @@ def set_state_dict(self, state_dict): tensor.set_xpu_scale_value( state_dict.get(var_tmp.name + ".SCALE_VALUE", -1.0) ) - - model_np = np.array(tensor) - - load_para = state_dict[var_tmp.name] - - if isinstance(load_para, Variable): - load_para_np = np.array(load_para) - elif isinstance(load_para, core.eager.Tensor): - load_para_np = np.array(load_para) - elif isinstance(load_para, np.ndarray): - load_para_np = load_para - else: - raise RuntimeError( - f"State dict type {str(type(load_para))} not supprt" - ) - - assert ( - model_np.shape == load_para_np.shape - ), "Parameter shape not match, Dygraph Parameter [ {} ] need tensor with shape {} but load tensor with shape {}".format( - model_np.name, model_np.shape, load_para_np.shape - ) - - assert ( - model_np.dtype == load_para_np.dtype - ), "Parameter dtype not match, Dygraph Parameter [ {} ] need tensor with dtype {} but load tensor with dtype {}".format( - model_np.name, model_np.dtype, load_para_np.dtype - ) - - tensor.set(load_para_np, framework._current_expected_place()) + var.set_value(state_dict[var_tmp.name]) def get_opti_var_name_list(self): return self._opti_name_list diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index 774dc3d2023b93..a735762cce6581 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -194,6 +194,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_gpt_with_prim MODULES test_gpt_with_prim) set_tests_properties(test_gpt_with_prim PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 200) + py_test_modules(test_dist_checkpoint_utils MODULES test_dist_checkpoint_utils) + set_tests_properties(test_dist_checkpoint_utils + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) py_test_modules(test_semi_auto_parallel_unshard_dtensor MODULES test_semi_auto_parallel_unshard_dtensor) set_tests_properties(test_semi_auto_parallel_unshard_dtensor diff --git a/test/auto_parallel/semi_auto_parallel_checkpoint_dedup_tensor.py b/test/auto_parallel/semi_auto_parallel_checkpoint_dedup_tensor.py new file mode 100644 index 00000000000000..7f8884156aa7ef --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_checkpoint_dedup_tensor.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist + + +class TestSaveStateDict: + def __init__(self): + self._ckpt_path = os.getenv("ckpt_path") + + def test_dedup_tesnor(self): + w1 = paddle.arange(32).reshape([4, 8]) + w2 = paddle.arange(32, 36).reshape([2, 2]) + mesh = dist.ProcessMesh([0, 1]) + dist_w1 = dist.shard_tensor(w1, mesh, [dist.Replicate()]) + dist_w2 = dist.shard_tensor(w2, mesh, [dist.Shard(0)]) + state_dict = {"w1": dist_w1, "w2": dist_w2} + # w1 is replicated in rank0 and ran1, it will only save in rank0. + # Therefore, rank0 save state_dict:{"w1": dist_w1, "w2": dist_w2}, rank1 save state_dict:{"w2": dist_w2} + dist.save_state_dict(state_dict, self._ckpt_path) + paddle.distributed.barrier() + # check + expect_local_state_dict = {} + for k, v in state_dict.items(): + if k == "w1" and paddle.distributed.get_rank() != 0: + continue + expect_local_state_dict[k] = v._local_value() + data_file_path = os.path.join( + self._ckpt_path, f"{paddle.distributed.get_rank()}_0.distcp" + ) + metadata_file_path = os.path.join(self._ckpt_path, "0.metadata") + assert os.path.exists(data_file_path) and os.path.exists( + metadata_file_path + ) + local_state_dict = paddle.load(data_file_path) + metadata = paddle.load(metadata_file_path) + + for k, local_tensor in local_state_dict.items(): + assert k in expect_local_state_dict + expect_tensor = expect_local_state_dict[k] + np.testing.assert_equal(expect_tensor.numpy(), local_tensor.numpy()) + for tensor_index, file_name in metadata.storage_metadata.items(): + rank = int(file_name.split(".")[0].split("_")[0]) + if tensor_index.tensor_key == "w1": + assert rank == 0 + + def run_test_case(self): + self.test_dedup_tesnor() + + +if __name__ == '__main__': + TestSaveStateDict().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_checkpoint_flatten_mapping.py b/test/auto_parallel/semi_auto_parallel_checkpoint_flatten_mapping.py new file mode 100644 index 00000000000000..c8cfdb22d85987 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_checkpoint_flatten_mapping.py @@ -0,0 +1,74 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import paddle +import paddle.distributed as dist + + +class TestSemiautoSaveLoad: + def __init__(self): + self._ckpt_path = os.getenv("ckpt_path") + + def test_flatten_mapping(self): + if paddle.distributed.get_rank() == 0: + state_dict = { + "model": { + "a": paddle.to_tensor([1, 2]), + "b": paddle.to_tensor([3, 4]), + }, + "optimizer": { + "c": paddle.to_tensor([5, 6]), + "d": paddle.to_tensor([7, 8]), + }, + } + else: + state_dict = { + "model": { + "a": paddle.to_tensor([10, 20]), + "b": paddle.to_tensor([30, 40]), + }, + "optimizer": { + "c": paddle.to_tensor([50, 60]), + "d": paddle.to_tensor([70, 80]), + }, + } + expected_mapping = { + "model.a": ("model", "a"), + "model.b": ("model", "b"), + "optimizer.c": ("optimizer", "c"), + "optimizer.d": ("optimizer", "d"), + } + dist.save_state_dict(state_dict, self._ckpt_path) + metadata_path = os.path.join(self._ckpt_path, "0.metadata") + assert os.path.exists(metadata_path) + metadata = paddle.load(metadata_path) + assert len(metadata.flat_mapping) == len( + expected_mapping + ), f"expect {len(expected_mapping)}, but got {len(metadata.flat_mapping)}" + for key in metadata.flat_mapping: + assert ( + key in expected_mapping + ), f"expect {key} in flatten_mapping, but not found" + assert ( + metadata.flat_mapping[key] == expected_mapping[key] + ), f"expect {metadata.flat_mapping[key]} == {expected_mapping[key]}, but not equal" + + def run_test_case(self): + self.test_flatten_mapping() + + +if __name__ == '__main__': + TestSemiautoSaveLoad().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_shard_optimizer_api.py b/test/auto_parallel/semi_auto_parallel_shard_optimizer_api.py index f4d22a16c41bd0..0153d3bd21216e 100644 --- a/test/auto_parallel/semi_auto_parallel_shard_optimizer_api.py +++ b/test/auto_parallel/semi_auto_parallel_shard_optimizer_api.py @@ -179,6 +179,61 @@ def test_shard_optimizer_master_params(self): assert v.is_dist() assert v.shape[-1] == v._local_shape[-1] * 2 + # save load + ckpt_state_dict = opt.state_dict() + dist.save_state_dict(ckpt_state_dict, self._ckpt_path) + paddle.distributed.barrier() + expected_local_state_dict = {} + expected_local_state_dict.setdefault("master_weights", {}) + need_load_state_dict = {} + need_load_state_dict.setdefault("master_weights", {}) + for k, v in ckpt_state_dict.items(): + if k == "LR_Scheduler": + continue + elif k == "master_weights": + assert isinstance(v, dict), v + for mk, mv in v.items(): + expected_local_state_dict[k][mk] = mv._local_value().clone() + need_load_state_dict[k][mk] = paddle.zeros_like(mv) + else: + expected_local_state_dict[k] = v._local_value().clone() + need_load_state_dict[k] = paddle.zeros_like(v) + opt.set_state_dict(need_load_state_dict) + after_set_state_dict = opt.state_dict() + for k, v in after_set_state_dict.items(): + if k == "master_weights": + assert isinstance(v, dict), v + for mk, mv in v.items(): + assert ( + mv.numpy().sum() == 0.0 + ), f"state_dict {k} in master_weights is not zero" + assert ( + need_load_state_dict[k][mk].numpy().sum() == 0.0 + ), f"state_dict {k} in master_weights is not zero" + else: + assert v.numpy().sum() == 0.0, f"state_dict {k} is not zero" + assert k in need_load_state_dict, f"state_dict {k} is not found" + assert ( + need_load_state_dict[k].numpy().sum() == 0.0 + ), f"state_dict {k} is not zero" + dist.load_state_dict(need_load_state_dict, self._ckpt_path) + opt.set_state_dict(need_load_state_dict) + new_state_dict = opt.state_dict() + assert "master_weights" in new_state_dict, new_state_dict + for k, v in new_state_dict.items(): + assert k in expected_local_state_dict + if k == "master_weights": + for mk, mv in v.items(): + np.testing.assert_equal( + mv._local_value().numpy(), + expected_local_state_dict[k][mk].numpy(), + ) + else: + np.testing.assert_equal( + v._local_value().numpy(), + expected_local_state_dict[k].numpy(), + ) + def test_shard_optimizer_params_group(self): paddle.seed(self._seed) linear = paddle.nn.Linear(10, 10) diff --git a/test/auto_parallel/test_dist_checkpoint_utils.py b/test/auto_parallel/test_dist_checkpoint_utils.py new file mode 100644 index 00000000000000..5a51f73f0fa56c --- /dev/null +++ b/test/auto_parallel/test_dist_checkpoint_utils.py @@ -0,0 +1,105 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import collective.test_communication_api_base as test_base +import numpy as np + +import paddle +from paddle.distributed.checkpoint.utils import ( + flatten_state_dict, + unflatten_state_dict, +) + + +class TestDistCheckpointUtils(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=2, timeout=120, nnode=1) + self._default_envs = {} + self._changeable_envs = {"backend": ["gpu"]} + + def test_flatten_mapping(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + ckpt_path_tmp = tempfile.TemporaryDirectory() + ckpt_path = ckpt_path_tmp.name + envs["ckpt_path"] = ckpt_path + self.run_test_case( + "semi_auto_parallel_checkpoint_flatten_mapping.py", + user_defined_envs=envs, + ) + ckpt_path_tmp.cleanup() + + def test_dedup_tensor(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + ckpt_path_tmp = tempfile.TemporaryDirectory() + ckpt_path = ckpt_path_tmp.name + envs["ckpt_path"] = ckpt_path + self.run_test_case( + "semi_auto_parallel_checkpoint_dedup_tensor.py", + user_defined_envs=envs, + ) + ckpt_path_tmp.cleanup() + + def test_flatten_state_dict(self): + state_dict = { + "model": { + "a.0": paddle.to_tensor([1, 2]), + "b": paddle.to_tensor([3, 4]), + }, + "optimizer": { + "c": paddle.to_tensor([5, 6]), + "d.2": paddle.to_tensor([7, 8]), + }, + } + expected_flat_state_dict = { + "model.a.0": paddle.to_tensor([1, 2]), + "model.b": paddle.to_tensor([3, 4]), + "optimizer.c": paddle.to_tensor([5, 6]), + "optimizer.d.2": paddle.to_tensor([7, 8]), + } + flat_state_dict, mapping = flatten_state_dict(state_dict) + self.assertTrue(len(expected_flat_state_dict) == len(flat_state_dict)) + for k, v in flat_state_dict.items(): + self.assertTrue(isinstance(v, paddle.Tensor)) + self.assertTrue(k in expected_flat_state_dict) + np.testing.assert_equal( + v.numpy(), expected_flat_state_dict[k].numpy() + ) + recover_state_dict = unflatten_state_dict(flat_state_dict, mapping) + + def check_state_dict(d1, d2): + self.assertTrue(len(d1) == len(d2)) + self.assertTrue(type(d1) == type(d2)) + if isinstance(d1, dict): + for k in d1: + self.assertTrue(k in d2) + check_state_dict(d1[k], d2[k]) + elif isinstance(d1, paddle.Tensor): + np.testing.assert_equal(d1.numpy(), d2.numpy()) + else: + raise ValueError(f"Invalid type of state_dict:{d1} != {d2}") + + check_state_dict(recover_state_dict, state_dict) + + +if __name__ == "__main__": + unittest.main() From c1d78603ca6d818aec775733521e04db9c145716 Mon Sep 17 00:00:00 2001 From: zyt1024 <42999008+zyt1024@users.noreply.github.com> Date: Thu, 28 Dec 2023 17:14:26 +0800 Subject: [PATCH 034/142] =?UTF-8?q?=E3=80=90Complex=20op=E3=80=91add=20com?= =?UTF-8?q?plex=20support=20for=20assign=5Fvalue=20=20(#59536)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * support_complex_for_assign_value * add test complex test for test_program_converter * add complex test for assign_value xpu * solve conflict * fix timeout * fix CE infer bug * fix program convert bug * fix program convert bug for assign_value --------- Co-authored-by: zyt1024 <1522064645@qq.com> --- paddle/fluid/framework/op_version_proto.cc | 1 + paddle/fluid/framework/program_converter.cc | 85 +++++++- .../ir_adaptor/translator/op_translator.cc | 17 +- .../ops_signature/assign_value_sig.cc | 26 +-- .../pir/dialect/operator/ir/op_attribute.cc | 4 + .../pir/dialect/operator/ir/op_attribute.h | 4 +- .../fluid/pir/dialect/operator/utils/utils.h | 6 + paddle/fluid/pybind/op_function_common.cc | 21 +- paddle/phi/api/yaml/op_version.yaml | 15 ++ paddle/phi/api/yaml/static_ops.yaml | 2 +- paddle/phi/kernels/assign_kernel.cc | 12 +- paddle/pir/core/builder.h | 5 + paddle/pir/core/builtin_attribute.cc | 10 + paddle/pir/core/builtin_attribute.h | 23 +++ paddle/pir/core/builtin_attribute_storage.h | 40 ++++ paddle/pir/core/builtin_dialect.cc | 4 +- paddle/pir/core/ir_printer.cc | 4 + python/paddle/nn/initializer/Bilinear.py | 2 +- python/paddle/nn/initializer/assign.py | 6 +- python/paddle/nn/initializer/dirac.py | 4 +- python/paddle/tensor/creation.py | 58 ++---- .../test_program_translator.py | 14 +- test/ir/inference/CMakeLists.txt | 4 +- test/ir/inference/test_mul_gru_fuse_pass.py | 2 +- test/ir/inference/test_mul_lstm_fuse_pass.py | 2 +- .../inference/test_seq_concat_fc_fuse_pass.py | 4 +- test/legacy_test/test_assign_value_op.py | 101 +++++++-- test/legacy_test/test_initializer.py | 3 +- test/legacy_test/test_initializer_nn.py | 4 +- test/legacy_test/test_program_converter.py | 193 ++++++++++++++++++ test/xpu/test_assign_value_op_xpu.py | 61 +++++- 31 files changed, 614 insertions(+), 123 deletions(-) diff --git a/paddle/fluid/framework/op_version_proto.cc b/paddle/fluid/framework/op_version_proto.cc index 2a93e755b085bf..8be9323098c971 100644 --- a/paddle/fluid/framework/op_version_proto.cc +++ b/paddle/fluid/framework/op_version_proto.cc @@ -21,6 +21,7 @@ namespace pb { const std::unordered_map& GetLegacyOpVersions() { static std::unordered_map op_versions = { {"not_equal", 1}, + {"assign_value", 0}, {"fake_channel_wise_dequantize_max_abs", 2}, {"yolo_box", 1}, {"data_norm", 1}, diff --git a/paddle/fluid/framework/program_converter.cc b/paddle/fluid/framework/program_converter.cc index fc60a0abf676e1..82739e788bba36 100644 --- a/paddle/fluid/framework/program_converter.cc +++ b/paddle/fluid/framework/program_converter.cc @@ -117,6 +117,41 @@ void ConvertSetValueOp(OpDesc* op) { } } +void ConvertAssignValueOp(OpDesc* op) { + std::vector values = PADDLE_GET_CONST( + std::vector, op->GetAttr("values", false)); + op->RemoveAttr("values"); + op->SetAttr("bool_values", std::vector()); + op->SetAttr("fp32_values", std::vector()); + op->SetAttr("int32_values", std::vector()); + op->SetAttr("int64_values", std::vector()); + + phi::DataType dtype = phi::DataType::FLOAT32; + if (values.size()) { + dtype = values.at(0).dtype(); + } + + switch (dtype) { + case phi::DataType::BOOL: + op->SetAttr("bool_values", ExtractPlainVector(values)); + break; + case phi::DataType::FLOAT32: + op->SetAttr("fp32_values", ExtractPlainVector(values)); + break; + case phi::DataType::FLOAT64: + op->SetAttr("fp32_values", ExtractPlainVector(values)); + break; + case phi::DataType::INT32: + op->SetAttr("int32_values", ExtractPlainVector(values)); + break; + case phi::DataType::INT64: + op->SetAttr("int64_values", ExtractPlainVector(values)); + break; + default: + PD_THROW("Invalid data type `", dtype, "`."); + } +} + void ConvertProgram(ProgramDesc* program) { PADDLE_ENFORCE_NOT_NULL( program, @@ -144,6 +179,9 @@ void ConvertProgram(ProgramDesc* program) { if (op_type == "set_value" || op_type == "set_value_grad") { ConvertSetValueOp(op); } + if (op_type == "assign_value") { + ConvertAssignValueOp(op); + } } } } @@ -204,6 +242,45 @@ void ConvertSetValueOp(OpDesc* op) { op->SetAttr("values", values); } +void ConvertAssignValueOp(OpDesc* op) { + VLOG(3) << "convert old assign value op to new"; + std::vector values; + + if (op->HasAttr("bool_values")) { + std::vector bool_values = + PADDLE_GET_CONST(std::vector, op->GetAttr("bool_values", false)); + if (bool_values.size()) { + values = WrapAsScalars(bool_values); + } + op->RemoveAttr("bool_values"); + } + if (op->HasAttr("fp32_values")) { + std::vector fp32_values = + PADDLE_GET_CONST(std::vector, op->GetAttr("fp32_values", false)); + if (fp32_values.size()) { + values = WrapAsScalars(fp32_values); + } + op->RemoveAttr("fp32_values"); + } + if (op->HasAttr("int32_values")) { + std::vector int32_values = + PADDLE_GET_CONST(std::vector, op->GetAttr("int32_values", false)); + if (int32_values.size()) { + values = WrapAsScalars(int32_values); + } + op->RemoveAttr("int32_values"); + } + if (op->HasAttr("int64_values")) { + std::vector int64_values = PADDLE_GET_CONST( + std::vector, op->GetAttr("int64_values", false)); + if (int64_values.size()) { + values = WrapAsScalars(int64_values); + } + op->RemoveAttr("int64_values"); + } + op->SetAttr("values", values); +} + void ConvertProgram(ProgramDesc* program) { PADDLE_ENFORCE_NOT_NULL( program, @@ -214,6 +291,7 @@ void ConvertProgram(ProgramDesc* program) { const std::unordered_map& legacy_op_versions = legacy_op_results.second; + VLOG(3) << "is_legacy_program : " << is_legacy_program; if (!is_legacy_program) return; VLOG(3) << "Updating Program Version and OpVersionMap"; @@ -232,10 +310,15 @@ void ConvertProgram(ProgramDesc* program) { for (size_t j = 0; j < num_ops; j++) { OpDesc* op = block->Op(static_cast(j)); const std::string op_type = op->Type(); + + if (op_type == "assign_value") { + VLOG(3) << "Converting program from old to new, op_type=" << op_type; + ConvertAssignValueOp(op); + } if (!legacy_op_versions.count(op_type)) { continue; } - + VLOG(3) << "Converting program from old to new, op_type=" << op_type; if (op_type == "set_value" || op_type == "set_value_grad") { ConvertSetValueOp(op); } diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 626073d143e3e3..c64004c7191dd9 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -972,19 +972,20 @@ struct AssignValueOpTranscriber : public OpTranscriber { ctx, phi::Place(phi::AllocationType::UNDEFINED)); attribute_map["place"] = attr_place; - int dtype = paddle::get(op_desc.GetAttr("dtype")); - - if (dtype == /*BOOL*/ 0) { + if (op_desc.HasAttr("bool_values")) { legacy_attr = op_desc.GetAttr("bool_values"); - } else if (dtype == /*INT32*/ 2) { - legacy_attr = op_desc.GetAttr("int32_values"); - } else if (dtype == /*FP32*/ 5) { + } else if (op_desc.HasAttr("fp32_values")) { legacy_attr = op_desc.GetAttr("fp32_values"); - } else if (dtype == /*INT64*/ 3) { + } else if (op_desc.HasAttr("int32_values")) { + legacy_attr = op_desc.GetAttr("int32_values"); + } else if (op_desc.HasAttr("int64_values")) { legacy_attr = op_desc.GetAttr("int64_values"); + } else if (op_desc.HasAttr("values")) { + legacy_attr = op_desc.GetAttr("values"); } else { IR_THROW( - "Op assign_value should have attribute `**_values` but not find"); + "Op assign_value should have attribute `**_values` or `values` but " + "not find"); } pir::Attribute attr_values = attribute_translator( diff --git a/paddle/fluid/operators/ops_signature/assign_value_sig.cc b/paddle/fluid/operators/ops_signature/assign_value_sig.cc index 977c2260e59b9c..ae14c5a9d7879b 100644 --- a/paddle/fluid/operators/ops_signature/assign_value_sig.cc +++ b/paddle/fluid/operators/ops_signature/assign_value_sig.cc @@ -18,30 +18,8 @@ namespace phi { KernelSignature AssignValueOpArgumentMapping( const ArgumentMappingContext& ctx) { - // Here we must use `dtype` attr to determine which attr to use, we can't - // judge by whether the attr is empty, some unittests will failed - int dtype = paddle::any_cast(ctx.Attr("dtype")); - // heer we can't depend on the fluid proto::VarType, so we use the dtype enum - // value directly, If the enum value is updated, the code also needs to be - // updated here, but the probability of updating the enum value is very low - if (dtype == /*BOOL*/ 0) { - return KernelSignature( - "assign_value", {}, {"shape", "dtype", "bool_values"}, {"Out"}); - } else if (dtype == /*INT32*/ 2) { - return KernelSignature( - "assign_value", {}, {"shape", "dtype", "int32_values"}, {"Out"}); - } else if (dtype == /*FP32*/ 5) { - return KernelSignature( - "assign_value", {}, {"shape", "dtype", "fp32_values"}, {"Out"}); - } else if (dtype == /*FP64*/ 6) { - return KernelSignature( - "assign_value", {}, {"shape", "dtype", "fp64_values"}, {"Out"}); - } else if (dtype == /*INT64*/ 3) { - return KernelSignature( - "assign_value", {}, {"shape", "dtype", "int64_values"}, {"Out"}); - } else { - return KernelSignature("unregistered", {}, {}, {}); - } + return KernelSignature( + "assign_value", {}, {"shape", "dtype", "values"}, {"Out"}); } } // namespace phi diff --git a/paddle/fluid/pir/dialect/operator/ir/op_attribute.cc b/paddle/fluid/pir/dialect/operator/ir/op_attribute.cc index 3134214cf9029b..10ae5a77d9f4a1 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_attribute.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_attribute.cc @@ -43,6 +43,10 @@ phi::Scalar ScalarAttribute::data() { return phi::Scalar(dyn_cast().data()); } else if (isa()) { return phi::Scalar(dyn_cast().AsString()); + } else if (isa()) { + return phi::Scalar(dyn_cast().data()); + } else if (isa()) { + return phi::Scalar(dyn_cast().data()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Unsupported ir attribute when casting it into " diff --git a/paddle/fluid/pir/dialect/operator/ir/op_attribute.h b/paddle/fluid/pir/dialect/operator/ir/op_attribute.h index 0b0973a5205c85..f58803fa200025 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_attribute.h +++ b/paddle/fluid/pir/dialect/operator/ir/op_attribute.h @@ -50,7 +50,9 @@ class ScalarAttribute : public pir::Attribute { (val.type_id() == pir::Int32Attribute::type_id()) || (val.type_id() == pir::IndexAttribute::type_id()) || (val.type_id() == pir::Int64Attribute::type_id()) || - (val.type_id() == pir::StrAttribute::type_id()); + (val.type_id() == pir::StrAttribute::type_id()) || + (val.type_id() == pir::Complex64Attribute::type_id()) || + (val.type_id() == pir::Complex128Attribute::type_id()); } static pir::Attribute get(pir::IrContext *ctx, phi::Scalar scalar) { diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.h b/paddle/fluid/pir/dialect/operator/utils/utils.h index 0e14077bb8559d..7a8a5083a3daed 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.h +++ b/paddle/fluid/pir/dialect/operator/utils/utils.h @@ -120,6 +120,12 @@ static inline pir::Attribute TransToIrAttribute(phi::Scalar scalar, return pir::Int64Attribute::get(ctx, scalar.to()); case phi::DataType::BOOL: return pir::BoolAttribute::get(ctx, scalar.to()); + case phi::DataType::COMPLEX64: + return pir::Complex64Attribute::get( + ctx, scalar.to>()); + case phi::DataType::COMPLEX128: + return pir::Complex128Attribute::get( + ctx, scalar.to>()); default: PADDLE_THROW(phi::errors::Unimplemented( "Unsupported phi data type `%s` when casting it into " diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index 489b25f35867c8..0555724a49cfaa 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -77,7 +77,7 @@ bool PyObject_CheckLongOrToLong(PyObject** obj) { } if (std::string(((PyTypeObject*)(*obj)->ob_type)->tp_name) // NOLINT - .find("numpy") != std::string::npos) { + .find("numpy.int") != std::string::npos) { auto to = PyNumber_Long(*obj); if (to) { *obj = to; @@ -95,8 +95,12 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) { (((TensorObject*)(*obj))->tensor.numel() == 1))) { // NOLINT return true; } - if (std::string(((PyTypeObject*)(*obj)->ob_type)->tp_name) // NOLINT - .find("numpy") != std::string::npos) { + auto type_name = + std::string(reinterpret_cast((*obj)->ob_type)->tp_name); + VLOG(4) << "type_name: " << type_name; + + if (type_name.find("numpy") != std::string::npos && + type_name.find("numpy.complex") == std::string::npos) { auto to = PyNumber_Float(*obj); if (to) { *obj = to; @@ -107,11 +111,15 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) { } bool PyObject_CheckComplexOrToComplex(PyObject** obj) { - if (PyComplex_Check(*obj) || PyLong_Check(*obj) || PyFloat_Check(*obj) || + if (PyComplex_Check(*obj) || PyObject_TypeCheck(*obj, g_vartype_pytype) || // NOLINT PyObject_TypeCheck(*obj, p_tensor_type)) { // NOLINT return true; } + if (std::string(((PyTypeObject*)(*obj)->ob_type)->tp_name) // NOLINT + .find("numpy.complex") != std::string::npos) { + return true; + } // consider numpy cfloat & numpy cdouble? return false; } @@ -242,10 +250,15 @@ double CastPyArg2Double(PyObject* obj, phi::dtype::complex CastPyArg2Complex(PyObject* obj, const std::string& op_type, ssize_t arg_pos) { + PyTypeObject* type = obj->ob_type; + auto type_name = std::string(type->tp_name); if (PyComplex_Check(obj)) { double real = PyComplex_RealAsDouble(obj); double imag = PyComplex_ImagAsDouble(obj); return phi::dtype::complex(real, imag); // NOLINT + } else if (type_name == "numpy.complex64") { + Py_complex v = PyComplex_AsCComplex(obj); + return phi::dtype::complex(v.real, v.imag); } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " diff --git a/paddle/phi/api/yaml/op_version.yaml b/paddle/phi/api/yaml/op_version.yaml index 7c9618f52b17b8..2bd09abd311aed 100644 --- a/paddle/phi/api/yaml/op_version.yaml +++ b/paddle/phi/api/yaml/op_version.yaml @@ -55,6 +55,21 @@ - delete_attr : atol comment : The attribute 'atol' is deleted. The reason why it is deleted is that attributes do not support a float64 value and it is changed to a tensor. +- op : assign_value + version : + - checkpoint : Upgrade assign_value, remove plain attributes in favor of generic attribute. + action : + - add_attr : values + comment : replace generic types with scalar. + default : std::vector() + - delete_attr : bool_values + comment : remove plain attributes. + - delete_attr : fp32_values + comment : remove plain attributes. + - delete_attr : int32_values + comment : remove plain attributes. + - delete_attr : int64_values + comment : remove plain attributes. - op : auc version : diff --git a/paddle/phi/api/yaml/static_ops.yaml b/paddle/phi/api/yaml/static_ops.yaml index 5fe9ea4260d402..6ff2bfe427122c 100755 --- a/paddle/phi/api/yaml/static_ops.yaml +++ b/paddle/phi/api/yaml/static_ops.yaml @@ -90,7 +90,7 @@ backward : assign_grad - op : assign_value - args : (int[] shape, DataType dtype, int[] bool_values = {}, float[] fp32_values = {}, double[] fp64_values = {}, int[] int32_values = {}, int64_t[] int64_values = {}) + args : (int[] shape, DataType dtype, Scalar[] values = {}) output : Tensor(out) infer_meta : func : AssignValueInferMeta diff --git a/paddle/phi/kernels/assign_kernel.cc b/paddle/phi/kernels/assign_kernel.cc index b4504f83818d77..f54dfec2f6ad2f 100644 --- a/paddle/phi/kernels/assign_kernel.cc +++ b/paddle/phi/kernels/assign_kernel.cc @@ -137,7 +137,9 @@ PD_REGISTER_KERNEL(assign_value, float, double, int8_t, - int64_t) {} + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PD_REGISTER_KERNEL_FOR_ALL_DTYPE(assign, @@ -165,7 +167,9 @@ PD_REGISTER_KERNEL(assign_value, float, double, int8_t, - int64_t) {} + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} #endif #ifdef PADDLE_WITH_XPU @@ -193,5 +197,7 @@ PD_REGISTER_KERNEL(assign_value, int, float, double, - int64_t) {} + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} #endif diff --git a/paddle/pir/core/builder.h b/paddle/pir/core/builder.h index c5e3472bb070ad..158d82f3fbcbe4 100644 --- a/paddle/pir/core/builder.h +++ b/paddle/pir/core/builder.h @@ -16,6 +16,7 @@ #include +#include "paddle/phi/common/complex.h" #include "paddle/pir/core/block.h" #include "paddle/pir/core/ir_context.h" #include "paddle/pir/core/operation.h" @@ -44,6 +45,8 @@ class Int64Attribute; class ArrayAttribute; class PointerAttribute; class TensorNameAttribute; +class Complex64Attribute; +class Complex128Attribute; using InsertionPoint = std::pair; /// @@ -150,6 +153,8 @@ class Builder { IR_API ArrayAttribute array_attr(const std::vector &value); IR_API PointerAttribute pointer_attr(void *value); IR_API TensorNameAttribute tensor_name_attr(const std::string &value); + IR_API Complex64Attribute complex64_attr(phi::dtype::complex value); + IR_API Complex128Attribute complex128_attr(phi::dtype::complex value); private: Operation *Insert(Operation *op); diff --git a/paddle/pir/core/builtin_attribute.cc b/paddle/pir/core/builtin_attribute.cc index a817fb48c55fcf..32136371d5780f 100644 --- a/paddle/pir/core/builtin_attribute.cc +++ b/paddle/pir/core/builtin_attribute.cc @@ -32,6 +32,14 @@ void* PointerAttribute::data() const { return storage()->data(); } Type TypeAttribute::data() const { return storage()->data(); } +phi::dtype::complex Complex64Attribute::data() const { + return storage()->data(); +} + +phi::dtype::complex Complex128Attribute::data() const { + return storage()->data(); +} + bool StrAttribute::operator<(const StrAttribute& right) const { return storage() < right.storage(); } @@ -109,3 +117,5 @@ IR_DEFINE_EXPLICIT_TYPE_ID(pir::ArrayAttribute) IR_DEFINE_EXPLICIT_TYPE_ID(pir::PointerAttribute) IR_DEFINE_EXPLICIT_TYPE_ID(pir::TypeAttribute) IR_DEFINE_EXPLICIT_TYPE_ID(pir::TensorNameAttribute) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::Complex64Attribute) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::Complex128Attribute) diff --git a/paddle/pir/core/builtin_attribute.h b/paddle/pir/core/builtin_attribute.h index a1751a8c248b80..59345c9e1b4f67 100644 --- a/paddle/pir/core/builtin_attribute.h +++ b/paddle/pir/core/builtin_attribute.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/phi/common/complex.h" #include "paddle/pir/core/attribute.h" #include "paddle/pir/core/builtin_attribute_storage.h" #include "paddle/pir/core/utils.h" @@ -28,6 +29,26 @@ class IR_API BoolAttribute : public Attribute { bool data() const; }; +class IR_API Complex64Attribute : public Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(Complex64Attribute, + Complex64AttributeStorage); + + phi::dtype::complex data() const; +}; + +class IR_API Complex128Attribute : public Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(Complex128Attribute, + Complex128AttributeStorage); + + phi::dtype::complex data() const; +}; + class IR_API FloatAttribute : public Attribute { public: using Attribute::Attribute; @@ -157,3 +178,5 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::ArrayAttribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::PointerAttribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::TypeAttribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::TensorNameAttribute) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::Complex64Attribute) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::Complex128Attribute) diff --git a/paddle/pir/core/builtin_attribute_storage.h b/paddle/pir/core/builtin_attribute_storage.h index 533b0a4ad03e98..9e66fb6b010c9c 100644 --- a/paddle/pir/core/builtin_attribute_storage.h +++ b/paddle/pir/core/builtin_attribute_storage.h @@ -19,6 +19,7 @@ #include #include "paddle/common/enforce.h" +#include "paddle/phi/common/complex.h" #include "paddle/pir/core/attribute.h" #include "paddle/pir/core/attribute_base.h" #include "paddle/pir/core/type.h" @@ -149,4 +150,43 @@ struct ArrayAttributeStorage : public AttributeStorage { const size_t size_; }; +struct Complex64AttributeStorage : public AttributeStorage { + using ParamKey = phi::dtype::complex; + explicit Complex64AttributeStorage(const ParamKey &key) { data_ = key; } + static Complex64AttributeStorage *Construct(const ParamKey &key) { + return new Complex64AttributeStorage(key); + } + static std::size_t HashValue(const ParamKey &key) { + std::stringstream complex_str; + complex_str << key.real << "+" << key.imag << "i"; + return std::hash{}(complex_str.str()); + } + + bool operator==(ParamKey key) const { return data_ == key; } + + phi::dtype::complex data() const { return data_; } + + private: + phi::dtype::complex data_; +}; + +struct Complex128AttributeStorage : public AttributeStorage { + using ParamKey = phi::dtype::complex; + explicit Complex128AttributeStorage(const ParamKey &key) { data_ = key; } + static Complex128AttributeStorage *Construct(const ParamKey &key) { + return new Complex128AttributeStorage(key); + } + static std::size_t HashValue(const ParamKey &key) { + std::stringstream complex_str; + complex_str << key.real << "+" << key.imag << "i"; + return std::hash{}(complex_str.str()); + } + + bool operator==(ParamKey key) const { return data_ == key; } + + phi::dtype::complex data() const { return data_; } + + private: + phi::dtype::complex data_; +}; } // namespace pir diff --git a/paddle/pir/core/builtin_dialect.cc b/paddle/pir/core/builtin_dialect.cc index 4bba7185384a32..91835c3029dc76 100644 --- a/paddle/pir/core/builtin_dialect.cc +++ b/paddle/pir/core/builtin_dialect.cc @@ -50,7 +50,9 @@ void BuiltinDialect::initialize() { Int64Attribute, ArrayAttribute, TypeAttribute, - TensorNameAttribute>(); + TensorNameAttribute, + Complex64Attribute, + Complex128Attribute>(); RegisterOps()) { os << "(Pointer)" << p.data(); + } else if (auto p = attr.dyn_cast()) { + os << "(Complex64)" << p.data(); + } else if (auto p = attr.dyn_cast()) { + os << "(Complex128)" << p.data(); } else if (auto arr = attr.dyn_cast()) { const auto& vec = arr.AsVector(); os << "["; diff --git a/python/paddle/nn/initializer/Bilinear.py b/python/paddle/nn/initializer/Bilinear.py index cfb18dac02c2a8..1da82cbeee970f 100644 --- a/python/paddle/nn/initializer/Bilinear.py +++ b/python/paddle/nn/initializer/Bilinear.py @@ -148,7 +148,7 @@ def forward(self, var, block=None): out_var = var if out_dtype in (core.VarDesc.VarType.FP32, core.DataType.FLOAT32): - value_name = "fp32_values" + value_name = "values" values = [float(v) for v in weight.flat] else: raise TypeError("Unsupported dtype %s", var.dtype) diff --git a/python/paddle/nn/initializer/assign.py b/python/paddle/nn/initializer/assign.py index 9274ff5275df09..3988f9f14859d9 100644 --- a/python/paddle/nn/initializer/assign.py +++ b/python/paddle/nn/initializer/assign.py @@ -89,13 +89,13 @@ def forward(self, var, block=None): np_value = self._value if out_dtype in (core.VarDesc.VarType.FP32, core.DataType.FLOAT32): - value_name = "fp32_values" + value_name = "values" values = [float(v) for v in np_value.flat] elif out_dtype in (core.VarDesc.VarType.FP64, core.DataType.FLOAT64): - value_name = "fp64_values" + value_name = "values" values = [float(v) for v in np_value.flat] elif out_dtype in (core.VarDesc.VarType.INT32, core.DataType.INT32): - value_name = "int32_values" + value_name = "values" values = [int(v) for v in np_value.flat] elif out_dtype in ( core.VarDesc.VarType.INT8, diff --git a/python/paddle/nn/initializer/dirac.py b/python/paddle/nn/initializer/dirac.py index 7da5cd15b54f7e..4aea131684f212 100644 --- a/python/paddle/nn/initializer/dirac.py +++ b/python/paddle/nn/initializer/dirac.py @@ -255,7 +255,7 @@ def __call__(self, var, block=None): attrs={ 'dtype': VarDesc.VarType.INT64, 'shape': [len(idx_list)], - 'int64_values': idx_list, + 'values': idx_list, }, stop_gradient=True, ) @@ -298,7 +298,7 @@ def __call__(self, var, block=None): attrs={ 'dtype': VarDesc.VarType.FP32, 'shape': [len(value_list)], - 'fp32_values': value_list, + 'values': value_list, }, stop_gradient=True, ) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 1fb067edcbb6e1..5fbf1f0fbc468c 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -16,7 +16,6 @@ import math import re -import warnings import numpy as np @@ -2361,6 +2360,8 @@ def assign(x, output=None): 'uint8', 'int8', 'bool', + 'complex64', + 'complex128', ], 'assign', '(When the type of input in assign is Variable.)', @@ -2408,44 +2409,23 @@ def convert_scalar(x): ) dtype = convert_np_dtype_to_dtype_(input.dtype) - if dtype == core.VarDesc.VarType.FP64: - # Setting FP64 numpy data is not supported in Paddle, so we - # use FP32 here - warnings.warn( - "paddle.assign doesn't support float64 input now due " - "to current platform protobuf data limitation, we convert " - "it to float32" - ) - dtype = core.VarDesc.VarType.FP32 - - if dtype == core.DataType.FLOAT64: - # Setting FP64 numpy data is not supported in Paddle, so we - # use FP32 here - warnings.warn( - "paddle.assign doesn't support float64 input now due " - "to current platform protobuf data limitation, we convert " - "it to float32" - ) - dtype = core.DataType.FLOAT32 - - if dtype in [core.VarDesc.VarType.BOOL, core.DataType.BOOL]: - value_name = "bool_values" - values = [int(v) for v in input.flat] - elif dtype in [core.VarDesc.VarType.FP32, core.DataType.FLOAT32]: - value_name = "fp32_values" - values = [float(v) for v in input.flat] - elif dtype in [core.VarDesc.VarType.INT32, core.DataType.INT32]: - value_name = "int32_values" - values = [int(v) for v in input.flat] - elif dtype in [core.VarDesc.VarType.INT64, core.DataType.INT64]: - value_name = "int64_values" - values = [int(v) for v in input.flat] - else: - raise TypeError( - "When the type of 'input' in assign is numpy.ndarray, " - "the data type of 'input' must be bool, float32, int32 or int64, but " - "received %s." % convert_dtype(dtype) - ) + check_dtype( + dtype, + 'input', + [ + 'float32', + 'float64', + 'int32', + 'int64', + 'bool', + 'complex64', + 'complex128', + ], + 'assign', + '(When the type of input in assign is numpy array.)', + ) + value_name = "values" + values = input.ravel().tolist() if input.size > 1024 * 1024: raise ValueError( "The size of input is too big. Please consider " diff --git a/test/dygraph_to_static/test_program_translator.py b/test/dygraph_to_static/test_program_translator.py index d384c7ad649d9b..d6addfe3400bc7 100644 --- a/test/dygraph_to_static/test_program_translator.py +++ b/test/dygraph_to_static/test_program_translator.py @@ -314,14 +314,24 @@ def test_ifelse_early_return1(self): answer = np.zeros([2, 2]) + 1 static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return1) out = static_func() - np.testing.assert_allclose(answer, out[0].numpy(), rtol=1e-05) + if isinstance(out, paddle.Tensor): + np.testing.assert_allclose( + paddle.to_tensor(answer), out, rtol=1e-05 + ) + elif isinstance(out, tuple): + np.testing.assert_allclose(answer, out[0].numpy(), rtol=1e-05) @disable_test_case((ToStaticMode.AST, IrMode.PT)) def test_ifelse_early_return2(self): answer = np.zeros([2, 2]) + 3 static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return2) out = static_func() - np.testing.assert_allclose(answer, out[0].numpy(), rtol=1e-05) + if isinstance(out, paddle.Tensor): + np.testing.assert_allclose( + paddle.to_tensor(answer), out, rtol=1e-05 + ) + elif isinstance(out, tuple): + np.testing.assert_allclose(answer, out[0].numpy(), rtol=1e-05) class TestRemoveCommentInDy2St(Dy2StTestBase): diff --git a/test/ir/inference/CMakeLists.txt b/test/ir/inference/CMakeLists.txt index 020b84b4fd32a2..185ca22f897f69 100755 --- a/test/ir/inference/CMakeLists.txt +++ b/test/ir/inference/CMakeLists.txt @@ -168,8 +168,8 @@ if(NOT WITH_MKLDNN set_tests_properties(${target} PROPERTIES LABELS "RUN_TYPE=INFER") endforeach() - set_tests_properties(test_mul_lstm_fuse_pass PROPERTIES TIMEOUT 300) - set_tests_properties(test_mul_gru_fuse_pass PROPERTIES TIMEOUT 300) + set_tests_properties(test_mul_lstm_fuse_pass PROPERTIES TIMEOUT 1000) + set_tests_properties(test_mul_gru_fuse_pass PROPERTIES TIMEOUT 600) endif() if(WITH_GPU AND TENSORRT_FOUND) diff --git a/test/ir/inference/test_mul_gru_fuse_pass.py b/test/ir/inference/test_mul_gru_fuse_pass.py index 91c8058c54ec55..0ccbe467246083 100644 --- a/test/ir/inference/test_mul_gru_fuse_pass.py +++ b/test/ir/inference/test_mul_gru_fuse_pass.py @@ -134,7 +134,7 @@ def sample_predictor_configs(self, program_config): def test(self): self.run_and_statis( - quant=False, max_duration=300, passes=["mul_gru_fuse_pass"] + quant=False, max_duration=600, passes=["mul_gru_fuse_pass"] ) diff --git a/test/ir/inference/test_mul_lstm_fuse_pass.py b/test/ir/inference/test_mul_lstm_fuse_pass.py index f6304404c36945..fec34311604eea 100644 --- a/test/ir/inference/test_mul_lstm_fuse_pass.py +++ b/test/ir/inference/test_mul_lstm_fuse_pass.py @@ -120,7 +120,7 @@ def sample_predictor_configs(self, program_config): def test(self): self.run_and_statis( - quant=False, max_duration=300, passes=["mul_lstm_fuse_pass"] + quant=False, max_duration=1000, passes=["mul_lstm_fuse_pass"] ) diff --git a/test/ir/inference/test_seq_concat_fc_fuse_pass.py b/test/ir/inference/test_seq_concat_fc_fuse_pass.py index 4f1a0cbb7af835..68e446c5a64691 100644 --- a/test/ir/inference/test_seq_concat_fc_fuse_pass.py +++ b/test/ir/inference/test_seq_concat_fc_fuse_pass.py @@ -140,7 +140,9 @@ def teller1(program_config, predictor_config): ) def test(self): - self.run_and_statis(quant=False, passes=["seq_concat_fc_fuse_pass"]) + self.run_and_statis( + quant=False, passes=["seq_concat_fc_fuse_pass"], max_duration=1000 + ) if __name__ == "__main__": diff --git a/test/legacy_test/test_assign_value_op.py b/test/legacy_test/test_assign_value_op.py index 6ff4282d9fc553..10ff186e2e966d 100644 --- a/test/legacy_test/test_assign_value_op.py +++ b/test/legacy_test/test_assign_value_op.py @@ -22,24 +22,24 @@ from paddle.base import framework -def assign_value_wrapper( - shape=[], dtype=base.core.VarDesc.VarType.FP32, values=0.0 -): - if paddle.framework.in_dynamic_mode(): - tensor = paddle.Tensor() - else: - np_type = paddle.base.data_feeder._PADDLE_DTYPE_2_NUMPY_DTYPE[dtype] - tensor = paddle.zeros(list(shape), np_type) - dtype = paddle.pir.core.convert_np_dtype_to_dtype_(np_type) - return paddle._C_ops.assign_value_( - tensor, shape, dtype, values, framework._current_expected_place() - ) +def wrap_assign_value_wrapper(dtype=base.core.VarDesc.VarType.FP32): + def assign_value_wrapper(shape=[], dtype=dtype, values=0.0): + if paddle.framework.in_dynamic_mode(): + tensor = paddle.Tensor() + else: + np_type = paddle.base.data_feeder._PADDLE_DTYPE_2_NUMPY_DTYPE[dtype] + tensor = paddle.zeros(list(shape), np_type) + dtype = paddle.pir.core.convert_np_dtype_to_dtype_(np_type) + return paddle._C_ops.assign_value_( + tensor, shape, dtype, values, framework._current_expected_place() + ) + + return assign_value_wrapper class TestAssignValueOp(op_test.OpTest): def setUp(self): self.op_type = "assign_value" - self.python_api = assign_value_wrapper self.inputs = {} self.attrs = {} self.init_data() @@ -47,11 +47,12 @@ def setUp(self): self.attrs["dtype"] = framework.convert_np_dtype_to_dtype_( self.value.dtype ) + self.python_api = wrap_assign_value_wrapper(self.attrs["dtype"]) self.outputs = {"Out": self.value} def init_data(self): self.value = np.random.random(size=(2, 5)).astype(np.float32) - self.attrs["fp32_values"] = [float(v) for v in self.value.flat] + self.attrs["values"] = [float(v) for v in self.value.flat] def test_forward(self): self.check_output(check_cinn=True, check_pir=True) @@ -60,13 +61,13 @@ def test_forward(self): class TestAssignValueOp2(TestAssignValueOp): def init_data(self): self.value = np.random.random(size=(2, 5)).astype(np.int32) - self.attrs["int32_values"] = [int(v) for v in self.value.flat] + self.attrs["values"] = [int(v) for v in self.value.flat] class TestAssignValueOp3(TestAssignValueOp): def init_data(self): self.value = np.random.random(size=(2, 5)).astype(np.int64) - self.attrs["int64_values"] = [int(v) for v in self.value.flat] + self.attrs["values"] = [int(v) for v in self.value.flat] class TestAssignValueOp4(TestAssignValueOp): @@ -74,7 +75,29 @@ def init_data(self): self.value = np.random.choice(a=[False, True], size=(2, 5)).astype( np.bool_ ) - self.attrs["bool_values"] = [int(v) for v in self.value.flat] + self.attrs["values"] = [int(v) for v in self.value.flat] + + +class TestAssignValueOp5(TestAssignValueOp): + def init_data(self): + self.value = np.random.random(size=(2, 5)).astype(np.float64) + self.attrs["values"] = [float(v) for v in self.value.flat] + + +class TestAssignValueOp6(TestAssignValueOp): + def init_data(self): + self.value = ( + np.random.random(size=(2, 5)) + 1j * np.random.random(size=(2, 5)) + ).astype(np.complex64) + self.attrs["values"] = list(self.value.flat) + + +class TestAssignValueOp7(TestAssignValueOp): + def init_data(self): + self.value = ( + np.random.random(size=(2, 5)) + 1j * np.random.random(size=(2, 5)) + ).astype(np.complex128) + self.attrs["values"] = list(self.value.flat) class TestAssignApi(unittest.TestCase): @@ -97,8 +120,7 @@ def test_assign(self): with op_test.paddle_static_guard(): main_program = base.Program() with base.program_guard(main_program): - x = paddle.tensor.create_tensor(dtype=self.dtype) - paddle.assign(self.value, output=x) + x = paddle.assign(self.value) exe = base.Executor(self.place) [fetched_x] = exe.run(main_program, feed={}, fetch_list=[x]) @@ -145,5 +167,46 @@ def init_dtype(self): self.dtype = "bool" +class TestAssignApi5(TestAssignApi): + def init_dtype(self): + self.dtype = "float64" + + +class TestAssignApi6(TestAssignApi): + def setUp(self): + with op_test.paddle_static_guard(): + self.init_dtype() + self.value = ( + np.random.random(size=(2, 5)) + + 1j * (np.random.random(size=(2, 5))) + ).astype(np.complex64) + self.place = ( + base.CUDAPlace(0) + if base.is_compiled_with_cuda() + else base.CPUPlace() + ) + + def init_dtype(self): + self.dtype = "complex64" + + +class TestAssignApi7(TestAssignApi): + def setUp(self): + with op_test.paddle_static_guard(): + self.init_dtype() + self.value = ( + np.random.random(size=(2, 5)) + + 1j * (np.random.random(size=(2, 5))) + ).astype(np.complex128) + self.place = ( + base.CUDAPlace(0) + if base.is_compiled_with_cuda() + else base.CPUPlace() + ) + + def init_dtype(self): + self.dtype = "complex128" + + if __name__ == '__main__': unittest.main() diff --git a/test/legacy_test/test_initializer.py b/test/legacy_test/test_initializer.py index 51702072844599..ac612d2b2bee30 100644 --- a/test/legacy_test/test_initializer.py +++ b/test/legacy_test/test_initializer.py @@ -1354,7 +1354,8 @@ def test_numpy_array_initializer(self, dtype="float32"): self.assertEqual(len(block.ops), num_ops) init_op = block.ops[0] self.assertEqual(init_op.type, 'assign_value') - assert (init_op.attr('fp32_values') == np_array).all() + values = framework.extract_plain_list(init_op.attr('values')) + assert values == np_array.ravel().tolist() return block def test_numpy_array_initializer_fp16(self): diff --git a/test/legacy_test/test_initializer_nn.py b/test/legacy_test/test_initializer_nn.py index 95c64ac6482905..1d9d8b08cf16de 100644 --- a/test/legacy_test/test_initializer_nn.py +++ b/test/legacy_test/test_initializer_nn.py @@ -664,8 +664,8 @@ def test_assign_initializer(self, dtype="float32"): self.assertEqual(len(block.ops), num_ops) init_op = block.ops[0] self.assertEqual(init_op.type, 'assign_value') - assert (init_op.attr('fp32_values') == np_array).all() - + values = framework.extract_plain_list(init_op.attr('values')) + assert values == np_array.ravel().tolist() paddle.disable_static() return block diff --git a/test/legacy_test/test_program_converter.py b/test/legacy_test/test_program_converter.py index 3894ca930ee0fc..3ba1e7f33ad577 100644 --- a/test/legacy_test/test_program_converter.py +++ b/test/legacy_test/test_program_converter.py @@ -301,3 +301,196 @@ def test_complex128(self): legacy_program_bytes = mp._get_desc().serialize_to_string( legacy_format=True ) + + +class TestAssignValue(unittest.TestCase): + def setUp(self): + paddle.enable_static() + + def _test_for_new_program_format(self, program_bytes): + restored_prog_as_is = framework_pb2.ProgramDesc.FromString( + program_bytes + ) + for block in restored_prog_as_is.blocks: + for op in block.ops: + if op.type in ("assign_value"): + attr_names = [attr.name for attr in op.attrs] + self.assertTrue("values" in attr_names) + self.assertFalse("bool_values" in attr_names) + self.assertFalse("int32_values" in attr_names) + self.assertFalse("int64_values" in attr_names) + self.assertFalse("fp32_values" in attr_names) + + def _test_for_legacy_program_format(self, program_bytes): + restored_prog_as_is = framework_pb2.ProgramDesc.FromString( + program_bytes + ) + for block in restored_prog_as_is.blocks: + for op in block.ops: + if op.type in ("set_value", "set_value_grad"): + attr_names = [attr.name for attr in op.attrs] + self.assertFalse("values" in attr_names) + self.assertTrue("bool_values" in attr_names) + self.assertTrue("int32_values" in attr_names) + self.assertTrue("int64_values" in attr_names) + self.assertTrue("fp32_values" in attr_names) + + def _test_equivalence( + self, + new_program_bytes, + legacy_program_bytes, + fetch_list, + expected_outputs, + ): + normal_program = paddle.static.io.deserialize_program(new_program_bytes) + converted_back_program = paddle.static.io.deserialize_program( + legacy_program_bytes + ) + exe = paddle.static.Executor(paddle.CPUPlace()) + out = exe.run(normal_program, fetch_list=fetch_list) + np.testing.assert_allclose(out[0], expected_outputs[0]) + out = exe.run(converted_back_program, fetch_list=fetch_list) + np.testing.assert_allclose(out[0], expected_outputs[0]) + + def test_int32(self): + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = np.array([[1, 1], [3, 4], [1, 3]]).astype(np.int32) + out = paddle.assign(x) + + normal_program_bytes = mp._get_desc().serialize_to_string() + legacy_program_bytes = mp._get_desc().serialize_to_string( + legacy_format=True + ) + self.assertNotEqual(normal_program_bytes, legacy_program_bytes) + self._test_for_new_program_format(normal_program_bytes) + self._test_for_legacy_program_format(legacy_program_bytes) + self._test_equivalence( + normal_program_bytes, + legacy_program_bytes, + fetch_list=[out.name], + expected_outputs=[x], + ) + + def test_int64(self): + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = np.array([[1, 1], [3, 4], [1, 3]]).astype(np.int64) + out = paddle.assign(x) + + normal_program_bytes = mp._get_desc().serialize_to_string() + legacy_program_bytes = mp._get_desc().serialize_to_string( + legacy_format=True + ) + + self.assertNotEqual(normal_program_bytes, legacy_program_bytes) + self._test_for_new_program_format(normal_program_bytes) + self._test_for_legacy_program_format(legacy_program_bytes) + self._test_equivalence( + normal_program_bytes, + legacy_program_bytes, + fetch_list=[out.name], + expected_outputs=[x], + ) + + def test_float32(self): + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = np.random.random(size=(2, 5)).astype(np.float32) + out = paddle.assign(x) + + normal_program_bytes = mp._get_desc().serialize_to_string() + legacy_program_bytes = mp._get_desc().serialize_to_string( + legacy_format=True + ) + + self.assertNotEqual(normal_program_bytes, legacy_program_bytes) + self._test_for_new_program_format(normal_program_bytes) + self._test_for_legacy_program_format(legacy_program_bytes) + self._test_equivalence( + normal_program_bytes, + legacy_program_bytes, + fetch_list=[out.name], + expected_outputs=[x], + ) + + def test_float64(self): + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = np.random.random(size=(2, 5)).astype(np.float64) + out = paddle.assign(x) + + normal_program_bytes = mp._get_desc().serialize_to_string() + legacy_program_bytes = mp._get_desc().serialize_to_string( + legacy_format=True + ) + + self.assertNotEqual(normal_program_bytes, legacy_program_bytes) + self._test_for_new_program_format(normal_program_bytes) + self._test_for_legacy_program_format(legacy_program_bytes) + self._test_equivalence( + normal_program_bytes, + legacy_program_bytes, + fetch_list=[out.name], + expected_outputs=[x], + ) + + def test_bool(self): + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = np.random.choice(a=[False, True], size=(2, 5)).astype(np.bool_) + out = paddle.assign(x) + + normal_program_bytes = mp._get_desc().serialize_to_string() + legacy_program_bytes = mp._get_desc().serialize_to_string( + legacy_format=True + ) + + self.assertNotEqual(normal_program_bytes, legacy_program_bytes) + self._test_for_new_program_format(normal_program_bytes) + self._test_for_legacy_program_format(legacy_program_bytes) + self._test_equivalence( + normal_program_bytes, + legacy_program_bytes, + fetch_list=[out.name], + expected_outputs=[x], + ) + + def test_complex64(self): + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = ( + np.random.random(size=(2, 5)) + + 1j * np.random.random(size=(2, 5)) + ).astype(np.complex64) + out = paddle.assign(x) + + with self.assertRaisesRegex(RuntimeError, "Invalid data type"): + legacy_program_bytes = mp._get_desc().serialize_to_string( + legacy_format=True + ) + + def test_complex128(self): + mp = paddle.static.Program() + sp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = ( + np.random.random(size=(2, 5)) + + 1j * np.random.random(size=(2, 5)) + ).astype(np.complex128) + out = paddle.assign(x) + + with self.assertRaisesRegex(RuntimeError, "Invalid data type"): + legacy_program_bytes = mp._get_desc().serialize_to_string( + legacy_format=True + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/xpu/test_assign_value_op_xpu.py b/test/xpu/test_assign_value_op_xpu.py index f6d2d2ec96ae39..e4414cdaafc050 100644 --- a/test/xpu/test_assign_value_op_xpu.py +++ b/test/xpu/test_assign_value_op_xpu.py @@ -53,7 +53,7 @@ def setUp(self): def init_data(self): self.value = np.random.random(size=(2, 5)).astype(np.float32) - self.attrs["fp32_values"] = [float(v) for v in self.value.flat] + self.attrs["values"] = [float(v) for v in self.value.flat] def test_forward(self): self.check_output_with_place(self.place) @@ -61,19 +61,40 @@ def test_forward(self): class TestAssignValueOp2(TestAssignValueOp): def init_data(self): self.value = np.random.random(size=(2, 5)).astype(np.int32) - self.attrs["int32_values"] = [int(v) for v in self.value.flat] + self.attrs["values"] = [int(v) for v in self.value.flat] class TestAssignValueOp3(TestAssignValueOp): def init_data(self): self.value = np.random.random(size=(2, 5)).astype(np.int64) - self.attrs["int64_values"] = [int(v) for v in self.value.flat] + self.attrs["values"] = [int(v) for v in self.value.flat] class TestAssignValueOp4(TestAssignValueOp): def init_data(self): self.value = np.random.choice(a=[False, True], size=(2, 5)).astype( np.bool_ ) - self.attrs["bool_values"] = [int(v) for v in self.value.flat] + self.attrs["values"] = [int(v) for v in self.value.flat] + + class TestAssignValueOp5(TestAssignValueOp): + def init_data(self): + self.value = np.random.random(size=(2, 5)).astype(np.float64) + self.attrs["values"] = [float(v) for v in self.value.flat] + + class TestAssignValueOp6(TestAssignValueOp): + def init_data(self): + self.value = ( + np.random.random(size=(2, 5)) + + 1j * np.random.random(size=(2, 5)) + ).astype(np.complex64) + self.attrs["values"] = list(self.value.flat) + + class TestAssignValueOp7(TestAssignValueOp): + def init_data(self): + self.value = ( + np.random.random(size=(2, 5)) + + 1j * np.random.random(size=(2, 5)) + ).astype(np.complex128) + self.attrs["values"] = list(self.value.flat) class TestAssignApi(unittest.TestCase): @@ -90,8 +111,7 @@ def init_dtype(self): def test_assign(self): main_program = base.Program() with base.program_guard(main_program): - x = paddle.tensor.create_tensor(dtype=self.dtype) - paddle.assign(self.value, output=x) + x = paddle.assign(self.value) exe = base.Executor(self.place) [fetched_x] = exe.run(main_program, feed={}, fetch_list=[x]) @@ -121,6 +141,35 @@ def init_dtype(self): self.dtype = "bool" +class TestAssignApi5(TestAssignApi): + def init_dtype(self): + self.dtype = "float64" + + +class TestAssignApi6(TestAssignApi): + def setUp(self): + self.init_dtype() + self.value = ( + np.random.random(size=(2, 5)) + 1j * (np.random.random(size=(2, 5))) + ).astype(np.complex64) + self.place = base.XPUPlace(0) + + def init_dtype(self): + self.dtype = "complex64" + + +class TestAssignApi7(TestAssignApi): + def setUp(self): + self.init_dtype() + self.value = ( + np.random.random(size=(2, 5)) + 1j * (np.random.random(size=(2, 5))) + ).astype(np.complex128) + self.place = base.XPUPlace(0) + + def init_dtype(self): + self.dtype = "complex128" + + support_types = get_xpu_op_support_types('assign_value') for stype in support_types: create_test_class(globals(), XPUTestAssignValueOp, stype) From d72ed8aeb0fcd0343e6fe15651af9c920d048964 Mon Sep 17 00:00:00 2001 From: xuxinyi389 <104957571+xuxinyi389@users.noreply.github.com> Date: Thu, 28 Dec 2023 18:52:14 +0800 Subject: [PATCH 035/142] =?UTF-8?q?Revert=20"=E3=80=90Hackathon=205th=20No?= =?UTF-8?q?.25=E3=80=91add=20`gammaln`=20api=20(#59311)"=20(#60450)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit beba862cd2aa4dd2b14cdd0c6c4c08be33df62f2. --- paddle/phi/api/yaml/backward.yaml | 10 -- paddle/phi/api/yaml/ops.yaml | 10 -- paddle/phi/kernels/cpu/gammaln_grad_kernel.cc | 22 --- paddle/phi/kernels/cpu/gammaln_kernel.cc | 22 --- paddle/phi/kernels/gammaln_grad_kernel.h | 27 --- paddle/phi/kernels/gammaln_kernel.h | 26 --- paddle/phi/kernels/gpu/gammaln_grad_kernel.cu | 30 ---- paddle/phi/kernels/gpu/gammaln_kernel.cu | 29 ---- .../kernels/impl/gammaln_grad_kernel_impl.h | 92 ---------- paddle/phi/kernels/impl/gammaln_kernel_impl.h | 49 ------ python/paddle/__init__.py | 4 - python/paddle/tensor/__init__.py | 4 - python/paddle/tensor/math.py | 45 ----- test/legacy_test/test_gammaln_op.py | 160 ------------------ test/legacy_test/test_inplace.py | 8 - 15 files changed, 538 deletions(-) delete mode 100644 paddle/phi/kernels/cpu/gammaln_grad_kernel.cc delete mode 100644 paddle/phi/kernels/cpu/gammaln_kernel.cc delete mode 100644 paddle/phi/kernels/gammaln_grad_kernel.h delete mode 100644 paddle/phi/kernels/gammaln_kernel.h delete mode 100644 paddle/phi/kernels/gpu/gammaln_grad_kernel.cu delete mode 100644 paddle/phi/kernels/gpu/gammaln_kernel.cu delete mode 100644 paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h delete mode 100644 paddle/phi/kernels/impl/gammaln_kernel_impl.h delete mode 100644 test/legacy_test/test_gammaln_op.py diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index d5748145ffe49d..938ea9d5000460 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -922,16 +922,6 @@ kernel : func : frame_grad -- backward_op : gammaln_grad - forward : gammaln(Tensor x) -> Tensor(out) - args : (Tensor x, Tensor out_grad) - output : Tensor(x_grad) - infer_meta : - func : UnchangedInferMeta - param: [x] - kernel : - func : gammaln_grad - - backward_op : gather_grad forward : gather(Tensor x, Tensor index, Scalar axis=0) -> Tensor(out) args : (Tensor x, Tensor index, Tensor out_grad, Scalar axis=0) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index dc545b7a2da546..de4d700cdf80ee 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1042,16 +1042,6 @@ data_type : dtype backend : place -- op : gammaln - args : (Tensor x) - output : Tensor(out) - infer_meta : - func : UnchangedInferMeta - kernel : - func : gammaln - inplace: (x -> out) - backward : gammaln_grad - - op : gather args : (Tensor x, Tensor index, Scalar axis=0) output : Tensor(out) diff --git a/paddle/phi/kernels/cpu/gammaln_grad_kernel.cc b/paddle/phi/kernels/cpu/gammaln_grad_kernel.cc deleted file mode 100644 index c52ee8b3848e9a..00000000000000 --- a/paddle/phi/kernels/cpu/gammaln_grad_kernel.cc +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/gammaln_grad_kernel.h" - -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h" - -PD_REGISTER_KERNEL( - gammaln_grad, CPU, ALL_LAYOUT, phi::GammalnGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/gammaln_kernel.cc b/paddle/phi/kernels/cpu/gammaln_kernel.cc deleted file mode 100644 index ff62f86d2522fd..00000000000000 --- a/paddle/phi/kernels/cpu/gammaln_kernel.cc +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/gammaln_kernel.h" - -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/gammaln_kernel_impl.h" - -PD_REGISTER_KERNEL( - gammaln, CPU, ALL_LAYOUT, phi::GammalnKernel, float, double) {} diff --git a/paddle/phi/kernels/gammaln_grad_kernel.h b/paddle/phi/kernels/gammaln_grad_kernel.h deleted file mode 100644 index 440dca72a9d469..00000000000000 --- a/paddle/phi/kernels/gammaln_grad_kernel.h +++ /dev/null @@ -1,27 +0,0 @@ - -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" - -namespace phi { - -template -void GammalnGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& d_out, - DenseTensor* d_x); -} // namespace phi diff --git a/paddle/phi/kernels/gammaln_kernel.h b/paddle/phi/kernels/gammaln_kernel.h deleted file mode 100644 index db3015c4a747db..00000000000000 --- a/paddle/phi/kernels/gammaln_kernel.h +++ /dev/null @@ -1,26 +0,0 @@ - -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" - -namespace phi { - -template -void GammalnKernel(const Context& dev_ctx, - const DenseTensor& x, - DenseTensor* out); -} // namespace phi diff --git a/paddle/phi/kernels/gpu/gammaln_grad_kernel.cu b/paddle/phi/kernels/gpu/gammaln_grad_kernel.cu deleted file mode 100644 index b2513d9e3f25ca..00000000000000 --- a/paddle/phi/kernels/gpu/gammaln_grad_kernel.cu +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/gammaln_grad_kernel.h" - -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/common/amp_type_traits.h" -#include "paddle/phi/core/kernel_registry.h" - -#include "paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h" - -PD_REGISTER_KERNEL(gammaln_grad, - GPU, - ALL_LAYOUT, - phi::GammalnGradKernel, - float, - double, - phi::dtype::float16, - phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/gammaln_kernel.cu b/paddle/phi/kernels/gpu/gammaln_kernel.cu deleted file mode 100644 index 3d57be7b277335..00000000000000 --- a/paddle/phi/kernels/gpu/gammaln_kernel.cu +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/gammaln_kernel.h" - -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/kernel_registry.h" - -#include "paddle/phi/kernels/impl/gammaln_kernel_impl.h" - -PD_REGISTER_KERNEL(gammaln, - GPU, - ALL_LAYOUT, - phi::GammalnKernel, - float, - double, - phi::dtype::float16, - phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h deleted file mode 100644 index 50c73cff27ce4a..00000000000000 --- a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/common/amp_type_traits.h" -#include "paddle/phi/kernels/funcs/for_range.h" - -namespace phi { -template -HOSTDEVICE T digamma(T x) { - static T c = T{8.5}; - static T euler_mascheroni = T{0.57721566490153286060}; - T r; - T value; - T x2; - - if (x <= T{0.0}) { - value = T{0.0}; - return value; - } - - if (x <= T{0.000001}) { - value = -euler_mascheroni - T{1.0} / x + T{1.6449340668482264365} * x; - return value; - } - - value = T{0.0}; - x2 = x; - while (x2 < c) { - value = value - T{1.0} / x2; - x2 = x2 + T{1.0}; - } - - r = T{1.0} / x2; - value = value + std::log(x2) - T{0.5} * r; - - r = r * r; - - value = value - - r * (T{1.0} / T{12.0} - - r * (T{1.0} / T{120.0} - - r * (T{1.0} / T{252.0} - - r * (T{1.0} / T{240.0} - r * (T{1.0} / T{132.0}))))); - - return value; -} - -template -struct GammalnGradFunctor { - GammalnGradFunctor(const T* dout, const T* x, T* output, int64_t numel) - : dout_(dout), x_(x), output_(output), numel_(numel) {} - - HOSTDEVICE void operator()(int64_t idx) const { - using MT = typename phi::dtype::MPTypeTrait::Type; - const MT mp_dout = static_cast(dout_[idx]); - const MT mp_x = static_cast(x_[idx]); - output_[idx] = static_cast(mp_dout * digamma(mp_x)); - } - - private: - const T* dout_; - const T* x_; - T* output_; - int64_t numel_; -}; -template -void GammalnGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& d_out, - DenseTensor* d_x) { - auto numel = d_out.numel(); - auto* dout_data = d_out.data(); - auto* x_data = x.data(); - auto* dx_data = - dev_ctx.template Alloc(d_x, static_cast(numel * sizeof(T))); - phi::funcs::ForRange for_range(dev_ctx, numel); - GammalnGradFunctor functor(dout_data, x_data, dx_data, numel); - for_range(functor); -} -} // namespace phi diff --git a/paddle/phi/kernels/impl/gammaln_kernel_impl.h b/paddle/phi/kernels/impl/gammaln_kernel_impl.h deleted file mode 100644 index 38385610de0de6..00000000000000 --- a/paddle/phi/kernels/impl/gammaln_kernel_impl.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/common/amp_type_traits.h" -#include "paddle/phi/kernels/funcs/for_range.h" - -namespace phi { -template -struct GammalnFunctor { - GammalnFunctor(const T* x, T* output, int64_t numel) - : x_(x), output_(output), numel_(numel) {} - - HOSTDEVICE void operator()(int64_t idx) const { - using MT = typename phi::dtype::MPTypeTrait::Type; - const MT mp_x = static_cast(x_[idx]); - output_[idx] = static_cast(std::lgamma(mp_x)); - } - - private: - const T* x_; - T* output_; - int64_t numel_; -}; - -template -void GammalnKernel(const Context& dev_ctx, - const DenseTensor& x, - DenseTensor* out) { - auto numel = x.numel(); - auto* x_data = x.data(); - auto* out_data = dev_ctx.template Alloc(out); - phi::funcs::ForRange for_range(dev_ctx, numel); - GammalnFunctor functor(x_data, out_data, numel); - for_range(functor); -} -} // namespace phi diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 1f0017562ebade..fc7b2a3533f892 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -398,8 +398,6 @@ frac, frac_, frexp, - gammaln, - gammaln_, gcd, gcd_, heaviside, @@ -775,8 +773,6 @@ 'square_', 'divide', 'divide_', - 'gammaln', - 'gammaln_', 'ceil', 'atan', 'atan_', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index b718910348d8ff..b26798892a2b2f 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -278,8 +278,6 @@ frac, frac_, frexp, - gammaln, - gammaln_, gcd, gcd_, heaviside, @@ -670,8 +668,6 @@ 'real', 'imag', 'is_floating_point', - 'gammaln', - 'gammaln_', 'digamma', 'digamma_', 'diagonal', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 6d75d41b4949ca..acaa0905ce6f40 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -5003,51 +5003,6 @@ def conj(x, name=None): return out -def gammaln(x, name=None): - r""" - Calculates the logarithm of the absolute value of the gamma function elementwisely. - - Args: - x (Tensor): Input Tensor. Must be one of the following types: float16, float32, float64, bfloat16. - name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. - - Returns: - Tensor, The values of the logarithm of the absolute value of the gamma at the given tensor x. - - Examples: - .. code-block:: python - - >>> import paddle - - >>> x = paddle.arange(1.5, 4.5, 0.5) - >>> out = paddle.gammaln(x) - >>> print(out) - Tensor(shape=[6], dtype=float32, place=Place(cpu), stop_gradient=True, - [-0.12078224, 0. , 0.28468287, 0.69314718, 1.20097363, - 1.79175949]) - """ - if in_dynamic_or_pir_mode(): - return _C_ops.gammaln(x) - else: - check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'bfloat16'], 'gammaln' - ) - helper = LayerHelper('gammaln', **locals()) - out = helper.create_variable_for_type_inference(x.dtype) - helper.append_op(type='gammaln', inputs={'x': x}, outputs={'out': out}) - return out - - -@inplace_apis_in_dygraph_only -def gammaln_(x, name=None): - r""" - Inplace version of ``gammaln`` API, the output Tensor will be inplaced with input ``x``. - Please refer to :ref:`api_paddle_gammaln`. - """ - if in_dynamic_mode(): - return _C_ops.gammaln_(x) - - def digamma(x, name=None): r""" Calculates the digamma of the given input tensor, element-wise. diff --git a/test/legacy_test/test_gammaln_op.py b/test/legacy_test/test_gammaln_op.py deleted file mode 100644 index 50331af5c7a34c..00000000000000 --- a/test/legacy_test/test_gammaln_op.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from op_test import OpTest, convert_float_to_uint16 -from scipy import special - -import paddle -from paddle.base import core - - -def ref_gammaln(x): - return special.gammaln(x) - - -def ref_gammaln_grad(x, dout): - return dout * special.polygamma(0, x) - - -class TestGammalnOp(OpTest): - def setUp(self): - self.op_type = 'gammaln' - self.python_api = paddle.gammaln - self.init_dtype_type() - self.shape = (3, 40) - self.x = np.random.random(self.shape).astype(self.dtype) + 1 - self.inputs = {'x': self.x} - out = ref_gammaln(self.x) - self.outputs = {'out': out} - - def init_dtype_type(self): - self.dtype = np.float64 - - def test_check_output(self): - self.check_output(check_pir=True) - - def test_check_grad(self): - self.check_grad(['x'], 'out', check_pir=True) - - -class TestGammalnOpFp32(TestGammalnOp): - def init_dtype_type(self): - self.dtype = np.float32 - - -class TestGammalnFP16Op(TestGammalnOp): - def init_dtype_type(self): - self.dtype = np.float16 - - -class TestGammalnBigNumberOp(TestGammalnOp): - def setUp(self): - self.op_type = 'gammaln' - self.python_api = paddle.gammaln - self.init_dtype_type() - self.shape = (100, 1) - self.x = np.random.random(self.shape).astype(self.dtype) + 1 - self.x[:5, 0] = np.array([1e5, 1e10, 1e20, 1e40, 1e80]) - self.inputs = {'x': self.x} - out = ref_gammaln(self.x) - self.outputs = {'out': out} - - def init_dtype_type(self): - self.dtype = np.float64 - - def test_check_grad(self): - d_out = self.outputs['out'] - d_x = ref_gammaln_grad(self.x, d_out) - self.check_grad( - ['x'], - 'out', - user_defined_grads=[ - d_x, - ], - user_defined_grad_outputs=[ - d_out, - ], - check_pir=True, - ) - - -@unittest.skipIf( - not core.is_compiled_with_cuda() - or not core.is_bfloat16_supported(core.CUDAPlace(0)), - "core is not compiled with CUDA or not support bfloat16", -) -class TestGammalnBF16Op(OpTest): - def setUp(self): - self.op_type = 'gammaln' - self.python_api = paddle.gammaln - self.dtype = np.uint16 - self.shape = (5, 30) - x = np.random.random(self.shape).astype("float32") + 1 - self.inputs = {'x': convert_float_to_uint16(x)} - out = ref_gammaln(x) - self.outputs = {'out': convert_float_to_uint16(out)} - - def test_check_output(self): - self.check_output_with_place(core.CUDAPlace(0), check_pir=True) - - def test_check_grad(self): - self.check_grad_with_place( - core.CUDAPlace(0), ['x'], 'out', check_pir=True - ) - - -class TestGammalnOpApi(unittest.TestCase): - def setUp(self): - self.shape = [2, 3, 4, 5] - self.init_dtype_type() - self.x_np = np.random.random(self.shape).astype(self.dtype) + 1 - self.place = ( - paddle.CUDAPlace(0) - if core.is_compiled_with_cuda() - else paddle.CPUPlace() - ) - - def init_dtype_type(self): - self.dtype = "float64" - - def test_static_api(self): - paddle.enable_static() - with paddle.static.program_guard(paddle.static.Program()): - x = paddle.static.data('x', self.x_np.shape, self.x_np.dtype) - out = paddle.gammaln(x) - exe = paddle.static.Executor(self.place) - (res,) = exe.run(feed={'x': self.x_np}, fetch_list=[out]) - out_ref = ref_gammaln(self.x_np) - np.testing.assert_allclose(out_ref, res, rtol=1e-5, atol=1e-5) - - def test_dygraph_api(self): - paddle.disable_static(self.place) - x = paddle.to_tensor(self.x_np) - out = paddle.gammaln(x) - out_ref = ref_gammaln(self.x_np) - np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-5, atol=1e-5) - paddle.enable_static() - - -class TestGammalnOpApiFp32(TestGammalnOpApi): - def init_dtype_type(self): - self.dtype = "float32" - - -if __name__ == "__main__": - paddle.enable_static() - unittest.main() diff --git a/test/legacy_test/test_inplace.py b/test/legacy_test/test_inplace.py index 38fbac0357d6df..42f9a46cfb9100 100644 --- a/test/legacy_test/test_inplace.py +++ b/test/legacy_test/test_inplace.py @@ -869,14 +869,6 @@ def test_leaf_inplace_var_error(self): pass -class TestDygraphInplaceGammaln(TestDygraphInplaceWithContinuous): - def inplace_api_processing(self, var): - return paddle.gammaln_(var) - - def non_inplace_api_processing(self, var): - return paddle.gammaln(var) - - class TestDygraphInplaceNeg(TestDygraphInplaceWithContinuous): def inplace_api_processing(self, var): return paddle.neg_(var) From 51dc03178d3b90f0fa84eb6da336d7cb1aaf02e5 Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Thu, 28 Dec 2023 19:15:04 +0800 Subject: [PATCH 036/142] fix dead code elimination pass bug (#60430) --- paddle/fluid/pir/transforms/dead_code_elimination_pass.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc b/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc index 4c8fa32c6d6353..bc2421cfe1a869 100644 --- a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc +++ b/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc @@ -58,6 +58,10 @@ class DeadCodeEliminationPass : public pir::Pass { } } } + + if (!deleted_ops.empty()) { + EraseOp(block, num_erasers); + } } }; From fdc38b2ba32a5b4f27557c36de09bbcee3b9d816 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Thu, 28 Dec 2023 20:47:36 +0800 Subject: [PATCH 037/142] [DRR] change namespace pir::drr:: to paddle::drr:: (#60432) --- .../operator/transforms/pd_to_cinn_pass.cc | 40 +-- .../op_generator/op_creator_drr_gen.py | 8 +- paddle/fluid/pir/drr/README.md | 24 +- paddle/fluid/pir/drr/README_cn.md | 24 +- paddle/fluid/pir/drr/api/drr_pattern_base.h | 6 +- .../fluid/pir/drr/api/drr_pattern_context.cc | 5 +- .../fluid/pir/drr/api/drr_pattern_context.h | 4 +- paddle/fluid/pir/drr/api/match_context.cc | 4 +- paddle/fluid/pir/drr/api/match_context.h | 4 +- paddle/fluid/pir/drr/api/tensor_interface.cc | 4 +- paddle/fluid/pir/drr/api/tensor_interface.h | 4 +- paddle/fluid/pir/drr/attr_type_uilts.h | 20 +- paddle/fluid/pir/drr/drr_rewrite_pattern.cc | 42 +-- paddle/fluid/pir/drr/drr_rewrite_pattern.h | 11 +- paddle/fluid/pir/drr/ir_operation.h | 4 +- paddle/fluid/pir/drr/ir_operation_factory.cc | 24 +- paddle/fluid/pir/drr/ir_operation_factory.h | 8 +- paddle/fluid/pir/drr/ir_value.h | 8 +- paddle/fluid/pir/drr/match_context_impl.h | 4 +- paddle/fluid/pir/drr/pattern_graph.cc | 4 +- paddle/fluid/pir/drr/pattern_graph.h | 4 +- .../transforms/fusion/attention_fuse_pass.cc | 50 ++-- .../transforms/fusion/conv2d_add_fuse_pass.cc | 18 +- .../fc_elementwise_layernorm_fuse_pass.cc | 32 ++- .../pir/transforms/fusion/fc_fuse_pass.cc | 33 +-- .../fusion/fc_with_special_op_fuse_pass.cc | 68 +++-- .../fused_dot_product_attention_pass.cc | 250 ++++++++++-------- .../fusion/fused_dropout_add_pass.cc | 16 +- .../fusion/fused_gemm_epilogue_pass.cc | 75 +++--- .../fused_linear_param_grad_add_pass.cc | 132 +++++---- .../fusion/fused_weight_only_linear_pass.cc | 64 ++--- .../fusion/matmul_scale_fuse_pass.cc | 24 +- .../pir/transforms/identity_op_clean_pass.cc | 62 ++--- .../drr_same_type_binding_test.cc | 8 +- test/cpp/pir/pattern_rewrite/drr_test.cc | 38 +-- 35 files changed, 597 insertions(+), 529 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index 295c50b0eae00e..352fd9fdde322b 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -31,11 +31,11 @@ namespace cinn { namespace dialect { namespace ir { -class SumOpPattern : public pir::drr::DrrPatternBase { +class SumOpPattern : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern - pir::drr::SourcePattern pattern = ctx->SourcePattern(); + paddle::drr::SourcePattern pattern = ctx->SourcePattern(); const auto &full_int_array = pattern.Op(paddle::dialect::FullIntArrayOp::name(), {{"value", pattern.Attr("axis_info")}, @@ -48,7 +48,7 @@ class SumOpPattern : public pir::drr::DrrPatternBase { pattern.Tensor("ret") = sum(pattern.Tensor("arg0"), full_int_array()); // Result patterns - pir::drr::ResultPattern res = pattern.ResultPattern(); + paddle::drr::ResultPattern res = pattern.ResultPattern(); const auto &cinn_reduce_sum = res.Op(cinn::dialect::ReduceSumOp::name(), {{"dim", pattern.Attr("axis_info")}, @@ -57,11 +57,11 @@ class SumOpPattern : public pir::drr::DrrPatternBase { } }; -class MaxOpPattern : public pir::drr::DrrPatternBase { +class MaxOpPattern : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern - pir::drr::SourcePattern pattern = ctx->SourcePattern(); + paddle::drr::SourcePattern pattern = ctx->SourcePattern(); const auto &full_int_array = pattern.Op(paddle::dialect::FullIntArrayOp::name(), {{"value", pattern.Attr("axis_info")}, @@ -73,7 +73,7 @@ class MaxOpPattern : public pir::drr::DrrPatternBase { pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array()); // Result patterns - pir::drr::ResultPattern res = pattern.ResultPattern(); + paddle::drr::ResultPattern res = pattern.ResultPattern(); const auto &cinn_reduce_max = res.Op(cinn::dialect::ReduceMaxOp::name(), {{"dim", pattern.Attr("axis_info")}, @@ -82,11 +82,11 @@ class MaxOpPattern : public pir::drr::DrrPatternBase { } }; -class MinOpPattern : public pir::drr::DrrPatternBase { +class MinOpPattern : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern - pir::drr::SourcePattern pattern = ctx->SourcePattern(); + paddle::drr::SourcePattern pattern = ctx->SourcePattern(); const auto &full_int_array = pattern.Op(paddle::dialect::FullIntArrayOp::name(), {{"value", pattern.Attr("axis_info")}, @@ -98,7 +98,7 @@ class MinOpPattern : public pir::drr::DrrPatternBase { pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array()); // Result patterns - pir::drr::ResultPattern res = pattern.ResultPattern(); + paddle::drr::ResultPattern res = pattern.ResultPattern(); const auto &cinn_reduce_max = res.Op(cinn::dialect::ReduceMinOp::name(), {{"dim", pattern.Attr("axis_info")}, @@ -107,11 +107,11 @@ class MinOpPattern : public pir::drr::DrrPatternBase { } }; -class ProdOpPattern : public pir::drr::DrrPatternBase { +class ProdOpPattern : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern - pir::drr::SourcePattern pattern = ctx->SourcePattern(); + paddle::drr::SourcePattern pattern = ctx->SourcePattern(); const auto &full_int_array = pattern.Op(paddle::dialect::FullIntArrayOp::name(), {{"value", pattern.Attr("axis_info")}, @@ -123,7 +123,7 @@ class ProdOpPattern : public pir::drr::DrrPatternBase { pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array()); // Result patterns - pir::drr::ResultPattern res = pattern.ResultPattern(); + paddle::drr::ResultPattern res = pattern.ResultPattern(); const auto &cinn_reduce_max = res.Op(cinn::dialect::ReduceProdOp::name(), {{"dim", pattern.Attr("axis_info")}, @@ -552,11 +552,11 @@ class SplitWithNumOpPattern } }; -class UniformOpPattern : public pir::drr::DrrPatternBase { +class UniformOpPattern : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern - pir::drr::SourcePattern pattern = ctx->SourcePattern(); + paddle::drr::SourcePattern pattern = ctx->SourcePattern(); const auto &full_int_array = pattern.Op(paddle::dialect::FullIntArrayOp::name(), {{"value", pattern.Attr("axis_info")}, @@ -585,7 +585,7 @@ class UniformOpPattern : public pir::drr::DrrPatternBase { // int64_t[] shape, float min, float max, int seed, DataType dtype, int // diag_num, int diag_step, float diag_val) // Result patterns - pir::drr::ResultPattern res = pattern.ResultPattern(); + paddle::drr::ResultPattern res = pattern.ResultPattern(); const auto &cinn_uniform = res.Op(cinn::dialect::UniformRandomOp::name(), {{"shape", pattern.Attr("axis_info")}, diff --git a/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py b/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py index 9a40f74429e52b..18dc70f9fa7a7c 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py @@ -27,7 +27,7 @@ {op_header} #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" -namespace pir {{ +namespace paddle {{ namespace drr {{ void OperationFactory::Register{dialect}GeneratedOpCreator() {{ @@ -35,14 +35,14 @@ }} }} // namespace drr -}} // namespace pir +}} // namespace paddle """ NORMAL_FUNCTION_TEMPLATE = """ RegisterOperationCreator( "{op_name}", - [](const std::vector& inputs, + [](const std::vector& inputs, const pir::AttributeMap& attrs, pir::PatternRewriter& rewriter) {{ return rewriter.Build<{namespace}::{op_class_name}>( @@ -53,7 +53,7 @@ MUTABLE_ATTR_FUNCTION_TEMPLATE = """ RegisterOperationCreator( "{op_name}", - [](const std::vector& inputs, + [](const std::vector& inputs, const pir::AttributeMap& attrs, pir::PatternRewriter& rewriter) {{ // mutable_attr is tensor diff --git a/paddle/fluid/pir/drr/README.md b/paddle/fluid/pir/drr/README.md index 4abdbb1b647179..6fbac0756ae865 100644 --- a/paddle/fluid/pir/drr/README.md +++ b/paddle/fluid/pir/drr/README.md @@ -10,9 +10,9 @@ Taking PASS to eliminate redundant CastOp as an example, the code example develo ~~~ c++ // 1. Inherit specialized template class from DrPatternBase class RemoveRedundentCastPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { // 2. Overload operator() - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // 3. Define a SourcePattern containing two consecutive CastOps using Op, Tensor, and Attribute auto pat = ctx->SourcePattern(); @@ -55,7 +55,7 @@ Developers only need to define `SourcePattern`, `Constrains` and `ResultPattern` DrrPatternBase
 virtual void operator()(
-        pir::drr::DrrPatternContext* ctx) const 
+ paddle::drr::DrrPatternContext* ctx) const Implement the entry function of DRR PASS ctx: Context parameters required to create Patten @@ -165,11 +165,11 @@ Attribute Attr(const AttrComputeFunc& attr_compute_func) const ## 3 Example Example 1: Matmul + Add -> FusedGemmEpilogue ~~~ c++ -class FusedLinearPattern : public pir::drr::DrrPatternBase { +class FusedLinearPattern : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Define SourcePattern - pir::drr::SourcePattern pat = ctx->SourcePattern(); + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); @@ -179,10 +179,10 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase { pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias")); // Define ResultPattern - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); // Define Constrain const auto &act_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return "none"; }); const auto &fused_gemm_epilogue = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(), @@ -199,11 +199,11 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase { Example 2: Full + Expand -> Full ~~~ c++ class FoldExpandToConstantPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Define SourcePattern - pir::drr::SourcePattern pat = ctx->SourcePattern(); + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &full1 = pat.Op(paddle::dialect::FullOp::name(), {{"shape", pat.Attr("shape_1")}, {"value", pat.Attr("value_1")}, @@ -218,7 +218,7 @@ class FoldExpandToConstantPattern pat.Tensor("ret") = expand(full1(), full_int_array1()); // Define ResultPattern - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &full2 = res.Op(paddle::dialect::FullOp::name(), {{"shape", pat.Attr("expand_shape_value")}, {"value", pat.Attr("value_1")}, diff --git a/paddle/fluid/pir/drr/README_cn.md b/paddle/fluid/pir/drr/README_cn.md index 456bf7921414bf..1291bec2954c48 100644 --- a/paddle/fluid/pir/drr/README_cn.md +++ b/paddle/fluid/pir/drr/README_cn.md @@ -10,9 +10,9 @@ DRR ( Declarative Rewrite Rule ) 是来处理这种 DAG-to-DAG 类型的一套 P ~~~ c++ // 1. 继承 DrrPatternBase 的特化模板类 class RemoveRedundentCastPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { // 2. 重载 operator() - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // 3. 使用 Op、Tensor 和 Attribute 定义一个包含两个连续 CastOp 的 SourcePattern auto pat = ctx->SourcePattern(); @@ -56,7 +56,7 @@ DRR PASS 包含以下三个部分: DrrPatternBase
 virtual void operator()(
-        pir::drr::DrrPatternContext* ctx) const 
+ paddle::drr::DrrPatternContext* ctx) const 实现 DRR PASS 的入口函数 ctx: 创建 Patten 所需要的 Context 参数 @@ -168,11 +168,11 @@ Attribute Attr(const AttrComputeFunc& attr_compute_func) const ## 3 使用示例 Example 1: Matmul + Add -> FusedGemmEpilogue ~~~ c++ -class FusedLinearPattern : public pir::drr::DrrPatternBase { +class FusedLinearPattern : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // 定义 Source Pattern - pir::drr::SourcePattern pat = ctx->SourcePattern(); + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); @@ -182,10 +182,10 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase { pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias")); // 定义 Result Pattern - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); // 定义 Constrain const auto &act_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return "none"; }); const auto &fused_gemm_epilogue = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(), @@ -202,11 +202,11 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase { Example 2: Full + Expand -> Full ~~~ c++ class FoldExpandToConstantPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // 定义 Source Pattern - pir::drr::SourcePattern pat = ctx->SourcePattern(); + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &full1 = pat.Op(paddle::dialect::FullOp::name(), {{"shape", pat.Attr("shape_1")}, {"value", pat.Attr("value_1")}, @@ -221,7 +221,7 @@ class FoldExpandToConstantPattern pat.Tensor("ret") = expand(full1(), full_int_array1()); // 定义 Result Pattern Constrains: 本 Pass 无额外约束规则 - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &full2 = res.Op(paddle::dialect::FullOp::name(), {{"shape", pat.Attr("expand_shape_value")}, {"value", pat.Attr("value_1")}, diff --git a/paddle/fluid/pir/drr/api/drr_pattern_base.h b/paddle/fluid/pir/drr/api/drr_pattern_base.h index 1a84c42800373b..18252d536869f7 100644 --- a/paddle/fluid/pir/drr/api/drr_pattern_base.h +++ b/paddle/fluid/pir/drr/api/drr_pattern_base.h @@ -17,7 +17,7 @@ #include "paddle/fluid/pir/drr/api/drr_pattern_context.h" #include "paddle/fluid/pir/drr/drr_rewrite_pattern.h" -namespace pir { +namespace paddle { namespace drr { template @@ -26,7 +26,7 @@ class DrrPatternBase { virtual ~DrrPatternBase() = default; // Define the Drr Pattern. - virtual void operator()(pir::drr::DrrPatternContext* ctx) const = 0; + virtual void operator()(paddle::drr::DrrPatternContext* ctx) const = 0; std::unique_ptr Build( pir::IrContext* ir_context, pir::PatternBenefit benefit = 1) const { @@ -39,4 +39,4 @@ class DrrPatternBase { }; } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/api/drr_pattern_context.cc b/paddle/fluid/pir/drr/api/drr_pattern_context.cc index 50e94c3458265c..7f98f0b34cbeb7 100644 --- a/paddle/fluid/pir/drr/api/drr_pattern_context.cc +++ b/paddle/fluid/pir/drr/api/drr_pattern_context.cc @@ -17,7 +17,7 @@ #include "paddle/fluid/pir/drr/pattern_graph.h" #include "paddle/phi/core/enforce.h" -namespace pir { +namespace paddle { namespace drr { DrrPatternContext::DrrPatternContext() { @@ -28,6 +28,7 @@ DrrPatternContext::DrrPatternContext() { drr::SourcePattern DrrPatternContext::SourcePattern() { return drr::SourcePattern(this); } + const Op& DrrPatternContext::SourceOpPattern( const std::string& op_type, const std::unordered_map& attributes) { @@ -167,4 +168,4 @@ void Tensor::operator=(const Tensor& other) const { // NOLINT } } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/api/drr_pattern_context.h b/paddle/fluid/pir/drr/api/drr_pattern_context.h index 5c235215dd19ba..feb0e988aa8822 100644 --- a/paddle/fluid/pir/drr/api/drr_pattern_context.h +++ b/paddle/fluid/pir/drr/api/drr_pattern_context.h @@ -24,7 +24,7 @@ #include "paddle/fluid/pir/drr/api/match_context.h" -namespace pir { +namespace paddle { namespace drr { class Op; @@ -334,4 +334,4 @@ class SourcePattern { }; } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/api/match_context.cc b/paddle/fluid/pir/drr/api/match_context.cc index 35b28db13254ed..e5f15adf72e75e 100644 --- a/paddle/fluid/pir/drr/api/match_context.cc +++ b/paddle/fluid/pir/drr/api/match_context.cc @@ -19,7 +19,7 @@ #include "paddle/fluid/pir/drr/ir_operation.h" #include "paddle/fluid/pir/drr/match_context_impl.h" -namespace pir { +namespace paddle { namespace drr { MatchContext::MatchContext(std::shared_ptr impl) @@ -46,4 +46,4 @@ template std::vector MatchContext::Attr>( const std::string&) const; } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/api/match_context.h b/paddle/fluid/pir/drr/api/match_context.h index a1699ccb5bddf6..762c86cf8a8e60 100644 --- a/paddle/fluid/pir/drr/api/match_context.h +++ b/paddle/fluid/pir/drr/api/match_context.h @@ -20,7 +20,7 @@ #include "paddle/fluid/pir/drr/api/tensor_interface.h" #include "paddle/fluid/pir/drr/ir_operation.h" -namespace pir { +namespace paddle { namespace drr { class TensorInterface; @@ -40,4 +40,4 @@ class MatchContext final { }; } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/api/tensor_interface.cc b/paddle/fluid/pir/drr/api/tensor_interface.cc index 03a35031f0d917..335f95214887a9 100644 --- a/paddle/fluid/pir/drr/api/tensor_interface.cc +++ b/paddle/fluid/pir/drr/api/tensor_interface.cc @@ -15,7 +15,7 @@ #include "paddle/fluid/pir/drr/api/tensor_interface.h" #include "paddle/fluid/pir/drr/ir_value.h" -namespace pir { +namespace paddle { namespace drr { bool ShapeInterface::operator==(const ShapeInterface& other) const { @@ -33,4 +33,4 @@ bool DtypeInterface::operator==(const DtypeInterface& other) const { IrDtype DtypeInterface::get() const { return *(this->dtype_); } } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/api/tensor_interface.h b/paddle/fluid/pir/drr/api/tensor_interface.h index 4684beba4ad844..24774f00d5a298 100644 --- a/paddle/fluid/pir/drr/api/tensor_interface.h +++ b/paddle/fluid/pir/drr/api/tensor_interface.h @@ -16,7 +16,7 @@ #include -namespace pir { +namespace paddle { namespace drr { class IrValue; @@ -60,4 +60,4 @@ class TensorInterface { }; } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/attr_type_uilts.h b/paddle/fluid/pir/drr/attr_type_uilts.h index 4043aa3c643835..8904ed0e9ff6a7 100644 --- a/paddle/fluid/pir/drr/attr_type_uilts.h +++ b/paddle/fluid/pir/drr/attr_type_uilts.h @@ -19,7 +19,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/pir/core/builtin_attribute.h" -namespace pir { +namespace paddle { namespace drr { template @@ -32,11 +32,11 @@ struct CppTypeToIrAttribute; using type = ir_attr_type; \ }; -PD_SPECIALIZE_CppTypeToIrAttribute(bool, BoolAttribute); -PD_SPECIALIZE_CppTypeToIrAttribute(int32_t, Int32Attribute); -PD_SPECIALIZE_CppTypeToIrAttribute(int64_t, Int64Attribute); -PD_SPECIALIZE_CppTypeToIrAttribute(float, FloatAttribute); -PD_SPECIALIZE_CppTypeToIrAttribute(std::string, StrAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(bool, pir::BoolAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(int32_t, pir::Int32Attribute); +PD_SPECIALIZE_CppTypeToIrAttribute(int64_t, pir::Int64Attribute); +PD_SPECIALIZE_CppTypeToIrAttribute(float, pir::FloatAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(std::string, pir::StrAttribute); PD_SPECIALIZE_CppTypeToIrAttribute(phi::DataType, paddle::dialect::DataTypeAttribute); PD_SPECIALIZE_CppTypeToIrAttribute(phi::Place, paddle::dialect::PlaceAttribute); @@ -61,7 +61,8 @@ struct IrAttrbuteCreator> { std::vector attr_vec; attr_vec.reserve(obj.size()); for (int32_t x : obj) { - attr_vec.push_back(Int32Attribute::get(pir::IrContext::Instance(), x)); + attr_vec.push_back( + pir::Int32Attribute::get(pir::IrContext::Instance(), x)); } return pir::ArrayAttribute::get(pir::IrContext::Instance(), attr_vec); } @@ -73,7 +74,8 @@ struct IrAttrbuteCreator> { std::vector attr_vec; attr_vec.reserve(obj.size()); for (float x : obj) { - attr_vec.push_back(FloatAttribute::get(pir::IrContext::Instance(), x)); + attr_vec.push_back( + pir::FloatAttribute::get(pir::IrContext::Instance(), x)); } return pir::ArrayAttribute::get(pir::IrContext::Instance(), attr_vec); } @@ -140,4 +142,4 @@ struct IrAttrTypeCast> { }; } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/drr_rewrite_pattern.cc b/paddle/fluid/pir/drr/drr_rewrite_pattern.cc index d0c0d71a3feaab..d408c1aab13490 100644 --- a/paddle/fluid/pir/drr/drr_rewrite_pattern.cc +++ b/paddle/fluid/pir/drr/drr_rewrite_pattern.cc @@ -14,12 +14,12 @@ #include "paddle/fluid/pir/drr/drr_rewrite_pattern.h" -namespace pir { +namespace paddle { namespace drr { bool DrrRewritePattern::MatchAndRewrite( pir::Operation* op, - PatternRewriter& rewriter) const { // NOLINT + pir::PatternRewriter& rewriter) const { // NOLINT std::shared_ptr src_match_ctx = std::make_shared(); if (PatternGraphMatch(op, src_match_ctx.get())) { @@ -41,8 +41,8 @@ bool DrrRewritePattern::PatternGraphMatch( return false; } std::vector drr_output_sequence; - std::vector ir_output_sequence; - std::unordered_map output_op_map; + std::vector ir_output_sequence; + std::unordered_map output_op_map; for (const auto& pair : bind_map) { drr_output_sequence.push_back(pair.first); } @@ -50,8 +50,8 @@ bool DrrRewritePattern::PatternGraphMatch( auto permute = [&](auto&& permute, size_t index) -> bool { if (index == drr_output_sequence.size()) { // avoiding duplicate binding of ir op - std::unordered_set ir_output_set; - for (Operation* op : ir_output_sequence) { + std::unordered_set ir_output_set; + for (pir::Operation* op : ir_output_sequence) { auto pr = ir_output_set.insert(op); if (pr.second == false) { return false; @@ -64,7 +64,7 @@ bool DrrRewritePattern::PatternGraphMatch( drr_output_sequence.end(), ir_output_sequence.begin(), std::inserter(output_op_map, output_op_map.end()), - [](const OpCall* drr_op, Operation* ir_op) { + [](const OpCall* drr_op, pir::Operation* ir_op) { return std::make_pair(drr_op, ir_op); }); if (MatchFromOutputToInput( @@ -214,12 +214,12 @@ void DrrRewritePattern::DfsVisitor( } bool DrrRewritePattern::MatchFromOutputToInput( - std::unordered_map output_op_map, + std::unordered_map output_op_map, const SourcePatternGraph& source_pattern_graph, const std::shared_ptr& source_pattern_match_ctx) const { VLOG(6) << "MatchFromOutputToInput Start"; std::unordered_set drr_visited; - std::unordered_set ir_visited; + std::unordered_set ir_visited; std::queue drr_q; std::queue ir_q; bool matched = true; @@ -385,8 +385,8 @@ MatchContextImpl DrrRewritePattern::CreateOperations( } } - std::vector> temp_program; - std::unordered_map op_2_temp_program_index; + std::vector> temp_program; + std::unordered_map op_2_temp_program_index; for (auto& op : *rewriter.block()) { op_2_temp_program_index[&op] = temp_program.size(); temp_program.push_back({&op}); @@ -397,14 +397,14 @@ MatchContextImpl DrrRewritePattern::CreateOperations( graph_topo_visit.WalkGraphNodesTopoOrder([&](const OpCall& op_call) { // set insert point size_t max_input_op_index = 0; - Operation* max_index_op = nullptr; + pir::Operation* max_index_op = nullptr; for (const Tensor* input : op_call.inputs()) { if (input->is_none()) { continue; } auto ir_val = res_match_ctx.GetIrValue(input->name()); if (ir_val) { - Operation* ir_input_op = ir_val.dyn_cast().owner(); + pir::Operation* ir_input_op = ir_val.dyn_cast().owner(); if (op_2_temp_program_index.count(ir_input_op) == 0) { max_input_op_index = 0UL; } else if (max_input_op_index < @@ -431,7 +431,7 @@ MatchContextImpl DrrRewritePattern::CreateOperations( } if (max_input_op_index == 0UL) { VLOG(6) << "Not found producer op for (" << op_call.name() << ")"; - Operation* source_patter_first_op = + pir::Operation* source_patter_first_op = src_match_ctx.Operation(source_pattern_graph.owned_op_call()[0].get()) .get(); max_input_op_index = op_2_temp_program_index[source_patter_first_op]; @@ -440,7 +440,7 @@ MatchContextImpl DrrRewritePattern::CreateOperations( rewriter.SetInsertionPointAfter(max_index_op); } - Operation* new_op = + pir::Operation* new_op = CreateOperation(op_call, src_match_ctx, rewriter, &res_match_ctx); op_2_temp_program_index[new_op] = max_input_op_index + 1; if (max_input_op_index + 1 >= temp_program.size()) { @@ -487,11 +487,11 @@ void DrrRewritePattern::DeleteSourcePatternOp( const ResultPatternGraph& result_pattern_graph, const MatchContextImpl& src_match_ctx, pir::PatternRewriter& rewriter) const { // NOLINT - std::queue delete_ops_que; - std::unordered_set delete_ops_set; + std::queue delete_ops_que; + std::unordered_set delete_ops_set; GraphTopo graph_topo_visit(&source_pattern_graph); graph_topo_visit.WalkGraphNodesTopoOrder([&](const OpCall& op_call) { - Operation* op = src_match_ctx.Operation(&op_call).get(); + pir::Operation* op = src_match_ctx.Operation(&op_call).get(); VLOG(5) << "DRR delete op: " << op->name() << " pointer: " << op; if (delete_ops_set.count(op) == 0 && op->use_empty()) { delete_ops_que.push(op); @@ -500,9 +500,9 @@ void DrrRewritePattern::DeleteSourcePatternOp( }); while (!delete_ops_que.empty()) { - Operation* op = delete_ops_que.front(); + pir::Operation* op = delete_ops_que.front(); delete_ops_que.pop(); - std::vector inputs = op->operands_source(); + std::vector inputs = op->operands_source(); VLOG(5) << "Delete (" << op->name() << " @" << op << ") in source_pattern_graph."; rewriter.EraseOp(op); @@ -517,4 +517,4 @@ void DrrRewritePattern::DeleteSourcePatternOp( } } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/drr_rewrite_pattern.h b/paddle/fluid/pir/drr/drr_rewrite_pattern.h index 5d20a5947f13b0..6163c6d9d0193e 100644 --- a/paddle/fluid/pir/drr/drr_rewrite_pattern.h +++ b/paddle/fluid/pir/drr/drr_rewrite_pattern.h @@ -31,7 +31,7 @@ #include "paddle/pir/core/type_name.h" #include "paddle/pir/pattern_rewrite/pattern_match.h" -namespace pir { +namespace paddle { namespace drr { class DrrRewritePattern : public pir::RewritePattern { @@ -57,8 +57,9 @@ class DrrRewritePattern : public pir::RewritePattern { "source pattern definition code.")); } - bool MatchAndRewrite(pir::Operation* op, - PatternRewriter& rewriter) const override; // // NOLINT + bool MatchAndRewrite( + pir::Operation* op, + pir::PatternRewriter& rewriter) const override; // // NOLINT private: bool PatternGraphMatch(pir::Operation* op, @@ -78,7 +79,7 @@ class DrrRewritePattern : public pir::RewritePattern { output_op_bind_map) const; bool MatchFromOutputToInput( - std::unordered_map output_op_map, + std::unordered_map output_op_map, const SourcePatternGraph& source_pattern_graph, const std::shared_ptr& source_pattern_match_ctx) const; @@ -113,4 +114,4 @@ class DrrRewritePattern : public pir::RewritePattern { }; } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/ir_operation.h b/paddle/fluid/pir/drr/ir_operation.h index 2764bc92454170..a88bb3bfff97cf 100644 --- a/paddle/fluid/pir/drr/ir_operation.h +++ b/paddle/fluid/pir/drr/ir_operation.h @@ -16,7 +16,7 @@ #include "paddle/pir/core/operation.h" -namespace pir { +namespace paddle { namespace drr { class IrOperation { @@ -30,4 +30,4 @@ class IrOperation { }; } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/ir_operation_factory.cc b/paddle/fluid/pir/drr/ir_operation_factory.cc index 6644026fabde01..bbc31e9df7c25b 100644 --- a/paddle/fluid/pir/drr/ir_operation_factory.cc +++ b/paddle/fluid/pir/drr/ir_operation_factory.cc @@ -24,13 +24,13 @@ #include "paddle/pir/core/operation.h" #include "paddle/pir/core/value.h" -namespace pir { +namespace paddle { namespace drr { void OperationFactory::RegisterManualOpCreator() { RegisterOperationCreator( "pd_op.fused_gemm_epilogue", - [](const std::vector& inputs, + [](const std::vector& inputs, const pir::AttributeMap& attrs, pir::PatternRewriter& rewriter) { return rewriter.Build( @@ -41,7 +41,7 @@ void OperationFactory::RegisterManualOpCreator() { }); RegisterOperationCreator( "pd_op.fused_gemm_epilogue_grad", - [](const std::vector& inputs, + [](const std::vector& inputs, const pir::AttributeMap& attrs, pir::PatternRewriter& rewriter) { return rewriter.Build( @@ -52,14 +52,14 @@ void OperationFactory::RegisterManualOpCreator() { attrs); }); RegisterOperationCreator("builtin.combine", - [](const std::vector& inputs, + [](const std::vector& inputs, const pir::AttributeMap& attrs, pir::PatternRewriter& rewriter) { return rewriter.Build(inputs); }); RegisterOperationCreator( "pd_op.scale", - [](const std::vector& inputs, + [](const std::vector& inputs, const pir::AttributeMap& attrs, pir::PatternRewriter& rewriter) { return rewriter.Build( @@ -130,18 +130,18 @@ pir::AttributeMap CreateAttributeMap(const OpCall& op_call, return attr_map; } -Value GetIrValueByDrrTensor(const Tensor& tensor, - const MatchContextImpl& res_match_ctx) { +pir::Value GetIrValueByDrrTensor(const Tensor& tensor, + const MatchContextImpl& res_match_ctx) { if (tensor.is_none()) { - return Value{}; + return pir::Value{}; } return res_match_ctx.GetIrValue(tensor.name()).get(); } -std::vector GetIrValuesByDrrTensors( +std::vector GetIrValuesByDrrTensors( const std::vector& tensors, const MatchContextImpl& res_match_ctx) { - std::vector ir_values; + std::vector ir_values; ir_values.reserve(tensors.size()); for (const auto* tensor : tensors) { ir_values.push_back(GetIrValueByDrrTensor(*tensor, res_match_ctx)); @@ -167,7 +167,7 @@ pir::Operation* CreateOperation(const OpCall& op_call, MatchContextImpl* res_match_ctx) { VLOG(6) << "Drr create [" << op_call.name() << "] op..."; const auto& inputs = op_call.inputs(); - std::vector ir_values = + std::vector ir_values = GetIrValuesByDrrTensors(inputs, *res_match_ctx); pir::Operation* op = OperationFactory::Instance().CreateOperation( op_call.name(), @@ -180,4 +180,4 @@ pir::Operation* CreateOperation(const OpCall& op_call, } } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/ir_operation_factory.h b/paddle/fluid/pir/drr/ir_operation_factory.h index adc76efb99b2de..40682904df62a8 100644 --- a/paddle/fluid/pir/drr/ir_operation_factory.h +++ b/paddle/fluid/pir/drr/ir_operation_factory.h @@ -20,7 +20,7 @@ #include "paddle/fluid/pir/drr/match_context_impl.h" #include "paddle/pir/pattern_rewrite/pattern_match.h" -namespace pir { +namespace paddle { namespace drr { class OperationFactory { @@ -31,7 +31,7 @@ class OperationFactory { } using operation_create_fn = - std::function&, + std::function&, const pir::AttributeMap&, pir::PatternRewriter&)>; @@ -42,7 +42,7 @@ class OperationFactory { pir::Operation* CreateOperation( const std::string& op_name, - const std::vector& inputs, + const std::vector& inputs, const pir::AttributeMap& attrs, pir::PatternRewriter& rewriter) const { // NOLINT auto iter = op_creator_map.find(op_name); @@ -79,4 +79,4 @@ pir::Operation* CreateOperation(const OpCall& op_call, MatchContextImpl* res_match_ctx); } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/ir_value.h b/paddle/fluid/pir/drr/ir_value.h index 125f198dcc74c5..ae99fd8c1964e2 100644 --- a/paddle/fluid/pir/drr/ir_value.h +++ b/paddle/fluid/pir/drr/ir_value.h @@ -21,7 +21,7 @@ #include "paddle/pir/core/type.h" #include "paddle/pir/core/value.h" -namespace pir { +namespace paddle { namespace drr { class IrShape { @@ -101,10 +101,10 @@ class IrValue : public TensorInterface { } // Don't use it in drr pass! - const Value& get() const { return value_; } + const pir::Value& get() const { return value_; } private: - const Value value_; + const pir::Value value_; const IrShape shape_; const IrDtype dtype_; }; @@ -112,4 +112,4 @@ class IrValue : public TensorInterface { class IrAttr; } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/match_context_impl.h b/paddle/fluid/pir/drr/match_context_impl.h index 37b06914cd2bdf..b1234d81299360 100644 --- a/paddle/fluid/pir/drr/match_context_impl.h +++ b/paddle/fluid/pir/drr/match_context_impl.h @@ -25,7 +25,7 @@ #include "paddle/fluid/pir/drr/ir_value.h" #include "paddle/pir/core/builtin_attribute.h" -namespace pir { +namespace paddle { namespace drr { class MatchContextImpl final { @@ -131,4 +131,4 @@ class MatchContextImpl final { }; } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/pattern_graph.cc b/paddle/fluid/pir/drr/pattern_graph.cc index 7d732b6576f68c..58c79c65acf2f6 100644 --- a/paddle/fluid/pir/drr/pattern_graph.cc +++ b/paddle/fluid/pir/drr/pattern_graph.cc @@ -19,7 +19,7 @@ #include "paddle/fluid/pir/drr/api/drr_pattern_context.h" #include "paddle/phi/core/enforce.h" -namespace pir { +namespace paddle { namespace drr { const drr::OpCall &PatternGraph::AddOpCall( @@ -238,4 +238,4 @@ std::ostream &operator<<(std::ostream &os, const PatternGraph &pattern_graph) { } } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/drr/pattern_graph.h b/paddle/fluid/pir/drr/pattern_graph.h index 63bd60eadf17f3..e5cd74b2fa2176 100644 --- a/paddle/fluid/pir/drr/pattern_graph.h +++ b/paddle/fluid/pir/drr/pattern_graph.h @@ -21,7 +21,7 @@ #include #include -namespace pir { +namespace paddle { namespace drr { class Constraint; @@ -105,4 +105,4 @@ class GraphTopo { }; } // namespace drr -} // namespace pir +} // namespace paddle diff --git a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc index fbabf835390018..ab19247de4b26a 100644 --- a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc @@ -22,13 +22,13 @@ namespace { class MultiHeadMatmulFusePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // // Source Pattern. // - pir::drr::SourcePattern src = ctx->SourcePattern(); + paddle::drr::SourcePattern src = ctx->SourcePattern(); // The first path to matmul with scale (q). const auto &matmul_1 = src.Op("pd_op.matmul", @@ -115,7 +115,8 @@ class MultiHeadMatmulFusePattern // // Constraints. // - src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool { + src.RequireNativeCall([](const paddle::drr::MatchContext &match_ctx) + -> bool { const auto &softmax_axis = match_ctx.Attr("softmax_axis"); if (softmax_axis != -1 && softmax_axis != 3) return false; @@ -145,7 +146,7 @@ class MultiHeadMatmulFusePattern // // Result Pattern. // - pir::drr::ResultPattern res = src.ResultPattern(); + paddle::drr::ResultPattern res = src.ResultPattern(); // W combine. const auto &combine_1 = res.Op("builtin.combine"); combine_1({&res.Tensor("matmul_1_in_2"), @@ -153,11 +154,11 @@ class MultiHeadMatmulFusePattern &res.Tensor("matmul_3_in_2")}, {&res.Tensor("combine_1_out")}); const auto &concat_axis = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> int { return 0; }); + [](const paddle::drr::MatchContext &match_ctx) -> int { return 0; }); const auto &concat_1 = res.Op("pd_op.concat", {{"axis", concat_axis}}); res.Tensor("concat_1_out") = concat_1(res.Tensor("combine_1_out")); const auto &reshape_5_shape = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> std::vector { + [](const paddle::drr::MatchContext &match_ctx) -> std::vector { auto matmul_1_in_2 = match_ctx.Tensor("matmul_1_in_2").Shape(); return {-1, 3, matmul_1_in_2.at(1)}; }); @@ -175,7 +176,7 @@ class MultiHeadMatmulFusePattern const auto &concat_2 = res.Op("pd_op.concat", {{"axis", concat_axis}}); res.Tensor("concat_2_out") = concat_2(res.Tensor("combine_2_out")); const auto &reshape_6_shape = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> std::vector { + [](const paddle::drr::MatchContext &match_ctx) -> std::vector { return {3, -1}; }); const auto &reshape_6 = @@ -184,28 +185,31 @@ class MultiHeadMatmulFusePattern {&res.Tensor("reshape_6_out"), &res.NoneTensor()}); const auto &head_number = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> int { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> int { const auto &full_int_array_1_value = match_ctx.Attr>("full_int_array_1_value"); return full_int_array_1_value.at(2); }); const auto &alpha = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { return match_ctx.Attr("full_1_value"); }); - const auto &multihead_matmul = res.Op( - "pd_op.multihead_matmul", - {{"transpose_q", res.Attr([](const pir::drr::MatchContext &match_ctx) { - return false; - })}, - {"transpose_k", res.Attr([](const pir::drr::MatchContext &match_ctx) { - return true; - })}, - {"transpose_v", res.Attr([](const pir::drr::MatchContext &match_ctx) { - return false; - })}, - {"head_number", head_number}, - {"alpha", alpha}}); + const auto &multihead_matmul = + res.Op("pd_op.multihead_matmul", + {{"transpose_q", + res.Attr([](const paddle::drr::MatchContext &match_ctx) { + return false; + })}, + {"transpose_k", + res.Attr([](const paddle::drr::MatchContext &match_ctx) { + return true; + })}, + {"transpose_v", + res.Attr([](const paddle::drr::MatchContext &match_ctx) { + return false; + })}, + {"head_number", head_number}, + {"alpha", alpha}}); multihead_matmul({&res.Tensor("matmul_1_in_1"), &res.Tensor("reshape_5_out"), &res.Tensor("reshape_6_out"), diff --git a/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc index 86846508a519dc..e86dc04037fa01 100644 --- a/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc @@ -29,10 +29,10 @@ namespace { class Conv2dAddFusePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &conv2d = pat.Op(paddle::dialect::Conv2dOp::name(), {{"strides", pat.Attr("strides")}, @@ -46,7 +46,7 @@ class Conv2dAddFusePattern {&pat.Tensor("conv2d_out")}); pat.Tensor("add_out") = add(pat.Tensor("conv2d_out"), pat.Tensor("bias")); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &fused_conv2d_add_act = res.Op( paddle::dialect::FusedConv2dAddActOp::name(), @@ -58,21 +58,21 @@ class Conv2dAddFusePattern {"groups", pat.Attr("groups")}, {"data_format", pat.Attr("data_format")}, {"activation", - res.Attr([](const pir::drr::MatchContext &match_ctx) + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::string { return "identity"; })}, {"split_channels", - res.Attr([](const pir::drr::MatchContext &match_ctx) + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::vector { return {}; })}, {"exhaustive_search", - res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { return false; })}, {"workspace_size_MB", - res.Attr([](const pir::drr::MatchContext &match_ctx) -> int { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> int { return 32; })}, {"fuse_alpha", - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { return 0.0f; })}, }}); diff --git a/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc index fdb4621fb350b6..7e5c4bbe8ea187 100644 --- a/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc @@ -22,10 +22,10 @@ namespace { class FcElementwiseLayerNormFusePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &fc = pat.Op(paddle::dialect::FcOp::name(), { @@ -47,7 +47,7 @@ class FcElementwiseLayerNormFusePattern &pat.Tensor("layernorm_mean"), &pat.Tensor("layernorm_variance")}); // Constrains the activation is none - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { int64_t layer_norm_x = 1; for (int i = match_ctx.Attr("begin_norm_axis"); i < match_ctx.Tensor("fc_out").Shape().size(); @@ -60,12 +60,16 @@ class FcElementwiseLayerNormFusePattern return false; }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &x_num_col_dims_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> std::any { return 1; }); - const auto &false_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &x_num_col_dims_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { + return 1; + }); + const auto &false_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &fused_fc_elementwise_op = res.Op(paddle::dialect::FusedFcElementwiseLayernormOp::name(), @@ -88,10 +92,10 @@ class FcElementwiseLayerNormFusePattern }; class FcElementwiseLayerNormFuse2Pattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &fc = pat.Op(paddle::dialect::FcOp::name(), { @@ -113,7 +117,7 @@ class FcElementwiseLayerNormFuse2Pattern &pat.Tensor("layernorm_mean"), &pat.Tensor("layernorm_variance")}); // Constrains the activation is none - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { int64_t layer_norm_x = 1; for (int i = match_ctx.Attr("begin_norm_axis"); i < match_ctx.Tensor("fc_out").Shape().size(); @@ -126,7 +130,7 @@ class FcElementwiseLayerNormFuse2Pattern return false; }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &fused_fc_elementwise_op = res.Op(paddle::dialect::FusedFcElementwiseLayernormOp::name(), diff --git a/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc index 2a320b75d6cc31..b49ab9ff4ac77b 100644 --- a/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc @@ -21,10 +21,10 @@ namespace { -class MatmulAddPattern : public pir::drr::DrrPatternBase { +class MatmulAddPattern : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("transpose_x")}, {"transpose_y", pat.Attr("transpose_y")}}); @@ -32,7 +32,7 @@ class MatmulAddPattern : public pir::drr::DrrPatternBase { matmul({&pat.Tensor("x"), &pat.Tensor("w")}, {&pat.Tensor("matmul_out")}); pat.Tensor("add_out") = add(pat.Tensor("matmul_out"), pat.Tensor("y")); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { if (match_ctx.Tensor("w").Shape().size() != 2 || match_ctx.Tensor("x").Shape().size() < 2) { return false; @@ -56,21 +56,23 @@ class MatmulAddPattern : public pir::drr::DrrPatternBase { return false; }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &in_num_col_dims_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return match_ctx.Tensor("x").Shape().size() - 1; }); - const auto &false_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &false_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &fc = res.Op(paddle::dialect::FcOp::name(), {{ {"in_num_col_dims", in_num_col_dims_attr}, {"activation_type", - res.Attr([](const pir::drr::MatchContext &match_ctx) + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::string { return ""; })}, {"padding_weights", false_attr}, }}); @@ -79,10 +81,11 @@ class MatmulAddPattern : public pir::drr::DrrPatternBase { } }; -class FcWithReluPattern : public pir::drr::DrrPatternBase { +class FcWithReluPattern + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &fc = pat.Op(paddle::dialect::FcOp::name(), {{ @@ -96,18 +99,18 @@ class FcWithReluPattern : public pir::drr::DrrPatternBase { relu({&pat.Tensor("fc_out")}, {&pat.Tensor("relu_out")}); // Constrains the activation is none - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return match_ctx.Attr("activation_type").empty(); }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &fc_with_relu = res.Op(paddle::dialect::FcOp::name(), {{ {"in_num_col_dims", pat.Attr("in_num_col_dims")}, {"activation_type", - res.Attr([](const pir::drr::MatchContext &match_ctx) + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::string { return "relu"; })}, {"padding_weights", pat.Attr("padding_weights")}, }}); diff --git a/paddle/fluid/pir/transforms/fusion/fc_with_special_op_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/fc_with_special_op_fuse_pass.cc index 6bb2b3a6d512db..74dd21a0828fe9 100644 --- a/paddle/fluid/pir/transforms/fusion/fc_with_special_op_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fc_with_special_op_fuse_pass.cc @@ -31,10 +31,10 @@ namespace { class SqueezeFcFusePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &squeeze_op = pat.Op(paddle::dialect::SqueezeOp::name()); const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("transpose_x")}, @@ -46,7 +46,7 @@ class SqueezeFcFusePattern {&pat.Tensor("matmul_out")}); pat.Tensor("add_out") = add(pat.Tensor("matmul_out"), pat.Tensor("bias")); // Constrains the activation is none - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { auto axis_type = match_ctx.Tensor("axis").Dtype().get(); if (axis_type.isa() && axis_type.dyn_cast().size() != 2) { @@ -87,19 +87,23 @@ class SqueezeFcFusePattern return false; }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &in_num_col_dims_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> std::any { return 1; }); - const auto &false_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &in_num_col_dims_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { + return 1; + }); + const auto &false_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &fc = res.Op(paddle::dialect::FcOp::name(), {{ {"in_num_col_dims", in_num_col_dims_attr}, {"activation_type", - res.Attr([](const pir::drr::MatchContext &match_ctx) + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::string { return ""; })}, {"padding_weights", false_attr}, }}); @@ -109,10 +113,10 @@ class SqueezeFcFusePattern }; class ReshapeFcFusePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &reshape_op = pat.Op(paddle::dialect::ReshapeOp::name()); const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("transpose_x")}, @@ -124,7 +128,7 @@ class ReshapeFcFusePattern {&pat.Tensor("matmul_out")}); add({&pat.Tensor("matmul_out"), &pat.Tensor("bias")}, {&pat.Tensor("add_out")}); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { if (match_ctx.Tensor("w").Shape().size() != 2 || match_ctx.Attr("transpose_x") == true || match_ctx.Attr("transpose_y") == true) { @@ -212,10 +216,10 @@ class ReshapeFcFusePattern } return true; }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &in_num_col_dims_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { int i = match_ctx.Tensor("x").Shape().size() - 1; int target = match_ctx.Tensor("reshape_out") @@ -228,15 +232,17 @@ class ReshapeFcFusePattern } return i; }); - const auto &false_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &false_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &fc = res.Op(paddle::dialect::FcOp::name(), {{ {"in_num_col_dims", in_num_col_dims_attr}, {"activation_type", - res.Attr([](const pir::drr::MatchContext &match_ctx) + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::string { return ""; })}, {"padding_weights", false_attr}, }}); @@ -246,10 +252,10 @@ class ReshapeFcFusePattern }; class FlattenFcFusePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &flatten_op = pat.Op(paddle::dialect::FlattenOp::name(), {{"start_axis", pat.Attr("start_axis")}, {"stop_axis", pat.Attr("stop_axis")}}); @@ -263,7 +269,7 @@ class FlattenFcFusePattern {&pat.Tensor("matmul_out")}); pat.Tensor("add_out") = add(pat.Tensor("matmul_out"), pat.Tensor("bias")); // Constrains the activation is none - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { bool flatten_flag = false; if (match_ctx.Tensor("x").Shape().size() == 4 && @@ -295,19 +301,23 @@ class FlattenFcFusePattern return false; }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &in_num_col_dims_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> std::any { return 1; }); - const auto &false_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &in_num_col_dims_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { + return 1; + }); + const auto &false_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &fc = res.Op(paddle::dialect::FcOp::name(), {{ {"in_num_col_dims", in_num_col_dims_attr}, {"activation_type", - res.Attr([](const pir::drr::MatchContext &match_ctx) + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::string { return ""; })}, {"padding_weights", false_attr}, }}); diff --git a/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc index 639c0e0e4b4140..9b2e7f2f3f2e74 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc @@ -22,10 +22,10 @@ namespace { class FusedDotProductAttentionPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern src = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern src = ctx->SourcePattern(); // q[b, s, head, head_dim] -> transpose -> q[b, head, s, head_dim] -> scale const auto &q_transpose = src.Op("pd_op.transpose"); @@ -82,40 +82,45 @@ class FusedDotProductAttentionPattern src.Tensor("out") = o_transpose(src.Tensor("context_matmul_out")); // Constraints - src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool { - const auto &softmax_axis = match_ctx.Attr("softmax_axis"); - if (softmax_axis != -1 && softmax_axis != 3) return false; - - bool qk_matmul_transpose_x = - match_ctx.Attr("qk_matmul_transpose_x"); - bool qk_matmul_transpose_y = - match_ctx.Attr("qk_matmul_transpose_y"); - if (qk_matmul_transpose_x || !qk_matmul_transpose_y) return false; - - bool context_matmul_transpose_x = - match_ctx.Attr("context_matmul_transpose_x"); - bool context_matmul_transpose_y = - match_ctx.Attr("context_matmul_transpose_y"); - if (context_matmul_transpose_x || context_matmul_transpose_y) - return false; - - return true; - }); + src.RequireNativeCall( + [](const paddle::drr::MatchContext &match_ctx) -> bool { + const auto &softmax_axis = match_ctx.Attr("softmax_axis"); + if (softmax_axis != -1 && softmax_axis != 3) return false; + + bool qk_matmul_transpose_x = + match_ctx.Attr("qk_matmul_transpose_x"); + bool qk_matmul_transpose_y = + match_ctx.Attr("qk_matmul_transpose_y"); + if (qk_matmul_transpose_x || !qk_matmul_transpose_y) return false; + + bool context_matmul_transpose_x = + match_ctx.Attr("context_matmul_transpose_x"); + bool context_matmul_transpose_y = + match_ctx.Attr("context_matmul_transpose_y"); + if (context_matmul_transpose_x || context_matmul_transpose_y) + return false; + + return true; + }); // Result pattern - pir::drr::ResultPattern res = src.ResultPattern(); + paddle::drr::ResultPattern res = src.ResultPattern(); const auto &scaling_factor = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { return match_ctx.Attr("q_scale_value"); }); const auto &dropout_prob = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { return static_cast(0.0); }); - const auto &is_training = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); - const auto &is_causal_masking = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &is_training = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return true; + }); + const auto &is_causal_masking = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &dot_product_attention = res.Op(paddle::dialect::FusedDotProductAttentionOp::name(), @@ -135,10 +140,10 @@ class FusedDotProductAttentionPattern }; class FusedDotProductAttentionGradPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern src = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern src = ctx->SourcePattern(); // q[b, s, head, head_dim] -> transpose -> q[b, head, s, head_dim] -> scale const auto &q_transpose = src.Op("pd_op.transpose"); @@ -239,40 +244,45 @@ class FusedDotProductAttentionGradPattern {&src.Tensor("k_grad")}); // Constraints - src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool { - const auto &softmax_axis = match_ctx.Attr("softmax_axis"); - if (softmax_axis != -1 && softmax_axis != 3) return false; - - bool qk_matmul_transpose_x = - match_ctx.Attr("qk_matmul_transpose_x"); - bool qk_matmul_transpose_y = - match_ctx.Attr("qk_matmul_transpose_y"); - if (qk_matmul_transpose_x || !qk_matmul_transpose_y) return false; - - bool context_matmul_transpose_x = - match_ctx.Attr("context_matmul_transpose_x"); - bool context_matmul_transpose_y = - match_ctx.Attr("context_matmul_transpose_y"); - if (context_matmul_transpose_x || context_matmul_transpose_y) - return false; - - return true; - }); + src.RequireNativeCall( + [](const paddle::drr::MatchContext &match_ctx) -> bool { + const auto &softmax_axis = match_ctx.Attr("softmax_axis"); + if (softmax_axis != -1 && softmax_axis != 3) return false; + + bool qk_matmul_transpose_x = + match_ctx.Attr("qk_matmul_transpose_x"); + bool qk_matmul_transpose_y = + match_ctx.Attr("qk_matmul_transpose_y"); + if (qk_matmul_transpose_x || !qk_matmul_transpose_y) return false; + + bool context_matmul_transpose_x = + match_ctx.Attr("context_matmul_transpose_x"); + bool context_matmul_transpose_y = + match_ctx.Attr("context_matmul_transpose_y"); + if (context_matmul_transpose_x || context_matmul_transpose_y) + return false; + + return true; + }); // Result pattern - pir::drr::ResultPattern res = src.ResultPattern(); + paddle::drr::ResultPattern res = src.ResultPattern(); const auto &scaling_factor = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { return match_ctx.Attr("q_scale_value"); }); const auto &dropout_prob = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { return static_cast(0.0); }); - const auto &is_training = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); - const auto &is_causal_masking = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &is_training = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return true; + }); + const auto &is_causal_masking = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &dot_product_attention = res.Op(paddle::dialect::FusedDotProductAttentionOp::name(), @@ -307,11 +317,11 @@ class FusedDotProductAttentionGradPattern }; class FusedDotProductAttentionWithDropoutPattern - : public pir::drr::DrrPatternBase< + : public paddle::drr::DrrPatternBase< FusedDotProductAttentionWithDropoutPattern> { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern src = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern src = ctx->SourcePattern(); // q[b, s, head, head_dim] -> transpose -> q[b, head, s, head_dim] -> scale const auto &q_transpose = src.Op("pd_op.transpose"); @@ -376,40 +386,45 @@ class FusedDotProductAttentionWithDropoutPattern src.Tensor("out") = o_transpose(src.Tensor("context_matmul_out")); // Constraints - src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool { - const auto &softmax_axis = match_ctx.Attr("softmax_axis"); - if (softmax_axis != -1 && softmax_axis != 3) return false; - - bool qk_matmul_transpose_x = - match_ctx.Attr("qk_matmul_transpose_x"); - bool qk_matmul_transpose_y = - match_ctx.Attr("qk_matmul_transpose_y"); - if (qk_matmul_transpose_x || !qk_matmul_transpose_y) return false; - - bool context_matmul_transpose_x = - match_ctx.Attr("context_matmul_transpose_x"); - bool context_matmul_transpose_y = - match_ctx.Attr("context_matmul_transpose_y"); - if (context_matmul_transpose_x || context_matmul_transpose_y) - return false; - - return true; - }); + src.RequireNativeCall( + [](const paddle::drr::MatchContext &match_ctx) -> bool { + const auto &softmax_axis = match_ctx.Attr("softmax_axis"); + if (softmax_axis != -1 && softmax_axis != 3) return false; + + bool qk_matmul_transpose_x = + match_ctx.Attr("qk_matmul_transpose_x"); + bool qk_matmul_transpose_y = + match_ctx.Attr("qk_matmul_transpose_y"); + if (qk_matmul_transpose_x || !qk_matmul_transpose_y) return false; + + bool context_matmul_transpose_x = + match_ctx.Attr("context_matmul_transpose_x"); + bool context_matmul_transpose_y = + match_ctx.Attr("context_matmul_transpose_y"); + if (context_matmul_transpose_x || context_matmul_transpose_y) + return false; + + return true; + }); // Result pattern - pir::drr::ResultPattern res = src.ResultPattern(); + paddle::drr::ResultPattern res = src.ResultPattern(); const auto &scaling_factor = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { return match_ctx.Attr("q_scale_value"); }); const auto &dropout_prob = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { return static_cast(0.0); }); - const auto &is_training = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); - const auto &is_causal_masking = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &is_training = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return true; + }); + const auto &is_causal_masking = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &dot_product_attention = res.Op(paddle::dialect::FusedDotProductAttentionOp::name(), @@ -429,11 +444,11 @@ class FusedDotProductAttentionWithDropoutPattern }; class FusedDotProductAttentionGradWithDropoutPattern - : public pir::drr::DrrPatternBase< + : public paddle::drr::DrrPatternBase< FusedDotProductAttentionGradWithDropoutPattern> { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern src = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern src = ctx->SourcePattern(); // q[b, s, head, head_dim] -> transpose -> q[b, head, s, head_dim] -> scale const auto &q_transpose = src.Op("pd_op.transpose"); @@ -548,36 +563,41 @@ class FusedDotProductAttentionGradWithDropoutPattern {&src.Tensor("k_grad")}); // Constraints - src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool { - const auto &softmax_axis = match_ctx.Attr("softmax_axis"); - if (softmax_axis != -1 && softmax_axis != 3) return false; - - bool qk_matmul_transpose_x = - match_ctx.Attr("qk_matmul_transpose_x"); - bool qk_matmul_transpose_y = - match_ctx.Attr("qk_matmul_transpose_y"); - if (qk_matmul_transpose_x || !qk_matmul_transpose_y) return false; - - bool context_matmul_transpose_x = - match_ctx.Attr("context_matmul_transpose_x"); - bool context_matmul_transpose_y = - match_ctx.Attr("context_matmul_transpose_y"); - if (context_matmul_transpose_x || context_matmul_transpose_y) - return false; - - return true; - }); + src.RequireNativeCall( + [](const paddle::drr::MatchContext &match_ctx) -> bool { + const auto &softmax_axis = match_ctx.Attr("softmax_axis"); + if (softmax_axis != -1 && softmax_axis != 3) return false; + + bool qk_matmul_transpose_x = + match_ctx.Attr("qk_matmul_transpose_x"); + bool qk_matmul_transpose_y = + match_ctx.Attr("qk_matmul_transpose_y"); + if (qk_matmul_transpose_x || !qk_matmul_transpose_y) return false; + + bool context_matmul_transpose_x = + match_ctx.Attr("context_matmul_transpose_x"); + bool context_matmul_transpose_y = + match_ctx.Attr("context_matmul_transpose_y"); + if (context_matmul_transpose_x || context_matmul_transpose_y) + return false; + + return true; + }); // Result pattern - pir::drr::ResultPattern res = src.ResultPattern(); + paddle::drr::ResultPattern res = src.ResultPattern(); const auto &scaling_factor = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { return match_ctx.Attr("q_scale_value"); }); - const auto &is_training = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); - const auto &is_causal_masking = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &is_training = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return true; + }); + const auto &is_causal_masking = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &dot_product_attention = res.Op(paddle::dialect::FusedDotProductAttentionOp::name(), diff --git a/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc index 35079d4f2cf1ca..df8b39cfc8676d 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc @@ -22,10 +22,10 @@ namespace { class FusedDropoutAddPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &dropout = pat.Op(paddle::dialect::DropoutOp::name(), {{"p", pat.Attr("p")}, {"is_test", pat.Attr("is_test")}, @@ -38,7 +38,7 @@ class FusedDropoutAddPattern {&pat.Tensor("dropout_out"), &pat.Tensor("mask")}); pat.Tensor("add_out") = add(pat.Tensor("dropout_out"), pat.Tensor("y")); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &fused_dropout_add = res.Op(paddle::dialect::FusedDropoutAddOp::name(), {{{"p", pat.Attr("p")}, @@ -53,10 +53,10 @@ class FusedDropoutAddPattern }; class FusedDropoutGradAddGradPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &dropout = pat.Op(paddle::dialect::DropoutOp::name(), {{"p", pat.Attr("p")}, {"is_test", pat.Attr("is_test")}, @@ -81,7 +81,7 @@ class FusedDropoutGradAddGradPattern dropout_grad({&pat.Tensor("mask"), &pat.Tensor("dropout_out_grad")}, {&pat.Tensor("x_grad")}); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &fused_dropout_add = res.Op(paddle::dialect::FusedDropoutAddOp::name(), {{{"p", pat.Attr("p")}, diff --git a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc index 6bc15234efd31b..02a6b4744cdcb8 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc @@ -21,10 +21,11 @@ namespace { -class FusedLinearPattern : public pir::drr::DrrPatternBase { +class FusedLinearPattern + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); @@ -33,15 +34,15 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase { pat.Tensor("tmp") = matmul(pat.Tensor("x"), pat.Tensor("w")); pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias")); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return (match_ctx.Tensor("w").Shape().size() == 2 && match_ctx.Tensor("x").Shape().size() >= 2 && match_ctx.Tensor("bias").Shape().size() == 1); }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &act_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return "none"; }); const auto &fused_gemm_epilogue = @@ -56,10 +57,10 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase { }; class FusedLinearGradPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); @@ -76,15 +77,15 @@ class FusedLinearGradPattern matmul_grad({&pat.Tensor("x"), &pat.Tensor("w"), &pat.Tensor("tmp_grad")}, {&pat.Tensor("x_grad"), &pat.Tensor("w_grad")}); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return (match_ctx.Tensor("w").Shape().size() == 2 && match_ctx.Tensor("x").Shape().size() >= 2 && match_ctx.Tensor("bias").Shape().size() == 1); }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &act_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return "none"; }); const auto &fused_gemm_epilogue = @@ -111,10 +112,10 @@ class FusedLinearGradPattern }; class FusedLinearGeluPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); // Source pattern const auto &fused_gemm_epilogue = pat.Op(paddle::dialect::FusedGemmEpilogueOp::name(), @@ -128,14 +129,14 @@ class FusedLinearGeluPattern pat.Tensor("out") = gelu(pat.Tensor("fuse_out")); // Constrains the activation is none - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return (match_ctx.Attr("act") == "none"); }); // Result pattern - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &act_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return "gelu"; }); const auto &fused_gemm_epilogue_gelu = @@ -149,10 +150,10 @@ class FusedLinearGeluPattern } }; class FusedLinearReluPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); // Source pattern const auto &fused_gemm_epilogue = pat.Op(paddle::dialect::FusedGemmEpilogueOp::name(), @@ -166,14 +167,14 @@ class FusedLinearReluPattern pat.Tensor("out") = relu(pat.Tensor("fuse_out")); // Constrains the activation is none - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return (match_ctx.Attr("act") == "none"); }); // Result pattern - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &act_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return "relu"; }); const auto &fused_gemm_epilogue_relu = @@ -188,10 +189,10 @@ class FusedLinearReluPattern }; class FusedLinearGeluGradPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &fused_gemm_epilogue = pat.Op(paddle::dialect::FusedGemmEpilogueOp::name(), {{{"trans_x", pat.Attr("trans_x1")}, @@ -218,14 +219,14 @@ class FusedLinearGeluGradPattern pat.Tensor("gelu_dx") = pat.Op(paddle::dialect::GeluGradOp::name())( pat.Tensor("fuse_out"), pat.Tensor("x1_grad")); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return match_ctx.Attr("act1") == "none" && match_ctx.Attr("act2") == "none"; }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &act_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return "gelu"; }); const auto &fused_gemm_epilogue_new = @@ -234,7 +235,7 @@ class FusedLinearGeluGradPattern {"trans_y", pat.Attr("trans_y1")}, {"activation", act_attr}}}); const auto &act_grad_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return "gelu_grad"; }); const auto &fused_gemm_epilogue_grad_new = @@ -256,10 +257,10 @@ class FusedLinearGeluGradPattern }; class FusedLinearReluGradPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &fused_gemm_epilogue = pat.Op(paddle::dialect::FusedGemmEpilogueOp::name(), {{{"trans_x", pat.Attr("trans_x1")}, @@ -297,14 +298,14 @@ class FusedLinearReluGradPattern &pat.Tensor("w_grad"), &pat.Tensor("bias_grad")}); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return match_ctx.Attr("act1") == "relu" && match_ctx.Attr("act3") == "none"; }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &act_grad_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return "relu_grad"; }); const auto &res_fused_gemm_epilogue_grad1 = diff --git a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc index 7a3afec65f33fc..8c93ff98226754 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc @@ -22,10 +22,10 @@ namespace { // add_grad + matmul_grad + add_ -> matmul + fused_liner_param_gard_add class FusedMatmulAddGradAddPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul0 = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); @@ -48,7 +48,7 @@ class FusedMatmulAddGradAddPattern pat.Tensor("add_out") = add_(pat.Tensor("dweight"), pat.Tensor("weight_grad")); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { const auto &x_trans = match_ctx.Attr("trans_x"); const auto &y_trans = match_ctx.Attr("trans_y"); return (match_ctx.Tensor("weight_grad").Shape() == @@ -58,17 +58,21 @@ class FusedMatmulAddGradAddPattern x_trans == false && y_trans == false); }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(match_ctx.Tensor("dweight").Dtype() == match_ctx.Tensor("weight_grad").Dtype()); }); - const auto &true_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); - const auto &false_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &true_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return true; + }); + const auto &false_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &matmul = res.Op(paddle::dialect::MatmulOp::name(), @@ -89,10 +93,10 @@ class FusedMatmulAddGradAddPattern // matmul_grad + add_ -> matmul + fused_liner_param_gard_add class FusedMatmulGradAddPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul_grad = pat.Op(paddle::dialect::MatmulGradOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); @@ -104,7 +108,7 @@ class FusedMatmulGradAddPattern pat.Tensor("add_out") = add_(pat.Tensor("dweight"), pat.Tensor("weight_grad")); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { const auto &x_trans = match_ctx.Attr("trans_x"); const auto &y_trans = match_ctx.Attr("trans_y"); return (match_ctx.Tensor("weight_grad").Shape() == @@ -112,18 +116,22 @@ class FusedMatmulGradAddPattern x_trans == false && y_trans == false); }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(match_ctx.Tensor("dweight").Dtype() == match_ctx.Tensor("weight_grad").Dtype()); }); - const auto &true_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); - const auto &false_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &true_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return true; + }); + const auto &false_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &matmul = res.Op(paddle::dialect::MatmulOp::name(), @@ -145,10 +153,10 @@ class FusedMatmulGradAddPattern // matmul + 0 = add_(0,1) -> fused_liner_param_gard_add class FusedMatmulAddaPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); @@ -159,22 +167,26 @@ class FusedMatmulAddaPattern pat.Tensor("add_out") = add_(pat.Tensor("dweight"), pat.Tensor("weight_grad")); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return (match_ctx.Tensor("weight_grad").Shape() == match_ctx.Tensor("dweight").Shape()); }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(match_ctx.Tensor("dweight").Dtype() == match_ctx.Tensor("weight_grad").Dtype()); }); - const auto &true_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); - const auto &false_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &true_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return true; + }); + const auto &false_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &fused_linear_param_grad_add = res.Op( paddle::dialect::FusedLinearParamGradAddOp::name(), @@ -190,10 +202,10 @@ class FusedMatmulAddaPattern // matmul + 1 = add_(1,0) -> fused_liner_param_gard_add class FusedMatmulAddbPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); @@ -204,22 +216,26 @@ class FusedMatmulAddbPattern pat.Tensor("add_out") = add_(pat.Tensor("weight_grad"), pat.Tensor("dweight")); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return (match_ctx.Tensor("weight_grad").Shape() == match_ctx.Tensor("dweight").Shape()); }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(match_ctx.Tensor("dweight").Dtype() == match_ctx.Tensor("weight_grad").Dtype()); }); - const auto &true_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); - const auto &false_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return false; }); + const auto &true_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return true; + }); + const auto &false_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return false; + }); const auto &fused_linear_param_grad_add = res.Op( paddle::dialect::FusedLinearParamGradAddOp::name(), @@ -235,10 +251,10 @@ class FusedMatmulAddbPattern // add_grad + matmul + 0 = add_(0,1) -> fused_liner_param_gard_add class FusedMatmulAddGradAddaPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); @@ -261,21 +277,23 @@ class FusedMatmulAddGradAddaPattern pat.Tensor("dweight_out") = add_(pat.Tensor("dweight"), pat.Tensor("weight_grad")); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return (match_ctx.Tensor("weight_grad").Shape() == match_ctx.Tensor("dweight").Shape() && match_ctx.Tensor("out").Shape() == match_ctx.Tensor("dadd_out").Shape()); }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(match_ctx.Tensor("dweight").Dtype() == match_ctx.Tensor("weight_grad").Dtype()); }); - const auto &true_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); + const auto &true_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return true; + }); const auto &fused_linear_param_grad_add = res.Op( paddle::dialect::FusedLinearParamGradAddOp::name(), {{{"multi_precision", muti_precision_attr}, {"has_bias", true_attr}}}); @@ -290,10 +308,10 @@ class FusedMatmulAddGradAddaPattern // add_grad + matmul + 1 = add_(1,0) -> fused_liner_param_gard_add class FusedMatmulAddGradAddbPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("trans_x")}, {"transpose_y", pat.Attr("trans_y")}}); @@ -316,21 +334,23 @@ class FusedMatmulAddGradAddbPattern pat.Tensor("dweight_out") = add_(pat.Tensor("weight_grad"), pat.Tensor("dweight")); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return (match_ctx.Tensor("weight_grad").Shape() == match_ctx.Tensor("dweight").Shape() && match_ctx.Tensor("out").Shape() == match_ctx.Tensor("dadd_out").Shape()); }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(match_ctx.Tensor("dweight").Dtype() == match_ctx.Tensor("weight_grad").Dtype()); }); - const auto &true_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> bool { return true; }); + const auto &true_attr = + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return true; + }); const auto &fused_linear_param_grad_add = res.Op( paddle::dialect::FusedLinearParamGradAddOp::name(), {{{"multi_precision", muti_precision_attr}, {"has_bias", true_attr}}}); diff --git a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc index fa83418e562baf..82864f3d80e88f 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc @@ -36,13 +36,13 @@ int getSMVersion() { } class FusedWeightOnlyLinearPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // // Source Pattern. // - pir::drr::SourcePattern src = ctx->SourcePattern(); + paddle::drr::SourcePattern src = ctx->SourcePattern(); const auto &matmul = src.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", src.Attr("matmul_transpose_x")}, @@ -57,47 +57,49 @@ class FusedWeightOnlyLinearPattern // // Constraints. // - src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool { - bool matmul_trans_x = match_ctx.Attr("matmul_transpose_x"); - bool matmul_trans_y = match_ctx.Attr("matmul_transpose_y"); - if (matmul_trans_x || matmul_trans_y) return false; - - if (!(match_ctx.Tensor("w").Shape().size() == 2 && - match_ctx.Tensor("x").Shape().size() >= 2 && - match_ctx.Tensor("bias").Shape().size() == 1)) { - return false; - } - - auto w_dims = match_ctx.Tensor("w").Shape(); - if (w_dims.at(0) % 64 != 0 || w_dims.at(1) % 16 != 0) return false; - - auto w_dtype = match_ctx.Tensor("w").Dtype().get(); - if (!w_dtype.isa() && !w_dtype.isa()) - return false; - - auto x_dims = match_ctx.Tensor("x").Shape(); - if (x_dims.at(x_dims.size() - 1) != w_dims.at(1)) return false; - - return true; - }); + src.RequireNativeCall( + [](const paddle::drr::MatchContext &match_ctx) -> bool { + bool matmul_trans_x = match_ctx.Attr("matmul_transpose_x"); + bool matmul_trans_y = match_ctx.Attr("matmul_transpose_y"); + if (matmul_trans_x || matmul_trans_y) return false; + + if (!(match_ctx.Tensor("w").Shape().size() == 2 && + match_ctx.Tensor("x").Shape().size() >= 2 && + match_ctx.Tensor("bias").Shape().size() == 1)) { + return false; + } + + auto w_dims = match_ctx.Tensor("w").Shape(); + if (w_dims.at(0) % 64 != 0 || w_dims.at(1) % 16 != 0) return false; + + auto w_dtype = match_ctx.Tensor("w").Dtype().get(); + if (!w_dtype.isa() && + !w_dtype.isa()) + return false; + + auto x_dims = match_ctx.Tensor("x").Shape(); + if (x_dims.at(x_dims.size() - 1) != w_dims.at(1)) return false; + + return true; + }); // // Result Pattern. // - pir::drr::ResultPattern res = src.ResultPattern(); + paddle::drr::ResultPattern res = src.ResultPattern(); // quantize weight const auto &weight_only_int8_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return "weight_only_int8"; }); const auto &arch_attr = - res.Attr([&](const pir::drr::MatchContext &match_ctx) -> int { + res.Attr([&](const paddle::drr::MatchContext &match_ctx) -> int { return getSMVersion(); }); const auto &group_size_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> int { return -1; }); + [](const paddle::drr::MatchContext &match_ctx) -> int { return -1; }); const auto &weight_quantize = res.Op(paddle::dialect::WeightQuantizeOp::name(), @@ -109,7 +111,7 @@ class FusedWeightOnlyLinearPattern &res.Tensor("weight_scale_tensor")}); const auto &weight_dtype_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { return "int8"; }); diff --git a/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc index 627c1cd516cc85..0bced0b8ec823f 100644 --- a/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc @@ -28,10 +28,10 @@ namespace { class MatmulScaleFusePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &matmul_op = pat.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("transpose_x")}, {"transpose_y", pat.Attr("transpose_y")}}); @@ -50,23 +50,23 @@ class MatmulScaleFusePattern scale_op({&pat.Tensor("matmul_out"), &full_op()}, {&pat.Tensor("scale_out")}); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return std::abs(match_ctx.Attr("bias")) <= 1e-6; }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &full_op_res = res.Op(paddle::dialect::FullOp::name(), {{"shape", pat.Attr("shape")}, {"value", pat.Attr("value")}, {"dtype", pat.Attr("dtype")}, {"place", pat.Attr("place")}}); - const auto &scale_op_res = - res.Op(paddle::dialect::ScaleOp::name(), - {{"bias", - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { - return 0.0; - })}, - {"bias_after_scale", pat.Attr("bias_after_scale")}}); + const auto &scale_op_res = res.Op( + paddle::dialect::ScaleOp::name(), + {{"bias", + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { + return 0.0; + })}, + {"bias_after_scale", pat.Attr("bias_after_scale")}}); const auto &matmul_op_res = res.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("transpose_x")}, diff --git a/paddle/fluid/pir/transforms/identity_op_clean_pass.cc b/paddle/fluid/pir/transforms/identity_op_clean_pass.cc index 377610196bf963..ac49d494d1c731 100644 --- a/paddle/fluid/pir/transforms/identity_op_clean_pass.cc +++ b/paddle/fluid/pir/transforms/identity_op_clean_pass.cc @@ -32,10 +32,10 @@ namespace { class RemoveUselessScalePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &full_op = pat.Op(paddle::dialect::FullOp::name(), {{"shape", pat.Attr("shape")}, {"value", pat.Attr("value")}, @@ -47,21 +47,21 @@ class RemoveUselessScalePattern {"bias_after_scale", pat.Attr("bias_after_scale")}}); scale_op({&pat.Tensor("x"), &full_op()}, {&pat.Tensor("scale_out")}); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { return (match_ctx.Attr("value") == 1.0 && match_ctx.Attr("bias") == 0.0); }); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); res.Tensor("scale_out").Assign(res.Tensor("x")); } }; class RemoveRedundentScalePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &full_op_1 = pat.Op(paddle::dialect::FullOp::name(), {{"shape", pat.Attr("shape_1")}, {"value", pat.Attr("value_1")}, @@ -84,10 +84,10 @@ class RemoveRedundentScalePattern scale_op_2({&pat.Tensor("scale_1_out"), &full_op_2()}, {&pat.Tensor("scale_2_out")}); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &bais_res = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { float res_bias_1 = 0.f; float res_bias_2 = 0.f; if (match_ctx.Attr("bias_after_scale_1")) { @@ -106,7 +106,7 @@ class RemoveRedundentScalePattern return res_bias_2; }); const auto &res_scale_input = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> float { return match_ctx.Attr("value_1") * match_ctx.Attr("value_2"); }); @@ -116,22 +116,22 @@ class RemoveRedundentScalePattern {"value", res_scale_input}, {"dtype", pat.Attr("dtype_1")}, {"place", pat.Attr("place_1")}}); - const auto &scale_op_res = - res.Op("pd_op.scale", - {{"bias", bais_res}, - {"bias_after_scale", - res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { - return true; - })}}); + const auto &scale_op_res = res.Op( + "pd_op.scale", + {{"bias", bais_res}, + {"bias_after_scale", + res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { + return true; + })}}); scale_op_res({&res.Tensor("x"), &full_op_res()}, {&res.Tensor("scale_2_out")}); } }; class RemoveUselessCastPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); pat.Tensor("ret") = pat.Op("pd_op.cast")(pat.Tensor("arg0")); pat.RequireEqual(pat.Tensor("ret").dtype(), pat.Tensor("arg0").dtype()); @@ -141,16 +141,16 @@ class RemoveUselessCastPattern }; class RemoveUselessConcatPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); const auto &combine = pat.Op(pir::CombineOp::name()); combine({&pat.Tensor("x")}, {&pat.Tensor("combine_out")}); pat.Tensor("out") = pat.Op(paddle::dialect::ConcatOp::name())( pat.Tensor("combine_out"), pat.Tensor("axis")); - pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { - auto combine_out = dynamic_cast( + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + auto combine_out = dynamic_cast( match_ctx.Tensor("combine_out")); return combine_out.type_isa() && combine_out.type_dyn_cast().size() == 1; @@ -161,8 +161,8 @@ class RemoveUselessConcatPattern }; class RemoveRedundentCastPattern - : public pir::drr::DrrPatternBase { - void operator()(pir::drr::DrrPatternContext *ctx) const override { + : public paddle::drr::DrrPatternBase { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); pat.Tensor("tmp") = pat.Op( "pd_op.cast", {{"dtype", pat.Attr("dtype1")}})(pat.Tensor("arg0")); @@ -175,10 +175,10 @@ class RemoveRedundentCastPattern }; class RemoveRedundentTransposePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &transpose1 = pat.Op("pd_op.transpose", {{"perm", pat.Attr("perm_1")}}); const auto &transpose2 = @@ -186,9 +186,9 @@ class RemoveRedundentTransposePattern pat.Tensor("ret") = transpose2(transpose1(pat.Tensor("arg_transpose"))); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &new_perm_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> std::vector { + [](const paddle::drr::MatchContext &match_ctx) -> std::vector { const auto &perm1 = match_ctx.Attr>("perm_1"); const auto &perm2 = match_ctx.Attr>("perm_2"); std::vector new_perm; diff --git a/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc b/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc index b550212ad3654e..1a938e7f600b78 100644 --- a/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc +++ b/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc @@ -53,10 +53,10 @@ class SameTypeBindingTestPattern // This class is for test cases of the same type of OP. // (without considering the computational logic between OPs, // only focusing on the process of matching and replacing) - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern src = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern src = ctx->SourcePattern(); // path 1 const auto &transpose_1 = @@ -141,7 +141,7 @@ class SameTypeBindingTestPattern const auto &relu_2 = src.Op("pd_op.relu"); src.Tensor("output6") = relu_2(src.Tensor("add_2_out")); - pir::drr::ResultPattern res = src.ResultPattern(); + paddle::drr::ResultPattern res = src.ResultPattern(); const auto &transpose_7 = res.Op("pd_op.transpose", {{"perm", src.Attr("perm_4")}}); res.Tensor("output0") = transpose_7(res.Tensor("input_1")); diff --git a/test/cpp/pir/pattern_rewrite/drr_test.cc b/test/cpp/pir/pattern_rewrite/drr_test.cc index fc0e7ae94f05f9..54b5ff2025e49d 100644 --- a/test/cpp/pir/pattern_rewrite/drr_test.cc +++ b/test/cpp/pir/pattern_rewrite/drr_test.cc @@ -24,11 +24,11 @@ #include "paddle/pir/pass/pass_manager.h" class RemoveRedundentReshapePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source patterns - pir::drr::SourcePattern pat = ctx->SourcePattern(); + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &reshape1 = pat.Op("pd_op.reshape"); const auto &reshape2 = pat.Op("pd_op.reshape"); @@ -38,18 +38,18 @@ class RemoveRedundentReshapePattern {&pat.Tensor("ret"), &pat.Tensor("xshape_1")}); // Result patterns - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); res.Op("pd_op.reshape")({&res.Tensor("arg0"), &res.Tensor("shape1")}, {&res.Tensor("ret"), &res.Tensor("xshape_1")}); } }; class FoldExpandToConstantPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern - pir::drr::SourcePattern pat = ctx->SourcePattern(); + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &full1 = pat.Op("pd_op.full", {{"shape", pat.Attr("shape_1")}, {"value", pat.Attr("value_1")}, @@ -64,9 +64,9 @@ class FoldExpandToConstantPattern pat.Tensor("ret") = expand(full1(), full_int_array1()); // Result patterns - pir::drr::ResultPattern res = pat.ResultPattern(); - const auto &new_perm_attr = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> phi::IntArray { + paddle::drr::ResultPattern res = pat.ResultPattern(); + const auto &new_perm_attr = res.Attr( + [](const paddle::drr::MatchContext &match_ctx) -> phi::IntArray { auto shape = match_ctx.Attr>("expand_shape_value"); @@ -82,10 +82,10 @@ class FoldExpandToConstantPattern }; class RemoveRedundentTransposePattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - pir::drr::SourcePattern pat = ctx->SourcePattern(); + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &transpose1 = pat.Op("pd_op.transpose", {{"perm", pat.Attr("perm_1")}}); const auto &transpose2 = @@ -93,9 +93,9 @@ class RemoveRedundentTransposePattern pat.Tensor("ret") = transpose2(transpose1(pat.Tensor("arg_transpose"))); - pir::drr::ResultPattern res = pat.ResultPattern(); + paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &new_perm_attr = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> std::vector { + [](const paddle::drr::MatchContext &match_ctx) -> std::vector { const auto &perm1 = match_ctx.Attr>("perm_1"); const auto &perm2 = match_ctx.Attr>("perm_2"); std::vector new_perm; @@ -112,8 +112,8 @@ class RemoveRedundentTransposePattern }; class RemoveRedundentCastPattern - : public pir::drr::DrrPatternBase { - void operator()(pir::drr::DrrPatternContext *ctx) const override { + : public paddle::drr::DrrPatternBase { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); pat.Tensor("tmp") = pat.Op( "pd_op.cast", {{"dtype", pat.Attr("dtype1")}})(pat.Tensor("arg0")); @@ -126,9 +126,9 @@ class RemoveRedundentCastPattern }; class RemoveUselessCastPattern - : public pir::drr::DrrPatternBase { + : public paddle::drr::DrrPatternBase { public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { + void operator()(paddle::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); pat.Tensor("ret") = pat.Op("pd_op.cast")(pat.Tensor("arg0")); pat.RequireEqual(pat.Tensor("ret").dtype(), pat.Tensor("arg0").dtype()); From 4cb084c3af2af42f567e5c24aa74fa08b2a2d21b Mon Sep 17 00:00:00 2001 From: Android zhang <53324261+zade23@users.noreply.github.com> Date: Fri, 29 Dec 2023 10:27:31 +0800 Subject: [PATCH 038/142] =?UTF-8?q?=E3=80=90CMake=20opt=20No.2=E3=80=91rm?= =?UTF-8?q?=20some=20DEPS=20of=20`test/cpp/auto=5Fparallel/CMakeLists.txt`?= =?UTF-8?q?=20(#60348)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update CMakeLists.txt * fix conflict * add DEPS --- test/cpp/auto_parallel/CMakeLists.txt | 50 ++++++--------------------- 1 file changed, 11 insertions(+), 39 deletions(-) diff --git a/test/cpp/auto_parallel/CMakeLists.txt b/test/cpp/auto_parallel/CMakeLists.txt index 5911712dffdf26..311958d2e10310 100644 --- a/test/cpp/auto_parallel/CMakeLists.txt +++ b/test/cpp/auto_parallel/CMakeLists.txt @@ -15,49 +15,21 @@ if(WITH_DISTRIBUTE) SRCS dist_tensor_test.cc DEPS phi common) - paddle_test( - spmd_rule_test - SRCS - spmd_rule_test.cc - DEPS - spmd_rule_test_util - spmd_rules - phi) - paddle_test( - softmax_grad_spmd_rule_test - SRCS - softmax_grad_spmd_rule_test.cc - DEPS - spmd_rule_test_util - spmd_rules - phi) + paddle_test(spmd_rule_test SRCS spmd_rule_test.cc DEPS spmd_rule_test_util) - paddle_test( - tile_spmd_rule_test - SRCS - tile_spmd_rule_test.cc - DEPS - spmd_rule_test_util - spmd_rules - phi) + paddle_test(softmax_grad_spmd_rule_test SRCS softmax_grad_spmd_rule_test.cc + DEPS spmd_rule_test_util) - paddle_test( - fused_linear_param_grad_add_spmd_rule_test - SRCS - fused_linear_param_grad_add_spmd_rule_test.cc - DEPS - spmd_rule_test_util - spmd_rules - phi) + paddle_test(tile_spmd_rule_test SRCS tile_spmd_rule_test.cc DEPS + spmd_rule_test_util) paddle_test( - cross_entropy_softmax_spmd_rule_test - SRCS - cross_entropy_softmax_spmd_rule_test.cc - DEPS - spmd_rule_test_util - spmd_rules - phi) + fused_linear_param_grad_add_spmd_rule_test SRCS + fused_linear_param_grad_add_spmd_rule_test.cc DEPS spmd_rule_test_util) + + paddle_test(cross_entropy_softmax_spmd_rule_test SRCS + cross_entropy_softmax_spmd_rule_test.cc DEPS spmd_rule_test_util) + endif() cc_test( From 7909f768418b767708e14fa1c5bf6685b66e0ce4 Mon Sep 17 00:00:00 2001 From: chen2016013 <111894720+chen2016013@users.noreply.github.com> Date: Fri, 29 Dec 2023 10:31:06 +0800 Subject: [PATCH 039/142] [PIR] open test in test_ifelse for PIR (#60372) * open test * update * update * update as comment --- paddle/common/ddim.cc | 2 +- .../dialect/operator/ir/control_flow_op.cc | 18 +++++--- test/dygraph_to_static/ifelse_simple_func.py | 34 +++++++++++++-- test/dygraph_to_static/test_ifelse.py | 42 ++++++++++++++----- 4 files changed, 76 insertions(+), 20 deletions(-) diff --git a/paddle/common/ddim.cc b/paddle/common/ddim.cc index 7394dd03bfd8d2..75eb1423cce8a4 100644 --- a/paddle/common/ddim.cc +++ b/paddle/common/ddim.cc @@ -267,7 +267,7 @@ DDim DDim::transpose(const std::vector& axis) const { DDim ComputeCompatibleDim(const DDim& dim1, const DDim& dim2) { IR_ENFORCE(dim1.size() == dim2.size(), - "Does not support rank inconsistency: dim1=%d, dim2=%d", + "Does not support rank inconsistency: rank1=%d, rank2=%d", dim1.size(), dim2.size()); std::vector result; diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 040fbb28377115..30d5ce5a1b685e 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -102,12 +102,18 @@ void IfOp::Build(pir::Builder &builder, // NOLINT "The dtype in output[%d] of " "true_block&false_block must be equal.", i)); - PADDLE_ENFORCE_EQ(l_type.data_layout(), - r_type.data_layout(), - phi::errors::PreconditionNotMet( - "The date_layout in output[%d] of " - "true_block&false_block must be equal.", - i)); + if (l_type.data_layout() != phi::DataLayout::UNDEFINED && + r_type.data_layout() != phi::DataLayout::UNDEFINED) { + PADDLE_ENFORCE_EQ( + l_type.data_layout(), + r_type.data_layout(), + phi::errors::PreconditionNotMet( + "The data_layout in output[%d] of " + "true_block (%s) & false_block (%s) must be equal.", + i, + l_type.data_layout(), + r_type.data_layout())); + } PADDLE_ENFORCE_EQ(l_type.lod(), r_type.lod(), phi::errors::PreconditionNotMet( diff --git a/test/dygraph_to_static/ifelse_simple_func.py b/test/dygraph_to_static/ifelse_simple_func.py index d7767a3cfbefb9..b011989fed709a 100644 --- a/test/dygraph_to_static/ifelse_simple_func.py +++ b/test/dygraph_to_static/ifelse_simple_func.py @@ -60,9 +60,12 @@ def dyfunc_with_if_else2(x, col=100): # TODO: Don't support return non-Tensor in Tensor-dependent `if` statement currently. # `x` is Tensor, `col` is not Tensor, and `col` is the return value of `true_fn` after transformed. # col = -1 - col = paddle.tensor.fill_constant(shape=[1], value=-1, dtype="int64") + col = paddle.tensor.fill_constant(shape=[], value=-1, dtype="int64") + else: + col = paddle.tensor.fill_constant(shape=[], value=1, dtype="int64") if paddle.mean(x).numpy() > x.numpy()[row][col]: - y = paddle.nn.functional.relu(x) + x_pow = paddle.pow(x, 2) + y = paddle.nn.functional.relu(x_pow) else: x_pow = paddle.pow(x, 2) y = paddle.tanh(x_pow) @@ -100,9 +103,12 @@ def false_fn_0(q, x, y): x = x + 1 z = x + 2 q = x + 3 + m = x + 2 + n = x + 3 else: y = y + 1 z = x - 2 + q = x + 3 m = x + 2 n = x + 3 @@ -165,6 +171,22 @@ def nested_if_else(x_v): tmp = y * w y = paddle.nn.functional.relu(tmp) if paddle.mean(y).numpy() < batch_size: + tmp = paddle.tensor.fill_constant( + y.shape, dtype='float32', value=-1 + ) + y = paddle.abs(y) + else: + tmp = paddle.tensor.fill_constant( + y.shape, dtype='float32', value=-1 + ) + y = y - tmp + else: + tmp = y * w + y = paddle.nn.functional.relu(tmp) + if paddle.mean(y).numpy() < batch_size: + tmp = paddle.tensor.fill_constant( + y.shape, dtype='float32', value=-1 + ) y = paddle.abs(y) else: tmp = paddle.tensor.fill_constant( @@ -173,6 +195,11 @@ def nested_if_else(x_v): y = y - tmp else: y = x_v - bias + w = paddle.tensor.fill_constant([feat_size], dtype='float32', value=10) + tmp = y * w + y = paddle.nn.functional.relu(tmp) + tmp = paddle.tensor.fill_constant(y.shape, dtype='float32', value=-1) + y = paddle.abs(y) return y @@ -223,12 +250,14 @@ def nested_if_else_3(x): ) # `z` is created in above code block. z = y + 1 + out = x - 1 else: res = paddle.tensor.fill_constant( value=3, shape=x.shape, dtype="int32" ) # `out` is a new var. out = x + 1 + z = y - 1 return res @@ -378,7 +407,6 @@ def __init__(self): def if_tensor_case(x): x = base.dygraph.to_variable(x) - mean = paddle.mean(x) # It is equivalent to `if mean != 0` if mean: diff --git a/test/dygraph_to_static/test_ifelse.py b/test/dygraph_to_static/test_ifelse.py index 5f50780597e814..7f2262fca3ea64 100644 --- a/test/dygraph_to_static/test_ifelse.py +++ b/test/dygraph_to_static/test_ifelse.py @@ -22,6 +22,7 @@ disable_test_case, enable_to_static_guard, test_ast_only, + test_legacy_and_pt_and_pir, test_legacy_only, ) from ifelse_simple_func import ( @@ -69,6 +70,7 @@ def setUp(self): self.error = "Your if/else have different number of return value." @test_ast_only + @test_legacy_and_pt_and_pir def test_error(self): if self.dyfunc: with self.assertRaisesRegex(Dygraph2StaticException, self.error): @@ -76,6 +78,13 @@ def test_error(self): self.assertTrue(paddle.jit.to_static(self.dyfunc)(self.x)) +class TestDy2StIfElseRetInt2(TestDy2staticException): + def setUp(self): + self.x = np.random.random([5]).astype('float32') + self.error = "Your if/else have different number of return value." + self.dyfunc = dyfunc_ifelse_ret_int2 + + class TestDygraphIfElse(Dy2StTestBase): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') @@ -104,7 +113,6 @@ def setUp(self): # TODO(dev): fix AST mode @disable_test_case((ToStaticMode.AST, IrMode.PT)) - @disable_test_case((ToStaticMode.AST, IrMode.LEGACY_IR)) def test_ast_to_func(self): self.assertTrue((self._run_dygraph() == self._run_static()).all()) @@ -143,6 +151,10 @@ def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.dyfunc = dyfunc_with_if_else_with_list_generator + @test_legacy_and_pt_and_pir + def test_ast_to_func(self): + self.assertTrue((self._run_dygraph() == self._run_static()).all()) + class TestDygraphNestedIfElse(Dy2StTestBase): def setUp(self): @@ -172,6 +184,10 @@ def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.dyfunc = nested_if_else_2 + @test_legacy_and_pt_and_pir + def test_ast_to_func(self): + self.assertTrue((self._run_dygraph() == self._run_static()).all()) + class TestDygraphNestedIfElse3(Dy2StTestBase): def setUp(self): @@ -269,12 +285,20 @@ def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.dyfunc = if_with_and_or_2 + @test_legacy_and_pt_and_pir + def test_ast_to_func(self): + self.assertTrue((self._run_dygraph() == self._run_static()).all()) + class TestDygraphIfElseWithAndOr3(TestDygraphIfElse): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.dyfunc = if_with_and_or_3 + @test_legacy_and_pt_and_pir + def test_ast_to_func(self): + self.assertTrue((self._run_dygraph() == self._run_static()).all()) + class TestDygraphIfElseWithAndOr4(TestDygraphIfElse): def setUp(self): @@ -439,11 +463,14 @@ def init_net(self): def _run(self, mode, to_static): with enable_to_static_guard(to_static): - net = paddle.jit.to_static(self.Net(mode)) + if to_static: + net = paddle.jit.to_static(self.Net(mode)) + else: + net = self.Net(mode) ret = net(self.x, self.y) - return ret.numpy() + @test_legacy_and_pt_and_pir def test_train_mode(self): self.assertTrue( ( @@ -452,6 +479,7 @@ def test_train_mode(self): ).all() ) + @test_legacy_and_pt_and_pir def test_infer_mode(self): self.assertTrue( ( @@ -467,6 +495,7 @@ def init_net(self): class TestNewVarCreateInOneBranch(Dy2StTestBase): + @test_legacy_and_pt_and_pir def test_var_used_in_another_for(self): def case_func(training): # targets and targets_list is dynamically defined by training @@ -510,13 +539,6 @@ def test_ast_to_func(self): self.assertIsInstance(self.out[1], int) -class TestDy2StIfElseRetInt2(TestDy2staticException): - def setUp(self): - self.x = np.random.random([5]).astype('float32') - self.error = "Your if/else have different number of return value." - self.dyfunc = dyfunc_ifelse_ret_int2 - - class TestDy2StIfElseRetInt3(TestDy2StIfElseRetInt1): def setUp(self): self.x = np.random.random([5]).astype('float32') From 23bf65ac06647b3a92daaac1d2cefac4330d90f3 Mon Sep 17 00:00:00 2001 From: lanxianghit <47554610+lanxianghit@users.noreply.github.com> Date: Fri, 29 Dec 2023 10:33:50 +0800 Subject: [PATCH 040/142] add Get&Set APIs for value2sym_expr map (#60301) att, add Get&Set APIs for value2sym_expr map --- paddle/pir/dialect/shape/utils/shape_utils.cc | 12 ++++++++++++ paddle/pir/dialect/shape/utils/shape_utils.h | 5 +++++ 2 files changed, 17 insertions(+) diff --git a/paddle/pir/dialect/shape/utils/shape_utils.cc b/paddle/pir/dialect/shape/utils/shape_utils.cc index faa5c498bb1f94..05bbb76db8937c 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.cc +++ b/paddle/pir/dialect/shape/utils/shape_utils.cc @@ -175,4 +175,16 @@ std::string GetValueId(Value* val) { return "op_" + std::to_string(op_id) + "_rst_" + std::to_string(val_idx); } +const symbol::ShapeOrDataDimExprs& +ShapeConstraintIRAnalysis::GetShapeOrDataForValue(Value* val) { + auto val_id = GetValueId(val); + return value_id_to_shapeordata[val_id]; +} + +void ShapeConstraintIRAnalysis::SetShapeOrDataForValue( + Value* val, const symbol::ShapeOrDataDimExprs& shape_or_data) { + auto val_id = GetValueId(val); + value_id_to_shapeordata[val_id] = shape_or_data; +} + } // namespace pir diff --git a/paddle/pir/dialect/shape/utils/shape_utils.h b/paddle/pir/dialect/shape/utils/shape_utils.h index ac72c0bae88c78..8f383f3ad6e05a 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_utils.h @@ -80,6 +80,11 @@ class IR_API ShapeConstraintIRAnalysis : public ShapeAnalysis { return "S" + std::to_string(next_sym_idx_++); } + const symbol::ShapeOrDataDimExprs& GetShapeOrDataForValue(Value* val); + + void SetShapeOrDataForValue(Value* val, + const symbol::ShapeOrDataDimExprs& shape_or_data); + // const symbol::ShapeOrData& GetShapeOrDataForValue() const; symbol::DimExprBuilder CreateDimExprBuilder() override; From f02f3d6c4f80aba59076b66ef7b87e572a157b8a Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Fri, 29 Dec 2023 10:53:58 +0800 Subject: [PATCH 041/142] adapt pir api (#60416) --- .../pir/dialect/op_generator/ops_api_gen.py | 1 + paddle/fluid/pir/dialect/operator/ir/ops.yaml | 3 + paddle/fluid/pybind/pir.cc | 12 ++ .../incubate/optimizer/functional/utils.py | 2 +- test/legacy_test/test_lr_scheduler.py | 32 ++++ test/legacy_test/test_lrn_op.py | 181 +++++++++--------- 6 files changed, 141 insertions(+), 90 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index d541f34a890dc2..79cbad13c0f56c 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -102,6 +102,7 @@ 'print', 'number_count', 'assign_value', + 'share_data', 'onednn_to_paddle_layout', ] diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 5bdcadc3cca03f..b926b055daa6a2 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -1153,6 +1153,9 @@ - op : share_data args : (Tensor x) output : Tensor(out) + infer_meta: + func: UnchangedInferMeta + param: [x] kernel: func: share_data param: [x] diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 9e87a3f39459df..e2471842c07291 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -777,6 +777,18 @@ void BindValue(py::module *m) { .def("apply", &apply) .def("is_same", &Value::operator==) .def("hash", [](Value self) { return std::hash{}(self); }) + .def("detach", + [](Value self) { + auto share_data_op = + ApiBuilder::Instance() + .GetBuilder() + ->Build(self); + auto out = share_data_op.out(); + out.set_attribute( + kAttrStopGradients, + BoolAttribute::get(pir::IrContext::Instance(), false)); + return out; + }) .def("__repr__", &Value2String); } diff --git a/python/paddle/incubate/optimizer/functional/utils.py b/python/paddle/incubate/optimizer/functional/utils.py index c6a6f1c6b405a3..6fce7ef1703f5d 100644 --- a/python/paddle/incubate/optimizer/functional/utils.py +++ b/python/paddle/incubate/optimizer/functional/utils.py @@ -23,7 +23,7 @@ def check_input_type(input, name, op_name): if not isinstance(input, paddle.Tensor): raise ValueError(f"The input: {input} must be tensor.") else: - check_type(input, name, Variable, op_name) + check_type(input, name, (Variable, paddle.pir.Value), op_name) def check_initial_inverse_hessian_estimate(H0): diff --git a/test/legacy_test/test_lr_scheduler.py b/test/legacy_test/test_lr_scheduler.py index 3db40ea291342c..1109d29fa3214c 100644 --- a/test/legacy_test/test_lr_scheduler.py +++ b/test/legacy_test/test_lr_scheduler.py @@ -1231,6 +1231,38 @@ def test_linear_warmp(self): natural_lr.step() natural_lr_warmup.step() + def test_pir_linear_warmup_lr(self): + params = { + 'learning_rate': 0.5, + 'warmup_steps': 10, + 'start_lr': 0, + 'end_lr': 0.5, + } + scheduler = paddle.optimizer.lr.LinearWarmup(**params) + adam = paddle.optimizer.Adam(learning_rate=scheduler) + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + x = paddle.static.data(name='x', shape=[3, 4, 5]) + loss = paddle.mean(x) + adam.minimize(loss) + lr_var = adam._global_learning_rate() + + exe = paddle.static.Executor() + exe.run(start_prog) + for epoch in range(5): + for batch_id in range(2): + out = exe.run( + main_prog, + feed={'x': np.random.randn(3, 4, 5).astype('float32')}, + fetch_list=[lr_var], + ) + self.assertEqual( + out, np.array(linear_warmup_lr(epoch, **params)) + ) + scheduler.step() + if __name__ == '__main__': paddle.enable_static() diff --git a/test/legacy_test/test_lrn_op.py b/test/legacy_test/test_lrn_op.py index 34ceff298ec3d6..c97e8e7dd8536b 100644 --- a/test/legacy_test/test_lrn_op.py +++ b/test/legacy_test/test_lrn_op.py @@ -19,7 +19,8 @@ import paddle from paddle import base -from paddle.base import Program, core, program_guard +from paddle.base import core +from paddle.pir_utils import test_with_pir_api class TestLRNOp(OpTest): @@ -115,98 +116,96 @@ def setUp(self): self.places.append(base.CUDAPlace(0)) def check_static_3d_input(self, place): - with paddle_static_guard(): - with base.program_guard(base.Program(), base.Program()): - in_np1 = np.random.random([3, 40, 40]).astype("float32") - in_np2 = np.transpose(in_np1, (0, 2, 1)) - - input1 = paddle.static.data( - name="input1", shape=[3, 40, 40], dtype="float32" - ) - input2 = paddle.static.data( - name="input2", shape=[3, 40, 40], dtype="float32" - ) - res1 = paddle.nn.functional.local_response_norm( - x=input1, size=5, data_format='NCL' - ) - res2 = paddle.nn.functional.local_response_norm( - x=input2, size=5, data_format='NLC' - ) - exe = base.Executor(place) - fetches = exe.run( - base.default_main_program(), - feed={"input1": in_np1, "input2": in_np2}, - fetch_list=[res1, res2], - ) - - fetches1_tran = np.transpose(fetches[1], (0, 2, 1)) - np.testing.assert_allclose( - fetches[0], fetches1_tran, rtol=1e-05 - ) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + in_np1 = np.random.random([3, 40, 40]).astype("float32") + in_np2 = np.transpose(in_np1, (0, 2, 1)) + + input1 = paddle.static.data( + name="input1", shape=[3, 40, 40], dtype="float32" + ) + input2 = paddle.static.data( + name="input2", shape=[3, 40, 40], dtype="float32" + ) + res1 = paddle.nn.functional.local_response_norm( + x=input1, size=5, data_format='NCL' + ) + res2 = paddle.nn.functional.local_response_norm( + x=input2, size=5, data_format='NLC' + ) + exe = base.Executor(place) + fetches = exe.run( + paddle.static.default_main_program(), + feed={"input1": in_np1, "input2": in_np2}, + fetch_list=[res1, res2], + ) + + fetches1_tran = np.transpose(fetches[1], (0, 2, 1)) + np.testing.assert_allclose(fetches[0], fetches1_tran, rtol=1e-05) def check_static_4d_input(self, place): - with paddle_static_guard(): - with base.program_guard(base.Program(), base.Program()): - input1 = paddle.static.data( - name="input1", shape=[3, 3, 40, 40], dtype="float32" - ) - input2 = paddle.static.data( - name="input2", shape=[3, 40, 40, 3], dtype="float32" - ) - - res1 = paddle.nn.functional.local_response_norm( - x=input1, size=5, data_format='NCHW' - ) - res2 = paddle.nn.functional.local_response_norm( - x=input2, size=5, data_format='NHWC' - ) - - in_np1 = np.random.random([3, 3, 40, 40]).astype("float32") - in_np2 = np.transpose(in_np1, (0, 2, 3, 1)) - - exe = base.Executor(place) - fetches = exe.run( - base.default_main_program(), - feed={"input1": in_np1, "input2": in_np2}, - fetch_list=[res1, res2], - ) - - fetches1_tran = np.transpose(fetches[1], (0, 3, 1, 2)) - np.testing.assert_allclose( - fetches[0], fetches1_tran, rtol=1e-05 - ) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + input1 = paddle.static.data( + name="input1", shape=[3, 3, 40, 40], dtype="float32" + ) + input2 = paddle.static.data( + name="input2", shape=[3, 40, 40, 3], dtype="float32" + ) + + res1 = paddle.nn.functional.local_response_norm( + x=input1, size=5, data_format='NCHW' + ) + res2 = paddle.nn.functional.local_response_norm( + x=input2, size=5, data_format='NHWC' + ) + + in_np1 = np.random.random([3, 3, 40, 40]).astype("float32") + in_np2 = np.transpose(in_np1, (0, 2, 3, 1)) + + exe = base.Executor(place) + fetches = exe.run( + paddle.static.default_main_program(), + feed={"input1": in_np1, "input2": in_np2}, + fetch_list=[res1, res2], + ) + + fetches1_tran = np.transpose(fetches[1], (0, 3, 1, 2)) + np.testing.assert_allclose(fetches[0], fetches1_tran, rtol=1e-05) def check_static_5d_input(self, place): - with paddle_static_guard(): - with base.program_guard(base.Program(), base.Program()): - input1 = paddle.static.data( - name="input1", shape=[3, 3, 3, 40, 40], dtype="float32" - ) - input2 = paddle.static.data( - name="input2", shape=[3, 3, 40, 40, 3], dtype="float32" - ) - res1 = paddle.nn.functional.local_response_norm( - x=input1, size=5, data_format='NCDHW' - ) - res2 = paddle.nn.functional.local_response_norm( - x=input2, size=5, data_format='NDHWC' - ) - - in_np1 = np.random.random([3, 3, 3, 40, 40]).astype("float32") - in_np2 = np.transpose(in_np1, (0, 2, 3, 4, 1)) - - exe = base.Executor(place) - fetches = exe.run( - base.default_main_program(), - feed={"input1": in_np1, "input2": in_np2}, - fetch_list=[res1, res2], - ) - - fetches1_tran = np.transpose(fetches[1], (0, 4, 1, 2, 3)) - np.testing.assert_allclose( - fetches[0], fetches1_tran, rtol=1e-05 - ) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + input1 = paddle.static.data( + name="input1", shape=[3, 3, 3, 40, 40], dtype="float32" + ) + input2 = paddle.static.data( + name="input2", shape=[3, 3, 40, 40, 3], dtype="float32" + ) + res1 = paddle.nn.functional.local_response_norm( + x=input1, size=5, data_format='NCDHW' + ) + res2 = paddle.nn.functional.local_response_norm( + x=input2, size=5, data_format='NDHWC' + ) + + in_np1 = np.random.random([3, 3, 3, 40, 40]).astype("float32") + in_np2 = np.transpose(in_np1, (0, 2, 3, 4, 1)) + + exe = base.Executor(place) + fetches = exe.run( + paddle.static.default_main_program(), + feed={"input1": in_np1, "input2": in_np2}, + fetch_list=[res1, res2], + ) + + fetches1_tran = np.transpose(fetches[1], (0, 4, 1, 2, 3)) + np.testing.assert_allclose(fetches[0], fetches1_tran, rtol=1e-05) + @test_with_pir_api def test_static(self): with paddle_static_guard(): for place in self.places: @@ -276,9 +275,12 @@ def test_dygraph(self): class TestLocalResponseNormFAPIError(unittest.TestCase): + @test_with_pir_api def test_errors(self): with paddle_static_guard(): - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): def test_Variable(): # the input of lrn must be Variable. @@ -346,6 +348,7 @@ def test_dygraph(self): res2_tran = np.transpose(res2.numpy(), (0, 3, 1, 2)) np.testing.assert_allclose(res1.numpy(), res2_tran, rtol=1e-05) + @test_with_pir_api def test_static_fp16_gpu(self): if paddle.base.core.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) From 4ac428075bfa169a35a42a4b2b4e10dc78f913c5 Mon Sep 17 00:00:00 2001 From: ming1753 <61511741+ming1753@users.noreply.github.com> Date: Fri, 29 Dec 2023 11:34:44 +0800 Subject: [PATCH 042/142] support optimized update shape_range_info_path (#60457) --- .../inference/analysis/ir_passes/tensorrt_subgraph_pass.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 2cabfe567b5d97..851e7863d7af25 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -193,6 +193,12 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( std::vector repetitive_params; std::vector engine_names; for (auto *node : graph->Nodes()) { + // load optimized model may update shape_range_info_path + auto shape_range_info_path = Get("trt_shape_range_info_path"); + if (node->IsOp() && node->Op()->Type() == "tensorrt_engine" && + !shape_range_info_path.empty()) { + node->Op()->SetAttr("shape_range_info_path", shape_range_info_path); + } if (node->IsOp() && !framework::ir::Agent(node).subgraph()->empty()) { engine_names.push_back(CreateTensorRTOp( node, graph, graph_param_names, &repetitive_params, use_cuda_graph)); From 567cd55d9b303d2c83a4c2b409f4d63515200c02 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding <77733235+wuhuachaocoding@users.noreply.github.com> Date: Fri, 29 Dec 2023 11:50:53 +0800 Subject: [PATCH 043/142] remove assert for sharding and mp hybrid parallel. (#60455) --- .../dygraph_optimizer/dygraph_sharding_optimizer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index b6b4c3c01842f6..605c08039d534e 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -537,9 +537,6 @@ def __init__(self, optimizer, hcg): self.pp_overlap = pp_config.sharding_comm_overlap self.pp_release_grads = pp_config.release_gradients - # TODO(liuzhenhai):support it latter - assert not self.comm_overlap, "not supported yet" - self._build_comm_buffers(acc_steps) # NOTE(shenliang03): Sort the comm_buffers by dst rank, # it will improve the performance in reduce communicate. Default From 21ee5780d519d1c21c8bf0bc3d12b34fb3c70d1c Mon Sep 17 00:00:00 2001 From: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> Date: Fri, 29 Dec 2023 12:37:17 +0800 Subject: [PATCH 044/142] bug fix (#60461) --- paddle/pir/dialect/shape/utils/shape_utils.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/paddle/pir/dialect/shape/utils/shape_utils.h b/paddle/pir/dialect/shape/utils/shape_utils.h index 8f383f3ad6e05a..7e4eafa6722763 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_utils.h @@ -105,6 +105,9 @@ class IR_API ShapeConstraintIRAnalysis : public ShapeAnalysis { int64_t next_sym_idx_ = 0; std::vector constraints_; + std::unordered_map + value_id_to_shapeordata; + public: explicit ShapeConstraintIRAnalysis(std::shared_ptr&& program) : ShapeConstraintIRAnalysis(program->module_op()) { From c15a6d3d44fbdda97b6e113f3a33797370c8d720 Mon Sep 17 00:00:00 2001 From: Jianbang Yang Date: Fri, 29 Dec 2023 13:59:47 +0800 Subject: [PATCH 045/142] [XPU] avoid pre-allocating gm buffer (#60387) --- paddle/fluid/distributed/collective/process_group_bkcl.cc | 2 ++ paddle/phi/backends/xpu/xpu_context.cc | 3 +++ 2 files changed, 5 insertions(+) diff --git a/paddle/fluid/distributed/collective/process_group_bkcl.cc b/paddle/fluid/distributed/collective/process_group_bkcl.cc index 8b306e29f52b32..cdc31cf9a64890 100644 --- a/paddle/fluid/distributed/collective/process_group_bkcl.cc +++ b/paddle/fluid/distributed/collective/process_group_bkcl.cc @@ -207,6 +207,8 @@ void ProcessGroupBKCL::CreateBKCLEnvCache(const Place& place, platform::DeviceContextPool::Instance().Get(place)); // must use XPUDeviceContext here to make sure XPUContext::Init() is called auto comm_ctx = std::make_unique(place); + // comm_ctx does not require a pre-allocated GM buffer + comm_ctx->x_context()->set_option("XPUAPI_DEFAULT_SIZE", "1"); auto bkcl_comm_ctx = this->GetCommContext(); comm_ctx->SetBkclContext(bkcl_comm_ctx->GetBKCLComm()); diff --git a/paddle/phi/backends/xpu/xpu_context.cc b/paddle/phi/backends/xpu/xpu_context.cc index ad0047b4e9ad63..9de9744393d4a5 100644 --- a/paddle/phi/backends/xpu/xpu_context.cc +++ b/paddle/phi/backends/xpu/xpu_context.cc @@ -200,6 +200,9 @@ struct XPUContext::Impl { << tname << " currently " << context_map_.size() << " contexts existing"; xpu::Context* ctx_t = xpu::create_context(); + // DataLoader does not require a pre-allocated GM buffer + // to avoid xpu_wait calls + ctx_t->set_option("XPUAPI_DEFAULT_SIZE", "1"); context_map_[tname] = ctx_t; } } From 16710f72e6696a2c45afac52360e4e21f05b047b Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Fri, 29 Dec 2023 14:08:21 +0800 Subject: [PATCH 046/142] [PIR] Fix some pir interpreter bug and refine some code (#60420) * fix * fix * fix * fix * fix * fix * fix --- .../{ => control_flow}/assert_instruction.cc | 2 +- .../{ => control_flow}/assert_instruction.h | 0 .../has_elements_instruction.cc | 2 +- .../has_elements_instruction.h | 0 .../{ => control_flow}/if_instruction.cc | 2 +- .../{ => control_flow}/if_instruction.h | 0 .../select_input_instruction.cc | 2 +- .../select_input_instruction.h | 0 .../tuple_pop_instruction.cc | 2 +- .../tuple_pop_instruction.h | 0 .../tuple_push_instruction.cc | 2 +- .../tuple_push_instruction.h | 0 .../{ => control_flow}/while_instruction.cc | 2 +- .../{ => control_flow}/while_instruction.h | 0 .../framework/new_executor/pir_interpreter.cc | 14 ++++---- .../translator/program_translator.cc | 32 ++++++++++++++++++- .../translator/program_translator.h | 2 ++ test/dygraph_to_static/test_for_enumerate.py | 3 -- test/legacy_test/test_cond.py | 7 +++- 19 files changed, 53 insertions(+), 19 deletions(-) rename paddle/fluid/framework/new_executor/instruction/{ => control_flow}/assert_instruction.cc (97%) rename paddle/fluid/framework/new_executor/instruction/{ => control_flow}/assert_instruction.h (100%) rename paddle/fluid/framework/new_executor/instruction/{ => control_flow}/has_elements_instruction.cc (96%) rename paddle/fluid/framework/new_executor/instruction/{ => control_flow}/has_elements_instruction.h (100%) rename paddle/fluid/framework/new_executor/instruction/{ => control_flow}/if_instruction.cc (99%) rename paddle/fluid/framework/new_executor/instruction/{ => control_flow}/if_instruction.h (100%) rename paddle/fluid/framework/new_executor/instruction/{ => control_flow}/select_input_instruction.cc (98%) rename paddle/fluid/framework/new_executor/instruction/{ => control_flow}/select_input_instruction.h (100%) rename paddle/fluid/framework/new_executor/instruction/{ => control_flow}/tuple_pop_instruction.cc (98%) rename paddle/fluid/framework/new_executor/instruction/{ => control_flow}/tuple_pop_instruction.h (100%) rename paddle/fluid/framework/new_executor/instruction/{ => control_flow}/tuple_push_instruction.cc (97%) rename paddle/fluid/framework/new_executor/instruction/{ => control_flow}/tuple_push_instruction.h (100%) rename paddle/fluid/framework/new_executor/instruction/{ => control_flow}/while_instruction.cc (99%) rename paddle/fluid/framework/new_executor/instruction/{ => control_flow}/while_instruction.h (100%) diff --git a/paddle/fluid/framework/new_executor/instruction/assert_instruction.cc b/paddle/fluid/framework/new_executor/instruction/control_flow/assert_instruction.cc similarity index 97% rename from paddle/fluid/framework/new_executor/instruction/assert_instruction.cc rename to paddle/fluid/framework/new_executor/instruction/control_flow/assert_instruction.cc index 96d1fcc57b9438..d2835dd65ccad1 100644 --- a/paddle/fluid/framework/new_executor/instruction/assert_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/control_flow/assert_instruction.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/new_executor/instruction/assert_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/control_flow/assert_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" #include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" #include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" diff --git a/paddle/fluid/framework/new_executor/instruction/assert_instruction.h b/paddle/fluid/framework/new_executor/instruction/control_flow/assert_instruction.h similarity index 100% rename from paddle/fluid/framework/new_executor/instruction/assert_instruction.h rename to paddle/fluid/framework/new_executor/instruction/control_flow/assert_instruction.h diff --git a/paddle/fluid/framework/new_executor/instruction/has_elements_instruction.cc b/paddle/fluid/framework/new_executor/instruction/control_flow/has_elements_instruction.cc similarity index 96% rename from paddle/fluid/framework/new_executor/instruction/has_elements_instruction.cc rename to paddle/fluid/framework/new_executor/instruction/control_flow/has_elements_instruction.cc index 958daf2239eaf0..900667071091b3 100644 --- a/paddle/fluid/framework/new_executor/instruction/has_elements_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/control_flow/has_elements_instruction.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/new_executor/instruction/has_elements_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/control_flow/has_elements_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" #include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" #include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" diff --git a/paddle/fluid/framework/new_executor/instruction/has_elements_instruction.h b/paddle/fluid/framework/new_executor/instruction/control_flow/has_elements_instruction.h similarity index 100% rename from paddle/fluid/framework/new_executor/instruction/has_elements_instruction.h rename to paddle/fluid/framework/new_executor/instruction/control_flow/has_elements_instruction.h diff --git a/paddle/fluid/framework/new_executor/instruction/if_instruction.cc b/paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.cc similarity index 99% rename from paddle/fluid/framework/new_executor/instruction/if_instruction.cc rename to paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.cc index 57146acdfb5df5..ef856c7fc01627 100644 --- a/paddle/fluid/framework/new_executor/instruction/if_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/new_executor/instruction/if_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.h" #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" diff --git a/paddle/fluid/framework/new_executor/instruction/if_instruction.h b/paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.h similarity index 100% rename from paddle/fluid/framework/new_executor/instruction/if_instruction.h rename to paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.h diff --git a/paddle/fluid/framework/new_executor/instruction/select_input_instruction.cc b/paddle/fluid/framework/new_executor/instruction/control_flow/select_input_instruction.cc similarity index 98% rename from paddle/fluid/framework/new_executor/instruction/select_input_instruction.cc rename to paddle/fluid/framework/new_executor/instruction/control_flow/select_input_instruction.cc index 893915f841d7fc..987edeb97eda02 100644 --- a/paddle/fluid/framework/new_executor/instruction/select_input_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/control_flow/select_input_instruction.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/new_executor/instruction/select_input_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/control_flow/select_input_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" #include "paddle/fluid/framework/new_executor/new_executor_defs.h" #include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" diff --git a/paddle/fluid/framework/new_executor/instruction/select_input_instruction.h b/paddle/fluid/framework/new_executor/instruction/control_flow/select_input_instruction.h similarity index 100% rename from paddle/fluid/framework/new_executor/instruction/select_input_instruction.h rename to paddle/fluid/framework/new_executor/instruction/control_flow/select_input_instruction.h diff --git a/paddle/fluid/framework/new_executor/instruction/tuple_pop_instruction.cc b/paddle/fluid/framework/new_executor/instruction/control_flow/tuple_pop_instruction.cc similarity index 98% rename from paddle/fluid/framework/new_executor/instruction/tuple_pop_instruction.cc rename to paddle/fluid/framework/new_executor/instruction/control_flow/tuple_pop_instruction.cc index a3a8f4461865e7..1cb27abb3e2d92 100644 --- a/paddle/fluid/framework/new_executor/instruction/tuple_pop_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/control_flow/tuple_pop_instruction.cc @@ -14,7 +14,7 @@ #include -#include "paddle/fluid/framework/new_executor/instruction/tuple_pop_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/control_flow/tuple_pop_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" #include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" diff --git a/paddle/fluid/framework/new_executor/instruction/tuple_pop_instruction.h b/paddle/fluid/framework/new_executor/instruction/control_flow/tuple_pop_instruction.h similarity index 100% rename from paddle/fluid/framework/new_executor/instruction/tuple_pop_instruction.h rename to paddle/fluid/framework/new_executor/instruction/control_flow/tuple_pop_instruction.h diff --git a/paddle/fluid/framework/new_executor/instruction/tuple_push_instruction.cc b/paddle/fluid/framework/new_executor/instruction/control_flow/tuple_push_instruction.cc similarity index 97% rename from paddle/fluid/framework/new_executor/instruction/tuple_push_instruction.cc rename to paddle/fluid/framework/new_executor/instruction/control_flow/tuple_push_instruction.cc index bb01125bf3ecaf..3f0082a4af5c8f 100644 --- a/paddle/fluid/framework/new_executor/instruction/tuple_push_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/control_flow/tuple_push_instruction.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/new_executor/instruction/tuple_push_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/control_flow/tuple_push_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" #include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" diff --git a/paddle/fluid/framework/new_executor/instruction/tuple_push_instruction.h b/paddle/fluid/framework/new_executor/instruction/control_flow/tuple_push_instruction.h similarity index 100% rename from paddle/fluid/framework/new_executor/instruction/tuple_push_instruction.h rename to paddle/fluid/framework/new_executor/instruction/control_flow/tuple_push_instruction.h diff --git a/paddle/fluid/framework/new_executor/instruction/while_instruction.cc b/paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.cc similarity index 99% rename from paddle/fluid/framework/new_executor/instruction/while_instruction.cc rename to paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.cc index b281e2b8a6cbe4..a9f23fd60e176f 100644 --- a/paddle/fluid/framework/new_executor/instruction/while_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/new_executor/instruction/while_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.h" #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" diff --git a/paddle/fluid/framework/new_executor/instruction/while_instruction.h b/paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.h similarity index 100% rename from paddle/fluid/framework/new_executor/instruction/while_instruction.h rename to paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.h diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.cc b/paddle/fluid/framework/new_executor/pir_interpreter.cc index 82bf2973345ad5..2afdfb5e9717ad 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/pir_interpreter.cc @@ -48,16 +48,16 @@ #include "paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h" #endif -#include "paddle/fluid/framework/new_executor/instruction/assert_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/builtin_combine_instruction.h" -#include "paddle/fluid/framework/new_executor/instruction/has_elements_instruction.h" -#include "paddle/fluid/framework/new_executor/instruction/if_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/control_flow/assert_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/control_flow/has_elements_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/control_flow/select_input_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/control_flow/tuple_pop_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/control_flow/tuple_push_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h" -#include "paddle/fluid/framework/new_executor/instruction/select_input_instruction.h" -#include "paddle/fluid/framework/new_executor/instruction/tuple_pop_instruction.h" -#include "paddle/fluid/framework/new_executor/instruction/tuple_push_instruction.h" -#include "paddle/fluid/framework/new_executor/instruction/while_instruction.h" #include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index 5eaaa5052f457f..7eca5767750b9a 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -131,6 +131,8 @@ ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program, void ProgramTranslator::Translate() { GetParameterForSingleBlock(legacy_program_->Block(0)); + InsertDataOpForSingleBlock(legacy_program_->Block(0)); + TranslateBlock(legacy_program_->Block(0), 0, legacy_program_->Block(0).OpSize(), @@ -155,7 +157,7 @@ void ProgramTranslator::TranslateBlock(const BlockDesc& src_block, uint64_t end_id, TranslationContext* translation_ctx, pir::Block* dst_block) { - VLOG(8) << "=============>start to translate a block"; + VLOG(8) << "=============>start to translate a block: " << &src_block; PADDLE_ENFORCE( (src_block.OpSize() >= end_id) && (start_id <= end_id), platform::errors::NotFound( @@ -419,6 +421,34 @@ inline pir::Operation* InsertSetParamaterOp(pir::IrContext* ctx, return operation; } +void ProgramTranslator::InsertDataOpForSingleBlock(const BlockDesc& block) { + std::unordered_set all_var_names; + for (auto& var : block.AllVars()) { + all_var_names.insert(var->Name()); + } + + std::unordered_set inner_outputs; + for (auto op_desc : block.AllOps()) { + for (const auto& n : op_desc->Inputs()) { + const auto& input_var_names = n.second; + for (const auto& var_name : input_var_names) { + if (param_map_.count(var_name) != 0) continue; + if (no_cast_var_names.count(var_name) != 0) continue; + if (all_var_names.count(var_name) == 0) continue; + if (inner_outputs.count(var_name) == 0) { + CreateUndefinedVariable(var_name, block); + } + } + } + for (const auto& n : op_desc->Outputs()) { + const auto& output_var_names = n.second; + for (const auto& var_name : output_var_names) { + inner_outputs.insert(var_name); + } + } + } +} + void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) { for (auto& var : block.AllVars()) { if (!var->Persistable()) continue; diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.h b/paddle/fluid/ir_adaptor/translator/program_translator.h index 5fce6b08c2648b..cff7684226c520 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.h +++ b/paddle/fluid/ir_adaptor/translator/program_translator.h @@ -120,6 +120,8 @@ class ProgramTranslator { void TranslateGeneralOperation(const OpDesc* src_op, TranslationContext* translation_ctx, pir::Block* dst_block); + + void InsertDataOpForSingleBlock(const BlockDesc& block); void GetParameterForSingleBlock(const BlockDesc& block); void SetParameterFromSingleBlock(const BlockDesc& block); void SetStopGradientAttributeForAllValue(const BlockDesc& block); diff --git a/test/dygraph_to_static/test_for_enumerate.py b/test/dygraph_to_static/test_for_enumerate.py index a540cef2e387bb..7b754fb1343eae 100644 --- a/test/dygraph_to_static/test_for_enumerate.py +++ b/test/dygraph_to_static/test_for_enumerate.py @@ -19,7 +19,6 @@ import numpy as np from dygraph_to_static_utils import ( Dy2StTestBase, - compare_legacy_with_pt, enable_to_static_guard, test_legacy_and_pt_and_pir, ) @@ -495,7 +494,6 @@ class TestForIterVarList(TestForInRangeConfig): def set_test_func(self): self.dygraph_func = for_iter_var_list - @compare_legacy_with_pt def test_transformed_result_compare(self): self.set_test_func() self.transformed_result_compare() @@ -505,7 +503,6 @@ class TestForEnumerateVarList(TestForInRangeConfig): def set_test_func(self): self.dygraph_func = for_enumerate_var_list - @compare_legacy_with_pt def test_transformed_result_compare(self): self.set_test_func() self.transformed_result_compare() diff --git a/test/legacy_test/test_cond.py b/test/legacy_test/test_cond.py index 1323d7caa6eaec..3dcd127e51c4b8 100644 --- a/test/legacy_test/test_cond.py +++ b/test/legacy_test/test_cond.py @@ -32,6 +32,7 @@ class TestCondInputOutput(unittest.TestCase): @compare_legacy_with_pt + @test_with_pir_api def test_return_single_var(self): """ pseudocode: @@ -73,7 +74,11 @@ def false_func(): else base.CPUPlace() ) exe = base.Executor(place) - (ret,) = exe.run(main_program, fetch_list=[out.name]) + if paddle.framework.in_pir_mode(): + (ret,) = exe.run(main_program, fetch_list=[out]) + else: + (ret,) = exe.run(main_program, fetch_list=[out.name]) + np.testing.assert_allclose( np.asarray(ret), np.full((3, 2), -1, np.int32), rtol=1e-05 ) From 5828223b6f13ced50b36439ec26248f29bb58d99 Mon Sep 17 00:00:00 2001 From: enzodechine Date: Fri, 29 Dec 2023 14:47:12 +0800 Subject: [PATCH 047/142] correct the unittest for bf16 op (#60415) --- .../legacy/xpu/elementwise_divide_kernel.cc | 1 + .../legacy/xpu/elementwise_multiply_kernel.cc | 1 + .../legacy/xpu/elementwise_subtract_kernel.cc | 1 + test/xpu/test_elementwise_div_op_xpu.py | 354 ++++-------------- test/xpu/test_elementwise_mul_op_xpu.py | 205 +++++----- test/xpu/test_elementwise_sub_op_xpu.py | 175 ++++----- test/xpu/test_reduce_sum_op_xpu.py | 39 +- 7 files changed, 278 insertions(+), 498 deletions(-) diff --git a/paddle/phi/kernels/legacy/xpu/elementwise_divide_kernel.cc b/paddle/phi/kernels/legacy/xpu/elementwise_divide_kernel.cc index 5318cb464001f8..ccdfcd750f091d 100644 --- a/paddle/phi/kernels/legacy/xpu/elementwise_divide_kernel.cc +++ b/paddle/phi/kernels/legacy/xpu/elementwise_divide_kernel.cc @@ -50,4 +50,5 @@ PD_REGISTER_KERNEL(divide_raw, ALL_LAYOUT, phi::DivideRawKernel, phi::dtype::float16, + phi::dtype::bfloat16, float) {} diff --git a/paddle/phi/kernels/legacy/xpu/elementwise_multiply_kernel.cc b/paddle/phi/kernels/legacy/xpu/elementwise_multiply_kernel.cc index 790bd72b240914..2986e555cda705 100644 --- a/paddle/phi/kernels/legacy/xpu/elementwise_multiply_kernel.cc +++ b/paddle/phi/kernels/legacy/xpu/elementwise_multiply_kernel.cc @@ -50,6 +50,7 @@ PD_REGISTER_KERNEL(multiply_raw, ALL_LAYOUT, phi::MultiplyRawKernel, phi::dtype::float16, + phi::dtype::bfloat16, float, int, int64_t) {} diff --git a/paddle/phi/kernels/legacy/xpu/elementwise_subtract_kernel.cc b/paddle/phi/kernels/legacy/xpu/elementwise_subtract_kernel.cc index 421a30a240a434..7fb4144d7705bc 100644 --- a/paddle/phi/kernels/legacy/xpu/elementwise_subtract_kernel.cc +++ b/paddle/phi/kernels/legacy/xpu/elementwise_subtract_kernel.cc @@ -45,4 +45,5 @@ PD_REGISTER_KERNEL(subtract_raw, phi::SubtractRawKernel, float, phi::dtype::float16, + phi::dtype::bfloat16, int64_t) {} diff --git a/test/xpu/test_elementwise_div_op_xpu.py b/test/xpu/test_elementwise_div_op_xpu.py index 52e2e62e067d2e..ca190b7eb12307 100644 --- a/test/xpu/test_elementwise_div_op_xpu.py +++ b/test/xpu/test_elementwise_div_op_xpu.py @@ -20,7 +20,10 @@ create_test_class, get_xpu_op_support_types, ) -from op_test import skip_check_grad_ci +from op_test import ( + convert_float_to_uint16, + skip_check_grad_ci, +) from op_test_xpu import XPUOpTest import paddle @@ -28,6 +31,8 @@ paddle.enable_static() +INT_GROUP = [np.int32, np.int64] + class XPUTestElementwiseDivOp(XPUOpTestWrapper): def __init__(self): @@ -40,6 +45,7 @@ def setUp(self): self.dtype = self.in_type self.init_dtype() self.use_xpu = True + self.init_shape() self.init_input_output() """ Warning CPU gradient check error! @@ -47,20 +53,40 @@ def setUp(self): 'Y': np.random.random((32,84)).astype("float32") """ + def gen_data_depend_on_dtype(self, shape): + if self.dtype in INT_GROUP: + return np.random.randint(1, 100, size=shape) + else: + return np.random.uniform(-1, 1, size=shape) + + def reshape_y_depend_on_x(self): + if len(self.x_shape) <= len(self.y_shape) or self.y_shape == (): + return self.y + reshape_dims = [ + 1 if i not in self.y_shape else i for i in self.x_shape + ] + return np.reshape(self.y, reshape_dims) + def init_input_output(self): - if self.dtype == np.int32 or self.dtype == np.int64: + self.x = self.gen_data_depend_on_dtype(self.x_shape) + self.y = self.gen_data_depend_on_dtype(self.y_shape) + reshaped_y = self.reshape_y_depend_on_x() + if self.dtype == np.uint16: + self.outputs = {'Out': np.divide(self.x, reshaped_y)} self.inputs = { - 'X': np.random.randint(1, 100, [13, 17]).astype(self.dtype), - 'Y': np.random.randint(1, 100, [13, 17]).astype(self.dtype), + 'X': convert_float_to_uint16(self.x), + 'Y': convert_float_to_uint16(self.y), } - self.outputs = {'Out': self.inputs['X'] // self.inputs['Y']} else: self.inputs = { - 'X': np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype), - 'Y': np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype), + 'X': self.x.astype(self.dtype), + 'Y': self.y.astype(self.dtype), } + reshaped_y.astype(self.dtype) self.outputs = { - 'Out': np.divide(self.inputs['X'], self.inputs['Y']) + 'Out': self.inputs['X'] // reshaped_y + if self.dtype in INT_GROUP + else np.divide(self.inputs['X'], reshaped_y) } def test_check_output(self): @@ -100,306 +126,80 @@ def test_check_grad_ingore_y(self): def init_dtype(self): pass + def init_shape(self): + self.x_shape = [13, 17] + self.y_shape = [13, 17] + class TestElementwiseDivOp_ZeroDim1(ElementwiseDivOp): - def init_input_output(self): - if self.dtype == np.int32 or self.dtype == np.int64: - self.inputs = { - 'X': np.random.randint(1, 100, []).astype(self.dtype), - 'Y': np.random.randint(1, 100, []).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] // self.inputs['Y']} - else: - self.inputs = { - 'X': np.random.uniform(-1, 1, []).astype(self.dtype), - 'Y': np.random.uniform(-1, 1, []).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] / self.inputs['Y']} + def init_shape(self): + self.x_shape = [] + self.y_shape = [] class TestElementwiseDivOp_ZeroDim2(ElementwiseDivOp): - def init_input_output(self): - if self.dtype == np.int32 or self.dtype == np.int64: - self.inputs = { - 'X': np.random.randint(1, 100, [13, 17]).astype(self.dtype), - 'Y': np.random.randint(1, 100, []).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] // self.inputs['Y']} - else: - self.inputs = { - 'X': np.random.uniform(-1, 1, [13, 17]).astype(self.dtype), - 'Y': np.random.uniform(-1, 1, []).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] / self.inputs['Y']} + def init_shape(self): + self.x_shape = [13, 17] + self.y_shape = [] @skip_check_grad_ci( reason="[skip shape check] Use y_shape(1) to test broadcast." ) class TestElementwiseDivOp_scalar(ElementwiseDivOp): - def init_input_output(self): - if self.dtype == np.int32 or self.dtype == np.int64: - self.inputs = { - 'X': np.random.randint(1, 100, [20, 3, 4]).astype( - self.dtype - ), - 'Y': np.random.randint(1, 100, [1]).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] // self.inputs['Y']} - else: - self.inputs = { - 'X': np.random.uniform(0.1, 1, [20, 3, 4]).astype( - self.dtype - ), - 'Y': np.random.uniform(0.1, 1, [1]).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] / self.inputs['Y']} + def init_shape(self): + self.x_shape = [20, 3, 4] + self.y_shape = [1] class TestElementwiseDivOp_Vector(ElementwiseDivOp): - def init_input_output(self): - if self.dtype == np.int32 or self.dtype == np.int64: - self.inputs = { - 'X': np.random.randint(1, 100, [100]).astype(self.dtype), - 'Y': np.random.randint(1, 100, [100]).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] // self.inputs['Y']} - else: - self.inputs = { - 'X': np.random.uniform(0.1, 1, [100]).astype(self.dtype), - 'Y': np.random.uniform(0.1, 1, [100]).astype(self.dtype), - } - self.outputs = { - 'Out': np.divide(self.inputs['X'], self.inputs['Y']) - } + def init_shape(self): + self.x_shape = [100] + self.y_shape = [100] class TestElementwiseDivOp_broadcast_0(ElementwiseDivOp): - def init_input_output(self): - if self.dtype == np.int32 or self.dtype == np.int64: - self.inputs = { - 'X': np.random.randint(1, 100, [100, 3, 4]).astype( - self.dtype - ), - 'Y': np.random.randint(1, 100, [100]).astype(self.dtype), - } - self.outputs = { - 'Out': self.inputs['X'] - // self.inputs['Y'].reshape(100, 1, 1) - } - else: - self.inputs = { - 'X': np.random.uniform(0.1, 1, [100, 3, 4]).astype( - self.dtype - ), - 'Y': np.random.uniform(0.1, 1, [100]).astype(self.dtype), - } - self.outputs = { - 'Out': np.divide( - self.inputs['X'], self.inputs['Y'].reshape(100, 1, 1) - ) - } - + def init_shape(self): + self.x_shape = [100, 3, 4] + self.y_shape = [100] self.attrs = {'axis': 0} class TestElementwiseDivOp_broadcast_1(ElementwiseDivOp): - def init_input_output(self): - if self.dtype == np.int32 or self.dtype == np.int64: - self.inputs = { - 'X': np.random.randint(1, 100, [2, 100, 4]).astype( - self.dtype - ), - 'Y': np.random.randint(1, 100, [100]).astype(self.dtype), - } - self.outputs = { - 'Out': self.inputs['X'] - // self.inputs['Y'].reshape(1, 100, 1) - } - else: - self.inputs = { - 'X': np.random.uniform(0.1, 1, [2, 100, 4]).astype( - self.dtype - ), - 'Y': np.random.uniform(0.1, 1, [100]).astype(self.dtype), - } - self.outputs = { - 'Out': np.divide( - self.inputs['X'], self.inputs['Y'].reshape(1, 100, 1) - ) - } - + def init_shape(self): + self.x_shape = [2, 100, 4] + self.y_shape = [100] self.attrs = {'axis': 1} class TestElementwiseDivOp_broadcast_2(ElementwiseDivOp): - def init_input_output(self): - if self.dtype == np.int32 or self.dtype == np.int64: - self.inputs = { - 'X': np.random.randint(1, 100, [2, 3, 100]).astype( - self.dtype - ), - 'Y': np.random.randint(1, 100, [100]).astype(self.dtype), - } - self.outputs = { - 'Out': self.inputs['X'] - // self.inputs['Y'].reshape(1, 1, 100) - } - else: - self.inputs = { - 'X': np.random.uniform(0.1, 1, [2, 3, 100]).astype( - self.dtype - ), - 'Y': np.random.uniform(0.1, 1, [100]).astype(self.dtype), - } - self.outputs = { - 'Out': np.divide( - self.inputs['X'], self.inputs['Y'].reshape(1, 1, 100) - ) - } + def init_shape(self): + self.x_shape = [2, 3, 100] + self.y_shape = [100] class TestElementwiseDivOp_broadcast_3(ElementwiseDivOp): - def init_input_output(self): - if self.dtype == np.int32 or self.dtype == np.int64: - self.inputs = { - 'X': np.random.randint(1, 100, [2, 10, 12, 5]).astype( - self.dtype - ), - 'Y': np.random.randint(1, 100, [10, 12]).astype(self.dtype), - } - self.outputs = { - 'Out': self.inputs['X'] - // self.inputs['Y'].reshape(1, 10, 12, 1) - } - else: - self.inputs = { - 'X': np.random.uniform(0.1, 1, [2, 10, 12, 5]).astype( - self.dtype - ), - 'Y': np.random.uniform(0.1, 1, [10, 12]).astype(self.dtype), - } - self.outputs = { - 'Out': np.divide( - self.inputs['X'], self.inputs['Y'].reshape(1, 10, 12, 1) - ) - } - + def init_shape(self): + self.x_shape = [2, 10, 12, 5] + self.y_shape = [10, 12] self.attrs = {'axis': 1} class TestElementwiseDivOp_broadcast_4(ElementwiseDivOp): - def init_input_output(self): - if self.dtype == np.int32 or self.dtype == np.int64: - self.inputs = { - 'X': np.random.randint(1, 100, [2, 3, 50]).astype( - self.dtype - ), - 'Y': np.random.randint(1, 100, [2, 1, 50]).astype( - self.dtype - ), - } - self.outputs = {'Out': self.inputs['X'] // self.inputs['Y']} - else: - self.inputs = { - 'X': np.random.uniform(0.1, 1, [2, 3, 50]).astype( - self.dtype - ), - 'Y': np.random.uniform(0.1, 1, [2, 1, 50]).astype( - self.dtype - ), - } - self.outputs = { - 'Out': np.divide(self.inputs['X'], self.inputs['Y']) - } + def init_shape(self): + self.x_shape = [2, 3, 50] + self.y_shape = [2, 1, 50] class TestElementwiseDivOp_broadcast_5(ElementwiseDivOp): - def init_input_output(self): - if self.dtype == np.int32 or self.dtype == np.int64: - self.inputs = { - 'X': np.random.randint(1, 100, [2, 3, 4, 20]).astype( - self.dtype - ), - 'Y': np.random.randint(1, 100, [2, 3, 1, 20]).astype( - self.dtype - ), - } - self.outputs = {'Out': self.inputs['X'] // self.inputs['Y']} - else: - self.inputs = { - 'X': np.random.uniform(0.1, 1, [2, 3, 4, 20]).astype( - self.dtype - ), - 'Y': np.random.uniform(0.1, 1, [2, 3, 1, 20]).astype( - self.dtype - ), - } - self.outputs = { - 'Out': np.divide(self.inputs['X'], self.inputs['Y']) - } + def init_shape(self): + self.x_shape = [2, 3, 4, 20] + self.y_shape = [2, 3, 1, 20] class TestElementwiseDivOp_commonuse_1(ElementwiseDivOp): - def init_input_output(self): - if self.dtype == np.int32 or self.dtype == np.int64: - self.inputs = { - 'X': np.random.randint(1, 100, [2, 3, 100]).astype( - self.dtype - ), - 'Y': np.random.randint(1, 100, [1, 1, 100]).astype( - self.dtype - ), - } - self.outputs = {'Out': self.inputs['X'] // self.inputs['Y']} - else: - self.inputs = { - 'X': np.random.uniform(0.1, 1, [2, 3, 100]).astype( - self.dtype - ), - 'Y': np.random.uniform(0.1, 1, [1, 1, 100]).astype( - self.dtype - ), - } - self.outputs = { - 'Out': np.divide(self.inputs['X'], self.inputs['Y']) - } + def init_shape(self): + self.x_shape = [2, 3, 100] + self.y_shape = [1, 1, 100] class TestElementwiseDivOp_commonuse_2(ElementwiseDivOp): - def init_input_output(self): - if self.dtype == np.int32 or self.dtype == np.int64: - self.inputs = { - 'X': np.random.randint(1, 100, [30, 3, 1, 5]).astype( - self.dtype - ), - 'Y': np.random.randint(1, 100, [30, 1, 4, 1]).astype( - self.dtype - ), - } - self.outputs = {'Out': self.inputs['X'] // self.inputs['Y']} - else: - self.inputs = { - 'X': np.random.uniform(0.1, 1, [30, 3, 1, 5]).astype( - self.dtype - ), - 'Y': np.random.uniform(0.1, 1, [30, 1, 4, 1]).astype( - self.dtype - ), - } - self.outputs = { - 'Out': np.divide(self.inputs['X'], self.inputs['Y']) - } + def init_shape(self): + self.x_shape = [30, 3, 1, 5] + self.y_shape = [30, 1, 4, 1] class TestElementwiseDivOp_xsize_lessthan_ysize(ElementwiseDivOp): - def init_input_output(self): - if self.dtype == np.int32 or self.dtype == np.int64: - self.inputs = { - 'X': np.random.randint(1, 100, [10, 12]).astype(self.dtype), - 'Y': np.random.randint(1, 100, [2, 3, 10, 12]).astype( - self.dtype - ), - } - self.outputs = {'Out': self.inputs['X'] // self.inputs['Y']} - else: - self.inputs = { - 'X': np.random.uniform(0.1, 1, [10, 12]).astype(self.dtype), - 'Y': np.random.uniform(0.1, 1, [2, 3, 10, 12]).astype( - self.dtype - ), - } - self.outputs = { - 'Out': np.divide(self.inputs['X'], self.inputs['Y']) - } - + def init_shape(self): + self.x_shape = [10, 12] + self.y_shape = [2, 3, 10, 12] self.attrs = {'axis': 2} class TestElementwiseDivBroadcast(unittest.TestCase): diff --git a/test/xpu/test_elementwise_mul_op_xpu.py b/test/xpu/test_elementwise_mul_op_xpu.py index 6bd604df07e40a..b8fda9a5b6217a 100644 --- a/test/xpu/test_elementwise_mul_op_xpu.py +++ b/test/xpu/test_elementwise_mul_op_xpu.py @@ -20,7 +20,10 @@ create_test_class, get_xpu_op_support_types, ) -from op_test import OpTest, skip_check_grad_ci +from op_test import ( + convert_float_to_uint16, + skip_check_grad_ci, +) from op_test_xpu import XPUOpTest import paddle @@ -40,13 +43,34 @@ def init_kernel_type(self): def setUp(self): self.op_type = 'elementwise_mul' self.use_xpu = True + self.cal_x = None + self.cal_y = None self.dtype = self.in_type self.axis = -1 - self.init_dtype() + self.init_data() + self.gen_output() self.init_input_output() self.init_kernel_type() self.init_axis() + def gen_output(self): + if self.cal_x is None: + self.cal_x = self.x + if self.cal_y is None: + self.cal_y = self.y + if self.dtype == np.uint16: + self.out = np.multiply(self.cal_x, self.cal_y) + else: + self.out = np.multiply( + self.cal_x.astype(self.dtype), self.cal_y.astype(self.dtype) + ) + + def gen_data_depend_on_dtype(self, shape): + if self.dtype == np.int32 or self.dtype == np.int64: + return np.random.randint(1, 100, size=shape) + else: + return np.random.uniform(0.1, 1, size=shape) + def test_check_output(self): if paddle.is_compiled_with_xpu(): place = paddle.XPUPlace(0) @@ -84,158 +108,109 @@ def test_check_grad_ingore_y(self): check_dygraph=False, ) + def init_data(self): + self.x = self.gen_data_depend_on_dtype([13, 17]) + self.y = self.gen_data_depend_on_dtype([13, 17]) + def init_input_output(self): - self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) - self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) - self.out = np.multiply(self.x, self.y) + if self.dtype == np.uint16: + self.x = convert_float_to_uint16(self.x) + self.y = convert_float_to_uint16(self.y) + else: + self.x = self.x.astype(self.dtype) + self.y = self.y.astype(self.dtype) + self.inputs = { - 'X': OpTest.np_dtype_to_base_dtype(self.x), - 'Y': OpTest.np_dtype_to_base_dtype(self.y), + 'X': self.x, + 'Y': self.y, } self.outputs = {'Out': self.out} self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn} - def init_dtype(self): - pass - def init_axis(self): pass class TestElementwiseMulOp_ZeroDim1(ElementwiseMulOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.uniform(-1, 1, []).astype(self.dtype), - 'Y': np.random.uniform(-1, 1, []).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] * self.inputs['Y']} + def init_data(self): + self.x = self.gen_data_depend_on_dtype([]) + self.y = self.gen_data_depend_on_dtype([]) class TestElementwiseMulOp_ZeroDim2(ElementwiseMulOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.uniform(-1, 1, [13, 17]).astype(self.dtype), - 'Y': np.random.uniform(-1, 1, []).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] * self.inputs['Y']} + def init_data(self): + self.x = self.gen_data_depend_on_dtype([13, 17]) + self.y = self.gen_data_depend_on_dtype([]) class TestElementwiseMulOp_ZeroDim3(ElementwiseMulOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.uniform(-1, 1, []).astype(self.dtype), - 'Y': np.random.uniform(-1, 1, [13, 17]).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] * self.inputs['Y']} + def init_data(self): + self.x = self.gen_data_depend_on_dtype([]) + self.y = self.gen_data_depend_on_dtype([13, 17]) @skip_check_grad_ci( reason="[skip shape check] Use y_shape(1) to test broadcast." ) class TestElementwiseMulOp_scalar(ElementwiseMulOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.rand(10, 3, 4).astype(self.dtype), - 'Y': np.random.rand(1).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] * self.inputs['Y']} + def init_data(self): + self.x = self.gen_data_depend_on_dtype([10, 3, 4]) + self.y = self.gen_data_depend_on_dtype([1]) class TestElementwiseMulOp_Vector(ElementwiseMulOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.random((100,)).astype(self.dtype), - 'Y': np.random.random((100,)).astype(self.dtype), - } - self.outputs = { - 'Out': np.multiply(self.inputs['X'], self.inputs['Y']) - } + def init_data(self): + self.x = self.gen_data_depend_on_dtype([100]) + self.y = self.gen_data_depend_on_dtype([100]) class TestElementwiseMulOp_broadcast_0(ElementwiseMulOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.rand(100, 2, 3).astype(self.dtype), - 'Y': np.random.rand(100).astype(self.dtype), - } - self.outputs = { - 'Out': self.inputs['X'] * self.inputs['Y'].reshape(100, 1, 1) - } - self.attrs = {'axis': 0} + def init_data(self): + self.x = self.gen_data_depend_on_dtype([100, 2, 3]) + self.y = self.gen_data_depend_on_dtype([100]) + self.cal_y = self.y.reshape(100, 1, 1) + self.axis = 0 class TestElementwiseMulOp_broadcast_1(ElementwiseMulOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.rand(2, 100, 3).astype(self.dtype), - 'Y': np.random.rand(100).astype(self.dtype), - } - - self.attrs = {'axis': 1} - self.outputs = { - 'Out': self.inputs['X'] * self.inputs['Y'].reshape(1, 100, 1) - } + def init_data(self): + self.x = self.gen_data_depend_on_dtype([2, 100, 3]) + self.y = self.gen_data_depend_on_dtype([100]) + self.cal_y = self.y.reshape(1, 100, 1) + self.axis = 1 class TestElementwiseMulOp_broadcast_2(ElementwiseMulOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.rand(2, 3, 100).astype(self.dtype), - 'Y': np.random.rand(100).astype(self.dtype), - } - - self.outputs = { - 'Out': self.inputs['X'] * self.inputs['Y'].reshape(1, 1, 100) - } + def init_data(self): + self.x = self.gen_data_depend_on_dtype([2, 3, 100]) + self.y = self.gen_data_depend_on_dtype([100]) + self.cal_y = self.y.reshape(1, 1, 100) class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.rand(2, 10, 12, 3).astype(self.dtype), - 'Y': np.random.rand(10, 12).astype(self.dtype), - } - - self.attrs = {'axis': 1} - self.outputs = { - 'Out': self.inputs['X'] * self.inputs['Y'].reshape(1, 10, 12, 1) - } + def init_data(self): + self.x = self.gen_data_depend_on_dtype([2, 10, 12, 3]) + self.y = self.gen_data_depend_on_dtype([10, 12]) + self.cal_y = self.y.reshape(1, 10, 12, 1) + self.axis = 1 class TestElementwiseMulOp_broadcast_4(ElementwiseMulOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.rand(10, 2, 11).astype(self.dtype), - 'Y': np.random.rand(10, 1, 11).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] * self.inputs['Y']} + def init_data(self): + self.x = self.gen_data_depend_on_dtype([10, 2, 11]) + self.y = self.gen_data_depend_on_dtype([10, 1, 11]) class TestElementwiseMulOp_broadcast_5(ElementwiseMulOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.rand(10, 4, 2, 3).astype(self.dtype), - 'Y': np.random.rand(10, 4, 1, 3).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] * self.inputs['Y']} + def init_data(self): + self.x = self.gen_data_depend_on_dtype([10, 4, 2, 3]) + self.y = self.gen_data_depend_on_dtype([10, 4, 1, 3]) class TestElementwiseMulOp_commonuse_1(ElementwiseMulOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.rand(2, 3, 100).astype(self.dtype), - 'Y': np.random.rand(1, 1, 100).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] * self.inputs['Y']} + def init_data(self): + self.x = self.gen_data_depend_on_dtype([2, 3, 100]) + self.y = self.gen_data_depend_on_dtype([1, 1, 100]) class TestElementwiseMulOp_commonuse_2(ElementwiseMulOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.rand(30, 3, 1, 5).astype(self.dtype), - 'Y': np.random.rand(30, 1, 4, 1).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] * self.inputs['Y']} + def init_data(self): + self.x = self.gen_data_depend_on_dtype([30, 3, 1, 5]) + self.y = self.gen_data_depend_on_dtype([30, 1, 4, 1]) class TestElementwiseMulOp_xsize_lessthan_ysize(ElementwiseMulOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.rand(10, 10).astype(self.dtype), - 'Y': np.random.rand(2, 2, 10, 10).astype(self.dtype), - } - - self.attrs = {'axis': 2} - - self.outputs = { - 'Out': self.inputs['X'].reshape(1, 1, 10, 10) * self.inputs['Y'] - } + def init_data(self): + self.x = self.gen_data_depend_on_dtype([10, 10]) + self.y = self.gen_data_depend_on_dtype([2, 2, 10, 10]) + self.cal_x = self.x.reshape(1, 1, 10, 10) + self.axis = 2 support_types = get_xpu_op_support_types('elementwise_mul') diff --git a/test/xpu/test_elementwise_sub_op_xpu.py b/test/xpu/test_elementwise_sub_op_xpu.py index 8e595932eae29d..3cb440f05de063 100644 --- a/test/xpu/test_elementwise_sub_op_xpu.py +++ b/test/xpu/test_elementwise_sub_op_xpu.py @@ -21,13 +21,18 @@ create_test_class, get_xpu_op_support_types, ) -from op_test import skip_check_grad_ci +from op_test import ( + convert_float_to_uint16, + skip_check_grad_ci, +) from op_test_xpu import XPUOpTest import paddle paddle.enable_static() +INT_GROUP = [np.int32, np.int64] + class XPUTestElementwiseSubOp(XPUOpTestWrapper): def __init__(self): @@ -39,14 +44,43 @@ def setUp(self): self.op_type = "elementwise_sub" self.use_xpu = True self.dtype = self.in_type + self.init_shape() self.init_input_output() + def reshape_data(self, x, y): + if len(x.shape) < len(y.shape): + reshape_dims = [1 if i not in x.shape else i for i in y.shape] + return np.reshape(x, reshape_dims) + else: + return x + + def gen_data_depend_on_dtype(self, shape): + if self.dtype in INT_GROUP: + return np.random.randint(1, 100, size=shape) + else: + return np.random.uniform(-1, 1, size=shape) + def init_input_output(self): + self.x = self.gen_data_depend_on_dtype(self.x_shape) + self.y = self.gen_data_depend_on_dtype(self.y_shape) + if self.dtype == np.uint16: + tmp_x = self.reshape_data(self.x, self.y) + tmp_y = self.reshape_data(self.y, self.x) + self.outputs = {'Out': tmp_x - tmp_y} + self.x = convert_float_to_uint16(self.x) + self.y = convert_float_to_uint16(self.y) + else: + tmp_x = self.reshape_data(self.x, self.y).astype(self.dtype) + tmp_y = self.reshape_data(self.y, self.x).astype(self.dtype) + self.outputs = {'Out': tmp_x - tmp_y} self.inputs = { - 'X': np.random.uniform(0.1, 1, [2, 3, 4, 5]).astype(self.dtype), - 'Y': np.random.uniform(0.1, 1, [2, 3, 4, 5]).astype(self.dtype), + 'X': self.x, + 'Y': self.y, } - self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} + + def init_shape(self): + self.x_shape = [2, 3, 4, 5] + self.y_shape = [2, 3, 4, 5] def test_check_output(self): if paddle.is_compiled_with_xpu(): @@ -81,132 +115,77 @@ def test_check_grad_ingore_y(self): ) class TestElementwiseSubOp_ZeroDim1(TestElementwiseOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.uniform(-1, 1, []).astype(self.dtype), - 'Y': np.random.uniform(-1, 1, []).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} + def init_shape(self): + self.x_shape = [] + self.y_shape = [] class TestElementwiseSubOp_ZeroDim2(TestElementwiseOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.uniform(-1, 1, [13, 17]).astype(self.dtype), - 'Y': np.random.uniform(-1, 1, []).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} + def init_shape(self): + self.x_shape = [13, 17] + self.y_shape = [] class TestElementwiseSubOp_ZeroDim3(TestElementwiseOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.uniform(-1, 1, []).astype(self.dtype), - 'Y': np.random.uniform(-1, 1, [13, 17]).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} + def init_shape(self): + self.x_shape = [] + self.y_shape = [13, 17] @skip_check_grad_ci( reason="[skip shape check] Use y_shape(1) to test broadcast." ) class TestElementwiseSubOp_scalar(TestElementwiseOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.rand(10, 3, 4).astype(self.dtype), - 'Y': np.random.rand(1).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} + def init_shape(self): + self.x_shape = [10, 3, 4] + self.y_shape = [1] class TestElementwiseSubOp_Vector(TestElementwiseOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.random((100,)).astype(self.dtype), - 'Y': np.random.random((100,)).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} + def init_shape(self): + self.x_shape = [100] + self.y_shape = [100] class TestElementwiseSubOp_broadcast_0(TestElementwiseOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.rand(100, 3, 2).astype(self.dtype), - 'Y': np.random.rand(100).astype(self.dtype), - } - + def init_shape(self): + self.x_shape = [100, 3, 2] + self.y_shape = [100] self.attrs = {'axis': 0} - self.outputs = { - 'Out': self.inputs['X'] - self.inputs['Y'].reshape(100, 1, 1) - } class TestElementwiseSubOp_broadcast_1(TestElementwiseOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.rand(2, 100, 3).astype(self.dtype), - 'Y': np.random.rand(100).astype(self.dtype), - } - + def init_shape(self): + self.x_shape = [2, 100, 3] + self.y_shape = [100] self.attrs = {'axis': 1} - self.outputs = { - 'Out': self.inputs['X'] - self.inputs['Y'].reshape(1, 100, 1) - } class TestElementwiseSubOp_broadcast_2(TestElementwiseOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.rand(2, 3, 100).astype(self.dtype), - 'Y': np.random.rand(100).astype(self.dtype), - } - - self.outputs = { - 'Out': self.inputs['X'] - self.inputs['Y'].reshape(1, 1, 100) - } + def init_shape(self): + self.x_shape = [2, 3, 100] + self.y_shape = [100] class TestElementwiseSubOp_broadcast_3(TestElementwiseOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.rand(2, 10, 12, 3).astype(self.dtype), - 'Y': np.random.rand(10, 12).astype(self.dtype), - } - + def init_shape(self): + self.x_shape = [2, 10, 12, 3] + self.y_shape = [10, 12] self.attrs = {'axis': 1} - self.outputs = { - 'Out': self.inputs['X'] - self.inputs['Y'].reshape(1, 10, 12, 1) - } class TestElementwiseSubOp_broadcast_4(TestElementwiseOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.rand(2, 5, 3, 12).astype(self.dtype), - 'Y': np.random.rand(2, 5, 1, 12).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} + def init_shape(self): + self.x_shape = [2, 5, 3, 12] + self.y_shape = [2, 5, 1, 12] class TestElementwiseSubOp_commonuse_1(TestElementwiseOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.rand(2, 3, 100).astype(self.dtype), - 'Y': np.random.rand(1, 1, 100).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} + def init_shape(self): + self.x_shape = [2, 3, 100] + self.y_shape = [1, 1, 100] class TestElementwiseSubOp_commonuse_2(TestElementwiseOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.rand(10, 3, 1, 4).astype(self.dtype), - 'Y': np.random.rand(10, 1, 12, 1).astype(self.dtype), - } - self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} + def init_shape(self): + self.x_shape = [10, 3, 1, 4] + self.y_shape = [10, 1, 12, 1] class TestElementwiseSubOp_xsize_lessthan_ysize(TestElementwiseOp): - def init_input_output(self): - self.inputs = { - 'X': np.random.rand(10, 12).astype(self.dtype), - 'Y': np.random.rand(2, 3, 10, 12).astype(self.dtype), - } - + def init_shape(self): + self.x_shape = [10, 12] + self.y_shape = [2, 3, 10, 12] self.attrs = {'axis': 2} - self.outputs = { - 'Out': self.inputs['X'].reshape(1, 1, 10, 12) - self.inputs['Y'] - } - support_types = get_xpu_op_support_types('elementwise_sub') for stype in support_types: diff --git a/test/xpu/test_reduce_sum_op_xpu.py b/test/xpu/test_reduce_sum_op_xpu.py index 06c62d29fb263d..cbf144c923bcba 100644 --- a/test/xpu/test_reduce_sum_op_xpu.py +++ b/test/xpu/test_reduce_sum_op_xpu.py @@ -20,6 +20,7 @@ create_test_class, get_xpu_op_support_types, ) +from op_test import convert_float_to_uint16 from op_test_xpu import XPUOpTest import paddle @@ -38,6 +39,16 @@ def setUp(self): self.init_case() self.set_case() + def gen_data_depend_on_dtype(self, shape): + if ( + self.dtype == np.int32 + or self.dtype == np.int64 + or self.dtype == np.uint8 + ): + return np.random.randint(1, 100, size=shape) + else: + return np.random.uniform(-1, 1, size=shape) + def set_case(self): self.op_type = 'reduce_sum' self.attrs = { @@ -46,17 +57,29 @@ def set_case(self): 'keep_dim': self.keep_dim, 'dim': self.axis, } - self.inputs = {'X': np.random.random(self.shape).astype(self.dtype)} - if self.attrs['reduce_all']: - self.outputs = {'Out': self.inputs['X'].sum()} + tmp_x = self.gen_data_depend_on_dtype(self.shape) + if self.dtype == np.uint16: + tmp_out = ( + tmp_x.sum() + if self.attrs['reduce_all'] + else tmp_x.sum( + axis=self.axis, keepdims=self.attrs['keep_dim'] + ) + ) + self.outputs = {'Out': tmp_out} + tmp_x = convert_float_to_uint16(tmp_x) + self.inputs = {'X': tmp_x} else: - self.outputs = { - 'Out': self.inputs['X'].sum( + tmp_x = tmp_x.astype(self.dtype) + self.inputs = {'X': tmp_x} + tmp_out = ( + tmp_x.sum() + if self.attrs['reduce_all'] + else tmp_x.sum( axis=self.axis, keepdims=self.attrs['keep_dim'] ) - } - if self.dtype == np.uint16: - self.outputs['Out'] = self.outputs['Out'].astype(np.uint16) + ) + self.outputs = {'Out': tmp_out} def init_case(self): self.shape = (5, 6, 10) From b7c36ed6aff14810920ce7fdaa7f1cd98340b53b Mon Sep 17 00:00:00 2001 From: Jianbang Yang Date: Fri, 29 Dec 2023 14:53:09 +0800 Subject: [PATCH 048/142] [XPU][Phi Kernel] xpu::nonzero support simulator XPUSIM_SKIP_RUN mode (#60388) --- paddle/phi/kernels/xpu/masked_select_kernel.cc | 10 +++++++++- paddle/phi/kernels/xpu/nonzero_kernel.cc | 5 ++--- ...gmoid_cross_entropy_with_logits_grad_kernel.cc | 10 ++++++++++ .../sigmoid_cross_entropy_with_logits_kernel.cc | 11 ++++++++++- test/xpu/test_masked_select_op_xpu.py | 15 +++++++++++++++ 5 files changed, 46 insertions(+), 5 deletions(-) diff --git a/paddle/phi/kernels/xpu/masked_select_kernel.cc b/paddle/phi/kernels/xpu/masked_select_kernel.cc index 62803fde27aa5c..85687c19f6c06d 100644 --- a/paddle/phi/kernels/xpu/masked_select_kernel.cc +++ b/paddle/phi/kernels/xpu/masked_select_kernel.cc @@ -14,6 +14,8 @@ #include "paddle/phi/kernels/masked_select_kernel.h" +#include "glog/logging.h" + #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" @@ -54,7 +56,13 @@ void MaskedSelectKernel(const Context& dev_ctx, mask.place(), static_cast(out_size), sizeof(int32_t)); - + if (std::getenv("XPUSIM_SKIP_RUN") && + std::strcmp(std::getenv("XPUSIM_SKIP_RUN"), "1") == 0) { + VLOG(3) << "WARNING: In the simulator mode, the variable out_size_cpu " + "stores an uninitialized value. To avoid allocating a memory of " + "random size, we assign numel to out_size_cpu"; + out_size_cpu = mask.numel(); + } DDim out_dim{out_size_cpu}; out->Resize(out_dim); auto out_data = reinterpret_cast(dev_ctx.template Alloc(out)); diff --git a/paddle/phi/kernels/xpu/nonzero_kernel.cc b/paddle/phi/kernels/xpu/nonzero_kernel.cc index e2a1339504bae1..8dfd7734cff521 100644 --- a/paddle/phi/kernels/xpu/nonzero_kernel.cc +++ b/paddle/phi/kernels/xpu/nonzero_kernel.cc @@ -46,9 +46,8 @@ void NonZeroKernel(const Context& dev_ctx, std::strcmp(std::getenv("XPUSIM_SKIP_RUN"), "1") == 0) { VLOG(3) << "WARNING: In the simulator mode, the variable true_num_cpu " "stores an uninitialized value. To avoid allocating a memory of " - "random size, we limit the value of true_num_cpu to the range 0 " - "<= true_num_cpu < numel"; - true_num_cpu = std::min(std::max(true_num_cpu, 0), static_cast(numel)); + "random size, we assign numel to true_num_cpu"; + true_num_cpu = numel; } out->Resize(common::make_ddim({static_cast(true_num_cpu), rank})); diff --git a/paddle/phi/kernels/xpu/sigmoid_cross_entropy_with_logits_grad_kernel.cc b/paddle/phi/kernels/xpu/sigmoid_cross_entropy_with_logits_grad_kernel.cc index cf383439a77e9e..9ee967b5e57252 100644 --- a/paddle/phi/kernels/xpu/sigmoid_cross_entropy_with_logits_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/sigmoid_cross_entropy_with_logits_grad_kernel.cc @@ -16,6 +16,8 @@ #include "paddle/phi/kernels/sigmoid_cross_entropy_with_logits_grad_kernel.h" +#include "glog/logging.h" + #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/core/kernel_registry.h" @@ -79,6 +81,14 @@ void SigmoidCrossEntropyWithLogitsGradKernel( dev_ctx.GetPlace(), static_cast(non_zero), sizeof(int)); + if (std::getenv("XPUSIM_SKIP_RUN") && + std::strcmp(std::getenv("XPUSIM_SKIP_RUN"), "1") == 0) { + VLOG(3) + << "WARNING: In the simulator mode, the variable non_zero_cpu " + "stores an uninitialized value. To avoid allocating a memory of " + "random size, we assign numel to true_num_cpu"; + non_zero_cpu = x.numel(); + } r = xpu::scale(dev_ctx.x_context(), reinterpret_cast(in_grad->data()), reinterpret_cast(in_grad->data()), diff --git a/paddle/phi/kernels/xpu/sigmoid_cross_entropy_with_logits_kernel.cc b/paddle/phi/kernels/xpu/sigmoid_cross_entropy_with_logits_kernel.cc index c189c143adb747..fa2b6f24c173a7 100644 --- a/paddle/phi/kernels/xpu/sigmoid_cross_entropy_with_logits_kernel.cc +++ b/paddle/phi/kernels/xpu/sigmoid_cross_entropy_with_logits_kernel.cc @@ -16,6 +16,8 @@ #include "paddle/phi/kernels/sigmoid_cross_entropy_with_logits_kernel.h" +#include "glog/logging.h" + #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/core/kernel_registry.h" @@ -75,7 +77,14 @@ void SigmoidCrossEntropyWithLogitsKernel( dev_ctx.GetPlace(), static_cast(non_zero), sizeof(int)); - + if (std::getenv("XPUSIM_SKIP_RUN") && + std::strcmp(std::getenv("XPUSIM_SKIP_RUN"), "1") == 0) { + VLOG(3) + << "WARNING: In the simulator mode, the variable non_zero_cpu " + "stores an uninitialized value. To avoid allocating a memory of " + "random size, we assign numel to non_zero_cpu"; + non_zero_cpu = x.numel(); + } r = xpu::scale(dev_ctx.x_context(), reinterpret_cast(out->data()), reinterpret_cast(out->data()), diff --git a/test/xpu/test_masked_select_op_xpu.py b/test/xpu/test_masked_select_op_xpu.py index f2ed82cd1e8d76..30b91f38b66d6b 100644 --- a/test/xpu/test_masked_select_op_xpu.py +++ b/test/xpu/test_masked_select_op_xpu.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest import numpy as np @@ -108,6 +109,20 @@ def test_static_mode(self): ) self.assertEqual(np.allclose(res, np_out), True) + def test_simulator_skip_run_mode(self): + os.environ['XPUSIM_SKIP_RUN'] = '1' + paddle.disable_static(paddle.XPUPlace(0)) + shape = (88, 6, 8) + np_x = np.random.random(shape).astype('float32') + np_mask = np.array(np.random.randint(2, size=shape, dtype=bool)) + x = paddle.to_tensor(np_x) + mask = paddle.to_tensor(np_mask) + out = paddle.masked_select(x, mask) + # only check the numel of output + np.testing.assert_equal(out.numpy().size, np_x.size) + paddle.enable_static() + del os.environ['XPUSIM_SKIP_RUN'] + class TestMaskedSelectError(unittest.TestCase): def test_error(self): From ebc859a69b7732a59ced3d68e8d7788f45cfaf50 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Fri, 29 Dec 2023 14:54:31 +0800 Subject: [PATCH 049/142] [DimExpr] Add utils for DimExpr, convert between DimExpr and pir::Attribute (#60390) * convert between DimExpr and pir::Attribute * add two helper functions: SubstituteDimExpr and MakeGetterDimExpr4SymbolName * Fix compile bug and add unittest * Fix bug * Add IR_API * Fix windows compile error --- .../pir/dialect/shape/utils/dim_expr_util.cc | 362 ++++++++++++++++++ .../pir/dialect/shape/utils/dim_expr_util.h | 42 ++ test/cpp/pir/shape_dialect/CMakeLists.txt | 3 + .../symbol_dim_expr_util_test.cc | 99 +++++ 4 files changed, 506 insertions(+) create mode 100644 paddle/pir/dialect/shape/utils/dim_expr_util.cc create mode 100644 paddle/pir/dialect/shape/utils/dim_expr_util.h create mode 100644 test/cpp/pir/shape_dialect/symbol_dim_expr_util_test.cc diff --git a/paddle/pir/dialect/shape/utils/dim_expr_util.cc b/paddle/pir/dialect/shape/utils/dim_expr_util.cc new file mode 100644 index 00000000000000..8421f500c23daa --- /dev/null +++ b/paddle/pir/dialect/shape/utils/dim_expr_util.cc @@ -0,0 +1,362 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pir/dialect/shape/utils/dim_expr_util.h" +#include "paddle/pir/core/builder.h" +#include "paddle/pir/core/builtin_attribute.h" + +namespace symbol { + +namespace { + +template +std::string GetSerializedTag(); + +template <> +std::string GetSerializedTag>() { + return "Negative"; +} + +template <> +std::string GetSerializedTag>() { + return "Reciprocal"; +} + +template <> +std::string GetSerializedTag>() { + return "Add"; +} + +template <> +std::string GetSerializedTag>() { + return "Mul"; +} + +template <> +std::string GetSerializedTag>() { + return "Max"; +} + +template <> +std::string GetSerializedTag>() { + return "Min"; +} + +template <> +std::string GetSerializedTag>() { + return "Broadcast"; +} + +::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder, + const std::int64_t& dim_expr) { + return builder->int64_attr(dim_expr); +} + +::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder, + const std::string& dim_expr) { + return builder->str_attr(dim_expr); +} + +template +::pir::Attribute ConvertUnaryDimExprToAttributeImpl(::pir::Builder* builder, + const T& dim_expr) { + std::vector<::pir::Attribute> attr_vecs{}; + attr_vecs.push_back(builder->str_attr(GetSerializedTag())); + const auto& operand = dim_expr->data; + attr_vecs.push_back(ConvertDimExprToAttribute(builder, operand)); + return builder->array_attr(attr_vecs); +} + +::pir::Attribute ConvertDimExprToAttributeImpl( + ::pir::Builder* builder, const Negative& dim_expr) { + return ConvertUnaryDimExprToAttributeImpl(builder, dim_expr); +} + +::pir::Attribute ConvertDimExprToAttributeImpl( + ::pir::Builder* builder, const Reciprocal& dim_expr) { + return ConvertUnaryDimExprToAttributeImpl(builder, dim_expr); +} + +template +::pir::Attribute ConvertVariadicDimExprToAttribute(::pir::Builder* builder, + const T& dim_expr) { + std::vector<::pir::Attribute> attr_vecs{}; + attr_vecs.push_back(builder->str_attr(GetSerializedTag())); + const auto& operands = *(dim_expr.operands); + for (const auto& operand : operands) { + attr_vecs.push_back(ConvertDimExprToAttribute(builder, operand)); + } + return builder->array_attr(attr_vecs); +} + +::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder, + const Add& dim_expr) { + return ConvertVariadicDimExprToAttribute(builder, dim_expr); +} + +::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder, + const Mul& dim_expr) { + return ConvertVariadicDimExprToAttribute(builder, dim_expr); +} + +::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder, + const Max& dim_expr) { + return ConvertVariadicDimExprToAttribute(builder, dim_expr); +} + +::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder, + const Min& dim_expr) { + return ConvertVariadicDimExprToAttribute(builder, dim_expr); +} + +::pir::Attribute ConvertDimExprToAttributeImpl( + ::pir::Builder* builder, const Broadcast& dim_expr) { + return ConvertVariadicDimExprToAttribute(builder, dim_expr); +} + +std::optional ConvertInt64AttributeToDimExpr( + const ::pir::Int64Attribute& attribute) { + return DimExpr{attribute.data()}; +} + +std::optional ConvertStrAttributeToDimExpr( + const ::pir::StrAttribute& attribute) { + return DimExpr{attribute.AsString()}; +} + +template +std::optional ConvertArrayAttributeToUnaryDimExpr( + const ::pir::ArrayAttribute& attribute) { + if (attribute.size() != 2) { + return std::nullopt; + } + std::optional operand = ConvertAttributeToDimExpr(attribute.at(1)); + if (!operand.has_value()) { + return std::nullopt; + } + return T{operand.value()}; +} + +template +std::optional ConvertArrayAttributeToVariadicDimExpr( + const ::pir::ArrayAttribute& attribute) { + if (attribute.size() < 2) { + return std::nullopt; + } + List operands{}; + for (std::size_t i = 1; i < attribute.size(); ++i) { + std::optional operand = ConvertAttributeToDimExpr(attribute.at(i)); + if (!operand.has_value()) { + return std::nullopt; + } + operands->push_back(operand.value()); + } + return T{operands}; +} + +typedef std::optional (*ArrayAttributeConverterT)( + const ::pir::ArrayAttribute& attribute); + +std::optional GetArrayAttributeConverter( + const std::string& tag) { + static std::unordered_map map{ + {GetSerializedTag>(), + &ConvertArrayAttributeToUnaryDimExpr>}, + {GetSerializedTag>(), + &ConvertArrayAttributeToUnaryDimExpr>}, + {GetSerializedTag>(), + &ConvertArrayAttributeToVariadicDimExpr>}, + {GetSerializedTag>(), + &ConvertArrayAttributeToVariadicDimExpr>}, + {GetSerializedTag>(), + &ConvertArrayAttributeToVariadicDimExpr>}, + {GetSerializedTag>(), + &ConvertArrayAttributeToVariadicDimExpr>}, + {GetSerializedTag>(), + &ConvertArrayAttributeToVariadicDimExpr>}, + }; + const auto& iter = map.find(tag); + if (iter == map.end()) { + return std::nullopt; + } + return iter->second; +} + +std::optional ConvertArrayAttributeToDimExpr( + const ::pir::ArrayAttribute& attribute) { + if (attribute.empty()) { + return std::nullopt; + } + if (!attribute.at(0).isa<::pir::StrAttribute>()) { + return std::nullopt; + } + const auto& tag = attribute.at(0).dyn_cast<::pir::StrAttribute>().AsString(); + auto opt_func = GetArrayAttributeConverter(tag); + if (!opt_func.has_value()) { + return std::nullopt; + } + return opt_func.value()(attribute); +} + +} // namespace + +::pir::Attribute ConvertDimExprToAttribute(::pir::Builder* builder, + const DimExpr& dim_expr) { + return std::visit( + [&](const auto& impl) { + return ConvertDimExprToAttributeImpl(builder, impl); + }, + dim_expr.variant()); +} + +std::optional ConvertAttributeToDimExpr(::pir::Attribute attribute) { + if (attribute.isa<::pir::Int64Attribute>()) { + return ConvertInt64AttributeToDimExpr( + attribute.dyn_cast<::pir::Int64Attribute>()); + } + if (attribute.isa<::pir::StrAttribute>()) { + return ConvertStrAttributeToDimExpr( + attribute.dyn_cast<::pir::StrAttribute>()); + } + if (attribute.isa<::pir::ArrayAttribute>()) { + return ConvertArrayAttributeToDimExpr( + attribute.dyn_cast<::pir::ArrayAttribute>()); + } + return std::nullopt; +} + +class SubstituteDimExprHelper final { + public: + using DimExpr4SymbolNameT = + std::function(const std::string& symbol_name)>; + + explicit SubstituteDimExprHelper( + const DimExpr4SymbolNameT& DimExpr4SymbolName) + : DimExpr4SymbolName_(DimExpr4SymbolName) {} + + std::optional Substitute(const DimExpr& dim_expr) { + return std::visit([&](const auto& impl) { return SubstituteImpl(impl); }, + dim_expr.variant()); + } + + private: + std::optional SubstituteImpl(const std::int64_t& dim_expr) { + return dim_expr; + } + std::optional SubstituteImpl(const std::string& dim_expr) { + return DimExpr4SymbolName_(dim_expr); + } + + std::optional SubstituteImpl(const Negative& dim_expr) { + return SubstituteUnary(dim_expr); + } + std::optional SubstituteImpl(const Reciprocal& dim_expr) { + return SubstituteUnary(dim_expr); + } + + template + std::optional SubstituteUnary(const T& dim_expr) { + const auto& operand = dim_expr->data; + const auto& substituted_operand = Substitute(operand); + if (!substituted_operand.has_value()) { + return std::nullopt; + } + return T{substituted_operand.value()}; + } + + std::optional SubstituteImpl(const Add& dim_expr) { + return SubstituteVariadic(dim_expr); + } + + std::optional SubstituteImpl(const Mul& dim_expr) { + return SubstituteVariadic(dim_expr); + } + + std::optional SubstituteImpl(const Max& dim_expr) { + return SubstituteVariadic(dim_expr); + } + + std::optional SubstituteImpl(const Min& dim_expr) { + return SubstituteVariadic(dim_expr); + } + + std::optional SubstituteImpl(const Broadcast& dim_expr) { + return SubstituteVariadic(dim_expr); + } + + template + std::optional SubstituteVariadic(const T& dim_expr) { + const auto& operands = *(dim_expr.operands); + List substituted_operands{}; + for (const auto& operand : operands) { + const auto& substituted_operand = Substitute(operand); + if (!substituted_operand.has_value()) { + return std::nullopt; + } + substituted_operands->push_back(substituted_operand.value()); + } + return T{substituted_operands}; + } + + DimExpr4SymbolNameT DimExpr4SymbolName_; +}; + +std::optional SubstituteDimExpr( + const DimExpr& dim_expr, + const std::function(const std::string& symbol_name)>& + DimExpr4SymbolName) { + return SubstituteDimExprHelper(DimExpr4SymbolName).Substitute(dim_expr); +} + +std::function(const std::string& symbol_name)> +MakeGetterDimExpr4SymbolName( + const std::vector>& symbol_bindings, + const std::function( + int in_tensor_idx, int in_tensor_dim_idx)>& DimExpr4InputDim) { + std::unordered_map>> + symbol_name2in_tensor_dim_pos; + for (const auto& tuple : symbol_bindings) { + const auto& [symbol_name, in_tensor_idx, in_tensor_dim_idx] = tuple; + symbol_name2in_tensor_dim_pos[symbol_name].emplace_back( + std::pair{in_tensor_idx, in_tensor_dim_idx}); + } + return [map = std::move(symbol_name2in_tensor_dim_pos), DimExpr4InputDim]( + const std::string& symbol_name) -> std::optional { + const auto& iter = map.find(symbol_name); + if (iter == map.end()) { + return std::nullopt; + } + const auto& positions = iter->second; + std::optional ret = std::nullopt; + for (const auto& [in_tensor_idx, in_tensor_dim_idx] : positions) { + const auto& current = DimExpr4InputDim(in_tensor_idx, in_tensor_dim_idx); + if (!current.has_value()) { + return std::nullopt; + } + if (ret.has_value()) { + // Same names, same DimExprs. + if (ret.value() != current.value()) { + return std::nullopt; + } + } else { + ret = current; + } + } + return ret; + }; +} + +} // namespace symbol diff --git a/paddle/pir/dialect/shape/utils/dim_expr_util.h b/paddle/pir/dialect/shape/utils/dim_expr_util.h new file mode 100644 index 00000000000000..3ed4550c2248d5 --- /dev/null +++ b/paddle/pir/dialect/shape/utils/dim_expr_util.h @@ -0,0 +1,42 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/core/builder.h" +#include "paddle/pir/core/dll_decl.h" +#include "paddle/pir/dialect/shape/utils/dim_expr.h" + +namespace symbol { + +IR_API ::pir::Attribute ConvertDimExprToAttribute(::pir::Builder* builder, + const DimExpr& dim_expr); +IR_API std::optional ConvertAttributeToDimExpr( + ::pir::Attribute attribute); + +IR_API std::optional SubstituteDimExpr( + const DimExpr& dim_expr, + const std::function(const std::string& symbol_name)>& + DimExpr4SymbolName); + +IR_API std::function(const std::string& symbol_name)> +MakeGetterDimExpr4SymbolName( + const std::vector>& symbol_bindings, + const std::function( + int in_tensor_idx, int in_tensor_dim_idx)>& DimExpr4InputDim); + +} // namespace symbol diff --git a/test/cpp/pir/shape_dialect/CMakeLists.txt b/test/cpp/pir/shape_dialect/CMakeLists.txt index f508efb56947e9..5c3aa2b9f43449 100644 --- a/test/cpp/pir/shape_dialect/CMakeLists.txt +++ b/test/cpp/pir/shape_dialect/CMakeLists.txt @@ -4,6 +4,9 @@ paddle_test(shape_struct_test SRCS shape_struct_test.cc DEPS gtest) paddle_test(symbol_dim_expr_test SRCS symbol_dim_expr_test.cc DEPS gtest) +paddle_test(symbol_dim_expr_util_test SRCS symbol_dim_expr_util_test.cc DEPS + gtest) + if(WITH_CINN) paddle_test( shape_optimization_test diff --git a/test/cpp/pir/shape_dialect/symbol_dim_expr_util_test.cc b/test/cpp/pir/shape_dialect/symbol_dim_expr_util_test.cc new file mode 100644 index 00000000000000..0893a6d5027055 --- /dev/null +++ b/test/cpp/pir/shape_dialect/symbol_dim_expr_util_test.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" + +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/pir/dialect/shape/utils/dim_expr_builder.h" +#include "paddle/pir/dialect/shape/utils/dim_expr_util.h" + +#include "test/cpp/pir/tools/test_pir_utils.h" + +namespace symbol { + +namespace { +DimExpr CreateExampleDimExpr() { + DimExprBuilder dim_expr_builder{nullptr}; + DimExpr sym0 = DimExpr("S0"); + DimExpr sym1 = DimExpr("S1"); + DimExpr constant = DimExpr(2); + DimExpr expr1 = (sym0 - sym1) * constant / sym0; + DimExpr expr2 = dim_expr_builder.Max(expr1, sym0); + DimExpr output = dim_expr_builder.Min(expr2, sym1); + return output; +} +} // namespace + +TEST(DimExprUtil, Convert) { + pir::IrContext* ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + + DimExpr dim_expr = CreateExampleDimExpr(); + ::pir::Attribute attr = ConvertDimExprToAttribute(&builder, dim_expr); + std::optional opt_expr = ConvertAttributeToDimExpr(attr); + ASSERT_TRUE(opt_expr.has_value()); + ASSERT_EQ(opt_expr.value(), dim_expr); +} + +TEST(DimExprUtil, Substitute) { + DimExpr dim_expr = CreateExampleDimExpr(); + const auto& opt_expr = SubstituteDimExpr( + dim_expr, [](const std::string& str) -> std::optional { + if (str == "S0") { + return DimExpr("symbol0"); + } else if (str == "S1") { + return DimExpr("symbol1"); + } else { + return std::nullopt; + } + }); + ASSERT_TRUE(opt_expr.has_value()); + const auto& ret_expr = SubstituteDimExpr( + opt_expr.value(), [](const std::string& str) -> std::optional { + if (str == "symbol0") { + return DimExpr("S0"); + } else if (str == "symbol1") { + return DimExpr("S1"); + } else { + return std::nullopt; + } + }); + ASSERT_TRUE(ret_expr.has_value()); + ASSERT_EQ(ret_expr.value(), dim_expr); +} + +TEST(DimExprUtil, MakeGetterDimExpr4SymbolName) { + std::vector> + symbol_bindings{}; + symbol_bindings.push_back(std::make_tuple("Symbol", 0, 0)); + const auto& dim_expr = CreateExampleDimExpr(); + const auto& DimExpr4SymbolName = MakeGetterDimExpr4SymbolName( + symbol_bindings, + [dim_expr](int in_tensor_idx, + int in_tensor_dim_idx) -> std::optional { + if (in_tensor_idx == 0 && in_tensor_dim_idx == 0) { + return dim_expr; + } else { + return std::nullopt; + } + }); + const auto& opt_dim_expr = DimExpr4SymbolName("Symbol"); + ASSERT_TRUE(opt_dim_expr.has_value()); + ASSERT_EQ(opt_dim_expr.value(), dim_expr); +} + +} // namespace symbol From e4b39bb56a4e55213383e96daf262f4f72c1811d Mon Sep 17 00:00:00 2001 From: lijin23 <41257772+lj970926@users.noreply.github.com> Date: Fri, 29 Dec 2023 15:05:25 +0800 Subject: [PATCH 050/142] [XPU][PHI Kernels] refine bf16 test for fused_rope (#60439) * refine fuesd_rope bf16 test * format code --- ..._fused_rotary_position_embedding_op_xpu.py | 94 ++++++++++++++----- 1 file changed, 71 insertions(+), 23 deletions(-) diff --git a/test/xpu/test_fused_rotary_position_embedding_op_xpu.py b/test/xpu/test_fused_rotary_position_embedding_op_xpu.py index 0fe25194c1633d..6aac9d828cc037 100644 --- a/test/xpu/test_fused_rotary_position_embedding_op_xpu.py +++ b/test/xpu/test_fused_rotary_position_embedding_op_xpu.py @@ -196,6 +196,7 @@ def get_forward_backward( fw.append(out_q) fw.append(out_k) fw.append(out_v) + paddle.seed(seed + 1) out_gq = paddle.randn(out_q.shape, self.dtype) out_gk = paddle.randn(out_q.shape, self.dtype) out_gv = paddle.randn(out_q.shape, self.dtype) @@ -203,9 +204,9 @@ def get_forward_backward( paddle.autograd.backward( [out_q, out_k, out_v], [out_gq, out_gk, out_gv], True ) - bw.append(tensor_q) - bw.append(tensor_k) - bw.append(tensor_v) + bw.append(tensor_q.grad) + bw.append(tensor_k.grad) + bw.append(tensor_v.grad) return fw, bw @@ -368,28 +369,28 @@ def setUp(self): self.shape = [2, 8, 2, 16] def test_api(self): - q_fp32 = paddle.rand(self.shape, dtype="float32") - k_fp32 = paddle.rand(self.shape, dtype="float32") - v_fp32 = paddle.rand(self.shape, dtype="float32") - sin_fp32 = paddle.rand( - [1, self.shape[1], 1, self.shape[3]], dtype="float32" + paddle.disable_static() + q_bf16 = paddle.randn(self.shape, dtype="bfloat16") + k_bf16 = paddle.randn(self.shape, dtype="bfloat16") + v_bf16 = paddle.randn(self.shape, dtype="bfloat16") + sin_bf16 = paddle.randn( + [1, self.shape[1], 1, self.shape[3]], dtype="bfloat16" ) - cos_fp32 = paddle.rand( - [1, self.shape[1], 1, self.shape[3]], dtype="float32" + cos_bf16 = paddle.randn( + [1, self.shape[1], 1, self.shape[3]], dtype="bfloat16" ) - q_bf16 = paddle.to_tensor(q_fp32, dtype="bfloat16") - k_bf16 = paddle.to_tensor(k_fp32, dtype="bfloat16") - v_bf16 = paddle.to_tensor(v_fp32, dtype="bfloat16") - sin_bf16 = paddle.to_tensor(sin_fp32, dtype="bfloat16") - cos_bf16 = paddle.to_tensor(cos_fp32, dtype="bfloat16") - - out_fp32 = fused_rotary_position_embedding( - q_fp32, - k_fp32, - v_fp32, - sin_fp32, - cos_fp32, - use_neox_rotary_style=False, + q_bf16.stop_gradient = False + k_bf16.stop_gradient = False + v_bf16.stop_gradient = False + q_fp32 = paddle.to_tensor(q_bf16, dtype="float32", stop_gradient=False) + k_fp32 = paddle.to_tensor(k_bf16, dtype="float32", stop_gradient=False) + v_fp32 = paddle.to_tensor(v_bf16, dtype="float32", stop_gradient=False) + sin_fp32 = paddle.to_tensor(sin_bf16, dtype="float32") + cos_fp32 = paddle.to_tensor(cos_bf16, dtype="float32") + + position_ids = paddle.arange(0, self.shape[1], dtype="int64") + position_ids = paddle.stack( + [position_ids for _ in range(self.shape[0])], axis=0 ) out_bf16 = fused_rotary_position_embedding( q_bf16, @@ -397,13 +398,60 @@ def test_api(self): v_bf16, sin_bf16, cos_bf16, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + + grad_out_q_bf16 = paddle.randn(self.shape, dtype="bfloat16") + grad_out_k_bf16 = paddle.randn(self.shape, dtype="bfloat16") + grad_out_v_bf16 = paddle.randn(self.shape, dtype="bfloat16") + + paddle.autograd.backward( + out_bf16, [grad_out_q_bf16, grad_out_k_bf16, grad_out_v_bf16], True + ) + grad_bf16 = [q_bf16.grad, k_bf16.grad, v_bf16.grad] + + out_fp32 = paddle_fused_rotary_position_embedding( + q_fp32, + k_fp32, + v_fp32, + sin_fp32, + cos_fp32, + position_ids=position_ids, use_neox_rotary_style=False, ) + + grad_out_q_fp32 = paddle.to_tensor(grad_out_q_bf16, dtype="float32") + grad_out_k_fp32 = paddle.to_tensor(grad_out_k_bf16, dtype="float32") + grad_out_v_fp32 = paddle.to_tensor(grad_out_v_bf16, dtype="float32") + paddle.autograd.backward( + out_fp32, [grad_out_q_fp32, grad_out_k_fp32, grad_out_v_fp32], True + ) + grad_fp32 = [q_fp32.grad, k_fp32.grad, v_fp32.grad] + for fp32_val, bf16_val in zip(out_fp32, out_bf16): bf16_val = convert_uint16_to_float(bf16_val.numpy()) np.testing.assert_allclose( fp32_val.numpy(), bf16_val, rtol=1e-2, atol=1e-2 ) + for grad_fp32_val, grad_bf16_val in zip(grad_fp32, grad_bf16): + grad_bf16_val = convert_uint16_to_float(grad_bf16_val.numpy()) + np.testing.assert_allclose( + grad_fp32_val.numpy(), grad_bf16_val, rtol=1e-2, atol=1e-2 + ) + + +class XPUTestFusedRotaryPositionEmbeddingBf16_2( + XPUTestFusedRotaryPositionEmbeddingBf16_1 +): + def setUp(self): + self.shape = [2, 2048, 16, 128] + + +# too long for CI +# class XPUTestFusedRotaryPositionEmbeddingBf16_3(XPUTestFusedRotaryPositionEmbeddingBf16_1): +# def setUp(self): +# self.shape = [2, 8192, 8, 128] if __name__ == '__main__': From 63776cfae91119b8f169536691f5f3aa1b23f1b8 Mon Sep 17 00:00:00 2001 From: xysheng-baidu <121540080+xysheng-baidu@users.noreply.github.com> Date: Fri, 29 Dec 2023 15:40:44 +0800 Subject: [PATCH 051/142] [auto config] Resume from history csv file (#60417) --- python/paddle/distributed/auto_tuner/tuner.py | 77 +++++++++++++++++++ python/paddle/distributed/launch/main.py | 71 +++++++++++++++++ 2 files changed, 148 insertions(+) diff --git a/python/paddle/distributed/auto_tuner/tuner.py b/python/paddle/distributed/auto_tuner/tuner.py index b3b6cbf3cdc528..6a6a0ba4e082ff 100644 --- a/python/paddle/distributed/auto_tuner/tuner.py +++ b/python/paddle/distributed/auto_tuner/tuner.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import csv +import os from .utils import default_candidates, gbs_default_candidates @@ -54,6 +56,8 @@ def __init__(self, tuner_cfg): raise NotImplementedError() self.history_cfgs = [] + self.resume_cfgs = [] + self.tuner_cfg = tuner_cfg def search_once(self): """Return a new task config.""" @@ -67,3 +71,76 @@ def search_once(self): def add_cfg(self, cfg): """Add cfg into history cfgs""" self.history_cfgs.append(cfg) + + def resume_form_history(self, history_csv_path="./history.csv"): + """Resume form history csv file""" + # The breakpoint resume function does not start when the resume csv file does not exist. + if not os.path.exists(history_csv_path): + return + resume_csv_path = os.path.join( + os.path.dirname(history_csv_path), + f'{os.path.basename(history_csv_path).split(".")[0]}_copy.csv', + ) + with open(history_csv_path, "r") as fread: + reader = csv.reader(fread) + data_list = list(reader) + with open(resume_csv_path, "w") as fwrite: + writer = csv.writer(fwrite) + for row in data_list: + writer.writerow(row) + # chang str type to real type + for row in data_list: + for i, value in enumerate(row): + try: + row[i] = int(value) + except ValueError: + try: + row[i] = float(value) + except ValueError: + pass + + data_dict = [] + keys = data_list[0] + values = data_list[1:] + for val in values: + val = [x if x != '' else None for x in val] + val = [True if x == 'True' else x for x in val] + val = [False if x == 'False' else x for x in val] + dictionary = dict(zip(keys, val)) + time_val = -1 + target_key = self.tuner_cfg["metric_cfg"]["name"] + if dictionary[target_key]: + time_val = dictionary[target_key] + dictionary["time"] = time_val + data_dict.append(dictionary) + self.resume_cfgs = data_dict + + def get_cfg_from_resume(self, cur_cfg): + """Get cfg from resume cfgs""" + keys_to_compare = [ + 'mp_degree', + 'sharding_degree', + 'pp_degree', + 'dp_degree', + 'sharding_stage', + 'micro_batch_size', + 'vpp_degree', + 'use_recompute', + 'recompute_granularity', + 'num_gpus', + 'nodes', + 'global_batch_size', + 'sharding_overlap', + 'acc_steps', + ] + for cfg in self.resume_cfgs: + ret_is_same = True + for key in keys_to_compare: + if not cfg.get(key) and not cur_cfg.get(key): + continue + else: + is_same = str(cfg.get(key)) == str(cur_cfg.get(key)) + ret_is_same = ret_is_same and is_same + if ret_is_same: + return cfg + return None diff --git a/python/paddle/distributed/launch/main.py b/python/paddle/distributed/launch/main.py index 0869ac7bbfcd95..40caf7f223677b 100644 --- a/python/paddle/distributed/launch/main.py +++ b/python/paddle/distributed/launch/main.py @@ -587,6 +587,10 @@ def launch(): logger.info( f"Launch {len(auto_tuner.algo.all_tasks)} tasks by auto tuner: " ) + resume_csv_file_path = tuner_cfg.get( + "resume_csv_file_path", history_file_path + ) + auto_tuner.resume_form_history(resume_csv_file_path) cur_cfg = auto_tuner.search_once() auto_tuner.add_cfg(cur_cfg) assert cur_cfg is not None, "No config can run." @@ -658,6 +662,73 @@ def launch(): ) logger.info(f"Launch task: job_id {task_job_id}, log_dir {log_dir}") + cur_resume_cfg = auto_tuner.get_cfg_from_resume(cur_cfg) + if cur_resume_cfg: + cur_cfg = cur_resume_cfg + cur_cfg['job_id'] = job_id + auto_tuner.history_cfgs.pop(-1) + auto_tuner.add_cfg(cur_cfg) + recorder.add_cfg(**cur_cfg) + cur_best_cfgs, err = recorder.get_best( + metric=tuner_cfg['metric_cfg']['name'], + direction=tuner_cfg['metric_cfg']['OptimizationDirection'], + ) + if not err: + ctx.logger.info(f"Current best config: {cur_best_cfgs}") + logger.info(f"Current best config: {cur_best_cfgs}") + else: + ctx.logger.info( + "Get best config failed. Currently no config can be run." + ) + logger.info( + "Get best config failed. Currently no config can be run." + ) + if ( + "sharding_overlap" in cur_cfg + and cur_cfg["sharding_overlap"] + ): + add_overlap_performance( + cur_cfg, tuner_cfg, recorder.history + ) + + if cur_cfg["error_info"]: + error_task_nums += 1 + error_info = cur_cfg["error_info"] + task_nums = len(auto_tuner.algo.all_tasks) + cur_task_id = auto_tuner.algo.idx + ctx.logger.info( + "Auto Tuner Schedule: [{}/{}], Pruned nums {}, Error nums {}, Error info {}, Remaining time {} min".format( + cur_task_id, + task_nums, + cur_task_id - job_id, + error_task_nums, + error_info, + round( + (task_nums - cur_task_id) * max_time_per_task / 60, + 2, + ), + ) + ) + logger.info( + "Auto Tuner Schedule: [{}/{}], Pruned nums {}, Error nums {}, Error info {}, Remaining time {} min".format( + cur_task_id, + task_nums, + cur_task_id - job_id, + error_task_nums, + error_info, + round( + (task_nums - cur_task_id) * max_time_per_task / 60, + 2, + ), + ) + ) + recorder.store_history(history_file_path) + # generate a new config + new_cfg = auto_tuner.search_once() + cur_cfg = copy.deepcopy(new_cfg) + auto_tuner.add_cfg(cur_cfg) + continue + # in single dp estimation scene, just some nodes not all nodes run ctx = gen_new_ctx(ctx, cur_cfg, tuner_cfg) actual_nnodes = int(ctx.args.nnodes.split(":")[0]) From 10b352e32dbaa804ffa54830f953a43225b2e0c8 Mon Sep 17 00:00:00 2001 From: Jianbang Yang Date: Fri, 29 Dec 2023 17:08:28 +0800 Subject: [PATCH 052/142] [XPU] update XHPC to 20231229 (#60421) --- cmake/external/xpu.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 64e9154f9f8e39..c0aea597308329 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -29,7 +29,7 @@ if(NOT DEFINED XPU_BASE_DATE) set(XPU_BASE_DATE "20231203") endif() if(NOT DEFINED XPU_XHPC_BASE_DATE) - set(XPU_XHPC_BASE_DATE "20231226") + set(XPU_XHPC_BASE_DATE "20231229") endif() set(XPU_XCCL_BASE_VERSION "1.1.8.1") if(NOT DEFINED XPU_XFT_BASE_VERSION) From 044dec73f552136757e23f67b73e58fa1dcf305b Mon Sep 17 00:00:00 2001 From: Yichen Zhang <32740647+pkuzyc@users.noreply.github.com> Date: Fri, 29 Dec 2023 19:59:00 +0800 Subject: [PATCH 053/142] refine shard_layer api (#60468) --- .../paddle/distributed/auto_parallel/api.py | 3 +- test/auto_parallel/test_shard_layer_api.py | 118 +++++++++++++++--- 2 files changed, 105 insertions(+), 16 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index d3f19baded5e6b..c012d7a59d1c61 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -525,8 +525,7 @@ def replicate_layer_params_and_buffers( else: # TODO(chenweihang): Support static mode branch later. raise NotImplementedError( - "`paddle.distributed.shard_layer` only supports dynamic graph mode " - "now. It will be supported for static graph mode later." + "`paddle.distributed.shard_layer` only supports dynamic graph mode." ) diff --git a/test/auto_parallel/test_shard_layer_api.py b/test/auto_parallel/test_shard_layer_api.py index fb0476303cd6e9..20e3d13946056c 100644 --- a/test/auto_parallel/test_shard_layer_api.py +++ b/test/auto_parallel/test_shard_layer_api.py @@ -14,6 +14,8 @@ import unittest +import numpy as np + import paddle import paddle.distributed as dist from paddle import nn @@ -43,6 +45,33 @@ def forward(self, x): return self.seq(x) +def shard_fn(layer_name, layer, process_mesh): + if isinstance(layer, nn.Linear): + for name, param in layer.named_parameters(): + if 'weight' in name: + dist_param = dist.shard_tensor( + param, process_mesh, [dist.Replicate()] + ) + else: + dist_param = dist.shard_tensor( + param, process_mesh, [dist.Replicate()] + ) + layer.add_parameter(name, dist_param) + + +class RandomDataset(paddle.io.Dataset): + def __init__(self, images, labels, num_samples): + self.images = images + self.labels = labels + self.num_samples = num_samples + + def __getitem__(self, idx): + return self.images[idx], self.labels[idx] + + def __len__(self): + return self.num_samples + + class TestShardLayer(unittest.TestCase): def setUp(self): self.mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) @@ -52,19 +81,6 @@ def setUp(self): def test_shard_layer_base(self): layer = MyLayer(self.num_features, self.num_layers) - def shard_fn(layer_name, layer, process_mesh): - if isinstance(layer, nn.Linear): - for name, param in layer.named_parameters(): - if 'weight' in name: - dist_param = dist.shard_tensor( - param, process_mesh, [dist.Replicate()] - ) - else: - dist_param = dist.shard_tensor( - param, process_mesh, [dist.Replicate()] - ) - layer.add_parameter(name, dist_param) - # test shard parameters sharded_params_layer = dist.shard_layer(layer, self.mesh, shard_fn) @@ -155,11 +171,85 @@ def test_shard_layer_static_mode(self): dist.shard_layer(layer, self.mesh) except NotImplementedError as ex: self.assertIn( - "`paddle.distributed.shard_layer` only supports dynamic graph mode now", + "`paddle.distributed.shard_layer` only supports dynamic graph mode.", str(ex), ) exception = ex self.assertIsNotNone(exception) + paddle.disable_static() + + def create_data_loader(self): + batch_size = 4 + hidden_size = self.num_features + images = np.random.rand(batch_size, hidden_size).astype('float32') + labels = np.random.rand(batch_size, hidden_size).astype('float32') + dataset = RandomDataset(images, labels, batch_size) + loader = paddle.io.DataLoader(dataset, batch_size=batch_size) + return loader + + def test_shard_layer_to_static(self): + def input_fn(inputs, process_mesh): + return dist.shard_tensor( + inputs[0], process_mesh, [dist.Replicate()] + ) + + def output_fn(outputs, process_mesh): + return dist.shard_tensor(outputs, process_mesh, [dist.Shard(0)]) + + layer = MyLayer(self.num_features, self.num_layers) + + sharded_layer = dist.shard_layer( + layer, self.mesh, shard_fn, input_fn=input_fn, output_fn=output_fn + ) + + loader = self.create_data_loader() + + dist_model, dist_loader = dist.to_static(sharded_layer, loader) + + serial_main_program = dist_model.serial_main_program() + for param in serial_main_program.all_parameters(): + self.assertTrue(param.dist_attr.is_annotated("dims_mapping")) + self.assertEqual(param.dist_attr.dims_mapping, [-1, -1]) + + input_var = serial_main_program.global_block().var("input0") + output_var = serial_main_program.global_block().var( + "matmul_v2_19.tmp_0" + ) + self.assertListEqual(input_var.dist_attr.dims_mapping, [-1, -1]) + self.assertListEqual(output_var.dist_attr.dims_mapping, [0, -1]) + + paddle.disable_static() + + def test_shard_layer_to_static_with_buffer(self): + layer = MyLayer(self.num_features, self.num_layers) + test_buffer0 = paddle.randn([3]) + layer.register_buffer("test_buffer0", test_buffer0, persistable=True) + test_buffer1 = paddle.randn([10]) + layer.register_buffer("test_buffer1", test_buffer1, persistable=True) + layer.test_buffer1 = dist.shard_tensor( + layer.test_buffer1, self.mesh, [dist.Shard(0)] + ) + sharded_buffers_layer = dist.shard_layer(layer, self.mesh, shard_fn) + + loader = self.create_data_loader() + dist_model, dist_loader = dist.to_static(sharded_buffers_layer, loader) + + serial_main_program = dist_model.serial_main_program() + for param in serial_main_program.all_parameters(): + self.assertTrue(param.dist_attr.is_annotated("dims_mapping")) + self.assertEqual(param.dist_attr.dims_mapping, [-1, -1]) + + buffer_vars = [ + var + for var in serial_main_program.list_vars() + if var.name.startswith("generated") + ] + buffer0_var = buffer_vars[1] + buffer1_var = buffer_vars[0] + self.assertTrue(buffer0_var.dist_attr.is_annotated("dims_mapping")) + self.assertEqual(buffer0_var.dist_attr.dims_mapping, [-1]) + self.assertTrue(buffer1_var.dist_attr.is_annotated("dims_mapping")) + self.assertEqual(buffer1_var.dist_attr.dims_mapping, [0]) if __name__ == '__main__': From c4bc9e15e9293f4d38afa2231e1106a36205bacf Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 29 Dec 2023 20:24:47 +0800 Subject: [PATCH 054/142] [PirPass]Refine ApplyPirPass logic in @to_static (#59702) * [PirPass]Refine ApplyPirPass logic in @to_static * fix shared_ptr * fix code * del usless code * fix codestyle * fix UT --- .../group_merge/cinn_group_lowering_pass.cc | 4 +- .../hlir/framework/pir/op_lowering_impl.cc | 2 + paddle/fluid/pybind/pir.cc | 27 +++--- .../jit/dy2static/pir_partial_program.py | 95 +++++-------------- test/ir/pir/cinn/test_cinn_sub_graph.py | 7 +- 5 files changed, 45 insertions(+), 90 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/cinn_group_lowering_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/cinn_group_lowering_pass.cc index f11613ead1bfc9..f4aa34bbc72638 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/cinn_group_lowering_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/cinn_group_lowering_pass.cc @@ -211,7 +211,7 @@ class GroupOpPattern : public pir::OpRewritePattern { } private: - std::shared_ptr shape_analysis_; + std::shared_ptr shape_analysis_{nullptr}; }; class CinnGroupLoweringPass : public pir::PatternRewritePass { @@ -237,7 +237,7 @@ class CinnGroupLoweringPass : public pir::PatternRewritePass { } private: - const std::shared_ptr& shape_analysis_; + std::shared_ptr shape_analysis_{nullptr}; }; } // namespace diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc index 1255a05825bab6..643e4ed294b4cd 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc @@ -520,6 +520,7 @@ std::vector OpLowererImpl::LowerOps( auto& strategy = Operator::GetAttrs("CINNStrategy"); std::vector func_bodies; for (auto* op : ops) { + VLOG(4) << "start lowering op:" << op->name(); // 1.Select Op impl std::vector op_func_arg_tensors = CollectInputTensor(group, op, group_func_arg_tensors, tensor_map); @@ -891,6 +892,7 @@ ir::Tensor OpLowererImpl::GetTensor(const GroupPtr& group, auto in_shape = ::common::vectorize(type_info.dims()); auto dtype = type_info.dtype(); std::string input_id = ValueName(value); + VLOG(3) << "group->shape_analysis:" << group->shape_analysis; if (group->shape_analysis != nullptr) { auto sym_vec = group->shape_analysis->GetOrCreateSymbolicDimsForRankedValue(value); diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index e2471842c07291..8813ff59de53e4 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -1607,37 +1607,36 @@ static bool HasDynamicShape(const Program &program) { return false; } -void ApplyPirPass(Program &forward_program) { // NOLINT +void AddCinnPass(std::shared_ptr &pass_manager, // NOLINT + Program &program) { // NOLINT #ifdef PADDLE_WITH_CINN pir::IrContext *ctx = pir::IrContext::Instance(); ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); - bool has_dynamic_shape = HasDynamicShape(forward_program); + bool has_dynamic_shape = HasDynamicShape(program); auto shape_analysis = has_dynamic_shape ? std::make_shared(ctx) : nullptr; - pir::PassManager pass_manager(ctx); - pass_manager.AddPass(pir::CreateShapeOptimizationPass()); - cinn::dialect::ir::PdOp2CinnOpConverter(&forward_program); + pass_manager->AddPass(pir::CreateShapeOptimizationPass()); + cinn::dialect::ir::PdOp2CinnOpConverter(&program); - pass_manager.AddPass( + pass_manager->AddPass( std::make_unique()); - pass_manager.AddPass(pir::CreateDeadCodeEliminationPass()); - pass_manager.AddPass(pir::CreateBuildCinnPass()); + pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); + pass_manager->AddPass(pir::CreateBuildCinnPass()); if (has_dynamic_shape) { - pass_manager.AddPass(pir::CreateInferSymbolicShapePass(shape_analysis)); + pass_manager->AddPass(pir::CreateInferSymbolicShapePass(shape_analysis)); } - pass_manager.AddPass( + pass_manager->AddPass( cinn::dialect::ir::CreateCinnGroupLoweringPass(shape_analysis)); - - pass_manager.Run(&forward_program); - VLOG(3) << "after BuildCinnPass, forward_program:\n" << forward_program; + VLOG(4) << "has_dynamic_shape :" << has_dynamic_shape + << ", shape_analysis: " << shape_analysis; #else PADDLE_THROW(platform::errors::Unimplemented( "Currently we only support CINN Pass for Pir under @to_static, please " @@ -1645,7 +1644,7 @@ void ApplyPirPass(Program &forward_program) { // NOLINT #endif } void BindIrPass(pybind11::module *m) { - m->def("apply_pir_pass", ApplyPirPass); + m->def("add_cinn_pass", AddCinnPass); py::class_> pass(*m, "Pass", diff --git a/python/paddle/jit/dy2static/pir_partial_program.py b/python/paddle/jit/dy2static/pir_partial_program.py index a5858df1886e8f..88b51f827581c9 100644 --- a/python/paddle/jit/dy2static/pir_partial_program.py +++ b/python/paddle/jit/dy2static/pir_partial_program.py @@ -280,11 +280,12 @@ def apply_pir_program_pass(self, pass_fn): def pass_fn(forward_program, backward_program): return forward_program, backward_program """ - program_name_attr = self.program_name_attr origin_fwd = self.forward_program origin_bwd = self.backward_program + # NOTE(dev): Add this line to trigger program_name_attr logic + program_name_attr = self.program_name_attr self.forward_program, self.backward_program = pass_fn( - self.forward_program, self.backward_program, program_name_attr + origin_fwd, origin_bwd ) # cached property can ensure program is splited only once. @@ -382,55 +383,6 @@ def backward_program(self): return self._forward_backward_program[0][1] -class PirPassContext: - """ - PirPassContext is a class that only has staticmethod currently. - It will create a new RunableProgram after calling apply method. - """ - - INPUT_OP_NAME = "pd_op.data" - PARAM_OP_NAME = "builtin.parameter" - OUTPUT_OP_NAME = "builtin.shadow_output" - - @classmethod - def apply(cls, runable_program, build_strategy): - # TODO(Aurelius84): Currently only support infer mode, - # and we just use forward_program because backward_program - # is empty. - if not build_strategy.build_cinn_pass: - return runable_program - elif not paddle.is_compiled_with_cinn(): - raise RuntimeError( - "Please install PaddlePaddle compiled with CINN while setting build_strategy.build_cinn_pass = True." - ) - fwd_program, _ = paddle.base.libpaddle.pir.clone_program( - runable_program.forward_program - ) - paddle.base.libpaddle.pir.apply_pir_pass(fwd_program) - in_out_values = cls._prepare_attr(fwd_program) - return RunableProgram(fwd_program, in_out_values) - - @classmethod - def _prepare_attr(cls, program): - """ - After applying Pass, we need to update the Input/Parameter/Output Value - that refer to the new program. - - NOTE: We assume that Inputs come from INPUT_OP, Params come from - PARM_OP and Output come from OUTPUT_OP. - """ - inputs, params, outputs = [], [], [] - for op in program.global_block().ops: - op_name = op.name() - if op_name == cls.INPUT_OP_NAME: - inputs.append(op.result(0)) - elif op_name == cls.PARAM_OP_NAME: - params.append(op.result(0)) - elif op_name == cls.OUTPUT_OP_NAME: - outputs.append(op.operand(0).source()) - return inputs, params, outputs - - class PartialProgramLayerHook: def before_append_backward(self, forward_program, src_vars): ... @@ -596,13 +548,19 @@ def _get_scope(self, program_id=None, use_scope_cache=False): @switch_to_static_graph def _create_program(self, is_infer_mode=False): if is_infer_mode: + + def pass_fn(forward_program, backward_program): + pm = paddle.base.libpaddle.pir.PassManager() + if self._build_strategy.build_cinn_pass: + paddle.base.libpaddle.pir.add_cinn_pass(pm, forward_program) + pm.run(forward_program) + return forward_program, backward_program + # TODO(xiongkun) who to transfer the pruning program? infer_program = self.origin_runable_program.clone() if self._hooker: self._hooker.after_infer(infer_program) - infer_program = PirPassContext.apply( - infer_program, self._build_strategy - ) + infer_program.apply_pir_program_pass(pass_fn) return infer_program else: train_program: RunableProgram = self.origin_runable_program.clone() @@ -610,23 +568,20 @@ def _create_program(self, is_infer_mode=False): # Note: Only set grad type once after initializing train program. So we put it here. self._set_grad_type(self._params, train_program) - # (NOTE:@xiongkun) HOW TO APPLY PASS: this is a example for forward/backward clone pass, just replace with your cases. - def pass_fn(forward_program, backward_program, name_attr): - fwd, _ = paddle.base.libpaddle.pir.clone_program( - forward_program - ) - - if self._build_strategy.build_cinn_pass: - paddle.base.libpaddle.pir.apply_pir_pass(fwd) - - bwd, _ = paddle.base.libpaddle.pir.clone_program( - backward_program - ) + def pass_fn(forward_program, backward_program): + fwd_pm = paddle.base.libpaddle.pir.PassManager() + bwd_pm = paddle.base.libpaddle.pir.PassManager() if self._build_strategy.build_cinn_pass: - paddle.base.libpaddle.pir.apply_pir_pass(bwd) - - return fwd, bwd + paddle.base.libpaddle.pir.add_cinn_pass( + fwd_pm, forward_program + ) + paddle.base.libpaddle.pir.add_cinn_pass( + bwd_pm, backward_program + ) + fwd_pm.run(forward_program) + bwd_pm.run(backward_program) + return forward_program, backward_program train_program.apply_pir_program_pass(pass_fn) return train_program @@ -748,7 +703,7 @@ def _insert_aggregation_ops_for_var(target_program, var): shape=var.shape, ) # step2: rename the var.name@GRAD to var.name@GRAD@dy2static - for idx, op in finded_ops: + for _, op in finded_ops: op._rename_input(var_grad_name, new_grad_name) op._rename_output(var_grad_name, new_grad_name) # step3: insert sum op to aggregate the gradient. diff --git a/test/ir/pir/cinn/test_cinn_sub_graph.py b/test/ir/pir/cinn/test_cinn_sub_graph.py index 32b0bd5779dd92..ad4c65d3d35413 100644 --- a/test/ir/pir/cinn/test_cinn_sub_graph.py +++ b/test/ir/pir/cinn/test_cinn_sub_graph.py @@ -203,10 +203,9 @@ def test_forward(self): cinn_out = self.train(use_cinn=True) dy_out = self.train(use_cinn=False) - # TODO(zhangliujie) fix precision error - # np.testing.assert_allclose( - # cinn_out.numpy(), dy_out.numpy(), atol=1e-8, rtol=1e-4 - # ) + np.testing.assert_allclose( + cinn_out.numpy(), dy_out.numpy(), atol=1e-8, rtol=1e-4 + ) class TestCinnDropout(TestCinnSubGraphBase): From 5bc7a5926308d2c1a22d1c696d98a199bf60ff3c Mon Sep 17 00:00:00 2001 From: gouzil <66515297+gouzil@users.noreply.github.com> Date: Fri, 29 Dec 2023 23:17:01 +0800 Subject: [PATCH 055/142] [CodeStyle][ruff] clean I001 ignore - Part 2 (#60466) --- pyproject.toml | 1 - python/paddle/distribution/__init__.py | 60 ++++++++++++++++---------- 2 files changed, 37 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 64727e39f1d643..eaf239288bd39a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -135,6 +135,5 @@ known-first-party = ["paddle"] "python/paddle/distributed/launch/controllers/__init__.py" = ["I001"] "python/paddle/distributed/passes/__init__.py" = ["I001"] "python/paddle/distributed/rpc/__init__.py" = ["I001"] -"python/paddle/distribution/__init__.py" = ["I001"] "python/paddle/incubate/distributed/fleet/__init__.py" = ["I001"] "python/paddle/incubate/distributed/fleet/parameter_server/pslib/__init__.py" = ["I001"] diff --git a/python/paddle/distribution/__init__.py b/python/paddle/distribution/__init__.py index c56da5805ad668..446c75aeaea700 100644 --- a/python/paddle/distribution/__init__.py +++ b/python/paddle/distribution/__init__.py @@ -12,29 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle.distribution import transform -from paddle.distribution.bernoulli import Bernoulli -from paddle.distribution.beta import Beta -from paddle.distribution.categorical import Categorical -from paddle.distribution.cauchy import Cauchy -from paddle.distribution.continuous_bernoulli import ContinuousBernoulli -from paddle.distribution.dirichlet import Dirichlet -from paddle.distribution.distribution import Distribution -from paddle.distribution.gumbel import Gumbel -from paddle.distribution.exponential_family import ExponentialFamily -from paddle.distribution.independent import Independent -from paddle.distribution.kl import kl_divergence, register_kl -from paddle.distribution.lognormal import LogNormal -from paddle.distribution.multinomial import Multinomial -from paddle.distribution.multivariate_normal import MultivariateNormal -from paddle.distribution.normal import Normal -from paddle.distribution.transform import * # noqa: F403 -from paddle.distribution.transformed_distribution import TransformedDistribution -from paddle.distribution.uniform import Uniform -from paddle.distribution.laplace import Laplace -from paddle.distribution.geometric import Geometric -from paddle.distribution.binomial import Binomial -from paddle.distribution.poisson import Poisson +from . import transform +from .bernoulli import Bernoulli +from .beta import Beta +from .binomial import Binomial +from .categorical import Categorical +from .cauchy import Cauchy +from .continuous_bernoulli import ContinuousBernoulli +from .dirichlet import Dirichlet +from .distribution import Distribution +from .exponential_family import ExponentialFamily +from .geometric import Geometric +from .gumbel import Gumbel +from .independent import Independent +from .kl import kl_divergence, register_kl +from .laplace import Laplace +from .lognormal import LogNormal +from .multinomial import Multinomial +from .multivariate_normal import MultivariateNormal +from .normal import Normal +from .poisson import Poisson +from .transform import ( # noqa:F401 + AbsTransform, + AffineTransform, + ChainTransform, + ExpTransform, + IndependentTransform, + PowerTransform, + ReshapeTransform, + SigmoidTransform, + SoftmaxTransform, + StackTransform, + StickBreakingTransform, + TanhTransform, + Transform, +) +from .transformed_distribution import TransformedDistribution +from .uniform import Uniform __all__ = [ 'Bernoulli', From 3177d59b2915ce345964449a5189e16a8e0ca544 Mon Sep 17 00:00:00 2001 From: ronnywang Date: Sat, 30 Dec 2023 20:58:49 +0800 Subject: [PATCH 056/142] [CustomDevice] release all xccl_comm in DeviceManager::Release (#60465) --- paddle/phi/backends/device_manager.cc | 4 +++ .../phi/core/distributed/xccl_comm_context.cc | 27 +++++++++++++++++++ .../phi/core/distributed/xccl_comm_context.h | 3 +++ 3 files changed, 34 insertions(+) diff --git a/paddle/phi/backends/device_manager.cc b/paddle/phi/backends/device_manager.cc index 1e57fb736b7c26..87a163b2cb4fa2 100644 --- a/paddle/phi/backends/device_manager.cc +++ b/paddle/phi/backends/device_manager.cc @@ -14,6 +14,7 @@ #include "paddle/phi/backends/device_manager.h" #include "paddle/phi/common/complex.h" +#include "paddle/phi/core/distributed/xccl_comm_context.h" #if !defined(_WIN32) #include @@ -699,6 +700,9 @@ DeviceManager& DeviceManager::Instance() { void DeviceManager::Release() { event::Event::ReleaseAll(); stream::Stream::ReleaseAll(); +#ifdef PADDLE_WITH_CUSTOM_DEVICE + phi::distributed::XCCLCommContext::ReleaseAll(); +#endif Instance().device_map_.clear(); Instance().device_impl_map_.clear(); } diff --git a/paddle/phi/core/distributed/xccl_comm_context.cc b/paddle/phi/core/distributed/xccl_comm_context.cc index ba7e24ab06b9e1..3e3608e4d88a59 100644 --- a/paddle/phi/core/distributed/xccl_comm_context.cc +++ b/paddle/phi/core/distributed/xccl_comm_context.cc @@ -14,6 +14,8 @@ #include "paddle/phi/core/distributed/xccl_comm_context.h" +#include + #include "glog/logging.h" #include "paddle/phi/core/dense_tensor.h" @@ -25,6 +27,29 @@ namespace phi { namespace distributed { +std::list g_xccl_comm_contexts; +std::mutex g_xccl_comm_contexts_mutex; + +void XCCLCommContext::ReleaseAll() { + std::unique_lock lock(g_xccl_comm_contexts_mutex); + for (auto xccl_comm_ctx : g_xccl_comm_contexts) { + phi::DeviceManager::CCLDestroyComm(xccl_comm_ctx->GetDeviceType(), + xccl_comm_ctx->GetXcclComm()); + xccl_comm_ctx->xccl_comm_ = nullptr; + } + g_xccl_comm_contexts.clear(); +} + +XCCLCommContext::~XCCLCommContext() { + std::unique_lock lock(g_xccl_comm_contexts_mutex); + if (phi::DeviceManager::HasDeviceType(this->GetDeviceType()) && + xccl_comm_ != nullptr) { + phi::DeviceManager::CCLDestroyComm(this->GetDeviceType(), xccl_comm_); + xccl_comm_ = nullptr; + } + g_xccl_comm_contexts.remove(this); +} + XCCLCommContext::XCCLCommContext(const phi::Place& place, int rank, int size, @@ -38,6 +63,8 @@ XCCLCommContext::XCCLCommContext(const phi::Place& place, &xccl_comm_); stream_ = std::make_shared(); stream_->Init(place_); + std::unique_lock lock(g_xccl_comm_contexts_mutex); + g_xccl_comm_contexts.push_back(this); } void XCCLCommContext::Broadcast(phi::DenseTensor* out_tensor, diff --git a/paddle/phi/core/distributed/xccl_comm_context.h b/paddle/phi/core/distributed/xccl_comm_context.h index 0c253eb925bb4d..8cdc7e4153d767 100644 --- a/paddle/phi/core/distributed/xccl_comm_context.h +++ b/paddle/phi/core/distributed/xccl_comm_context.h @@ -28,6 +28,9 @@ class XCCLCommContext final : public CommContext { int rank, int size, const ccl::CCLRootId& xccl_id); + ~XCCLCommContext(); + + static void ReleaseAll(); ccl::CCLComm GetXcclComm() const { return xccl_comm_; } From 77d7638a75bf527c1db3b7df3688102b4210d74f Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Tue, 2 Jan 2024 10:27:53 +0800 Subject: [PATCH 057/142] [DimExpr] DimExpr support hash (#60471) --- paddle/pir/dialect/shape/utils/dim_expr.cc | 52 +++++++++++++++++++ paddle/pir/dialect/shape/utils/dim_expr.h | 13 +++++ .../pir/shape_dialect/symbol_dim_expr_test.cc | 34 +++++++++--- 3 files changed, 93 insertions(+), 6 deletions(-) diff --git a/paddle/pir/dialect/shape/utils/dim_expr.cc b/paddle/pir/dialect/shape/utils/dim_expr.cc index 0d9b6ece23245c..61f7a582cb5a56 100644 --- a/paddle/pir/dialect/shape/utils/dim_expr.cc +++ b/paddle/pir/dialect/shape/utils/dim_expr.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/pir/dialect/shape/utils/dim_expr.h" +#include "paddle/pir/core/utils.h" namespace symbol { @@ -184,4 +185,55 @@ std::ostream& operator<<(std::ostream& stream, const DimExpr& dim_expr) { return stream; } +namespace { + +std::size_t GetHashValueImpl(const std::int64_t& dim_expr) { return dim_expr; } + +std::size_t GetHashValueImpl(const std::string& dim_expr) { + return std::hash()(dim_expr); +} + +std::size_t GetHashValueImpl(const Negative& dim_expr) { + return -GetHashValue(dim_expr->data); +} + +std::size_t GetHashValueImpl(const Reciprocal& dim_expr) { + return pir::hash_combine(1, -GetHashValue(dim_expr->data)); +} + +std::size_t GetHashValueImpl(const List& exprs) { + std::size_t ret = 0; + for (const auto& expr : *exprs) { + ret = pir::hash_combine(ret, GetHashValue(expr)); + } + return ret; +} + +std::size_t GetHashValueImpl(const Add& dim_expr) { + return pir::hash_combine(1, GetHashValueImpl(dim_expr.operands)); +} + +std::size_t GetHashValueImpl(const Mul& dim_expr) { + return pir::hash_combine(2, GetHashValueImpl(dim_expr.operands)); +} + +std::size_t GetHashValueImpl(const Max& dim_expr) { + return pir::hash_combine(3, GetHashValueImpl(dim_expr.operands)); +} + +std::size_t GetHashValueImpl(const Min& dim_expr) { + return pir::hash_combine(4, GetHashValueImpl(dim_expr.operands)); +} + +std::size_t GetHashValueImpl(const Broadcast& dim_expr) { + return pir::hash_combine(5, GetHashValueImpl(dim_expr.operands)); +} + +} // namespace + +std::size_t GetHashValue(const DimExpr& dim_expr) { + return std::visit([](const auto& impl) { return GetHashValueImpl(impl); }, + dim_expr.variant()); +} + } // namespace symbol diff --git a/paddle/pir/dialect/shape/utils/dim_expr.h b/paddle/pir/dialect/shape/utils/dim_expr.h index 277a6febe66ed7..a65390200cd062 100644 --- a/paddle/pir/dialect/shape/utils/dim_expr.h +++ b/paddle/pir/dialect/shape/utils/dim_expr.h @@ -253,4 +253,17 @@ IR_API std::string ToString(const DimExpr& dim_expr); IR_API std::ostream& operator<<(std::ostream&, const DimExpr& dim_expr); +IR_API std::size_t GetHashValue(const DimExpr& dim_expr); + } // namespace symbol + +namespace std { + +template <> +struct hash { + std::size_t operator()(const symbol::DimExpr& dim_expr) const { + return symbol::GetHashValue(dim_expr); + } +}; + +} // namespace std diff --git a/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc b/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc index 6157850e3842c3..3aebb367d1a272 100644 --- a/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc +++ b/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc @@ -22,7 +22,7 @@ namespace symbol::test { // Construct DimExpr by overloaded operator(+, - , *, /) -TEST(DimExpr, dim_expr_naive) { +TEST(DimExpr, DimExprNaive) { DimExpr sym0 = DimExpr("S0"); DimExpr sym1 = DimExpr("S1"); DimExpr constant1 = DimExpr(1); @@ -30,7 +30,7 @@ TEST(DimExpr, dim_expr_naive) { } // Construct DimExpr by DimExprBuilder -TEST(DimExpr, dim_expr_builder) { +TEST(DimExpr, DimExprBuilder) { DimExprBuilder builder{nullptr}; DimExpr sym0 = DimExpr("S0"); DimExpr sym1 = DimExpr("S1"); @@ -40,7 +40,7 @@ TEST(DimExpr, dim_expr_builder) { } // Add constraints by DimExprBuilder -TEST(DimExpr, constraint) { +TEST(DimExpr, Constraint) { std::vector constraints{}; DimExprBuilder builder(&constraints); DimExpr sym0 = DimExpr("S0"); @@ -55,7 +55,7 @@ TEST(DimExpr, constraint) { extend_x = x.shape out = pd.reshape(y, extend_x) */ -TEST(DimExpr, data_shape_expr) { +TEST(DimExpr, DataShapeExpr) { // Show ideal ShapeOrDataDimExprs of each pir::Value std::vector x_shapes{DimExpr("S0"), DimExpr(2)}; std::vector y_shapes{DimExpr(1), DimExpr("S1"), DimExpr(2)}; @@ -80,7 +80,7 @@ TEST(Simplify, NumberArithmetic) { ASSERT_EQ((mul_div.Get()), 1); } -TEST(DimExpr, equal) { +TEST(DimExpr, Equal) { DimExprBuilder builder{nullptr}; DimExpr sym0 = DimExpr("S0"); DimExpr sym1 = DimExpr("S1"); @@ -111,7 +111,7 @@ TEST(DimExpr, equal) { builder.Broadcast(DimExpr("S0"), constant1)); } -TEST(DimExpr, print) { +TEST(DimExpr, Print) { DimExprBuilder builder{nullptr}; DimExpr sym0 = DimExpr("S0"); DimExpr sym1 = DimExpr("S1"); @@ -124,4 +124,26 @@ TEST(DimExpr, print) { ASSERT_EQ((ToString(builder.Broadcast(sym0, sym1))), "Broadcast(S0, S1)"); } +TEST(DimExpr, Hash) { + DimExprBuilder builder{nullptr}; + DimExpr sym0 = DimExpr("S0"); + DimExpr sym1 = DimExpr("S1"); + ASSERT_EQ((std::hash()(sym0 + sym1)), + (std::hash()(sym0 + sym1))); + ASSERT_NE((std::hash()(sym0 + sym1)), + (std::hash()(sym1 + sym0))); + ASSERT_NE((std::hash()(sym0 + sym1)), + (std::hash()(sym0 - sym1))); + ASSERT_NE((std::hash()(sym0 + sym1)), + (std::hash()(sym0 * sym1))); + ASSERT_NE((std::hash()(sym0 + sym1)), + (std::hash()(sym0 / sym1))); + ASSERT_NE((std::hash()(sym0 + sym1)), + (std::hash()(builder.Max(sym0, sym1)))); + ASSERT_NE((std::hash()(sym0 + sym1)), + (std::hash()(builder.Min(sym0, sym1)))); + ASSERT_NE((std::hash()(sym0 + sym1)), + (std::hash()(builder.Broadcast(sym0, sym1)))); +} + } // namespace symbol::test From 8fcf35b724cddc2fd327f0656850e70efbb74ddb Mon Sep 17 00:00:00 2001 From: ooo oo <106524776+ooooo-create@users.noreply.github.com> Date: Tue, 2 Jan 2024 10:38:25 +0800 Subject: [PATCH 058/142] open warning with `paddle.utils.deprecated` (#60458) * open_warning * update unittest * update * fix typos * fix warning in test runner * uncomment * cleanup todo * using VisibleDeprecationWarning * update comment * fix typo * fix indentation * fix * fix * fix indent level and test * update --------- Co-authored-by: SigureMo --- python/paddle/utils/deprecated.py | 23 ++++++-- test/legacy_test/CMakeLists.txt | 2 - test/legacy_test/test_deprecated_decorator.py | 59 ++++++++----------- 3 files changed, 42 insertions(+), 42 deletions(-) diff --git a/python/paddle/utils/deprecated.py b/python/paddle/utils/deprecated.py index 873c6b3a6a9fce..39b1f737480987 100755 --- a/python/paddle/utils/deprecated.py +++ b/python/paddle/utils/deprecated.py @@ -16,6 +16,7 @@ """ import functools +import inspect import sys import warnings @@ -24,6 +25,18 @@ __all__ = [] +class VisibleDeprecationWarning(UserWarning): + """Visible deprecation warning. + + Since Python 3.7, Python only show the DeprecationWarning if the module + is __main__. So we use this warning to make the deprecation warning visible. + + See more details from https://peps.python.org/pep-0565/ + """ + + ... + + def deprecated(update_to="", since="", reason="", level=0): """Decorate a function to signify its deprecation. @@ -47,8 +60,6 @@ def deprecated(update_to="", since="", reason="", level=0): """ def decorator(func): - # TODO(zhiqiu): temporally disable the warnings - return func """construct warning message, and return a decorated function or class.""" assert isinstance(update_to, str), 'type of "update_to" must be str.' assert isinstance(since, str), 'type of "since" must be str.' @@ -75,9 +86,11 @@ def decorator(func): ) msg += f' Please use "{_update_to}" instead.' if len(_reason) > 0: - msg += f"\nreason: {_reason}" + msg += f"\n Reason: {_reason}" if func.__doc__: - func.__doc__ = ('\n\nWarning: ' + msg + '\n') + func.__doc__ + func.__doc__ = ( + '\n\nWarning:\n ' + msg + '\n\n' + ) + inspect.cleandoc(func.__doc__) if level == 0: return func @@ -110,7 +123,7 @@ def wrapper(*args, **kwargs): or v_current >= v_since ): warnings.warn( - warningmsg, category=DeprecationWarning, stacklevel=2 + warningmsg, category=VisibleDeprecationWarning, stacklevel=2 ) return func(*args, **kwargs) diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 824d50d8a6aaf7..ed0f40f982d23c 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -118,8 +118,6 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) list(REMOVE_ITEM TEST_OPS test_fleet_executor_cond_interceptor) endif() -list(REMOVE_ITEM TEST_OPS test_deprecated_decorator) - if(WIN32) list(REMOVE_ITEM TEST_OPS test_multiprocess_reader_exception) list(REMOVE_ITEM TEST_OPS test_trainer_desc) diff --git a/test/legacy_test/test_deprecated_decorator.py b/test/legacy_test/test_deprecated_decorator.py index 81ae80a1f9bf78..0c7ce0f0625900 100755 --- a/test/legacy_test/test_deprecated_decorator.py +++ b/test/legacy_test/test_deprecated_decorator.py @@ -19,23 +19,20 @@ import numpy as np import paddle -from paddle import _legacy_C_ops from paddle.utils import deprecated LOWEST_WARNING_POSTION = 3 ERROR_WARNING_POSTION = sys.maxsize # custom paddle version -paddle.version.major = '1' -paddle.version.minor = '8' +paddle.version.major = '0' +paddle.version.minor = '0' paddle.version.patch = '0' paddle.version.rc = '0' -paddle.__version__ = '1.8.0' -paddle.version.full_version = '1.8.0' +paddle.__version__ = '0.0.0' +paddle.version.full_version = '0.0.0' print("current paddle version: ", paddle.__version__) -paddle.disable_static() - def get_warning_index(api): """ @@ -49,22 +46,25 @@ def get_warning_index(api): index (int): the index of the Warinng information in its doc string if exists. """ - doc_lst = api.__doc__.splitlines() - for idx, val in enumerate(doc_lst): + doc_list = api.__doc__.splitlines() + if len(doc_list) < 2: + return ERROR_WARNING_POSTION + for idx, (current_line, next_line) in enumerate( + zip(doc_list[:-1], doc_list[1:]) + ): if ( - val.startswith("Warning: ") - and val.endswith(" instead.") - and "and will be removed in future versions." in val + current_line == "Warning:" + and next_line.endswith(" instead.") + and "and will be removed in future versions." in next_line ): return idx return ERROR_WARNING_POSTION -class TestDeprecatedDocorator(unittest.TestCase): +class TestDeprecatedDecorator(unittest.TestCase): """ - tests for paddle's Deprecated Docorator. + tests for paddle's deprecated decorator. test_new_multiply: test for new api, which should not insert warning information. - test_ops_elementwise_mul: test for C++ elementwise_mul op, which should not insert warning information. """ def test_new_multiply(self): @@ -87,26 +87,15 @@ def test_new_multiply(self): # testting self.assertLess(expected, captured) - def test_ops_elementwise_mul(self): - """ - Test for new C++ elementwise_op, expected result should be True, - because not matter what base.layers.elementwise_mul is deprecated. - """ - - a = np.random.uniform(0.1, 1, [51, 76]).astype(np.float32) - b = np.random.uniform(0.1, 1, [51, 76]).astype(np.float32) - x = paddle.to_tensor(a) - y = paddle.to_tensor(b) - res = _legacy_C_ops.elementwise_mul(x, y) - - # expected - expected = LOWEST_WARNING_POSTION - - # captured - captured = get_warning_index(paddle.multiply) - - # testting - self.assertGreater(expected, captured) + def test_indent_level(self): + # test for different indent_level + dataset = paddle.base.DatasetFactory().create_dataset("InMemoryDataset") + with warnings.catch_warnings(record=True): + dataset.set_merge_by_lineid() + assert ( + '\nSet merge by' + in paddle.base.InMemoryDataset.set_merge_by_lineid.__doc__ + ) def test_tensor_gradient(self): paddle.__version__ = '2.1.0' From a08580e25c6db807e7ba7318550f566db55ac1f8 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Tue, 2 Jan 2024 10:45:08 +0800 Subject: [PATCH 059/142] [AutoParallel] Auto Trans PP to VPP (#60467) * [AutoParallel] Auto Trans PP to VPP * add comment --- .../auto_parallel/static/completion.py | 189 ++++++++++++++---- .../distributed/auto_parallel/static/utils.py | 2 +- .../pipeline_scheduler_vpp_unittest.py | 80 ++++++-- 3 files changed, 207 insertions(+), 64 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/static/completion.py b/python/paddle/distributed/auto_parallel/static/completion.py index 76c6a9d181766d..692d02b7563c6e 100644 --- a/python/paddle/distributed/auto_parallel/static/completion.py +++ b/python/paddle/distributed/auto_parallel/static/completion.py @@ -1057,22 +1057,43 @@ def set_chunk_id(block, op, chunk_id, var_to_chunk_id): dist_op = self._dist_context.get_dist_op_for_program(op) dist_op.dist_attr.chunk_id = chunk_id for name in op.input_arg_names + op.output_arg_names: - var = block._find_var_recursive(name) if "lod_tensor_blocking_queue" in name: continue if name not in var_to_chunk_id: - op_dist_attr = ( - self._dist_context.get_op_dist_attr_for_program(op) + var = block._find_var_recursive(name) + dist_tensor = ( + self._dist_context.get_dist_tensor_for_program(var) ) - tensor_dist_attr = ( - self._dist_context.get_tensor_dist_attr_for_program(var) + if ( + dist_op.dist_attr.process_mesh + == dist_tensor.dist_attr.process_mesh + ): + dist_tensor.dist_attr.chunk_id = chunk_id + var_to_chunk_id[var.name] = chunk_id + + def set_process_mesh(block, op, process_mesh, var_to_process_mesh): + dist_op = self._dist_context.get_dist_op_for_program(op) + for name in op.input_arg_names: + if name not in var_to_process_mesh: + var = block._find_var_recursive(name) + dist_tensor = ( + self._dist_context.get_dist_tensor_for_program(var) ) if ( - op_dist_attr.process_mesh - == tensor_dist_attr.process_mesh + dist_op.dist_attr.process_mesh + == dist_tensor.dist_attr.process_mesh ): - tensor_dist_attr.chunk_id = op_dist_attr.chunk_id - var_to_chunk_id[var.name] = op_dist_attr.chunk_id + dist_tensor.dist_attr.process_mesh = process_mesh + var_to_process_mesh[var.name] = process_mesh + for name in op.output_arg_names: + if name not in var_to_process_mesh: + var = block._find_var_recursive(name) + dist_tensor = ( + self._dist_context.get_dist_tensor_for_program(var) + ) + dist_tensor.dist_attr.process_mesh = process_mesh + var_to_process_mesh[var.name] = process_mesh + dist_op.dist_attr.process_mesh = process_mesh if ( not self._dist_context.strategy @@ -1080,7 +1101,7 @@ def set_chunk_id(block, op, chunk_id, var_to_chunk_id): ): return - pp_degree = get_pp_degree(self._dist_context) + pp_degree, sub_process_meshes = get_pp_degree(self._dist_context) vpp_degree = self._dist_context.strategy.pipeline.vpp_degree seg_method = self._dist_context.strategy.pipeline.vpp_seg_method schedule_mode = self._dist_context.strategy.pipeline.schedule_mode @@ -1099,8 +1120,11 @@ def set_chunk_id(block, op, chunk_id, var_to_chunk_id): block = serial_main_program.global_block() ops = block.ops - # 1. search seg_method in op's struct_name, and get all ops of segments - seg_op_deps = collections.OrderedDict() + # Step1: search seg_method in op's struct_name + # 1. get op_idx of each segment + # 2. get process_mesh or each segment + seg_op_deps = collections.OrderedDict() # struct_name -> [idx] + seg_op_mesh = collections.OrderedDict() # struct_name -> process_mesh regex = re.compile(seg_method, re.IGNORECASE) for i, op in enumerate(ops): struct_name = op.struct_name @@ -1109,59 +1133,93 @@ def set_chunk_id(block, op, chunk_id, var_to_chunk_id): continue struct_name = struct_name[m.start(0) :].split("/")[0] + dist_op = self._dist_context.get_dist_op_for_program(op) if struct_name not in seg_op_deps: seg_op_deps[struct_name] = [i] + seg_op_mesh[struct_name] = dist_op.dist_attr.process_mesh else: assert ( seg_op_deps[struct_name][-1] + 1 == i ), "The segment's ops should be continuous." - pre_op = ops[seg_op_deps[struct_name][-1]] - pre_dist_op = self._dist_context.get_dist_op_for_program(pre_op) - dist_op = self._dist_context.get_dist_op_for_program(op) + pre_mesh = seg_op_mesh[struct_name] assert ( - pre_dist_op.dist_attr.process_mesh - == dist_op.dist_attr.process_mesh + pre_mesh == dist_op.dist_attr.process_mesh ), "The segment's ops should have same process_mesh." seg_op_deps[struct_name].extend([i]) - # the num of chunk is equal to vpp_degree - num_parts = pp_degree * vpp_degree + num_chunks = pp_degree * vpp_degree assert ( - len(seg_op_deps.keys()) % num_parts == 0 - ), "number of layers[{}] ({}) should be devided by part number ({}).".format( - seg_method, len(seg_op_deps.keys()), num_parts + len(seg_op_deps) % num_chunks == 0 + ), "The number of layers[{}] ({}) should be devided by part number ({}).".format( + seg_method, len(seg_op_deps), num_chunks ) - part_size = len(seg_op_deps.keys()) // vpp_degree + # Step2: analysis whether the pp_stage is non-decreasing among segments + # 1. if non_decreasing is True, the ops' process_mesh will be changed by vpp strategy + # 2. if non_decreasing is False, the ops's process_mesh will not be changed. + non_decreasing = True + seg_pp_stages = [-1] + for seg_pm in seg_op_mesh.values(): + assert seg_pm in sub_process_meshes + pp_stage = sub_process_meshes.index(seg_pm) + if seg_pp_stages[-1] > pp_stage: + non_decreasing = False + break + seg_pp_stages.append(pp_stage) - # 2. get boundary index of each chunk - results = [0] * (vpp_degree + 1) - memory_counter = 0 - result_idx = 1 - for struct_name, idxs in seg_op_deps.items(): + if not non_decreasing: + _logger.info("Cannot Use Auto VPP") + else: + _logger.info("Using Auto VPP") + + # Step3: Get op index boundary, pp_stage, chunk_id, struct_names of each segment + seg_pp_stages = [i % pp_degree for i in range(num_chunks)] + seg_chunk_ids = [i // pp_degree for i in range(num_chunks)] + part_size = len(seg_op_deps) // num_chunks + segment_struct_names = [] + segment_parts = [0] * (num_chunks + 1) + memory_counter, seg_idx = 0, 1 + struct_name = [] + for name, idxs in seg_op_deps.items(): + struct_name.append(name) memory_counter += 1 if memory_counter == part_size: - results[result_idx] = idxs[-1] + 1 - result_idx += 1 - memory_counter = 0 - results[vpp_degree] = len(ops) + segment_parts[seg_idx] = idxs[-1] + 1 + memory_counter, seg_idx = 0, seg_idx + 1 + segment_struct_names.append(struct_name) + struct_name = [] + segment_parts[num_chunks] = len(ops) - # 3. set right chunk_id for each op + # Step4: set right chunk_id and process_mesh for each op and var var_to_chunk_id = {} - for chunk_id in range(len(results) - 1): - start_idx = results[chunk_id] - end_idx = results[chunk_id + 1] + var_to_process_mesh = {} + for seg_id in range(len(segment_parts) - 1): + start_idx = segment_parts[seg_id] + end_idx = segment_parts[seg_id + 1] + pp_stage = seg_pp_stages[seg_id] + chunk_id = seg_chunk_ids[seg_id] + process_mesh = sub_process_meshes[pp_stage] + struct_names = segment_struct_names[seg_id] + seg_op_idx = [] + for name in struct_names: + seg_op_idx.extend(seg_op_deps[name]) + _logger.info( - "[chunk_{}] start op: [{}]: [{}] [{}]".format( + "stage=[{}], chunk_id=[{}], layer_name=[{}]".format( + pp_stage, chunk_id, + struct_names, + ) + ) + _logger.info( + "start op: [{}]: [{}] [{}]".format( ops[start_idx].type, ops[start_idx].input_arg_names, ops[start_idx].output_arg_names, ) ) _logger.info( - "[chunk_{}] end op: [{}]: [{}] [{}]".format( - chunk_id, + "end op: [{}]: [{}] [{}]".format( ops[end_idx - 1].type, ops[end_idx - 1].input_arg_names, ops[end_idx - 1].output_arg_names, @@ -1173,9 +1231,28 @@ def set_chunk_id(block, op, chunk_id, var_to_chunk_id): if op.has_attr("sub_block"): block_id = op.attr('sub_block').id sub_block = serial_main_program.blocks[block_id] - for op in sub_block.ops: - set_chunk_id(sub_block, op, chunk_id, var_to_chunk_id) + if non_decreasing and idx in seg_op_idx: + set_process_mesh( + block, op, process_mesh, var_to_process_mesh + ) + set_chunk_id(block, op, chunk_id, var_to_chunk_id) + + for sub_op in sub_block.ops: + if non_decreasing and idx in seg_op_idx: + set_process_mesh( + sub_block, + sub_op, + process_mesh, + var_to_process_mesh, + ) + set_chunk_id( + sub_block, sub_op, chunk_id, var_to_chunk_id + ) else: + if non_decreasing and idx in seg_op_idx: + set_process_mesh( + block, op, process_mesh, var_to_process_mesh + ) set_chunk_id(block, op, chunk_id, var_to_chunk_id) def _update_dist_attr_for_dp(self): @@ -1915,8 +1992,34 @@ def infer_backward_op_partial_status( grad_op_dist_attr.set_output_dims_mapping( output_name, ref_fwd_dims_mapping ) - grad_op_dist_attr.process_mesh = ref_fwd_process_mesh - grad_op_dist_attr.chunk_id = ref_fwd_chunk_id + # NOTE(zhaoyingli): + # The sum op is used to accmulate the grads' value of the same forward var, + # sum op's chunk_id is same with the last op which generate the grad. + ref_chunk_id = None + ref_process_mesh = None + for pre_idx in range( + idx - 1, first_backward_op_idx + 1, -1 + ): + pre_grad_op = ops[pre_idx] + inter_arg_name = list( + set(pre_grad_op.output_arg_names) + & set(grad_op.input_arg_names) + ) + if len(inter_arg_name) > 0: + pre_op_dist_attr = ( + self._dist_context.get_op_dist_attr_for_program( + pre_grad_op + ) + ) + ref_chunk_id = pre_op_dist_attr.chunk_id + ref_process_mesh = pre_op_dist_attr.process_mesh + break + assert ( + ref_chunk_id is not None + and ref_process_mesh is not None + ) + grad_op_dist_attr.process_mesh = ref_process_mesh + grad_op_dist_attr.chunk_id = ref_chunk_id self._dist_context.set_op_dist_attr_for_program( grad_op, grad_op_dist_attr ) diff --git a/python/paddle/distributed/auto_parallel/static/utils.py b/python/paddle/distributed/auto_parallel/static/utils.py index 296196230d086b..359767c7345e87 100644 --- a/python/paddle/distributed/auto_parallel/static/utils.py +++ b/python/paddle/distributed/auto_parallel/static/utils.py @@ -2335,7 +2335,7 @@ def get_pp_degree(dist_context): for idx in reversed(global_pm_idx): process_meshes.pop(idx) - return len(process_meshes) + return len(process_meshes), process_meshes def get_pp_stage(dist_context, rank): diff --git a/test/auto_parallel/pipeline_scheduler_vpp_unittest.py b/test/auto_parallel/pipeline_scheduler_vpp_unittest.py index 431e782cb073e2..bed72232a05ca4 100644 --- a/test/auto_parallel/pipeline_scheduler_vpp_unittest.py +++ b/test/auto_parallel/pipeline_scheduler_vpp_unittest.py @@ -37,8 +37,8 @@ class MyLinear(nn.Layer): def __init__( self, - hidden_size=1024, - intermediate_size=4 * 1024, + hidden_size=784, + intermediate_size=4 * 784, dropout_ratio=0.1, weight_attr=None, ): @@ -64,10 +64,11 @@ def forward(self, input): class MLPLayer(nn.Layer): def __init__( self, - hidden_size=1024, - intermediate_size=4 * 1024, + hidden_size=784, + intermediate_size=4 * 784, dropout_ratio=0.1, initializer_range=0.02, + manual=True, ): super().__init__() @@ -86,7 +87,10 @@ def __init__( self.linear = nn.Linear(hidden_size, 1, weight_attr, bias_attr=None) self.norm = nn.LayerNorm(hidden_size, epsilon=1e-5) - self.layer_to_mesh = [PP_MESH_0, PP_MESH_1, PP_MESH_0, PP_MESH_1] + if manual: + self.layer_to_mesh = [PP_MESH_0, PP_MESH_1, PP_MESH_0, PP_MESH_1] + else: + self.layer_to_mesh = [PP_MESH_0, PP_MESH_0, PP_MESH_1, PP_MESH_1] def forward(self, input): out = self.norm(input) @@ -99,6 +103,11 @@ def forward(self, input): return out +def loss_fn(pred, label): + loss = F.l1_loss(pred, label) + return loss + + def apply_pass(schedule_mode, acc_step): strategy = auto.Strategy() strategy.auto_mode = "semi" @@ -126,8 +135,8 @@ def __init__(self, num_samples): self.num_samples = num_samples def __getitem__(self, index): - input = np.random.uniform(size=1024).astype("float32") - label = np.random.randint(0, 9, dtype="int64") + input = np.random.uniform(size=784).astype("float32") + label = np.random.uniform(size=1).astype("float32") return input, label def __len__(self): @@ -136,8 +145,6 @@ def __len__(self): class TestVPPPass(unittest.TestCase): def setUp(self): - self.rtol = 1e-5 - self.atol = 1e-8 self.batch_size = 4 self.batch_num = 10 self.clip_norm = 0.2 @@ -151,23 +158,50 @@ def init(self, engine): place = paddle.base.CUDAPlace(ParallelEnv().dev_id) engine._executor = paddle.static.Executor(place) - def get_engine(self, schedule_mode, acc_step): + def get_engine(self, schedule_mode, acc_step, manual=True): reset_prog() strategy = apply_pass(schedule_mode, acc_step) clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) - model = MLPLayer() - loss = paddle.nn.CrossEntropyLoss() + model = MLPLayer(manual=manual) - engine = auto.Engine(model, loss, opt, strategy=strategy) + engine = auto.Engine(model, loss_fn, opt, strategy=strategy) self.init(engine) return engine def test_pp_pass(self): - # pp2-vpp - engine = self.get_engine(schedule_mode="VPP", acc_step=4) - engine.fit(self.dataset, batch_size=self.batch_size, log_freq=1) + # pp2-vpp-manual + engine = self.get_engine(schedule_mode="VPP", acc_step=4, manual=True) + out_manual = engine.fit( + self.dataset, batch_size=self.batch_size, log_freq=1 + ) + assert engine._strategy.pipeline.schedule_mode == "VPP" + + fw_chunk_ids = [] + bw_chunk_ids = [] + for op in engine.main_program.global_block().ops: + if is_optimize_op(op): + break + + dist_op = engine.dist_context.get_dist_op_for_program(op) + if is_forward_op(op): + fw_chunk_ids.append(dist_op.dist_attr.chunk_id) + if is_backward_op(op): + bw_chunk_ids.append(dist_op.dist_attr.chunk_id) + + if paddle.distributed.get_rank() == 0: + self.assertEqual(sum(fw_chunk_ids), 8) + self.assertEqual(sum(bw_chunk_ids), 13) + else: + self.assertEqual(sum(fw_chunk_ids), 12) + self.assertEqual(sum(bw_chunk_ids), 19) + + # pp2-vpp-auto + engine = self.get_engine(schedule_mode="VPP", acc_step=4, manual=False) + out_auto = engine.fit( + self.dataset, batch_size=self.batch_size, log_freq=1 + ) assert engine._strategy.pipeline.schedule_mode == "VPP" fw_chunk_ids = [] @@ -183,11 +217,17 @@ def test_pp_pass(self): bw_chunk_ids.append(dist_op.dist_attr.chunk_id) if paddle.distributed.get_rank() == 0: - assert sum(fw_chunk_ids) == 8 - assert sum(bw_chunk_ids) == 13 + self.assertEqual(sum(fw_chunk_ids), 9) + self.assertEqual(sum(bw_chunk_ids), 13) else: - assert sum(fw_chunk_ids) == 12 - assert sum(bw_chunk_ids) == 18 + self.assertEqual(sum(fw_chunk_ids), 13) + self.assertEqual(sum(bw_chunk_ids), 19) + + if paddle.distributed.get_rank() == 1: + self.assertEqual( + np.mean(out_manual.history["loss"][0]), + np.mean(out_auto.history["loss"][0]), + ) if __name__ == "__main__": From b56d14015432c84cb0064c93b5b4e9b269d6488e Mon Sep 17 00:00:00 2001 From: xingmingyyj <135400902+xingmingyyj@users.noreply.github.com> Date: Tue, 2 Jan 2024 11:18:18 +0800 Subject: [PATCH 060/142] =?UTF-8?q?=E3=80=90PIR=20OpTest=20Fix=20No.23?= =?UTF-8?q?=E3=80=91=20fix=20test=5Fdistribute=5Ffpn=5Fproposals=5Fop=20(#?= =?UTF-8?q?60335)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix * fix --- paddle/fluid/pir/dialect/operator/ir/ops.yaml | 2 +- test/white_list/pir_op_test_white_list | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index b926b055daa6a2..b992c139b8543a 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -361,7 +361,7 @@ kernel : func : distribute_fpn_proposals data_type : fpn_rois - optional : rois_num + optional : rois_num, multi_level_rois_num - op : divide args : (Tensor x, Tensor y) diff --git a/test/white_list/pir_op_test_white_list b/test/white_list/pir_op_test_white_list index 2bf69d7d82fafb..a0db4b97520691 100644 --- a/test/white_list/pir_op_test_white_list +++ b/test/white_list/pir_op_test_white_list @@ -80,6 +80,7 @@ test_diagonal_op test_digamma_op test_dirichlet_op test_dist_op +test_distribute_fpn_proposals_op test_dot_op test_dpsgd_op test_edit_distance_op From 7041276a875c670848a45d9626cce3df2f5ff36c Mon Sep 17 00:00:00 2001 From: xingmingyyj <135400902+xingmingyyj@users.noreply.github.com> Date: Tue, 2 Jan 2024 11:21:46 +0800 Subject: [PATCH 061/142] fix test_lookup_table_v2_bf16_op (#60332) --- .../test_lookup_table_v2_bf16_op.py | 22 +++++++++---------- test/white_list/pir_op_test_white_list | 1 + 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/test/legacy_test/test_lookup_table_v2_bf16_op.py b/test/legacy_test/test_lookup_table_v2_bf16_op.py index 44a2f1881b086a..ff9fcebef4f65b 100644 --- a/test/legacy_test/test_lookup_table_v2_bf16_op.py +++ b/test/legacy_test/test_lookup_table_v2_bf16_op.py @@ -15,14 +15,8 @@ import unittest import numpy as np +import test_lookup_table_bf16_op from op_test import convert_uint16_to_float -from test_lookup_table_bf16_op import ( - TestLookupTableBF16Op, - TestLookupTableBF16OpIds4D, - TestLookupTableBF16OpWIsSelectedRows, - TestLookupTableBF16OpWIsSelectedRows4DIds, - _lookup, -) import paddle from paddle import base @@ -30,7 +24,7 @@ from paddle.pir_utils import test_with_pir_api -class TestLookupTableV2BF16Op(TestLookupTableBF16Op): +class TestLookupTableV2BF16Op(test_lookup_table_bf16_op.TestLookupTableBF16Op): def init_test(self): self.op_type = "lookup_table_v2" self.python_api = paddle.nn.functional.embedding @@ -38,7 +32,9 @@ def init_test(self): self.mkldnn_data_type = "bfloat16" -class TestLookupTableV2BF16OpIds4D(TestLookupTableBF16OpIds4D): +class TestLookupTableV2BF16OpIds4D( + test_lookup_table_bf16_op.TestLookupTableBF16OpIds4D +): def init_test(self): self.op_type = "lookup_table_v2" self.python_api = paddle.nn.functional.embedding @@ -47,7 +43,7 @@ def init_test(self): class TestLookupTableV2BF16OpWIsSelectedRows( - TestLookupTableBF16OpWIsSelectedRows + test_lookup_table_bf16_op.TestLookupTableBF16OpWIsSelectedRows ): def init_test(self): self.op_type = "lookup_table_v2" @@ -56,7 +52,7 @@ def init_test(self): class TestLookupTableV2BF16OpWIsSelectedRows4DIds( - TestLookupTableBF16OpWIsSelectedRows4DIds + test_lookup_table_bf16_op.TestLookupTableBF16OpWIsSelectedRows4DIds ): def init_test(self): self.op_type = "lookup_table_v2" @@ -134,7 +130,9 @@ def test_embedding_weights(self): @test_with_pir_api def test_lookup_results(self): lookup_result = convert_uint16_to_float(self.result[1]) - lookup_ref = _lookup(self.w_fp32, self.ids, self.flat_ids, self.op_type) + lookup_ref = test_lookup_table_bf16_op._lookup( + self.w_fp32, self.ids, self.flat_ids, self.op_type + ) np.testing.assert_array_equal(lookup_result, lookup_ref) diff --git a/test/white_list/pir_op_test_white_list b/test/white_list/pir_op_test_white_list index a0db4b97520691..045e8b4df94595 100644 --- a/test/white_list/pir_op_test_white_list +++ b/test/white_list/pir_op_test_white_list @@ -180,6 +180,7 @@ test_logcumsumexp_op test_logit_op test_logspace test_logsumexp +test_lookup_table_v2_bf16_op test_lookup_table_v2_op test_lookup_table_v2_op_static_build test_lrn_mkldnn_op From 1761931b9c30966c11e49b1c50ab834ad716b186 Mon Sep 17 00:00:00 2001 From: JYChen Date: Tue, 2 Jan 2024 12:07:21 +0800 Subject: [PATCH 062/142] Fix shape error in combined-indexing setitem (#60447) * add ut * fix shape error in combine-indexing * fix ut --- paddle/fluid/pybind/eager_method.cc | 16 ++- paddle/fluid/pybind/slice_utils.h | 43 ++++----- python/paddle/base/variable_index.py | 45 ++++++--- test/indexing/test_setitem.py | 139 +++++++++++++++++++++++++++ 4 files changed, 200 insertions(+), 43 deletions(-) diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 5a102f3c75cc5f..b3898533f965c1 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -1375,7 +1375,7 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self, // step3: Dealing with advanced indexing std::vector transed_index; - std::vector trans_back_dim; + std::vector trans_back_dim, trans_dim; int pos_of_new_dim = INT_MAX, rank_of_new_dim = 1; paddle::Tensor transed_tensor = dealWithAdvancedIndex(out, @@ -1385,7 +1385,8 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self, &transed_index, &trans_back_dim, &pos_of_new_dim, - &rank_of_new_dim); + &rank_of_new_dim, + &trans_dim); if (transed_index.size() == 1 && transed_index[0].dtype() == phi::DataType::BOOL) { @@ -1679,9 +1680,9 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self, &use_strided_slice); std::vector transed_index; - std::vector trans_back_dim; + std::vector trans_back_dim, trans_dim; - int pos_of_new_dim = 0, rank_of_new_dim = 0; + int pos_of_new_dim = INT_MAX, rank_of_new_dim = 1; paddle::Tensor transed_sub_tensor = dealWithAdvancedIndex(sub_tensor, @@ -1691,7 +1692,8 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self, &transed_index, &trans_back_dim, &pos_of_new_dim, - &rank_of_new_dim); + &rank_of_new_dim, + &trans_dim); // Release gil and do tracing py::gil_scoped_release release; @@ -1714,6 +1716,10 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self, } } + if (value_tensor.dims().size() > 1 && pos_of_new_dim != 0) { + value_tensor = transpose_ad_func(value_tensor, trans_dim); + } + // TODO(zoooo0820) 1.Using inplace version index_put // 2.Remove following code after backward bug fixed. transed_sub_tensor = assign_ad_func(transed_sub_tensor); diff --git a/paddle/fluid/pybind/slice_utils.h b/paddle/fluid/pybind/slice_utils.h index f4eef3af16bcf1..e60ab9406396a2 100644 --- a/paddle/fluid/pybind/slice_utils.h +++ b/paddle/fluid/pybind/slice_utils.h @@ -398,9 +398,8 @@ static paddle::Tensor dealWithAdvancedIndex( std::vector* transed_index, std::vector* trans_back_dim, int* pos_of_new_dim, - int* rank_of_new_dim) { - std::vector trans_dim; - + int* rank_of_new_dim, + std::vector* trans_dim) { int p = 0; for (size_t i = 0; i < advanced_index_dim->size(); ++i) { auto index_dim = (*advanced_index_dim)[i]; @@ -409,30 +408,28 @@ static paddle::Tensor dealWithAdvancedIndex( // advanced_index_dim auto index = (*advanced_index)[p++]; - if (!is_for_setitem) { - if (index_dim == 0) { - // case 1: advanced indices at axis 0, the new dim will be at first. - *pos_of_new_dim = 0; - } else if (index_dim > 0 && trans_dim.size() > 0 && - trans_dim[trans_dim.size() - 1] != index_dim - 1) { - // case 2: there are not adjacent advanced indices, the new dim will - // be at first. - *pos_of_new_dim = 0; - } else { - *pos_of_new_dim = std::min(index_dim, *pos_of_new_dim); - } - *rank_of_new_dim = - std::max(*rank_of_new_dim, static_cast(index.shape().size())); + if (index_dim == 0) { + // case 1: advanced indices at axis 0, the new dim will be at first. + *pos_of_new_dim = 0; + } else if (index_dim > 0 && trans_dim->size() > 0 && + (*trans_dim)[trans_dim->size() - 1] != index_dim - 1) { + // case 2: there are not adjacent advanced indices, the new dim will + // be at first. + *pos_of_new_dim = 0; + } else { + *pos_of_new_dim = std::min(index_dim, *pos_of_new_dim); } + *rank_of_new_dim = + std::max(*rank_of_new_dim, static_cast(index.shape().size())); - trans_dim.push_back(index_dim); + trans_dim->push_back(index_dim); transed_index->push_back(std::move(index)); } } for (size_t i = 0; i < tensor.shape().size(); ++i) { if ((*advanced_index_dim)[i] == -1) { - trans_dim.push_back(i); + trans_dim->push_back(i); } } @@ -442,19 +439,19 @@ static paddle::Tensor dealWithAdvancedIndex( std::vector original_dim_order(tensor.shape().size()); std::iota(original_dim_order.begin(), original_dim_order.end(), 0); - if (original_dim_order == trans_dim) { + if (original_dim_order == *trans_dim) { transed_tensor = tensor; } else { - transed_tensor = transpose_ad_func(tensor, trans_dim); + transed_tensor = transpose_ad_func(tensor, *trans_dim); } if (is_for_setitem) { - trans_back_dim->resize(trans_dim.size()); + trans_back_dim->resize(trans_dim->size()); std::iota(trans_back_dim->begin(), trans_back_dim->end(), 0); std::sort(trans_back_dim->begin(), trans_back_dim->end(), [&trans_dim](int left, int right) { - return trans_dim[left] < trans_dim[right]; + return (*trans_dim)[left] < (*trans_dim)[right]; }); } return transed_tensor; diff --git a/python/paddle/base/variable_index.py b/python/paddle/base/variable_index.py index ca3a107765dcb7..efbb5eb40edc7b 100644 --- a/python/paddle/base/variable_index.py +++ b/python/paddle/base/variable_index.py @@ -170,7 +170,7 @@ def _setitem_for_tensor_array(var, item, value): ) -def deal_advanced_index(ori_tensor, indices, is_for_setitem): +def deal_advanced_index(ori_tensor, indices, is_for_setitem, values): """ Transpose origin Tensor and advanced indices to the front. @@ -180,6 +180,7 @@ def deal_advanced_index(ori_tensor, indices, is_for_setitem): trans_back_dim (List): order of axes to transpose back to original order. Only used in __setitem__. pos_of_new_dim (int): axis of new dim in the result. Only used in __getitem__. rank_of_new_dim (int): rank of new dim in the result. Only used in __getitem__. + transed_value_tensor (Tensor): value tensor transed to the front. Only used in __setitem__. """ transed_dim = [] transed_index = [] @@ -191,16 +192,15 @@ def deal_advanced_index(ori_tensor, indices, is_for_setitem): for i, indice in enumerate(indices): if indice is not None: - if not is_for_setitem: - if i == 0: - # case 1: advanced indices at axis 0, the new dim will be at first. - pos_of_new_dim = 0 - if i > 0 and len(transed_dim) > 0 and transed_dim[-1] != i - 1: - # case 2: there are not adjacent advanced indices, the new dim will be at first. - pos_of_new_dim = 0 - else: - pos_of_new_dim = min(pos_of_new_dim, i) - rank_of_new_dim = max(rank_of_new_dim, indice[1].ndim) + if i == 0: + # case 1: advanced indices at axis 0, the new dim will be at first. + pos_of_new_dim = 0 + if i > 0 and len(transed_dim) > 0 and transed_dim[-1] != i - 1: + # case 2: there are not adjacent advanced indices, the new dim will be at first. + pos_of_new_dim = 0 + else: + pos_of_new_dim = min(pos_of_new_dim, i) + rank_of_new_dim = max(rank_of_new_dim, indice[1].ndim) transed_dim.append(i) transed_index.append(indice[1]) for i in range(ori_tensor.ndim): @@ -210,12 +210,22 @@ def deal_advanced_index(ori_tensor, indices, is_for_setitem): trans_back_dim = np.argsort(transed_dim).tolist() if is_for_setitem else [] + transed_value_tensor = None + if is_for_setitem: + if values.ndim > 1 and pos_of_new_dim != 0: + # If the value tensor is not a scalar / 1-D Tensor, and the src tensor was + # transposed at 1st dim, the value tensor should be transposed too. + transed_value_tensor = values.transpose(transed_dim) + else: + transed_value_tensor = values + return ( transed_tensor, transed_index, trans_back_dim, pos_of_new_dim, rank_of_new_dim, + transed_value_tensor, ) @@ -577,6 +587,11 @@ def _setitem_static(x, indices, values): # 3. assign values to the sliced result by index_put OP; # 4. transpose back and assign the result to original tensor by set_value OP. + if not isinstance( + values, (Variable, paddle.pir.Value, paddle.pir.OpResult) + ): + values = paddle.assign(values).astype(x.dtype) + sub_tensor = get_tensor_with_basic_indexing( x, axes, @@ -593,9 +608,8 @@ def _setitem_static(x, indices, values): transback_dim, _, _, - ) = deal_advanced_index(sub_tensor, advanced_index, True) - if not isinstance(values, (Variable, paddle.pir.Value)): - values = paddle.assign(values).astype(transed_sub_tensor.dtype) + values, + ) = deal_advanced_index(sub_tensor, advanced_index, True, values) if values.dtype != transed_sub_tensor.dtype: values = values.astype(transed_sub_tensor.dtype) @@ -818,7 +832,8 @@ def _getitem_static(x, indices): _, pos_of_new_dim, rank_of_new_dim, - ) = deal_advanced_index(out, advanced_index, False) + _, + ) = deal_advanced_index(out, advanced_index, False, None) # TODO(zooooo0820): Replacing gather_nd to another advanded OP for handling of mixed indexes more efficiently if ( diff --git a/test/indexing/test_setitem.py b/test/indexing/test_setitem.py index ed948d635256f5..350176d1acb03a 100644 --- a/test/indexing/test_setitem.py +++ b/test/indexing/test_setitem.py @@ -229,6 +229,79 @@ def test_indexing_is_boolean_false(self): np.testing.assert_allclose(x.numpy(), np_data) + def test_combined_indexing_and_value_is_tensor_1(self): + # value is tensor with same shape to getitem and index will be adjusted + np_data = np.ones((3, 3)).astype(self.ndtype) + value_data = np.array([-1, -1, -1]).astype(self.ndtype) + + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + value_data = convert_uint16_to_float( + convert_float_to_uint16(value_data) + ) + if self.dtype == 'complex64' or self.dtype == 'complex128': + np_data = np_data + 1j * np_data + value_data = value_data + 1j * value_data + + x = paddle.to_tensor(np_data, dtype=self.dtype) + v = paddle.to_tensor(value_data, dtype=self.dtype) + + np_data[:, [0, 2]] = np_data[:, [0, 2]] + np.expand_dims(value_data, -1) + x[:, [0, 2]] = x[:, [0, 2]] + v.unsqueeze(-1) + + if self.dtype == 'bfloat16': + x = paddle.cast(x, dtype='float32') + + np.testing.assert_allclose(x.numpy(), np_data) + + def test_combined_indexing_and_value_is_tensor_2(self): + # value is tensor needed to broadcast and index will be adjusted + np_data = np.ones((3, 4, 5, 6)).astype(self.ndtype) + value_data = np.arange(3 * 4 * 2 * 1).reshape((3, 4, 2, 1)) + + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + value_data = convert_uint16_to_float( + convert_float_to_uint16(value_data) + ) + if self.dtype == 'complex64' or self.dtype == 'complex128': + np_data = np_data + 1j * np_data + value_data = value_data + 1j * value_data + + x = paddle.to_tensor(np_data, dtype=self.dtype) + v = paddle.to_tensor(value_data, dtype=self.dtype) + x[..., [1, 4], ::2] = v + + np_data[..., [1, 4], ::2] = value_data + if self.dtype == 'bfloat16': + x = paddle.cast(x, dtype='float32') + np.testing.assert_allclose(x.numpy(), np_data) + + def test_combined_indexing_and_value_is_tensor_3(self): + # value is tensor and index will be adjusted + # and the value rank is less than original tensor + np_data = np.ones((3, 4, 5, 6)).astype(self.ndtype) + value_data = np.arange(2 * 3 * 5).reshape((2, 3, 5)) + + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + value_data = convert_uint16_to_float( + convert_float_to_uint16(value_data) + ) + if self.dtype == 'complex64' or self.dtype == 'complex128': + np_data = np_data + 1j * np_data + value_data = value_data + 1j * value_data + + x = paddle.to_tensor(np_data, dtype=self.dtype) + v = paddle.to_tensor(value_data, dtype=self.dtype) + x[:, [1, 3], :, [3, 4]] = v + + np_data[:, [1, 3], :, [3, 4]] = value_data + + if self.dtype == 'bfloat16': + x = paddle.cast(x, dtype='float32') + np.testing.assert_allclose(x.numpy(), np_data) + def test_inplace_with_stride(self): np_v = np.random.randn(3, 1).astype(self.ndtype) if self.dtype == 'bfloat16': @@ -588,6 +661,72 @@ def test_indexing_is_boolean_false(self): np.testing.assert_allclose(res[0], np_data) + @test_with_pir_api + def test_combined_indexing_and_value_is_tensor_1(self): + # value is tensor with same shape to getitem and index will be adjusted + np_data = np.ones((3, 3), dtype='int32') + value_data = np.array([-1, -1, -1]) + np_data[:, [0, 2]] = np_data[:, [0, 2]] * np.expand_dims(value_data, -1) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.ones((3, 3), dtype='int32') + v = paddle.to_tensor([-1, -1, -1]) + y = _setitem_static( + x, + (slice(None), [0, 2]), + x[:, [0, 2]] * v.unsqueeze(-1), + ) + res = self.exe.run(fetch_list=[y]) + + np.testing.assert_allclose(res[0], np_data) + + @test_with_pir_api + def test_combined_indexing_and_value_is_tensor_2(self): + # value is tensor needed to broadcast and index will be adjusted + np_data = np.ones((3, 4, 5, 6), dtype='int32') + value_data = np.arange(3 * 4 * 2 * 1).reshape((3, 4, 2, 1)) + np_data[..., [1, 4], ::2] = value_data + + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.ones((3, 4, 5, 6), dtype='int32') + v = paddle.arange(3 * 4 * 2 * 1).reshape((3, 4, 2, 1)) + + y = _setitem_static( + x, + (..., [1, 4], slice(None, None, 2)), + v, + ) + + res = self.exe.run(fetch_list=[y]) + + np.testing.assert_allclose(res[0], np_data) + + @test_with_pir_api + def test_combined_indexing_and_value_is_tensor_3(self): + # value is tensor and index will be adjusted + # and the value rank is less than original tensor + np_data = np.ones((3, 4, 5, 6), dtype='int32') + value_data = np.arange(2 * 3 * 5).reshape((2, 3, 5)) + np_data[:, [1, 3], :, [3, 4]] = value_data + + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.ones((3, 4, 5, 6), dtype='int32') + v = paddle.arange(2 * 3 * 5).reshape((2, 3, 5)) + y = _setitem_static( + x, + (slice(None), [1, 3], slice(None), [3, 4]), + v, + ) + + res = self.exe.run(fetch_list=[y]) + + np.testing.assert_allclose(res[0], np_data) + if __name__ == '__main__': unittest.main() From c0d6d7d09ad7c82618ae9ab723e65ae53668eb51 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Tue, 2 Jan 2024 14:44:08 +0800 Subject: [PATCH 063/142] [auto parallel] Add pp lazy init, bug fix for xavier (#60441) --- paddle/fluid/pybind/eager_method.cc | 17 ++-- .../paddle/distributed/auto_parallel/api.py | 6 ++ python/paddle/nn/initializer/xavier.py | 18 +++-- .../semi_auto_parallel_lazy_init.py | 81 ++++++++++++++----- .../test_semi_auto_parallel_lazy_init.py | 5 +- 5 files changed, 95 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index b3898533f965c1..c9b3b106597448 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -1078,12 +1078,17 @@ static PyObject* tensor__share_underline_tensor_to(TensorObject* self, EAGER_TRY paddle::Tensor* src_ptr = &(reinterpret_cast(PyTuple_GET_ITEM(args, 0))->tensor); - PADDLE_ENFORCE_EQ(self->tensor.initialized(), - true, - platform::errors::InvalidArgument( - "Tensor %s has not been initialized! please initialize " - "src tensor before share_buffer_with to other.", - self->tensor.name())); + if (!self->tensor.initialized()) { + PADDLE_ENFORCE(self->tensor.is_dist_tensor() && + !phi::distributed::IsCurRankInMesh( + static_cast( + self->tensor.impl().get()) + ->process_mesh()), + platform::errors::InvalidArgument( + "Tensor %s has not been initialized! Please initialize " + "src tensor before share_buffer_with to other.", + self->tensor.name())); + } src_ptr->set_impl(self->tensor.impl()); RETURN_PY_NONE diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index c012d7a59d1c61..7e734bd95b1b10 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -187,6 +187,12 @@ def shard_tensor( if isinstance(data, EagerParamBase): def lazy_init_hook(param, origin_hook): + for placement in param.placements: + assert not placement.is_partial(), ( + "Lazy init not support partial reshard. Notice that: shard a param to partial " + "won't save any memory, but will increase the communication cost!" + ) + # lazy init hook with randomness controlling def _init_func(var, block): # get the unique rng name diff --git a/python/paddle/nn/initializer/xavier.py b/python/paddle/nn/initializer/xavier.py index 58d73d21dfe865..e455ca455cd004 100644 --- a/python/paddle/nn/initializer/xavier.py +++ b/python/paddle/nn/initializer/xavier.py @@ -105,6 +105,11 @@ def forward(self, var, block=None): if self._seed == 0: self._seed = block.program.random_seed + out_var_shape = ( + var._local_shape + if (isinstance(var, framework.EagerParamBase) and var.is_dist()) + else var.shape + ) # to be compatible of fp16 initalizers if var.dtype == core.VarDesc.VarType.FP16 or ( var.dtype == core.VarDesc.VarType.BF16 and not self._uniform @@ -114,9 +119,7 @@ def forward(self, var, block=None): name=unique_name.generate( ".".join(['xavier_init', var.name, 'tmp']) ), - shape=var._local_shape - if (isinstance(var, framework.EagerParamBase) and var.is_dist()) - else var.shape, + shape=out_var_shape, dtype=out_dtype, type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, @@ -135,7 +138,7 @@ def forward(self, var, block=None): if self._uniform: limit = math.sqrt(6.0 / float(fan_in + fan_out)) out_var = _C_ops.uniform( - out_var.shape, + out_var_shape, out_dtype, -limit, limit, @@ -147,7 +150,12 @@ def forward(self, var, block=None): place = _current_expected_place() out_var = _C_ops.gaussian( - out_var.shape, 0.0, std, self._seed, out_dtype, place + out_var_shape, + 0.0, + std, + self._seed, + out_dtype, + place, ) if var.dtype == core.VarDesc.VarType.FP16 or ( diff --git a/test/auto_parallel/semi_auto_parallel_lazy_init.py b/test/auto_parallel/semi_auto_parallel_lazy_init.py index 52016c358ea357..cfeff65b2733a1 100644 --- a/test/auto_parallel/semi_auto_parallel_lazy_init.py +++ b/test/auto_parallel/semi_auto_parallel_lazy_init.py @@ -22,40 +22,81 @@ class TestSemiAutoParallelLazyInit: def __init__(self): self._backend = os.getenv("backend") + self._placements_type = os.getenv("_placements_type") self._seed = eval(os.getenv("seed")) - self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + if self._placements_type == "DP": + self._mesh_weight = dist.ProcessMesh([0, 1], dim_names=["x"]) + self._mesh_bias = dist.ProcessMesh([0, 1], dim_names=["x"]) + self._placements_weight = [dist.Replicate()] + self._placements_bias = [dist.Replicate()] + elif self._placements_type == "PP": + self._mesh_weight = dist.ProcessMesh([0], dim_names=["x"]) + self._mesh_bias = dist.ProcessMesh([1], dim_names=["x"]) + self._placements_weight = [dist.Replicate()] + self._placements_bias = [dist.Replicate()] - def test_replicate(self): + def test_different_xavier(self): + paddle.distributed.auto_parallel.parallel_manual_seed(self._seed) + weight_attr = paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.XavierNormal() + ) + bias_attr = paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.XavierUniform() + ) + with LazyGuard(): + linear = paddle.nn.Linear( + 10, 10, weight_attr=weight_attr, bias_attr=bias_attr + ) + linear.weight = dist.shard_tensor( + linear.weight, self._mesh_weight, self._placements_weight + ) + linear.bias = dist.shard_tensor( + linear.bias, self._mesh_bias, self._placements_bias + ) + + def test_placements(self): paddle.distributed.auto_parallel.parallel_manual_seed(self._seed) with LazyGuard(): linear = paddle.nn.Linear(10, 10) linear.weight = dist.shard_tensor( - linear.weight, self._mesh, [dist.Replicate()] + linear.weight, self._mesh_weight, self._placements_weight ) linear.bias = dist.shard_tensor( - linear.bias, self._mesh, [dist.Replicate()] + linear.bias, self._mesh_bias, self._placements_bias ) for param in linear.parameters(): assert not param._is_initialized() param.initialize() - assert param._is_initialized() - - local_weight_md5 = linear.weight._local_value()._md5sum() - mesh0 = dist.ProcessMesh([0], dim_names=["x"]) - mesh1 = dist.ProcessMesh([1], dim_names=["x"]) - tmp = paddle.distributed.auto_parallel.api.dtensor_from_local( - linear.weight._local_value(), - mesh0 if dist.get_rank() == 0 else mesh1, - [dist.Replicate()], - ) - tmp = dist.reshard( - tmp, mesh1 if dist.get_rank() == 0 else mesh0, [dist.Replicate()] - ) - tmp_md5 = tmp._local_value()._md5sum() - assert local_weight_md5 == tmp_md5 + + if self._placements_type == "DP": + assert linear.weight._is_initialized() + assert linear.bias._is_initialized() + local_weight_md5 = linear.weight._local_value()._md5sum() + mesh0 = dist.ProcessMesh([0], dim_names=["x"]) + mesh1 = dist.ProcessMesh([1], dim_names=["x"]) + tmp = paddle.distributed.auto_parallel.api.dtensor_from_local( + linear.weight._local_value(), + mesh0 if dist.get_rank() == 0 else mesh1, + [dist.Replicate()], + ) + tmp = dist.reshard( + tmp, + mesh1 if dist.get_rank() == 0 else mesh0, + [dist.Replicate()], + ) + tmp_md5 = tmp._local_value()._md5sum() + assert local_weight_md5 == tmp_md5 + elif self._placements_type == "PP": + if dist.get_rank() == 0: + assert linear.weight._is_initialized() + assert not linear.bias._is_initialized() + else: + assert not linear.weight._is_initialized() + assert linear.bias._is_initialized() def run_test_case(self): - self.test_replicate() + self.test_placements() + self.test_different_xavier() if __name__ == '__main__': diff --git a/test/auto_parallel/test_semi_auto_parallel_lazy_init.py b/test/auto_parallel/test_semi_auto_parallel_lazy_init.py index d0c09749af53d5..b55423184b9188 100644 --- a/test/auto_parallel/test_semi_auto_parallel_lazy_init.py +++ b/test/auto_parallel/test_semi_auto_parallel_lazy_init.py @@ -27,7 +27,10 @@ def setUp(self): "dtype": "float32", "seed": "2023", } - self._changeable_envs = {"backend": ["cpu", "gpu"]} + self._changeable_envs = { + "backend": ["cpu", "gpu"], + "_placements_type": ["DP", "PP"], + } def test_lazy_init(self): envs_list = test_base.gen_product_envs_list( From df4ca858521eb398cd65a6910b92a932ccac539b Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Tue, 2 Jan 2024 15:25:45 +0800 Subject: [PATCH 064/142] [PIR] add slice_array_dense api (#60433) * fix * fix --- .../pir/dialect/operator/ir/manual_api.cc | 7 ++ .../pir/dialect/operator/ir/manual_api.h | 2 + .../pir/dialect/operator/ir/manual_op.cc | 67 ++++++++++++++++++- .../fluid/pir/dialect/operator/ir/manual_op.h | 5 ++ .../fluid/pybind/manual_static_op_function.h | 41 ++++++++++++ paddle/phi/infermeta/unary.cc | 3 +- 6 files changed, 123 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc index d9c5debe92ee66..33fecafdbb0258 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc @@ -165,5 +165,12 @@ std::tuple array_to_tensor(pir::Value x, return std::make_tuple(array_to_tensor.result(0), array_to_tensor.result(1)); } +pir::OpResult slice_array_dense(pir::Value input, pir::Value starts) { + auto op = ApiBuilder::Instance() + .GetBuilder() + ->Build(input, starts); + return op.result(0); +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.h b/paddle/fluid/pir/dialect/operator/ir/manual_api.h index 680cd5b54ab905..347e10494696c0 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.h @@ -72,5 +72,7 @@ std::tuple array_to_tensor(pir::Value x, int axis, bool use_stack); +pir::OpResult slice_array_dense(pir::Value input, pir::Value starts); + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index d3d8c46111bbb2..0a60b4c7d7d819 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -1807,7 +1807,7 @@ OpInfoTuple SliceArrayDenseOp::GetOpInfo() { paddle::dialect::OpOutputInfo( "out", "paddle::dialect::DenseTensorType", false, false)}; paddle::dialect::OpRunTimeInfo run_time_info = - paddle::dialect::OpRunTimeInfo("SliceArrayInferMeta", + paddle::dialect::OpRunTimeInfo("SliceArrayDenseInferMeta", {"input", "starts"}, "slice_array_dense", {"input", "starts"}, @@ -1855,6 +1855,71 @@ void SliceArrayDenseOp::VerifySig() { VLOG(4) << "End Verifying for: SliceArrayOp."; } +void SliceArrayDenseOp::Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value input, + pir::Value starts) { + VLOG(4) << "Start build SliceArrayDenseOp"; + VLOG(4) << "Builder construction inputs"; + argument.AddInputs({input, starts}); + VLOG(4) << "Builder construction attributes"; + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorArrayType input_type = + input.type().dyn_cast(); + paddle::dialect::IrTensor dense_input( + paddle::dialect::TransToPhiDataType(input_type.dtype()), + {}, + input_type.data_layout(), + {}); + paddle::dialect::IrMetaTensor meta_input(&dense_input); + + phi::IntArray starts_list; + if (starts.dyn_cast() + .owner() + ->isa()) { + starts_list = std::move(phi::IntArray(paddle::dialect::GetInt64Vector( + starts.dyn_cast() + .owner() + ->dyn_cast() + .attribute("value")))); + } else if (starts.type().isa()) { + size_t starts_size = starts.type().dyn_cast().size(); + starts_list = + std::move(phi::IntArray(std::vector(starts_size, -1))); + starts_list.SetFromTensor(true); + } else if (starts.type().isa()) { + common::DDim starts_dim = + starts.type().dyn_cast().dims(); + size_t starts_size = common::product(starts_dim); + if (common::contain_unknown_dim(starts_dim)) { + starts_size = 1; + } + starts_list = + std::move(phi::IntArray(std::vector(starts_size, -1))); + starts_list.SetFromTensor(true); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support VectorType or DenseTensorType")); + } + + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::SliceArrayDenseInferMeta( + meta_input, starts_list, &meta_out, phi::MetaConfig(false, false)); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); +} + void SliceArrayDenseOp::InferMeta(phi::InferMetaContext *infer_meta) { auto fn = PD_INFER_META(phi::SliceArrayDenseInferMeta); fn(infer_meta); diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index 4d001206409512..121c95dee169aa 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -334,6 +334,11 @@ class SliceArrayDenseOp static OpInfoTuple GetOpInfo(); void VerifySig(); + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value input, + pir::Value starts); + static phi::DataType GetKernelTypeForVar( const std::string &var_name, const phi::DataType &tensor_dtype, diff --git a/paddle/fluid/pybind/manual_static_op_function.h b/paddle/fluid/pybind/manual_static_op_function.h index 21285163dd64f1..dc09d539f39ffb 100644 --- a/paddle/fluid/pybind/manual_static_op_function.h +++ b/paddle/fluid/pybind/manual_static_op_function.h @@ -274,6 +274,43 @@ static PyObject *static_api_array_to_tensor(PyObject *self, } } +static PyObject *static_api_slice_array_dense(PyObject *self, + PyObject *args, + PyObject *kwargs) { + try { + VLOG(6) << "Add slice_array_dense op into program"; + VLOG(8) << "args count: " << (PyTuple_Size(args) / 2); + + // Get Value from args + PyObject *input_obj = PyTuple_GET_ITEM(args, 0); + auto input = CastPyArg2Value(input_obj, "slice_array_dense", 0); + + PyObject *starts_obj = PyTuple_GET_ITEM(args, 1); + pir::Value starts; + if (PyObject_CheckIROpResult(starts_obj)) { + starts = CastPyArg2Value(starts_obj, "slice_array_dense", 1); + } else if (PyObject_CheckIRVectorOfOpResult(starts_obj)) { + std::vector starts_tmp = + CastPyArg2VectorOfValue(starts_obj, "slice_array_dense", 1); + starts = paddle::dialect::stack(starts_tmp, /*axis*/ 0); + + } else { + std::vector starts_tmp = + CastPyArg2Longs(starts_obj, "slice_array_dense", 1); + starts = paddle::dialect::full_int_array( + starts_tmp, phi::DataType::INT64, phi::CPUPlace()); + } + + // Call ir static api + auto static_api_out = paddle::dialect::slice_array_dense(input, starts); + + return ToPyObject(static_api_out); + } catch (...) { + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } +} + static PyMethodDef ManualOpsAPI[] = { {"set_parameter", (PyCFunction)(void (*)(void))static_api_set_parameter, @@ -303,6 +340,10 @@ static PyMethodDef ManualOpsAPI[] = { (PyCFunction)(void (*)(void))static_api_array_to_tensor, METH_VARARGS | METH_KEYWORDS, "C++ interface function for array_to_tensor."}, + {"slice_array_dense", + (PyCFunction)(void (*)(void))static_api_slice_array_dense, + METH_VARARGS | METH_KEYWORDS, + "C++ interface function for slice_array_dense."}, {nullptr, nullptr, 0, nullptr}}; } // namespace pybind diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 90987398057fe9..a75cd4170e2785 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -3687,7 +3687,8 @@ void SliceArrayDenseInferMeta(const MetaTensor& input, if (config.is_runtime) { return; } - out->set_dims(input.dims()); + // out->set_dims(input.dims()); + out->set_dtype(input.dtype()); } void SliceRawInferMeta(const MetaTensor& input, From 7c7446f55dc88c5dbeb33387627ccd45c72526e5 Mon Sep 17 00:00:00 2001 From: JYChen Date: Tue, 2 Jan 2024 15:28:16 +0800 Subject: [PATCH 065/142] Set value with scalar (#60452) * set_value with scalar * fix ut --- paddle/fluid/pybind/eager_method.cc | 92 +++++++++------- paddle/fluid/pybind/slice_utils.h | 101 ++++++++++++++++++ .../base/dygraph/tensor_patch_methods.py | 11 +- 3 files changed, 157 insertions(+), 47 deletions(-) diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index c9b3b106597448..feaf7ccd1a2f68 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -1613,12 +1613,9 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self, &use_strided_slice); // step2: Parse values - PADDLE_ENFORCE( - PyCheckTensor(value_obj), - platform::errors::InvalidArgument("The value must be a Tensor")); - + std::vector values; paddle::Tensor value_tensor = - reinterpret_cast(value_obj)->tensor; + dealWithValues(tensor, value_obj, &values, has_advanced_index); if (!has_advanced_index) { // use set_value OP if there is no advanced index @@ -1626,45 +1623,60 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self, // Release gil and do tracing py::gil_scoped_release release; // use inplace set_value_ operator - if (value_tensor.initialized() && - (self->tensor.dtype() != value_tensor.dtype())) { - if (egr::Controller::Instance().GetAMPLevel() != - paddle::imperative::AmpLevel::O0) { - paddle::small_vector, - egr::kSlotSmallVectorSize> - tmps = {{self->tensor}, {value_tensor}}; - auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps); - self->tensor = egr::EagerAmpAutoCast( - self->tensor.name(), self->tensor, amp_dtype, "set_value"); - value_tensor = egr::EagerAmpAutoCast( - value_tensor.name(), value_tensor, amp_dtype, "set_value"); - } + if (value_tensor.initialized()) { if (self->tensor.dtype() != value_tensor.dtype()) { - value_tensor = cast_ad_func(value_tensor, self->tensor.dtype()); + if (egr::Controller::Instance().GetAMPLevel() != + paddle::imperative::AmpLevel::O0) { + paddle::small_vector, + egr::kSlotSmallVectorSize> + tmps = {{self->tensor}, {value_tensor}}; + auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps); + self->tensor = egr::EagerAmpAutoCast( + self->tensor.name(), self->tensor, amp_dtype, "set_value"); + value_tensor = egr::EagerAmpAutoCast( + value_tensor.name(), value_tensor, amp_dtype, "set_value"); + } + if (self->tensor.dtype() != value_tensor.dtype()) { + value_tensor = cast_ad_func(value_tensor, self->tensor.dtype()); + } } - } - // step3.1: Only basic indexing, use OP set_value. - const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, self->tensor, value_tensor)) { - ConvertAllInputsToDistTensor(mesh, self->tensor, value_tensor); - } - self->tensor = set_value_with_tensor__ad_func(self->tensor, - value_tensor, - slice_starts, - slice_ends, - slice_strides, - slice_axes, - decrease_axis, - none_axes); - if (PyCheckTensor(value_obj)) { - // pass the stop_gradient from value to tensor. - // pass stop gradient should be done after CheckInplace in - // set_value__dygraph_function. - if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() && - egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) { - egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false); + // step3.1: Only basic indexing, use OP set_value. + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self->tensor, value_tensor)) { + ConvertAllInputsToDistTensor(mesh, self->tensor, value_tensor); + } + self->tensor = set_value_with_tensor__ad_func(self->tensor, + value_tensor, + slice_starts, + slice_ends, + slice_strides, + slice_axes, + decrease_axis, + none_axes); + if (PyCheckTensor(value_obj)) { + // pass the stop_gradient from value to tensor. + // pass stop gradient should be done after CheckInplace in + // set_value__dygraph_function. + if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() && + egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) { + egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false); + } + } + } else { + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self->tensor)) { + ConvertAllInputsToDistTensor(mesh, self->tensor); } + self->tensor = set_value__ad_func(self->tensor, + slice_starts, + slice_ends, + slice_strides, + slice_axes, + decrease_axis, + none_axes, + {1}, + values); } } else { // step3.2: Case for there are advanced indexing. diff --git a/paddle/fluid/pybind/slice_utils.h b/paddle/fluid/pybind/slice_utils.h index e60ab9406396a2..82bdcc80562c45 100644 --- a/paddle/fluid/pybind/slice_utils.h +++ b/paddle/fluid/pybind/slice_utils.h @@ -27,9 +27,11 @@ #include "paddle/fluid/framework/scope_guard.h" #include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/fluid/operators/utils.h" +#include "paddle/fluid/pybind/tensor_py.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/dense_tensor.h" +#include "pybind11/numpy.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" @@ -531,5 +533,104 @@ static void ParseBoolAndBroadcastIndices( } } +static paddle::Tensor dealWithValues(const paddle::Tensor& tensor, + PyObject* value_obj, + std::vector* values, + const bool trans_to_tensor) { + paddle::Tensor value_tensor; + if (PyCheckTensor(value_obj)) { + value_tensor = reinterpret_cast(value_obj)->tensor; + } else if (py::isinstance(value_obj)) { + paddle::Tensor value_tensor_tmp( + std::make_shared(), + egr::Controller::Instance().GenerateUniqueName()); + py::object value_obj_tmp(py::handle(value_obj), true); + py::object value = value_obj_tmp; + if (tensor.dtype() == phi::DataType::FLOAT32) { + if (!py::isinstance>(value_obj_tmp)) { + value = pybind11::detail::CastNumpyArray(value_obj_tmp); + } + } else if (tensor.dtype() == phi::DataType::FLOAT64) { + if (!py::isinstance>(value_obj_tmp)) { + value = pybind11::detail::CastNumpyArray(value_obj_tmp); + } + } else if (tensor.dtype() == phi::DataType::INT32) { + if (!py::isinstance>(value_obj_tmp)) { + value = pybind11::detail::CastNumpyArray(value_obj_tmp); + } + } else if (tensor.dtype() == phi::DataType::INT64) { + if (!py::isinstance>(value_obj_tmp)) { + value = pybind11::detail::CastNumpyArray(value_obj_tmp); + } + } else if (tensor.dtype() == phi::DataType::BOOL) { + if (!py::isinstance>(value_obj_tmp)) { + value = pybind11::detail::CastNumpyArray(value_obj_tmp); + } + } else if (tensor.dtype() == phi::DataType::COMPLEX64) { + if (!py::isinstance>>(value_obj_tmp)) { + value = pybind11::detail::CastNumpyArray>( + value_obj_tmp); + } + } else if (tensor.dtype() == phi::DataType::COMPLEX128) { + if (!py::isinstance>>(value_obj_tmp)) { + value = pybind11::detail::CastNumpyArray>( + value_obj_tmp); + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "When assign a numpy.np value to a paddle.Tensor, " + "the data type of the paddle.Tensor must be bool, " + "float32, float64, complex64, complex128, int32 or int64, " + "please check the type of tensor.")); + } + SetTensorFromPyArray( + static_cast(value_tensor_tmp.impl().get()), + value, + tensor.place(), + false); + value_tensor = value_tensor_tmp; + } else { + py::object value_obj_tmp(py::handle(value_obj), true); + // convert the value to self data type + if (py::isinstance(value_obj_tmp) || + py::isinstance(value_obj_tmp) || + py::isinstance(value_obj_tmp) || + PyComplex_Check(value_obj)) { + if (tensor.dtype() == phi::DataType::FLOAT32 || + tensor.dtype() == phi::DataType::FLOAT16 || + tensor.dtype() == phi::DataType::BFLOAT16) { + values->push_back(value_obj_tmp.cast()); + } else if (tensor.dtype() == phi::DataType::FLOAT64) { + values->push_back(value_obj_tmp.cast()); + } else if (tensor.dtype() == phi::DataType::INT32 || + tensor.dtype() == phi::DataType::INT16 || + tensor.dtype() == phi::DataType::INT8 || + tensor.dtype() == phi::DataType::UINT8) { + values->push_back(value_obj_tmp.cast()); + } else if (tensor.dtype() == phi::DataType::INT64) { + values->push_back(value_obj_tmp.cast()); + } else if (tensor.dtype() == phi::DataType::BOOL) { + values->push_back(value_obj_tmp.cast()); + } else if (tensor.dtype() == phi::DataType::COMPLEX64) { + values->push_back(value_obj_tmp.cast>()); + } else if (tensor.dtype() == phi::DataType::COMPLEX128) { + values->push_back(value_obj_tmp.cast>()); + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Value type error. The assign value allows " + "Tensor, numpy.ndarray, integer, float, complex or bool, " + "but received %s.", + Py_TYPE(value_obj))); + } + + if (trans_to_tensor) { + value_tensor = + full_ad_func({1}, (*values)[0], tensor.dtype(), tensor.place()); + } + } + return value_tensor; +} + } // namespace pybind } // namespace paddle diff --git a/python/paddle/base/dygraph/tensor_patch_methods.py b/python/paddle/base/dygraph/tensor_patch_methods.py index a6d1f90df4fa48..aed4833188d6c1 100644 --- a/python/paddle/base/dygraph/tensor_patch_methods.py +++ b/python/paddle/base/dygraph/tensor_patch_methods.py @@ -975,7 +975,7 @@ def __array__(self, dtype=None): array = array.astype(dtype) return array - def pre_deal_index_and_value(self, item, value=None): + def pre_deal_index(self, item): # since in pybind there is no effiency way to transfer Py_Tuple/Py_List/Py_Range to Tensor # we call this function in python level. item = list(item) if isinstance(item, tuple) else [item] @@ -985,17 +985,14 @@ def pre_deal_index_and_value(self, item, value=None): elif isinstance(slice_item, range): item[i] = paddle.to_tensor(list(slice_item)) - if value is not None and not isinstance(value, Variable): - value = paddle.to_tensor(value, dtype=self.dtype) - - return tuple(item), value + return tuple(item) def __getitem__(self, item): - item, _ = pre_deal_index_and_value(self, item) + item = pre_deal_index(self, item) return self._getitem_dygraph(item) def __setitem__(self, item, value): - item, value = pre_deal_index_and_value(self, item, value) + item = pre_deal_index(self, item) return self._setitem_dygraph(item, value) @framework.dygraph_only From cfad7d2abdee711cdc02b7e13ee38eaf6b0afdb2 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Tue, 2 Jan 2024 15:38:44 +0800 Subject: [PATCH 066/142] [PIR]Support custom op in PIR (#59790) * support custom op in pir * fix compile bugs * fix bugs * delete code * fix windows bugs * fix windows bugs * add symbol to paddle lib * fix windows bugs * revert code * fix bugs * fix bugs * perfect code according comment * fix py3 * revert third party * fix bugs * fix bug * fix compile bugs * fix windows --- paddle/common/hash_funcs.h | 42 ++ .../fluid/framework/custom_operator_utils.h | 28 +- .../instruction/custom_kernel_instruction.cc | 499 ++++++++++++++++++ .../instruction/custom_kernel_instruction.h | 78 +++ .../instruction/legacy_kernel_instruction.cc | 3 +- .../onednn/onednn_phi_kernel_instruction.cc | 2 +- .../instruction/phi_kernel_instruction.cc | 3 +- .../interpreter/interpreter_util.cc | 4 +- .../pir_adaptor/pir_adaptor_util.cc | 2 +- .../framework/new_executor/pir_interpreter.cc | 5 + .../fluid/inference/api/analysis_predictor.cc | 1 - .../inference/api/demo_ci/CMakeLists.txt | 13 + .../inference/api/demo_ci/custom_op_demo.cc | 64 +++ .../inference/api/demo_ci/custom_relu_op.cc | 105 ++++ .../inference/api/demo_ci/custom_relu_op.cu | 71 +++ paddle/fluid/inference/api/demo_ci/run.sh | 33 ++ paddle/fluid/inference/api/helper.cc | 20 + .../ir_adaptor/translator/op_translator.cc | 39 +- .../pir/dialect/kernel/ir/kernel_dialect.cc | 119 +++-- .../pir/dialect/kernel/ir/kernel_dialect.h | 18 + .../fluid/pir/dialect/kernel/ir/kernel_op.cc | 42 ++ .../fluid/pir/dialect/kernel/ir/kernel_op.h | 13 + .../dialect/operator/interface/op_yaml_info.h | 14 +- .../pir/dialect/operator/ir/op_dialect.cc | 306 +++++++++-- .../pir/dialect/operator/ir/op_dialect.h | 35 ++ paddle/fluid/pir/transforms/inplace_pass.cc | 10 +- .../pir/transforms/pd_op_to_kernel_pass.cc | 257 ++++++--- paddle/fluid/pybind/pir.cc | 2 +- paddle/phi/kernels/autotune/cache_base.h | 28 +- paddle/pir/core/builtin_type_storage.h | 18 +- paddle/pir/core/operation.cc | 2 - test/cpp/pir/shape_dialect/CMakeLists.txt | 6 +- 32 files changed, 1624 insertions(+), 258 deletions(-) create mode 100644 paddle/common/hash_funcs.h create mode 100644 paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.cc create mode 100644 paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.h create mode 100644 paddle/fluid/inference/api/demo_ci/custom_op_demo.cc create mode 100755 paddle/fluid/inference/api/demo_ci/custom_relu_op.cc create mode 100644 paddle/fluid/inference/api/demo_ci/custom_relu_op.cu diff --git a/paddle/common/hash_funcs.h b/paddle/common/hash_funcs.h new file mode 100644 index 00000000000000..e4a905ff539b98 --- /dev/null +++ b/paddle/common/hash_funcs.h @@ -0,0 +1,42 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +inline void HashCombine(std::size_t* seed) {} + +// combine hash value +// https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x +template +inline void HashCombine(std::size_t* seed, const T& v, Rest... rest) { + std::hash hasher; + *seed ^= hasher(v) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); + *seed *= 0x00000100000001B3; + HashCombine(seed, rest...); +} + +// custom specialization of std::hash can be injected in namespace std +// ref: https://en.cppreference.com/w/cpp/utility/hash +namespace std { +template +struct hash> { + std::size_t operator()(std::vector const& vec) const noexcept { + std::size_t seed = 0xcbf29ce484222325; + for (auto val : vec) { + HashCombine(&seed, val); + } + return seed; + } +}; +} // namespace std diff --git a/paddle/fluid/framework/custom_operator_utils.h b/paddle/fluid/framework/custom_operator_utils.h index ec00e8b9d0d6bc..bf1750dfdbbb50 100644 --- a/paddle/fluid/framework/custom_operator_utils.h +++ b/paddle/fluid/framework/custom_operator_utils.h @@ -19,10 +19,11 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/string/string_helper.h" #include "paddle/phi/api/ext/op_meta_info.h" +#include "paddle/phi/core/enforce.h" namespace paddle { namespace framework { - +constexpr char kCustomDialectPrefix[] = "custom_op."; // NOLINT namespace detail { // dynamic lib load func @@ -81,6 +82,31 @@ inline static bool IsMemberOf(const std::vector& vec, return std::find(vec.cbegin(), vec.cend(), name) != vec.cend(); } +inline static const OpMetaInfo& GetOpInfoByPirName( + const std::string& pir_op_name) { + auto custom_name = pir_op_name.substr(strlen(kCustomDialectPrefix)); + int pos = custom_name.length(); + if (custom_name.find("_grad_grad") != custom_name.npos) { + pos = custom_name.find("_grad_grad") + 1; + } else if (custom_name.find("_grad") != custom_name.npos) { + pos = custom_name.find("_grad") + 1; + } + auto custom_name_prefix = custom_name.substr(0, pos); + auto map_iter = + paddle::OpMetaInfoMap::Instance().GetMap().find(custom_name_prefix); + if (map_iter == paddle::OpMetaInfoMap::Instance().GetMap().end()) { + PADDLE_THROW("The info of custom op : " + custom_name + " is not exists!"); + } + const auto& vec_op_meta = map_iter->second; + if (custom_name.find("_grad_grad") != custom_name.npos) { + return vec_op_meta[2]; + } else if (custom_name.find("_grad") != custom_name.npos) { + return vec_op_meta[1]; + } else { + return vec_op_meta[0]; + } +} + } // namespace detail } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.cc new file mode 100644 index 00000000000000..a585976fd6b9af --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.cc @@ -0,0 +1,499 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.h" +#include "paddle/fluid/framework/custom_operator_utils.h" +#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" +#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/value.h" + +namespace paddle { +namespace framework { + +void CustomKernelInstruction::BuildCustomContext( + const paddle::dialect::OpYamlInfoParser& op_yaml_info) { + Scope* inner_scope = value_exec_info_.GetScope(); + VLOG(6) << "Build custom op infermeta param inner_scope[" << inner_scope + << "]"; + + auto attr_map = op_->attributes(); + + // EmplaceBackInputs + auto& vec_input_tensor_params = op_yaml_info.TensorParams(true); + auto& name2id = op_yaml_info.InputName2Id(); + for (auto& t : vec_input_tensor_params) { + PADDLE_ENFORCE_EQ( + name2id.count(t), + true, + phi::errors::NotFound("param [%s] MUST in name2id map", t)); + + pir::Value ptr = op_->operand_source(op_yaml_info.InputName2Id().at(t)); + + if (!IsInvalid(ptr)) { + if (op_yaml_info.GetInputType(op_yaml_info.InputName2Id().at(t)) == + "pir::VectorType") { + vec_input_shapes_.emplace_back(); + vec_input_dtypes_.emplace_back(); + // NOTE(YuanRisheng): In dygraph mode, we can not distinguish Tensor and + // vector when user inputs None, so dygraph mode appends one + // un-initialized Tensor to CustomOpKernelContext. To be compatible with + // dygraph mode, `custom_vec_in` also emplace_back one un-initialized + // tensor here. + std::vector custom_vec_in; + custom_vec_in.emplace_back(paddle::Tensor()); + custom_kernel_ctx_.EmplaceBackInputs(std::move(custom_vec_in)); + } else { + input_shapes_.emplace_back(); + input_dtypes_.emplace_back(); + custom_kernel_ctx_.EmplaceBackInput(std::move(paddle::Tensor())); + } + VLOG(8) << "ctx->EmplaceBackInput : an optioanl input " << t; + continue; + } + + auto in_var_name = value_exec_info_.GetVarName(ptr); + VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name; + + PADDLE_ENFORCE_NOT_NULL(inner_scope->FindVar(in_var_name), + phi::errors::PreconditionNotMet( + "can not find var[%s] in scope", in_var_name)); + auto var = inner_scope->FindVar(in_var_name); + if (var->IsType()) { + auto dense_tensor_in = var->GetMutable(); + std::shared_ptr tensor_in( + dense_tensor_in, [](phi::DenseTensor* ptr) { + VLOG(6) << ptr << " ptr will not be deleted by shared_ptr"; + }); + input_shapes_.push_back(phi::vectorize(tensor_in->dims())); + input_dtypes_.push_back(tensor_in->dtype()); + paddle::Tensor custom_in; + custom_in.set_impl(tensor_in); + custom_kernel_ctx_.EmplaceBackInput(std::move(custom_in)); + } else if (var->IsType()) { + std::vector> vec_input_shape; + std::vector vec_input_dtype; + std::vector vec_custom_in; + auto& variable_array = var->Get(); + for (size_t i = 0; i < variable_array.size(); ++i) { + if (variable_array[i]->IsType()) { + phi::DenseTensor* dense_tensor_in = const_cast( + &(variable_array[i]->Get())); + std::shared_ptr tensor_in( + dense_tensor_in, [](phi::DenseTensor* ptr) { + VLOG(6) << ptr << " ptr will not be deleted by shared_ptr"; + }); + vec_input_shape.push_back(phi::vectorize(tensor_in->dims())); + vec_input_dtype.push_back(tensor_in->dtype()); + paddle::Tensor custom_in; + custom_in.set_impl(tensor_in); + vec_custom_in.push_back(std::move(custom_in)); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support Vector and vector now, " + "not support vector<%d>.", + variable_array[i]->Type())); + } + } + vec_input_shapes_.push_back(vec_input_shape); + vec_input_dtypes_.push_back(vec_input_dtype); + custom_kernel_ctx_.EmplaceBackInputs(vec_custom_in); + } else { + PADDLE_THROW(phi::errors::Unimplemented("Not support var type [%d] ", + var->Type())); + } + } + + // EmplaceBackAttributes + auto& vec_attr_params = op_yaml_info.AttrParams(true); + for (auto& t : vec_attr_params) { + PADDLE_ENFORCE_NE( + attr_map.find(t), + attr_map.end(), + phi::errors::NotFound("Not found %s in attr_map, it maybe need mapping " + "it in OpTranslator.", + t)); + auto& attr_type_name = op_yaml_info.AttrTypeName(t); + if (attr_type_name == "pir::Int32Attribute") { + custom_attrs_.push_back( + attr_map[t].dyn_cast().data()); + custom_kernel_ctx_.EmplaceBackAttr( + attr_map[t].dyn_cast().data()); + } else if (attr_type_name == "pir::Int64Attribute") { + custom_attrs_.push_back( + attr_map[t].dyn_cast().data()); + custom_kernel_ctx_.EmplaceBackAttr( + attr_map[t].dyn_cast().data()); + } else if (attr_type_name == "pir::FloatAttribute") { + custom_attrs_.push_back( + attr_map[t].dyn_cast().data()); + custom_kernel_ctx_.EmplaceBackAttr( + attr_map[t].dyn_cast().data()); + } else if (attr_type_name == "pir::DoubleAttribute") { + custom_attrs_.push_back( + attr_map[t].dyn_cast().data()); + custom_kernel_ctx_.EmplaceBackAttr( + attr_map[t].dyn_cast().data()); + } else if (attr_type_name == "pir::BoolAttribute") { + custom_attrs_.push_back( + attr_map[t].dyn_cast().data()); + custom_kernel_ctx_.EmplaceBackAttr( + attr_map[t].dyn_cast().data()); + } else if (attr_type_name == "pir::StrAttribute") { + custom_attrs_.push_back( + attr_map[t].dyn_cast().AsString()); + custom_kernel_ctx_.EmplaceBackAttr( + attr_map[t].dyn_cast().AsString()); + } else if (attr_type_name == "pir::ArrayAttribute") { + auto array_list = attr_map[t].dyn_cast().AsVector(); + std::vector vec_res; + if (array_list.size() > 0) { + PADDLE_ENFORCE_EQ( + array_list[0].isa(), + true, + phi::errors::Unimplemented( + "the 0th elementwise MUST be pir::Int32Attribute")); + for (size_t i = 0; i < array_list.size(); ++i) { + vec_res.push_back( + array_list[i].dyn_cast().data()); + } + } + custom_attrs_.push_back(vec_res); + custom_kernel_ctx_.EmplaceBackAttr(vec_res); + } else if (attr_type_name == "pir::ArrayAttribute") { + auto array_list = attr_map[t].dyn_cast().AsVector(); + std::vector vec_res; + if (array_list.size() > 0) { + if (array_list[0].isa()) { + for (size_t i = 0; i < array_list.size(); ++i) { + vec_res.push_back( + array_list[i].dyn_cast().data()); + } + + } else { + PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ", + attr_type_name)); + } + } + custom_attrs_.push_back(vec_res); + custom_kernel_ctx_.EmplaceBackAttr(vec_res); + } else if (attr_type_name == "pir::ArrayAttribute") { + auto array_list = attr_map[t].dyn_cast().AsVector(); + + std::vector vec_res; + if (array_list.size() > 0) { + PADDLE_ENFORCE_EQ( + array_list[0].isa(), + true, + phi::errors::PreconditionNotMet( + "Element in array list MUST be pir::Int64Attribute ")); + + for (size_t i = 0; i < array_list.size(); ++i) { + vec_res.push_back( + array_list[i].dyn_cast().data()); + } + } + custom_attrs_.push_back(vec_res); + custom_kernel_ctx_.EmplaceBackAttr(vec_res); + } else if (attr_type_name == "pir::ArrayAttribute") { + auto array_list = attr_map[t].dyn_cast().AsVector(); + + std::vector vec_res; + if (array_list.size() > 0) { + PADDLE_ENFORCE_EQ( + array_list[0].isa(), + true, + phi::errors::PreconditionNotMet( + "Element in array list MUST be pir::StrAttribute ")); + + for (size_t i = 0; i < array_list.size(); ++i) { + vec_res.push_back( + array_list[i].dyn_cast().AsString()); + } + } + custom_attrs_.push_back(vec_res); + custom_kernel_ctx_.EmplaceBackAttr(vec_res); + + } else { + PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ", + attr_type_name)); + } + VLOG(6) << "ctx->EmplaceBackAttr: " << t; + } + + // EmplaceBackOutputs + VLOG(8) << "ctx->EmplaceBackOutput: "; + for (size_t i = 0; i < op_->num_results(); ++i) { + pir::Value out_ptr = op_->result(i); + if (!IsInvalid(out_ptr)) { + if (op_yaml_info.GetOutputType(i) == + "pir::VectorType") { + std::vector custom_vec_out; + custom_vec_out.emplace_back(); + cache_out_ptrs_.emplace_back(nullptr); + custom_kernel_ctx_.EmplaceBackOutputs(std::move(custom_vec_out)); + } else { + cache_out_ptrs_.emplace_back(nullptr); + custom_kernel_ctx_.EmplaceBackOutput(std::move(paddle::Tensor())); + } + VLOG(8) << "ctx->EmplaceBackOutput : an optioanl output"; + continue; + } + + if (out_ptr.type().isa()) { + auto dense_tensor_out = + inner_scope->FindVar(value_exec_info_.GetVarName(out_ptr)) + ->GetMutable(); + cache_out_ptrs_.push_back(dense_tensor_out); + std::shared_ptr tensor_out( + dense_tensor_out, [](phi::DenseTensor* ptr) { + VLOG(6) << ptr << " ptr will not be deleted by shared_ptr"; + }); + paddle::Tensor custom_out; + // here only can copy the output tensor into context + custom_out.set_impl(tensor_out); + + custom_kernel_ctx_.EmplaceBackOutput(std::move(custom_out)); + VLOG(8) << "ctx->EmplaceBackOutput DenseTensor: " + << value_exec_info_.GetVarName(out_ptr); + } else if (out_ptr.type().isa()) { + std::vector vec_custom_out; + auto& variable_array = + inner_scope->FindVar(value_exec_info_.GetVarName(out_ptr)) + ->Get(); + std::vector custom_vec_out; + for (size_t i = 0; i < variable_array.size(); ++i) { + if (variable_array[i]->IsType()) { + auto dense_tensor_out = const_cast( + &(variable_array[i]->Get())); + cache_out_ptrs_.emplace_back(dense_tensor_out); + std::shared_ptr tensor_out( + dense_tensor_out, [](phi::DenseTensor* ptr) { + VLOG(6) << ptr << " ptr will not be deleted by shared_ptr"; + }); + paddle::Tensor custom_out; + custom_out.set_impl(tensor_out); + custom_vec_out.push_back(std::move(custom_out)); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support Vector and vector now, " + "not support vector<%d>.", + variable_array[i]->Type())); + } + } + VLOG(8) << "ctx->EmplaceBackOutput VariableRefArray: " + << value_exec_info_.GetVarName(out_ptr); + custom_kernel_ctx_.EmplaceBackOutputs(custom_vec_out); + } else { + PADDLE_THROW( + phi::errors::Unimplemented("only support DenseTensor and vector ")); + } + } + auto& op_inputs = OpMetaInfoHelper::GetInputs(*custom_op_meta_); + auto& op_outputs = OpMetaInfoHelper::GetOutputs(*custom_op_meta_); + auto& op_inplace_map = OpMetaInfoHelper::GetInplaceMap(*custom_op_meta_); + // handle inplace map + custom_kernel_ctx_.UpdatePlainOutputs(op_inputs, op_outputs, op_inplace_map); + VLOG(6) << "Done build custom context"; +} + +CustomKernelInstruction::CustomKernelInstruction( + size_t id, + const platform::Place& place, + pir::Operation* op, + const ValueExecutionInfo& value_exec_info) + : InstructionBase(id, place), value_exec_info_(value_exec_info) { + auto op_attributes = op->attributes(); + auto op_name = + op_attributes.at("op_name").dyn_cast().AsString(); + pir::OpInfo op_info = + pir::IrContext::Instance()->GetRegisteredOpInfo(op_name); + op_ = op; + custom_op_name_ = op_name; + VLOG(6) << "construct custom kernel instruction for: " << custom_op_name_; + + VLOG(6) << "finish process dist attributes"; + + SetKernelType(AnalyseOpFuncType(op, place)); + VLOG(6) << "finish process analyse kernel type"; + + auto yaml_interface = + op_info.GetInterfaceImpl(); + PADDLE_ENFORCE_NOT_NULL( + yaml_interface, + phi::errors::PreconditionNotMet( + "can not find OpYamlInfoInterface from [%s]", custom_op_name_)); + paddle::dialect::OpYamlInfoParser yaml_info_parser( + yaml_interface->get_op_info_(custom_op_name_), + paddle::dialect::IsLegacyOp(custom_op_name_)); + VLOG(6) << "finish process yaml_info_parser"; + + const auto& op_meta = + paddle::framework::detail::GetOpInfoByPirName(custom_op_name_); + custom_op_meta_ = &op_meta; + infershape_func_ = OpMetaInfoHelper::GetInferShapeFn(op_meta); + inferdtype_func_ = OpMetaInfoHelper::GetInferDtypeFn(op_meta); + kernel_func_ = OpMetaInfoHelper::GetKernelFn(op_meta); + BuildCustomContext(yaml_info_parser); + VLOG(6) << "finish process custom context"; + auto kernel_key = op_attributes.at("kernel_key") + .dyn_cast() + .data(); + SetDeviceContext( + ParseDeviceContext(op, + phi::DeviceContextPool::Instance().Get( + phi::TransToPhiPlace(kernel_key.backend())), + place, + GetExecutionStream(), + GetStreamPriority())); + VLOG(6) << "finish process device context"; + + InitInputsOutputsIds(op, value_exec_info_); + VLOG(6) << "finish process inputs outputs index"; + + auto& no_need_buffer_ids = yaml_info_parser.NoNeedBufferIds(); + std::unordered_set no_need_buffer_values; + for (size_t id = 0; id < no_need_buffer_ids.size(); id++) { + no_need_buffer_values.insert(op->operand_source(no_need_buffer_ids[id])); + } + SetNoNeedBuffer(no_need_buffer_values); + VLOG(6) << "finish process no need buffer"; +} + +void CustomKernelInstruction::UpdateOutputMeta( + const std::vector>& output_shapes, + const std::vector& output_dtypes) { + PADDLE_ENFORCE_EQ( + output_shapes.size(), + cache_out_ptrs_.size(), + phi::errors::InvalidArgument( + "The number of output shapes after running custom operator's " + "InferShapeFunc is wrong, " + "expected contains %d Tensors' shape, but actually contains %d " + "Tensors' shape", + cache_out_ptrs_.size(), + output_shapes.size())); + + PADDLE_ENFORCE_EQ( + output_dtypes.size(), + cache_out_ptrs_.size(), + phi::errors::InvalidArgument( + "The number of output dtypes after running custom operator's " + "InferDtypeFunc is wrong, " + "expected contains %d Tensors' dtype, but actually contains %d " + "Tensors' dtype", + cache_out_ptrs_.size(), + output_dtypes.size())); + + for (size_t i = 0; i < cache_out_ptrs_.size(); ++i) { + auto out_in_scope = cache_out_ptrs_.at(i); + // update dims and dtype + auto out_meta = phi::DenseTensorUtils::GetMutableMeta(out_in_scope); + out_meta->dims = phi::make_ddim(output_shapes[i]); + out_meta->dtype = output_dtypes[i]; + } +} + +void CustomKernelInstruction::Run() { + VLOG(3) << "Custom Operator: InferShape - calc output ddim."; + std::vector> output_shapes; + std::vector output_dtypes; + if (infershape_func_) { + output_shapes = + infershape_func_(input_shapes_, vec_input_shapes_, custom_attrs_); + } else { + PADDLE_ENFORCE_EQ( + OpMetaInfoHelper::GetInputs(*custom_op_meta_).size(), + 1UL, + phi::errors::Unavailable( + "Your custom operator contains multiple inputs. " + "We only allow a custom operator that contains only one input " + "and only one output without setting the InferShapeFn. " + "At this time, the input shape will be directly set to " + "the output shape.\n" + "Please set the InferShapeFn of custom " + "operator by .SetInferShapeFn(PD_INFER_SHAPE(...))")); + PADDLE_ENFORCE_EQ( + OpMetaInfoHelper::GetOutputs(*custom_op_meta_).size(), + 1UL, + phi::errors::Unavailable( + "Your custom operator contains multiple outputs. " + "We only allow a custom operator that contains only one input " + "and only one output without setting the InferShapeFn. " + "At this time, the input shape will be directly set to " + "the output shape.\n" + "Please set the InferShapeFn of custom " + "operator by .SetInferShapeFn(PD_INFER_SHAPE(...))")); + + VLOG(3) << "Custom Operator: Default InferShape - share ddim."; + if (input_shapes_.size() == 1) { + output_shapes = input_shapes_; + } else if (vec_input_shapes_.size() == 1) { + output_shapes = vec_input_shapes_[0]; + } else { + PADDLE_THROW(phi::errors::Unavailable( + "We only allow a custom operator that contains only one input " + "and only one output without setting the InferShapeFn. ")); + } + } + + if (inferdtype_func_) { + output_dtypes = + inferdtype_func_(input_dtypes_, vec_input_dtypes_, custom_attrs_); + } else { + PADDLE_ENFORCE_EQ( + OpMetaInfoHelper::GetInputs(*custom_op_meta_).size(), + 1UL, + phi::errors::Unavailable( + "Your custom operator contains multiple inputs. " + "We only allow a custom operator that contains only one input " + "and only one output without setting the InferDtypeFn. " + "At this time, the input dtype will be directly set to " + "the output dtype.\n" + "Please set the InferDtypeFn of custom " + "operator by `.SetInferDtypeFn(PD_INFER_DTYPE(...))`")); + PADDLE_ENFORCE_EQ( + OpMetaInfoHelper::GetOutputs(*custom_op_meta_).size(), + 1UL, + phi::errors::Unavailable( + "Your custom operator contains multiple outputs. " + "We only allow a custom operator that contains only one input " + "and only one output without setting the InferDtypeFn. " + "At this time, the input dtype will be directly set to " + "the output dtype.\n" + "Please set the InferDtypeFn of custom " + "operator by `.SetInferDtypeFn(PD_INFER_DTYPE(...))`")); + + VLOG(3) << "Custom Operator: InferDtype - share dtype."; + if (input_dtypes_.size() == 1) { + output_dtypes = input_dtypes_; + } else if (vec_input_dtypes_.size() == 1) { + output_dtypes = vec_input_dtypes_[0]; + } else { + PADDLE_THROW(phi::errors::Unavailable( + "We only allow a custom operator that contains only one input " + "and only one output without setting the InferDtypeFn. ")); + } + } + UpdateOutputMeta(output_shapes, output_dtypes); + + VLOG(6) << "Run custom op " << custom_op_name_ << " kernel."; + kernel_func_(&custom_kernel_ctx_); + custom_kernel_ctx_.AssignInplaceOutputs(); +} +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.h b/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.h new file mode 100644 index 00000000000000..6c6a7d90ae8f0f --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.h @@ -0,0 +1,78 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" +#include "paddle/phi/api/ext/op_meta_info.h" + +namespace pir { +class Operation; +} // namespace pir + +namespace paddle { +namespace framework { +class Scope; + +class CustomKernelInstruction : public InstructionBase { + public: + CustomKernelInstruction(size_t id, + const platform::Place& place, + ::pir::Operation* op, + const ValueExecutionInfo& value_exec_info); + + ::pir::Operation* Operation() const override { return op_; } + + void Run() override; + + const std::string& Name() const override { return custom_op_name_; } + + void clear(); + + private: + void BuildCustomContext( + const paddle::dialect::OpYamlInfoParser& op_yaml_info); + + void UpdateOutputMeta(const std::vector>& output_shapes, + const std::vector& output_dtypes); + + paddle::CustomOpKernelContext custom_kernel_ctx_; + + paddle::InferShapeFunc infershape_func_ = nullptr; + paddle::InferDtypeFunc inferdtype_func_ = nullptr; + paddle::KernelFunc kernel_func_ = nullptr; + + // use for runing infershape + std::vector> input_shapes_; + std::vector>> vec_input_shapes_; + std::vector custom_attrs_; + + // use for runing inferdtype + std::vector input_dtypes_; + std::vector> vec_input_dtypes_; + + // use for update output + std::vector cache_out_ptrs_; + + std::string custom_op_name_; + + ::pir::Operation* op_{nullptr}; // not owned + + const paddle::OpMetaInfo* custom_op_meta_; // not owned + const ValueExecutionInfo& value_exec_info_; // not owned +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc index 1dc779b1d43a96..812f86704ee507 100644 --- a/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc @@ -106,7 +106,8 @@ LegacyKernelInstruction::LegacyKernelInstruction( phi::errors::PreconditionNotMet( "can not find OpYamlInfoInterface from [%s]", legacy_op_name_)); paddle::dialect::OpYamlInfoParser yaml_info_parser( - yaml_interface->get_op_info_(), paddle::dialect::IsLegacyOp(op_name)); + yaml_interface->get_op_info_(op_name), + paddle::dialect::IsLegacyOp(op_name)); VLOG(6) << "finish process yaml_info_parser"; if (infer_meta_interface_) { diff --git a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.cc index 71385619cb958b..fb8407a1a7ea34 100644 --- a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_phi_kernel_instruction.cc @@ -216,7 +216,7 @@ OneDNNPhiKernelInstruction::OneDNNPhiKernelInstruction( phi::errors::PreconditionNotMet( "can not find OpYamlInfoInterface from [%s]", phi_op_name_)); paddle::dialect::OpYamlInfoParser yaml_info_parser( - yaml_interface->get_op_info_(), + yaml_interface->get_op_info_(op_name), paddle::dialect::IsOneDNNLegacyOp(op_name)); VLOG(6) << "finish process yaml_info_parser"; diff --git a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc index 798735f24058db..ed5bee9ce87772 100644 --- a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc @@ -110,7 +110,8 @@ PhiKernelInstruction::PhiKernelInstruction( phi::errors::PreconditionNotMet( "can not find OpYamlInfoInterface from [%s]", phi_op_name_)); paddle::dialect::OpYamlInfoParser yaml_info_parser( - yaml_interface->get_op_info_(), paddle::dialect::IsLegacyOp(op_name)); + yaml_interface->get_op_info_(op_name), + paddle::dialect::IsLegacyOp(op_name)); VLOG(6) << "finish process yaml_info_parser"; if (infer_meta_interface_) { diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 614b97c26b7b07..0a111922d4409b 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -1304,7 +1304,7 @@ std::vector GetOriginInputNames(const std::string& op_name) { if (op_info.GetInterfaceImpl()) { paddle::dialect::OpYamlInfoParser yaml_parser( op_info.GetInterfaceImpl() - ->get_op_info_()); + ->get_op_info_(op_name)); ret = yaml_parser.InputNames(); } return ret; @@ -1317,7 +1317,7 @@ std::vector GetOriginOutputNames(const std::string& op_name) { if (op_info.GetInterfaceImpl()) { paddle::dialect::OpYamlInfoParser yaml_parser( op_info.GetInterfaceImpl() - ->get_op_info_()); + ->get_op_info_(op_name)); ret = yaml_parser.OutputNames(); } return ret; diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc index a06abb197de5fe..be32e1f473a1b5 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc @@ -658,7 +658,7 @@ void HandleForInplaceOp(pir::Operation* op, pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name); paddle::dialect::OpYamlInfoParser yaml_parser( op_info.GetInterfaceImpl() - ->get_op_info_(), + ->get_op_info_(op_name), paddle::dialect::IsLegacyOp(op_name)); for (size_t i = 0; i < op->num_results(); ++i) { diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.cc b/paddle/fluid/framework/new_executor/pir_interpreter.cc index 2afdfb5e9717ad..19e3d6e86ebdeb 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/pir_interpreter.cc @@ -56,6 +56,7 @@ #include "paddle/fluid/framework/new_executor/instruction/control_flow/tuple_pop_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/control_flow/tuple_push_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h" #include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" @@ -749,6 +750,10 @@ void PirInterpreter::BuildInstruction() { } else if (op.dialect()->name() == "cinn_runtime") { CREATE_INSTR(CinnJitInstruction); #endif + } else if (op.dialect()->name() == "custom_kernel") { + vec_instruction_base_.emplace_back( + std::make_unique( + op_idx++, place_, &op, *(value_exe_info_.get()))); } else { PADDLE_THROW(platform::errors::Unimplemented( "Now only support pd_kernel and cinn dialect.")); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 4af55a7c6c9337..4b52ceb58ff777 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -783,7 +783,6 @@ bool AnalysisPredictor::PrepareExecutor() { auto output_names = GetOutputNames(); execution_config.skip_gc_vars.insert(output_names.begin(), output_names.end()); - if (FLAGS_enable_pir_in_executor) { pir_program_ = std::move( paddle::TranslateLegacyProgramToProgram(*inference_program_)); diff --git a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt index 0cca4532a0ce6b..778ce2055e0b5d 100644 --- a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt +++ b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt @@ -7,6 +7,7 @@ option(WITH_STATIC_LIB option(USE_TENSORRT "Compile demo with TensorRT." OFF) option(WITH_ONNXRUNTIME "Compile demo with ONNXRuntime" OFF) option(WITH_SHARED_PHI "Compile demo with phi shared lib" ON) +option(CUSTOM_OPERATOR_FILES "List of file names for custom operators" "") if(NOT WITH_STATIC_LIB) add_definitions("-DPADDLE_WITH_SHARED_LIB") @@ -252,6 +253,18 @@ if(WITH_GPU) endif() endif() +if(CUSTOM_OPERATOR_FILES) + if(WITH_GPU AND NOT APPLE) + add_definitions("-DPADDLE_WITH_CUDA") + enable_language(CUDA) + find_package(CUDA REQUIRED) + include_directories("${CUDA_INCLUDE_DIRS}") + endif() + add_library(pd_infer_custom_op SHARED ${CUSTOM_OPERATOR_FILES}) + target_link_libraries(pd_infer_custom_op ${DEPS}) + set(DEPS ${DEPS} pd_infer_custom_op) +endif() + add_executable(${DEMO_NAME} ${DEMO_NAME}.cc) target_link_libraries(${DEMO_NAME} ${DEPS}) if(WIN32) diff --git a/paddle/fluid/inference/api/demo_ci/custom_op_demo.cc b/paddle/fluid/inference/api/demo_ci/custom_op_demo.cc new file mode 100644 index 00000000000000..b4c8cccb8e7906 --- /dev/null +++ b/paddle/fluid/inference/api/demo_ci/custom_op_demo.cc @@ -0,0 +1,64 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle_inference_api.h" //NOLINT + +DEFINE_string(modeldir, "", "Directory of the inference model."); + +using paddle_infer::Config; +using paddle_infer::CreatePredictor; +using paddle_infer::Predictor; + +void run(Predictor *predictor, + const std::vector &input, + const std::vector &input_shape, + std::vector *out_data) { + auto input_names = predictor->GetInputNames(); + auto input_t = predictor->GetInputHandle(input_names[0]); + input_t->Reshape(input_shape); + input_t->CopyFromCpu(input.data()); + + CHECK(predictor->Run()); + + auto output_names = predictor->GetOutputNames(); + auto output_t = predictor->GetOutputHandle(output_names[0]); + std::vector output_shape = output_t->shape(); + int out_num = std::accumulate( + output_shape.begin(), output_shape.end(), 1, std::multiplies()); + + out_data->resize(out_num); + output_t->CopyToCpu(out_data->data()); +} + +int main(int argc, char **argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + paddle::AnalysisConfig config; + config.EnableUseGpu(100, 0); + config.SetModel(FLAGS_modeldir + "/custom_relu.pdmodel", + FLAGS_modeldir + "/custom_relu.pdiparams"); + config.EnableNewExecutor(true); + auto predictor{paddle_infer::CreatePredictor(config)}; + std::vector input_shape = {1, 1, 28, 28}; + std::vector input_data(1 * 1 * 28 * 28, 1); + std::vector out_data; + run(predictor.get(), input_data, input_shape, &out_data); + for (auto e : out_data) { + LOG(INFO) << e << '\n'; + } + return 0; +} diff --git a/paddle/fluid/inference/api/demo_ci/custom_relu_op.cc b/paddle/fluid/inference/api/demo_ci/custom_relu_op.cc new file mode 100755 index 00000000000000..e55b943a5568f8 --- /dev/null +++ b/paddle/fluid/inference/api/demo_ci/custom_relu_op.cc @@ -0,0 +1,105 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/extension.h" + +template +void relu_cpu_forward_kernel(const data_t* x_data, + data_t* out_data, + int64_t x_numel) { + for (int i = 0; i < x_numel; ++i) { + out_data[i] = std::max(static_cast(0.), x_data[i]); + } +} + +template +void relu_cpu_backward_kernel(const data_t* grad_out_data, + const data_t* out_data, + data_t* grad_x_data, + int64_t out_numel) { + for (int i = 0; i < out_numel; ++i) { + grad_x_data[i] = + grad_out_data[i] * (out_data[i] > static_cast(0) ? 1. : 0.); + } +} + +std::vector relu_cpu_forward(const paddle::Tensor& x) { + auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); + + PD_DISPATCH_FLOATING_TYPES( + x.type(), "relu_cpu_forward", ([&] { + relu_cpu_forward_kernel( + x.data(), out.mutable_data(x.place()), x.size()); + })); + + return {out}; +} + +std::vector relu_cpu_backward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out) { + auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, x.shape()); + + PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] { + relu_cpu_backward_kernel( + grad_out.data(), + out.data(), + grad_x.mutable_data(x.place()), + out.size()); + })); + + return {grad_x}; +} + +std::vector relu_cuda_forward(const paddle::Tensor& x); +std::vector relu_cuda_backward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out); + +std::vector ReluForward(const paddle::Tensor& x) { + // TODO(chenweihang): Check Input + if (x.place() == paddle::PlaceType::kCPU) { + return relu_cpu_forward(x); + } else if (x.place() == paddle::PlaceType::kGPU) { + return relu_cuda_forward(x); + } else { + throw std::runtime_error("Not implemented."); + } +} + +std::vector ReluBackward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out) { + // TODO(chenweihang): Check Input + if (x.place() == paddle::PlaceType::kCPU) { + return relu_cpu_backward(x, out, grad_out); + } else if (x.place() == paddle::PlaceType::kGPU) { + return relu_cuda_backward(x, out, grad_out); + } else { + throw std::runtime_error("Not implemented."); + } +} + +PD_BUILD_OP(custom_relu) + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(ReluForward)); + +PD_BUILD_GRAD_OP(custom_relu) + .Inputs({"X", "Out", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(ReluBackward)); diff --git a/paddle/fluid/inference/api/demo_ci/custom_relu_op.cu b/paddle/fluid/inference/api/demo_ci/custom_relu_op.cu new file mode 100644 index 00000000000000..a4b7fcf06bce6c --- /dev/null +++ b/paddle/fluid/inference/api/demo_ci/custom_relu_op.cu @@ -0,0 +1,71 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" + +template +__global__ void relu_cuda_forward_kernel(const data_t* x, + data_t* y, + const int num) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = gid; i < num; i += blockDim.x * gridDim.x) { + y[i] = max(x[i], static_cast(0.)); + } +} + +template +__global__ void relu_cuda_backward_kernel(const data_t* dy, + const data_t* y, + data_t* dx, + const int num) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = gid; i < num; i += blockDim.x * gridDim.x) { + dx[i] = dy[i] * (y[i] > 0 ? 1. : 0.); + } +} + +std::vector relu_cuda_forward(const paddle::Tensor& x) { + auto out = paddle::Tensor(paddle::PlaceType::kGPU, x.shape()); + + int numel = x.size(); + int block = 512; + int grid = (numel + block - 1) / block; + PD_DISPATCH_FLOATING_TYPES( + x.type(), "relu_cuda_forward_kernel", ([&] { + relu_cuda_forward_kernel<<>>( + x.data(), out.mutable_data(x.place()), numel); + })); + + return {out}; +} + +std::vector relu_cuda_backward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out) { + auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU, x.shape()); + + int numel = out.size(); + int block = 512; + int grid = (numel + block - 1) / block; + PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cuda_backward_kernel", ([&] { + relu_cuda_backward_kernel + <<>>( + grad_out.data(), + out.data(), + grad_x.mutable_data(x.place()), + numel); + })); + + return {grad_x}; +} diff --git a/paddle/fluid/inference/api/demo_ci/run.sh b/paddle/fluid/inference/api/demo_ci/run.sh index 50112b20f29a02..795b414258b560 100755 --- a/paddle/fluid/inference/api/demo_ci/run.sh +++ b/paddle/fluid/inference/api/demo_ci/run.sh @@ -102,6 +102,17 @@ else wget -q http://paddle-inference-dist.bj.bcebos.com/word2vec.inference.model.tar.gz tar xzf *.tar.gz fi +cd .. + +#download custom_op_demo data +mkdir -p custom_op +cd custom_op +if [[ -e "custom_relu_infer_model.tgz" ]]; then + echo "custom_relu_infer_model.tgz has been downloaded." +else + wget -q https://paddle-inference-dist.bj.bcebos.com/inference_demo/custom_operator/custom_relu_infer_model.tgz + tar xzf *.tgz +fi # compile and test the demo cd $current_dir @@ -275,6 +286,28 @@ for WITH_STATIC_LIB in ON OFF; do EXIT_CODE=1 fi fi + + # --------custom op demo on linux/mac------ + if [ $TEST_GPU_CPU == ON -a $WITH_STATIC_LIB == OFF ]; then + rm -rf * + CUSTOM_OPERATOR_FILES="custom_relu_op.cc;custom_relu_op.cu" + cmake .. -DPADDLE_LIB=${inference_install_dir} \ + -DWITH_MKL=$TURN_ON_MKL \ + -DDEMO_NAME=custom_op_demo \ + -DWITH_GPU=$TEST_GPU_CPU \ + -DWITH_STATIC_LIB=OFF \ + -DUSE_TENSORRT=$USE_TENSORRT \ + -DTENSORRT_ROOT=$TENSORRT_ROOT_DIR \ + -DCUSTOM_OPERATOR_FILES=$CUSTOM_OPERATOR_FILES \ + -DWITH_ONNXRUNTIME=$WITH_ONNXRUNTIME + make -j$(nproc) + FLAGS_enable_pir_in_executor=1 ./custom_op_demo \ + --modeldir=$DATA_DIR/custom_op/custom_relu_infer_model + if [ $? -ne 0 ]; then + echo "custom_op_demo runs failed " >> ${current_dir}/test_summary.txt + EXIT_CODE=1 + fi + fi fi done diff --git a/paddle/fluid/inference/api/helper.cc b/paddle/fluid/inference/api/helper.cc index 3fd8ed490fe458..44d7a75cae21aa 100644 --- a/paddle/fluid/inference/api/helper.cc +++ b/paddle/fluid/inference/api/helper.cc @@ -16,8 +16,13 @@ #include "paddle/fluid/framework/custom_operator.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/platform/init.h" #include "paddle/phi/api/ext/op_meta_info.h" +#include "paddle/phi/core/flags.h" +#include "paddle/pir/core/ir_context.h" + +PHI_DECLARE_bool(enable_pir_in_executor); namespace paddle { namespace inference { @@ -49,6 +54,21 @@ void RegisterAllCustomOperator() { auto &op_meta_info_map = OpMetaInfoMap::Instance(); const auto &meta_info_map = op_meta_info_map.GetMap(); for (auto &pair : meta_info_map) { + if (FLAGS_enable_pir_in_executor) { + ::pir::IrContext *ctx = ::pir::IrContext::Instance(); + auto *custom_dialect = + ctx->GetOrRegisterDialect(); + if (custom_dialect->HasRegistered(pair.first)) { + LOG(INFO) << "The operator `" << pair.first + << "` has been registered. " + "Therefore, we will not repeat the registration here."; + continue; + } + for (const auto &meta_info : pair.second) { + LOG(INFO) << "register pir custom op :" << pair.first; + custom_dialect->RegisterCustomOp(meta_info); + } + } const auto &all_op_kernels{framework::OperatorWithKernel::AllOpKernels()}; if (all_op_kernels.find(pair.first) == all_op_kernels.end()) { framework::RegisterOperatorWithMetaInfo(pair.second); diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index c64004c7191dd9..68e9a89cefb76a 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -83,6 +83,7 @@ constexpr char kTargetDialectPrefix[] = "pd_op."; // NOLINT #ifdef PADDLE_WITH_DNNL constexpr char kOneDNNTargetDialectPrefix[] = "pd_onednn_op."; // NOLINT #endif +constexpr char kCustomOpDialectPrefix[] = "custom_op."; constexpr char kEmptyVarName[] = "@EMPTY@"; // NOLINT static const std::unordered_set SpecialNonInplaceOps = {}; @@ -229,16 +230,27 @@ inline pir::Operation* InsertCreateArrayOp(pir::IrContext* ctx, return create_array_op.operation(); } +inline bool HasOpInfo(pir::IrContext* ctx, + const OpDesc& op_desc, + std::string prefix) { + std::string target_op_name = prefix + OpNameCompatibleMapping(op_desc.Type()); + if (IsInplace(op_desc) && *target_op_name.rbegin() != '_') { + target_op_name += "_"; + } + auto op_info = ctx->GetRegisteredOpInfo(target_op_name); + if (op_info) { + return true; + } + return false; +} + inline std::string GetPrefix(pir::IrContext* ctx, const OpDesc& op_desc) { + if (HasOpInfo(ctx, op_desc, kCustomOpDialectPrefix)) { + return kCustomOpDialectPrefix; + } #ifdef PADDLE_WITH_DNNL if (op_desc.GetAttrIfExists("use_mkldnn")) { - std::string target_op_name = - kOneDNNTargetDialectPrefix + OpNameCompatibleMapping(op_desc.Type()); - if (IsInplace(op_desc) && *target_op_name.rbegin() != '_') { - target_op_name += "_"; - } - auto op_info = ctx->GetRegisteredOpInfo(target_op_name); - if (!op_info) { + if (!HasOpInfo(ctx, op_desc, kOneDNNTargetDialectPrefix)) { VLOG(3) << op_desc.Type() << "'s use_mkldnn == True, but PIR not support OneDNN for this " "op right now."; @@ -284,7 +296,7 @@ pir::OpInfo OpTranscriber::LoopkUpOpInfo(pir::IrContext* ctx, OpAttributeInfoList attr_infos; OpOutputInfoList output_infos; std::tie(input_infos, attr_infos, output_infos, std::ignore, std::ignore) = - op_info_concept->get_op_info_(); + op_info_concept->get_op_info_(op_info.name()); auto& op_normalizer = OpNameNormalizer::instance(); std::vector need_inputs_sig; @@ -355,9 +367,6 @@ pir::OpInfo OpTranscriber::LoopkUpOpInfo(pir::IrContext* ctx, if (IsInplace(op_desc) && *target_op_name.rbegin() != '_') { target_op_name += "_"; } - VLOG(6) << "[op name normalizing]: " << op_desc.Type() << " to " - << target_op_name; - op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { IR_THROW("Op %d should have corresponding OpInfo %d", op_desc.Type(), @@ -792,8 +801,9 @@ pir::Operation* OpTranscriber::operator()(pir::IrContext* ctx, OpInputInfoList input_infos; OpAttributeInfoList attr_infos; OpOutputInfoList output_infos; + std::tie(input_infos, attr_infos, output_infos, std::ignore, std::ignore) = - op_info_concept->get_op_info_(); + op_info_concept->get_op_info_(op_info.name()); this->InsertSliceOperationForInput( ctx, param_map, op_desc, input_infos, block); @@ -810,7 +820,6 @@ pir::Operation* OpTranscriber::operator()(pir::IrContext* ctx, this->TranslateOpAttribute(ctx, op_info.name(), attr_infos, op_desc); TranslateOpDistAttribute(op_desc, &attribute_map); VLOG(4) << "[general op][" << op_desc.Type() << "] preparation end."; - pir::Operation* operation = pir::Operation::Create( op_inputs, attribute_map, op_output_types, op_info); VLOG(4) << "[general op][" << op_desc.Type() << "] opearation creation end."; @@ -940,7 +949,7 @@ struct AssignValueOpTranscriber : public OpTranscriber { OpAttributeInfoList attr_infos; OpOutputInfoList output_infos; std::tie(input_infos, attr_infos, output_infos, std::ignore, std::ignore) = - op_info_concept->get_op_info_(); + op_info_concept->get_op_info_(op_info.name()); std::unordered_map attr_info_maps; for (auto const& info : attr_infos) { attr_info_maps.insert({info.name, info}); @@ -1274,7 +1283,7 @@ struct FetchOpTranscriber : public OpTranscriber { OpAttributeInfoList attr_infos; OpOutputInfoList output_infos; std::tie(input_infos, attr_infos, output_infos, std::ignore, std::ignore) = - op_info_concept->get_op_info_(); + op_info_concept->get_op_info_(op_info.name()); this->InsertSliceOperationForInput( ctx, param_map, op_desc, input_infos, block); diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.cc b/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.cc index ecf04d4411397b..63e2a83a7dbe97 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.cc +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.cc @@ -26,20 +26,7 @@ REGISTER_FILE_SYMBOLS(kernel_dialect); namespace paddle { namespace dialect { -KernelDialect::KernelDialect(pir::IrContext *context) - : pir::Dialect(name(), context, pir::TypeId::get()) { - initialize(); -} - -void KernelDialect::initialize() { - RegisterTypes(); - RegisterOps(); - RegisterAttributes(); -} - -void KernelDialect::PrintType(pir::Type type, std::ostream &os) const { +void PrintKernelType(pir::Type type, std::ostream &os) { if (type.isa()) { AllocatedDenseTensorType tensor_type = type.dyn_cast(); @@ -75,14 +62,35 @@ void KernelDialect::PrintType(pir::Type type, std::ostream &os) const { } } -void KernelDialect::PrintAttribute(pir::Attribute attr, - std::ostream &os) const { +void PrintKernelAttribute(pir::Attribute attr, std::ostream &os) { phi::KernelKey kernel = attr.dyn_cast().data(); os << ""; } +KernelDialect::KernelDialect(pir::IrContext *context) + : pir::Dialect(name(), context, pir::TypeId::get()) { + initialize(); +} + +void KernelDialect::initialize() { + RegisterTypes(); + RegisterOps(); + RegisterAttributes(); +} + +void KernelDialect::PrintType(pir::Type type, std::ostream &os) const { + PrintKernelType(type, os); +} + +void KernelDialect::PrintAttribute(pir::Attribute attr, + std::ostream &os) const { + PrintKernelAttribute(attr, os); +} + void KernelDialect::PrintOperation(pir::Operation *op, pir::IrPrinter &printer) const { if (op->dyn_cast() || op->dyn_cast()) { @@ -122,6 +130,45 @@ void KernelDialect::PrintOperation(pir::Operation *op, } } +CustomKernelDialect::CustomKernelDialect(pir::IrContext *context) + : pir::Dialect(name(), context, pir::TypeId::get()) { + initialize(); +} + +void CustomKernelDialect::initialize() { + RegisterTypes(); + RegisterOps(); + RegisterAttributes(); +} + +void CustomKernelDialect::PrintType(pir::Type type, std::ostream &os) const { + PrintKernelType(type, os); +} + +void CustomKernelDialect::PrintAttribute(pir::Attribute attr, + std::ostream &os) const { + PrintKernelAttribute(attr, os); +} + +void CustomKernelDialect::PrintOperation(pir::Operation *op, + pir::IrPrinter &printer) const { + auto &os = printer.os; + printer.PrintOpResult(op); + os << " ="; + auto custom_kernel_op = op->dyn_cast(); + std::string kernel_name = custom_kernel_op.kernel_name(); + if (op->attributes().count("is_inplace") != 0 && + op->attributes().at("is_inplace").dyn_cast().data()) { + kernel_name = kernel_name + "_"; + } + os << " \"" << kernel_name << "(custom_kernel)\""; + printer.PrintOpOperands(op); + printer.PrintAttributeMap(op); + os << " :"; + printer.PrintOperandsType(op); + os << " -> "; + printer.PrintOpReturnType(op); +} #ifdef PADDLE_WITH_DNNL OneDNNKernelDialect::OneDNNKernelDialect(pir::IrContext *context) : pir::Dialect(name(), context, pir::TypeId::get()) { @@ -139,47 +186,12 @@ void OneDNNKernelDialect::initialize() { } void OneDNNKernelDialect::PrintType(pir::Type type, std::ostream &os) const { - if (type.isa()) { - AllocatedDenseTensorType tensor_type = - type.dyn_cast(); - - os << phi::AllocationTypeStr(tensor_type.place().GetType()) << "_"; - os << "tensor<"; - for (auto d : common::vectorize(tensor_type.dims())) { - os << d; - os << "x"; - } - tensor_type.dtype().Print(os); - os << ">"; - } else if (type.isa()) { - AllocatedSelectedRowsType tensor_type = - type.dyn_cast(); - - os << phi::AllocationTypeStr(tensor_type.place().GetType()) << "_"; - os << "tensor<"; - for (auto d : common::vectorize(tensor_type.dims())) { - os << d; - os << "x"; - } - tensor_type.dtype().Print(os); - os << ">"; - } else if (type.isa()) { - AllocatedDenseTensorArrayType tensor_array_type = - type.dyn_cast(); - - os << phi::AllocationTypeStr(tensor_array_type.place().GetType()) << "_"; - os << "tensor_array<"; - tensor_array_type.dtype().Print(os); - os << ">"; - } + PrintKernelType(type, os); } void OneDNNKernelDialect::PrintAttribute(pir::Attribute attr, std::ostream &os) const { - phi::KernelKey kernel = attr.dyn_cast().data(); - - os << ""; + PrintKernelAttribute(attr, os); } void OneDNNKernelDialect::PrintOperation(pir::Operation *op, @@ -226,6 +238,7 @@ void OneDNNKernelDialect::PrintOperation(pir::Operation *op, } // namespace paddle IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::KernelDialect) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CustomKernelDialect) #ifdef PADDLE_WITH_DNNL IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNKernelDialect) #endif diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h b/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h index fbdb53a40b183d..ad198cb25296df 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h @@ -36,6 +36,23 @@ class KernelDialect : public pir::Dialect { void initialize(); }; +class CustomKernelDialect : public pir::Dialect { + public: + explicit CustomKernelDialect(pir::IrContext* context); + + static const char* name() { return "custom_kernel"; } + + void PrintType(pir::Type type, std::ostream& os) const override; + + void PrintAttribute(pir::Attribute attr, std::ostream& os) const override; + + void PrintOperation(pir::Operation* op, + pir::IrPrinter& printer) const override; // NOLINT + + private: + void initialize(); +}; + #ifdef PADDLE_WITH_DNNL class OneDNNKernelDialect : public pir::Dialect { public: @@ -59,6 +76,7 @@ class OneDNNKernelDialect : public pir::Dialect { } // namespace paddle IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::KernelDialect) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::CustomKernelDialect) #ifdef PADDLE_WITH_DNNL IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNKernelDialect) #endif diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc index 45f0a848fc174d..c5095046ff8aef 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc @@ -98,6 +98,46 @@ phi::KernelKey LegacyKernelOp::kernel_key() { return attributes().at("kernel_key").dyn_cast().data(); } +const char* CustomKernelOp::attributes_name[attributes_num] = { // NOLINT + "op_name", + "kernel_name", + "kernel_key"}; + +void CustomKernelOp::VerifySig() { + VLOG(4) << "Verifying inputs, outputs and attributes for: CustomKernelOp."; + auto& attributes = this->attributes(); + + PADDLE_ENFORCE(attributes.count("op_name") > 0 && + attributes.at("op_name").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: op_name is not right.")); + + PADDLE_ENFORCE(attributes.count("kernel_name") > 0 && + attributes.at("kernel_name").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: kernel_name is not right.")); + + PADDLE_ENFORCE(attributes.count("kernel_key") > 0 && + attributes.at("kernel_key").isa(), + phi::errors::PreconditionNotMet( + "Type of attribute: kernel_key is not right.")); +} + +std::string CustomKernelOp::op_name() { + return attributes().at("op_name").dyn_cast().AsString(); +} + +std::string CustomKernelOp::kernel_name() { + return attributes() + .at("kernel_name") + .dyn_cast() + .AsString(); +} + +phi::KernelKey CustomKernelOp::kernel_key() { + return attributes().at("kernel_key").dyn_cast().data(); +} + #ifdef PADDLE_WITH_DNNL const char* OneDNNPhiKernelOp::attributes_name[attributes_num] = { // NOLINT "op_name", @@ -134,6 +174,7 @@ std::string OneDNNPhiKernelOp::kernel_name() { .dyn_cast() .AsString(); } + phi::KernelKey OneDNNPhiKernelOp::kernel_key() { return attributes().at("kernel_key").dyn_cast().data(); } @@ -225,6 +266,7 @@ phi::KernelKey OneDNNLegacyKernelOp::kernel_key() { IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::PhiKernelOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::LegacyKernelOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CustomKernelOp) #ifdef PADDLE_WITH_DNNL IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNPhiKernelOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNMixedPhiKernelOp) diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h index df723158702085..0fcaeb20807424 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h @@ -44,6 +44,18 @@ class LegacyKernelOp : public pir::Op { void VerifySig(); }; +class CustomKernelOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "custom_kernel"; } + static constexpr uint32_t attributes_num = 3; + static const char *attributes_name[attributes_num]; + std::string op_name(); + std::string kernel_name(); + phi::KernelKey kernel_key(); + void VerifySig(); +}; + #ifdef PADDLE_WITH_DNNL class OneDNNPhiKernelOp : public pir::Op { public: @@ -87,6 +99,7 @@ class OneDNNLegacyKernelOp : public pir::Op { IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::PhiKernelOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::LegacyKernelOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::CustomKernelOp) #ifdef PADDLE_WITH_DNNL IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNPhiKernelOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OneDNNMixedPhiKernelOp) diff --git a/paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h b/paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h index 24a7622fa99b07..0f045cb97a0ec0 100644 --- a/paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h +++ b/paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h @@ -28,26 +28,28 @@ namespace dialect { class OpYamlInfoInterface : public pir::OpInterfaceBase { public: struct Concept { - explicit Concept(OpInfoTuple (*get_op_info)()) + explicit Concept(OpInfoTuple (*get_op_info)(const std::string& op_name)) : get_op_info_(get_op_info) {} - OpInfoTuple (*get_op_info_)(); + OpInfoTuple (*get_op_info_)(const std::string& op_name); }; template struct Model : public Concept { - static OpInfoTuple GetOpInfo() { return ConcreteOp::GetOpInfo(); } + static OpInfoTuple GetOpInfo(const std::string& op_name) { + return ConcreteOp::GetOpInfo(); + } Model() : Concept(GetOpInfo) {} }; /// Constructor - OpYamlInfoInterface(pir::Operation *op, Concept *impl) + OpYamlInfoInterface(pir::Operation* op, Concept* impl) : pir::OpInterfaceBase(op), impl_(impl) {} - OpInfoTuple GetOpInfo() { return impl_->get_op_info_(); } + OpInfoTuple GetOpInfo() { return impl_->get_op_info_(operation_->name()); } private: - Concept *impl_; + Concept* impl_; }; } // namespace dialect diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 6e2e105d9c18a0..80f6e598f967c2 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/framework/custom_operator_utils.h" #include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" @@ -29,10 +30,20 @@ namespace paddle { namespace dialect { +static std::unordered_map kCustomTypeMap = { + {"bool", "pir::BoolAttribute"}, + {"int", "pir::Int32Attribute"}, + {"float", "pir::FloatAttribute"}, + {"int64_t", "pir::Int64Attribute"}, + {"std::string", "pir::StrAttribute"}, + {"std::vector", "pir::ArrayAttribute"}, + {"std::vector", "pir::ArrayAttribute"}, + {"std::vector", "pir::ArrayAttribute"}, + {"std::vector", "pir::ArrayAttribute"}}; struct CombineOpInferSymbolicShapeInterfaceModel : public InferSymbolicShapeInterface::Concept { static inline bool InferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { symbol::ShapeOrDataDimExprs value_shape; // for (auto operand_source : op->operands_source()) { @@ -55,7 +66,7 @@ struct CombineOpInferSymbolicShapeInterfaceModel : InferSymbolicShapeInterface::Concept(InferSymbolicShape) {} }; -OperatorDialect::OperatorDialect(pir::IrContext *ctx) +OperatorDialect::OperatorDialect(pir::IrContext* ctx) : pir::Dialect(name(), ctx, pir::TypeId::get()) { initialize(); ctx->GetOrRegisterDialect<::pir::ControlFlowDialect>(); @@ -69,40 +80,7 @@ OperatorDialect::OperatorDialect(pir::IrContext *ctx) CombineOpInferSymbolicShapeInterfaceModel>())); } -void OperatorDialect::initialize() { - RegisterTypes(); - - RegisterAttributes(); - - // NOTE(zhangbo9674): GET_OP_LIST is defined in pd_op.h which is - // generated by op_gen.py, see details in - // paddle/fluid/pir/dialect/CMakeLists.txt. - // NOTE(Ruting)GET_MANUAL_OP_LIST is define in manual_op.h" - // use RegisterOps when list has more than two ops. - RegisterOps< -#define GET_OP_LIST -#include "paddle/fluid/pir/dialect/operator/ir/pd_op_info.cc" // NOLINT - >(); - - RegisterOps< -#define GET_OP_LIST -#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc" // NOLINT - >(); - - RegisterOps< -#define GET_OP_LIST -#include "paddle/fluid/pir/dialect/operator/ir/manual_op.cc" // NOLINT - >(); - - RegisterInterfaces(); -} - -void OperatorDialect::PrintType(pir::Type type, std::ostream &os) const { +void PrintTypeImpl(pir::Type type, std::ostream& os) { os << type.dialect().name(); os << '.'; if (auto tensor_type = type.dyn_cast()) { @@ -127,16 +105,14 @@ void OperatorDialect::PrintType(pir::Type type, std::ostream &os) const { os << ">"; } } - -void OperatorDialect::PrintAttribute(pir::Attribute attr, - std::ostream &os) const { +void PrintAttributeImpl(pir::Attribute attr, std::ostream& os) { os << "(" << attr.dialect().name(); os << '.'; if (auto int_array_attr = attr.dyn_cast()) { phi::IntArray data = int_array_attr.data(); os << "IntArray)" << "["; - const auto &inner_data = data.GetData(); + const auto& inner_data = data.GetData(); pir::PrintInterleave( inner_data.begin(), inner_data.end(), @@ -154,7 +130,60 @@ void OperatorDialect::PrintAttribute(pir::Attribute attr, } } -pir::Type OperatorDialect::ParseType(pir::IrParser &parser) { // NOLINT +void PrintOperationImpl(pir::Operation* op, + pir::IrPrinter& printer) { // NOLINT + if (auto if_op = op->dyn_cast()) { + if_op.Print(printer); + } else if (auto while_op = op->dyn_cast()) { + while_op.Print(printer); + } else { + printer.PrintGeneralOperation(op); + } +} + +void OperatorDialect::initialize() { + RegisterTypes(); + + RegisterAttributes(); + + // NOTE(zhangbo9674): GET_OP_LIST is defined in pd_op.h which is + // generated by op_gen.py, see details in + // paddle/fluid/pir/dialect/CMakeLists.txt. + // NOTE(Ruting)GET_MANUAL_OP_LIST is define in manual_op.h" + // use RegisterOps when list has more than two ops. + RegisterOps< +#define GET_OP_LIST +#include "paddle/fluid/pir/dialect/operator/ir/pd_op_info.cc" // NOLINT + >(); + + RegisterOps< +#define GET_OP_LIST +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc" // NOLINT + >(); + + RegisterOps< +#define GET_OP_LIST +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.cc" // NOLINT + >(); + + RegisterInterfaces(); +} + +void OperatorDialect::PrintType(pir::Type type, std::ostream& os) const { + PrintTypeImpl(type, os); +} + +void OperatorDialect::PrintAttribute(pir::Attribute attr, + std::ostream& os) const { + PrintAttributeImpl(attr, os); +} + +pir::Type OperatorDialect::ParseType(pir::IrParser& parser) { // NOLINT parser.ConsumeAToken("pd_op.tensor"); parser.ConsumeAToken("<"); std::vector dim{}; @@ -184,7 +213,7 @@ pir::Type OperatorDialect::ParseType(pir::IrParser &parser) { // NOLINT } pir::Attribute OperatorDialect::ParseAttribute( - pir::IrParser &parser) { // NOLINT + pir::IrParser& parser) { // NOLINT std::string type_name = parser.ConsumeToken().val_; std::string attribute_name = type_name.substr(type_name.find('.') + 1, std::string::npos); @@ -203,18 +232,195 @@ pir::Attribute OperatorDialect::ParseAttribute( } } -void OperatorDialect::PrintOperation(pir::Operation *op, - pir::IrPrinter &printer) const { - if (auto if_op = op->dyn_cast()) { - if_op.Print(printer); - } else if (auto while_op = op->dyn_cast()) { - while_op.Print(printer); - } else { - printer.PrintGeneralOperation(op); +void OperatorDialect::PrintOperation(pir::Operation* op, + pir::IrPrinter& printer) const { + PrintOperationImpl(op, printer); +} + +class IdManager { + public: + static IdManager& Instance() { + static IdManager instance; + return instance; + } + + ~IdManager() { + for (auto id : ids_) { + delete id; + } + ids_.clear(); + } + + pir::TypeId CreateId() { + pir::detail::UniqueingId* unique_id = new pir::detail::UniqueingId(); + ids_.push_back(unique_id); + return ids_.back()->id(); } + + private: + std::vector ids_; +}; + +class AttributeManager { + public: + static AttributeManager& Instance() { + static AttributeManager instance; + return instance; + } + + ~AttributeManager() { + for (size_t i = 0; i < char_pointers_.size(); i++) { + for (size_t j = 0; j < pointers_size_[i]; j++) { + delete char_pointers_[i][j]; + } + delete char_pointers_[i]; + } + char_pointers_.clear(); + pointers_size_.clear(); + } + + const char** ToCharPointers(const std::vector& attr_names) { + const char** char_pointers = new const char*[attr_names.size()]; + for (size_t i = 0; i < attr_names.size(); i++) { + const std::string& attr_name = attr_names[i]; + char* ptr = new char[attr_name.size() + 1]; + snprintf(ptr, attr_name.size() + 1, "%s", attr_name.c_str()); + char_pointers[i] = ptr; + } + pointers_size_.push_back(attr_names.size()); + char_pointers_.push_back(char_pointers); + return char_pointers; + } + + private: + std::vector char_pointers_; + std::vector pointers_size_; +}; + +struct CustomOpInfoInterfaceModel : public OpYamlInfoInterface::Concept { + static OpInfoTuple GetPirOpInfo(const std::string& pir_op_name) { + const auto& op_meta = + paddle::framework::detail::GetOpInfoByPirName(pir_op_name); + std::vector inputs_info; + std::vector attributes_info; + std::vector outputs_info; + std::vector param_names; + // translate input info + auto& op_input_names = OpMetaInfoHelper::GetInputs(op_meta); + for (const auto& input_name : op_input_names) { + param_names.push_back(input_name); + bool is_optional = false; + std::string input_type = "paddle::dialect::DenseTensorType"; + if (paddle::framework::detail::IsOptionalVar(input_name)) { + is_optional = true; + } + if (paddle::framework::detail::IsDuplicableVar(input_name)) { + input_type = "pir::VectorType"; + } + // Now, we only support dense tensor as input. + inputs_info.push_back(paddle::dialect::OpInputInfo{ + input_name, input_type, is_optional, false, false, false}); + } + + // translate attr info + auto& op_attrs = OpMetaInfoHelper::GetAttrs(op_meta); + for (const auto& op_attr : op_attrs) { + auto attr_name_and_type = paddle::ParseAttrStr(op_attr); + auto attr_name = attr_name_and_type[0]; + auto attr_type_str = attr_name_and_type[1]; + param_names.push_back(attr_name); + if (kCustomTypeMap.find(attr_type_str) == kCustomTypeMap.end()) { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported `%s` type value as custom attribute now. " + "Supported data types include `bool`, `int`, `float`, " + "`int64_t`, `std::string`, `std::vector`, " + "`std::vector`, `std::vector`, " + "`std::vector`, Please check whether " + "the attribute data type and data type string are matched.", + attr_type_str)); + } + std::string attr_pir_type = kCustomTypeMap[attr_type_str]; + attributes_info.push_back( + paddle::dialect::OpAttributeInfo{attr_name, attr_pir_type, ""}); + } + + // translate output info + auto& op_output_names = OpMetaInfoHelper::GetOutputs(op_meta); + for (const auto& output_name : op_output_names) { + bool is_optional = false; + if (paddle::framework::detail::IsOptionalVar(output_name)) { + is_optional = true; + } + // Now, we only support dense tensor as output. + outputs_info.push_back(paddle::dialect::OpOutputInfo{ + output_name, "paddle::dialect::DenseTensorType", is_optional, false}); + } + + // we only need kernel params name in run_time_info + paddle::dialect::OpRunTimeInfo run_time_info = + paddle::dialect::OpRunTimeInfo("", {}, "", param_names, {}, {}, {}, {}); + return std::make_tuple( + inputs_info, attributes_info, outputs_info, run_time_info, ""); + } + + CustomOpInfoInterfaceModel() : OpYamlInfoInterface::Concept(GetPirOpInfo) {} +}; + +CustomOpDialect::CustomOpDialect(pir::IrContext* context) + : pir::Dialect(name(), context, pir::TypeId::get()) {} + +void CustomOpDialect::PrintType(pir::Type type, std::ostream& os) const { + PrintTypeImpl(type, os); } +void CustomOpDialect::PrintAttribute(pir::Attribute attr, + std::ostream& os) const { + PrintAttributeImpl(attr, os); +} + +void CustomOpDialect::PrintOperation(pir::Operation* op, + pir::IrPrinter& printer) const { + PrintOperationImpl(op, printer); +} + +void CustomOpDialect::RegisterCustomOp(const paddle::OpMetaInfo& op_meta) { + pir::TypeId id = IdManager::Instance().CreateId(); + std::string op_name = paddle::framework::kCustomDialectPrefix + + OpMetaInfoHelper::GetOpName(op_meta); + op_names_.push_back(op_name); + + auto& op_attrs = OpMetaInfoHelper::GetAttrs(op_meta); + std::vector attr_names; + for (const auto& op_attr : op_attrs) { + auto attr_name_and_type = paddle::ParseAttrStr(op_attr); + auto attr_name = attr_name_and_type[0]; + attr_names.push_back(attr_name); + } + const char** attr_name = + AttributeManager::Instance().ToCharPointers(attr_names); + uint32_t attr_num = attr_names.size(); + + std::vector traits; + std::set interface_values; + pir::InterfaceValue op_info_interface = + pir::InterfaceValue::Get(); + interface_values.insert(std::move(op_info_interface)); + // Currently we set empty verify function and will reset it if it is used in + // future. + pir::VerifyPtr verify_func = [](pir::Operation* op) {}; + ir_context()->RegisterOpInfo(this, + id, + op_names_.back().c_str(), + std::move(interface_values), + traits, + attr_num, + attr_name, + verify_func, + verify_func); +} } // namespace dialect } // namespace paddle IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OperatorDialect) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CustomOpDialect) diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.h b/paddle/fluid/pir/dialect/operator/ir/op_dialect.h index 8a61f6cb9615be..d6626f999ffd1a 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.h +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.h @@ -14,7 +14,10 @@ #pragma once +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/phi/api/ext/op_meta_info.h" #include "paddle/pir/core/dialect.h" +#include "paddle/pir/core/operation.h" #include "paddle/utils/test_macros.h" namespace paddle { @@ -39,7 +42,39 @@ class TEST_API OperatorDialect : public pir::Dialect { void initialize(); }; +inline bool IsCustomOp(pir::Operation* op) { + std::string op_name = op->name(); + return op_name.find("custom_op") != op_name.npos; +} + +class CustomOpDialect : public pir::Dialect { + public: + explicit CustomOpDialect(pir::IrContext* context); + + static const char* name() { return "custom_op"; } + + void PrintType(pir::Type type, std::ostream& os) const override; + void PrintAttribute(pir::Attribute type, std::ostream& os) const override; + + void PrintOperation(pir::Operation* op, + pir::IrPrinter& printer) const override; // NOLINT + + void RegisterCustomOp(const paddle::OpMetaInfo& op_meta); + + bool HasRegistered(const std::string& op_name) { + if (std::find(op_names_.begin(), op_names_.end(), op_name) != + op_names_.end()) { + return true; + } + return false; + } + + private: + std::vector op_names_; +}; + } // namespace dialect } // namespace paddle IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OperatorDialect) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::CustomOpDialect) diff --git a/paddle/fluid/pir/transforms/inplace_pass.cc b/paddle/fluid/pir/transforms/inplace_pass.cc index b836617321f8cf..56d767180c15ad 100644 --- a/paddle/fluid/pir/transforms/inplace_pass.cc +++ b/paddle/fluid/pir/transforms/inplace_pass.cc @@ -181,7 +181,8 @@ bool IsNoNeedBuffer(pir::Operation* op, pir::Value value) { op_info.GetInterfaceImpl(); if (info_interface) { paddle::dialect::OpYamlInfoParser info_parser( - info_interface->get_op_info_(), paddle::dialect::IsLegacyOp(op_name)); + info_interface->get_op_info_(op_name), + paddle::dialect::IsLegacyOp(op_name)); auto& no_need_buffer_ids = info_parser.NoNeedBufferIds(); for (size_t id = 0; id < no_need_buffer_ids.size(); id++) { if (value == op->operand_source(no_need_buffer_ids[id])) { @@ -274,23 +275,19 @@ void GetEagerDelValueOfOp( std::unordered_map> GetEagerDeletionValues(const pir::Block& block) { std::unordered_set skip_dels = GetSkipDeletionValues(block); - std::unordered_map del_value_2_op; GetEagerDelValueOfOp(block, skip_dels, &del_value_2_op); - std::unordered_map> eager_dels; for (auto& kv : del_value_2_op) { eager_dels[kv.second].insert(kv.first); } - return eager_dels; } std::unordered_map GetInplaceOps( const pir::Block& block) { const auto eager_dels = GetEagerDeletionValues(block); - std::unordered_map inplace_ops; std::unordered_set visited_values; @@ -312,7 +309,6 @@ std::unordered_map GetInplaceOps( } continue; } - auto upper_op_attrs = op.attributes(); auto upper_op_name = upper_op_attrs.at("op_name").dyn_cast().AsString(); @@ -389,7 +385,7 @@ std::unordered_map GetInplaceOps( phi::errors::PreconditionNotMet( "can not find OpYamlInfoInterface from [%s]", upper_op_name + "_")); paddle::dialect::OpYamlInfoParser upper_inplace_op_info_parser( - upper_inplace_op_interface->get_op_info_()); + upper_inplace_op_interface->get_op_info_(upper_op_name + "_")); std::unordered_map inplace_out_2_in = upper_inplace_op_info_parser.GetInplaceIdMap(); diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index df7b8673d9ea80..165a1d3fde4fc7 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -1475,6 +1475,178 @@ void HandleForSpecialOp( VLOG(6) << "Deep copy a new special op: " << op_item->name(); } +void PushBackOutputTypes(pir::IrContext* ctx, + pir::Operation* op_item, + const phi::Place& out_place, + const phi::KernelKey& kernel_key, + std::vector* op_output_types, + size_t index) { + auto result_type = op_item->result(index).type(); + if (!result_type) { + op_output_types->push_back(result_type); + } else if (result_type.isa() || + result_type.isa() || + result_type.isa()) { +#ifdef PADDLE_WITH_DNNL + if (kernel_key.backend() == phi::Backend::ONEDNN) { + op_output_types->push_back(BuildOutputType( + result_type, out_place, phi::DataLayout::ONEDNN, ctx)); + } else { + op_output_types->push_back(BuildOutputType(result_type, out_place, ctx)); + } +#else + op_output_types->push_back(BuildOutputType(result_type, out_place, ctx)); +#endif + + } else if (result_type.isa()) { + std::vector vec_inner_types; + auto base_types = result_type.dyn_cast().data(); + for (auto& base_type : base_types) { + if (base_type) { + if (base_type.isa() || + base_type.isa()) { +#ifdef PADDLE_WITH_DNNL + if (kernel_key.backend() == phi::Backend::ONEDNN) { + vec_inner_types.push_back(BuildOutputType( + base_type, out_place, phi::DataLayout::ONEDNN, ctx)); + } else { + vec_inner_types.push_back( + BuildOutputType(base_type, out_place, ctx)); + } +#else + vec_inner_types.push_back(BuildOutputType(base_type, out_place, ctx)); +#endif + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "only support dense tensor and selected rows in vector type " + "for now")); + } + } else { + // NOTE(phlrain), kernel not support a nullptr in output + pir::Type fp32_dtype = pir::Float32Type::get(ctx); + phi::DDim dims = {}; + phi::DataLayout data_layout = phi::DataLayout::NCHW; +#ifdef PADDLE_WITH_DNNL + if (kernel_key.backend() == phi::Backend::ONEDNN) { + data_layout = phi::DataLayout::ONEDNN; + } +#endif + phi::LoD lod = {{}}; + size_t offset = 0; + auto dense_tensor_dtype = DenseTensorType::get( + ctx, fp32_dtype, dims, data_layout, lod, offset); + auto allocated_dense_tensor_dtype = + AllocatedDenseTensorType::get(ctx, out_place, dense_tensor_dtype); + vec_inner_types.push_back(allocated_dense_tensor_dtype); + } + } + + pir::Type t1 = pir::VectorType::get(ctx, vec_inner_types); + op_output_types->push_back(t1); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Result type only support DenseTensorType, SelectedRowType and " + "VectorType")); + } +} + +void HandleForCustomOp( + pir::IrContext* ctx, + pir::Operation* op_item, + const phi::KernelKey& kernel_key, + const phi::Place place, + const OpYamlInfoParser* op_info_parser, + std::unordered_map* map_op_pair, + std::unordered_map* map_value_pair, + pir::Block* block) { + // Prepare output types + std::vector op_output_types; + + for (size_t i = 0; i < op_item->num_results(); ++i) { + phi::Place out_place = phi::TransToPhiPlace(kernel_key.backend()); + PushBackOutputTypes( + ctx, op_item, out_place, kernel_key, &op_output_types, i); + } + + // Prepare input + std::vector vec_inputs; + + for (size_t i = 0; i < op_item->num_operands(); ++i) { + auto cur_in = op_item->operand_source(i); + if (!cur_in) { + vec_inputs.emplace_back(); + continue; + } + PADDLE_ENFORCE_EQ( + map_value_pair->count(cur_in), + true, + phi::errors::PreconditionNotMet( + "[%d]'s input of [%s] op MUST in map pair", i, op_item->name())); + + auto new_in = map_value_pair->at(cur_in); + auto new_in_type = new_in.type(); + + if (new_in_type.isa()) { + auto in_place = new_in_type.dyn_cast().place(); + // GPU_PINNED -> GPU, refer to PR#41972 + if (phi::AllocationType::GPUPINNED == place.GetType()) { + VLOG(6) << "need trans from GPUPINNED to GPU"; + // build memcopy op + auto out_place = phi::TransToPhiPlace(phi::Backend::GPU); + auto new_in_alloc_type = + new_in_type.dyn_cast(); + auto out_type = + AllocatedDenseTensorType::get(ctx, + out_place, + new_in_alloc_type.dtype(), + new_in_alloc_type.dims(), + new_in_alloc_type.data_layout(), + new_in_alloc_type.lod(), + new_in_alloc_type.offset()); + new_in = AddPlaceTransferOp( + new_in, out_type, in_place, out_place, kernel_key, block); + } + } + + vec_inputs.push_back(new_in); + } + + // Prepare attr + std::unordered_map op_attribute{ + {"op_name", pir::StrAttribute::get(ctx, op_item->name())}, + {"kernel_name", pir::StrAttribute::get(ctx, op_item->name())}, + {"kernel_key", KernelAttribute::get(ctx, kernel_key)}}; + auto op_attr_map = op_item->attributes(); + + for (auto& map_item : op_attr_map) { + op_attribute.emplace(map_item.first, map_item.second); + } + + if (op_item->HasTrait()) { + op_attribute.emplace("is_inplace", pir::BoolAttribute::get(ctx, true)); + } + + VLOG(6) << "Lower custom op: " << op_item->name() + << " to : " << CustomKernelOp::name(); + + pir::OpInfo custom_kernel_op_info = + ctx->GetRegisteredOpInfo(CustomKernelOp::name()); + + pir::Operation* op = nullptr; + op = pir::Operation::Create( + vec_inputs, op_attribute, op_output_types, custom_kernel_op_info); + + (*map_op_pair)[op_item] = op; + + // only deal with single output + if (op_item->num_results() > 0) { + for (size_t i = 0; i < op_item->num_results(); ++i) { + (*map_value_pair)[op_item->result(i)] = op->result(i); + } + } + block->push_back(op); +} + std::vector BuildOutputs(pir::Operation* op_item, const std::string& kernel_fn_str, const phi::KernelKey& kernel_key, @@ -1508,75 +1680,8 @@ std::vector BuildOutputs(pir::Operation* op_item, (!IsLegacyOp(op_item->name())) && phi_kernel.IsValid()) { out_place = phi::TransToPhiPlace(output_defs[i].backend); } - - auto result_type = op_item->result(i).type(); - if (!result_type) { - op_output_types.push_back(result_type); - } else if (result_type.isa() || - result_type.isa() || - result_type.isa()) { -#ifdef PADDLE_WITH_DNNL - if (kernel_key.backend() == phi::Backend::ONEDNN) { - op_output_types.push_back(BuildOutputType( - result_type, out_place, phi::DataLayout::ONEDNN, ctx)); - } else { - op_output_types.push_back(BuildOutputType(result_type, out_place, ctx)); - } -#else - op_output_types.push_back(BuildOutputType(result_type, out_place, ctx)); -#endif - - } else if (result_type.isa()) { - std::vector vec_inner_types; - auto base_types = result_type.dyn_cast().data(); - for (auto& base_type : base_types) { - if (base_type) { - if (base_type.isa() || - base_type.isa()) { -#ifdef PADDLE_WITH_DNNL - if (kernel_key.backend() == phi::Backend::ONEDNN) { - vec_inner_types.push_back(BuildOutputType( - base_type, out_place, phi::DataLayout::ONEDNN, ctx)); - } else { - vec_inner_types.push_back( - BuildOutputType(base_type, out_place, ctx)); - } -#else - vec_inner_types.push_back( - BuildOutputType(base_type, out_place, ctx)); -#endif - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "only support dense tensor and selected rows in vector type " - "for now")); - } - } else { - // NOTE(phlrain), kernel not support a nullptr in output - pir::Type fp32_dtype = pir::Float32Type::get(ctx); - phi::DDim dims = {}; - phi::DataLayout data_layout = phi::DataLayout::NCHW; -#ifdef PADDLE_WITH_DNNL - if (kernel_key.backend() == phi::Backend::ONEDNN) { - data_layout = phi::DataLayout::ONEDNN; - } -#endif - phi::LoD lod = {{}}; - size_t offset = 0; - auto dense_tensor_dtype = DenseTensorType::get( - ctx, fp32_dtype, dims, data_layout, lod, offset); - auto allocated_dense_tensor_dtype = - AllocatedDenseTensorType::get(ctx, out_place, dense_tensor_dtype); - vec_inner_types.push_back(allocated_dense_tensor_dtype); - } - } - - pir::Type t1 = pir::VectorType::get(ctx, vec_inner_types); - op_output_types.push_back(t1); - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "Result type only support DenseTensorType, SelectedRowType and " - "VectorType")); - } + PushBackOutputTypes( + ctx, op_item, out_place, kernel_key, &op_output_types, i); } return op_output_types; @@ -2074,6 +2179,18 @@ void ProcessBlock( op_item, place, kernel_name, *map_value_pair, op_info_parser.get()); VLOG(6) << "kernel type " << kernel_key; + if (paddle::dialect::IsCustomOp(op_item)) { + HandleForCustomOp(ctx, + op_item, + kernel_key, + place, + op_info_parser.get(), + map_op_pair, + map_value_pair, + new_block); + continue; + } + #ifdef PADDLE_WITH_DNNL if (op_item->HasTrait() && kernel_key.backend() != phi::Backend::ONEDNN) { @@ -2147,6 +2264,8 @@ std::unique_ptr PdOpLowerToKernelPass(pir::Program* prog, pir::IrContext* ctx = pir::IrContext::Instance(); ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + #ifdef PADDLE_WITH_DNNL ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 8813ff59de53e4..1c398cf7cdf975 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -1406,7 +1406,7 @@ std::map GetOpInplaceInfo(const pir::Operation *op) { pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name); paddle::dialect::OpYamlInfoParser yaml_parser( op_info.GetInterfaceImpl() - ->get_op_info_(), + ->get_op_info_(op_name), paddle::dialect::IsLegacyOp(op_name)); for (size_t i = 0; i < op->num_results(); ++i) { diff --git a/paddle/phi/kernels/autotune/cache_base.h b/paddle/phi/kernels/autotune/cache_base.h index 82af1ccbb71325..37f6106b1baa8e 100644 --- a/paddle/phi/kernels/autotune/cache_base.h +++ b/paddle/phi/kernels/autotune/cache_base.h @@ -19,38 +19,12 @@ #include #include "paddle/common/errors.h" +#include "paddle/common/hash_funcs.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/flags.h" PHI_DECLARE_int32(search_cache_max_number); -inline void HashCombine(std::size_t* seed UNUSED) {} - -// combine hash value -// https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x -template -inline void HashCombine(std::size_t* seed, const T& v, Rest... rest) { - std::hash hasher; - *seed ^= hasher(v) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); - *seed *= 0x00000100000001B3; - HashCombine(seed, rest...); -} - -// custom specialization of std::hash can be injected in namespace std -// ref: https://en.cppreference.com/w/cpp/utility/hash -namespace std { -template -struct hash> { - std::size_t operator()(std::vector const& vec) const noexcept { - std::size_t seed = 0xcbf29ce484222325; - for (auto val : vec) { - HashCombine(&seed, val); - } - return seed; - } -}; -} // namespace std - namespace phi { namespace autotune { diff --git a/paddle/pir/core/builtin_type_storage.h b/paddle/pir/core/builtin_type_storage.h index d8361658f9e85b..77c3383f797981 100644 --- a/paddle/pir/core/builtin_type_storage.h +++ b/paddle/pir/core/builtin_type_storage.h @@ -16,28 +16,12 @@ #include "paddle/common/ddim.h" #include "paddle/common/dim.h" +#include "paddle/common/hash_funcs.h" #include "paddle/common/layout.h" #include "paddle/pir/core/type.h" #include "paddle/pir/core/type_base.h" #include "paddle/pir/core/utils.h" -namespace std { -/// -/// \brief Enable hashing std::vector instances. -/// -template -struct hash> { - std::size_t operator()(const std::vector& dim) const { - std::size_t seed = 0; - for (size_t i = 0; i < dim.size(); ++i) { - seed ^= std::hash()(dim[i]) + 0x9e3779b9 + (seed << 6) + (seed >> 2); - } - return seed; - } -}; - -} // namespace std - namespace pir { /// /// \brief Define Parametric TypeStorage for DenseTensorType. diff --git a/paddle/pir/core/operation.cc b/paddle/pir/core/operation.cc index fc670d4e9e44e4..c0ce8842155ab6 100644 --- a/paddle/pir/core/operation.cc +++ b/paddle/pir/core/operation.cc @@ -117,7 +117,6 @@ Operation *Operation::Create(const std::vector &inputs, base_ptr += sizeof(detail::BlockOperandImpl); } } - // 3.5. Construct Regions if (num_regions > 0) { op->regions_ = reinterpret_cast(base_ptr); @@ -126,7 +125,6 @@ Operation *Operation::Create(const std::vector &inputs, base_ptr += sizeof(Region); } } - // 0. Verify if (op_info) { try { diff --git a/test/cpp/pir/shape_dialect/CMakeLists.txt b/test/cpp/pir/shape_dialect/CMakeLists.txt index 5c3aa2b9f43449..decfc904088464 100644 --- a/test/cpp/pir/shape_dialect/CMakeLists.txt +++ b/test/cpp/pir/shape_dialect/CMakeLists.txt @@ -1,8 +1,8 @@ -paddle_test(shape_op_test SRCS shape_op_test.cc DEPS gtest) +paddle_test(shape_op_test SRCS shape_op_test.cc) -paddle_test(shape_struct_test SRCS shape_struct_test.cc DEPS gtest) +paddle_test(shape_struct_test SRCS shape_struct_test.cc) -paddle_test(symbol_dim_expr_test SRCS symbol_dim_expr_test.cc DEPS gtest) +paddle_test(symbol_dim_expr_test SRCS symbol_dim_expr_test.cc) paddle_test(symbol_dim_expr_util_test SRCS symbol_dim_expr_util_test.cc DEPS gtest) From 5e2a3dbf836cc30fb6c2b66002e8965b862c91a3 Mon Sep 17 00:00:00 2001 From: kevin Date: Tue, 2 Jan 2024 15:57:45 +0800 Subject: [PATCH 067/142] [Prim][PIR] support roll, gather, scatter, scatter_nd_add op backward in pir prim (#60481) * prim gather op backward * prim scatter op backward * prim roll op backward * prim scatter_nd op backward --- paddle/fluid/primitive/codegen/gen.py | 4 + paddle/fluid/primitive/rule/vjp/details.h | 100 ++++++++++++++++++++++ test/legacy_test/test_gather_op.py | 11 ++- test/legacy_test/test_roll_op.py | 18 +++- test/legacy_test/test_scatter_nd_op.py | 26 +++++- test/legacy_test/test_scatter_op.py | 38 ++++++-- 6 files changed, 183 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index 01e760c2b33b21..005eae29593434 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -80,13 +80,17 @@ 'sum_grad', 'cast_grad', 'reshape_grad', + 'roll_grad', 'split_grad', 'transpose_grad', 'concat_grad', 'expand_grad', + 'gather_grad', 'gather_nd_grad', 'pad_grad', 'max_grad', + 'scatter_grad', + 'scatter_nd_add_grad', 'slice_grad', 'tile_grad', 'topk_grad', diff --git a/paddle/fluid/primitive/rule/vjp/details.h b/paddle/fluid/primitive/rule/vjp/details.h index 60d51c60146270..1be68ba043e19f 100644 --- a/paddle/fluid/primitive/rule/vjp/details.h +++ b/paddle/fluid/primitive/rule/vjp/details.h @@ -243,6 +243,23 @@ void reshape_grad(const Tensor& xshape, } } +template +void roll_grad(const Tensor& x, + const Tensor& out_grad, + const IntArray& shifts, + const std::vector& axis, + Tensor* x_grad) { + if (x_grad) { + auto shifts_ = shifts.GetData(); + int64_t nums = shifts_.size(); + for (int64_t i = 0; i < nums; i++) { + shifts_[i] = 0 - shifts_[i]; + } + auto x_grad_output = roll(out_grad, shifts_, axis); + set_output(x_grad_output, x_grad); + } +} + template void transpose_grad(const Tensor& grad_out, const std::vector& perm, @@ -262,6 +279,43 @@ void transpose_grad(const Tensor& grad_out, } } +template +void scatter_grad(const Tensor& index, + const Tensor& updates, + const Tensor& out_grad, + bool overwrite, + Tensor* x_grad, + Tensor* updates_grad) { + if (x_grad) { + auto zero_tensor = + full(common::vectorize(updates.dims()), 0.0, updates.dtype()); + auto tmp_grad = scatter(out_grad, index, zero_tensor, false); + set_output(tmp_grad, x_grad); + } + + if (updates_grad) { + Scalar tmp_zero = 0; + auto tmp_updates_grad = gather(out_grad, index, tmp_zero); + set_output(tmp_updates_grad, updates_grad); + } +} + +template +void scatter_nd_add_grad(const Tensor& index, + const Tensor& updates, + const Tensor& out_grad, + Tensor* x_grad, + Tensor* updates_grad) { + if (x_grad) { + by_pass(out_grad, x_grad); + } + if (updates_grad) { + // Gradient by Gather: dUpdates = dO[Ids] + auto tmp_updates_grad = gather_nd(out_grad, index); + set_output(tmp_updates_grad, updates_grad); + } +} + template void sin_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { auto x_grad_tmp = cos(x) * out_grad; @@ -818,6 +872,52 @@ void relu_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { } } +template +void gather_grad(const Tensor& x, + const Tensor& index, + const Tensor& out_grad, + const Scalar& axis, + Tensor* grad_x) { + auto zero_tensor = full(common::vectorize(x.dims()), 0.0, x.dtype()); + std::vector tmp_perm; + + // change axis to rank 0 + int axis_value = axis.to(); + tmp_perm.push_back(axis_value); + // make other ranks + for (int i = 0; i < x.dims().size(); ++i) { + if (i != axis_value) { + tmp_perm.push_back(i); + } + } + std::vector reverse_perm(tmp_perm); + // make origin ranks + for (int i = 0; i < static_cast(tmp_perm.size()); ++i) { + if (tmp_perm[i] >= 0) { + reverse_perm[tmp_perm[i]] = i; + } else { + reverse_perm[tmp_perm[i] + tmp_perm.size()] = i; + } + } + + // transpose out_grad and zero grad to target rank. + auto tmp_zero_x_grad = zero_tensor; + auto tmp_out_grad = out_grad; + if (zero_tensor.dims().size() > 0) { + tmp_zero_x_grad = transpose(zero_tensor, tmp_perm); + } + if (out_grad.dims().size() > 0) { + tmp_out_grad = transpose(out_grad, tmp_perm); + } + // scatter grad to grad_x + auto tmp_grad_x = scatter(tmp_zero_x_grad, index, tmp_out_grad, false); + auto tmp_grad_x_tranposed = tmp_grad_x; + if (tmp_grad_x.dims().size() > 0) { + tmp_grad_x_tranposed = transpose(tmp_grad_x, reverse_perm); + } + set_output(tmp_grad_x_tranposed, grad_x); +} + template void gather_nd_grad(const Tensor& x, const Tensor& index, diff --git a/test/legacy_test/test_gather_op.py b/test/legacy_test/test_gather_op.py index 3ebb2de7b8560a..f37af3a62ddb95 100644 --- a/test/legacy_test/test_gather_op.py +++ b/test/legacy_test/test_gather_op.py @@ -45,7 +45,9 @@ def test_check_output(self): self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) + self.check_grad( + ['X'], 'Out', check_prim=True, check_pir=True, check_prim_pir=True + ) def config(self): """ @@ -119,7 +121,12 @@ def test_check_output(self): def test_check_grad(self): self.check_grad_with_place( - paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True, check_pir=True + paddle.CUDAPlace(0), + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) diff --git a/test/legacy_test/test_roll_op.py b/test/legacy_test/test_roll_op.py index 5512e248acbb17..e6057705e4987f 100644 --- a/test/legacy_test/test_roll_op.py +++ b/test/legacy_test/test_roll_op.py @@ -52,7 +52,9 @@ def test_check_output(self): self.check_output(check_prim=True, check_pir=True) def test_check_grad_normal(self): - self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) + self.check_grad( + ['X'], 'Out', check_prim=True, check_pir=True, check_prim_pir=True + ) class TestRollOpCase2(TestRollOp): @@ -139,7 +141,12 @@ def test_check_output(self): def test_check_grad_normal(self): self.check_grad_with_place( - self.place, ['X'], 'Out', check_prim=True, check_pir=True + self.place, + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -163,7 +170,12 @@ def test_check_output(self): def test_check_grad_normal(self): self.check_grad_with_place( - self.place, ['X'], 'Out', check_prim=True, check_pir=True + self.place, + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) diff --git a/test/legacy_test/test_scatter_nd_op.py b/test/legacy_test/test_scatter_nd_op.py index e9e541e09af670..6290c0b485c4fc 100644 --- a/test/legacy_test/test_scatter_nd_op.py +++ b/test/legacy_test/test_scatter_nd_op.py @@ -98,7 +98,11 @@ def test_check_output(self): def test_check_grad(self): self.check_grad( - ['X', 'Updates'], 'Out', check_prim=True, check_pir=True + ['X', 'Updates'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -133,7 +137,12 @@ def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X', 'Updates'], 'Out', check_prim=True, check_pir=True + place, + ['X', 'Updates'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -176,7 +185,11 @@ def test_check_output(self): def test_check_grad(self): self.check_grad( - ['X', 'Updates'], 'Out', check_prim=True, check_pir=True + ['X', 'Updates'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -211,7 +224,12 @@ def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X', 'Updates'], 'Out', check_prim=True, check_pir=True + place, + ['X', 'Updates'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) diff --git a/test/legacy_test/test_scatter_op.py b/test/legacy_test/test_scatter_op.py index d44982c6321d09..61b6b0b45f3088 100644 --- a/test/legacy_test/test_scatter_op.py +++ b/test/legacy_test/test_scatter_op.py @@ -57,7 +57,11 @@ def test_check_output(self): def test_check_grad(self): self.check_grad( - ["X", "Updates"], "Out", check_prim=True, check_pir=True + ["X", "Updates"], + "Out", + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -92,6 +96,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_pir=True, + check_prim_pir=True, ) @@ -128,7 +133,11 @@ def test_check_output(self): def test_check_grad(self): self.check_grad( - ["X", "Updates"], "Out", check_prim=True, check_pir=True + ["X", "Updates"], + "Out", + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -163,6 +172,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_pir=True, + check_prim_pir=True, ) @@ -202,7 +212,11 @@ def test_check_output(self): def test_check_grad(self): self.check_grad( - ["X", "Updates"], "Out", check_prim=True, check_pir=True + ["X", "Updates"], + "Out", + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -237,6 +251,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_pir=True, + check_prim_pir=True, ) @@ -284,6 +299,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_pir=True, + check_prim_pir=True, ) @@ -356,6 +372,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_pir=True, + check_prim_pir=True, ) @@ -412,7 +429,11 @@ def test_check_output(self): def test_check_grad(self): self.check_grad( - ['X', 'Updates'], 'Out', check_prim=True, check_pir=True + ['X', 'Updates'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -447,6 +468,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_pir=True, + check_prim_pir=True, ) @@ -494,6 +516,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_pir=True, + check_prim_pir=True, ) @@ -550,7 +573,11 @@ def test_check_output(self): def test_check_grad(self): self.check_grad( - ["X", "Updates"], "Out", check_prim=True, check_pir=True + ["X", "Updates"], + "Out", + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -585,6 +612,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_pir=True, + check_prim_pir=True, ) From 8550d4c630f3d1d1dc00a917e9fe5ea898d91cae Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Tue, 2 Jan 2024 15:58:25 +0800 Subject: [PATCH 068/142] [PIR] delete dense_tensor mem_desc_ (#60024) * delete dense_tensor mem_desc_ --- paddle/phi/core/dense_tensor.cc | 10 ---------- paddle/phi/core/dense_tensor.h | 18 ------------------ paddle/phi/core/dense_tensor.inl | 12 +----------- paddle/phi/core/dense_tensor_impl.cc | 28 ++++++++++++++++++++++++---- 4 files changed, 25 insertions(+), 43 deletions(-) diff --git a/paddle/phi/core/dense_tensor.cc b/paddle/phi/core/dense_tensor.cc index c86a06bedef8d1..1181a812669762 100644 --- a/paddle/phi/core/dense_tensor.cc +++ b/paddle/phi/core/dense_tensor.cc @@ -59,10 +59,6 @@ DenseTensor::DenseTensor(const DenseTensor& other) { storage_properties_ = std::move(CopyStorageProperties(other.storage_properties_)); inplace_version_counter_ = other.inplace_version_counter_; - -#ifdef PADDLE_WITH_DNNL - mem_desc_ = other.mem_desc_; -#endif } DenseTensor& DenseTensor::operator=(const DenseTensor& other) { @@ -74,9 +70,6 @@ DenseTensor& DenseTensor::operator=(const DenseTensor& other) { storage_properties_ = std::move(CopyStorageProperties(other.storage_properties_)); inplace_version_counter_ = other.inplace_version_counter_; -#ifdef PADDLE_WITH_DNNL - mem_desc_ = other.mem_desc_; -#endif return *this; } @@ -85,9 +78,6 @@ DenseTensor& DenseTensor::operator=(DenseTensor&& other) noexcept { std::swap(holder_, other.holder_); storage_properties_ = std::move(other.storage_properties_); std::swap(inplace_version_counter_, other.inplace_version_counter_); -#ifdef PADDLE_WITH_DNNL - mem_desc_ = other.mem_desc_; -#endif return *this; } diff --git a/paddle/phi/core/dense_tensor.h b/paddle/phi/core/dense_tensor.h index bcc2b07a89e3a3..b78cec14832722 100644 --- a/paddle/phi/core/dense_tensor.h +++ b/paddle/phi/core/dense_tensor.h @@ -22,12 +22,6 @@ limitations under the License. */ #include "paddle/phi/core/tensor_meta.h" #include "paddle/utils/test_macros.h" -/* @jim19930609: Move to MKLDNN_Tensor in the future - */ -#ifdef PADDLE_WITH_DNNL -#include "dnnl.hpp" // NOLINT -#endif - namespace phi { class DenseTensorUtils; @@ -290,18 +284,6 @@ class TEST_API DenseTensor : public TensorBase, std::shared_ptr inplace_version_counter_ = std::make_shared(); -/* @jim19930609: This is a hack -In general, it is badly designed to fuse MKLDNN-specific objects into a -generic Tensor. -We temporarily leave them here to unblock Tensor Unification progress. -In the final state, we should come up with a MKLDNN_Tensor and move the -following codes there. -*/ -#ifdef PADDLE_WITH_DNNL - /// \brief memory descriptor of tensor which have layout set as kMKLDNN - dnnl::memory::desc mem_desc_; -#endif - #ifndef PADDLE_WITH_CUSTOM_KERNEL #include "paddle/phi/core/dense_tensor.inl" #endif diff --git a/paddle/phi/core/dense_tensor.inl b/paddle/phi/core/dense_tensor.inl index 19101e7093f745..a8672b21711432 100644 --- a/paddle/phi/core/dense_tensor.inl +++ b/paddle/phi/core/dense_tensor.inl @@ -97,22 +97,12 @@ std::vector Split(int64_t split_size, int64_t axis) const; std::vector Chunk(int64_t chunks, int64_t axis) const; -/* @jim19930609: This is a hack -In general, it is badly designed to fuse MKLDNN-specific objects into a -generic Tensor. -We temporarily leave them here to unblock Tensor Unification progress. -In the final state, we should come up with a MKLDNN_Tensor and move the -following codes there. -*/ #ifdef PADDLE_WITH_DNNL public: const dnnl::memory::desc& mem_desc() const; -inline void set_mem_desc(const dnnl::memory::desc& mem_desc) { - mem_desc_ = mem_desc; - meta_.layout = DataLayout::ONEDNN; -} +void set_mem_desc(const dnnl::memory::desc& mem_desc); #endif diff --git a/paddle/phi/core/dense_tensor_impl.cc b/paddle/phi/core/dense_tensor_impl.cc index 5fa43647da19ce..39efb048e74320 100644 --- a/paddle/phi/core/dense_tensor_impl.cc +++ b/paddle/phi/core/dense_tensor_impl.cc @@ -377,7 +377,30 @@ std::vector DenseTensor::Chunk(int64_t chunks, } #ifdef PADDLE_WITH_DNNL -const dnnl::memory::desc& DenseTensor::mem_desc() const { return mem_desc_; } +const dnnl::memory::desc& DenseTensor::mem_desc() const { + if (storage_properties_ == nullptr) { + static dnnl::memory::desc undef_desc = dnnl::memory::desc(); + return undef_desc; + } + return this->storage_properties().mem_desc; +} + +void DenseTensor::set_mem_desc(const dnnl::memory::desc& mem_desc) { + if (storage_properties_ == nullptr) { + storage_properties_ = std::make_unique(); + static_cast(storage_properties_.get())->mem_desc = + mem_desc; + meta_.layout = DataLayout::ONEDNN; + } else if (OneDNNStorageProperties::classof(storage_properties_.get())) { + static_cast(storage_properties_.get())->mem_desc = + mem_desc; + meta_.layout = DataLayout::ONEDNN; + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "The actual type of storage_properties is inconsistent with the type " + "of the template parameter passed in.")); + } +} #endif // NOTE: For historical reasons, this interface has a special behavior, @@ -394,9 +417,6 @@ DenseTensor& DenseTensor::ShareDataWith(const DenseTensor& src) { meta_.strides = src.meta_.strides; storage_properties_ = std::move(CopyStorageProperties(src.storage_properties_)); -#ifdef PADDLE_WITH_DNNL - mem_desc_ = src.mem_desc_; -#endif return *this; } From 1aa6851bc94e066a1b0cecf27f3a760df143162f Mon Sep 17 00:00:00 2001 From: kangguangli Date: Tue, 2 Jan 2024 16:33:47 +0800 Subject: [PATCH 069/142] [PIR] Complement op defs (#60475) * complement translation of legacy matmul * Complement op mappings in translation for deformable_conv_v1. --- paddle/fluid/ir_adaptor/translator/op_compat_gen.py | 10 ++++++++++ paddle/fluid/ir_adaptor/translator/op_translator.cc | 12 +++++++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py index 596bf8534bfe67..ea844c659554d6 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py +++ b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py @@ -127,6 +127,16 @@ def insert_new_mutable_attributes( ) # special mapping list + op_name_mappings["deformable_conv_v1"] = "deformable_conv" + op_name_mappings["deformable_conv_v1_grad"] = "deformable_conv_grad" + op_arg_name_mappings["deformable_conv_v1"] = { + "x": "Input", + "offset": "Offset", + "filter": "Filter", + "mask": "Mask", + "out": "Output", + } + op_arg_name_mappings["set_value_grad"]["values_grad"] = "ValueTensor@GRAD" op_arg_name_mappings["fetch"] = {"x": "X"} op_arg_name_mappings["elementwise_add_grad_grad"] = { diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 68e9a89cefb76a..0227091e0aa531 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -544,8 +544,9 @@ std::vector OpTranscriber::GenerateOperationInput( info.name, op_desc.Type()); IR_ENFORCE(param_map->count(legacy_input_vars[0]), - "Input [%s] of op [%s] not found in param map", + "Input [%s: %s] of op [%s] not found in param map", info.name, + legacy_input_vars[0], op_desc.Type()); auto defining_info = (*param_map)[legacy_input_vars[0]]; op_inputs.push_back(defining_info.value); @@ -2998,6 +2999,14 @@ struct LegacyMatmulOpTranscriber : public OpTranscriber { param_map->PushValue(output_vars[0], VariableDefiningInfo(scale_op.out(), false, -1)); } + + void HandleNonexistentAttribute(pir::IrContext* ctx, + pir::AttributeMap* attribute_map, + const OpAttributeInfo& info) override { + if (info.name == "transpose_x" || info.name == "transpose_y") { + (*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, false); + } + } }; struct CEmbeddingOpTranscriber : public OpTranscriber { @@ -3051,6 +3060,7 @@ OpTranslator::OpTranslator() { special_handlers["sum"] = AddNOpTranscriber(); special_handlers["tril_triu"] = TrilAndTriuOpTranscriber(); special_handlers["tril_triu_grad"] = TrilAndTriuGradOpTranscriber(); + special_handlers["matmul"] = LegacyMatmulOpTranscriber(); special_handlers["matrix_rank"] = MatrixRankOpTranscriber(); special_handlers["mul"] = MulOpTranscriber(); special_handlers["mul_grad"] = MulGradOpTranscriber(); From 617b8ad1108c5ac1e9414ebde08436cb4f19d598 Mon Sep 17 00:00:00 2001 From: zhangyuqin1998 <75946871+zhangyuqin1998@users.noreply.github.com> Date: Tue, 2 Jan 2024 17:36:06 +0800 Subject: [PATCH 070/142] [pir]Supporting constant_folding_pass for train (#60355) * [pir]Supporting constant_folding_pass for train * fix * Update constant_folding_pass.cc --- .../pir/transforms/constant_folding_pass.cc | 114 +++++++++++++++--- .../pattern_rewrite/pattern_rewrite_test.cc | 31 ++++- 2 files changed, 126 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc index 553cf3967dd68b..620a7c1c2fecc4 100644 --- a/paddle/fluid/pir/transforms/constant_folding_pass.cc +++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc @@ -126,20 +126,7 @@ class ConstantFoldingPattern : public pir::RewritePattern { pir::PatternRewriter& rewriter) const override { // NOLINT VLOG(4) << "constant_folding_pass applys rewrite on [" << op->name() << "] op"; - pir::Program new_program(rewriter.ir_context()); - auto output_var_names = - BuildProgramFromOperation(op, &new_program, rewriter); - - // execute program - for (auto output_var_name : output_var_names) { - exe_config_->skip_gc_vars.insert(output_var_name); - } - auto kernel_program = - paddle::dialect::PdOpLowerToKernelPass(&new_program, place_); - paddle::framework::InterpreterCore core( - place_, {}, kernel_program->block(), scope_, *exe_config_); - - core.Run({}); + auto output_var_names = RunOp(op, rewriter); // ParameterOp and ConstantTensorOp should be created in the top-level block rewriter.SetInsertionPointToStart( @@ -236,6 +223,27 @@ class ConstantFoldingPattern : public pir::RewritePattern { return true; } + protected: + std::vector RunOp( + pir::Operation* op, + pir::PatternRewriter& rewriter) const { // NOLINT + pir::Program new_program(rewriter.ir_context()); + auto output_var_names = + BuildProgramFromOperation(op, &new_program, rewriter); + + // execute program + for (auto output_var_name : output_var_names) { + exe_config_->skip_gc_vars.insert(output_var_name); + } + auto kernel_program = + paddle::dialect::PdOpLowerToKernelPass(&new_program, place_); + paddle::framework::InterpreterCore core( + place_, {}, kernel_program->block(), scope_, *exe_config_); + + core.Run({}); + return output_var_names; + } + std::vector BuildProgramFromOperation( pir::Operation* op, pir::Program* new_program, @@ -299,7 +307,7 @@ class ConstantFoldingPattern : public pir::RewritePattern { return output_var_names; } - private: + protected: size_t* counter_; phi::Place place_; paddle::framework::Scope* scope_; @@ -307,6 +315,68 @@ class ConstantFoldingPattern : public pir::RewritePattern { std::vector* deleted_vars_; }; +class ConstantFoldingPatternForTrain : public ConstantFoldingPattern { + public: + ConstantFoldingPatternForTrain( + pir::IrContext* context, + size_t* counter, + const phi::Place& place, + paddle::framework::Scope* scope, + paddle::framework::interpreter::ExecutionConfig* exe_config, + std::vector* deleted_vars) + : ConstantFoldingPattern( + context, counter, place, scope, exe_config, deleted_vars) {} + + bool Match(pir::Operation* op) const override { + VLOG(4) << "constant_folding_pass applys match on [" << op->name() + << "] op"; + if (!ConstantFoldingPattern::Match(op)) { + return false; + } + for (uint32_t i = 0; i < op->num_operands(); i++) { + // inputs must come from or constant op + auto* prev_op = pir::GetDefiningOpForInput(op, i); + if (!prev_op || !prev_op->isa()) { + return false; + } + } + return true; + } + + void Rewrite(pir::Operation* op, + pir::PatternRewriter& rewriter) const override { // NOLINT + VLOG(4) << "constant_folding_pass for train applys rewrite on [" + << op->name() << "] op"; + + auto output_var_names = RunOp(op, rewriter); + + // ConstantTensorOp should be created in the top-level block + rewriter.SetInsertionPointToStart( + rewriter.block()->parent_program()->block()); + + for (uint32_t i = 0; i < op->num_results(); i++) { + if (!op->result(i) || !op->result(i).type()) { + continue; + } + std::string output_var_name = output_var_names[i]; + PADDLE_ENFORCE_NOT_NULL( + scope_->FindVar(output_var_name), + phi::errors::InvalidArgument("Parameter var [%s] not in scope.", + output_var_name)); + + auto constant_op = rewriter.Build( + rewriter.tensor_name_attr(output_var_name), op->result(i).type()); + constant_op->set_attribute( + kAttrIsPersisable, rewriter.array_attr({rewriter.bool_attr(true)})); + + rewriter.ReplaceAllUsesWith(op->result(i), constant_op->result(0)); + } + rewriter.EraseOp(op); + VLOG(4) << "constant_folding_pass for traun applied rewrite on [" + << op->name() << "] op"; + } +}; + class ConstantFoldingPass : public pir::Pass { public: ConstantFoldingPass() @@ -332,8 +402,18 @@ class ConstantFoldingPass : public pir::Pass { scope_, phi::errors::InvalidArgument("scope can not be nullptr")); pir::RewritePatternSet ps(context); - ps.Add( - context, &counter_, place_, scope_, &exe_config_, &deleted_vars_); + + if (Has("train_mode") && Get("train_mode")) { + ps.Add(context, + &counter_, + phi::CPUPlace{}, + scope_, + &exe_config_, + &deleted_vars_); + } else { + ps.Add( + context, &counter_, place_, scope_, &exe_config_, &deleted_vars_); + } patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); return true; } diff --git a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc index 93156a9d697ce9..1a87247dab35bd 100644 --- a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc @@ -445,8 +445,10 @@ void BuildConstantFoldingProgram(pir::Program *program, paddle::platform::DeviceContextPool::Instance().Get( paddle::platform::CPUPlace()); - auto op1 = builder.Build("a", dense_tensor_dtype); - auto op2 = builder.Build("b", dense_tensor_dtype); + auto op1 = builder.Build(builder.tensor_name_attr("a"), + dense_tensor_dtype); + auto op2 = builder.Build(builder.tensor_name_attr("b"), + dense_tensor_dtype); auto op3 = builder.Build(op1->result(0), op2->result(0)); @@ -493,6 +495,31 @@ TEST(constant_folding, ConstantFolding) { EXPECT_EQ(program.block()->size(), 2u); } +TEST(constant_folding, ConstantFolding_Train) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + pir::Program program(ctx); + paddle::framework::Scope scope; + BuildConstantFoldingProgram(&program, ctx, &scope); + + pir::PassManager pm(ctx); + std::unique_ptr constant_folding_pass = + pir::CreateConstantFoldingPass(); + phi::Place place = phi::CPUPlace(); + constant_folding_pass->SetNotOwned(pir::kPlaceAttr, &place); + constant_folding_pass->SetNotOwned(pir::kParamScopeAttr, &scope); + constant_folding_pass->Set("train_mode", new bool(true)); + + pm.AddPass(std::move(constant_folding_pass)); + pm.AddPass(pir::CreateDeadCodeEliminationPass()); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(program.block()->size(), 4u); +} + void BuildConcatProgram(pir::Program *program, pir::IrContext *ctx) { pir::Builder builder = pir::Builder(ctx, program->block()); auto x = builder From fea90ed51c76fd53bf3b196b5cdd22e01fba1b88 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Tue, 2 Jan 2024 18:31:45 +0800 Subject: [PATCH 071/142] [Dynamic Shape] Fuse shape ops into generate shape op pass (#60490) * add shape.generate_shape op * rename shape.generate_shape to cinn_op.generate_shape * refactor GenerateShapeOp::SymbolBinding * move GenerateShapeOp related helper functions into generate_shape_util.cc * minor fix * minor fix * backup * refine signature of ConvertDimExprToAttribute * minor fix for signature of ConvertDimExprToAttributes * remove SubstituteDimExpr from generate_shape_util.h * Fix compile error * Fix unittest compile error * Code format * Code format --- .../hlir/dialect/operator/ir/CMakeLists.txt | 4 +- .../operator/ir/generate_shape_util.cc} | 125 ++++-- .../operator/ir/generate_shape_util.h} | 31 +- .../hlir/dialect/operator/ir/manual_op.cc | 217 ++++++++-- .../cinn/hlir/dialect/operator/ir/manual_op.h | 38 ++ .../operator/transforms/CMakeLists.txt | 10 + ...e_shape_ops_into_generate_shape_op_pass.cc | 369 ++++++++++++++++++ ...se_shape_ops_into_generate_shape_op_pass.h | 42 ++ test/cpp/pir/cinn/CMakeLists.txt | 6 +- .../generate_shape_util_test.cc} | 11 +- test/cpp/pir/shape_dialect/CMakeLists.txt | 3 - 11 files changed, 778 insertions(+), 78 deletions(-) rename paddle/{pir/dialect/shape/utils/dim_expr_util.cc => cinn/hlir/dialect/operator/ir/generate_shape_util.cc} (69%) rename paddle/{pir/dialect/shape/utils/dim_expr_util.h => cinn/hlir/dialect/operator/ir/generate_shape_util.h} (53%) create mode 100644 paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc create mode 100644 paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.h rename test/cpp/pir/{shape_dialect/symbol_dim_expr_util_test.cc => cinn/generate_shape_util_test.cc} (92%) diff --git a/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt b/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt index 0b54046cb6e6af..56f9ab3d5ebe72 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt @@ -67,10 +67,12 @@ if(NOT CINN_ONLY) op_dialect.cc ${cinn_op_source_file} ${cinn_op_info_file} + generate_shape_util.cc manual_op.cc op_attribute.cc DEPS - op_dialect_vjp) + op_dialect_vjp + pir) target_include_directories(cinn_op_dialect PRIVATE ${CINN_DIALECT_SOURCE_DIR}) endif() diff --git a/paddle/pir/dialect/shape/utils/dim_expr_util.cc b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc similarity index 69% rename from paddle/pir/dialect/shape/utils/dim_expr_util.cc rename to paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc index 8421f500c23daa..eef663585a4086 100644 --- a/paddle/pir/dialect/shape/utils/dim_expr_util.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pir/dialect/shape/utils/dim_expr_util.h" +#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h" #include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_attribute.h" -namespace symbol { +namespace cinn::dialect { +using namespace symbol; // NOLINT namespace { @@ -58,71 +59,71 @@ std::string GetSerializedTag>() { return "Broadcast"; } -::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder, +::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx, const std::int64_t& dim_expr) { - return builder->int64_attr(dim_expr); + return pir::Int64Attribute::get(ctx, dim_expr); } -::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder, +::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx, const std::string& dim_expr) { - return builder->str_attr(dim_expr); + return pir::StrAttribute::get(ctx, dim_expr); } template -::pir::Attribute ConvertUnaryDimExprToAttributeImpl(::pir::Builder* builder, +::pir::Attribute ConvertUnaryDimExprToAttributeImpl(::pir::IrContext* ctx, const T& dim_expr) { std::vector<::pir::Attribute> attr_vecs{}; - attr_vecs.push_back(builder->str_attr(GetSerializedTag())); + attr_vecs.push_back(pir::StrAttribute::get(ctx, GetSerializedTag())); const auto& operand = dim_expr->data; - attr_vecs.push_back(ConvertDimExprToAttribute(builder, operand)); - return builder->array_attr(attr_vecs); + attr_vecs.push_back(ConvertDimExprToAttribute(ctx, operand)); + return pir::ArrayAttribute::get(ctx, attr_vecs); } ::pir::Attribute ConvertDimExprToAttributeImpl( - ::pir::Builder* builder, const Negative& dim_expr) { - return ConvertUnaryDimExprToAttributeImpl(builder, dim_expr); + ::pir::IrContext* ctx, const Negative& dim_expr) { + return ConvertUnaryDimExprToAttributeImpl(ctx, dim_expr); } ::pir::Attribute ConvertDimExprToAttributeImpl( - ::pir::Builder* builder, const Reciprocal& dim_expr) { - return ConvertUnaryDimExprToAttributeImpl(builder, dim_expr); + ::pir::IrContext* ctx, const Reciprocal& dim_expr) { + return ConvertUnaryDimExprToAttributeImpl(ctx, dim_expr); } template -::pir::Attribute ConvertVariadicDimExprToAttribute(::pir::Builder* builder, +::pir::Attribute ConvertVariadicDimExprToAttribute(::pir::IrContext* ctx, const T& dim_expr) { std::vector<::pir::Attribute> attr_vecs{}; - attr_vecs.push_back(builder->str_attr(GetSerializedTag())); + attr_vecs.push_back(pir::StrAttribute::get(ctx, GetSerializedTag())); const auto& operands = *(dim_expr.operands); for (const auto& operand : operands) { - attr_vecs.push_back(ConvertDimExprToAttribute(builder, operand)); + attr_vecs.push_back(ConvertDimExprToAttribute(ctx, operand)); } - return builder->array_attr(attr_vecs); + return pir::ArrayAttribute::get(ctx, attr_vecs); } -::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder, +::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx, const Add& dim_expr) { - return ConvertVariadicDimExprToAttribute(builder, dim_expr); + return ConvertVariadicDimExprToAttribute(ctx, dim_expr); } -::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder, +::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx, const Mul& dim_expr) { - return ConvertVariadicDimExprToAttribute(builder, dim_expr); + return ConvertVariadicDimExprToAttribute(ctx, dim_expr); } -::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder, +::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx, const Max& dim_expr) { - return ConvertVariadicDimExprToAttribute(builder, dim_expr); + return ConvertVariadicDimExprToAttribute(ctx, dim_expr); } -::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder, +::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx, const Min& dim_expr) { - return ConvertVariadicDimExprToAttribute(builder, dim_expr); + return ConvertVariadicDimExprToAttribute(ctx, dim_expr); } ::pir::Attribute ConvertDimExprToAttributeImpl( - ::pir::Builder* builder, const Broadcast& dim_expr) { - return ConvertVariadicDimExprToAttribute(builder, dim_expr); + ::pir::IrContext* ctx, const Broadcast& dim_expr) { + return ConvertVariadicDimExprToAttribute(ctx, dim_expr); } std::optional ConvertInt64AttributeToDimExpr( @@ -211,11 +212,11 @@ std::optional ConvertArrayAttributeToDimExpr( } // namespace -::pir::Attribute ConvertDimExprToAttribute(::pir::Builder* builder, +::pir::Attribute ConvertDimExprToAttribute(pir::IrContext* ctx, const DimExpr& dim_expr) { return std::visit( [&](const auto& impl) { - return ConvertDimExprToAttributeImpl(builder, impl); + return ConvertDimExprToAttributeImpl(ctx, impl); }, dim_expr.variant()); } @@ -359,4 +360,66 @@ MakeGetterDimExpr4SymbolName( }; } -} // namespace symbol +namespace { + +std::optional GetDimExprBySymbolBindingImpl( + const GenerateShapeOp::DataSymbolBinding& symbol_binding, + const std::function& + DimExpr4InputDim) { + const symbol::ShapeOrDataDimExprs& shape_or_data_dim_expr = + DimExpr4InputDim(symbol_binding.input_tensor_idx); + if (!shape_or_data_dim_expr.data().has_value()) return std::nullopt; + int dim_idx = symbol_binding.input_tensor_dim_idx; + if (dim_idx >= shape_or_data_dim_expr.data().value().size()) + return std::nullopt; + return shape_or_data_dim_expr.data().value().at(dim_idx); +} + +std::optional GetDimExprBySymbolBindingImpl( + const GenerateShapeOp::ShapeSymbolBinding& symbol_binding, + const std::function& + DimExpr4InputDim) { + const symbol::ShapeOrDataDimExprs& shape_or_data_dim_expr = + DimExpr4InputDim(symbol_binding.input_tensor_idx); + int dim_idx = symbol_binding.input_tensor_dim_idx; + if (dim_idx >= shape_or_data_dim_expr.shape().size()) return std::nullopt; + return shape_or_data_dim_expr.shape().at(dim_idx); +} + +} // namespace + +std::function(const std::string& symbol_name)> +MakeGetterDimExpr4SymbolName( + const GenerateShapeOp::SymbolBindings& symbol_bindings, + const std::function& + DimExpr4InputDim) { + std::unordered_map> + symbol_name2symbol_bindins{}; + const auto& GetDimExpr = + [&](const GenerateShapeOp::SymbolBinding& symbol_binding) { + return std::visit( + [&](const auto& impl) { + return GetDimExprBySymbolBindingImpl(impl, DimExpr4InputDim); + }, + symbol_binding); + }; + return [map = std::move(symbol_name2symbol_bindins), GetDimExpr]( + const std::string& symbol_name) -> std::optional { + const auto& iter = map.find(symbol_name); + if (iter == map.end()) return std::nullopt; + std::optional ret = std::nullopt; + for (const auto& symbol_binding : iter->second) { + const auto& current = GetDimExpr(symbol_binding); + if (!current.has_value()) return std::nullopt; + if (ret.has_value()) { + // Same names, same DimExprs. + if (ret.value() != current.value()) return std::nullopt; + } else { + ret = current; + } + } + return ret; + }; +} + +} // namespace cinn::dialect diff --git a/paddle/pir/dialect/shape/utils/dim_expr_util.h b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h similarity index 53% rename from paddle/pir/dialect/shape/utils/dim_expr_util.h rename to paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h index 3ed4550c2248d5..ee4ad3c129e6b4 100644 --- a/paddle/pir/dialect/shape/utils/dim_expr_util.h +++ b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h @@ -15,28 +15,35 @@ #pragma once #include +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/pir/core/builder.h" -#include "paddle/pir/core/dll_decl.h" #include "paddle/pir/dialect/shape/utils/dim_expr.h" -namespace symbol { +namespace cinn::dialect { -IR_API ::pir::Attribute ConvertDimExprToAttribute(::pir::Builder* builder, - const DimExpr& dim_expr); -IR_API std::optional ConvertAttributeToDimExpr( +::pir::Attribute ConvertDimExprToAttribute(pir::IrContext* ctx, + const symbol::DimExpr& dim_expr); + +std::optional ConvertAttributeToDimExpr( ::pir::Attribute attribute); -IR_API std::optional SubstituteDimExpr( - const DimExpr& dim_expr, - const std::function(const std::string& symbol_name)>& - DimExpr4SymbolName); +std::optional SubstituteDimExpr( + const symbol::DimExpr& dim_expr, + const std::function( + const std::string& symbol_name)>& DimExpr4SymbolName); -IR_API std::function(const std::string& symbol_name)> +std::function(const std::string& symbol_name)> MakeGetterDimExpr4SymbolName( const std::vector>& symbol_bindings, - const std::function( + const std::function( int in_tensor_idx, int in_tensor_dim_idx)>& DimExpr4InputDim); -} // namespace symbol +std::function(const std::string& symbol_name)> +MakeGetterDimExpr4SymbolName( + const GenerateShapeOp::SymbolBindings& symbol_bindings, + const std::function& + DimExpr4InputDim); + +} // namespace cinn::dialect diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc index 68a09ad7a9868b..7bbcd74025a076 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc @@ -16,8 +16,12 @@ #include #include "glog/logging.h" +#include "paddle/common/ddim.h" #include "paddle/common/enforce.h" +#include "paddle/fluid/pir/dialect/operator/ir/ir_meta_tensor.h" +#include "paddle/fluid/pir/dialect/operator/ir/ir_tensor.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/op_base.h" #include "paddle/pir/dialect/control_flow/ir/cf_op.h" @@ -25,25 +29,25 @@ namespace cinn { namespace dialect { -const char *GroupOp::attributes_name[GroupOp::attributes_num] = {"group_info"}; -const char *ConcatOp::attributes_name[ConcatOp::attributes_num] = {"axis"}; -const char *SplitOp::attributes_name[SplitOp::attributes_num] = { +const char* GroupOp::attributes_name[GroupOp::attributes_num] = {"group_info"}; +const char* ConcatOp::attributes_name[ConcatOp::attributes_num] = {"axis"}; +const char* SplitOp::attributes_name[SplitOp::attributes_num] = { "num_or_sections", "axis"}; -void GroupOp::Build(pir::Builder &builder, - pir::OperationArgument &argument, - const std::vector &output_types) { +void GroupOp::Build(pir::Builder& builder, + pir::OperationArgument& argument, + const std::vector& output_types) { argument.AddRegion(nullptr); argument.output_types = output_types; } -void GroupOp::Build(pir::Builder &builder, // NOLINT - pir::OperationArgument &argument, // NOLINT - std::unique_ptr &&block) { +void GroupOp::Build(pir::Builder& builder, // NOLINT + pir::OperationArgument& argument, // NOLINT + std::unique_ptr&& block) { VLOG(4) << "Start build GroupOp"; if (block && !block->empty()) { IR_ENFORCE(block->back().isa()); - auto &op = block->back(); + auto& op = block->back(); for (size_t i = 0; i < op.num_operands(); ++i) { argument.AddOutput(op.operand(i).type()); } @@ -51,15 +55,15 @@ void GroupOp::Build(pir::Builder &builder, // NOLINT argument.AddRegion().push_back(block.release()); } -pir::Block *GroupOp::block() { - pir::Region ®ion = (*this)->region(0); +pir::Block* GroupOp::block() { + pir::Region& region = (*this)->region(0); if (region.empty()) region.emplace_back(); return ®ion.front(); } -std::vector GroupOp::ops() { - std::vector rt_ops; - for (auto &op : *block()) { +std::vector GroupOp::ops() { + std::vector rt_ops; + for (auto& op : *block()) { rt_ops.push_back(&op); } return rt_ops; @@ -67,8 +71,8 @@ std::vector GroupOp::ops() { void GroupOp::VerifySig() {} -void GroupOp::Print(pir::IrPrinter &printer) { - auto &os = printer.os; +void GroupOp::Print(pir::IrPrinter& printer) { + auto& os = printer.os; auto op = operation(); printer.PrintOpResult(op); os << " = " << name(); @@ -76,16 +80,16 @@ void GroupOp::Print(pir::IrPrinter &printer) { os << " -> "; printer.PrintOpReturnType(op); os << " {"; - for (auto &sub_op : ops()) { + for (auto& sub_op : ops()) { os << "\n"; printer.PrintOperation(sub_op); } os << " \n }"; } -void ConcatOp::Build(pir::Builder &builder, // NOLINT - pir::OperationArgument &argument, // NOLINT - const std::vector &inputs, +void ConcatOp::Build(pir::Builder& builder, // NOLINT + pir::OperationArgument& argument, // NOLINT + const std::vector& inputs, int axis) { VLOG(4) << "Start build ConcatOp"; @@ -131,10 +135,10 @@ void ConcatOp::Build(pir::Builder &builder, // NOLINT "axis", pir::Int32Attribute::get(pir::IrContext::Instance(), axis)); } -void SplitOp::Build(pir::Builder &builder, // NOLINT - pir::OperationArgument &argument, // NOLINT +void SplitOp::Build(pir::Builder& builder, // NOLINT + pir::OperationArgument& argument, // NOLINT pir::Value input, - const std::vector §ions, + const std::vector& sections, int axis) { VLOG(4) << "Start build ConcatOp"; @@ -177,9 +181,174 @@ void SplitOp::Build(pir::Builder &builder, // NOLINT "axis", pir::Int32Attribute::get(pir::IrContext::Instance(), axis)); } +const char* GenerateShapeOp::attributes_name[attributes_num] = { + "output_dim_exprs", "symbol_bindings"}; + +void GenerateShapeOp::Build( + pir::Builder& builder, + pir::OperationArgument& argument, + const std::vector& inputs, + const std::vector& output_dim_exprs, + const GenerateShapeOp::SymbolBindings& symbol_bindings) { + CHECK(!inputs.empty()); + argument.AddInputs(inputs); + argument.AddAttribute("output_dim_exprs", + builder.array_attr(output_dim_exprs)); + argument.AddAttribute( + "symbol_bindings", + ConvertSymbolBindingsToAttribute(builder, symbol_bindings)); + argument.AddOutputs({[&]() { + auto* ctx = pir::IrContext::Instance(); + auto type = pir::Int64Type::get(ctx); + auto dim = + ::common::make_ddim({static_cast(output_dim_exprs.size())}); + return paddle::dialect::DenseTensorType::get(ctx, type, dim); + }()}); + ::pir::PassStopGradientsDefaultly(argument); +} + +namespace { + +const char* GetSymbolBindingTypeImpl( + const GenerateShapeOp::DataSymbolBinding& binding) { + return "DataSymbolBinding"; +} + +const char* GetSymbolBindingTypeImpl( + const GenerateShapeOp::ShapeSymbolBinding& binding) { + return "ShapeSymbolBinding"; +} + +const char* GetSymbolBindingType( + const GenerateShapeOp::SymbolBinding& binding) { + return std::visit( + [](const auto& impl) { return GetSymbolBindingTypeImpl(impl); }, binding); +} + +const GenerateShapeOp::SymbolBindingBase* GetSymbolBindingBaseImpl( + const GenerateShapeOp::DataSymbolBinding& binding) { + return &binding; +} + +const GenerateShapeOp::SymbolBindingBase* GetSymbolBindingBaseImpl( + const GenerateShapeOp::ShapeSymbolBinding& binding) { + return &binding; +} + +const GenerateShapeOp::SymbolBindingBase* GetSymbolBindingBase( + const GenerateShapeOp::SymbolBinding& binding) { + return std::visit( + [](const auto& impl) { return GetSymbolBindingBaseImpl(impl); }, binding); +} + +typedef GenerateShapeOp::SymbolBinding (*SymbolBindingConstructorT)( + const std::string& symbol_name, + int64_t input_tensor_idx, + int64_t input_tensor_dim_idx); + +GenerateShapeOp::SymbolBinding MakeDataSymbolBinding( + const std::string& symbol_name, + int64_t input_tensor_idx, + int64_t input_tensor_dim_idx) { + return GenerateShapeOp::DataSymbolBinding{ + symbol_name, input_tensor_idx, input_tensor_dim_idx}; +} + +GenerateShapeOp::SymbolBinding MakeShapeSymbolBinding( + const std::string& symbol_name, + int64_t input_tensor_idx, + int64_t input_tensor_dim_idx) { + return GenerateShapeOp::ShapeSymbolBinding{ + symbol_name, input_tensor_idx, input_tensor_dim_idx}; +} + +std::optional GetMakerSymbolBinding( + const std::string& type) { + static std::map map{ + {GetSymbolBindingTypeImpl(GenerateShapeOp::DataSymbolBinding{}), + &MakeDataSymbolBinding}, + {GetSymbolBindingTypeImpl(GenerateShapeOp::ShapeSymbolBinding{}), + &MakeShapeSymbolBinding}, + }; + const auto& iter = map.find(type); + if (iter == map.end()) return std::nullopt; + return iter->second; +} + +std::optional MakeSymbolBinding( + const std::string& type, + const std::string& symbol_name, + int64_t input_tensor_idx, + int64_t input_tensor_dim_idx) { + auto opt_creator = GetMakerSymbolBinding(type); + if (!opt_creator.has_value()) return std::nullopt; + return opt_creator.value()( + symbol_name, input_tensor_idx, input_tensor_dim_idx); +} + +} // namespace + +pir::Attribute GenerateShapeOp::ConvertSymbolBindingsToAttribute( + pir::Builder& builder, + const GenerateShapeOp::SymbolBindings& symbol_bindings) { + const auto& ConvertSymbolBindingToAttr = [&](const SymbolBinding& binding) { + const auto* type = GetSymbolBindingType(binding); + const auto& [symbol_name, input_tensor_idx, input_tensor_dim_idx] = + *GetSymbolBindingBase(binding); + return builder.array_attr({ + builder.str_attr(type), + builder.str_attr(symbol_name), + builder.int64_attr(input_tensor_idx), + builder.int64_attr(input_tensor_dim_idx), + }); + }; + std::vector bindings_attr{}; + for (const auto& symbol_binding : symbol_bindings) { + bindings_attr.push_back(ConvertSymbolBindingToAttr(symbol_binding)); + } + return builder.array_attr(bindings_attr); +} + +std::optional +GenerateShapeOp::ConvertAttributeToSymbolBindings( + const pir::Attribute& symbol_bindings) { + if (!symbol_bindings.isa()) return std::nullopt; + const auto& symbol_bindings_array_attr = + symbol_bindings.dyn_cast(); + GenerateShapeOp::SymbolBindings ret{GenerateShapeOp::SymbolBindings{}}; + for (int i = 0; i < symbol_bindings_array_attr.size(); ++i) { + const auto& symbol_binding = symbol_bindings_array_attr.at(i); + if (!symbol_binding.isa()) return std::nullopt; + const auto& symbol_binding_array_attr = + symbol_binding.dyn_cast(); + if (symbol_binding_array_attr.size() != 4) return std::nullopt; + if (!symbol_binding_array_attr.at(0).isa()) + return std::nullopt; + if (!symbol_binding_array_attr.at(1).isa()) + return std::nullopt; + if (!symbol_binding_array_attr.at(2).isa()) + return std::nullopt; + if (!symbol_binding_array_attr.at(3).isa()) + return std::nullopt; + const auto& opt_symbol_binding = MakeSymbolBinding( + symbol_binding_array_attr.at(0) + .dyn_cast() + .AsString(), + symbol_binding_array_attr.at(1) + .dyn_cast() + .AsString(), + symbol_binding_array_attr.at(2).dyn_cast().data(), + symbol_binding_array_attr.at(3).dyn_cast().data()); + if (!opt_symbol_binding.has_value()) return std::nullopt; + ret.emplace_back(opt_symbol_binding.value()); + } + return std::move(ret); +} + } // namespace dialect } // namespace cinn IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::GroupOp) IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::ConcatOp) IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::SplitOp) +IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::GenerateShapeOp); diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h index fbec6e32ee56b7..8a9acef15aa9d7 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include "paddle/phi/core/infermeta_utils.h" #include "paddle/pir/core/builder.h" #include "paddle/pir/core/ir_printer.h" @@ -81,9 +82,46 @@ class IR_API SplitOp : public pir::Op { void VerifySig() const {} }; +class GenerateShapeOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "cinn_op.generate_shape"; } + static constexpr uint32_t attributes_num = 2; + static const char *attributes_name[attributes_num]; + + struct SymbolBindingBase { + std::string symbol_name; + int64_t input_tensor_idx; + int64_t input_tensor_dim_idx; + }; + + struct DataSymbolBinding : public SymbolBindingBase {}; + struct ShapeSymbolBinding : public SymbolBindingBase {}; + + using SymbolBinding = std::variant; + + using SymbolBindings = std::vector; + + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + const std::vector &inputs, + const std::vector &output_dim_exprs, + const SymbolBindings &symbol_bindings); + + void VerifySig() {} + + pir::OpResult out() { return result(0); } + + static pir::Attribute ConvertSymbolBindingsToAttribute( + pir::Builder &builder, const SymbolBindings &symbol_bindings); // NOLINT + static std::optional ConvertAttributeToSymbolBindings( + const pir::Attribute &symbol_bindings); +}; + } // namespace dialect } // namespace cinn IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GroupOp) IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::ConcatOp) IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::SplitOp) +IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GenerateShapeOp); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt index 18ce80a92baff4..6d76ccbec8adc1 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt @@ -19,4 +19,14 @@ if(NOT CINN_ONLY) pir cinn_op_dialect op_dialect_vjp) + + cinn_cc_library( + fuse_shape_ops_into_generate_shape_op_pass + SRCS + fuse_shape_ops_into_generate_shape_op_pass.cc + DEPS + pir + cinn_op_dialect + op_dialect_vjp) + endif() diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc new file mode 100644 index 00000000000000..48c7427b402a14 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc @@ -0,0 +1,369 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.h" +#include +#include +#include "paddle/cinn/common/bfs_walker.h" +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" +#include "paddle/common/ddim.h" +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/match_context.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pattern_rewrite/pattern_applicator.h" +#include "paddle/pir/pattern_rewrite/pattern_match.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +namespace cinn { +namespace dialect { +namespace ir { + +namespace { + +using ShapeOrDataDimExprs4ValueT = + std::function; + +std::vector FindSourceDenseTensorOfDimTensor( + pir::Value shape, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) { + std::vector ret{}; + const auto& Emplace = [&](pir::Value value) { + if (std::find(ret.begin(), ret.end(), value) != ret.end()) return; + ret.emplace_back(value); + }; + const auto& ForEachInputValue = + [&](pir::Value value, const std::function& Visit) { + // find input dimension tensor; + pir::Operation* owner = value.defining_op(); + if (owner == nullptr) return; + for (int i = 0; i < owner->num_operands(); ++i) { + Visit(owner->operand_source(i)); + } + }; + const auto& IsDimTensor = [&](pir::Value value) -> bool { + return ShapeOrDataDimExprs4Value(value).data().has_value(); + }; + const auto& ForEachInputDimTensor = + [&](pir::Value value, const std::function& Visit) { + // find input dimension tensor; + ForEachInputValue(value, [&](pir::Value input) { + if (IsDimTensor(input)) { + Visit(input); + } + }); + }; + common::BfsWalker walker(ForEachInputDimTensor); + walker(shape, [&](pir::Value value) { + size_t input_cnt = 0; + ForEachInputValue(value, [&](pir::Value input) { + ++input_cnt; + if (IsDimTensor(input)) return; + Emplace(input); + }); + if (input_cnt == 0) { + // `value` is a result of a source op. + Emplace(value); + } + }); + return ret; +} + +bool IsConstant(const std::vector& dim_exprs) { + for (const auto& dim_expr : dim_exprs) { + if (dim_expr.isa()) continue; + return false; + } + return true; +} + +bool IsAtomicImpl(int64_t) { return true; } + +bool IsAtomicImpl(const std::string&) { return true; } + +bool IsAtomicImpl(const symbol::Negative&) { return false; } + +bool IsAtomicImpl(const symbol::Reciprocal&) { return false; } + +bool IsAtomicImpl(const symbol::Add&) { return false; } + +bool IsAtomicImpl(const symbol::Mul&) { return false; } + +bool IsAtomicImpl(const symbol::Max&) { return false; } + +bool IsAtomicImpl(const symbol::Min&) { return false; } + +bool IsAtomicImpl(const symbol::Broadcast&) { return false; } + +bool IsAtomic(const symbol::DimExpr& dim_expr) { + return std::visit([](const auto& impl) { return IsAtomicImpl(impl); }, + dim_expr.variant()); +} + +bool InputDimExprsAllSupported( + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, + const std::vector& input_tensors) { + const auto& AllSupported = + [](const std::vector& dim_exprs) -> bool { + for (const auto& dim_expr : dim_exprs) { + if (!IsAtomic(dim_expr)) return false; + } + return true; + }; + for (const auto& input_tensor : input_tensors) { + const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor); + if (!AllSupported(dim_exprs.shape())) return false; + if (dim_exprs.data().has_value()) { + if (!AllSupported(dim_exprs.data().value())) return false; + } + } + return true; +} + +void ConvertDimExprToAttributes(pir::IrContext* ir_context, + const std::vector& dim_exprs, + std::vector* attrs) { + attrs->clear(); + attrs->reserve(dim_exprs.size()); + for (const auto& dim_expr : dim_exprs) { + attrs->emplace_back(ConvertDimExprToAttribute(ir_context, dim_expr)); + } +} + +void CollectSymbolNames(const symbol::DimExpr& dim_expr, + std::set* ret); + +void CollectSymbolNamesImpl(const int64_t& dim_expr, + std::set* ret) { + // do nothing. +} + +void CollectSymbolNamesImpl(const std::string& dim_expr, + std::set* ret) { + ret->insert(dim_expr); +} + +template +void CollectSymbolNamesImplForUnary(const T& dim_expr, + std::set* ret) { + const auto& [operand] = *dim_expr; + CollectSymbolNames(operand, ret); +} + +void CollectSymbolNamesImpl(const symbol::Negative& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForUnary(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Reciprocal& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForUnary(dim_expr, ret); +} + +template +void CollectSymbolNamesImplForVariadic(const T& dim_expr, + std::set* ret) { + const auto& operands = *(dim_expr.operands); + for (const auto& operand : operands) { + CollectSymbolNames(operand, ret); + } +} + +void CollectSymbolNamesImpl(const symbol::Add& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Mul& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Max& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Min& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Broadcast& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNames(const symbol::DimExpr& dim_expr, + std::set* ret) { + return std::visit( + [&](const auto& impl) { return CollectSymbolNamesImpl(impl, ret); }, + dim_expr.variant()); +} + +void CollectSymbolNames(const std::vector& dim_exprs, + std::set* ret) { + for (const auto& dim_expr : dim_exprs) { + CollectSymbolNames(dim_expr, ret); + } +} + +template +void AppendSymbolBindings(const std::vector& dim_exprs, + const std::set& symbol_names, + int in_tensor_idx, + GenerateShapeOp::SymbolBindings* symbol_bindings) { + for (int in_tensor_dim_idx = 0; in_tensor_dim_idx < dim_exprs.size(); + ++in_tensor_dim_idx) { + const auto& dim_expr = dim_exprs.at(in_tensor_dim_idx); + CHECK(IsAtomic(dim_expr)); + if (!dim_expr.isa()) continue; + const auto& sym_name = dim_expr.dyn_cast(); + if (symbol_names.find(sym_name) == symbol_names.end()) continue; + symbol_bindings->emplace_back(SymbolBindingsT{ + /*.symbol_name=*/sym_name, + /*.input_tensor_idx=*/in_tensor_idx, + /*.input_tensor_dim_idx=*/in_tensor_dim_idx, + }); + } +} + +void GenerateSymbolBindings( + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, + const std::vector& input_tensors, + const std::set& symbol_names, + GenerateShapeOp::SymbolBindings* symbol_bindings) { + for (int i = 0; i < input_tensors.size(); ++i) { + const auto& input_tensor = input_tensors.at(i); + const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor); + AppendSymbolBindings( + dim_exprs.shape(), symbol_names, i, symbol_bindings); + if (dim_exprs.data().has_value()) { + AppendSymbolBindings( + dim_exprs.shape(), symbol_names, i, symbol_bindings); + } + } +} + +bool MakeGenerateShapeOpAttribute( + pir::IrContext* ir_context, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, + const std::vector& input_tensors, + pir::Value output_shape, + std::vector* output_dim_expr_attrs, + GenerateShapeOp::SymbolBindings* symbol_bindings) { + const auto& shape_or_data_dim_exprs = ShapeOrDataDimExprs4Value(output_shape); + CHECK(shape_or_data_dim_exprs.data().has_value()); + const auto& out_dim_exprs = shape_or_data_dim_exprs.data().value(); + if (IsConstant(out_dim_exprs)) return false; + if (!InputDimExprsAllSupported(ShapeOrDataDimExprs4Value, input_tensors)) { + VLOG(4) << "input dim_exprs are not as simple as symbols, please make sure " + "they are handled by other passes"; + return false; + } + // generate output_dim_expr_attrs + ConvertDimExprToAttributes( + ir_context, out_dim_exprs, /*out*/ output_dim_expr_attrs); + // generate symbol_bindings + std::set symbol_names_in_out_dim_exprs{}; + CollectSymbolNames(out_dim_exprs, &symbol_names_in_out_dim_exprs); + GenerateSymbolBindings(ShapeOrDataDimExprs4Value, + input_tensors, + symbol_names_in_out_dim_exprs, + /*out*/ symbol_bindings); + return true; +} + +std::optional GetOutOfRewritedGenerateShapeOp( + pir::Value shape, + pir::PatternRewriter* rewriter, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) { + std::vector input_tensors = + FindSourceDenseTensorOfDimTensor(shape, ShapeOrDataDimExprs4Value); + if (input_tensors.empty()) return std::nullopt; + std::vector output_dim_expr_attrs{}; + GenerateShapeOp::SymbolBindings symbol_bindings{}; + bool success = MakeGenerateShapeOpAttribute(rewriter->ir_context(), + ShapeOrDataDimExprs4Value, + input_tensors, + shape, + &output_dim_expr_attrs, + &symbol_bindings); + if (!success) return std::nullopt; + return rewriter + ->Build( + input_tensors, output_dim_expr_attrs, symbol_bindings) + .out(); +} + +bool ProcessOp(paddle::dialect::ExpandOp op, + pir::PatternRewriter* rewriter, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) { + std::optional opt_generated_shape = + GetOutOfRewritedGenerateShapeOp( + op.shape(), rewriter, ShapeOrDataDimExprs4Value); + if (!opt_generated_shape.has_value()) return false; + op->operand(1).set_source(opt_generated_shape.value()); + return true; +} + +} // namespace + +template +class FuseShapeOpsIntoGenerateShapeOpPattern + : public pir::OpRewritePattern { + public: + FuseShapeOpsIntoGenerateShapeOpPattern( + pir::IrContext* context, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) + : pir::OpRewritePattern(context), + ShapeOrDataDimExprs4Value_(ShapeOrDataDimExprs4Value) {} + + bool MatchAndRewrite(OPTYPE op, + pir::PatternRewriter& rewriter) const override { + return ProcessOp(op, &rewriter, ShapeOrDataDimExprs4Value_); + } + + private: + ShapeOrDataDimExprs4ValueT ShapeOrDataDimExprs4Value_; +}; + +FuseShapeOpsIntoGenerateShapeOpPass::FuseShapeOpsIntoGenerateShapeOpPass( + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) + : pir::PatternRewritePass("fuse_shape_ops_into_generate_shape_op_pass", 1), + ShapeOrDataDimExprs4Value_(ShapeOrDataDimExprs4Value) {} + +pir::RewritePatternSet FuseShapeOpsIntoGenerateShapeOpPass::InitializePatterns( + pir::IrContext* context) { + pir::RewritePatternSet ps(context); + // elementwise ops + ps.Add>( + context, ShapeOrDataDimExprs4Value_); + + return ps; +} + +bool FuseShapeOpsIntoGenerateShapeOpPass::CanApplyOn(pir::Operation* op) const { + return op->isa() && op->num_regions() > 0; +} + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.h b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.h new file mode 100644 index 00000000000000..393ae49825182a --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.h @@ -0,0 +1,42 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/dialect/shape/utils/dim_expr.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" + +namespace cinn { +namespace dialect { +namespace ir { + +class FuseShapeOpsIntoGenerateShapeOpPass : public pir::PatternRewritePass { + public: + using ShapeOrDataDimExprs4ValueT = + std::function; + explicit FuseShapeOpsIntoGenerateShapeOpPass( + const ShapeOrDataDimExprs4ValueT &ShapeOrDataDimExprs4Value); + + pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override; + + bool CanApplyOn(pir::Operation *op) const override; + + private: + ShapeOrDataDimExprs4ValueT ShapeOrDataDimExprs4Value_; +}; + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/test/cpp/pir/cinn/CMakeLists.txt b/test/cpp/pir/cinn/CMakeLists.txt index b38edcbb62041d..ccf7c4d8ce686d 100644 --- a/test/cpp/pir/cinn/CMakeLists.txt +++ b/test/cpp/pir/cinn/CMakeLists.txt @@ -25,6 +25,9 @@ if(WITH_TESTING AND WITH_CINN) paddle_test(test_compilation_task SRCS compilation_task_test.cc) + paddle_test(test_generate_shape_util_test SRCS generate_shape_util_test.cc + DEPS cinn_op_dialect) + # DO NOT forget add test name here, otherwise it will not be executed in # CINN CI. set(cinn_unit_tests @@ -37,7 +40,8 @@ if(WITH_TESTING AND WITH_CINN) test_pir_all_path test_group_op test_pir_build_cinn_pass - test_compilation_task) + test_compilation_task + test_generate_shape_util_test) foreach(test_name ${cinn_unit_tests}) get_property( diff --git a/test/cpp/pir/shape_dialect/symbol_dim_expr_util_test.cc b/test/cpp/pir/cinn/generate_shape_util_test.cc similarity index 92% rename from test/cpp/pir/shape_dialect/symbol_dim_expr_util_test.cc rename to test/cpp/pir/cinn/generate_shape_util_test.cc index 0893a6d5027055..4fc69c877eb5f7 100644 --- a/test/cpp/pir/shape_dialect/symbol_dim_expr_util_test.cc +++ b/test/cpp/pir/cinn/generate_shape_util_test.cc @@ -14,13 +14,14 @@ #include "gtest/gtest.h" +#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/pir/dialect/shape/utils/dim_expr_builder.h" -#include "paddle/pir/dialect/shape/utils/dim_expr_util.h" #include "test/cpp/pir/tools/test_pir_utils.h" -namespace symbol { +namespace cinn::dialect { +using namespace symbol; // NOLINT namespace { DimExpr CreateExampleDimExpr() { @@ -37,11 +38,9 @@ DimExpr CreateExampleDimExpr() { TEST(DimExprUtil, Convert) { pir::IrContext* ctx = pir::IrContext::Instance(); - pir::Program program(ctx); - pir::Builder builder = pir::Builder(ctx, program.block()); DimExpr dim_expr = CreateExampleDimExpr(); - ::pir::Attribute attr = ConvertDimExprToAttribute(&builder, dim_expr); + ::pir::Attribute attr = ConvertDimExprToAttribute(ctx, dim_expr); std::optional opt_expr = ConvertAttributeToDimExpr(attr); ASSERT_TRUE(opt_expr.has_value()); ASSERT_EQ(opt_expr.value(), dim_expr); @@ -96,4 +95,4 @@ TEST(DimExprUtil, MakeGetterDimExpr4SymbolName) { ASSERT_EQ(opt_dim_expr.value(), dim_expr); } -} // namespace symbol +} // namespace cinn::dialect diff --git a/test/cpp/pir/shape_dialect/CMakeLists.txt b/test/cpp/pir/shape_dialect/CMakeLists.txt index decfc904088464..19e1f55dad7638 100644 --- a/test/cpp/pir/shape_dialect/CMakeLists.txt +++ b/test/cpp/pir/shape_dialect/CMakeLists.txt @@ -4,9 +4,6 @@ paddle_test(shape_struct_test SRCS shape_struct_test.cc) paddle_test(symbol_dim_expr_test SRCS symbol_dim_expr_test.cc) -paddle_test(symbol_dim_expr_util_test SRCS symbol_dim_expr_util_test.cc DEPS - gtest) - if(WITH_CINN) paddle_test( shape_optimization_test From 58ca933a1d13843392c8fdebbc9db5174ec54ddf Mon Sep 17 00:00:00 2001 From: co63oc Date: Tue, 2 Jan 2024 19:06:46 +0800 Subject: [PATCH 072/142] Fix _hiden_size to _hidden_size (#60485) --- .../dygraph_to_static/seq2seq_dygraph_model.py | 6 +++--- test/dygraph_to_static/test_fallback.py | 18 +++++++++--------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/test/dygraph_to_static/seq2seq_dygraph_model.py b/test/dygraph_to_static/seq2seq_dygraph_model.py index d2488c31b0ddb0..2359a7df502399 100644 --- a/test/dygraph_to_static/seq2seq_dygraph_model.py +++ b/test/dygraph_to_static/seq2seq_dygraph_model.py @@ -41,7 +41,7 @@ def __init__( ): super().__init__(dtype) - self._hiden_size = hidden_size + self._hidden_size = hidden_size self._param_attr = param_attr self._bias_attr = bias_attr self._gate_activation = gate_activation or paddle.nn.functional.sigmoid @@ -52,13 +52,13 @@ def __init__( self._weight = self.create_parameter( attr=self._param_attr, - shape=[self._input_size + self._hiden_size, 4 * self._hiden_size], + shape=[self._input_size + self._hidden_size, 4 * self._hidden_size], dtype=self._dtype, ) self._bias = self.create_parameter( attr=self._bias_attr, - shape=[4 * self._hiden_size], + shape=[4 * self._hidden_size], dtype=self._dtype, is_bias=True, ) diff --git a/test/dygraph_to_static/test_fallback.py b/test/dygraph_to_static/test_fallback.py index 25a89da29fc9c5..aca2cdbb507cec 100644 --- a/test/dygraph_to_static/test_fallback.py +++ b/test/dygraph_to_static/test_fallback.py @@ -36,7 +36,7 @@ def unsupport_func(x): return paddle.to_tensor(t) -class SuppportNet(paddle.nn.Layer): +class SupportNet(paddle.nn.Layer): def __init__(self): super().__init__() @@ -44,7 +44,7 @@ def forward(self, x): return support_func(x) -class UnsuppportNet(paddle.nn.Layer): +class UnsupportNet(paddle.nn.Layer): def __init__(self): super().__init__() @@ -76,8 +76,8 @@ def test_case_func_fallback(self): np.testing.assert_allclose(output.numpy(), unsupport_func(self.x)) def test_case_net_fallback(self): - s_net = SuppportNet() - u_net = UnsuppportNet() + s_net = SupportNet() + u_net = UnsupportNet() np.testing.assert_allclose( paddle.jit.to_static(s_net)(self.x).numpy(), 4 ) @@ -92,8 +92,8 @@ def test_case_net_fallback(self): @test_ast_only def test_case_net_error(self): - s_net = SuppportNet() - u_net = UnsuppportNet() + s_net = SupportNet() + u_net = UnsupportNet() np.testing.assert_allclose( paddle.jit.to_static(s_net)(self.x).numpy(), 4 ) @@ -111,7 +111,7 @@ def test_case_training(self): build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = True u_net = paddle.jit.to_static( - UnsuppportNet(), build_strategy=build_strategy + UnsupportNet(), build_strategy=build_strategy ) u_net.eval() np.testing.assert_allclose(u_net(self.x).numpy(), [1, 1]) @@ -122,7 +122,7 @@ def test_case_save_error(self): """ test the save will raise error. """ - u_net = UnsuppportNet() + u_net = UnsupportNet() u_net = paddle.jit.to_static( u_net, input_spec=[paddle.static.InputSpec(name='x', shape=[1])] ) @@ -133,7 +133,7 @@ def test_case_save_error_2(self): """ test the save will raise error. """ - u_net = UnsuppportNet() + u_net = UnsupportNet() build_strategy = paddle.static.BuildStrategy() build_strategy.build_cinn_pass = True u_net = paddle.jit.to_static(u_net, build_strategy=build_strategy) From 5376caa3cd172d2450e5fe4820f9050b286b996e Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Tue, 2 Jan 2024 19:51:20 +0800 Subject: [PATCH 073/142] [DimExpr] Add substitute DimExpr util (#60493) * add SubstituteDimExpr * Fix compile error * Code format * Polish DimExprUtilTest * Change namesapce * Fix unittest * Polish DimExprUtilTest --- paddle/cinn/common/CMakeLists.txt | 4 +- paddle/cinn/common/dim_expr_util.cc | 111 +++++++++++++++++++++++ paddle/cinn/common/dim_expr_util.h | 29 ++++++ paddle/cinn/common/dim_expr_util_test.cc | 43 +++++++++ 4 files changed, 186 insertions(+), 1 deletion(-) create mode 100644 paddle/cinn/common/dim_expr_util.cc create mode 100644 paddle/cinn/common/dim_expr_util.h create mode 100644 paddle/cinn/common/dim_expr_util_test.cc diff --git a/paddle/cinn/common/CMakeLists.txt b/paddle/cinn/common/CMakeLists.txt index b71055169945cc..ff024385d34795 100644 --- a/paddle/cinn/common/CMakeLists.txt +++ b/paddle/cinn/common/CMakeLists.txt @@ -23,7 +23,8 @@ gather_srcs( nvgpu_dev_info.cc integer_set.cc dim_expr_simplify.cc - dim_expr_converter.cc) + dim_expr_converter.cc + dim_expr_util.cc) cinn_cc_test(test_equation_graph_topo_walker SRCS equation_graph_topo_walker_test.cc DEPS gtest glog) @@ -48,6 +49,7 @@ if(WITH_CUDA) gtest glog) endif() if(NOT CINN_ONLY) + cinn_cc_test(dim_expr_util_test SRCS dim_expr_util_test.cc DEPS cinncore) cinn_cc_test(dim_expr_simplify_test SRCS dim_expr_simplify_test.cc DEPS cinncore) cinn_cc_test(dim_expr_converter_test SRCS dim_expr_converter_test.cc DEPS diff --git a/paddle/cinn/common/dim_expr_util.cc b/paddle/cinn/common/dim_expr_util.cc new file mode 100644 index 00000000000000..0d0a9090429a05 --- /dev/null +++ b/paddle/cinn/common/dim_expr_util.cc @@ -0,0 +1,111 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/common/dim_expr_util.h" + +namespace cinn::common { +using namespace symbol; // NOLINT + +namespace { + +class SubstituteDimExprHelper final { + public: + explicit SubstituteDimExprHelper( + const std::unordered_map& + pattern_to_replacement) + : pattern_to_replacement_(pattern_to_replacement) {} + + std::optional Substitute(const DimExpr& dim_expr) { + auto iter = pattern_to_replacement_.find(dim_expr); + if (iter != pattern_to_replacement_.end()) return iter->second; + return std::visit([&](const auto& impl) { return SubstituteImpl(impl); }, + dim_expr.variant()); + } + + private: + std::optional SubstituteImpl(const std::int64_t& value) { + // `Substitute` has handled the case that `value` is matched. + return std::nullopt; + } + std::optional SubstituteImpl(const std::string& value) { + // `Substitute` has handled the case that `value` is matched. + return std::nullopt; + } + + std::optional SubstituteImpl(const Negative& dim_expr) { + return SubstituteUnary(dim_expr); + } + std::optional SubstituteImpl(const Reciprocal& dim_expr) { + return SubstituteUnary(dim_expr); + } + + template + std::optional SubstituteUnary(const T& dim_expr) { + const auto& operand = dim_expr->data; + const auto& substituted_operand = Substitute(operand); + if (!substituted_operand.has_value()) return std::nullopt; + return T{substituted_operand.value()}; + } + + std::optional SubstituteImpl(const Add& dim_expr) { + return SubstituteVariadic(dim_expr); + } + + std::optional SubstituteImpl(const Mul& dim_expr) { + return SubstituteVariadic(dim_expr); + } + + std::optional SubstituteImpl(const Max& dim_expr) { + return SubstituteVariadic(dim_expr); + } + + std::optional SubstituteImpl(const Min& dim_expr) { + return SubstituteVariadic(dim_expr); + } + + std::optional SubstituteImpl(const Broadcast& dim_expr) { + return SubstituteVariadic(dim_expr); + } + + template + std::optional SubstituteVariadic(const T& dim_expr) { + const auto& operands = *(dim_expr.operands); + List substituted_operands{}; + size_t replace_cnt = 0; + for (const auto& operand : operands) { + const auto& substituted_operand = Substitute(operand); + replace_cnt += substituted_operand.has_value(); + substituted_operands->push_back(substituted_operand.has_value() + ? substituted_operand.value() + : operand); + } + if (replace_cnt == 0) return std::nullopt; + return T{substituted_operands}; + } + + std::unordered_map pattern_to_replacement_; +}; + +} // namespace + +symbol::DimExpr SubstituteDimExpr( + const symbol::DimExpr& dim_expr, + const std::unordered_map& + pattern_to_replacement) { + const auto& opt_replaced = + SubstituteDimExprHelper(pattern_to_replacement).Substitute(dim_expr); + return opt_replaced.has_value() ? opt_replaced.value() : dim_expr; +} + +} // namespace cinn::common diff --git a/paddle/cinn/common/dim_expr_util.h b/paddle/cinn/common/dim_expr_util.h new file mode 100644 index 00000000000000..163aeb226ab0d1 --- /dev/null +++ b/paddle/cinn/common/dim_expr_util.h @@ -0,0 +1,29 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/pir/core/builder.h" +#include "paddle/pir/dialect/shape/utils/dim_expr.h" + +namespace cinn::common { + +symbol::DimExpr SubstituteDimExpr( + const symbol::DimExpr& dim_expr, + const std::unordered_map& + pattern_to_replacement); + +} diff --git a/paddle/cinn/common/dim_expr_util_test.cc b/paddle/cinn/common/dim_expr_util_test.cc new file mode 100644 index 00000000000000..82b300fc5bfe2b --- /dev/null +++ b/paddle/cinn/common/dim_expr_util_test.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/common/dim_expr_util.h" + +#include "gtest/gtest.h" + +namespace cinn::common { +using namespace symbol; // NOLINT + +namespace { +DimExpr CreateExampleDimExpr() { + DimExpr sym0 = DimExpr("S0"); + DimExpr sym1 = DimExpr("S1"); + DimExpr constant = DimExpr(2); + return (sym0 - sym1) * constant / sym0; +} +} // namespace + +TEST(DimExprUtil, Substitute) { + DimExpr dim_expr = CreateExampleDimExpr(); + std::unordered_map naive_to_full_name{ + {DimExpr("S0"), DimExpr("symbol0")}, {DimExpr("S1"), DimExpr("symbol1")}}; + std::unordered_map full_name_to_naive{ + {DimExpr("symbol0"), DimExpr("S0")}, {DimExpr("symbol1"), DimExpr("S1")}}; + + const auto& mid_expr = SubstituteDimExpr(dim_expr, naive_to_full_name); + const auto& ret_expr = SubstituteDimExpr(mid_expr, full_name_to_naive); + ASSERT_EQ(ret_expr, dim_expr); +} + +} // namespace cinn::common From 290bf411734a43dcb92ae1078d98ba9bffaefa5d Mon Sep 17 00:00:00 2001 From: NeroLoh <745827440@qq.com> Date: Tue, 2 Jan 2024 20:28:15 +0800 Subject: [PATCH 074/142] [xpu]add sine_pos fuse pass and sine_pos xpu kernel (#60025) --- cmake/external/xpu.cmake | 2 +- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../framework/ir/xpu/sine_pos_fuse_pass.cc | 286 ++++++++++++++++++ .../inference/api/paddle_pass_builder.cc | 1 + paddle/phi/api/yaml/fused_ops.yaml | 9 + paddle/phi/backends/xpu/xpu2_op_list.cc | 7 +- paddle/phi/infermeta/fusion.cc | 31 ++ paddle/phi/infermeta/fusion.h | 4 + .../kernels/fusion/xpu/sine_pos_xpu_kernel.cc | 55 ++++ .../kernels/legacy/xpu/reduce_max_kernel.cc | 8 +- paddle/phi/kernels/xpu/activation_kernel.cc | 13 +- paddle/phi/kernels/xpu/reduce_max_kernel.cc | 9 +- test/ir/inference/test_xpu_sine_pos_pass.py | 132 ++++++++ 13 files changed, 548 insertions(+), 10 deletions(-) create mode 100644 paddle/fluid/framework/ir/xpu/sine_pos_fuse_pass.cc create mode 100644 paddle/phi/kernels/fusion/xpu/sine_pos_xpu_kernel.cc create mode 100644 test/ir/inference/test_xpu_sine_pos_pass.py diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index c0aea597308329..2b5c94872a36ca 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -26,7 +26,7 @@ set(XPU_XBLAS_LIB_NAME "libxpu_blas.so") set(XPU_XFA_LIB_NAME "libxpu_flash_attention.so") if(NOT DEFINED XPU_BASE_DATE) - set(XPU_BASE_DATE "20231203") + set(XPU_BASE_DATE "20231218") endif() if(NOT DEFINED XPU_XHPC_BASE_DATE) set(XPU_XHPC_BASE_DATE "20231229") diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 3c7560b69e3323..35f5ba1522368e 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -318,6 +318,7 @@ if(WITH_XPU) ${XPU_PASS_DEPS}) pass_library(elementwise_mul_add_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(sine_pos_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) endif() cc_library( diff --git a/paddle/fluid/framework/ir/xpu/sine_pos_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/sine_pos_fuse_pass.cc new file mode 100644 index 00000000000000..6c398b775abf5b --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/sine_pos_fuse_pass.cc @@ -0,0 +1,286 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/quantize_helper.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/ir/xpu/quant_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { +/* +fuse block in vis model to sine_pos_xpu op +------------------------------------------------------ +sub block: + x y + \ / + \ / + \ / + mul + / \ + / \ + / \ + slice slice + | | + | | + sin cos + \ / + \ / + \ / + stack + | + | + flatten + | + out +------------------------------------------------------ +After the pass is applied: + x y + \ / + \ / + \ / + sine_pos_xpu + | + | + out +*/ + +struct SinePosXPUPattern : public PatternBase { + SinePosXPUPattern(PDPattern* pattern, const std::string& name_scope); + // declare operator node's name + PATTERN_DECL_NODE(ew_mul); + PATTERN_DECL_NODE(slice1); + PATTERN_DECL_NODE(slice2); + PATTERN_DECL_NODE(sin); + PATTERN_DECL_NODE(cos); + PATTERN_DECL_NODE(stack); + PATTERN_DECL_NODE(flatten); + // declare variable node's name + PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(y); + PATTERN_DECL_NODE(ew_mul_out); + PATTERN_DECL_NODE(slice1_out); + PATTERN_DECL_NODE(slice2_out); + PATTERN_DECL_NODE(sin_out); + PATTERN_DECL_NODE(cos_out); + PATTERN_DECL_NODE(stack_out); + PATTERN_DECL_NODE(flatten_out); +}; + +SinePosXPUPattern::SinePosXPUPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto x = pattern->NewNode(x_repr()) + ->assert_is_op_input("elementwise_mul", "X") + ->assert_more([&](Node* node) { + auto x_shape = node->Var()->GetShape(); + size_t x_rank = x_shape.size(); + return x_rank == 3 && x_shape.back() == 1; + }); + auto y = pattern->NewNode(y_repr()) + ->assert_is_op_input("elementwise_mul", "Y") + ->assert_more([&](Node* node) { + auto x_shape = node->Var()->GetShape(); + size_t x_rank = x_shape.size(); + return x_rank == 1 && x_shape[0] % 2 == 0; + }); + auto* ew_mul = pattern->NewNode(ew_mul_repr()) + ->assert_is_op("elementwise_mul") + ->assert_more([&](Node* node) { + auto* op_desc = node->Op(); + return op_desc->GetAttrIfExists("axis") == -1; + }); + auto* ew_mul_out = pattern->NewNode(ew_mul_out_repr()) + ->assert_is_op_output("elementwise_mul", "Out") + ->assert_is_op_input("strided_slice", "Input"); + ew_mul->LinksFrom({x, y}).LinksTo({ew_mul_out}); + auto* slice1 = + pattern->NewNode(slice1_repr()) + ->assert_is_op("strided_slice") + ->assert_more([&](Node* node) { + auto* op_desc = node->Op(); + return op_desc->GetAttrIfExists>("axes") == + std::vector{2} && + op_desc->GetAttrIfExists>("starts") == + std::vector{0} && + op_desc->GetAttrIfExists>("strides") == + std::vector{2}; + }); + auto* slice1_out = pattern->NewNode(slice1_out_repr()) + ->assert_is_op_output("strided_slice", "Out") + ->assert_is_op_input("sin", "X"); + slice1->LinksFrom({ew_mul_out}).LinksTo({slice1_out}); + auto* sin = pattern->NewNode(sin_repr())->assert_is_op("sin"); + auto* sin_out = pattern->NewNode(sin_out_repr()) + ->assert_is_op_output("sin", "Out") + ->assert_is_op_nth_input("stack", "X", 0); + sin->LinksFrom({slice1_out}).LinksTo({sin_out}); + auto* slice2 = + pattern->NewNode(slice2_repr()) + ->assert_is_op("strided_slice") + ->assert_more([&](Node* node) { + auto* op_desc = node->Op(); + return op_desc->GetAttrIfExists>("axes") == + std::vector{2} && + op_desc->GetAttrIfExists>("starts") == + std::vector{1} && + op_desc->GetAttrIfExists>("strides") == + std::vector{2}; + }); + auto* slice2_out = pattern->NewNode(slice2_out_repr()) + ->assert_is_op_output("strided_slice", "Out") + ->assert_is_op_input("cos", "X"); + slice2->LinksFrom({ew_mul_out}).LinksTo({slice2_out}); + auto* cos = pattern->NewNode(cos_repr())->assert_is_op("cos"); + auto* cos_out = pattern->NewNode(cos_out_repr()) + ->assert_is_op_output("cos", "Out") + ->assert_is_op_nth_input("stack", "X", 1); + cos->LinksFrom({slice2_out}).LinksTo({cos_out}); + auto* stack = pattern->NewNode(stack_repr()) + ->assert_is_op("stack") + ->assert_more([&](Node* node) { + auto* op_desc = node->Op(); + return op_desc->GetAttrIfExists("axis") == 3; + }); + auto* stack_out = pattern->NewNode(stack_out_repr()) + ->assert_is_op_output("stack", "Y") + ->assert_is_op_input("flatten_contiguous_range", "X"); + stack->LinksFrom({sin_out, cos_out}).LinksTo({stack_out}); + + auto* flatten = + pattern->NewNode(flatten_repr()) + ->assert_is_op("flatten_contiguous_range") + ->assert_more([&](Node* node) { + auto* op_desc = node->Op(); + return op_desc->GetAttrIfExists("start_axis") == 2 && + op_desc->GetAttrIfExists("stop_axis") == 3; + }); + auto* flatten_out = + pattern->NewNode(flatten_out_repr()) + ->assert_is_op_output("flatten_contiguous_range", "Out") + ->AsOutput(); + flatten->LinksFrom({stack_out}).LinksTo({flatten_out}); +} + +} // namespace patterns + +class SinePosFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + const std::string name_scope_{"sine_pos_fuse_pass"}; +}; + +void SinePosFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + GraphPatternDetector gpd; + patterns::SinePosXPUPattern pattern(gpd.mutable_pattern(), name_scope_); + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle SinePosFusePass fuse"; + /* declare operator node's name */ + // declare operator node's name + GET_IR_NODE(ew_mul); + GET_IR_NODE(slice1); + GET_IR_NODE(slice2); + GET_IR_NODE(sin); + GET_IR_NODE(cos); + GET_IR_NODE(stack); + GET_IR_NODE(flatten); + // declare variable node's name + GET_IR_NODE(x); + GET_IR_NODE(y); + GET_IR_NODE(ew_mul_out); + GET_IR_NODE(slice1_out); + GET_IR_NODE(slice2_out); + GET_IR_NODE(sin_out); + GET_IR_NODE(cos_out); + GET_IR_NODE(stack_out); + GET_IR_NODE(flatten_out); + auto* block = flatten->Op()->Block(); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); + // Generate sine_pos_xpu fused op + framework::OpDesc fused_op_desc(block); + fused_op_desc.SetType("sine_pos_xpu"); + // set attrs for fused op + fused_op_desc.SetInput("x", {x->Name()}); + fused_op_desc.SetInput("y", {y->Name()}); + + fused_op_desc.SetOutput("out", {flatten_out->Name()}); + // relink fused op + auto* fused_op = graph->CreateOpNode(&fused_op_desc); + IR_NODE_LINK_TO(x, fused_op); + IR_NODE_LINK_TO(y, fused_op); + IR_NODE_LINK_TO(fused_op, flatten_out); + // delete useless node + std::unordered_set delete_nodes = {ew_mul, + ew_mul_out, + slice1, + slice1_out, + slice2, + slice2_out, + sin, + sin_out, + cos, + cos_out, + stack, + stack_out, + flatten}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(sine_pos_fuse_pass, paddle::framework::ir::SinePosFusePass); + +REGISTER_PASS_CAPABILITY(sine_pos_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "sin_pos_xpu", 0)); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 726e833fd515ac..0a0e6b591ef899 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -577,6 +577,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "yolo_box_xpu_fuse_pass", "fast_where_xpu_fuse_pass", "elementwise_mul_add_fuse_pass", + "sine_pos_fuse_pass", // "auto_mixed_precision_pass", "cast_mixed_precision_op_fuse_pass", "xpu_quantize_op_pass", diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index a31dee6a4c27d7..f1d253945139ed 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -431,6 +431,15 @@ func : self_dp_attention data_type : x +- op : sine_pos_xpu + args : (Tensor x, Tensor y) + output : Tensor(out) + infer_meta : + func : SinePosXPUInferMeta + kernel : + func : sine_pos_xpu + data_type : x + - op : skip_layernorm args : (Tensor x, Tensor y, Tensor scale, Tensor bias, float epsilon, int begin_norm_axis) output : Tensor(out) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 31d16aaf5c0a38..1d388b2a47d5a5 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -680,7 +680,7 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"pool3d", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, - {"pow", XPUKernelSet({phi::DataType::FLOAT32})}, + {"pow", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"pow_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"pow2_decay_with_linear_warmup", XPUKernelSet({phi::DataType::FLOAT32})}, {"prior_box", XPUKernelSet({phi::DataType::FLOAT32})}, @@ -707,7 +707,8 @@ XPUOpMap& get_kl2_ops() { {"reduce_max", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32, - phi::DataType::INT64})}, + phi::DataType::INT64, + phi::DataType::FLOAT16})}, {"reduce_mean_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_mean", XPUKernelSet({phi::DataType::FLOAT32, @@ -1171,6 +1172,8 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32, phi::DataType::INT32})}, + {"sine_pos_xpu", + XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, }; return s_xpu2_kernels; diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index f38ffe0f1fc9db..41329efaa86d53 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -3687,4 +3687,35 @@ void QKVAttentionXPUInferMeta(const MetaTensor& q, qkv_max->set_dtype(out_dtype); qkv_max->set_layout(q.layout()); } +void SinePosXPUInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out) { + auto x_dims = x.dims(); + auto x_dims_size = x_dims.size(); + PADDLE_ENFORCE_EQ( + x_dims_size, + 3, + phi::errors::InvalidArgument( + "x_dims_size should be 3, but received x_dims_size is %d", + x_dims_size)); + PADDLE_ENFORCE_EQ(x_dims[x_dims_size - 1], + 1, + phi::errors::InvalidArgument( + "x last dim size should be 1, but received is %d", + x_dims[x_dims_size - 1])); + auto y_dims = y.dims(); + auto y_dims_size = y_dims.size(); + PADDLE_ENFORCE_EQ( + y_dims_size, + 1, + phi::errors::InvalidArgument( + "x_dims_size should be 3, but received x_dims_size is %d", + y_dims_size)); + + phi::DDim out_dim = phi::make_ddim({x_dims[0], x_dims[1], y_dims[0]}); + + out->set_dims(out_dim); + out->set_dtype(x.dtype()); +} + } // namespace phi diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index ade4e38d457a61..e294e67aa1c951 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -834,4 +834,8 @@ void QKVAttentionXPUInferMeta(const MetaTensor& q, DataType out_dtype, MetaTensor* qkv, MetaTensor* qkv_max); +void SinePosXPUInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/fusion/xpu/sine_pos_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/sine_pos_xpu_kernel.cc new file mode 100644 index 00000000000000..0936f7be2f0ab7 --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/sine_pos_xpu_kernel.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace fusion { + +template +void SinePosXPUKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + + auto* x_data = reinterpret_cast(x.data()); + auto* y_data = reinterpret_cast(y.data()); + auto* out_data = reinterpret_cast(ctx.template Alloc(out)); + // fix precision of fp16 model + xpu::ctx_guard RAII_GUARD(ctx.x_context()); + std::vector x_shape = phi::vectorize(x.dims()); + std::vector y_shape = phi::vectorize(y.dims()); + // yolo_box_coord only support fp32&&fp16 precision + int r = xpu::sine_pos_fusion( + /* baidu::xpu::api::Context* ctx */ ctx.x_context(), + /* const T* x */ x_data, + /* const T* y */ y_data, + /* T* out */ out_data, + /* int64_t batch */ x_shape[0], + /* int64_t n */ x_shape[1], + /* int64_t dim */ y_shape[0]); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "sine_pos_xpu"); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(sine_pos_xpu, + XPU, + ALL_LAYOUT, + phi::fusion::SinePosXPUKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/legacy/xpu/reduce_max_kernel.cc b/paddle/phi/kernels/legacy/xpu/reduce_max_kernel.cc index cb9ff8f6bdb802..9e21dfd6ba30e0 100644 --- a/paddle/phi/kernels/legacy/xpu/reduce_max_kernel.cc +++ b/paddle/phi/kernels/legacy/xpu/reduce_max_kernel.cc @@ -49,4 +49,10 @@ void MaxRawKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(max_raw, XPU, ALL_LAYOUT, phi::MaxRawKernel, float, int) {} +PD_REGISTER_KERNEL(max_raw, + XPU, + ALL_LAYOUT, + phi::MaxRawKernel, + float, + int, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/activation_kernel.cc b/paddle/phi/kernels/xpu/activation_kernel.cc index 449be30474193a..4f82566ca45f19 100644 --- a/paddle/phi/kernels/xpu/activation_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_kernel.cc @@ -195,15 +195,16 @@ void PowKernel(const Context& dev_ctx, const DenseTensor& x, const Scalar& factor, DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; dev_ctx.template Alloc(out); - float pow_factor = factor.to(); - const T* x_data = x.data(); - T* y_data = out->data(); + T pow_factor = factor.to(); + const XPUType* x_data = reinterpret_cast(x.data()); + XPUType* y_data = reinterpret_cast(out->data()); auto xpu_context = dev_ctx.x_context(); // allocate temp memory for factor on xpu xpu::ctx_guard RAII_GUARD(xpu_context); - T* factor_data = RAII_GUARD.alloc_l3_or_gm(1); + XPUType* factor_data = RAII_GUARD.alloc_l3_or_gm(1); PADDLE_ENFORCE_NOT_NULL( factor_data, errors::External("XPU alloc_l3_or_gm returns nullptr")); memory_utils::Copy(dev_ctx.GetPlace(), @@ -653,6 +654,9 @@ PD_REGISTER_KERNEL(cos, phi::dtype::float16, phi::dtype::bfloat16) {} +PD_REGISTER_KERNEL( + pow, XPU, ALL_LAYOUT, phi::PowKernel, float, phi::dtype::float16) {} + #define PD_REGISTER_ACTIVATION_KERNEL(name, func) \ PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {} @@ -660,7 +664,6 @@ PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) // no grad PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) -PD_REGISTER_ACTIVATION_KERNEL(pow, PowKernel) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel) PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel) diff --git a/paddle/phi/kernels/xpu/reduce_max_kernel.cc b/paddle/phi/kernels/xpu/reduce_max_kernel.cc index 8842f86b0c9fb3..72ce736ddcad2a 100644 --- a/paddle/phi/kernels/xpu/reduce_max_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_max_kernel.cc @@ -57,4 +57,11 @@ void MaxKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(max, XPU, ALL_LAYOUT, phi::MaxKernel, float, int, int64_t) {} +PD_REGISTER_KERNEL(max, + XPU, + ALL_LAYOUT, + phi::MaxKernel, + float, + int, + int64_t, + phi::dtype::float16) {} diff --git a/test/ir/inference/test_xpu_sine_pos_pass.py b/test/ir/inference/test_xpu_sine_pos_pass.py new file mode 100644 index 00000000000000..8d8abbfdfb1843 --- /dev/null +++ b/test/ir/inference/test_xpu_sine_pos_pass.py @@ -0,0 +1,132 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial + +import hypothesis.strategies as st +import numpy as np +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestSinePosXPUFusePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ["sine_pos_xpu"], (1e-3, 1e-3) + + def sample_program_config(self, draw): + x_shape = draw( + st.lists( + st.integers(min_value=1, max_value=10), min_size=3, max_size=3 + ) + ) + x_shape[1] = draw(st.integers(min_value=100, max_value=512)) + x_shape[2] = draw(st.integers(min_value=1, max_value=1)) + y_shape = draw( + st.lists( + st.integers(min_value=128, max_value=128), + min_size=1, + max_size=1, + ) + ) + + def generate_data(shape): + return np.random.random(shape).astype(np.float32) + + # Here we will compose a program + # Still has some risks that the program is invalid or cause bug while running + # Use function `is_program_valid` to filter the invalid programs before running + # Use function `add_skip_pass_case` to ignore the programs even if they cause bug while runing + mul_op = OpConfig( + "elementwise_mul", + inputs={"X": ["x"], "Y": ["y"]}, + outputs={"Out": ["mul_out"]}, + axis=-1, + ) + slice1_op = OpConfig( + "strided_slice", + inputs={"Input": ["mul_out"]}, + outputs={"Out": ["slice1_out"]}, + axes=[2], + starts=[0], + strides=[2], + ends=[128], + infer_flags=[1], + ) + sin_op = OpConfig( + "sin", + inputs={"X": ["slice1_out"]}, + outputs={"Out": ["sin_out"]}, + ) + slice2_op = OpConfig( + "strided_slice", + inputs={"Input": ["mul_out"]}, + outputs={"Out": ["slice2_out"]}, + axes=[2], + starts=[1], + strides=[2], + ends=[128], + infer_flags=[1], + ) + cos_op = OpConfig( + "cos", + inputs={"X": ["slice2_out"]}, + outputs={"Out": ["cos_out"]}, + ) + stack_op = OpConfig( + "stack", + inputs={"X": ["sin_out", "cos_out"]}, + outputs={"Y": ["stack_out"]}, + axis=3, + ) + flatten_op = OpConfig( + "flatten_contiguous_range", + inputs={"X": ["stack_out"]}, + outputs={"Out": ["flatten_out"]}, + start_axis=2, + stop_axis=3, + ) + + ops = [ + mul_op, + slice1_op, + slice2_op, + sin_op, + cos_op, + stack_op, + flatten_op, + ] + + program_config = ProgramConfig( + ops=ops, + inputs={ + "x": TensorConfig(data_gen=partial(generate_data, x_shape)), + "y": TensorConfig(data_gen=partial(generate_data, y_shape)), + }, + weights={}, + outputs=ops[-1].outputs["Out"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=25, + passes=["sine_pos_fuse_pass"], + ) + + +if __name__ == "__main__": + unittest.main() From bd29981ef25da15e158dfc71ee99389fe06a6166 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Tue, 2 Jan 2024 21:14:07 +0800 Subject: [PATCH 075/142] add split with variable in factors and rewrite vectorize,unroll,bind error handling mechanism (#60449) --- paddle/cinn/ir/schedule/impl/for_type.cc | 26 +++++-- paddle/cinn/ir/schedule/impl/ir_schedule.h | 2 + .../ir/schedule/impl/loop_transformation.cc | 67 ++++++++++++++++++- paddle/cinn/ir/schedule/ir_schedule.cc | 15 +++-- paddle/cinn/ir/schedule/schedule_base.h | 2 + 5 files changed, 102 insertions(+), 10 deletions(-) diff --git a/paddle/cinn/ir/schedule/impl/for_type.cc b/paddle/cinn/ir/schedule/impl/for_type.cc index 9c3b79cfe5f180..2060ef580a33cd 100644 --- a/paddle/cinn/ir/schedule/impl/for_type.cc +++ b/paddle/cinn/ir/schedule/impl/for_type.cc @@ -63,20 +63,37 @@ void DyScheduleImpl::Parallel(const Expr& loop) { } void DyScheduleImpl::Vectorize(const Expr& loop, int factor) { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "Vectorize"; + std::ostringstream os; CHECK_GT(factor, 0) << "vectorize factor should be more than 0"; - CHECK(loop.As()->extent.is_constant()) - << "The loop to be vectorized should be constant!\n"; + if (factor <= 0) { + os << "vectorize factor should be more than 0\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + if (!loop.As()->extent.is_constant()) { + os << "The loop to be vectorized should be constant!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } MutateForType(loop, ForType::Vectorized, factor); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } void DyScheduleImpl::Unroll(const Expr& loop) { - CHECK(loop.As()->extent.is_constant()) - << "The loop to be unrolled should be constant!\n"; + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "Unroll"; + std::ostringstream os; + if (!loop.As()->extent.is_constant()) { + os << "The loop to be unrolled should be constant!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } MutateForType(loop, ForType::Unrolled); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } void DyScheduleImpl::Bind(const Expr& loop, const std::string& thread_axis) { #ifdef CINN_WITH_CUDA + CINN_IR_SCHEDULE_BEGIN(); std::string primitive = "Bind"; std::ostringstream os; @@ -117,6 +134,7 @@ void DyScheduleImpl::Bind(const Expr& loop, const std::string& thread_axis) { } MutateForType(loop, ForType::GPUThread, offset); } + CINN_IR_SCHEDULE_END(this->err_msg_level_); #endif } } // namespace ir diff --git a/paddle/cinn/ir/schedule/impl/ir_schedule.h b/paddle/cinn/ir/schedule/impl/ir_schedule.h index 22d78ac38b73f6..3fe35854cb4aa0 100644 --- a/paddle/cinn/ir/schedule/impl/ir_schedule.h +++ b/paddle/cinn/ir/schedule/impl/ir_schedule.h @@ -49,6 +49,7 @@ class DyScheduleImpl : public ScheduleBase { std::vector GetChildBlocks(const Expr& expr) const; Expr GetBlock(const std::string& block_name) const; std::vector Split(const Expr& loop, const std::vector& factors); + std::vector Split(const Expr& loop, const std::vector& factors); std::vector SamplePerfectTile( utils::LinearRandomEngine::StateType* rand_seed, const Expr& loop, @@ -122,6 +123,7 @@ class StScheduleImpl : public ScheduleBase { std::vector GetChildBlocks(const Expr& expr) const; Expr GetBlock(const std::string& block_name) const; std::vector Split(const Expr& loop, const std::vector& factors); + std::vector Split(const Expr& loop, const std::vector& factors); std::vector SamplePerfectTile( utils::LinearRandomEngine::StateType* rand_seed, const Expr& loop, diff --git a/paddle/cinn/ir/schedule/impl/loop_transformation.cc b/paddle/cinn/ir/schedule/impl/loop_transformation.cc index c3a3ad448f5362..6f0c1b4f5ae824 100644 --- a/paddle/cinn/ir/schedule/impl/loop_transformation.cc +++ b/paddle/cinn/ir/schedule/impl/loop_transformation.cc @@ -12,9 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/cinn/common/macros.h" #include "paddle/cinn/ir/schedule/impl/ir_schedule.h" +#include "paddle/cinn/common/integer_set.h" +#include "paddle/cinn/common/macros.h" + /** \brief A macro that guards the beginning of each implementation of schedule */ #define CINN_IR_SCHEDULE_BEGIN() try { @@ -157,6 +159,63 @@ std::vector DyScheduleImpl::Split(const Expr& loop, return splited_loops; } +// TODO(@LiuYang): now -1 can't exsit in factors, +std::vector DyScheduleImpl::Split(const Expr& loop, + const std::vector& factors) { + CHECK(loop.As()) + << "Expr param of Split must be For node! Please check."; + auto* for_node = loop.As(); + CHECK(common::is_zero(for_node->min)) + << "The For node must start with 0! Please check."; + CHECK(!factors.empty()) + << "The factors param of Split should not be empty! Please check."; + CHECK(!loop.As()->extent.is_constant()) + << "Can't Split a loop with constant extent but with variable in " + "factors!"; + Expr tot_extent = for_node->extent; + + VLOG(3) << "Try Split loop from (" << for_node->loop_var->name << ", 0, " + << tot_extent << ") to (" << cinn::utils::Join(factors, ", ") + << ") at loop:\n" + << loop; + + std::vector process_factors(factors); + Expr prod_size(1); + for (auto factor : factors) prod_size = prod_size * Expr(factor); + cinn::common::SymbolicExprAnalyzer analyzer({}); + CHECK(analyzer.ProveEQ(tot_extent, prod_size).value_or(false)) + << "Product of factors can't be proved to be equal to the extent of " + "current for loop!"; + + std::vector new_loop_vars; + Expr substitute_value(0); + for (int i = 0; i < process_factors.size(); ++i) { + Var temp_var(common::UniqName(for_node->loop_var->name)); + substitute_value = Expr(temp_var) + substitute_value * process_factors[i]; + new_loop_vars.push_back(temp_var); + } + substitute_value = cinn::common::AutoSimplify(substitute_value); + Expr new_node = ir::ir_utils::IRCopy(for_node->body); + ReplaceExpr(&new_node, {for_node->loop_var}, {substitute_value}); + std::vector splited_loops; + splited_loops.resize(process_factors.size()); + + for (int i = process_factors.size() - 1; i >= 0; i--) { + if (!new_node.As()) new_node = Block::Make({new_node}); + new_node = For::Make(new_loop_vars[i], + Expr(0), + process_factors[i], + for_node->for_type(), + for_node->device_api, + new_node); + splited_loops[i] = new_node; + } + + this->Replace(loop, new_node); + VLOG(3) << "After Split, ir is:\n" << splited_loops.at(0); + return splited_loops; +} + Expr DyScheduleImpl::Fuse(const std::vector& loops) { VLOG(3) << "Tring to fuse:\n" << cinn::utils::Join(loops, "\n"); std::vector for_nodes; @@ -370,6 +429,12 @@ std::vector StScheduleImpl::Split(const Expr& loop, return splited_loops; } +std::vector StScheduleImpl::Split(const Expr& loop, + const std::vector& factors) { + CHECK(false) << "Static shape schedule don't support Split with some " + "variables in factors"; +} + Expr StScheduleImpl::Fuse(const std::vector& loops) { VLOG(3) << "Tring to fuse:\n" << cinn::utils::Join(loops, "\n"); std::vector for_nodes; diff --git a/paddle/cinn/ir/schedule/ir_schedule.cc b/paddle/cinn/ir/schedule/ir_schedule.cc index 153d050afa3fc0..fb151051f0b67f 100644 --- a/paddle/cinn/ir/schedule/ir_schedule.cc +++ b/paddle/cinn/ir/schedule/ir_schedule.cc @@ -405,11 +405,16 @@ std::vector IRSchedule::Split(const std::string& block_name, std::vector IRSchedule::Split(const Expr& loop, const std::vector& factors) { std::vector int_factors; - std::transform(factors.begin(), - factors.end(), - std::back_inserter(int_factors), - [](Expr x) { return x.as_int32(); }); - auto results = impl_->Split(loop, int_factors); + std::vector results; + std::for_each(factors.begin(), factors.end(), [&int_factors](const Expr& e) { + if (e.is_constant()) int_factors.push_back(e.as_int32()); + }); + if (int_factors.size() == factors.size()) { + results = impl_->Split(loop, int_factors); + } else { + results = impl_->Split(loop, factors); + } + trace_.Append(ScheduleDesc::Step( "Split", {{"loop", std::vector({loop})}, {"factors", factors}}, diff --git a/paddle/cinn/ir/schedule/schedule_base.h b/paddle/cinn/ir/schedule/schedule_base.h index e94cc8d0bf5d17..6ce5caaeaad12c 100644 --- a/paddle/cinn/ir/schedule/schedule_base.h +++ b/paddle/cinn/ir/schedule/schedule_base.h @@ -97,6 +97,8 @@ class ScheduleBase { virtual Expr GetBlock(const std::string& block_name) const = 0; virtual std::vector Split(const Expr& loop, const std::vector& factors) = 0; + virtual std::vector Split(const Expr& loop, + const std::vector& factors) = 0; virtual std::vector SamplePerfectTile( utils::LinearRandomEngine::StateType* rand_seed, const Expr& loop, From 8b2b95305d14d57845e9b403da0ec4bf29e45e27 Mon Sep 17 00:00:00 2001 From: Nyakku Shigure Date: Tue, 2 Jan 2024 22:59:21 +0800 Subject: [PATCH 076/142] [CodeStyle] Fix regression of Ruff in sot (#60483) --- python/paddle/jit/sot/opcode_translator/__init__.py | 2 +- python/paddle/jit/sot/utils/__init__.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/jit/sot/opcode_translator/__init__.py b/python/paddle/jit/sot/opcode_translator/__init__.py index dec41c8bba1721..6a86c46322f652 100644 --- a/python/paddle/jit/sot/opcode_translator/__init__.py +++ b/python/paddle/jit/sot/opcode_translator/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .skip_files import setup_skip_files from .eval_frame_callback import eval_frame_callback # noqa: F401 +from .skip_files import setup_skip_files setup_skip_files() diff --git a/python/paddle/jit/sot/utils/__init__.py b/python/paddle/jit/sot/utils/__init__.py index 16e2cd5b1afe52..eb1c9ae58a7e57 100644 --- a/python/paddle/jit/sot/utils/__init__.py +++ b/python/paddle/jit/sot/utils/__init__.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .call_ast_utils import get_static_function, try_ast_func # noqa: F401 from .envs import ( # noqa: F401 ENV_CLEAN_CODE, ENV_COST_MODEL, ENV_MIN_GRAPH_SIZE, ENV_SHOW_TRACKERS, ENV_SOT_LOG_LEVEL, - ENV_STRICT_MODE, ENV_SOT_WITH_CONTROL_FLOW, + ENV_STRICT_MODE, cost_model_guard, min_graph_size_guard, strict_mode_guard, @@ -51,8 +52,8 @@ count_if, current_tmp_name_records, execute_time, - flatten_extend, flatten, + flatten_extend, get_unbound_method, hashable, in_paddle_module, @@ -72,4 +73,3 @@ no_eval_frame, tmp_name_guard, ) -from .call_ast_utils import get_static_function, try_ast_func From 48b727960dbe2db2913ec9b82e7a169abedefe71 Mon Sep 17 00:00:00 2001 From: lzydev Date: Wed, 3 Jan 2024 10:29:43 +0800 Subject: [PATCH 077/142] support cast op from FP32 to low precision (#60385) --- .../distributed/passes/auto_parallel_amp.py | 5 +++++ .../distributed/passes/auto_parallel_fp16.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index ac533db098c619..a7782b6d8d130b 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -232,6 +232,11 @@ def build_state(self): return is_train def _mark_black_white_ops(self, op, ops, block): + # deal auto_cast info + if not op.amp_options.enable: + self._op_fp16_dict[op.desc.original_id()] = False + return + # ernie inference trick if op.type == "assign" and "array_" in op.input_arg_names[0]: self._op_fp16_dict[op.desc.original_id()] = False diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 2985b4da290f40..92259dee3ae057 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -209,6 +209,9 @@ def _build_state(self): for block in self.program.blocks: self.resolute_tensor_dtype(block) + for block in self.program.blocks: + self.resolute_cast_op(block) + # insert cast ops for block in self.program.blocks: self.cast_block(block) @@ -296,6 +299,19 @@ def set_var_to_fp16(self, var_name, block): if var.dtype == core.VarDesc.VarType.FP32: var.desc.set_dtype(__target_dtype__) + def resolute_cast_op(self, block): + """ + Deal the "cast_op" from "FP32" to "FP16" or "BF16" in the model. + """ + for op in block.ops: + if op.type == "cast": + in_name = op.input('X')[0] + out_name = op.output('Out')[0] + in_var = block._find_var_recursive(in_name) + out_var = block._find_var_recursive(out_name) + op._set_attr("in_dtype", in_var.dtype) + op._set_attr("out_dtype", out_var.dtype) + def resolute_tensor_dtype(self, block): for op in block.ops: # 'amp_options' flag has highest priority From 40fa8bc14aa5f9c7839aa988883c0cf984ab9e46 Mon Sep 17 00:00:00 2001 From: tianshuo78520a <707759223@qq.com> Date: Wed, 3 Jan 2024 10:41:37 +0800 Subject: [PATCH 078/142] test=document_fix (#60399) --- tools/dockerfile/ubuntu20_dev.sh | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tools/dockerfile/ubuntu20_dev.sh b/tools/dockerfile/ubuntu20_dev.sh index 85c46a1416e221..6078638035e6c3 100755 --- a/tools/dockerfile/ubuntu20_dev.sh +++ b/tools/dockerfile/ubuntu20_dev.sh @@ -15,7 +15,20 @@ # limitations under the License. function base_image(){ - if [[ ${ref_CUDA_MAJOR} == "11.2" ]];then + if [[ ${ref_CUDA_MAJOR} == "cpu" ]];then + dockerfile_name="Dockerfile-cpu" + sed "s##ubuntu:20.04#g" ./Dockerfile.ubuntu20 >${dockerfile_name} + sed -i "s###g" ${dockerfile_name} + sed -i "s#WITH_GPU:-ON#WITH_GPU:-OFF#g" ${dockerfile_name} + sed -i "s#RUN apt-key del 7fa2af80##g" ${dockerfile_name} + sed -i 's#RUN rm /etc/apt/sources.list.d/\*##g' ${dockerfile_name} + sed -i "s#RUN apt-key adv --fetch-keys https://developer.download.nvidia.cn/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub##g" ${dockerfile_name} + sed -i 's##RUN apt-get install -y gcc g++ make#g' ${dockerfile_name} + sed -i "s##WORKDIR /usr/bin ENV PATH=/usr/local/gcc-8.2/bin:\$PATH #g" ${dockerfile_name} + sed -i 's#RUN bash /build_scripts/install_trt.sh##g' ${dockerfile_name} + sed -i 's#RUN bash /build_scripts/install_cudnn.sh cudnn841##g' ${dockerfile_name} + sed -i 's#ENV CUDNN_VERSION=8.4.1##g' ${dockerfile_name} + elif [[ ${ref_CUDA_MAJOR} == "11.2" ]];then dockerfile_name="Dockerfile-112" sed "s##nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04#g" ./Dockerfile.ubuntu20 >${dockerfile_name} sed -i "s##ENV LD_LIBRARY_PATH=/usr/local/cuda-11.2/targets/x86_64-linux/lib:\$LD_LIBRARY_PATH #g" ${dockerfile_name} @@ -72,6 +85,8 @@ function base_image(){ } +export ref_CUDA_MAJOR=cpu +base_image export ref_CUDA_MAJOR=11.2 base_image export ref_CUDA_MAJOR=11.6 From 619ca112b8c0f5989ec1813c60a746e433048583 Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Wed, 3 Jan 2024 10:48:17 +0800 Subject: [PATCH 079/142] [XPU] refine flash attention ut (#60474) * [XPU] refine flash attention ut * refine tolerance --- test/xpu/test_flash_attention_op_xpu.py | 113 +++++++++++++++--------- 1 file changed, 69 insertions(+), 44 deletions(-) diff --git a/test/xpu/test_flash_attention_op_xpu.py b/test/xpu/test_flash_attention_op_xpu.py index 7a8c2805561d13..ecb407c42ab20a 100644 --- a/test/xpu/test_flash_attention_op_xpu.py +++ b/test/xpu/test_flash_attention_op_xpu.py @@ -72,43 +72,45 @@ class TestFlashAttentionAPI(unittest.TestCase): def setUp(self): self.place = paddle.XPUPlace(0) self.shape = (1, 128, 2, 32) - self.dtype = 'float32' self.dropout = 0.0 self.causal = True self.return_softmax = False - self.rtol = 1e-3 - self.atol = 1e-3 def test_all(self): + self.run_case(dtype="float32", tolerance=5e-4, tolerance_dv=5e-4) + self.run_case(dtype="float16", tolerance=5e-4, tolerance_dv=1e-3) + self.run_case(dtype="bfloat16", tolerance=5e-3, tolerance_dv=1e-2) + + def run_case(self, dtype, tolerance, tolerance_dv): # TODO(houj04) remove debug codes after correctness check - print(f"Test case shape {self.shape} dtype {self.dtype}") + print(f"Test case shape {self.shape} dtype {dtype}") # test dynamic paddle.disable_static() np.random.seed(2023) - query = np.random.random(self.shape) - key = np.random.random(self.shape) - value = np.random.random(self.shape) + query = np.random.uniform(-1.0, 1.0, self.shape) + key = np.random.uniform(-1.0, 1.0, self.shape) + value = np.random.uniform(-1.0, 1.0, self.shape) q = paddle.to_tensor( - query, place=self.place, dtype=self.dtype, stop_gradient=False + query, place=self.place, dtype=dtype, stop_gradient=False ) k = paddle.to_tensor( - key, place=self.place, dtype=self.dtype, stop_gradient=False + key, place=self.place, dtype=dtype, stop_gradient=False ) v = paddle.to_tensor( - value, place=self.place, dtype=self.dtype, stop_gradient=False + value, place=self.place, dtype=dtype, stop_gradient=False ) q_ = paddle.to_tensor( - query, place=self.place, dtype=self.dtype, stop_gradient=False + query, place=self.place, dtype=dtype, stop_gradient=False ) k_ = paddle.to_tensor( - key, place=self.place, dtype=self.dtype, stop_gradient=False + key, place=self.place, dtype=dtype, stop_gradient=False ) v_ = paddle.to_tensor( - value, place=self.place, dtype=self.dtype, stop_gradient=False + value, place=self.place, dtype=dtype, stop_gradient=False ) out, _ = flash_attention( @@ -125,8 +127,17 @@ def test_all(self): float_out_ = paddle.cast(out_, "float32") np.testing.assert_allclose( - float_out, float_out_, rtol=self.rtol, atol=self.atol + float_out, float_out_, rtol=tolerance, atol=tolerance + ) + # TODO(houj04) remove debug codes after correctness check + max_diff_forward = np.max( + np.abs(float_out.numpy() - float_out_.numpy()) ) + mean_diff_forward = np.mean( + np.abs(float_out.numpy() - float_out_.numpy()) + ) + print("max_diff_forward:", max_diff_forward) + print("mean_diff_forward:", mean_diff_forward) # backward shape self.assertEqual(q.grad.shape, q.shape) @@ -173,40 +184,54 @@ def test_all(self): print("mean_diff_v_grad:", mean_diff_v_grad) np.testing.assert_allclose( - float_q_grad, float_q_grad_, rtol=self.rtol, atol=self.atol + float_q_grad, float_q_grad_, rtol=tolerance, atol=tolerance ) np.testing.assert_allclose( - float_k_grad, float_k_grad_, rtol=self.rtol, atol=self.atol + float_k_grad, float_k_grad_, rtol=tolerance, atol=tolerance ) np.testing.assert_allclose( - float_v_grad, float_v_grad_, rtol=self.rtol, atol=self.atol - ) - - -class TestFlashAttentionAPITestFP16(TestFlashAttentionAPI): - def setUp(self): - self.place = paddle.XPUPlace(0) - self.shape = (1, 128, 2, 32) - self.dtype = 'float16' - self.dropout = 0.0 - self.causal = True - self.return_softmax = False - # TODO(houj04) fix ut threshold after correctness check - self.rtol = 5e-3 - self.atol = 5e-3 - - -class TestFlashAttentionAPITestBF16(TestFlashAttentionAPI): - def setUp(self): - self.place = paddle.XPUPlace(0) - self.shape = (1, 128, 2, 32) - self.dtype = 'bfloat16' - self.dropout = 0.0 - self.causal = True - self.return_softmax = False - # TODO(houj04) fix ut threshold after correctness check - self.rtol = 1e-1 - self.atol = 1e-1 + float_v_grad, float_v_grad_, rtol=tolerance_dv, atol=tolerance_dv + ) + + +# TODO(houj04) un-comment following DEBUG cases after correctness check +# class TestFlashAttentionAPITest1(TestFlashAttentionAPI): +# def setUp(self): +# self.place = paddle.XPUPlace(0) +# self.shape = (2, 128, 1, 32) +# self.dropout = 0.0 +# self.causal = True +# self.return_softmax = False + + +# TODO(houj04) un-comment following REAL cases after correctness check +# class TestFlashAttentionAPITestEB(TestFlashAttentionAPI): +# def setUp(self): +# self.place = paddle.XPUPlace(0) +# self.shape = (4, 4096, 4, 128) +# self.dropout = 0.0 +# self.causal = True +# self.return_softmax = False + + +# TODO(houj04) un-comment following REAL cases after correctness check +# class TestFlashAttentionAPITestLlama7B(TestFlashAttentionAPI): +# def setUp(self): +# self.place = paddle.XPUPlace(0) +# self.shape = (2, 2048, 16, 128) +# self.dropout = 0.0 +# self.causal = True +# self.return_softmax = False + + +# TODO(houj04) un-comment following REAL cases after correctness check +# class TestFlashAttentionAPITestLlama65B(TestFlashAttentionAPI): +# def setUp(self): +# self.place = paddle.XPUPlace(0) +# self.shape = (2, 8192, 8, 128) +# self.dropout = 0.0 +# self.causal = True +# self.return_softmax = False if __name__ == '__main__': From c7a3f63d19ef81cad2aeffc6d935fd78ec971cfc Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Wed, 3 Jan 2024 11:22:25 +0800 Subject: [PATCH 080/142] [Inference] support collect shape in sub block (#60451) * support collect shape in sub block * udpate * udpate --- paddle/fluid/framework/naive_executor.cc | 3 +- .../interpreter/interpreter_util.cc | 6 +-- .../new_executor/program_interpreter.cc | 6 +-- .../fluid/inference/api/analysis_predictor.cc | 52 ++++++++++--------- .../controlflow/conditional_block_op.cc | 2 + 5 files changed, 33 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/framework/naive_executor.cc b/paddle/fluid/framework/naive_executor.cc index 3bfacc950325c1..925c04f658b0a7 100644 --- a/paddle/fluid/framework/naive_executor.cc +++ b/paddle/fluid/framework/naive_executor.cc @@ -101,8 +101,9 @@ void NaiveExecutor::Run() { func(op.get(), scope_); } - if (op->Type() == "while") { + if (op->Type() == "while" || op->Type() == "conditional_block") { op->SetOutputHooks(output_hookfuncs_); + op->SetInputHooks(input_hookfuncs_); } #ifdef PADDLE_WITH_NVTX diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 0a111922d4409b..9eb4559295649f 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -635,16 +635,12 @@ void BuildOpFuncList(const platform::Place& place, hook(op, local_scope); } - if (op->Type() == "while") { + if (op->Type() == "while" || op->Type() == "conditional_block") { op->SetInputHooks(input_hookfuncs); op->SetOutputHooks(output_hookfuncs); auto runtime_attrs = op->RuntimeAttrs(); runtime_attrs.insert(std::make_pair("used_for_inference", true)); op->SetRuntimeAttributeMap(runtime_attrs); - } else if (op->Type() == "conditional_block") { - auto runtime_attrs = op->RuntimeAttrs(); - runtime_attrs.insert(std::make_pair("used_for_inference", true)); - op->SetRuntimeAttributeMap(runtime_attrs); } } diff --git a/paddle/fluid/framework/new_executor/program_interpreter.cc b/paddle/fluid/framework/new_executor/program_interpreter.cc index 9434e4fd81af60..bc41742437ff9c 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.cc +++ b/paddle/fluid/framework/new_executor/program_interpreter.cc @@ -915,16 +915,12 @@ void ProgramInterpreter::RunOperator(const Instruction& instr_node) { hook(op, local_scope); } - if (op->Type() == "while") { + if (op->Type() == "while" || op->Type() == "conditional_block") { op->SetInputHooks(input_hookfuncs_); op->SetOutputHooks(output_hookfuncs_); auto runtime_attrs = op->RuntimeAttrs(); runtime_attrs.insert(std::make_pair("used_for_inference", true)); op->SetRuntimeAttributeMap(runtime_attrs); - } else if (op->Type() == "conditional_block") { - auto runtime_attrs = op->RuntimeAttrs(); - runtime_attrs.insert(std::make_pair("used_for_inference", true)); - op->SetRuntimeAttributeMap(runtime_attrs); } } diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 4b52ceb58ff777..e042f358c9874c 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2391,9 +2391,15 @@ bool AnalysisPredictor::ExpRunWithExternalStream(const gpuStream_t stream) { #endif void AnalysisPredictor::HookCollectShapeRangeInfo() { + if (config_.new_executor_enabled()) { + LOG_FIRST_N(WARNING, 1) + << "When collecting shapes, it is recommended to run multiple loops to " + "obtain more accurate shape information."; + } + auto hook = [&](const std::string &op_type, const std::string &input_name, - const paddle::Tensor &var) -> void { + const paddle::Tensor &input_tensor) -> void { paddle::platform::DeviceContextPool &pool = paddle::platform::DeviceContextPool::Instance(); if (config_.use_gpu()) { @@ -2408,26 +2414,22 @@ void AnalysisPredictor::HookCollectShapeRangeInfo() { #endif } - auto *new_var = sub_scope_->GetVar(input_name); - if (!new_var) return; - if (!new_var->IsType()) { - return; - } - auto tensor = new_var->Get(); - if (!tensor.initialized()) return; - framework::DDim dim = tensor.dims(); + if (!input_tensor.is_dense_tensor()) return; + auto tensor = + std::dynamic_pointer_cast(input_tensor.impl()).get(); + framework::DDim dim = tensor->dims(); std::vector shape(dim.size()); for (int i = 0; i < static_cast(shape.size()); ++i) shape[i] = static_cast(dim[i]); if (!shape.empty()) { shape_info_[input_name].emplace_back(shape); - } else if (tensor.numel() > 0) { + } else if (tensor->numel() > 0) { // This must be a zero dimension tensor. - PADDLE_ENFORCE_EQ(tensor.numel(), + PADDLE_ENFORCE_EQ(tensor->numel(), 1UL, platform::errors::PreconditionNotMet( "This tensor must have one element, but got %ld.", - tensor.numel())); + tensor->numel())); std::vector zero_shape(1, 1); shape_info_[input_name].emplace_back(zero_shape); } @@ -2437,19 +2439,19 @@ void AnalysisPredictor::HookCollectShapeRangeInfo() { // assumption that all shape tensors in the model have numbers <= 8. // This is a simple method to identify all shape tensors with some // mistakes, but it doesn't matter. - auto is_shape_tensor = tensor.numel() <= 8 && tensor.numel() >= 1; - if ((tensor.dtype() == phi::DataType::INT32 || - tensor.dtype() == phi::DataType::INT64) && + auto is_shape_tensor = tensor->numel() <= 8 && tensor->numel() >= 1; + if ((tensor->dtype() == phi::DataType::INT32 || + tensor->dtype() == phi::DataType::INT64) && is_shape_tensor) { - std::vector int32_host(tensor.numel()); + std::vector int32_host(tensor->numel()); - if (platform::is_cpu_place(tensor.place())) { - auto &int32_tensor = tensor; - if (tensor.dtype() == phi::DataType::INT64) { + if (platform::is_cpu_place(tensor->place())) { + auto &int32_tensor = *tensor; + if (tensor->dtype() == phi::DataType::INT64) { auto *cpu_ctx = pool.Get(platform::CPUPlace()); int32_tensor = phi::funcs::TransDataType( reinterpret_cast(*cpu_ctx), - tensor, + *tensor, DataType::INT32); } paddle::memory::Copy(platform::CPUPlace(), @@ -2457,14 +2459,14 @@ void AnalysisPredictor::HookCollectShapeRangeInfo() { platform::CPUPlace(), int32_tensor.data(), int32_tensor.numel() * sizeof(int)); - } else if (platform::is_gpu_place(tensor.place())) { + } else if (platform::is_gpu_place(tensor->place())) { #if defined(PADDLE_WITH_CUDA) - auto *dev_ctx = pool.Get(tensor.place()); - auto &int32_tensor = tensor; - if (tensor.dtype() == phi::DataType::INT64) { + auto *dev_ctx = pool.Get(tensor->place()); + auto &int32_tensor = *tensor; + if (tensor->dtype() == phi::DataType::INT64) { int32_tensor = phi::funcs::TransDataType( reinterpret_cast(*dev_ctx), - tensor, + *tensor, DataType::INT32); } paddle::memory::Copy(platform::CPUPlace(), diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index 58e0114045db4b..cdeb2319a280fc 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -116,6 +116,8 @@ class ConditionalBlockOp : public ConditionalOp { #endif core_.reset(new InterpreterCore( dev_place, *block, &cur_scope, execution_config)); + core_->SetOutputHooks(output_hookfuncs_); + core_->SetInputHooks(input_hookfuncs_); VLOG(10) << "[interpreterCore] created:" << core_; } else { BuildScopeForControlFlowOp(*core_, *block, &cur_scope); From ddd29d2e3423771ff29c2be6fcb501d9f450b7a1 Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Wed, 3 Jan 2024 12:21:41 +0800 Subject: [PATCH 081/142] fix process mesh incorrect set in converter (#60504) --- paddle/fluid/pybind/eager_custom_python_api.h | 6 +- paddle/fluid/pybind/eager_utils.cc | 63 +++++++------------ 2 files changed, 27 insertions(+), 42 deletions(-) diff --git a/paddle/fluid/pybind/eager_custom_python_api.h b/paddle/fluid/pybind/eager_custom_python_api.h index 8552f1e7208b8c..e829df0f4eb0cc 100644 --- a/paddle/fluid/pybind/eager_custom_python_api.h +++ b/paddle/fluid/pybind/eager_custom_python_api.h @@ -25,9 +25,9 @@ static PyObject *eager_api_linear(PyObject *self, PyObject *kwargs) { PyThreadState *tstate = nullptr; try { - auto x = GetTensorFromArgs("linear", "X", args, 0, false); - auto weight = GetTensorFromArgs("linear", "weight", args, 1, false); - auto bias = GetTensorFromArgs("linear", "Bias", args, 2, true); + auto &x = GetTensorFromArgs("linear", "X", args, 0, false); + auto &weight = GetTensorFromArgs("linear", "weight", args, 1, false); + auto &bias = GetTensorFromArgs("linear", "Bias", args, 2, true); tstate = PyEval_SaveThread(); diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index e23217feacb650..14824ed13907c2 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -120,8 +120,7 @@ void ConvertToDistTensor(Tensor* x, const phi::distributed::ProcessMesh* mesh) { if (x->is_dist_tensor()) { PADDLE_ENFORCE_EQ( std::dynamic_pointer_cast(x->impl()) - ->dist_attr() - .process_mesh(), + ->process_mesh(), *mesh, platform::errors::InvalidArgument( "Input %s has different mesh. However all inputs should " @@ -136,16 +135,18 @@ void ConvertToDistTensor(Tensor* x, const phi::distributed::ProcessMesh* mesh) { "Failed to convert input %s impl to phi::distributed::DistTensor " "as it's not phi::DenseTensor.", x->name())); - phi::distributed::TensorDistAttr dist_attr( - common::vectorize(x->impl()->dims())); - dist_attr.set_process_mesh(*mesh); + phi::distributed::Placements placements; + for (int64_t i = 0; i < mesh->ndim(); ++i) { + placements.emplace_back(std::make_shared()); + } + auto dense_t = std::static_pointer_cast(x->impl()); // auto parallel in dygraph doesn't support strided kernel. if (!dense_t->meta().is_contiguous()) { *dense_t = paddle::experimental::Trans2Contiguous(*dense_t); } - x->set_impl( - std::make_shared(dense_t, dist_attr)); + x->set_impl(std::make_shared( + dense_t, *mesh, placements)); } } @@ -393,8 +394,7 @@ std::vector CastPyArg2VectorOfTensor( local_mesh = &(std::dynamic_pointer_cast( tensor.impl()) - ->dist_attr() - .process_mesh()); + ->process_mesh()); mesh_start_index = i; } } @@ -433,8 +433,7 @@ std::vector CastPyArg2VectorOfTensor( local_mesh = &(std::dynamic_pointer_cast( tensor.impl()) - ->dist_attr() - .process_mesh()); + ->process_mesh()); mesh_start_index = i; } } @@ -1431,8 +1430,7 @@ std::vector GetTensorListFromArgs( local_mesh = &(std::dynamic_pointer_cast( tensor.impl()) - ->dist_attr() - .process_mesh()); + ->process_mesh()); mesh_start_index = i; } } @@ -1465,8 +1463,7 @@ std::vector GetTensorListFromArgs( local_mesh = &(std::dynamic_pointer_cast( tensor.impl()) - ->dist_attr() - .process_mesh()); + ->process_mesh()); mesh_start_index = i; } } @@ -1539,8 +1536,7 @@ paddle::optional> GetOptionalTensorListFromArgs( local_mesh = &(std::dynamic_pointer_cast( tensor.impl()) - ->dist_attr() - .process_mesh()); + ->process_mesh()); mesh_start_index = i; } } @@ -1573,8 +1569,7 @@ paddle::optional> GetOptionalTensorListFromArgs( local_mesh = &(std::dynamic_pointer_cast( tensor.impl()) - ->dist_attr() - .process_mesh()); + ->process_mesh()); mesh_start_index = i; } } @@ -1679,8 +1674,7 @@ std::vector GetTensorPtrListFromArgs( local_mesh = &(std::dynamic_pointer_cast( tensor->impl()) - ->dist_attr() - .process_mesh()); + ->process_mesh()); mesh_start_index = i; } } @@ -1712,8 +1706,7 @@ std::vector GetTensorPtrListFromArgs( local_mesh = &(std::dynamic_pointer_cast( tensor->impl()) - ->dist_attr() - .process_mesh()); + ->process_mesh()); mesh_start_index = i; } } @@ -1760,8 +1753,7 @@ std::vector GetTensorPtrListFromPyObject(PyObject* obj) { local_mesh = &(std::dynamic_pointer_cast( tensor->impl()) - ->dist_attr() - .process_mesh()); + ->process_mesh()); mesh_start_index = i; } } @@ -1789,8 +1781,7 @@ std::vector GetTensorPtrListFromPyObject(PyObject* obj) { local_mesh = &(std::dynamic_pointer_cast( tensor->impl()) - ->dist_attr() - .process_mesh()); + ->process_mesh()); mesh_start_index = i; } } @@ -1832,8 +1823,7 @@ std::vector GetTensorListFromPyObject(PyObject* obj, local_mesh = &(std::dynamic_pointer_cast( tensor.impl()) - ->dist_attr() - .process_mesh()); + ->process_mesh()); mesh_start_index = i; } } @@ -1871,8 +1861,7 @@ std::vector GetTensorListFromPyObject(PyObject* obj, local_mesh = &(std::dynamic_pointer_cast( tensor.impl()) - ->dist_attr() - .process_mesh()); + ->process_mesh()); mesh_start_index = i; } } @@ -2687,8 +2676,7 @@ PyMODINIT_FUNC PyInit__static_op_arg_pre_cast_hook() { void DistTensorTypeParser::operator()(const Tensor& x) { if (x.defined() && x.is_dist_tensor()) { *mesh = &(std::dynamic_pointer_cast(x.impl()) - ->dist_attr() - .process_mesh()); + ->process_mesh()); result = true; } } @@ -2698,8 +2686,7 @@ void DistTensorTypeParser::operator()(const paddle::optional& x) { if (x.get_ptr()->defined() && x.get_ptr()->is_dist_tensor()) { *mesh = &(std::dynamic_pointer_cast( x.get_ptr()->impl()) - ->dist_attr() - .process_mesh()); + ->process_mesh()); result = true; } } @@ -2711,8 +2698,7 @@ void DistTensorTypeParser::operator()(const std::vector& x) { if (t.defined() && t.is_dist_tensor()) { *mesh = &(std::dynamic_pointer_cast(t.impl()) - ->dist_attr() - .process_mesh()); + ->process_mesh()); result = true; break; } @@ -2728,8 +2714,7 @@ void DistTensorTypeParser::operator()( if (t.defined() && t.is_dist_tensor()) { *mesh = &( std::dynamic_pointer_cast(t.impl()) - ->dist_attr() - .process_mesh()); + ->process_mesh()); result = true; break; } From be8bc1e53dd5430a94a8e91c625927be70953900 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Wed, 3 Jan 2024 13:52:57 +0800 Subject: [PATCH 082/142] =?UTF-8?q?=E3=80=90CMake=20opt=20No.13=E3=80=91Re?= =?UTF-8?q?move=20CINN=20DEPS=20in=20test/cpp/pir/shape=5Fdialect/CMakeLis?= =?UTF-8?q?ts.txt=09=20(#60517)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update CMakeLists.txt * Apply suggestions from code review * Apply suggestions from code review * Update CMakeLists.txt * Update CMakeLists.txt --- test/cpp/pir/shape_dialect/CMakeLists.txt | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/test/cpp/pir/shape_dialect/CMakeLists.txt b/test/cpp/pir/shape_dialect/CMakeLists.txt index 19e1f55dad7638..fc890d2f6b8c58 100644 --- a/test/cpp/pir/shape_dialect/CMakeLists.txt +++ b/test/cpp/pir/shape_dialect/CMakeLists.txt @@ -5,15 +5,7 @@ paddle_test(shape_struct_test SRCS shape_struct_test.cc) paddle_test(symbol_dim_expr_test SRCS symbol_dim_expr_test.cc) if(WITH_CINN) - paddle_test( - shape_optimization_test - SRCS - shape_optimization_test.cc - DEPS - gtest - op_dialect_vjp - pir - pir_transforms) + paddle_test(shape_optimization_test SRCS shape_optimization_test.cc) set_tests_properties( shape_optimization_test PROPERTIES ENVIRONMENT "FLAGS_enable_pir_in_executor=true") From deb539766f6e30aeaee2921e608294b629b0500c Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Wed, 3 Jan 2024 14:00:36 +0800 Subject: [PATCH 083/142] =?UTF-8?q?=E3=80=90pir=E3=80=91=20add=20tensorarr?= =?UTF-8?q?ay=20op=20createarrylike,=20add=5Fn=20(#60460)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * optimize backward * [PIR] add vjp interface for while op * [PIR] fix ci error. * modify while stopgradient * merge * modify while grad bug * modify while grad op * modify * increment vp * [PIR] add get_used_external_value interface for block. * while case * delete print * delete print * Update python/paddle/autograd/ir_backward.py * [PIR] add unit_test for get_used_external_value * modify while_loop * code_style * modofy ci bug * modify while api * modify ci * modify array * Update python/paddle/autograd/ir_backward.py * Update test/legacy_test/test_cond.py * update * modify array_write grad info * merge * add_n and createarraylike * conflict * modify exe bug * modify kernel choose --------- Co-authored-by: winter-wang <1030748926@qq.com> --- .../pir_adaptor/pir_adaptor_util.h | 6 +- .../pir/dialect/operator/ir/manual_api.cc | 17 ++ .../pir/dialect/operator/ir/manual_api.h | 4 + .../pir/dialect/operator/ir/manual_op.cc | 255 +++++++++++++++++- .../fluid/pir/dialect/operator/ir/manual_op.h | 41 +++ .../pir/transforms/pd_op_to_kernel_pass.cc | 20 +- .../fluid/pybind/manual_static_op_function.h | 55 ++++ paddle/fluid/pybind/pir.cc | 8 + paddle/phi/infermeta/multiary.cc | 15 +- paddle/phi/infermeta/unary.cc | 5 + paddle/phi/infermeta/unary.h | 2 + paddle/phi/kernels/array_kernel.cc | 44 +++ paddle/phi/kernels/array_kernel.h | 6 + test/legacy_test/test_array_read_write_op.py | 43 +++ test/legacy_test/test_while_loop_op.py | 6 +- 15 files changed, 508 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h index cd1ca07bbe23d3..6bac089335b8ee 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h @@ -213,9 +213,13 @@ void BuildPhiContext(pir::Operation* op, } else if (variable_array[i]->IsType()) { inputs.emplace_back(InType(const_cast( &(variable_array[i]->Get())))); + } else if (variable_array[i]->IsType()) { + inputs.emplace_back(InType(const_cast( + &(variable_array[i]->Get())))); } else { PADDLE_THROW(phi::errors::Unimplemented( - "Only support Vector and vector now, " + "Only support Vector and vector " + "and vector now " "not support vector<%d>.", variable_array[i]->Type())); } diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc index 33fecafdbb0258..b4edb817521e06 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc @@ -133,6 +133,14 @@ pir::OpResult create_array(phi::DataType dtype) { return create_array_op.out(); } +pir::OpResult create_array_like(pir::Value input, float value) { + auto create_array_like_op = + ApiBuilder::Instance() + .GetBuilder() + ->Build(input, value); + return create_array_like_op.out(); +} + pir::OpResult array_length(pir::Value x) { auto array_length_op = ApiBuilder::Instance() .GetBuilder() @@ -165,6 +173,15 @@ std::tuple array_to_tensor(pir::Value x, return std::make_tuple(array_to_tensor.result(0), array_to_tensor.result(1)); } +pir::OpResult add_n_array(const std::vector& inputs) { + auto inputs_combine_op = + ApiBuilder::Instance().GetBuilder()->Build(inputs); + paddle::dialect::AddNArrayOp add_n_array_op = + ApiBuilder::Instance().GetBuilder()->Build( + inputs_combine_op.out()); + return add_n_array_op.result(0); +} + pir::OpResult slice_array_dense(pir::Value input, pir::Value starts) { auto op = ApiBuilder::Instance() .GetBuilder() diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.h b/paddle/fluid/pir/dialect/operator/ir/manual_api.h index 347e10494696c0..587554ab2c3c81 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.h @@ -62,6 +62,8 @@ pir::OpResult zeros(const std::vector& shape, pir::OpResult create_array(phi::DataType dtype); +pir::OpResult create_array_like(pir::Value input, float value); + pir::OpResult array_length(pir::Value x); pir::OpResult array_read(pir::Value array, pir::Value i); @@ -72,6 +74,8 @@ std::tuple array_to_tensor(pir::Value x, int axis, bool use_stack); +pir::OpResult add_n_array(const std::vector& inputs); + pir::OpResult slice_array_dense(pir::Value input, pir::Value starts); } // namespace dialect diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index 0a60b4c7d7d819..4c07132cfaa1f0 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -14,14 +14,16 @@ #ifdef GET_OP_LIST #undef GET_OP_LIST paddle::dialect::AddNOp, paddle::dialect::AddN_Op, - paddle::dialect::AddNWithKernelOp, paddle::dialect::FusedGemmEpilogueOp, + paddle::dialect::AddNWithKernelOp, paddle::dialect::AddNArrayOp, + paddle::dialect::FusedGemmEpilogueOp, paddle::dialect::FusedGemmEpilogueGradOp, paddle::dialect::SplitGradOp, paddle::dialect::ExpandOp, paddle::dialect::CreateArrayOp, - paddle::dialect::ArrayLengthOp, paddle::dialect::ArrayReadOp, - paddle::dialect::ArrayWrite_Op, paddle::dialect::SliceArrayOp, - paddle::dialect::SliceArrayDenseOp, paddle::dialect::AssignArray_Op, - paddle::dialect::ArrayToTensorOp, paddle::dialect::SelectInputOp, - paddle::dialect::IncrementOp, paddle::dialect::Increment_Op + paddle::dialect::CreateArrayLikeOp, paddle::dialect::ArrayLengthOp, + paddle::dialect::ArrayReadOp, paddle::dialect::ArrayWrite_Op, + paddle::dialect::SliceArrayOp, paddle::dialect::SliceArrayDenseOp, + paddle::dialect::AssignArray_Op, paddle::dialect::ArrayToTensorOp, + paddle::dialect::SelectInputOp, paddle::dialect::IncrementOp, + paddle::dialect::Increment_Op #else #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" @@ -421,6 +423,136 @@ void AddNWithKernelOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +OpInfoTuple AddNArrayOp::GetOpInfo() { + std::vector inputs = { + OpInputInfo("inputs", + "pir::VectorType", + false, + false, + false, + true)}; + std::vector attributes = {}; + std::vector outputs = {OpOutputInfo( + "out", "paddle::dialect::DenseTensorArrayType", false, false)}; + paddle::dialect::OpRunTimeInfo run_time_info = + OpRunTimeInfo("AddNTensorArrayInferMeta", + {"inputs"}, + "add_n_array", + {"inputs"}, + {}, + {}, + {}, + {}); + + return std::make_tuple( + inputs, attributes, outputs, run_time_info, "add_n_array"); +} + +void AddNArrayOp::VerifySig() { + VLOG(4) << "Start Verifying inputs, outputs and attributes for: AddNArrayOp."; + VLOG(4) << "Verifying inputs:"; + { + auto input_size = num_operands(); + PADDLE_ENFORCE_EQ( + input_size, + 1u, + phi::errors::PreconditionNotMet( + "The size %d of inputs must be equal to 1.", input_size)); + if (auto vec_type = + (*this)->operand(0).type().dyn_cast()) { + for (size_t i = 0; i < vec_type.size(); ++i) { + PADDLE_ENFORCE(vec_type[i].isa(), + phi::errors::PreconditionNotMet( + "Type validation failed for the 0th input.")); + } + } else { + PADDLE_ENFORCE((*this) + ->operand(0) + .type() + .isa(), + phi::errors::PreconditionNotMet( + "Type validation failed for the 0th input.")); + } + } + VLOG(4) << "Verifying attributes:"; + { + // Attributes num is 0, not need to check attributes type. + } + VLOG(4) << "Verifying outputs:"; + { + auto output_size = num_results(); + PADDLE_ENFORCE_EQ( + output_size, + 1u, + phi::errors::PreconditionNotMet( + "The size %d of outputs must be equal to 1.", output_size)); + PADDLE_ENFORCE( + (*this)->result(0).type().isa(), + phi::errors::PreconditionNotMet( + "Type validation failed for the 0th output.")); + } + VLOG(4) << "End Verifying for: AddNArrayOp."; +} + +void AddNArrayOp::Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value inputs_) { + VLOG(4) << "Start build AddNArrayOp"; + + VLOG(4) << "Builder construction inputs"; + argument.AddInput(inputs_); + + VLOG(4) << "Builder construction attributes"; + + VLOG(4) << "Builder construction outputs"; + pir::VectorType inputs = inputs_.type().dyn_cast(); + + std::vector vec_dense_inputs; + for (size_t i = 0; i < inputs.size(); i++) { + vec_dense_inputs.push_back(paddle::dialect::IrTensor( + TransToPhiDataType( + inputs[i] + .dyn_cast() + .dtype()), + {}, + inputs[i] + .dyn_cast() + .data_layout(), + {})); + } + + std::vector vec_meta_inputs; + for (size_t i = 0; i < vec_dense_inputs.size(); i++) { + vec_meta_inputs.push_back( + paddle::dialect::IrMetaTensor(&vec_dense_inputs[i])); + } + + std::vector meta_inputs; + for (size_t i = 0; i < static_cast(vec_meta_inputs.size()); i++) { + meta_inputs.push_back(&vec_meta_inputs[i]); + } + + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::AddNTensorArrayInferMeta( + meta_inputs, &meta_out, phi::MetaConfig(false, false)); + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorArrayType::get( + pir::IrContext::Instance(), + TransToIrDataType(dense_out.dtype()), + dense_out.layout()); + + argument_outputs.push_back(out_dense_tensor_type); + argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); + ::pir::PassStopGradientsDefaultly(argument); +} + +void AddNArrayOp::InferMeta(phi::InferMetaContext *infer_meta) { + auto fn = PD_INFER_META(phi::AddNTensorArrayInferMeta); + fn(infer_meta); +} + const char *FusedGemmEpilogueOp::attributes_name[3] = { "trans_x", "trans_y", "activation"}; @@ -1156,6 +1288,114 @@ void CreateArrayOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } +const char *CreateArrayLikeOp::attributes_name[1] = {"val"}; + +OpInfoTuple CreateArrayLikeOp::GetOpInfo() { + std::vector inputs = { + paddle::dialect::OpInputInfo("input", + "paddle::dialect::DenseTensorArrayType", + false, + false, + false, + false)}; + + std::vector attributes = { + paddle::dialect::OpAttributeInfo("val", "pir::FloatAttribute", "")}; + + std::vector outputs = {OpOutputInfo( + "out", "paddle::dialect::DenseTensorArrayType", false, false)}; + + paddle::dialect::OpRunTimeInfo run_time_info = + OpRunTimeInfo("CreateArrayLikeInferMeta", + {"input"}, + "create_array_like", + {"input", "val"}, + {}, + {}, + {}, + {}); + + return std::make_tuple( + inputs, attributes, outputs, run_time_info, "create_array_like"); +} + +void CreateArrayLikeOp::Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value &input_, // NOLINT + float &val) { + VLOG(4) << "Start build CreateArrayLikeOp"; + VLOG(4) << "Builder construction inputs"; + std::vector argument_inputs = {input_}; + argument.AddInputs(argument_inputs); + + VLOG(4) << "Builder construction attributes"; + pir::Attribute attr_val = + pir::FloatAttribute::get(pir::IrContext::Instance(), val); + argument.AddAttribute("val", attr_val); + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorArrayType input_type = + input_.type().dyn_cast(); + paddle::dialect::IrTensor dense_input( + paddle::dialect::TransToPhiDataType(input_type.dtype()), + {}, + input_type.data_layout(), + {}); + + paddle::dialect::IrMetaTensor meta_input(&dense_input); + + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::CreateArrayLikeInferMeta(meta_input, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorArrayType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.layout()); + argument_outputs.push_back(out_dense_tensor_type); + argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); + ::pir::PassStopGradientsDefaultly(argument); +} + +void CreateArrayLikeOp::VerifySig() { + VLOG(4) << "Start Verifying inputs, outputs and attributes for: " + "CreateArrayLikeOp."; + VLOG(4) << "Verifying inputs:"; + { + auto input_size = num_operands(); + PADDLE_ENFORCE_EQ( + input_size, + 1u, + phi::errors::PreconditionNotMet( + "The size %d of inputs must be equal to 1.", input_size)); + } + VLOG(4) << "Verifying attributes:"; + { + auto &attributes = this->attributes(); + PADDLE_ENFORCE(attributes.count("val") > 0, "val does not exist."); + } + VLOG(4) << "Verifying outputs:"; + { + auto output_size = num_results(); + PADDLE_ENFORCE_EQ( + output_size, + 1u, + phi::errors::PreconditionNotMet( + "The size %d of outputs must be equal to 1.", output_size)); + PADDLE_ENFORCE( + (*this)->result(0).type().isa(), + phi::errors::PreconditionNotMet( + "Type validation failed for the 0th output.")); + } + VLOG(4) << "End Verifying for: CreateArrayLikeOp."; +} + +void CreateArrayLikeOp::InferMeta(phi::InferMetaContext *infer_meta) { + auto fn = PD_INFER_META(phi::CreateArrayLikeInferMeta); + fn(infer_meta); +} + OpInfoTuple ArrayLengthOp::GetOpInfo() { std::vector inputs = { OpInputInfo("x", @@ -1319,6 +1559,7 @@ void ArrayReadOp::Build(pir::Builder &builder, dense_out.lod()); argument_outputs.push_back(out_type); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); + ::pir::PassStopGradientsDefaultly(argument); } void ArrayReadOp::Build(pir::Builder &builder, @@ -2691,9 +2932,11 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNArrayOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueGradOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CreateArrayOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CreateArrayLikeOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayLengthOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayReadOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayWrite_Op) diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index 121c95dee169aa..cbfadb24b97e6d 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -102,6 +102,26 @@ class AddNWithKernelOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_op.add_n_array"; } + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + static OpInfoTuple GetOpInfo(); + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value inputs_); + + void VerifySig(); + pir::Value inputs() { return operand_source(0); } + pir::OpResult out() { return result(0); } + + static void InferMeta(phi::InferMetaContext *infer_meta); +}; + class FusedGemmEpilogueOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_op.create_array_like"; } + static constexpr uint32_t attributes_num = 1; + static const char *attributes_name[attributes_num]; + static OpInfoTuple GetOpInfo(); + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value &input_, // NOLINT + float &val); // NOLINT + void VerifySig(); + pir::Value input() { return operand_source(0); } + pir::OpResult out() { return result(0); } + static void InferMeta(phi::InferMetaContext *infer_meta); +}; + class ArrayLengthOp : public pir::Op { public: @@ -522,9 +561,11 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNArrayOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueGradOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::CreateArrayOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::CreateArrayLikeOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayLengthOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayReadOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ArrayWrite_Op) diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 165a1d3fde4fc7..10bc81ea6eac99 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -292,6 +292,9 @@ static std::vector> PrepareFakeTensors( } else if (inner_types[i].isa()) { res.push_back( fake_sr(inner_types[i].dyn_cast())); + } else if (inner_types[i].isa()) { + res.push_back(fake_tensor_array( + inner_types[i].dyn_cast())); } } } else if (in_type.isa()) { @@ -942,6 +945,7 @@ phi::KernelKey GetKernelKey( if (!combine_op_res) { continue; } + if (combine_op_res.owner()->isa()) { auto data_op = combine_op_res.owner(); auto data_place = @@ -1807,11 +1811,15 @@ std::vector BuildInputs( place = in_i_type.dyn_cast().place(); } else if (in_i_type.isa()) { place = in_i_type.dyn_cast().place(); + } else if (in_i_type.isa()) { + place = + in_i_type.dyn_cast().place(); } else { PADDLE_THROW(phi::errors::Unimplemented( "builtin.combine Input type only support " "VectorType and " - "VectorType")); + "VectorType and" + "VectorType")); } // get input args def type @@ -1844,11 +1852,19 @@ std::vector BuildInputs( pre_define_op->operand_source(j) .type() .dyn_cast()); + } else if (in_i_type.isa()) { + out_type = AllocatedDenseTensorArrayType::get( + ctx, + out_place, + pre_define_op->operand_source(j) + .type() + .dyn_cast()); } else { PADDLE_THROW(phi::errors::Unimplemented( "builtin.combine Input type only support " "VectorType and " - "VectorType")); + "VectorType and" + "VectorType")); } in_i = AddPlaceTransferOp( in_i, out_type, place, out_place, kernel_key, block); diff --git a/paddle/fluid/pybind/manual_static_op_function.h b/paddle/fluid/pybind/manual_static_op_function.h index dc09d539f39ffb..af733ecbce53f8 100644 --- a/paddle/fluid/pybind/manual_static_op_function.h +++ b/paddle/fluid/pybind/manual_static_op_function.h @@ -147,6 +147,31 @@ static PyObject *static_api_create_array(PyObject *self, } } +static PyObject *static_api_create_array_like(PyObject *self, + PyObject *args, + PyObject *kwargs) { + try { + VLOG(6) << "Add create_array_like op into program"; + VLOG(8) << "args count: " << (PyTuple_Size(args) / 2); + + // Get Value from args + PyObject *input_obj = PyTuple_GET_ITEM(args, 0); + auto input = CastPyArg2Value(input_obj, "create_array_like", 0); + + // Parse Attributes + PyObject *value_obj = PyTuple_GET_ITEM(args, 1); + float value = CastPyArg2Float(value_obj, "create_array_like", 1); + + // Call ir static api + auto static_api_out = paddle::dialect::create_array_like(input, value); + + return ToPyObject(static_api_out); + } catch (...) { + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } +} + static PyObject *static_api_array_length(PyObject *self, PyObject *args, PyObject *kwargs) { @@ -274,6 +299,28 @@ static PyObject *static_api_array_to_tensor(PyObject *self, } } +PyObject *static_api_add_n_array(PyObject *self, + PyObject *args, + PyObject *kwargs) { + try { + VLOG(6) << "Add add_n_array op into program"; + VLOG(8) << "args count: " << (PyTuple_Size(args) / 2); + + // Get Value from args + PyObject *inputs_obj = PyTuple_GET_ITEM(args, 0); + auto inputs = CastPyArg2VectorOfValue(inputs_obj, "add_n", 0); + + // Parse Attributes + + // Call ir static api + auto static_api_out = paddle::dialect::add_n_array(inputs); + + return ToPyObject(static_api_out); + } catch (...) { + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } +} static PyObject *static_api_slice_array_dense(PyObject *self, PyObject *args, PyObject *kwargs) { @@ -324,6 +371,10 @@ static PyMethodDef ManualOpsAPI[] = { (PyCFunction)(void (*)(void))static_api_create_array, METH_VARARGS | METH_KEYWORDS, "C++ interface function for create_array."}, + {"create_array_like", + (PyCFunction)(void (*)(void))static_api_create_array_like, + METH_VARARGS | METH_KEYWORDS, + "C++ interface function for create_array_like."}, {"array_length", (PyCFunction)(void (*)(void))static_api_array_length, METH_VARARGS | METH_KEYWORDS, @@ -340,6 +391,10 @@ static PyMethodDef ManualOpsAPI[] = { (PyCFunction)(void (*)(void))static_api_array_to_tensor, METH_VARARGS | METH_KEYWORDS, "C++ interface function for array_to_tensor."}, + {"add_n_array", + (PyCFunction)(void (*)(void))static_api_add_n_array, + METH_VARARGS | METH_KEYWORDS, + "C++ interface function for add_n_array."}, {"slice_array_dense", (PyCFunction)(void (*)(void))static_api_slice_array_dense, METH_VARARGS | METH_KEYWORDS, diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 1c398cf7cdf975..a477f42e40c485 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -665,6 +665,14 @@ void BindValue(py::module *m) { "is persistable")); } }) + .def("is_tensorarray", + [](Value self) { + if (self.type().isa()) { + return true; + } else { + return false; + } + }) .def_property( "shape", [](Value self) { return phi::vectorize(GetValueDims(self)); }, diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 6250b3a3b23c81..5b9708b38a17e1 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -464,22 +464,19 @@ void AddNInferMeta(const std::vector& x, void AddNTensorArrayInferMeta(const std::vector& x, MetaTensor* out, MetaConfig config) { - int64_t max_length = 0; bool has_tensor_array = false; for (auto input : x) { if (input->is_tensor_array()) { + if (out->is_tensor_array()) { + out->set_dtype(input->dtype()); + out->set_layout(input->layout()); + } has_tensor_array = true; - // if input is lod_tensor_array, dims() will return its size (one element) - max_length = - input->dims()[0] > max_length ? input->dims()[0] : max_length; + break; } } - if (has_tensor_array) { - if (out->is_tensor_array()) { - out->set_dims(common::make_ddim({max_length})); - } - } else { + if (!has_tensor_array) { AddNInferMeta(x, out, config); } } diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index a75cd4170e2785..b1b06fdbfed715 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -524,6 +524,11 @@ void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) { out->set_layout(x.layout()); } +void CreateArrayLikeInferMeta(const MetaTensor& x, MetaTensor* out) { + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); +} + void CumInferMeta(const MetaTensor& x, int axis, bool flatten, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index eae4614a8eb5c9..0126b76754fef2 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -124,6 +124,8 @@ void CIdentityInferMeta(const MetaTensor& x, void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out); +void CreateArrayLikeInferMeta(const MetaTensor& x, MetaTensor* out); + void CropInferMeta(const MetaTensor& x, const IntArray& shape, const IntArray& offsets, diff --git a/paddle/phi/kernels/array_kernel.cc b/paddle/phi/kernels/array_kernel.cc index 8a599dcf9d80d8..9f794be3f42721 100644 --- a/paddle/phi/kernels/array_kernel.cc +++ b/paddle/phi/kernels/array_kernel.cc @@ -27,6 +27,20 @@ void CreateArrayKernel(const Context& dev_ctx, DataType dtype, TensorArray* out) {} +template +void CreateArrayLikeKernel(const Context& dev_ctx, + const TensorArray& input, + float val, + TensorArray* out) { + out->resize(input.size()); + for (size_t i = 0; i < input.size(); i++) { + DenseTensor input_i = input[i]; + out->at(i).Resize(input_i.dims()); + FullLikeKernel( + dev_ctx, input_i, val, input_i.dtype(), &out->at(i)); + } +} + template void ArrayLengthKernel(const Context& dev_ctx, const TensorArray& x, @@ -150,6 +164,36 @@ PD_REGISTER_KERNEL(create_array, phi::dtype::complex) {} #endif +PD_REGISTER_KERNEL(create_array_like, + CPU, + ALL_LAYOUT, + phi::CreateArrayLikeKernel, + bool, + int, + int64_t, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL(create_array_like, + GPU, + ALL_LAYOUT, + phi::CreateArrayLikeKernel, + bool, + int, + int64_t, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} +#endif + PD_REGISTER_KERNEL(array_length, CPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/array_kernel.h b/paddle/phi/kernels/array_kernel.h index 0c8436501371da..d41fc36d1057f6 100644 --- a/paddle/phi/kernels/array_kernel.h +++ b/paddle/phi/kernels/array_kernel.h @@ -25,6 +25,12 @@ void CreateArrayKernel(const Context& dev_ctx, DataType dtype, TensorArray* out); +template +void CreateArrayLikeKernel(const Context& dev_ctx, + const TensorArray& input, + float val, + TensorArray* out); + template void ArrayLengthKernel(const Context& dev_ctx, const TensorArray& x, diff --git a/test/legacy_test/test_array_read_write_op.py b/test/legacy_test/test_array_read_write_op.py index 5125ec16cf70d5..dbdcb7707c3939 100644 --- a/test/legacy_test/test_array_read_write_op.py +++ b/test/legacy_test/test_array_read_write_op.py @@ -242,6 +242,49 @@ def test_array_backward(self): np.testing.assert_allclose(res[0], mean, rtol=1e-05) np.testing.assert_allclose(res[1], x_grad, rtol=1e-05) + def test_create_array_like_add_n(self): + paddle.enable_static() + np.random.seed(2013) + with paddle.pir_utils.IrGuard(): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + d0 = paddle.static.data(name='d0', shape=[10], dtype='float32') + d1 = paddle.static.data(name='d1', shape=[10], dtype='float32') + i = paddle.zeros(shape=[1], dtype='int64') + mem_array = paddle.tensor.array_write(x=d0, i=i) + i = paddle.increment(i) + paddle.tensor.array_write(x=d1, i=i, array=mem_array) + copy_array = paddle._pir_ops.create_array_like(mem_array, 0.0) + out = paddle.tensor.array_read(array=copy_array, i=i) + + paddle.tensor.array_write(x=d0, i=i, array=copy_array) + i = paddle.increment(i, -1) + paddle.tensor.array_write(x=d1, i=i, array=copy_array) + + add_array = paddle._pir_ops.add_n_array([mem_array, copy_array]) + out_1 = paddle.tensor.array_read(array=add_array, i=i) + i = paddle.increment(i, 1) + out_2 = paddle.tensor.array_read(array=add_array, i=i) + + place = ( + base.CUDAPlace(0) + if core.is_compiled_with_cuda() + else base.CPUPlace() + ) + d0 = np.random.random(size=[10]).astype('float32') + d1 = np.random.random(size=[10]).astype('float32') + exe = base.Executor(place) + res = exe.run( + main_program, + feed={'d0': d0, 'd1': d1}, + fetch_list=[out, out_1, out_2], + ) + out = [0.0] * 10 + np.testing.assert_allclose(res[0], out, rtol=1e-05) + np.testing.assert_allclose(res[1], d0 + d1, rtol=1e-05) + np.testing.assert_allclose(res[2], d0 + d1, rtol=1e-05) + if __name__ == '__main__': unittest.main() diff --git a/test/legacy_test/test_while_loop_op.py b/test/legacy_test/test_while_loop_op.py index 83fecc6b5ad7f5..44ee6383fa6abc 100644 --- a/test/legacy_test/test_while_loop_op.py +++ b/test/legacy_test/test_while_loop_op.py @@ -353,7 +353,6 @@ def body(i, x): fetch_list = [out[1]] for p, g in grad_list: fetch_list.append(g) - res = exe.run( main_program, feed={'i': feed_i, 'x': feed_x}, @@ -409,10 +408,14 @@ def internal_body(j, x, mem_array): d1 = paddle.static.data(name='d1', shape=[10], dtype='float32') d2 = paddle.static.data(name='d2', shape=[10], dtype='float32') x = paddle.static.data(name='x', shape=[10], dtype='float32') + d0.persistable = True + d1.persistable = True + d2.persistable = True x.stop_gradient = False x.persistable = True i = paddle.zeros(shape=[1], dtype='int64') i.stop_gradient = True + i.persistable = True init = paddle.zeros(shape=[10], dtype='float32') mem_array = paddle.tensor.array_write(x=init, i=i) data_array = paddle.tensor.array_write(x=d0, i=i) @@ -440,6 +443,7 @@ def internal_body(j, x, mem_array): sum_result = paddle.tensor.array_read(array=out[3], i=j) mean = paddle.mean(sum_result) grad_list = append_backward(mean) + place = ( base.CUDAPlace(0) if core.is_compiled_with_cuda() From 698bb4252ae56d147aa41449b7bf7ddc35c4cc3d Mon Sep 17 00:00:00 2001 From: BiynXu <62832681+BiynXu@users.noreply.github.com> Date: Wed, 3 Jan 2024 14:17:58 +0800 Subject: [PATCH 084/142] Add align iter space tactic (#60498) Add align iter space tactic --- paddle/cinn/hlir/pe/elementwise.cc | 3 +- .../dy_shape_group_scheduler.cc | 149 ++++++++++++++++-- .../group_schedule/dy_shape_group_scheduler.h | 5 + .../ir/group_schedule/tactic/CMakeLists.txt | 3 +- .../tactic/align_iter_space_tactic.cc | 87 ++++++++++ .../tactic/align_iter_space_tactic.h | 37 +++++ .../tactic/arrange_storage_tactic.cc | 8 +- .../tactic/arrange_storage_tactic.h | 5 +- .../tactic/compute_inline_tactic.cc | 7 +- .../tactic/compute_inline_tactic.h | 6 +- .../group_schedule/tactic/schedule_tactic.h | 37 ++++- paddle/cinn/ir/ir_analyzer/ir_analyzer.cc | 67 ++++++++ paddle/cinn/ir/ir_analyzer/ir_analyzer.h | 6 + .../ir/schedule/impl/loop_transformation.cc | 10 +- 14 files changed, 400 insertions(+), 30 deletions(-) create mode 100644 paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.cc create mode 100644 paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h diff --git a/paddle/cinn/hlir/pe/elementwise.cc b/paddle/cinn/hlir/pe/elementwise.cc index b09ce0f971d534..60933cd66c4b07 100644 --- a/paddle/cinn/hlir/pe/elementwise.cc +++ b/paddle/cinn/hlir/pe/elementwise.cc @@ -17,6 +17,7 @@ #include #include +#include "paddle/cinn/common/cas.h" #include "paddle/cinn/hlir/op/op_util.h" #include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/lang/builtin.h" @@ -216,7 +217,7 @@ ir::Tensor Reshape(const ir::Tensor& A, } std::vector indice_a; for (int i = A_expr_shape.size() - 1; i >= 0; i--) { - auto temp = offset % A_expr_shape[i]; + auto temp = common::AutoSimplify(offset % A_expr_shape[i]); indice_a.insert(indice_a.begin(), temp); offset = (offset - temp) / A_expr_shape[i]; } diff --git a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc index 04e7afa8760f64..d56fc994fdcea3 100644 --- a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc +++ b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc @@ -13,31 +13,31 @@ // limitations under the License. #include "paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h" +#include "paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h" +#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h" namespace cinn { namespace ir { void DynamicShapeGroupScheduler::Init() { - std::unordered_set output_names = OutputTensorNames(); - tactics_.emplace_back(new ComputeInlineTactic(output_names, target_)); - tactics_.emplace_back(new ArrangeStorageTactic(output_names)); + schedule_context_.output_names = OutputTensorNames(); + schedule_context_.global_master = FindGlobalMasterNode(); + schedule_context_.iter_space_info = + ConstructIterSpaceInfo(schedule_context_.global_master); + schedule_context_.target = target_; + tactics_.emplace_back(new AlignIterSpaceTactic()); + tactics_.emplace_back(new ComputeInlineTactic()); + tactics_.emplace_back(new ArrangeStorageTactic()); } void DynamicShapeGroupScheduler::Schedule() { // Fake schedule for test - std::vector all_blocks = ir_sch_->GetAllBlocks(); - for (int i = 0; i < all_blocks.size(); i++) { - std::vector loops = ir_sch_->GetLoops(all_blocks[i]); - ir_sch_->Fuse(loops); - } - ApplyTactics(); - all_blocks = ir_sch_->GetAllBlocks(); + std::vector all_blocks = ir_sch_->GetAllBlocks(); auto block0_loops = ir_sch_->GetLoops(all_blocks[0]); auto splited_loops1 = ir_sch_->Split(block0_loops[0], {1024, -1}); - ir_sch_->Bind(splited_loops1[0], "threadIdx.x"); ir::Expr predicate1 = ir::LE::Make(Expr(1023), Expr(1024)); @@ -49,11 +49,22 @@ void DynamicShapeGroupScheduler::Schedule() { void DynamicShapeGroupScheduler::ApplyTactics() { schedule_block_graph_->Update(*ir_sch_); for (const auto& tactic : tactics_) { + VLOG(5) << "[Start " << tactic->TacticName() << "] func body:\n" + << ir_sch_->GetModule().GetExprs().front(); auto ApplyTacticFunc = [&](ir::ScheduleBlockNode* node) { + VLOG(6) << "before applying [" << tactic->TacticName() + << "] on ScheduleBlockNode [" << node->id() << "] func body:\n" + << ir_sch_->GetModule().GetExprs().front(); + tactic->Init(&schedule_context_); tactic->Apply(ir_sch_, node->id()); + VLOG(6) << "after applying [" << tactic->TacticName() + << "] on ScheduleBlockNode [" << node->id() << "] func body:\n" + << ir_sch_->GetModule().GetExprs().front(); }; schedule_block_graph_->DFSTopoWalk(ApplyTacticFunc); schedule_block_graph_->Update(*ir_sch_); + VLOG(5) << "[End " << tactic->TacticName() + << "] func body: " << ir_sch_->GetModule().GetExprs().front(); } } @@ -67,5 +78,121 @@ DynamicShapeGroupScheduler::GetIRs() { return irs; } +IterativeSpaceInfo DynamicShapeGroupScheduler::ConstructIterSpaceInfo( + ScheduleBlockNode* node) { + IterativeSpaceInfo info; + std::vector sp_iter_indices; + std::vector rb_iter_indices; + + ir::Expr block = node->Block(); + std::vector iter_values = + block.As()->iter_values; + std::vector iter_vars = block.As() + ->schedule_block.As() + ->iter_vars; + std::vector loops = ir_sch_->GetLoops(block); + std::unordered_set reduce_iter_vars = + analyzer::GetReduceIterVars(block); + std::unordered_map iter_var2value = + analyzer::GetIterVarToValueOfSBlock(block); + + if (!reduce_iter_vars.empty()) { + std::set reduce_loads = ir::ir_utils::CollectIRNodesWithoutTensor( + block, + [&](const ir::Expr* x) { + bool find_reduce_var = false; + if (x->As()) { + for (ir::Expr index : x->As()->indices) { + if (index.as_var() && + reduce_iter_vars.count(index.as_var_ref()) > 0) { + find_reduce_var = true; + break; + } + } + } + return find_reduce_var; + }, + /* uniq_target = */ true); + CHECK_EQ(reduce_loads.size(), 1); + + std::vector reduce_load_indices = + reduce_loads.begin()->As()->indices; + int loop_idx = 0; + for (int i = 0; i < reduce_load_indices.size(); ++i) { + ir::Expr& index = reduce_load_indices[i]; + if (index.is_constant()) continue; + CHECK_NOTNULL(index.as_var()); + ir::Var iter_var = index.as_var_ref(); + ir::Expr iter_value = iter_var2value.at(iter_var); + CHECK_NOTNULL(iter_value.as_var()); + ir::For* for_node; + for (ir::Expr& loop : loops) { + if (loop.As()->loop_var == iter_value.as_var_ref()) { + for_node = loop.As(); + } + } + CHECK_NOTNULL(for_node); + bool is_reduce_iter_var = reduce_iter_vars.count(iter_var) > 0; + if (is_reduce_iter_var) { + info.rb_space.emplace_back(for_node->extent, + IterativeSpaceInfo::AxisType::kSerial); + info.memory_consistent_order_space.emplace_back(for_node->extent); + rb_iter_indices.push_back(loop_idx); + } else { + info.sp_space.emplace_back(for_node->extent, + IterativeSpaceInfo::AxisType::kSerial); + info.memory_consistent_order_space.emplace_back(for_node->extent); + sp_iter_indices.push_back(loop_idx); + } + ++loop_idx; + } + info.rb_last_order.insert(info.rb_last_order.end(), + sp_iter_indices.begin(), + sp_iter_indices.end()); + info.rb_last_order.insert(info.rb_last_order.end(), + rb_iter_indices.begin(), + rb_iter_indices.end()); + } else { + for (int i = 0; i < loops.size(); ++i) { + ir::For* for_node = loops[i].As(); + info.memory_consistent_order_space.emplace_back(for_node->extent); + info.sp_space.emplace_back(for_node->extent, + IterativeSpaceInfo::AxisType::kSerial); + info.rb_last_order.push_back(i); + } + } + return info; +} + +ir::ScheduleBlockNode* DynamicShapeGroupScheduler::FindGlobalMasterNode() { + ir::ScheduleBlockNode* master = nullptr; + // 1. reduce + auto FindReduce = [&](ir::ScheduleBlockNode* node) { + if (analyzer::IsReductionSBlock(node->Block())) { + master = node; + } + }; + schedule_block_graph_->NodesWalk(FindReduce); + if (master != nullptr) { + VLOG(6) << "Find the global master node: " << master->id(); + return master; + } + // 2. broadcast + auto FindBroadcast = [&](ir::ScheduleBlockNode* node) { + if (analyzer::IsBroadcastSBlock(node->Block())) { + master = node; + } + }; + schedule_block_graph_->NodesWalk(FindBroadcast); + if (master != nullptr) { + VLOG(6) << "Find the global master node: " << master->id(); + return master; + } + // 3. end point + master = schedule_block_graph_->EndPoints().back(); + VLOG(6) << "Find the global master node: " << master->id(); + return master; +} + } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h index 7d2f9115776dca..896fe86bec852d 100644 --- a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h +++ b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h @@ -42,10 +42,15 @@ class DynamicShapeGroupScheduler : public GroupScheduler { void ApplyTactics(); + ir::ScheduleBlockNode* FindGlobalMasterNode(); + + IterativeSpaceInfo ConstructIterSpaceInfo(ScheduleBlockNode* node); + private: std::vector>> ir_schs_; std::vector> tactics_; + ScheduleContext schedule_context_; }; } // namespace ir diff --git a/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt b/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt index 6ed979ece476b6..da964e770ae9ba 100644 --- a/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt +++ b/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt @@ -1,4 +1,5 @@ core_gather_headers() -gather_srcs(cinnapi_src SRCS arrange_storage_tactic.cc) +gather_srcs(cinnapi_src SRCS align_iter_space_tactic.cc) gather_srcs(cinnapi_src SRCS compute_inline_tactic.cc) +gather_srcs(cinnapi_src SRCS arrange_storage_tactic.cc) diff --git a/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.cc new file mode 100644 index 00000000000000..80a044d26b5bd4 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h" +#include "paddle/cinn/common/cas.h" +#include "paddle/cinn/common/integer_set.h" +#include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h" +#include "paddle/cinn/ir/op/ir_operators.h" +#include "paddle/cinn/ir/utils/ir_copy.h" + +namespace cinn { +namespace ir { + +void AlignIterSpaceTactic::Init(ScheduleContext* context) { + context_ = context; +} + +void AlignIterSpaceTactic::Apply(ir::IRSchedule* sch, + const std::string& block_id) { + ir::Expr block = sch->GetBlock(block_id); + if (analyzer::IsReductionSBlock(block)) { + return; + } + + std::vector loops = sch->GetLoops(block_id); + ir::Expr src_fused_loop = sch->Fuse(loops); + ir::Expr src_total_extent = src_fused_loop.As()->extent; + + ir::Expr target_sp_extent{1}; + for (const auto& iter : context_->iter_space_info.sp_space) { + target_sp_extent = target_sp_extent * std::get<0>(iter); + } + ir::Expr target_total_extent = ir_utils::IRCopy(target_sp_extent); + for (const auto& iter : context_->iter_space_info.rb_space) { + target_total_extent = target_total_extent * std::get<0>(iter); + } + + common::cas_intervals_t var_intervals; + common::SymbolicExprAnalyzer symbolic_expr_analyzer(var_intervals); + std::optional total_extent_eq = + symbolic_expr_analyzer.ProveEQ(src_total_extent, target_total_extent); + bool need_reorder = false; + for (int i = 0; i < context_->iter_space_info.rb_last_order.size(); ++i) { + if (context_->iter_space_info.rb_last_order[i] != i) { + need_reorder = true; + break; + } + } + + if (total_extent_eq.has_value() && total_extent_eq.value()) { + sch->Split(src_fused_loop, + context_->iter_space_info.memory_consistent_order_space); + loops = sch->GetLoops(block_id); + if (need_reorder) { + sch->Reorder(block_id, context_->iter_space_info.rb_last_order); + } + if (context_->iter_space_info.sp_space.size() < loops.size() - 1) { + loops = sch->GetLoops(block_id); + std::vector rb_loops( + loops.begin() + context_->iter_space_info.sp_space.size(), + loops.end()); + sch->Fuse(rb_loops); + } + if (context_->iter_space_info.sp_space.size() > 1) { + loops = sch->GetLoops(block_id); + std::vector sp_loops( + loops.begin(), + loops.begin() + context_->iter_space_info.sp_space.size()); + sch->Fuse(sp_loops); + } + } +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h b/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h new file mode 100644 index 00000000000000..69729ce2bfb8c6 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h @@ -0,0 +1,37 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h" + +namespace cinn { +namespace ir { + +class AlignIterSpaceTactic final : public ScheduleTactic { + public: + void Init(ScheduleContext* context) override; + + void Apply(ir::IRSchedule* sch, const std::string& block_id) override; + + std::string TacticName() const override { return "AlignIterSpaceTactic"; } + + private: + ScheduleContext* context_; +}; + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.cc index fad7097d097877..cec04ba2c1e877 100644 --- a/paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.cc +++ b/paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.cc @@ -24,7 +24,7 @@ namespace cinn { namespace ir { -// [block_name, [var_name, for_node]] +// [block_name, [var, for_node]] using VarToForMap = std::unordered_map>; using IntSet = common::SingleIntervalIntSet; @@ -337,9 +337,9 @@ std::optional AnalyzeCrossType(const VarToForMap& var2for_map, return std::nullopt; } -ArrangeStorageTactic::ArrangeStorageTactic( - const std::unordered_set& output_names) - : output_names_(output_names) {} +void ArrangeStorageTactic::Init(ScheduleContext* context) { + output_names_ = context->output_names; +} void ArrangeStorageTactic::Apply(ir::IRSchedule* sch, const std::string& block_id) { diff --git a/paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h b/paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h index 0371aead7e163e..994108d1662b9d 100644 --- a/paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h +++ b/paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h @@ -23,11 +23,12 @@ namespace ir { class ArrangeStorageTactic final : public ScheduleTactic { public: - explicit ArrangeStorageTactic( - const std::unordered_set& output_names); + void Init(ScheduleContext* context) override; void Apply(ir::IRSchedule* sch, const std::string& block_id) override; + std::string TacticName() const override { return "ArrangeStorageTactic"; } + private: std::unordered_set output_names_; }; diff --git a/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.cc index 81bf65366a9683..dc9d33a731b05b 100644 --- a/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.cc +++ b/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.cc @@ -25,9 +25,10 @@ namespace cinn { namespace ir { -ComputeInlineTactic::ComputeInlineTactic( - const std::unordered_set& output_names, const Target& target) - : output_names_(output_names), target_(target) {} +void ComputeInlineTactic::Init(ScheduleContext* context) { + output_names_ = context->output_names; + target_ = context->target; +} void ComputeInlineTactic::Apply(ir::IRSchedule* sch, const std::string& block_id) { diff --git a/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h b/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h index 71754fdc4adcd0..b03e28d579bc88 100644 --- a/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h +++ b/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h @@ -24,12 +24,12 @@ namespace ir { class ComputeInlineTactic final : public ScheduleTactic { public: - explicit ComputeInlineTactic( - const std::unordered_set& output_names, - const cinn::common::Target& target); + void Init(ScheduleContext* context) override; void Apply(ir::IRSchedule* sch, const std::string& block_id) override; + std::string TacticName() const override { return "ComputeInlineTactic"; } + private: std::unordered_set output_names_; cinn::common::Target target_; diff --git a/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h b/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h index bc2c88c7d5ccdd..4084c69bf493ae 100644 --- a/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h +++ b/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h @@ -15,15 +15,50 @@ #pragma once #include - +#include "paddle/cinn/common/integer_set.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" +#include "paddle/cinn/ir/schedule_block_graph.h" namespace cinn { namespace ir { +struct IterativeSpaceInfo { + enum class AxisType : int { + kSerial = 0, + kCudaThreadX = 1, + kCudaThreadY = 2, + kCudaThreadZ = 3, + kCudaBlockX = 4, + kCudaBlockY = 5, + kCudaBlockZ = 6, + }; + // pure spatial iterative space + std::vector> sp_space; + // reduce or broadcast iterative space + std::vector> rb_space; + // original loop order with same iteration order as the memory order + std::vector memory_consistent_order_space; + // index that transform from memory consistent order to rb last order + // for example: + // the memory consistent order axis is [A, B, C], and the B axis is reduce, + // the rb last order axis is [A, C, B], and rb_last_order is [0, 2, 1]. + std::vector rb_last_order; +}; + +struct ScheduleContext { + std::unordered_set output_names; + ScheduleBlockNode* global_master; + IterativeSpaceInfo iter_space_info; + Target target; +}; + class ScheduleTactic { public: + virtual void Init(ScheduleContext* context) = 0; + virtual void Apply(ir::IRSchedule* sch, const std::string& block_id) = 0; + + virtual std::string TacticName() const = 0; }; } // namespace ir diff --git a/paddle/cinn/ir/ir_analyzer/ir_analyzer.cc b/paddle/cinn/ir/ir_analyzer/ir_analyzer.cc index bdb37d4189ce4d..701d003dbcd2d9 100644 --- a/paddle/cinn/ir/ir_analyzer/ir_analyzer.cc +++ b/paddle/cinn/ir/ir_analyzer/ir_analyzer.cc @@ -373,6 +373,73 @@ std::vector GetIterValuesOfAccess(ir::Expr load_or_store, return iter_values; } +std::unordered_set GetReduceIterVars(ir::Expr block) { + ir::ScheduleBlockRealize* schedule_block_realize = + block.As(); + CHECK_NOTNULL(schedule_block_realize); + ir::ScheduleBlock* schedule_block = + schedule_block_realize->schedule_block.As(); + CHECK_NOTNULL(schedule_block); + std::vector& iter_vars = schedule_block->iter_vars; + std::unordered_set reduce_vars; + for (int i = 0; i < iter_vars.size(); ++i) { + if (iter_vars[i]->is_reduce_axis) { + reduce_vars.insert(iter_vars[i]); + } + } + return reduce_vars; +} + +bool IsReductionSBlock(ir::Expr block) { + ir::ScheduleBlockRealize* s_block_realize = + block.As(); + CHECK_NOTNULL(s_block_realize); + ir::ScheduleBlock* s_block = + s_block_realize->schedule_block.As(); + CHECK_NOTNULL(s_block); + for (const ir::Var& var : s_block->iter_vars) { + if (var->is_reduce_axis) { + return true; + } + } + return false; +} + +bool IsBroadcastSBlock(ir::Expr block) { + ir::ScheduleBlockRealize* s_block_realize = + block.As(); + CHECK_NOTNULL(s_block_realize); + ir::ScheduleBlock* s_block = + s_block_realize->schedule_block.As(); + CHECK_NOTNULL(s_block); + ir::Expr e_store = GetStoreOfSBlock(block); + ir::Store* store = e_store.As(); + CHECK_NOTNULL(store); + ir::Load* load = store->value.As(); + if (load == nullptr) { + return false; + } + // each load index can be found in store index and maintain relative order + for (size_t i = 0; i < load->indices.size(); ++i) { + bool found = false; + for (size_t j = i; j < store->indices.size(); ++j) { + ir::_Var_* load_var = load->indices[i].as_var(); + ir::_Var_* store_var = store->indices[j].as_var(); + if (load_var == nullptr || store_var == nullptr) { + return false; + } + if (load_var->name == store_var->name) { + found = true; + break; + } + } + if (!found) { + return false; + } + } + return load->indices.size() < store->indices.size(); +} + } // namespace analyzer } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/ir_analyzer/ir_analyzer.h b/paddle/cinn/ir/ir_analyzer/ir_analyzer.h index a24e9726b8ce7a..50f3a4eafaf2d2 100644 --- a/paddle/cinn/ir/ir_analyzer/ir_analyzer.h +++ b/paddle/cinn/ir/ir_analyzer/ir_analyzer.h @@ -67,6 +67,12 @@ ir::Expr ReplaceVarWithExpr(const ir::Expr& source, std::vector GetIterValuesOfAccess(ir::Expr load_or_store, ir::Expr block); +std::unordered_set GetReduceIterVars(ir::Expr block); + +bool IsReductionSBlock(ir::Expr block); + +bool IsBroadcastSBlock(ir::Expr block); + } // namespace analyzer } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/schedule/impl/loop_transformation.cc b/paddle/cinn/ir/schedule/impl/loop_transformation.cc index 6f0c1b4f5ae824..1d66697f43d136 100644 --- a/paddle/cinn/ir/schedule/impl/loop_transformation.cc +++ b/paddle/cinn/ir/schedule/impl/loop_transformation.cc @@ -14,6 +14,7 @@ #include "paddle/cinn/ir/schedule/impl/ir_schedule.h" +#include "paddle/cinn/common/cas.h" #include "paddle/cinn/common/integer_set.h" #include "paddle/cinn/common/macros.h" @@ -124,9 +125,9 @@ std::vector DyScheduleImpl::Split(const Expr& loop, if (factor < 1 && factor != -1) is_positive = false; if (factor == -1) ++num_minus1; }); - CHECK((num_minus1 == 1) && is_positive) - << "The paramss in factors of Split on dynamic shape should contains a " - "-1 and the rest of them should be positive!\n"; + CHECK((num_minus1 <= 1) && is_positive) + << "The params in factors of Split on dynamic shape should contains at " + "most one '-1' and the rest of them should be positive!\n"; std::vector new_loop_vars; Expr substitute_value(0); @@ -182,7 +183,8 @@ std::vector DyScheduleImpl::Split(const Expr& loop, std::vector process_factors(factors); Expr prod_size(1); for (auto factor : factors) prod_size = prod_size * Expr(factor); - cinn::common::SymbolicExprAnalyzer analyzer({}); + common::cas_intervals_t var_intervals = {}; + cinn::common::SymbolicExprAnalyzer analyzer(var_intervals); CHECK(analyzer.ProveEQ(tot_extent, prod_size).value_or(false)) << "Product of factors can't be proved to be equal to the extent of " "current for loop!"; From 54b95ae3fe0d17dd31b14e047d29c61271bca78b Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Wed, 3 Jan 2024 14:22:22 +0800 Subject: [PATCH 085/142] [Dynamic Shape] Add helper function MakeGenerateShapeOpAttribute (#60512) * add helper function MakeGenerateShapeOpAttribute * fix complier complaint * Code format --- .../operator/ir/generate_shape_util.cc | 235 ++++++++++++++++++ .../dialect/operator/ir/generate_shape_util.h | 15 ++ ...e_shape_ops_into_generate_shape_op_pass.cc | 209 +--------------- 3 files changed, 261 insertions(+), 198 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc index eef663585a4086..f64bb9269d63a5 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h" +#include #include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_attribute.h" @@ -422,4 +423,238 @@ MakeGetterDimExpr4SymbolName( }; } +namespace { + +bool IsAtomicImpl(int64_t) { return true; } + +bool IsAtomicImpl(const std::string&) { return true; } + +bool IsAtomicImpl(const symbol::Negative&) { return false; } + +bool IsAtomicImpl(const symbol::Reciprocal&) { return false; } + +bool IsAtomicImpl(const symbol::Add&) { return false; } + +bool IsAtomicImpl(const symbol::Mul&) { return false; } + +bool IsAtomicImpl(const symbol::Max&) { return false; } + +bool IsAtomicImpl(const symbol::Min&) { return false; } + +bool IsAtomicImpl(const symbol::Broadcast&) { return false; } + +bool IsAtomic(const symbol::DimExpr& dim_expr) { + return std::visit([](const auto& impl) { return IsAtomicImpl(impl); }, + dim_expr.variant()); +} + +bool InputDimExprsAllSupported( + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, + const std::vector& input_tensors) { + const auto& AllSupported = + [](const std::vector& dim_exprs) -> bool { + for (const auto& dim_expr : dim_exprs) { + if (!IsAtomic(dim_expr)) return false; + } + return true; + }; + for (const auto& input_tensor : input_tensors) { + const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor); + if (!AllSupported(dim_exprs.shape())) return false; + if (dim_exprs.data().has_value()) { + if (!AllSupported(dim_exprs.data().value())) return false; + } + } + return true; +} + +void ConvertDimExprToAttributes(pir::IrContext* ir_context, + const std::vector& dim_exprs, + std::vector* attrs) { + attrs->clear(); + attrs->reserve(dim_exprs.size()); + for (const auto& dim_expr : dim_exprs) { + attrs->emplace_back(ConvertDimExprToAttribute(ir_context, dim_expr)); + } +} + +void CollectSymbolNames(const symbol::DimExpr& dim_expr, + std::set* ret); + +void CollectSymbolNamesImpl(const int64_t& dim_expr, + std::set* ret) { + // do nothing. +} + +void CollectSymbolNamesImpl(const std::string& dim_expr, + std::set* ret) { + ret->insert(dim_expr); +} + +template +void CollectSymbolNamesImplForUnary(const T& dim_expr, + std::set* ret) { + const auto& [operand] = *dim_expr; + CollectSymbolNames(operand, ret); +} + +void CollectSymbolNamesImpl(const symbol::Negative& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForUnary(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Reciprocal& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForUnary(dim_expr, ret); +} + +template +void CollectSymbolNamesImplForVariadic(const T& dim_expr, + std::set* ret) { + const auto& operands = *(dim_expr.operands); + for (const auto& operand : operands) { + CollectSymbolNames(operand, ret); + } +} + +void CollectSymbolNamesImpl(const symbol::Add& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Mul& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Max& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Min& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNamesImpl(const symbol::Broadcast& dim_expr, + std::set* ret) { + CollectSymbolNamesImplForVariadic(dim_expr, ret); +} + +void CollectSymbolNames(const symbol::DimExpr& dim_expr, + std::set* ret) { + return std::visit( + [&](const auto& impl) { return CollectSymbolNamesImpl(impl, ret); }, + dim_expr.variant()); +} + +void CollectSymbolNames(const std::vector& dim_exprs, + std::set* ret) { + for (const auto& dim_expr : dim_exprs) { + CollectSymbolNames(dim_expr, ret); + } +} + +template +void AppendSymbolBindings(const std::vector& dim_exprs, + const std::set& symbol_names, + int in_tensor_idx, + GenerateShapeOp::SymbolBindings* symbol_bindings) { + for (int in_tensor_dim_idx = 0; in_tensor_dim_idx < dim_exprs.size(); + ++in_tensor_dim_idx) { + const auto& dim_expr = dim_exprs.at(in_tensor_dim_idx); + CHECK(IsAtomic(dim_expr)); + if (!dim_expr.isa()) continue; + const auto& sym_name = dim_expr.dyn_cast(); + if (symbol_names.find(sym_name) == symbol_names.end()) continue; + symbol_bindings->emplace_back(SymbolBindingsT{ + /*.symbol_name=*/sym_name, + /*.input_tensor_idx=*/in_tensor_idx, + /*.input_tensor_dim_idx=*/in_tensor_dim_idx, + }); + } +} + +void GenerateSymbolBindings( + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, + const std::vector& input_tensors, + const std::set& symbol_names, + GenerateShapeOp::SymbolBindings* symbol_bindings) { + for (int i = 0; i < input_tensors.size(); ++i) { + const auto& input_tensor = input_tensors.at(i); + const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor); + AppendSymbolBindings( + dim_exprs.shape(), symbol_names, i, symbol_bindings); + if (dim_exprs.data().has_value()) { + AppendSymbolBindings( + dim_exprs.shape(), symbol_names, i, symbol_bindings); + } + } +} + +std::vector GetMinimalInputs( + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, + const std::vector& input_tensors) { + std::unordered_set handdled_dim_exprs; + std::unordered_set first_occurred_input_tensors; + auto TryCollectFirstOcurredInput_tensor = + [&](pir::Value input_tensor, + const std::vector& dim_exprs) { + for (const auto& dim_expr : dim_exprs) { + if (dim_expr.isa()) continue; + if (!handdled_dim_exprs.insert(dim_expr).second) { + first_occurred_input_tensors.insert(input_tensor); + } + } + }; + for (pir::Value input_tensor : input_tensors) { + const auto& shape_or_data_dim_exprs = + ShapeOrDataDimExprs4Value(input_tensor); + if (shape_or_data_dim_exprs.data().has_value()) { + TryCollectFirstOcurredInput_tensor( + input_tensor, shape_or_data_dim_exprs.data().value()); + } + TryCollectFirstOcurredInput_tensor(input_tensor, + shape_or_data_dim_exprs.shape()); + } + std::vector ret{}; + ret.reserve(input_tensors.size()); + for (pir::Value input_tensor : input_tensors) { + if (first_occurred_input_tensors.count(input_tensor) > 0) { + ret.emplace_back(input_tensor); + } + } + return ret; +} + +} // namespace + +bool MakeGenerateShapeOpAttribute( + pir::IrContext* ir_context, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, + const std::vector& out_dim_exprs, + const std::vector& origin_inputs, + std::vector* minial_inputs, + std::vector* output_dim_expr_attrs, + GenerateShapeOp::SymbolBindings* symbol_bindings) { + *minial_inputs = GetMinimalInputs(ShapeOrDataDimExprs4Value, origin_inputs); + if (!InputDimExprsAllSupported(ShapeOrDataDimExprs4Value, *minial_inputs)) { + VLOG(4) << "input dim_exprs are not as simple as symbols, please make sure " + "they are handled by other passes"; + return false; + } + // generate output_dim_expr_attrs + ConvertDimExprToAttributes( + ir_context, out_dim_exprs, /*out*/ output_dim_expr_attrs); + // generate symbol_bindings + std::set symbol_names_in_out_dim_exprs{}; + CollectSymbolNames(out_dim_exprs, &symbol_names_in_out_dim_exprs); + GenerateSymbolBindings(ShapeOrDataDimExprs4Value, + *minial_inputs, + symbol_names_in_out_dim_exprs, + /*out*/ symbol_bindings); + return true; +} + } // namespace cinn::dialect diff --git a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h index ee4ad3c129e6b4..401c240f61e86f 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h +++ b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h @@ -14,7 +14,9 @@ #pragma once +#include #include +#include #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/pir/core/builder.h" #include "paddle/pir/dialect/shape/utils/dim_expr.h" @@ -46,4 +48,17 @@ MakeGetterDimExpr4SymbolName( const std::function& DimExpr4InputDim); +using ShapeOrDataDimExprs4ValueT = + std::function; + +// Returns true if success. +bool MakeGenerateShapeOpAttribute( + pir::IrContext* ir_context, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, + const std::vector& out_dim_exprs, + const std::vector& origin_inputs, + std::vector* minial_inputs, + std::vector* output_dim_expr_attrs, + GenerateShapeOp::SymbolBindings* symbol_bindings); + } // namespace cinn::dialect diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc index 48c7427b402a14..0bff6c7daa886d 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc @@ -38,9 +38,6 @@ namespace ir { namespace { -using ShapeOrDataDimExprs4ValueT = - std::function; - std::vector FindSourceDenseTensorOfDimTensor( pir::Value shape, const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) { @@ -86,209 +83,24 @@ std::vector FindSourceDenseTensorOfDimTensor( return ret; } -bool IsConstant(const std::vector& dim_exprs) { - for (const auto& dim_expr : dim_exprs) { - if (dim_expr.isa()) continue; - return false; - } - return true; -} - -bool IsAtomicImpl(int64_t) { return true; } - -bool IsAtomicImpl(const std::string&) { return true; } - -bool IsAtomicImpl(const symbol::Negative&) { return false; } - -bool IsAtomicImpl(const symbol::Reciprocal&) { return false; } - -bool IsAtomicImpl(const symbol::Add&) { return false; } - -bool IsAtomicImpl(const symbol::Mul&) { return false; } - -bool IsAtomicImpl(const symbol::Max&) { return false; } - -bool IsAtomicImpl(const symbol::Min&) { return false; } - -bool IsAtomicImpl(const symbol::Broadcast&) { return false; } - -bool IsAtomic(const symbol::DimExpr& dim_expr) { - return std::visit([](const auto& impl) { return IsAtomicImpl(impl); }, - dim_expr.variant()); -} - -bool InputDimExprsAllSupported( - const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, - const std::vector& input_tensors) { - const auto& AllSupported = - [](const std::vector& dim_exprs) -> bool { - for (const auto& dim_expr : dim_exprs) { - if (!IsAtomic(dim_expr)) return false; - } - return true; - }; - for (const auto& input_tensor : input_tensors) { - const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor); - if (!AllSupported(dim_exprs.shape())) return false; - if (dim_exprs.data().has_value()) { - if (!AllSupported(dim_exprs.data().value())) return false; - } - } - return true; -} - -void ConvertDimExprToAttributes(pir::IrContext* ir_context, - const std::vector& dim_exprs, - std::vector* attrs) { - attrs->clear(); - attrs->reserve(dim_exprs.size()); - for (const auto& dim_expr : dim_exprs) { - attrs->emplace_back(ConvertDimExprToAttribute(ir_context, dim_expr)); - } -} - -void CollectSymbolNames(const symbol::DimExpr& dim_expr, - std::set* ret); - -void CollectSymbolNamesImpl(const int64_t& dim_expr, - std::set* ret) { - // do nothing. -} - -void CollectSymbolNamesImpl(const std::string& dim_expr, - std::set* ret) { - ret->insert(dim_expr); -} - -template -void CollectSymbolNamesImplForUnary(const T& dim_expr, - std::set* ret) { - const auto& [operand] = *dim_expr; - CollectSymbolNames(operand, ret); -} - -void CollectSymbolNamesImpl(const symbol::Negative& dim_expr, - std::set* ret) { - CollectSymbolNamesImplForUnary(dim_expr, ret); -} - -void CollectSymbolNamesImpl(const symbol::Reciprocal& dim_expr, - std::set* ret) { - CollectSymbolNamesImplForUnary(dim_expr, ret); -} - -template -void CollectSymbolNamesImplForVariadic(const T& dim_expr, - std::set* ret) { - const auto& operands = *(dim_expr.operands); - for (const auto& operand : operands) { - CollectSymbolNames(operand, ret); - } -} - -void CollectSymbolNamesImpl(const symbol::Add& dim_expr, - std::set* ret) { - CollectSymbolNamesImplForVariadic(dim_expr, ret); -} - -void CollectSymbolNamesImpl(const symbol::Mul& dim_expr, - std::set* ret) { - CollectSymbolNamesImplForVariadic(dim_expr, ret); -} - -void CollectSymbolNamesImpl(const symbol::Max& dim_expr, - std::set* ret) { - CollectSymbolNamesImplForVariadic(dim_expr, ret); -} - -void CollectSymbolNamesImpl(const symbol::Min& dim_expr, - std::set* ret) { - CollectSymbolNamesImplForVariadic(dim_expr, ret); -} - -void CollectSymbolNamesImpl(const symbol::Broadcast& dim_expr, - std::set* ret) { - CollectSymbolNamesImplForVariadic(dim_expr, ret); -} - -void CollectSymbolNames(const symbol::DimExpr& dim_expr, - std::set* ret) { - return std::visit( - [&](const auto& impl) { return CollectSymbolNamesImpl(impl, ret); }, - dim_expr.variant()); -} - -void CollectSymbolNames(const std::vector& dim_exprs, - std::set* ret) { - for (const auto& dim_expr : dim_exprs) { - CollectSymbolNames(dim_expr, ret); - } -} - -template -void AppendSymbolBindings(const std::vector& dim_exprs, - const std::set& symbol_names, - int in_tensor_idx, - GenerateShapeOp::SymbolBindings* symbol_bindings) { - for (int in_tensor_dim_idx = 0; in_tensor_dim_idx < dim_exprs.size(); - ++in_tensor_dim_idx) { - const auto& dim_expr = dim_exprs.at(in_tensor_dim_idx); - CHECK(IsAtomic(dim_expr)); - if (!dim_expr.isa()) continue; - const auto& sym_name = dim_expr.dyn_cast(); - if (symbol_names.find(sym_name) == symbol_names.end()) continue; - symbol_bindings->emplace_back(SymbolBindingsT{ - /*.symbol_name=*/sym_name, - /*.input_tensor_idx=*/in_tensor_idx, - /*.input_tensor_dim_idx=*/in_tensor_dim_idx, - }); - } -} - -void GenerateSymbolBindings( - const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, - const std::vector& input_tensors, - const std::set& symbol_names, - GenerateShapeOp::SymbolBindings* symbol_bindings) { - for (int i = 0; i < input_tensors.size(); ++i) { - const auto& input_tensor = input_tensors.at(i); - const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor); - AppendSymbolBindings( - dim_exprs.shape(), symbol_names, i, symbol_bindings); - if (dim_exprs.data().has_value()) { - AppendSymbolBindings( - dim_exprs.shape(), symbol_names, i, symbol_bindings); - } - } -} - bool MakeGenerateShapeOpAttribute( pir::IrContext* ir_context, const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, - const std::vector& input_tensors, pir::Value output_shape, + const std::vector& origin_inputs, + std::vector* minimal_inputs, std::vector* output_dim_expr_attrs, GenerateShapeOp::SymbolBindings* symbol_bindings) { const auto& shape_or_data_dim_exprs = ShapeOrDataDimExprs4Value(output_shape); CHECK(shape_or_data_dim_exprs.data().has_value()); const auto& out_dim_exprs = shape_or_data_dim_exprs.data().value(); - if (IsConstant(out_dim_exprs)) return false; - if (!InputDimExprsAllSupported(ShapeOrDataDimExprs4Value, input_tensors)) { - VLOG(4) << "input dim_exprs are not as simple as symbols, please make sure " - "they are handled by other passes"; - return false; - } - // generate output_dim_expr_attrs - ConvertDimExprToAttributes( - ir_context, out_dim_exprs, /*out*/ output_dim_expr_attrs); - // generate symbol_bindings - std::set symbol_names_in_out_dim_exprs{}; - CollectSymbolNames(out_dim_exprs, &symbol_names_in_out_dim_exprs); - GenerateSymbolBindings(ShapeOrDataDimExprs4Value, - input_tensors, - symbol_names_in_out_dim_exprs, - /*out*/ symbol_bindings); - return true; + return MakeGenerateShapeOpAttribute(ir_context, + ShapeOrDataDimExprs4Value, + out_dim_exprs, + origin_inputs, + minimal_inputs, + output_dim_expr_attrs, + symbol_bindings); } std::optional GetOutOfRewritedGenerateShapeOp( @@ -302,8 +114,9 @@ std::optional GetOutOfRewritedGenerateShapeOp( GenerateShapeOp::SymbolBindings symbol_bindings{}; bool success = MakeGenerateShapeOpAttribute(rewriter->ir_context(), ShapeOrDataDimExprs4Value, - input_tensors, shape, + /*origin inputs*/ input_tensors, + /*minimal inputs*/ &input_tensors, &output_dim_expr_attrs, &symbol_bindings); if (!success) return std::nullopt; From 99af9f790c4d191f5ed2fe9eb285beadddd9318a Mon Sep 17 00:00:00 2001 From: cyber-pioneer <116002591+cyber-pioneer@users.noreply.github.com> Date: Wed, 3 Jan 2024 14:52:34 +0800 Subject: [PATCH 086/142] [Prim][PIR] Set prim gflag for pure cpp (#60505) * inference support decomp * polish code * add decomp base define * add decomp base define2 * change decomp infer * fix symbol overload * fix test case * debug * debug * decomp add debug info * add cpp flag * revert * remove unused flag --- paddle/fluid/prim/utils/utils.cc | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/prim/utils/utils.cc b/paddle/fluid/prim/utils/utils.cc index 80721ccf1049d2..fcee9301a9aa77 100644 --- a/paddle/fluid/prim/utils/utils.cc +++ b/paddle/fluid/prim/utils/utils.cc @@ -17,10 +17,15 @@ #include "paddle/fluid/prim/utils/static/static_global_utils.h" PADDLE_DEFINE_EXPORTED_bool(prim_enabled, false, "enable_prim or not"); +PADDLE_DEFINE_EXPORTED_bool(prim_all, false, "enable prim_all or not"); +PADDLE_DEFINE_EXPORTED_bool(prim_forward, false, "enable prim_forward or not"); +PADDLE_DEFINE_EXPORTED_bool(prim_backward, false, "enable prim_backward not"); + namespace paddle { namespace prim { bool PrimCommonUtils::IsBwdPrimEnabled() { - return StaticCompositeContext::Instance().IsBwdPrimEnabled(); + bool res = StaticCompositeContext::Instance().IsBwdPrimEnabled(); + return res || FLAGS_prim_all || FLAGS_prim_backward; } void PrimCommonUtils::SetBwdPrimEnabled(bool enable_prim) { @@ -36,7 +41,8 @@ void PrimCommonUtils::SetEagerPrimEnabled(bool enable_prim) { } bool PrimCommonUtils::IsFwdPrimEnabled() { - return StaticCompositeContext::Instance().IsFwdPrimEnabled(); + bool res = StaticCompositeContext::Instance().IsFwdPrimEnabled(); + return res || FLAGS_prim_all || FLAGS_prim_forward; } void PrimCommonUtils::SetFwdPrimEnabled(bool enable_prim) { From 8e280ea5858a7d264818458a6ac984e2e9750272 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Wed, 3 Jan 2024 15:23:07 +0800 Subject: [PATCH 087/142] [PIR] Refine and fix pir exe (#60443) * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix --- .../control_flow/if_instruction.cc | 10 ++-- .../control_flow/while_instruction.cc | 22 +++------ .../instruction/instruction_util.cc | 49 +++++++++++++++++++ .../instruction/instruction_util.h | 6 +++ test/legacy_test/test_while_op.py | 30 +----------- 5 files changed, 71 insertions(+), 46 deletions(-) diff --git a/paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.cc b/paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.cc index ef856c7fc01627..624ce6221cd5e7 100644 --- a/paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.cc @@ -71,9 +71,8 @@ IfInstruction::IfInstruction(size_t id, GetInputIds(op, *value_exec_info, &inputs); auto true_outside_inputs = GetExternalInputs(&true_branch_block, *value_exec_info, &inputs); - std::vector false_outside_inputs; auto& false_branch_block = if_op.false_block(); - false_outside_inputs = + auto false_outside_inputs = GetExternalInputs(&false_branch_block, *value_exec_info, &inputs); // NOTE(chenxi67): the variable corresponding to container value if a // Type. It will recursively get the ID of internal @@ -107,9 +106,14 @@ IfInstruction::IfInstruction(size_t id, } } InsertTuplePushContinerToOuts(&true_branch_block, *value_exec_info, &outputs); - InsertTuplePushContinerToOuts( &if_op.false_block(), *value_exec_info, &outputs); + + InsertInplacedExternalInputsToOuts( + &true_branch_block, true_outside_inputs, *value_exec_info, &outputs); + InsertInplacedExternalInputsToOuts( + &false_branch_block, false_outside_inputs, *value_exec_info, &outputs); + for (auto& item : outputs) { auto& var_vec = item.second; for (auto it = var_vec.begin(); it != var_vec.end();) { diff --git a/paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.cc b/paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.cc index a9f23fd60e176f..4f444b35f9b20d 100644 --- a/paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.cc @@ -47,33 +47,25 @@ WhileInstruction::WhileInstruction( ValueExecutionInfo* parent_exe_info, interpreter::ExecutionConfig execution_config) : InstructionBase(id, place) { - op_ = op; - VLOG(6) << "finish process dist attributes"; - - SetKernelType(AnalyseOpFuncType(op, place)); - VLOG(6) << "finish process analyse kernel type"; - - VLOG(6) << "finish process inputs outputs index"; - PADDLE_ENFORCE(op->isa(), phi::errors::PreconditionNotMet( "While instruction only support While op")); - + op_ = op; auto while_op = op->dyn_cast(); + body_block_ = &while_op.body(); - cond_var_ = parent_exe_info->GetVarByValue(while_op.operand_source(0)); + SetKernelType(AnalyseOpFuncType(op, place)); + VLOG(6) << "finish process analyse kernel type"; + cond_var_ = parent_exe_info->GetVarByValue(while_op.operand_source(0)); for (size_t i = 1; i < while_op.num_operands(); ++i) { inputs_.push_back( parent_exe_info->GetVarByValue(while_op.operand_source(i))); } - for (size_t i = 0; i < while_op.num_results(); ++i) { outputs_.push_back(parent_exe_info->GetVarByValue(while_op.result(i))); } - body_block_ = &while_op.body(); - std::unordered_map> inputs; GetInputIds(op, *parent_exe_info, &inputs); auto body_outside_inputs = @@ -94,8 +86,10 @@ WhileInstruction::WhileInstruction( std::vector outputs_id = GetValueIds(value, *parent_exe_info); outputs.emplace(value, outputs_id); } - InsertTuplePushContinerToOuts(body_block_, *parent_exe_info, &outputs); } + InsertTuplePushContinerToOuts(body_block_, *parent_exe_info, &outputs); + InsertInplacedExternalInputsToOuts( + body_block_, body_outside_inputs, *parent_exe_info, &outputs); SetOutputs(outputs); Scope* body_scope = &(parent_exe_info->GetScope()->NewScope()); diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc index 55dc035b6e0638..9a28eeb39f9bc3 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc +++ b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc @@ -346,6 +346,55 @@ void InsertTuplePushContinerToOuts( } } +void InsertInplacedExternalInputsToOuts( + pir::Block* block, + const std::vector& external_inputs, + const ValueExecutionInfo& value_exec_info, + std::unordered_map>* outputs) { + for (auto& op : *block) { + if (op.attributes().count("is_inplace") != 0 && + op.attributes() + .at("is_inplace") + .dyn_cast() + .data()) { + std::string op_name = op.name(); + if (op.attributes().count("op_name")) { + op_name = op.attributes() + .at("op_name") + .dyn_cast() + .AsString(); + } + pir::OpInfo op_info = + pir::IrContext::Instance()->GetRegisteredOpInfo(op_name); + paddle::dialect::OpYamlInfoParser yaml_parser( + op_info.GetInterfaceImpl() + ->get_op_info_(op_name), + paddle::dialect::IsLegacyOp(op_name)); + + for (size_t i = 0; i < op.num_results(); ++i) { + pir::Value value = op.result(i); + if (!IsInvalid(value)) { + VLOG(8) << "Number " << i << " result of " << op_name + << " is not invalid, so skip build a variable."; + continue; + } + std::string value_name = yaml_parser.OutputNames()[i]; + if (yaml_parser.HasInplace(value_name)) { + const std::string& inplace_name = yaml_parser.InplaceName(value_name); + pir::Value inplace_value = + op.operand_source(yaml_parser.InputName2Id().at(inplace_name)); + if (std::find(external_inputs.begin(), + external_inputs.end(), + inplace_value) != external_inputs.end()) { + outputs->emplace(value, + GetValueIds(inplace_value, value_exec_info)); + } + } + } + } + } +} + bool GetCondData(const phi::DenseTensor& cond) { if (paddle::platform::is_cpu_place(cond.place())) { return cond.data()[0]; diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_util.h b/paddle/fluid/framework/new_executor/instruction/instruction_util.h index e42a945a63c9f8..810209a9150c69 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_util.h +++ b/paddle/fluid/framework/new_executor/instruction/instruction_util.h @@ -59,6 +59,12 @@ void InsertTuplePushContinerToOuts( const ValueExecutionInfo& value_exec_info, std::unordered_map>* outputs); +void InsertInplacedExternalInputsToOuts( + pir::Block* block, + const std::vector& external_inputs, + const ValueExecutionInfo& value_exec_info, + std::unordered_map>* outputs); + bool GetCondData(const phi::DenseTensor& cond); } // namespace framework } // namespace paddle diff --git a/test/legacy_test/test_while_op.py b/test/legacy_test/test_while_op.py index 63affc80d7cf4a..a8d79af8a93b6c 100644 --- a/test/legacy_test/test_while_op.py +++ b/test/legacy_test/test_while_op.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import unittest import numpy from utils import compare_legacy_with_pt import paddle -from paddle import base, set_flags +from paddle import base from paddle.base import core from paddle.base.backward import append_backward from paddle.base.executor import Executor @@ -82,7 +81,6 @@ def simple_net(self): loss = paddle.mean(sum_result) return loss, sum_result - # TODO(winter-wang): Support pir test in (FLAGS_enable_pir_in_executor_trace_run = False && FLAGS_new_executor_serial_run == False). @test_with_pir_api def test_simple_net(self): main_program = base.Program() @@ -92,14 +90,6 @@ def test_simple_net(self): append_backward(loss) - if in_pir_mode(): - flag_1 = "FLAGS_enable_pir_in_executor_trace_run" - flag_2 = "FLAGS_new_executor_serial_run" - os.environ[flag_1] = 'True' - os.environ[flag_2] = 'True' - set_flags({flag_1: True}) - set_flags({flag_2: True}) - cpu = core.CPUPlace() exe = Executor(cpu) d = [] @@ -111,14 +101,8 @@ def test_simple_net(self): feed={'d0': d[0], 'd1': d[1], 'd2': d[2]}, fetch_list=[sum_result], ) - if in_pir_mode(): - del os.environ[flag_1] - del os.environ[flag_2] - set_flags({flag_1: False}) - set_flags({flag_2: False}) self.assertAlmostEqual(numpy.sum(d), numpy.sum(outs[0]), delta=0.01) - # TODO(winter-wang): Support pir test in (FLAGS_enable_pir_in_executor_trace_run = False && FLAGS_new_executor_serial_run == False). @test_with_pir_api def test_simple_net_forward(self): main_program = base.Program() @@ -136,20 +120,8 @@ def test_simple_net_forward(self): for i in range(3): d.append(numpy.random.random(size=[10]).astype('float32')) - if in_pir_mode(): - flag_1 = "FLAGS_enable_pir_in_executor_trace_run" - flag_2 = "FLAGS_new_executor_serial_run" - os.environ[flag_1] = 'True' - os.environ[flag_2] = 'True' - set_flags({flag_1: True}) - set_flags({flag_2: True}) for _ in range(2): exe.run(binary, feed={'d0': d[0], 'd1': d[1], 'd2': d[2]}) - if in_pir_mode(): - del os.environ[flag_1] - del os.environ[flag_2] - set_flags({flag_1: False}) - set_flags({flag_2: False}) @compare_legacy_with_pt @test_with_pir_api From 4b22e3e6bfd5b66d1d570e70646ef5e49cb618f8 Mon Sep 17 00:00:00 2001 From: Vigi Zhang Date: Wed, 3 Jan 2024 17:01:56 +0800 Subject: [PATCH 088/142] update 2023 security advisory, test=document_fix (#60527) --- security/README.md | 36 ++++++++++++++++++------- security/README_cn.md | 38 ++++++++++++++++++++------- security/README_ja.md | 36 ++++++++++++++++++------- security/advisory/pdsa-2023-004_cn.md | 2 +- security/advisory/pdsa-2023-006.md | 31 ++++++++++++++++++++++ security/advisory/pdsa-2023-006_cn.md | 31 ++++++++++++++++++++++ security/advisory/pdsa-2023-007.md | 31 ++++++++++++++++++++++ security/advisory/pdsa-2023-007_cn.md | 31 ++++++++++++++++++++++ security/advisory/pdsa-2023-008.md | 31 ++++++++++++++++++++++ security/advisory/pdsa-2023-008_cn.md | 31 ++++++++++++++++++++++ security/advisory/pdsa-2023-009.md | 31 ++++++++++++++++++++++ security/advisory/pdsa-2023-009_cn.md | 31 ++++++++++++++++++++++ security/advisory/pdsa-2023-010.md | 33 +++++++++++++++++++++++ security/advisory/pdsa-2023-010_cn.md | 33 +++++++++++++++++++++++ security/advisory/pdsa-2023-011.md | 32 ++++++++++++++++++++++ security/advisory/pdsa-2023-011_cn.md | 32 ++++++++++++++++++++++ security/advisory/pdsa-2023-012.md | 35 ++++++++++++++++++++++++ security/advisory/pdsa-2023-012_cn.md | 35 ++++++++++++++++++++++++ security/advisory/pdsa-2023-013.md | 32 ++++++++++++++++++++++ security/advisory/pdsa-2023-013_cn.md | 32 ++++++++++++++++++++++ security/advisory/pdsa-2023-014.md | 32 ++++++++++++++++++++++ security/advisory/pdsa-2023-014_cn.md | 32 ++++++++++++++++++++++ security/advisory/pdsa-2023-015.md | 33 +++++++++++++++++++++++ security/advisory/pdsa-2023-015_cn.md | 33 +++++++++++++++++++++++ security/advisory/pdsa-2023-016.md | 32 ++++++++++++++++++++++ security/advisory/pdsa-2023-016_cn.md | 32 ++++++++++++++++++++++ security/advisory/pdsa-2023-017.md | 33 +++++++++++++++++++++++ security/advisory/pdsa-2023-017_cn.md | 33 +++++++++++++++++++++++ security/advisory/pdsa-2023-018.md | 32 ++++++++++++++++++++++ security/advisory/pdsa-2023-018_cn.md | 32 ++++++++++++++++++++++ security/advisory/pdsa-2023-019.md | 35 ++++++++++++++++++++++++ security/advisory/pdsa-2023-019_cn.md | 35 ++++++++++++++++++++++++ security/advisory/pdsa-2023-020.md | 28 ++++++++++++++++++++ security/advisory/pdsa-2023-020_cn.md | 28 ++++++++++++++++++++ security/advisory/pdsa-2023-021.md | 33 +++++++++++++++++++++++ security/advisory/pdsa-2023-021_cn.md | 33 +++++++++++++++++++++++ security/advisory/pdsa-2023-022.md | 30 +++++++++++++++++++++ security/advisory/pdsa-2023-022_cn.md | 30 +++++++++++++++++++++ security/advisory/pdsa-2023-023.md | 28 ++++++++++++++++++++ security/advisory/pdsa-2023-023_cn.md | 28 ++++++++++++++++++++ 40 files changed, 1227 insertions(+), 29 deletions(-) create mode 100644 security/advisory/pdsa-2023-006.md create mode 100644 security/advisory/pdsa-2023-006_cn.md create mode 100644 security/advisory/pdsa-2023-007.md create mode 100644 security/advisory/pdsa-2023-007_cn.md create mode 100644 security/advisory/pdsa-2023-008.md create mode 100644 security/advisory/pdsa-2023-008_cn.md create mode 100644 security/advisory/pdsa-2023-009.md create mode 100644 security/advisory/pdsa-2023-009_cn.md create mode 100644 security/advisory/pdsa-2023-010.md create mode 100644 security/advisory/pdsa-2023-010_cn.md create mode 100644 security/advisory/pdsa-2023-011.md create mode 100644 security/advisory/pdsa-2023-011_cn.md create mode 100644 security/advisory/pdsa-2023-012.md create mode 100644 security/advisory/pdsa-2023-012_cn.md create mode 100644 security/advisory/pdsa-2023-013.md create mode 100644 security/advisory/pdsa-2023-013_cn.md create mode 100644 security/advisory/pdsa-2023-014.md create mode 100644 security/advisory/pdsa-2023-014_cn.md create mode 100644 security/advisory/pdsa-2023-015.md create mode 100644 security/advisory/pdsa-2023-015_cn.md create mode 100644 security/advisory/pdsa-2023-016.md create mode 100644 security/advisory/pdsa-2023-016_cn.md create mode 100644 security/advisory/pdsa-2023-017.md create mode 100644 security/advisory/pdsa-2023-017_cn.md create mode 100644 security/advisory/pdsa-2023-018.md create mode 100644 security/advisory/pdsa-2023-018_cn.md create mode 100644 security/advisory/pdsa-2023-019.md create mode 100644 security/advisory/pdsa-2023-019_cn.md create mode 100644 security/advisory/pdsa-2023-020.md create mode 100644 security/advisory/pdsa-2023-020_cn.md create mode 100644 security/advisory/pdsa-2023-021.md create mode 100644 security/advisory/pdsa-2023-021_cn.md create mode 100644 security/advisory/pdsa-2023-022.md create mode 100644 security/advisory/pdsa-2023-022_cn.md create mode 100644 security/advisory/pdsa-2023-023.md create mode 100644 security/advisory/pdsa-2023-023_cn.md diff --git a/security/README.md b/security/README.md index 01559632d7dd45..9bcc28bc318956 100644 --- a/security/README.md +++ b/security/README.md @@ -7,12 +7,30 @@ We regularly publish security advisories about using PaddlePaddle. *Note*: In conjunction with these security advisories, we strongly encourage PaddlePaddle users to read and understand PaddlePaddle's security model as outlined in [SECURITY.md](../SECURITY.md). -| Advisory Number | Type | Versions affected | Reported by | Additional Information | -|----------------------------------------------|------------------------------------------------------|:-----------------:|------------------------------------------------------------------|------------------------| -| [PDSA-2023-005](./advisory/pdsa-2023-005.md) | Command injection in fs.py | < 2.5.0 | Xiaochen Guo from Huazhong University of Science and Technology | | -| [PDSA-2023-004](./advisory/pdsa-2023-004.md) | FPE in paddle.linalg.matrix_power | < 2.5.0 | Tong Liu of ShanghaiTech University | | -| [PDSA-2023-003](./advisory/pdsa-2023-003.md) | Heap buffer overflow in paddle.trace | < 2.5.0 | Tong Liu of ShanghaiTech University | | -| [PDSA-2023-002](./advisory/pdsa-2023-002.md) | Null pointer dereference in paddle.flip | < 2.5.0 | Tong Liu of ShanghaiTech University | | -| [PDSA-2023-001](./advisory/pdsa-2023-001.md) | Use after free in paddle.diagonal | < 2.5.0 | Tong Liu of ShanghaiTech University | | -| [PDSA-2022-002](./advisory/pdsa-2022-002.md) | Code injection in paddle.audio.functional.get_window | = 2.4.0-rc0 | Tong Liu of ShanghaiTech University | | -| [PDSA-2022-001](./advisory/pdsa-2022-001.md) | OOB read in gather_tree | < 2.4 | Wang Xuan(王旋) of Qihoo 360 AIVul Team | | +| Advisory Number | Type | Versions affected | Reported by | Additional Information | +|----------------------------------------------|------------------------------------------------------|:-----------------:|-----------------------------------------------------------------|------------------------| +| [PDSA-2023-023](./advisory/pdsa-2023-023.md) | Command injection in convert_shape_compare | < 2.6.0 | leeya_bug | | +| [PDSA-2023-022](./advisory/pdsa-2023-022.md) | FPE in paddle.argmin and paddle.argmax | < 2.6.0 | Peng Zhou (zpbrent) from Shanghai University | | +| [PDSA-2023-021](./advisory/pdsa-2023-021.md) | Null pointer dereference in paddle.crop | < 2.6.0 | Peng Zhou (zpbrent) from Shanghai University | | +| [PDSA-2023-020](./advisory/pdsa-2023-020.md) | Command injection in _wget_download | < 2.6.0 | huntr.com | | +| [PDSA-2023-019](./advisory/pdsa-2023-019.md) | Command injection in get_online_pass_interval | < 2.6.0 | huntr.com | | +| [PDSA-2023-018](./advisory/pdsa-2023-018.md) | Heap buffer overflow in paddle.repeat_interleave | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-017](./advisory/pdsa-2023-017.md) | FPE in paddle.amin | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-016](./advisory/pdsa-2023-016.md) | Stack overflow in paddle.linalg.lu_unpack | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-015](./advisory/pdsa-2023-015.md) | FPE in paddle.lerp | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-014](./advisory/pdsa-2023-014.md) | FPE in paddle.topk | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-013](./advisory/pdsa-2023-013.md) | Stack overflow in paddle.searchsorted | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-012](./advisory/pdsa-2023-012.md) | Segfault in paddle.put_along_axis | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-011](./advisory/pdsa-2023-011.md) | Null pointer dereference in paddle.nextafter | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-010](./advisory/pdsa-2023-010.md) | Segfault in paddle.mode | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-009](./advisory/pdsa-2023-009.md) | FPE in paddle.linalg.eig | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-008](./advisory/pdsa-2023-008.md) | Segfault in paddle.dot | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-007](./advisory/pdsa-2023-007.md) | FPE in paddle.linalg.matrix_rank | < 2.6.0 | Tong Liu of ShanghaiTech University | | +| [PDSA-2023-006](./advisory/pdsa-2023-006.md) | FPE in paddle.nanmedian | < 2.6.0 | Tong Liu of ShanghaiTech University | | +| [PDSA-2023-005](./advisory/pdsa-2023-005.md) | Command injection in fs.py | < 2.5.0 | Xiaochen Guo from Huazhong University of Science and Technology | | +| [PDSA-2023-004](./advisory/pdsa-2023-004.md) | FPE in paddle.linalg.matrix_power | < 2.5.0 | Tong Liu of ShanghaiTech University | | +| [PDSA-2023-003](./advisory/pdsa-2023-003.md) | Heap buffer overflow in paddle.trace | < 2.5.0 | Tong Liu of ShanghaiTech University | | +| [PDSA-2023-002](./advisory/pdsa-2023-002.md) | Null pointer dereference in paddle.flip | < 2.5.0 | Tong Liu of ShanghaiTech University | | +| [PDSA-2023-001](./advisory/pdsa-2023-001.md) | Use after free in paddle.diagonal | < 2.5.0 | Tong Liu of ShanghaiTech University | | +| [PDSA-2022-002](./advisory/pdsa-2022-002.md) | Code injection in paddle.audio.functional.get_window | = 2.4.0-rc0 | Tong Liu of ShanghaiTech University | | +| [PDSA-2022-001](./advisory/pdsa-2022-001.md) | OOB read in gather_tree | < 2.4 | Wang Xuan(王旋) of Qihoo 360 AIVul Team | | diff --git a/security/README_cn.md b/security/README_cn.md index 49223df8844f39..0cd8a9743b5be2 100644 --- a/security/README_cn.md +++ b/security/README_cn.md @@ -4,15 +4,33 @@ -注:我们非常建议飞桨用户阅读和理解[SECURITY_cn.md](../SECURITY_cn.md)所介绍的飞桨安全模型,以便更好地了解此安全公告。 +*注*:我们非常建议飞桨用户阅读和理解[SECURITY_cn.md](../SECURITY_cn.md)所介绍的飞桨安全模型,以便更好地了解此安全公告。 -| 安全公告编号 | 类型 | 受影响版本 | 报告者 | 备注 | -|-------------------------------------------------|------------------------------------------------------|:------------:|-----------------------------------------------------------------|----| -| [PDSA-2023-005](./advisory/pdsa-2023-005_cn.md) | Command injection in fs.py | < 2.5.0 | Xiaochen Guo from Huazhong University of Science and Technology | | -| [PDSA-2023-004](./advisory/pdsa-2023-004_cn.md) | FPE in paddle.linalg.matrix_power | < 2.5.0 | Tong Liu of ShanghaiTech University | | -| [PDSA-2023-003](./advisory/pdsa-2023-003_cn.md) | Heap buffer overflow in paddle.trace | < 2.5.0 | Tong Liu of ShanghaiTech University | | -| [PDSA-2023-002](./advisory/pdsa-2023-002_cn.md) | Null pointer dereference in paddle.flip | < 2.5.0 | Tong Liu of ShanghaiTech University | | -| [PDSA-2023-001](./advisory/pdsa-2023-001_cn.md) | Use after free in paddle.diagonal | < 2.5.0 | Tong Liu of ShanghaiTech University | | -| [PDSA-2022-002](./advisory/pdsa-2022-002_cn.md) | Code injection in paddle.audio.functional.get_window | = 2.4.0-rc0 | Tong Liu of ShanghaiTech University | | -| [PDSA-2022-001](./advisory/pdsa-2022-001_cn.md) | OOB read in gather_tree | < 2.4 | Wang Xuan(王旋) of Qihoo 360 AIVul Team | | +| 安全公告编号 | 类型 | 受影响版本 | 报告者 | 备注 | +|-------------------------------------------------|------------------------------------------------------|:-----------:|-----------------------------------------------------------------|----| +| [PDSA-2023-023](./advisory/pdsa-2023-023_cn.md) | Command injection in convert_shape_compare | < 2.6.0 | leeya_bug | | +| [PDSA-2023-022](./advisory/pdsa-2023-022_cn.md) | FPE in paddle.argmin and paddle.argmax | < 2.6.0 | Peng Zhou (zpbrent) from Shanghai University | | +| [PDSA-2023-021](./advisory/pdsa-2023-021_cn.md) | Null pointer dereference in paddle.crop | < 2.6.0 | Peng Zhou (zpbrent) from Shanghai University | | +| [PDSA-2023-020](./advisory/pdsa-2023-020_cn.md) | Command injection in _wget_download | < 2.6.0 | huntr.com | | +| [PDSA-2023-019](./advisory/pdsa-2023-019_cn.md) | Command injection in get_online_pass_interval | < 2.6.0 | huntr.com | | +| [PDSA-2023-018](./advisory/pdsa-2023-018_cn.md) | Heap buffer overflow in paddle.repeat_interleave | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-017](./advisory/pdsa-2023-017_cn.md) | FPE in paddle.amin | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-016](./advisory/pdsa-2023-016_cn.md) | Stack overflow in paddle.linalg.lu_unpack | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-015](./advisory/pdsa-2023-015_cn.md) | FPE in paddle.lerp | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-014](./advisory/pdsa-2023-014_cn.md) | FPE in paddle.topk | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-013](./advisory/pdsa-2023-013_cn.md) | Stack overflow in paddle.searchsorted | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-012](./advisory/pdsa-2023-012_cn.md) | Segfault in paddle.put_along_axis | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-011](./advisory/pdsa-2023-011_cn.md) | Null pointer dereference in paddle.nextafter | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-010](./advisory/pdsa-2023-010_cn.md) | Segfault in paddle.mode | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-009](./advisory/pdsa-2023-009_cn.md) | FPE in paddle.linalg.eig | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-008](./advisory/pdsa-2023-008_cn.md) | Segfault in paddle.dot | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-007](./advisory/pdsa-2023-007_cn.md) | FPE in paddle.linalg.matrix_rank | < 2.6.0 | Tong Liu of ShanghaiTech University | | +| [PDSA-2023-006](./advisory/pdsa-2023-006_cn.md) | FPE in paddle.nanmedian | < 2.6.0 | Tong Liu of ShanghaiTech University | | +| [PDSA-2023-005](./advisory/pdsa-2023-005_cn.md) | Command injection in fs.py | < 2.5.0 | Xiaochen Guo from Huazhong University of Science and Technology | | +| [PDSA-2023-004](./advisory/pdsa-2023-004_cn.md) | FPE in paddle.linalg.matrix_power | < 2.5.0 | Tong Liu of ShanghaiTech University | | +| [PDSA-2023-003](./advisory/pdsa-2023-003_cn.md) | Heap buffer overflow in paddle.trace | < 2.5.0 | Tong Liu of ShanghaiTech University | | +| [PDSA-2023-002](./advisory/pdsa-2023-002_cn.md) | Null pointer dereference in paddle.flip | < 2.5.0 | Tong Liu of ShanghaiTech University | | +| [PDSA-2023-001](./advisory/pdsa-2023-001_cn.md) | Use after free in paddle.diagonal | < 2.5.0 | Tong Liu of ShanghaiTech University | | +| [PDSA-2022-002](./advisory/pdsa-2022-002_cn.md) | Code injection in paddle.audio.functional.get_window | = 2.4.0-rc0 | Tong Liu of ShanghaiTech University | | +| [PDSA-2022-001](./advisory/pdsa-2022-001_cn.md) | OOB read in gather_tree | < 2.4 | Wang Xuan(王旋) of Qihoo 360 AIVul Team | | diff --git a/security/README_ja.md b/security/README_ja.md index 4bd0b984c5834c..1841cfe8aa6fb1 100644 --- a/security/README_ja.md +++ b/security/README_ja.md @@ -7,12 +7,30 @@ PaddlePaddle の使用に関するセキュリティ勧告を定期的に発表 *注*: これらのセキュリティ勧告と併せ、PaddlePaddle ユーザーには [SECURITY.md](../SECURITY_ja.md) に記載されている PaddlePaddle のセキュリティモデルを読み、理解することを強くお勧めします。 -| アドバイザリー番号 | タイプ | 対象バージョン | 報告者 | 追加情報 | -|----------------------------------------------|------------------------------------------------------|:-----------------:|------------------------------------------------------------------|------------------------| -| [PDSA-2023-005](./advisory/pdsa-2023-005.md) | Command injection in fs.py | < 2.5.0 | Xiaochen Guo from Huazhong University of Science and Technology | | -| [PDSA-2023-004](./advisory/pdsa-2023-004.md) | FPE in paddle.linalg.matrix_power | < 2.5.0 | Tong Liu of ShanghaiTech University | | -| [PDSA-2023-003](./advisory/pdsa-2023-003.md) | Heap buffer overflow in paddle.trace | < 2.5.0 | Tong Liu of ShanghaiTech University | | -| [PDSA-2023-002](./advisory/pdsa-2023-002.md) | Null pointer dereference in paddle.flip | < 2.5.0 | Tong Liu of ShanghaiTech University | | -| [PDSA-2023-001](./advisory/pdsa-2023-001.md) | Use after free in paddle.diagonal | < 2.5.0 | Tong Liu of ShanghaiTech University | | -| [PDSA-2022-002](./advisory/pdsa-2022-002.md) | Code injection in paddle.audio.functional.get_window | = 2.4.0-rc0 | Tong Liu of ShanghaiTech University | | -| [PDSA-2022-001](./advisory/pdsa-2022-001.md) | OOB read in gather_tree | < 2.4 | Wang Xuan(王旋) of Qihoo 360 AIVul Team | | +| アドバイザリー番号 | タイプ | 対象バージョン | 報告者 | 追加情報 | +|----------------------------------------------|------------------------------------------------------|:-----------:|-----------------------------------------------------------------|------| +| [PDSA-2023-023](./advisory/pdsa-2023-023.md) | Command injection in convert_shape_compare | < 2.6.0 | leeya_bug | | +| [PDSA-2023-022](./advisory/pdsa-2023-022.md) | FPE in paddle.argmin and paddle.argmax | < 2.6.0 | Peng Zhou (zpbrent) from Shanghai University | | +| [PDSA-2023-021](./advisory/pdsa-2023-021.md) | Null pointer dereference in paddle.crop | < 2.6.0 | Peng Zhou (zpbrent) from Shanghai University | | +| [PDSA-2023-020](./advisory/pdsa-2023-020.md) | Command injection in _wget_download | < 2.6.0 | huntr.com | | +| [PDSA-2023-019](./advisory/pdsa-2023-019.md) | Command injection in get_online_pass_interval | < 2.6.0 | huntr.com | | +| [PDSA-2023-018](./advisory/pdsa-2023-018.md) | Heap buffer overflow in paddle.repeat_interleave | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-017](./advisory/pdsa-2023-017.md) | FPE in paddle.amin | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-016](./advisory/pdsa-2023-016.md) | Stack overflow in paddle.linalg.lu_unpack | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-015](./advisory/pdsa-2023-015.md) | FPE in paddle.lerp | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-014](./advisory/pdsa-2023-014.md) | FPE in paddle.topk | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-013](./advisory/pdsa-2023-013.md) | Stack overflow in paddle.searchsorted | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-012](./advisory/pdsa-2023-012.md) | Segfault in paddle.put_along_axis | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-011](./advisory/pdsa-2023-011.md) | Null pointer dereference in paddle.nextafter | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-010](./advisory/pdsa-2023-010.md) | Segfault in paddle.mode | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-009](./advisory/pdsa-2023-009.md) | FPE in paddle.linalg.eig | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-008](./advisory/pdsa-2023-008.md) | Segfault in paddle.dot | < 2.6.0 | Tong Liu of CAS-IIE | | +| [PDSA-2023-007](./advisory/pdsa-2023-007.md) | FPE in paddle.linalg.matrix_rank | < 2.6.0 | Tong Liu of ShanghaiTech University | | +| [PDSA-2023-006](./advisory/pdsa-2023-006.md) | FPE in paddle.nanmedian | < 2.6.0 | Tong Liu of ShanghaiTech University | | +| [PDSA-2023-005](./advisory/pdsa-2023-005.md) | Command injection in fs.py | < 2.5.0 | Xiaochen Guo from Huazhong University of Science and Technology | | +| [PDSA-2023-004](./advisory/pdsa-2023-004.md) | FPE in paddle.linalg.matrix_power | < 2.5.0 | Tong Liu of ShanghaiTech University | | +| [PDSA-2023-003](./advisory/pdsa-2023-003.md) | Heap buffer overflow in paddle.trace | < 2.5.0 | Tong Liu of ShanghaiTech University | | +| [PDSA-2023-002](./advisory/pdsa-2023-002.md) | Null pointer dereference in paddle.flip | < 2.5.0 | Tong Liu of ShanghaiTech University | | +| [PDSA-2023-001](./advisory/pdsa-2023-001.md) | Use after free in paddle.diagonal | < 2.5.0 | Tong Liu of ShanghaiTech University | | +| [PDSA-2022-002](./advisory/pdsa-2022-002.md) | Code injection in paddle.audio.functional.get_window | = 2.4.0-rc0 | Tong Liu of ShanghaiTech University | | +| [PDSA-2022-001](./advisory/pdsa-2022-001.md) | OOB read in gather_tree | < 2.4 | Wang Xuan(王旋) of Qihoo 360 AIVul Team | | diff --git a/security/advisory/pdsa-2023-004_cn.md b/security/advisory/pdsa-2023-004_cn.md index c31c4da4f8728f..11f22a45aca11c 100644 --- a/security/advisory/pdsa-2023-004_cn.md +++ b/security/advisory/pdsa-2023-004_cn.md @@ -6,7 +6,7 @@ CVE-2023-38672 ### 影响 -当张量包含纬度值为0的情况,`paddle.linalg.matrix_power`会触发除0异常,导致程序运行时崩溃,PoC代码如下: +当张量包含维度值为0的情况,`paddle.linalg.matrix_power`会触发除0异常,导致程序运行时崩溃,PoC代码如下: ```python import paddle diff --git a/security/advisory/pdsa-2023-006.md b/security/advisory/pdsa-2023-006.md new file mode 100644 index 00000000000000..4997760cd5000a --- /dev/null +++ b/security/advisory/pdsa-2023-006.md @@ -0,0 +1,31 @@ +## PDSA-2023-006: FPE in paddle.nanmedian + +### CVE Number + +CVE-2023-38674 + +### Impact + +When `x` dim calculates `stride` to 0, `paddle.nanmedian` triggers FPE by `numel / stride`. The PoC is as follows: + +```python +import paddle +import numpy as np + +x = np.random.uniform(0,0,[0,0,0,0,0]).astype(np.float32) +x = paddle.to_tensor(x) +paddle.nanmedian(x) +``` + +### Patches + +We have patched the issue in commit [9bb6c669206c4bcc3ce3f6daf8a55650e190c1a1](https://github.com/PaddlePaddle/Paddle/pull/55644/commits/9bb6c669206c4bcc3ce3f6daf8a55650e190c1a1). +The fix will be included in PaddlePaddle 2.6.0. + +### For more information + +Please consult [our security guide](../../SECURITY.md) for more information regarding the security model and how to contact us with issues and questions. + +### Attribution + +This vulnerability has been reported by Tong Liu of ShanghaiTech University. diff --git a/security/advisory/pdsa-2023-006_cn.md b/security/advisory/pdsa-2023-006_cn.md new file mode 100644 index 00000000000000..e8ac803c033d6a --- /dev/null +++ b/security/advisory/pdsa-2023-006_cn.md @@ -0,0 +1,31 @@ +## PDSA-2023-006: FPE in paddle.nanmedian + +### CVE编号 + +CVE-2023-38674 + +### 影响 + +当由`x`的dim计算的`stride`为0时,`paddle.nanmedian`会由`numel / stride`触发除0异常,PoC代码如下: + +```python +import paddle +import numpy as np + +x = np.random.uniform(0,0,[0,0,0,0,0]).astype(np.float32) +x = paddle.to_tensor(x) +paddle.nanmedian(x) +``` + +### 补丁 + +我们在commit [9bb6c669206c4bcc3ce3f6daf8a55650e190c1a1](https://github.com/PaddlePaddle/Paddle/pull/55644/commits/9bb6c669206c4bcc3ce3f6daf8a55650e190c1a1)中对此问题进行了补丁。 +修复将包含在飞桨2.6.0版本当中。 + +### 更多信息 + +请参考我们的[安全指南](../../SECURITY_cn.md)以获得更多关于安全的信息,以及如何与我们联系问题。 + +### 贡献者 + +此漏洞由 Tong Liu of ShanghaiTech University 提交。 diff --git a/security/advisory/pdsa-2023-007.md b/security/advisory/pdsa-2023-007.md new file mode 100644 index 00000000000000..f61223193cabfe --- /dev/null +++ b/security/advisory/pdsa-2023-007.md @@ -0,0 +1,31 @@ +## PDSA-2023-007: FPE in paddle.linalg.matrix_rank + +### CVE Number + +CVE-2023-38675 + +### Impact + +When `x` dim calculates `rows` or `cols` to 0, `paddle.linalg.matrix_rank` triggers FPE by `numel / (rows * cols)`. The PoC is as follows: + +```python +import paddle +import numpy as np + +x = np.random.uniform(0,0,[0,0,0,0,0]).astype(np.float32) +x = paddle.to_tensor(x) +paddle.linalg.matrix_rank(x) +``` + +### Patches + +We have patched the issue in commit [9bb6c669206c4bcc3ce3f6daf8a55650e190c1a1](https://github.com/PaddlePaddle/Paddle/pull/55644/commits/9bb6c669206c4bcc3ce3f6daf8a55650e190c1a1). +The fix will be included in PaddlePaddle 2.6.0. + +### For more information + +Please consult [our security guide](../../SECURITY.md) for more information regarding the security model and how to contact us with issues and questions. + +### Attribution + +This vulnerability has been reported by Tong Liu of ShanghaiTech University. diff --git a/security/advisory/pdsa-2023-007_cn.md b/security/advisory/pdsa-2023-007_cn.md new file mode 100644 index 00000000000000..0572aa1767b36d --- /dev/null +++ b/security/advisory/pdsa-2023-007_cn.md @@ -0,0 +1,31 @@ +## PDSA-2023-007: FPE in paddle.linalg.matrix_rank + +### CVE编号 + +CVE-2023-38675 + +### 影响 + +当由`x`的dim计算的`rows`或者`cols`为0时,`paddle.linalg.matrix_rank`会由`numel / (rows * cols)`触发除0异常,PoC代码如下: + +```python +import paddle +import numpy as np + +x = np.random.uniform(0,0,[0,0,0,0,0]).astype(np.float32) +x = paddle.to_tensor(x) +paddle.linalg.matrix_rank(x) +``` + +### 补丁 + +我们在commit [9bb6c669206c4bcc3ce3f6daf8a55650e190c1a1](https://github.com/PaddlePaddle/Paddle/pull/55644/commits/9bb6c669206c4bcc3ce3f6daf8a55650e190c1a1)中对此问题进行了补丁。 +修复将包含在飞桨2.6.0版本当中。 + +### 更多信息 + +请参考我们的[安全指南](../../SECURITY_cn.md)以获得更多关于安全的信息,以及如何与我们联系问题。 + +### 贡献者 + +此漏洞由 Tong Liu of ShanghaiTech University 提交。 diff --git a/security/advisory/pdsa-2023-008.md b/security/advisory/pdsa-2023-008.md new file mode 100644 index 00000000000000..8994abd90fc23e --- /dev/null +++ b/security/advisory/pdsa-2023-008.md @@ -0,0 +1,31 @@ +## PDSA-2023-008: Segfault in paddle.dot + +### CVE Number + +CVE-2023-38676 + +### Impact + +Segfault occurs when `x` and `y` shape is 0 in `paddle.dot`. The PoC is as follows: + +```python +import paddle +import numpy as np + +x = paddle.to_tensor(np.random.uniform(-6666666, 100000000, [0, 0]).astype(np.float32)) +y = paddle.to_tensor(np.random.uniform(-6666666, 100000000, [0, 0]).astype(np.float32)) +paddle.dot(x, y) +``` + +### Patches + +We have patched the issue in commit [19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc](https://github.com/PaddlePaddle/Paddle/commit/19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc). +The fix will be included in PaddlePaddle 2.6.0. + +### For more information + +Please consult [our security guide](../../SECURITY.md) for more information regarding the security model and how to contact us with issues and questions. + +### Attribution + +This vulnerability has been reported by Tong Liu of CAS-IIE. diff --git a/security/advisory/pdsa-2023-008_cn.md b/security/advisory/pdsa-2023-008_cn.md new file mode 100644 index 00000000000000..92052de2f38090 --- /dev/null +++ b/security/advisory/pdsa-2023-008_cn.md @@ -0,0 +1,31 @@ +## PDSA-2023-008: Segfault in paddle.dot + +### CVE编号 + +CVE-2023-38676 + +### 影响 + +在`paddle.dot`中当`x`和`y`的shape为0时,将造成segfault,PoC代码如下: + +```python +import paddle +import numpy as np + +x = paddle.to_tensor(np.random.uniform(-6666666, 100000000, [0, 0]).astype(np.float32)) +y = paddle.to_tensor(np.random.uniform(-6666666, 100000000, [0, 0]).astype(np.float32)) +paddle.dot(x, y) +``` + +### 补丁 + +我们在commit [19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc](https://github.com/PaddlePaddle/Paddle/commit/19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc)中对此问题进行了补丁。 +修复将包含在飞桨2.6.0版本当中。 + +### 更多信息 + +请参考我们的[安全指南](../../SECURITY_cn.md)以获得更多关于安全的信息,以及如何与我们联系问题。 + +### 贡献者 + +此漏洞由 Tong Liu of CAS-IIE 提交。 diff --git a/security/advisory/pdsa-2023-009.md b/security/advisory/pdsa-2023-009.md new file mode 100644 index 00000000000000..2f0450f9eb4e32 --- /dev/null +++ b/security/advisory/pdsa-2023-009.md @@ -0,0 +1,31 @@ +## PDSA-2023-009: FPE in paddle.linalg.eig + +### CVE Number + +CVE-2023-38677 + +### Impact + +When tensor dims contain 0, `paddle.linalg.eig` will trigger a float point exception. The PoC is as follows: + +```python +import paddle +import numpy as np + +x = paddle.to_tensor(np.random.uniform(-6666666, 100000000, [3, 6, 0, 2, 2]).astype(np.float32)) + +paddle.linalg.eig(x) +``` + +### Patches + +We have patched the issue in commit [19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc](https://github.com/PaddlePaddle/Paddle/commit/19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc). +The fix will be included in PaddlePaddle 2.6.0. + +### For more information + +Please consult [our security guide](../../SECURITY.md) for more information regarding the security model and how to contact us with issues and questions. + +### Attribution + +This vulnerability has been reported by Tong Liu of CAS-IIE. diff --git a/security/advisory/pdsa-2023-009_cn.md b/security/advisory/pdsa-2023-009_cn.md new file mode 100644 index 00000000000000..a212a2320c8902 --- /dev/null +++ b/security/advisory/pdsa-2023-009_cn.md @@ -0,0 +1,31 @@ +## PDSA-2023-009: FPE in paddle.linalg.eig + +### CVE编号 + +CVE-2023-38677 + +### 影响 + +当张量包含维度值为0的情况,`paddle.linalg.eig`会触发除0异常,PoC代码如下: + +```python +import paddle +import numpy as np + +x = paddle.to_tensor(np.random.uniform(-6666666, 100000000, [3, 6, 0, 2, 2]).astype(np.float32)) + +paddle.linalg.eig(x) +``` + +### 补丁 + +我们在commit [19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc](https://github.com/PaddlePaddle/Paddle/commit/19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc)中对此问题进行了补丁。 +修复将包含在飞桨2.6.0版本当中。 + +### 更多信息 + +请参考我们的[安全指南](../../SECURITY_cn.md)以获得更多关于安全的信息,以及如何与我们联系问题。 + +### 贡献者 + +此漏洞由 Tong Liu of CAS-IIE 提交。 diff --git a/security/advisory/pdsa-2023-010.md b/security/advisory/pdsa-2023-010.md new file mode 100644 index 00000000000000..3f1c65f6c91c4f --- /dev/null +++ b/security/advisory/pdsa-2023-010.md @@ -0,0 +1,33 @@ +## PDSA-2023-010: Segfault in paddle.mode + +### CVE Number + +CVE-2023-38678 + +### Impact + +Invalid `axis` and `dim_size` may cause `paddle.mode` segfault . The PoC is as follows: + +```python +import paddle +import numpy as np + +paddle.mode( + x=paddle.to_tensor(np.random.uniform(-6666666, 100000000, []).astype(np.float64)), + axis=paddle.to_tensor(np.random.uniform(-2147483648, 2147483647, []).astype(np.int32)), + keepdim=True +) +``` + +### Patches + +We have patched the issue in commit [19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc](https://github.com/PaddlePaddle/Paddle/commit/19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc). +The fix will be included in PaddlePaddle 2.6.0. + +### For more information + +Please consult [our security guide](../../SECURITY.md) for more information regarding the security model and how to contact us with issues and questions. + +### Attribution + +This vulnerability has been reported by Tong Liu of CAS-IIE. diff --git a/security/advisory/pdsa-2023-010_cn.md b/security/advisory/pdsa-2023-010_cn.md new file mode 100644 index 00000000000000..f72cd8af856360 --- /dev/null +++ b/security/advisory/pdsa-2023-010_cn.md @@ -0,0 +1,33 @@ +## PDSA-2023-010: Segfault in paddle.mode + +### CVE编号 + +CVE-2023-38678 + +### 影响 + +接收异常的`axis`和`dim_size`可能会造成`paddle.mode`发生segfault,PoC代码如下: + +```python +import paddle +import numpy as np + +paddle.mode( + x=paddle.to_tensor(np.random.uniform(-6666666, 100000000, []).astype(np.float64)), + axis=paddle.to_tensor(np.random.uniform(-2147483648, 2147483647, []).astype(np.int32)), + keepdim=True +) +``` + +### 补丁 + +我们在commit [19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc](https://github.com/PaddlePaddle/Paddle/commit/19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc)中对此问题进行了补丁。 +修复将包含在飞桨2.6.0版本当中。 + +### 更多信息 + +请参考我们的[安全指南](../../SECURITY_cn.md)以获得更多关于安全的信息,以及如何与我们联系问题。 + +### 贡献者 + +此漏洞由 Tong Liu of CAS-IIE 提交。 diff --git a/security/advisory/pdsa-2023-011.md b/security/advisory/pdsa-2023-011.md new file mode 100644 index 00000000000000..da7985dede7d00 --- /dev/null +++ b/security/advisory/pdsa-2023-011.md @@ -0,0 +1,32 @@ +## PDSA-2023-011: Null pointer dereference in paddle.nextafter + +### CVE Number + +CVE-2023-52302 + +### Impact + +Null pointer dereference in `paddle.nextafter` when tensor dims are invalid . The PoC is as follows: + +```python +import paddle +import numpy as np + +paddle.nextafter( + x=paddle.to_tensor(np.random.uniform(-6666666, 100000000, [1, 2]).astype(np.float32)), + y=paddle.to_tensor(np.random.uniform(-6666666, 100000000, [0, 0, 0, 0, 0]).astype(np.float32)) +) +``` + +### Patches + +We have patched the issue in commit [19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc](https://github.com/PaddlePaddle/Paddle/commit/19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc). +The fix will be included in PaddlePaddle 2.6.0. + +### For more information + +Please consult [our security guide](../../SECURITY.md) for more information regarding the security model and how to contact us with issues and questions. + +### Attribution + +This vulnerability has been reported by Tong Liu of CAS-IIE. diff --git a/security/advisory/pdsa-2023-011_cn.md b/security/advisory/pdsa-2023-011_cn.md new file mode 100644 index 00000000000000..71440ac2c5d9a2 --- /dev/null +++ b/security/advisory/pdsa-2023-011_cn.md @@ -0,0 +1,32 @@ +## PDSA-2023-011: Null pointer dereference in paddle.nextafter + +### CVE编号 + +CVE-2023-52302 + +### 影响 + +输入张量的维度异常时,`paddle.nextafter`会引发空指针解引用,PoC代码如下: + +```python +import paddle +import numpy as np + +paddle.nextafter( + x=paddle.to_tensor(np.random.uniform(-6666666, 100000000, [1, 2]).astype(np.float32)), + y=paddle.to_tensor(np.random.uniform(-6666666, 100000000, [0, 0, 0, 0, 0]).astype(np.float32)) +) +``` + +### 补丁 + +我们在commit [19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc](https://github.com/PaddlePaddle/Paddle/commit/19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc)中对此问题进行了补丁。 +修复将包含在飞桨2.6.0版本当中。 + +### 更多信息 + +请参考我们的[安全指南](../../SECURITY_cn.md)以获得更多关于安全的信息,以及如何与我们联系问题。 + +### 贡献者 + +此漏洞由 Tong Liu of CAS-IIE 提交。 diff --git a/security/advisory/pdsa-2023-012.md b/security/advisory/pdsa-2023-012.md new file mode 100644 index 00000000000000..f659d356154474 --- /dev/null +++ b/security/advisory/pdsa-2023-012.md @@ -0,0 +1,35 @@ +## PDSA-2023-012: Segfault in paddle.put_along_axis + +### CVE Number + +CVE-2023-52303 + +### Impact + +Segfault in `paddle.put_along_axis` when tensor dims are invalid . The PoC is as follows: + +```python +import paddle +import numpy as np + +paddle.put_along_axis( + arr=paddle.to_tensor(np.random.uniform(-2147483648, 2147483647, [1]).astype(np.int32)), + indices=paddle.to_tensor(np.random.uniform(-9223372036854775808, 9223372036854775807, [1]).astype(np.int64)), + values=paddle.to_tensor(np.random.uniform(-2147483648, 2147483647, []).astype(np.int32)), + axis=0, + reduce="assign" +) +``` + +### Patches + +We have patched the issue in commit [19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc](https://github.com/PaddlePaddle/Paddle/commit/19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc). +The fix will be included in PaddlePaddle 2.6.0. + +### For more information + +Please consult [our security guide](../../SECURITY.md) for more information regarding the security model and how to contact us with issues and questions. + +### Attribution + +This vulnerability has been reported by Tong Liu of CAS-IIE. diff --git a/security/advisory/pdsa-2023-012_cn.md b/security/advisory/pdsa-2023-012_cn.md new file mode 100644 index 00000000000000..234961cded2359 --- /dev/null +++ b/security/advisory/pdsa-2023-012_cn.md @@ -0,0 +1,35 @@ +## PDSA-2023-012: Segfault in paddle.put_along_axis + +### CVE编号 + +CVE-2023-52303 + +### 影响 + +输入张量的维度异常时,`paddle.put_along_axis`会引发segfault,PoC代码如下: + +```python +import paddle +import numpy as np + +paddle.put_along_axis( + arr=paddle.to_tensor(np.random.uniform(-2147483648, 2147483647, [1]).astype(np.int32)), + indices=paddle.to_tensor(np.random.uniform(-9223372036854775808, 9223372036854775807, [1]).astype(np.int64)), + values=paddle.to_tensor(np.random.uniform(-2147483648, 2147483647, []).astype(np.int32)), + axis=0, + reduce="assign" +) +``` + +### 补丁 + +我们在commit [19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc](https://github.com/PaddlePaddle/Paddle/commit/19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc)中对此问题进行了补丁。 +修复将包含在飞桨2.6.0版本当中。 + +### 更多信息 + +请参考我们的[安全指南](../../SECURITY_cn.md)以获得更多关于安全的信息,以及如何与我们联系问题。 + +### 贡献者 + +此漏洞由 Tong Liu of CAS-IIE 提交。 diff --git a/security/advisory/pdsa-2023-013.md b/security/advisory/pdsa-2023-013.md new file mode 100644 index 00000000000000..53deab6f3c346a --- /dev/null +++ b/security/advisory/pdsa-2023-013.md @@ -0,0 +1,32 @@ +## PDSA-2023-013: Stack overflow in paddle.searchsorted + +### CVE Number + +CVE-2023-52304 + +### Impact + +Invalid shapes cuase stack buffer overflow in `paddle.searchsorted`. The PoC is as follows: + +```python +import paddle +import numpy as np + +sorted_sequence = paddle.to_tensor(np.array(0)) +values = paddle.to_tensor(np.random.uniform(-10, 10, []).astype(np.float64)) + +paddle.searchsorted(sorted_sequence, values, out_int32=True, right=True) +``` + +### Patches + +We have patched the issue in commit [19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc](https://github.com/PaddlePaddle/Paddle/commit/19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc). +The fix will be included in PaddlePaddle 2.6.0. + +### For more information + +Please consult [our security guide](../../SECURITY.md) for more information regarding the security model and how to contact us with issues and questions. + +### Attribution + +This vulnerability has been reported by Tong Liu of CAS-IIE. diff --git a/security/advisory/pdsa-2023-013_cn.md b/security/advisory/pdsa-2023-013_cn.md new file mode 100644 index 00000000000000..c5210242f651fd --- /dev/null +++ b/security/advisory/pdsa-2023-013_cn.md @@ -0,0 +1,32 @@ +## PDSA-2023-013: Stack overflow in paddle.searchsorted + +### CVE编号 + +CVE-2023-52304 + +### 影响 + +不正确的shapes会引发`paddle.searchsorted`栈溢出,PoC代码如下: + +```python +import paddle +import numpy as np + +sorted_sequence = paddle.to_tensor(np.array(0)) +values = paddle.to_tensor(np.random.uniform(-10, 10, []).astype(np.float64)) + +paddle.searchsorted(sorted_sequence, values, out_int32=True, right=True) +``` + +### 补丁 + +我们在commit [19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc](https://github.com/PaddlePaddle/Paddle/commit/19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc)中对此问题进行了补丁。 +修复将包含在飞桨2.6.0版本当中。 + +### 更多信息 + +请参考我们的[安全指南](../../SECURITY_cn.md)以获得更多关于安全的信息,以及如何与我们联系问题。 + +### 贡献者 + +此漏洞由 Tong Liu of CAS-IIE 提交。 diff --git a/security/advisory/pdsa-2023-014.md b/security/advisory/pdsa-2023-014.md new file mode 100644 index 00000000000000..1792f3b21e8fac --- /dev/null +++ b/security/advisory/pdsa-2023-014.md @@ -0,0 +1,32 @@ +## PDSA-2023-014: FPE in paddle.topk + +### CVE Number + +CVE-2023-52305 + +### Impact + +FPE in `paddle.topk` when `x` and `k` dims not correct. The PoC is as follows: + +```python +import paddle +import numpy as np + +x = paddle.to_tensor(np.random.uniform(-6666666, 100000000, [6, 2, 1, 4, 2, 0]).astype(np.float64)) +k = paddle.to_tensor(np.array(1).astype(np.int32)) + +paddle.topk(x, k, axis=2,largest=False, sorted=True) +``` + +### Patches + +We have patched the issue in commit [19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc](https://github.com/PaddlePaddle/Paddle/commit/19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc). +The fix will be included in PaddlePaddle 2.6.0. + +### For more information + +Please consult [our security guide](../../SECURITY.md) for more information regarding the security model and how to contact us with issues and questions. + +### Attribution + +This vulnerability has been reported by Tong Liu of CAS-IIE. diff --git a/security/advisory/pdsa-2023-014_cn.md b/security/advisory/pdsa-2023-014_cn.md new file mode 100644 index 00000000000000..d1be63be148d21 --- /dev/null +++ b/security/advisory/pdsa-2023-014_cn.md @@ -0,0 +1,32 @@ +## PDSA-2023-014: FPE in paddle.topk + +### CVE编号 + +CVE-2023-52305 + +### 影响 + +当`x`和`k`的dims不符合要求时,可能导致`paddle.topk`除0异常,PoC代码如下: + +```python +import paddle +import numpy as np + +x = paddle.to_tensor(np.random.uniform(-6666666, 100000000, [6, 2, 1, 4, 2, 0]).astype(np.float64)) +k = paddle.to_tensor(np.array(1).astype(np.int32)) + +paddle.topk(x, k, axis=2,largest=False, sorted=True) +``` + +### 补丁 + +我们在commit [19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc](https://github.com/PaddlePaddle/Paddle/commit/19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc)中对此问题进行了补丁。 +修复将包含在飞桨2.6.0版本当中。 + +### 更多信息 + +请参考我们的[安全指南](../../SECURITY_cn.md)以获得更多关于安全的信息,以及如何与我们联系问题。 + +### 贡献者 + +此漏洞由 Tong Liu of CAS-IIE 提交。 diff --git a/security/advisory/pdsa-2023-015.md b/security/advisory/pdsa-2023-015.md new file mode 100644 index 00000000000000..6830516e0505b6 --- /dev/null +++ b/security/advisory/pdsa-2023-015.md @@ -0,0 +1,33 @@ +## PDSA-2023-015: FPE in paddle.lerp + +### CVE Number + +CVE-2023-52306 + +### Impact + +FPE in `paddle.lerp` when tensor shape is invalid. The PoC is as follows: + +```python +import paddle +import numpy as np + +x = paddle.to_tensor(np.random.uniform(-6666666, 100000000, []).astype(np.float64)) +y = paddle.to_tensor(np.random.uniform(-6666666, 100000000, [4, 0, 0, 2, 6]).astype(np.float64)) +weight = paddle.to_tensor(np.random.uniform(-6666666, 100000000, []).astype(np.float64)) + +paddle.lerp(x, y, weight) +``` + +### Patches + +We have patched the issue in commit [19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc](https://github.com/PaddlePaddle/Paddle/commit/19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc). +The fix will be included in PaddlePaddle 2.6.0. + +### For more information + +Please consult [our security guide](../../SECURITY.md) for more information regarding the security model and how to contact us with issues and questions. + +### Attribution + +This vulnerability has been reported by Tong Liu of CAS-IIE. diff --git a/security/advisory/pdsa-2023-015_cn.md b/security/advisory/pdsa-2023-015_cn.md new file mode 100644 index 00000000000000..7daa17bfff490b --- /dev/null +++ b/security/advisory/pdsa-2023-015_cn.md @@ -0,0 +1,33 @@ +## PDSA-2023-015: FPE in paddle.lerp + +### CVE编号 + +CVE-2023-52306 + +### 影响 + +不合法的张量shape可能导致`paddle.lerp`除0异常,PoC代码如下: + +```python +import paddle +import numpy as np + +x = paddle.to_tensor(np.random.uniform(-6666666, 100000000, []).astype(np.float64)) +y = paddle.to_tensor(np.random.uniform(-6666666, 100000000, [4, 0, 0, 2, 6]).astype(np.float64)) +weight = paddle.to_tensor(np.random.uniform(-6666666, 100000000, []).astype(np.float64)) + +paddle.lerp(x, y, weight) +``` + +### 补丁 + +我们在commit [19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc](https://github.com/PaddlePaddle/Paddle/commit/19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc)中对此问题进行了补丁。 +修复将包含在飞桨2.6.0版本当中。 + +### 更多信息 + +请参考我们的[安全指南](../../SECURITY_cn.md)以获得更多关于安全的信息,以及如何与我们联系问题。 + +### 贡献者 + +此漏洞由 Tong Liu of CAS-IIE 提交。 diff --git a/security/advisory/pdsa-2023-016.md b/security/advisory/pdsa-2023-016.md new file mode 100644 index 00000000000000..2c6e93e3f87717 --- /dev/null +++ b/security/advisory/pdsa-2023-016.md @@ -0,0 +1,32 @@ +## PDSA-2023-016: Stack overflow in paddle.linalg.lu_unpack + +### CVE Number + +CVE-2023-52307 + +### Impact + +Invalid shapes cuase stack buffer overflow in `paddle.linalg.lu_unpack`. The PoC is as follows: + +```python +import paddle +import numpy as np + +x = paddle.to_tensor(np.random.uniform(-6666666, 100000000, [1, 6, 4, 8, 2]).astype(np.float32)) +y = paddle.to_tensor(np.random.uniform(-2147483648, 2147483647, []).astype(np.int32)) + +paddle.linalg.lu_unpack(x, y, True, True) +``` + +### Patches + +We have patched the issue in commit [10093636a10f29f73f13729b33570d8cafd58fb6](https://github.com/PaddlePaddle/Paddle/pull/56311/commits/10093636a10f29f73f13729b33570d8cafd58fb6). +The fix will be included in PaddlePaddle 2.6.0. + +### For more information + +Please consult [our security guide](../../SECURITY.md) for more information regarding the security model and how to contact us with issues and questions. + +### Attribution + +This vulnerability has been reported by Tong Liu of CAS-IIE. diff --git a/security/advisory/pdsa-2023-016_cn.md b/security/advisory/pdsa-2023-016_cn.md new file mode 100644 index 00000000000000..cdad03e02dce4a --- /dev/null +++ b/security/advisory/pdsa-2023-016_cn.md @@ -0,0 +1,32 @@ +## PDSA-2023-016: Stack overflow in paddle.linalg.lu_unpack + +### CVE编号 + +CVE-2023-52307 + +### 影响 + +不正确的shapes会引发`paddle.linalg.lu_unpack`栈溢出,PoC代码如下: + +```python +import paddle +import numpy as np + +x = paddle.to_tensor(np.random.uniform(-6666666, 100000000, [1, 6, 4, 8, 2]).astype(np.float32)) +y = paddle.to_tensor(np.random.uniform(-2147483648, 2147483647, []).astype(np.int32)) + +paddle.linalg.lu_unpack(x, y, True, True) +``` + +### 补丁 + +我们在commit [10093636a10f29f73f13729b33570d8cafd58fb6](https://github.com/PaddlePaddle/Paddle/pull/56311/commits/10093636a10f29f73f13729b33570d8cafd58fb6)中对此问题进行了补丁。 +修复将包含在飞桨2.6.0版本当中。 + +### 更多信息 + +请参考我们的[安全指南](../../SECURITY_cn.md)以获得更多关于安全的信息,以及如何与我们联系问题。 + +### 贡献者 + +此漏洞由 Tong Liu of CAS-IIE 提交。 diff --git a/security/advisory/pdsa-2023-017.md b/security/advisory/pdsa-2023-017.md new file mode 100644 index 00000000000000..2d65947f7be858 --- /dev/null +++ b/security/advisory/pdsa-2023-017.md @@ -0,0 +1,33 @@ +## PDSA-2023-017: FPE in paddle.amin + +### CVE Number + +CVE-2023-52308 + +### Impact + +FPE in `paddle.amin` when `x` has invalid dims. The PoC is as follows: + +```python +import paddle +import numpy as np + +paddle.amin( + x=paddle.to_tensor(np.random.uniform(-6666666, 100000000, [0, 0, 6, 3]).astype(np.float32)), + axis=-1, + keepdim=True +) +``` + +### Patches + +We have patched the issue in commit [19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc](https://github.com/PaddlePaddle/Paddle/commit/19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc). +The fix will be included in PaddlePaddle 2.6.0. + +### For more information + +Please consult [our security guide](../../SECURITY.md) for more information regarding the security model and how to contact us with issues and questions. + +### Attribution + +This vulnerability has been reported by Tong Liu of CAS-IIE. diff --git a/security/advisory/pdsa-2023-017_cn.md b/security/advisory/pdsa-2023-017_cn.md new file mode 100644 index 00000000000000..ac04896e1ffeb4 --- /dev/null +++ b/security/advisory/pdsa-2023-017_cn.md @@ -0,0 +1,33 @@ +## PDSA-2023-017: FPE in paddle.amin + +### CVE编号 + +CVE-2023-52308 + +### 影响 + +当`x` dims不符合要求时,可能导致`paddle.amin`除0异常,PoC代码如下: + +```python +import paddle +import numpy as np + +paddle.amin( + x=paddle.to_tensor(np.random.uniform(-6666666, 100000000, [0, 0, 6, 3]).astype(np.float32)), + axis=-1, + keepdim=True +) +``` + +### 补丁 + +我们在commit [19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc](https://github.com/PaddlePaddle/Paddle/commit/19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc)中对此问题进行了补丁。 +修复将包含在飞桨2.6.0版本当中。 + +### 更多信息 + +请参考我们的[安全指南](../../SECURITY_cn.md)以获得更多关于安全的信息,以及如何与我们联系问题。 + +### 贡献者 + +此漏洞由 Tong Liu of CAS-IIE 提交。 diff --git a/security/advisory/pdsa-2023-018.md b/security/advisory/pdsa-2023-018.md new file mode 100644 index 00000000000000..6dbec29738b2f8 --- /dev/null +++ b/security/advisory/pdsa-2023-018.md @@ -0,0 +1,32 @@ +## PDSA-2023-018: Heap buffer overflow in paddle.repeat_interleave + +### CVE Number + +CVE-2023-52309 + +### Impact + +Heap buffer overflow in `paddle.repeat_interleave` by using invalid params. The PoC is as follows: + +```python +import paddle +import numpy as np + +x = paddle.to_tensor(np.random.uniform(-6666666, 100000000, [4, 4, 8, 3, 2, 4]).astype(np.float64)) +repeats = paddle.to_tensor(np.random.uniform(-2147483648, 2147483647, [2, 1]).astype(np.int32)) + +paddle.repeat_interleave(x, repeats, axis=-2) +``` + +### Patches + +We have patched the issue in commit [19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc](https://github.com/PaddlePaddle/Paddle/commit/19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc). +The fix will be included in PaddlePaddle 2.6.0. + +### For more information + +Please consult [our security guide](../../SECURITY.md) for more information regarding the security model and how to contact us with issues and questions. + +### Attribution + +This vulnerability has been reported by Tong Liu of CAS-IIE. diff --git a/security/advisory/pdsa-2023-018_cn.md b/security/advisory/pdsa-2023-018_cn.md new file mode 100644 index 00000000000000..9680099b47d83c --- /dev/null +++ b/security/advisory/pdsa-2023-018_cn.md @@ -0,0 +1,32 @@ +## PDSA-2023-018: Heap buffer overflow in paddle.repeat_interleave + +### CVE编号 + +CVE-2023-52309 + +### 影响 + +非法的参数可能导致`paddle.repeat_interleave`堆溢出,PoC代码如下: + +```python +import paddle +import numpy as np + +x = paddle.to_tensor(np.random.uniform(-6666666, 100000000, [4, 4, 8, 3, 2, 4]).astype(np.float64)) +repeats = paddle.to_tensor(np.random.uniform(-2147483648, 2147483647, [2, 1]).astype(np.int32)) + +paddle.repeat_interleave(x, repeats, axis=-2) +``` + +### 补丁 + +我们在commit [19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc](https://github.com/PaddlePaddle/Paddle/commit/19da5c0c4d8c5e4dfef2a92e24141c3f51884dcc)中对此问题进行了补丁。 +修复将包含在飞桨2.6.0版本当中。 + +### 更多信息 + +请参考我们的[安全指南](../../SECURITY_cn.md)以获得更多关于安全的信息,以及如何与我们联系问题。 + +### 贡献者 + +此漏洞由 Tong Liu of CAS-IIE 提交。 diff --git a/security/advisory/pdsa-2023-019.md b/security/advisory/pdsa-2023-019.md new file mode 100644 index 00000000000000..c496895190bc81 --- /dev/null +++ b/security/advisory/pdsa-2023-019.md @@ -0,0 +1,35 @@ +## PDSA-2023-019: Command injection in get_online_pass_interval + +### CVE Number + +CVE-2023-52310 + +### Impact + +Command injection in `get_online_pass_interval` which could lead to execute arbitrary commands. The PoC is as follows: + +```python +from paddle.incubate.distributed.fleet.fleet_util import FleetUtil + +fleet_util = FleetUtil() +online_pass_interval = fleet_util.get_online_pass_interval( + days="{20190720..20190729}", + hours="9;touch /home/test/aaaa", + split_interval=5, + split_per_pass=2, + is_data_hourly_placed=False +) +``` + +### Patches + +We have patched the issue in commit [1aae481dfd7d2055c801563e254f1484b974b68e](https://github.com/PaddlePaddle/Paddle/pull/60023/commits/1aae481dfd7d2055c801563e254f1484b974b68e). +The fix will be included in PaddlePaddle 2.6.0. + +### For more information + +Please consult [our security guide](../../SECURITY.md) for more information regarding the security model and how to contact us with issues and questions. + +### Attribution + +This vulnerability has been reported by huntr.com. diff --git a/security/advisory/pdsa-2023-019_cn.md b/security/advisory/pdsa-2023-019_cn.md new file mode 100644 index 00000000000000..8bab64810ad416 --- /dev/null +++ b/security/advisory/pdsa-2023-019_cn.md @@ -0,0 +1,35 @@ +## PDSA-2023-019: Command injection in get_online_pass_interval + +### CVE编号 + +CVE-2023-52310 + +### 影响 + +`get_online_pass_interval`存在命令注入漏洞,可造成任意命令执行,PoC代码如下: + +```python +from paddle.incubate.distributed.fleet.fleet_util import FleetUtil + +fleet_util = FleetUtil() +online_pass_interval = fleet_util.get_online_pass_interval( + days="{20190720..20190729}", + hours="9;touch /home/test/aaaa", + split_interval=5, + split_per_pass=2, + is_data_hourly_placed=False +) +``` + +### 补丁 + +我们在commit [1aae481dfd7d2055c801563e254f1484b974b68e](https://github.com/PaddlePaddle/Paddle/pull/60023/commits/1aae481dfd7d2055c801563e254f1484b974b68e)中对此问题进行了补丁。 +修复将包含在飞桨2.6.0版本当中。 + +### 更多信息 + +请参考我们的[安全指南](../../SECURITY_cn.md)以获得更多关于安全的信息,以及如何与我们联系问题。 + +### 贡献者 + +此漏洞由 huntr.com 提交。 diff --git a/security/advisory/pdsa-2023-020.md b/security/advisory/pdsa-2023-020.md new file mode 100644 index 00000000000000..ed3a5966d6ca60 --- /dev/null +++ b/security/advisory/pdsa-2023-020.md @@ -0,0 +1,28 @@ +## PDSA-2023-020: Command injection in _wget_download + +### CVE Number + +CVE-2023-52311 + +### Impact + +Command injection in `_wget_download` which could lead to execute arbitrary commands. The PoC is as follows: + +```python +from paddle import utils + +utils.download._wget_download("aa; touch codexecution", "bb") +``` + +### Patches + +We have patched the issue in commit [d5550d3f2f5bab48c783b4986ba1cd8e061ce542](https://github.com/PaddlePaddle/Paddle/pull/59957/commits/d5550d3f2f5bab48c783b4986ba1cd8e061ce542). +The fix will be included in PaddlePaddle 2.6.0. + +### For more information + +Please consult [our security guide](../../SECURITY.md) for more information regarding the security model and how to contact us with issues and questions. + +### Attribution + +This vulnerability has been reported by huntr.com. diff --git a/security/advisory/pdsa-2023-020_cn.md b/security/advisory/pdsa-2023-020_cn.md new file mode 100644 index 00000000000000..a6bd1321592e62 --- /dev/null +++ b/security/advisory/pdsa-2023-020_cn.md @@ -0,0 +1,28 @@ +## PDSA-2023-020: Command injection in _wget_download + +### CVE编号 + +CVE-2023-52311 + +### 影响 + +`_wget_download`存在命令注入漏洞,可造成任意命令执行,PoC代码如下: + +```python +from paddle import utils + +utils.download._wget_download("aa; touch codexecution", "bb") +``` + +### 补丁 + +我们在commit [d5550d3f2f5bab48c783b4986ba1cd8e061ce542](https://github.com/PaddlePaddle/Paddle/pull/59957/commits/d5550d3f2f5bab48c783b4986ba1cd8e061ce542)中对此问题进行了补丁。 +修复将包含在飞桨2.6.0版本当中。 + +### 更多信息 + +请参考我们的[安全指南](../../SECURITY_cn.md)以获得更多关于安全的信息,以及如何与我们联系问题。 + +### 贡献者 + +此漏洞由 huntr.com 提交。 diff --git a/security/advisory/pdsa-2023-021.md b/security/advisory/pdsa-2023-021.md new file mode 100644 index 00000000000000..6a8ec45b33e23c --- /dev/null +++ b/security/advisory/pdsa-2023-021.md @@ -0,0 +1,33 @@ +## PDSA-2023-021: Null pointer dereference in paddle.crop + +### CVE Number + +CVE-2023-52312 + +### Impact + +Null pointer dereference in `paddle.crop` when tensor dims are invalid . The PoC is as follows: + +```python +import paddle +import numpy as np + +x = paddle.to_tensor(np.random.uniform(0, 10, [2, 2]).astype(np.int32)) +shape = paddle.to_tensor([-1, 0], dtype='int32') +offsets = paddle.to_tensor([], dtype='int32') + +out = paddle.crop(x, shape, offsets) +``` + +### Patches + +We have patched the issue in commit [c074de6911944d5d30d28cc7ce2c7099f1c87bce](https://github.com/PaddlePaddle/Paddle/pull/59967/commits/c074de6911944d5d30d28cc7ce2c7099f1c87bce). +The fix will be included in PaddlePaddle 2.6.0. + +### For more information + +Please consult [our security guide](../../SECURITY.md) for more information regarding the security model and how to contact us with issues and questions. + +### Attribution + +This vulnerability has been reported by Peng Zhou (zpbrent) from Shanghai University. diff --git a/security/advisory/pdsa-2023-021_cn.md b/security/advisory/pdsa-2023-021_cn.md new file mode 100644 index 00000000000000..eff0b0c2225aac --- /dev/null +++ b/security/advisory/pdsa-2023-021_cn.md @@ -0,0 +1,33 @@ +## PDSA-2023-021: Null pointer dereference in paddle.crop + +### CVE编号 + +CVE-2023-52312 + +### 影响 + +输入张量的维度异常时,`paddle.crop`会引发空指针解引用,PoC代码如下: + +```python +import paddle +import numpy as np + +x = paddle.to_tensor(np.random.uniform(0, 10, [2, 2]).astype(np.int32)) +shape = paddle.to_tensor([-1, 0], dtype='int32') +offsets = paddle.to_tensor([], dtype='int32') + +out = paddle.crop(x, shape, offsets) +``` + +### 补丁 + +我们在commit [c074de6911944d5d30d28cc7ce2c7099f1c87bce](https://github.com/PaddlePaddle/Paddle/pull/59967/commits/c074de6911944d5d30d28cc7ce2c7099f1c87bce)中对此问题进行了补丁。 +修复将包含在飞桨2.6.0版本当中。 + +### 更多信息 + +请参考我们的[安全指南](../../SECURITY_cn.md)以获得更多关于安全的信息,以及如何与我们联系问题。 + +### 贡献者 + +此漏洞由 Peng Zhou (zpbrent) from Shanghai University 提交。 diff --git a/security/advisory/pdsa-2023-022.md b/security/advisory/pdsa-2023-022.md new file mode 100644 index 00000000000000..b5b3b3519c9c0e --- /dev/null +++ b/security/advisory/pdsa-2023-022.md @@ -0,0 +1,30 @@ +## PDSA-2023-022: FPE in paddle.argmin and paddle.argmax + +### CVE Number + +CVE-2023-52313 + +### Impact + +FPE in `paddle.argmin` and `paddle.argmax` when input `x.numel()` is 0. The PoC is as follows: + +```python +import paddle + +data = paddle.to_tensor([], dtype="int32") + +paddle.argmax(data, axis=0) +``` + +### Patches + +We have patched the issue in commit [41eda9080b12e6f1b3a49cdc8439a1b9f1ed6794](https://github.com/PaddlePaddle/Paddle/pull/59976/commits/41eda9080b12e6f1b3a49cdc8439a1b9f1ed6794). +The fix will be included in PaddlePaddle 2.6.0. + +### For more information + +Please consult [our security guide](../../SECURITY.md) for more information regarding the security model and how to contact us with issues and questions. + +### Attribution + +This vulnerability has been reported by Peng Zhou (zpbrent) from Shanghai University. diff --git a/security/advisory/pdsa-2023-022_cn.md b/security/advisory/pdsa-2023-022_cn.md new file mode 100644 index 00000000000000..d7c57f94394955 --- /dev/null +++ b/security/advisory/pdsa-2023-022_cn.md @@ -0,0 +1,30 @@ +## PDSA-2023-022: FPE in paddle.argmin and paddle.argmax + +### CVE编号 + +CVE-2023-52313 + +### 影响 + +输入`x.numel()`为0时`paddle.argmin`和`paddle.argmax`会引发除0异常,PoC代码如下: + +```python +import paddle + +data = paddle.to_tensor([], dtype="int32") + +paddle.argmax(data, axis=0) +``` + +### 补丁 + +我们在commit [41eda9080b12e6f1b3a49cdc8439a1b9f1ed6794](https://github.com/PaddlePaddle/Paddle/pull/59976/commits/41eda9080b12e6f1b3a49cdc8439a1b9f1ed6794)中对此问题进行了补丁。 +修复将包含在飞桨2.6.0版本当中。 + +### 更多信息 + +请参考我们的[安全指南](../../SECURITY_cn.md)以获得更多关于安全的信息,以及如何与我们联系问题。 + +### 贡献者 + +此漏洞由 Peng Zhou (zpbrent) from Shanghai University 提交。 diff --git a/security/advisory/pdsa-2023-023.md b/security/advisory/pdsa-2023-023.md new file mode 100644 index 00000000000000..c2671f7f87adca --- /dev/null +++ b/security/advisory/pdsa-2023-023.md @@ -0,0 +1,28 @@ +## PDSA-2023-023: Command injection in convert_shape_compare + +### CVE Number + +CVE-2023-52314 + +### Impact + +Command injection in `convert_shape_compare` which could lead to execute arbitrary commands. The PoC is as follows: + +```python +import paddle + +paddle.jit.dy2static.convert_operators.convert_shape_compare('prefix','+ str(__import__("os").system("cat /etc/passwd")) +','1') +``` + +### Patches + +We have patched the issue in commit [c3b6414eb313480f1417abe92d410dfe89723097](https://github.com/PaddlePaddle/Paddle/pull/60097/commits/c3b6414eb313480f1417abe92d410dfe89723097). +The fix will be included in PaddlePaddle 2.6.0. + +### For more information + +Please consult [our security guide](../../SECURITY.md) for more information regarding the security model and how to contact us with issues and questions. + +### Attribution + +This vulnerability has been reported by leeya_bug. diff --git a/security/advisory/pdsa-2023-023_cn.md b/security/advisory/pdsa-2023-023_cn.md new file mode 100644 index 00000000000000..3de87a4d707674 --- /dev/null +++ b/security/advisory/pdsa-2023-023_cn.md @@ -0,0 +1,28 @@ +## PDSA-2023-023: Command injection in convert_shape_compare + +### CVE编号 + +CVE-2023-52314 + +### 影响 + +`convert_shape_compare`存在命令注入漏洞,可造成任意命令执行,PoC代码如下: + +```python +import paddle + +paddle.jit.dy2static.convert_operators.convert_shape_compare('prefix','+ str(__import__("os").system("cat /etc/passwd")) +','1') +``` + +### 补丁 + +我们在commit [c3b6414eb313480f1417abe92d410dfe89723097](https://github.com/PaddlePaddle/Paddle/pull/60097/commits/c3b6414eb313480f1417abe92d410dfe89723097)中对此问题进行了补丁。 +修复将包含在飞桨2.6.0版本当中。 + +### 更多信息 + +请参考我们的[安全指南](../../SECURITY_cn.md)以获得更多关于安全的信息,以及如何与我们联系问题。 + +### 贡献者 + +此漏洞由 leeya_bug 提交。 From d8900198c3268683a47a9517020fbf79c9fa8ffa Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Wed, 3 Jan 2024 19:42:30 +0800 Subject: [PATCH 089/142] [Inference] refine common/*.h for inference lib (#60513) --- cmake/inference_lib.cmake | 4 -- paddle/fluid/operators/bernoulli_op.h | 2 +- paddle/fluid/operators/bilateral_slice_op.h | 2 +- paddle/fluid/operators/dequantize_log_op.cu | 2 +- .../fluid/operators/detection/box_clip_op.cu | 2 +- paddle/fluid/operators/fake_quantize_op.h | 2 +- paddle/fluid/operators/index_impl.cu.h | 2 +- paddle/fluid/operators/interpolate_op.h | 2 +- paddle/fluid/operators/math/cos_sim_functor.h | 2 +- .../fluid/operators/modified_huber_loss_op.cu | 2 +- .../fluid/operators/modified_huber_loss_op.h | 2 +- paddle/fluid/operators/quantize_linear_op.h | 2 +- paddle/phi/api/all.h | 7 +++- paddle/phi/common/bfloat16.h | 2 +- paddle/phi/common/complex.h | 2 +- paddle/phi/common/float16.h | 2 +- paddle/phi/common/transform.h | 2 +- paddle/phi/common/type_safe_sign_math.h | 2 +- paddle/phi/core/hostdevice.h | 37 ------------------- .../phi/kernels/cpu/graph_send_recv_funcs.h | 2 +- .../kernels/cpu/graph_send_ue_recv_funcs.h | 2 +- paddle/phi/kernels/cpu/send_u_recv_kernel.cc | 2 +- .../kernels/cpu/send_ue_recv_grad_kernel.cc | 2 +- paddle/phi/kernels/cpu/send_ue_recv_kernel.cc | 2 +- paddle/phi/kernels/cpu/send_uv_grad_kernel.cc | 2 +- paddle/phi/kernels/cpu/send_uv_kernel.cc | 2 +- paddle/phi/kernels/funcs/algorithm.h | 2 +- paddle/phi/kernels/funcs/aligned_vector.h | 2 +- paddle/phi/kernels/funcs/complex_functors.h | 2 +- paddle/phi/kernels/funcs/cross_entropy.h | 2 +- .../funcs/detail/activation_functions.h | 2 +- paddle/phi/kernels/funcs/detail/gru_kernel.h | 2 +- paddle/phi/kernels/funcs/detail/lstm_kernel.h | 2 +- paddle/phi/kernels/funcs/diag_functor.h | 2 +- .../phi/kernels/funcs/distribution_helper.h | 2 +- paddle/phi/kernels/funcs/eigen/extensions.h | 2 +- .../phi/kernels/funcs/elementwise_functor.h | 2 +- paddle/phi/kernels/funcs/fft_fill_conj.h | 2 +- paddle/phi/kernels/funcs/index_impl.cu.h | 2 +- paddle/phi/kernels/funcs/math.h | 2 +- paddle/phi/kernels/funcs/pooling.h | 2 +- paddle/phi/kernels/funcs/quant_dequant.h | 2 +- .../phi/kernels/gpu/bce_loss_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/bce_loss_kernel.cu | 2 +- paddle/phi/kernels/gpu/cum_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/cum_kernel.cu | 2 +- paddle/phi/kernels/gpu/cum_maxmin_kernel.cu | 2 +- paddle/phi/kernels/gpu/depthwise_conv.h | 2 +- paddle/phi/kernels/gpu/graph_reindex_funcs.h | 2 +- .../gpu/graph_sample_neighbors_kernel.cu | 2 +- .../phi/kernels/gpu/graph_send_recv_funcs.h | 2 +- .../kernels/gpu/graph_send_ue_recv_funcs.h | 2 +- paddle/phi/kernels/gpu/nll_loss.h | 2 +- .../kernels/gpu/send_u_recv_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/send_u_recv_kernel.cu | 2 +- .../kernels/gpu/send_ue_recv_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/send_ue_recv_kernel.cu | 2 +- paddle/phi/kernels/gpu/send_uv_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/send_uv_kernel.cu | 2 +- .../gpu/sigmoid_cross_entropy_with_logits.h | 2 +- .../gpu/weighted_sample_neighbors_kernel.cu | 2 +- paddle/phi/kernels/impl/amp_kernel_impl.h | 2 +- .../impl/deformable_conv_grad_kernel_impl.h | 2 +- .../impl/deformable_conv_kernel_impl.h | 2 +- .../impl/kldiv_loss_grad_kernel_impl.h | 2 +- .../phi/kernels/impl/kldiv_loss_kernel_impl.h | 2 +- .../phi/kernels/impl/merged_momentum_impl.h | 2 +- .../phi/kernels/impl/quantize_linear_impl.h | 2 +- paddle/phi/kernels/nms_kernel.h | 2 +- paddle/phi/kernels/strings/unicode.h | 2 +- test/cpp/phi/common/transform_test.cu | 2 +- 71 files changed, 73 insertions(+), 111 deletions(-) delete mode 100644 paddle/phi/core/hostdevice.h diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 517ac24cccc72e..f44e23e6da74e8 100755 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -328,10 +328,6 @@ copy( inference_lib_dist SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/visit_type.h DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/) -copy( - inference_lib_dist - SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/hostdevice.h - DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/) copy( inference_lib_dist SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/init_phi.h diff --git a/paddle/fluid/operators/bernoulli_op.h b/paddle/fluid/operators/bernoulli_op.h index e0bd2145db6352..ffa2722ccbb602 100644 --- a/paddle/fluid/operators/bernoulli_op.h +++ b/paddle/fluid/operators/bernoulli_op.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/common/hostdevice.h" #include "paddle/fluid/platform/enforce.h" -#include "paddle/phi/core/hostdevice.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/bilateral_slice_op.h b/paddle/fluid/operators/bilateral_slice_op.h index 66783f151ea06a..c88b9c9054e30a 100644 --- a/paddle/fluid/operators/bilateral_slice_op.h +++ b/paddle/fluid/operators/bilateral_slice_op.h @@ -13,8 +13,8 @@ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/hostdevice.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/dequantize_log_op.cu b/paddle/fluid/operators/dequantize_log_op.cu index 89af422859fe4c..933e074b8bbe77 100644 --- a/paddle/fluid/operators/dequantize_log_op.cu +++ b/paddle/fluid/operators/dequantize_log_op.cu @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/dequantize_log_op.h" +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/funcs/math.h" namespace paddle { diff --git a/paddle/fluid/operators/detection/box_clip_op.cu b/paddle/fluid/operators/detection/box_clip_op.cu index ef0f2439c5ee1c..7b802337ef7b70 100644 --- a/paddle/fluid/operators/detection/box_clip_op.cu +++ b/paddle/fluid/operators/detection/box_clip_op.cu @@ -13,10 +13,10 @@ limitations under the License. */ #include +#include "paddle/common/hostdevice.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detection/box_clip_op.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index 13f1e5a3a26124..dd8675331fce6b 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -16,12 +16,12 @@ limitations under the License. */ #include +#include "paddle/common/hostdevice.h" #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/memory/malloc.h" #include "paddle/phi/common/transform.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/funcs/blas/blas.h" namespace paddle { diff --git a/paddle/fluid/operators/index_impl.cu.h b/paddle/fluid/operators/index_impl.cu.h index 7d2cdae87950bc..629717f61933a8 100644 --- a/paddle/fluid/operators/index_impl.cu.h +++ b/paddle/fluid/operators/index_impl.cu.h @@ -18,12 +18,12 @@ limitations under the License. */ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/generator.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/primitive/kernel_primitives.h" diff --git a/paddle/fluid/operators/interpolate_op.h b/paddle/fluid/operators/interpolate_op.h index 31767d68b9d3c9..563879e301d12b 100644 --- a/paddle/fluid/operators/interpolate_op.h +++ b/paddle/fluid/operators/interpolate_op.h @@ -14,8 +14,8 @@ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/math/cos_sim_functor.h b/paddle/fluid/operators/math/cos_sim_functor.h index 0175bec6e7edf5..1c0f994fd87332 100644 --- a/paddle/fluid/operators/math/cos_sim_functor.h +++ b/paddle/fluid/operators/math/cos_sim_functor.h @@ -16,8 +16,8 @@ limitations under the License. */ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/fluid/platform/device_context.h" -#include "paddle/phi/core/hostdevice.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/modified_huber_loss_op.cu b/paddle/fluid/operators/modified_huber_loss_op.cu index d063e8d1cb4d57..aec8c49cdd2506 100644 --- a/paddle/fluid/operators/modified_huber_loss_op.cu +++ b/paddle/fluid/operators/modified_huber_loss_op.cu @@ -16,9 +16,9 @@ limitations under the License. */ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/modified_huber_loss_op.h" -#include "paddle/phi/core/hostdevice.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/modified_huber_loss_op.h b/paddle/fluid/operators/modified_huber_loss_op.h index 4330abde2a828a..88cb91d454e721 100644 --- a/paddle/fluid/operators/modified_huber_loss_op.h +++ b/paddle/fluid/operators/modified_huber_loss_op.h @@ -14,9 +14,9 @@ limitations under the License. */ #pragma once +#include "paddle/common/hostdevice.h" #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/hostdevice.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/quantize_linear_op.h b/paddle/fluid/operators/quantize_linear_op.h index 1f1bfc3dea73bd..48f4b472baab5d 100644 --- a/paddle/fluid/operators/quantize_linear_op.h +++ b/paddle/fluid/operators/quantize_linear_op.h @@ -15,13 +15,13 @@ limitations under the License. */ #include #include "paddle/common/ddim.h" +#include "paddle/common/hostdevice.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/fake_quantize_op.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/transform.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/cast_kernel.h" namespace paddle { diff --git a/paddle/phi/api/all.h b/paddle/phi/api/all.h index ec521021859706..93c97605f9f3ff 100644 --- a/paddle/phi/api/all.h +++ b/paddle/phi/api/all.h @@ -29,14 +29,17 @@ limitations under the License. */ #include "paddle/phi/api/include/tensor_utils.h" // phi common headers -#include "paddle/common/layout.h" #include "paddle/phi/common/backend.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/scalar.h" // original custom op headers -#include "paddle/common/exception.h" #include "paddle/phi/api/ext/dispatch.h" #include "paddle/phi/api/ext/op_meta_info.h" #include "paddle/phi/api/ext/tensor_compat.h" + +// common headers +#include "paddle/common/ddim.h" +#include "paddle/common/exception.h" +#include "paddle/common/layout.h" diff --git a/paddle/phi/common/bfloat16.h b/paddle/phi/common/bfloat16.h index 028851e34c8bc7..d3a7f3daf37624 100644 --- a/paddle/phi/common/bfloat16.h +++ b/paddle/phi/common/bfloat16.h @@ -20,7 +20,7 @@ #include #include #include -#include "paddle/phi/core/hostdevice.h" +#include "paddle/common/hostdevice.h" #ifdef PADDLE_WITH_CUDA #include diff --git a/paddle/phi/common/complex.h b/paddle/phi/common/complex.h index 5de6290fb77057..34605855137e0e 100644 --- a/paddle/phi/common/complex.h +++ b/paddle/phi/common/complex.h @@ -20,7 +20,7 @@ #include #include #include -#include "paddle/phi/core/hostdevice.h" +#include "paddle/common/hostdevice.h" #ifdef PADDLE_WITH_CUDA #include #include diff --git a/paddle/phi/common/float16.h b/paddle/phi/common/float16.h index 9d60b8c6241ae3..04411aa4cec49d 100644 --- a/paddle/phi/common/float16.h +++ b/paddle/phi/common/float16.h @@ -32,7 +32,7 @@ #include #include -#include "paddle/phi/core/hostdevice.h" +#include "paddle/common/hostdevice.h" #ifdef PADDLE_WITH_CUDA #include #endif // PADDLE_WITH_CUDA diff --git a/paddle/phi/common/transform.h b/paddle/phi/common/transform.h index e80561284b885f..d83b698a45bc6f 100644 --- a/paddle/phi/common/transform.h +++ b/paddle/phi/common/transform.h @@ -17,9 +17,9 @@ limitations under the License. */ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/enforce.h" -#include "paddle/phi/core/hostdevice.h" #if defined(__NVCC__) || defined(__HIPCC__) #include diff --git a/paddle/phi/common/type_safe_sign_math.h b/paddle/phi/common/type_safe_sign_math.h index e5d3cf48e022de..3031b3d8f6a584 100644 --- a/paddle/phi/common/type_safe_sign_math.h +++ b/paddle/phi/common/type_safe_sign_math.h @@ -17,9 +17,9 @@ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" -#include "paddle/phi/core/hostdevice.h" namespace phi { diff --git a/paddle/phi/core/hostdevice.h b/paddle/phi/core/hostdevice.h deleted file mode 100644 index decebbe66a5381..00000000000000 --- a/paddle/phi/core/hostdevice.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#ifdef __HIPCC__ -#include -#endif - -#if defined(__xpu__) -#include - -#include "xpu/kernel/cluster_header.h" -#include "xpu/kernel/debug.h" -#include "xpu/kernel/math.h" -#endif - -#if (defined(__CUDACC__) || defined(__HIPCC__) || defined(__xpu__)) -#define HOSTDEVICE __host__ __device__ -#define DEVICE __device__ -#define HOST __host__ -#else -#define HOSTDEVICE -#define DEVICE -#define HOST -#endif diff --git a/paddle/phi/kernels/cpu/graph_send_recv_funcs.h b/paddle/phi/kernels/cpu/graph_send_recv_funcs.h index c67480cc9e33e2..dc800876b99a17 100644 --- a/paddle/phi/kernels/cpu/graph_send_recv_funcs.h +++ b/paddle/phi/kernels/cpu/graph_send_recv_funcs.h @@ -16,9 +16,9 @@ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { diff --git a/paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h b/paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h index 7647415d8e7cbf..4af25b37f943b8 100644 --- a/paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h +++ b/paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h @@ -16,9 +16,9 @@ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { diff --git a/paddle/phi/kernels/cpu/send_u_recv_kernel.cc b/paddle/phi/kernels/cpu/send_u_recv_kernel.cc index 9e186aeedfab36..8d09034f35cb72 100644 --- a/paddle/phi/kernels/cpu/send_u_recv_kernel.cc +++ b/paddle/phi/kernels/cpu/send_u_recv_kernel.cc @@ -18,8 +18,8 @@ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/cpu/graph_send_recv_funcs.h" diff --git a/paddle/phi/kernels/cpu/send_ue_recv_grad_kernel.cc b/paddle/phi/kernels/cpu/send_ue_recv_grad_kernel.cc index 404a1637380706..0d0210ac661c04 100644 --- a/paddle/phi/kernels/cpu/send_ue_recv_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/send_ue_recv_grad_kernel.cc @@ -17,8 +17,8 @@ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/cpu/graph_send_recv_funcs.h" #include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h" diff --git a/paddle/phi/kernels/cpu/send_ue_recv_kernel.cc b/paddle/phi/kernels/cpu/send_ue_recv_kernel.cc index a53efc2bc17b05..73a671fde4e48d 100644 --- a/paddle/phi/kernels/cpu/send_ue_recv_kernel.cc +++ b/paddle/phi/kernels/cpu/send_ue_recv_kernel.cc @@ -18,8 +18,8 @@ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h" #include "paddle/phi/kernels/impl/graph_message_passing_impl.h" diff --git a/paddle/phi/kernels/cpu/send_uv_grad_kernel.cc b/paddle/phi/kernels/cpu/send_uv_grad_kernel.cc index 7c16308f7360a1..152cd948562311 100644 --- a/paddle/phi/kernels/cpu/send_uv_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/send_uv_grad_kernel.cc @@ -14,8 +14,8 @@ #include "paddle/phi/kernels/send_uv_grad_kernel.h" +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/phi/kernels/cpu/send_uv_kernel.cc b/paddle/phi/kernels/cpu/send_uv_kernel.cc index 59f334c2b67f48..c5200182a1d08c 100644 --- a/paddle/phi/kernels/cpu/send_uv_kernel.cc +++ b/paddle/phi/kernels/cpu/send_uv_kernel.cc @@ -14,8 +14,8 @@ #include "paddle/phi/kernels/send_uv_kernel.h" +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/cpu/graph_send_ue_recv_funcs.h" #include "paddle/phi/kernels/impl/graph_message_passing_impl.h" diff --git a/paddle/phi/kernels/funcs/algorithm.h b/paddle/phi/kernels/funcs/algorithm.h index 5f66f6f1abd4d2..7ad7dce7928fc2 100644 --- a/paddle/phi/kernels/funcs/algorithm.h +++ b/paddle/phi/kernels/funcs/algorithm.h @@ -18,7 +18,7 @@ #include // for int64_t #include -#include "paddle/phi/core/hostdevice.h" +#include "paddle/common/hostdevice.h" namespace phi { namespace funcs { diff --git a/paddle/phi/kernels/funcs/aligned_vector.h b/paddle/phi/kernels/funcs/aligned_vector.h index 753aa44b0aa3ae..003b365c16d1d6 100644 --- a/paddle/phi/kernels/funcs/aligned_vector.h +++ b/paddle/phi/kernels/funcs/aligned_vector.h @@ -16,8 +16,8 @@ limitations under the License. */ #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/hostdevice.h" #if defined(__xpu__) #define CHAR_BIT 8 diff --git a/paddle/phi/kernels/funcs/complex_functors.h b/paddle/phi/kernels/funcs/complex_functors.h index e6ffeb3b5602e9..c8848c2f6ce240 100644 --- a/paddle/phi/kernels/funcs/complex_functors.h +++ b/paddle/phi/kernels/funcs/complex_functors.h @@ -19,9 +19,9 @@ limitations under the License. */ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/common/complex.h" #include "paddle/phi/common/type_traits.h" -#include "paddle/phi/core/hostdevice.h" namespace phi { namespace funcs { diff --git a/paddle/phi/kernels/funcs/cross_entropy.h b/paddle/phi/kernels/funcs/cross_entropy.h index 3c4057420c3d47..ebba150cd9162d 100644 --- a/paddle/phi/kernels/funcs/cross_entropy.h +++ b/paddle/phi/kernels/funcs/cross_entropy.h @@ -15,10 +15,10 @@ limitations under the License. */ #pragma once #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { diff --git a/paddle/phi/kernels/funcs/detail/activation_functions.h b/paddle/phi/kernels/funcs/detail/activation_functions.h index 758503563680be..e635cd1cf03372 100644 --- a/paddle/phi/kernels/funcs/detail/activation_functions.h +++ b/paddle/phi/kernels/funcs/detail/activation_functions.h @@ -18,9 +18,9 @@ limitations under the License. */ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/common/macros.h" #include "paddle/phi/backends/cpu/cpu_info.h" -#include "paddle/phi/core/hostdevice.h" namespace phi { namespace funcs { namespace detail { diff --git a/paddle/phi/kernels/funcs/detail/gru_kernel.h b/paddle/phi/kernels/funcs/detail/gru_kernel.h index 9e2aef19406191..8ab0b44f7e6276 100644 --- a/paddle/phi/kernels/funcs/detail/gru_kernel.h +++ b/paddle/phi/kernels/funcs/detail/gru_kernel.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include -#include "paddle/phi/core/hostdevice.h" +#include "paddle/common/hostdevice.h" #include "paddle/phi/kernels/funcs/detail/activation_functions.h" // TODO(guosheng): refine code style in gru_kernel diff --git a/paddle/phi/kernels/funcs/detail/lstm_kernel.h b/paddle/phi/kernels/funcs/detail/lstm_kernel.h index 0846f05a0c2c53..8af7467e36e295 100644 --- a/paddle/phi/kernels/funcs/detail/lstm_kernel.h +++ b/paddle/phi/kernels/funcs/detail/lstm_kernel.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include -#include "paddle/phi/core/hostdevice.h" +#include "paddle/common/hostdevice.h" #include "paddle/phi/kernels/funcs/detail/activation_functions.h" namespace phi { diff --git a/paddle/phi/kernels/funcs/diag_functor.h b/paddle/phi/kernels/funcs/diag_functor.h index 6fe54363e6f0e2..424b7169a19b18 100644 --- a/paddle/phi/kernels/funcs/diag_functor.h +++ b/paddle/phi/kernels/funcs/diag_functor.h @@ -14,9 +14,9 @@ #pragma once +#include "paddle/common/hostdevice.h" #include "paddle/phi/common/type_traits.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/for_range.h" diff --git a/paddle/phi/kernels/funcs/distribution_helper.h b/paddle/phi/kernels/funcs/distribution_helper.h index abade7ac0ef877..07e04e94da60a1 100644 --- a/paddle/phi/kernels/funcs/distribution_helper.h +++ b/paddle/phi/kernels/funcs/distribution_helper.h @@ -21,12 +21,12 @@ limitations under the License. */ #include #endif +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/generator.h" -#include "paddle/phi/core/hostdevice.h" #if defined(__NVCC__) || defined(__HIPCC__) #include "paddle/phi/kernels/funcs/index_impl.cu.h" diff --git a/paddle/phi/kernels/funcs/eigen/extensions.h b/paddle/phi/kernels/funcs/eigen/extensions.h index c724564417b19a..285b58d29ed851 100644 --- a/paddle/phi/kernels/funcs/eigen/extensions.h +++ b/paddle/phi/kernels/funcs/eigen/extensions.h @@ -16,10 +16,10 @@ #ifndef __xpu__ +#include "paddle/common/hostdevice.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/complex.h" #include "paddle/phi/common/float16.h" -#include "paddle/phi/core/hostdevice.h" #include "unsupported/Eigen/CXX11/Tensor" namespace Eigen { diff --git a/paddle/phi/kernels/funcs/elementwise_functor.h b/paddle/phi/kernels/funcs/elementwise_functor.h index eaf527fbba9f6b..5477b952d08b3b 100644 --- a/paddle/phi/kernels/funcs/elementwise_functor.h +++ b/paddle/phi/kernels/funcs/elementwise_functor.h @@ -14,12 +14,12 @@ limitations under the License. */ #pragma once +#include "paddle/common/hostdevice.h" #include "paddle/common/macros.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/complex.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/enforce.h" -#include "paddle/phi/core/hostdevice.h" #if defined(__xpu__) #include diff --git a/paddle/phi/kernels/funcs/fft_fill_conj.h b/paddle/phi/kernels/funcs/fft_fill_conj.h index ab6d351986ecc2..c47257818f3a34 100644 --- a/paddle/phi/kernels/funcs/fft_fill_conj.h +++ b/paddle/phi/kernels/funcs/fft_fill_conj.h @@ -15,8 +15,8 @@ #pragma once #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/funcs/for_range.h" #if defined(__NVCC__) || defined(__HIPCC__) #include "thrust/device_vector.h" diff --git a/paddle/phi/kernels/funcs/index_impl.cu.h b/paddle/phi/kernels/funcs/index_impl.cu.h index cfe95f87f6335d..84a665ac249d11 100644 --- a/paddle/phi/kernels/funcs/index_impl.cu.h +++ b/paddle/phi/kernels/funcs/index_impl.cu.h @@ -18,9 +18,9 @@ limitations under the License. */ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/primitive/kernel_primitives.h" diff --git a/paddle/phi/kernels/funcs/math.h b/paddle/phi/kernels/funcs/math.h index 004279c25d4afa..f52175ccff7a3d 100644 --- a/paddle/phi/kernels/funcs/math.h +++ b/paddle/phi/kernels/funcs/math.h @@ -15,8 +15,8 @@ #pragma once #include "math.h" // NOLINT +#include "paddle/common/hostdevice.h" #include "paddle/phi/common/float16.h" -#include "paddle/phi/core/hostdevice.h" namespace phi { namespace funcs { diff --git a/paddle/phi/kernels/funcs/pooling.h b/paddle/phi/kernels/funcs/pooling.h index 1ffd747735543c..7fa3e68be5a111 100644 --- a/paddle/phi/kernels/funcs/pooling.h +++ b/paddle/phi/kernels/funcs/pooling.h @@ -18,10 +18,10 @@ limitations under the License. */ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/common/macros.h" // import FLT_MAX #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/hostdevice.h" #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/phi/backends/gpu/gpu_decls.h" diff --git a/paddle/phi/kernels/funcs/quant_dequant.h b/paddle/phi/kernels/funcs/quant_dequant.h index c0ba1df5c6344a..dc0007b4e606d4 100644 --- a/paddle/phi/kernels/funcs/quant_dequant.h +++ b/paddle/phi/kernels/funcs/quant_dequant.h @@ -15,10 +15,10 @@ limitations under the License. */ #pragma once #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/common/transform.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/blas/blas.h" diff --git a/paddle/phi/kernels/gpu/bce_loss_grad_kernel.cu b/paddle/phi/kernels/gpu/bce_loss_grad_kernel.cu index b50fc1301748c8..942f1be4f1625d 100644 --- a/paddle/phi/kernels/gpu/bce_loss_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/bce_loss_grad_kernel.cu @@ -17,10 +17,10 @@ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/float16.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" diff --git a/paddle/phi/kernels/gpu/bce_loss_kernel.cu b/paddle/phi/kernels/gpu/bce_loss_kernel.cu index 49191b3e354a7a..c1e73afac71f98 100644 --- a/paddle/phi/kernels/gpu/bce_loss_kernel.cu +++ b/paddle/phi/kernels/gpu/bce_loss_kernel.cu @@ -17,10 +17,10 @@ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/float16.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/primitive/functor_primitives.h" diff --git a/paddle/phi/kernels/gpu/cum_grad_kernel.cu b/paddle/phi/kernels/gpu/cum_grad_kernel.cu index 0e99305ac3127d..846a093db4f843 100644 --- a/paddle/phi/kernels/gpu/cum_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/cum_grad_kernel.cu @@ -27,11 +27,11 @@ namespace cub = hipcub; #endif +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/cum_kernel.cu b/paddle/phi/kernels/gpu/cum_kernel.cu index 831b225c61d844..c85dd99a3d401f 100644 --- a/paddle/phi/kernels/gpu/cum_kernel.cu +++ b/paddle/phi/kernels/gpu/cum_kernel.cu @@ -26,12 +26,12 @@ namespace cub = hipcub; #endif +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/common/memory_utils.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/cum_maxmin_kernel.cu b/paddle/phi/kernels/gpu/cum_maxmin_kernel.cu index 24ba48429e10ce..08b8b89afe4b35 100644 --- a/paddle/phi/kernels/gpu/cum_maxmin_kernel.cu +++ b/paddle/phi/kernels/gpu/cum_maxmin_kernel.cu @@ -16,8 +16,8 @@ #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/phi/kernels/gpu/depthwise_conv.h b/paddle/phi/kernels/gpu/depthwise_conv.h index 278b219b453d3d..03c14a6234bda5 100644 --- a/paddle/phi/kernels/gpu/depthwise_conv.h +++ b/paddle/phi/kernels/gpu/depthwise_conv.h @@ -15,8 +15,8 @@ limitations under the License. */ #pragma once #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/hostdevice.h" #ifdef __NVCC__ #include diff --git a/paddle/phi/kernels/gpu/graph_reindex_funcs.h b/paddle/phi/kernels/gpu/graph_reindex_funcs.h index 2a5479e076e1d6..eb26b434091ccb 100644 --- a/paddle/phi/kernels/gpu/graph_reindex_funcs.h +++ b/paddle/phi/kernels/gpu/graph_reindex_funcs.h @@ -14,9 +14,9 @@ #pragma once +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/graph_reindex_kernel.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu b/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu index 20e1c6727ae91e..595b803a68dfdb 100644 --- a/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu @@ -27,8 +27,8 @@ #include #endif +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/graph_sample_neighbors_kernel.h" diff --git a/paddle/phi/kernels/gpu/graph_send_recv_funcs.h b/paddle/phi/kernels/gpu/graph_send_recv_funcs.h index 9aacba8a7a3aa5..0a6ccbe45b39f6 100644 --- a/paddle/phi/kernels/gpu/graph_send_recv_funcs.h +++ b/paddle/phi/kernels/gpu/graph_send_recv_funcs.h @@ -19,9 +19,9 @@ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/send_u_recv_kernel.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h b/paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h index bff91078865d92..a5d6f3d8c6b06b 100644 --- a/paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h +++ b/paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h @@ -17,9 +17,9 @@ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/impl/graph_message_passing_impl.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/nll_loss.h b/paddle/phi/kernels/gpu/nll_loss.h index 9d063d0ef44a0b..648b69b45253c1 100644 --- a/paddle/phi/kernels/gpu/nll_loss.h +++ b/paddle/phi/kernels/gpu/nll_loss.h @@ -19,9 +19,9 @@ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/funcs/math.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/send_u_recv_grad_kernel.cu b/paddle/phi/kernels/gpu/send_u_recv_grad_kernel.cu index b9c4a8daf2326c..5f310da580d6a8 100644 --- a/paddle/phi/kernels/gpu/send_u_recv_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/send_u_recv_grad_kernel.cu @@ -17,8 +17,8 @@ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h" diff --git a/paddle/phi/kernels/gpu/send_u_recv_kernel.cu b/paddle/phi/kernels/gpu/send_u_recv_kernel.cu index d4a08a72d80a98..69ba6a18300028 100644 --- a/paddle/phi/kernels/gpu/send_u_recv_kernel.cu +++ b/paddle/phi/kernels/gpu/send_u_recv_kernel.cu @@ -21,8 +21,8 @@ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h" diff --git a/paddle/phi/kernels/gpu/send_ue_recv_grad_kernel.cu b/paddle/phi/kernels/gpu/send_ue_recv_grad_kernel.cu index 5703b5faea07c5..d021e877ddcf47 100644 --- a/paddle/phi/kernels/gpu/send_ue_recv_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/send_ue_recv_grad_kernel.cu @@ -13,8 +13,8 @@ // limitations under the License. #include "paddle/phi/kernels/send_ue_recv_grad_kernel.h" +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" diff --git a/paddle/phi/kernels/gpu/send_ue_recv_kernel.cu b/paddle/phi/kernels/gpu/send_ue_recv_kernel.cu index c87f133d07b8d8..1ed59c43e83d47 100644 --- a/paddle/phi/kernels/gpu/send_ue_recv_kernel.cu +++ b/paddle/phi/kernels/gpu/send_ue_recv_kernel.cu @@ -20,8 +20,8 @@ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/phi/kernels/gpu/send_uv_grad_kernel.cu b/paddle/phi/kernels/gpu/send_uv_grad_kernel.cu index bc61ae766d6c24..d5126979b84f6c 100644 --- a/paddle/phi/kernels/gpu/send_uv_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/send_uv_grad_kernel.cu @@ -14,8 +14,8 @@ #include "paddle/phi/kernels/send_uv_grad_kernel.h" +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" diff --git a/paddle/phi/kernels/gpu/send_uv_kernel.cu b/paddle/phi/kernels/gpu/send_uv_kernel.cu index 860a900dac8341..cbc9e42ff84ea4 100644 --- a/paddle/phi/kernels/gpu/send_uv_kernel.cu +++ b/paddle/phi/kernels/gpu/send_uv_kernel.cu @@ -16,8 +16,8 @@ #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h" diff --git a/paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits.h b/paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits.h index dc6d8312e06c7e..2baa96d2a51600 100644 --- a/paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits.h +++ b/paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits.h @@ -16,9 +16,9 @@ #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_helper.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" diff --git a/paddle/phi/kernels/gpu/weighted_sample_neighbors_kernel.cu b/paddle/phi/kernels/gpu/weighted_sample_neighbors_kernel.cu index d4e0ca632e04de..3b43459ad37281 100644 --- a/paddle/phi/kernels/gpu/weighted_sample_neighbors_kernel.cu +++ b/paddle/phi/kernels/gpu/weighted_sample_neighbors_kernel.cu @@ -26,9 +26,9 @@ #endif #include "math.h" // NOLINT +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/memory_utils.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/block_radix_topk.cuh" #include "paddle/phi/kernels/funcs/random.cuh" diff --git a/paddle/phi/kernels/impl/amp_kernel_impl.h b/paddle/phi/kernels/impl/amp_kernel_impl.h index ec857f3f640d56..cfe5b338145944 100644 --- a/paddle/phi/kernels/impl/amp_kernel_impl.h +++ b/paddle/phi/kernels/impl/amp_kernel_impl.h @@ -14,8 +14,8 @@ #pragma once +#include "paddle/common/hostdevice.h" #include "paddle/phi/common/amp_type_traits.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/amp_kernel.h" #include "paddle/phi/kernels/full_kernel.h" diff --git a/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h b/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h index fdd31e510510a6..b9931a89978307 100644 --- a/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h @@ -14,8 +14,8 @@ #pragma once +#include "paddle/common/hostdevice.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/blas/blas.h" diff --git a/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h b/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h index d4647128963e5d..9d3d66ab2dbe65 100644 --- a/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h +++ b/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h @@ -14,8 +14,8 @@ #pragma once +#include "paddle/common/hostdevice.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/deformable_conv_functor.h" diff --git a/paddle/phi/kernels/impl/kldiv_loss_grad_kernel_impl.h b/paddle/phi/kernels/impl/kldiv_loss_grad_kernel_impl.h index d272e73303436e..a5e6c3d8fbfae1 100644 --- a/paddle/phi/kernels/impl/kldiv_loss_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/kldiv_loss_grad_kernel_impl.h @@ -15,8 +15,8 @@ #pragma once #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { diff --git a/paddle/phi/kernels/impl/kldiv_loss_kernel_impl.h b/paddle/phi/kernels/impl/kldiv_loss_kernel_impl.h index 851a78b07413ef..4232e32597ed1c 100644 --- a/paddle/phi/kernels/impl/kldiv_loss_kernel_impl.h +++ b/paddle/phi/kernels/impl/kldiv_loss_kernel_impl.h @@ -15,9 +15,9 @@ #pragma once #include +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { diff --git a/paddle/phi/kernels/impl/merged_momentum_impl.h b/paddle/phi/kernels/impl/merged_momentum_impl.h index 85f253fd32d492..32a6024ebee913 100644 --- a/paddle/phi/kernels/impl/merged_momentum_impl.h +++ b/paddle/phi/kernels/impl/merged_momentum_impl.h @@ -16,10 +16,10 @@ #include "glog/logging.h" +#include "paddle/common/hostdevice.h" #include "paddle/common/macros.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/impl/momentum_kernel_impl.h" #include "paddle/phi/kernels/merged_momentum_kernel.h" diff --git a/paddle/phi/kernels/impl/quantize_linear_impl.h b/paddle/phi/kernels/impl/quantize_linear_impl.h index 9f86fd07447ee5..a454023d859d8a 100644 --- a/paddle/phi/kernels/impl/quantize_linear_impl.h +++ b/paddle/phi/kernels/impl/quantize_linear_impl.h @@ -18,9 +18,9 @@ #include "paddle/phi/kernels/quantize_linear_kernel.h" +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/common/place.h" -#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/cast_kernel.h" diff --git a/paddle/phi/kernels/nms_kernel.h b/paddle/phi/kernels/nms_kernel.h index e8511f4c4a49f3..a8ceb0c1f3c84b 100644 --- a/paddle/phi/kernels/nms_kernel.h +++ b/paddle/phi/kernels/nms_kernel.h @@ -14,8 +14,8 @@ #pragma once +#include "paddle/common/hostdevice.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/hostdevice.h" namespace phi { diff --git a/paddle/phi/kernels/strings/unicode.h b/paddle/phi/kernels/strings/unicode.h index 410543c27d68fc..6dfb6aeb6ede6a 100644 --- a/paddle/phi/kernels/strings/unicode.h +++ b/paddle/phi/kernels/strings/unicode.h @@ -17,8 +17,8 @@ limitations under the License. */ #include #include +#include "paddle/common/hostdevice.h" #include "paddle/common/macros.h" -#include "paddle/phi/core/hostdevice.h" namespace phi { namespace strings { diff --git a/test/cpp/phi/common/transform_test.cu b/test/cpp/phi/common/transform_test.cu index 22c72b6626f854..8326160d881b61 100644 --- a/test/cpp/phi/common/transform_test.cu +++ b/test/cpp/phi/common/transform_test.cu @@ -16,9 +16,9 @@ limitations under the License. */ #include "paddle/phi/common/transform.h" +#include "paddle/common/hostdevice.h" #include "paddle/phi/backends/context_pool.h" #include "paddle/phi/common/memory_utils.h" -#include "paddle/phi/core/hostdevice.h" template class Scale { From 5d0138206231dc715dfc098d59edf816c7b29185 Mon Sep 17 00:00:00 2001 From: zbt78 <1095497213@qq.com> Date: Wed, 3 Jan 2024 19:56:52 +0800 Subject: [PATCH 090/142] =?UTF-8?q?=E3=80=90complex=20op=E3=80=91No.19=20a?= =?UTF-8?q?dd=20complex=20support=20for=20triangular=5Fsolve=20(#59529)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../cpu/triangular_solve_grad_kernel.cc | 4 +- .../kernels/cpu/triangular_solve_kernel.cc | 4 +- paddle/phi/kernels/funcs/blas/blas_impl.h | 2 +- paddle/phi/kernels/funcs/matrix_reduce.cc | 2 + paddle/phi/kernels/funcs/matrix_reduce.cu | 2 + .../gpu/triangular_solve_grad_kernel.cu | 4 +- .../kernels/gpu/triangular_solve_kernel.cu | 4 +- python/paddle/tensor/linalg.py | 14 +- test/legacy_test/test_triangular_solve_op.py | 500 +++++++++++++++++- 9 files changed, 523 insertions(+), 13 deletions(-) diff --git a/paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc b/paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc index 80b2015f7318ad..95e96b6d7918cb 100644 --- a/paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc @@ -20,4 +20,6 @@ PD_REGISTER_KERNEL(triangular_solve_grad, ALL_LAYOUT, phi::TriangularSolveGradKernel, float, - double) {} + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/triangular_solve_kernel.cc b/paddle/phi/kernels/cpu/triangular_solve_kernel.cc index 6245eb90426405..68af8bc2b1e924 100644 --- a/paddle/phi/kernels/cpu/triangular_solve_kernel.cc +++ b/paddle/phi/kernels/cpu/triangular_solve_kernel.cc @@ -82,4 +82,6 @@ PD_REGISTER_KERNEL(triangular_solve, ALL_LAYOUT, phi::TriangularSolveKernel, float, - double) {} + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.h b/paddle/phi/kernels/funcs/blas/blas_impl.h index ffafe15b8fcf2d..b4ee437011f665 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.h @@ -877,7 +877,7 @@ struct CBlas> { const phi::dtype::complex alpha, const phi::dtype::complex *A, const int lda, - phi::dtype::complex *B, + phi::dtype::complex *B, const int ldb) { cblas_ctrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb); } diff --git a/paddle/phi/kernels/funcs/matrix_reduce.cc b/paddle/phi/kernels/funcs/matrix_reduce.cc index e20d98984eb5aa..03bdc820abe07d 100644 --- a/paddle/phi/kernels/funcs/matrix_reduce.cc +++ b/paddle/phi/kernels/funcs/matrix_reduce.cc @@ -55,6 +55,8 @@ class MatrixReduceSumFunctor { template class MatrixReduceSumFunctor; template class MatrixReduceSumFunctor; +template class MatrixReduceSumFunctor, CPUContext>; +template class MatrixReduceSumFunctor, CPUContext>; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/matrix_reduce.cu b/paddle/phi/kernels/funcs/matrix_reduce.cu index f4305914c41713..39bb62a6bf3037 100644 --- a/paddle/phi/kernels/funcs/matrix_reduce.cu +++ b/paddle/phi/kernels/funcs/matrix_reduce.cu @@ -52,6 +52,8 @@ class MatrixReduceSumFunctor { template class MatrixReduceSumFunctor; template class MatrixReduceSumFunctor; +template class MatrixReduceSumFunctor, GPUContext>; +template class MatrixReduceSumFunctor, GPUContext>; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu b/paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu index f7eaa485797947..67861b282529b0 100644 --- a/paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu @@ -20,4 +20,6 @@ PD_REGISTER_KERNEL(triangular_solve_grad, ALL_LAYOUT, phi::TriangularSolveGradKernel, float, - double) {} + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/triangular_solve_kernel.cu b/paddle/phi/kernels/gpu/triangular_solve_kernel.cu index 2a943fd0ac6815..342b8e3885d7b6 100644 --- a/paddle/phi/kernels/gpu/triangular_solve_kernel.cu +++ b/paddle/phi/kernels/gpu/triangular_solve_kernel.cu @@ -128,4 +128,6 @@ PD_REGISTER_KERNEL(triangular_solve, ALL_LAYOUT, phi::TriangularSolveKernel, float, - double) {} + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index fe80aaa10d6350..92a19d6713d22e 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -3192,9 +3192,9 @@ def triangular_solve( Args: x (Tensor): The input triangular coefficient matrix. Its shape should be `[*, M, M]`, where `*` is zero or - more batch dimensions. Its data type should be float32 or float64. + more batch dimensions. Its data type should be float32, float64, complex64, complex128. y (Tensor): Multiple right-hand sides of system of equations. Its shape should be `[*, M, K]`, where `*` is - zero or more batch dimensions. Its data type should be float32 or float64. + zero or more batch dimensions. Its data type should be float32, float64, complex64, complex128. upper (bool, optional): Whether to solve the upper-triangular system of equations (default) or the lower-triangular system of equations. Default: True. transpose (bool, optional): whether `x` should be transposed before calculation. Default: False. @@ -3233,10 +3233,16 @@ def triangular_solve( inputs = {"X": [x], "Y": [y]} helper = LayerHelper("triangular_solve", **locals()) check_variable_and_dtype( - x, 'x', ['float32', 'float64'], 'triangular_solve' + x, + 'x', + ['float32', 'float64', 'complex64', 'complex128'], + 'triangular_solve', ) check_variable_and_dtype( - y, 'y', ['float32', 'float64'], 'triangular_solve' + y, + 'y', + ['float32', 'float64', 'complex64', 'complex128'], + 'triangular_solve', ) out = helper.create_variable_for_type_inference(dtype=x.dtype) diff --git a/test/legacy_test/test_triangular_solve_op.py b/test/legacy_test/test_triangular_solve_op.py index f3624b53328175..d4aecda8780ce5 100644 --- a/test/legacy_test/test_triangular_solve_op.py +++ b/test/legacy_test/test_triangular_solve_op.py @@ -51,10 +51,23 @@ def setUp(self): self.python_api = paddle.tensor.linalg.triangular_solve self.config() - self.inputs = { - 'X': np.random.random(self.x_shape).astype(self.dtype), - 'Y': np.random.random(self.y_shape).astype(self.dtype), - } + if self.dtype is np.complex64 or self.dtype is np.complex128: + self.inputs = { + 'X': ( + np.random.random(self.x_shape) + + 1j * np.random.random(self.x_shape) + ).astype(self.dtype), + 'Y': ( + np.random.random(self.y_shape) + + 1j * np.random.random(self.y_shape) + ).astype(self.dtype), + } + else: + self.inputs = { + 'X': np.random.random(self.x_shape).astype(self.dtype), + 'Y': np.random.random(self.y_shape).astype(self.dtype), + } + self.attrs = { 'upper': self.upper, 'transpose': self.transpose, @@ -248,6 +261,485 @@ def set_output(self): self.output = np.matmul(np.linalg.inv(x), y) +# 3D(broadcast) + 3D complex64 +class TestTriangularSolveOpCp643b3(TestTriangularSolveOp): + """ + case 10 + """ + + def config(self): + self.x_shape = [1, 10, 10] + self.y_shape = [6, 10, 12] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex64 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + +# 2D + 2D upper complex64 +class TestTriangularSolveOpCp6422Up(TestTriangularSolveOp): + """ + case 11 + """ + + def config(self): + self.x_shape = [12, 12] + self.y_shape = [12, 10] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex64 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + max_relative_error=0.02, + ) + + +# 2D(broadcast) + 3D, test 'transpose' complex64 +class TestTriangularSolveOpCp6423T(TestTriangularSolveOp): + """ + case 12 + """ + + def config(self): + self.x_shape = [10, 10] + self.y_shape = [3, 10, 8] + self.upper = False + self.transpose = True + self.unitriangular = False + self.dtype = np.complex64 + + def set_output(self): + x = np.tril(self.inputs['X']).transpose(1, 0) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + +# 2D + 2D , test 'unitriangular' complex64 +class TestTriangularSolveOpCp6422Un(TestTriangularSolveOp): + """ + case 13 + """ + + def config(self): + self.x_shape = [10, 10] + self.y_shape = [10, 10] + self.upper = True + self.transpose = False + self.unitriangular = True + self.dtype = np.complex64 + + def set_output(self): + x = np.triu(self.inputs['X']) + np.fill_diagonal(x, 1.0 + 0j) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out') + + +# 4D(broadcast) + 4D(broadcast) complex64 +class TestTriangularSolveOpCp644b4b(TestTriangularSolveOp): + """ + case 14 + """ + + def config(self): + self.x_shape = [1, 3, 10, 10] + self.y_shape = [2, 3, 10, 5] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex64 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + max_relative_error=0.008, + ) + + +# 3D(broadcast) + 4D(broadcast), test 'upper' complex64 +class TestTriangularSolveOpCp643b4bUp(TestTriangularSolveOp): + """ + case 15 + """ + + def config(self): + self.x_shape = [2, 10, 10] + self.y_shape = [5, 1, 10, 2] + self.upper = True + self.transpose = True + self.unitriangular = False + self.dtype = np.complex64 + + def set_output(self): + x = np.triu(self.inputs['X']).transpose(0, 2, 1) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + +# 3D(broadcast) + 5D complex64 +class TestTriangularSolveOpCp643b5(TestTriangularSolveOp): + """ + case 16 + """ + + def config(self): + self.x_shape = [12, 3, 3] + self.y_shape = [2, 3, 12, 3, 2] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex64 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + +# 5D + 4D(broadcast) complex64 +class TestTriangularSolveOpCp6454b(TestTriangularSolveOp): + """ + case 17 + """ + + def config(self): + self.x_shape = [2, 4, 2, 3, 3] + self.y_shape = [4, 1, 3, 10] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex64 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.matmul(np.linalg.inv(x), y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + +# 3D(broadcast) + 3D complex128 +class TestTriangularSolveOpCp1283b3(TestTriangularSolveOp): + """ + case 18 + """ + + def config(self): + self.x_shape = [1, 10, 10] + self.y_shape = [6, 10, 12] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex128 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + +# 2D + 2D upper complex128 +class TestTriangularSolveOpCp12822Up(TestTriangularSolveOp): + """ + case 19 + """ + + def config(self): + self.x_shape = [12, 12] + self.y_shape = [12, 10] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex128 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + +# 2D(broadcast) + 3D, test 'transpose' complex128 +class TestTriangularSolveOpCp12823T(TestTriangularSolveOp): + """ + case 20 + """ + + def config(self): + self.x_shape = [10, 10] + self.y_shape = [3, 10, 8] + self.upper = False + self.transpose = True + self.unitriangular = False + self.dtype = np.complex128 + + def set_output(self): + x = np.tril(self.inputs['X']).transpose(1, 0) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + +# 2D + 2D , test 'unitriangular' complex128 +class TestTriangularSolveOpCp12822Un(TestTriangularSolveOp): + """ + case 21 + """ + + def config(self): + self.x_shape = [10, 10] + self.y_shape = [10, 10] + self.upper = True + self.transpose = False + self.unitriangular = True + self.dtype = np.complex128 + + def set_output(self): + x = np.triu(self.inputs['X']) + np.fill_diagonal(x, 1.0 + 0j) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + ) + + +# 4D(broadcast) + 4D(broadcast) complex128 +class TestTriangularSolveOpCp1284b4b(TestTriangularSolveOp): + """ + case 22 + """ + + def config(self): + self.x_shape = [1, 3, 10, 10] + self.y_shape = [2, 3, 10, 5] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex128 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + +# 3D(broadcast) + 4D(broadcast), test 'upper' complex128 +class TestTriangularSolveOpCp1283b4bUp(TestTriangularSolveOp): + """ + case 23 + """ + + def config(self): + self.x_shape = [2, 10, 10] + self.y_shape = [5, 1, 10, 2] + self.upper = True + self.transpose = True + self.unitriangular = False + self.dtype = np.complex128 + + def set_output(self): + x = np.triu(self.inputs['X']).transpose(0, 2, 1) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + +# 3D(broadcast) + 5D complex128 +class TestTriangularSolveOpCp1283b5(TestTriangularSolveOp): + """ + case 24 + """ + + def config(self): + self.x_shape = [12, 3, 3] + self.y_shape = [2, 3, 12, 3, 2] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex128 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + +# 5D + 4D(broadcast) complex128 +class TestTriangularSolveOpCp12854b(TestTriangularSolveOp): + """ + case 25 + """ + + def config(self): + self.x_shape = [2, 4, 2, 3, 3] + self.y_shape = [4, 1, 3, 10] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = np.complex128 + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.matmul(np.linalg.inv(x), y) + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + ) + + class TestTriangularSolveAPI(unittest.TestCase): def setUp(self): np.random.seed(2021) From 2ad9e24a02e752b89d036756f484de6e7a86a11a Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Thu, 4 Jan 2024 07:23:07 +0800 Subject: [PATCH 091/142] fix reshard dist_attr (#60535) --- paddle/phi/core/distributed/auto_parallel/dist_tensor.cc | 3 +++ .../distributed/auto_parallel/reshard/reshard_function.cc | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc index 626b5bdf5e4413..fff9af10339a60 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc @@ -117,6 +117,9 @@ DistTensor::DistTensor() : value_(std::make_shared()) {} DistTensor::DistTensor(const std::shared_ptr& global_value, const TensorDistAttr& dist_attr) : global_dims_(global_value->dims()), dist_attr_(dist_attr) { + process_mesh_ = dist_attr_.process_mesh(); + placements_ = ToPlacements(dist_attr); + // If the current rank doesn't in process_mesh, we should create an // uninitialized tensor only with tensor_meta. if (IsCurRankInMesh(dist_attr.process_mesh())) { diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.cc index 9644c0b28e916e..99da6feb54eba0 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_function.cc @@ -52,6 +52,8 @@ void ReshardFunction::SetDistProps(DistTensor* tensor, tensor->global_dims_ = dims; tensor->dist_attr_ = dist_attr; + tensor->process_mesh_ = dist_attr.process_mesh(); + tensor->placements_ = ToPlacements(dist_attr); } void ReshardFunction::SetDistProps(DistTensor* tensor, @@ -64,6 +66,8 @@ void ReshardFunction::SetDistProps(DistTensor* tensor, str_join(vectorize(tensor->dims())))); tensor->dist_attr_ = dist_attr; + tensor->process_mesh_ = dist_attr.process_mesh(); + tensor->placements_ = ToPlacements(dist_attr); } DenseTensor* ReshardFunction::GetMutableTensor(DistTensor* tensor) { From 353cb274dbc03b7eeabe1de5522ed85f7292a771 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Thu, 4 Jan 2024 10:40:34 +0800 Subject: [PATCH 092/142] =?UTF-8?q?=E3=80=90auto=20parallel=E3=80=91?= =?UTF-8?q?=E5=89=94=E9=99=A4=E5=88=87=E5=88=86=E6=8E=A8=E5=AF=BC=E7=9B=B8?= =?UTF-8?q?=E5=85=B3=E7=9A=84=E5=A4=B4=E6=96=87=E4=BB=B6=E5=AF=B9proto=20?= =?UTF-8?q?=E7=9A=84=E4=BE=9D=E8=B5=96=20(#60543)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * decouple proto * format * format * strcuct pre def --- .../distributed/auto_parallel/dist_attr.cc | 10 ++- .../distributed/auto_parallel/CMakeLists.txt | 1 + .../distributed/auto_parallel/device_mesh.cc | 70 ++++++++----------- .../distributed/auto_parallel/device_mesh.h | 17 +++-- .../distributed/auto_parallel/dist_attr.cc | 20 +++--- .../distributed/auto_parallel/dist_attr.h | 7 +- .../distributed/auto_parallel/dist_mapper.cc | 10 +-- .../distributed/auto_parallel/dist_mapper.h | 5 +- .../distributed/auto_parallel/process_mesh.cc | 13 ++-- .../distributed/auto_parallel/process_mesh.h | 6 +- .../distributed/auto_parallel/proto_helper.cc | 65 +++++++++++++++++ .../distributed/auto_parallel/proto_helper.h | 43 ++++++++++++ test/cpp/auto_parallel/device_mesh_test.cc | 7 +- test/cpp/auto_parallel/dist_attr_test.cc | 3 +- test/cpp/auto_parallel/dist_mapper_test.cc | 3 +- test/cpp/auto_parallel/process_mesh_test.cc | 3 +- 16 files changed, 202 insertions(+), 81 deletions(-) create mode 100644 paddle/phi/core/distributed/auto_parallel/proto_helper.cc create mode 100644 paddle/phi/core/distributed/auto_parallel/proto_helper.h diff --git a/paddle/fluid/distributed/auto_parallel/dist_attr.cc b/paddle/fluid/distributed/auto_parallel/dist_attr.cc index 8e96ec90723e64..805641cf01837f 100644 --- a/paddle/fluid/distributed/auto_parallel/dist_attr.cc +++ b/paddle/fluid/distributed/auto_parallel/dist_attr.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/var_desc.h" +#include "paddle/phi/core/distributed/auto_parallel/proto_helper.h" namespace paddle { namespace distributed { @@ -406,14 +407,17 @@ OperatorDistAttrProto OperatorDistAttr::to_proto() const { for (const auto& item : input_dist_attrs_) { auto proto_item = proto.mutable_input_dist_attrs()->Add(); proto_item->set_name(item.first); - proto_item->mutable_tensor_dist_attr()->CopyFrom(item.second.to_proto()); + proto_item->mutable_tensor_dist_attr()->CopyFrom( + phi::distributed::to_proto(item.second)); } for (const auto& item : output_dist_attrs_) { auto proto_item = proto.mutable_output_dist_attrs()->Add(); proto_item->set_name(item.first); - proto_item->mutable_tensor_dist_attr()->CopyFrom(item.second.to_proto()); + proto_item->mutable_tensor_dist_attr()->CopyFrom( + phi::distributed::to_proto(item.second)); } - proto.mutable_process_mesh()->CopyFrom(process_mesh_.to_proto()); + proto.mutable_process_mesh()->CopyFrom( + phi::distributed::to_proto(process_mesh_)); proto.set_impl_type(impl_type_); proto.set_impl_idx(impl_idx_); proto.set_chunk_id(chunk_id_); diff --git a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt index 42b2ef2d30a9aa..8cc7f8f20388ee 100644 --- a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt +++ b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt @@ -9,6 +9,7 @@ collect_srcs( dist_mapper.cc dist_tensor.cc dist_meta_tensor.cc + proto_helper.cc placement_types.cc inferspmd_utils.cc) diff --git a/paddle/phi/core/distributed/auto_parallel/device_mesh.cc b/paddle/phi/core/distributed/auto_parallel/device_mesh.cc index 98291dabe12ea5..32030b05b55fdc 100644 --- a/paddle/phi/core/distributed/auto_parallel/device_mesh.cc +++ b/paddle/phi/core/distributed/auto_parallel/device_mesh.cc @@ -16,8 +16,8 @@ limitations under the License. */ #include #include "paddle/phi/core/distributed/auto_parallel/device_mesh.h" +#include "paddle/phi/core/distributed/auto_parallel/proto_helper.h" #include "paddle/phi/core/distributed/auto_parallel/utils.h" - namespace phi { namespace distributed { namespace auto_parallel { @@ -41,13 +41,11 @@ DeviceCapability DeviceCapability::from_proto( return capability; } -DeviceCapabilityProto DeviceCapability::to_proto() const { - DeviceCapabilityProto proto; - proto.set_single_precision_flops(single_precision_flops); - proto.set_double_precision_flops(double_precision_flops); - proto.set_memory_size_in_bytes(memory_size_in_bytes); - proto.set_clock_rate_in_ghz(clock_rate_in_ghz); - return proto; +void DeviceCapability::to_proto(DeviceCapabilityProto *proto) const { + proto->set_single_precision_flops(single_precision_flops); + proto->set_double_precision_flops(double_precision_flops); + proto->set_memory_size_in_bytes(memory_size_in_bytes); + proto->set_clock_rate_in_ghz(clock_rate_in_ghz); } std::string Device::to_string() const { @@ -69,14 +67,13 @@ Device Device::from_proto(const DeviceProto &proto) { return device; } -DeviceProto Device::to_proto() const { - DeviceProto proto; - proto.set_global_id(global_id_); - proto.set_local_id(local_id_); - proto.set_machine_id(machine_id_); - proto.set_type(type_); - proto.mutable_capability()->CopyFrom(capability_.to_proto()); - return proto; +void Device::to_proto(DeviceProto *proto) const { + proto->set_global_id(global_id_); + proto->set_local_id(local_id_); + proto->set_machine_id(machine_id_); + proto->set_type(type_); + proto->mutable_capability()->CopyFrom( + phi::distributed::to_proto(capability_)); } bool operator==(const Device &lhs, const Device &rhs) { @@ -109,11 +106,9 @@ LinkCapability LinkCapability::from_proto(const LinkCapabilityProto &proto) { return capability; } -LinkCapabilityProto LinkCapability::to_proto() const { - LinkCapabilityProto proto; - proto.set_bandwidth(bandwidth); - proto.set_latency(latency); - return proto; +void LinkCapability::to_proto(LinkCapabilityProto *proto) const { + proto->set_bandwidth(bandwidth); + proto->set_latency(latency); } std::string Link::to_string() const { @@ -133,13 +128,12 @@ Link Link::from_proto(const LinkProto &proto) { return link; } -LinkProto Link::to_proto() const { - LinkProto proto; - proto.set_source_id(source_id_); - proto.set_target_id(target_id_); - proto.set_type(type_); - proto.mutable_capability()->CopyFrom(capability_.to_proto()); - return proto; +void Link::to_proto(LinkProto *proto) const { + proto->set_source_id(source_id_); + proto->set_target_id(target_id_); + proto->set_type(type_); + proto->mutable_capability()->CopyFrom( + phi::distributed::to_proto(capability_)); } bool operator==(const Link &lhs, const Link &rhs) { @@ -355,34 +349,32 @@ DeviceMesh DeviceMesh::from_proto(const DeviceMeshProto &proto) { return mesh; } -DeviceMeshProto DeviceMesh::to_proto() const { - DeviceMeshProto proto; - - proto.set_name(name_); +void DeviceMesh::to_proto(DeviceMeshProto *proto) const { + proto->set_name(name_); for (const auto &i : shape_) { - proto.add_shape(i); + proto->add_shape(i); } for (const auto &i : device_ids_) { - proto.add_device_ids(i); + proto->add_device_ids(i); } for (const auto &i : dim_names_) { - proto.add_dim_names(i); + proto->add_dim_names(i); } for (const auto &device : devices_) { - proto.mutable_devices()->Add()->CopyFrom(device.second.to_proto()); + proto->mutable_devices()->Add()->CopyFrom( + phi::distributed::to_proto(device.second)); } for (const auto &neighbors : links_) { for (const auto &link : neighbors.second) { - proto.mutable_links()->Add()->CopyFrom(link.second.to_proto()); + proto->mutable_links()->Add()->CopyFrom( + phi::distributed::to_proto(link.second)); } } - - return proto; } bool operator==(const DeviceMesh &lhs, const DeviceMesh &rhs) { diff --git a/paddle/phi/core/distributed/auto_parallel/device_mesh.h b/paddle/phi/core/distributed/auto_parallel/device_mesh.h index 0888d5e2e7a2a6..8cfdc6ed242f0e 100644 --- a/paddle/phi/core/distributed/auto_parallel/device_mesh.h +++ b/paddle/phi/core/distributed/auto_parallel/device_mesh.h @@ -30,6 +30,13 @@ limitations under the License. */ namespace phi { namespace distributed { namespace auto_parallel { + +class DeviceCapabilityProto; +class DeviceProto; +class LinkCapabilityProto; +class LinkProto; +class DeviceMeshProto; + struct DeviceCapability { double single_precision_flops = 0.0; double double_precision_flops = 0.0; @@ -40,7 +47,7 @@ struct DeviceCapability { std::string to_string() const; static DeviceCapability from_proto(const DeviceCapabilityProto& proto); - DeviceCapabilityProto to_proto() const; + void to_proto(DeviceCapabilityProto* proto) const; }; inline std::ostream& operator<<(std::ostream& os, const DeviceCapability& obj) { @@ -74,7 +81,7 @@ class Device { std::string to_string() const; static Device from_proto(const DeviceProto& proto); - DeviceProto to_proto() const; + void to_proto(DeviceProto* proto) const; private: int64_t global_id_; @@ -103,7 +110,7 @@ struct LinkCapability { std::string to_string() const; static LinkCapability from_proto(const LinkCapabilityProto& proto); - LinkCapabilityProto to_proto() const; + void to_proto(LinkCapabilityProto* proto) const; }; inline std::ostream& operator<<(std::ostream& os, const LinkCapability& obj) { @@ -131,7 +138,7 @@ class Link { std::string to_string() const; static Link from_proto(const LinkProto& proto); - LinkProto to_proto() const; + void to_proto(LinkProto* proto) const; private: int64_t source_id_; @@ -273,7 +280,7 @@ class DeviceMesh { std::string to_string() const; static DeviceMesh from_proto(const DeviceMeshProto& proto); - DeviceMeshProto to_proto() const; + void to_proto(DeviceMeshProto* proto) const; private: std::string name_; diff --git a/paddle/phi/core/distributed/auto_parallel/dist_attr.cc b/paddle/phi/core/distributed/auto_parallel/dist_attr.cc index e60c150f0fb725..1477504be4a73d 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_attr.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_attr.cc @@ -19,6 +19,7 @@ limitations under the License. */ #include #include "glog/logging.h" +#include "paddle/phi/core/distributed/auto_parallel/proto_helper.h" namespace phi { namespace distributed { @@ -308,25 +309,24 @@ void TensorDistAttr::from_proto(const TensorDistAttrProto& proto) { } } -TensorDistAttrProto TensorDistAttr::to_proto() const { - TensorDistAttrProto proto; - proto.mutable_process_mesh()->CopyFrom(process_mesh_.to_proto()); +void TensorDistAttr::to_proto(TensorDistAttrProto* proto) const { + proto->mutable_process_mesh()->CopyFrom( + phi::distributed::to_proto(process_mesh_)); for (const auto& i : dims_mapping_) { - proto.add_dims_mapping(i); + proto->add_dims_mapping(i); } - proto.set_batch_dim(batch_dim_); - proto.set_chunk_id(chunk_id_); + proto->set_batch_dim(batch_dim_); + proto->set_chunk_id(chunk_id_); for (const auto& i : dynamic_dims_) { - proto.add_dynamic_dims(i); + proto->add_dynamic_dims(i); } - return proto; } std::string TensorDistAttr::serialize_to_string() { std::string data; - auto proto = to_proto(); + auto proto = phi::distributed::to_proto(*this); proto.SerializeToString(&data); - PADDLE_ENFORCE_EQ(to_proto().SerializeToString(&data), + PADDLE_ENFORCE_EQ(phi::distributed::to_proto(*this).SerializeToString(&data), true, errors::InvalidArgument( "Failed to serialize tensor dist attr to string.")); diff --git a/paddle/phi/core/distributed/auto_parallel/dist_attr.h b/paddle/phi/core/distributed/auto_parallel/dist_attr.h index d5232ab836261b..d158fc848c8d40 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_attr.h +++ b/paddle/phi/core/distributed/auto_parallel/dist_attr.h @@ -22,7 +22,6 @@ limitations under the License. */ #include #include "paddle/phi/common/reduce_type.h" -#include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h" #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" #include "paddle/phi/core/distributed/auto_parallel/utils.h" #include "paddle/phi/core/enforce.h" @@ -32,6 +31,10 @@ limitations under the License. */ namespace phi { namespace distributed { +namespace auto_parallel { +class TensorDistAttrProto; +} + constexpr int kReplicateDim = -1; class PlacementStatus { @@ -169,7 +172,7 @@ class TEST_API TensorDistAttr { // future partial-support-stage-II. void from_proto(const auto_parallel::TensorDistAttrProto& proto); - auto_parallel::TensorDistAttrProto to_proto() const; + void to_proto(auto_parallel::TensorDistAttrProto* proto) const; std::string serialize_to_string(); diff --git a/paddle/phi/core/distributed/auto_parallel/dist_mapper.cc b/paddle/phi/core/distributed/auto_parallel/dist_mapper.cc index fdfa5907bbe550..6b5ff5625f63ee 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_mapper.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_mapper.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include "paddle/phi/core/distributed/auto_parallel/dist_mapper.h" +#include "paddle/phi/core/distributed/auto_parallel/proto_helper.h" #include "paddle/phi/core/distributed/auto_parallel/utils.h" namespace phi { @@ -91,20 +92,19 @@ DistributedMapper DistributedMapper::from_proto( return dist_mapper; } -DistributedMapperProto DistributedMapper::to_proto() const { - DistributedMapperProto proto; +void DistributedMapper::to_proto(DistributedMapperProto* proto) const { for (const auto& item : device_meshes_) { - proto.mutable_device_meshes()->Add()->CopyFrom(item.second.to_proto()); + proto->mutable_device_meshes()->Add()->CopyFrom( + phi::distributed::to_proto(item.second)); } for (const auto& outer : process_id_to_device_ids_) { - auto proto_item = proto.mutable_process_id_to_device_ids()->Add(); + auto proto_item = proto->mutable_process_id_to_device_ids()->Add(); proto_item->set_process_id(outer.first); proto_item->set_device_mesh_name(outer.second.first); for (const auto& inner : outer.second.second) { proto_item->add_device_ids(inner); } } - return proto; } std::string DistributedMapper::to_string() const { diff --git a/paddle/phi/core/distributed/auto_parallel/dist_mapper.h b/paddle/phi/core/distributed/auto_parallel/dist_mapper.h index 527801761e70f5..5436bc7a6cb5b3 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_mapper.h +++ b/paddle/phi/core/distributed/auto_parallel/dist_mapper.h @@ -15,7 +15,6 @@ limitations under the License. */ #include -#include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h" #include "paddle/phi/core/distributed/auto_parallel/device_mesh.h" #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" @@ -23,6 +22,8 @@ namespace phi { namespace distributed { namespace auto_parallel { +class DistributedMapperProto; + class DistributedMapper { public: DistributedMapper() = default; @@ -52,7 +53,7 @@ class DistributedMapper { std::string to_string() const; static DistributedMapper from_proto(const DistributedMapperProto& proto); - DistributedMapperProto to_proto() const; + void to_proto(DistributedMapperProto* proto) const; private: std::map device_meshes_; diff --git a/paddle/phi/core/distributed/auto_parallel/process_mesh.cc b/paddle/phi/core/distributed/auto_parallel/process_mesh.cc index 27702a9c8b9d81..a1b60e27c27e67 100644 --- a/paddle/phi/core/distributed/auto_parallel/process_mesh.cc +++ b/paddle/phi/core/distributed/auto_parallel/process_mesh.cc @@ -16,7 +16,6 @@ limitations under the License. */ #include #include - #include "paddle/phi/core/distributed/auto_parallel/utils.h" namespace phi { @@ -105,22 +104,18 @@ ProcessMesh ProcessMesh::from_proto(const ProcessMeshProto &proto) { return mesh; } -ProcessMeshProto ProcessMesh::to_proto() const { - ProcessMeshProto proto; - +void ProcessMesh::to_proto(ProcessMeshProto *proto) const { for (const auto &i : shape_) { - proto.add_shape(i); + proto->add_shape(i); } for (const auto &i : process_ids_) { - proto.add_process_ids(i); + proto->add_process_ids(i); } for (const auto &i : dim_names_) { - proto.add_dim_names(i); + proto->add_dim_names(i); } - - return proto; } bool operator==(const ProcessMesh &lhs, const ProcessMesh &rhs) { diff --git a/paddle/phi/core/distributed/auto_parallel/process_mesh.h b/paddle/phi/core/distributed/auto_parallel/process_mesh.h index d512255ec10359..792d5e38f5318b 100644 --- a/paddle/phi/core/distributed/auto_parallel/process_mesh.h +++ b/paddle/phi/core/distributed/auto_parallel/process_mesh.h @@ -28,6 +28,10 @@ limitations under the License. */ namespace phi { namespace distributed { +namespace auto_parallel { +class ProcessMeshProto; +} + class ProcessMesh { public: ProcessMesh() = default; @@ -68,7 +72,7 @@ class ProcessMesh { std::string to_string() const; static ProcessMesh from_proto(const auto_parallel::ProcessMeshProto& proto); - auto_parallel::ProcessMeshProto to_proto() const; + void to_proto(auto_parallel::ProcessMeshProto* proto) const; private: std::vector shape_; diff --git a/paddle/phi/core/distributed/auto_parallel/proto_helper.cc b/paddle/phi/core/distributed/auto_parallel/proto_helper.cc new file mode 100644 index 00000000000000..e8e4197a63c08a --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/proto_helper.cc @@ -0,0 +1,65 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/auto_parallel/proto_helper.h" +#include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h" +#include "paddle/phi/core/distributed/auto_parallel/device_mesh.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_mapper.h" + +#define TO_PROTO_HELPER(object, proto_type) \ + proto_type proto; \ + object.to_proto(&proto); \ + return proto + +namespace phi { +namespace distributed { + +auto_parallel::TensorDistAttrProto to_proto(const TensorDistAttr& dist_attr) { + TO_PROTO_HELPER(dist_attr, auto_parallel::TensorDistAttrProto); +} + +auto_parallel::ProcessMeshProto to_proto(const ProcessMesh& process_mesh) { + TO_PROTO_HELPER(process_mesh, auto_parallel::ProcessMeshProto); +} + +auto_parallel::DeviceCapabilityProto to_proto( + const auto_parallel::DeviceCapability& device_capibilty) { + TO_PROTO_HELPER(device_capibilty, auto_parallel::DeviceCapabilityProto); +} + +auto_parallel::DeviceProto to_proto(const auto_parallel::Device& device) { + TO_PROTO_HELPER(device, auto_parallel::DeviceProto); +} + +auto_parallel::LinkCapabilityProto to_proto( + const auto_parallel::LinkCapability& link_capibilty) { + TO_PROTO_HELPER(link_capibilty, auto_parallel::LinkCapabilityProto); +} + +auto_parallel::LinkProto to_proto(const auto_parallel::Link& link) { + TO_PROTO_HELPER(link, auto_parallel::LinkProto); +} + +auto_parallel::DeviceMeshProto to_proto( + const auto_parallel::DeviceMesh& device_mesh) { + TO_PROTO_HELPER(device_mesh, auto_parallel::DeviceMeshProto); +} + +auto_parallel::DistributedMapperProto to_proto( + const auto_parallel::DistributedMapper& dist_mapper) { + TO_PROTO_HELPER(dist_mapper, auto_parallel::DistributedMapperProto); +} +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/proto_helper.h b/paddle/phi/core/distributed/auto_parallel/proto_helper.h new file mode 100644 index 00000000000000..66bdf2af74406d --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/proto_helper.h @@ -0,0 +1,43 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h" +namespace phi { +namespace distributed { +class TensorDistAttr; +class ProcessMesh; +namespace auto_parallel { +struct DeviceCapability; +class Device; +struct LinkCapability; +class Link; +class DeviceMesh; +class DistributedMapper; +} // namespace auto_parallel +auto_parallel::TensorDistAttrProto to_proto(const TensorDistAttr& dist_attr); +auto_parallel::ProcessMeshProto to_proto(const ProcessMesh& dist_attr); + +auto_parallel::DeviceCapabilityProto to_proto( + const auto_parallel::DeviceCapability& device_capibilty); +auto_parallel::DeviceProto to_proto(const auto_parallel::Device& device); +auto_parallel::LinkCapabilityProto to_proto( + const auto_parallel::LinkCapability& link_capibilty); +auto_parallel::LinkProto to_proto(const auto_parallel::Link& link); +auto_parallel::DeviceMeshProto to_proto(const auto_parallel::DeviceMesh& link); +auto_parallel::DistributedMapperProto to_proto( + const auto_parallel::DistributedMapper& dist_mapper); + +} // namespace distributed +} // namespace phi diff --git a/test/cpp/auto_parallel/device_mesh_test.cc b/test/cpp/auto_parallel/device_mesh_test.cc index 9c3c47de921425..d0648fdc97eaf8 100644 --- a/test/cpp/auto_parallel/device_mesh_test.cc +++ b/test/cpp/auto_parallel/device_mesh_test.cc @@ -12,9 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/phi/core/distributed/auto_parallel/device_mesh.h" #include #include + +#include "paddle/phi/core/distributed/auto_parallel/device_mesh.h" +#include "paddle/phi/core/distributed/auto_parallel/proto_helper.h" + #include "gtest/gtest.h" namespace phi { @@ -83,7 +86,7 @@ TEST(DeviceMesh, Ctor) { std::stringstream sstream; sstream << device_mesh; EXPECT_EQ(sstream.str(), device_mesh.to_string()); - auto proto = device_mesh.to_proto(); + auto proto = phi::distributed::to_proto(device_mesh); DeviceMesh new_device_mesh = DeviceMesh::from_proto(proto); EXPECT_EQ(device_mesh, new_device_mesh); } diff --git a/test/cpp/auto_parallel/dist_attr_test.cc b/test/cpp/auto_parallel/dist_attr_test.cc index 383e0c8bf5ba70..a68a56d4003b95 100644 --- a/test/cpp/auto_parallel/dist_attr_test.cc +++ b/test/cpp/auto_parallel/dist_attr_test.cc @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/var_desc.h" +#include "paddle/phi/core/distributed/auto_parallel/proto_helper.h" namespace phi { namespace distributed { @@ -99,7 +100,7 @@ TEST(DistAttr, ctor) { std::stringstream x_sstream; x_sstream << x_dist_attr; EXPECT_EQ(x_sstream.str(), x_dist_attr.to_string()); - auto x_proto = x_dist_attr.to_proto(); + auto x_proto = phi::distributed::to_proto(x_dist_attr); TensorDistAttr new_x_dist_attr = get_dist_attr(x); new_x_dist_attr.from_proto(x_proto); EXPECT_EQ(x_dist_attr, new_x_dist_attr); diff --git a/test/cpp/auto_parallel/dist_mapper_test.cc b/test/cpp/auto_parallel/dist_mapper_test.cc index bc426e5651f97c..30fc6ec3c297f7 100644 --- a/test/cpp/auto_parallel/dist_mapper_test.cc +++ b/test/cpp/auto_parallel/dist_mapper_test.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include #include "gtest/gtest.h" +#include "paddle/phi/core/distributed/auto_parallel/proto_helper.h" namespace phi { namespace distributed { @@ -62,7 +63,7 @@ TEST(DistributedMapper, Ctor) { std::stringstream sstream; sstream << dist_mapper; EXPECT_EQ(sstream.str(), dist_mapper.to_string()); - auto proto = dist_mapper.to_proto(); + auto proto = phi::distributed::to_proto(dist_mapper); DistributedMapper new_dist_mapper = DistributedMapper::from_proto(proto); EXPECT_EQ(dist_mapper, new_dist_mapper); } diff --git a/test/cpp/auto_parallel/process_mesh_test.cc b/test/cpp/auto_parallel/process_mesh_test.cc index 3e88f5629c624c..7e9b23062fbe53 100644 --- a/test/cpp/auto_parallel/process_mesh_test.cc +++ b/test/cpp/auto_parallel/process_mesh_test.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include #include "gtest/gtest.h" +#include "paddle/phi/core/distributed/auto_parallel/proto_helper.h" namespace phi { namespace distributed { @@ -43,7 +44,7 @@ TEST(ProcessMesh, Ctor) { std::stringstream sstream; sstream << process_mesh; EXPECT_EQ(sstream.str(), process_mesh.to_string()); - auto proto = process_mesh.to_proto(); + auto proto = phi::distributed::to_proto(process_mesh); ProcessMesh new_process_mesh = ProcessMesh::from_proto(proto); EXPECT_EQ(process_mesh, new_process_mesh); } From 6b2d74cc7bfa9054276b4bc359ff161fc03b2a52 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 4 Jan 2024 10:46:27 +0800 Subject: [PATCH 093/142] [PIR] Support Operation::Clone Interface (#60536) * [PIR] Support Operation::Clone Interface * modify into shared_ptr --- paddle/cinn/hlir/framework/pir/group.h | 44 +++++++++++++++++++++++++- paddle/pir/core/ir_mapping.h | 37 ++++++++++++++++++++++ paddle/pir/core/operation.cc | 26 +++++++++++++++ paddle/pir/core/operation.h | 22 ++++++++++++- 4 files changed, 127 insertions(+), 2 deletions(-) create mode 100644 paddle/pir/core/ir_mapping.h diff --git a/paddle/cinn/hlir/framework/pir/group.h b/paddle/cinn/hlir/framework/pir/group.h index 870a159d49c929..2cd3b9b9deddaa 100644 --- a/paddle/cinn/hlir/framework/pir/group.h +++ b/paddle/cinn/hlir/framework/pir/group.h @@ -13,8 +13,11 @@ // limitations under the License. #pragma once +#include #include +#include #include +#include "glog/logging.h" #include "paddle/cinn/adt/graph_symbolic_dim_infer_ctx.h" #include "paddle/cinn/hlir/framework/op.h" @@ -34,8 +37,17 @@ namespace framework { namespace pir { using framework::OpPatternKind; -// TODO(Aurelius84): Need to be replaced with CinnGroupOp struct Group { + // Control the clone strategy for Group. + class Options { + public: + Options() : only_clone_ops(true) {} + bool OnlyCloneOps() const { return only_clone_ops; } + + private: + bool only_clone_ops = false; + }; + public: Group() = default; Group(const Group&) = delete; @@ -47,6 +59,36 @@ struct Group { explicit Group(std::initializer_list<::pir::Operation*> group_ops) : ops(group_ops) {} + std::shared_ptr Clone(::pir::Block* target_block, + ::pir::IrMapping& ir_mapping, + const Options& option = Options()) const { + CHECK_EQ(option.OnlyCloneOps(), true) + << "Only Support Clone Group ops information."; + std::vector<::pir::Operation*> new_ops; + // Mapper from original to new ops. + std::unordered_map<::pir::Operation*, ::pir::Operation*> ops_mapper; + ::pir::CloneOptions clone_options(false, true); + for (auto* op : this->ops_set) { + auto* new_op = op->Clone(ir_mapping, clone_options); + // NOTE(dev): Must call MoveTo to deal with ownership, otherwise it + // will lead memory-leak. + new_op->MoveTo(target_block, target_block->end()); + new_ops.push_back(new_op); + ops_mapper[op] = new_op; + } + // Construct Base information for new Group + auto new_group = std::make_shared(new_ops); + this->CollectOps(); + for (auto& iter : this->input_ops) { + new_group->input_ops[ops_mapper[iter.first]] = iter.second; + } + for (auto* op : this->output_ops) { + new_group->output_ops.insert(ops_mapper[op]); + } + + return new_group; + } + // distance to last group. int depth{0}; int max_depth{0}; diff --git a/paddle/pir/core/ir_mapping.h b/paddle/pir/core/ir_mapping.h new file mode 100644 index 00000000000000..607c8cc0704f54 --- /dev/null +++ b/paddle/pir/core/ir_mapping.h @@ -0,0 +1,37 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include "paddle/common/enforce.h" +#include "paddle/pir/core/block.h" + +namespace pir { + +class IrMapping { + public: + void Add(Value from, Value to) { value_map_[from] = to; } + + Value Lookup(Value from) const { + IR_ENFORCE(value_map_.count(from) > 0, "Not Found Value in IRMapping."); + return value_map_.at(from); + } + void Earse(Value from) { value_map_.erase(from); } + + void Clear() { value_map_.clear(); } + + private: + std::unordered_map value_map_; +}; + +} // namespace pir diff --git a/paddle/pir/core/operation.cc b/paddle/pir/core/operation.cc index c0ce8842155ab6..0a8e26d788ca15 100644 --- a/paddle/pir/core/operation.cc +++ b/paddle/pir/core/operation.cc @@ -137,6 +137,32 @@ Operation *Operation::Create(const std::vector &inputs, return op; } +Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) { + IR_ENFORCE(options.IsCloneRegions() || num_regions_ > 0, + "Operation CloneOperands is unimplemented currently."); + IR_ENFORCE(num_successors_ == 0, + "Operation::Clone is not unimplemented for multiple successors."); + + auto inputs = operands_source(); + if (options.IsCloneOperands()) { + // replace value by IRMapping inplacely. + for (auto &value : inputs) { + value = ir_mapping.Lookup(value); + } + } + + std::vector output_types; + for (auto &result : results()) { + output_types.push_back(result.type()); + } + auto *new_op = Create(inputs, attributes_, output_types, info_, num_regions_); + // record outputs mapping info + for (uint32_t i = 0; i < num_results_; ++i) { + ir_mapping.Add(result(i), new_op->result(i)); + } + return new_op; +} + // Call destructors for Region , OpResults, Operation, and OpOperands in // sequence, and finally free memory. void Operation::Destroy() { diff --git a/paddle/pir/core/operation.h b/paddle/pir/core/operation.h index ea31c85ca7c261..0dafcdd7a5b2b4 100644 --- a/paddle/pir/core/operation.h +++ b/paddle/pir/core/operation.h @@ -20,12 +20,12 @@ #include "paddle/common/enforce.h" #include "paddle/common/macros.h" #include "paddle/pir/core/block.h" +#include "paddle/pir/core/ir_mapping.h" #include "paddle/pir/core/iterator.h" #include "paddle/pir/core/op_info.h" #include "paddle/pir/core/operation_utils.h" #include "paddle/pir/core/type.h" #include "paddle/pir/core/visitors.h" - namespace pir { class OpBase; class Program; @@ -37,6 +37,20 @@ class OpResultImpl; class OpOperendImpl; } // namespace detail +class CloneOptions { + public: + CloneOptions() : clone_regions_{false}, clone_operands_{false} {} + CloneOptions(bool clone_regions, bool clone_operands) + : clone_regions_(clone_regions), clone_operands_(clone_operands) {} + + bool IsCloneRegions() const { return clone_regions_; } + bool IsCloneOperands() const { return clone_operands_; } + + private: + bool clone_regions_{true}; + bool clone_operands_{true}; +}; + class IR_API alignas(8) Operation final : public DoubleLevelContainer { public: @@ -53,6 +67,12 @@ class IR_API alignas(8) Operation final size_t num_regions = 0, const std::vector &successors = {}); static Operation *Create(OperationArgument &&op_argument); + + /// + /// \brief Deep copy all information and create a new operation. + /// + Operation *Clone(IrMapping &ir_mapping, + CloneOptions options = CloneOptions()); /// /// \brief Destroy the operation objects and free memory by create(). /// From a05f19533d6a768fb025b47fe408edef374cb41e Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Thu, 4 Jan 2024 10:49:28 +0800 Subject: [PATCH 094/142] [Dynamic Shape] Add FullyInsertBroadcastPass and Broadcast Op (#60511) * add ShapeBroadcastOp * add pass FullyInsertBroadcastPass * InferSymbolicShape of BroadcastShape Op * Delete unit test * Fix return error * Code format * Fix error message * Update paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> --------- Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com> --- .../operator/transforms/CMakeLists.txt | 9 + .../transforms/fully_insert_broadcast_pass.cc | 112 +++++++++++++ .../transforms/fully_insert_broadcast_pass.h | 35 ++++ .../pir/dialect/operator/ir/manual_op.cc | 156 +++++++++++++++++- .../fluid/pir/dialect/operator/ir/manual_op.h | 33 ++++ 5 files changed, 344 insertions(+), 1 deletion(-) create mode 100644 paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc create mode 100644 paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h diff --git a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt index 6d76ccbec8adc1..dbe7f3c40adad6 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt @@ -29,4 +29,13 @@ if(NOT CINN_ONLY) cinn_op_dialect op_dialect_vjp) + cinn_cc_library( + fully_insert_broadcast_pass + SRCS + fully_insert_broadcast_pass.cc + DEPS + pir + cinn_op_dialect + op_dialect_vjp) + endif() diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc new file mode 100644 index 00000000000000..04ba01b4cbea2c --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc @@ -0,0 +1,112 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h" + +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" +#include "paddle/common/ddim.h" +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/match_context.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pattern_rewrite/pattern_applicator.h" +#include "paddle/pir/pattern_rewrite/pattern_match.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +namespace cinn { +namespace dialect { +namespace ir { + +pir::Value GetOutputDimTensor(pir::PatternRewriter* rewriter, + pir::Value x, + pir::Value y) { + pir::Value x_shape = rewriter->Build(x).out(); + pir::Value y_shape = rewriter->Build(y).out(); + return rewriter->Build(x_shape, y_shape) + .out(); +} + +bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { + pir::Value x = op->operand_source(0); + pir::Value y = op->operand_source(1); + pir::Value output_dim_tensor = GetOutputDimTensor(rewriter, x, y); + { + pir::Value broadcasted_x = + rewriter->Build(x, output_dim_tensor).out(); + op->operand(0).set_source(broadcasted_x); + } + { + pir::Value broadcasted_y = + rewriter->Build(y, output_dim_tensor).out(); + op->operand(1).set_source(broadcasted_y); + } + return true; +} + +template +class FullyInsertBroadcastPattern : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + + bool MatchAndRewrite(OPTYPE op, + pir::PatternRewriter& rewriter) const override { + return ProcessOp(op, &rewriter); + } +}; + +FullyInsertBroadcastPass::FullyInsertBroadcastPass() + : pir::PatternRewritePass("fully_insert_broadcast_pass", 1) {} + +pir::RewritePatternSet FullyInsertBroadcastPass::InitializePatterns( + pir::IrContext* context) { + pir::RewritePatternSet ps(context); + // elementwise ops + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>( + context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + + // compare ops + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + + // bitwise ops + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + + return ps; +} + +bool FullyInsertBroadcastPass::CanApplyOn(pir::Operation* op) const { + return op->isa() && op->num_regions() > 0; +} + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h b/paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h new file mode 100644 index 00000000000000..ba174583992784 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h @@ -0,0 +1,35 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" + +namespace cinn { +namespace dialect { +namespace ir { + +class FullyInsertBroadcastPass : public pir::PatternRewritePass { + public: + FullyInsertBroadcastPass(); + + pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override; + + bool CanApplyOn(pir::Operation *op) const override; +}; + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index 4c07132cfaa1f0..ca720fdb26ee35 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -23,7 +23,7 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op, paddle::dialect::SliceArrayOp, paddle::dialect::SliceArrayDenseOp, paddle::dialect::AssignArray_Op, paddle::dialect::ArrayToTensorOp, paddle::dialect::SelectInputOp, paddle::dialect::IncrementOp, - paddle::dialect::Increment_Op + paddle::dialect::Increment_Op, paddle::dialect::ShapeBroadcastOp #else #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" @@ -35,6 +35,7 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op, #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/primitive/rule/vjp/vjp.h" +#include "paddle/phi/api/lib/data_type_set.h" #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" @@ -2925,6 +2926,158 @@ phi::DataType Increment_Op::GetKernelTypeForVar( return expected_kernel_dtype; } +void ShapeBroadcastOp::Build(pir::Builder &builder, + pir::OperationArgument &argument, + pir::Value x_, + pir::Value y_) { + VLOG(4) << "Start build ShapeBroadcastOp"; + + VLOG(4) << "Builder construction inputs"; + std::vector argument_inputs = {x_, y_}; + argument.AddInputs(argument_inputs); + + VLOG(4) << "Builder construction attributes"; + + VLOG(4) << "Builder construction outputs"; + paddle::dialect::DenseTensorType x = + x_.type().dyn_cast(); + paddle::dialect::DenseTensorType y = + y_.type().dyn_cast(); + + VLOG(4) << "Builder construction dense_x"; + paddle::dialect::IrTensor ir_tensor_x( + paddle::dialect::TransToPhiDataType(x.dtype()), + x.dims(), + x.data_layout(), + x.lod(), + x.offset()); + VLOG(4) << "Builder construction meta_x"; + paddle::dialect::IrMetaTensor meta_x(&ir_tensor_x); + + VLOG(4) << "Builder construction dense_y"; + paddle::dialect::IrTensor ir_tensor_y( + paddle::dialect::TransToPhiDataType(y.dtype()), + y.dims(), + y.data_layout(), + y.lod(), + y.offset()); + VLOG(4) << "Builder construction meta_y"; + paddle::dialect::IrMetaTensor meta_y(&ir_tensor_y); + paddle::dialect::IrTensor dense_out; + paddle::dialect::IrMetaTensor meta_out(&dense_out); + + phi::ElementwiseInferMeta(meta_x, meta_y, &meta_out); + + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dense_out.dtype()), + dense_out.dims(), + dense_out.layout(), + dense_out.lod(), + dense_out.offset()); + argument_outputs.push_back(out_dense_tensor_type); + argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); + ::pir::PassStopGradientsDefaultly(argument); +} + +namespace { + +void ShapeBroadcastOpInferMeta(const phi::MetaTensor &x, + const phi::MetaTensor &y, + phi::MetaTensor *out) { + PADDLE_ENFORCE_EQ( + x.dims().size(), + 1, + phi::errors::PreconditionNotMet( + "The size %d of x.dims() must be equal to 1.", x.dims().size())); + PADDLE_ENFORCE_EQ( + y.dims().size(), + 1, + phi::errors::PreconditionNotMet( + "The size %d of y.dims() must be equal to 1.", y.dims().size())); + out->set_dims({std::max(x.dims().at(0), y.dims().at(0))}); + // dtype need promote when meet input dtype with more precision + paddle::experimental::DataTypeSet dtype_set{x.dtype()}; + dtype_set = dtype_set | paddle::experimental::DataTypeSet(y.dtype()); + DataType promote_result = PromoteTypes(dtype_set); + if (promote_result == DataType::UNDEFINED) { + promote_result = x.dtype(); + } + out->set_dtype(promote_result); + out->set_layout(x.layout()); + out->share_lod(x); +} + +} // namespace + +void ShapeBroadcastOp::InferMeta(phi::InferMetaContext *infer_meta) { + auto fn = PD_INFER_META(ShapeBroadcastOpInferMeta); + fn(infer_meta); +} + +phi::DataType ShapeBroadcastOp::GetKernelTypeForVar( + const std::string &var_name, + const phi::DataType &tensor_dtype, + const phi::DataType &expected_kernel_dtype) { + VLOG(4) << "Get KernelType for Var of op: ShapeBroadcastOp"; + + return expected_kernel_dtype; +} + +namespace { + +symbol::DimExpr GetBroadcastDimExpr(const symbol::DimExpr &lhs, + const symbol::DimExpr &rhs) { + if (lhs.isa() && rhs.isa()) { + return std::max(lhs.dyn_cast(), rhs.dyn_cast()); + } else if (lhs.isa()) { + return lhs.dyn_cast() == 1 ? rhs : lhs; + } else if (rhs.isa()) { + return rhs.dyn_cast() == 1 ? lhs : rhs; + } else { + return symbol::Broadcast{ + symbol::List{lhs, rhs}}; + } + LOG(FATAL) << "Dead code"; +} + +} // namespace + +bool ShapeBroadcastOp::InferSymbolicShape( + pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value x = operand_source(0); + pir::Value y = operand_source(1); + std::string x_id = pir::GetValueId(&x); + std::string y_id = pir::GetValueId(&y); + + IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(x_id) > 0, + "x_id does not exist."); + IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(y_id) > 0, + "y_id does not exist."); + const auto &x_data_shape = shape_analysis->value_id_to_shapeordata_.at(x_id); + const auto &y_data_shape = shape_analysis->value_id_to_shapeordata_.at(y_id); + IR_ENFORCE(x_data_shape.data().has_value(), + "Value x comes from ShapeOp, it must have data"); + IR_ENFORCE(y_data_shape.data().has_value(), + "Value y comes from ShapeOp, it must have data"); + const auto &x_data = x_data_shape.data().value(); + const auto &y_data = y_data_shape.data().value(); + IR_ENFORCE(x_data.size() == y_data.size(), "Support same rank temporarily"); + + std::vector output_data; + for (std::size_t i = 0; i < x_data.size(); ++i) { + output_data.emplace_back(GetBroadcastDimExpr(x_data.at(i), y_data.at(i))); + } + + pir::OpResult res = result(0); + std::string res_id = pir::GetValueId(&res); + symbol::ShapeOrDataDimExprs output_data_shape = + symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(output_data); + shape_analysis->value_id_to_shapeordata_[res_id] = output_data_shape; + return true; +} + } // namespace dialect } // namespace paddle @@ -2948,4 +3101,5 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::Increment_Op) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ShapeBroadcastOp) #endif diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index cbfadb24b97e6d..6e1e2d67a69e40 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -18,6 +18,7 @@ #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/pir/dialect/operator/interface/decomp.h" #include "paddle/fluid/pir/dialect/operator/interface/get_kernel_type_for_var.h" +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h" #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" #include "paddle/fluid/pir/dialect/operator/interface/vjp.h" @@ -554,6 +555,37 @@ class Increment_Op const std::vector> &stop_gradients); }; +class IR_API ShapeBroadcastOp + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_op.shape_broadcast"; } + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value x_, + pir::Value y_); + + void VerifySig() {} + + pir::Value x() { return operand_source(0); } + pir::Value y() { return operand_source(1); } + pir::OpResult out() { return result(0); } + + static void InferMeta(phi::InferMetaContext *infer_meta); + + static phi::DataType GetKernelTypeForVar( + const std::string &var_name, + const phi::DataType &tensor_dtype, + const phi::DataType &expected_kernel_dtype); + + bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); +}; + } // namespace dialect } // namespace paddle @@ -577,3 +609,4 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::Increment_Op) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ShapeBroadcastOp) From 488bd17a9ad708504e9cf263513dc08b679eccce Mon Sep 17 00:00:00 2001 From: xingmingyyj <135400902+xingmingyyj@users.noreply.github.com> Date: Thu, 4 Jan 2024 11:05:07 +0800 Subject: [PATCH 095/142] Fix OpTranslatorTest name (#60518) * fix name * fix name * fix name * fix name --- test/ir/pir/translator/CMakeLists.txt | 4 ++-- ...educe_min_translate.py => test_c_reduce_min_translator.py} | 4 ++-- .../{test_op_transcriber.py => test_op_translator.py} | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) rename test/ir/pir/translator/{test_c_reduce_min_translate.py => test_c_reduce_min_translator.py} (92%) rename test/ir/pir/translator/{test_op_transcriber.py => test_op_translator.py} (97%) diff --git a/test/ir/pir/translator/CMakeLists.txt b/test/ir/pir/translator/CMakeLists.txt index 108615b0c204e5..8ac1fb1e7a3b6b 100644 --- a/test/ir/pir/translator/CMakeLists.txt +++ b/test/ir/pir/translator/CMakeLists.txt @@ -4,10 +4,10 @@ file( "test_*.py") string(REPLACE ".py" "" TEST_INTERP_CASES "${TEST_INTERP_CASES}") -set(DISTRIBUTED_OP_TRANSLATION_TEST test_c_reduce_min_translate) +set(DISTRIBUTED_OP_TRANSLATOR_TEST test_c_reduce_min_translator) if(NOT WITH_DISTRIBUTE) - list(REMOVE_ITEM TEST_INTERP_CASES ${DISTRIBUTED_OP_TRANSLATION_TEST}) + list(REMOVE_ITEM TEST_INTERP_CASES ${DISTRIBUTED_OP_TRANSLATOR_TEST}) endif() foreach(target ${TEST_INTERP_CASES}) diff --git a/test/ir/pir/translator/test_c_reduce_min_translate.py b/test/ir/pir/translator/test_c_reduce_min_translator.py similarity index 92% rename from test/ir/pir/translator/test_c_reduce_min_translate.py rename to test/ir/pir/translator/test_c_reduce_min_translator.py index 63c4e8271c2e15..71610cf9a3e43c 100644 --- a/test/ir/pir/translator/test_c_reduce_min_translate.py +++ b/test/ir/pir/translator/test_c_reduce_min_translator.py @@ -14,13 +14,13 @@ import unittest -import test_op_transcriber +import test_op_translator import paddle from paddle.base.layer_helper import LayerHelper -class TestCReduceMinOpTranscriber(test_op_transcriber.TestOpTranscriber): +class TestCReduceMinOpTranslator(test_op_translator.TestOpTranslator): def append_op(self): self.op_type = "c_reduce_min" x = paddle.ones(shape=(100, 2, 3), dtype='float32') diff --git a/test/ir/pir/translator/test_op_transcriber.py b/test/ir/pir/translator/test_op_translator.py similarity index 97% rename from test/ir/pir/translator/test_op_transcriber.py rename to test/ir/pir/translator/test_op_translator.py index dfb8fa63a18705..8f257fa59b9f78 100644 --- a/test/ir/pir/translator/test_op_transcriber.py +++ b/test/ir/pir/translator/test_op_translator.py @@ -21,7 +21,7 @@ paddle.enable_static() -class TestOpTranscriber(unittest.TestCase): +class TestOpTranslator(unittest.TestCase): def setUp(self): self.place = core.Place() self.place.set_place(paddle.CPUPlace()) From d3078904bf11e37217239b790720ca167c069197 Mon Sep 17 00:00:00 2001 From: Lu Qi <61354321+MarioLulab@users.noreply.github.com> Date: Thu, 4 Jan 2024 11:09:55 +0800 Subject: [PATCH 096/142] [PIR] migrate DataFeeder into pir (#60434) --- python/paddle/base/data_feeder.py | 42 +++++--- python/paddle/static/input.py | 1 + test/legacy_test/test_data_feeder.py | 151 +++++++++++++++++---------- 3 files changed, 124 insertions(+), 70 deletions(-) diff --git a/python/paddle/base/data_feeder.py b/python/paddle/base/data_feeder.py index 4d56db43fbfdca..9f273c3194e828 100644 --- a/python/paddle/base/data_feeder.py +++ b/python/paddle/base/data_feeder.py @@ -16,6 +16,8 @@ import numpy as np +from paddle import pir + from ..pir import Value from ..pir.core import ParameterMeta from . import core @@ -419,19 +421,35 @@ def __init__(self, feed_list, place, program=None): self.feed_names = [] self.feed_shapes = [] self.feed_lod_level = [] - if program is None: - program = default_main_program() - for each_var in feed_list: - if isinstance(each_var, str): - each_var = program.block(0).var(each_var) - if not isinstance(each_var, Variable): - raise TypeError("Feed list should contain a list of variable") - self.feed_dtypes.append(each_var.dtype) - self.feed_names.append(each_var.name) - self.feed_lod_level.append(each_var.lod_level) - self.feed_shapes.append(each_var.shape) - self.place = place + if in_pir_mode(): + if program is None: + program = pir.core.default_main_program() + for each_var in feed_list: + if isinstance(each_var, str): + raise ValueError( + "In PIR Mode, Not supported string input yet" + ) + if not isinstance(each_var, Value): + raise TypeError("Feed list should contain a list of Value") + self.feed_dtypes.append(each_var.dtype) + self.feed_names.append(each_var.name) + self.feed_lod_level.append(each_var.lod_level) + self.feed_shapes.append(each_var.shape) + else: + if program is None: + program = default_main_program() + for each_var in feed_list: + if isinstance(each_var, str): + each_var = program.block(0).var(each_var) + if not isinstance(each_var, Variable): + raise TypeError( + "Feed list should contain a list of variable" + ) + self.feed_dtypes.append(each_var.dtype) + self.feed_names.append(each_var.name) + self.feed_lod_level.append(each_var.lod_level) + self.feed_shapes.append(each_var.shape) def feed(self, iterable): """ diff --git a/python/paddle/static/input.py b/python/paddle/static/input.py index 623ed4d2ef73ff..f8ee2b9bfafbd7 100644 --- a/python/paddle/static/input.py +++ b/python/paddle/static/input.py @@ -134,6 +134,7 @@ def _reset_data_op_insertion_point(): ir_dtype = paddle.pir.core.convert_np_dtype_to_dtype_(dtype) _reset_data_op_insertion_point() out = paddle._pir_ops.data(name, shape, ir_dtype, core.Place()) + out.lod_level = lod_level paddle.pir.reset_insertion_point_to_end() return out diff --git a/test/legacy_test/test_data_feeder.py b/test/legacy_test/test_data_feeder.py index e8cda8eb45d63f..5653ff7d98b191 100644 --- a/test/legacy_test/test_data_feeder.py +++ b/test/legacy_test/test_data_feeder.py @@ -16,74 +16,109 @@ import paddle from paddle import base +from paddle.pir_utils import test_with_pir_api paddle.enable_static() class TestDataFeeder(unittest.TestCase): + @test_with_pir_api def test_lod_level_0_converter(self): - img = paddle.static.data(name='image', shape=[-1, 1, 28, 28]) - label = paddle.static.data(name='label', shape=[-1, 1], dtype='int64') - feeder = base.DataFeeder([img, label], base.CPUPlace()) - result = feeder.feed([([0] * 784, [9]), ([1] * 784, [1])]) - - self.assertEqual(result['image'].shape(), [2, 1, 28, 28]) - self.assertEqual(result['label'].shape(), [2, 1]) - self.assertEqual(result['image'].recursive_sequence_lengths(), []) - self.assertEqual(result['label'].recursive_sequence_lengths(), []) - - try: - result = feeder.feed([([0] * 783, [9]), ([1] * 783, [1])]) - self.assertTrue(False) - except ValueError: - self.assertTrue(True) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + img = paddle.static.data(name='image', shape=[-1, 1, 28, 28]) + label = paddle.static.data( + name='label', shape=[-1, 1], dtype='int64' + ) + feeder = base.DataFeeder([img, label], base.CPUPlace()) + result = feeder.feed([([0] * 784, [9]), ([1] * 784, [1])]) + self.assertEqual(result['image'].shape(), [2, 1, 28, 28]) + self.assertEqual(result['label'].shape(), [2, 1]) + self.assertEqual(result['image'].recursive_sequence_lengths(), []) + self.assertEqual(result['label'].recursive_sequence_lengths(), []) + + try: + result = feeder.feed([([0] * 783, [9]), ([1] * 783, [1])]) + self.assertTrue(False) + except ValueError: + self.assertTrue(True) + + @test_with_pir_api def test_lod_level_1_converter(self): - # lod_level = 1 - # each sentence has a different number of words - sentences = paddle.static.data( - name='sentences', shape=[-1, 1], dtype='int64', lod_level=1 - ) - label = paddle.static.data(name='label', shape=[-1, 1], dtype='int64') - feeder = base.DataFeeder([sentences, label], base.CPUPlace()) - - # lod = [[0, 3, 5, 9]] - # data = [[1, 2, 3], [4, 5], [6, 7, 8, 9]] - # label = [1] * len(data) - result = feeder.feed( - [([1, 2, 3], [1]), ([4, 5], [1]), ([6, 7, 8, 9], [1])] - ) - - self.assertEqual(result['sentences'].shape(), [9, 1]) - self.assertEqual(result['label'].shape(), [3, 1]) - self.assertEqual( - result['sentences'].recursive_sequence_lengths(), [[3, 2, 4]] - ) - self.assertEqual(result['label'].recursive_sequence_lengths(), []) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + # lod_level = 1 + # each sentence has a different number of words + sentences = paddle.static.data( + name='sentences', shape=[-1, 1], dtype='int64', lod_level=1 + ) + label = paddle.static.data( + name='label', shape=[-1, 1], dtype='int64' + ) + feeder = base.DataFeeder([sentences, label], base.CPUPlace()) + + # lod = [[0, 3, 5, 9]] + # data = [[1, 2, 3], [4, 5], [6, 7, 8, 9]] + # label = [1] * len(data) + result = feeder.feed( + [([1, 2, 3], [1]), ([4, 5], [1]), ([6, 7, 8, 9], [1])] + ) + + self.assertEqual(result['sentences'].shape(), [9, 1]) + self.assertEqual(result['label'].shape(), [3, 1]) + self.assertEqual( + result['sentences'].recursive_sequence_lengths(), [[3, 2, 4]] + ) + self.assertEqual(result['label'].recursive_sequence_lengths(), []) + @test_with_pir_api def test_lod_level_2_converter(self): - # lod_level = 2 - # paragraphs -> sentences -> words - paragraphs = paddle.static.data( - name='paragraphs', shape=[-1, 1], dtype='int64', lod_level=2 - ) - label = paddle.static.data(name='label', shape=[-1, 1], dtype='int64') - feeder = base.DataFeeder([paragraphs, label], base.CPUPlace()) - - # lod = [[0, 2, 3], [0, 3, 5, 9]] - # data = [[[1, 2, 3], [4, 5]], [[6, 7, 8, 9]]] - # label = [1] * len(data) - result = feeder.feed( - [([[1, 2, 3], [4, 5]], [1]), ([[6, 7, 8, 9]], [1])] - ) - - self.assertEqual(result['paragraphs'].shape(), [9, 1]) - self.assertEqual(result['label'].shape(), [2, 1]) - self.assertEqual( - result['paragraphs'].recursive_sequence_lengths(), - [[2, 1], [3, 2, 4]], - ) - self.assertEqual(result['label'].recursive_sequence_lengths(), []) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + # lod_level = 2 + # paragraphs -> sentences -> words + paragraphs = paddle.static.data( + name='paragraphs', shape=[-1, 1], dtype='int64', lod_level=2 + ) + label = paddle.static.data( + name='label', shape=[-1, 1], dtype='int64' + ) + feeder = base.DataFeeder([paragraphs, label], base.CPUPlace()) + + # lod = [[0, 2, 3], [0, 3, 5, 9]] + # data = [[[1, 2, 3], [4, 5]], [[6, 7, 8, 9]]] + # label = [1] * len(data) + result = feeder.feed( + [([[1, 2, 3], [4, 5]], [1]), ([[6, 7, 8, 9]], [1])] + ) + + self.assertEqual(result['paragraphs'].shape(), [9, 1]) + self.assertEqual(result['label'].shape(), [2, 1]) + self.assertEqual( + result['paragraphs'].recursive_sequence_lengths(), + [[2, 1], [3, 2, 4]], + ) + self.assertEqual(result['label'].recursive_sequence_lengths(), []) + + def test_errors(self): + def pir_mode_not_supported_str_feed(): + with paddle.pir_utils.IrGuard(): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + img = paddle.static.data( + name='image', shape=[-1, 1, 28, 28] + ) + label = paddle.static.data( + name='label', shape=[-1, 1], dtype='int64' + ) + feeder = base.DataFeeder(['image', label], base.CPUPlace()) + + self.assertRaises(ValueError, pir_mode_not_supported_str_feed) if __name__ == '__main__': From e397b2952db853fb54e5ef0b08688a7f06782dab Mon Sep 17 00:00:00 2001 From: LoneRanger <836253168@qq.com> Date: Thu, 4 Jan 2024 11:10:12 +0800 Subject: [PATCH 097/142] =?UTF-8?q?=E3=80=90PIR=20API=20adaptor=20No.90,92?= =?UTF-8?q?=E3=80=91Migrate=20some=20ops=20into=20pir=20(#59801)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../incubate/operators/graph_reindex.py | 4 ++-- python/paddle/vision/ops.py | 22 +++++++++++++++++++ .../test_generate_proposals_v2_op.py | 5 ++--- test/legacy_test/test_graph_reindex.py | 1 + 4 files changed, 27 insertions(+), 5 deletions(-) diff --git a/python/paddle/incubate/operators/graph_reindex.py b/python/paddle/incubate/operators/graph_reindex.py index abba95b97ca039..af124893541c9c 100644 --- a/python/paddle/incubate/operators/graph_reindex.py +++ b/python/paddle/incubate/operators/graph_reindex.py @@ -15,7 +15,7 @@ from paddle import _C_ops from paddle.base.data_feeder import check_variable_and_dtype from paddle.base.layer_helper import LayerHelper -from paddle.framework import in_dynamic_mode +from paddle.framework import in_dynamic_or_pir_mode from paddle.utils import deprecated @@ -130,7 +130,7 @@ def graph_reindex( "be None if `flag_buffer_hashtable` is True." ) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): reindex_src, reindex_dst, out_nodes = _C_ops.reindex_graph( x, neighbors, diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index ee34bf5e0ee699..4803c15da7911b 100755 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -26,6 +26,7 @@ convert_np_dtype_to_dtype_, in_dygraph_mode, in_dynamic_or_pir_mode, + in_pir_mode, ) from ..base.layer_helper import LayerHelper from ..framework import _current_expected_place @@ -2144,6 +2145,27 @@ def generate_proposals( scores, bbox_deltas, img_size, anchors, variances, *attrs ) + return rpn_rois, rpn_roi_probs, rpn_rois_num + elif in_pir_mode(): + assert ( + return_rois_num + ), "return_rois_num should be True in PaddlePaddle inner op mode." + rpn_rois, rpn_roi_probs, rpn_rois_num = _C_ops.generate_proposals( + scores, + bbox_deltas, + img_size, + anchors, + variances, + pre_nms_top_n, + post_nms_top_n, + nms_thresh, + min_size, + eta, + pixel_offset, + ) + rpn_rois.stop_gradient = True + rpn_roi_probs.stop_gradient = True + rpn_rois_num.stop_gradient = True return rpn_rois, rpn_roi_probs, rpn_rois_num else: helper = LayerHelper('generate_proposals_v2', **locals()) diff --git a/test/legacy_test/test_generate_proposals_v2_op.py b/test/legacy_test/test_generate_proposals_v2_op.py index 568c466e066661..87e9e6c60fe7d6 100644 --- a/test/legacy_test/test_generate_proposals_v2_op.py +++ b/test/legacy_test/test_generate_proposals_v2_op.py @@ -54,7 +54,7 @@ def python_generate_proposals_v2( pixel_offset=pixel_offset, return_rois_num=return_rois_num, ) - return rpn_rois, rpn_roi_probs + return rpn_rois, rpn_roi_probs, rpn_rois_num def generate_proposals_v2_in_python( @@ -223,12 +223,11 @@ def set_data(self): } def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def setUp(self): self.op_type = "generate_proposals_v2" self.python_api = python_generate_proposals_v2 - self.python_out_sig = ['Out'] self.set_data() def init_test_params(self): diff --git a/test/legacy_test/test_graph_reindex.py b/test/legacy_test/test_graph_reindex.py index 85ad07f86af0f1..13ec55660a9e16 100644 --- a/test/legacy_test/test_graph_reindex.py +++ b/test/legacy_test/test_graph_reindex.py @@ -129,6 +129,7 @@ def test_heter_reindex_result_v2(self): np.testing.assert_allclose(reindex_dst, reindex_dst_, rtol=1e-05) np.testing.assert_allclose(out_nodes, out_nodes_, rtol=1e-05) + @test_with_pir_api def test_reindex_result_static(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): From 1b2696693c5a53ae94bec81e04ee89fc3adf7432 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Thu, 4 Jan 2024 11:14:06 +0800 Subject: [PATCH 098/142] [DimExpr] Convert Broadcast to BroadcastTree (#60440) * backup BroadcastTree * add SubstituteDimExpr * add helper function ConstructBroadcastTree * Fix compile error * Code format * Polish DimExprUtilTest * Add cmake file * Change namesapce * Fix compile error * Fix unittest * reconstruct BroadcastTree * Polish DimExprUtilTest * Reconstruct BroadcastTree * Finish BroadcastBranch * Finish BroadcastBranch * Finish BroadcastBranch * Add Unittest * Remove unnecessary dim_expr_util * Add header file --- paddle/cinn/common/CMakeLists.txt | 2 + paddle/cinn/common/broadcast_tree.cc | 298 ++++++++++++++++++++++ paddle/cinn/common/broadcast_tree.h | 34 +++ paddle/cinn/common/broadcast_tree_test.cc | 116 +++++++++ 4 files changed, 450 insertions(+) create mode 100644 paddle/cinn/common/broadcast_tree.cc create mode 100644 paddle/cinn/common/broadcast_tree.h create mode 100644 paddle/cinn/common/broadcast_tree_test.cc diff --git a/paddle/cinn/common/CMakeLists.txt b/paddle/cinn/common/CMakeLists.txt index ff024385d34795..3d0dc05f77b11d 100644 --- a/paddle/cinn/common/CMakeLists.txt +++ b/paddle/cinn/common/CMakeLists.txt @@ -24,6 +24,7 @@ gather_srcs( integer_set.cc dim_expr_simplify.cc dim_expr_converter.cc + broadcast_tree.cc dim_expr_util.cc) cinn_cc_test(test_equation_graph_topo_walker SRCS @@ -54,4 +55,5 @@ if(NOT CINN_ONLY) cinncore) cinn_cc_test(dim_expr_converter_test SRCS dim_expr_converter_test.cc DEPS cinncore) + cinn_cc_test(broadcast_tree_test SRCS broadcast_tree_test.cc DEPS cinncore) endif() diff --git a/paddle/cinn/common/broadcast_tree.cc b/paddle/cinn/common/broadcast_tree.cc new file mode 100644 index 00000000000000..ddc78b738c7b78 --- /dev/null +++ b/paddle/cinn/common/broadcast_tree.cc @@ -0,0 +1,298 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/common/broadcast_tree.h" + +#include +#include + +#include "paddle/cinn/common/dim_expr_simplify.h" +#include "paddle/cinn/common/dim_expr_util.h" + +namespace cinn::common { + +namespace { + +template +bool SearchBroadcast(const symbol::DimExpr& dim_expr, const DoEachT& DoEach); + +template +bool SearchBroadcastImpl(int64_t, const DoEachT& DoEach) { + return false; +} + +template +bool SearchBroadcastImpl(const std::string&, const DoEachT& DoEach) { + return false; +} + +template +bool SearchBroadcastImplForUnary(const T& unary, const DoEachT& DoEach) { + const auto& operand = unary->data; + return SearchBroadcast(operand, DoEach); +} + +template +bool SearchBroadcastImpl(const symbol::Negative& unary, + const DoEachT& DoEach) { + return SearchBroadcastImplForUnary(unary, DoEach); +} + +template +bool SearchBroadcastImpl(const symbol::Reciprocal& unary, + const DoEachT& DoEach) { + return SearchBroadcastImplForUnary(unary, DoEach); +} + +template +bool SearchBroadcastImplForVariadic(const T& variadic, const DoEachT& DoEach) { + const auto& operands = *(variadic.operands); + for (const auto& operand : operands) { + if (SearchBroadcast(operand, DoEach)) return true; + } + return false; +} + +template +bool SearchBroadcastImpl(const symbol::Add& variadic, + const DoEachT& DoEach) { + return SearchBroadcastImplForVariadic(variadic, DoEach); +} + +template +bool SearchBroadcastImpl(const symbol::Mul& variadic, + const DoEachT& DoEach) { + return SearchBroadcastImplForVariadic(variadic, DoEach); +} + +template +bool SearchBroadcastImpl(const symbol::Max& variadic, + const DoEachT& DoEach) { + return SearchBroadcastImplForVariadic(variadic, DoEach); +} + +template +bool SearchBroadcastImpl(const symbol::Min& variadic, + const DoEachT& DoEach) { + return SearchBroadcastImplForVariadic(variadic, DoEach); +} + +template +bool SearchBroadcastImpl(const symbol::Broadcast& variadic, + const DoEachT& DoEach) { + const auto& operands = *(variadic.operands); + for (const auto& operand : operands) { + CHECK(!operand.isa()); + if (SearchBroadcast(operand, DoEach)) return true; + } + return DoEach(variadic); +} + +template +bool SearchBroadcast(const symbol::DimExpr& dim_expr, const DoEachT& DoEach) { + return std::visit( + [&](const auto& impl) { return SearchBroadcastImpl(impl, DoEach); }, + dim_expr.variant()); +} + +template +void ForEachBroadcastDimExpr(const BroadcastLeaf& leaves, + const DoEachT& DoEach) { + for (const auto& dim_exprs : *leaves) { + for (const auto& dim_expr : dim_exprs) { + if (SearchBroadcast(dim_expr, DoEach)) return; + } + } +} + +std::optional> GetFirstCstrBroadcastable( + const BroadcastLeaf& leaves) { + std::optional> ret; + ForEachBroadcastDimExpr(leaves, [&](const auto& broadcast) -> bool { + const auto& operands = broadcast.operands; + std::optional lhs_symbol; + std::optional rhs_symbol; + size_t i = 0; + for (; i < operands->size(); ++i) { + if (operands->at(i).template isa()) { + lhs_symbol = operands->at(i); + break; + } + } + for (i++; i < operands->size(); ++i) { + if (operands->at(i).template isa()) { + rhs_symbol = operands->at(i); + break; + } + } + if (lhs_symbol.has_value() && rhs_symbol.has_value()) { + CHECK(lhs_symbol != rhs_symbol); + ret = symbol::Broadcastable{lhs_symbol.value(), + rhs_symbol.value()}; + return true; + } + return false; + }); + if (ret.has_value()) return ret.value(); + ForEachBroadcastDimExpr(leaves, [&](const auto& broadcast) -> bool { + const auto& operands = broadcast.operands; + std::optional lhs_symbol; + std::optional rhs; + for (const auto& operand : *operands) { + if (operand.template isa()) { + lhs_symbol = operand; + break; + } + } + for (const auto& operand : *operands) { + if (operand != lhs_symbol) { + rhs = operand; + break; + } + } + if (lhs_symbol.has_value() && rhs.has_value()) { + ret = symbol::Broadcastable{lhs_symbol.value(), + rhs.value()}; + return true; + } + return false; + }); + if (ret.has_value()) return ret.value(); + ForEachBroadcastDimExpr(leaves, [&](const auto& broadcast) -> bool { + const auto& operands = broadcast.operands; + CHECK_GE(operands->size(), 2); + CHECK(operands->at(0) != operands->at(1)); + ret = symbol::Broadcastable{operands->at(0), + operands->at(1)}; + return true; + }); + return ret; +} + +using Pattern2Placement = std::unordered_map; + +Pattern2Placement ConstructCstrLhsEqRhsReplacement( + const symbol::Broadcastable& broadcastable_condition) { + auto [lhs, rhs] = *broadcastable_condition; + if (lhs.isa()) return Pattern2Placement{{lhs, rhs}}; + if (rhs.isa()) return Pattern2Placement{{rhs, lhs}}; + return Pattern2Placement{{lhs, rhs}}; +} + +Pattern2Placement ConstructCstrLhsEqOneReplacement( + const symbol::Broadcastable& broadcastable_condition) { + const auto& [lhs, rhs] = *broadcastable_condition; + return Pattern2Placement{{lhs, symbol::DimExpr{1}}}; +} + +Pattern2Placement ConstructCstrRhsEqOneReplacement( + const symbol::Broadcastable& broadcastable_condition) { + const auto& [lhs, rhs] = *broadcastable_condition; + return Pattern2Placement{{rhs, symbol::DimExpr{1}}}; +} + +symbol::DimExpr GetCstrLhsEqRhsDimExpr( + const symbol::Broadcastable& broadcastable_condition, + const symbol::DimExpr& dim_expr) { + const auto& pattern2replacement = + ConstructCstrLhsEqRhsReplacement(broadcastable_condition); + return SimplifyDimExpr(SubstituteDimExpr(dim_expr, pattern2replacement)); +} + +symbol::DimExpr GetCstrLhsEqOneDimExpr( + const symbol::Broadcastable& broadcastable_condition, + const symbol::DimExpr& dim_expr) { + const auto& pattern2replacement = + ConstructCstrLhsEqOneReplacement(broadcastable_condition); + return SimplifyDimExpr(SubstituteDimExpr(dim_expr, pattern2replacement)); +} + +symbol::DimExpr GetCstrRhsEqOneDimExpr( + const symbol::Broadcastable& broadcastable_condition, + const symbol::DimExpr& dim_expr) { + const auto& pattern2replacement = + ConstructCstrRhsEqOneReplacement(broadcastable_condition); + return SimplifyDimExpr(SubstituteDimExpr(dim_expr, pattern2replacement)); +} + +typedef symbol::DimExpr (*ConvertDimExprT)( + const symbol::Broadcastable& broadcastable_condition, + const symbol::DimExpr& dim_expr); + +template +BroadcastLeaf ConvertBroadcastLeaf( + const symbol::Broadcastable& broadcastable_condition, + const BroadcastLeaf& leaves) { + BroadcastLeaf ret{}; + for (const auto& dim_exprs : *leaves) { + std::vector converted{}; + converted.reserve(dim_exprs.size()); + for (const auto& dim_expr : dim_exprs) { + converted.push_back(ConvertDimExpr(broadcastable_condition, dim_expr)); + } + ret->emplace_back(std::move(converted)); + } + return ret; +} + +BroadcastLeaf GetCstrLhsEqRhsLeaves( + const symbol::Broadcastable& broadcastable_condition, + const BroadcastLeaf& leaves) { + return ConvertBroadcastLeaf<&GetCstrLhsEqRhsDimExpr>(broadcastable_condition, + leaves); +} + +BroadcastLeaf GetCstrLhsEqOneLeaves( + const symbol::Broadcastable& broadcastable_condition, + const BroadcastLeaf& leaves) { + return ConvertBroadcastLeaf<&GetCstrLhsEqOneDimExpr>(broadcastable_condition, + leaves); +} + +BroadcastLeaf GetCstrRhsEqOneLeaves( + const symbol::Broadcastable& broadcastable_condition, + const BroadcastLeaf& leaves) { + return ConvertBroadcastLeaf<&GetCstrRhsEqOneDimExpr>(broadcastable_condition, + leaves); +} + +BroadcastBranch ConstructBroadcastBranch( + const symbol::Broadcastable& broadcastable_condition, + const BroadcastLeaf& leaves) { + BroadcastLeaf cstr_lhs_eq_rhs_leaves = + GetCstrLhsEqRhsLeaves(broadcastable_condition, leaves); + BroadcastLeaf cstr_lhs_eq_one_leaves = + GetCstrLhsEqOneLeaves(broadcastable_condition, leaves); + BroadcastLeaf cstr_rhs_eq_one_leaves = + GetCstrRhsEqOneLeaves(broadcastable_condition, leaves); + // clang-format off + return BroadcastBranch{ + /*broadcastable_condition*/ broadcastable_condition, + /*cstr_lhs_eq_rhs_branch*/ ConstructBroadcastTree(cstr_lhs_eq_rhs_leaves), + /*cstr_lhs_eq_one_branch*/ ConstructBroadcastTree(cstr_lhs_eq_one_leaves), + /*cstr_rhs_eq_one_branch*/ ConstructBroadcastTree(cstr_rhs_eq_one_leaves) + }; + // clang-format on +} + +} // namespace + +BroadcastTree ConstructBroadcastTree(const BroadcastLeaf& leaves) { + std::optional> + broadcastable_condition = GetFirstCstrBroadcastable(leaves); + if (!broadcastable_condition.has_value()) return leaves; + return ConstructBroadcastBranch(broadcastable_condition.value(), leaves); +} + +} // namespace cinn::common diff --git a/paddle/cinn/common/broadcast_tree.h b/paddle/cinn/common/broadcast_tree.h new file mode 100644 index 00000000000000..dbabbe7d1d6a59 --- /dev/null +++ b/paddle/cinn/common/broadcast_tree.h @@ -0,0 +1,34 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/cinn/adt/tree.h" +#include "paddle/pir/dialect/shape/utils/dim_expr.h" + +namespace cinn::common { + +template +using BroadcastBranch = adt::Tuple, + /*cstr_lhs_eq_rhs_branch*/ T, + /*cstr_lhs_eq_one_branch*/ T, + /*cstr_rhs_eq_one_branch*/ T>; + +using BroadcastLeaf = adt::List>; + +using BroadcastTree = adt::Tree; + +BroadcastTree ConstructBroadcastTree(const BroadcastLeaf& leaves); + +} // namespace cinn::common diff --git a/paddle/cinn/common/broadcast_tree_test.cc b/paddle/cinn/common/broadcast_tree_test.cc new file mode 100644 index 00000000000000..8a09e8abd7dee1 --- /dev/null +++ b/paddle/cinn/common/broadcast_tree_test.cc @@ -0,0 +1,116 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/common/broadcast_tree.h" + +#include "gtest/gtest.h" + +namespace cinn::common { +using namespace symbol; // NOLINT + +namespace { + +DimExpr MakeBroadcastDimExpr(const DimExpr& expr1, const DimExpr& expr2) { + List operands{expr1, expr2}; + return Broadcast{operands}; +} + +bool DimExprNonBroadcast(const DimExpr& dim_expr) { + if (dim_expr.Has>()) { + return false; + } else { + return true; + } +} + +void CheckLeafNonBroadcast(const BroadcastLeaf& leaf) { + for (const auto& operands : *leaf) { + for (const auto& operand : operands) { + ASSERT_TRUE(DimExprNonBroadcast(operand)); + } + } +} + +void CheckInnerBranchNonBroadcast( + const BroadcastBranch& branch) { + const auto& [_, lhs_eq_rhs_tree, lhs_eq_one_tree, rhs_eq_one_tree] = + branch.tuple(); + ASSERT_TRUE(lhs_eq_rhs_tree.Has()); + ASSERT_TRUE(lhs_eq_one_tree.Has()); + ASSERT_TRUE(rhs_eq_one_tree.Has()); + CheckLeafNonBroadcast(lhs_eq_rhs_tree.Get()); + CheckLeafNonBroadcast(lhs_eq_one_tree.Get()); + CheckLeafNonBroadcast(rhs_eq_one_tree.Get()); +} + +} // namespace + +TEST(BroadcastTree, Naive) { + DimExpr expr1("S1"); + DimExpr expr2("S2"); + DimExpr expr3("S3"); + DimExpr expr4("S4"); + std::vector tensor_shape{expr1, + expr2, + MakeBroadcastDimExpr(expr1, expr2), + MakeBroadcastDimExpr(expr3, expr4)}; + BroadcastLeaf leaf = adt::List>{tensor_shape}; + BroadcastTree tree = ConstructBroadcastTree(leaf); + ASSERT_TRUE(tree.Has>()); + const auto& branch = tree.Get>(); + const auto& [cstr_broadcastable, + lhs_eq_rhs_tree, + lhs_eq_one_tree, + rhs_eq_one_tree] = branch.tuple(); + ASSERT_EQ(cstr_broadcastable->lhs, DimExpr("S1")); + ASSERT_EQ(cstr_broadcastable->rhs, DimExpr("S2")); + ASSERT_TRUE(lhs_eq_rhs_tree.Has>()); + ASSERT_TRUE(lhs_eq_one_tree.Has>()); + ASSERT_TRUE(rhs_eq_one_tree.Has>()); + CheckInnerBranchNonBroadcast( + lhs_eq_rhs_tree.Get>()); + CheckInnerBranchNonBroadcast( + lhs_eq_one_tree.Get>()); + CheckInnerBranchNonBroadcast( + rhs_eq_one_tree.Get>()); +} + +TEST(BroadcastTree, SimplifyConstantBroadcast) { + DimExpr expr1("S1"); + DimExpr expr2("S2"); + DimExpr expr3("S3"); + DimExpr expr4(4); + std::vector tensor_shape{expr1, + expr2, + MakeBroadcastDimExpr(expr1, expr2), + MakeBroadcastDimExpr(expr3, expr4)}; + BroadcastLeaf leaf = adt::List>{tensor_shape}; + BroadcastTree tree = ConstructBroadcastTree(leaf); + ASSERT_TRUE(tree.Has>()); + const auto& branch = tree.Get>(); + const auto& [cstr_broadcastable, + lhs_eq_rhs_tree, + lhs_eq_one_tree, + rhs_eq_one_tree] = branch.tuple(); + ASSERT_EQ(cstr_broadcastable->lhs, DimExpr("S1")); + ASSERT_EQ(cstr_broadcastable->rhs, DimExpr("S2")); + ASSERT_TRUE(lhs_eq_rhs_tree.Has()); + ASSERT_TRUE(lhs_eq_one_tree.Has()); + ASSERT_TRUE(rhs_eq_one_tree.Has()); + CheckLeafNonBroadcast(lhs_eq_rhs_tree.Get()); + CheckLeafNonBroadcast(lhs_eq_one_tree.Get()); + CheckLeafNonBroadcast(rhs_eq_one_tree.Get()); +} + +} // namespace cinn::common From aacdc4d9ac1573fd65e6170afd0335d8e666fba2 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Thu, 4 Jan 2024 11:16:57 +0800 Subject: [PATCH 099/142] [Dynamic Shape] Erase expand (#60525) * EraseExpandOp * minor fix * minor fix * Code format --- .../group_merge/cinn_group_lowering_pass.cc | 75 +++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/cinn_group_lowering_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/cinn_group_lowering_pass.cc index f4aa34bbc72638..db2dd030ba7021 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/cinn_group_lowering_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/cinn_group_lowering_pass.cc @@ -34,10 +34,85 @@ #include "paddle/pir/pass/pass_registry.h" #include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" +#include "paddle/pir/dialect/shape/utils/dim_expr.h" + PD_DECLARE_bool(cinn_enable_map_expr); namespace { +using ShapeOrDataDimExprs4ValueT = + std::function; + +pir::Block::ConstIterator FindFirstExpandOp(pir::Block* block) { + for (auto iter = block->begin(); iter != block->end(); ++iter) { + if (iter->isa()) { + return iter; + } + } +} + +bool SameInputOutputShape( + paddle::dialect::ExpandOp expand_op, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) { + const auto& x = ShapeOrDataDimExprs4Value(expand_op.x()); + const auto& shape = ShapeOrDataDimExprs4Value(expand_op.shape()); + const auto& out = ShapeOrDataDimExprs4Value(expand_op.out()); + if (x.data().has_value()) return false; + if (!shape.data().has_value()) return false; + if (out.data().has_value()) return false; + CHECK(shape.data().value() == out.shape()); + return x.shape() == out.shape(); +} + +void ReplaceAllUsesWithInput(paddle::dialect::ExpandOp expand) { + pir::Value x = expand.x(); + expand.out().ReplaceAllUsesWith(x); +} + +void EraseExpandOp(pir::Block* block, pir::Block::ConstIterator expand_it) { + block->erase(expand_it); +} + +void EraseUpstreamGenerateShapeOp( + pir::Block* block, cinn::dialect::GenerateShapeOp generate_shape_op) { + for (auto iter = block->begin(); iter != block->end(); ++iter) { + if (iter->isa()) { + if (iter->dyn_cast() == + generate_shape_op) { + block->erase(iter); + } + } + } +} + +// Returns true if success +bool EraseOneExpand( + pir::Block* block, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) { + for (auto expand_it = block->begin(); expand_it != block->end(); + ++expand_it) { + if (!expand_it->isa()) continue; + auto expand = expand_it->dyn_cast(); + if (!SameInputOutputShape(expand, ShapeOrDataDimExprs4Value)) continue; + auto generate_shape_op = + expand.shape().defining_op(); + CHECK_NOTNULL(generate_shape_op); + ReplaceAllUsesWithInput(expand); + EraseExpandOp(block, expand_it); + EraseUpstreamGenerateShapeOp(block, generate_shape_op); + return true; + } + return false; +} + +void EraseExpands(pir::Block* block, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) { + while (EraseOneExpand(block, ShapeOrDataDimExprs4Value)) { + // Do nothing. + } +} + std::vector GetBlockOutsideInput( const std::vector op_list) { std::vector vec_res; From 193fea3d064e5e099d8b5fc283fdc7bdb20d61e9 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Thu, 4 Jan 2024 11:32:43 +0800 Subject: [PATCH 100/142] [inference] Support wint4 groupwise with cutlass gemm (#60422) * support gemv-groupwise func && weightQuanter-groupwise && weightDeQuanter-groupwise * fix build bug * add unit_test && fix bug * delete useless code * fix ci build bug * fix ci && optimize * fix merge conflict * add op change info * fix weight_only_linear_pass * fix format * solve ci unit_test * init * support cutlass gemm with groupwise * add unit test * fix strange bug * delete random bug * fix sm70 build bug * try to fix ci build bug * fix bug * fix volta build bug * skip sm70 in groupwise mode * change cutlass branch --- .../cutlass/cutlass_extensions/arch/mma.h | 93 ++- .../cutlass_extensions/ft_gemm_configs.h | 4 +- .../gemm/kernel/fpA_intB_gemm.h | 82 +- .../gemm/kernel/fpA_intB_gemm_split_k.h | 2 +- .../threadblock/default_dq_mma_multistage.h | 105 ++- .../dp_mma_multistage_finegrained.h | 741 ++++++++++++++++++ .../threadblock/dp_mma_multistage_percol.h | 684 ++++++++++++++++ .../gemm/threadblock/dq_mma_multistage.h | 610 +------------- .../gemm/threadblock/dq_mma_pipelined.h | 17 +- .../gemm/warp/mma_tensorop_dequantizer.h | 14 + .../fine_grained_scale_zero_iterator.h | 277 +++++++ .../cutlass_kernels/cutlass_heuristic.h | 67 +- .../fpA_intB_gemm/fpA_intB_gemm.h | 10 +- .../fpA_intB_gemm/fpA_intB_gemm_template.cu | 310 +++++--- .../fpA_intB_gemm/fpA_intB_gemm_template.h | 44 +- .../generic_mixed_gemm_kernelLauncher.py | 48 +- .../kernels/gpu/weight_only_linear_kernel.cu | 4 + test/quantization/test_weight_only_linear.py | 86 +- 18 files changed, 2429 insertions(+), 769 deletions(-) create mode 100644 paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dp_mma_multistage_finegrained.h create mode 100644 paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dp_mma_multistage_percol.h create mode 100644 paddle/phi/kernels/fusion/cutlass/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h index 151153e4b297da..aef10063f1dc42 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h @@ -1,18 +1,34 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: * - * http://www.apache.org/licenses/LICENSE-2.0 + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. @@ -42,5 +58,58 @@ namespace arch { // Tag which triggers MMA which will trigger struct OpMultiplyAddDequantizeInterleavedBToA; +/* + Below we have extra tags to signal what kind of dequantization we want to do + (per col, scale only fine grained, finegrained with zero). This still lets us + the existing template infrastructure (incl. that in CUTLASS). However, we + split out the template below into OpMultiplyAddDequantizeInterleavedBToA along + with the quantization op before instantiating the GEMM pieces. + + Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount + of code we need to duplicate. + */ +struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +struct OpMultiplyAddDequantizeInterleavedBToA_fine_grained_scale; + +// The default just forwards the original operator +template +struct TagOperator { + using TaggedOperator = MmaOp; +}; + +// Specializations below attach more information to the operator +template <> +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +}; + +template <> +struct TagOperator { + using TaggedOperator = + OpMultiplyAddDequantizeInterleavedBToA_fine_grained_scale; +}; + +// Here we instantiate some structs to "detag" the tagged operator. It splits it +// back to the original operator + the extra information. If no extra info was +// tagged, the dequant op per column scaling as a default. +template +struct DetagOperator { + using Operator = TaggedMmaOp; + static constexpr bool FineGrained = false; +}; + +template <> +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr bool FineGrained = false; +}; + +template <> +struct DetagOperator< + OpMultiplyAddDequantizeInterleavedBToA_fine_grained_scale> { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr bool FineGrained = true; +}; + } // namespace arch } // namespace cutlass diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h index 68bf13bb25995d..972c9e1ffa6288 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h @@ -61,7 +61,9 @@ enum class CutlassTileConfig { // configs for large M in encoder CtaShape128x256x64_WarpShape64x64x64, - // CtaShape256x128x64_WarpShape64x64x64 + + // configs for finegrained + CtaShape256x128x64_WarpShape64x64x64, }; enum class SplitKStyle { diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h index a14604728baf4c..839745161a3d87 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h @@ -56,8 +56,10 @@ template struct GemmFpAIntB { using Mma = Mma_; @@ -103,6 +105,7 @@ struct GemmFpAIntB { /// Parameters structure struct Arguments : UniversalArgumentsBase { cutlass::gemm::GemmCoord problem_size; + int group_size; typename Mma::IteratorA::TensorRef ref_A; typename Mma::IteratorB::TensorRef ref_B; typename Mma::IteratorScale::TensorRef ref_scale; @@ -125,6 +128,7 @@ struct GemmFpAIntB { CUTLASS_HOST_DEVICE Arguments(cutlass::gemm::GemmCoord const& problem_size, + int group_size, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, typename Mma::IteratorScale::TensorRef ref_scale, @@ -143,6 +147,7 @@ struct GemmFpAIntB { problem_size, /*serial_split_k_factor=*/serial_split_k_factor, /*batch_stride_D=*/0), + group_size(group_size), ref_A(ref_A), ref_B(ref_B), ref_scale(ref_scale), @@ -181,6 +186,7 @@ struct GemmFpAIntB { int const* gather_A_indices; int const* gather_B_indices; int const* scatter_D_indices; + int group_size; // // Methods @@ -192,6 +198,7 @@ struct GemmFpAIntB { CUTLASS_HOST_DEVICE Params(Arguments const& args, int device_sms, int sm_occupancy) : ParamsBase(args, device_sms, sm_occupancy), + group_size(args.group_size), params_A(args.ref_A.layout()), ref_A(args.ref_A), params_B(args.ref_B.layout()), @@ -276,6 +283,52 @@ struct GemmFpAIntB { return Status::kSuccess; } + // Initializes the fine grained scale+bias iterator. Needed since the fine + // grained iterator has a different constructor signature than a regular + // cutlass iterator + + template + struct initialize_scale { + CUTLASS_DEVICE static IteratorScale apply( + typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, + typename IteratorScale::TensorCoord extent, + int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, + int group_size); + }; + + template + struct initialize_scale { + CUTLASS_DEVICE static IteratorScale apply( + typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, + typename IteratorScale::TensorCoord extent, + int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, + int group_size) { + return IteratorScale(params, + pointer_scale, + extent, + thread_id, + threadblock_offset, + group_size); + } + }; + + template + struct initialize_scale { + CUTLASS_DEVICE static IteratorScale apply( + typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, + typename IteratorScale::TensorCoord extent, + int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, + int group_size) { + return IteratorScale( + params, pointer_scale, extent, thread_id, threadblock_offset); + } + }; static size_t get_extra_workspace_size( Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { return 0; @@ -335,8 +388,12 @@ struct GemmFpAIntB { threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; + typename MatrixCoord::Index fg_row_offset = + threadblock_tile_offset.k() * params.gemm_k_size / 64; + typename MatrixCoord::Index scale_row_offset = + Finegrained == true ? fg_row_offset : 0; cutlass::MatrixCoord tb_offset_scale{ - 0, threadblock_tile_offset.n() * Mma::Shape::kN}; + scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN}; // Problem size is a function of threadblock index in the K dimension int problem_size_k = @@ -368,11 +425,16 @@ struct GemmFpAIntB { tb_offset_B, params.gather_B_indices); - typename Mma::IteratorScale iterator_scale(params.params_scale, - params.ref_scale.data(), - {1, params.problem_size.n()}, - thread_idx, - tb_offset_scale); + typename MatrixCoord::Index scale_row_extent = + Finegrained == true ? problem_size_k / 64 : 1; + typename Mma::IteratorScale iterator_scale = + initialize_scale::apply( + params.params_scale, + params.ref_scale.data(), + {scale_row_extent, params.problem_size.n()}, + thread_idx, + tb_offset_scale, + params.group_size); // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. @@ -383,7 +445,11 @@ struct GemmFpAIntB { // Main loop // // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + Mma mma(shared_storage.main_loop, + params.group_size, + thread_idx, + warp_idx, + lane_idx); typename Mma::FragmentC accumulators; diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm_split_k.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm_split_k.h index a4fda93533a1f6..b4f10395798612 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm_split_k.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm_split_k.h @@ -847,7 +847,7 @@ struct GemmFpAIntBSplitK { // static_assert(print_type()); // Perform this tile's range of multiply-accumulate (MAC) iterations - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + Mma mma(shared_storage.main_loop, -1, thread_idx, warp_idx, lane_idx); mma(tile_work.k_iters_remaining, accumulator_tile, diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h index da4b8d73376f65..e27a9e8ee9f84c 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h @@ -37,6 +37,7 @@ limitations under the License. */ #include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/default_mma_tensor_op.h" #include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" #include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/tile_interleaved_layout.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" #include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma.h" @@ -46,6 +47,54 @@ namespace threadblock { //////////////////////////////////////////////////////////////////////////////// +template +struct DefaultScaleIterators; + +// Fine grained iterators +template +struct DefaultScaleIterators { + using IteratorScale = + cutlass::transform::threadblock::FineGrainedScaleZeroIterator< + cutlass::MatrixShape<1, MmaShape::kN>, + Element, + Layout, + 0, + Alignment>; + + using SmemIteratorScale = IteratorScale; +}; + +// Per column iterators +template +struct DefaultScaleIterators { + // ThreadMap for scale iterator + static_assert((MmaShape::kN % Alignment) == 0, ""); + + private: + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + MmaShape::kN / Alignment, + Alignment>; + + public: + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape<1, MmaShape::kN>, + Element, + Layout, + 0, + IteratorScaleThreadMap, + Alignment>; + + using SmemIteratorScale = IteratorScale; +}; + +//////////////////////////////////////////////////////////////////////////////// + template < /// Type for elementA typename ElementA, @@ -80,7 +129,7 @@ template < /// Stages in GEMM int kStages, /// - typename Operator, + typename Operator_, /// SharedMemoryClearOption SharedMemoryClear> struct DqMma= 80)>::type> { + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value || platform::is_same::value, "Element A must be fp16 or bf16"); @@ -171,22 +223,15 @@ struct DqMma; - // ThreadMap for scale iterator static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); - using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - MmaCore::Shape::kN / kAlignmentScale, - kAlignmentScale>; + using ScaleIterators = DefaultScaleIterators; // Define iterators over tiles from the scale operand - using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape<1, MmaCore::Shape::kN>, - ElementScale, - LayoutScale, - 0, - IteratorScaleThreadMap, - kAlignmentScale>; - + using IteratorScale = typename ScaleIterators::IteratorScale; using SmemIteratorScale = IteratorScale; using Converter = FastInterleavedAndBiasedNumericArrayConverter< @@ -210,7 +255,8 @@ struct DqMma; + SharedMemoryClear, + OperatorInfo::FineGrained>; }; template < @@ -245,7 +291,7 @@ template < /// Stages in GEMM int kStages, /// - typename Operator, + typename Operator_, /// SharedMemoryClearOption SharedMemoryClear, /// @@ -269,10 +315,13 @@ struct DqMma= 80)>::type> { + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value || platform::is_same::value, "Element A must be fp16 or bf16"); @@ -364,19 +413,14 @@ struct DqMma, - MmaCore::Shape::kN / kAlignmentScale, - kAlignmentScale>; + using ScaleIterators = DefaultScaleIterators; // Define iterators over tiles from the scale operand - using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape<1, MmaCore::Shape::kN>, - ElementScale, - LayoutScale, - 0, - IteratorScaleThreadMap, - kAlignmentScale>; + using IteratorScale = typename ScaleIterators::IteratorScale; using SmemIteratorScale = IteratorScale; @@ -401,7 +445,8 @@ struct DqMma; + SharedMemoryClear, + OperatorInfo::FineGrained>; }; } // namespace threadblock diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dp_mma_multistage_finegrained.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dp_mma_multistage_finegrained.h new file mode 100644 index 00000000000000..5b6e8249aa80ce --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dp_mma_multistage_finegrained.h @@ -0,0 +1,741 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +class DqMmaMultistage : public DqMmaBase { + public: + ///< Base class + using Base = + DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + }; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave< + typename LayoutDetailsForB::Layout>::value; + static_assert(!RequiresTileInterleave || + (RequiresTileInterleave && + (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale and zero operand to + /// shared memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, // NOLINT + /// The group size for quantization + int group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_dequantizer_( + {shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / + Base::WarpCount::kM, + lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_scale_(LayoutScale(Shape::kN), + shared_storage.operand_scale.data(), + {Base::kStages, Shape::kN}, + thread_idx, + group_size) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_scales_and_advance(IteratorScale& iterator_scale, // NOLINT + int stage = -1, + int k_iter = -1) { + static_assert(IteratorScale::Shape::kRow == 1, "Scale stride must be 1."); + + typename IteratorScale::AccessType* gmem_scale_ptr = + iterator_scale.get_scale(); + // typename IteratorScale::AccessType* gmem_zero_ptr = + // iterator_scale.get_zero(); + + typename IteratorScale::AccessType* smem_scale_ptr = + reinterpret_cast( + this->smem_iterator_scale_.get_scale()); + // typename IteratorScale::AccessType* smem_zero_ptr + // = reinterpret_cast(this->smem_iterator_scale_.get_zero()); + + int const kSrcBytes = sizeof_bits::value * + IteratorScale::kAlignment / 8; + + cutlass::arch::cp_async( + smem_scale_ptr, gmem_scale_ptr, iterator_scale.valid()); + + // if (gmem_zero_ptr != nullptr) + // { + // cutlass::arch::cp_async(smem_zero_ptr, + // gmem_zero_ptr, iterator_scale.valid()); + // } + + if (iterator_scale.group_size_ == 64) { + iterator_scale.add_tile_offset({1, 0}); + } else if (iterator_scale.group_size_ == 128) { + if (iterator_scale.row_groupsize64_ & 0x1) { + iterator_scale.add_tile_offset({1, 0}); + } + } + + iterator_scale.row_groupsize64_++; + + this->smem_iterator_scale_.add_tile_offset({1, 0}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA& iterator_A, // NOLINT + IteratorB& iterator_B, // NOLINT + IteratorScale& iterator_scale, // NOLINT + int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, // NOLINT + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // + + TransformBAfterLDS lds_converter; + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint + // are zero. + // + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Wait until we have at least one committed global fetch stage. + // (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + // typename Dequantizer::FragmentZero warp_frag_zeros; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + warp_dequantizer_.load(warp_frag_scales); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + warp_dequantizer_.add_pointer_offset(Shape::kN); + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + const int warp_tileB_k_compute_offset = + warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + const int warp_tileB_k_load_offset = + warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == + Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load( + warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + + run_warp_mma(warp_mma, + accum, + warp_frag_A[warp_mma_k % 2], + converted_frag_B, + accum, + warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + iterator_scale, + group_start_iteration_A, + group_start_iteration_B); + + // This is the first group of a given stage, so we issue the loads for + // the B scales immediately. + if (group_start_iteration_B == 0) { + copy_scales_and_advance(iterator_scale); + } + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + iterator_scale, + group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Wait until we have at least one committed global fetch stage. + // (#uncommitted = Base::kStages - 1 - #committed) + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterationsForB, + 0}); + warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + } + } + + // Load the scale needed for the next tile iteration. + warp_dequantizer_.load(warp_frag_scales); + // Update internal pointer to set of scales in shared memory. + warp_dequantizer_.add_pointer_offset(Shape::kN); + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dp_mma_multistage_percol.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dp_mma_multistage_percol.h new file mode 100644 index 00000000000000..7307131f8dfd33 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dp_mma_multistage_percol.h @@ -0,0 +1,684 @@ + +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +class DqMmaMultistage + : public DqMmaBase { + public: + ///< Base class + using Base = + DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + // + // Dependent types + // + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + }; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave< + typename LayoutDetailsForB::Layout>::value; + static_assert(!RequiresTileInterleave || + (RequiresTileInterleave && + (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale operand to shared + /// memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, // NOLINT + ///< Group size for quantization. Not used by this main loop since it + ///< assumes per-column + int group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_dequantizer_( + {shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / + Base::WarpCount::kM, + lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_scale_(LayoutScale(Shape::kN), + shared_storage.operand_scale.data(), + {1, Shape::kN}, + thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA& iterator_A, // NOLINT + IteratorB& iterator_B, // NOLINT + int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, // NOLINT + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // + + TransformBAfterLDS lds_converter; + + // NOTE - switch to ldg.sts + // Issue this first, so cp.async.commit_group will commit this load as well. + // Note: we do not commit here and this load will commit in the same group + // as + // the first load of A. + FragmentScale tb_frag_scales; + tb_frag_scales.clear(); + iterator_scale.load(tb_frag_scales); + this->smem_iterator_scale_.store(tb_frag_scales); + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint + // are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + warp_dequantizer_.load(warp_frag_scales); + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + const int warp_tileB_k_compute_offset = + warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + const int warp_tileB_k_load_offset = + warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == + Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load( + warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + run_warp_mma(warp_mma, + accum, + warp_frag_A[warp_mma_k % 2], + converted_frag_B, + accum, + warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterationsForB, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + } + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h index bf95ed2fc3540c..b6911f05a45000 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h @@ -1,33 +1,34 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: * - * http://www.apache.org/licenses/LICENSE-2.0 + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ /*! \file \brief Template for a double-buffered threadblock-scoped GEMM kernel. */ @@ -94,559 +95,12 @@ template < /// Use zfill or predicate for out-of-bound cp.async SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, /// Used for partial specialization - typename Enable = bool> -class DqMmaMultistage : public DqMmaBase { - public: - ///< Base class - using Base = - DqMmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - ///< Policy describing tuning details - using Policy = Policy_; - - using IteratorScale = IteratorScale_; - using ElementScale = typename IteratorScale::Element; - using LayoutScale = typename IteratorScale::Layout; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorScale = SmemIteratorScale_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - using TransformBAfterLDS = TransformBAfterLDS_; - - // - // Dependent types - // - - /// Fragment of operand Scale loaded from global memory; - using FragmentScale = typename IteratorScale::Fragment; - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; - - using Dequantizer = warp::MmaTensorOpDequantizer; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - /// Internal structure exposed for introspection. - struct Detail { - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = - IteratorA::ThreadMap::Iterations::kCount; - - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = - IteratorB::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA = - (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / - Base::kWarpGemmIterations; - - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB = - (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / - Base::kWarpGemmIterations; - }; - - private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - Dequantizer warp_dequantizer_; - - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; - - static constexpr bool RequiresTileInterleave = - layout::IsColumnMajorTileInterleave< - typename LayoutDetailsForB::Layout>::value; - static_assert(!RequiresTileInterleave || - (RequiresTileInterleave && - (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); - - private: - // - // Data members - // - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - /// Iterator to write threadblock-scoped tile of scale operand to shared - /// memory - SmemIteratorScale smem_iterator_scale_; - - public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage& shared_storage, // NOLINT - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx) - : Base(shared_storage, thread_idx, warp_idx, lane_idx), - warp_dequantizer_( - {shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, - (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / - Base::WarpCount::kM, - lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), - smem_iterator_scale_(LayoutScale(Shape::kN), - shared_storage.operand_scale.data(), - {1, Shape::kN}, - thread_idx) { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset( - {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset( - {Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); - } - - CUTLASS_DEVICE - void copy_tiles_and_advance(IteratorA& iterator_A, // NOLINT - IteratorB& iterator_B, // NOLINT - int group_start_A = 0, - int group_start_B = 0) { - iterator_A.set_iteration_index(group_start_A * - IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { - typename IteratorA::AccessType* dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_A.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, iterator_A.valid()); - } else { - cutlass::arch::cp_async( - dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - } - - iterator_B.set_iteration_index(group_start_B * - IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType* dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_B.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill( - dst_ptr + v, gmem_ptr, iterator_B.valid()); - } else { - cutlass::arch::cp_async( - dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - - ++iterator_B; - } - ++this->smem_iterator_B_; - } - } - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC& accum, // NOLINT - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< iterator over scale operand in global memory - IteratorScale iterator_scale, - ///< initial value of accumulator - FragmentC const& src_accum) { - // - // Prologue - // - - TransformBAfterLDS lds_converter; - - // NOTE - switch to ldg.sts - // Issue this first, so cp.async.commit_group will commit this load as well. - // Note: we do not commit here and this load will commit in the same group - // as - // the first load of A. - FragmentScale tb_frag_scales; - tb_frag_scales.clear(); - iterator_scale.load(tb_frag_scales); - this->smem_iterator_scale_.store(tb_frag_scales); - - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; - ++stage, --gemm_k_iterations) { - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType* dst_ptr = - reinterpret_cast( - this->smem_iterator_A_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - int const kSrcBytes = - sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType* dst_ptr = - reinterpret_cast( - this->smem_iterator_B_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); - - ++iterator_B; - } - - ++this->smem_iterator_B_; - } - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Defines the boundary of a stage of cp.async. - cutlass::arch::cp_async_fence(); - } - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - // - // Clear the remaining tiles of SMEM. This is a functional requirement for - // some kernels so that all accumulator elements outside the GEMM footprint - // are zero. - // - - if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { - /// Iterator to write threadblock-scoped tile of A operand to shared - /// memory - SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); - - typename IteratorA::AccessType zero_A; - zero_A.clear(); - - last_smem_iterator_A.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType* dst_ptr = - reinterpret_cast( - last_smem_iterator_A.get()); - - *dst_ptr = zero_A; - - ++last_smem_iterator_A; - } - - /// Iterator to write threadblock-scoped tile of B operand to shared - /// memory - SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); - typename IteratorB::AccessType zero_B; - - zero_B.clear(); - last_smem_iterator_B.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType* dst_ptr = - reinterpret_cast( - last_smem_iterator_B.get()); - - *dst_ptr = zero_B; - - ++last_smem_iterator_B; - } - } - - // Waits until kStages-2 stages have committed. - cutlass::arch::cp_async_wait(); - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - typename Dequantizer::FragmentScale warp_frag_scales; - - Operator warp_mma; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); - warp_dequantizer_.load(warp_frag_scales); - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; - ++warp_mma_k) { - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % - Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - const int warp_tileB_k_compute_offset = - warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - const int warp_tileB_k_load_offset = - warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - if (warp_tileB_k_compute_offset == - Base::kNumKIterationsPerWarpBLoad - 1) { - this->warp_tile_iterator_B_.set_kgroup_index( - (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load( - warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - - typename TransformBAfterLDS::result_type converted_frag_B = - lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); - run_warp_mma(warp_mma, - accum, - warp_frag_A[warp_mma_k % 2], - converted_frag_B, - accum, - warp_tileB_k_compute_offset); - - // Issue global->shared copies for the this stage - if (warp_mma_k < Base::kWarpGemmIterations - 1) { - int group_start_iteration_A, group_start_iteration_B; - - group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, - iterator_B, - group_start_iteration_A, - group_start_iteration_B); - } - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) { - int group_start_iteration_A, group_start_iteration_B; - group_start_iteration_A = - (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = - (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, - iterator_B, - group_start_iteration_A, - group_start_iteration_B); - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - - // Waits until kStages-2 stages have committed. - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } else { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) { - this->warp_tile_iterator_A_.add_tile_offset( - {0, - -Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterationsForB, - 0}); - smem_read_stage_idx = 0; - } else { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - } - } - } - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - // commit and drain all pending and predicated LDGSTS pnz from the GEMM - // mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// + bool FineGrained = false> +class DqMmaMultistage; } // namespace threadblock } // namespace gemm } // namespace cutlass -///////////////////////////////////////////////////////////////////////////////////////////////// +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dp_mma_multistage_finegrained.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dp_mma_multistage_percol.h" diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h index 65fe9693727ea0..9071e1affad16b 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h @@ -195,11 +195,18 @@ class DqMmaPipelined : public DqMmaBase=80. + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp ) : Base(shared_storage, thread_idx, warp_idx, lane_idx), warp_dequantizer_( diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h index 1426182b1363c3..e02e79316c460f 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h @@ -194,6 +194,13 @@ class MmaTensorOpDequantizer< } } + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_ += offset; + } + private: ElementScale const* pointer_; }; @@ -297,6 +304,13 @@ class MmaTensorOpDequantizer< } } + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_ += offset; + } + private: ElementScale const* pointer_; }; diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h new file mode 100644 index 00000000000000..24c95134cfe293 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h @@ -0,0 +1,277 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/*! \file + \brief Templates for visiting scales to be used when dequantizing the + weights for weight-only GEMM quantization. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +// #define _DEBUG_CUTLASS_FINE_GRAINED_SCALE_ZERO_ITERATOR + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +class FineGrainedScaleZeroIterator; + +template +class FineGrainedScaleZeroIterator { + public: + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = 0; + static int const kAlignment = Alignment_; + + static int const kAccessesPerVector = 1; + + /// Row index of scales corresponding to the groupsize of 64 + int row_groupsize64_; + int group_size_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using AccessType = AlignedArray; + + // For compatibility with existing iterator interface + struct Params { + LongIndex stride_ = 0; + + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_ = 0; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : stride_(layout.stride(0)) { // NOLINT + inc_advance_ = Shape::kRow * stride_ * sizeof_bits::value / 8; + } + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const params_; + + /// Internal pointer to first access of tile + BytePointer pointer_scale_; + // BytePointer pointer_zero_; + + bool is_valid_ = false; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_DEVICE + FineGrainedScaleZeroIterator( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of scale tensor + Pointer pointer_scale, + ///< Pointer to start of zero tensor + // Pointer pointer_zero, + ///< Extent of the scale and bias + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + ///< Group size + int group_size) + : params_(params), + pointer_scale_(reinterpret_cast( + const_cast(pointer_scale))) { + row_groupsize64_ = threadblock_offset.row(); + group_size_ = group_size; + + const LongIndex tb_row_byte_offset = threadblock_offset.row() / + (group_size / 64) * params_.stride_ * + sizeof_bits::value / 8; + const LongIndex tb_col_byte_offset = + threadblock_offset.column() * sizeof_bits::value / 8; + + pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset); + + // TODO(freeliuzc): support ZERO + // if (pointer_zero_ != nullptr) + // { + // pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset); + // } + + static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment; + + const int thread_row = thread_id / THREADS_PER_ROW; + const int thread_col = thread_id % THREADS_PER_ROW; + const LongIndex thread_row_byte_offset = + thread_row * params_.stride_ * sizeof_bits::value / 8; + const LongIndex thread_col_byte_offset = + thread_col * kAlignment * sizeof_bits::value / 8; + + pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset); + // TODO(freeliuzc): support ZERO + // if (pointer_zero_ != nullptr) + // { + // pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset); + // } + + // For the rows, we must check that we are within the extent AND the tile to + // avoid extra reads on a given iteration. The same threads will be + // responsible for issues reads since the number of scales read in a given + // iteration is a constant. Therefore, we should never have to update + // is_valid_ outside of the constructor. + const int global_row = threadblock_offset.row() + thread_row; + const int global_col = + threadblock_offset.column() + thread_col * kAlignment; + + const bool row_in_bounds = + global_row < extent.row() && thread_row < Shape::kRow; + const bool col_in_bounds = global_col < extent.column(); + + is_valid_ = row_in_bounds && col_in_bounds; + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator( + Params const& params, ///< Precomputed parameters object + Pointer pointer_scale, ///< Pointer to start of scale tensor + // Pointer pointer_zero, ///< Pointer to start of zero tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + int group_size) + : FineGrainedScaleZeroIterator(params, + pointer_scale, + extent, + thread_id, + make_Coord(0, 0), + group_size) {} + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_; + const LongIndex col_byte_offset = + tile_offset.column() * Shape::kColumn * sizeof_bits::value / 8; + pointer_scale_ += row_byte_offset + col_byte_offset; + + // TODO(freeliuzc): support ZERO + // if (pointer_zero_ != nullptr) + // { + // pointer_zero_ += row_byte_offset + col_byte_offset; + // } + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE void clear_mask(bool enable = true) { + is_valid_ &= (!enable); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { return is_valid_; } + + /// Returns a scale pointer + CUTLASS_HOST_DEVICE + AccessType* get_scale() const { + return reinterpret_cast(pointer_scale_); + } + + // TODO(freeliuzc): support ZERO + // Returns a zero pointer + // CUTLASS_HOST_DEVICE + // AccessType* get_zero() const + // { + // return reinterpret_cast(pointer_zero_); + // } +}; + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h index 38048a08f9c0d0..ff878f896a74de 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h @@ -56,6 +56,8 @@ static TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) { // {256, 128} have better performance than 128, 128 case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: return TileShape{128, 256}; + case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: + return TileShape{256, 128}; default: throw std::runtime_error( "[fpA_intB_gemm Error][get_grid_shape_for_config] Invalid config"); @@ -106,7 +108,8 @@ static std::vector get_candidate_tiles( const bool is_weight_only, const bool is_weight_only_encoder, const bool simt_configs_only, - const int sm) { + const int sm, + const int group_size) { VLOG(3) << "get_candidate_tiles sm: " << sm; std::vector simt_configs{ CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; @@ -124,13 +127,23 @@ static std::vector get_candidate_tiles( CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape64x64x64, CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64, - CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64}; + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + }; + std::vector quant_B_configs_sm80_finegrained{ + CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64, + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64, + }; std::vector quant_B_configs; switch (sm) { case 86: - case 80: - quant_B_configs = quant_B_configs_sm80; + case 80: { + quant_B_configs = group_size > 0 ? quant_B_configs_sm80_finegrained + : quant_B_configs_sm80; break; + } case 75: case 70: quant_B_configs = quant_B_configs_sm70; @@ -147,12 +160,17 @@ static std::vector get_candidate_tiles( } static std::vector get_candidate_configs( - int sm, + const int sm, + const int group_size, const bool is_weight_only, const bool is_weight_only_encoder, const bool simt_configs_only) { - std::vector tiles = get_candidate_tiles( - is_weight_only, is_weight_only_encoder, simt_configs_only, sm); + std::vector tiles = + get_candidate_tiles(is_weight_only, + is_weight_only_encoder, + simt_configs_only, + sm, + group_size); std::vector candidate_configs; const int min_stages = 2; @@ -174,11 +192,13 @@ static CutlassGemmConfig estimate_best_config_from_occupancies( const int64_t m, const int64_t n, const int64_t k, + const int group_size, const int64_t num_experts, const int split_k_limit, const size_t workspace_bytes, const int multi_processor_count, - const int is_weight_only) { + const int is_weight_only, + const int sm) { if (occupancies.size() != candidate_configs.size()) { throw std::runtime_error( "[fpA_intB_gemm Error][estimate_best_config_from_occupancies] " @@ -187,14 +207,41 @@ static CutlassGemmConfig estimate_best_config_from_occupancies( } CutlassGemmConfig best_config; - if (m >= 256 && + + if (m >= 256 && sm == 86 && group_size > 0 && std::find_if( candidate_configs.begin(), candidate_configs.end(), [](const CutlassGemmConfig& gemm_config) { return gemm_config.tile_config == - CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64; + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64; }) != candidate_configs.end()) { + best_config = CutlassGemmConfig{ + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64, + SplitKStyle::NO_SPLIT_K, + 1, + 2}; + } else if (m >= 256 && sm == 80 && group_size > 0 && + std::find_if(candidate_configs.begin(), + candidate_configs.end(), + [](const CutlassGemmConfig& gemm_config) { + return gemm_config.tile_config == + CutlassTileConfig:: + CtaShape256x128x64_WarpShape64x64x64; + }) != candidate_configs.end()) { + best_config = CutlassGemmConfig{ + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64, + SplitKStyle::NO_SPLIT_K, + 1, + 4}; + } else if (m >= 256 && sm == 80 && group_size <= 0 && + std::find_if(candidate_configs.begin(), + candidate_configs.end(), + [](const CutlassGemmConfig& gemm_config) { + return gemm_config.tile_config == + CutlassTileConfig:: + CtaShape128x256x64_WarpShape64x64x64; + }) != candidate_configs.end()) { best_config = CutlassGemmConfig{ CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, SplitKStyle::NO_SPLIT_K, diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h index 15c5267ae0f9dc..0fef3771f2f05a 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h @@ -63,6 +63,7 @@ class CutlassFpAIntBGemmRunner { int m, int n, int k, + int group_size, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream); @@ -75,6 +76,7 @@ class CutlassFpAIntBGemmRunner { int m, int n, int k, + int group_size, std::string activation_type, char* workspace_ptr, const size_t workspace_bytes, @@ -84,7 +86,7 @@ class CutlassFpAIntBGemmRunner { int getWorkspaceSize(const int m, const int n, const int k); private: - template + template void dispatch_to_arch(const T* A, const WeightType* B, const T* weight_scales, @@ -93,13 +95,14 @@ class CutlassFpAIntBGemmRunner { int m, int n, int k, + int group_size, CutlassGemmConfig gemm_config, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr); - template + template void run_gemm(const T* A, const WeightType* B, const T* weight_scales, @@ -108,6 +111,7 @@ class CutlassFpAIntBGemmRunner { int m, int n, int k, + int group_size, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream); @@ -136,6 +140,7 @@ class CutlassFpAIntBGemmRunner { int m, int n, int k, + int group_size, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream); @@ -148,6 +153,7 @@ class CutlassFpAIntBGemmRunner { int m, int n, int k, + int group_size, std::string activation_type, char* workspace_ptr, const size_t workspace_bytes, diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.cu b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.cu index 2f566d4dbc35e1..dce644bd7ae1d6 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.cu +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.cu @@ -42,6 +42,7 @@ template void dispatch_gemm_config(const T* A, @@ -52,6 +53,7 @@ void dispatch_gemm_config(const T* A, int m, int n, int k, + int group_size, CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, @@ -63,6 +65,7 @@ void dispatch_gemm_config(const T* A, WeightType, arch, EpilogueTag, + FineGrained, ThreadblockShape, WarpShape, 2>; @@ -74,6 +77,7 @@ void dispatch_gemm_config(const T* A, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, @@ -85,6 +89,7 @@ void dispatch_gemm_config(const T* A, WeightType, arch, EpilogueTag, + FineGrained, ThreadblockShape, WarpShape, 3>; @@ -96,6 +101,7 @@ void dispatch_gemm_config(const T* A, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, @@ -107,6 +113,7 @@ void dispatch_gemm_config(const T* A, WeightType, arch, EpilogueTag, + FineGrained, ThreadblockShape, WarpShape, 4>; @@ -118,6 +125,7 @@ void dispatch_gemm_config(const T* A, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, @@ -129,6 +137,7 @@ void dispatch_gemm_config(const T* A, WeightType, arch, EpilogueTag, + FineGrained, ThreadblockShape, WarpShape, 5>; @@ -140,6 +149,7 @@ void dispatch_gemm_config(const T* A, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, @@ -154,7 +164,11 @@ void dispatch_gemm_config(const T* A, } } -template +template void dispatch_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, @@ -163,6 +177,7 @@ void dispatch_gemm_to_cutlass(const T* A, int m, int n, int k, + int group_size, char* workspace, size_t workspace_bytes, CutlassGemmConfig gemm_config, @@ -179,6 +194,7 @@ void dispatch_gemm_to_cutlass(const T* A, WeightType, arch, EpilogueTag, + FineGrained, cutlass::gemm::GemmShape<16, 128, 64>, cutlass::gemm::GemmShape<16, 32, 64>>( A, @@ -189,6 +205,7 @@ void dispatch_gemm_to_cutlass(const T* A, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, @@ -201,6 +218,7 @@ void dispatch_gemm_to_cutlass(const T* A, WeightType, arch, EpilogueTag, + FineGrained, cutlass::gemm::GemmShape<32, 128, 64>, cutlass::gemm::GemmShape<32, 32, 64>>( A, @@ -211,6 +229,7 @@ void dispatch_gemm_to_cutlass(const T* A, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, @@ -222,6 +241,7 @@ void dispatch_gemm_to_cutlass(const T* A, WeightType, arch, EpilogueTag, + FineGrained, cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>>( A, @@ -232,6 +252,7 @@ void dispatch_gemm_to_cutlass(const T* A, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, @@ -244,6 +265,7 @@ void dispatch_gemm_to_cutlass(const T* A, WeightType, arch, EpilogueTag, + FineGrained, cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>>( A, @@ -254,6 +276,7 @@ void dispatch_gemm_to_cutlass(const T* A, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, @@ -266,6 +289,7 @@ void dispatch_gemm_to_cutlass(const T* A, WeightType, arch, EpilogueTag, + FineGrained, cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>>( A, @@ -276,6 +300,30 @@ void dispatch_gemm_to_cutlass(const T* A, m, n, k, + group_size, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 64, 64>>( + A, + B, + weight_scales, + biases, + C, + m, + n, + k, + group_size, gemm_config, workspace, workspace_bytes, @@ -300,7 +348,11 @@ void dispatch_gemm_to_cutlass(const T* A, } } -template +template void dispatch_gemm_to_cutlass_sm7x(const T* A, const WeightType* B, const T* weight_scales, @@ -309,6 +361,7 @@ void dispatch_gemm_to_cutlass_sm7x(const T* A, int m, int n, int k, + int group_size, char* workspace, size_t workspace_bytes, CutlassGemmConfig gemm_config, @@ -324,6 +377,7 @@ void dispatch_gemm_to_cutlass_sm7x(const T* A, WeightType, arch, EpilogueTag, + FineGrained, cutlass::gemm::GemmShape<32, 128, 64>, cutlass::gemm::GemmShape<32, 32, 64>>( A, @@ -334,6 +388,7 @@ void dispatch_gemm_to_cutlass_sm7x(const T* A, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, @@ -345,6 +400,7 @@ void dispatch_gemm_to_cutlass_sm7x(const T* A, WeightType, arch, EpilogueTag, + FineGrained, cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>>( A, @@ -355,6 +411,7 @@ void dispatch_gemm_to_cutlass_sm7x(const T* A, m, n, k, + group_size, gemm_config, workspace, workspace_bytes, @@ -394,8 +451,9 @@ CutlassFpAIntBGemmRunner::~CutlassFpAIntBGemmRunner() { } template -template -void CutlassFpAIntBGemmRunner::dispatch_to_arch( +template +void CutlassFpAIntBGemmRunner::dispatch_to_arch( const T* A, const WeightType* B, const T* weight_scales, @@ -404,6 +462,7 @@ void CutlassFpAIntBGemmRunner::dispatch_to_arch( int m, int n, int k, + int group_size, CutlassGemmConfig gemm_config, char* workspace_ptr, const size_t workspace_bytes, @@ -415,19 +474,21 @@ void CutlassFpAIntBGemmRunner::dispatch_to_arch( dispatch_gemm_to_cutlass_sm7x(A, - B, - weight_scales, - biases, - C, - m, - n, - k, - workspace_ptr, - workspace_bytes, - gemm_config, - stream, - occupancy); + EpilogueTag, + false>(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + group_size, + workspace_ptr, + workspace_bytes, + gemm_config, + stream, + occupancy); #else throw std::runtime_error( "[CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for " @@ -438,19 +499,21 @@ void CutlassFpAIntBGemmRunner::dispatch_to_arch( dispatch_gemm_to_cutlass_sm7x(A, - B, - weight_scales, - biases, - C, - m, - n, - k, - workspace_ptr, - workspace_bytes, - gemm_config, - stream, - occupancy); + EpilogueTag, + false>(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + group_size, + workspace_ptr, + workspace_bytes, + gemm_config, + stream, + occupancy); #else throw std::runtime_error( "[CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for " @@ -458,20 +521,24 @@ void CutlassFpAIntBGemmRunner::dispatch_to_arch( #endif } else if (sm_ >= 80 && sm_ < 90) { #if defined(USE_FPAINTB_GEMM_WITH_SM80) - dispatch_gemm_to_cutlass( - A, - B, - weight_scales, - biases, - C, - m, - n, - k, - workspace_ptr, - workspace_bytes, - gemm_config, - stream, - occupancy); + dispatch_gemm_to_cutlass(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + group_size, + workspace_ptr, + workspace_bytes, + gemm_config, + stream, + occupancy); #else throw std::runtime_error( "[CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for " @@ -485,8 +552,9 @@ void CutlassFpAIntBGemmRunner::dispatch_to_arch( } template -template -void CutlassFpAIntBGemmRunner::run_gemm( +template +void CutlassFpAIntBGemmRunner::run_gemm( const T* A, const WeightType* B, const T* weight_scales, @@ -495,30 +563,32 @@ void CutlassFpAIntBGemmRunner::run_gemm( int m, int n, int k, + int group_size, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { // VLOG(3)<<__PRETTY_FUNCTION__; static constexpr bool is_weight_only = !std::is_same::value; const bool is_weight_only_encoder = m >= 512 ? true : false; - std::vector candidate_configs = - get_candidate_configs(sm_, is_weight_only, is_weight_only_encoder, false); + std::vector candidate_configs = get_candidate_configs( + sm_, group_size, is_weight_only, is_weight_only_encoder, false); std::vector occupancies(candidate_configs.size()); for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { - dispatch_to_arch(A, - B, - weight_scales, - biases, - C, - m, - n, - k, - candidate_configs[ii], - workspace_ptr, - workspace_bytes, - stream, - &occupancies[ii]); + dispatch_to_arch(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + group_size, + candidate_configs[ii], + workspace_ptr, + workspace_bytes, + stream, + &occupancies[ii]); } // Standard GEMM, so 1 "expert". We use the same function for MoE and regular // FFN. @@ -529,24 +599,27 @@ void CutlassFpAIntBGemmRunner::run_gemm( m, n, k, + group_size, num_experts, split_k_limit, workspace_bytes, multi_processor_count_, - is_weight_only); + is_weight_only, + sm_); - dispatch_to_arch(A, - B, - weight_scales, - biases, - C, - m, - n, - k, - chosen_config, - workspace_ptr, - workspace_bytes, - stream); + dispatch_to_arch(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + group_size, + chosen_config, + workspace_ptr, + workspace_bytes, + stream); } template @@ -559,6 +632,7 @@ void CutlassFpAIntBGemmRunner::gemm_bias_act( int m, int n, int k, + int group_size, std::string activation_type, char* workspace_ptr, const size_t workspace_bytes, @@ -570,17 +644,37 @@ void CutlassFpAIntBGemmRunner::gemm_bias_act( PADDLE_THROW(phi::errors::Unimplemented( "Activation_type = relu for fpA_intB gemm is not instantiated.")); } else if (activation_type == "none") { - run_gemm(A, - B, - weight_scales, - biases, - C, - m, - n, - k, - workspace_ptr, - workspace_bytes, - stream); + if (group_size > 0) { + PADDLE_ENFORCE_GE(sm_, + 80, + phi::errors::Unimplemented( + "Groupwise mode is not supported on SM < 8.0")); + run_gemm(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + group_size, + workspace_ptr, + workspace_bytes, + stream); + } else { + run_gemm(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + group_size, + workspace_ptr, + workspace_bytes, + stream); + } } else { throw std::runtime_error(("Invalid activation type.")); } @@ -594,21 +688,41 @@ void CutlassFpAIntBGemmRunner::gemm(const T* A, int m, int n, int k, + int group_size, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { - // VLOG(3)<<__PRETTY_FUNCTION__; - run_gemm(A, - B, - weight_scales, - nullptr, - C, - m, - n, - k, - workspace_ptr, - workspace_bytes, - stream); + if (group_size > 0) { + PADDLE_ENFORCE_GE(sm_, + 80, + phi::errors::Unimplemented( + "Groupwise mode is not supported on SM < 8.0")); + run_gemm(A, + B, + weight_scales, + nullptr, + C, + m, + n, + k, + group_size, + workspace_ptr, + workspace_bytes, + stream); + } else { + run_gemm(A, + B, + weight_scales, + nullptr, + C, + m, + n, + k, + group_size, + workspace_ptr, + workspace_bytes, + stream); + } } template @@ -636,6 +750,7 @@ void CutlassFpAIntBGemmRunner::gemm_bias_act( int m, int n, int k, + int group_size, std::string activation_type, char* workspace_ptr, const size_t workspace_bytes, @@ -654,6 +769,7 @@ void CutlassFpAIntBGemmRunner::gemm( int m, int n, int k, + int group_size, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h index 8ae1047c43afc8..f7c73dc99cede8 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h @@ -54,6 +54,7 @@ template @@ -65,6 +66,7 @@ void generic_mixed_gemm_kernelLauncher(const T* A, int m, int n, int k, + int group_size, CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, @@ -117,7 +119,13 @@ void generic_mixed_gemm_kernelLauncher(const T* A, MixedGemmArchTraits::ElementsPerAccessC, ElementAccumulator, EpilogueTag>::Op; - if (gemm_config.split_k_style == SplitKStyle::NO_SPLIT_K) { + + if (gemm_config.split_k_style == SplitKStyle::NO_SPLIT_K || + FineGrained == true) { + using Operator = typename MixedGemmArchTraits::Operator; + using TaggedOperator = + typename cutlass::arch::TagOperator::TaggedOperator; using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm< ElementType, cutlass::layout::RowMajor, @@ -137,14 +145,15 @@ void generic_mixed_gemm_kernelLauncher(const T* A, typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, Stages, true, - typename MixedGemmArchTraits::Operator>::GemmKernel; + TaggedOperator>::GemmKernel; using GemmKernel = cutlass::gemm::kernel::GemmFpAIntB< typename GemmKernel_::Mma, typename GemmKernel_::Epilogue, typename GemmKernel_::ThreadblockSwizzle, arch, // Ensure top level arch is used for dispatch - GemmKernel_::kSplitKSerial>; + GemmKernel_::kSplitKSerial, + FineGrained>; if (occupancy != nullptr) { *occupancy = compute_occupancy_for_kernel(); @@ -161,9 +170,10 @@ void generic_mixed_gemm_kernelLauncher(const T* A, typename Gemm::Arguments args( {m, n, k}, + group_size, {reinterpret_cast(const_cast(A)), k}, {reinterpret_cast(const_cast(B)), ldb}, - {reinterpret_cast(const_cast(weight_scales)), 0}, + {reinterpret_cast(const_cast(weight_scales)), n}, {reinterpret_cast(const_cast(biases)), 0}, {reinterpret_cast(C), n}, gemm_config.split_k_factor, @@ -221,7 +231,8 @@ void generic_mixed_gemm_kernelLauncher(const T* A, std::string(cutlassGetStatusString(run_status)); throw std::runtime_error("[fpA_intB Runner] " + err_msg); } - } else { + + } else /* Per-Channel mode */ { // for stream-k, we set gemm_config.split_k_factor = 1 to use default load // balance. gemm_config.split_k_factor = 1; @@ -334,6 +345,7 @@ template @@ -345,6 +357,7 @@ void generic_mixed_gemm_kernelLauncher_template(const T* A, int m, int n, int k, + int group_size, CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, @@ -355,6 +368,7 @@ template struct dispatch_stages { @@ -401,6 +418,7 @@ struct dispatch_stages(A, @@ -422,6 +441,7 @@ struct dispatch_stages @@ -441,6 +462,7 @@ struct dispatch_stages(A, @@ -472,6 +496,7 @@ struct dispatch_stages void dispatch_gemm_config(const T* A, @@ -495,13 +521,18 @@ void dispatch_gemm_config(const T* A, int m, int n, int k, + int group_size, CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy); -template +template void dispatch_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, @@ -510,6 +541,7 @@ void dispatch_gemm_to_cutlass(const T* A, int m, int n, int k, + int group_size, char* workspace, size_t workspace_bytes, CutlassGemmConfig gemm_config, diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/generic_mixed_gemm_kernelLauncher.py b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/generic_mixed_gemm_kernelLauncher.py index ad7f1e65591ce9..5847956020cebe 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/generic_mixed_gemm_kernelLauncher.py +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/generic_mixed_gemm_kernelLauncher.py @@ -33,6 +33,7 @@ {WeightType}, {arch}, {EpilogueTag}, + {FineGrained}, {ThreadblockShape}, {WarpShape}, {Stages}>( @@ -44,6 +45,7 @@ int m, int n, int k, + int group_size, CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, @@ -53,6 +55,7 @@ {WeightType}, {arch}, {EpilogueTag}, + {FineGrained}, {ThreadblockShape}, {WarpShape}, {Stages}>( @@ -64,6 +67,7 @@ m, n, k, + group_size, gemm_config, workspace, workspace_bytes, @@ -87,6 +91,7 @@ "cutlass::gemm::GemmShape<64, 128, 64>", "cutlass::gemm::GemmShape<128, 128, 64>", "cutlass::gemm::GemmShape<128, 256, 64>", + "cutlass::gemm::GemmShape<256, 128, 64>", ] WarpShapes = [ "cutlass::gemm::GemmShape<16, 32, 64>", @@ -94,6 +99,7 @@ "cutlass::gemm::GemmShape<64, 64, 64>", "cutlass::gemm::GemmShape<64, 64, 64>", "cutlass::gemm::GemmShape<64, 64, 64>", + "cutlass::gemm::GemmShape<64, 64, 64>", ] ThreadblockShapes_sm70 = [ @@ -119,6 +125,9 @@ # "biasReLU": "EpilogueOpBiasReLU", } +FineGrainedTypes = ["true", "false"] +FineGrainedTypes_sm70 = ["false"] + def SubstituteTemplate(template, values): text = template @@ -174,28 +183,36 @@ def parse_args(): # generate source cu def generate_source_cu( - element_type: str, arch: int, epilogue_tag: str, stages: int + element_type: str, + arch: int, + epilogue_tag: str, + stages: int, ): all_code = CommonHead ThreadblockShapes_arch = ThreadblockShapes WarpShapes_arch = WarpShapes + FineGrainedTypes_arch = FineGrainedTypes + if arch < 80: ThreadblockShapes_arch = ThreadblockShapes_sm70 WarpShapes_arch = WarpShapes_sm70 + FineGrainedTypes_arch = FineGrainedTypes_sm70 for WeightType in WeightTypes: for i in range(len(ThreadblockShapes_arch)): - value_dict = { - "T": ElementTypes[element_type], - "WeightType": WeightType, - "arch": Archs[arch], - "EpilogueTag": EpilogueTags[epilogue_tag], - "ThreadblockShape": ThreadblockShapes_arch[i], - "WarpShape": WarpShapes_arch[i], - "Stages": str(stages), - } - all_code += SubstituteTemplate( - DispatchGemmConfigInstanceDeclare, value_dict - ) + for j in range(len(FineGrainedTypes_arch)): + value_dict = { + "T": ElementTypes[element_type], + "WeightType": WeightType, + "arch": Archs[arch], + "EpilogueTag": EpilogueTags[epilogue_tag], + "FineGrained": FineGrainedTypes_arch[j], + "ThreadblockShape": ThreadblockShapes_arch[i], + "WarpShape": WarpShapes_arch[i], + "Stages": str(stages), + } + all_code += SubstituteTemplate( + DispatchGemmConfigInstanceDeclare, value_dict + ) all_code += CommonTail return all_code @@ -221,7 +238,10 @@ def generate_source_cu( element_type, arch, stages, epilogue_tag ) all_code = generate_source_cu( - element_type, arch, epilogue_tag, stages + element_type, + arch, + epilogue_tag, + stages, ) with open(file_name, "w") as f: f.write(all_code) diff --git a/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu b/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu index c41b86148291de..901a291d3924db 100644 --- a/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu +++ b/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu @@ -89,6 +89,7 @@ we havenot support sm70 weightonly gemv, because sm70 weight layout is RowMajor. m, n, k, + group_size, "none", mixgemm_workspace_data, mixgemm_workspace_size_bytes, @@ -104,6 +105,7 @@ we havenot support sm70 weightonly gemv, because sm70 weight layout is RowMajor. m, n, k, + group_size, mixgemm_workspace_data, mixgemm_workspace_size_bytes, dev_ctx.stream()); @@ -134,6 +136,7 @@ we havenot support sm70 weightonly gemv, because sm70 weight layout is RowMajor. m, n, k, + group_size, "none", mixgemm_workspace_data, mixgemm_workspace_size_bytes, @@ -149,6 +152,7 @@ we havenot support sm70 weightonly gemv, because sm70 weight layout is RowMajor. m, n, k, + group_size, mixgemm_workspace_data, mixgemm_workspace_size_bytes, dev_ctx.stream()); diff --git a/test/quantization/test_weight_only_linear.py b/test/quantization/test_weight_only_linear.py index f3749d0b4fb15c..f09698e4a1a68b 100644 --- a/test/quantization/test_weight_only_linear.py +++ b/test/quantization/test_weight_only_linear.py @@ -109,7 +109,7 @@ def weightQuantizeCPUGPUConsistenceCheck(self, weight_float): def setUp(self): self.config() if self.dtype == "bfloat16" or self.weight_dtype == "int4": - self.atol = 1.5e-1 + self.atol = 1.3e-1 x = np.random.random((self.batch, self.token, self.in_features)) self.x = paddle.to_tensor(x, dtype=self.dtype) if self.bias: @@ -451,8 +451,10 @@ def config(self): @unittest.skipIf( - not core.is_compiled_with_cuda() or get_cuda_version() < 11020, - "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul groupwise mode need CUDA >= 11.2 and CUDA_ARCH >= 8", ) class WeightOnlyLinearTestCase17(WeightOnlyLinearTestCase): def config(self): @@ -466,8 +468,10 @@ def config(self): @unittest.skipIf( - not core.is_compiled_with_cuda() or get_cuda_version() < 11020, - "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul groupwise mode need CUDA >= 11.2 and CUDA_ARCH >= 8", ) class WeightOnlyLinearTestCase18(WeightOnlyLinearTestCase): def config(self): @@ -576,6 +580,78 @@ def config(self): self.out_features = 288 +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase25(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int4" + self.group_size = 128 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase26(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int4" + self.group_size = 64 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase27(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.weight_dtype = "int4" + self.group_size = 128 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase28(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int4" + self.token = 300 + self.group_size = 128 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase29(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int8" + self.token = 300 + self.group_size = 128 + + @unittest.skipIf( not core.is_compiled_with_cuda() or get_cuda_version() < 11020, "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", From 7b616c498fa0f4f43eba9eb96e29f8375181c0d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Thu, 4 Jan 2024 12:57:18 +0800 Subject: [PATCH 101/142] simplify extent of loop after fuse and add corresponding test case (#60538) --- .../ir/schedule/impl/loop_transformation.cc | 2 +- test/cinn/ir/test_llir_schedule_fuse_split.py | 55 ++++++++++++++++++- 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/paddle/cinn/ir/schedule/impl/loop_transformation.cc b/paddle/cinn/ir/schedule/impl/loop_transformation.cc index 1d66697f43d136..daa453d85a7442 100644 --- a/paddle/cinn/ir/schedule/impl/loop_transformation.cc +++ b/paddle/cinn/ir/schedule/impl/loop_transformation.cc @@ -264,7 +264,7 @@ Expr DyScheduleImpl::Fuse(const std::vector& loops) { for (int i = 0; i < loops_number; ++i) { fused_extent = fused_extent * for_nodes[i]->extent; } - + fused_extent = cinn::common::AutoSimplify(fused_extent); if (!fused_body.As()) fused_body = Block::Make({fused_body}); Expr new_stmt = For::Make(fused_var, Expr(0), diff --git a/test/cinn/ir/test_llir_schedule_fuse_split.py b/test/cinn/ir/test_llir_schedule_fuse_split.py index 362cb81f87b964..b4722a1a02434e 100644 --- a/test/cinn/ir/test_llir_schedule_fuse_split.py +++ b/test/cinn/ir/test_llir_schedule_fuse_split.py @@ -37,7 +37,7 @@ def elementwise_fuse_assign_loop( def elementwise_fuse_assign_loop_gt( X: DataArray((128, 128, 128)), Y: DataArray((128, 128, 128)) ): - for i in range(((1 * 128) * 128) * 128): + for i in range(2097152): with ir.ScheduleBlockContext("Y") as block_y: i1_1, j1_1, k1_1 = ir.AxisMap( "SSS", [(i / 128) / 128, (i / 128) % 128, i % 128] @@ -148,7 +148,7 @@ def elementwise_fuse_assign_loop( Y: DataArray((-1, 128, 128)), N: ir.Var(), ): - for i_j_k_fused in range(((1 * N) * 128) * 128): + for i_j_k_fused in range(16384 * N): with ir.ScheduleBlockContext("Y") as block_y: i1, j1, k1 = ir.AxisMap( "SSS", @@ -207,9 +207,60 @@ def elementwise_split( assert_llir_equal(origin.elementwise_split, expected.elementwise_split) +def test_fuse_split(): + @to_cinn_llir + def elementwise_fuse_split_origin( + X: DataArray((64, 128, 128)), Y: DataArray((64, 128, 128)) + ): + for i in range(64): + for j in range(128): + for k in range(128): + with ir.ScheduleBlockContext("Y") as Y_block: + i1, j1, k1 = ir.AxisMap("SSS", [i, j, k]) + fused = sch.fuse([i, j]) + sch.split(fused, factors=[2, 512, -1]) + Y[i1, j1, k1] = X[i1, j1, k1] * 2.0 + + @to_cinn_llir + def elementwise_fuse_split_expected( + X: DataArray((64, 128, 128)), Y: DataArray((64, 128, 128)) + ): + for i_j_fused in range(2): + for i_j_fused_0 in range(512): + for i_j_fused_1 in range(8): + for k in range(128): + with ir.ScheduleBlockContext("Y") as Y_block: + i1, j1, k1 = ir.AxisMap( + "SSS", + [ + ( + ( + (4096 * i_j_fused) + + ((8 * i_j_fused_0) + i_j_fused_1) + ) + / 128 + ), + ( + ( + (4096 * i_j_fused) + + ((8 * i_j_fused_0) + i_j_fused_1) + ) + % 128 + ), + k, + ], + ) + Y[i1, j1, k1] = X[i1, j1, k1] * 2.0 + + assert_llir_equal( + elementwise_fuse_split_origin, elementwise_fuse_split_expected + ) + + if __name__ == "__main__": test_fuse() test_split() + test_fuse_split() test_split_predicate() test_fuse_dynamic() test_split_dynamic() From f46072341f18bacdc260ffc96dafad0056b7eb5a Mon Sep 17 00:00:00 2001 From: YibLiu <68105073+YibinLiu666@users.noreply.github.com> Date: Thu, 4 Jan 2024 14:29:55 +0800 Subject: [PATCH 102/142] fix bug of put_along_axis (#60551) --- .../kernels/funcs/gather_scatter_functor.cu | 316 ++++++++++-------- test/legacy_test/test_put_along_axis_op.py | 39 +++ 2 files changed, 221 insertions(+), 134 deletions(-) diff --git a/paddle/phi/kernels/funcs/gather_scatter_functor.cu b/paddle/phi/kernels/funcs/gather_scatter_functor.cu index d293b3f7a0efac..7939589d7c6628 100644 --- a/paddle/phi/kernels/funcs/gather_scatter_functor.cu +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.cu @@ -108,10 +108,10 @@ __global__ void ScatterAssignGPUKernel(tensor_t* self_data, int64_t outer_dim_size_src, int64_t numel, int64_t numel_data, - const func_t& reduce_op) { + const func_t& reduce_op, + int* thread_ids) { int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= numel) return; - extern __shared__ int thread_ids[]; if (tid == 0) { for (int i = 0; i < numel_data; i++) { @@ -182,10 +182,10 @@ __global__ void GatherScatterGPUKernel(tensor_t* self_data, int64_t numel, int64_t numel_data, bool include_self, - const func_t& reduce_op) { + const func_t& reduce_op, + int* shared_mem) { int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= numel) return; - extern __shared__ int shared_mem[]; if (include_self == false) { if (tid == 0) { for (int i = 0; i < numel_data; i++) { @@ -262,10 +262,10 @@ __global__ void ScatterMeanGPUKernel(tensor_t* self_data, int64_t numel, int64_t numel_data, bool include_self, - const func_t& reduce_op) { + const func_t& reduce_op, + int* shared_mem) { int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= numel) return; - extern __shared__ int shared_mem[]; if (tid == 0) { for (int i = 0; i < numel_data; i++) { @@ -381,40 +381,55 @@ struct gpu_gather_scatter_functor { auto stream = reinterpret_cast(ctx).stream(); if (method_name == "scatter_assign_gpu") { int shared_mem_size = sizeof(int) * self_size; + int* shared_mem; + cudaMallocAsync( + reinterpret_cast(&shared_mem), shared_mem_size, stream); ScatterAssignGPUKernel - <<>>(self_data, - dim, - index_data, - src_data, - select_dim_size, - self_select_dim_size, - src_select_dim_size, - outer_dim_size, - outer_dim_size_self, - outer_dim_size_src, - index_size, - self_size, - reduce_op); + <<>>(self_data, + dim, + index_data, + src_data, + select_dim_size, + self_select_dim_size, + src_select_dim_size, + outer_dim_size, + outer_dim_size_self, + outer_dim_size_src, + index_size, + self_size, + reduce_op, + shared_mem); + cudaFreeAsync(shared_mem, stream); } else if (method_name == "scatter_mean_gpu") { int shared_mem_size = sizeof(int) * self_size * 2; + int* shared_mem; + cudaMallocAsync( + reinterpret_cast(&shared_mem), shared_mem_size, stream); ScatterMeanGPUKernel - <<>>(self_data, - dim, - index_data, - src_data, - select_dim_size, - self_select_dim_size, - src_select_dim_size, - outer_dim_size, - outer_dim_size_self, - outer_dim_size_src, - index_size, - self_size, - include_self, - reduce_op); + <<>>(self_data, + dim, + index_data, + src_data, + select_dim_size, + self_select_dim_size, + src_select_dim_size, + outer_dim_size, + outer_dim_size_self, + outer_dim_size_src, + index_size, + self_size, + include_self, + reduce_op, + shared_mem); + cudaFreeAsync(shared_mem, stream); } else { int shared_mem_size = 0; - if (include_self == false) shared_mem_size = sizeof(int) * self_size; + int* shared_mem = nullptr; + if (include_self == false) { + shared_mem_size = sizeof(int) * self_size; + cudaMallocAsync( + reinterpret_cast(&shared_mem), shared_mem_size, stream); + } GatherScatterGPUKernel <<>>(self_data, dim, @@ -429,7 +444,11 @@ struct gpu_gather_scatter_functor { index_size, self_size, include_self, - reduce_op); + reduce_op, + shared_mem); + if (include_self == false) { + cudaFreeAsync(shared_mem, stream); + } } } }; // struct gpu_gather_scatter_functor @@ -594,17 +613,16 @@ void gpu_scatter_input_grad_kernel(phi::DenseTensor self, int64_t n = inner_dim_size * select_dim_size * outer_dim_size; int64_t grid = (n + block - 1) / block; auto stream = reinterpret_cast(ctx).stream(); - int shared_mem_size = sizeof(int) * grad_size; ScatterInputGradGPUKernel - <<>>(grad_data, - dim, - index_data, - select_dim_size, - grad_select_dim_size, - outer_dim_size, - outer_dim_size_data, - index_size, - grad_size); + <<>>(grad_data, + dim, + index_data, + select_dim_size, + grad_select_dim_size, + outer_dim_size, + outer_dim_size_data, + index_size, + grad_size); } template @@ -618,10 +636,10 @@ __global__ void ScatterMulInputGradGPUKernel(tensor_t* grad_data, int64_t outer_dim_size, int64_t outer_dim_size_grad, int64_t numel, - int64_t numel_grad) { + int64_t numel_grad, + int* thread_ids) { int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= numel) return; - extern __shared__ int thread_ids[]; if (tid == 0) { for (int i = 0; i < numel_grad; i++) { thread_ids[i] = 0; @@ -660,10 +678,10 @@ __global__ void ScatterMinMaxInputGradGPUKernel(tensor_t* grad_data, int64_t outer_dim_size_value, int64_t numel, int64_t numel_grad, - const std::string& reduce) { + const std::string& reduce, + int* shared_mem) { int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= numel) return; - extern __shared__ int shared_mem[]; if (tid == 0) { for (int i = 0; i < numel_grad; i++) { @@ -741,37 +759,47 @@ void gpu_scatter_mul_min_max_input_grad_kernel(phi::DenseTensor self, auto stream = reinterpret_cast(ctx).stream(); if (reduce == "mul" || reduce == "multiply") { int shared_mem_size = sizeof(int) * grad_size; + int* shared_mem; + cudaMallocAsync( + reinterpret_cast(&shared_mem), shared_mem_size, stream); ScatterMulInputGradGPUKernel - <<>>(grad_data, - dim, - index_data, - out_data, - x_data, - select_dim_size, - grad_select_dim_size, - outer_dim_size, - outer_dim_size_grad, - index_size, - grad_size); + <<>>(grad_data, + dim, + index_data, + out_data, + x_data, + select_dim_size, + grad_select_dim_size, + outer_dim_size, + outer_dim_size_grad, + index_size, + grad_size, + shared_mem); + cudaFreeAsync(shared_mem, stream); } else if (reduce == "amin" || reduce == "amax") { int shared_mem_size = sizeof(int) * grad_size; + int* shared_mem; + cudaMallocAsync( + reinterpret_cast(&shared_mem), shared_mem_size, stream); ScatterMinMaxInputGradGPUKernel - <<>>(grad_data, - dim, - index_data, - out_data, - x_data, - value_data, - self_data, - select_dim_size, - grad_select_dim_size, - value_select_dim_size, - outer_dim_size, - outer_dim_size_grad, - outer_dim_size_value, - index_size, - grad_size, - reduce); + <<>>(grad_data, + dim, + index_data, + out_data, + x_data, + value_data, + self_data, + select_dim_size, + grad_select_dim_size, + value_select_dim_size, + outer_dim_size, + outer_dim_size_grad, + outer_dim_size_value, + index_size, + grad_size, + reduce, + shared_mem); + cudaFreeAsync(shared_mem, stream); } } @@ -784,10 +812,10 @@ __global__ void ScatterMeanInputGradGPUKernel(tensor_t* grad_data, int64_t outer_dim_size, int64_t outer_dim_size_grad, int64_t numel, - int64_t numel_grad) { + int64_t numel_grad, + int* shared_mem) { int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= numel) return; - extern __shared__ int shared_mem[]; if (tid == 0) { for (int i = 0; i < numel_grad; i++) { shared_mem[i] = 0; // thread_ids @@ -848,16 +876,21 @@ void gpu_scatter_mean_input_grad_kernel(phi::DenseTensor self, int64_t grid = (n + block - 1) / block; auto stream = reinterpret_cast(ctx).stream(); int shared_mem_size = sizeof(int) * grad_size * 2; + int* shared_mem; + cudaMallocAsync( + reinterpret_cast(&shared_mem), shared_mem_size, stream); ScatterMeanInputGradGPUKernel - <<>>(grad_data, - dim, - index_data, - select_dim_size, - grad_select_dim_size, - outer_dim_size, - outer_dim_size_grad, - index_size, - grad_size); + <<>>(grad_data, + dim, + index_data, + select_dim_size, + grad_select_dim_size, + outer_dim_size, + outer_dim_size_grad, + index_size, + grad_size, + shared_mem); + cudaFreeAsync(shared_mem, stream); } template @@ -872,10 +905,10 @@ __global__ void ScatterValueGradGPUKernel(tensor_t* grad_data, int64_t outer_dim_size_self, int64_t outer_dim_size_grad, int64_t numel, - int64_t numel_data) { + int64_t numel_data, + int* thread_ids) { int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= numel) return; - extern __shared__ int thread_ids[]; if (tid == 0) { for (int i = 0; i < numel_data; i++) { @@ -939,19 +972,24 @@ void gpu_scatter_value_grad_kernel(phi::DenseTensor self, int64_t grid = (n + block - 1) / block; auto stream = reinterpret_cast(ctx).stream(); int shared_mem_size = sizeof(int) * self_size; + int* shared_mem; + cudaMallocAsync( + reinterpret_cast(&shared_mem), shared_mem_size, stream); ScatterValueGradGPUKernel - <<>>(grad_data, - dim, - self_data, - index_data, - select_dim_size, - self_select_dim_size, - grad_select_dim_size, - outer_dim_size, - outer_dim_size_self, - outer_dim_size_grad, - index_size, - self_size); + <<>>(grad_data, + dim, + self_data, + index_data, + select_dim_size, + self_select_dim_size, + grad_select_dim_size, + outer_dim_size, + outer_dim_size_self, + outer_dim_size_grad, + index_size, + self_size, + shared_mem); + cudaFreeAsync(shared_mem, stream); } template @@ -967,10 +1005,10 @@ __global__ void ScatterMeanValueGradGPUKernel(tensor_t* grad_data, int64_t outer_dim_size_grad, int64_t numel, int64_t numel_self, - bool include_self) { + bool include_self, + int* shared_mem) { int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= numel) return; - extern __shared__ int shared_mem[]; if (tid == 0) { for (int i = 0; i < numel_self; i++) { @@ -1073,20 +1111,25 @@ void gpu_scatter_add_mean_value_grad_kernel( auto stream = reinterpret_cast(ctx).stream(); if (reduce == "mean") { int shared_mem_size = sizeof(int) * self_size; + int* shared_mem; + cudaMallocAsync( + reinterpret_cast(&shared_mem), shared_mem_size, stream); ScatterMeanValueGradGPUKernel - <<>>(grad_data, - dim, - self_data, - index_data, - select_dim_size, - self_select_dim_size, - grad_select_dim_size, - outer_dim_size, - outer_dim_size_self, - outer_dim_size_grad, - index_size, - self_size, - include_self); + <<>>(grad_data, + dim, + self_data, + index_data, + select_dim_size, + self_select_dim_size, + grad_select_dim_size, + outer_dim_size, + outer_dim_size_self, + outer_dim_size_grad, + index_size, + self_size, + include_self, + shared_mem); + cudaFreeAsync(shared_mem, stream); } else if (reduce == "add") { ScatterAddValueGradGPUKernel <<>>(grad_data, @@ -1150,10 +1193,10 @@ __global__ void ScatterMinMaxValueGradGPUKernel(tensor_t* grad_data, int64_t outer_dim_size_grad, int64_t numel, int64_t numel_self, - bool include_self) { + bool include_self, + int* shared_mem) { int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= numel) return; - extern __shared__ int shared_mem[]; int64_t i, j, k; i = tid / (select_dim_size * outer_dim_size); int64_t remind = tid % (select_dim_size * outer_dim_size); @@ -1246,23 +1289,28 @@ void gpu_scatter_mul_min_max_value_grad_kernel(phi::DenseTensor self, index_size); } else if (reduce == "amin" || reduce == "amax") { int shared_mem_size = sizeof(int) * self_size; + int* shared_mem; + cudaMallocAsync( + reinterpret_cast(&shared_mem), shared_mem_size, stream); ScatterMinMaxValueGradGPUKernel - <<>>(grad_data, - dim, - index_data, - self_data, - value_data, - out_data, - x_data, - select_dim_size, - self_select_dim_size, - grad_select_dim_size, - outer_dim_size, - outer_dim_size_self, - outer_dim_size_grad, - index_size, - self_size, - include_self); + <<>>(grad_data, + dim, + index_data, + self_data, + value_data, + out_data, + x_data, + select_dim_size, + self_select_dim_size, + grad_select_dim_size, + outer_dim_size, + outer_dim_size_self, + outer_dim_size_grad, + index_size, + self_size, + include_self, + shared_mem); + cudaFreeAsync(shared_mem, stream); } } diff --git a/test/legacy_test/test_put_along_axis_op.py b/test/legacy_test/test_put_along_axis_op.py index e3e7077ea0f339..47cfc65d617136 100644 --- a/test/legacy_test/test_put_along_axis_op.py +++ b/test/legacy_test/test_put_along_axis_op.py @@ -738,6 +738,45 @@ def run(place): run(place) +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not complied with CUDA", +) +class TestPutAlongAxisAPILargeCase(unittest.TestCase): + def setUp(self): + np.random.seed(0) + self.shape = [64, 1327104] + self.index_shape = [64, 1327104] + self.index_np = np.zeros(self.index_shape).astype('int64') + self.x_np = np.random.random(self.shape).astype(np.float32) + self.axis = 1 + self.value_np = np.ones(self.index_shape).astype(np.float32) + self.x_feed = copy.deepcopy(self.x_np) + self.place = [paddle.CUDAPlace(0)] + + def test_api_dygraph(self): + def run(place): + paddle.disable_static(place) + x_tensor = paddle.to_tensor(self.x_np) + index_tensor = paddle.to_tensor(self.index_np) + value_tensor = paddle.to_tensor(self.value_np) + out = paddle.put_along_axis( + x_tensor, index_tensor, value_tensor, self.axis + ) + np.array( + np.put_along_axis( + self.x_np, self.index_np, self.value_np, self.axis + ) + ) + out_ref = self.x_np + np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001) + + paddle.enable_static() + + for place in self.place: + run(place) + + class TestPutAlongAxisAPICase2(TestPutAlongAxisAPI): def setUp(self): np.random.seed(0) From f07846a3aeb384fbce6951af4cec51b41b65c38d Mon Sep 17 00:00:00 2001 From: engineer1109 Date: Thu, 4 Jan 2024 14:37:34 +0800 Subject: [PATCH 103/142] remove clearPass to allow custom device use fusion under fp16 (#60541) --- paddle/fluid/inference/api/analysis_predictor.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index e042f358c9874c..c7164b61bb7c00 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1781,7 +1781,6 @@ void AnalysisPredictor::PrepareArgument() { argument_->SetEnableCustomDeviceMixed(config_.enable_custom_device_mixed()); if (config_.enable_custom_device_mixed_) { argument_->SetEnableIrOptim(true); - pass_builder->ClearPasses(); pass_builder->AppendPass("auto_mixed_precision_pass"); LOG(INFO) << "This model run in Custom Device mixed precision mode."; } From 5f1727bface4a5a31fc81e45ef70506aa0c4e643 Mon Sep 17 00:00:00 2001 From: danleifeng <52735331+danleifeng@users.noreply.github.com> Date: Thu, 4 Jan 2024 14:46:11 +0800 Subject: [PATCH 104/142] fix fleetutil get_online_pass_interval bug2; test=develop (#60544) --- python/paddle/incubate/distributed/fleet/fleet_util.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/paddle/incubate/distributed/fleet/fleet_util.py b/python/paddle/incubate/distributed/fleet/fleet_util.py index 1777afffe9aaf4..e0450948e61714 100644 --- a/python/paddle/incubate/distributed/fleet/fleet_util.py +++ b/python/paddle/incubate/distributed/fleet/fleet_util.py @@ -1324,7 +1324,8 @@ def get_online_pass_interval( and "/" not in days and "(" not in days and ")" not in days - ), r"days should not contain [|,;,\,/,(,)]" + and "&" not in days + ), r"days should not contain [|,;,\,/,(,),&]" days = os.popen("echo -n " + days).read().split(" ") assert ( "|" not in hours @@ -1332,8 +1333,9 @@ def get_online_pass_interval( and "\\" not in hours and "/" not in hours and "(" not in hours - and ")" not in days - ), r"hours should not contain [|,;,\,/,(,)]" + and ")" not in hours + and "&" not in hours + ), r"hours should not contain [|,;,\,/,(,),&]" hours = os.popen("echo -n " + hours).read().split(" ") split_interval = int(split_interval) split_per_pass = int(split_per_pass) From 90183718b44d22074540975dde285ba2c51fe2b2 Mon Sep 17 00:00:00 2001 From: xuxinyi389 <104957571+xuxinyi389@users.noreply.github.com> Date: Thu, 4 Jan 2024 15:10:31 +0800 Subject: [PATCH 105/142] fix vs2017 limit (#60528) --- .../hlir/dialect/operator/ir/op_dialect.cc | 15 ++++++ .../fluid/pir/dialect/op_generator/op_gen.py | 50 ++++++++++++++++--- .../pir/dialect/operator/ir/op_dialect.cc | 17 ++++++- .../dialect/operator/ir/op_onednn_dialect.cc | 15 ++++++ 4 files changed, 89 insertions(+), 8 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc b/paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc index d7e55095fcc2ec..5c46fc4be85e50 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc @@ -33,10 +33,25 @@ void OperatorDialect::initialize() { // NOTE(chenxi67): GET_OP_LIST is defined in cinn_op.h which is // generated by op_gen.py, see details in // paddle/cinn/hlir/dialect/CMakeLists.txt. + + // NOTE(cocoshe): VS2017 has a limit on the length of template + // parameters, which causes "fatal error C1202". + // Split GET_OP_LIST into two part on WIN32 here. +#ifdef WIN32 + RegisterOps< +#define GET_OP_LIST1 +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op_info.cc" // NOLINT + >(); + RegisterOps< +#define GET_OP_LIST2 +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op_info.cc" // NOLINT + >(); +#else RegisterOps< #define GET_OP_LIST #include "paddle/cinn/hlir/dialect/operator/ir/cinn_op_info.cc" // NOLINT >(); +#endif RegisterOp(); RegisterOp(); RegisterOp(); diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index d29982d22e5f77..01ba59f79c4e2d 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -167,10 +167,24 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ {define_type_id} """ - -CC_OP_INFO_FILE_TEMPLATE = """#ifdef GET_OP_LIST +# NOTE(cocoshe): Use CC_OP_INFO_FILE_TEMPLATE_WIN_PART1 to generate two GET_OP_LIST to avoid +# "fatal error C1202: recursive type or function dependency context too complex" error +# when compiling on vs2017 because the GET_OP_LIST is too long. +# And use CC_OP_INFO_FILE_TEMPLATE_PART1 to generate just one GET_OP_LIST for other compiler. +CC_OP_INFO_FILE_TEMPLATE_PART1 = """#ifdef GET_OP_LIST #undef GET_OP_LIST {op_declare} +""" + +CC_OP_INFO_FILE_TEMPLATE_WIN_PART1 = """#ifdef GET_OP_LIST1 +#undef GET_OP_LIST1 +{op_declare_first_part} +#elif defined(GET_OP_LIST2) +#undef GET_OP_LIST2 +{op_declare_second_part} +""" + +CC_OP_INFO_FILE_TEMPLATE_PART2 = """ #else // This file is generated by "paddle/fluid/pir/dialect/op_generator/op_gen.py" #include "{h_file}" @@ -1994,11 +2008,33 @@ def OpGenerator( op_to_multi_kernels_map_str = "" if op_info_file is not None: - op_info_str = CC_OP_INFO_FILE_TEMPLATE.format( - op_declare=",".join(op_list_strs).replace("\n", ""), - op_to_multi_kernels_map=op_to_multi_kernels_map_str, - h_file=op_def_h_file[:-4], - ) + if sys.platform == "win32": + n = len(op_list_strs) // 2 + first_part_op_info = op_list_strs[:n] + second_part_op_info = op_list_strs[n:] + CC_OP_INFO_FILE_TEMPLATE = ( + CC_OP_INFO_FILE_TEMPLATE_WIN_PART1 + + CC_OP_INFO_FILE_TEMPLATE_PART2 + ) + op_info_str = CC_OP_INFO_FILE_TEMPLATE.format( + op_declare_first_part=",".join(first_part_op_info).replace( + "\n", "" + ), + op_declare_second_part=",".join(second_part_op_info).replace( + "\n", "" + ), + op_to_multi_kernels_map=op_to_multi_kernels_map_str, + h_file=op_def_h_file[:-4], + ) + else: + CC_OP_INFO_FILE_TEMPLATE = ( + CC_OP_INFO_FILE_TEMPLATE_PART1 + CC_OP_INFO_FILE_TEMPLATE_PART2 + ) + op_info_str = CC_OP_INFO_FILE_TEMPLATE.format( + op_declare=",".join(op_list_strs).replace("\n", ""), + op_to_multi_kernels_map=op_to_multi_kernels_map_str, + h_file=op_def_h_file[:-4], + ) with open(op_info_file, 'w') as f: f.write(op_info_str) diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 80f6e598f967c2..a9129a28793b09 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -156,11 +156,26 @@ void OperatorDialect::initialize() { // paddle/fluid/pir/dialect/CMakeLists.txt. // NOTE(Ruting)GET_MANUAL_OP_LIST is define in manual_op.h" // use RegisterOps when list has more than two ops. + + // NOTE(cocoshe): VS2017 has a limit on the length of template + // parameters, which causes "fatal error C1202". + // Split GET_OP_LIST into two part on WIN32 here. +#ifdef WIN32 RegisterOps< -#define GET_OP_LIST +#define GET_OP_LIST1 #include "paddle/fluid/pir/dialect/operator/ir/pd_op_info.cc" // NOLINT >(); + RegisterOps< +#define GET_OP_LIST2 +#include "paddle/fluid/pir/dialect/operator/ir/pd_op_info.cc" // NOLINT + >(); +#else + RegisterOps< +#define GET_OP_LIST +#include "paddle/fluid/pir/dialect/operator/ir/pd_op_info.cc" // NOLINT + >(); +#endif RegisterOps< #define GET_OP_LIST #include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc" // NOLINT diff --git a/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc index 0d65389cc4922b..b2d817b506199d 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc @@ -44,10 +44,25 @@ void OneDNNOperatorDialect::initialize() { // paddle/fluid/pir/dialect/CMakeLists.txt. // NOTE(Ruting)GET_MANUAL_OP_LIST is define in manual_op.h" // use RegisterOps when list has more than two ops. + + // NOTE(cocoshe): VS2017 has a limit on the length of template + // parameters, which causes "fatal error C1202". + // Split GET_OP_LIST into two part on WIN32 here. +#ifdef WIN32 + RegisterOps< +#define GET_OP_LIST1 +#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.cc" // NOLINT + >(); + RegisterOps< +#define GET_OP_LIST2 +#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.cc" // NOLINT + >(); +#else RegisterOps< #define GET_OP_LIST #include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.cc" // NOLINT >(); +#endif } void OneDNNOperatorDialect::PrintType(pir::Type type, std::ostream &os) const { From 08ed2a5dbb957d4ecb4ea9acca78abaaa9e9e5ef Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Thu, 4 Jan 2024 15:52:07 +0800 Subject: [PATCH 106/142] =?UTF-8?q?=E3=80=90Hackathon=205th=20No.20?= =?UTF-8?q?=E3=80=91=E4=B8=BA=20Paddle=20=E6=96=B0=E5=A2=9E=20Exponential?= =?UTF-8?q?=20=E5=92=8C=20Gamma=20API=20(#57899)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add exponential * add gamma distribution * refine docs * add kl_divergence and test * resolve conflicts * resolve conflicts * fix bug * refine test * fix test timeout * refine code * add standard_gamma kernel * fix comments * fix tests * fix tests * fix comments * fix tests * fix gamma grad * fix yaml * fix bugs * fix tests * fix standard_gamma_grad * fix test * fix test * add cdf & icdf * add cdf & icdf * refine comments * fix * fix * fix head file * fix * fix cuda op * fix * fix * refine test * fix test * refine comments * fix comments * fix * fix * fix type check * fix docs * delete useless comments --- paddle/phi/api/ext/tensor_compat.h | 1 + paddle/phi/api/yaml/ops.yaml | 8 + paddle/phi/kernels/cpu/dirichlet_kernel.cc | 82 --- .../phi/kernels/cpu/standard_gamma_kernel.cc | 20 + paddle/phi/kernels/gpu/dirichlet_kernel.cu | 97 ---- .../phi/kernels/gpu/standard_gamma_kernel.cu | 27 + .../phi/kernels/impl/dirichlet_kernel_impl.h | 204 ++++++- .../kernels/impl/standard_gamma_kernel_impl.h | 29 + paddle/phi/kernels/standard_gamma_kernel.h | 34 ++ python/paddle/__init__.py | 2 + python/paddle/distribution/__init__.py | 4 + python/paddle/distribution/exponential.py | 225 ++++++++ python/paddle/distribution/gamma.py | 228 ++++++++ python/paddle/distribution/kl.py | 12 + python/paddle/distribution/normal.py | 4 +- python/paddle/tensor/random.py | 51 ++ .../test_distribution_beta_static.py | 3 +- .../test_distribution_exponential.py | 366 +++++++++++++ .../test_distribution_exponential_static.py | 445 +++++++++++++++ test/distribution/test_distribution_gamma.py | 463 ++++++++++++++++ .../test_distribution_gamma_static.py | 509 ++++++++++++++++++ .../test_distribution_geometric_static.py | 5 +- 22 files changed, 2631 insertions(+), 188 deletions(-) create mode 100644 paddle/phi/kernels/cpu/standard_gamma_kernel.cc create mode 100644 paddle/phi/kernels/gpu/standard_gamma_kernel.cu create mode 100644 paddle/phi/kernels/impl/standard_gamma_kernel_impl.h create mode 100644 paddle/phi/kernels/standard_gamma_kernel.h create mode 100644 python/paddle/distribution/exponential.py create mode 100644 python/paddle/distribution/gamma.py create mode 100644 test/distribution/test_distribution_exponential.py create mode 100644 test/distribution/test_distribution_exponential_static.py create mode 100644 test/distribution/test_distribution_gamma.py create mode 100644 test/distribution/test_distribution_gamma_static.py diff --git a/paddle/phi/api/ext/tensor_compat.h b/paddle/phi/api/ext/tensor_compat.h index 202876d4dc6986..b1a140da46a890 100644 --- a/paddle/phi/api/ext/tensor_compat.h +++ b/paddle/phi/api/ext/tensor_compat.h @@ -144,6 +144,7 @@ using experimental::split; using experimental::sqrt; using experimental::square; using experimental::stack; +using experimental::standard_gamma; using experimental::strided_slice; using experimental::subtract; using experimental::swish; diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index de4d700cdf80ee..d4ee3628ad19a4 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -2542,6 +2542,14 @@ backward : stack_grad interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : standard_gamma + args : (Tensor x) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + kernel : + func : standard_gamma + - op : stanh args : (Tensor x, float scale_a=0.67f, float scale_b=1.7159f) output : Tensor(out) diff --git a/paddle/phi/kernels/cpu/dirichlet_kernel.cc b/paddle/phi/kernels/cpu/dirichlet_kernel.cc index 855e6bdfe1e1ff..b18fee4694ee67 100644 --- a/paddle/phi/kernels/cpu/dirichlet_kernel.cc +++ b/paddle/phi/kernels/cpu/dirichlet_kernel.cc @@ -13,90 +13,8 @@ // limitations under the License. #include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/cpu/elementwise.h" -#include "paddle/phi/kernels/funcs/elementwise_functor.h" -#include "paddle/phi/kernels/funcs/for_range.h" -#include "paddle/phi/kernels/funcs/reduce_function.h" -#include "paddle/phi/kernels/funcs/reduce_functor.h" #include "paddle/phi/kernels/impl/dirichlet_kernel_impl.h" -namespace phi { - -template -struct GammaCPUFunctor { - GammaCPUFunctor(const T* alpha, - T* gamma, - BaseSampler uniform, - BaseSampler normal) - : alpha_(alpha), gamma_(gamma), uniform_(uniform), normal_(normal) {} - - HOST void operator()(int64_t index) { - auto sample = sample_gamma( - alpha_[index], uniform_, normal_); - gamma_[index] = std::max(std::numeric_limits::min(), sample); - } - - const T* alpha_; - T* gamma_; - BaseSampler uniform_; - BaseSampler normal_; -}; - -template -struct DirichletSampler { - void operator()(const CPUContext& dev_ctx, - const DenseTensor& alpha, - DenseTensor* out) { - auto generator = dev_ctx.GetGenerator()->GetCPUEngine(); - - auto uniform = [&generator]() -> T { - std::uniform_real_distribution u(0.0, 1.0); - return u(*generator); - }; - BaseSampler standard_uniform(uniform); - - auto normal = [&generator]() { - std::normal_distribution n(0.0, 1.0); - return n(*generator); - }; - BaseSampler standard_normal(normal); - - // sample from K gamma distributions, where K=alpha.numel() - DenseTensor gamma_samples; - gamma_samples.Resize(alpha.dims()); - dev_ctx.template Alloc(&gamma_samples); - - GammaCPUFunctor gamma_functor( - alpha.data(), - gamma_samples.data(), - standard_uniform, - standard_normal); - funcs::ForRange for_range(dev_ctx, alpha.numel()); - for_range(gamma_functor); - - // normalize them into a simplex, along the last axis - DenseTensor gamma_sum; - auto new_shape = gamma_samples.dims(); - new_shape[new_shape.size() - 1] = 1; - gamma_sum.Resize(new_shape); - dev_ctx.template Alloc(&gamma_sum); - - funcs::ReduceKernelImpl( - dev_ctx, - gamma_samples, - &gamma_sum, - {new_shape.size() - 1}, - true, - false); - - funcs::ElementwiseCompute, T>( - dev_ctx, gamma_samples, gamma_sum, funcs::DivideFunctor(), out); - } -}; - -} // namespace phi - PD_REGISTER_KERNEL( dirichlet, CPU, ALL_LAYOUT, phi::Dirichletkernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/standard_gamma_kernel.cc b/paddle/phi/kernels/cpu/standard_gamma_kernel.cc new file mode 100644 index 00000000000000..74cd594bdb6cf4 --- /dev/null +++ b/paddle/phi/kernels/cpu/standard_gamma_kernel.cc @@ -0,0 +1,20 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/standard_gamma_kernel_impl.h" + +PD_REGISTER_KERNEL( + standard_gamma, CPU, ALL_LAYOUT, phi::StandardGammaKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/dirichlet_kernel.cu b/paddle/phi/kernels/gpu/dirichlet_kernel.cu index 912c84bf26c210..12b70c3ec68a55 100644 --- a/paddle/phi/kernels/gpu/dirichlet_kernel.cu +++ b/paddle/phi/kernels/gpu/dirichlet_kernel.cu @@ -1,5 +1,3 @@ - - // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,102 +14,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/elementwise_divide_kernel.h" -#include "paddle/phi/kernels/funcs/broadcast_function.h" -#include "paddle/phi/kernels/funcs/elementwise_functor.h" -#include "paddle/phi/kernels/funcs/for_range.h" -#include "paddle/phi/kernels/funcs/reduce_function.h" -#include "paddle/phi/kernels/funcs/reduce_functor.h" #include "paddle/phi/kernels/impl/dirichlet_kernel_impl.h" -#include "paddle/phi/kernels/reduce_sum_kernel.h" - -#ifdef PADDLE_WITH_CUDA -#include -#endif -#ifdef PADDLE_WITH_HIP -#include -#endif - -#if defined(PADDLE_WITH_CUDA) -using COMPAT_RANDSTATEPHILOX4_32_10_T = curandStatePhilox4_32_10_t; -#define COMPAT_RAND_INIT curand_init -#define COMPAT_RAND_UNIFORM curand_uniform -#define COMPAT_RAND_NORMAL curand_normal -#elif defined(PADDLE_WITH_HIP) -using COMPAT_RANDSTATEPHILOX4_32_10_T = hiprandStatePhilox4_32_10_t; -#define COMPAT_RAND_INIT hiprand_init -#define COMPAT_RAND_UNIFORM hiprand_uniform -#define COMPAT_RAND_NORMAL hiprand_normal -#endif - -namespace phi { -template -struct GammaCUDAFunctor { - GammaCUDAFunctor(const T* alpha, T* gamma, uint64_t seed, uint64_t offset) - : alpha_(alpha), gamma_(gamma), seed_(seed), offset_(offset) {} - - DEVICE void operator()(int64_t index) { - // curand initialization - COMPAT_RANDSTATEPHILOX4_32_10_T state; - COMPAT_RAND_INIT( - /*seed=*/seed_, /*subsequence=*/index, /*offset=*/offset_, &state); - - // sample - auto uniform_lambda = [&state]() { return COMPAT_RAND_UNIFORM(&state); }; - BaseSampler standard_uniform(uniform_lambda); - auto normal_lambda = [&state]() { return COMPAT_RAND_NORMAL(&state); }; - BaseSampler standard_normal(normal_lambda); - - auto sample = - sample_gamma( - alpha_[index], standard_uniform, standard_normal); - gamma_[index] = std::max(std::numeric_limits::min(), sample); - } - - const T* alpha_; - T* gamma_; - const uint64_t seed_; - const uint64_t offset_; -}; - -template -struct DirichletSampler { - void operator()(const GPUContext& dev_ctx, - const DenseTensor& alpha, - DenseTensor* out) { - auto p_gen = dev_ctx.GetGenerator(); - auto seed_and_offset = p_gen->IncrementOffset(10); // hard-coded offset - auto seed = seed_and_offset.first; - auto offset = seed_and_offset.second; - - // sample from K gamma distributions, where K=alpha.numel() - DenseTensor gamma_samples; - gamma_samples.Resize(alpha.dims()); - dev_ctx.template Alloc(&gamma_samples); - - GammaCUDAFunctor gamma_functor( - alpha.data(), gamma_samples.data(), seed, offset); - funcs::ForRange for_range(dev_ctx, out->numel()); - for_range(gamma_functor); - - // normalize them into a simplex, along the last axis - DenseTensor gamma_sum; - auto new_shape = gamma_samples.dims(); - new_shape[new_shape.size() - 1] = 1; - gamma_sum.Resize(new_shape); - dev_ctx.template Alloc(&gamma_sum); - - phi::SumRawKernel(dev_ctx, - gamma_samples, - {new_shape.size() - 1}, - true, - false, - gamma_sum.dtype(), - &gamma_sum); - phi::DivideKernel(dev_ctx, gamma_samples, gamma_sum, out); - } -}; -} // namespace phi PD_REGISTER_KERNEL(dirichlet, GPU, diff --git a/paddle/phi/kernels/gpu/standard_gamma_kernel.cu b/paddle/phi/kernels/gpu/standard_gamma_kernel.cu new file mode 100644 index 00000000000000..9573181b3164b5 --- /dev/null +++ b/paddle/phi/kernels/gpu/standard_gamma_kernel.cu @@ -0,0 +1,27 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/standard_gamma_kernel_impl.h" + +PD_REGISTER_KERNEL(standard_gamma, + GPU, + ALL_LAYOUT, + phi::StandardGammaKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/impl/dirichlet_kernel_impl.h b/paddle/phi/kernels/impl/dirichlet_kernel_impl.h index 82eeeee6f4c64f..9b09ca51ab6de1 100644 --- a/paddle/phi/kernels/impl/dirichlet_kernel_impl.h +++ b/paddle/phi/kernels/impl/dirichlet_kernel_impl.h @@ -16,8 +16,19 @@ #include #include +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/cpu/elementwise.h" #include "paddle/phi/kernels/dirichlet_kernel.h" +#include "paddle/phi/kernels/elementwise_divide_kernel.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" +#include "paddle/phi/kernels/funcs/reduce_functor.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" // ROCM hcc doesn't work well with using std:: in kernel functions #if defined(PADDLE_WITH_CUDA) @@ -42,6 +53,25 @@ #define COMPAT_LOG1P std::log1p #endif +#ifdef PADDLE_WITH_CUDA +#include +#endif +#ifdef PADDLE_WITH_HIP +#include +#endif + +#if defined(PADDLE_WITH_CUDA) +using COMPAT_RANDSTATEPHILOX4_32_10_T = curandStatePhilox4_32_10_t; +#define COMPAT_RAND_INIT curand_init +#define COMPAT_RAND_UNIFORM curand_uniform +#define COMPAT_RAND_NORMAL curand_normal +#elif defined(PADDLE_WITH_HIP) +using COMPAT_RANDSTATEPHILOX4_32_10_T = hiprandStatePhilox4_32_10_t; +#define COMPAT_RAND_INIT hiprand_init +#define COMPAT_RAND_UNIFORM hiprand_uniform +#define COMPAT_RAND_NORMAL hiprand_normal +#endif + namespace phi { template @@ -54,6 +84,20 @@ struct BaseSampler { } }; +template +struct GammaSampler { + void operator()(const Context& dev_ctx, + const DenseTensor& alpha, + DenseTensor* out); +}; + +template +struct DirichletSampler { + void operator()(const Context& dev_ctx, + const DenseTensor& alpha, + DenseTensor* out); +}; + // `sample_gamma` is d from Numpy's distributions.c, and add support for // paddle data type and code style. // Source MIT licensed: @@ -124,13 +168,164 @@ sample_gamma(ScalarT alpha, } } -template -struct DirichletSampler { - void operator()(const Context& dev_ctx, +template +struct GammaCPUFunctor { + GammaCPUFunctor(const T* alpha, + T* gamma, + BaseSampler uniform, + BaseSampler normal) + : alpha_(alpha), gamma_(gamma), uniform_(uniform), normal_(normal) {} + + HOST void operator()(int64_t index) { + auto sample = sample_gamma( + alpha_[index], uniform_, normal_); + gamma_[index] = std::max(std::numeric_limits::min(), sample); + } + + const T* alpha_; + T* gamma_; + BaseSampler uniform_; + BaseSampler normal_; +}; + +template +struct GammaSampler { + void operator()(const CPUContext& dev_ctx, const DenseTensor& alpha, - DenseTensor* out); + DenseTensor* out) { + auto generator = dev_ctx.GetGenerator()->GetCPUEngine(); + + auto uniform = [&generator]() -> T { + std::uniform_real_distribution u(0.0, 1.0); + return u(*generator); + }; + BaseSampler standard_uniform(uniform); + + auto normal = [&generator]() { + std::normal_distribution n(0.0, 1.0); + return n(*generator); + }; + BaseSampler standard_normal(normal); + + GammaCPUFunctor gamma_functor( + alpha.data(), out->data(), standard_uniform, standard_normal); + funcs::ForRange for_range(dev_ctx, out->numel()); + for_range(gamma_functor); + } +}; + +template +struct DirichletSampler { + void operator()(const CPUContext& dev_ctx, + const DenseTensor& alpha, + DenseTensor* out) { + // sample from K gamma distributions, where K=alpha.numel() + DenseTensor gamma_samples; + gamma_samples.Resize(alpha.dims()); + dev_ctx.template Alloc(&gamma_samples); + + GammaSampler gamma_sampler; + gamma_sampler(dev_ctx, alpha, &gamma_samples); + + // normalize them into a simplex, along the last axis + DenseTensor gamma_sum; + auto new_shape = gamma_samples.dims(); + new_shape[new_shape.size() - 1] = 1; + gamma_sum.Resize(new_shape); + dev_ctx.template Alloc(&gamma_sum); + + funcs::ReduceKernelImpl( + dev_ctx, + gamma_samples, + &gamma_sum, + {new_shape.size() - 1}, + true, + false); + + funcs::ElementwiseCompute, T>( + dev_ctx, gamma_samples, gamma_sum, funcs::DivideFunctor(), out); + } +}; + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +template +struct GammaCUDAFunctor { + GammaCUDAFunctor(const T* alpha, T* gamma, uint64_t seed, uint64_t offset) + : alpha_(alpha), gamma_(gamma), seed_(seed), offset_(offset) {} + + DEVICE void operator()(int64_t index) { + // curand initialization + COMPAT_RANDSTATEPHILOX4_32_10_T state; + COMPAT_RAND_INIT( + /*seed=*/seed_, /*subsequence=*/index, /*offset=*/offset_, &state); + + // sample + auto uniform_lambda = [&state]() { return COMPAT_RAND_UNIFORM(&state); }; + BaseSampler standard_uniform(uniform_lambda); + auto normal_lambda = [&state]() { return COMPAT_RAND_NORMAL(&state); }; + BaseSampler standard_normal(normal_lambda); + + auto sample = + sample_gamma( + alpha_[index], standard_uniform, standard_normal); + gamma_[index] = std::max(std::numeric_limits::min(), sample); + } + + const T* alpha_; + T* gamma_; + const uint64_t seed_; + const uint64_t offset_; +}; + +template +struct GammaSampler { + void operator()(const GPUContext& dev_ctx, + const DenseTensor& alpha, + DenseTensor* out) { + auto p_gen = dev_ctx.GetGenerator(); + auto seed_and_offset = p_gen->IncrementOffset(10); // hard-coded offset + auto seed = seed_and_offset.first; + auto offset = seed_and_offset.second; + + GammaCUDAFunctor gamma_functor( + alpha.data(), out->data(), seed, offset); + funcs::ForRange for_range(dev_ctx, out->numel()); + for_range(gamma_functor); + } }; +template +struct DirichletSampler { + void operator()(const GPUContext& dev_ctx, + const DenseTensor& alpha, + DenseTensor* out) { + // sample from K gamma distributions, where K=alpha.numel() + DenseTensor gamma_samples; + gamma_samples.Resize(alpha.dims()); + dev_ctx.template Alloc(&gamma_samples); + + GammaSampler gamma_sampler; + gamma_sampler(dev_ctx, alpha, &gamma_samples); + + // normalize them into a simplex, along the last axis + DenseTensor gamma_sum; + auto new_shape = gamma_samples.dims(); + new_shape[new_shape.size() - 1] = 1; + gamma_sum.Resize(new_shape); + dev_ctx.template Alloc(&gamma_sum); + + phi::SumRawKernel(dev_ctx, + gamma_samples, + {new_shape.size() - 1}, + true, + false, + gamma_sum.dtype(), + &gamma_sum); + phi::DivideKernel(dev_ctx, gamma_samples, gamma_sum, out); + } +}; +#endif + template void Dirichletkernel(const Context& dev_ctx, const DenseTensor& alpha, @@ -139,4 +334,5 @@ void Dirichletkernel(const Context& dev_ctx, DirichletSampler sampler; sampler(dev_ctx, alpha, out); } + } // namespace phi diff --git a/paddle/phi/kernels/impl/standard_gamma_kernel_impl.h b/paddle/phi/kernels/impl/standard_gamma_kernel_impl.h new file mode 100644 index 00000000000000..12461b2c7b7127 --- /dev/null +++ b/paddle/phi/kernels/impl/standard_gamma_kernel_impl.h @@ -0,0 +1,29 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/impl/dirichlet_kernel_impl.h" +#include "paddle/phi/kernels/standard_gamma_kernel.h" + +namespace phi { +template +void StandardGammaKernel(const Context& dev_ctx, + const DenseTensor& alpha, + DenseTensor* out) { + dev_ctx.template Alloc(out); + GammaSampler sampler; + sampler(dev_ctx, alpha, out); +} +} // namespace phi diff --git a/paddle/phi/kernels/standard_gamma_kernel.h b/paddle/phi/kernels/standard_gamma_kernel.h new file mode 100644 index 00000000000000..d77ebecaa2db21 --- /dev/null +++ b/paddle/phi/kernels/standard_gamma_kernel.h @@ -0,0 +1,34 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +/** + * @brief This kernel generate random value that obey standard gamma + * distribution. + * @param ctx device context + * @param x The input tensor of standard gamma kernel + * @param out The output tensor of standard gamma kernel, it has the same + * shape and dtype with input. Each element corresponds to input tensor + */ +template +void StandardGammaKernel(const Context& ctx, + const DenseTensor& alpha, + DenseTensor* out); + +} // namespace phi diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index fc7b2a3533f892..8902b82cadf847 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -506,6 +506,7 @@ randint_like, randn, randperm, + standard_gamma, standard_normal, uniform, ) @@ -753,6 +754,7 @@ 'bernoulli', 'binomial', 'poisson', + 'standard_gamma', 'sinh', 'sinh_', 'round', diff --git a/python/paddle/distribution/__init__.py b/python/paddle/distribution/__init__.py index 446c75aeaea700..246c4ffb71173b 100644 --- a/python/paddle/distribution/__init__.py +++ b/python/paddle/distribution/__init__.py @@ -21,7 +21,9 @@ from .continuous_bernoulli import ContinuousBernoulli from .dirichlet import Dirichlet from .distribution import Distribution +from .exponential import Exponential from .exponential_family import ExponentialFamily +from .gamma import Gamma from .geometric import Geometric from .gumbel import Gumbel from .independent import Independent @@ -58,6 +60,7 @@ 'ContinuousBernoulli', 'Dirichlet', 'Distribution', + 'Exponential', 'ExponentialFamily', 'Multinomial', 'MultivariateNormal', @@ -69,6 +72,7 @@ 'TransformedDistribution', 'Laplace', 'LogNormal', + 'Gamma', 'Gumbel', 'Geometric', 'Binomial', diff --git a/python/paddle/distribution/exponential.py b/python/paddle/distribution/exponential.py new file mode 100644 index 00000000000000..abf88bc3c4c37e --- /dev/null +++ b/python/paddle/distribution/exponential.py @@ -0,0 +1,225 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np + +import paddle +from paddle import distribution +from paddle.base.data_feeder import check_type, convert_dtype +from paddle.base.framework import Variable +from paddle.distribution import exponential_family +from paddle.framework import in_dynamic_mode + + +class Exponential(exponential_family.ExponentialFamily): + r""" + Exponential distribution parameterized by :attr:`rate`. + + The probability density function (pdf) is + + .. math:: + + f(x; \theta) = \theta e^{- \theta x }, (x \ge 0) $$ + + In the above equation: + + * :math:`rate = \theta`: is the rate parameter. + + Args: + rate (float|Tensor): Rate parameter. The value of rate must be positive. + + Example: + .. code-block:: python + + >>> import paddle + + >>> expon = paddle.distribution.Exponential(paddle.to_tensor([0.5])) + >>> print(expon.mean) + Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [2.]) + + >>> print(expon.variance) + Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [4.]) + + >>> print(expon.entropy()) + Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [1.69314718]) + """ + + def __init__(self, rate): + if not in_dynamic_mode(): + check_type( + rate, + 'rate', + (float, Variable), + 'Exponential', + ) + + # Get/convert rate to tensor. + if self._validate_args(rate): + self.rate = rate + self.dtype = convert_dtype(rate.dtype) + else: + [self.rate] = self._to_tensor(rate) + self.dtype = paddle.get_default_dtype() + + if not paddle.all(self.rate > 0): + raise ValueError("The arg of `rate` must be positive.") + + super().__init__(self.rate.shape) + + @property + def mean(self): + """Mean of exponential distribuion. + + Returns: + Tensor: mean value. + """ + return self.rate.reciprocal() + + @property + def variance(self): + """Variance of exponential distribution. + + Returns: + Tensor: variance value. + """ + return self.rate.pow(-2) + + def sample(self, shape=()): + """Generate samples of the specified shape. + + Args: + shape (Sequence[int], optional): Shape of the generated samples. + + Returns: + Tensor, A tensor with prepended dimensions shape. The data type is float32. + """ + with paddle.no_grad(): + return self.rsample(shape) + + def rsample(self, shape=()): + """Generate reparameterized samples of the specified shape. + + Args: + shape (Sequence[int], optional): Shape of the generated samples. + + Returns: + Tensor: A tensor with prepended dimensions shape. The data type is float32. + """ + shape = distribution.Distribution._extend_shape( + self, sample_shape=shape + ) + + uniform = paddle.uniform( + shape=shape, + min=float(np.finfo(dtype='float32').tiny), + max=1.0, + dtype=self.rate.dtype, + ) + + return -paddle.log(uniform) / self.rate + + def prob(self, value): + r"""Probability density funciotn evaluated at value. + + .. math:: + + { f(x; \theta) = \theta e^{- \theta x}, (x \ge 0 ) } + + Args: + value (float|Tensor): Value to be evaluated. + + Returns: + Tensor: Probability. + """ + return self.rate * paddle.exp(-self.rate * value) + + def log_prob(self, value): + """Log probability density function evaluated at value. + + Args: + value (float|Tensor): Value to be evaluated + + Returns: + Tensor: Log probability. + """ + return paddle.log(self.rate) - self.rate * value + + def entropy(self): + """Entropy of exponential distribution. + + Returns: + Tensor: Entropy. + """ + return 1.0 - paddle.log(self.rate) + + def cdf(self, value): + r"""Cumulative distribution function(CDF) evaluated at value. + + .. math:: + + + { cdf(x; \theta) = 1 - e^{- \theta x }, (x \ge 0) } + + Args: + value (float|Tensor): Value to be evaluated. + + Returns: + Tensor: CDF evaluated at value. + """ + return 1.0 - paddle.exp(-self.rate * value) + + def icdf(self, value): + r"""Inverse cumulative distribution function(CDF) evaluated at value. + + .. math:: + + + { icdf(x; \theta) = -\frac{ 1 }{ \theta } ln(1 + x), (x \ge 0) } + + Args: + value (float|Tensor): Value to be evaluated. + + Returns: + Tensor: CDF evaluated at value. + """ + return -paddle.log1p(-value) / self.rate + + def kl_divergence(self, other): + """The KL-divergence between two exponential distributions. + + Args: + other (Exponential): instance of Exponential. + + Returns: + Tensor: kl-divergence between two exponential distributions. + """ + if not isinstance(other, Exponential): + raise TypeError( + f"Expected type of other is Exponential, but got {type(other)}" + ) + + rate_ratio = other.rate / self.rate + t1 = -paddle.log(rate_ratio) + return t1 + rate_ratio - 1 + + @property + def _natural_parameters(self): + return (-self.rate,) + + def _log_normalizer(self, x): + return -paddle.log(-x) diff --git a/python/paddle/distribution/gamma.py b/python/paddle/distribution/gamma.py new file mode 100644 index 00000000000000..e1ae3a1f636583 --- /dev/null +++ b/python/paddle/distribution/gamma.py @@ -0,0 +1,228 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import paddle +from paddle import distribution +from paddle.base.data_feeder import check_type, convert_dtype +from paddle.base.framework import Variable +from paddle.distribution import exponential_family +from paddle.framework import in_dynamic_mode + + +class Gamma(exponential_family.ExponentialFamily): + r""" + Gamma distribution parameterized by :attr:`concentration` (aka "alpha") and :attr:`rate` (aka "beta"). + + The probability density function (pdf) is + + .. math:: + + f(x; \alpha, \beta, x > 0) = \frac{\beta^{\alpha}}{\Gamma(\alpha)} x^{\alpha-1}e^{-\beta x} + + \Gamma(\alpha)=\int_{0}^{\infty} x^{\alpha-1} e^{-x} \mathrm{~d} x, (\alpha>0) + + Args: + concentration (float|Tensor): Concentration parameter. It supports broadcast semantics. + The value of concentration must be positive. When the parameter is a tensor, + it represents multiple independent distribution with + a batch_shape(refer to :ref:`api_paddle_distribution_Distribution`). + rate (float|Tensor): Rate parameter. It supports broadcast semantics. + The value of rate must be positive. When the parameter is tensor, + it represent multiple independent distribution with + a batch_shape(refer to :ref:`api_paddle_distribution_Distribution`). + + Example: + .. code-block:: python + + >>> import paddle + + >>> # scale input + >>> gamma = paddle.distribution.Gamma(0.5, 0.5) + >>> print(gamma.mean) + Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, + 1.) + + >>> print(gamma.variance) + Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, + 2.) + + >>> print(gamma.entropy()) + Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, + 0.78375685) + + >>> # tensor input with broadcast + >>> gamma = paddle.distribution.Gamma(paddle.to_tensor([0.2, 0.4]), paddle.to_tensor(0.6)) + >>> print(gamma.mean) + Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [0.33333331, 0.66666663]) + + >>> print(gamma.variance) + Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [0.55555552, 1.11111104]) + + >>> print(gamma.entropy()) + Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [-1.99634242, 0.17067254]) + """ + + def __init__(self, concentration, rate): + if not in_dynamic_mode(): + check_type( + concentration, + 'concentration', + (float, Variable), + 'Gamma', + ) + check_type( + rate, + 'rate', + (float, Variable), + 'Gamma', + ) + + # Get/convert concentration/rate to tensor. + if self._validate_args(concentration, rate): + self.concentration = concentration + self.rate = rate + self.dtype = convert_dtype(concentration.dtype) + else: + [self.concentration, self.rate] = self._to_tensor( + concentration, rate + ) + self.dtype = paddle.get_default_dtype() + + if not paddle.all(self.concentration > 0): + raise ValueError("The arg of `concentration` must be positive.") + + if not paddle.all(self.rate > 0): + raise ValueError("The arg of `rate` must be positive.") + + super().__init__(self.concentration.shape) + + @property + def mean(self): + """Mean of gamma distribuion. + + Returns: + Tensor: mean value. + """ + return self.concentration / self.rate + + @property + def variance(self): + """Variance of gamma distribution. + + Returns: + Tensor: variance value. + """ + return self.concentration / self.rate.pow(2) + + def prob(self, value): + """Probability density funciotn evaluated at value + + Args: + value (float|Tensor): Value to be evaluated. + + Returns: + Tensor: Probability. + """ + return paddle.exp(self.log_prob(value)) + + def log_prob(self, value): + """Log probability density function evaluated at value + + Args: + value (float|Tensor): Value to be evaluated + + Returns: + Tensor: Log probability. + """ + return ( + self.concentration * paddle.log(self.rate) + + (self.concentration - 1) * paddle.log(value) + - self.rate * value + - paddle.lgamma(self.concentration) + ) + + def entropy(self): + """Entropy of gamma distribution + + Returns: + Tensor: Entropy. + """ + return ( + self.concentration + - paddle.log(self.rate) + + paddle.lgamma(self.concentration) + + (1.0 - self.concentration) * paddle.digamma(self.concentration) + ) + + def sample(self, shape=()): + """Generate samples of the specified shape. + + Args: + shape (Sequence[int], optional): Shape of the generated samples. + + Returns: + Tensor, A tensor with prepended dimensions shape.The data type is float32. + """ + with paddle.no_grad(): + return self.rsample(shape) + + def rsample(self, shape=()): + """Generate reparameterized samples of the specified shape. + + Args: + shape (Sequence[int], optional): Shape of the generated samples. + + Returns: + Tensor: A tensor with prepended dimensions shape.The data type is float32. + """ + shape = distribution.Distribution._extend_shape( + self, sample_shape=shape + ) + return paddle.standard_gamma( + self.concentration.expand(shape) + ) / self.rate.expand(shape) + + def kl_divergence(self, other): + """The KL-divergence between two gamma distributions. + + Args: + other (Gamma): instance of Gamma. + + Returns: + Tensor: kl-divergence between two gamma distributions. + """ + if not isinstance(other, Gamma): + raise TypeError( + f"Expected type of other is Exponential, but got {type(other)}" + ) + + t1 = other.concentration * paddle.log(self.rate / other.rate) + t2 = paddle.lgamma(other.concentration) - paddle.lgamma( + self.concentration + ) + t3 = (self.concentration - other.concentration) * paddle.digamma( + self.concentration + ) + t4 = (other.rate - self.rate) * (self.concentration / self.rate) + return t1 + t2 + t3 + t4 + + def _natural_parameters(self): + return (self.concentration - 1, -self.rate) + + def _log_normalizer(self, x, y): + return paddle.lgamma(x + 1) + (x + 1) * paddle.log(-y.reciprocal()) diff --git a/python/paddle/distribution/kl.py b/python/paddle/distribution/kl.py index ecec0f425d2d6c..44474b4ab5f79a 100644 --- a/python/paddle/distribution/kl.py +++ b/python/paddle/distribution/kl.py @@ -23,7 +23,9 @@ from paddle.distribution.continuous_bernoulli import ContinuousBernoulli from paddle.distribution.dirichlet import Dirichlet from paddle.distribution.distribution import Distribution +from paddle.distribution.exponential import Exponential from paddle.distribution.exponential_family import ExponentialFamily +from paddle.distribution.gamma import Gamma from paddle.distribution.geometric import Geometric from paddle.distribution.laplace import Laplace from paddle.distribution.lognormal import LogNormal @@ -271,6 +273,16 @@ def _kl_expfamily_expfamily(p, q): return kl +@register_kl(Exponential, Exponential) +def _kl_exponential_exponential(p, q): + return p.kl_divergence(q) + + +@register_kl(Gamma, Gamma) +def _kl_gamma_gamma(p, q): + return p.kl_divergence(q) + + @register_kl(LogNormal, LogNormal) def _kl_lognormal_lognormal(p, q): return p._base.kl_divergence(q._base) diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index aacf8ffa635a26..a985556a5c7fd4 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -159,7 +159,7 @@ def __init__(self, loc, scale, name=None): @property def mean(self): - """Mean of multinomial distribuion. + """Mean of normal distribuion. Returns: Tensor: mean value. @@ -168,7 +168,7 @@ def mean(self): @property def variance(self): - """Variance of lognormal distribution. + """Variance of normal distribution. Returns: Tensor: variance value. diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index 945cc8ba00fb74..e2c9aca080e1a9 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -216,6 +216,57 @@ def poisson(x, name=None): return out +def standard_gamma(x, name=None): + r""" + Returns a tensor filled with random number from a Standard Gamma Distribution. + + .. math:: + + out_i \sim Gamma (alpha = x_i, beta = 1.0) + + Args: + x(Tensor): A tensor with rate parameter of standrad gamma Distribution. The data type + should be bfloat16, float16, float32, float64. + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + Returns: + Tensor: A Tensor filled with random number with the same shape and dtype as ``x``. + + Examples: + .. code-block:: python + + >>> import paddle + >>> paddle.set_device('cpu') + >>> paddle.seed(100) + + >>> x = paddle.uniform([2,3], min=1.0, max=5.0) + >>> out = paddle.standard_gamma(x) + >>> print(out) + >>> # doctest: +SKIP("Random output") + Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True, + [[3.35393834, 0.80538225, 0.36511323], + [6.10344696, 4.28612375, 6.37196636]]) + >>> # doctest: -SKIP + """ + if in_dynamic_or_pir_mode(): + return _C_ops.standard_gamma(x) + else: + check_variable_and_dtype( + x, "x", ["float32", "float64"], "standard_gamma" + ) + + helper = LayerHelper("standard_gamma", **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='standard_gamma', + inputs={'x': x}, + outputs={'out': out}, + attrs={}, + ) + return out + + def multinomial(x, num_samples=1, replacement=False, name=None): """ Returns a Tensor filled with random values sampled from a Multinomical diff --git a/test/distribution/test_distribution_beta_static.py b/test/distribution/test_distribution_beta_static.py index afc0da1c69f7a7..b3d81dd0105eac 100644 --- a/test/distribution/test_distribution_beta_static.py +++ b/test/distribution/test_distribution_beta_static.py @@ -139,5 +139,6 @@ def test_sample(self): fetch_list=self._paddle_beta.sample(), ) self.assertTrue( - data.shape, np.broadcast_arrays(self.alpha, self.beta)[0].shape + data.shape + == np.broadcast_arrays(self.alpha, self.beta)[0].shape ) diff --git a/test/distribution/test_distribution_exponential.py b/test/distribution/test_distribution_exponential.py new file mode 100644 index 00000000000000..15da6ed59a325c --- /dev/null +++ b/test/distribution/test_distribution_exponential.py @@ -0,0 +1,366 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numbers +import unittest + +import numpy as np +import parameterize +import scipy.stats +from distribution import config + +import paddle +from paddle.distribution import exponential, kl + +np.random.seed(2023) +paddle.seed(2023) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'rate'), + [ + ( + '0-dim', + 0.5, + ), + ( + 'one-dim', + parameterize.xrand( + (4,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ( + 'multi-dim', + parameterize.xrand( + (10, 12), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ], +) +class TestExponential(unittest.TestCase): + def setUp(self): + rate = self.rate + if not isinstance(self.rate, numbers.Real): + rate = paddle.to_tensor(self.rate, dtype=paddle.float32) + + self.scale = 1 / rate + self._paddle_expon = exponential.Exponential(rate) + + def test_mean(self): + with paddle.base.dygraph.guard(self.place): + np.testing.assert_allclose( + self._paddle_expon.mean, + scipy.stats.expon.mean(scale=self.scale), + rtol=config.RTOL.get( + str(self._paddle_expon.rate.numpy().dtype) + ), + atol=config.ATOL.get( + str(self._paddle_expon.rate.numpy().dtype) + ), + ) + + def test_variance(self): + with paddle.base.dygraph.guard(self.place): + np.testing.assert_allclose( + self._paddle_expon.variance, + scipy.stats.expon.var(scale=self.scale), + rtol=config.RTOL.get( + str(self._paddle_expon.rate.numpy().dtype) + ), + atol=config.ATOL.get( + str(self._paddle_expon.rate.numpy().dtype) + ), + ) + + def test_prob(self): + value = np.random.rand(*self._paddle_expon.rate.shape) + with paddle.base.dygraph.guard(self.place): + np.testing.assert_allclose( + self._paddle_expon.prob(paddle.to_tensor(value)), + scipy.stats.expon.pdf(value, scale=self.scale), + rtol=config.RTOL.get( + str(self._paddle_expon.rate.numpy().dtype) + ), + atol=config.ATOL.get( + str(self._paddle_expon.rate.numpy().dtype) + ), + ) + + def test_cdf(self): + value = np.random.rand(*self._paddle_expon.rate.shape) + with paddle.base.dygraph.guard(self.place): + np.testing.assert_allclose( + self._paddle_expon.cdf(paddle.to_tensor(value)), + scipy.stats.expon.cdf(value, scale=self.scale), + rtol=config.RTOL.get( + str(self._paddle_expon.rate.numpy().dtype) + ), + atol=config.ATOL.get( + str(self._paddle_expon.rate.numpy().dtype) + ), + ) + + def test_icdf(self): + value = np.random.rand(*self._paddle_expon.rate.shape) + with paddle.base.dygraph.guard(self.place): + np.testing.assert_allclose( + self._paddle_expon.icdf(paddle.to_tensor(value)), + scipy.stats.expon.ppf(value, scale=self.scale), + rtol=config.RTOL.get( + str(self._paddle_expon.rate.numpy().dtype) + ), + atol=config.ATOL.get( + str(self._paddle_expon.rate.numpy().dtype) + ), + ) + + def test_entropy(self): + with paddle.base.dygraph.guard(self.place): + np.testing.assert_allclose( + self._paddle_expon.entropy(), + scipy.stats.expon.entropy(scale=self.scale), + rtol=config.RTOL.get( + str(self._paddle_expon.rate.numpy().dtype) + ), + atol=config.ATOL.get( + str(self._paddle_expon.rate.numpy().dtype) + ), + ) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'rate'), + [ + ( + 'one-dim', + parameterize.xrand( + (2,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ( + 'multi-dim', + parameterize.xrand( + (2, 3), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ], +) +class TestExponentialSample(unittest.TestCase): + def setUp(self): + rate = self.rate + if not isinstance(self.rate, numbers.Real): + rate = paddle.to_tensor(self.rate, dtype=paddle.float32) + + self.scale = 1 / rate + self._paddle_expon = exponential.Exponential(rate) + + def test_sample_shape(self): + cases = [ + { + 'input': (), + 'expect': () + + tuple(paddle.squeeze(self._paddle_expon.rate).shape), + }, + { + 'input': (3, 2), + 'expect': (3, 2) + + tuple(paddle.squeeze(self._paddle_expon.rate).shape), + }, + ] + for case in cases: + self.assertTrue( + tuple(self._paddle_expon.sample(case.get('input')).shape) + == case.get('expect') + ) + + def test_sample(self): + sample_shape = (20000,) + samples = self._paddle_expon.sample(sample_shape) + sample_values = samples.numpy() + self.assertEqual(sample_values.dtype, self.rate.dtype) + + np.testing.assert_allclose( + sample_values.mean(axis=0), + scipy.stats.expon.mean(scale=self.scale), + rtol=0.1, + atol=config.ATOL.get(str(self._paddle_expon.rate.numpy().dtype)), + ) + np.testing.assert_allclose( + sample_values.var(axis=0), + scipy.stats.expon.var(scale=self.scale), + rtol=0.1, + atol=config.ATOL.get(str(self._paddle_expon.rate.numpy().dtype)), + ) + + def test_rsample_shape(self): + cases = [ + { + 'input': (), + 'expect': () + + tuple(paddle.squeeze(self._paddle_expon.rate).shape), + }, + { + 'input': (2, 5), + 'expect': (2, 5) + + tuple(paddle.squeeze(self._paddle_expon.rate).shape), + }, + ] + for case in cases: + self.assertTrue( + tuple(self._paddle_expon.rsample(case.get('input')).shape) + == case.get('expect') + ) + + def test_rsample(self): + sample_shape = (20000,) + samples = self._paddle_expon.rsample(sample_shape) + sample_values = samples.numpy() + self.assertEqual(sample_values.dtype, self.rate.dtype) + + np.testing.assert_allclose( + sample_values.mean(axis=0), + scipy.stats.expon.mean(scale=self.scale), + rtol=0.1, + atol=config.ATOL.get(str(self._paddle_expon.rate.numpy().dtype)), + ) + np.testing.assert_allclose( + sample_values.var(axis=0), + scipy.stats.expon.var(scale=self.scale), + rtol=0.1, + atol=config.ATOL.get(str(self._paddle_expon.rate.numpy().dtype)), + ) + + def test_rsample_backpropagation(self): + sample_shape = (1000, 2) + with paddle.base.dygraph.guard(self.place): + self._paddle_expon.rate.stop_gradient = False + samples = self._paddle_expon.rsample(sample_shape) + grads = paddle.grad([samples], [self._paddle_expon.rate]) + self.assertEqual(len(grads), 1) + self.assertEqual(grads[0].dtype, self._paddle_expon.rate.dtype) + self.assertEqual(grads[0].shape, self._paddle_expon.rate.shape) + axis = list(range(len(sample_shape))) + np.testing.assert_allclose( + -samples.sum(axis) / self._paddle_expon.rate, + grads[0], + rtol=config.RTOL.get( + str(self._paddle_expon.rate.numpy().dtype) + ), + atol=config.ATOL.get( + str(self._paddle_expon.rate.numpy().dtype) + ), + ) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'rate'), + [ + ('0-dim', 0.4), + ], +) +class TestExponentialSampleKS(unittest.TestCase): + def setUp(self): + rate = paddle.to_tensor(self.rate, dtype=paddle.float32) + self.scale = rate.reciprocal() + self._paddle_expon = exponential.Exponential(rate) + + def test_sample_ks(self): + sample_shape = (10000,) + samples = self._paddle_expon.sample(sample_shape) + self.assertTrue(self._kstest(samples)) + + def test_rsample_ks(self): + sample_shape = (10000,) + samples = self._paddle_expon.rsample(sample_shape) + self.assertTrue(self._kstest(samples)) + + def _kstest(self, samples): + # Uses the Kolmogorov-Smirnov test for goodness of fit. + ks, _ = scipy.stats.kstest( + samples, scipy.stats.expon(scale=self.scale).cdf + ) + return ks < 0.02 + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'rate1', 'rate2'), + [ + ( + 'one-dim', + parameterize.xrand( + (2,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (2,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ( + 'multi-dim', + parameterize.xrand( + (2, 3), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (2, 3), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ], +) +class TestExponentialKL(unittest.TestCase): + def setUp(self): + self._expon1 = exponential.Exponential(paddle.to_tensor(self.rate1)) + self._expon2 = exponential.Exponential(paddle.to_tensor(self.rate2)) + + def test_kl_divergence(self): + np.testing.assert_allclose( + kl.kl_divergence(self._expon1, self._expon2), + self._kl(), + rtol=config.RTOL.get(str(self._expon1.rate.numpy().dtype)), + atol=config.ATOL.get(str(self._expon1.rate.numpy().dtype)), + ) + + def test_kl1_error(self): + self.assertRaises( + TypeError, + self._expon1.kl_divergence, + paddle.distribution.beta.Beta, + ) + + def _kl(self): + rate_ratio = self.rate2 / self.rate1 + t1 = -np.log(rate_ratio) + return t1 + rate_ratio - 1 + + +if __name__ == '__main__': + unittest.main() diff --git a/test/distribution/test_distribution_exponential_static.py b/test/distribution/test_distribution_exponential_static.py new file mode 100644 index 00000000000000..71de03b16b36d4 --- /dev/null +++ b/test/distribution/test_distribution_exponential_static.py @@ -0,0 +1,445 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import parameterize +import scipy.stats +from distribution import config + +import paddle +from paddle.distribution import exponential + +np.random.seed(2023) +paddle.seed(2023) + +paddle.enable_static() + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'rate'), + [ + ( + 'one-dim', + parameterize.xrand( + (4,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ( + 'multi-dim', + parameterize.xrand( + (10, 12), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ], +) +class TestExponential(unittest.TestCase): + def setUp(self): + self.program = paddle.static.Program() + self.executor = paddle.static.Executor(self.place) + with paddle.static.program_guard(self.program): + self.scale = 1 / self.rate + rate = paddle.static.data('rate', self.rate.shape, self.rate.dtype) + self._paddle_expon = exponential.Exponential(rate) + self.feeds = {'rate': self.rate} + + def test_mean(self): + with paddle.static.program_guard(self.program): + [mean] = self.executor.run( + self.program, + feed=self.feeds, + fetch_list=[self._paddle_expon.mean], + ) + np.testing.assert_allclose( + mean, + scipy.stats.expon.mean(scale=self.scale), + rtol=config.RTOL.get(str(self.rate.dtype)), + atol=config.ATOL.get(str(self.rate.dtype)), + ) + + def test_variance(self): + with paddle.static.program_guard(self.program): + [variance] = self.executor.run( + self.program, + feed=self.feeds, + fetch_list=[self._paddle_expon.variance], + ) + np.testing.assert_allclose( + variance, + scipy.stats.expon.var(scale=self.scale), + rtol=config.RTOL.get(str(self.rate.dtype)), + atol=config.ATOL.get(str(self.rate.dtype)), + ) + + def test_entropy(self): + with paddle.static.program_guard(self.program): + [entropy] = self.executor.run( + self.program, + feed=self.feeds, + fetch_list=[self._paddle_expon.entropy()], + ) + np.testing.assert_allclose( + entropy, + scipy.stats.expon.entropy(scale=self.scale), + rtol=config.RTOL.get(str(self.rate.dtype)), + atol=config.ATOL.get(str(self.rate.dtype)), + ) + + def test_prob(self): + with paddle.static.program_guard(self.program): + value = paddle.static.data( + 'value', + self._paddle_expon.rate.shape, + self._paddle_expon.rate.dtype, + ) + prob = self._paddle_expon.prob(value) + + random_number = np.random.rand( + *self._paddle_expon.rate.shape + ).astype(self.rate.dtype) + feeds = dict(self.feeds, value=random_number) + [prob] = self.executor.run( + self.program, feed=feeds, fetch_list=[prob] + ) + np.testing.assert_allclose( + prob, + scipy.stats.expon.pdf(random_number, scale=self.scale), + rtol=config.RTOL.get(str(self.rate.dtype)), + atol=config.ATOL.get(str(self.rate.dtype)), + ) + + def test_log_prob(self): + with paddle.static.program_guard(self.program): + value = paddle.static.data( + 'value', + self._paddle_expon.rate.shape, + self._paddle_expon.rate.dtype, + ) + log_prob = self._paddle_expon.log_prob(value) + + random_number = np.random.rand( + *self._paddle_expon.rate.shape + ).astype(self.rate.dtype) + feeds = dict(self.feeds, value=random_number) + [log_prob] = self.executor.run( + self.program, feed=feeds, fetch_list=[log_prob] + ) + np.testing.assert_allclose( + log_prob, + scipy.stats.expon.logpdf(random_number, scale=self.scale), + rtol=config.RTOL.get(str(self.rate.dtype)), + atol=config.ATOL.get(str(self.rate.dtype)), + ) + + def test_cdf(self): + with paddle.static.program_guard(self.program): + value = paddle.static.data( + 'value', + self._paddle_expon.rate.shape, + self._paddle_expon.rate.dtype, + ) + cdf = self._paddle_expon.cdf(value) + + random_number = np.random.rand( + *self._paddle_expon.rate.shape + ).astype(self.rate.dtype) + feeds = dict(self.feeds, value=random_number) + [cdf] = self.executor.run( + self.program, feed=feeds, fetch_list=[cdf] + ) + np.testing.assert_allclose( + cdf, + scipy.stats.expon.cdf(random_number, scale=self.scale), + rtol=config.RTOL.get(str(self.rate.dtype)), + atol=config.ATOL.get(str(self.rate.dtype)), + ) + + def test_icdf(self): + with paddle.static.program_guard(self.program): + value = paddle.static.data( + 'value', + self._paddle_expon.rate.shape, + self._paddle_expon.rate.dtype, + ) + icdf = self._paddle_expon.icdf(value) + + random_number = np.random.rand( + *self._paddle_expon.rate.shape + ).astype(self.rate.dtype) + feeds = dict(self.feeds, value=random_number) + [icdf] = self.executor.run( + self.program, feed=feeds, fetch_list=[icdf] + ) + np.testing.assert_allclose( + icdf, + scipy.stats.expon.ppf(random_number, scale=self.scale), + rtol=config.RTOL.get(str(self.rate.dtype)), + atol=config.ATOL.get(str(self.rate.dtype)), + ) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'rate'), + [ + ( + 'one-dim', + parameterize.xrand( + (2,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ( + 'multi-dim', + parameterize.xrand( + (2, 3), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ], +) +class TestExponentialSample(unittest.TestCase): + def setUp(self): + self.program = paddle.static.Program() + self.executor = paddle.static.Executor(self.place) + with paddle.static.program_guard(self.program): + self.scale = 1 / self.rate + rate = paddle.static.data('rate', self.rate.shape, self.rate.dtype) + self._paddle_expon = exponential.Exponential(rate) + self.feeds = {'rate': self.rate} + + def test_sample_shape(self): + cases = [ + { + 'input': (), + 'expect': () + np.squeeze(self.rate).shape, + }, + { + 'input': (4, 2), + 'expect': (4, 2) + np.squeeze(self.rate).shape, + }, + ] + for case in cases: + with paddle.static.program_guard(self.program): + [data] = self.executor.run( + self.program, + feed=self.feeds, + fetch_list=self._paddle_expon.sample(case.get('input')), + ) + + self.assertTrue(data.shape == case.get('expect')) + + def test_rsample_shape(self): + cases = [ + { + 'input': (), + 'expect': () + np.squeeze(self.rate).shape, + }, + { + 'input': (3, 2), + 'expect': (3, 2) + np.squeeze(self.rate).shape, + }, + ] + for case in cases: + with paddle.static.program_guard(self.program): + [data] = self.executor.run( + self.program, + feed=self.feeds, + fetch_list=self._paddle_expon.rsample(case.get('input')), + ) + + self.assertTrue(data.shape == case.get('expect')) + + def test_sample(self): + sample_shape = (20000,) + with paddle.static.program_guard(self.program): + [data] = self.executor.run( + self.program, + feed=self.feeds, + fetch_list=self._paddle_expon.sample(sample_shape), + ) + except_shape = sample_shape + np.squeeze(self.rate).shape + self.assertTrue(data.shape == except_shape) + np.testing.assert_allclose( + data.mean(axis=0), + scipy.stats.expon.mean(scale=self.scale), + rtol=0.1, + atol=config.ATOL.get(str(self.rate.dtype)), + ) + np.testing.assert_allclose( + data.var(axis=0), + scipy.stats.expon.var(scale=self.scale), + rtol=0.1, + atol=config.ATOL.get(str(self.rate.dtype)), + ) + + def test_rsample(self): + sample_shape = (20000,) + with paddle.static.program_guard(self.program): + [data] = self.executor.run( + self.program, + feed=self.feeds, + fetch_list=self._paddle_expon.rsample(sample_shape), + ) + except_shape = sample_shape + np.squeeze(self.rate).shape + self.assertTrue(data.shape == except_shape) + np.testing.assert_allclose( + data.mean(axis=0), + scipy.stats.expon.mean(scale=self.scale), + rtol=0.1, + atol=config.ATOL.get(str(self.rate.dtype)), + ) + np.testing.assert_allclose( + data.var(axis=0), + scipy.stats.expon.var(scale=self.scale), + rtol=0.1, + atol=config.ATOL.get(str(self.rate.dtype)), + ) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'rate'), + [ + ('0-dim', 0.4), + ], +) +class TestExponentialSampleKS(unittest.TestCase): + def setUp(self): + self.program = paddle.static.Program() + self.executor = paddle.static.Executor(self.place) + with paddle.static.program_guard(self.program): + self.scale = 1 / self.rate + rate = paddle.static.data('rate', (), 'float') + self._paddle_expon = exponential.Exponential(rate) + self.feeds = {'rate': self.rate} + + def test_sample(self): + sample_shape = (10000,) + with paddle.static.program_guard(self.program): + [samples] = self.executor.run( + self.program, + feed=self.feeds, + fetch_list=self._paddle_expon.sample(sample_shape), + ) + self.assertTrue(self._kstest(samples)) + + def test_rsample(self): + sample_shape = (10000,) + with paddle.static.program_guard(self.program): + [samples] = self.executor.run( + self.program, + feed=self.feeds, + fetch_list=self._paddle_expon.rsample(sample_shape), + ) + self.assertTrue(self._kstest(samples)) + + def _kstest(self, samples): + # Uses the Kolmogorov-Smirnov test for goodness of fit. + ks, _ = scipy.stats.kstest( + samples, scipy.stats.expon(scale=self.scale).cdf + ) + return ks < 0.02 + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'rate1', 'rate2'), + [ + ( + 'one-dim', + parameterize.xrand( + (2,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (2,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ( + 'multi-dim', + parameterize.xrand( + (2, 3), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (2, 3), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ], +) +class TestExponentialKL(unittest.TestCase): + def setUp(self): + self.program1 = paddle.static.Program() + self.program2 = paddle.static.Program() + self.executor = paddle.static.Executor(self.place) + with paddle.static.program_guard(self.program1, self.program2): + rate1 = paddle.static.data( + 'rate1', self.rate1.shape, self.rate1.dtype + ) + rate2 = paddle.static.data( + 'rate2', self.rate2.shape, self.rate2.dtype + ) + + self._expon1 = exponential.Exponential(rate1) + self._expon2 = exponential.Exponential(rate2) + + self.feeds = { + 'rate1': self.rate1, + 'rate2': self.rate2, + } + + def test_kl_divergence(self): + with paddle.static.program_guard(self.program1, self.program2): + self.executor.run(self.program2) + [kl] = self.executor.run( + self.program1, + feed=self.feeds, + fetch_list=[self._expon1.kl_divergence(self._expon2)], + ) + np.testing.assert_allclose( + kl, + self._kl(), + rtol=config.RTOL.get(str(self.rate1.dtype)), + atol=config.ATOL.get(str(self.rate1.dtype)), + ) + + def test_kl1_error(self): + self.assertRaises( + TypeError, + self._expon1.kl_divergence, + paddle.distribution.beta.Beta, + ) + + def _kl(self): + rate_ratio = self.rate2 / self.rate1 + t1 = -np.log(rate_ratio) + return t1 + rate_ratio - 1 + + +if __name__ == '__main__': + unittest.main() diff --git a/test/distribution/test_distribution_gamma.py b/test/distribution/test_distribution_gamma.py new file mode 100644 index 00000000000000..630d7ec0c608bf --- /dev/null +++ b/test/distribution/test_distribution_gamma.py @@ -0,0 +1,463 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numbers +import unittest + +import numpy as np +import parameterize +import scipy.stats +from distribution import config + +import paddle +from paddle.distribution import gamma, kl + +np.random.seed(2023) +paddle.seed(2023) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'concentration', 'rate'), + [ + ( + '0-dim', + 0.5, + 0.5, + ), + ( + 'one-dim', + parameterize.xrand( + (6,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (6,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ( + 'multi-dim', + parameterize.xrand( + (10, 12), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (10, 12), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ( + 'broadcast', + parameterize.xrand( + (4, 1), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (4, 6), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ], +) +class TestGamma(unittest.TestCase): + def setUp(self): + concentration = self.concentration + if not isinstance(self.concentration, numbers.Real): + concentration = paddle.to_tensor(self.concentration) + + rate = self.rate + if not isinstance(self.rate, numbers.Real): + rate = paddle.to_tensor(self.rate) + + self.scale = 1 / rate + self._paddle_gamma = gamma.Gamma(concentration, rate) + + def test_mean(self): + with paddle.base.dygraph.guard(self.place): + np.testing.assert_allclose( + self._paddle_gamma.mean, + scipy.stats.gamma.mean(self.concentration, scale=self.scale), + rtol=config.RTOL.get( + str(self._paddle_gamma.concentration.numpy().dtype) + ), + atol=config.ATOL.get( + str(self._paddle_gamma.concentration.numpy().dtype) + ), + ) + + def test_variance(self): + with paddle.base.dygraph.guard(self.place): + np.testing.assert_allclose( + self._paddle_gamma.variance, + scipy.stats.gamma.var(self.concentration, scale=self.scale), + rtol=config.RTOL.get( + str(self._paddle_gamma.concentration.numpy().dtype) + ), + atol=config.ATOL.get( + str(self._paddle_gamma.concentration.numpy().dtype) + ), + ) + + def test_prob(self): + value = np.random.rand(*self._paddle_gamma.rate.shape) + with paddle.base.dygraph.guard(self.place): + np.testing.assert_allclose( + self._paddle_gamma.prob(paddle.to_tensor(value)), + scipy.stats.gamma.pdf( + value, self.concentration, scale=self.scale + ), + rtol=config.RTOL.get( + str(self._paddle_gamma.concentration.numpy().dtype) + ), + atol=config.ATOL.get( + str(self._paddle_gamma.concentration.numpy().dtype) + ), + ) + + def test_log_prob(self): + value = np.random.rand(*self._paddle_gamma.rate.shape) + with paddle.base.dygraph.guard(self.place): + np.testing.assert_allclose( + self._paddle_gamma.log_prob(paddle.to_tensor(value)), + scipy.stats.gamma.logpdf( + value, self.concentration, scale=self.scale + ), + rtol=config.RTOL.get( + str(self._paddle_gamma.concentration.numpy().dtype) + ), + atol=config.ATOL.get( + str(self._paddle_gamma.concentration.numpy().dtype) + ), + ) + + def test_entropy(self): + with paddle.base.dygraph.guard(self.place): + np.testing.assert_allclose( + self._paddle_gamma.entropy(), + scipy.stats.gamma.entropy(self.concentration, scale=self.scale), + rtol=config.RTOL.get( + str(self._paddle_gamma.concentration.numpy().dtype) + ), + atol=config.ATOL.get( + str(self._paddle_gamma.concentration.numpy().dtype) + ), + ) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'concentration', 'rate'), + [ + ( + 'one-dim', + parameterize.xrand( + (2,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (2,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ( + 'multi-dim', + parameterize.xrand( + (2, 3), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (2, 3), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ], +) +class TestGammaSample(unittest.TestCase): + def setUp(self): + concentration = self.concentration + if not isinstance(self.concentration, numbers.Real): + concentration = paddle.to_tensor(self.concentration) + + rate = self.rate + if not isinstance(self.rate, numbers.Real): + rate = paddle.to_tensor(self.rate) + + self.scale = 1 / rate + self._paddle_gamma = gamma.Gamma(concentration, rate) + + def test_sample_shape(self): + cases = [ + { + 'input': (), + 'expect': () + + tuple(paddle.squeeze(self._paddle_gamma.rate).shape), + }, + { + 'input': (3, 2), + 'expect': (3, 2) + + tuple(paddle.squeeze(self._paddle_gamma.rate).shape), + }, + ] + for case in cases: + self.assertTrue( + tuple(self._paddle_gamma.sample(case.get('input')).shape) + == case.get('expect') + ) + + def test_rsample_shape(self): + cases = [ + { + 'input': (), + 'expect': () + + tuple(paddle.squeeze(self._paddle_gamma.rate).shape), + }, + { + 'input': (3, 2), + 'expect': (3, 2) + + tuple(paddle.squeeze(self._paddle_gamma.rate).shape), + }, + ] + for case in cases: + self.assertTrue( + tuple(self._paddle_gamma.rsample(case.get('input')).shape) + == case.get('expect') + ) + + def test_sample(self): + sample_shape = (30000,) + samples = self._paddle_gamma.sample(sample_shape) + sample_values = samples.numpy() + self.assertEqual(sample_values.dtype, self.rate.dtype) + + np.testing.assert_allclose( + sample_values.mean(axis=0), + scipy.stats.gamma.mean(self.concentration, scale=self.scale), + rtol=0.1, + atol=config.ATOL.get( + str(self._paddle_gamma.concentration.numpy().dtype) + ), + ) + np.testing.assert_allclose( + sample_values.var(axis=0), + scipy.stats.gamma.var(self.concentration, scale=self.scale), + rtol=0.1, + atol=config.ATOL.get( + str(self._paddle_gamma.concentration.numpy().dtype) + ), + ) + + def test_rsample(self): + sample_shape = (30000,) + samples = self._paddle_gamma.rsample(sample_shape) + sample_values = samples.numpy() + self.assertEqual(sample_values.dtype, self.rate.dtype) + + np.testing.assert_allclose( + sample_values.mean(axis=0), + scipy.stats.gamma.mean(self.concentration, scale=self.scale), + rtol=0.1, + atol=config.ATOL.get( + str(self._paddle_gamma.concentration.numpy().dtype) + ), + ) + np.testing.assert_allclose( + sample_values.var(axis=0), + scipy.stats.gamma.var(self.concentration, scale=self.scale), + rtol=0.1, + atol=config.ATOL.get( + str(self._paddle_gamma.concentration.numpy().dtype) + ), + ) + + @unittest.skip("TODO: implement standard_gamma grad op.") + def test_rsample_backpropagation(self): + sample_shape = (1000,) + with paddle.base.dygraph.guard(self.place): + self._paddle_gamma.concentration.stop_gradient = False + self._paddle_gamma.rate.stop_gradient = False + samples = self._paddle_gamma.rsample(sample_shape) + grads = paddle.grad( + [samples], + [self._paddle_gamma.concentration, self._paddle_gamma.rate], + ) + self.assertEqual(len(grads), 2) + self.assertEqual( + grads[0].dtype, self._paddle_gamma.concentration.dtype + ) + self.assertEqual( + grads[0].shape, self._paddle_gamma.concentration.shape + ) + self.assertEqual(grads[1].dtype, self._paddle_gamma.rate.dtype) + self.assertEqual(grads[1].shape, self._paddle_gamma.rate.shape) + + samples.backward() + self.assertEqual( + list(self._paddle_gamma.concentration.gradient().shape), + self._paddle_gamma.concentration.shape, + ) + self.assertEqual( + list(self._paddle_gamma.rate.gradient().shape), + self._paddle_gamma.rate.shape, + ) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'concentration', 'rate'), + [ + ('0-dim', 0.4, 0.5), + ], +) +class TestGammaSampleKS(unittest.TestCase): + def setUp(self): + concentration = self.concentration + if not isinstance(self.concentration, numbers.Real): + concentration = paddle.to_tensor(self.concentration) + + rate = self.rate + if not isinstance(self.rate, numbers.Real): + rate = paddle.to_tensor(self.rate) + + self.scale = 1 / rate + self._paddle_gamma = gamma.Gamma(concentration, rate) + + def test_sample_ks(self): + sample_shape = (15000,) + samples = self._paddle_gamma.sample(sample_shape) + self.assertTrue(self._kstest(samples)) + + def test_rsample_ks(self): + sample_shape = (15000,) + samples = self._paddle_gamma.rsample(sample_shape) + self.assertTrue(self._kstest(samples)) + + def _kstest(self, samples): + # Uses the Kolmogorov-Smirnov test for goodness of fit. + ks, _ = scipy.stats.kstest( + samples, scipy.stats.gamma(self.concentration, scale=self.scale).cdf + ) + return ks < 0.02 + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + ( + parameterize.TEST_CASE_NAME, + 'concentration1', + 'rate1', + 'concentration2', + 'rate2', + ), + [ + ( + 'one-dim', + parameterize.xrand( + (2,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (2,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (2,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (2,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ( + 'multi-dim', + parameterize.xrand( + (2, 3), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (2, 3), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (2, 3), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (2, 3), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ], +) +class TestGammaKL(unittest.TestCase): + def setUp(self): + self._gamma1 = gamma.Gamma( + paddle.to_tensor(self.concentration1), paddle.to_tensor(self.rate1) + ) + self._gamma2 = gamma.Gamma( + paddle.to_tensor(self.concentration2), paddle.to_tensor(self.rate2) + ) + + def test_kl_divergence(self): + np.testing.assert_allclose( + kl.kl_divergence(self._gamma1, self._gamma2), + self._kl(), + rtol=config.RTOL.get(str(self._gamma1.concentration.numpy().dtype)), + atol=config.ATOL.get(str(self._gamma1.concentration.numpy().dtype)), + ) + + def test_kl1_error(self): + self.assertRaises( + TypeError, + self._gamma1.kl_divergence, + paddle.distribution.beta.Beta, + ) + + def _kl(self): + concentration1 = self.concentration1 + concentration2 = self.concentration2 + rate1 = self.rate1 + rate2 = self.rate2 + t1 = concentration2 * np.log(rate1 / rate2) + t2 = scipy.special.gammaln(concentration2) - scipy.special.gammaln( + concentration1 + ) + t3 = (concentration1 - concentration2) * scipy.special.digamma( + concentration1 + ) + t4 = (rate2 - rate1) * (concentration1 / rate1) + return t1 + t2 + t3 + t4 + + +if __name__ == '__main__': + unittest.main() diff --git a/test/distribution/test_distribution_gamma_static.py b/test/distribution/test_distribution_gamma_static.py new file mode 100644 index 00000000000000..06bdf335529226 --- /dev/null +++ b/test/distribution/test_distribution_gamma_static.py @@ -0,0 +1,509 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterize +import scipy.stats +from distribution import config + +import paddle +from paddle.distribution import gamma + +np.random.seed(2023) +paddle.seed(2023) + +paddle.enable_static() + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'concentration', 'rate'), + [ + ( + 'one-dim', + parameterize.xrand( + (6,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (6,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ( + 'multi-dim', + parameterize.xrand( + (10, 12), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (10, 12), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ( + 'broadcast', + parameterize.xrand( + (4, 1), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (4, 6), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ], +) +class TestGamma(unittest.TestCase): + def setUp(self): + self.program = paddle.static.Program() + self.executor = paddle.static.Executor(self.place) + with paddle.static.program_guard(self.program): + self.scale = 1 / self.rate + concentration = paddle.static.data( + 'concentration', + self.concentration.shape, + self.concentration.dtype, + ) + rate = paddle.static.data('rate', self.rate.shape, self.rate.dtype) + self._paddle_gamma = gamma.Gamma(concentration, rate) + self.feeds = { + 'concentration': self.concentration, + 'rate': self.rate, + } + + def test_mean(self): + with paddle.static.program_guard(self.program): + [mean] = self.executor.run( + self.program, + feed=self.feeds, + fetch_list=[self._paddle_gamma.mean], + ) + np.testing.assert_allclose( + mean, + scipy.stats.gamma.mean(self.concentration, scale=self.scale), + rtol=config.RTOL.get(str(self.concentration.dtype)), + atol=config.ATOL.get(str(self.concentration.dtype)), + ) + + def test_variance(self): + with paddle.static.program_guard(self.program): + [variance] = self.executor.run( + self.program, + feed=self.feeds, + fetch_list=[self._paddle_gamma.variance], + ) + np.testing.assert_allclose( + variance, + scipy.stats.gamma.var(self.concentration, scale=self.scale), + rtol=config.RTOL.get(str(self.concentration.dtype)), + atol=config.ATOL.get(str(self.concentration.dtype)), + ) + + def test_entropy(self): + with paddle.static.program_guard(self.program): + [entropy] = self.executor.run( + self.program, + feed=self.feeds, + fetch_list=[self._paddle_gamma.entropy()], + ) + np.testing.assert_allclose( + entropy, + scipy.stats.gamma.entropy(self.concentration, scale=self.scale), + rtol=config.RTOL.get(str(self.concentration.dtype)), + atol=config.ATOL.get(str(self.concentration.dtype)), + ) + + def test_prob(self): + with paddle.static.program_guard(self.program): + value = paddle.static.data( + 'value', + self._paddle_gamma.concentration.shape, + self._paddle_gamma.concentration.dtype, + ) + prob = self._paddle_gamma.prob(value) + + random_number = np.random.rand( + *self._paddle_gamma.concentration.shape + ).astype(self.concentration.dtype) + feeds = dict(self.feeds, value=random_number) + [prob] = self.executor.run( + self.program, feed=feeds, fetch_list=[prob] + ) + np.testing.assert_allclose( + prob, + scipy.stats.gamma.pdf( + random_number, self.concentration, scale=self.scale + ), + rtol=config.RTOL.get(str(self.concentration.dtype)), + atol=config.ATOL.get(str(self.concentration.dtype)), + ) + + def test_log_prob(self): + with paddle.static.program_guard(self.program): + value = paddle.static.data( + 'value', + self._paddle_gamma.concentration.shape, + self._paddle_gamma.concentration.dtype, + ) + log_prob = self._paddle_gamma.log_prob(value) + + random_number = np.random.rand( + *self._paddle_gamma.concentration.shape + ).astype(self.concentration.dtype) + feeds = dict(self.feeds, value=random_number) + [log_prob] = self.executor.run( + self.program, feed=feeds, fetch_list=[log_prob] + ) + np.testing.assert_allclose( + log_prob, + scipy.stats.gamma.logpdf( + random_number, self.concentration, scale=self.scale + ), + rtol=config.RTOL.get(str(self.concentration.dtype)), + atol=config.ATOL.get(str(self.concentration.dtype)), + ) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'concentration', 'rate'), + [ + ( + 'one-dim', + parameterize.xrand( + (2,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (2,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ( + 'multi-dim', + parameterize.xrand( + (2, 3), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (2, 3), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ], +) +class TestGammaSample(unittest.TestCase): + def setUp(self): + self.program = paddle.static.Program() + self.executor = paddle.static.Executor(self.place) + with paddle.static.program_guard(self.program): + self.scale = 1 / self.rate + concentration = paddle.static.data( + 'concentration', + self.concentration.shape, + self.concentration.dtype, + ) + rate = paddle.static.data('rate', self.rate.shape, self.rate.dtype) + self._paddle_gamma = gamma.Gamma(concentration, rate) + self.feeds = { + 'concentration': self.concentration, + 'rate': self.rate, + } + + def test_sample_shape(self): + cases = [ + { + 'input': (), + 'expect': () + np.squeeze(self.rate).shape, + }, + { + 'input': (4, 2), + 'expect': (4, 2) + np.squeeze(self.rate).shape, + }, + ] + for case in cases: + with paddle.static.program_guard(self.program): + [data] = self.executor.run( + self.program, + feed=self.feeds, + fetch_list=self._paddle_gamma.sample(case.get('input')), + ) + + self.assertTrue(data.shape == case.get('expect')) + + def test_rsample_shape(self): + cases = [ + { + 'input': (), + 'expect': () + np.squeeze(self.rate).shape, + }, + { + 'input': (3, 2), + 'expect': (3, 2) + np.squeeze(self.rate).shape, + }, + ] + for case in cases: + with paddle.static.program_guard(self.program): + [data] = self.executor.run( + self.program, + feed=self.feeds, + fetch_list=self._paddle_gamma.rsample(case.get('input')), + ) + + self.assertTrue(data.shape == case.get('expect')) + + def test_sample(self): + sample_shape = (30000,) + with paddle.static.program_guard(self.program): + [data] = self.executor.run( + self.program, + feed=self.feeds, + fetch_list=self._paddle_gamma.sample(sample_shape), + ) + except_shape = sample_shape + np.squeeze(self.rate).shape + self.assertTrue(data.shape == except_shape) + np.testing.assert_allclose( + data.mean(axis=0), + scipy.stats.gamma.mean(self.concentration, scale=self.scale), + rtol=0.1, + atol=config.ATOL.get(str(self.concentration.dtype)), + ) + np.testing.assert_allclose( + data.var(axis=0), + scipy.stats.gamma.var(self.concentration, scale=self.scale), + rtol=0.1, + atol=config.ATOL.get(str(self.concentration.dtype)), + ) + + def test_rsample(self): + sample_shape = (30000,) + with paddle.static.program_guard(self.program): + [data] = self.executor.run( + self.program, + feed=self.feeds, + fetch_list=self._paddle_gamma.rsample(sample_shape), + ) + except_shape = sample_shape + np.squeeze(self.rate).shape + self.assertTrue(data.shape == except_shape) + np.testing.assert_allclose( + data.mean(axis=0), + scipy.stats.gamma.mean(self.concentration, scale=self.scale), + rtol=0.1, + atol=config.ATOL.get(str(self.concentration.dtype)), + ) + np.testing.assert_allclose( + data.var(axis=0), + scipy.stats.gamma.var(self.concentration, scale=self.scale), + rtol=0.1, + atol=config.ATOL.get(str(self.concentration.dtype)), + ) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'concentration', 'rate'), + [ + ('0-dim', 0.4, 0.5), + ], +) +class TestGammaSampleKS(unittest.TestCase): + def setUp(self): + self.program = paddle.static.Program() + self.executor = paddle.static.Executor(self.place) + with paddle.static.program_guard(self.program): + self.scale = 1 / self.rate + concentration = paddle.static.data( + 'concentration', + (), + 'float', + ) + rate = paddle.static.data('rate', (), 'float') + self._paddle_gamma = gamma.Gamma(concentration, rate) + self.feeds = { + 'concentration': self.concentration, + 'rate': self.rate, + } + + def test_sample(self): + sample_shape = (15000,) + with paddle.static.program_guard(self.program): + [samples] = self.executor.run( + self.program, + feed=self.feeds, + fetch_list=self._paddle_gamma.sample(sample_shape), + ) + self.assertTrue(self._kstest(samples)) + + def test_rsample(self): + sample_shape = (15000,) + with paddle.static.program_guard(self.program): + [samples] = self.executor.run( + self.program, + feed=self.feeds, + fetch_list=self._paddle_gamma.rsample(sample_shape), + ) + self.assertTrue(self._kstest(samples)) + + def _kstest(self, samples): + # Uses the Kolmogorov-Smirnov test for goodness of fit. + ks, _ = scipy.stats.kstest( + samples, scipy.stats.gamma(self.concentration, scale=self.scale).cdf + ) + return ks < 0.02 + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + ( + parameterize.TEST_CASE_NAME, + 'concentration1', + 'rate1', + 'concentration2', + 'rate2', + ), + [ + ( + 'one-dim', + parameterize.xrand( + (2,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (2,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (2,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (2,), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ( + 'multi-dim', + parameterize.xrand( + (2, 3), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (2, 3), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (2, 3), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + parameterize.xrand( + (2, 3), + dtype='float32', + min=np.finfo(dtype='float32').tiny, + ), + ), + ], +) +class TestGammaKL(unittest.TestCase): + def setUp(self): + self.program1 = paddle.static.Program() + self.program2 = paddle.static.Program() + self.executor = paddle.static.Executor(self.place) + with paddle.static.program_guard(self.program1, self.program2): + concentration1 = paddle.static.data( + 'concentration1', + self.concentration1.shape, + self.concentration1.dtype, + ) + concentration2 = paddle.static.data( + 'concentration2', + self.concentration2.shape, + self.concentration2.dtype, + ) + rate1 = paddle.static.data( + 'rate1', self.rate1.shape, self.rate1.dtype + ) + rate2 = paddle.static.data( + 'rate2', self.rate2.shape, self.rate2.dtype + ) + + self._gamma1 = gamma.Gamma(concentration1, rate1) + self._gamma2 = gamma.Gamma(concentration2, rate2) + + self.feeds = { + 'concentration1': self.concentration1, + 'concentration2': self.concentration2, + 'rate1': self.rate1, + 'rate2': self.rate2, + } + + def test_kl_divergence(self): + with paddle.static.program_guard(self.program1, self.program2): + self.executor.run(self.program2) + [kl] = self.executor.run( + self.program1, + feed=self.feeds, + fetch_list=[self._gamma1.kl_divergence(self._gamma2)], + ) + np.testing.assert_allclose( + kl, + self._kl(), + rtol=config.RTOL.get(str(self.concentration1.dtype)), + atol=config.ATOL.get(str(self.concentration1.dtype)), + ) + + def test_kl1_error(self): + self.assertRaises( + TypeError, + self._gamma1.kl_divergence, + paddle.distribution.beta.Beta, + ) + + def _kl(self): + concentration1 = self.concentration1 + concentration2 = self.concentration2 + rate1 = self.rate1 + rate2 = self.rate2 + t1 = concentration2 * np.log(rate1 / rate2) + t2 = scipy.special.gammaln(concentration2) - scipy.special.gammaln( + concentration1 + ) + t3 = (concentration1 - concentration2) * scipy.special.digamma( + concentration1 + ) + t4 = (rate2 - rate1) * (concentration1 / rate1) + return t1 + t2 + t3 + t4 + + +if __name__ == '__main__': + unittest.main() diff --git a/test/distribution/test_distribution_geometric_static.py b/test/distribution/test_distribution_geometric_static.py index c56d9029d617bc..0a5aa67ef3d788 100644 --- a/test/distribution/test_distribution_geometric_static.py +++ b/test/distribution/test_distribution_geometric_static.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import unittest import numpy as np @@ -113,7 +114,7 @@ def test_sample(self): fetch_list=self._paddle_geometric.sample(), ) self.assertTrue( - data.shape, np.broadcast_arrays(self.probs)[0].shape + data.shape == np.broadcast_arrays(self.probs)[0].shape ) def test_rsample(self): @@ -124,7 +125,7 @@ def test_rsample(self): fetch_list=self._paddle_geometric.rsample(), ) self.assertTrue( - data.shape, np.broadcast_arrays(self.probs)[0].shape + data.shape == np.broadcast_arrays(self.probs)[0].shape ) def test_entropy(self): From 9c24a2ac84ca376f65b20d4d24e735d6d2e4ea8a Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Thu, 4 Jan 2024 16:16:12 +0800 Subject: [PATCH 107/142] [CINN] Add IntrinsicOps into ir_codes_collector (#60556) This PR fixed a bug of running Resnet PaddleClas. The bug is due to vectorize introduce an intrinsic GetAddr and we didn't collect the tensor of GetAddr in ir_node_collector, this would caused tensor alias won't create in cuda code. TODO: we may modify IntrinsicOp in the near future --- paddle/cinn/ir/ir_base.h | 9 ++- paddle/cinn/ir/utils/ir_nodes_collector.cc | 67 +++++++++++++++++++++- 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/paddle/cinn/ir/ir_base.h b/paddle/cinn/ir/ir_base.h index c333448d029ae0..0047100ebcfdfc 100644 --- a/paddle/cinn/ir/ir_base.h +++ b/paddle/cinn/ir/ir_base.h @@ -110,16 +110,23 @@ class Dim; macro__(Product) \ macro__(Sum) \ macro__(PrimitiveNode) \ - macro__(IntrinsicOp) \ macro__(_BufferRange_) \ macro__(ScheduleBlock) \ macro__(ScheduleBlockRealize) \ macro__(_Dim_) \ +#define NODETY_CONTROL_OP_FOR_INTRINSIC(macro__) \ + macro__(IntrinsicOp) \ #define NODETY_FORALL(__m) \ NODETY_PRIMITIVE_TYPE_FOR_EACH(__m) \ NODETY_OP_FOR_EACH(__m) \ + NODETY_CONTROL_OP_FOR_INTRINSIC(__m) \ + NODETY_CONTROL_OP_FOR_EACH(__m) + +#define NODETY_FORALL_EXCEPT_INTRINSIC(__m) \ + NODETY_PRIMITIVE_TYPE_FOR_EACH(__m) \ + NODETY_OP_FOR_EACH(__m) \ NODETY_CONTROL_OP_FOR_EACH(__m) // clang-format on diff --git a/paddle/cinn/ir/utils/ir_nodes_collector.cc b/paddle/cinn/ir/utils/ir_nodes_collector.cc index ac2f0317e9213f..e4ebaca653bae9 100644 --- a/paddle/cinn/ir/utils/ir_nodes_collector.cc +++ b/paddle/cinn/ir/utils/ir_nodes_collector.cc @@ -15,6 +15,8 @@ #include "paddle/cinn/ir/utils/ir_nodes_collector.h" #include +#include "paddle/cinn/ir/intrinsic_ops.h" +#include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_mutator.h" #include "paddle/cinn/ir/ir_printer.h" @@ -71,8 +73,71 @@ struct IrNodesCollector : public IRVisitorRequireReImpl { } \ } - NODETY_FORALL(__m) + NODETY_FORALL_EXCEPT_INTRINSIC(__m) #undef __m + + void Visit(const ir::IntrinsicOp* op) { + switch (op->getKind()) { +#define __(x) \ + case ir::IntrinsicKind::k##x: \ + Visit(llvm::dyn_cast(op)); \ + break; + + INTRINSIC_KIND_FOR_EACH(__) +#undef __ + } + } + + void Visit(const ir::intrinsics::GetAddr* x) { + if (x->data.defined()) { + Visit(&(x->data)); + } + } + + void Visit(const ir::intrinsics::BufferGetDataHandle* x) { + if (x->buffer.defined()) { + Visit(&(x->buffer)); + } + } + + void Visit(const ir::intrinsics::BufferGetDataConstHandle* x) { + if (x->buffer.defined()) { + Visit(&(x->buffer)); + } + } + + void Visit(const ir::intrinsics::PodValueToX* x) { + if (x->pod_value_ptr.defined()) { + Visit(&(x->pod_value_ptr)); + } + } + + void Visit(const ir::intrinsics::BufferCreate* x) { + if (x->buffer.defined()) { + Visit(&(x->buffer)); + } + } + + void Visit(const ir::intrinsics::ArgsConstruct* x) { + if (x->var.defined()) { + Expr convert = Expr(x->var); + Visit(&convert); + } + for (int i = 0; i < x->args.size(); ++i) { + if (x->args[i].defined()) { + Visit(&(x->args[i])); + } + } + } + + void Visit(const ir::intrinsics::BuiltinIntrin* x) { + for (int i = 0; i < x->args.size(); ++i) { + if (x->args[i].defined()) { + Visit(&(x->args[i])); + } + } + } + std::set visited_; }; From 8ed3d1859749e2ee2281f91fc43b1b57cd2b0b3f Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Thu, 4 Jan 2024 16:42:15 +0800 Subject: [PATCH 108/142] =?UTF-8?q?=E3=80=90auto=20parallel=E3=80=91custom?= =?UTF-8?q?=20op=20=20spmd=20rule=20register=20=20(#60509)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * custom op spmd rule register * custom op spmd rule register * custom op spmd rule register * custom op spmd rule register * polish --- cmake/inference_lib.cmake | 7 + paddle/phi/api/ext/op_meta_info.h | 16 + paddle/phi/api/ext/spmd_infer.h | 140 ++++ paddle/phi/api/lib/op_meta_info.cc | 15 + paddle/phi/core/distributed/type_defs.h | 1 + paddle/phi/infermeta/spmd_rules/rules.cc | 605 ++++++++++++++++++ paddle/phi/infermeta/spmd_rules/rules.h | 590 ----------------- test/cpp/auto_parallel/CMakeLists.txt | 24 +- .../auto_parallel/custom_op_spmd_rule_test.cc | 89 +++ tools/gpups_test.sh | 2 +- 10 files changed, 892 insertions(+), 597 deletions(-) create mode 100644 paddle/phi/api/ext/spmd_infer.h create mode 100644 paddle/phi/infermeta/spmd_rules/rules.cc create mode 100644 test/cpp/auto_parallel/custom_op_spmd_rule_test.cc diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index f44e23e6da74e8..d0a055d0f2e64c 100755 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -328,6 +328,13 @@ copy( inference_lib_dist SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/visit_type.h DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/) + +copy( + inference_lib_dist + SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/distributed/type_defs.h + DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/distributed/ +) + copy( inference_lib_dist SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/init_phi.h diff --git a/paddle/phi/api/ext/op_meta_info.h b/paddle/phi/api/ext/op_meta_info.h index e5273958504fd1..2b73e28b448581 100644 --- a/paddle/phi/api/ext/op_meta_info.h +++ b/paddle/phi/api/ext/op_meta_info.h @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/common/exception.h" #include "paddle/phi/api/include/dll_decl.h" #include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" #include "paddle/utils/any.h" #include "paddle/utils/none.h" #include "paddle/utils/optional.h" @@ -996,6 +997,15 @@ struct TrtGetOutputDimsFuncImpl { ////////////////////// Op Meta Info ////////////////////// +using CustomSpmdInferTensorArg = + paddle::variant>; +using CustomSpmdInferAttrArg = paddle::any; + +using InferSpmdFunc = phi::distributed::SpmdInfo (*)( + const std::vector& inputs, + const std::vector& attrs); + class PADDLE_API OpMetaInfo { public: explicit OpMetaInfo(const std::string& op_name) : name_(op_name) {} @@ -1023,6 +1033,9 @@ class PADDLE_API OpMetaInfo { // format: PD_INFER_DTYPE(...) OpMetaInfo& SetInferDtypeFn(InferDtypeFunc&& func); + // format: PD_INFER_SPMD_RULE(...) + OpMetaInfo& SetInferSpmdFn(InferSpmdFunc&& func); + #ifdef PADDLE_WITH_TENSORRT // format: PD_TRT_INFER_SHAPE(...) OpMetaInfo& SetTrtInferShapeFn(TrtGetOutputDimsFunc&& func); @@ -1045,6 +1058,7 @@ class PADDLE_API OpMetaInfo { KernelFunc kernel_fn_{nullptr}; InferShapeFunc infer_shape_fn_{nullptr}; InferDtypeFunc infer_dtype_fn_{nullptr}; + InferSpmdFunc infer_spmd_fn_{nullptr}; #ifdef PADDLE_WITH_TENSORRT TrtGetOutputDimsFunc trt_infer_shape_fn_{nullptr}; std::vector trt_supports_format_config_; @@ -1068,6 +1082,7 @@ class OpMetaInfoHelper { static const KernelFunc& GetKernelFn(const paddle::OpMetaInfo& info); static const InferShapeFunc& GetInferShapeFn(const paddle::OpMetaInfo& info); static const InferDtypeFunc& GetInferDtypeFn(const paddle::OpMetaInfo& info); + static const InferSpmdFunc& GetInferSpmdFn(const paddle::OpMetaInfo& info); #ifdef PADDLE_WITH_TENSORRT static const TrtGetOutputDimsFunc& GetTrtInferShapeFn( @@ -1108,6 +1123,7 @@ class PADDLE_API OpMetaInfoBuilder { OpMetaInfoBuilder& SetKernelFn(KernelFunc func); OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc func); OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc func); + OpMetaInfoBuilder& SetInferSpmdFn(InferSpmdFunc func); #ifdef PADDLE_WITH_TENSORRT OpMetaInfoBuilder& SetTrtInferShapeFn(TrtGetOutputDimsFunc func); diff --git a/paddle/phi/api/ext/spmd_infer.h b/paddle/phi/api/ext/spmd_infer.h new file mode 100644 index 00000000000000..df4d177054a9a8 --- /dev/null +++ b/paddle/phi/api/ext/spmd_infer.h @@ -0,0 +1,140 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace paddle { + +using CustomSpmdInferTensorArg = + paddle::variant>; + +using CustomSpmdInferAttrArg = paddle::any; +template +struct SpmdInferHelperTypeEnd {}; + +#define PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(attr_type) \ + template \ + struct SpmdInferHelper { \ + template \ + static phi::distributed::SpmdInfo InferSpmd( \ + const std::vector& inputs, \ + const std::vector& attrs, \ + const PreviousArgs&... pargs) { \ + try { \ + attr_type arg = paddle::any_cast(attrs[attr_idx]); \ + return SpmdInferHelper::template InferSpmd( \ + inputs, attrs, pargs..., arg); \ + } catch (paddle::bad_any_cast&) { \ + PD_THROW( \ + "Attribute cast error in custom operator SpmdInferFunc " \ + "function. " \ + "Expected " #attr_type \ + " value. SpmdInferFunc's attribute list must be exactly " \ + "same " \ + "as " \ + "Forward " \ + "KernelFn's attribute list except std::vector " \ + "attribute."); \ + } \ + } \ + } + +template +struct SpmdInferImpl; + +template +struct SpmdInferImpl { + static phi::distributed::SpmdInfo InferSpmd( + const std::vector& inputs, + const std::vector& attrs) { + return SpmdInferHelper>:: + template InferSpmd<0, 0>(inputs, attrs); + } + + private: + template + struct SpmdInferHelper; + + // Handle args for general tensor input case + template + struct SpmdInferHelper { + template + static phi::distributed::SpmdInfo InferSpmd( + const std::vector& inputs, + const std::vector& attrs, + PreviousArgs&... pargs) { + auto& arg = + PADDLE_GET_CONST(phi::distributed::DistMetaTensor, inputs[in_idx]); + return SpmdInferHelper::template InferSpmd( + inputs, attrs, pargs..., arg); + } + }; + + // Handle args for vector of Tensor input case + template + struct SpmdInferHelper&, + Tail...> { + template + static phi::distributed::SpmdInfo InferSpmd( + const std::vector& inputs, + const std::vector& attrs, + PreviousArgs&... pargs) { + auto& arg = PADDLE_GET_CONST( + std::vector, inputs[in_idx]); + return SpmdInferHelper::template InferSpmd( + inputs, attrs, pargs..., arg); + } + }; + + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(bool); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(int); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(float); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(int64_t); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::string&); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::vector&); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::vector&); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::vector&); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::vector&); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const bool&); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const int&); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const float&); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const int64_t&); + + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(std::string); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(std::vector); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(std::vector); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(std::vector); + PD_SPECIALIZE_SpmdInferHelper_FOR_AttrType(const std::vector); + + // end: base template + template + struct SpmdInferHelper> { + template + static phi::distributed::SpmdInfo InferSpmd( + const std::vector& inputs, + const std::vector& attrs, + PreviousArgs&... pargs) { + return impl_fn(pargs...); + } + }; +}; + +#define PD_INFER_SPMD_RULE(...) \ + ::paddle::SpmdInferImpl::InferSpmd + +} // namespace paddle diff --git a/paddle/phi/api/lib/op_meta_info.cc b/paddle/phi/api/lib/op_meta_info.cc index 857c2930da45f9..3cef3187193f72 100644 --- a/paddle/phi/api/lib/op_meta_info.cc +++ b/paddle/phi/api/lib/op_meta_info.cc @@ -358,6 +358,11 @@ OpMetaInfo& OpMetaInfo::SetInferDtypeFn(InferDtypeFunc&& func) { return *this; } +OpMetaInfo& OpMetaInfo::SetInferSpmdFn(InferSpmdFunc&& func) { + infer_spmd_fn_ = std::forward(func); + return *this; +} + #ifdef PADDLE_WITH_TENSORRT OpMetaInfo& OpMetaInfo::SetTrtInferShapeFn(TrtGetOutputDimsFunc&& func) { trt_infer_shape_fn_ = std::forward(func); @@ -407,6 +412,11 @@ const InferDtypeFunc& OpMetaInfoHelper::GetInferDtypeFn( return info.infer_dtype_fn_; } +const InferSpmdFunc& OpMetaInfoHelper::GetInferSpmdFn( + const paddle::OpMetaInfo& info) { + return info.infer_spmd_fn_; +} + #ifdef PADDLE_WITH_TENSORRT const TrtGetOutputDimsFunc& OpMetaInfoHelper::GetTrtInferShapeFn( const paddle::OpMetaInfo& info) { @@ -559,6 +569,11 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferDtypeFn(InferDtypeFunc func) { return *this; } +OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferSpmdFn(InferSpmdFunc func) { + info_ptr_->SetInferSpmdFn(std::forward(func)); + return *this; +} + #ifdef PADDLE_WITH_TENSORRT OpMetaInfoBuilder& OpMetaInfoBuilder::SetTrtInferShapeFn( TrtGetOutputDimsFunc func) { diff --git a/paddle/phi/core/distributed/type_defs.h b/paddle/phi/core/distributed/type_defs.h index 1b7035c1a45287..a629fccbf9fbb9 100644 --- a/paddle/phi/core/distributed/type_defs.h +++ b/paddle/phi/core/distributed/type_defs.h @@ -23,6 +23,7 @@ namespace phi { namespace distributed { class TensorDistAttr; +class DistMetaTensor; using ArgDistAttr = paddle::variant>; diff --git a/paddle/phi/infermeta/spmd_rules/rules.cc b/paddle/phi/infermeta/spmd_rules/rules.cc new file mode 100644 index 00000000000000..cef950dfd2d81f --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/rules.cc @@ -0,0 +1,605 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/rules.h" + +/** + * Design Notes: + * + * 1. SPMD info is the special meta info of DistTensor, so we put Spmd infer + * functions in `infermeta` directory. + * + * 2. Since the infer functions of Spmd forward and backward are closely related + * and need to be registered together, we manage them together in one file. + * + * 3. SPMD rules are much smaller than infermeta function, and we manage files + * in operator units. + * + * 4. The previous registration used some compile-time regular matching methods, + * which was less flexible, and the registration of SPMD rules here is declare + * directly in the header file + */ + +namespace phi { +namespace distributed { + +// matmul rule +PD_REGISTER_SPMD_RULE(matmul, + PD_INFER_SPMD(phi::distributed::MatmulInferSpmd), + PD_INFER_SPMD(phi::distributed::MatmulInferSpmdReverse)); +PD_REGISTER_SPMD_RULE(matmul_v2, // static mode + PD_INFER_SPMD(phi::distributed::MatmulInferSpmd), + PD_INFER_SPMD(phi::distributed::MatmulInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE( + elementwise_unary, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE( + elementwise_binary, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); + +// default data parallel rule +PD_REGISTER_SPMD_RULE( + default_data_parallel, + PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmd), + PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + default_, + PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmd), + PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmdReverse)); + +// fused rope +PD_REGISTER_SPMD_RULE( + fused_rotary_position_embedding, + PD_INFER_SPMD(phi::distributed::FusedRopeInferSpmd), + PD_INFER_SPMD(phi::distributed::FusedRopeInferSpmdReverse)); + +// replicated rule /* for unittest */ +PD_REGISTER_SPMD_RULE( + replicated, + PD_INFER_SPMD(phi::distributed::ReplicatedInferSpmd), + PD_INFER_SPMD(phi::distributed::ReplicatedInferSpmdReverse)); + +// unsqueeze rule +PD_REGISTER_SPMD_RULE( + unsqueeze, + PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmd), + PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + unsqueeze2, + PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmd), + PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmdReverse)); + +// elementwise unary rule +PD_REGISTER_SPMD_RULE( + assign, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + hardswish, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + mish, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + relu6, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + swish, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + acos, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + acosh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + asin, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + asinh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + atan, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + atanh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + bernoulli, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + bitwise_not, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + ceil, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + celu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + clip, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + conj, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + cos, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + cosh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + digamma, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + elu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + erf, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + erfinv, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + exp, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + expm1, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + fill, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + floor, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + gelu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + hardshrink, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + hardsigmoid, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + hardtanh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + label_smooth, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + leaky_relu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + lgamma, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + log, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + log10, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + log1p, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + log2, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + logical_not, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + logit, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + logsigmoid, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + poisson, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + pow, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + reciprocal, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + relu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + round, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + rsqrt, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + scale, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + selu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + sigmoid, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + sign, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + silu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + sin, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + sinh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + softplus, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + softshrink, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + softsign, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + sqrt, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + square, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + stanh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + tan, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + tanh, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + tanh_shrink, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + thresholded_relu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + trunc, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + dropout, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); + +// elementwise binary rule +PD_REGISTER_SPMD_RULE( + add, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + elementwise_add, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + divide, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + elementwise_div, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + elementwise_pow, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + floor_divide, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + fmin, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + heaviside, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + maximum, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + minimum, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + multiply, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + elementwise_mul, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + remainder, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + subtract, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + bitwise_and, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + bitwise_or, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + bitwise_xor, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + fmax, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + logical_and, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + logical_or, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + logical_xor, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE( + not_equal, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); + +// TODO(pkuzyc): add multiary elementwise rule + +// reduction rule +PD_REGISTER_SPMD_RULE( + all, + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + amax, + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + amin, + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + any, + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + frobenius_norm, + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + max, + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE( + reduce_max, + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE( + min, + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + prod, + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + sum, + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + reduce_sum, // static + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); + +// layer_norm +PD_REGISTER_SPMD_RULE( + layer_norm, + PD_INFER_SPMD(phi::distributed::LayerNormInferSpmd), + PD_INFER_SPMD(phi::distributed::LayerNormInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE( + flash_attention, + PD_INFER_SPMD(phi::distributed::FlashAttInferSpmdStatic), + PD_INFER_SPMD(phi::distributed::FlashAttInferSpmdReverse)); + +// reshape rule +PD_REGISTER_SPMD_RULE(reshape, + PD_INFER_SPMD(phi::distributed::ReshapeInferSpmd), + PD_INFER_SPMD(phi::distributed::ReshapeInferSpmdReverse)); +PD_REGISTER_SPMD_RULE(reshape2, + PD_INFER_SPMD(phi::distributed::ReshapeInferSpmd), + PD_INFER_SPMD(phi::distributed::ReshapeInferSpmdReverse)); + +// squeeze rule +PD_REGISTER_SPMD_RULE(squeeze, + PD_INFER_SPMD(phi::distributed::SqueezeInferSpmd), + PD_INFER_SPMD(phi::distributed::SqueezeInferSpmdReverse)); +// flatten rule +PD_REGISTER_SPMD_RULE(flatten, + PD_INFER_SPMD(phi::distributed::FlattenInferSpmd), + PD_INFER_SPMD(phi::distributed::FlattenInferSpmdReverse)); + +// embedding rule +PD_REGISTER_SPMD_RULE( + embedding, + PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmd), + PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + lookup_table_v2, + PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmd), + PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmdReverse)); + +// split rule +PD_REGISTER_SPMD_RULE(split, + PD_INFER_SPMD(phi::distributed::SplitInferSpmd), + PD_INFER_SPMD(phi::distributed::SplitInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + split_with_num, + PD_INFER_SPMD(phi::distributed::SplitWithNumInferSpmd), + PD_INFER_SPMD(phi::distributed::SplitWithNumInferSpmdReverse)); + +// slice rule +PD_REGISTER_SPMD_RULE(slice, + PD_INFER_SPMD(phi::distributed::SliceInferSpmd), + PD_INFER_SPMD(phi::distributed::SliceInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE(concat, + PD_INFER_SPMD(phi::distributed::ConcatInferSpmd), + PD_INFER_SPMD(phi::distributed::ConcatInferSpmdReverse)); + +// transpose rule +PD_REGISTER_SPMD_RULE( + transpose, + PD_INFER_SPMD(phi::distributed::TransposeInferSpmd), + PD_INFER_SPMD(phi::distributed::TransposeInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + transpose2, + PD_INFER_SPMD(phi::distributed::TransposeInferSpmd), + PD_INFER_SPMD(phi::distributed::TransposeInferSpmdReverse)); + +// softmax rule +PD_REGISTER_SPMD_RULE(softmax, + PD_INFER_SPMD(phi::distributed::SoftmaxInferSpmd), + PD_INFER_SPMD(phi::distributed::SoftmaxInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE(log_softmax, + PD_INFER_SPMD(phi::distributed::SoftmaxInferSpmd), + PD_INFER_SPMD(phi::distributed::SoftmaxInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE(where, + PD_INFER_SPMD(phi::distributed::WhereInferSpmd), + PD_INFER_SPMD(phi::distributed::WhereInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE(triu, + PD_INFER_SPMD(phi::distributed::TriuInferSpmd), + PD_INFER_SPMD(phi::distributed::TriuInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE( + tril_triu, + PD_INFER_SPMD(phi::distributed::TrilTriuInferSpmd), + PD_INFER_SPMD(phi::distributed::TrilTriuInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE(tile, + PD_INFER_SPMD(phi::distributed::TileInferSpmd), + PD_INFER_SPMD(phi::distributed::TileInferSpmdReverse)); + +// cross_entropy_with_softmax +PD_REGISTER_SPMD_RULE( + cross_entropy_with_softmax, + PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmd), + PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE( + softmax_with_cross_entropy, + PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmd), + PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmdReverse)); + +// fused_linear_param_grad_add got no reverse infer spmd rule +PD_REGISTER_SPMD_RULE( + fused_linear_param_grad_add, + PD_INFER_SPMD(phi::distributed::FusedLinearParamGradAddInferSpmd), + PD_INFER_SPMD( + phi::distributed::FusedLinearParamGradAddInferSpmdFakeReverse)); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index 1015f61802bc4a..37eab9f57ba73c 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -44,593 +44,3 @@ limitations under the License. */ #include "paddle/phi/infermeta/spmd_rules/triu.h" #include "paddle/phi/infermeta/spmd_rules/unsqueeze.h" #include "paddle/phi/infermeta/spmd_rules/where.h" - -/** - * Design Notes: - * - * 1. SPMD info is the special meta info of DistTensor, so we put Spmd infer - * functions in `infermeta` directory. - * - * 2. Since the infer functions of Spmd forward and backward are closely related - * and need to be registered together, we manage them together in one file. - * - * 3. SPMD rules are much smaller than infermeta function, and we manage files - * in operator units. - * - * 4. The previous registration used some compile-time regular matching methods, - * which was less flexible, and the registration of SPMD rules here is declare - * directly in the header file - */ - -namespace phi { -namespace distributed { - -// matmul rule -PD_REGISTER_SPMD_RULE(matmul, - PD_INFER_SPMD(phi::distributed::MatmulInferSpmd), - PD_INFER_SPMD(phi::distributed::MatmulInferSpmdReverse)); -PD_REGISTER_SPMD_RULE(matmul_v2, // static mode - PD_INFER_SPMD(phi::distributed::MatmulInferSpmd), - PD_INFER_SPMD(phi::distributed::MatmulInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE( - elementwise_unary, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE( - elementwise_binary, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); - -// default data parallel rule -PD_REGISTER_SPMD_RULE( - default_data_parallel, - PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmd), - PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - default_, - PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmd), - PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmdReverse)); - -// fused rope -PD_REGISTER_SPMD_RULE( - fused_rotary_position_embedding, - PD_INFER_SPMD(phi::distributed::FusedRopeInferSpmd), - PD_INFER_SPMD(phi::distributed::FusedRopeInferSpmdReverse)); - -// replicated rule /* for unittest */ -PD_REGISTER_SPMD_RULE( - replicated, - PD_INFER_SPMD(phi::distributed::ReplicatedInferSpmd), - PD_INFER_SPMD(phi::distributed::ReplicatedInferSpmdReverse)); - -// unsqueeze rule -PD_REGISTER_SPMD_RULE( - unsqueeze, - PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmd), - PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - unsqueeze2, - PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmd), - PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmdReverse)); - -// elementwise unary rule -PD_REGISTER_SPMD_RULE( - assign, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - hardswish, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - mish, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - relu6, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - swish, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - acos, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - acosh, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - asin, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - asinh, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - atan, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - atanh, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - bernoulli, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - bitwise_not, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - ceil, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - celu, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - clip, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - conj, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - cos, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - cosh, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - digamma, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - elu, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - erf, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - erfinv, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - exp, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - expm1, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - fill, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - floor, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - gelu, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - hardshrink, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - hardsigmoid, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - hardtanh, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - label_smooth, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - leaky_relu, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - lgamma, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - log, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - log10, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - log1p, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - log2, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - logical_not, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - logit, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - logsigmoid, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - poisson, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - pow, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - reciprocal, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - relu, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - round, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - rsqrt, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - scale, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - selu, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - sigmoid, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - sign, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - silu, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - sin, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - sinh, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - softplus, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - softshrink, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - softsign, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - sqrt, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - square, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - stanh, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - tan, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - tanh, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - tanh_shrink, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - thresholded_relu, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - trunc, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - dropout, - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); - -// elementwise binary rule -PD_REGISTER_SPMD_RULE( - add, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - elementwise_add, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - divide, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - elementwise_div, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - elementwise_pow, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - floor_divide, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - fmin, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - heaviside, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - maximum, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - minimum, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - multiply, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - elementwise_mul, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - remainder, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - subtract, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - bitwise_and, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - bitwise_or, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - bitwise_xor, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - fmax, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - logical_and, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - logical_or, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - logical_xor, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE( - not_equal, - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), - PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); - -// TODO(pkuzyc): add multiary elementwise rule - -// reduction rule -PD_REGISTER_SPMD_RULE( - all, - PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), - PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - amax, - PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), - PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - amin, - PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), - PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - any, - PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), - PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - frobenius_norm, - PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), - PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - max, - PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), - PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE( - reduce_max, - PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), - PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE( - min, - PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), - PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - prod, - PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), - PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - sum, - PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), - PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - reduce_sum, // static - PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), - PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); - -// layer_norm -PD_REGISTER_SPMD_RULE( - layer_norm, - PD_INFER_SPMD(phi::distributed::LayerNormInferSpmd), - PD_INFER_SPMD(phi::distributed::LayerNormInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE( - flash_attention, - PD_INFER_SPMD(phi::distributed::FlashAttInferSpmdStatic), - PD_INFER_SPMD(phi::distributed::FlashAttInferSpmdReverse)); - -// reshape rule -PD_REGISTER_SPMD_RULE(reshape, - PD_INFER_SPMD(phi::distributed::ReshapeInferSpmd), - PD_INFER_SPMD(phi::distributed::ReshapeInferSpmdReverse)); -PD_REGISTER_SPMD_RULE(reshape2, - PD_INFER_SPMD(phi::distributed::ReshapeInferSpmd), - PD_INFER_SPMD(phi::distributed::ReshapeInferSpmdReverse)); - -// squeeze rule -PD_REGISTER_SPMD_RULE(squeeze, - PD_INFER_SPMD(phi::distributed::SqueezeInferSpmd), - PD_INFER_SPMD(phi::distributed::SqueezeInferSpmdReverse)); -// flatten rule -PD_REGISTER_SPMD_RULE(flatten, - PD_INFER_SPMD(phi::distributed::FlattenInferSpmd), - PD_INFER_SPMD(phi::distributed::FlattenInferSpmdReverse)); - -// embedding rule -PD_REGISTER_SPMD_RULE( - embedding, - PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmd), - PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - lookup_table_v2, - PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmd), - PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmdReverse)); - -// split rule -PD_REGISTER_SPMD_RULE(split, - PD_INFER_SPMD(phi::distributed::SplitInferSpmd), - PD_INFER_SPMD(phi::distributed::SplitInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - split_with_num, - PD_INFER_SPMD(phi::distributed::SplitWithNumInferSpmd), - PD_INFER_SPMD(phi::distributed::SplitWithNumInferSpmdReverse)); - -// slice rule -PD_REGISTER_SPMD_RULE(slice, - PD_INFER_SPMD(phi::distributed::SliceInferSpmd), - PD_INFER_SPMD(phi::distributed::SliceInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE(concat, - PD_INFER_SPMD(phi::distributed::ConcatInferSpmd), - PD_INFER_SPMD(phi::distributed::ConcatInferSpmdReverse)); - -// transpose rule -PD_REGISTER_SPMD_RULE( - transpose, - PD_INFER_SPMD(phi::distributed::TransposeInferSpmd), - PD_INFER_SPMD(phi::distributed::TransposeInferSpmdReverse)); -PD_REGISTER_SPMD_RULE( - transpose2, - PD_INFER_SPMD(phi::distributed::TransposeInferSpmd), - PD_INFER_SPMD(phi::distributed::TransposeInferSpmdReverse)); - -// softmax rule -PD_REGISTER_SPMD_RULE(softmax, - PD_INFER_SPMD(phi::distributed::SoftmaxInferSpmd), - PD_INFER_SPMD(phi::distributed::SoftmaxInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE(log_softmax, - PD_INFER_SPMD(phi::distributed::SoftmaxInferSpmd), - PD_INFER_SPMD(phi::distributed::SoftmaxInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE(where, - PD_INFER_SPMD(phi::distributed::WhereInferSpmd), - PD_INFER_SPMD(phi::distributed::WhereInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE(triu, - PD_INFER_SPMD(phi::distributed::TriuInferSpmd), - PD_INFER_SPMD(phi::distributed::TriuInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE( - tril_triu, - PD_INFER_SPMD(phi::distributed::TrilTriuInferSpmd), - PD_INFER_SPMD(phi::distributed::TrilTriuInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE(tile, - PD_INFER_SPMD(phi::distributed::TileInferSpmd), - PD_INFER_SPMD(phi::distributed::TileInferSpmdReverse)); - -// cross_entropy_with_softmax -PD_REGISTER_SPMD_RULE( - cross_entropy_with_softmax, - PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmd), - PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmdReverse)); - -PD_REGISTER_SPMD_RULE( - softmax_with_cross_entropy, - PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmd), - PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmdReverse)); - -// fused_linear_param_grad_add got no reverse infer spmd rule -PD_REGISTER_SPMD_RULE( - fused_linear_param_grad_add, - PD_INFER_SPMD(phi::distributed::FusedLinearParamGradAddInferSpmd), - PD_INFER_SPMD( - phi::distributed::FusedLinearParamGradAddInferSpmdFakeReverse)); - -} // namespace distributed -} // namespace phi diff --git a/test/cpp/auto_parallel/CMakeLists.txt b/test/cpp/auto_parallel/CMakeLists.txt index 311958d2e10310..39a7cd28f1c6d8 100644 --- a/test/cpp/auto_parallel/CMakeLists.txt +++ b/test/cpp/auto_parallel/CMakeLists.txt @@ -15,20 +15,32 @@ if(WITH_DISTRIBUTE) SRCS dist_tensor_test.cc DEPS phi common) - paddle_test(spmd_rule_test SRCS spmd_rule_test.cc DEPS spmd_rule_test_util) + paddle_test(spmd_rule_test SRCS spmd_rule_test.cc DEPS spmd_rule_test_util + spmd_rules) paddle_test(softmax_grad_spmd_rule_test SRCS softmax_grad_spmd_rule_test.cc - DEPS spmd_rule_test_util) + DEPS spmd_rule_test_util spmd_rules) paddle_test(tile_spmd_rule_test SRCS tile_spmd_rule_test.cc DEPS - spmd_rule_test_util) + spmd_rule_test_util spmd_rules) paddle_test( fused_linear_param_grad_add_spmd_rule_test SRCS - fused_linear_param_grad_add_spmd_rule_test.cc DEPS spmd_rule_test_util) + fused_linear_param_grad_add_spmd_rule_test.cc DEPS spmd_rule_test_util + spmd_rules) - paddle_test(cross_entropy_softmax_spmd_rule_test SRCS - cross_entropy_softmax_spmd_rule_test.cc DEPS spmd_rule_test_util) + paddle_test( + cross_entropy_softmax_spmd_rule_test SRCS + cross_entropy_softmax_spmd_rule_test.cc DEPS spmd_rule_test_util spmd_rules) + + paddle_test( + custom_op_spmd_rule_test + SRCS + custom_op_spmd_rule_test.cc + DEPS + spmd_rule_test_util + spmd_rules + phi) endif() diff --git a/test/cpp/auto_parallel/custom_op_spmd_rule_test.cc b/test/cpp/auto_parallel/custom_op_spmd_rule_test.cc new file mode 100644 index 00000000000000..6e51634e1df492 --- /dev/null +++ b/test/cpp/auto_parallel/custom_op_spmd_rule_test.cc @@ -0,0 +1,89 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/api/ext/op_meta_info.h" +#include "paddle/phi/api/ext/spmd_infer.h" +#include "test/cpp/auto_parallel/spmd_rule_test_util.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { +TEST(CustomOp, Ctor) { + // test with concat rule + std::vector mesh_shape = {2, 2}; + std::vector process_ids = {0, 1, 2, 3}; + std::vector dim_names = {"x", "y"}; + ProcessMesh process_mesh(mesh_shape, process_ids, dim_names); + + std::vector> shapes = { + {16, 16, 16}, {4, 16, 16}, {2, 16, 16}}; + std::vector> dim_mappings = { + {-1, 0, 1}, {-1, 1, 0}, {-1, -1, 0}}; + std::vector> partial_status = {{}, {}, {1}}; + + auto build_inputs = [&] { + std::vector inputs; + for (int i = 0; i < 3; i++) { + auto t_dist_attr = TensorDistAttr(); + t_dist_attr.set_process_mesh(process_mesh); + t_dist_attr.set_dims_mapping(dim_mappings[i]); + t_dist_attr.set_dynamic_dims({false, false, false}); + auto input = phi::distributed::DistMetaTensor( + common::make_ddim(shapes[i]), t_dist_attr); + inputs.push_back(input); + } + return inputs; + }; + + // test 1, inputs are aligned according to cost, and partial status is cleared + auto inputs = build_inputs(); + + auto forward_spmd_func = + PD_INFER_SPMD_RULE(phi::distributed::ConcatInferSpmd); + int axis = 0; + std::vector infer_inputs = {inputs}; + std::vector attrs = {axis}; + + auto infered_dist_attrs = forward_spmd_func(infer_inputs, attrs); + // list of tensor => sigle tensor + EXPECT_EQ(infered_dist_attrs.first.size(), static_cast(1)); + EXPECT_EQ(infered_dist_attrs.second.size(), static_cast(1)); + EXPECT_TRUE( + paddle::holds_alternative>( + infered_dist_attrs.first[0])); + EXPECT_TRUE(paddle::holds_alternative( + infered_dist_attrs.second[0])); + auto& inputs_infer1 = + PADDLE_GET_CONST(std::vector, + infered_dist_attrs.first[0]); + + for (auto e : inputs_infer1) { + check_dim_mapping(e, {-1, 1, 0}); + check_partial_dims(e, {}); + } + check_dim_mapping(infered_dist_attrs.second[0], {-1, 1, 0}); + check_partial_dims(infered_dist_attrs.second[0], {}); +} + +TEST(CustomOp, Register) { + OpMetaInfoBuilder builder("test_custom_op_smpd", 0); + auto iter = OpMetaInfoMap::Instance().GetMap().find("test_custom_op_smpd"); + EXPECT_TRUE(iter != OpMetaInfoMap::Instance().GetMap().end()); + EXPECT_TRUE(OpMetaInfoHelper::GetInferSpmdFn(iter->second[0]) == nullptr); + builder.SetInferSpmdFn(PD_INFER_SPMD_RULE(phi::distributed::ConcatInferSpmd)); + EXPECT_TRUE(OpMetaInfoHelper::GetInferSpmdFn(iter->second[0]) != nullptr); +} +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle diff --git a/tools/gpups_test.sh b/tools/gpups_test.sh index 91cc6627dd7e29..a482de9074eac9 100644 --- a/tools/gpups_test.sh +++ b/tools/gpups_test.sh @@ -124,7 +124,7 @@ set +e ctest --output-on-failure -R "($parallel_list)" --timeout 120 -j4 | tee -a $tmpfile; test ${PIPESTATUS[0]} -eq 0; EXIT_CODE_1=$? -ctest --output-on-failure -R "($serial_list)" --timeout 120 -j1 | tee -a $tmpfile; test ${PIPESTATUS[0]} -eq 0; +ctest --output-on-failure -R "($serial_list)" --timeout 180 -j1 | tee -a $tmpfile; test ${PIPESTATUS[0]} -eq 0; EXIT_CODE_2=$? set -e From f84fbddddefe2cde310bcba5db707d62e73c8139 Mon Sep 17 00:00:00 2001 From: lzydev Date: Thu, 4 Jan 2024 17:06:41 +0800 Subject: [PATCH 109/142] =?UTF-8?q?=E3=80=90AutoParallel=E3=80=91Add=20mas?= =?UTF-8?q?ter=20grad=20in=20AMP-O2=20of=20AutoParallel=20(#59987)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add master_grad in auto-parallel * reset third_party * fix coverage * support bf16 master_grad * fix bug in master_grad * change code according to review * change the way to find optimizer op --- .../distributed/auto_parallel/constants.py | 1 + .../auto_parallel/static/parallelizer_v2.py | 22 +- python/paddle/distributed/passes/__init__.py | 1 + .../distributed/passes/auto_parallel_fp16.py | 2 +- .../passes/auto_parallel_gradient_merge.py | 13 +- .../passes/auto_parallel_master_grad.py | 239 ++++++++++++++++++ test/auto_parallel/amp_o2_pass.py | 60 ++++- 7 files changed, 325 insertions(+), 13 deletions(-) create mode 100644 python/paddle/distributed/passes/auto_parallel_master_grad.py diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index 2d2073f293ed79..bcc64a50ae2187 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -78,6 +78,7 @@ def set_field_default_config(category, field, default_value): set_field_default_config(AMP, "custom_black_varnames", []) set_field_default_config(AMP, "use_fp16_guard", False) set_field_default_config(AMP, "use_bf16_guard", False) +set_field_default_config(AMP, "use_master_grad", False) ######################################### # sharding configuration diff --git a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py index 73dd1de8508bf9..7f38ebb9f6bedd 100644 --- a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py @@ -252,14 +252,15 @@ def _generate_optimizer( # but optimizer will be called repeatedly in re-launch, so optimizer need to be copied. # 2. lr_scheduler cannot be deepcopy, cause 'deepcopy' will lead to difference of learning_rate between executor and engine. learning_rate = optimizer._learning_rate - optimizer = copy.deepcopy(optimizer) + new_optimizer = copy.deepcopy(optimizer) + new_optimizer._learning_rate = learning_rate + new_optimizer._sorted = False self._dist_context._serial_optimizer = optimizer self._dist_context._serial_optimizer._learning_rate = learning_rate - optimizer._sorted = False with program_guard(main_program, startup_program): with main_program.switch_name_generator_guard("opt_"): - optimizer_ops = optimizer.apply_gradients(params_grads) + optimizer_ops = new_optimizer.apply_gradients(params_grads) self._completer.complete_update_annotation(main_program) return optimizer_ops @@ -380,6 +381,21 @@ def _apply_post_optimization( [main_program], [startup_program], self._pass_context ) + # apply master grad pass + if self._strategy.amp.enable: + amp_config = copy.deepcopy(self._strategy.amp.to_dict()) + config = {} + config["dist_context"] = self._dist_context + config["params_grads"] = params_grads + config["completer"] = self._completer + if amp_config['level'] == "o2" and amp_config["use_master_grad"]: + master_grad_pass = new_pass( + "auto_parallel_master_grad_pass", config + ) + master_grad_pass.apply( + [main_program], [startup_program], self._pass_context + ) + # data parallel optimization if self._strategy.dp_optimization.enable: config = copy.deepcopy(self._strategy.dp_optimization.to_dict()) diff --git a/python/paddle/distributed/passes/__init__.py b/python/paddle/distributed/passes/__init__.py index 107fe74a569d0a..e78cc5bbd0081d 100644 --- a/python/paddle/distributed/passes/__init__.py +++ b/python/paddle/distributed/passes/__init__.py @@ -17,6 +17,7 @@ from .auto_parallel_gradient_merge import * # noqa: F403 from .auto_parallel_sharding import * # noqa: F403 from .auto_parallel_amp import * # noqa: F403 +from .auto_parallel_master_grad import * # noqa: F403 from .auto_parallel_fp16 import * # noqa: F403 from .auto_parallel_recompute import * # noqa: F403 from .auto_parallel_quantization import * # noqa: F403 diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 92259dee3ae057..cd29cbbacc2cef 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -79,7 +79,7 @@ def set_auto_cast_attr(cast_op, block): ), f"in_var {in_name} or out_var {out_name} is None of cast op" if is_forward_op(cast_op): cast_op._set_attr('in_dtype', in_var.dtype) - cast_op._set_attr('out_dtype', out_var.dtype) + out_var.desc.set_dtype(paddle.dtype(cast_op.attr('out_dtype'))) elif is_backward_op(cast_op): in_var_fw = block._find_var_recursive(in_name[: in_name.find("@")]) out_var_fw = block._find_var_recursive(out_name[: out_name.find("@")]) diff --git a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py index c793639c5ba013..51a781b6f0f85b 100644 --- a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py +++ b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py @@ -178,6 +178,7 @@ def _append_gradient_merge_backward_op( for out_name in op.desc.output_arg_names(): if out_name in grad_to_params_grads: param = grad_to_params_grads[out_name][0] + grad = grad_to_params_grads[out_name][1] assert param is not None ref_dist_attr = dist_context.get_tensor_dist_attr_for_program( param @@ -188,8 +189,8 @@ def _append_gradient_merge_backward_op( # Add persistable gradient variables in main_program gradient_merge_var = main_block.create_var( name=param.name + "@GRAD@MERGE", - shape=param.shape, - dtype=param.dtype, + shape=grad.shape, + dtype=grad.dtype, persistable=True, ) ref_process_mesh = ref_dist_attr.process_mesh @@ -205,8 +206,8 @@ def _append_gradient_merge_backward_op( # Add persistable gradient variables in startup_program startup_gradient_merge_var = startup_block.create_var( name=param.name + "@GRAD@MERGE", - shape=param.shape, - dtype=param.dtype, + shape=grad.shape, + dtype=grad.dtype, persistable=True, ) # Initial persistable gradient variables in startup_program @@ -214,8 +215,8 @@ def _append_gradient_merge_backward_op( type="fill_constant", outputs={"Out": startup_gradient_merge_var}, attrs={ - "shape": param.shape, - "dtype": param.dtype, + "shape": grad.shape, + "dtype": grad.dtype, "value": float(0), }, ) diff --git a/python/paddle/distributed/passes/auto_parallel_master_grad.py b/python/paddle/distributed/passes/auto_parallel_master_grad.py new file mode 100644 index 00000000000000..9d105acade045b --- /dev/null +++ b/python/paddle/distributed/passes/auto_parallel_master_grad.py @@ -0,0 +1,239 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +from collections import OrderedDict +from typing import List, Tuple + +from paddle.base import Variable +from paddle.distributed.auto_parallel.static.utils import ( + is_backward_op, + is_gradient_clip_op, + is_optimize_op, + naive_set_dist_op_attr_for_program_by_mesh_and_mapping, + set_var_dist_attr, +) +from paddle.distributed.fleet.meta_optimizers.common import ( + OP_ROLE_KEY, + OpRole, +) +from paddle.framework import core +from paddle.static import program_guard + +from ..utils.log_utils import get_logger +from .auto_parallel_sharding import _supported_optimizer_type +from .pass_base import PassBase, register_pass + +logger = get_logger(logging.INFO, "MasterGradPass") + + +def get_output_in_varlist(op, var_names) -> List[str]: + grad_names = [] + for output_name in op.output_arg_names: + if output_name in var_names: + grad_names.append(output_name) + return grad_names + + +@register_pass("auto_parallel_master_grad_pass") +class MasterGradPass(PassBase): + """ + Use the high precision gradient to replace the low precision gradient in optimizer to avoid inf/nan values of low precision. + The high precision gradient 'master grad' will be used by communication operator, `update_loss_scaling`, `GradClip` and `optimizer`. + """ + + def __init__(self): + super().__init__() + + def _check_self(self): + return True + + def _check_conflict(self, other_pass): + return True + + def _apply_single_impl(self, main_program, startup_program, context): + self._completer = self.get_attr("completer") + dist_context = self.get_attr("dist_context") + params_grads = self.get_attr("params_grads") + logger.debug(f"Origin main_program: {main_program}") + self._add_master_grad(main_program, params_grads, dist_context) + self._regenerate_optimizer( + main_program, startup_program, params_grads, dist_context + ) + logger.debug(f"After main program: {main_program}") + + def _add_cast_op(self, cur_block, grad_names: List[str], dist_context): + grad_first_ids = OrderedDict() + for idx, op in enumerate(cur_block.ops): + if is_optimize_op(op): + break + elif is_backward_op(op): + var_names = get_output_in_varlist(op, grad_names) + for var_name in var_names: + if var_name not in grad_first_ids: + grad_first_ids[var_name] = idx + # Communication operators such as 'allreduce_sum' use input var as output. + else: + pass + + # insert cast op + for grad_name, idx in reversed(grad_first_ids.items()): + grad_var = cur_block.var(grad_name) + if ( + grad_var.dtype == core.VarDesc.VarType.FP16 + or grad_var.dtype == core.VarDesc.VarType.BF16 + ): + is_fp16 = grad_var.dtype == core.VarDesc.VarType.FP16 + producer_op = cur_block.ops[idx] + producer_op_dist_attr = ( + dist_context.get_op_dist_attr_for_program(producer_op) + ) + assert ( + producer_op_dist_attr is not None + ), f"The op: '{producer_op}' should be distributed" + ref_output_dist_attr = ( + producer_op_dist_attr.get_output_dist_attr(grad_name) + ) + assert ( + ref_output_dist_attr is not None + ), f"The output: '{grad_name}' should be distributed" + ref_mesh = ref_output_dist_attr.process_mesh + ref_dims_mapping = ref_output_dist_attr.dims_mapping + ref_chunk_id = producer_op_dist_attr.chunk_id + grad_half_precision_name = ( + grad_name + '@tmp_fp16' + if is_fp16 + else grad_name + '@tmp_bf16' + ) + grad_half_precision = cur_block.create_var( + name=grad_half_precision_name, + dtype=grad_var.dtype, + shape=grad_var.shape, + persistable=False, + stop_gradient=False, + ) + set_var_dist_attr( + dist_context, + grad_half_precision, + ref_dims_mapping, + ref_mesh, + chunk_id=ref_chunk_id, + ) + producer_op._rename_output(grad_name, grad_half_precision.name) + grad_var.desc.set_dtype(core.VarDesc.VarType.FP32) + cast_op = cur_block._insert_op_without_sync( + idx + 1, + type="cast", + inputs={"X": grad_half_precision}, + outputs={"Out": grad_var}, + attrs={ + "in_dtype": grad_half_precision.dtype, + "out_dtype": grad_var.dtype, + }, + ) + cast_op._set_attr(OP_ROLE_KEY, OpRole.Backward) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + cast_op, + ref_mesh, + ref_dims_mapping, + dist_context, + chunk_id=ref_chunk_id, + ) + cur_block._sync_with_cpp() + + def _regenerate_optimizer( + self, + main_program, + startup_program, + params_grads: List[Tuple[Variable, Variable]], + dist_context, + ): + grad_names = [g.name for _, g in params_grads] + # 1. delete the origin optimizer op + # 1.1 delete the var and op associated with the optimizer op in main_program + main_ops = main_program.global_block().ops + main_ops_len = len(main_ops) + first_optimize_idx = main_ops_len + for idx, op in enumerate(main_ops): + # We don't delete the operators for check_nan_inf + if is_optimize_op(op) and is_gradient_clip_op(op): + first_optimize_idx = idx + break + assert ( + first_optimize_idx < main_ops_len + ), "The first optimizer op is not found!" + deleted_temp_var_names = [] + deleted_persist_var_names = [] + reserved_var_names = [] + for idx in range(main_ops_len - 1, first_optimize_idx - 1, -1): + op = main_ops[idx] + inout_arg_names = op.input_arg_names + op.output_arg_names + if op.type in _supported_optimizer_type: + param_names = op.input("Param") + skip_update_names = op.input("SkipUpdate") + for reserved_name in param_names + skip_update_names: + if reserved_name not in reserved_var_names: + reserved_var_names.append(reserved_name) + for input_name in inout_arg_names: + if input_name in grad_names: + continue + var = main_program.global_block().var(input_name) + if ( + var.persistable + and input_name not in deleted_persist_var_names + ): + deleted_persist_var_names.append(input_name) + elif ( + not var.persistable + and input_name not in deleted_temp_var_names + ): + deleted_temp_var_names.append(input_name) + main_program.global_block()._remove_op(idx) + + for var_name in deleted_temp_var_names + deleted_persist_var_names: + if var_name not in reserved_var_names: + main_program.global_block()._remove_var(var_name) + main_program.global_block()._sync_with_cpp() + + # 1.2 delete the var and op in startup_program + for reserved_name in reserved_var_names: + if reserved_name in deleted_persist_var_names: + deleted_persist_var_names.remove(reserved_name) + startup_global_block = startup_program.global_block() + for var_name in deleted_persist_var_names: + if startup_global_block.has_var(var_name): + startup_global_block._remove_var(var_name) + for idx, op in reversed(list(enumerate(startup_global_block.ops))): + inout_arg_names = op.input_arg_names + op.output_arg_names + for var_name in inout_arg_names: + if var_name in deleted_persist_var_names: + startup_program.global_block()._remove_op(idx) + break + + # 2. re-generate new optimizer op + serial_optimizer = copy.deepcopy(dist_context._serial_optimizer) + serial_optimizer._learning_rate = ( + dist_context._serial_optimizer._learning_rate + ) + serial_optimizer._sorted = False + with program_guard(main_program, startup_program): + with main_program.switch_name_generator_guard("opt_"): + _ = serial_optimizer.apply_gradients(params_grads) + self._completer.complete_update_annotation(main_program) + + def _add_master_grad(self, main_program, params_grads, dist_context): + grad_names = [g.name for _, g in params_grads] + for sub_block in main_program.blocks: + self._add_cast_op(sub_block, grad_names, dist_context) diff --git a/test/auto_parallel/amp_o2_pass.py b/test/auto_parallel/amp_o2_pass.py index a770be6d1e4283..501d1f92cae658 100644 --- a/test/auto_parallel/amp_o2_pass.py +++ b/test/auto_parallel/amp_o2_pass.py @@ -39,7 +39,7 @@ def get_cuda_version(): return -1 -def apply_pass(use_amp=False, amp_dtype="bfloat16"): +def apply_pass(use_amp=False, use_master_grad=False, amp_dtype="bfloat16"): strategy = auto.Strategy() strategy.auto_mode = "semi" strategy.reinit = True @@ -54,6 +54,8 @@ def apply_pass(use_amp=False, amp_dtype="bfloat16"): 'elementwise_div', 'reduce_sum', ] + if use_master_grad: + amp.use_master_grad = True return strategy @@ -77,10 +79,12 @@ def init(self, engine): place = paddle.base.CUDAPlace(paddle.distributed.ParallelEnv().dev_id) engine._executor = paddle.static.Executor(place) - def get_engine(self, use_amp=False, amp_dtype="bfloat16"): + def get_engine( + self, use_amp=False, use_master_grad=False, amp_dtype="bfloat16" + ): reset_prog() - strategy = apply_pass(use_amp, amp_dtype) + strategy = apply_pass(use_amp, use_master_grad, amp_dtype) clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) model, loss = generate_model("mp") @@ -105,6 +109,23 @@ def check_bf16(self, program): self.assertEqual(num_fp16, 0) self.assertEqual(num_fp32, 10) + def check_fp16(self, program): + num_bf16 = 0 + num_fp16 = 0 + num_fp32 = 0 + + for p in program.all_parameters(): + if p.dtype == core.VarDesc.VarType.FP32: + num_fp32 += 1 + if p.dtype == core.VarDesc.VarType.FP16: + num_fp16 += 1 + if p.dtype == core.VarDesc.VarType.BF16: + num_bf16 += 1 + + self.assertEqual(num_bf16, 0) + self.assertEqual(num_fp16, 26) + self.assertEqual(num_fp32, 10) + def test_param_grad_fuse_overlap(self): # std mp_engine = self.get_engine(use_amp=False) @@ -139,6 +160,39 @@ def test_param_grad_fuse_overlap(self): self.check_bf16(mp_bf16_engine.main_program) + def test_master_grad(self): + # fp16 + mp_fp16_engine = self.get_engine(use_amp=True, amp_dtype="float16") + if not (paddle.amp.is_float16_supported()): + return + + mp_fp16_history = mp_fp16_engine.fit( + self.dataset, + 3, + epochs=1, + steps_per_epoch=self.batch_num, + log_freq=1, + batch_size=self.batch_size, + ) + loss1 = mp_fp16_history.history['loss'][0] + self.check_fp16(mp_fp16_engine.main_program) + # fp16 + mater_grad + mp_fp16_mater_grad_engine = self.get_engine( + use_amp=True, use_master_grad=True, amp_dtype="float16" + ) + mp_fp16_master_grad_history = mp_fp16_mater_grad_engine.fit( + self.dataset, + 3, + epochs=1, + steps_per_epoch=self.batch_num, + log_freq=1, + batch_size=self.batch_size, + ) + loss2 = mp_fp16_master_grad_history.history['loss'][0] + np.testing.assert_allclose(loss1, loss2, atol=1e-3, rtol=1e-2) + + self.check_fp16(mp_fp16_mater_grad_engine.main_program) + if __name__ == "__main__": unittest.main() From 6dfb15e092308ee9202fd57d538146bf999c9134 Mon Sep 17 00:00:00 2001 From: Nyakku Shigure Date: Thu, 4 Jan 2024 17:12:33 +0800 Subject: [PATCH 110/142] [Dy2St] Fix `NameloadJstTransformer` missing transform call kwargs (#60515) --------- Co-authored-by: gouzil <66515297+gouzil@users.noreply.github.com> --- python/paddle/base/dygraph/math_op_patch.py | 4 +-- .../transformers/basic_api_transformer.py | 2 ++ python/paddle/pir/math_op_patch.py | 30 +++++++++++++++++++ .../test_load_transformer.py | 22 ++++++++++++++ test/legacy_test/test_math_op_patch_pir.py | 14 +++++++++ 5 files changed, 69 insertions(+), 3 deletions(-) diff --git a/python/paddle/base/dygraph/math_op_patch.py b/python/paddle/base/dygraph/math_op_patch.py index 172f73bf7f531f..3f7b7a40ffa461 100644 --- a/python/paddle/base/dygraph/math_op_patch.py +++ b/python/paddle/base/dygraph/math_op_patch.py @@ -167,9 +167,7 @@ def _size_(var): def _T_(var): if len(var.shape) == 1: return var - perm = [] - for i in range(len(var.shape)): - perm.insert(0, i) + perm = list(reversed(range(len(var.shape)))) out = _C_ops.transpose(var, perm) return out diff --git a/python/paddle/jit/dy2static/transformers/basic_api_transformer.py b/python/paddle/jit/dy2static/transformers/basic_api_transformer.py index 0902a3558b2b0d..01b831706cceb6 100644 --- a/python/paddle/jit/dy2static/transformers/basic_api_transformer.py +++ b/python/paddle/jit/dy2static/transformers/basic_api_transformer.py @@ -152,6 +152,8 @@ def visit_Call(self, node): Can't convert name of function call, bacause this will affect CallTransformer. """ node.args = [self.visit(arg) for arg in node.args] + for keyword in node.keywords: + keyword.value = self.visit(keyword.value) node.func = self.visit(node.func) return node diff --git a/python/paddle/pir/math_op_patch.py b/python/paddle/pir/math_op_patch.py index 45f8917bf04de6..74cb7157c6f244 100644 --- a/python/paddle/pir/math_op_patch.py +++ b/python/paddle/pir/math_op_patch.py @@ -405,6 +405,35 @@ def _size_(self): """ return paddle.numel(self) + @property + def _T_(self): + """ + + Permute current Value with its dimensions reversed. + + If `n` is the dimensions of `x` , `x.T` is equivalent to `x.transpose([n-1, n-2, ..., 0])`. + + Examples: + .. code-block:: python + + >>> import paddle + >>> paddle.enable_static() + + >>> x = paddle.ones(shape=[2, 3, 5]) + >>> x_T = x.T + + >>> exe = paddle.static.Executor() + >>> x_T_np = exe.run(paddle.static.default_main_program(), fetch_list=[x_T])[0] + >>> print(x_T_np.shape) + (5, 3, 2) + + """ + if len(self.shape) == 1: + return self + perm = list(reversed(range(len(self.shape)))) + + return _C_ops.transpose(self, perm) + def clone(self): """ Returns a new static Value, which is the clone of the original static @@ -511,6 +540,7 @@ def value_hash(self): ('ndim', _ndim), ('astype', astype), ('size', _size_), + ('T', _T_), ('clone', clone), ('clear_gradient', clear_gradient), ('append', append), diff --git a/test/dygraph_to_static/test_load_transformer.py b/test/dygraph_to_static/test_load_transformer.py index 6698ba7ef60757..80652734e933e7 100644 --- a/test/dygraph_to_static/test_load_transformer.py +++ b/test/dygraph_to_static/test_load_transformer.py @@ -71,5 +71,27 @@ def func(x): np.testing.assert_allclose(output_dy.numpy(), output_st.numpy()) +class LoadInCallKwargsNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.extra_inputs = [] + + def forward(self, x): + for i in range(len(self.extra_inputs)): + x = paddle.nn.functional.linear(weight=self.extra_inputs[i].T, x=x) + return x + + +class TestLoadInCallKwargs(Dy2StTestBase): + @test_legacy_and_pt_and_pir + def test_name_load_nograd(self): + net = LoadInCallKwargsNet() + x = paddle.rand([10, 10]) + net.extra_inputs.append(paddle.rand([10, 10])) + output_st = paddle.jit.to_static(net)(x) + output_dy = net(x) + np.testing.assert_allclose(output_dy.numpy(), output_st.numpy()) + + if __name__ == "__main__": unittest.main() diff --git a/test/legacy_test/test_math_op_patch_pir.py b/test/legacy_test/test_math_op_patch_pir.py index dc2fe9abed1a96..dbd57c1999115d 100644 --- a/test/legacy_test/test_math_op_patch_pir.py +++ b/test/legacy_test/test_math_op_patch_pir.py @@ -450,6 +450,20 @@ def test_size(self): (output_x,) = exe.run(main_program, fetch_list=[x.size]) self.assertEqual(output_x, 24) + def test_T(self): + with paddle.pir_utils.IrGuard(): + for ndim in range(5): + # shape is [], [1], [1, 2], [1, 2, 3], [1, 2, 3, 4] + shape = list(range(1, ndim + 1)) + out_shape = list(reversed(shape)) + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.rand(shape, dtype="float32") + x_T = x.T + self.assertEqual(x_T.shape, out_shape) + (output_x,) = exe.run(main_program, fetch_list=[x_T]) + self.assertEqual(output_x.shape, tuple(out_shape)) + def test_hash_error(self): with paddle.pir_utils.IrGuard(): _, _, program_guard = new_program() From 0ebae80fbc5ea4bcfcba5949f5229a4022440ae7 Mon Sep 17 00:00:00 2001 From: 6clc Date: Thu, 4 Jan 2024 17:22:18 +0800 Subject: [PATCH 111/142] cinn(backends): generate infer shape kernel to infer shape of output tensor (#60519) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 通过二维指针来返回后端infer shape的结果。生成的cinn ir如下。tensor_shape_args是一个二维指针。 infer_shape_set_value(0, 0, S1, tensor_shape_args) 表示将第0个output tensor的第0维设置为S1。 --- paddle/cinn/backends/codegen_cuda_host.cc | 10 +++ paddle/cinn/backends/codegen_cuda_host.h | 2 +- paddle/cinn/backends/codegen_cuda_util.h | 24 ++++- paddle/cinn/backends/llvm/codegen_llvm.cc | 3 +- paddle/cinn/common/type.h | 12 +++ paddle/cinn/hlir/framework/op_lowering.h | 11 +-- paddle/cinn/hlir/framework/op_lowering_impl.h | 2 +- .../hlir/framework/op_lowering_impl_base.h | 9 +- .../hlir/framework/pir/compilation_task.cc | 22 +++-- .../hlir/framework/pir/compilation_task.h | 4 +- .../hlir/framework/pir/op_lowering_impl.cc | 90 +++++++++++++++---- .../hlir/framework/pir/op_lowering_impl.h | 26 ++++-- paddle/cinn/hlir/framework/pir/utils.h | 1 + paddle/cinn/ir/ir.h | 1 + paddle/cinn/ir/lowered_func.cc | 10 +++ paddle/cinn/ir/module.cc | 4 + paddle/cinn/ir/module.h | 1 + paddle/cinn/ir/utils/ir_copy.cc | 6 +- paddle/cinn/runtime/cinn_runtime.cc | 3 + paddle/cinn/runtime/cinn_runtime.h | 1 + paddle/cinn/runtime/cuda/cuda_intrinsics.cc | 10 +++ paddle/cinn/runtime/cuda/cuda_util.cc | 3 + paddle/cinn/runtime/cuda/cuda_util.h | 1 + paddle/cinn/runtime/intrinsic.h | 2 + .../instruction/cinn_jit_instruction.cc | 70 ++++++++++++--- .../instruction/cinn_jit_instruction.h | 3 +- 26 files changed, 274 insertions(+), 57 deletions(-) diff --git a/paddle/cinn/backends/codegen_cuda_host.cc b/paddle/cinn/backends/codegen_cuda_host.cc index b23028355a06ae..11e986bb9ace1b 100644 --- a/paddle/cinn/backends/codegen_cuda_host.cc +++ b/paddle/cinn/backends/codegen_cuda_host.cc @@ -198,6 +198,11 @@ llvm::Value* CodeGenCUDA_Host::LowerHostFunc(const ir::_LoweredFunc_* func) { [](auto& arg) { return std::addressof(arg); }); // @} + // Set local scope table + CHECK_EQ(ll_function_args.size(), func->args.size()); + for (int i = 0; i < ll_function_args.size(); ++i) { + SetVar(func->args[i].name(), ll_function_args[i]); + } llvm::BasicBlock* entry = llvm::BasicBlock::Create( /*Context=*/b_->getContext(), /*Name=*/"entry", @@ -205,6 +210,11 @@ llvm::Value* CodeGenCUDA_Host::LowerHostFunc(const ir::_LoweredFunc_* func) { /*InsertBefore=*/nullptr); b_->SetInsertPoint(entry); CodeGenLLVM::Visit(&func->body); + + // Reset local scope table + for (const ir::Argument& func_arg : func->args) { + symbol_table_->Erase(func_arg.name()); + } RetVoid(); return f_; diff --git a/paddle/cinn/backends/codegen_cuda_host.h b/paddle/cinn/backends/codegen_cuda_host.h index aafaeebc248eb0..3a3453f80522b3 100644 --- a/paddle/cinn/backends/codegen_cuda_host.h +++ b/paddle/cinn/backends/codegen_cuda_host.h @@ -53,7 +53,7 @@ class CodeGenCUDA_Host : public CodeGenLLVM { } else if (op->name == runtime::intrinsic::call_cuda_kernel) { return LowerCUDAKernelCall(op); } else { - CINN_NOT_IMPLEMENTED; + return CodeGenLLVM::Visit(op); } } diff --git a/paddle/cinn/backends/codegen_cuda_util.h b/paddle/cinn/backends/codegen_cuda_util.h index 5a7f1f5882bf9b..52296bd2a8807b 100644 --- a/paddle/cinn/backends/codegen_cuda_util.h +++ b/paddle/cinn/backends/codegen_cuda_util.h @@ -31,6 +31,7 @@ namespace backends { #define KERNEL_ARGS "kernel_args" #define KERNEL_ARGS_NUM "kernel_args_num" #define KERNEL_STREAM "kernel_stream" +#define TENSOR_SHAPE_ARGS "tensor_shape_args" /** * Split a CINN Module into two separate modules, one cantains the host @@ -150,7 +151,8 @@ struct CollectBucketStrategyHostFunctionVisitor : CollectHostFunctionVisitor(module_name), kernel_args_(KERNEL_ARGS, type_of()), kernel_args_num_(KERNEL_ARGS_NUM, type_of()), - kernel_stream_(KERNEL_STREAM, type_of()) {} + kernel_stream_(KERNEL_STREAM, type_of()), + tensor_shape_args_(TENSOR_SHAPE_ARGS, type_of()) {} std::tuple operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); @@ -181,6 +183,25 @@ struct CollectBucketStrategyHostFunctionVisitor {}); host_module_builder.AddFunctionWithoutOptim( host_func.as_lowered_func_ref()); + + // Parse LoweredFunc to infer output tensor's shape + std::vector infer_shape_func_body_stmts(arg_defs_); + infer_shape_func_body_stmts.insert( + infer_shape_func_body_stmts.end(), + op->infer_shape_func.as_lowered_func()->body); + + std::vector infer_shape_arguments = { + ir::Argument(kernel_args_, ir::Argument::IO::kOutput), + ir::Argument(kernel_args_num_, ir::Argument::IO::kInput), + ir::Argument(tensor_shape_args_, ir::Argument::IO::kOutput)}; + + ir::Expr host_infer_shape_func = + ir::_LoweredFunc_::Make(op->infer_shape_func.as_lowered_func()->name, + infer_shape_arguments, + ir::Block::Make(infer_shape_func_body_stmts), + {}); + host_module_builder.AddFunctionWithoutOptim( + host_infer_shape_func.as_lowered_func_ref()); } void ProcessLoweredFunc(ir::Expr func, ir::Expr predicate); @@ -199,6 +220,7 @@ struct CollectBucketStrategyHostFunctionVisitor ir::Var kernel_args_; ir::Var kernel_args_num_; ir::Var kernel_stream_; + ir::Var tensor_shape_args_; }; } // namespace detail diff --git a/paddle/cinn/backends/llvm/codegen_llvm.cc b/paddle/cinn/backends/llvm/codegen_llvm.cc index a79e67fd6c4839..e554eca8795a43 100644 --- a/paddle/cinn/backends/llvm/codegen_llvm.cc +++ b/paddle/cinn/backends/llvm/codegen_llvm.cc @@ -818,7 +818,8 @@ llvm::Value *CodeGenLLVM::Visit(const ir::_Var_ *op) { // TODO(fc500110) hard coding if (LLVM_WillVarLowerAsPointer(op->name)) { result = value; - } else if (value->getType()->isPointerTy()) { + } else if (value->getType()->isPointerTy() && + !value->getType()->getPointerElementType()->isPointerTy()) { result = Load(value, op->name + "_load"); } else { result = value; diff --git a/paddle/cinn/common/type.h b/paddle/cinn/common/type.h index 9ce9402d84f8fe..b11a320bbd5a19 100644 --- a/paddle/cinn/common/type.h +++ b/paddle/cinn/common/type.h @@ -251,6 +251,18 @@ inline Type type_of() { return x; } template <> +inline Type type_of() { + Type x = Int(32); + x.set_cpp_handle(); + return x; +} +template <> +inline Type type_of() { + Type x = Int(32); + x.set_cpp_handle2(); + return x; +} +template <> inline Type type_of() { Type x = type_of(); x.set_cpp_handle(); diff --git a/paddle/cinn/hlir/framework/op_lowering.h b/paddle/cinn/hlir/framework/op_lowering.h index 57a54310c77198..d4b4a78e9cd3fa 100644 --- a/paddle/cinn/hlir/framework/op_lowering.h +++ b/paddle/cinn/hlir/framework/op_lowering.h @@ -47,11 +47,12 @@ class OpLowerer { group, apply_op_schedule, apply_group_schedule, apply_pass); } - std::vector> BucketLower( - const T& group, - bool apply_op_schedule = false, - bool apply_group_schedule = true, - bool apply_pass = true) { + std::vector< + std::pair> + BucketLower(const T& group, + bool apply_op_schedule = false, + bool apply_group_schedule = true, + bool apply_pass = true) { return impl_->BucketLower( group, apply_op_schedule, apply_group_schedule, apply_pass); } diff --git a/paddle/cinn/hlir/framework/op_lowering_impl.h b/paddle/cinn/hlir/framework/op_lowering_impl.h index 038c6f1ec8bf33..d48cbbeb7e9b4a 100644 --- a/paddle/cinn/hlir/framework/op_lowering_impl.h +++ b/paddle/cinn/hlir/framework/op_lowering_impl.h @@ -60,7 +60,7 @@ class OpLowererImpl : public OpLowererImplBase { bool apply_group_schedule = true, bool apply_pass = true); - std::vector> BucketLower( + std::vector> BucketLower( const GroupPtr& group, bool apply_op_schedule = false, bool apply_group_schedule = true, diff --git a/paddle/cinn/hlir/framework/op_lowering_impl_base.h b/paddle/cinn/hlir/framework/op_lowering_impl_base.h index bab0a700891121..32bda3ca50f675 100644 --- a/paddle/cinn/hlir/framework/op_lowering_impl_base.h +++ b/paddle/cinn/hlir/framework/op_lowering_impl_base.h @@ -30,6 +30,13 @@ namespace framework { template class OpLowererImplBase { public: + struct WrapLoweredFunc { + ir::LoweredFunc kernel_func; + ir::LoweredFunc infer_shape_func; + WrapLoweredFunc(ir::LoweredFunc kernel_func, + ir::LoweredFunc infer_shape_func = ir::LoweredFunc()) + : infer_shape_func(infer_shape_func), kernel_func(kernel_func) {} + }; OpLowererImplBase() = default; ~OpLowererImplBase() = default; @@ -38,7 +45,7 @@ class OpLowererImplBase { bool apply_group_schedule = true, bool apply_pass = true) = 0; - virtual std::vector> + virtual std::vector> BucketLower(const T& group, bool apply_op_schedule = false, bool apply_group_schedule = true, diff --git a/paddle/cinn/hlir/framework/pir/compilation_task.cc b/paddle/cinn/hlir/framework/pir/compilation_task.cc index 01c940b228a3dc..c6d3412102c302 100644 --- a/paddle/cinn/hlir/framework/pir/compilation_task.cc +++ b/paddle/cinn/hlir/framework/pir/compilation_task.cc @@ -15,6 +15,7 @@ #pragma once #include "paddle/cinn/hlir/framework/pir/compilation_task.h" +#include "paddle/cinn/common/target.h" #include "paddle/cinn/hlir/framework/op_lowering.h" #include "paddle/cinn/ir/module.h" @@ -23,11 +24,14 @@ namespace hlir { namespace framework { void GroupCompilationContext::SetLoweredFuncs( - std::vector>&& funcs) { - for (std::pair& predicate2func : - funcs) { + std::vector>&& funcs) { + for (std::pair& + predicate2func : funcs) { predicates_.push_back(predicate2func.first); - lowered_funcs_.push_back(predicate2func.second); + lowered_funcs_.push_back(predicate2func.second.kernel_func); + infer_shape_lowered_funcs_.push_back( + predicate2func.second.infer_shape_func); ++func_size_; } } @@ -67,12 +71,13 @@ void CompilationTask::CodegenAndJit() { ir::Module::Builder builder(cinn::common::UniqName("module"), context_->target_); CHECK_EQ(context_->predicates_.size(), context_->lowered_funcs_.size()); - for (const ir::Expr predicate : context_->predicates_) { + for (const ir::Expr& predicate : context_->predicates_) { builder.AddPredicate(predicate); } for (const ir::LoweredFunc& func : context_->lowered_funcs_) { builder.AddFunction(func); } + builder.AddInferShapeFunc(context_->infer_shape_lowered_funcs_[0]); ir::Module ir_module = builder.Build(); context_->backend_compiler_ = backends::Compiler::Create(context_->target_); @@ -90,6 +95,9 @@ std::unique_ptr CompilationTask::BuildInstruction() { VLOG(4) << "Lookup kernel name: " << fn_name; auto* fn_ptr = context_->backend_compiler_->Lookup(fn_name); CHECK(fn_ptr); + auto* infer_shape_fn_ptr = + context_->backend_compiler_->Lookup(fn_name + "_infer_shape" + fn_name); + CHECK(infer_shape_fn_ptr); instr->SetLoweredFunc(reinterpret_cast(fn_ptr), fn_name); instr->Finalize(); return instr; @@ -100,8 +108,12 @@ pir::CINNKernelInfo CompilationTask::BuildPirCINNKernelInfo() { VLOG(4) << "Lookup kernel name: " << fn_name; auto* fn_ptr = context_->backend_compiler_->Lookup(fn_name); CHECK(fn_ptr); + auto* infer_shape_fn_ptr = + context_->backend_compiler_->Lookup(fn_name + "_infer_shape"); + CHECK(infer_shape_fn_ptr); pir::CINNKernelInfo cinn_kernel_info; cinn_kernel_info.fn_ptr = fn_ptr; + cinn_kernel_info.infer_shape_fn_ptr = infer_shape_fn_ptr; cinn_kernel_info.int_args_map = context_->group_->int_args_map; return cinn_kernel_info; } diff --git a/paddle/cinn/hlir/framework/pir/compilation_task.h b/paddle/cinn/hlir/framework/pir/compilation_task.h index 5291cafe4a2f32..9e96c64694527e 100644 --- a/paddle/cinn/hlir/framework/pir/compilation_task.h +++ b/paddle/cinn/hlir/framework/pir/compilation_task.h @@ -32,7 +32,8 @@ class GroupCompilationContext { : target_(target), group_(group), scope_(scope) {} void SetLoweredFuncs( - std::vector>&& funcs); + std::vector>&& funcs); std::string PrintPredicate2Funcs() const; void* FuncPtr(); std::shared_ptr BackendCompiler(); @@ -47,6 +48,7 @@ class GroupCompilationContext { size_t func_size_ = 0; std::vector predicates_; std::vector lowered_funcs_; + std::vector infer_shape_lowered_funcs_; std::string host_func_name_; std::string host_code_; std::vector device_code_; diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc index 643e4ed294b4cd..062e5db1cc1f8c 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc @@ -18,6 +18,7 @@ #include "paddle/cinn/adt/map_expr_ctx.h" #include "paddle/cinn/ast_gen_ius/tensor_group.h" +#include "paddle/cinn/backends/codegen_cuda_util.h" #include "paddle/cinn/hlir/framework/compile_error.h" #include "paddle/cinn/hlir/framework/pir/op_lowering_util.h" #include "paddle/cinn/hlir/framework/pir/utils.h" @@ -99,7 +100,7 @@ std::vector OpLowererImpl::Lower(const GroupPtr& group, } } -std::vector> +std::vector> OpLowererImpl::BucketLower(const GroupPtr& group, bool apply_op_schedule, bool apply_group_schedule, @@ -107,7 +108,8 @@ OpLowererImpl::BucketLower(const GroupPtr& group, // 1.Do compute, lower and schedule for each op. auto& ops = group->ops; if (ops.size() == 1 && ops[0]->name() == "custom_call") { - return {{ir::Expr(1), LowerCustomCall(group)[0]}}; + return {{ir::Expr(1), + pir::OpLowererImpl::WrapLoweredFunc(LowerCustomCall(group)[0])}}; } std::vector group_func_arg_tensors; std::unordered_map<::pir::Value, ir::Tensor> tensor_map; @@ -150,20 +152,22 @@ OpLowererImpl::BucketLower(const GroupPtr& group, // 3.Do post-processing, // including preparing function args and temporary variables, // applying low-level optimization passes, etc. - std::vector> cond2funcs; + std::vector> cond2funcs; for (std::pair& cond2body : cond2func_bodies) { std::vector group_func_arg_tensors_copy = group_func_arg_tensors; + std::vector group_func_args; std::vector funcs = PostProcess(group, tensor_map, apply_op_schedule, cond2body.second, - &group_func_arg_tensors_copy); - for (ir::LoweredFunc& func : funcs) { - cond2funcs.emplace_back(cond2body.first, func); - } + &group_func_arg_tensors_copy, + &group_func_args); + ir::LoweredFunc infer_shape_func = GenerateInferShapeFunc( + group, group_func_arg_tensors_copy, group_func_args); + cond2funcs.push_back({cond2body.first, {funcs[0], infer_shape_func}}); } return cond2funcs; } @@ -292,11 +296,13 @@ std::vector OpLowererImpl::LowerMapExpr( // 3.Do post-processing, // including preparing function args and temporary variables, // applying low-level optimization passes, etc. + std::vector group_func_args; return PostProcess(group, *tensor_map, apply_op_schedule, ir_sch.GetModule().GetExprs()[0], - group_func_arg_tensors); + group_func_arg_tensors, + &group_func_args); } std::vector OpLowererImpl::LowerGroup( @@ -345,11 +351,13 @@ std::vector OpLowererImpl::LowerGroup( // 3.Do post-processing, // including preparing function args and temporary variables, // applying low-level optimization passes, etc. + std::vector group_func_args; return PostProcess(group, tensor_map, do_op_schedule, ir_sch.GetModule().GetExprs().at(0), - &group_func_arg_tensors); + &group_func_arg_tensors, + &group_func_args); } std::vector OpLowererImpl::LowerCustomCall( @@ -403,16 +411,17 @@ std::vector OpLowererImpl::PostProcess( const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, bool done_op_schedule, ir::Expr func_body, - std::vector* group_func_arg_tensors) { + std::vector* group_func_arg_tensors, + std::vector* group_func_args) { // 1.Prepare function args group->input_names.clear(); - std::vector group_func_args; std::unordered_set arg_name_set; for (auto& arg_tensor : *group_func_arg_tensors) { // input data name. group->input_names.push_back(arg_tensor->name); // input args - group_func_args.emplace_back(arg_tensor->buffer, ir::Argument::IO::kInput); + (*group_func_args) + .emplace_back(arg_tensor->buffer, ir::Argument::IO::kInput); arg_name_set.insert(arg_tensor->buffer->name); } @@ -434,14 +443,15 @@ std::vector OpLowererImpl::PostProcess( group_func_arg_tensors->push_back(tensor); // output args group->output_names.push_back(tensor->name); - group_func_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput); + (*group_func_args) + .emplace_back(tensor->buffer, ir::Argument::IO::kOutput); arg_name_set.insert(tensor->buffer->name); } } if (!done_op_schedule) { std::unordered_set args_set; - for (auto arg : group_func_args) { + for (auto arg : (*group_func_args)) { args_set.insert(arg.name()); } for (auto& op : group->ops) { @@ -457,15 +467,16 @@ std::vector OpLowererImpl::PostProcess( group->output_values.push_back(opresult); group_func_arg_tensors->push_back(tensor); group->output_names.push_back(tensor->name); - group_func_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput); + group_func_args->emplace_back(tensor->buffer, + ir::Argument::IO::kOutput); } } } std::map mps; // update args for dynamic dim - int num_tensor_args = static_cast(group_func_args.size()); - int non_tensor_arg_idx = group_func_args.size(); + int num_tensor_args = static_cast(group_func_args->size()); + int non_tensor_arg_idx = group_func_args->size(); std::unordered_set int_args_set; for (int tensor_arg_idx = 0; tensor_arg_idx < num_tensor_args; tensor_arg_idx++) { @@ -480,7 +491,7 @@ std::vector OpLowererImpl::PostProcess( continue; } int_args_set.insert(symbol_name); - group_func_args.emplace_back( + group_func_args->emplace_back( ir::_Var_::Make(symbol_name, cinn::common::Int(32))); group->int_args_map[non_tensor_arg_idx++] = {tensor_arg_idx, tensor_arg_dim_idx}; @@ -500,7 +511,7 @@ std::vector OpLowererImpl::PostProcess( lang::GetTempBuffers(*group_func_arg_tensors, stages, func_body); // 3.Building LoweredFunc auto func = ir::_LoweredFunc_::Make( - group->FuncName(), group_func_args, func_body, temp_buffers); + group->FuncName(), *group_func_args, func_body, temp_buffers); if (!done_op_schedule) { func->PrepareBufferCastExprs(); } @@ -1023,6 +1034,47 @@ bool OpLowererImpl::IsInTensorMap( return false; } +ir::LoweredFunc OpLowererImpl::GenerateInferShapeFunc( + const GroupPtr& group, + const std::vector group_func_arg_tensors, + const std::vector group_func_args) { + // CHECK_EQ(group_func_arg_tensors.size(), group_func_args.size()); + std::vector ir_bodys; + int output_tensor_idx = 0; + for (int tensor_arg_idx = 0; tensor_arg_idx < group_func_arg_tensors.size(); + ++tensor_arg_idx) { + if (group_func_args[tensor_arg_idx].is_input()) { + continue; + } + auto tensor_dim = group_func_arg_tensors[tensor_arg_idx]->sym_shape; + int tensor_dim_size = tensor_dim.size(); + auto tensor_shape = group_func_arg_tensors[tensor_arg_idx]->shape; + + ir::Var tensor_shape_args(TENSOR_SHAPE_ARGS, type_of()); + for (int i = 0; i < tensor_shape.size(); i++) { + ir::Expr call_set_infer_shape_value = + ir::Call::Make(type_of(), + runtime::intrinsic::infer_shape_set_value, + {ir::Expr(output_tensor_idx), + ir::Expr(i), + tensor_shape[i], + tensor_shape_args}, + {}, + ir::CallType::Extern, + ir::FunctionRef(), + 0); + ir_bodys.push_back(call_set_infer_shape_value); + } + ++output_tensor_idx; + } + ir::LoweredFunc infer_shape_func = + ir::_LoweredFunc_::Make(group->FuncName() + "_infer_shape", + group_func_args, + ir::Block::Make(ir_bodys), + {}); + return infer_shape_func; +} + } // namespace pir } // namespace framework } // namespace hlir diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_impl.h b/paddle/cinn/hlir/framework/pir/op_lowering_impl.h index aa29119281b51c..0a9f4d4b33820a 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_impl.h +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.h @@ -70,11 +70,11 @@ class OpLowererImpl : public OpLowererImplBase { * @param apply_group_schedule Whether to schedule at group level. * @return The lowered funcs. */ - std::vector> BucketLower( - const GroupPtr& group, - bool apply_op_schedule = false, - bool apply_group_schedule = true, - bool apply_pass = true); + std::vector> + BucketLower(const GroupPtr& group, + bool apply_op_schedule = false, + bool apply_group_schedule = true, + bool apply_pass = true); void InsertNameGeneToScope(std::shared_ptr scope); @@ -110,6 +110,7 @@ class OpLowererImpl : public OpLowererImplBase { * applied. * @param func_body The scheduled func body of group. * @param group_func_arg_tensors Tensors used as the group function arguments. + * @param group_func_args Arguments used as the group function arguments. * @return The lowered funcs after the post processing. */ std::vector PostProcess( @@ -117,7 +118,8 @@ class OpLowererImpl : public OpLowererImplBase { const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, bool done_op_schedule, ir::Expr func_body, - std::vector* group_func_arg_tensors); + std::vector* group_func_arg_tensors, + std::vector* group_func_args); /** * @brief Lower an Op set to CINN IR. @@ -214,6 +216,18 @@ class OpLowererImpl : public OpLowererImplBase { const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, const std::unordered_map& tmp_tensor_info); + /** + * @brief Generates the output tensor infer shape function. + * @param group The group to be lowered. + * @param group_func_arg_tensors Tensors used as the group function arguments. + * @param group_func_args Arguments used as the group function arguments. + * @return The lowered func to infer output tensor's shape. + */ + ir::LoweredFunc GenerateInferShapeFunc( + const GroupPtr& group, + const std::vector group_func_arg_tensors, + const std::vector group_func_args); + // Functions used to determine which Ops to schedule at op level, define a // policy for each type of group. inline bool ReduceScheduleDetermineFunction(::pir::Operation* op); diff --git a/paddle/cinn/hlir/framework/pir/utils.h b/paddle/cinn/hlir/framework/pir/utils.h index 4d97d48291903a..ce9fa8c1cb9f17 100644 --- a/paddle/cinn/hlir/framework/pir/utils.h +++ b/paddle/cinn/hlir/framework/pir/utils.h @@ -31,6 +31,7 @@ namespace pir { struct CINNKernelInfo { void* fn_ptr; + void* infer_shape_fn_ptr; struct ArgDimIdx { int arg_idx; diff --git a/paddle/cinn/ir/ir.h b/paddle/cinn/ir/ir.h index 7859a7181c527b..3e9460e084a36f 100644 --- a/paddle/cinn/ir/ir.h +++ b/paddle/cinn/ir/ir.h @@ -1018,6 +1018,7 @@ struct _Module_ : public ExprNode<_Module_> { std::vector functions; std::vector submodules; std::vector predicates; + Expr infer_shape_func; static ir::Module Make(const std::string& name, Target target); diff --git a/paddle/cinn/ir/lowered_func.cc b/paddle/cinn/ir/lowered_func.cc index 129fc5d6e32782..d252a5e44954f5 100644 --- a/paddle/cinn/ir/lowered_func.cc +++ b/paddle/cinn/ir/lowered_func.cc @@ -398,11 +398,21 @@ void _LoweredFunc_::PrepareArgumentExprs() { } else if (arg.type() == type_of()) { pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + } else if (arg.type() == type_of()) { + pod_cast_expr = + ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + } else if (arg.type() == type_of()) { + pod_cast_expr = + ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + } else if (arg.type() == type_of()) { + pod_cast_expr = + ir::intrinsics::PodValueToX::Make(load_expr, type_of()); } else { LOG(ERROR) << "Not supported type [" << arg.type() << "]"; CINN_NOT_IMPLEMENTED } + VLOG(6) << "args " << i << "convert"; Expr let_expr = Let::Make(_arg, pod_cast_expr); CHECK(let_expr.type().valid()); argument_prepare_exprs.push_back(let_expr); diff --git a/paddle/cinn/ir/module.cc b/paddle/cinn/ir/module.cc index d54286d9fc2ec6..fc58e44956fe76 100644 --- a/paddle/cinn/ir/module.cc +++ b/paddle/cinn/ir/module.cc @@ -53,6 +53,10 @@ void Module::Builder::AddPredicate(ir::Expr predicate) { module_->predicates.push_back(predicate); } +void Module::Builder::AddInferShapeFunc(ir::Expr infer_shape_func) { + module_->infer_shape_func = infer_shape_func; +} + void Module::Builder::Clear() { module_->buffers.clear(); module_->functions.clear(); diff --git a/paddle/cinn/ir/module.h b/paddle/cinn/ir/module.h index fad8377e6b0158..9910caab42b503 100644 --- a/paddle/cinn/ir/module.h +++ b/paddle/cinn/ir/module.h @@ -45,6 +45,7 @@ class Module : public ir::IrNodeRef { void AddFunctionWithoutOptim(const ir::LoweredFunc& func); void AddBuffer(ir::Buffer buffer); void AddPredicate(ir::Expr predicate); + void AddInferShapeFunc(ir::Expr infer_shape_func); void Clear(); Target::Arch GetTargetArch(); diff --git a/paddle/cinn/ir/utils/ir_copy.cc b/paddle/cinn/ir/utils/ir_copy.cc index 08dc2bc1e628cd..b444be218c39a5 100644 --- a/paddle/cinn/ir/utils/ir_copy.cc +++ b/paddle/cinn/ir/utils/ir_copy.cc @@ -242,7 +242,7 @@ struct IRCopyVisitor : public ir::IRVisitorRequireReImpl { std::vector functions; std::vector submodules; std::vector predicates; - + Expr infer_shape_func; for (auto& expr : op->buffers) { buffers.push_back(Visit(&expr)); } @@ -258,12 +258,16 @@ struct IRCopyVisitor : public ir::IRVisitorRequireReImpl { for (auto& expr : op->predicates) { predicates.push_back(Visit(&expr)); } + if (op->infer_shape_func.defined()) { + infer_shape_func = Visit(&op->infer_shape_func); + } auto res = ir::_Module_::Make(op->name, op->target); res->buffers = buffers; res->functions = functions; res->submodules = submodules; res->predicates = predicates; + res->infer_shape_func = infer_shape_func; return Expr(res); } diff --git a/paddle/cinn/runtime/cinn_runtime.cc b/paddle/cinn/runtime/cinn_runtime.cc index b8bc96d508877b..c4c25e8f867868 100644 --- a/paddle/cinn/runtime/cinn_runtime.cc +++ b/paddle/cinn/runtime/cinn_runtime.cc @@ -375,6 +375,9 @@ uint8_t cinn_pod_value_to_uint8(cinn_pod_value_t* value) { return *value; } bool cinn_pod_value_to_bool(cinn_pod_value_t* value) { return *value; } void* cinn_pod_value_to_void_p(cinn_pod_value_t* value) { return *value; } +int32_t* cinn_pod_value_to_int32_p(cinn_pod_value_t* value) { + return reinterpret_cast(value->data_addr()); +} cinn_buffer_t* cinn_pod_value_to_buffer_p(cinn_pod_value_t* value) { return *value; } diff --git a/paddle/cinn/runtime/cinn_runtime.h b/paddle/cinn/runtime/cinn_runtime.h index 17b5a400fd122b..4a5ce5d18d179c 100644 --- a/paddle/cinn/runtime/cinn_runtime.h +++ b/paddle/cinn/runtime/cinn_runtime.h @@ -561,6 +561,7 @@ uint8_t cinn_pod_value_to_uint8(cinn_pod_value_t* value); bool cinn_pod_value_to_bool(cinn_pod_value_t* value); void* cinn_pod_value_to_void_p(cinn_pod_value_t* value); +int32_t* cinn_pod_value_to_int32_p(cinn_pod_value_t* value); cinn_buffer_t* cinn_pod_value_to_buffer_p(cinn_pod_value_t* value); // @} diff --git a/paddle/cinn/runtime/cuda/cuda_intrinsics.cc b/paddle/cinn/runtime/cuda/cuda_intrinsics.cc index e090117a423e4b..c4f335603963be 100644 --- a/paddle/cinn/runtime/cuda/cuda_intrinsics.cc +++ b/paddle/cinn/runtime/cuda/cuda_intrinsics.cc @@ -434,6 +434,16 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) { .AddInputType() // index .End(); + using cinn::runtime::cuda::infer_shape_set_value; + REGISTER_EXTERN_FUNC_HELPER(infer_shape_set_value, + cinn::common::DefaultHostTarget()) + .SetRetType() + .AddInputType() + .AddInputType() + .AddInputType() + .AddInputType() + .End(); + using cinn::runtime::cuda::cinn_call_cuda_kernel; REGISTER_EXTERN_FUNC_HELPER(cinn_call_cuda_kernel, cinn::common::DefaultHostTarget()) diff --git a/paddle/cinn/runtime/cuda/cuda_util.cc b/paddle/cinn/runtime/cuda/cuda_util.cc index 326e5a3aac561d..98ba1c52d7edc3 100644 --- a/paddle/cinn/runtime/cuda/cuda_util.cc +++ b/paddle/cinn/runtime/cuda/cuda_util.cc @@ -2748,6 +2748,9 @@ void cinn_gpu_cudnn_pool2d(const std::vector &attrs, cudnnDestroyPoolingDescriptor(pooling_desc); } +void infer_shape_set_value(int row, int col, int32_t value, int32_t **v) { + v[row][col] = value; +} void cinn_gpu_cudnn_softmax(const std::vector &attrs, cinn_buffer_t *input, cinn_buffer_t *output, diff --git a/paddle/cinn/runtime/cuda/cuda_util.h b/paddle/cinn/runtime/cuda/cuda_util.h index 7ea9dbe00a2c5b..c7d9220e00688f 100644 --- a/paddle/cinn/runtime/cuda/cuda_util.h +++ b/paddle/cinn/runtime/cuda/cuda_util.h @@ -96,6 +96,7 @@ void cinn_call_cuda_memcpy(void* v_args, void* stream = nullptr); int32_t cinn_get_value_in_cuda_kernel_args(void* v_args, int idx); +void infer_shape_set_value(int row, int col, int32_t value, int32_t** v); /** * Call a CUDA compiled kernel. diff --git a/paddle/cinn/runtime/intrinsic.h b/paddle/cinn/runtime/intrinsic.h index 6939a8ea1f457f..c2db240de2d12f 100644 --- a/paddle/cinn/runtime/intrinsic.h +++ b/paddle/cinn/runtime/intrinsic.h @@ -107,6 +107,8 @@ static const char* call_cuda_kernel = "cinn_call_cuda_kernel"; static const char* get_value_in_cuda_kernel_args = "cinn_get_value_in_cuda_kernel_args"; +static const char* infer_shape_set_value = "infer_shape_set_value"; + static const char* pod_values_to_array_repr = "pod_values_to_array"; static const char* get_address_repr = "get_address"; diff --git a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc index 9ff10d0ae7c91c..180eb4f478fa6b 100644 --- a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc @@ -30,6 +30,7 @@ namespace paddle { namespace framework { typedef void (*lower_func_ptr_g)(void*, int32_t, void*); +typedef void (*infer_shape_func_ptr_g)(void*, int32_t, int32_t**); class CinnJitInstruction::FnPtrImpl { using CINNKernelInfo = cinn::hlir::framework::pir::CINNKernelInfo; @@ -61,6 +62,49 @@ class CinnJitInstruction::FnPtrImpl { static_cast(func_args_.data()), func_args_.size(), stream); } + void InferShape(const std::vector& kernel_args, + int32_t input_tensor_size, + int32_t output_tensor_size) { + func_args_.clear(); + + // 1. Convert the phi::DenseTensor type to cinn_pod_value_t + for (size_t i = 0; i < kernel_args.size(); ++i) { + auto* buffer = new cinn_buffer_t(); + func_args_.emplace_back(buffer); + } + + // 2. Convert arg's data about shape of Tensor to cinn_pod_value_t + for (const auto& int_arg_mp : cinn_kernel_info_.int_args_map) { + func_args_.emplace_back(kernel_args[int_arg_mp.second.arg_idx]->dims().at( + int_arg_mp.second.dim_idx)); + func_args_.emplace_back(static_cast( + kernel_args[int_arg_mp.second.arg_idx]->dims().at( + int_arg_mp.second.dim_idx))); + } + + // 3. Define an array of Pointers to hold the output tensor shape + int32_t* output_tensor_shapes[output_tensor_size]; + for (int i = 0; i < output_tensor_size; ++i) { + output_tensor_shapes[i] = reinterpret_cast( + malloc(kernel_args[input_tensor_size + i]->dims().size() * + sizeof(int32_t*))); + } + + // 4. Launch infer_shape_fn_ptr to infer shape of output tensor + ((infer_shape_func_ptr_g)cinn_kernel_info_.infer_shape_fn_ptr)( + static_cast(func_args_.data()), + func_args_.size(), + output_tensor_shapes); + + // 5. Resize shape of output tensor + for (int i = 0; i < output_tensor_size; ++i) { + DDim dim(output_tensor_shapes[i], + kernel_args[input_tensor_size + i]->dims().size()); + kernel_args[input_tensor_size + i]->Resize(dim); + free(output_tensor_shapes[i]); + } + } + private: CINNKernelInfo cinn_kernel_info_; @@ -76,6 +120,8 @@ CinnJitInstruction::CinnJitInstruction( auto jit_kernel_op = op->dyn_cast(); fn_ptr_impl_ = std::make_shared(jit_kernel_op.cinn_kernel_info()); op_ = op; + input_tensor_size = op->num_operands(); + output_tensor_size = op->num_results(); place_ = place; @@ -103,14 +149,11 @@ CinnJitInstruction::CinnJitInstruction( ->GetMutable(); tensor_args_.push_back(tensor); - - if (!FLAGS_cinn_bucket_compile) { - auto alloc_tensor_type = - result.type().dyn_cast(); - tensor->set_type( - paddle::dialect::TransToPhiDataType(alloc_tensor_type.dtype())); - tensor->Resize(alloc_tensor_type.dims()); - } + auto alloc_tensor_type = + result.type().dyn_cast(); + tensor->set_type( + paddle::dialect::TransToPhiDataType(alloc_tensor_type.dtype())); + tensor->Resize(alloc_tensor_type.dims()); } } @@ -120,16 +163,15 @@ void CinnJitInstruction::Run() { auto stream = gpu_ctx->stream(); + if (FLAGS_cinn_bucket_compile) { + fn_ptr_impl_->InferShape( + tensor_args_, input_tensor_size, output_tensor_size); + } for (size_t i = 0; i < tensor_args_.size(); ++i) { - // TODO(6clc): template infer shape from tensor_args_[0]. - // After supporting symbolic calculation, perfect the code to query shape - // of output tensor - if (FLAGS_cinn_bucket_compile) { - tensor_args_[i]->Resize(tensor_args_[0]->dims()); - } gpu_ctx->Alloc(tensor_args_[i], tensor_args_[i]->dtype()); } + // 2. exexute kernel fn_ptr_impl_->Run(tensor_args_, static_cast(stream)); #else VLOG(phi::FATAL) << "Not Supported: cinn jit instruction currently does not " diff --git a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h index b15fae77bdbe77..5f744f4229d911 100644 --- a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h @@ -49,7 +49,8 @@ class CinnJitInstruction : public InstructionBase { phi::DeviceContext* dev_ctx_; - phi::DenseTensor* out_tensor_; + int32_t input_tensor_size; + int32_t output_tensor_size; std::vector tensor_args_; From a6cd3bd8096f0abf22a819b40b596cdcb269cfbb Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Thu, 4 Jan 2024 17:40:22 +0800 Subject: [PATCH 112/142] fix tensor math method inplace converter (#60546) --- paddle/fluid/pybind/eager_math_op_patch.cc | 803 +++++++++++++-------- python/paddle/optimizer/optimizer.py | 4 +- 2 files changed, 486 insertions(+), 321 deletions(-) diff --git a/paddle/fluid/pybind/eager_math_op_patch.cc b/paddle/fluid/pybind/eager_math_op_patch.cc index 2c01e122914aa4..cdfc882118401b 100644 --- a/paddle/fluid/pybind/eager_math_op_patch.cc +++ b/paddle/fluid/pybind/eager_math_op_patch.cc @@ -233,24 +233,34 @@ static PyObject* tensor__add__method(TensorObject* self, // 2. create or get tensor for other_obj paddle::Tensor other_tensor; if (PyCheckTensor(other_obj)) { - other_tensor = CastPyArg2Tensor(other_obj, 0); - } else if (IsNumpyArray(other_obj)) { - py::object numpy_value = py::object(py::handle(other_obj), true); - other_tensor = paddle::Tensor(place); - InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + auto& self_tensor_ref = self->tensor; + auto& other_tensor_ref = CastPyArg2Tensor(other_obj, 0); + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor_ref, other_tensor_ref)) { + ConvertAllInputsToDistTensor(mesh, self_tensor_ref, other_tensor_ref); + } + self_tensor = self_tensor_ref; + other_tensor = other_tensor_ref; } else { - paddle::experimental::Scalar value = - CastPyArg2Scalar(other_obj, "__add__", 0); - { - eager_gil_scoped_release guard; - other_tensor = full_ad_func( - self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); + if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + } else { + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__add__", 0); + { + eager_gil_scoped_release guard; + other_tensor = full_ad_func(self_tensor.shape(), + value, + self_tensor.dtype(), + self_tensor.place()); + } + } + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { + ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - } - - const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { - ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } // 3. promote types or unify right var type to left var, float type promotion @@ -313,8 +323,8 @@ static PyObject* tensor__sub__method(TensorObject* self, SetDevice(place); paddle::Tensor ret; - paddle::Tensor self_tensor = self->tensor; + paddle::Tensor self_tensor = self->tensor; PyObject* other_obj = PyTuple_GET_ITEM(args, 0); // 1. scalar exists cases if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) || @@ -337,27 +347,38 @@ static PyObject* tensor__sub__method(TensorObject* self, return ToPyObject(ret); } + // 2. create or get tensor for other_obj paddle::Tensor other_tensor; if (PyCheckTensor(other_obj)) { - other_tensor = CastPyArg2Tensor(other_obj, 0); - } else if (IsNumpyArray(other_obj)) { - py::object numpy_value = py::object(py::handle(other_obj), true); - other_tensor = paddle::Tensor(place); - InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + auto& self_tensor_ref = self->tensor; + auto& other_tensor_ref = CastPyArg2Tensor(other_obj, 0); + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor, other_tensor_ref)) { + ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor_ref); + } + self_tensor = self_tensor_ref; + other_tensor = other_tensor_ref; } else { - paddle::experimental::Scalar value = - CastPyArg2Scalar(other_obj, "__sub__", 0); - { - eager_gil_scoped_release guard; - other_tensor = full_ad_func( - self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); + if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + } else { + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__sub__", 0); + { + eager_gil_scoped_release guard; + other_tensor = full_ad_func(self_tensor.shape(), + value, + self_tensor.dtype(), + self_tensor.place()); + } + } + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { + ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - } - - const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { - ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } // 3. promote types or unify right var type to left var, float type promotion @@ -444,24 +465,34 @@ static PyObject* tensor__rsub__method(TensorObject* self, // 2. create or get tensor for other_obj paddle::Tensor other_tensor; if (PyCheckTensor(other_obj)) { - other_tensor = CastPyArg2Tensor(other_obj, 0); - } else if (IsNumpyArray(other_obj)) { - py::object numpy_value = py::object(py::handle(other_obj), true); - other_tensor = paddle::Tensor(place); - InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + auto& other_tensor_ref = CastPyArg2Tensor(other_obj, 0); + auto& self_tensor_ref = self->tensor; + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor_ref, other_tensor_ref)) { + ConvertAllInputsToDistTensor(mesh, self_tensor_ref, other_tensor_ref); + } + self_tensor = self_tensor_ref; + other_tensor = other_tensor_ref; } else { - paddle::experimental::Scalar value = - CastPyArg2Scalar(other_obj, "__rsub__", 0); - { - eager_gil_scoped_release guard; - other_tensor = full_ad_func( - self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); + if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + } else { + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__rsub__", 0); + { + eager_gil_scoped_release guard; + other_tensor = full_ad_func(self_tensor.shape(), + value, + self_tensor.dtype(), + self_tensor.place()); + } + } + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { + ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - } - - const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { - ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } // 3. promote types or unify right var type to left var, float type promotion @@ -521,8 +552,8 @@ static PyObject* tensor__mul__method(TensorObject* self, SetDevice(place); paddle::Tensor ret; - paddle::Tensor self_tensor = self->tensor; + paddle::Tensor self_tensor = self->tensor; PyObject* other_obj = PyTuple_GET_ITEM(args, 0); // 1. scalar exists cases @@ -547,32 +578,46 @@ static PyObject* tensor__mul__method(TensorObject* self, } // 2. create or get tensor for other_obj + // if lhs or rhs input is tensor, we need to inplace cast it to dist_tensor + // if one of the input is numpy or scalar, no need to do inplace cast. paddle::Tensor other_tensor; if (PyCheckTensor(other_obj)) { - other_tensor = CastPyArg2Tensor(other_obj, 0); - } else if (IsNumpyArray(other_obj)) { - py::object numpy_value = py::object(py::handle(other_obj), true); - other_tensor = paddle::Tensor(place); - InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + auto& self_tensor_ref = self->tensor; + auto& other_tensor_ref = CastPyArg2Tensor(other_obj, 0); + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor_ref, other_tensor_ref)) { + ConvertAllInputsToDistTensor(mesh, self_tensor_ref, other_tensor_ref); + } + self_tensor = self_tensor_ref; + other_tensor = other_tensor_ref; } else { - paddle::experimental::Scalar value = - CastPyArg2Scalar(other_obj, "__mul__", 0); - if (PyComplex_Check(other_obj)) { - eager_gil_scoped_release guard; - other_tensor = full_ad_func( - self_tensor.shape(), value, DataType::COMPLEX64, self_tensor.place()); + if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); } else { - eager_gil_scoped_release guard; - other_tensor = full_ad_func( - self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__mul__", 0); + if (PyComplex_Check(other_obj)) { + eager_gil_scoped_release guard; + other_tensor = full_ad_func(self_tensor.shape(), + value, + DataType::COMPLEX64, + self_tensor.place()); + } else { + eager_gil_scoped_release guard; + other_tensor = full_ad_func(self_tensor.shape(), + value, + self_tensor.dtype(), + self_tensor.place()); + } + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { + ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); + } } } - const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { - ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); - } - // 3. promote types or unify right var type to left var, float type promotion // mv to multiply_ad_func phi::DataType lhs_dtype = self_tensor.dtype(); @@ -633,8 +678,8 @@ static PyObject* tensor__div__method(TensorObject* self, SetDevice(place); paddle::Tensor ret; - paddle::Tensor self_tensor = self->tensor; + paddle::Tensor self_tensor = self->tensor; PyObject* other_obj = PyTuple_GET_ITEM(args, 0); // 1. scalar exists cases @@ -661,28 +706,38 @@ static PyObject* tensor__div__method(TensorObject* self, // 2. create or get tensor for other_obj paddle::Tensor other_tensor; if (PyCheckTensor(other_obj)) { - other_tensor = CastPyArg2Tensor(other_obj, 0); - } else if (IsNumpyArray(other_obj)) { - py::object numpy_value = py::object(py::handle(other_obj), true); - other_tensor = paddle::Tensor(place); - InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + auto& self_tensor_ref = self->tensor; + auto& other_tensor_ref = CastPyArg2Tensor(other_obj, 0); + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor_ref, other_tensor_ref)) { + ConvertAllInputsToDistTensor(mesh, self_tensor_ref, other_tensor_ref); + } + self_tensor = self_tensor_ref; + other_tensor = other_tensor_ref; } else { - paddle::experimental::Scalar value = - CastPyArg2Scalar(other_obj, "__div__", 0); - if (PyComplex_Check(other_obj)) { - eager_gil_scoped_release guard; - other_tensor = - full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); } else { - eager_gil_scoped_release guard; - other_tensor = full_ad_func( - self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__div__", 0); + if (PyComplex_Check(other_obj)) { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + } else { + eager_gil_scoped_release guard; + other_tensor = full_ad_func(self_tensor.shape(), + value, + self_tensor.dtype(), + self_tensor.place()); + } + } + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { + ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - } - - const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { - ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } // 3. promote types or unify right var type to left var @@ -753,8 +808,8 @@ static PyObject* tensor__rdiv__method(TensorObject* self, SetDevice(place); paddle::Tensor ret; - paddle::Tensor self_tensor = self->tensor; + paddle::Tensor self_tensor = self->tensor; PyObject* other_obj = PyTuple_GET_ITEM(args, 0); // 1. scalar exists cases @@ -786,28 +841,38 @@ static PyObject* tensor__rdiv__method(TensorObject* self, self_tensor.dtype(), place); } else if (PyCheckTensor(other_obj)) { - other_tensor = CastPyArg2Tensor(other_obj, 0); - } else if (IsNumpyArray(other_obj)) { - py::object numpy_value = py::object(py::handle(other_obj), true); - other_tensor = paddle::Tensor(place); - InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + auto& self_tensor_ref = self->tensor; + auto& other_tensor_ref = CastPyArg2Tensor(other_obj, 0); + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor_ref, other_tensor_ref)) { + ConvertAllInputsToDistTensor(mesh, self_tensor_ref, other_tensor_ref); + } + self_tensor = self_tensor_ref; + other_tensor = other_tensor_ref; } else { - paddle::experimental::Scalar value = - CastPyArg2Scalar(other_obj, "__rdiv__", 0); - if (PyComplex_Check(other_obj)) { - eager_gil_scoped_release guard; - other_tensor = - full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); } else { - eager_gil_scoped_release guard; - other_tensor = full_ad_func( - self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__rdiv__", 0); + if (PyComplex_Check(other_obj)) { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + } else { + eager_gil_scoped_release guard; + other_tensor = full_ad_func(self_tensor.shape(), + value, + self_tensor.dtype(), + self_tensor.place()); + } + } + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { + ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - } - - const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { - ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } // 3. promote types or unify right var type to left var @@ -909,28 +974,38 @@ static PyObject* tensor__gt__method(TensorObject* self, self_tensor.dtype(), place); } else if (PyCheckTensor(other_obj)) { - other_tensor = CastPyArg2Tensor(other_obj, 0); - } else if (IsNumpyArray(other_obj)) { - py::object numpy_value = py::object(py::handle(other_obj), true); - other_tensor = paddle::Tensor(place); - InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + auto& self_tensor_ref = self->tensor; + auto& other_tensor_ref = CastPyArg2Tensor(other_obj, 0); + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor_ref, other_tensor_ref)) { + ConvertAllInputsToDistTensor(mesh, self_tensor_ref, other_tensor_ref); + } + self_tensor = self_tensor_ref; + other_tensor = other_tensor_ref; } else { - paddle::experimental::Scalar value = - CastPyArg2Scalar(other_obj, "__gt__", 0); - if (PyComplex_Check(other_obj)) { - eager_gil_scoped_release guard; - other_tensor = - full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); } else { - eager_gil_scoped_release guard; - other_tensor = full_ad_func( - self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__gt__", 0); + if (PyComplex_Check(other_obj)) { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + } else { + eager_gil_scoped_release guard; + other_tensor = full_ad_func(self_tensor.shape(), + value, + self_tensor.dtype(), + self_tensor.place()); + } + } + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { + ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - } - - const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { - ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } // 3. promote types or unify right var type to left var @@ -1004,28 +1079,38 @@ static PyObject* tensor__ge__method(TensorObject* self, self_tensor.dtype(), place); } else if (PyCheckTensor(other_obj)) { - other_tensor = CastPyArg2Tensor(other_obj, 0); - } else if (IsNumpyArray(other_obj)) { - py::object numpy_value = py::object(py::handle(other_obj), true); - other_tensor = paddle::Tensor(place); - InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + auto& self_tensor_ref = self->tensor; + auto& other_tensor_ref = CastPyArg2Tensor(other_obj, 0); + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor_ref, other_tensor_ref)) { + ConvertAllInputsToDistTensor(mesh, self_tensor_ref, other_tensor_ref); + } + self_tensor = self_tensor_ref; + other_tensor = other_tensor_ref; } else { - paddle::experimental::Scalar value = - CastPyArg2Scalar(other_obj, "__ge__", 0); - if (PyComplex_Check(other_obj)) { - eager_gil_scoped_release guard; - other_tensor = - full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); } else { - eager_gil_scoped_release guard; - other_tensor = full_ad_func( - self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__ge__", 0); + if (PyComplex_Check(other_obj)) { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + } else { + eager_gil_scoped_release guard; + other_tensor = full_ad_func(self_tensor.shape(), + value, + self_tensor.dtype(), + self_tensor.place()); + } + } + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { + ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - } - - const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { - ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } // 3. promote types or unify right var type to left var @@ -1067,8 +1152,8 @@ static PyObject* tensor__mod__method(TensorObject* self, SetDevice(place); paddle::Tensor ret; - paddle::Tensor self_tensor = self->tensor; + paddle::Tensor self_tensor = self->tensor; PyObject* other_obj = PyTuple_GET_ITEM(args, 0); // 1. scalar exists cases @@ -1100,28 +1185,38 @@ static PyObject* tensor__mod__method(TensorObject* self, self_tensor.dtype(), self_tensor.place()); } else if (PyCheckTensor(other_obj)) { - other_tensor = CastPyArg2Tensor(other_obj, 0); - } else if (IsNumpyArray(other_obj)) { - py::object numpy_value = py::object(py::handle(other_obj), true); - other_tensor = paddle::Tensor(place); - InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + auto& self_tensor_ref = self->tensor; + auto& other_tensor_ref = CastPyArg2Tensor(other_obj, 0); + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor_ref, other_tensor_ref)) { + ConvertAllInputsToDistTensor(mesh, self_tensor_ref, other_tensor_ref); + } + self_tensor = self_tensor_ref; + other_tensor = other_tensor_ref; } else { - paddle::experimental::Scalar value = - CastPyArg2Scalar(other_obj, "__mod__", 0); - if (PyComplex_Check(other_obj)) { - eager_gil_scoped_release guard; - other_tensor = - full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); } else { - eager_gil_scoped_release guard; - other_tensor = full_ad_func( - self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__mod__", 0); + if (PyComplex_Check(other_obj)) { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + } else { + eager_gil_scoped_release guard; + other_tensor = full_ad_func(self_tensor.shape(), + value, + self_tensor.dtype(), + self_tensor.place()); + } + } + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { + ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - } - - const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { - ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } // 3. promote types or unify right var type to left var @@ -1195,28 +1290,36 @@ static PyObject* tensor__matmul__method(TensorObject* self, self_tensor.dtype(), self_tensor.place()); } else if (PyCheckTensor(other_obj)) { - other_tensor = CastPyArg2Tensor(other_obj, 0); - } else if (IsNumpyArray(other_obj)) { - py::object numpy_value = py::object(py::handle(other_obj), true); - other_tensor = paddle::Tensor(place); - InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + auto& self_tensor_ref = self->tensor; + auto& other_tensor_ref = CastPyArg2Tensor(other_obj, 0); + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor_ref, other_tensor_ref)) { + ConvertAllInputsToDistTensor(mesh, self_tensor_ref, other_tensor_ref); + } + self_tensor = self_tensor_ref; + other_tensor = other_tensor_ref; } else { - paddle::experimental::Scalar value = - CastPyArg2Scalar(other_obj, "__matmul__", 0); - if (PyComplex_Check(other_obj)) { - eager_gil_scoped_release guard; - other_tensor = - full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); } else { - eager_gil_scoped_release guard; - other_tensor = - full_ad_func({1}, value, self_tensor.dtype(), self_tensor.place()); + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__matmul__", 0); + if (PyComplex_Check(other_obj)) { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + } else { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func({1}, value, self_tensor.dtype(), self_tensor.place()); + } + } + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { + ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - } - - const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { - ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } // 3. promote types or unify right var type to left var @@ -1308,28 +1411,38 @@ static PyObject* tensor__lt__method(TensorObject* self, self_tensor.dtype(), self_tensor.place()); } else if (PyCheckTensor(other_obj)) { - other_tensor = CastPyArg2Tensor(other_obj, 0); - } else if (IsNumpyArray(other_obj)) { - py::object numpy_value = py::object(py::handle(other_obj), true); - other_tensor = paddle::Tensor(place); - InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + auto& self_tensor_ref = self->tensor; + auto& other_tensor_ref = CastPyArg2Tensor(other_obj, 0); + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor_ref, other_tensor_ref)) { + ConvertAllInputsToDistTensor(mesh, self_tensor_ref, other_tensor_ref); + } + self_tensor = self_tensor_ref; + other_tensor = other_tensor_ref; } else { - paddle::experimental::Scalar value = - CastPyArg2Scalar(other_obj, "__lt__", 0); - if (PyComplex_Check(other_obj)) { - eager_gil_scoped_release guard; - other_tensor = - full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); } else { - eager_gil_scoped_release guard; - other_tensor = full_ad_func( - self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__lt__", 0); + if (PyComplex_Check(other_obj)) { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + } else { + eager_gil_scoped_release guard; + other_tensor = full_ad_func(self_tensor.shape(), + value, + self_tensor.dtype(), + self_tensor.place()); + } + } + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { + ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - } - - const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { - ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } // 3. promote types or unify right var type to left var @@ -1403,28 +1516,38 @@ static PyObject* tensor__le__method(TensorObject* self, self_tensor.dtype(), self_tensor.place()); } else if (PyCheckTensor(other_obj)) { - other_tensor = CastPyArg2Tensor(other_obj, 0); - } else if (IsNumpyArray(other_obj)) { - py::object numpy_value = py::object(py::handle(other_obj), true); - other_tensor = paddle::Tensor(place); - InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + auto& self_tensor_ref = self->tensor; + auto& other_tensor_ref = CastPyArg2Tensor(other_obj, 0); + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor_ref, other_tensor_ref)) { + ConvertAllInputsToDistTensor(mesh, self_tensor_ref, other_tensor_ref); + } + self_tensor = self_tensor_ref; + other_tensor = other_tensor_ref; } else { - paddle::experimental::Scalar value = - CastPyArg2Scalar(other_obj, "__le__", 0); - if (PyComplex_Check(other_obj)) { - eager_gil_scoped_release guard; - other_tensor = - full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); } else { - eager_gil_scoped_release guard; - other_tensor = full_ad_func( - self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__le__", 0); + if (PyComplex_Check(other_obj)) { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + } else { + eager_gil_scoped_release guard; + other_tensor = full_ad_func(self_tensor.shape(), + value, + self_tensor.dtype(), + self_tensor.place()); + } + } + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { + ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - } - - const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { - ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } // 3. promote types or unify right var type to left var @@ -1499,28 +1622,36 @@ static PyObject* tensor__floordiv__method(TensorObject* self, self_tensor.dtype(), self_tensor.place()); } else if (PyCheckTensor(other_obj)) { - other_tensor = CastPyArg2Tensor(other_obj, 0); - } else if (IsNumpyArray(other_obj)) { - py::object numpy_value = py::object(py::handle(other_obj), true); - other_tensor = paddle::Tensor(place); - InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + auto& self_tensor_ref = self->tensor; + auto& other_tensor_ref = CastPyArg2Tensor(other_obj, 0); + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor_ref, other_tensor_ref)) { + ConvertAllInputsToDistTensor(mesh, self_tensor_ref, other_tensor_ref); + } + self_tensor = self_tensor_ref; + other_tensor = other_tensor_ref; } else { - paddle::experimental::Scalar value = - CastPyArg2Scalar(other_obj, "__floordiv__", 0); - if (PyComplex_Check(other_obj)) { - eager_gil_scoped_release guard; - other_tensor = - full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); } else { - eager_gil_scoped_release guard; - other_tensor = - full_ad_func({1}, value, self_tensor.dtype(), self_tensor.place()); + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__floordiv__", 0); + if (PyComplex_Check(other_obj)) { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + } else { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func({1}, value, self_tensor.dtype(), self_tensor.place()); + } + } + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { + ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - } - - const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { - ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } // 3. promote types or unify right var type to left var @@ -1593,28 +1724,36 @@ static PyObject* tensor__pow__method(TensorObject* self, // 2. create or get tensor for other_obj paddle::Tensor other_tensor; if (PyCheckTensor(other_obj)) { - other_tensor = CastPyArg2Tensor(other_obj, 0); - } else if (IsNumpyArray(other_obj)) { - py::object numpy_value = py::object(py::handle(other_obj), true); - other_tensor = paddle::Tensor(place); - InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + auto& self_tensor_ref = self->tensor; + auto& other_tensor_ref = CastPyArg2Tensor(other_obj, 0); + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor_ref, other_tensor_ref)) { + ConvertAllInputsToDistTensor(mesh, self_tensor_ref, other_tensor_ref); + } + self_tensor = self_tensor_ref; + other_tensor = other_tensor_ref; } else { - paddle::experimental::Scalar value = - CastPyArg2Scalar(other_obj, "__pow__", 0); - if (PyComplex_Check(other_obj)) { - eager_gil_scoped_release guard; - other_tensor = - full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); } else { - eager_gil_scoped_release guard; - other_tensor = - full_ad_func({1}, value, self_tensor.dtype(), self_tensor.place()); + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__pow__", 0); + if (PyComplex_Check(other_obj)) { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + } else { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func({1}, value, self_tensor.dtype(), self_tensor.place()); + } + } + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { + ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - } - - const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { - ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } // 3. promote types or unify right var type to left var @@ -1690,28 +1829,38 @@ static PyObject* tensor__rpow__method(TensorObject* self, self_tensor.dtype(), self_tensor.place()); } else if (PyCheckTensor(other_obj)) { - other_tensor = CastPyArg2Tensor(other_obj, 0); - } else if (IsNumpyArray(other_obj)) { - py::object numpy_value = py::object(py::handle(other_obj), true); - other_tensor = paddle::Tensor(place); - InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + auto& self_tensor_ref = self->tensor; + auto& other_tensor_ref = CastPyArg2Tensor(other_obj, 0); + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor_ref, other_tensor_ref)) { + ConvertAllInputsToDistTensor(mesh, self_tensor_ref, other_tensor_ref); + } + self_tensor = self_tensor_ref; + other_tensor = other_tensor_ref; } else { - paddle::experimental::Scalar value = - CastPyArg2Scalar(other_obj, "__rpow__", 0); - if (PyComplex_Check(other_obj)) { - eager_gil_scoped_release guard; - other_tensor = - full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); } else { - eager_gil_scoped_release guard; - other_tensor = full_ad_func( - self_tensor.shape(), value, self_tensor.dtype(), self_tensor.place()); + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__rpow__", 0); + if (PyComplex_Check(other_obj)) { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + } else { + eager_gil_scoped_release guard; + other_tensor = full_ad_func(self_tensor.shape(), + value, + self_tensor.dtype(), + self_tensor.place()); + } + } + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { + ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - } - - const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { - ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } // 3. promote types or unify right var type to left var @@ -1785,28 +1934,36 @@ static PyObject* tensor__ne__method(TensorObject* self, self_tensor.dtype(), self_tensor.place()); } else if (PyCheckTensor(other_obj)) { - other_tensor = CastPyArg2Tensor(other_obj, 0); - } else if (IsNumpyArray(other_obj)) { - py::object numpy_value = py::object(py::handle(other_obj), true); - other_tensor = paddle::Tensor(place); - InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + auto& self_tensor_ref = self->tensor; + auto& other_tensor_ref = CastPyArg2Tensor(other_obj, 0); + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor_ref, other_tensor_ref)) { + ConvertAllInputsToDistTensor(mesh, self_tensor_ref, other_tensor_ref); + } + self_tensor = self_tensor_ref; + other_tensor = other_tensor_ref; } else { - paddle::experimental::Scalar value = - CastPyArg2Scalar(other_obj, "__ne__", 0); - if (PyComplex_Check(other_obj)) { - eager_gil_scoped_release guard; - other_tensor = - full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); } else { - eager_gil_scoped_release guard; - other_tensor = - full_ad_func({1}, value, self_tensor.dtype(), self_tensor.place()); + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__ne__", 0); + if (PyComplex_Check(other_obj)) { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + } else { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func({1}, value, self_tensor.dtype(), self_tensor.place()); + } + } + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { + ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - } - - const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { - ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } // 3. promote types or unify right var type to left var @@ -1880,28 +2037,36 @@ static PyObject* tensor__eq__method(TensorObject* self, self_tensor.dtype(), self_tensor.place()); } else if (PyCheckTensor(other_obj)) { - other_tensor = CastPyArg2Tensor(other_obj, 0); - } else if (IsNumpyArray(other_obj)) { - py::object numpy_value = py::object(py::handle(other_obj), true); - other_tensor = paddle::Tensor(place); - InitTensorWithNumpyValue(numpy_value, place, &other_tensor); + auto& self_tensor_ref = self->tensor; + auto& other_tensor_ref = CastPyArg2Tensor(other_obj, 0); + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor_ref, other_tensor_ref)) { + ConvertAllInputsToDistTensor(mesh, self_tensor_ref, other_tensor_ref); + } + self_tensor = self_tensor_ref; + other_tensor = other_tensor_ref; } else { - paddle::experimental::Scalar value = - CastPyArg2Scalar(other_obj, "__eq__", 0); - if (PyComplex_Check(other_obj)) { - eager_gil_scoped_release guard; - other_tensor = - full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + if (IsNumpyArray(other_obj)) { + py::object numpy_value = py::object(py::handle(other_obj), true); + other_tensor = paddle::Tensor(place); + InitTensorWithNumpyValue(numpy_value, place, &other_tensor); } else { - eager_gil_scoped_release guard; - other_tensor = - full_ad_func({1}, value, self_tensor.dtype(), self_tensor.place()); + paddle::experimental::Scalar value = + CastPyArg2Scalar(other_obj, "__eq__", 0); + if (PyComplex_Check(other_obj)) { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); + } else { + eager_gil_scoped_release guard; + other_tensor = + full_ad_func({1}, value, self_tensor.dtype(), self_tensor.place()); + } + } + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { + ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - } - - const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { - ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } // 3. promote types or unify right var type to left var diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 134b164409a95f..d4047b76b6a683 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -1174,7 +1174,7 @@ def _create_optimization_pass( # need to filter again here. if ( param_and_grad[1] is None - or not param_and_grad[1]._is_initialized() + or not param_and_grad[0]._is_initialized() ): continue if param_and_grad[0].stop_gradient is False: @@ -1185,7 +1185,7 @@ def _create_optimization_pass( for param_and_grad in parameters_and_grads['params']: if ( param_and_grad[1] is None - or not param_and_grad[1]._is_initialized() + or not param_and_grad[0]._is_initialized() ): continue if param_and_grad[0].stop_gradient is False: From 51c869e88eb668dccd864f6766ae973eb1064581 Mon Sep 17 00:00:00 2001 From: Terry <38135104+TR666@users.noreply.github.com> Date: Thu, 4 Jan 2024 18:03:05 +0800 Subject: [PATCH 113/142] [xpu]Add vis_decoder_attention_xpu_pass && modify qkv_attention_xpu_kernel (#60361) --- paddle/fluid/framework/ir/CMakeLists.txt | 2 + .../ir/xpu/decoder_attention_xpu_fuse_pass.cc | 313 ++++++++++++++++++ .../ir/xpu/decoder_attention_xpu_fuse_pass.h | 104 ++++++ .../ir/xpu/qk_qkv_attention_xpu_fuse_pass.cc | 5 +- .../inference/api/paddle_pass_builder.cc | 1 + paddle/phi/api/yaml/fused_ops.yaml | 2 +- paddle/phi/infermeta/fusion.cc | 16 +- paddle/phi/infermeta/fusion.h | 1 + .../fusion/xpu/qkv_attention_xpu_kernel.cc | 14 +- ...est_xpu_decoder_attention_xpu_fuse_pass.py | 172 ++++++++++ 10 files changed, 617 insertions(+), 13 deletions(-) create mode 100644 paddle/fluid/framework/ir/xpu/decoder_attention_xpu_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/xpu/decoder_attention_xpu_fuse_pass.h create mode 100644 test/ir/inference/test_xpu_decoder_attention_xpu_fuse_pass.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 35f5ba1522368e..94b13f7c4fa51e 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -272,6 +272,8 @@ if(WITH_XPU) ${XPU_PASS_DEPS}) pass_library(qk_qkv_attention_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(decoder_attention_xpu_fuse_pass inference DIR xpu DEPS + ${XPU_PASS_DEPS}) pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(multi_encoder_xpu_adaptive_seqlen_fuse_pass inference DIR xpu diff --git a/paddle/fluid/framework/ir/xpu/decoder_attention_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/decoder_attention_xpu_fuse_pass.cc new file mode 100644 index 00000000000000..ad8dd1a55a8686 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/decoder_attention_xpu_fuse_pass.cc @@ -0,0 +1,313 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/xpu/decoder_attention_xpu_fuse_pass.h" + +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +namespace patterns { + +struct DecoderAttentionFusePattern : public PatternBase { + DecoderAttentionFusePattern(PDPattern* pattern, + const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(reshape2_1); + PATTERN_DECL_NODE(reshape2_2); + PATTERN_DECL_NODE(reshape2_3); + PATTERN_DECL_NODE(transpose2_1); + PATTERN_DECL_NODE(transpose2_2); + PATTERN_DECL_NODE(transpose2_3); + PATTERN_DECL_NODE(qk_matmul); + PATTERN_DECL_NODE(scale); + PATTERN_DECL_NODE(qk_softmax); + PATTERN_DECL_NODE(qkv_matmul); + PATTERN_DECL_NODE(transpose2_4); + PATTERN_DECL_NODE(reshape2_4); + + // declare variable node's name + PATTERN_DECL_NODE(input_q); + PATTERN_DECL_NODE(input_k); + PATTERN_DECL_NODE(input_v); + PATTERN_DECL_NODE(reshape2_1_out); + PATTERN_DECL_NODE(reshape2_2_out); + PATTERN_DECL_NODE(reshape2_3_out); + PATTERN_DECL_NODE(transpose2_1_out); + PATTERN_DECL_NODE(transpose2_2_out); + PATTERN_DECL_NODE(transpose2_3_out); + PATTERN_DECL_NODE(qk_matmul_out); + PATTERN_DECL_NODE(scale_out); + PATTERN_DECL_NODE(qk_softmax_out); + PATTERN_DECL_NODE(qkv_matmul_out); + PATTERN_DECL_NODE(transpose2_4_out); + PATTERN_DECL_NODE(output); +}; + +DecoderAttentionFusePattern::DecoderAttentionFusePattern( + PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* input_q = pattern->NewNode(input_q_repr()) + ->assert_is_op_input("reshape2", "X") + ->AsInput(); + auto* input_k = pattern->NewNode(input_k_repr()) + ->assert_is_op_input("reshape2", "X") + ->AsInput(); + auto* input_v = pattern->NewNode(input_v_repr()) + ->assert_is_op_input("reshape2", "X") + ->AsInput(); + auto* reshape2_1 = + pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2"); + auto* reshape2_1_out = pattern->NewNode(reshape2_1_out_repr()) + ->assert_is_op_output("reshape2", "Out") + ->assert_is_op_input("transpose2", "X"); + auto* reshape2_2 = + pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2"); + auto* reshape2_2_out = pattern->NewNode(reshape2_2_out_repr()) + ->assert_is_op_output("reshape2", "Out") + ->assert_is_op_input("transpose2", "X"); + auto* reshape2_3 = + pattern->NewNode(reshape2_3_repr())->assert_is_op("reshape2"); + auto* reshape2_3_out = pattern->NewNode(reshape2_3_out_repr()) + ->assert_is_op_output("reshape2", "Out") + ->assert_is_op_input("transpose2", "X"); + auto* transpose2_1 = + pattern->NewNode(transpose2_1_repr()) + ->assert_is_op("transpose2") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto axis = op_desc->GetAttrIfExists>("axis"); + size_t axis_rank = axis.size(); + return axis_rank == 4 && axis[0] == 0 && axis[1] == 2 && + axis[2] == 1 && axis[3] == 3; + }); + + auto* transpose2_1_out = pattern->NewNode(transpose2_1_out_repr()) + ->assert_is_op_output("transpose2", "Out") + ->assert_is_op_input("matmul_v2", "X"); + auto* transpose2_2 = + pattern->NewNode(transpose2_2_repr()) + ->assert_is_op("transpose2") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto axis = op_desc->GetAttrIfExists>("axis"); + size_t axis_rank = axis.size(); + return axis_rank == 4 && axis[0] == 0 && axis[1] == 2 && + axis[2] == 1 && axis[3] == 3; + }); + auto* transpose2_2_out = pattern->NewNode(transpose2_2_out_repr()) + ->assert_is_op_output("transpose2", "Out") + ->assert_is_op_input("matmul_v2", "Y"); + auto* transpose2_3 = + pattern->NewNode(transpose2_3_repr()) + ->assert_is_op("transpose2") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto axis = op_desc->GetAttrIfExists>("axis"); + size_t axis_rank = axis.size(); + return axis_rank == 4 && axis[0] == 0 && axis[1] == 2 && + axis[2] == 1 && axis[3] == 3; + }); + auto* transpose2_3_out = pattern->NewNode(transpose2_3_out_repr()) + ->assert_is_op_output("transpose2", "Out") + ->assert_is_op_input("matmul_v2", "Y"); + auto* qk_matmul = + pattern->NewNode(qk_matmul_repr())->assert_is_op("matmul_v2"); + auto* qk_matmul_out = pattern->NewNode(qk_matmul_out_repr()) + ->assert_is_op_output("matmul_v2", "Out") + ->assert_is_op_input("scale", "X"); + auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale"); + auto* scale_out = pattern->NewNode(scale_out_repr()) + ->assert_is_op_output("scale", "Out") + ->assert_is_op_input("softmax", "X"); + auto* qk_softmax = + pattern->NewNode(qk_softmax_repr())->assert_is_op("softmax"); + auto* qk_softmax_out = pattern->NewNode(qk_softmax_out_repr()) + ->assert_is_op_output("softmax", "Out") + ->assert_is_op_input("matmul_v2", "X"); + auto* qkv_matmul = + pattern->NewNode(qkv_matmul_repr())->assert_is_op("matmul_v2"); + auto* qkv_matmul_out = pattern->NewNode(qkv_matmul_out_repr()) + ->assert_is_op_output("matmul_v2", "Out") + ->assert_is_op_input("transpose2", "X"); + auto* transpose2_4 = + pattern->NewNode(transpose2_4_repr())->assert_is_op("transpose2"); + auto* transpose2_4_out = pattern->NewNode(transpose2_4_out_repr()) + ->assert_is_op_output("transpose2", "Out") + ->assert_is_op_input("reshape2", "X"); + auto* reshape2_4 = + pattern->NewNode(reshape2_4_repr())->assert_is_op("reshape2"); + auto* output = pattern->NewNode(output_repr()) + ->AsOutput() + ->assert_is_op_output("reshape2", "Out"); + + // link nodes + reshape2_1->LinksFrom({input_q}).LinksTo({reshape2_1_out}); + reshape2_2->LinksFrom({input_k}).LinksTo({reshape2_2_out}); + reshape2_3->LinksFrom({input_v}).LinksTo({reshape2_3_out}); + transpose2_1->LinksFrom({reshape2_1_out}).LinksTo({transpose2_1_out}); + transpose2_2->LinksFrom({reshape2_2_out}).LinksTo({transpose2_2_out}); + transpose2_3->LinksFrom({reshape2_3_out}).LinksTo({transpose2_3_out}); + qk_matmul->LinksFrom({transpose2_1_out, transpose2_2_out}) + .LinksTo({qk_matmul_out}); + scale->LinksFrom({qk_matmul_out}).LinksTo({scale_out}); + qk_softmax->LinksFrom({scale_out}).LinksTo({qk_softmax_out}); + qkv_matmul->LinksFrom({qk_softmax_out, transpose2_3_out}) + .LinksTo({qkv_matmul_out}); + transpose2_4->LinksFrom({qkv_matmul_out}).LinksTo({transpose2_4_out}); + reshape2_4->LinksFrom({transpose2_4_out}).LinksTo({output}); +} + +} // namespace patterns + +void DecoderAttentionXPUFusePass::ApplyDecoderAttentionXPUFuse( + ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::DecoderAttentionFusePattern pattern(gpd.mutable_pattern(), + name_scope_); + int found_subgraph_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle DecoderAttentionXPUFusePass"; + + // declare operator node's name + GET_IR_NODE(reshape2_1); + GET_IR_NODE(reshape2_2); + GET_IR_NODE(reshape2_3); + GET_IR_NODE(transpose2_1); + GET_IR_NODE(transpose2_2); + GET_IR_NODE(transpose2_3); + GET_IR_NODE(qk_matmul); + GET_IR_NODE(scale); + GET_IR_NODE(qk_softmax); + GET_IR_NODE(qkv_matmul); + GET_IR_NODE(transpose2_4); + GET_IR_NODE(reshape2_4); + + // declare variable node's name + GET_IR_NODE(input_q); + GET_IR_NODE(input_k); + GET_IR_NODE(input_v); + GET_IR_NODE(reshape2_1_out); + GET_IR_NODE(reshape2_2_out); + GET_IR_NODE(reshape2_3_out); + GET_IR_NODE(transpose2_1_out); + GET_IR_NODE(transpose2_2_out); + GET_IR_NODE(transpose2_3_out); + GET_IR_NODE(qk_matmul_out); + GET_IR_NODE(scale_out); + GET_IR_NODE(qk_softmax_out); + GET_IR_NODE(qkv_matmul_out); + GET_IR_NODE(transpose2_4_out); + GET_IR_NODE(output); + + // Generate fuse op + auto* block = reshape2_1->Op()->Block(); + framework::OpDesc fused_op_desc(block); + fused_op_desc.SetType("qkv_attention_xpu"); + + // set input of fuse_op + fused_op_desc.SetInput("q", {input_q->Name()}); + fused_op_desc.SetInput("k", {input_k->Name()}); + fused_op_desc.SetInput("v", {input_v->Name()}); + + // set attributes of fuse_op + float scale_val = PADDLE_GET_CONST(float, scale->Op()->GetAttr("scale")); + fused_op_desc.SetAttr("alpha", scale_val); + fused_op_desc.SetAttr( + "head_num", static_cast(transpose2_1_out->Var()->GetShape()[1])); + fused_op_desc.SetAttr( + "head_dim", static_cast(transpose2_1_out->Var()->GetShape()[3])); + // In this pattern, there is only one possible situation. + fused_op_desc.SetAttr("qkv_fc_fusion", false); + + // TODO(tianrui): support more out_dtype + fused_op_desc.SetAttr("out_dtype", input_q->Var()->GetDataType()); + + // set output of fuse_op + VarDesc fused_op_out_max_desc("qkv_max"); + Node* fused_op_out_max = graph->CreateVarNode(&fused_op_out_max_desc); + fused_op_desc.SetOutput("qkv_max", {"qkv_max"}); + fused_op_desc.SetOutput("qkv", {output->Name()}); + + auto* fused_op = graph->CreateOpNode(&fused_op_desc); + + IR_NODE_LINK_TO(input_q, fused_op); + IR_NODE_LINK_TO(input_k, fused_op); + IR_NODE_LINK_TO(input_v, fused_op); + IR_NODE_LINK_TO(fused_op, output); + IR_NODE_LINK_TO(fused_op, fused_op_out_max); + + // delete useless node + std::unordered_set del_node_set; + del_node_set.insert(reshape2_1); + del_node_set.insert(reshape2_2); + del_node_set.insert(reshape2_3); + del_node_set.insert(transpose2_1); + del_node_set.insert(transpose2_2); + del_node_set.insert(transpose2_3); + del_node_set.insert(qk_matmul); + del_node_set.insert(scale); + del_node_set.insert(qk_softmax); + del_node_set.insert(qkv_matmul); + del_node_set.insert(transpose2_4); + del_node_set.insert(reshape2_4); + del_node_set.insert(reshape2_1_out); + del_node_set.insert(reshape2_2_out); + del_node_set.insert(reshape2_3_out); + del_node_set.insert(transpose2_1_out); + del_node_set.insert(transpose2_2_out); + del_node_set.insert(transpose2_3_out); + del_node_set.insert(qk_matmul_out); + del_node_set.insert(scale_out); + del_node_set.insert(qk_softmax_out); + del_node_set.insert(qkv_matmul_out); + del_node_set.insert(transpose2_4_out); + + GraphSafeRemoveNodes(graph, del_node_set); + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +void DecoderAttentionXPUFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + ApplyDecoderAttentionXPUFuse(graph); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(decoder_attention_xpu_fuse_pass, + paddle::framework::ir::DecoderAttentionXPUFusePass); + +REGISTER_PASS_CAPABILITY(decoder_attention_xpu_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "qkv_attention_xpu", 0)); diff --git a/paddle/fluid/framework/ir/xpu/decoder_attention_xpu_fuse_pass.h b/paddle/fluid/framework/ir/xpu/decoder_attention_xpu_fuse_pass.h new file mode 100644 index 00000000000000..c41e455f2acc48 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/decoder_attention_xpu_fuse_pass.h @@ -0,0 +1,104 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +/* +This pass is used to fuse the QKV attention subgraph into one op in decoder +module of visual models . + +Origin subgraph: + + v q k + | | | + | | | + | | | + reshape reshape reshape + | | | + | | | + | | | + transpose transpose transpose + | | | + | \ / + | \ / + | qk_matmul + | | + | | + | | + | scale + | | + | | + | | + \ qk_softmax + \ | + \ / + \ / + qkv_matmul + | + | + | + transpose + | + | + | + reshape + | + | + | + output + +------------------------------------------------------- +Fused subgraph: + q k v + \ | / + \ | / + \ | / + qkv_attention_xpu + | + | + | + output + +*/ + +class DecoderAttentionXPUFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + void ApplyDecoderAttentionXPUFuse(ir::Graph* graph) const; + + const std::string name_scope_{"vis_decoder_attention_xpu_fuse_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/qk_qkv_attention_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/qk_qkv_attention_xpu_fuse_pass.cc index 7cbb12de5d2e74..2ca1d081aab89d 100644 --- a/paddle/fluid/framework/ir/xpu/qk_qkv_attention_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/qk_qkv_attention_xpu_fuse_pass.cc @@ -232,10 +232,13 @@ void QkQkvAttentionXPUFusePass::ApplyQkQkvAttentionXPUFuse( "head_num", static_cast(transpose2_1_out->Var()->GetShape()[2])); fused_op_desc.SetAttr( "head_dim", static_cast(transpose2_1_out->Var()->GetShape()[4])); + // In this pattern, there is only one possible situation. + fused_op_desc.SetAttr("qkv_fc_fusion", true); + // TODO(tianrui): support more out_dtype fused_op_desc.SetAttr("out_dtype", input->Var()->GetDataType()); - // set input of fuse_op + // set output of fuse_op VarDesc fused_op_out_max_desc("qkv_max"); Node* fused_op_out_max = graph->CreateVarNode(&fused_op_out_max_desc); fused_op_desc.SetOutput("qkv_max", {"qkv_max"}); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 0a0e6b591ef899..fa979a46a19627 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -543,6 +543,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "multi_encoder_xpu_slice_fuse_pass", "fused_multi_transformer_cachekv_layout_trans_pass", "fused_multi_transformer_int8_cachekv_layout_trans_pass", + "decoder_attention_xpu_fuse_pass", "one_beam_size_fuse_pass", "fold_interp_outsize_fuse_pass", "fold_two_squeeze2_fuse_pass", diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index f1d253945139ed..1b429fc958de7e 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -404,7 +404,7 @@ optional : bias_qk - op : qkv_attention_xpu - args : (Tensor q, Tensor k, Tensor v, Tensor q_max, Tensor k_max, Tensor v_max, float alpha, int head_num, int head_dim, DataType out_dtype) + args : (Tensor q, Tensor k, Tensor v, Tensor q_max, Tensor k_max, Tensor v_max, float alpha, int head_num, int head_dim, bool qkv_fc_fusion, DataType out_dtype) output : Tensor(qkv), Tensor(qkv_max) infer_meta : func : QKVAttentionXPUInferMeta diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 41329efaa86d53..d25c41e4538942 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -3640,13 +3640,14 @@ void QKVAttentionXPUInferMeta(const MetaTensor& q, float alpha, int head_num, int head_dim, + bool qkv_fc_fusion, DataType out_dtype, MetaTensor* qkv, MetaTensor* qkv_max) { auto q_dims = q.dims(); auto k_dims = k.dims(); auto v_dims = v.dims(); - // input shape : {B, L, 3*H*D} + // input shape : {B, L, 3*H*D} or {B, L, H*D} PADDLE_ENFORCE_EQ(q_dims.size(), 3, phi::errors::InvalidArgument("The dim of q should be 3! " @@ -3671,13 +3672,16 @@ void QKVAttentionXPUInferMeta(const MetaTensor& q, phi::errors::InvalidArgument("The shape of k , v should be the same! " "But received .")); } + int hidden_dim = + qkv_fc_fusion ? 3 * head_num * head_dim : head_num * head_dim; PADDLE_ENFORCE_EQ( q_dims[2], - 3 * head_num * head_dim, - phi::errors::InvalidArgument("To support do_fc_qkv_fusion," - "The shape of q should be [B, L, 3*H*D]! " - "But received q_dims[2]: [%d] != 3*H*D.", - q_dims[2])); + hidden_dim, + phi::errors::InvalidArgument( + "The shape of q should be [B, L, H*D] or [B, L, 3*H*D]! " + "But received q_dims[2]: [%d] != expected hidden_dim : [%d].", + q_dims[2], + hidden_dim)); // output shape: {B, L, HD} qkv->set_dims(phi::make_ddim({q_dims[0], q_dims[1], head_num * head_dim})); diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index e294e67aa1c951..c5d3981b66f40e 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -831,6 +831,7 @@ void QKVAttentionXPUInferMeta(const MetaTensor& q, float alpha, int head_num, int head_dim, + bool qkv_fc_fusion, DataType out_dtype, MetaTensor* qkv, MetaTensor* qkv_max); diff --git a/paddle/phi/kernels/fusion/xpu/qkv_attention_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/qkv_attention_xpu_kernel.cc index 02ff19da4e259d..b08921e750a80c 100644 --- a/paddle/phi/kernels/fusion/xpu/qkv_attention_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/qkv_attention_xpu_kernel.cc @@ -32,6 +32,7 @@ void QKVAttentionXPUKernelImpl(const Context& ctx, float alpha, int head_num, int head_dim, + bool qkv_fc_fusion, DenseTensor* qkv, DenseTensor* qkv_max) { using XPUTypeX = typename XPUTypeTrait::Type; @@ -40,8 +41,10 @@ void QKVAttentionXPUKernelImpl(const Context& ctx, auto* q_data = reinterpret_cast(q.data()); auto* k_data = reinterpret_cast(k.data()); auto* v_data = reinterpret_cast(v.data()); - k_data += head_num * head_dim; - v_data += 2 * head_num * head_dim; + if (qkv_fc_fusion) { + k_data += head_num * head_dim; + v_data += 2 * head_num * head_dim; + } const float* q_max_data = q_max.get_ptr() == nullptr ? nullptr : q_max.get_ptr()->data(); const float* k_max_data = @@ -54,9 +57,7 @@ void QKVAttentionXPUKernelImpl(const Context& ctx, auto* qkv_max_data = ctx.template Alloc(qkv_max); int batch = q.dims()[0]; int max_seq_len = q.dims()[1]; - // int qkv_shape = 2; // B x H x L x D int qkv_shape = 0; // B x L x H x D - bool do_fc_qkv_fusion = true; int hidden_dim = head_num * head_dim; // no mask input, construct a fake LOD to compute via vsl std::vector lod; @@ -69,7 +70,8 @@ void QKVAttentionXPUKernelImpl(const Context& ctx, qkv_attn_param.qkv_shape = qkv_shape; qkv_attn_param.hidden_dim = hidden_dim; qkv_attn_param.alpha = alpha; - qkv_attn_param.do_fc_qkv_fusion = do_fc_qkv_fusion; + qkv_attn_param.do_fc_qkv_fusion = qkv_fc_fusion; + // TODO(tianrui): ctrl by env // This feature may cause precision diff, // but it is more efficient, especially in long seqL cases @@ -126,6 +128,7 @@ void QKVAttentionXPUKernelImpl(const Context& ctx, alpha, \ head_num, \ head_dim, \ + qkv_fc_fusion, \ qkv, \ qkv_max); @@ -140,6 +143,7 @@ void QKVAttentionXPUKernel(const Context& ctx, float alpha, int head_num, int head_dim, + bool qkv_fc_fusion, DataType qkv_dtype, DenseTensor* qkv, DenseTensor* qkv_max) { diff --git a/test/ir/inference/test_xpu_decoder_attention_xpu_fuse_pass.py b/test/ir/inference/test_xpu_decoder_attention_xpu_fuse_pass.py new file mode 100644 index 00000000000000..a7c5e069e23d1a --- /dev/null +++ b/test/ir/inference/test_xpu_decoder_attention_xpu_fuse_pass.py @@ -0,0 +1,172 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import unittest + +import hypothesis.strategies as st +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestDecoderAttentionXPUFusePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ["qkv_attention_xpu"], (1e-1, 1e-1) + + def sample_program_config(self, draw): + # set input shape + batch_size = draw(st.integers(min_value=1, max_value=50)) + seqlen = draw(st.integers(min_value=100, max_value=2000)) + input_shape = [batch_size, seqlen, 256] + + # Here we will compose a program + # Still has some risks that the program is invalid or cause bug while running + # Use function `is_program_valid` to filter the invalid programs before running + # Use function `add_skip_pass_case` to ignore the programs even if they cause bug while runing + reshape2_1_op = OpConfig( + "reshape2", + inputs={"X": ["input_q"]}, + outputs={ + "Out": ["reshape2_1_out"], + "XShape": ["reshape2_1_xshape"], + }, + shape=[0, 0, 8, 32], + ) + reshape2_2_op = OpConfig( + "reshape2", + inputs={"X": ["input_k"]}, + outputs={ + "Out": ["reshape2_2_out"], + "XShape": ["reshape2_2_xshape"], + }, + shape=[0, 0, 8, 32], + ) + reshape2_3_op = OpConfig( + "reshape2", + inputs={"X": ["input_v"]}, + outputs={ + "Out": ["reshape2_3_out"], + "XShape": ["reshape2_3_xshape"], + }, + shape=[0, 0, 8, 32], + ) + transpose2_1_op = OpConfig( + "transpose2", + inputs={"X": ["reshape2_1_out"]}, + outputs={ + "Out": ["transpose2_1_out"], + "XShape": ["transpose2_1_xshape"], + }, + axis=[0, 2, 1, 3], + ) + transpose2_2_op = OpConfig( + "transpose2", + inputs={"X": ["reshape2_2_out"]}, + outputs={ + "Out": ["transpose2_2_out"], + "XShape": ["transpose2_2_xshape"], + }, + axis=[0, 2, 1, 3], + ) + transpose2_3_op = OpConfig( + "transpose2", + inputs={"X": ["reshape2_3_out"]}, + outputs={ + "Out": ["transpose2_3_out"], + "XShape": ["transpose2_3_xshape"], + }, + axis=[0, 2, 1, 3], + ) + qk_matmul_op = OpConfig( + "matmul_v2", + inputs={"X": ["transpose2_1_out"], "Y": ["transpose2_2_out"]}, + outputs={"Out": ["qk_matmul_out"]}, + trans_x=False, + trans_y=True, + ) + scale_op = OpConfig( + "scale", + inputs={"X": ["qk_matmul_out"]}, + outputs={"Out": ["scale_out"]}, + scale=1 / math.sqrt(32), + bias=0, + bias_after_scale=True, + ) + qk_softmax_op = OpConfig( + "softmax", + inputs={"X": ["scale_out"]}, + outputs={"Out": ["qk_softmax_out"]}, + axis=-1, + ) + qkv_matmul_op = OpConfig( + "matmul_v2", + inputs={"X": ["qk_softmax_out"], "Y": ["transpose2_3_out"]}, + outputs={"Out": ["qkv_matmul_out"]}, + trans_x=False, + trans_y=False, + ) + transpose2_4_op = OpConfig( + "transpose2", + inputs={"X": ["qkv_matmul_out"]}, + outputs={ + "Out": ["transpose2_4_out"], + "XShape": ["transpose2_4_xshape"], + }, + axis=[0, 2, 1, 3], + ) + reshape2_4_op = OpConfig( + "reshape2", + inputs={"X": ["transpose2_4_out"]}, + outputs={"Out": ["output"], "XShape": ["reshape2_4_xshape"]}, + shape=[0, 0, 256], + ) + + ops = [ + reshape2_1_op, + reshape2_2_op, + reshape2_3_op, + transpose2_1_op, + transpose2_2_op, + transpose2_3_op, + qk_matmul_op, + scale_op, + qk_softmax_op, + qkv_matmul_op, + transpose2_4_op, + reshape2_4_op, + ] + + program_config = ProgramConfig( + ops=ops, + inputs={ + "input_q": TensorConfig(shape=input_shape), + "input_k": TensorConfig(shape=input_shape), + "input_v": TensorConfig(shape=input_shape), + }, + weights={}, + outputs=["output"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=25, + passes=["decoder_attention_xpu_fuse_pass"], + ) + + +if __name__ == "__main__": + unittest.main() From 0033033be6df47731221cb0a94a8612d0415d424 Mon Sep 17 00:00:00 2001 From: kevin Date: Thu, 4 Jan 2024 18:50:04 +0800 Subject: [PATCH 114/142] [Prim][PIR] support abs, instance_norm op backward in prim pir (#60444) * abs op backward * add test case * update code * update code * update code * update code * update code * instance_norm op backward * add instance_norm_v2 test cast * custom op --- paddle/fluid/primitive/codegen/gen.py | 2 + paddle/fluid/primitive/rule/vjp/details.h | 100 +++++++++++++++++++ test/legacy_test/test_activation_op.py | 21 +++- test/legacy_test/test_instance_norm_op.py | 6 +- test/legacy_test/test_instance_norm_op_v2.py | 9 ++ 5 files changed, 133 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index 005eae29593434..7a80f05784b0ff 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -55,6 +55,7 @@ # prim op with one input and one output, with no attribute UNARY_PRIM_VJP_OPS = [ + 'abs_grad', 'erf_grad', 'exp_grad', 'floor_grad', @@ -103,6 +104,7 @@ 'dropout_grad', 'gelu_grad', 'hardswish_grad', + 'instance_norm_grad', 'layer_norm_grad', 'leaky_relu_grad', 'relu_grad', diff --git a/paddle/fluid/primitive/rule/vjp/details.h b/paddle/fluid/primitive/rule/vjp/details.h index 1be68ba043e19f..716f8f7040aa4f 100644 --- a/paddle/fluid/primitive/rule/vjp/details.h +++ b/paddle/fluid/primitive/rule/vjp/details.h @@ -29,6 +29,14 @@ namespace paddle { namespace primitive { namespace details { +template +void abs_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { + if (x_grad) { + auto sign_tmp = sign(x); + set_output(out_grad * sign_tmp, x_grad); + } +} + template void assign_grad(const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { @@ -930,6 +938,98 @@ void gather_nd_grad(const Tensor& x, } } +template +void instance_norm_grad(const Tensor& x, + const paddle::optional& scale, + const Tensor& saved_mean, + const Tensor& saved_variance, + const Tensor& y_grad, + float epsilon, + Tensor* x_grad, + Tensor* scale_grad, + Tensor* bias_grad) { + const int n = x.dims()[0]; + const int c = x.dims()[1]; + const int h = x.dims()[2]; + const int w = x.dims()[3]; + + auto promoted_y_grad = y_grad; + if (x.dtype() == phi::DataType::FLOAT16 || + x.dtype() == phi::DataType::BFLOAT16) { + promoted_y_grad = cast(y_grad, phi::DataType::FLOAT32); + } + + Tensor x_hat; + Tensor std_inv; + if (scale_grad || x_grad) { + auto promoted_x = x; + auto promoted_saved_mean = saved_mean; + auto promoted_saved_var = saved_variance; + if (x.dtype() == phi::DataType::FLOAT16 || + x.dtype() == phi::DataType::BFLOAT16) { + promoted_x = cast(x, phi::DataType::FLOAT32); + promoted_saved_mean = cast(saved_mean, phi::DataType::FLOAT32); + promoted_saved_var = cast(saved_variance, phi::DataType::FLOAT32); + } + auto mean = reshape(promoted_saved_mean, IntArray({n, c, 1, 1})) + .tile(IntArray({1, 1, h, w})); + std_inv = reshape(promoted_saved_var, IntArray({n, c, 1, 1})) + .tile(IntArray({1, 1, h, w})); + x_hat = (promoted_x - mean) * std_inv; + } + + // x_grad = scale * inv_var * (y_grad - y_grad.mean(2,3) - x_hat * (y_grad * + // x_hat).mean((h,w))) + if (x_grad) { + auto scale_data = + reshape(scale.get_ptr() ? scale.get() + : full(IntArray({c}), 1., x.dtype()), + IntArray({1, c, 1, 1})) + .tile(IntArray({n, 1, h, w})); + auto promoted_scale = scale_data; + if (scale_data.dtype() == phi::DataType::FLOAT16 || + scale_data.dtype() == phi::DataType::BFLOAT16) { + promoted_scale = cast(scale_data, phi::DataType::FLOAT32); + } + auto result = + (promoted_scale * std_inv) * + (promoted_y_grad - + promoted_y_grad.sum(IntArray({2, 3}), promoted_y_grad.dtype(), true) / + (h * w) - + (x_hat * ((promoted_y_grad * x_hat) + .sum(IntArray({2, 3}), promoted_y_grad.dtype(), true) / + (h * w)))); + if (x.dtype() == phi::DataType::FLOAT16 || + x.dtype() == phi::DataType::BFLOAT16) { + set_output(cast(result, x.dtype()), x_grad); + } else { + set_output(result, x_grad); + } + } + // scale_grad = x_hat * y_grad.sum(n, h, w) + if (scale_grad) { + auto result = (promoted_y_grad * x_hat).sum(IntArray({0, 2, 3})); + auto scale_dtype = scale.get_ptr() ? scale.get().dtype() : x.dtype(); + if (scale_dtype == phi::DataType::FLOAT16 || + scale_dtype == phi::DataType::BFLOAT16) { + set_output(cast(result, scale_dtype), scale_grad); + } else { + set_output(result, scale_grad); + } + } + // d_bias = y_grad.sum(n, h, w) + if (bias_grad) { + auto result = promoted_y_grad.sum(IntArray({0, 2, 3})); + auto scale_dtype = scale.get_ptr() ? scale.get().dtype() : x.dtype(); + if (scale_dtype == phi::DataType::FLOAT16 || + scale_dtype == phi::DataType::BFLOAT16) { + set_output(cast(result, scale_dtype), bias_grad); + } else { + set_output(result, bias_grad); + } + } +} + template void pad_grad(const Tensor& input, const Tensor& out_grad, diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index f25a9ce3a78dcc..8a9379f528c1eb 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -474,7 +474,12 @@ def test_check_output(self): def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_pir=True + place, + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -1803,7 +1808,9 @@ def test_check_output(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) + self.check_grad( + ['X'], 'Out', check_prim=True, check_pir=True, check_prim_pir=True + ) class TestAbs_ZeroDim(TestAbs): @@ -4852,7 +4859,11 @@ def test_check_grad(self): check_prim_pir=True, ) create_test_act_fp16_class( - TestAbs, check_prim=True, enable_cinn=True, check_pir=True + TestAbs, + check_prim=True, + enable_cinn=True, + check_pir=True, + check_prim_pir=True, ) create_test_act_fp16_class(TestCeil, grad_check=False, check_pir=True) create_test_act_fp16_class( @@ -5019,7 +5030,9 @@ def test_check_grad(self): create_test_act_bf16_class( TestSqrtComp, check_prim=True, check_pir=True, check_prim_pir=True ) -create_test_act_bf16_class(TestAbs, check_prim=True, check_pir=True) +create_test_act_bf16_class( + TestAbs, check_prim=True, check_pir=True, check_prim_pir=True +) create_test_act_bf16_class(TestCeil, grad_check=False, check_pir=True) create_test_act_bf16_class( TestFloor, diff --git a/test/legacy_test/test_instance_norm_op.py b/test/legacy_test/test_instance_norm_op.py index 9910a9c04db6d5..c5fd7af6b48799 100644 --- a/test/legacy_test/test_instance_norm_op.py +++ b/test/legacy_test/test_instance_norm_op.py @@ -134,7 +134,11 @@ def test_check_output(self): def test_check_grad(self): self.check_grad( - ['X', 'Scale', 'Bias'], 'Y', check_prim=True, check_pir=True + ['X', 'Scale', 'Bias'], + 'Y', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) def init_test_case(self): diff --git a/test/legacy_test/test_instance_norm_op_v2.py b/test/legacy_test/test_instance_norm_op_v2.py index c05deb1feaf0fb..fe8e26aaec7839 100644 --- a/test/legacy_test/test_instance_norm_op_v2.py +++ b/test/legacy_test/test_instance_norm_op_v2.py @@ -229,6 +229,9 @@ def test_check_grad(self): 'Y', check_prim=self.check_prim, check_pir=True, + check_prim_pir=False + if os.getenv("FLAGS_enable_pir_in_executor") + else True, ) def init_dtype(self): @@ -284,6 +287,9 @@ def test_check_grad(self): max_relative_error=self.max_relative_error, check_prim=self.check_prim, check_pir=True, + check_prim_pir=False + if os.getenv("FLAGS_enable_pir_in_executor") + else True, ) @@ -356,6 +362,9 @@ def test_check_grad(self): user_defined_grads=self.user_defined_grads, check_prim=self.check_prim, check_pir=True, + check_prim_pir=False + if os.getenv("FLAGS_enable_pir_in_executor") + else True, ) From c2dd2025a10c93c8511c74d8f6594e5053798151 Mon Sep 17 00:00:00 2001 From: winter-wang <78149749+winter-wang@users.noreply.github.com> Date: Thu, 4 Jan 2024 19:33:06 +0800 Subject: [PATCH 115/142] [PIR] remove log simply name mechnism from phi to common. (#60507) --- paddle/common/enforce.cc | 103 ++++++++++++++++++ paddle/common/enforce.h | 74 +++++-------- paddle/fluid/framework/CMakeLists.txt | 5 +- paddle/fluid/framework/type_defs.cc | 45 ++++++++ paddle/fluid/imperative/CMakeLists.txt | 2 +- paddle/fluid/imperative/type_defs.cc | 21 ++++ paddle/fluid/memory/stats.h | 6 +- paddle/fluid/platform/enforce.h | 1 + paddle/fluid/platform/init.cc | 4 +- paddle/phi/core/enforce.cc | 144 ++----------------------- paddle/phi/core/enforce.h | 23 ++-- paddle/utils/variant_test.cc | 2 +- 12 files changed, 221 insertions(+), 209 deletions(-) create mode 100644 paddle/common/enforce.cc create mode 100644 paddle/fluid/framework/type_defs.cc create mode 100644 paddle/fluid/imperative/type_defs.cc diff --git a/paddle/common/enforce.cc b/paddle/common/enforce.cc new file mode 100644 index 00000000000000..22161f890eb651 --- /dev/null +++ b/paddle/common/enforce.cc @@ -0,0 +1,103 @@ +/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/common/enforce.h" + +#include +#include +#include +#include + +REGISTER_LOG_SIMPLY_STR(std::string); + +namespace { +class StrSizeCmp { + public: + bool operator()(const std::string& lhs, const std::string& rhs) const { + return lhs.size() > rhs.size(); + } +}; + +using LogSimplyStrMap = std::map; + +LogSimplyStrMap& GetLogStrSimplyMap() { + static LogSimplyStrMap str_simply_map; + return str_simply_map; +} + +std::string SimplifyDemangleStr(std::string str) { + auto& str_map = GetLogStrSimplyMap(); + for (auto& value : str_map) { + size_t start_pos = 0; + while ((start_pos = str.find(value.first, start_pos)) != + std::string::npos) { + str.replace(start_pos, value.first.length(), value.second); + start_pos += value.second.length(); + } + } + return str; +} +} // namespace + +namespace common { +namespace enforce { + +bool RegisterLogSimplyStr(const std::string& type_name, + const std::string& simply_name) { + return GetLogStrSimplyMap() + .emplace(std::make_pair(type_name, simply_name)) + .second; +} + +std::string GetCurrentTraceBackString(bool for_signal) { + std::ostringstream sout; + + if (!for_signal) { + sout << "\n\n--------------------------------------\n"; + sout << "C++ Traceback (most recent call last):"; + sout << "\n--------------------------------------\n"; + } +#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL) + static constexpr int TRACE_STACK_LIMIT = 100; + + std::array call_stack; + auto size = backtrace(call_stack.data(), TRACE_STACK_LIMIT); + auto symbols = backtrace_symbols(call_stack.data(), size); + Dl_info info; + int idx = 0; + // `for_signal` used to remove the stack trace introduced by + // obtaining the error stack trace when the signal error occurred, + // that is not related to the signal error self, remove it to + // avoid misleading users and developers + int end_idx = for_signal ? 2 : 0; + for (int i = size - 1; i >= end_idx; --i) { + if (dladdr(call_stack[i], &info) && info.dli_sname) { + auto demangled = common::demangle(info.dli_sname); + std::string path(info.dli_fname); + // C++ traceback info are from core.so + if (path.substr(path.length() - 3).compare(".so") == 0) { + sout << paddle::string::Sprintf( + "%-3d %s\n", idx++, SimplifyDemangleStr(demangled)); + } + } + } + free(symbols); // NOLINT +#else + sout << "Not support stack backtrace yet.\n"; +#endif + return sout.str(); +} + +} // namespace enforce +} // namespace common diff --git a/paddle/common/enforce.h b/paddle/common/enforce.h index 13e33c7e32a76c..b734c90d0672bc 100644 --- a/paddle/common/enforce.h +++ b/paddle/common/enforce.h @@ -31,6 +31,7 @@ #include "paddle/common/errors.h" #include "paddle/common/macros.h" +#include "paddle/utils/test_macros.h" #if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL) #include @@ -40,10 +41,20 @@ #define GLOG_NO_ABBREVIATED_SEVERITIES #include "paddle/utils/string/printf.h" #include "paddle/utils/string/to_string.h" -#include "paddle/utils/test_macros.h" #include "paddle/utils/variant.h" namespace common { +#ifdef __GNUC__ +inline std::string demangle(std::string name) { + int status = -4; // some arbitrary value to eliminate the compiler warning + std::unique_ptr res{ + abi::__cxa_demangle(name.c_str(), NULL, NULL, &status), std::free}; + return (status == 0) ? res.get() : name; +} +#else +inline std::string demangle(std::string name) { return name; } +#endif + class CommonNotMetException : public std::exception { public: explicit CommonNotMetException(const std::string& str) : err_str_(str) {} @@ -53,9 +64,7 @@ class CommonNotMetException : public std::exception { private: std::string err_str_; }; -} // namespace common -namespace common { namespace enforce { /** HELPER MACROS AND FUNCTIONS **/ @@ -161,6 +170,20 @@ using CommonType2 = typename std::add_lvalue_reference< #define COMMON_ENFORCE_LE(__VAL0, __VAL1, ...) \ __COMMON_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__) +TEST_API bool RegisterLogSimplyStr(const std::string& type, + const std::string& simply); +TEST_API std::string GetCurrentTraceBackString(bool for_signal = false); +template +class LogSimplyStrRegistrar { + public: + static bool success; +}; + +#define REGISTER_LOG_SIMPLY_STR(Type) \ + template <> \ + bool ::common::enforce::LogSimplyStrRegistrar::success = \ + ::common::enforce::RegisterLogSimplyStr( \ + ::common::demangle(typeid(Type).name()), #Type); } // namespace enforce } // namespace common @@ -172,53 +195,10 @@ inline bool is_error(const T& stat) { } namespace pir { - -#ifdef __GNUC__ -inline std::string demangle(std::string name) { - int status = -4; // some arbitrary value to eliminate the compiler warning - std::unique_ptr res{ - abi::__cxa_demangle(name.c_str(), NULL, NULL, &status), std::free}; - return (status == 0) ? res.get() : name; -} -#else -inline std::string demangle(std::string name) { return name; } -#endif - -static std::string GetCurrentTraceBackString() { - std::ostringstream sout; - sout << "\n\n--------------------------------------\n"; - sout << "C++ Traceback (most recent call last):"; - sout << "\n--------------------------------------\n"; -#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL) - static constexpr int TRACE_STACK_LIMIT = 100; - - void* call_stack[TRACE_STACK_LIMIT]; - auto size = backtrace(call_stack, TRACE_STACK_LIMIT); - auto symbols = backtrace_symbols(call_stack, size); - Dl_info info; - int idx = 0; - int end_idx = 0; - for (int i = size - 1; i >= end_idx; --i) { - if (dladdr(call_stack[i], &info) && info.dli_sname) { - auto demangled = demangle(info.dli_sname); - std::string path(info.dli_fname); - // C++ traceback info are from core.so - if (path.substr(path.length() - 3).compare(".so") == 0) { - sout << idx++ << " " << demangled << "\n"; - } - } - } - free(symbols); -#else - sout << "Not support stack backtrace yet.\n"; -#endif - return sout.str(); -} - class IrNotMetException : public std::exception { public: explicit IrNotMetException(const std::string& str) - : err_str_(str + GetCurrentTraceBackString()) {} + : err_str_(str + ::common::enforce::GetCurrentTraceBackString()) {} const char* what() const noexcept override { return err_str_.c_str(); } diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 338130c64d9a06..6e70167a853b69 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -947,7 +947,10 @@ cc_library( imperative_flag layer) -cc_library(type_info SRCS type_info.cc) +cc_library( + type_info + SRCS type_info.cc type_defs.cc + DEPS common) target_link_libraries(type_info pir op_dialect) add_dependencies(type_info framework_proto auto_parallel_proto xxhash) if(WITH_MKLDNN) diff --git a/paddle/fluid/framework/type_defs.cc b/paddle/fluid/framework/type_defs.cc new file mode 100644 index 00000000000000..d8a6546ea718d6 --- /dev/null +++ b/paddle/fluid/framework/type_defs.cc @@ -0,0 +1,45 @@ +/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/type_defs.h" + +#include "paddle/common/enforce.h" + +namespace paddle { + +using namespace framework; // NOLINT +template class variant, + std::vector, + std::vector, + bool, + std::vector, + BlockDesc*, + int64_t, + std::vector, + std::vector, + std::vector, + VarDesc*, + std::vector, + double, + paddle::experimental::Scalar, + std::vector, + ::pir::Block*, + std::vector<::pir::Value>>; +} // namespace paddle +REGISTER_LOG_SIMPLY_STR(paddle::framework::AttributeMap); +REGISTER_LOG_SIMPLY_STR(paddle::framework::Attribute); diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index 7a764f5302021b..86688213ef1867 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -4,7 +4,7 @@ cc_library( DEPS phi common) cc_library( var_helper - SRCS var_helper.cc + SRCS var_helper.cc type_defs.cc DEPS tensor phi common) if(WITH_XPU) cc_library( diff --git a/paddle/fluid/imperative/type_defs.cc b/paddle/fluid/imperative/type_defs.cc new file mode 100644 index 00000000000000..fa4a327f8a21ec --- /dev/null +++ b/paddle/fluid/imperative/type_defs.cc @@ -0,0 +1,21 @@ +/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/imperative/type_defs.h" + +#include "paddle/common/enforce.h" + +REGISTER_LOG_SIMPLY_STR(paddle::imperative::NameTensorMap); +REGISTER_LOG_SIMPLY_STR(paddle::imperative::NameVarBaseMap); +REGISTER_LOG_SIMPLY_STR(paddle::imperative::NameVariableWrapperMap); diff --git a/paddle/fluid/memory/stats.h b/paddle/fluid/memory/stats.h index d2c8b04bc70ab7..2cecfb16c3e014 100644 --- a/paddle/fluid/memory/stats.h +++ b/paddle/fluid/memory/stats.h @@ -77,8 +77,7 @@ class Stat : public StatBase { thread_local_stat->current += increment; VLOG(8) << string::split_string( - phi::enforce::demangle(typeid(*thread_local_stat).name()), - "::") + common::demangle(typeid(*thread_local_stat).name()), "::") .back() << ": Update current_value with " << increment << ", after update, current value = " << GetCurrentValue(); @@ -91,8 +90,7 @@ class Stat : public StatBase { !peak_value_.compare_exchange_weak(prev_value, current_value)) { } VLOG(8) << string::split_string( - phi::enforce::demangle(typeid(*thread_local_stat).name()), - "::") + common::demangle(typeid(*thread_local_stat).name()), "::") .back() << ": Update current_value with " << increment << ", after update, peak_value = " << peak_value_.load() diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index 1a82b05f3bc3af..c7ee1707c42868 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -108,6 +108,7 @@ PHI_DECLARE_int32(call_stack_level); namespace paddle { namespace platform { using namespace ::phi::enforce; // NOLINT +using ::common::demangle; /** HELPER MACROS AND FUNCTIONS **/ diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index a3fff528f7903e..62353a415617cc 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -51,6 +51,7 @@ limitations under the License. */ #include "paddle/fluid/platform/device/ipu/ipu_info.h" #endif +#include "paddle/common/enforce.h" #include "paddle/fluid/memory/allocation/allocator_facade.h" #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/platform/flags.h" @@ -310,7 +311,8 @@ void SignalHandle(const char *data, int size) { sout << "\n\n--------------------------------------\n"; sout << "C++ Traceback (most recent call last):"; sout << "\n--------------------------------------\n"; - auto traceback = platform::GetCurrentTraceBackString(/*for_signal=*/true); + auto traceback = + ::common::enforce::GetCurrentTraceBackString(/*for_signal=*/true); if (traceback.empty()) { sout << "No stack trace in paddle, may be caused by external reasons.\n"; diff --git a/paddle/phi/core/enforce.cc b/paddle/phi/core/enforce.cc index 5d4041738bc4fc..979b147c6b3e19 100644 --- a/paddle/phi/core/enforce.cc +++ b/paddle/phi/core/enforce.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include #include "glog/logging.h" +#include "paddle/common/enforce.h" #include "paddle/phi/common/scalar.h" #include "paddle/utils/blank.h" #include "paddle/utils/flags.h" @@ -30,146 +31,11 @@ limitations under the License. */ #endif // PADDLE_WITH_CUDA PD_DECLARE_int32(call_stack_level); - -namespace egr { -class EagerVariable; -} -namespace paddle { -namespace framework { -class VarDesc; -class BlockDesc; -using Attribute = paddle::variant, - std::vector, - std::vector, - bool, - std::vector, - BlockDesc*, - int64_t, - std::vector, - std::vector, - std::vector, - VarDesc*, - std::vector, - double, - paddle::experimental::Scalar, - std::vector>; -using AttributeMap = std::unordered_map; -} // namespace framework -namespace imperative { -class VariableWrapper; -class SavedVariableWrapperList; -class VarBase; - -namespace details { -template -struct NameVarMapTrait {}; - -template <> -struct NameVarMapTrait { - using Type = std::map>>; -}; - -template <> -struct NameVarMapTrait { - using Type = std::map; -}; - -template <> -struct NameVarMapTrait { - using Type = - std::map>>; -}; - -} // namespace details - -template -using NameVarMap = typename details::NameVarMapTrait::Type; - -using NameVarBaseMap = NameVarMap; -using NameVariableWrapperMap = NameVarMap; -using NameTensorMap = NameVarMap; - -} // namespace imperative -} // namespace paddle - namespace phi { namespace enforce { TEST_API int GetCallStackLevel() { return FLAGS_call_stack_level; } -template -static std::string ReplaceComplexTypeStr(std::string str, - const std::string& type_name) { - auto demangle_type_str = demangle(typeid(T).name()); - size_t start_pos = 0; - while ((start_pos = str.find(demangle_type_str, start_pos)) != - std::string::npos) { - str.replace(start_pos, demangle_type_str.length(), type_name); - start_pos += type_name.length(); - } - return str; -} - -#define __REPLACE_COMPLEX_TYPE_STR__(__TYPENAME, __STR) \ - do { \ - __STR = \ - phi::enforce::ReplaceComplexTypeStr<__TYPENAME>(__STR, #__TYPENAME); \ - } while (0) - -static std::string SimplifyDemangleStr(std::string str) { - // the older is important, you have to put complex types in front - __REPLACE_COMPLEX_TYPE_STR__(paddle::framework::AttributeMap, str); - __REPLACE_COMPLEX_TYPE_STR__(paddle::framework::Attribute, str); - __REPLACE_COMPLEX_TYPE_STR__(paddle::imperative::NameVariableWrapperMap, str); - __REPLACE_COMPLEX_TYPE_STR__(paddle::imperative::NameVarBaseMap, str); - __REPLACE_COMPLEX_TYPE_STR__(paddle::imperative::NameTensorMap, str); - __REPLACE_COMPLEX_TYPE_STR__(std::string, str); - return str; -} - -TEST_API std::string GetCurrentTraceBackString(bool for_signal) { - std::ostringstream sout; - - if (!for_signal) { - sout << "\n\n--------------------------------------\n"; - sout << "C++ Traceback (most recent call last):"; - sout << "\n--------------------------------------\n"; - } -#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL) - static constexpr int TRACE_STACK_LIMIT = 100; - - std::array call_stack; - auto size = backtrace(call_stack.data(), TRACE_STACK_LIMIT); - auto symbols = backtrace_symbols(call_stack.data(), size); - Dl_info info; - int idx = 0; - // `for_signal` used to remove the stack trace introduced by - // obtaining the error stack trace when the signal error occurred, - // that is not related to the signal error self, remove it to - // avoid misleading users and developers - int end_idx = for_signal ? 2 : 0; - for (int i = size - 1; i >= end_idx; --i) { - if (dladdr(call_stack[i], &info) && info.dli_sname) { - auto demangled = demangle(info.dli_sname); - std::string path(info.dli_fname); - // C++ traceback info are from core.so - if (path.substr(path.length() - 3).compare(".so") == 0) { - sout << paddle::string::Sprintf( - "%-3d %s\n", idx++, SimplifyDemangleStr(demangled)); - } - } - } - free(symbols); // NOLINT -#else - sout << "Not support stack backtrace yet.\n"; -#endif - return sout.str(); -} - void ThrowWarnInternal(const std::string& msg) { LOG(WARNING) << "WARNING :" << msg; } @@ -276,7 +142,9 @@ std::string GetExternalErrorMsg(T status) { std::string search_path_3; #if !defined(_WIN32) Dl_info info; - if (dladdr(reinterpret_cast(GetCurrentTraceBackString), &info)) { + if (dladdr( + reinterpret_cast(common::enforce::GetCurrentTraceBackString), + &info)) { std::string phi_so_path(info.dli_fname); const size_t last_slash_idx = phi_so_path.find_last_of('/'); if (std::string::npos != last_slash_idx) { @@ -297,7 +165,9 @@ std::string GetExternalErrorMsg(T status) { char buf[512]; MEMORY_BASIC_INFORMATION mbi; HMODULE h_module = - (::VirtualQuery(GetCurrentTraceBackString, &mbi, sizeof(mbi)) != 0) + (::VirtualQuery(common::enforce::GetCurrentTraceBackString, + &mbi, + sizeof(mbi)) != 0) ? (HMODULE)mbi.AllocationBase : NULL; GetModuleFileName(h_module, buf, 512); diff --git a/paddle/phi/core/enforce.h b/paddle/phi/core/enforce.h index 61e502951f24ee..feb2852a9dc679 100644 --- a/paddle/phi/core/enforce.h +++ b/paddle/phi/core/enforce.h @@ -79,17 +79,6 @@ limitations under the License. */ namespace phi { namespace enforce { -#ifdef __GNUC__ -inline std::string demangle(std::string name) { - int status = -4; // some arbitrary value to eliminate the compiler warning - std::unique_ptr res{ - abi::__cxa_demangle(name.c_str(), NULL, NULL, &status), std::free}; - return (status == 0) ? res.get() : name; -} -#else -inline std::string demangle(std::string name) { return name; } -#endif - namespace details { template inline constexpr bool IsArithmetic() { @@ -164,7 +153,6 @@ struct BinaryCompareMessageConverter { } // namespace details TEST_API int GetCallStackLevel(); -TEST_API std::string GetCurrentTraceBackString(bool for_signal = false); TEST_API std::string SimplifyErrorTypeFormat(const std::string& str); template @@ -192,7 +180,7 @@ std::string GetCompleteTraceBackString(StrType&& what, sout << paddle::string::Sprintf( "%s (at %s:%d)", std::forward(what), file, line) << std::endl; - return GetCurrentTraceBackString() + sout.str(); + return ::common::enforce::GetCurrentTraceBackString() + sout.str(); } template @@ -201,7 +189,8 @@ static std::string GetTraceBackString(StrType&& what, int line) { if (GetCallStackLevel() > 1) { // FLAGS_call_stack_level>1 means showing c++ call stack - return GetCurrentTraceBackString() + GetErrorSumaryString(what, file, line); + return ::common::enforce::GetCurrentTraceBackString() + + GetErrorSumaryString(what, file, line); } else { return GetErrorSumaryString(what, file, line); } @@ -453,7 +442,7 @@ struct EnforceNotMet : public std::exception { " 1. The %s is not the %s of operator %s;\n" \ " 2. The %s has no corresponding variable passed in;\n" \ " 3. The %s corresponding variable is not initialized.", \ - phi::demangle( \ + common::demangle( \ typeid(std::add_lvalue_reference::type) \ .name()), \ __ROLE, \ @@ -514,8 +503,8 @@ namespace details { "paddle::get failed, cannot get value " \ "(%s) by type %s, its type is %s.", \ expression, \ - phi::enforce::demangle(typeid(OutputType).name()), \ - phi::enforce::demangle(input.type().name())), \ + common::demangle(typeid(OutputType).name()), \ + common::demangle(input.type().name())), \ file, \ line); \ END_HANDLE_THE_ERROR \ diff --git a/paddle/utils/variant_test.cc b/paddle/utils/variant_test.cc index ef4a6cf8cd89cd..854fe742e686d6 100644 --- a/paddle/utils/variant_test.cc +++ b/paddle/utils/variant_test.cc @@ -18,7 +18,7 @@ #include "paddle/phi/core/enforce.h" TEST(interface_test, type) { - using phi::enforce::demangle; + using common::demangle; paddle::variant var; From 1d59851583f3dc39bc6f5899e546d6449f64b078 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Thu, 4 Jan 2024 19:55:06 +0800 Subject: [PATCH 116/142] [InferSymbolicShape] Delete redundent value_id_to_shapeordata_ (#60554) --- paddle/pir/dialect/shape/utils/shape_utils.cc | 4 ++-- paddle/pir/dialect/shape/utils/shape_utils.h | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/paddle/pir/dialect/shape/utils/shape_utils.cc b/paddle/pir/dialect/shape/utils/shape_utils.cc index 05bbb76db8937c..4beb53dde4911b 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.cc +++ b/paddle/pir/dialect/shape/utils/shape_utils.cc @@ -178,13 +178,13 @@ std::string GetValueId(Value* val) { const symbol::ShapeOrDataDimExprs& ShapeConstraintIRAnalysis::GetShapeOrDataForValue(Value* val) { auto val_id = GetValueId(val); - return value_id_to_shapeordata[val_id]; + return value_id_to_shapeordata_[val_id]; } void ShapeConstraintIRAnalysis::SetShapeOrDataForValue( Value* val, const symbol::ShapeOrDataDimExprs& shape_or_data) { auto val_id = GetValueId(val); - value_id_to_shapeordata[val_id] = shape_or_data; + value_id_to_shapeordata_[val_id] = shape_or_data; } } // namespace pir diff --git a/paddle/pir/dialect/shape/utils/shape_utils.h b/paddle/pir/dialect/shape/utils/shape_utils.h index 7e4eafa6722763..8f383f3ad6e05a 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_utils.h @@ -105,9 +105,6 @@ class IR_API ShapeConstraintIRAnalysis : public ShapeAnalysis { int64_t next_sym_idx_ = 0; std::vector constraints_; - std::unordered_map - value_id_to_shapeordata; - public: explicit ShapeConstraintIRAnalysis(std::shared_ptr&& program) : ShapeConstraintIRAnalysis(program->module_op()) { From 09544f6681946db8a76d62a503a1e33faa930941 Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Fri, 5 Jan 2024 10:57:57 +0800 Subject: [PATCH 117/142] =?UTF-8?q?=E3=80=90Hackathon=205th=20No.25?= =?UTF-8?q?=E3=80=91add=20gammaln=20api=20(#60553)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/api/yaml/backward.yaml | 10 ++ paddle/phi/api/yaml/ops.yaml | 10 ++ paddle/phi/kernels/cpu/gammaln_grad_kernel.cc | 22 +++ paddle/phi/kernels/cpu/gammaln_kernel.cc | 22 +++ paddle/phi/kernels/gammaln_grad_kernel.h | 27 +++ paddle/phi/kernels/gammaln_kernel.h | 26 +++ paddle/phi/kernels/gpu/gammaln_grad_kernel.cu | 30 ++++ paddle/phi/kernels/gpu/gammaln_kernel.cu | 29 ++++ .../kernels/impl/gammaln_grad_kernel_impl.h | 92 ++++++++++ paddle/phi/kernels/impl/gammaln_kernel_impl.h | 49 ++++++ python/paddle/__init__.py | 4 + python/paddle/tensor/__init__.py | 4 + python/paddle/tensor/math.py | 45 +++++ test/legacy_test/test_gammaln_op.py | 160 ++++++++++++++++++ test/legacy_test/test_inplace.py | 8 + 15 files changed, 538 insertions(+) create mode 100644 paddle/phi/kernels/cpu/gammaln_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/gammaln_kernel.cc create mode 100644 paddle/phi/kernels/gammaln_grad_kernel.h create mode 100644 paddle/phi/kernels/gammaln_kernel.h create mode 100644 paddle/phi/kernels/gpu/gammaln_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/gammaln_kernel.cu create mode 100644 paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/gammaln_kernel_impl.h create mode 100644 test/legacy_test/test_gammaln_op.py diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 938ea9d5000460..d5748145ffe49d 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -922,6 +922,16 @@ kernel : func : frame_grad +- backward_op : gammaln_grad + forward : gammaln(Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : gammaln_grad + - backward_op : gather_grad forward : gather(Tensor x, Tensor index, Scalar axis=0) -> Tensor(out) args : (Tensor x, Tensor index, Tensor out_grad, Scalar axis=0) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index d4ee3628ad19a4..835046c1e7911b 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1042,6 +1042,16 @@ data_type : dtype backend : place +- op : gammaln + args : (Tensor x) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + kernel : + func : gammaln + inplace: (x -> out) + backward : gammaln_grad + - op : gather args : (Tensor x, Tensor index, Scalar axis=0) output : Tensor(out) diff --git a/paddle/phi/kernels/cpu/gammaln_grad_kernel.cc b/paddle/phi/kernels/cpu/gammaln_grad_kernel.cc new file mode 100644 index 00000000000000..c52ee8b3848e9a --- /dev/null +++ b/paddle/phi/kernels/cpu/gammaln_grad_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gammaln_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h" + +PD_REGISTER_KERNEL( + gammaln_grad, CPU, ALL_LAYOUT, phi::GammalnGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/gammaln_kernel.cc b/paddle/phi/kernels/cpu/gammaln_kernel.cc new file mode 100644 index 00000000000000..ff62f86d2522fd --- /dev/null +++ b/paddle/phi/kernels/cpu/gammaln_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gammaln_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/gammaln_kernel_impl.h" + +PD_REGISTER_KERNEL( + gammaln, CPU, ALL_LAYOUT, phi::GammalnKernel, float, double) {} diff --git a/paddle/phi/kernels/gammaln_grad_kernel.h b/paddle/phi/kernels/gammaln_grad_kernel.h new file mode 100644 index 00000000000000..440dca72a9d469 --- /dev/null +++ b/paddle/phi/kernels/gammaln_grad_kernel.h @@ -0,0 +1,27 @@ + +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GammalnGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& d_out, + DenseTensor* d_x); +} // namespace phi diff --git a/paddle/phi/kernels/gammaln_kernel.h b/paddle/phi/kernels/gammaln_kernel.h new file mode 100644 index 00000000000000..db3015c4a747db --- /dev/null +++ b/paddle/phi/kernels/gammaln_kernel.h @@ -0,0 +1,26 @@ + +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GammalnKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out); +} // namespace phi diff --git a/paddle/phi/kernels/gpu/gammaln_grad_kernel.cu b/paddle/phi/kernels/gpu/gammaln_grad_kernel.cu new file mode 100644 index 00000000000000..b2513d9e3f25ca --- /dev/null +++ b/paddle/phi/kernels/gpu/gammaln_grad_kernel.cu @@ -0,0 +1,30 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gammaln_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(gammaln_grad, + GPU, + ALL_LAYOUT, + phi::GammalnGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/gammaln_kernel.cu b/paddle/phi/kernels/gpu/gammaln_kernel.cu new file mode 100644 index 00000000000000..3d57be7b277335 --- /dev/null +++ b/paddle/phi/kernels/gpu/gammaln_kernel.cu @@ -0,0 +1,29 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gammaln_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/impl/gammaln_kernel_impl.h" + +PD_REGISTER_KERNEL(gammaln, + GPU, + ALL_LAYOUT, + phi::GammalnKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h new file mode 100644 index 00000000000000..50c73cff27ce4a --- /dev/null +++ b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h @@ -0,0 +1,92 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/kernels/funcs/for_range.h" + +namespace phi { +template +HOSTDEVICE T digamma(T x) { + static T c = T{8.5}; + static T euler_mascheroni = T{0.57721566490153286060}; + T r; + T value; + T x2; + + if (x <= T{0.0}) { + value = T{0.0}; + return value; + } + + if (x <= T{0.000001}) { + value = -euler_mascheroni - T{1.0} / x + T{1.6449340668482264365} * x; + return value; + } + + value = T{0.0}; + x2 = x; + while (x2 < c) { + value = value - T{1.0} / x2; + x2 = x2 + T{1.0}; + } + + r = T{1.0} / x2; + value = value + std::log(x2) - T{0.5} * r; + + r = r * r; + + value = value - + r * (T{1.0} / T{12.0} - + r * (T{1.0} / T{120.0} - + r * (T{1.0} / T{252.0} - + r * (T{1.0} / T{240.0} - r * (T{1.0} / T{132.0}))))); + + return value; +} + +template +struct GammalnGradFunctor { + GammalnGradFunctor(const T* dout, const T* x, T* output, int64_t numel) + : dout_(dout), x_(x), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + using MT = typename phi::dtype::MPTypeTrait::Type; + const MT mp_dout = static_cast(dout_[idx]); + const MT mp_x = static_cast(x_[idx]); + output_[idx] = static_cast(mp_dout * digamma(mp_x)); + } + + private: + const T* dout_; + const T* x_; + T* output_; + int64_t numel_; +}; +template +void GammalnGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& d_out, + DenseTensor* d_x) { + auto numel = d_out.numel(); + auto* dout_data = d_out.data(); + auto* x_data = x.data(); + auto* dx_data = + dev_ctx.template Alloc(d_x, static_cast(numel * sizeof(T))); + phi::funcs::ForRange for_range(dev_ctx, numel); + GammalnGradFunctor functor(dout_data, x_data, dx_data, numel); + for_range(functor); +} +} // namespace phi diff --git a/paddle/phi/kernels/impl/gammaln_kernel_impl.h b/paddle/phi/kernels/impl/gammaln_kernel_impl.h new file mode 100644 index 00000000000000..38385610de0de6 --- /dev/null +++ b/paddle/phi/kernels/impl/gammaln_kernel_impl.h @@ -0,0 +1,49 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/kernels/funcs/for_range.h" + +namespace phi { +template +struct GammalnFunctor { + GammalnFunctor(const T* x, T* output, int64_t numel) + : x_(x), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + using MT = typename phi::dtype::MPTypeTrait::Type; + const MT mp_x = static_cast(x_[idx]); + output_[idx] = static_cast(std::lgamma(mp_x)); + } + + private: + const T* x_; + T* output_; + int64_t numel_; +}; + +template +void GammalnKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + auto numel = x.numel(); + auto* x_data = x.data(); + auto* out_data = dev_ctx.template Alloc(out); + phi::funcs::ForRange for_range(dev_ctx, numel); + GammalnFunctor functor(x_data, out_data, numel); + for_range(functor); +} +} // namespace phi diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 8902b82cadf847..ae35e627553563 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -398,6 +398,8 @@ frac, frac_, frexp, + gammaln, + gammaln_, gcd, gcd_, heaviside, @@ -775,6 +777,8 @@ 'square_', 'divide', 'divide_', + 'gammaln', + 'gammaln_', 'ceil', 'atan', 'atan_', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index b26798892a2b2f..b718910348d8ff 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -278,6 +278,8 @@ frac, frac_, frexp, + gammaln, + gammaln_, gcd, gcd_, heaviside, @@ -668,6 +670,8 @@ 'real', 'imag', 'is_floating_point', + 'gammaln', + 'gammaln_', 'digamma', 'digamma_', 'diagonal', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index acaa0905ce6f40..6d75d41b4949ca 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -5003,6 +5003,51 @@ def conj(x, name=None): return out +def gammaln(x, name=None): + r""" + Calculates the logarithm of the absolute value of the gamma function elementwisely. + + Args: + x (Tensor): Input Tensor. Must be one of the following types: float16, float32, float64, bfloat16. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, The values of the logarithm of the absolute value of the gamma at the given tensor x. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> x = paddle.arange(1.5, 4.5, 0.5) + >>> out = paddle.gammaln(x) + >>> print(out) + Tensor(shape=[6], dtype=float32, place=Place(cpu), stop_gradient=True, + [-0.12078224, 0. , 0.28468287, 0.69314718, 1.20097363, + 1.79175949]) + """ + if in_dynamic_or_pir_mode(): + return _C_ops.gammaln(x) + else: + check_variable_and_dtype( + x, 'x', ['float16', 'float32', 'float64', 'bfloat16'], 'gammaln' + ) + helper = LayerHelper('gammaln', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op(type='gammaln', inputs={'x': x}, outputs={'out': out}) + return out + + +@inplace_apis_in_dygraph_only +def gammaln_(x, name=None): + r""" + Inplace version of ``gammaln`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_gammaln`. + """ + if in_dynamic_mode(): + return _C_ops.gammaln_(x) + + def digamma(x, name=None): r""" Calculates the digamma of the given input tensor, element-wise. diff --git a/test/legacy_test/test_gammaln_op.py b/test/legacy_test/test_gammaln_op.py new file mode 100644 index 00000000000000..50331af5c7a34c --- /dev/null +++ b/test/legacy_test/test_gammaln_op.py @@ -0,0 +1,160 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import OpTest, convert_float_to_uint16 +from scipy import special + +import paddle +from paddle.base import core + + +def ref_gammaln(x): + return special.gammaln(x) + + +def ref_gammaln_grad(x, dout): + return dout * special.polygamma(0, x) + + +class TestGammalnOp(OpTest): + def setUp(self): + self.op_type = 'gammaln' + self.python_api = paddle.gammaln + self.init_dtype_type() + self.shape = (3, 40) + self.x = np.random.random(self.shape).astype(self.dtype) + 1 + self.inputs = {'x': self.x} + out = ref_gammaln(self.x) + self.outputs = {'out': out} + + def init_dtype_type(self): + self.dtype = np.float64 + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad(self): + self.check_grad(['x'], 'out', check_pir=True) + + +class TestGammalnOpFp32(TestGammalnOp): + def init_dtype_type(self): + self.dtype = np.float32 + + +class TestGammalnFP16Op(TestGammalnOp): + def init_dtype_type(self): + self.dtype = np.float16 + + +class TestGammalnBigNumberOp(TestGammalnOp): + def setUp(self): + self.op_type = 'gammaln' + self.python_api = paddle.gammaln + self.init_dtype_type() + self.shape = (100, 1) + self.x = np.random.random(self.shape).astype(self.dtype) + 1 + self.x[:5, 0] = np.array([1e5, 1e10, 1e20, 1e40, 1e80]) + self.inputs = {'x': self.x} + out = ref_gammaln(self.x) + self.outputs = {'out': out} + + def init_dtype_type(self): + self.dtype = np.float64 + + def test_check_grad(self): + d_out = self.outputs['out'] + d_x = ref_gammaln_grad(self.x, d_out) + self.check_grad( + ['x'], + 'out', + user_defined_grads=[ + d_x, + ], + user_defined_grad_outputs=[ + d_out, + ], + check_pir=True, + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support bfloat16", +) +class TestGammalnBF16Op(OpTest): + def setUp(self): + self.op_type = 'gammaln' + self.python_api = paddle.gammaln + self.dtype = np.uint16 + self.shape = (5, 30) + x = np.random.random(self.shape).astype("float32") + 1 + self.inputs = {'x': convert_float_to_uint16(x)} + out = ref_gammaln(x) + self.outputs = {'out': convert_float_to_uint16(out)} + + def test_check_output(self): + self.check_output_with_place(core.CUDAPlace(0), check_pir=True) + + def test_check_grad(self): + self.check_grad_with_place( + core.CUDAPlace(0), ['x'], 'out', check_pir=True + ) + + +class TestGammalnOpApi(unittest.TestCase): + def setUp(self): + self.shape = [2, 3, 4, 5] + self.init_dtype_type() + self.x_np = np.random.random(self.shape).astype(self.dtype) + 1 + self.place = ( + paddle.CUDAPlace(0) + if core.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def init_dtype_type(self): + self.dtype = "float64" + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('x', self.x_np.shape, self.x_np.dtype) + out = paddle.gammaln(x) + exe = paddle.static.Executor(self.place) + (res,) = exe.run(feed={'x': self.x_np}, fetch_list=[out]) + out_ref = ref_gammaln(self.x_np) + np.testing.assert_allclose(out_ref, res, rtol=1e-5, atol=1e-5) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out = paddle.gammaln(x) + out_ref = ref_gammaln(self.x_np) + np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-5, atol=1e-5) + paddle.enable_static() + + +class TestGammalnOpApiFp32(TestGammalnOpApi): + def init_dtype_type(self): + self.dtype = "float32" + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/test/legacy_test/test_inplace.py b/test/legacy_test/test_inplace.py index 42f9a46cfb9100..38fbac0357d6df 100644 --- a/test/legacy_test/test_inplace.py +++ b/test/legacy_test/test_inplace.py @@ -869,6 +869,14 @@ def test_leaf_inplace_var_error(self): pass +class TestDygraphInplaceGammaln(TestDygraphInplaceWithContinuous): + def inplace_api_processing(self, var): + return paddle.gammaln_(var) + + def non_inplace_api_processing(self, var): + return paddle.gammaln(var) + + class TestDygraphInplaceNeg(TestDygraphInplaceWithContinuous): def inplace_api_processing(self, var): return paddle.neg_(var) From c3106c4d92b9b785626aa260e4928469400e8e8a Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Fri, 5 Jan 2024 11:01:58 +0800 Subject: [PATCH 118/142] fix (#60570) --- paddle/phi/kernels/gpu/arange_kernel.cu | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/paddle/phi/kernels/gpu/arange_kernel.cu b/paddle/phi/kernels/gpu/arange_kernel.cu index 10905ff89e18e9..ce8f997fe7d158 100644 --- a/paddle/phi/kernels/gpu/arange_kernel.cu +++ b/paddle/phi/kernels/gpu/arange_kernel.cu @@ -66,8 +66,12 @@ void ArangeNullaryKernel(const Context& dev_ctx, const T end_value, const T step_value, DenseTensor* out) { + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType start_value_mpt = static_cast(start_value); + MPType end_value_mpt = static_cast(end_value); + MPType step_value_mpt = static_cast(step_value); int64_t size = 0; - phi::funcs::GetSize(start_value, end_value, step_value, &size); + phi::funcs::GetSize(start_value_mpt, end_value_mpt, step_value_mpt, &size); out->Resize(common::make_ddim({size})); T* out_data = dev_ctx.template Alloc(out); @@ -77,7 +81,8 @@ void ArangeNullaryKernel(const Context& dev_ctx, return; } int64_t grid = (size + block - 1) / block; - Range<<>>(start_value, step_value, size, out_data); + Range<<>>( + start_value_mpt, step_value_mpt, size, out_data); } template @@ -86,11 +91,10 @@ void ArangeKernel(const Context& dev_ctx, const Scalar& end, const Scalar& step, DenseTensor* out) { - using MPType = typename phi::dtype::MPTypeTrait::Type; - MPType start_value = start.to(); - MPType end_value = end.to(); - MPType step_value = step.to(); - ArangeNullaryKernel( + T start_value = start.to(); + T end_value = end.to(); + T step_value = step.to(); + ArangeNullaryKernel( dev_ctx, start_value, end_value, step_value, out); } From bc13117dece35f6d50613c1fdac66ba41274cff6 Mon Sep 17 00:00:00 2001 From: BiynXu <62832681+BiynXu@users.noreply.github.com> Date: Fri, 5 Jan 2024 11:11:10 +0800 Subject: [PATCH 119/142] [CINN] Add tile tactic and bind cuda tactic (#60534) * [CINN] Add tile tactic * [CINN] Add bind cuda tactic --- .../dy_shape_group_scheduler.cc | 36 ++++++-- .../ir/group_schedule/tactic/CMakeLists.txt | 2 + .../tactic/align_iter_space_tactic.h | 1 - .../group_schedule/tactic/bind_cuda_tactic.cc | 58 +++++++++++++ .../group_schedule/tactic/bind_cuda_tactic.h | 36 ++++++++ .../group_schedule/tactic/schedule_tactic.h | 12 +++ .../ir/group_schedule/tactic/tile_tactic.cc | 84 +++++++++++++++++++ .../ir/group_schedule/tactic/tile_tactic.h | 36 ++++++++ 8 files changed, 256 insertions(+), 9 deletions(-) create mode 100644 paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.cc create mode 100644 paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h create mode 100644 paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc create mode 100644 paddle/cinn/ir/group_schedule/tactic/tile_tactic.h diff --git a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc index d56fc994fdcea3..9f7a52d97fb178 100644 --- a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc +++ b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc @@ -15,31 +15,36 @@ #include "paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h" #include "paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h" +#include "paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h" +#include "paddle/cinn/ir/group_schedule/tactic/tile_tactic.h" #include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h" +#include "paddle/cinn/ir/op/ir_operators.h" namespace cinn { namespace ir { void DynamicShapeGroupScheduler::Init() { + // Only 1 bucket for test now. + schedule_context_.target = target_; schedule_context_.output_names = OutputTensorNames(); schedule_context_.global_master = FindGlobalMasterNode(); schedule_context_.iter_space_info = ConstructIterSpaceInfo(schedule_context_.global_master); - schedule_context_.target = target_; + schedule_context_.bucket_info = {/* sp_lower_bound = */ 1024, + /* sp_upper_bound = */ INT_MAX, + /* rb_lower_bound = */ 64, + /* rb_upper_bound = */ INT_MAX}; tactics_.emplace_back(new AlignIterSpaceTactic()); + tactics_.emplace_back(new TileTactic()); tactics_.emplace_back(new ComputeInlineTactic()); + tactics_.emplace_back(new BindCudaTactic()); tactics_.emplace_back(new ArrangeStorageTactic()); } void DynamicShapeGroupScheduler::Schedule() { - // Fake schedule for test ApplyTactics(); - std::vector all_blocks = ir_sch_->GetAllBlocks(); - auto block0_loops = ir_sch_->GetLoops(all_blocks[0]); - auto splited_loops1 = ir_sch_->Split(block0_loops[0], {1024, -1}); - ir_sch_->Bind(splited_loops1[0], "threadIdx.x"); - + // Fake bucket for test ir::Expr predicate1 = ir::LE::Make(Expr(1023), Expr(1024)); std::unique_ptr new_ir_sch1 = std::make_unique(*ir_sch_); @@ -55,12 +60,12 @@ void DynamicShapeGroupScheduler::ApplyTactics() { VLOG(6) << "before applying [" << tactic->TacticName() << "] on ScheduleBlockNode [" << node->id() << "] func body:\n" << ir_sch_->GetModule().GetExprs().front(); - tactic->Init(&schedule_context_); tactic->Apply(ir_sch_, node->id()); VLOG(6) << "after applying [" << tactic->TacticName() << "] on ScheduleBlockNode [" << node->id() << "] func body:\n" << ir_sch_->GetModule().GetExprs().front(); }; + tactic->Init(&schedule_context_); schedule_block_graph_->DFSTopoWalk(ApplyTacticFunc); schedule_block_graph_->Update(*ir_sch_); VLOG(5) << "[End " << tactic->TacticName() @@ -96,6 +101,7 @@ IterativeSpaceInfo DynamicShapeGroupScheduler::ConstructIterSpaceInfo( std::unordered_map iter_var2value = analyzer::GetIterVarToValueOfSBlock(block); + // init iter info if (!reduce_iter_vars.empty()) { std::set reduce_loads = ir::ir_utils::CollectIRNodesWithoutTensor( block, @@ -161,6 +167,20 @@ IterativeSpaceInfo DynamicShapeGroupScheduler::ConstructIterSpaceInfo( info.rb_last_order.push_back(i); } } + // init total extents + ir::Expr sp_extent = ir::Expr(1); + ir::Expr rb_extent = ir::Expr(1); + for (const auto& axis : info.sp_space) { + const ir::Expr& extent = std::get<0>(axis); + sp_extent = sp_extent * extent; + } + for (const auto& axis : info.rb_space) { + const ir::Expr& extent = std::get<0>(axis); + rb_extent = rb_extent * extent; + } + info.total_sp_extent = sp_extent; + info.total_rb_extent = rb_extent; + return info; } diff --git a/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt b/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt index da964e770ae9ba..b12e669b8c2d07 100644 --- a/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt +++ b/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt @@ -1,5 +1,7 @@ core_gather_headers() gather_srcs(cinnapi_src SRCS align_iter_space_tactic.cc) +gather_srcs(cinnapi_src SRCS tile_tactic.cc) gather_srcs(cinnapi_src SRCS compute_inline_tactic.cc) +gather_srcs(cinnapi_src SRCS bind_cuda_tactic.cc) gather_srcs(cinnapi_src SRCS arrange_storage_tactic.cc) diff --git a/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h b/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h index 69729ce2bfb8c6..ef30f80ce470b2 100644 --- a/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h +++ b/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h @@ -15,7 +15,6 @@ #pragma once #include -#include #include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h" namespace cinn { diff --git a/paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.cc new file mode 100644 index 00000000000000..0da0ce3bcb396c --- /dev/null +++ b/paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h" +#include +#include "paddle/cinn/ir/ir.h" + +namespace cinn { +namespace ir { + +void BindCudaTactic::Init(ScheduleContext* context) { context_ = context; } + +const std::unordered_map + axis_type2bind_info = { + {IterativeSpaceInfo::AxisType::kCudaBlockX, "blockIdx.x"}, + {IterativeSpaceInfo::AxisType::kCudaBlockY, "blockIdx.y"}, + {IterativeSpaceInfo::AxisType::kCudaBlockZ, "blockIdx.z"}, + {IterativeSpaceInfo::AxisType::kCudaThreadX, "threadIdx.x"}, + {IterativeSpaceInfo::AxisType::kCudaThreadY, "threadIdx.y"}, + {IterativeSpaceInfo::AxisType::kCudaThreadZ, "threadIdx.z"}, +}; + +void BindCudaTactic::Apply(ir::IRSchedule* sch, const std::string& block_id) { + std::vector loops = sch->GetLoops(block_id); + int loop_idx = 0; + for (int i = 0; + i < context_->iter_space_info.sp_space.size() && loop_idx < loops.size(); + ++i, ++loop_idx) { + const auto& axis = context_->iter_space_info.sp_space[i]; + const IterativeSpaceInfo::AxisType& axis_type = std::get<1>(axis); + if (axis_type2bind_info.count(axis_type) != 0) { + sch->Bind(loops[loop_idx], axis_type2bind_info.at(axis_type)); + } + } + for (int i = 0; + i < context_->iter_space_info.rb_space.size() && loop_idx < loops.size(); + ++i, ++loop_idx) { + const auto& axis = context_->iter_space_info.rb_space[i]; + const IterativeSpaceInfo::AxisType& axis_type = std::get<1>(axis); + if (axis_type2bind_info.count(axis_type) != 0) { + sch->Bind(loops[loop_idx], axis_type2bind_info.at(axis_type)); + } + } +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h b/paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h new file mode 100644 index 00000000000000..b66c7d1fb802c0 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h @@ -0,0 +1,36 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h" + +namespace cinn { +namespace ir { + +class BindCudaTactic final : public ScheduleTactic { + public: + void Init(ScheduleContext* context) override; + + void Apply(ir::IRSchedule* sch, const std::string& block_id) override; + + std::string TacticName() const override { return "BindCudaTactic"; } + + private: + ScheduleContext* context_; +}; + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h b/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h index 4084c69bf493ae..05c258b82c47ce 100644 --- a/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h +++ b/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h @@ -36,6 +36,10 @@ struct IterativeSpaceInfo { std::vector> sp_space; // reduce or broadcast iterative space std::vector> rb_space; + // total sp extent + ir::Expr total_sp_extent; + // total rb extent + ir::Expr total_rb_extent; // original loop order with same iteration order as the memory order std::vector memory_consistent_order_space; // index that transform from memory consistent order to rb last order @@ -45,11 +49,19 @@ struct IterativeSpaceInfo { std::vector rb_last_order; }; +struct BucketInfo { + int sp_lower_bound = 0; + int sp_upper_bound = UINT_MAX; + int rb_lower_bound = 0; + int rb_upper_bound = UINT_MAX; +}; + struct ScheduleContext { std::unordered_set output_names; ScheduleBlockNode* global_master; IterativeSpaceInfo iter_space_info; Target target; + BucketInfo bucket_info; }; class ScheduleTactic { diff --git a/paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc new file mode 100644 index 00000000000000..3cace2636f2d39 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc @@ -0,0 +1,84 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/ir/group_schedule/tactic/tile_tactic.h" +#include "paddle/cinn/ir/ir.h" + +namespace cinn { +namespace ir { + +void TileTactic::Init(ScheduleContext* context) { + context_ = context; + // fake strategy + auto GetFirstFactor = [](int num) { + int factor = 1; + for (int i = num - 1; i >= 1; --i) { + if (num % i == 0) { + return i; + } + } + }; + + bool has_rb_iter = !context_->iter_space_info.rb_space.empty(); + bool has_sp_iter = !context_->iter_space_info.sp_space.empty(); + context_->iter_space_info.rb_space.clear(); + context_->iter_space_info.sp_space.clear(); + + if (has_sp_iter) { + int sp_factor = GetFirstFactor(context_->bucket_info.sp_lower_bound); + context_->iter_space_info.sp_space.emplace_back( + ir::Expr(context_->bucket_info.sp_lower_bound / sp_factor), + IterativeSpaceInfo::AxisType::kCudaBlockX); + context_->iter_space_info.sp_space.emplace_back( + ir::Expr(sp_factor), + has_rb_iter ? IterativeSpaceInfo::AxisType::kCudaThreadY + : IterativeSpaceInfo::AxisType::kCudaThreadX); + context_->iter_space_info.sp_space.emplace_back( + ir::Expr(-1), IterativeSpaceInfo::AxisType::kSerial); + } + + if (has_rb_iter) { + context_->iter_space_info.rb_space.emplace_back( + ir::Expr(context_->bucket_info.rb_lower_bound), + IterativeSpaceInfo::AxisType::kCudaThreadX); + context_->iter_space_info.rb_space.emplace_back( + ir::Expr(-1), IterativeSpaceInfo::AxisType::kSerial); + } +} + +void TileTactic::Apply(ir::IRSchedule* sch, const std::string& block_id) { + std::vector loops = sch->GetLoops(block_id); + CHECK(loops.size() == 1 || loops.size() == 2) + << "All loops must be unified as sp_loop or rb_loop."; + if (loops.size() == 2) { + std::vector rb_factors; + for (const auto& axis : context_->iter_space_info.rb_space) { + rb_factors.push_back(std::get<0>(axis)); + } + sch->Split(loops[1], rb_factors); + loops = sch->GetLoops(block_id); + VLOG(6) << "after split rb loop of " << block_id << ": " + << sch->GetModule().GetExprs()[0]; + } + std::vector sp_factors; + for (const auto& axis : context_->iter_space_info.sp_space) { + sp_factors.push_back(std::get<0>(axis)); + } + sch->Split(loops[0], sp_factors); + VLOG(6) << "after split sp loop of " << block_id << ": " + << sch->GetModule().GetExprs()[0]; +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/tile_tactic.h b/paddle/cinn/ir/group_schedule/tactic/tile_tactic.h new file mode 100644 index 00000000000000..8a6d2bb8dd7668 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/tactic/tile_tactic.h @@ -0,0 +1,36 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h" + +namespace cinn { +namespace ir { + +class TileTactic final : public ScheduleTactic { + public: + void Init(ScheduleContext* context) override; + + void Apply(ir::IRSchedule* sch, const std::string& block_id) override; + + std::string TacticName() const override { return "TileTactic"; } + + private: + ScheduleContext* context_; +}; + +} // namespace ir +} // namespace cinn From 58689d36b4b1e5a62e9467247c447d03a3a7423d Mon Sep 17 00:00:00 2001 From: xingmingyyj <135400902+xingmingyyj@users.noreply.github.com> Date: Fri, 5 Jan 2024 11:29:12 +0800 Subject: [PATCH 120/142] =?UTF-8?q?=E3=80=90PIR=20OpTest=20Fix=20No.8?= =?UTF-8?q?=E3=80=91=20fix=20test=5Fshuffle=5Fbatch=5Fop=20(#59631)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix test_shuffle_batch_op * fix --- .../fluid/pir/dialect/op_generator/ops_api_gen.py | 1 + paddle/fluid/pir/dialect/operator/ir/ops.yaml | 10 ++++++++++ .../pir/dialect/operator/ir/ops_backward.yaml | 10 ++++++++++ paddle/phi/api/yaml/op_compat.yaml | 7 +++++++ paddle/phi/infermeta/backward.cc | 9 +++++++++ paddle/phi/infermeta/backward.h | 5 +++++ paddle/phi/infermeta/binary.cc | 15 +++++++++++++++ paddle/phi/infermeta/binary.h | 9 +++++++++ test/white_list/pir_op_test_no_check_list | 1 + test/white_list/pir_op_test_white_list | 1 + 10 files changed, 68 insertions(+) diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index 79cbad13c0f56c..cf3cd898b7b732 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -134,6 +134,7 @@ 'seed', 'send_v2', 'shadow_feed', + 'shuffle_batch', 'sparse_momentum', 'soft_relu', 'uniform_random_batch_size_like', diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index b992c139b8543a..0ba8fac6b52be2 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -1160,6 +1160,16 @@ func: share_data param: [x] +- op : shuffle_batch + args : (Tensor x, Tensor seed, int startup_seed=0) + output : Tensor(out), Tensor(shuffle_idx), Tensor(seed_out) + infer_meta: + func: ShuffleBatchInferMeta + kernel: + func: shuffle_batch + data_type: x + backward : shuffle_batch_grad + - op : slice args : (Tensor input, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis) output : Tensor diff --git a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml index bf0b939267e1b6..43f271b7b7c49a 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml @@ -925,6 +925,16 @@ func: fused_elemwise_add_activation_grad optional : x, intermediate_out +- backward_op: shuffle_batch_grad + forward: shuffle_batch (Tensor x, Tensor seed, int startup_seed=0) -> Tensor(out), Tensor(shuffle_idx), Tensor(seed_out) + args: (Tensor shuffle_idx, Tensor out_grad,int startup_seed=0) + output : Tensor(x_grad) + infer_meta: + func: ShuffleBatchGradInferMeta + kernel: + func: shuffle_batch_grad + data_type : out_grad + - backward_op: unpool_grad forward: unpool (Tensor x, Tensor indices, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format) -> Tensor(out) args: (Tensor x, Tensor indices, Tensor out, Tensor out_grad, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index e605dab1543371..73c976116f1c52 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -2799,6 +2799,13 @@ out : Out xout : XOut +- op : shuffle_batch + backward: shuffle_batch_grad + inputs: + {x : X, seed : Seed} + outputs: + {out : Out, shuffle_idx : ShuffleIdx, seed_out : SeedOut} + - op : shuffle_channel backward : shuffle_channel_grad extra : diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index ee2388762668b3..79a5f291267347 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1039,6 +1039,15 @@ void ScatterNdAddGradInferMeta(const MetaTensor& index, } } +void ShuffleBatchGradInferMeta(const MetaTensor& shuffle_idx, + const MetaTensor& out_grad, + int startup_seed, + MetaTensor* x_grad) { + x_grad->share_dims(out_grad); + x_grad->share_lod(out_grad); + x_grad->set_dtype(out_grad.dtype()); +} + void SpectralNormGradInferMeta(const MetaTensor& weight, const MetaTensor& u, const MetaTensor& v, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 922bafed0add8c..cdbb8e399fe369 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -428,6 +428,11 @@ void ScatterNdAddGradInferMeta(const MetaTensor& index, MetaTensor* x_grad, MetaTensor* updates_grad); +void ShuffleBatchGradInferMeta(const MetaTensor& shuffle_idx, + const MetaTensor& out_grad, + int startup_seed, + MetaTensor* x_grad); + void SpectralNormGradInferMeta(const MetaTensor& weight, const MetaTensor& u, const MetaTensor& v, diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index b771fba031317a..5ffc2c45f9a23e 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -2832,6 +2832,21 @@ void SearchsortedInferMeta(const MetaTensor& sorted_sequence, } } +void ShuffleBatchInferMeta(const MetaTensor& x, + const MetaTensor& seed, + int startup_seed, + MetaTensor* out, + MetaTensor* shuffle_idx, + MetaTensor* seed_out + +) { + out->share_dims(x); + out->share_lod(x); + seed_out->share_dims(seed); + seed_out->share_lod(seed); + shuffle_idx->set_dims(phi::make_ddim({-1})); +} + void SequenceMaskInferMeta(const MetaTensor& x, const MetaTensor& max_len_tensor, int maxlen, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 82f5fc64d57a53..2f0626b6524358 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -455,6 +455,15 @@ void SequenceMaskInferMeta(const MetaTensor& x, int out_dtype, MetaTensor* y); +void ShuffleBatchInferMeta(const MetaTensor& x, + const MetaTensor& seed, + int startup_seed, + MetaTensor* out, + MetaTensor* shuffle_idx, + MetaTensor* seed_out + +); + void SoftmaxMaskFuseInferMeta(const MetaTensor& x, const MetaTensor& mask, MetaTensor* out); diff --git a/test/white_list/pir_op_test_no_check_list b/test/white_list/pir_op_test_no_check_list index 99e67a1dc2d027..00d8f054df131f 100644 --- a/test/white_list/pir_op_test_no_check_list +++ b/test/white_list/pir_op_test_no_check_list @@ -9,3 +9,4 @@ test_randperm_op test_seed_op test_uniform_random_bf16_op test_uniform_random_inplace_op +test_shuffle_batch_op diff --git a/test/white_list/pir_op_test_white_list b/test/white_list/pir_op_test_white_list index 045e8b4df94595..275c7013c4650a 100644 --- a/test/white_list/pir_op_test_white_list +++ b/test/white_list/pir_op_test_white_list @@ -269,6 +269,7 @@ test_sgd_op test_shape_mkldnn_op test_shape_op test_shard_index_op +test_shuffle_batch_op test_sigmoid_cross_entropy_with_logits_op test_sign_op test_size_op From 2033381298067f972afa16301153db20244c9ac0 Mon Sep 17 00:00:00 2001 From: xingmingyyj <135400902+xingmingyyj@users.noreply.github.com> Date: Fri, 5 Jan 2024 11:29:22 +0800 Subject: [PATCH 121/142] =?UTF-8?q?=E3=80=90PIR=20OpTest=20Fix=20No.14?= =?UTF-8?q?=E3=80=91=20fix=20test=5Fnce=20(#60255)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix test_nce * fix test_nce * Update ops.yaml * fix * Update utils.cc * Update ops.yaml --- .../pir/dialect/op_generator/ops_api_gen.py | 1 + paddle/fluid/pir/dialect/operator/ir/ops.yaml | 11 +++ .../pir/dialect/operator/ir/ops_backward.yaml | 12 +++ .../fluid/pir/dialect/operator/utils/utils.cc | 2 + paddle/phi/api/yaml/op_compat.yaml | 7 ++ paddle/phi/infermeta/backward.cc | 27 ++++++ paddle/phi/infermeta/backward.h | 7 ++ paddle/phi/infermeta/multiary.cc | 92 +++++++++++++++++++ paddle/phi/infermeta/multiary.h | 21 +++++ test/white_list/pir_op_test_white_list | 1 + 10 files changed, 181 insertions(+) diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index cf3cd898b7b732..78904eefc51004 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -127,6 +127,7 @@ 'fused_scale_bias_add_relu', 'fused_dconv_drelu_dbn', 'fused_dot_product_attention', + 'nce', 'lars_momentum', 'recv_v2', 'rnn_', diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 0ba8fac6b52be2..6c8b4f6d3fdcf6 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -1481,6 +1481,17 @@ data_type: param optional: master_param, master_param_out +- op: nce + args: (Tensor input, Tensor label, Tensor weight, Tensor bias, Tensor sample_weight, Tensor custom_dist_probs, Tensor custom_dist_alias, Tensor custom_dist_alias_probs, int num_total_classes, int[] custom_neg_classes={}, int num_neg_samples=10, int sampler=0, int seed=0, bool is_sparse=false, bool remote_prefetch=false, bool is_test=false) + output: Tensor(cost), Tensor(sample_logits), Tensor(sample_labels) + infer_meta: + func: NceInferMeta + kernel: + func: nce + data_type: input + optional: bias, sample_weight, custom_dist_probs, custom_dist_alias, custom_dist_alias_probs + backward: nce_grad + - op: number_count args: (Tensor numbers, int upper_range) output: Tensor(out) diff --git a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml index 43f271b7b7c49a..90221982ebbddf 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml @@ -503,6 +503,18 @@ func : multiply_triple_grad optional : fwd_grad_grad_x, fwd_grad_grad_y, grad_x_grad, grad_y_grad, grad_grad_out_grad +- backward_op : nce_grad + forward: nec (Tensor input, Tensor label, Tensor weight, Tensor bias, Tensor sample_weight, Tensor custom_dist_probs, Tensor custom_dist_alias, Tensor custom_dist_alias_probs, int num_total_classes, int num_neg_samples=10, int sampler=0, int seed=0, bool is_sparse=false, bool remote_prefetch=false, bool is_test=false) -> Tensor(cost), Tensor(sample_logits), Tensor(sample_labels) + args: (Tensor input, Tensor label, Tensor bias, Tensor weight, Tensor sample_logits, Tensor sample_labels, Tensor sample_weight, Tensor custom_dist_probs, Tensor custom_dist_alias, Tensor custom_dist_alias_probs, Tensor cost_grad, int num_total_classes, int[] custom_neg_classes={}, int num_neg_samples=10, int sampler=0, int seed=0, bool is_sparse=false, bool remote_prefetch=false, bool is_test=false) + output: Tensor(input_grad), Tensor(bias_grad), Tensor(weight_grad) + infer_meta: + func: NceGradInferMeta + param: [input, bias, weight] + kernel: + func: nce_grad + data_type: input + optional: bias, sample_weight, custom_dist_probs, custom_dist_alias, custom_dist_alias_probs + - backward_op : norm_grad forward : norm (Tensor x, int axis, float epsilon, bool is_test) -> Tensor(out), Tensor(norm) args : (Tensor x, Tensor norm, Tensor out_grad, int axis, float epsilon, bool is_test) diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index ebc1615a16d51a..68b8edbd6da8a9 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -59,6 +59,8 @@ const std::unordered_set LegacyOpList = { RowConvGradOp::name(), SoftReluOp::name(), SoftReluGradOp::name(), + NceOp::name(), + NceGradOp::name(), CReduceMinOp::name()}; const std::unordered_set OneDNNLegacyOpList = {}; diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 73c976116f1c52..26e180c95405a9 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -3517,6 +3517,13 @@ outputs : out : Out +- op: nce + backward: nce_grad + inputs: + {input : Input, label : Label, weight : Weight, bias : Bias, sample_weight : SampleWeight, custom_dist_probs : CustomDistProbs, custom_dist_alias : CustomDistAlias, custom_dist_alias_probs : CustomDistAliasProbs} + outputs: + {cost : Cost, sample_logits : SampleLogits, sample_labels : SampleLabels} + - op: number_count inputs : {numbers: numbers} diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 79a5f291267347..6d33afeb94898a 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -818,6 +818,33 @@ void NanmedianGradInferMeta(const MetaTensor& x, x_grad->set_dtype(x.dtype()); } +void NceGradInferMeta(const MetaTensor& input, + const MetaTensor& bias, + const MetaTensor& weight, + MetaTensor* input_grad, + MetaTensor* bias_grad, + MetaTensor* weight_grad + +) { + auto x_dims = input.dims(); + if (input_grad != nullptr) { + input_grad->set_dims(x_dims); + input_grad->set_dtype(input.dtype()); + } + + auto w_dims = weight.dims(); + if (weight_grad) { + weight_grad->set_dims(w_dims); + weight_grad->set_dtype(weight.dtype()); + } + + auto bias_dims = bias.dims(); + if (bias_grad) { + bias_grad->set_dims(bias_dims); + bias_grad->set_dtype(bias.dtype()); + } +} + void NllLossGradInferMeta(const MetaTensor& x, const MetaTensor& label, const MetaTensor& weight, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index cdbb8e399fe369..3112fa8b9ddad4 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -361,6 +361,13 @@ void NanmedianGradInferMeta(const MetaTensor& x, bool keep_dim, MetaTensor* x_grad); +void NceGradInferMeta(const MetaTensor& input, + const MetaTensor& bias, + const MetaTensor& weight, + MetaTensor* input_grad, + MetaTensor* bias_grad, + MetaTensor* weight_grad); + void NllLossGradInferMeta(const MetaTensor& input, const MetaTensor& label, const MetaTensor& weight, diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 5b9708b38a17e1..1f72c07c59826b 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -3178,6 +3178,98 @@ void MultiplexInferMeta(const std::vector& ins, out->set_dtype(ins[0]->dtype()); } +void NceInferMeta(const MetaTensor& input, + const MetaTensor& label, + const MetaTensor& weight, + const MetaTensor& bias, + const MetaTensor& sample_weight, + const MetaTensor& custom_dist_probs, + const MetaTensor& custom_dist_alias, + const MetaTensor& custom_dist_alias_probs, + int num_total_classes, + const std::vector& custom_neg_classes, + int num_neg_samples, + int sampler, + int seed, + bool is_sparse, + bool remote_prefetch, + bool is_test, + MetaTensor* cost, + MetaTensor* sample_logits, + MetaTensor* sample_labels, + MetaConfig config) { + auto x_dims = input.dims(); + auto label_dims = label.dims(); + if (config.is_runtime || (x_dims[0] > 0 && label_dims[0] > 0)) { + PADDLE_ENFORCE_EQ( + x_dims[0], + label_dims[0], + phi::errors::InvalidArgument( + "The first dimension of Input(Input) and Input(Label) should be " + "equal in runtime. But received: Input(Input)'s shape = [%s] " + "with 1st dim = %d, Input(Label)'s shape = [%s] with 1st dim = " + "%d.", + x_dims, + x_dims[0], + label_dims, + label_dims[0])); + } + int num_true_classes = + static_cast(label_dims.size() == 2 ? label_dims[1] : 1); + if (bias) { + PADDLE_ENFORCE_EQ( + weight.dims()[0], + bias.dims()[0], + phi::errors::InvalidArgument( + "The first dimension of Input(Weight) and Input(Bias) " + "should be equal. But received: Input(Weight)'s shape = [%s] " + "with 1st dim = %d, and Input(Bias)'s shape = [%s] with 1st dim " + "= %d.", + weight.dims(), + weight.dims()[0], + bias.dims(), + bias.dims()[0])); + } + + PADDLE_ENFORCE_EQ( + num_total_classes, + weight.dims()[0], + phi::errors::InvalidArgument( + "The number of total classes should be equal to the first " + "dimension of Input(Weight). But received: Attr(num_total_classes) " + "= %d, Input(Weight)'s shape = [%s] with 1st dim = %d.", + num_total_classes, + weight.dims(), + weight.dims()[0])); + if (custom_neg_classes.size() > 0) { + PADDLE_ENFORCE_EQ( + custom_neg_classes.size(), + static_cast(num_neg_samples), + phi::errors::InvalidArgument( + "The size of Attr(custom_neg_classes) should be equal " + "to the number of negative samples. But received: " + "custom_neg_classes.size() = %d, num_neg_samples = %d.", + custom_neg_classes.size(), + num_neg_samples)); + } + // set dims of output(Out) + std::vector out_dims; + out_dims.push_back(x_dims[0]); + out_dims.push_back(1); + cost->set_dims(common::make_ddim(out_dims)); + cost->set_dtype(DataType::FLOAT32); + + if (!is_test) { + // set dims of output(SampleOut) + std::vector sample_out_dims; + sample_out_dims.push_back(x_dims[0]); + sample_out_dims.push_back( + (num_true_classes == -1) ? -1 : (num_neg_samples + num_true_classes)); + sample_logits->set_dims(common::make_ddim(sample_out_dims)); + sample_labels->set_dims(common::make_ddim(sample_out_dims)); + } +} + void PsroiPoolInferMeta(const MetaTensor& x, const MetaTensor& rois, const MetaTensor& rois_num, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index f51c3dacb19095..ebda904e4d4a7f 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -569,6 +569,27 @@ void MultiplexInferMeta(const std::vector& ins, const MetaTensor& ids, MetaTensor* out); +void NceInferMeta(const MetaTensor& input, + const MetaTensor& label, + const MetaTensor& weight, + const MetaTensor& bias, + const MetaTensor& sample_weight, + const MetaTensor& custom_dist_probs, + const MetaTensor& custom_dist_alias, + const MetaTensor& custom_dist_alias_probs, + int num_total_classes, + const std::vector& custom_neg_classes, + int num_neg_samples, + int sampler, + int seed, + bool is_sparse, + bool remote_prefetch, + bool is_test, + MetaTensor* cost, + MetaTensor* sample_logits, + MetaTensor* sample_labels, + MetaConfig config = MetaConfig()); + void PsroiPoolInferMeta(const MetaTensor& x, const MetaTensor& rois, const MetaTensor& rois_num, diff --git a/test/white_list/pir_op_test_white_list b/test/white_list/pir_op_test_white_list index 275c7013c4650a..b277a75f91605c 100644 --- a/test/white_list/pir_op_test_white_list +++ b/test/white_list/pir_op_test_white_list @@ -210,6 +210,7 @@ test_multinomial_op test_multiplex_op test_mv_op test_nanmedian +test_nce test_nearest_interp_mkldnn_op test_nearest_interp_v2_op test_nextafter_op From 1874d1c5eddd90a1641943b722563a796db22d6a Mon Sep 17 00:00:00 2001 From: xingmingyyj <135400902+xingmingyyj@users.noreply.github.com> Date: Fri, 5 Jan 2024 11:29:32 +0800 Subject: [PATCH 122/142] =?UTF-8?q?=E3=80=90PIR=20OpTest=20Fix=20No.19?= =?UTF-8?q?=E3=80=91=20fix=20test=5Fftrl=5Fop=20(#60329)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix test_ftrl_op * fix --- .../pir/dialect/op_generator/ops_api_gen.py | 1 + paddle/fluid/pir/dialect/operator/ir/ops.yaml | 9 ++++ .../fluid/pir/dialect/operator/utils/utils.cc | 1 + paddle/phi/api/yaml/op_compat.yaml | 6 +++ paddle/phi/infermeta/multiary.cc | 42 +++++++++++++++++++ paddle/phi/infermeta/multiary.h | 12 ++++++ test/white_list/pir_op_test_white_list | 1 + 7 files changed, 72 insertions(+) diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index 78904eefc51004..af224cb5be8ab3 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -120,6 +120,7 @@ 'decayed_adagrad', 'dpsgd', 'embedding_grad_sparse', + 'ftrl', 'fused_batch_norm_act_', 'fused_bn_add_activation_', 'fused_elemwise_add_activation', diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 6c8b4f6d3fdcf6..221aeb6c7dfa30 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -1437,6 +1437,15 @@ func: dpsgd data_type: param +- op: ftrl + args: (Tensor param, Tensor squared_accumulator, Tensor linear_accumulator, Tensor grad, Tensor learning_rate, float l1=0.0f, float l2=0.0f, float lr_power=-0.5f) + output: Tensor(param_out), Tensor(squared_accum_out), Tensor(linear_accum_out) + infer_meta: + func: FtrlInferMeta + kernel: + func: ftrl + data_type: param + - op: fused_attention args: (Tensor x, Tensor ln_scale, Tensor ln_bias, Tensor qkv_weight, Tensor qkv_bias, Tensor cache_kv, Tensor src_mask, Tensor out_linear_weight, Tensor out_linear_bias, Tensor ln_scale_2, Tensor ln_bias_2, int num_heads, bool transpose_qkv_wb, bool pre_layer_norm, float epsilon, float attn_dropout_rate, bool is_test, bool attn_dropout_fix_seed, int attn_dropout_seed, str attn_dropout_implementation, float dropout_rate, bool dropout_fix_seed, int dropout_seed, str dropout_implementation, float ln_epsilon, bool add_residual, int ring_id) output: Tensor(ln_mean), Tensor(ln_var), Tensor(ln_out), Tensor(qkv_out), Tensor(qkv_bias_out), Tensor(transpose_out_2), Tensor(qk_out), Tensor(qktv_out), Tensor(softmax_out), Tensor(attn_dropout_mask_out), Tensor(attn_dropout_out), Tensor(src_mask_out), Tensor(fmha_out), Tensor(out_linear_out), Tensor(dropout_mask_out), Tensor(ln_mean_2), Tensor(ln_var_2), Tensor(bias_dropout_residual_out), Tensor(cache_kv_out), Tensor(out) diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index 68b8edbd6da8a9..d0e800f157c5f0 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -36,6 +36,7 @@ const std::unordered_set LegacyOpList = { CBroadcast_Op::name(), CSyncCalcStream_Op::name(), CSyncCommStream_Op::name(), + FtrlOp::name(), FusedElemwiseAddActivationOp::name(), FusedElemwiseAddActivationGradOp::name(), FusedGemmEpilogueOp::name(), diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 26e180c95405a9..7071df37e4aa58 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -3443,6 +3443,12 @@ inputs: {x: X} outputs: {out: Out} +- op: ftrl + inputs: + {param: Param, squared_accumulator: SquaredAccumulator, linear_accumulator: LinearAccumulator, grad: Grad, learning_rate: LearningRate} + outputs: + {param_out: ParamOut, squared_accum_out: SquaredAccumOut, linear_accum_out: LinearAccumOut} + - op: full_batch_size_like (fill_constant_batch_size_like) inputs: {input: Input} diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 1f72c07c59826b..77af8e5c19f94c 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1525,6 +1525,48 @@ void EditDistanceInferMeta(const MetaTensor& hyps, sequencenum->set_dtype(DataType::FLOAT32); } +void FtrlInferMeta(const MetaTensor& param, + const MetaTensor& squared_accumulator, + const MetaTensor& linear_accumulator, + const MetaTensor& grad, + const MetaTensor& learning_rate, + float l1, + float l2, + float lr_power, + MetaTensor* param_out, + MetaTensor* squared_accum_out, + MetaTensor* linear_accum_out) { + auto param_dim = param.dims(); + PADDLE_ENFORCE_EQ(param_dim, + grad.dims(), + phi::errors::InvalidArgument( + "Two input of FTRL Op's dimension must be same, but " + "param_dim is %d, Grad is %d", + param_dim, + grad.dims())); + + auto lr_dim = learning_rate.dims(); + PADDLE_ENFORCE_NE(common::product(lr_dim), + 0, + phi::errors::InvalidArgument( + "Maybe the Input variable LearningRate has not " + "been initialized. You may need to confirm " + "if you put exe.run(startup_program) " + "after optimizer.minimize function.")); + PADDLE_ENFORCE_EQ(common::product(lr_dim), + 1, + phi::errors::InvalidArgument( + "Learning Rate should be a scalar, but got %d", + common::product(lr_dim))); + + param_out->set_dims(param_dim); + param_out->set_dtype(param.dtype()); + squared_accum_out->set_dims(param_dim); + squared_accum_out->set_dtype(param.dtype()); + linear_accum_out->set_dims(param_dim); + linear_accum_out->set_dtype(param.dtype()); +} + void FusedBatchNormActInferMeta(const MetaTensor& x, const MetaTensor& scale, const MetaTensor& bias, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index ebda904e4d4a7f..0c18d27836adc4 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -308,6 +308,18 @@ void EditDistanceInferMeta(const MetaTensor& hyps, MetaTensor* sequencenum, MetaTensor* out); +void FtrlInferMeta(const MetaTensor& param, + const MetaTensor& squared_accumulator, + const MetaTensor& linear_accumulator, + const MetaTensor& grad, + const MetaTensor& learning_rate, + float l1, + float l2, + float lr_power, + MetaTensor* param_out, + MetaTensor* squared_accum_out, + MetaTensor* linear_accum_out); + void FusedBatchNormActInferMeta(const MetaTensor& x, const MetaTensor& scale, const MetaTensor& bias, diff --git a/test/white_list/pir_op_test_white_list b/test/white_list/pir_op_test_white_list index b277a75f91605c..d696ab19863a76 100644 --- a/test/white_list/pir_op_test_white_list +++ b/test/white_list/pir_op_test_white_list @@ -121,6 +121,7 @@ test_fmax_op test_fmin_op test_fold_op test_frame_op +test_ftrl_op test_full_like_op test_fused_attention_op test_fused_attention_op_api From a9712d14f7428cb4adc44726edfb502c1b5d2672 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Fri, 5 Jan 2024 12:32:46 +0800 Subject: [PATCH 123/142] [auto parallel] Lazy init for MP. Add reshard infer shape. (#60563) --- .../distributed/auto_parallel/dist_tensor.cc | 31 +++++---- .../auto_parallel/reshard/reshard_utils.cc | 23 +++++++ .../auto_parallel/reshard/reshard_utils.h | 4 ++ python/paddle/nn/initializer/constant.py | 17 ++++- .../semi_auto_parallel_lazy_init.py | 63 ++++++++++++++++++- .../test_semi_auto_parallel_lazy_init.py | 2 +- 6 files changed, 124 insertions(+), 16 deletions(-) diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc index fff9af10339a60..c41effe6c85220 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc @@ -174,17 +174,26 @@ DistTensor::DistTensor(const std::shared_ptr& global_value, // uninitialized tensor only with dist_tensor_meta_. if (IsCurRankInMesh(process_mesh)) { if (!dist_attr_.is_replicated()) { - value_ = std::make_shared(); - // 1. create replicated global tensor - TensorDistAttr replicated_dist_attr( - common::vectorize(global_value->dims())); - replicated_dist_attr.set_process_mesh(process_mesh); - DistTensor replicated_tensor(global_value, replicated_dist_attr); - - // 2. reshard from replicated to other state - auto* func = ChooseProperReshardFunction(replicated_tensor, dist_attr_); - auto* dev_ctx = DeviceContextPool::Instance().Get(global_value->place()); - func->Eval(dev_ctx, replicated_tensor, dist_attr_, this); + if (global_value->initialized()) { + value_ = std::make_shared(); + // 1. create replicated global tensor + TensorDistAttr replicated_dist_attr( + common::vectorize(global_value->dims())); + replicated_dist_attr.set_process_mesh(process_mesh); + DistTensor replicated_tensor(global_value, replicated_dist_attr); + + // 2. reshard from replicated to other state + auto* func = ChooseProperReshardFunction(replicated_tensor, dist_attr_); + auto* dev_ctx = + DeviceContextPool::Instance().Get(global_value->place()); + func->Eval(dev_ctx, replicated_tensor, dist_attr_, this); + } else { + // For lazy init, the global value is an uninitialized tensor. + // Just infer the local shape of the dist tensor. + value_ = global_value; + value_->Resize( + InferShapeForReshardFromReplicate(global_value, dist_attr_)); + } } else { value_ = global_value; } diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.cc b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.cc index e7a1ec15da307a..c4ae44120ddcc5 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.cc @@ -180,5 +180,28 @@ phi::DeviceContext* GetDistTensorDeviceContext( return phi::DeviceContextPool::Instance().Get(place); } +phi::DDim InferShapeForReshardFromReplicate( + const std::shared_ptr& global_value, + const TensorDistAttr& dist_attr) { + phi::DDim out_dim = global_value->dims(); + auto coord_id = GetCurRankCoordInMesh(dist_attr.process_mesh()); + for (int tensor_axis = 0; tensor_axis < global_value->dims().size(); + ++tensor_axis) { + if (dist_attr.is_shard(-1, tensor_axis)) { + for (int mesh_axis = 0; mesh_axis < dist_attr.process_mesh().ndim(); + ++mesh_axis) { + if (dist_attr.is_shard(mesh_axis, tensor_axis)) { + // handle the shard axis + int64_t global_shape = out_dim[tensor_axis]; + int64_t mesh_size = dist_attr.process_mesh().dim_size(mesh_axis); + auto balance_shard = BalancedSplit(global_shape, mesh_size); + out_dim[tensor_axis] = balance_shard[coord_id[mesh_axis]]; + } + } + } + } + return out_dim; +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h index 022dc065980641..a7d0177407b618 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h @@ -71,6 +71,10 @@ std::vector BalancedSplit(int64_t total_nums, int64_t num_of_pieces); CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx, const std::vector& process_ids); +phi::DDim InferShapeForReshardFromReplicate( + const std::shared_ptr& global_value, + const TensorDistAttr& dist_attr); + #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #define RESHARD_FUNCTOR_IMPL(dev_ctx, fn_name, dtype, ...) \ do { \ diff --git a/python/paddle/nn/initializer/constant.py b/python/paddle/nn/initializer/constant.py index d45c784c20b8cd..91a283dac39374 100644 --- a/python/paddle/nn/initializer/constant.py +++ b/python/paddle/nn/initializer/constant.py @@ -72,9 +72,20 @@ def forward(self, var, block=None): if self._force_cpu: place = core.CPUPlace() if in_dygraph_mode(): - _C_ops.full_( - var, var.shape, float(self._value), var.dtype, place - ) + if isinstance(var, framework.EagerParamBase) and var.is_dist(): + out_var = _C_ops.full( + var._local_shape, float(self._value), var.dtype, place + ) + out_var = ( + paddle.distributed.auto_parallel.api.dtensor_from_local( + out_var, var.process_mesh, var.placements + ) + ) + out_var._share_underline_tensor_to(var) + else: + _C_ops.full_( + var, var.shape, float(self._value), var.dtype, place + ) return None else: return _C_ops.full( diff --git a/test/auto_parallel/semi_auto_parallel_lazy_init.py b/test/auto_parallel/semi_auto_parallel_lazy_init.py index cfeff65b2733a1..2b6f2f96ddf8db 100644 --- a/test/auto_parallel/semi_auto_parallel_lazy_init.py +++ b/test/auto_parallel/semi_auto_parallel_lazy_init.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import logging import os import paddle @@ -34,6 +34,11 @@ def __init__(self): self._mesh_bias = dist.ProcessMesh([1], dim_names=["x"]) self._placements_weight = [dist.Replicate()] self._placements_bias = [dist.Replicate()] + elif self._placements_type == "MP": + self._mesh_weight = dist.ProcessMesh([0, 1], dim_names=["x"]) + self._mesh_bias = dist.ProcessMesh([0, 1], dim_names=["x"]) + self._placements_weight = [dist.Shard(1)] + self._placements_bias = [dist.Shard(0)] def test_different_xavier(self): paddle.distributed.auto_parallel.parallel_manual_seed(self._seed) @@ -53,6 +58,31 @@ def test_different_xavier(self): linear.bias = dist.shard_tensor( linear.bias, self._mesh_bias, self._placements_bias ) + for param in linear.parameters(): + param.initialize() + logging.info(param) + + def test_constant(self): + paddle.distributed.auto_parallel.parallel_manual_seed(self._seed) + weight_attr = paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(2.0) + ) + bias_attr = paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(1.0) + ) + with LazyGuard(): + linear = paddle.nn.Linear( + 10, 10, weight_attr=weight_attr, bias_attr=bias_attr + ) + linear.weight = dist.shard_tensor( + linear.weight, self._mesh_weight, self._placements_weight + ) + linear.bias = dist.shard_tensor( + linear.bias, self._mesh_bias, self._placements_bias + ) + for param in linear.parameters(): + param.initialize() + logging.info(param) def test_placements(self): paddle.distributed.auto_parallel.parallel_manual_seed(self._seed) @@ -67,6 +97,7 @@ def test_placements(self): for param in linear.parameters(): assert not param._is_initialized() param.initialize() + logging.info(param) if self._placements_type == "DP": assert linear.weight._is_initialized() @@ -93,10 +124,40 @@ def test_placements(self): else: assert not linear.weight._is_initialized() assert linear.bias._is_initialized() + elif self._placements_type == "MP": + assert linear.weight._is_initialized() + assert linear.bias._is_initialized() + assert linear.weight._local_shape == [10, 5] + assert linear.bias._local_shape == [5] + + def test_unbalance_mp(self): + paddle.distributed.auto_parallel.parallel_manual_seed(self._seed) + with LazyGuard(): + linear = paddle.nn.Linear(11, 11) + linear.weight = dist.shard_tensor( + linear.weight, self._mesh_weight, self._placements_weight + ) + linear.bias = dist.shard_tensor( + linear.bias, self._mesh_bias, self._placements_bias + ) + for param in linear.parameters(): + assert not param._is_initialized() + param.initialize() + assert param._is_initialized() + + if dist.get_rank() == 0: + assert linear.weight._local_shape == [11, 6] + assert linear.bias._local_shape == [6] + else: + assert linear.weight._local_shape == [11, 5] + assert linear.bias._local_shape == [5] def run_test_case(self): self.test_placements() self.test_different_xavier() + self.test_constant() + if self._placements_type == "MP": + self.test_unbalance_mp() if __name__ == '__main__': diff --git a/test/auto_parallel/test_semi_auto_parallel_lazy_init.py b/test/auto_parallel/test_semi_auto_parallel_lazy_init.py index b55423184b9188..1c1d4a2ef5c128 100644 --- a/test/auto_parallel/test_semi_auto_parallel_lazy_init.py +++ b/test/auto_parallel/test_semi_auto_parallel_lazy_init.py @@ -29,7 +29,7 @@ def setUp(self): } self._changeable_envs = { "backend": ["cpu", "gpu"], - "_placements_type": ["DP", "PP"], + "_placements_type": ["DP", "PP", "MP"], } def test_lazy_init(self): From 75e62a25b5ed81ffce61db204ce727763175951e Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 5 Jan 2024 12:43:16 +0800 Subject: [PATCH 124/142] [PIR] Add unittest for Operation::Clone and Group::Clone (#60577) --- paddle/cinn/hlir/framework/pir/group.h | 12 +-- paddle/pir/core/operation.cc | 4 +- test/cpp/pir/cinn/group_op_test.cc | 110 +++++++++++++++++++++++++ 3 files changed, 118 insertions(+), 8 deletions(-) diff --git a/paddle/cinn/hlir/framework/pir/group.h b/paddle/cinn/hlir/framework/pir/group.h index 2cd3b9b9deddaa..535260c2ee96b9 100644 --- a/paddle/cinn/hlir/framework/pir/group.h +++ b/paddle/cinn/hlir/framework/pir/group.h @@ -68,22 +68,22 @@ struct Group { // Mapper from original to new ops. std::unordered_map<::pir::Operation*, ::pir::Operation*> ops_mapper; ::pir::CloneOptions clone_options(false, true); - for (auto* op : this->ops_set) { + for (auto* op : ops) { + VLOG(4) << "clone op :" << op->name(); auto* new_op = op->Clone(ir_mapping, clone_options); - // NOTE(dev): Must call MoveTo to deal with ownership, otherwise it + // NOTE(dev): Must call block.insert to deal with ownership, otherwise it // will lead memory-leak. - new_op->MoveTo(target_block, target_block->end()); + target_block->insert(target_block->end(), new_op); new_ops.push_back(new_op); ops_mapper[op] = new_op; } // Construct Base information for new Group auto new_group = std::make_shared(new_ops); - this->CollectOps(); for (auto& iter : this->input_ops) { - new_group->input_ops[ops_mapper[iter.first]] = iter.second; + new_group->input_ops[ops_mapper.at(iter.first)] = iter.second; } for (auto* op : this->output_ops) { - new_group->output_ops.insert(ops_mapper[op]); + new_group->output_ops.insert(ops_mapper.at(op)); } return new_group; diff --git a/paddle/pir/core/operation.cc b/paddle/pir/core/operation.cc index 0a8e26d788ca15..4d14213dd9f910 100644 --- a/paddle/pir/core/operation.cc +++ b/paddle/pir/core/operation.cc @@ -138,8 +138,8 @@ Operation *Operation::Create(const std::vector &inputs, } Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) { - IR_ENFORCE(options.IsCloneRegions() || num_regions_ > 0, - "Operation CloneOperands is unimplemented currently."); + IR_ENFORCE(!options.IsCloneRegions() || num_regions_ <= 0, + "Operation CloneRegions is unimplemented currently."); IR_ENFORCE(num_successors_ == 0, "Operation::Clone is not unimplemented for multiple successors."); diff --git a/test/cpp/pir/cinn/group_op_test.cc b/test/cpp/pir/cinn/group_op_test.cc index 7fcc8ae6bef317..9f9643c85f4a85 100644 --- a/test/cpp/pir/cinn/group_op_test.cc +++ b/test/cpp/pir/cinn/group_op_test.cc @@ -20,6 +20,7 @@ #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/cinn_group_lowering_pass.h" +#include "paddle/cinn/hlir/framework/pir/group.h" #include "paddle/fluid/framework/new_executor/interpretercore.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" @@ -243,3 +244,112 @@ TEST(GroupOp, CINNLowering) { EXPECT_EQ(res2, true); EXPECT_EQ(res3, true); } + +class GroupOpPattern : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + using Group = cinn::hlir::framework::pir::Group; + + bool MatchAndRewrite(cinn::dialect::GroupOp group_op, + pir::PatternRewriter& rewriter) const override { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + auto* program = group_op->GetParentProgram(); + ::pir::Builder builder = ::pir::Builder(ctx, program->block()); + VLOG(4) << "Before GroupOpPattern: " << *program; + std::vector<::pir::Operation*> group_ops = group_op.ops(); + auto yeild_op = group_ops.back(); + std::vector<::pir::Type> output_type{yeild_op->operand_source(0).type()}; + + // construct hlir::Group + Group group({group_ops.begin(), group_ops.end() - 1}); + group.input_ops[group_ops[0]] = 0; // first tan + auto last_op_idx = group_ops.size() - 2; + group.output_ops.insert(group_ops[last_op_idx]); // last relu + + // clone group and sync their op into new GroupOp + builder.SetInsertionPointAfter(group_op.operation()); + auto new_group_op = builder.Build(output_type); + + // prepare IrMapping + ::pir::IrMapping ir_mapping; + auto depend_value = group_ops[0]->operand_source(0); + ir_mapping.Add(depend_value, depend_value); + std::shared_ptr new_group = + group.Clone(new_group_op.block(), ir_mapping); + + EXPECT_EQ(new_group->ops.size(), group.ops.size()); + EXPECT_EQ(new_group->input_ops.size(), group.input_ops.size()); + EXPECT_EQ(new_group->output_ops.size(), group.output_ops.size()); + + // Add yield op + builder.SetInsertionPointToBlockEnd(new_group_op.block()); + std::vector<::pir::Value> yield_inputs{ + new_group_op.ops().back()->result(0)}; + builder.Build<::pir::YieldOp>(yield_inputs); + EXPECT_EQ(new_group_op.ops().size(), group_ops.size()); + + // replace result UD between GroupOp + rewriter.ReplaceAllUsesWith(group_op->result(0), new_group_op->result(0)); + rewriter.EraseOp(group_op); + VLOG(4) << "After GroupOpPattern.EraseOp: " << *program; + return true; + } +}; + +class TestGroupClonePass : public pir::PatternRewritePass { + public: + TestGroupClonePass() : pir::PatternRewritePass("test_group_clone", 1) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { + pir::RewritePatternSet ps(context); + ps.Add(context); + + return ps; + } + + bool CanApplyOn(pir::Operation* op) const override { + return op->isa() && op->num_regions() > 0; + } +}; + +std::shared_ptr<::pir::Program> BuildSingleGroupProgram() { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect<::pir::ControlFlowDialect>(); + + auto program = std::make_shared<::pir::Program>(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program->block()); + const std::vector shape = {64, 128}; + // full op + auto full_x = builder.Build( + shape, 0.5, phi::DataType::FLOAT32, phi::GPUPlace()); + + // group op + auto group_op = builder.Build( + CreateDenseTensorTypes(common::make_ddim(shape))); + pir::Block* block = group_op.block(); + builder.SetInsertionPointToBlockEnd(block); + + auto tan_op_x = builder.Build(full_x->result(0)); + auto relu_op_x = builder.Build(tan_op_x->result(0)); + auto tan_op_y = builder.Build(relu_op_x->result(0)); + auto relu_op_y = builder.Build(tan_op_y->result(0)); + builder.Build<::pir::YieldOp>(std::vector<::pir::Value>{relu_op_y.out()}); + + // tan op + builder.SetInsertionPointToBlockEnd(program->block()); + auto final_op = builder.Build(group_op->result(0)); + + return program; +} + +TEST(Group, Clone) { + // Step 1: Construct pir::Program + std::shared_ptr<::pir::Program> program = BuildSingleGroupProgram(); + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ::pir::PassManager pm(ctx); + // Step 2: Run TestGroupClonePass + pm.AddPass(std::make_unique()); + pm.Run(program.get()); +} From 488f367b298b8cb19b45cb6b10dd9c56e4bb2402 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Fri, 5 Jan 2024 13:15:29 +0800 Subject: [PATCH 125/142] [PIR] dce pass disable custom op (#60578) --- .../instruction/phi_kernel_instruction.cc | 6 +++-- .../transforms/dead_code_elimination_pass.cc | 22 +++++++++++-------- .../test_pir_matmul_scale_fuse_pass.py | 5 ----- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc index ed5bee9ce87772..094b15b8bc2dbe 100644 --- a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc @@ -177,12 +177,14 @@ PhiKernelInstruction::~PhiKernelInstruction() { } void PhiKernelInstruction::Run() { + VLOG(6) << "Begin run op " << phi_op_name_ << " infer meta."; if (infer_meta_interface_) { infer_meta_interface_->infer_meta_(&(infer_meta_context_)); } - VLOG(6) << "Run op " << phi_op_name_ << " infer meta."; + VLOG(6) << "End run op " << phi_op_name_ << " infer meta."; + VLOG(6) << "Begin run op " << phi_op_name_ << " kernel."; (*(phi_kernel_))(&(kernel_context_)); - VLOG(6) << "Run op " << phi_op_name_ << " kernel."; + VLOG(6) << "End run op " << phi_op_name_ << " kernel."; } } // namespace framework diff --git a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc b/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc index bc2421cfe1a869..1a6433e233ed1f 100644 --- a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc +++ b/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/pir/core/block.h" #include "paddle/pir/core/builtin_op.h" @@ -39,27 +40,30 @@ class DeadCodeEliminationPass : public pir::Pass { std::vector deleted_ops; for (auto& op : block) { if (op.HasTrait() || - op.isa()) { + op.isa() || + paddle::dialect::IsCustomOp(&op)) { continue; } if (op.use_empty()) { deleted_ops.push_back(&op); } } + for (auto* op : deleted_ops) { op->Erase(); (*num_erasers)++; } - for (auto& op : block) { - for (size_t i = 0; i < op.num_regions(); ++i) { - auto& inner_region = op.region(i); - for (auto& inner_block : inner_region) { - EraseOp(inner_block, num_erasers); + + if (deleted_ops.empty()) { + for (auto& op : block) { + for (size_t i = 0; i < op.num_regions(); ++i) { + auto& inner_region = op.region(i); + for (auto& inner_block : inner_region) { + EraseOp(inner_block, num_erasers); + } } } - } - - if (!deleted_ops.empty()) { + } else { EraseOp(block, num_erasers); } } diff --git a/test/ir/pir/fused_pass/test_pir_matmul_scale_fuse_pass.py b/test/ir/pir/fused_pass/test_pir_matmul_scale_fuse_pass.py index 54a551e258218f..8320997a074926 100644 --- a/test/ir/pir/fused_pass/test_pir_matmul_scale_fuse_pass.py +++ b/test/ir/pir/fused_pass/test_pir_matmul_scale_fuse_pass.py @@ -85,10 +85,5 @@ def test_check_output(self): self.check_pass_correct() -class TestMatmulScaleFusePatternWtihCpu(TestMatmulScaleFusePattern): - def setUp(self): - self.place_runtime = "cpu" - - if __name__ == "__main__": unittest.main() From 2b866379094776cd5b4f7937330d752dc5a8802d Mon Sep 17 00:00:00 2001 From: ming1753 <61511741+ming1753@users.noreply.github.com> Date: Fri, 5 Jan 2024 14:32:54 +0800 Subject: [PATCH 126/142] [Inference] Fix bug of RunWithExternalStream API in new executor (#60122) * fix bug of RunWithExternalStream API in new executor * add test * fix bug of RunWithExternalStream API in new executor * reset flage in RunWithExternalStream * fix bug * add param swith_stream * fix bug * modify python api * fix bug --- paddle/fluid/framework/naive_executor.cc | 6 +- paddle/fluid/framework/naive_executor.h | 3 +- .../new_executor/interpreter_base_impl.h | 6 +- .../framework/new_executor/interpretercore.cc | 16 +++- .../framework/new_executor/interpretercore.h | 6 +- .../framework/new_executor/pir_interpreter.cc | 18 +++- .../framework/new_executor/pir_interpreter.h | 6 +- .../new_executor/program_interpreter.cc | 96 +++++++++++-------- .../new_executor/program_interpreter.h | 11 ++- .../fluid/inference/api/analysis_predictor.cc | 10 +- .../fluid/inference/api/analysis_predictor.h | 3 +- .../inference/api/onnxruntime_predictor.cc | 2 +- .../inference/api/onnxruntime_predictor.h | 3 +- paddle/fluid/inference/api/paddle_api.h | 3 +- paddle/fluid/pybind/inference_api.cc | 12 ++- .../api/analysis_predictor_tester.cc | 4 +- 16 files changed, 131 insertions(+), 74 deletions(-) diff --git a/paddle/fluid/framework/naive_executor.cc b/paddle/fluid/framework/naive_executor.cc index 925c04f658b0a7..90f5b93dcb2efa 100644 --- a/paddle/fluid/framework/naive_executor.cc +++ b/paddle/fluid/framework/naive_executor.cc @@ -72,12 +72,14 @@ void NaiveExecutor::PrepareInterpreterCore( } void NaiveExecutor::RunInterpreterCore( - const std::vector &feed_names, bool need_fetch) { + const std::vector &feed_names, + bool need_fetch, + bool switch_stream) { platform::ScopedFlushDenormal flush; #ifdef PADDLE_WITH_NVTX platform::CudaNvtxRangePush("model", platform::NvtxRangeColor::Yellow); #endif - interpreter_core_->Run(feed_names, need_fetch); + interpreter_core_->Run(feed_names, need_fetch, false, false, switch_stream); #ifdef PADDLE_WITH_NVTX platform::CudaNvtxRangePop(); #endif diff --git a/paddle/fluid/framework/naive_executor.h b/paddle/fluid/framework/naive_executor.h index 5a558f3bd69216..8388bfe3a37fc1 100644 --- a/paddle/fluid/framework/naive_executor.h +++ b/paddle/fluid/framework/naive_executor.h @@ -77,7 +77,8 @@ class NaiveExecutor { void Run(); void RunInterpreterCore(const std::vector& feed_names = {}, - bool need_fetch = false); + bool need_fetch = false, + bool switch_stream = false); // Get an tensor to operating directly, without the need for feed_ops. phi::DenseTensor* FindTensor(const std::string& name); diff --git a/paddle/fluid/framework/new_executor/interpreter_base_impl.h b/paddle/fluid/framework/new_executor/interpreter_base_impl.h index ff5832ba8335e6..a7a618ac90284e 100644 --- a/paddle/fluid/framework/new_executor/interpreter_base_impl.h +++ b/paddle/fluid/framework/new_executor/interpreter_base_impl.h @@ -68,13 +68,15 @@ class InterpreterBaseImpl { const std::vector& feed_names, const std::vector& feed_tensors, bool need_fetch = true, - bool enable_job_schedule_profiler = false) = 0; + bool enable_job_schedule_profiler = false, + bool switch_stream = false) = 0; virtual paddle::framework::FetchList Run( const std::vector& feed_names, bool need_fetch = true, bool enable_job_schedule_profiler = false, - bool enable_op_profiling = false) = 0; + bool enable_op_profiling = false, + bool switch_stream = false) = 0; virtual void ShareWorkQueueFrom(InterpreterBaseImpl* src) = 0; diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index b0bbd11aef0dbd..8fdddb1548d9d2 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -67,19 +67,25 @@ FetchList InterpreterCore::Run( const std::vector& feed_names, const std::vector& feed_tensors, bool need_fetch, - bool enable_job_schedule_profiler) { - return impl_->Run( - feed_names, feed_tensors, need_fetch, enable_job_schedule_profiler); + bool enable_job_schedule_profiler, + bool switch_stream) { + return impl_->Run(feed_names, + feed_tensors, + need_fetch, + enable_job_schedule_profiler, + switch_stream); } FetchList InterpreterCore::Run(const std::vector& feed_names, bool need_fetch, bool enable_job_schedule_profiler, - bool enable_op_profiling) { + bool enable_op_profiling, + bool switch_stream) { return impl_->Run(feed_names, need_fetch, enable_job_schedule_profiler, - enable_op_profiling); + enable_op_profiling, + switch_stream); } void InterpreterCore::ShareWorkQueueFrom(std::shared_ptr src) { diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index b8c1913d931dcb..7731620565fb82 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -49,12 +49,14 @@ class InterpreterCore { const std::vector& feed_names, const std::vector& feed_tensors, bool need_fetch = true, - bool enable_job_schedule_profiler = false); + bool enable_job_schedule_profiler = false, + bool switch_stream = false); paddle::framework::FetchList Run(const std::vector& feed_names, bool need_fetch = true, bool enable_job_schedule_profiler = false, - bool enable_op_profiling = false); + bool enable_op_profiling = false, + bool switch_stream = false); void RunProfile(const std::vector& feed_names); diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.cc b/paddle/fluid/framework/new_executor/pir_interpreter.cc index 19e3d6e86ebdeb..84c8fa753eb31d 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/pir_interpreter.cc @@ -1255,7 +1255,8 @@ paddle::framework::FetchList PirInterpreter::Run( const std::vector& feed_names, const std::vector& feed_tensors, bool need_fetch, - bool enable_job_schedule_profiler) { + bool enable_job_schedule_profiler, + bool switch_stream) { enable_job_schedule_profiler_ = enable_job_schedule_profiler; auto FeedInput = [&] { @@ -1318,6 +1319,12 @@ paddle::framework::FetchList PirInterpreter::Run( is_build_ = true; is_shared_results_build_ = true; } else { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (switch_stream) { + BuildInstruction(); + VLOG(4) << "Done BuildInstruction"; + } +#endif if (FLAGS_enable_pir_in_executor_trace_run || nccl_op_num_ > 1 || execution_config_.used_for_inference || ((execution_config_.used_for_jit || execution_config_.used_for_cinn) && @@ -1350,7 +1357,8 @@ paddle::framework::FetchList PirInterpreter::Run( FetchList PirInterpreter::Run(const std::vector& feed_names, bool need_fetch, bool enable_job_schedule_profiler, - bool enable_op_profiling) { + bool enable_op_profiling, + bool switch_stream) { enable_job_schedule_profiler_ = enable_job_schedule_profiler; if (enable_op_profiling) { @@ -1401,6 +1409,12 @@ FetchList PirInterpreter::Run(const std::vector& feed_names, is_build_ = true; is_shared_results_build_ = true; } else { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (switch_stream) { + BuildInstruction(); + VLOG(4) << "Done BuildInstruction"; + } +#endif if (FLAGS_enable_pir_in_executor_trace_run || nccl_op_num_ > 1 || execution_config_.used_for_inference || ((execution_config_.used_for_jit || execution_config_.used_for_cinn) && diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.h b/paddle/fluid/framework/new_executor/pir_interpreter.h index 1684aeffef8cfa..3f197f53e12f8c 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.h +++ b/paddle/fluid/framework/new_executor/pir_interpreter.h @@ -57,12 +57,14 @@ class PirInterpreter : public InterpreterBaseImpl { const std::vector& feed_names, const std::vector& feed_tensors, bool need_fetch = true, - bool enable_job_schedule_profiler = false) override; + bool enable_job_schedule_profiler = false, + bool switch_stream = false) override; paddle::framework::FetchList Run(const std::vector& feed_names, bool need_fetch = true, bool enable_job_schedule_profiler = false, - bool enable_op_profiling = false) override; + bool enable_op_profiling = false, + bool switch_stream = false) override; void ShareWorkQueueFrom(InterpreterBaseImpl* src) override; diff --git a/paddle/fluid/framework/new_executor/program_interpreter.cc b/paddle/fluid/framework/new_executor/program_interpreter.cc index bc41742437ff9c..0f50665e1621e8 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.cc +++ b/paddle/fluid/framework/new_executor/program_interpreter.cc @@ -144,7 +144,8 @@ void ProgramInterpreter::RunImpl() { FetchList ProgramInterpreter::Run(const std::vector& feed_names, bool need_fetch, bool enable_job_schedule_profiler, - bool enable_op_profiling) { + bool enable_op_profiling, + bool switch_stream) { enable_job_schedule_profiler_ = enable_job_schedule_profiler; is_in_op_profiling_mode_ = enable_op_profiling; @@ -163,6 +164,11 @@ FetchList ProgramInterpreter::Run(const std::vector& feed_names, is_build_ = true; is_shared_results_build_ = true; } else { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (switch_stream) { + BuildOpFuncNode(&op_func_nodes); + } +#endif RunImpl(); } @@ -233,7 +239,8 @@ FetchList ProgramInterpreter::Run( const std::vector& feed_names, const std::vector& feed_tensors, bool need_fetch, - bool enable_job_schedule_profiler) { + bool enable_job_schedule_profiler, + bool switch_stream) { enable_job_schedule_profiler_ = enable_job_schedule_profiler; SetDeviceId(place_); @@ -244,7 +251,7 @@ FetchList ProgramInterpreter::Run( #endif bool is_build = is_build_; - Prepare(feed_names, feed_tensors, is_build); + Prepare(feed_names, feed_tensors, is_build, switch_stream); if (is_build) { RunImpl(); @@ -671,42 +678,7 @@ std::tuple ProgramInterpreter::InterpreterRunTime() { void ProgramInterpreter::Convert( std::vector* op_func_nodes) { auto& vec_meta_info = var_scope_.MutableVecMetaInfo(); - auto nodes = *op_func_nodes; - auto op_nums = nodes.size(); - vec_instruction_.clear(); - vec_instruction_.reserve(op_nums); - for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) { - auto& op_func_node = nodes[op_idx]; - stream_analyzer_.SetForceEventsToWaitInfo(force_evnets_to_wait_); - auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node); -#ifdef PADDLE_WITH_CUDA - if (FLAGS_new_executor_use_cuda_graph) { - auto& op = op_func_node.operator_base_; - auto& op_type = op->Type(); - if (op_type == interpreter::kMemcpyD2H || - op_type == interpreter::kMemcpyH2D) { - PADDLE_THROW(paddle::platform::errors::Fatal( - "Cuda memory copy d2h/h2d is not allowed while using cuda graph.")); - } - PADDLE_ENFORCE_EQ(typeid(*dev_ctx_) == typeid(phi::GPUContext), - true, - platform::errors::InvalidArgument( - "Device context of op %s must be [%s] while using " - "cuda graph, but got [%s].", - op_type, - typeid(phi::GPUContext).name(), - typeid(*dev_ctx_).name())); - // cuda graph needs to record all stream - phi::backends::gpu::CUDAGraphContextManager::Instance() - .RecordCapturingDeviceContext(dev_ctx_); - } -#endif - vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_); - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - vec_instruction_.back().UpdataRecordStreamForGcInfo(); -#endif - } + BuildOpFuncNode(op_func_nodes); BuildOperatorDependences(); @@ -743,6 +715,7 @@ void ProgramInterpreter::Convert( } // calculate last_live_ops_ + auto op_nums = (*op_func_nodes).size(); for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) { Instruction& instr = vec_instruction_[op_idx]; OpInOutInfo info; @@ -879,6 +852,46 @@ void ProgramInterpreter::Convert( AnalyseExecuteOrderForTrace(); } +void ProgramInterpreter::BuildOpFuncNode( + std::vector* op_func_nodes) { + auto nodes = *op_func_nodes; + auto op_nums = nodes.size(); + vec_instruction_.clear(); + vec_instruction_.reserve(op_nums); + for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) { + auto& op_func_node = nodes[op_idx]; + stream_analyzer_.SetForceEventsToWaitInfo(force_evnets_to_wait_); + auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node); +#ifdef PADDLE_WITH_CUDA + if (FLAGS_new_executor_use_cuda_graph) { + auto& op = op_func_node.operator_base_; + auto& op_type = op->Type(); + if (op_type == interpreter::kMemcpyD2H || + op_type == interpreter::kMemcpyH2D) { + PADDLE_THROW(paddle::platform::errors::Fatal( + "Cuda memory copy d2h/h2d is not allowed while using cuda graph.")); + } + PADDLE_ENFORCE_EQ(typeid(*dev_ctx_) == typeid(phi::GPUContext), + true, + platform::errors::InvalidArgument( + "Device context of op %s must be [%s] while using " + "cuda graph, but got [%s].", + op_type, + typeid(phi::GPUContext).name(), + typeid(*dev_ctx_).name())); + // cuda graph needs to record all stream + phi::backends::gpu::CUDAGraphContextManager::Instance() + .RecordCapturingDeviceContext(dev_ctx_); + } +#endif + vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_); + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + vec_instruction_.back().UpdataRecordStreamForGcInfo(); +#endif + } +} + void ProgramInterpreter::BuildSkipShareLoDInfo() { for (size_t i = 0; i < vec_instruction_.size(); ++i) { bool can_skip_lod = true; @@ -1494,7 +1507,8 @@ void ProgramInterpreter::CheckGC(const Instruction& instr) { void ProgramInterpreter::Prepare( const std::vector& feed_names, const std::vector& feed_tensors, - bool prepare_feed) { + bool prepare_feed, + bool switch_stream) { PADDLE_ENFORCE_EQ(feed_names.size(), feed_tensors.size(), platform::errors::PreconditionNotMet( @@ -1517,7 +1531,7 @@ void ProgramInterpreter::Prepare( } }; - if (!is_build_) { + if (!is_build_ || switch_stream) { paddle::framework::interpreter::BuildVariableScope( block_, execution_config_, &var_scope_); FeedInput(); diff --git a/paddle/fluid/framework/new_executor/program_interpreter.h b/paddle/fluid/framework/new_executor/program_interpreter.h index b19e3a06a42588..5359c41fddcdc6 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.h +++ b/paddle/fluid/framework/new_executor/program_interpreter.h @@ -49,12 +49,14 @@ class ProgramInterpreter : public InterpreterBaseImpl { const std::vector& feed_names, const std::vector& feed_tensors, bool need_fetch = true, - bool enable_job_schedule_profiler = false) override; + bool enable_job_schedule_profiler = false, + bool switch_stream = false) override; paddle::framework::FetchList Run(const std::vector& feed_names, bool need_fetch = true, bool enable_job_schedule_profiler = false, - bool enable_op_profiling = false) override; + bool enable_op_profiling = false, + bool switch_stream = false) override; std::shared_ptr GetMutableCopyProgram() override; @@ -125,6 +127,8 @@ class ProgramInterpreter : public InterpreterBaseImpl { void BuildSkipShareLoDInfo(); void UpdateSyncOpNum(); void AnalyseExecuteOrderForTrace(); + void BuildOpFuncNode( + std::vector* op_func_nodes); // inplace void BuildInplace(); @@ -150,7 +154,8 @@ class ProgramInterpreter : public InterpreterBaseImpl { // only used when program contains no feed op void Prepare(const std::vector& feed_names, const std::vector& feed_tensors, - bool prepare_feed); + bool prepare_feed, + bool switch_stream = false); void RecordMemcpyD2H(const Instruction& instr_node); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index c7164b61bb7c00..5bd88276e76bc7 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2249,7 +2249,7 @@ std::unique_ptr AnalysisPredictor::GetOutputTensor( return res; } -bool AnalysisPredictor::ZeroCopyRun() { +bool AnalysisPredictor::ZeroCopyRun(bool switch_stream) { inference::DisplayMemoryInfo(place_, "before run"); #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) if (config_.dist_config().use_dist_model()) { @@ -2312,7 +2312,7 @@ bool AnalysisPredictor::ZeroCopyRun() { #endif if (config_.new_executor_enabled()) { - executor_->RunInterpreterCore(); + executor_->RunInterpreterCore({}, false, switch_stream); } else { executor_->Run(); } @@ -2353,7 +2353,7 @@ bool AnalysisPredictor::ExpRunWithExternalStream(const gpuStream_t stream) { "Please use config.SetExecStream to init gpu resources, and then we " "will bind gpu resources to execution stream.")); } - + bool switch_stream = false; if (stream != predictor_stream_) { #ifdef PADDLE_WITH_HIP hipStreamSynchronize(static_cast(predictor_stream_)); @@ -2383,9 +2383,9 @@ bool AnalysisPredictor::ExpRunWithExternalStream(const gpuStream_t stream) { })); auto &pool = paddle::experimental::DeviceContextPool::Instance(); pool.SyncDeviceContext(place_); + switch_stream = true; } - - return ZeroCopyRun(); + return ZeroCopyRun(switch_stream); } #endif diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index 4a5cfb229a459e..0f2091478af2a1 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -204,9 +204,10 @@ class AnalysisPredictor : public PaddlePredictor { /// /// \brief Run the prediction engine /// + /// \param switch_stream Whether the stream is switched /// \return Whether the function executed successfully /// - bool ZeroCopyRun() override; + bool ZeroCopyRun(bool switch_stream = false) override; #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) // Note: Can only be used under thread_local semantics. diff --git a/paddle/fluid/inference/api/onnxruntime_predictor.cc b/paddle/fluid/inference/api/onnxruntime_predictor.cc index 25970440469168..f2d8f7478d9024 100644 --- a/paddle/fluid/inference/api/onnxruntime_predictor.cc +++ b/paddle/fluid/inference/api/onnxruntime_predictor.cc @@ -333,7 +333,7 @@ bool ONNXRuntimePredictor::Run(const std::vector &inputs, return false; } -bool ONNXRuntimePredictor::ZeroCopyRun() { +bool ONNXRuntimePredictor::ZeroCopyRun(bool switch_stream) { try { const char *device_name = platform::is_cpu_place(place_) ? "Cpu" : "Cuda"; std::vector inputs; diff --git a/paddle/fluid/inference/api/onnxruntime_predictor.h b/paddle/fluid/inference/api/onnxruntime_predictor.h index 971632c4b3c7a6..c983f8acdae281 100644 --- a/paddle/fluid/inference/api/onnxruntime_predictor.h +++ b/paddle/fluid/inference/api/onnxruntime_predictor.h @@ -175,9 +175,10 @@ class ONNXRuntimePredictor : public PaddlePredictor { /// /// \brief Run the prediction engine /// + /// \param switch_stream Whether the stream is switched /// \return Whether the function executed successfully /// - bool ZeroCopyRun() override; + bool ZeroCopyRun(bool switch_stream = false) override; /// /// \brief Release all tmp tensor to compress the size of the memory pool. diff --git a/paddle/fluid/inference/api/paddle_api.h b/paddle/fluid/inference/api/paddle_api.h index 3fefba9ef22be8..89540a91e37895 100644 --- a/paddle/fluid/inference/api/paddle_api.h +++ b/paddle/fluid/inference/api/paddle_api.h @@ -295,8 +295,9 @@ class PD_INFER_DECL PaddlePredictor { /// To use it, one should call the AnalysisConfig.SwitchUseFeedFetchOp(false) /// and then use the `GetInputTensor` and `GetOutputTensor` /// to directly write or read the input/output tensors. + /// \param switch_stream Whether the stream is switched. /// \return Whether the run is successful - virtual bool ZeroCopyRun() { return false; } + virtual bool ZeroCopyRun(bool switch_stream = false) { return false; } /// /// \brief Clear the intermediate tensors of the predictor diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 03a95e870b8105..94df6a0ee0d418 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -691,7 +691,9 @@ void BindPaddlePredictor(py::module *m) { .def("get_output_tensor", &PaddlePredictor::GetOutputTensor) .def("get_input_names", &PaddlePredictor::GetInputNames) .def("get_output_names", &PaddlePredictor::GetOutputNames) - .def("zero_copy_run", &PaddlePredictor::ZeroCopyRun) + .def("zero_copy_run", + &PaddlePredictor::ZeroCopyRun, + py::arg("switch_stream") = false) .def("clone", [](PaddlePredictor &self) { return self.Clone(nullptr); }) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) .def("clone", @@ -740,7 +742,9 @@ void BindNativePredictor(py::module *m) { }) .def("get_input_tensor", &NativePaddlePredictor::GetInputTensor) .def("get_output_tensor", &NativePaddlePredictor::GetOutputTensor) - .def("zero_copy_run", &NativePaddlePredictor::ZeroCopyRun) + .def("zero_copy_run", + &NativePaddlePredictor::ZeroCopyRun, + py::arg("switch_stream") = false) .def("clone", [](NativePaddlePredictor &self) { return self.Clone(nullptr); }) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -1130,7 +1134,9 @@ void BindAnalysisPredictor(py::module *m) { .def("get_input_names", &AnalysisPredictor::GetInputNames) .def("get_output_names", &AnalysisPredictor::GetOutputNames) .def("get_input_tensor_shape", &AnalysisPredictor::GetInputTensorShape) - .def("zero_copy_run", &AnalysisPredictor::ZeroCopyRun) + .def("zero_copy_run", + &AnalysisPredictor::ZeroCopyRun, + py::arg("switch_stream") = false) .def("clear_intermediate_tensor", &AnalysisPredictor::ClearIntermediateTensor) .def("try_shrink_memory", &AnalysisPredictor::TryShrinkMemory) diff --git a/test/cpp/inference/api/analysis_predictor_tester.cc b/test/cpp/inference/api/analysis_predictor_tester.cc index 3d841954a89d65..3d87140d9c05a7 100644 --- a/test/cpp/inference/api/analysis_predictor_tester.cc +++ b/test/cpp/inference/api/analysis_predictor_tester.cc @@ -668,6 +668,7 @@ TEST(Tensor, RunWithExternalStream) { cudaStream_t stream; cudaStreamCreate(&stream); config.SetExecStream(stream); + config.EnableNewExecutor(); auto predictor = CreatePredictor(config); auto w0 = predictor->GetInputHandle("firstw"); @@ -703,8 +704,7 @@ TEST(Tensor, RunWithExternalStream) { cudaStream_t external_stream; cudaStreamCreate(&external_stream); - Config tmp_config(config); - tmp_config.SetExecStream(external_stream); + predictor->Run(); paddle_infer::experimental::InternalUtils::RunWithExternalStream( predictor.get(), external_stream); From a11aabd748bfffbb04dd4b01ebe9af615d8dcc0c Mon Sep 17 00:00:00 2001 From: Frank Lin Date: Fri, 5 Jan 2024 16:07:28 +0800 Subject: [PATCH 127/142] Resubmit PR-58859 (#60310) * allow multiple rng state in generator * Fix 60142; Fix some comments from sneaxiy * Overwrite copy constructors * add api * pre-commit --- paddle/fluid/pybind/generator_py.cc | 23 +- paddle/phi/core/generator.cc | 164 ++++++------ paddle/phi/core/generator.h | 103 +++++-- paddle/phi/kernels/funcs/dropout_impl.cu.h | 34 ++- .../gpu/fused_dropout_add_grad_kernel.cu | 4 +- .../fusion/gpu/fused_dropout_add_kernel.cu | 48 ++-- .../paddle/device/cuda/cuda_graphed_layer.py | 2 +- .../distributed/fleet/layers/mpu/random.py | 33 ++- python/paddle/incubate/__init__.py | 5 + python/paddle/incubate/framework/__init__.py | 6 + python/paddle/incubate/framework/random.py | 251 ++++++++++++++++++ python/setup.py.in | 1 + setup.py | 1 + test/legacy_test/test_cuda_graphed_layer.py | 2 +- .../test_random_generator_set_get_state.py | 85 ++++++ 15 files changed, 605 insertions(+), 157 deletions(-) create mode 100644 python/paddle/incubate/framework/random.py create mode 100644 test/legacy_test/test_random_generator_set_get_state.py diff --git a/paddle/fluid/pybind/generator_py.cc b/paddle/fluid/pybind/generator_py.cc index 520fe09bc710cd..c4fd46fd623e38 100644 --- a/paddle/fluid/pybind/generator_py.cc +++ b/paddle/fluid/pybind/generator_py.cc @@ -38,7 +38,7 @@ void BindGenerator(py::module* m_ptr) { "GeneratorState") .def("current_seed", [](std::shared_ptr& self) { - return self->current_seed; + return self->seed; }) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ defined(PADDLE_WITH_CUSTOM_DEVICE) || defined(PADDLE_WITH_XPU) @@ -46,7 +46,7 @@ void BindGenerator(py::module* m_ptr) { // type, resulting in a problem with precision under the cpu. .def(py::pickle( [](const phi::Generator::GeneratorState& s) { // __getstate__ - return py::make_tuple(s.device, s.current_seed, s.thread_offset); + return py::make_tuple(s.device, s.seed, s.offset); }, [](py::tuple s) { // __setstate__ if (s.size() != 3) @@ -54,21 +54,19 @@ void BindGenerator(py::module* m_ptr) { "Invalid Random state. Please check the format(device, " "current_seed, thread_offset)."); - phi::Generator::GeneratorState state; - state.device = s[0].cast(); - state.current_seed = s[1].cast(); - state.thread_offset = s[2].cast(); + int64_t device = s[0].cast(); + int64_t seed = s[1].cast(); + uint64_t offset = s[2].cast(); + + phi::Generator::GeneratorState state(device, seed, offset); - std::seed_seq seq({state.current_seed}); - auto engine = std::make_shared(seq); - state.cpu_engine = *engine; return state; })) #endif .def("__str__", [](const phi::Generator::GeneratorState& self) { std::stringstream ostr; - ostr << self.device << " " << self.current_seed << " " - << self.thread_offset << " " << self.cpu_engine; + ostr << self.device << " " << self.seed << " " << self.offset << " " + << self.cpu_engine; return ostr.str(); }); @@ -78,6 +76,9 @@ void BindGenerator(py::module* m_ptr) { [](phi::Generator& self) { new (&self) phi::Generator(); }) .def("get_state", &phi::Generator::GetState) .def("set_state", &phi::Generator::SetState) + .def("get_state_index", &phi::Generator::GetStateIndex) + .def("set_state_index", &phi::Generator::SetStateIndex) + .def("register_state_index", &phi::Generator::RegisterStateIndex) .def("manual_seed", [](std::shared_ptr& self, uint64_t seed) { self->SetCurrentSeed(seed); diff --git a/paddle/phi/core/generator.cc b/paddle/phi/core/generator.cc index b3f8a2d19caba0..ef5d3e7a646211 100644 --- a/paddle/phi/core/generator.cc +++ b/paddle/phi/core/generator.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include +#include #include #include @@ -157,134 +158,125 @@ const std::shared_ptr& GetRandomSeedGenerator( // RandomGenerator. std::shared_ptr GetCPURandomEngine(uint64_t seed) { if (seed == 0) { - VLOG(4) << "Use random engine from generator"; + VLOG(4) << "Use random cpu_engine from generator"; return DefaultCPUGenerator()->GetCPUEngine(); } else { - // NOTE(zhiqiu): creating an engine instance everytime instead of using + // NOTE(zhiqiu): creating an cpu_engine instance everytime instead of using // OpDefaultCPUEngine(), this is the legacy behavior of random operators. // The benefit is that when runing PE with fixed-seed in multiple thrads, - // each thread has their own engine, and doesn't affect each other. + // each thread has their own cpu_engine, and doesn't affect each other. // // And we need to measure the determinacy of Generator in PE. - auto engine = std::make_shared(); + auto cpu_engine = std::make_shared(); static std::mutex mu_; { std::lock_guard lock(mu_); - engine->seed(seed); + cpu_engine->seed(seed); } - return engine; + return cpu_engine; } } +inline void Generator::print_state_info() { + VLOG(4) << "Generator Random state " + << "device id: " << state().device << ", seed: " << state().seed + << ", offset: " << state().offset << ", cpu_engine: " << cpu_engine(); +} + Generator::Generator() { auto seed = GetRandomSeed(); - std::seed_seq seq({seed}); - auto engine = std::make_shared(seq); - this->state_.cpu_engine = *engine; - this->state_.device = -1; - this->state_.current_seed = seed; - this->state_.thread_offset = 0; - this->engine_ = engine; - VLOG(4) << "initial seed: " << this->state_.current_seed - << ", cpu engine: " << &this->state_.cpu_engine; + current_index = states_.size(); + states_.emplace_back(-1, seed); + print_state_info(); } Generator::Generator(uint64_t seed) { - std::seed_seq seq({seed}); - auto engine = std::make_shared(seq); - this->state_.cpu_engine = *engine; - this->state_.device = -1; - this->state_.current_seed = seed; - this->state_.thread_offset = 0; - this->engine_ = engine; - VLOG(4) << "initial seed: " << this->state_.current_seed - << ", cpu engine: " << &this->state_.cpu_engine; + current_index = states_.size(); + states_.emplace_back(-1, seed); + print_state_info(); } -Generator::Generator(uint64_t seed, uint64_t device_id) { - std::seed_seq seq({seed}); - auto engine = std::make_shared(seq); - this->state_.cpu_engine = *engine; - this->state_.device = static_cast(device_id); - this->state_.current_seed = seed; - this->state_.thread_offset = 0; - this->engine_ = engine; - VLOG(4) << "initial seed: " << this->state_.current_seed - << ", cpu engine: " << &this->state_.cpu_engine; +Generator::Generator(uint64_t seed, int64_t device_id) { + current_index = states_.size(); + // device id first, then seed + states_.emplace_back(device_id, seed); + print_state_info(); } -phi::Generator::GeneratorState Generator::GetState() { - std::lock_guard lock(this->mu_); - state_.cpu_engine = *engine_; - VLOG(4) << "Get Random state: " - << "device id: " << (uint64_t)(this->state_.device) - << ", current_seed: " << this->state_.current_seed - << ", thread_offset: " << this->state_.thread_offset - << ", cpu engine: " << *(this->engine_); - return this->state_; -} +phi::Generator::GeneratorState Generator::GetState() { return state(); } void Generator::SetState(const phi::Generator::GeneratorState& state) { - std::lock_guard lock(this->mu_); - this->state_ = state; - this->engine_ = std::make_shared(state.cpu_engine); - VLOG(4) << "Set Random state: " - << "device id: " << (uint64_t)(this->state_.device) - << ", current_seed: " << this->state_.current_seed - << ", thread_offset: " << this->state_.thread_offset - << ", cpu engine: " << *(this->engine_); + std::lock_guard lock(mu_); + if (current_index < states_.size()) + states_[current_index] = state; + else + PADDLE_THROW(phi::errors::NotFound("Generator index is not found")); + print_state_info(); +} + +uint64_t Generator::GetStateIndex() { return current_index; } + +void Generator::SetStateIndex(uint64_t StateIndex) { + std::lock_guard lock(mu_); + if (current_index < states_.size()) + current_index = StateIndex; + else + PADDLE_THROW(phi::errors::NotFound("Generator index is not found")); +} + +uint64_t Generator::RegisterStateIndex(const GeneratorState& state) { + std::lock_guard lock(mu_); + auto new_index = states_.size(); + states_.push_back(state); + current_index = new_index; + return new_index; +} + +inline Generator::GeneratorState& Generator::state() { + if (current_index < states_.size()) + return states_[current_index]; + else + PADDLE_THROW(phi::errors::NotFound("Generator index is not found")); +} + +inline std::shared_ptr Generator::cpu_engine() { + return state().cpu_engine; } uint64_t Generator::GetCurrentSeed() { - std::lock_guard lock(this->mu_); - return this->state_.current_seed; + std::lock_guard lock(mu_); + return state().seed; } uint64_t Generator::Seed() { - std::lock_guard lock(this->mu_); - uint64_t seed = 0; - std::random_device de; - seed = ((((uint64_t)de()) << 32) + de()) & 0x1FFFFFFFFFFFFF; - this->state_.current_seed = seed; - std::seed_seq seq({seed}); - this->engine_->seed(seq); - - return this->state_.current_seed; + std::lock_guard lock(mu_); + uint64_t seed = GetRandomSeed(); + state().reset(seed); + return seed; } void Generator::SetCurrentSeed(uint64_t seed) { - std::lock_guard lock(this->mu_); - this->state_.current_seed = seed; - this->state_.thread_offset = 0; - std::seed_seq seq({seed}); - this->engine_->seed(seq); + std::lock_guard lock(mu_); + state().reset(seed); } std::shared_ptr Generator::GetCPUEngine() { - std::lock_guard lock(this->mu_); - return this->engine_; -} - -void Generator::SetCPUEngine(std::shared_ptr engine) { - std::lock_guard lock(this->mu_); - this->engine_ = engine; + return cpu_engine(); } uint64_t Generator::Random64() { - std::lock_guard lock(this->mu_); - auto engine = this->engine_; - return (*engine)(); + std::lock_guard lock(mu_); + auto current_engine = cpu_engine(); + return (*current_engine)(); } -std::pair Generator::IncrementOffset( - uint64_t increment_offset) { +std::pair Generator::IncrementOffset(uint64_t increment) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - std::lock_guard lock(this->mu_); - uint64_t cur_offset = this->state_.thread_offset; - VLOG(10) << "cur_offset = " << cur_offset - << " increment_offset = " << increment_offset; - this->state_.thread_offset += increment_offset; - return std::make_pair(this->state_.current_seed, cur_offset); + std::lock_guard lock(mu_); + uint64_t offset = state().offset; + state().offset = offset + increment; + print_state_info(); + return std::make_pair(state().seed, offset); #else PADDLE_THROW(phi::errors::PermissionDenied( "Increment Offset only support in CUDA place")); diff --git a/paddle/phi/core/generator.h b/paddle/phi/core/generator.h index aeaa89f2abf5ef..2aadeb9910abe1 100644 --- a/paddle/phi/core/generator.h +++ b/paddle/phi/core/generator.h @@ -14,64 +14,127 @@ limitations under the License. */ #pragma once -#include - #include +#include #include -#include // temp for debug #include #include // NOLINT #include #include #include +#include #include "paddle/phi/common/place.h" namespace phi { +#define MAGIC_RANDOM_SEED 34342423252 class Generator { public: struct GeneratorState { - int64_t device = -1; - uint64_t current_seed = 34342423252; - uint64_t thread_offset = 0; - std::mt19937_64 cpu_engine; + int64_t device; + uint64_t seed; + uint64_t offset; + std::shared_ptr cpu_engine; + + GeneratorState(int64_t device_ = -1, + uint64_t seed_ = MAGIC_RANDOM_SEED, + uint64_t offset_ = 0) + : device(device_), seed(seed_), offset(offset_) { + std::seed_seq seq({seed}); + cpu_engine = std::make_shared(seq); + } + + GeneratorState(const GeneratorState& state) + : device(state.device), seed(state.seed), offset(state.offset) { + if (state.cpu_engine) { + std::seed_seq seq({state.seed}); + cpu_engine = std::make_shared(seq); + // Clone the engine state + *(cpu_engine) = *(state.cpu_engine); + } + } + + GeneratorState& operator=(const GeneratorState& state) { + if (this != &state) { + device = state.device; + seed = state.seed; + offset = state.offset; + + if (state.cpu_engine) { + std::seed_seq seq({state.seed}); + cpu_engine = std::make_shared(seq); + *cpu_engine = *(state.cpu_engine); + } else { + cpu_engine = nullptr; + } + } + return *this; + } + + void reset(uint64_t new_seed = MAGIC_RANDOM_SEED) { + std::seed_seq seq({new_seed}); + cpu_engine->seed(seq); + offset = 0; + seed = new_seed; + } }; Generator(); explicit Generator(uint64_t seed); - Generator(uint64_t seed, uint64_t device_id); + Generator(uint64_t seed, int64_t device_id); Generator(const Generator& other) = delete; ~Generator() = default; - // get random state + // Retrieves the cloned current state of the generator. GeneratorState GetState(); - // set random state + // Directly sets the generator's current state to a specified state. void SetState(const GeneratorState&); - // get current seed + + // Retrieves the seed of the current generator state. uint64_t GetCurrentSeed(); - // random a seed and get + // Retrieves the offset of the current generator state. + uint64_t GetCurrentOffset(); + + // Retrieves the index of the current generator state. + uint64_t GetStateIndex(); + // Sets the index for the current generator state, switching the active state. + void SetStateIndex(uint64_t StateIndex); + + // Registers a new state with the generator and switch to new state. + // Returns the index of this new state. + uint64_t RegisterStateIndex(const GeneratorState&); + + // Generates and sets a new random seed. uint64_t Seed(); - // set seed + // Sets the seed of the current generator state. void SetCurrentSeed(uint64_t seed); - // get cpu engine + + // Retrieves cpu cpu_engine in current state. std::shared_ptr GetCPUEngine(); - // set cpu engine - void SetCPUEngine(std::shared_ptr); + // Set CPU random number generation cpu_engine to current state + void SetCPUEngine(std::shared_ptr cpu_engine); uint64_t Random64(); + // Increments the offset of the current generator state by a specified amount + // and returns the new seed and offset. std::pair IncrementOffset(uint64_t increment_offset); - uint64_t get_device_id() { return this->state_.device; } - private: - GeneratorState state_; - std::shared_ptr engine_; + // Accesses the current generator state by index. + inline GeneratorState& state(); + // Accesses the current cpu cpu_engine by index. + inline std::shared_ptr cpu_engine(); + // Outputs detailed information about the current generator state to the log. + inline void print_state_info(); + + size_t current_index = 0; + std::vector states_; mutable std::mutex mu_; }; diff --git a/paddle/phi/kernels/funcs/dropout_impl.cu.h b/paddle/phi/kernels/funcs/dropout_impl.cu.h index 985c028afb2a88..0a89a3ca825ba5 100644 --- a/paddle/phi/kernels/funcs/dropout_impl.cu.h +++ b/paddle/phi/kernels/funcs/dropout_impl.cu.h @@ -366,18 +366,34 @@ void DropoutFwGPUKernelDriver( reinterpret_cast(&(VectorizedRandomGenerator)); cudaFunction_t cudaFunc; PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr)); + const phi::GPUContext* dev_ctx_p = &dev_ctx; + auto gen_cuda = dev_ctx.GetGenerator(); + auto state_index = gen_cuda->GetStateIndex(); + phi::backends::gpu::CUDAGraphNodeLauncher::parameterSetter_t - parameterSetter = [offset, dev_ctx_p]( + parameterSetter = [offset, dev_ctx_p, state_index, is_fix_seed]( phi::backends::gpu::CUDAKernelParams& params) { - uint64_t seed_data, increment; - phi::funcs::GetSeedDataAndIncrement( - *dev_ctx_p, nullptr, false, 0, offset, &seed_data, &increment); - params.As(2) = seed_data; - params.As(8) = increment; - VLOG(10) << "CUDA_GRAPH seed_data = " << seed_data - << ", increment = " << increment; + if (!is_fix_seed) { + // we assume seed is null pointer + // seed copy to cpu is meaningless here + assert(seed_tensor_ptr == nullptr); + + auto gen_cuda = dev_ctx_p->GetGenerator(); + // ensure the generator use correct state index + gen_cuda->SetStateIndex(state_index); + + uint64_t seed, increment; + std::tie(seed, increment) = gen_cuda->IncrementOffset(offset); + + params.As(2) = seed; + params.As(8) = increment; + + VLOG(10) << "CUDA_GRAPH seed = " << seed + << ", increment = " << increment; + } }; + phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t cudaKernelCallback = [=](unsigned int id) { VectorizedRandomGenerator @@ -395,7 +411,7 @@ void DropoutFwGPUKernelDriver( phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch( cudaFunc, parameterSetter, cudaKernelCallback); - VLOG(10) << "NON_CUDA_GRAPH seed_data = " << seed_data + VLOG(10) << "NON_CUDA_GRAPH seed = " << seed_data << ", increment = " << increment; #endif } diff --git a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu index 4d06cc27a34e34..a3eb1483b4f749 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu @@ -225,7 +225,7 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx, params.As(2) = seed_data; params.As(6) = increment; - VLOG(10) << "CUDA_GRAPH seed_data = " << seed_data + VLOG(10) << "CUDA_GRAPH seed = " << seed_data << ", increment = " << increment; }; void* functionPtr = reinterpret_cast( @@ -249,7 +249,7 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx, phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch( cudaFunc, parameterSetter, cudaKernelCallback); - VLOG(10) << "NON_CUDA_GRAPH seed_data = " << seed_data + VLOG(10) << "NON_CUDA_GRAPH seed = " << seed_data << ", increment = " << increment; #endif } diff --git a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu index e4effaf6be28c4..935e5af3179486 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu @@ -205,25 +205,37 @@ void FusedDropoutAddKernel(const Context& dev_ctx, &(VectorizedDropoutForward>)); cudaFunction_t cudaFunc; PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr)); + + // seed_offset_data should preserved by cudaGraph pool + auto gen_cuda = dev_ctx.GetGenerator(); + auto state_index = gen_cuda->GetStateIndex(); auto parameterSetter = - [numel, dev_ctx_p, seed, offset, seed_offset_data, seed_tensor_ptr]( - phi::backends::gpu::CUDAKernelParams& params) { - uint64_t seed_data, increment; - // we get the seed_data/increment from seed/offset - phi::funcs::GetSeedDataAndIncrement(*dev_ctx_p, - seed_tensor_ptr, - false, // fix_seed - seed, - offset, - &seed_data, - &increment); - params.As(2) = seed_data; - params.As(6) = increment; - VLOG(10) << "CUDA_GRAPH seed_data = " << seed_data - << ", increment = " << increment; + [dev_ctx_p, + offset, + seed_offset_data, + state_index, + seed_tensor_ptr, + fix_seed](phi::backends::gpu::CUDAKernelParams& params) { + if (!fix_seed) { + auto gen_cuda = dev_ctx_p->GetGenerator(); + // ensure the generator use correct state index + gen_cuda->SetStateIndex(state_index); + + // we assume seed is null pointer + // seed copy to cpu is meaningless here + assert(seed_tensor_ptr == nullptr); + + uint64_t seed, increment; + std::tie(seed, increment) = gen_cuda->IncrementOffset(offset); + VLOG(10) << "CUDA_GRAPH seed = " << seed + << ", increment = " << increment; + + params.As(2) = seed; + params.As(6) = increment; - seed_offset_data[0] = static_cast(seed_data); - seed_offset_data[1] = static_cast(increment); + seed_offset_data[0] = static_cast(seed); + seed_offset_data[1] = static_cast(increment); + } }; phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t cudaKernelCallback = [=](unsigned int id) { @@ -241,7 +253,7 @@ void FusedDropoutAddKernel(const Context& dev_ctx, phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch( cudaFunc, parameterSetter, cudaKernelCallback); - VLOG(10) << "NON_CUDA_GRAPH seed_data = " << seed_data + VLOG(10) << "NON_CUDA_GRAPH seed = " << seed_data << ", increment = " << increment; #endif } else { diff --git a/python/paddle/device/cuda/cuda_graphed_layer.py b/python/paddle/device/cuda/cuda_graphed_layer.py index 9765cc19690c51..3e3f7da3092c3e 100644 --- a/python/paddle/device/cuda/cuda_graphed_layer.py +++ b/python/paddle/device/cuda/cuda_graphed_layer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/paddle/distributed/fleet/layers/mpu/random.py b/python/paddle/distributed/fleet/layers/mpu/random.py index fad40dc03409b1..d3f8521a3d0cbc 100644 --- a/python/paddle/distributed/fleet/layers/mpu/random.py +++ b/python/paddle/distributed/fleet/layers/mpu/random.py @@ -51,31 +51,46 @@ def add(self, name, seed): self.seeds_.add(seed) if name in self.states_: raise ValueError(f'state {name} already exists') - orig_rng_state = paddle.get_rng_state() + orig_rng_state_index = paddle.incubate.get_rng_state(use_index=True) + # register a new state and set that state with the seed, store the indices into states_ + self.states_[name] = paddle.incubate.register_rng_state_as_index() paddle.seed(seed) - self.states_[name] = paddle.get_rng_state() - paddle.set_rng_state(orig_rng_state) + paddle.incubate.set_rng_state(orig_rng_state_index, use_index=True) def get_states_tracker(self): states = {} + orig_rng_state_index = paddle.incubate.get_rng_state(use_index=True) for name in self.states_: - states[name] = self.states_[name] + # switch index to name + paddle.incubate.set_rng_state(self.states_[name], use_index=True) + # export the saved state + states[name] = paddle.get_cuda_rng_state() + paddle.incubate.set_rng_state(orig_rng_state_index, use_index=True) return states def set_states_tracker(self, states): - self.states_ = states + orig_rng_state_index = paddle.incubate.get_rng_state(use_index=True) + for name in states: + if name not in self.states_: + raise ValueError(f'state {name} does not exists') + # switch index to name + paddle.incubate.set_rng_state(self.states_[name], use_index=True) + # set the state to saved state + paddle.set_cuda_rng_state(states[name]) + + paddle.incubate.set_rng_state(orig_rng_state_index, use_index=True) @contextlib.contextmanager def rng_state(self, name=MODEL_PARALLEL_RNG): if name not in self.states_: raise ValueError(f'state {name} does not exist') - orig_rng_state = paddle.get_rng_state() - paddle.set_rng_state(self.states_[name]) + orig_rng_state_index = paddle.incubate.get_rng_state(use_index=True) + paddle.incubate.set_rng_state(self.states_[name], use_index=True) try: yield finally: - self.states_[name] = paddle.get_rng_state() - paddle.set_rng_state(orig_rng_state) + self.states_[name] = paddle.incubate.get_rng_state(use_index=True) + paddle.incubate.set_rng_state(orig_rng_state_index, use_index=True) RNG_STATE_TRACKER = RNGStatesTracker() diff --git a/python/paddle/incubate/__init__.py b/python/paddle/incubate/__init__.py index 6966cedc44d7e8..e6e2dc766fc874 100644 --- a/python/paddle/incubate/__init__.py +++ b/python/paddle/incubate/__init__.py @@ -24,6 +24,11 @@ xpu, ) from .checkpoint import auto_checkpoint # noqa: F401 +from .framework import ( # noqa: F401 + get_rng_state, + register_rng_state_as_index, + set_rng_state, +) from .nn.loss import identity_loss from .operators import ( graph_khop_sampler, diff --git a/python/paddle/incubate/framework/__init__.py b/python/paddle/incubate/framework/__init__.py index c2a4d5037531ce..144665307fe44c 100644 --- a/python/paddle/incubate/framework/__init__.py +++ b/python/paddle/incubate/framework/__init__.py @@ -12,4 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .random import ( # noqa: F401 + get_rng_state, + register_rng_state_as_index, + set_rng_state, +) + __all__ = [] diff --git a/python/paddle/incubate/framework/random.py b/python/paddle/incubate/framework/random.py new file mode 100644 index 00000000000000..32388dbd015466 --- /dev/null +++ b/python/paddle/incubate/framework/random.py @@ -0,0 +1,251 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: define random api +import paddle +from paddle import base +from paddle.base import core + +__all__ = [] + + +def get_rng_state(device=None, use_index=False): + """ + Get all random states of random generators of specified device. + Args: + device(str): This parameter determines the specific running device. + It can be ``cpu``, ``gpu``, ``xpu``, Default is None. + If None, return the generators of current device (specified by ``set_device``). + use_index(bool): If use index is True, return the index that saved in the generator + Returns: + GeneratorState: object. + Examples: + .. code-block:: python + >>> import paddle + >>> sts = paddle.incubate.get_rng_state() + """ + + def get_state(generator): + if use_index: + return generator.get_state_index() + else: + return generator.get_state() + + state_list = [] + if device is None: + place = base.framework._current_expected_place() + else: + place = paddle.device._convert_to_place(device) + + if isinstance(place, core.CPUPlace): + state_list.append(get_state(core.default_cpu_generator())) + elif isinstance(place, core.CUDAPlace): + for i in range(core.get_cuda_device_count()): + state_list.append(get_state(core.default_cuda_generator(i))) + elif isinstance(place, core.XPUPlace): + for i in range(core.get_xpu_device_count()): + state_list.append(get_state(core.default_xpu_generator(i))) + elif isinstance(place, core.CustomPlace): + dev_cnt = sum( + [ + place.get_device_type() == s.split(':')[0] + for s in core.get_available_custom_device() + ] + ) + for i in range(dev_cnt): + state_list.append( + get_state( + core.default_custom_device_generator( + core.CustomPlace(place.get_device_type(), i) + ) + ) + ) + else: + raise ValueError( + f"get_rng_state is not implemented for current device: {place}" + ) + + return state_list + + +def set_rng_state(state_list, device=None, use_index=False): + """ + + Sets generator state for all device generators. + + Args: + state_list(list|tuple): The device states to set back to device generators. state_list is obtained from get_rng_state(). + device(str): This parameter determines the specific running device. + It can be ``cpu``, ``gpu``, ``xpu``, Default is None. + If None, return the generators of current device (specified by ``set_device``). + use_index(bool): If use index is True, state_list should be the indices of the states + + Returns: + None. + + Examples: + .. code-block:: python + + >>> import paddle + >>> sts = paddle.incubate.get_rng_state() + >>> paddle.incubate.set_rng_state(sts) + + """ + + def set_state(generator, state): + if use_index: + generator.set_state_index(state) + else: + generator.set_state(state) + + if device is None: + place = base.framework._current_expected_place() + else: + place = device._convert_to_place(device) + + if isinstance(place, core.CUDAPlace): + if not len(state_list) == core.get_cuda_device_count(): + raise ValueError( + "Length of gpu state list shoule be equal to the gpu device count" + ) + for i in range(core.get_cuda_device_count()): + set_state(core.default_cuda_generator(i), state_list[i]) + elif isinstance(place, core.XPUPlace): + if not len(state_list) == core.get_xpu_device_count(): + raise ValueError( + "Length of xpu state list shoule be equal to the xpu device count" + ) + for i in range(core.get_xpu_device_count()): + set_state(core.default_xpu_generator(i), state_list[i]) + elif isinstance(place, core.CustomPlace): + dev_cnt = sum( + [ + place.get_device_type() == s.split(':')[0] + for s in core.get_available_custom_device() + ] + ) + if not len(state_list) == dev_cnt: + raise ValueError( + f"Length of custom device state list shoule be equal to the {place.get_dtype_type()} device count" + ) + for i in range(dev_cnt): + set_state( + core.default_custom_device_generator( + core.CustomPlace(place.get_device_type(), i) + ), + state_list[i], + ) + elif isinstance(place, core.CPUPlace): + if not len(state_list) == 1: + raise ValueError("Length of cpu state list shoule be equal to 1") + set_state(core.default_cpu_generator(), state_list[0]) + else: + raise ValueError( + f"set_rng_state is not implemented for current device: {place}" + ) + + +def register_rng_state_as_index(state_list=None, device=None): + """ + + The register_rng_state_as_index function creates and registers a new generator state within the generator. + It enables users to manage multiple generator states via indices, + offering a convenient way to switch between these states without directly manipulating the generator's state. + + Args: + state_list(list|tuple): A list or tuple representing the RNG states for devices. + If not provided, the function will register the current state. + device(str): This parameter determines the specific running device. + It can be ``cpu``, ``gpu``, ``xpu``, Default is None. + If None, return the generators of current device (specified by ``set_device``). + + Returns: + A list of indices representing the positions at which the new states were saved within the generator. + These indices can be used to switch between states using set_rng_state(use_index=True) + + + Examples: + .. code-block:: python + + >>> import paddle + >>> old_index = paddle.incubate.get_rng_state(use_index=True) + >>> print(old_index) + [0] + >>> new_index = paddle.incubate.register_rng_state_as_index() + >>> print(new_index) + [1] + >>> paddle.incubate.set_rng_state(old_index, use_index=True) + >>> paddle.incubate.set_rng_state(new_index, use_index=True) + + """ + new_state_index_list = [] + + if device is None: + place = base.framework._current_expected_place() + else: + place = device._convert_to_place(device) + + if state_list is None: + state_list = get_rng_state(device) + + if isinstance(place, core.CUDAPlace): + if not len(state_list) == core.get_cuda_device_count(): + raise ValueError( + "Length of gpu state list shoule be equal to the gpu device count" + ) + for i in range(core.get_cuda_device_count()): + new_state_index_list.append( + core.default_cuda_generator(i).register_state_index( + state_list[i] + ) + ) + elif isinstance(place, core.XPUPlace): + if not len(state_list) == core.get_xpu_device_count(): + raise ValueError( + "Length of xpu state list shoule be equal to the xpu device count" + ) + for i in range(core.get_xpu_device_count()): + new_state_index_list.append( + core.default_xpu_generator(i).register_state_index( + state_list[i] + ) + ) + elif isinstance(place, core.CustomPlace): + dev_cnt = sum( + [ + place.get_device_type() == s.split(':')[0] + for s in core.get_available_custom_device() + ] + ) + if not len(state_list) == dev_cnt: + raise ValueError( + f"Length of custom device state list shoule be equal to the {place.get_dtype_type()} device count" + ) + for i in range(dev_cnt): + new_state_index_list.append( + core.default_custom_device_generator( + core.CustomPlace(place.get_device_type(), i) + ).register_state_index(state_list[i]) + ) + elif isinstance(place, core.CPUPlace): + if not len(state_list) == 1: + raise ValueError("Length of cpu state list shoule be equal to 1") + new_state_index_list.append( + core.default_cpu_generator().register_state_index(state_list[0]) + ) + else: + raise ValueError( + f"register_rng_state_index is not implemented for current device: {place}" + ) + return new_state_index_list diff --git a/python/setup.py.in b/python/setup.py.in index 772d9f77aca622..3ec9e1577009f2 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -418,6 +418,7 @@ packages=['paddle', 'paddle.incubate.nn', 'paddle.incubate.asp', 'paddle.incubate.passes', + 'paddle.incubate.framework', 'paddle.distribution', 'paddle.distributed.utils', 'paddle.distributed.sharding', diff --git a/setup.py b/setup.py index 72ca9b1bce7239..7e7189e10cb435 100644 --- a/setup.py +++ b/setup.py @@ -1421,6 +1421,7 @@ def get_setup_parameters(): 'paddle.incubate.nn', 'paddle.incubate.asp', 'paddle.incubate.passes', + 'paddle.incubate.framework', 'paddle.distribution', 'paddle.distributed.utils', 'paddle.distributed.sharding', diff --git a/test/legacy_test/test_cuda_graphed_layer.py b/test/legacy_test/test_cuda_graphed_layer.py index 5bfdd3c81f5c8e..cc54699a951c60 100644 --- a/test/legacy_test/test_cuda_graphed_layer.py +++ b/test/legacy_test/test_cuda_graphed_layer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/test/legacy_test/test_random_generator_set_get_state.py b/test/legacy_test/test_random_generator_set_get_state.py new file mode 100644 index 00000000000000..d3840a1ee0d8a2 --- /dev/null +++ b/test/legacy_test/test_random_generator_set_get_state.py @@ -0,0 +1,85 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.base import core, framework + + +def get_default_generator(): + """Get default generator for different devices.""" + place = framework._current_expected_place() + if isinstance(place, core.CPUPlace): + return core.default_cpu_generator() + elif isinstance(place, core.CUDAPlace): + return core.default_cuda_generator(0) + elif isinstance(place, core.XPUPlace): + return core.default_xpu_generator(0) + elif isinstance(place, core.CustomPlace): + return core.default_custom_device_generator( + core.CustomPlace(place.get_device_type(), 0) + ) + + +def convert_state_to_seed_offset(state): + """Get seed and offset from state.""" + device, seed, offset = (int(i) for i in str(state).split(' ')[:3]) + return np.array([device, seed, offset]) + + +def generate_random_number_and_states(gen): + """Concatenate random number and state for compare.""" + ret = [] + for i in range(3): + x = paddle.uniform([10], dtype="float32", min=0.0, max=1.0).numpy() + state = convert_state_to_seed_offset(gen.get_state()) + ret.append(np.concatenate([x, state])) + return np.array(ret) + + +class TestRandomGeneratorSetGetState(unittest.TestCase): + def test_random_generator_set_get_state(self): + """Test Generator Get/Set state with Index.""" + paddle.seed(102) + gen = get_default_generator() + orig_state = gen.get_state() + + x = generate_random_number_and_states(gen) + + assert_array_equal = lambda x, y: np.testing.assert_array_equal(x, y) + + paddle.seed(102) + + assert_array_equal(x, generate_random_number_and_states(gen)) + + gen.set_state(orig_state) + + assert_array_equal(x, generate_random_number_and_states(gen)) + + gen.set_state(orig_state) + orig_index = gen.get_state_index() + new_index = gen.register_state_index(orig_state) + + assert_array_equal(x, generate_random_number_and_states(gen)) + + gen.set_state_index(orig_index) + + assert_array_equal(x, generate_random_number_and_states(gen)) + + +if __name__ == "__main__": + unittest.main() From 116c892a087a17c0a7121c8fe54c738d96c4d69a Mon Sep 17 00:00:00 2001 From: JYChen Date: Fri, 5 Jan 2024 16:53:13 +0800 Subject: [PATCH 128/142] tensor_array slice in PIR (#60503) * use slice_array, now will meet error of destory opresult still in use * disable the pir test until the bug fixed --- python/paddle/base/variable_index.py | 21 +++++++++++++++------ test/dygraph_to_static/test_list.py | 1 + 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/python/paddle/base/variable_index.py b/python/paddle/base/variable_index.py index efbb5eb40edc7b..c4b20843864dfa 100644 --- a/python/paddle/base/variable_index.py +++ b/python/paddle/base/variable_index.py @@ -250,7 +250,19 @@ def slice_is_same_to_original(start, end, step): def parse_index(x, indices): - advanced_index = [None] * 2 * len(x.shape) # content is (dim, index) + from .framework import in_pir_mode + + if in_pir_mode(): + is_tensor_array = x.is_dense_tensor_array_type() + else: + is_tensor_array = ( + hasattr(x, "desc") + and x.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY + ) + + advanced_index = ( + [] if is_tensor_array else [None] * 2 * len(x.shape) + ) # content is (dim, index) # for set_value / slice / strided_slice OP decrease_axes = [] axes = [] @@ -267,11 +279,6 @@ def parse_index(x, indices): indices = replace_ellipsis(x, indices) indices, none_axes = replace_none(indices) - is_tensor_array = ( - hasattr(x, "desc") - and x.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY - ) - estimated_dim = 0 dim = 0 for i, slice_item in enumerate(indices): @@ -740,6 +747,8 @@ def get_tensor_with_basic_indexing( if isinstance(end, (list, tuple)): if paddle.utils._contain_var(end): end = paddle.utils.get_int_tensor_list(end) + if x.is_dense_tensor_array_type(): + return paddle._pir_ops.slice_array_dense(x, st) out = paddle._C_ops.slice( x, axes, diff --git a/test/dygraph_to_static/test_list.py b/test/dygraph_to_static/test_list.py index 52db0e53eb6255..ef3d195d90805d 100644 --- a/test/dygraph_to_static/test_list.py +++ b/test/dygraph_to_static/test_list.py @@ -292,6 +292,7 @@ def init_dygraph_func(self): test_list_pop_in_while_loop, ] + # TODO(zhangbo): Refine BuildOpFrom for op with sub_block def train(self, to_static=False): with base.dygraph.guard(): if to_static: From 57feb0a59e36bee800a6aa9331422c79e91141bd Mon Sep 17 00:00:00 2001 From: pangengzheng <117730991+pangengzheng@users.noreply.github.com> Date: Fri, 5 Jan 2024 17:18:41 +0800 Subject: [PATCH 129/142] Set DistModel state_dict keys to structure_names (#60478) * exclude xpu * check structure name mapping * test pp * polish * support dynamic save static load * support dygraph save static load * polish * polish * use structured_name as key in DistModel state_dict * polish * polish * fix checkpoint path conflict * test get_rank_to_files * static save dynamic load test --- .../paddle/distributed/auto_parallel/api.py | 22 +- .../distributed/checkpoint/load_state_dict.py | 62 +++-- .../distributed/checkpoint/save_state_dict.py | 5 +- ...e_dict.py => semi_auto_load_state_dict.py} | 2 +- ..._mutual_load_between_dynamic_and_static.py | 216 ++++++++++++++++++ .../semi_auto_parallel_simple_net_dp_mp.py | 4 + .../semi_auto_parallel_simple_net_dp_mp_pp.py | 4 + ...e_dict.py => semi_auto_save_state_dict.py} | 39 +++- .../test_save_load_state_dict.py | 30 ++- .../semi_auto_parallel_shard_optimizer_api.py | 14 +- .../test_dist_checkpoint_utils.py | 57 +++++ 11 files changed, 414 insertions(+), 41 deletions(-) rename test/auto_parallel/hybrid_strategy/{load_state_dict.py => semi_auto_load_state_dict.py} (98%) create mode 100644 test/auto_parallel/hybrid_strategy/semi_auto_parallel_mutual_load_between_dynamic_and_static.py rename test/auto_parallel/hybrid_strategy/{save_state_dict.py => semi_auto_save_state_dict.py} (56%) diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index 7e734bd95b1b10..efec4383022ed4 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -1015,6 +1015,12 @@ def __init__( ): self._feed_name_list = [] self._inner_strategy = self.__convert_strategy(strategy) + self._structured_to_parameter_name = { + k: v.name for k, v in layer.state_dict().items() + } + self._parameter_to_structured_name = { + v: k for k, v in self._structured_to_parameter_name.items() + } self._engine = Engine( layer, loss, optimizer, metrics, strategy=self._inner_strategy ) @@ -1257,6 +1263,15 @@ def state_dict(self, mode="all"): mode=self._engine._mode ).state_dict(mode) dist_state_dict = self._build_distributed_state_dict(local_state_dict) + mapping_names = [ + self._parameter_to_structured_name[k] + if k in self._parameter_to_structured_name + else k + for k in dist_state_dict.keys() + ] + dist_state_dict = dict( + zip(mapping_names, list(dist_state_dict.values())) + ) return dist_state_dict def _build_distributed_state_dict(self, local_state_dict): @@ -1331,7 +1346,12 @@ def set_state_dict(self, state_dict): ].process_mesh or check_placements_equal( v.placements, cur_v.placements ), f"process_mesh:{v.process_mesh} != {cur_v.process_mesh} or placements:{v.placements} != {cur_v.placements} not match" - local_state_dict[k] = v._local_value() + param_name = ( + self._structured_to_parameter_name[k] + if k in self._structured_to_parameter_name + else k + ) + local_state_dict[param_name] = v._local_value() dist_main_program.set_state_dict(local_state_dict) diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py index 4ae82398713aee..5505d73ab3843c 100644 --- a/python/paddle/distributed/checkpoint/load_state_dict.py +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -15,7 +15,7 @@ import copy import os from dataclasses import dataclass -from typing import Tuple +from typing import Dict, Tuple import paddle from paddle.distributed.communication.group import is_initialized @@ -37,7 +37,13 @@ class ReadItem: lengths: Tuple[int] -def get_rank_to_files(path, state_dict, process_group, use_dist): +PATH_TO_CHECKPOINT_FILES: Dict[str, Tuple[list, list]] = {} + + +def get_checkpoint_files(path, use_cache=True): + global PATH_TO_CHECKPOINT_FILES + if use_cache and path in PATH_TO_CHECKPOINT_FILES: + return PATH_TO_CHECKPOINT_FILES[path] accessible_files = os.listdir(path) metadata_files = [ file for file in accessible_files if file.endswith(".metadata") @@ -45,6 +51,22 @@ def get_rank_to_files(path, state_dict, process_group, use_dist): assert ( len(metadata_files) > 0 ), f"No metadata file found in the checkpoint directory:{path}." + local_data_files = [ + file for file in accessible_files if file.endswith(".distcp") + ] + assert ( + len(local_data_files) > 0 + ), f"No data file found in the checkpoint directory:{path}." + if use_cache: + PATH_TO_CHECKPOINT_FILES[path] = (metadata_files, local_data_files) + return (metadata_files, local_data_files) + + +def get_rank_to_files(path, state_dict, process_group, use_dist): + """ + Get the mapping of rank to its accessible files. + """ + metadata_files, local_data_files = get_checkpoint_files(path) # The neccesary files to be read tensor_key_list = [] necessary_files = [] @@ -62,12 +84,10 @@ def get_rank_to_files(path, state_dict, process_group, use_dist): logger.warning( f"No necessary data files found in the checkpoint directory:{path}. Please check the metadata_files:{metadata_files}" ) - return {} + missing_keys = set(state_dict.keys()) + return {}, missing_keys # allgather all accessible files - local_data_files = [ - file for file in accessible_files if file.endswith(".distcp") - ] global_data_files = [] if use_dist: paddle.distributed.all_gather_object( @@ -101,12 +121,16 @@ def get_rank_to_files(path, state_dict, process_group, use_dist): ] rank_to_files[rank] = local_files logger.debug(f"mapping rank_to_files:{rank_to_files}") - return rank_to_files + return rank_to_files, missing_keys def get_local_load_files(rank_to_files): """ Load files in a load-balanced manner. + + Args: + rank_to_files (dict): mapping from rank to files. + Example: Case1: all ranks access the same data files rank_to_files = {rank0:[0_0.distcp, 1_0.distcp, 2_0.distcp, 3_0.distcp], rank1:[0_0.distcp, 1_0.distcp, 2_0.distcp, 3_0.distcp]} @@ -196,13 +220,7 @@ def update(rank_to_read_files, rank_to_not_read_files, rank_file): def get_load_infos(path, local_load_files, process_group, use_dist): load_info = {} - accessible_files = os.listdir(path) - metadata_files = [ - file for file in accessible_files if file.endswith(".metadata") - ] - assert ( - len(metadata_files) > 0 - ), "No metadata file found in the checkpoint directory:{path}." + metadata_files, _ = get_checkpoint_files(path) for metadata_file in metadata_files: metadata = paddle.load(os.path.join(path, metadata_file)) for local_tensor_index, file_name in metadata.storage_metadata.items(): @@ -277,14 +295,8 @@ def not_overlap( def get_read_items(path, state_dict, process_group, use_dist): - accessible_files = os.listdir(path) - metadata_files = [ - file for file in accessible_files if file.endswith(".metadata") - ] - assert ( - len(metadata_files) > 0 - ), "No metadata file found in the checkpoint directory:{path}." storage_state_dict_metadata = {} + metadata_files, _ = get_checkpoint_files(path) for metadata_file in metadata_files: metadata = paddle.load(os.path.join(path, metadata_file)) for ( @@ -410,7 +422,7 @@ def load_state_dict( for val in flat_state_dict.values(): assert isinstance( val, paddle.Tensor - ), f"Only support dygraph Tensor now, but is {val}" + ), f"The value of state_dict should be a paddle.Tensor, but got: {val}." use_dist = True if paddle.distributed.get_world_size() > 1 else False @@ -422,9 +434,13 @@ def load_state_dict( # sync to avoid some ranks not write path yet paddle.distributed.barrier(process_group) - rank_to_files = get_rank_to_files( + rank_to_files, missing_keys = get_rank_to_files( path, flat_state_dict, process_group, use_dist ) + if len(missing_keys) > 0: + logger.warning( + f"The following keys:{missing_keys} are not found in checkpoint path: {path}." + ) if len(rank_to_files) <= 0: return local_load_files = get_local_load_files(rank_to_files) diff --git a/python/paddle/distributed/checkpoint/save_state_dict.py b/python/paddle/distributed/checkpoint/save_state_dict.py index 86047e637e3609..ab84003a6d2abf 100644 --- a/python/paddle/distributed/checkpoint/save_state_dict.py +++ b/python/paddle/distributed/checkpoint/save_state_dict.py @@ -139,7 +139,7 @@ def save_state_dict( for val in flat_state_dict.values(): assert isinstance( val, paddle.Tensor - ), "Only support dygraph Tensor now, support static DistributedTensor later" + ), f"The value of state_dict should be a paddle.Tensor, but got: {val}." if not os.path.exists(path): os.makedirs(path, exist_ok=True) @@ -188,6 +188,8 @@ def save_state_dict( if local_shape is None or global_offset is None: continue local_tensor = val._local_value() + # Note: The local_tensor must keep the same name with the original tensor. Otherwise, the StructuredToParameterName@@ mapping will be wrong. + local_tensor.name = val.name else: local_shape = tuple(val.shape) global_offset = ( @@ -203,6 +205,7 @@ def save_state_dict( local_storage_metadata[ LocalTensorIndex(key, tuple(global_offset)) ] = file_name + global_state_dict_metadata = [] global_storage_metadata = [] global_flatten_mapping = [] diff --git a/test/auto_parallel/hybrid_strategy/load_state_dict.py b/test/auto_parallel/hybrid_strategy/semi_auto_load_state_dict.py similarity index 98% rename from test/auto_parallel/hybrid_strategy/load_state_dict.py rename to test/auto_parallel/hybrid_strategy/semi_auto_load_state_dict.py index 092fbcc8c09955..a235370c7b3ed3 100644 --- a/test/auto_parallel/hybrid_strategy/load_state_dict.py +++ b/test/auto_parallel/hybrid_strategy/semi_auto_load_state_dict.py @@ -15,7 +15,7 @@ import os import numpy as np -from auto_parallel.hybrid_strategy.save_state_dict import ( +from auto_parallel.hybrid_strategy.semi_auto_save_state_dict import ( get_global_state_dict, ) diff --git a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_mutual_load_between_dynamic_and_static.py b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_mutual_load_between_dynamic_and_static.py new file mode 100644 index 00000000000000..ce751649cd03d4 --- /dev/null +++ b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_mutual_load_between_dynamic_and_static.py @@ -0,0 +1,216 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random + +import numpy as np +from auto_parallel.semi_auto_parallel_dist_to_static_mlp import RandomDataset +from auto_parallel.semi_auto_parallel_simple_net import ( + BATCH_SIZE, + CLASS_NUM, + IMAGE_SIZE, + DemoNet, + TestSimpleNetForSemiAutoParallel, +) + +import paddle +import paddle.distributed as dist +from paddle import nn +from paddle.io import DataLoader + + +class TestSemiAutoParallelMutualLoadBetweenDynamicAndStatic( + TestSimpleNetForSemiAutoParallel +): + def __init__(self): + self._ckpt_path = os.environ.get("ckpt_path") + self._seed = os.environ.get("seed", 123) + + def create_data_loader(self): + images = np.random.rand(BATCH_SIZE, IMAGE_SIZE).astype('float32') + labels = np.random.rand(BATCH_SIZE, CLASS_NUM).astype('float32') + dataset = RandomDataset(images, labels, BATCH_SIZE) + loader = DataLoader(dataset, batch_size=BATCH_SIZE) + return loader + + def run_dynamic(self, layer, opt, data_loader, is_recompute=False): + loss_fn = nn.MSELoss() + + loss_list = [] + for _ in range(5): + for batch_id, (image, label) in enumerate(data_loader()): + if is_recompute: + image.stop_gradient = False + out = layer(image) + loss = loss_fn(out, label) + loss_list.append(loss.numpy()) + loss.backward() + + opt.step() + opt.clear_grad() + return np.array(loss_list) + + def run_dy2static(self, layer, opt, data_loader): + # create loss + loss_fn = nn.MSELoss() + + # static training + dist_model, dist_loader = dist.to_static( + layer, data_loader, loss_fn, opt + ) + loss_list = [] + dist_model.train() + for epoch in range(5): + for batch_id, (image, label) in enumerate(dist_loader()): + loss = dist_model(image, label) + loss_list.append(loss) + + return np.array(loss_list), dist_model + + def set_random_seed(self, seed): + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) + + def test_dygraph_save_static_load(self): + paddle.disable_static() + mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + # set seed to promise the same input for different tp rank + self.set_random_seed(self._seed) + data_loader = self.create_data_loader() + + dy_layer = dist.shard_layer( + DemoNet("dp_mp_hybrid_strategy"), mesh, self.shard_fn + ) + dy_opt = paddle.optimizer.SGD( + learning_rate=0.1, parameters=dy_layer.parameters() + ) + dy_losses = self.run_dynamic(dy_layer, dy_opt, data_loader) + saved_dy_layer_state_dict = dy_layer.state_dict() + ckpt_path = os.path.join( + self._ckpt_path, "test_dygraph_save_static_load" + ) + dist.save_state_dict(saved_dy_layer_state_dict, ckpt_path) + dist.barrier() + + dy2static_opt = paddle.optimizer.SGD( + learning_rate=0.1, parameters=dy_layer.parameters() + ) + + loss_fn = nn.MSELoss() + dist_model, _ = dist.to_static( + dy_layer, data_loader, loss_fn, dy2static_opt + ) + need_load_state_dict = {} + expected_state_dict = {} + with paddle.base.dygraph.guard(): + for k, v in saved_dy_layer_state_dict.items(): + expected_state_dict[k] = v._local_value().clone() + need_load_state_dict[k] = paddle.zeros_like(v) + dist_model.set_state_dict(need_load_state_dict) + state_dict_to_load = dist_model.state_dict(mode="param") + assert len(state_dict_to_load) == len(expected_state_dict) + for k, v in state_dict_to_load.items(): + assert ( + k in expected_state_dict + ), f"key {k} not in expected_state_dict:{expected_state_dict}" + assert np.any( + np.not_equal( + v._local_value().numpy(), + expected_state_dict[k].numpy(), + ) + ), f"key:{k}, v:{v}, expected_state_dict[k]:{expected_state_dict[k]}" + + dist.load_state_dict(state_dict_to_load, ckpt_path) + dist_model.set_state_dict(state_dict_to_load) + + program_state_dict = dist_model.state_dict(mode="param") + assert len(expected_state_dict) == len(program_state_dict) + for k, v in program_state_dict.items(): + assert ( + k in expected_state_dict + ), f"key {k} not in expected_state_dict:{expected_state_dict}" + np.testing.assert_equal( + v._local_value().numpy(), + expected_state_dict[k].numpy(), + ) + + def test_static_save_dynamic_load(self): + paddle.disable_static() + mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + # set seed to promise the same input for different tp rank + self.set_random_seed(self._seed) + data_loader = self.create_data_loader() + + dy_layer = dist.shard_layer( + DemoNet("dp_mp_hybrid_strategy"), mesh, self.shard_fn + ) + + dy2static_opt = paddle.optimizer.SGD( + learning_rate=0.1, parameters=dy_layer.parameters() + ) + dy2static_losses, dist_model = self.run_dy2static( + dy_layer, dy2static_opt, data_loader + ) + + saved_static_layer_state_dict = dist_model.state_dict("param") + ckpt_path = os.path.join( + self._ckpt_path, "test_static_save_dynamic_load" + ) + dist.save_state_dict(saved_static_layer_state_dict, ckpt_path) + dist.barrier() + + paddle.disable_static() + need_load_state_dict = {} + expected_state_dict = {} + with paddle.base.dygraph.guard(): + for k, v in saved_static_layer_state_dict.items(): + expected_state_dict[k] = v._local_value().clone() + need_load_state_dict[k] = paddle.zeros_like(v) + dy_layer.set_state_dict(need_load_state_dict) + state_dict_to_load = dy_layer.state_dict() + assert len(state_dict_to_load) == len(expected_state_dict) + for k, v in state_dict_to_load.items(): + assert ( + k in expected_state_dict + ), f"key {k} not in expected_state_dict:{expected_state_dict}" + assert np.any( + np.not_equal( + v._local_value().numpy(), + expected_state_dict[k].numpy(), + ) + ), f"key:{k}, v:{v}, expected_state_dict[k]:{expected_state_dict[k]}" + + dist.load_state_dict(state_dict_to_load, ckpt_path) + dy_layer.set_state_dict(state_dict_to_load) + + state_dict = dy_layer.state_dict() + assert len(expected_state_dict) == len(state_dict) + for k, v in state_dict.items(): + assert ( + k in expected_state_dict + ), f"key {k} not in expected_state_dict:{expected_state_dict}" + np.testing.assert_equal( + v._local_value().numpy(), + expected_state_dict[k].numpy(), + ) + + def run_test_case(self): + self.test_dygraph_save_static_load() + self.test_static_save_dynamic_load() + + +if __name__ == "__main__": + TestSemiAutoParallelMutualLoadBetweenDynamicAndStatic().run_test_case() diff --git a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py index 4d067108e115d5..8f19196af918d9 100644 --- a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py +++ b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py @@ -14,6 +14,9 @@ import os +from auto_parallel.hybrid_strategy.semi_auto_save_state_dict import ( + check_structure_name_mapping, +) from auto_parallel.semi_auto_parallel_simple_net import ( DemoNet, TestSimpleNetForSemiAutoParallel, @@ -60,6 +63,7 @@ def test_dp_mp_demo_net(self): state_dict = model.state_dict() paddle.distributed.save_state_dict(state_dict, self._ckpt_path) paddle.distributed.barrier() + check_structure_name_mapping(self._ckpt_path, state_dict) expected_local_state_dict = {} need_load_state_dict = {} for k, v in state_dict.items(): diff --git a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp_pp.py b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp_pp.py index 86a9008ac6aa95..0986af8795f92e 100644 --- a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp_pp.py +++ b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp_pp.py @@ -14,6 +14,9 @@ import os +from auto_parallel.hybrid_strategy.semi_auto_save_state_dict import ( + check_structure_name_mapping, +) from auto_parallel.semi_auto_parallel_simple_net import ( DemoNet, TestSimpleNetForSemiAutoParallel, @@ -108,6 +111,7 @@ def test_dp_mp_pp_demo_net(self): state_dict = model.state_dict() paddle.distributed.save_state_dict(state_dict, self._ckpt_path) paddle.distributed.barrier() + check_structure_name_mapping(self._ckpt_path, state_dict) expected_local_state_dict = {} need_load_state_dict = {} for k, v in state_dict.items(): diff --git a/test/auto_parallel/hybrid_strategy/save_state_dict.py b/test/auto_parallel/hybrid_strategy/semi_auto_save_state_dict.py similarity index 56% rename from test/auto_parallel/hybrid_strategy/save_state_dict.py rename to test/auto_parallel/hybrid_strategy/semi_auto_save_state_dict.py index 0fd2f5d7049dbf..ef103ade40c6e8 100644 --- a/test/auto_parallel/hybrid_strategy/save_state_dict.py +++ b/test/auto_parallel/hybrid_strategy/semi_auto_save_state_dict.py @@ -25,6 +25,34 @@ def get_global_state_dict(): return {"w1": w1, "w2": w2} +def check_structure_name_mapping(ckpt_path, state_dict): + metadata_file_path = os.path.join(ckpt_path, "0.metadata") + data_file_path = os.path.join( + ckpt_path, f"{paddle.distributed.get_rank()}_0.distcp" + ) + assert os.path.exists( + metadata_file_path + ), f"metadata file {metadata_file_path} is not found" + assert os.path.exists( + data_file_path + ), f"data file {data_file_path} is not found" + metadata = paddle.load(metadata_file_path) + cur_rank_state_dict = paddle.load(data_file_path, keep_name_table=True) + local_structure_name_mapping = cur_rank_state_dict.pop( + "StructuredToParameterName@@" + ) + assert isinstance( + local_structure_name_mapping, dict + ), f"local_structure_name_mapping:{local_structure_name_mapping} is not dict type" + for structure_name, param_name in local_structure_name_mapping.items(): + assert ( + structure_name in state_dict + ), f"tensor key:{structure_name} is not found in state dict:{state_dict}" + assert ( + param_name == state_dict[structure_name].name + ), f"param name:{param_name} is not equal to param name in state_dict:{state_dict[structure_name].name}" + + class TestSaveStateDict: def __init__(self): self._ckpt_path = os.getenv("ckpt_path") @@ -35,6 +63,7 @@ def test_save_state_dict_with_one_device(self): w1, w2 = list(global_state_dict.values()) state_dict = dict(zip(keys, [w1, w2])) save_state_dict(state_dict, self._ckpt_path) + check_structure_name_mapping(self._ckpt_path, state_dict) def test_save_state_dict_with_four_devices(self): global_state_dict = get_global_state_dict() @@ -42,14 +71,12 @@ def test_save_state_dict_with_four_devices(self): w1, w2 = list(global_state_dict.values()) mesh = dist.ProcessMesh([0, 1]) mesh2 = dist.ProcessMesh([2, 3]) - sharded_w1 = dist.shard_tensor( - w1, mesh, [dist.Shard(0), dist.Replicate()] - ) - sharded_w2 = dist.shard_tensor( - w2, mesh2, [dist.Shard(0), dist.Replicate()] - ) + sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0)]) + sharded_w2 = dist.shard_tensor(w2, mesh2, [dist.Shard(0)]) state_dict = dict(zip(keys, [sharded_w1, sharded_w2])) save_state_dict(state_dict, self._ckpt_path) + paddle.distributed.barrier() + check_structure_name_mapping(self._ckpt_path, state_dict) def run_test_case(self): device_num = int(os.getenv("device_num")) diff --git a/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py b/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py index a0b64a374d6274..446ab9a9f75572 100644 --- a/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py +++ b/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py @@ -23,12 +23,12 @@ def setUp(self): self._default_envs = {} self._changeable_envs = {"device_num": ["1", "2", "4", "8"]} - def test_save_load_state_dict(self): + def test_reshard(self): # save with 1 device ckpt_path = tempfile.TemporaryDirectory() super().setUp(num_of_devices=1, timeout=120, nnode=1) self.run_test_case( - "save_state_dict.py", + "semi_auto_save_state_dict.py", user_defined_envs={"device_num": "1", "ckpt_path": ckpt_path.name}, ) @@ -44,7 +44,7 @@ def test_save_load_state_dict(self): nnode=1, ) self.run_test_case( - "load_state_dict.py", + "semi_auto_load_state_dict.py", user_defined_envs=envs, ) ckpt_path.cleanup() @@ -53,7 +53,7 @@ def test_save_load_state_dict(self): ckpt_path = tempfile.TemporaryDirectory() super().setUp(num_of_devices=4, timeout=120, nnode=1) self.run_test_case( - "save_state_dict.py", + "semi_auto_save_state_dict.py", user_defined_envs={"device_num": "4", "ckpt_path": ckpt_path.name}, ) # load with 1, 2, 4, 8 devices @@ -68,11 +68,31 @@ def test_save_load_state_dict(self): nnode=1, ) self.run_test_case( - "load_state_dict.py", + "semi_auto_load_state_dict.py", user_defined_envs=envs, ) ckpt_path.cleanup() + def test_mutual_load_between_dynamic_and_static(self): + changeable_envs = {"device_num": ["2"]} + envs_list = test_base.gen_product_envs_list( + self._default_envs, changeable_envs + ) + + for envs in envs_list: + ckpt_path = tempfile.TemporaryDirectory() + envs["ckpt_path"] = ckpt_path.name + super().setUp( + num_of_devices=int(envs["device_num"]), + timeout=180, + nnode=1, + ) + self.run_test_case( + "semi_auto_parallel_mutual_load_between_dynamic_and_static.py", + user_defined_envs=envs, + ) + ckpt_path.cleanup() + if __name__ == '__main__': unittest.main() diff --git a/test/auto_parallel/semi_auto_parallel_shard_optimizer_api.py b/test/auto_parallel/semi_auto_parallel_shard_optimizer_api.py index 0153d3bd21216e..d68a3eeb73d303 100644 --- a/test/auto_parallel/semi_auto_parallel_shard_optimizer_api.py +++ b/test/auto_parallel/semi_auto_parallel_shard_optimizer_api.py @@ -116,7 +116,10 @@ def test_shard_optimizer_from_non_shard_layer(self): # save load ckpt_state_dict = opt.state_dict() ckpt_state_dict_keys = list(ckpt_state_dict.keys()) - dist.save_state_dict(ckpt_state_dict, self._ckpt_path) + ckpt_path = os.path.join( + self._ckpt_path, "test_shard_optimizer_from_non_shard_layer" + ) + dist.save_state_dict(ckpt_state_dict, ckpt_path) linear = paddle.nn.Linear(10, 10) new_opt = paddle.optimizer.AdamW(parameters=linear.parameters()) new_opt = dist.shard_optimizer(new_opt) @@ -125,7 +128,7 @@ def test_shard_optimizer_from_non_shard_layer(self): ckpt_state_dict_keys[i]: v for i, (k, v) in enumerate(new_state_dict.items()) } - dist.load_state_dict(new_state_dict, self._ckpt_path) + dist.load_state_dict(new_state_dict, ckpt_path) assert len(new_state_dict) > 0, "load_state_dict fail" for k, v in new_state_dict.items(): assert k in ckpt_state_dict @@ -181,7 +184,10 @@ def test_shard_optimizer_master_params(self): # save load ckpt_state_dict = opt.state_dict() - dist.save_state_dict(ckpt_state_dict, self._ckpt_path) + ckpt_path = os.path.join( + self._ckpt_path, "test_shard_optimizer_master_params" + ) + dist.save_state_dict(ckpt_state_dict, ckpt_path) paddle.distributed.barrier() expected_local_state_dict = {} expected_local_state_dict.setdefault("master_weights", {}) @@ -216,7 +222,7 @@ def test_shard_optimizer_master_params(self): assert ( need_load_state_dict[k].numpy().sum() == 0.0 ), f"state_dict {k} is not zero" - dist.load_state_dict(need_load_state_dict, self._ckpt_path) + dist.load_state_dict(need_load_state_dict, ckpt_path) opt.set_state_dict(need_load_state_dict) new_state_dict = opt.state_dict() assert "master_weights" in new_state_dict, new_state_dict diff --git a/test/auto_parallel/test_dist_checkpoint_utils.py b/test/auto_parallel/test_dist_checkpoint_utils.py index 5a51f73f0fa56c..e709178d1b53fe 100644 --- a/test/auto_parallel/test_dist_checkpoint_utils.py +++ b/test/auto_parallel/test_dist_checkpoint_utils.py @@ -19,6 +19,7 @@ import numpy as np import paddle +import paddle.distributed as dist from paddle.distributed.checkpoint.utils import ( flatten_state_dict, unflatten_state_dict, @@ -100,6 +101,62 @@ def check_state_dict(d1, d2): check_state_dict(recover_state_dict, state_dict) + def test_get_rank_to_files(self): + process_group = None + use_dist = False + ckpt_dir_tmp = tempfile.TemporaryDirectory() + ckpt_dir = ckpt_dir_tmp.name + state_dict = { + "w1": paddle.to_tensor([1, 2]), + "w2": paddle.to_tensor([3, 4]), + } + dist.save_state_dict(state_dict, ckpt_dir) + new_state_dict = { + "w1": paddle.to_tensor([1, 2]), + "w2": paddle.to_tensor([3, 4]), + } + ( + rank_to_files, + missing_keys, + ) = dist.checkpoint.load_state_dict.get_rank_to_files( + ckpt_dir, new_state_dict, process_group, use_dist + ) + self.assertTrue(len(rank_to_files) == 1 and 0 in rank_to_files) + self.assertTrue(rank_to_files[0] == ["0_0.distcp"]) + self.assertTrue(len(missing_keys) == 0) + + new_state_dict = { + "w1": paddle.to_tensor([1, 2]), + "w3": paddle.to_tensor([3, 4]), + } + ( + rank_to_files, + missing_keys, + ) = dist.checkpoint.load_state_dict.get_rank_to_files( + ckpt_dir, new_state_dict, process_group, use_dist + ) + self.assertTrue(len(rank_to_files) == 1 and 0 in rank_to_files) + self.assertTrue(rank_to_files[0] == ["0_0.distcp"]) + self.assertTrue(len(missing_keys) == 1) + self.assertTrue("w3" in missing_keys) + + new_state_dict = { + "w3": paddle.to_tensor([3, 4]), + "w4": paddle.to_tensor([5, 6]), + } + ( + rank_to_files, + missing_keys, + ) = dist.checkpoint.load_state_dict.get_rank_to_files( + ckpt_dir, new_state_dict, process_group, use_dist + ) + self.assertTrue(len(rank_to_files) == 0) + self.assertTrue(len(missing_keys) == 2) + self.assertTrue("w3" in missing_keys) + self.assertTrue("w4" in missing_keys) + + ckpt_dir_tmp.cleanup() + if __name__ == "__main__": unittest.main() From ed6f32d6c47d6fe02bbddae074aec9ec6a26fe69 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Fri, 5 Jan 2024 21:02:26 +0800 Subject: [PATCH 130/142] fix sm75 build bug (#60583) --- .../gemm/threadblock/default_dq_mma_pipelined.h | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h index 8ccf64f3214c37..4ccef1ca983a65 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h @@ -78,7 +78,7 @@ template < /// Instruction-level tile size (concept: GemmShape) typename InstructionShape, /// Operation performed by GEMM - typename Operator> + typename Operator_> struct DqMma< ElementA, LayoutA, @@ -97,9 +97,12 @@ struct DqMma< WarpShape, InstructionShape, 2, - Operator, + Operator_, SharedMemoryClearOption::kNone, typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> { + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value || platform::is_same::value, "Element A must be fp16 or bf16"); @@ -227,7 +230,7 @@ template < /// Instruction-level tile size (concept: GemmShape) typename InstructionShape, /// Operation performed by GEMM - typename Operator, + typename Operator_, /// int RowsPerTile, /// @@ -250,9 +253,12 @@ struct DqMma< WarpShape, InstructionShape, 2, - Operator, + Operator_, SharedMemoryClearOption::kNone, typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> { + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value || platform::is_same::value, "Element A must be fp16 or bf16"); From ee3d2fc7af200280c5c58e0942038d139e70a3a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Sat, 6 Jan 2024 09:53:25 +0800 Subject: [PATCH 131/142] Add CanProveDivisible for symbolic calculation (#60572) * add CanProveDivisible for symbolic calculation * delete extra cout for debug * fix according to some comments --- paddle/cinn/common/integer_set.cc | 121 ++++++++++++++++++++++++- paddle/cinn/common/integer_set.h | 2 + paddle/cinn/common/integer_set_test.cc | 44 +++++++++ 3 files changed, 163 insertions(+), 4 deletions(-) diff --git a/paddle/cinn/common/integer_set.cc b/paddle/cinn/common/integer_set.cc index 762c273caef7c5..1887238c2eb4a8 100644 --- a/paddle/cinn/common/integer_set.cc +++ b/paddle/cinn/common/integer_set.cc @@ -13,6 +13,8 @@ // limitations under the License. #include "paddle/cinn/common/integer_set.h" + +#include "paddle/cinn/common/arithmatic.h" #include "paddle/cinn/ir/ir_mutator.h" #include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/utils/ir_copy.h" @@ -164,11 +166,115 @@ std::optional SymbolicExprAnalyzer::ProveLT(const ir::Expr& lhs, return ProveGT(rhs, lhs); } +// Tell whether lhs can be divisible by rhs, lhs must be a pure math expression +// and rhs must be a var +std::optional SymbolicExprAnalyzer::ProveDivisible( + const ir::Expr& lhs, const ir::Expr& rhs) const { + CHECK(rhs.is_var()) << "Rhs in ProveDivisible must be a var temporarily!\n"; + CHECK(lhs.defined()); + CHECK(rhs.defined()); + CHECK(cinn::common::IsPureMath(lhs)); + + ir::Expr lhs_copy = ir::ir_utils::IRCopy(lhs); + if (cinn::common::is_zero(lhs_copy)) return true; + + auto OptionalAnd = [](const std::optional& lhs, + const std::optional& rhs) -> std::optional { + if (lhs.has_value() && rhs.has_value()) { + return lhs.value() && rhs.value(); + } else { + return std::nullopt; + } + }; + auto OptionalOr = [](const std::optional& lhs, + const std::optional& rhs) -> std::optional { + if (lhs.has_value() && rhs.has_value()) { + return lhs.value() || rhs.value(); + } else if ((!lhs.has_value()) && (!rhs.has_value())) { + return std::nullopt; + } else if (lhs.has_value() && (!rhs.has_value())) { + return lhs.value() ? std::optional(lhs.value()) + : std::optional(std::nullopt); + } else { + return rhs.value() ? std::optional(rhs.value()) + : std::optional(std::nullopt); + } + }; + + std::vector ops{}; + std::optional res = std::nullopt; + ir::Expr zero(0); + ir::Expr tmp_expr; + + auto is_ge = ProveGE(lhs, rhs); + + switch (lhs.node_type()) { + case cinn::ir::IrNodeTy::_Var_: + return ProveEQ(lhs, rhs); + case cinn::ir::IrNodeTy::IntImm: + return false; + case cinn::ir::IrNodeTy::Sum: + res = true; + ops = lhs.As()->operands(); + CHECK(!ops.empty()); + std::for_each(ops.begin(), ops.end(), [&](const ir::Expr& expr) { + res = OptionalAnd(res, this->ProveDivisible(expr, rhs)); + }); + res = OptionalAnd(res, is_ge); + return res; + case cinn::ir::IrNodeTy::Product: + res = false; + ops = lhs.As()->operands(); + CHECK(!ops.empty()); + std::for_each(ops.begin(), ops.end(), [&](const ir::Expr& expr) { + res = OptionalOr(res, this->ProveDivisible(expr, rhs)); + if (res.has_value() && res.value()) return; + }); + res = OptionalAnd(res, is_ge); + return res; + case cinn::ir::IrNodeTy::FracOp: + tmp_expr = cinn::common::AutoSimplify(lhs); + if (tmp_expr.node_type() == cinn::ir::IrNodeTy::FracOp) + return std::nullopt; + return OptionalAnd(ProveDivisible(tmp_expr, rhs), is_ge); + case cinn::ir::IrNodeTy::FloatImm: + return false; + case cinn::ir::IrNodeTy::Add: + return OptionalAnd( + OptionalAnd(ProveDivisible(lhs.As()->a(), rhs), + ProveDivisible(lhs.As()->b(), rhs)), + is_ge); + case cinn::ir::IrNodeTy::Sub: + return OptionalAnd( + OptionalAnd(ProveDivisible(lhs.As()->a(), rhs), + ProveDivisible(lhs.As()->b(), rhs)), + is_ge); + case cinn::ir::IrNodeTy::Div: + tmp_expr = cinn::common::AutoSimplify(lhs); + if (tmp_expr.node_type() == cinn::ir::IrNodeTy::Div) return std::nullopt; + return OptionalAnd(ProveDivisible(tmp_expr, rhs), is_ge); + case cinn::ir::IrNodeTy::Mul: + return OptionalAnd( + OptionalOr(ProveDivisible(lhs.As()->a(), rhs), + ProveDivisible(lhs.As()->b(), rhs)), + is_ge); + case cinn::ir::IrNodeTy::Mod: + return false; + case cinn::ir::IrNodeTy::Minus: + return ProveDivisible(lhs.As()->v(), rhs); + default: + LOG(FATAL) << "Not supported yet!"; + break; + } +} + class BoundReplacer : public ir::IRMutator<> { public: explicit BoundReplacer(const cas_intervals_t& var_intervals, bool is_lower_bound) - : var_intervals_(var_intervals), sign_(is_lower_bound) {} + : var_intervals_(var_intervals), + sign_(is_lower_bound), + var_visited_({}) {} void operator()(ir::Expr* expr) { IRMutator::Visit(expr, expr); } @@ -183,10 +289,16 @@ class BoundReplacer : public ir::IRMutator<> { upper_bound = interval.e_r.defined() ? interval.e_r : ir::Expr(interval.r); } - if (sign_) { - *op = ir::ir_utils::IRCopy(lower_bound); + if (!var_visited_.count(var->name)) { + if (sign_) { + *op = ir::ir_utils::IRCopy(lower_bound); + var_visited_.insert({var->name, lower_bound}); + } else { + *op = ir::ir_utils::IRCopy(upper_bound); + var_visited_.insert({var->name, upper_bound}); + } } else { - *op = ir::ir_utils::IRCopy(upper_bound); + *op = ir::ir_utils::IRCopy(var_visited_.at(var->name)); } } @@ -248,6 +360,7 @@ class BoundReplacer : public ir::IRMutator<> { private: const cas_intervals_t& var_intervals_; + std::unordered_map var_visited_; // Determine replacing with upper or lower bound, // True means lower bound and False means upper bound. bool sign_; diff --git a/paddle/cinn/common/integer_set.h b/paddle/cinn/common/integer_set.h index e0f23da2e744f8..6d095b12083f11 100644 --- a/paddle/cinn/common/integer_set.h +++ b/paddle/cinn/common/integer_set.h @@ -41,6 +41,8 @@ class SymbolicExprAnalyzer { std::optional ProveLE(const ir::Expr& lhs, const ir::Expr& rhs) const; std::optional ProveGT(const ir::Expr& lhs, const ir::Expr& rhs) const; std::optional ProveLT(const ir::Expr& lhs, const ir::Expr& rhs) const; + std::optional ProveDivisible(const ir::Expr& lhs, + const ir::Expr& rhs) const; ir::Expr LowerBound(const ir::Expr& expr) const; ir::Expr UpperBound(const ir::Expr& expr) const; diff --git a/paddle/cinn/common/integer_set_test.cc b/paddle/cinn/common/integer_set_test.cc index 23406ec2f770ea..6d57f2dd0ed257 100644 --- a/paddle/cinn/common/integer_set_test.cc +++ b/paddle/cinn/common/integer_set_test.cc @@ -136,6 +136,50 @@ TEST_F(TestSymbolicExprAnalyzer, compare) { analyzer.Prove(e3 < e4).value()); } +TEST_F(TestSymbolicExprAnalyzer, Divisible) { + auto x = ir::Var(ir::Expr(1), ir::Expr(7), "x"); + auto y = ir::Var(ir::Expr(1), ir::Expr(15), "y"); + auto S = ir::Var(ir::Expr(16), ir::Expr(256), "S"); + + cas_intervals_t divisible_var_intervals = { + {"x", CasInterval(x->lower_bound, x->upper_bound)}, + {"y", CasInterval(y->lower_bound, y->upper_bound)}, + {"S", CasInterval(S->lower_bound, S->upper_bound)}, + }; + SymbolicExprAnalyzer divisible_analyzer{divisible_var_intervals}; + + // case 1 + ir::Expr e1 = 4 * x + 2 * y * x; + ir::Expr e2 = x; + ir::Expr e3 = y; + + EXPECT_TRUE(divisible_analyzer.ProveDivisible(e1, e2).value_or(false)); + EXPECT_FALSE(divisible_analyzer.ProveDivisible(e1, e3).value_or(false)); + + // case 2 + ir::Expr e4 = y + y * x + 4 * y - x * y; + + EXPECT_TRUE(divisible_analyzer.ProveDivisible(e4, e3).value_or(false)); + EXPECT_FALSE(divisible_analyzer.ProveDivisible(e4, e2).value_or(false)); + + // case 3 + ir::Expr e5 = x / y + x + y; + + EXPECT_FALSE(divisible_analyzer.ProveDivisible(e5, e3).value_or(false)); + EXPECT_FALSE(divisible_analyzer.ProveDivisible(e5, e2).value_or(false)); + + // case 4 + ir::Expr e6 = S * x / 4 + x * y; + + EXPECT_FALSE(divisible_analyzer.ProveDivisible(e6, e2).value_or(false)); + EXPECT_FALSE(divisible_analyzer.ProveDivisible(e6, e3).value_or(false)); + + ir::Expr e7 = 16 * x / 4 + x * y; + + EXPECT_TRUE(divisible_analyzer.ProveDivisible(e7, e2).value_or(false)); + EXPECT_FALSE(divisible_analyzer.ProveDivisible(e7, e3).value_or(false)); +} + TEST(SingleIntervalIntSet, constant) { SingleIntervalIntSet empty_set(ir::Expr(0), ir::Expr(-1)); SingleIntervalIntSet all_set(SymbolicExprLimit::negative_inf, From 7c7c5b11b2c908d98f9c9b512231eb7d7a57be79 Mon Sep 17 00:00:00 2001 From: lanxianghit <47554610+lanxianghit@users.noreply.github.com> Date: Sat, 6 Jan 2024 12:57:27 +0800 Subject: [PATCH 132/142] [PIR][DynamicShape] make shape pass default and fix some bugs (#60548) att, make shape pass default and fix some bugs --- .../cinn/hlir/dialect/operator/ir/manual_op.h | 13 +- .../interface/infer_symbolic_shape.cc | 80 +++- .../operator/interface/infer_symbolic_shape.h | 3 + .../pir/dialect/operator/ir/op_dialect.cc | 26 +- paddle/fluid/pir/dialect/operator/ir/ops.yaml | 1 + paddle/fluid/pir/transforms/CMakeLists.txt | 7 +- .../pir/transforms/shape_optimization_pass.cc | 138 +----- .../pir/transforms/shape_optimization_pass.h | 4 - paddle/fluid/pybind/pir.cc | 16 +- paddle/phi/core/flags.cc | 13 + paddle/pir/dialect/shape/utils/dim_expr.h | 6 +- .../shape/utils/shape_optimization_utils.cc | 212 --------- .../shape/utils/shape_optimization_utils.h | 8 - paddle/pir/dialect/shape/utils/shape_utils.cc | 5 +- .../jit/dy2static/pir_partial_program.py | 9 + test/cpp/pir/cinn/adt/map_expr_test.cc | 6 + .../shape_dialect/shape_optimization_test.cc | 12 - .../pir/shape_dialect/shape_struct_test.cc | 424 ------------------ 18 files changed, 152 insertions(+), 831 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h index 8a9acef15aa9d7..77c0e0196d15aa 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h @@ -16,6 +16,7 @@ #include #include "paddle/phi/core/infermeta_utils.h" #include "paddle/pir/core/builder.h" +#include "paddle/pir/core/dll_decl.h" #include "paddle/pir/core/ir_printer.h" #include "paddle/pir/core/op_base.h" #include "paddle/pir/core/operation.h" @@ -24,7 +25,7 @@ namespace cinn { namespace dialect { -class GroupOp : public pir::Op { +class IR_API GroupOp : public pir::Op { public: using Op::Op; static const char *name() { return "cinn_op.group"; } @@ -82,7 +83,7 @@ class IR_API SplitOp : public pir::Op { void VerifySig() const {} }; -class GenerateShapeOp : public pir::Op { +class IR_API GenerateShapeOp : public pir::Op { public: using Op::Op; static const char *name() { return "cinn_op.generate_shape"; } @@ -121,7 +122,7 @@ class GenerateShapeOp : public pir::Op { } // namespace dialect } // namespace cinn -IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GroupOp) -IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::ConcatOp) -IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::SplitOp) -IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GenerateShapeOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GroupOp) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::ConcatOp) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::SplitOp) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GenerateShapeOp); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc index 1b9ca43b7d9f10..79c8e703e1184f 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc @@ -116,7 +116,9 @@ bool ShapeOpInferSymbolicShape(pir::Operation *op, shapes.push_back(dim_expr); } - symbol::ShapeOrDataDimExprs shape_data{shapes}; + symbol::ShapeOrDataDimExprs shape_data{ + shapes, + shape_analysis->value_id_to_shapeordata_[operand_source_id].shape()}; shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; return true; } @@ -146,9 +148,9 @@ bool ReshapeOpInferSymbolicShape( pir::OpResult res = op->result(0); std::string res_id = pir::GetValueId(&res); - symbol::ShapeOrDataDimExprs shape_data; + symbol::ShapeOrDataDimExprs shape_data{ + *(shape_analysis->value_id_to_shapeordata_[operand_source_1_id].data())}; - shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_1_id]; shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; return true; } @@ -158,6 +160,54 @@ bool Reshape_OpInferSymbolicShape( return ReshapeOpInferSymbolicShape(op, shape_analysis); } +bool SliceOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + std::string operand_source_id = pir::GetValueId(&operand_source); + pir::OpResult res = op->result(0); + std::string res_id = pir::GetValueId(&res); + + std::vector dims = + common::vectorize(res.type().dyn_cast().dims()); + + std::vector shapes; + for (int64_t dim : dims) { + symbol::DimExpr dim_expr; + if (dim == -1) { + symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName()); + dim_expr = res_dim_expr; + } else { + symbol::DimExpr res_dim_expr(dim); + dim_expr = res_dim_expr; + } + shapes.push_back(dim_expr); + } + + auto operand_source_1 = op->operand_source(1); + std::string operand_source_1_id = pir::GetValueId(&operand_source_1); + auto starts_array = + (shape_analysis->value_id_to_shapeordata_[operand_source_1_id]).data(); + auto start = starts_array->at(0).Get(); + + auto operand_source_2 = op->operand_source(2); + std::string operand_source_2_id = pir::GetValueId(&operand_source_2); + auto ends_array = + (shape_analysis->value_id_to_shapeordata_[operand_source_2_id]).data(); + auto end = ends_array->at(0).Get(); + + std::vector data; + auto source_data = + (shape_analysis->value_id_to_shapeordata_[operand_source_id]).data(); + + for (int i = start; i < end; i++) { + data.emplace_back(source_data->at(i)); + } + + symbol::ShapeOrDataDimExprs shape_data{shapes, data}; + shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; + return true; +} + } // namespace paddle::dialect namespace cinn::dialect { @@ -184,17 +234,25 @@ bool SliceOpInferSymbolicShape(pir::Operation *op, shapes.push_back(dim_expr); } - // pir::AttributeMap attributes = op->attributes(); + pir::AttributeMap attributes = op->attributes(); + + auto attr_starts = + attributes["starts"].dyn_cast().AsVector(); + auto start = attr_starts[0].dyn_cast().data(); - // auto attr_starts = - // attributes["starts"].dyn_cast().AsVector(); - // auto start = attr_starts[0].dyn_cast().data(); + auto attr_ends = + attributes["ends"].dyn_cast().AsVector(); + auto end = attr_ends[0].dyn_cast().data(); - // auto attr_ends = - // attributes["ends"].dyn_cast().AsVector(); - // auto end = attr_ends[0].dyn_cast().data(); + std::vector data; + auto source_data = + (shape_analysis->value_id_to_shapeordata_[operand_source_id]).data(); + + for (int i = start; i < end; i++) { + data.emplace_back(source_data->at(i)); + } - symbol::ShapeOrDataDimExprs shape_data{shapes}; + symbol::ShapeOrDataDimExprs shape_data{shapes, data}; shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; return true; } diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h index b1c72e3111df23..fc96df40596af5 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h @@ -101,6 +101,9 @@ bool ReshapeOpInferSymbolicShape( bool Reshape_OpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); +bool SliceOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + } // namespace paddle::dialect namespace cinn::dialect { diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index a9129a28793b09..969edf32204bf4 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -44,21 +44,25 @@ struct CombineOpInferSymbolicShapeInterfaceModel : public InferSymbolicShapeInterface::Concept { static inline bool InferSymbolicShape( pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { - symbol::ShapeOrDataDimExprs value_shape; - - // for (auto operand_source : op->operands_source()) { - // std::string operand_source_id = pir::GetValueId(&operand_source); - // auto source_shape_vec = - // shape_analysis->value_id_to_shapeordata_[operand_source_id]; - // for (int i = 0; i < source_shape_vec.size(); i++) { - // value_shape.second.emplace_back(source_shape_vec[i]); - // } - // } + std::vector shapes; + std::vector data; + + for (auto operand_source : op->operands_source()) { + std::string operand_source_id = pir::GetValueId(&operand_source); + auto source_data_p = + shape_analysis->value_id_to_shapeordata_[operand_source_id].data(); + auto source_shape_vec = + source_data_p.value_or(std::vector{}); + for (size_t i = 0; i < source_shape_vec.size(); i++) { + data.emplace_back(source_shape_vec.at(i)); + } + } auto res = op->result(0); auto res_id = pir::GetValueId(&res); - shape_analysis->value_id_to_shapeordata_[res_id] = value_shape; + symbol::ShapeOrDataDimExprs shape_data{shapes, data}; + shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; return true; } diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 221aeb6c7dfa30..97fa1a6879e0a3 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -1179,6 +1179,7 @@ kernel : func : slice backward : slice_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : soft_relu args : (Tensor x, float threshold = 20.0f) diff --git a/paddle/fluid/pir/transforms/CMakeLists.txt b/paddle/fluid/pir/transforms/CMakeLists.txt index 83f9680e1cd5ed..a5ffb11f0063c2 100644 --- a/paddle/fluid/pir/transforms/CMakeLists.txt +++ b/paddle/fluid/pir/transforms/CMakeLists.txt @@ -1,12 +1,9 @@ file(GLOB_RECURSE transforms_srcs "*.cc") if(NOT WITH_CINN) list( - REMOVE_ITEM - transforms_srcs - ${CMAKE_CURRENT_SOURCE_DIR}/build_cinn_pass.cc + REMOVE_ITEM transforms_srcs ${CMAKE_CURRENT_SOURCE_DIR}/build_cinn_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_extract_pass.cc - ${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_detector.cc - ${CMAKE_CURRENT_SOURCE_DIR}/shape_optimization_pass.cc) + ${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_detector.cc) endif() set(transforms_deps drr op_dialect op_dialect_vjp standalone_executor pir diff --git a/paddle/fluid/pir/transforms/shape_optimization_pass.cc b/paddle/fluid/pir/transforms/shape_optimization_pass.cc index 5c6481110034e2..1ad2700684186e 100644 --- a/paddle/fluid/pir/transforms/shape_optimization_pass.cc +++ b/paddle/fluid/pir/transforms/shape_optimization_pass.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/shape_optimization_pass.h" -#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" @@ -27,117 +26,6 @@ #include "paddle/pir/pattern_rewrite/pattern_match.h" #include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" -namespace { - -void InferUnaryElementwiseSymbolicShape( - const pir::Operation& op, - const std::shared_ptr& shape_analysis) { - auto input = op.operand_source(0); - auto output = op.result(0); - const auto& in_sym_dims = - shape_analysis->GetOrCreateSymbolicDimsForRankedValue(input); - const auto& out_sym_dims = - shape_analysis->GetOrCreateSymbolicDimsForRankedValue(output); - pir::SymbolicDimMgr& sym_dim_mgr = shape_analysis->symbolicDimMgr(); - for (auto i = 0; i < out_sym_dims.size(); ++i) { - if (in_sym_dims[i].IsDynamic() || out_sym_dims[i].IsDynamic()) { - sym_dim_mgr.MapSymbolicDimEqual(in_sym_dims[i], out_sym_dims[i]); - } else { - // do nothing - } - } -} - -// TODO(zyfncg): support broadcast for elementwise ops. -void InferBinaryElementwiseSymbolicShape( - const pir::Operation& op, - const std::shared_ptr& shape_analysis) { - auto input0 = op.operand_source(0); - auto input1 = op.operand_source(1); - auto output = op.result(0); - const auto& in_sym_dims0 = - shape_analysis->GetOrCreateSymbolicDimsForRankedValue(input0); - const auto& in_sym_dims1 = - shape_analysis->GetOrCreateSymbolicDimsForRankedValue(input1); - const auto& out_sym_dims = - shape_analysis->GetOrCreateSymbolicDimsForRankedValue(output); - pir::SymbolicDimMgr& sym_dim_mgr = shape_analysis->symbolicDimMgr(); - for (auto i = 0; i < out_sym_dims.size(); ++i) { - if (in_sym_dims0[i].IsDynamic() || in_sym_dims1[i].IsDynamic() || - out_sym_dims[i].IsDynamic()) { - sym_dim_mgr.MapSymbolicDimEqual(in_sym_dims0[i], out_sym_dims[i]); - sym_dim_mgr.MapSymbolicDimEqual(in_sym_dims1[i], out_sym_dims[i]); - } else { - // do nothing - } - } -} - -class InferSymbolicShapePass : public pir::Pass { - public: - InferSymbolicShapePass( - const std::shared_ptr& shape_analysis) - : pir::Pass("infer_symbolic_shape_pass", /*opt_level=*/1), - shape_analysis_(shape_analysis) {} - - void Run(pir::Operation* op) override { - auto module_op = op->dyn_cast(); - IR_ENFORCE(module_op, "infer_symbolic_shape_pass should run on module op."); - - for (auto& op : module_op.block()) { - if (op.isa()) { - for (auto* local_op : op.dyn_cast().ops()) { - InferSymbolicShape(*local_op); - } - } else { - InferSymbolicShape(op); - } - } - } - - bool CanApplyOn(pir::Operation* op) const override { - return op->isa() && op->num_regions() > 0; - } - - private: - typedef void (*InferSymShapeFunc)( - const pir::Operation&, - const std::shared_ptr&); - void InferSymbolicShape(const pir::Operation& op) { - thread_local static std::unordered_map - infer_sym_shape_map(GetInferSymShapeMap()); - auto it = infer_sym_shape_map.find(op.name()); - - if (it != infer_sym_shape_map.end()) { - it->second(op, shape_analysis_); - } else { - LOG(WARNING) << "[" << op.name() - << "] is not supported for infer_symbolic_shape pass."; - } - } - - static std::unordered_map - GetInferSymShapeMap() { - return std::unordered_map{ - {paddle::dialect::ExpOp::name(), &InferUnaryElementwiseSymbolicShape}, - {paddle::dialect::SubtractOp::name(), - &InferBinaryElementwiseSymbolicShape}}; - } - - std::shared_ptr shape_analysis_; -}; - -} // namespace - -namespace pir { - -std::unique_ptr CreateInferSymbolicShapePass( - const std::shared_ptr& shape_analysis) { - return std::make_unique(shape_analysis); -} - -} // namespace pir - namespace pir { namespace { @@ -446,14 +334,12 @@ bool ShapeComputationIRAnalysis::ApplyTieShapeOpConstraint(Operation* op) { bool OptimizeShapeComputation(pir::ModuleOp m, PassPipelineRunner runner) { // TODO(zhangbopd): Do some Canonicalizer. pir::SymbolicDimMgr mgr(m); - IR_ENFORCE(mgr.Load(), - "SymbolicDimMgr Load failed in OptimizeShapeComputation."); + ShapeComputationIRAnalysis analysis(m, mgr); if (!analysis.Run()) { return false; } - IR_ENFORCE(mgr.Save(), - "SymbolicDimMgr save failed in OptimizeShapeComputation."); + return true; } @@ -469,7 +355,7 @@ void PrintProgram(pir::ModuleOp m, std::string mgs) { void DebugPrintOpInfo( pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis = nullptr) { - VLOG(0) << op->name() << ", num_operands: " << op->num_operands(); + VLOG(3) << op->name() << ", num_operands: " << op->num_operands(); for (auto& res : op->results()) { auto value_id = pir::GetValueId(&res); std::ostringstream print_stream; @@ -503,7 +389,7 @@ void DebugPrintOpInfo( } print_stream << "]\n"; } - VLOG(0) << print_stream.str(); + VLOG(3) << print_stream.str(); } } @@ -511,7 +397,7 @@ void InferSymExprForAllValues(ModuleOp module_op) { auto shape_analysis_mgr = ShapeAnalysisManager::Instance(); ShapeConstraintIRAnalysis& shape_analysis = shape_analysis_mgr.Get(module_op.program()); - for (int i = 0; i < module_op->num_regions(); i++) { + for (uint32_t i = 0; i < module_op->num_regions(); i++) { for (auto& block : module_op->region(i)) { for (auto& op : block) { if (op.num_operands() == 0) { @@ -534,18 +420,21 @@ void InferSymExprForAllValues(ModuleOp module_op) { shapes.push_back(dim_expr); } + symbol::ShapeOrDataDimExprs shape_data{shapes}; + shape_analysis.value_id_to_shapeordata_[value_id] = shape_data; + if (op.name() == "pd_op.full_int_array") { + std::vector data; auto attributes = op.attributes(); auto attr = attributes["value"]; auto arr = attr.dyn_cast(); const auto& vec = arr.AsVector(); for (auto item : vec) { int64_t i = item.dyn_cast().data(); - shapes.push_back(symbol::DimExpr(i)); + data.push_back(symbol::DimExpr(i)); } + shape_analysis.value_id_to_shapeordata_[value_id].SetData(data); } - symbol::ShapeOrDataDimExprs shape_data{shapes}; - shape_analysis.value_id_to_shapeordata_[value_id] = shape_data; } } else { auto infer_symbolic_shape_interface = @@ -574,14 +463,11 @@ class ShapeOptimizationPass : public pir::Pass { PrintProgram(module_op, "Origin Program"); InferSymExprForAllValues(module_op); - MaterializeShapeComputation(module_op); // Runner is for Canonicalizer. PassPipelineRunner runner = [this](pir::PassManager& pm, pir::ModuleOp m) { return pm.Run(m.program()); }; - // if (!OptimizeShapeComputation(module_op, runner)) { - // return; - // } + VLOG(3) << "===================== ShapeOptimizationPass Run End. " "============================="; } diff --git a/paddle/fluid/pir/transforms/shape_optimization_pass.h b/paddle/fluid/pir/transforms/shape_optimization_pass.h index fa192972a41b8d..cbaa377157823e 100644 --- a/paddle/fluid/pir/transforms/shape_optimization_pass.h +++ b/paddle/fluid/pir/transforms/shape_optimization_pass.h @@ -22,10 +22,6 @@ namespace pir { class Pass; -// Apply some shape-related optimization. -IR_API std::unique_ptr CreateInferSymbolicShapePass( - const std::shared_ptr& shape_analysis); - IR_API std::unique_ptr CreateShapeOptimizationPass(); } // namespace pir diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index a477f42e40c485..bb0d1d230052a1 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -130,6 +130,7 @@ USE_PIR_PASS(conv2d_add_act_fuse_pass); USE_PIR_PASS(fused_dot_product_attention_pass); PHI_DECLARE_bool(print_ir); +PHI_DECLARE_bool(pir_apply_shape_optimization_pass); namespace paddle { namespace pybind { @@ -1629,7 +1630,6 @@ void AddCinnPass(std::shared_ptr &pass_manager, // NOLINT has_dynamic_shape ? std::make_shared(ctx) : nullptr; - pass_manager->AddPass(pir::CreateShapeOptimizationPass()); cinn::dialect::ir::PdOp2CinnOpConverter(&program); pass_manager->AddPass( @@ -1637,10 +1637,6 @@ void AddCinnPass(std::shared_ptr &pass_manager, // NOLINT pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); pass_manager->AddPass(pir::CreateBuildCinnPass()); - if (has_dynamic_shape) { - pass_manager->AddPass(pir::CreateInferSymbolicShapePass(shape_analysis)); - } - pass_manager->AddPass( cinn::dialect::ir::CreateCinnGroupLoweringPass(shape_analysis)); VLOG(4) << "has_dynamic_shape :" << has_dynamic_shape @@ -1651,8 +1647,18 @@ void AddCinnPass(std::shared_ptr &pass_manager, // NOLINT "compile PaddlePaddle with CINN")); #endif } + +void InferSymbolicShapePass( + std::shared_ptr &pass_manager, // NOLINT + Program &program) { // NOLINT + if (FLAGS_pir_apply_shape_optimization_pass) { + pass_manager->AddPass(pir::CreateShapeOptimizationPass()); + } +} + void BindIrPass(pybind11::module *m) { m->def("add_cinn_pass", AddCinnPass); + m->def("infer_symbolic_shape_pass", InferSymbolicShapePass); py::class_> pass(*m, "Pass", diff --git a/paddle/phi/core/flags.cc b/paddle/phi/core/flags.cc index ea1af5eee4d0b9..77b03f7efda2e6 100644 --- a/paddle/phi/core/flags.cc +++ b/paddle/phi/core/flags.cc @@ -1413,6 +1413,19 @@ PHI_DEFINE_EXPORTED_bool(pir_apply_inplace_pass, "Whether to apply inplace pass on lowering " "::pir::Program to Kernel Dialect"); +/** + * Apply shape optimization pass to new IR FLAG + * Name: pir_apply_shape_optimization_pass + * Since Version: 3.0.0 + * Value Range: bool, default=false + * Example: + * Note: If Ture, will apply shape_optimization pass to new IR. + */ +PHI_DEFINE_EXPORTED_bool(pir_apply_shape_optimization_pass, + false, + "Whether to apply shape_optimization pass " + "to infer symbolic shape"); + PHI_DEFINE_EXPORTED_string( ir_inplace_kernel_blacklist, "", diff --git a/paddle/pir/dialect/shape/utils/dim_expr.h b/paddle/pir/dialect/shape/utils/dim_expr.h index a65390200cd062..4363d507691708 100644 --- a/paddle/pir/dialect/shape/utils/dim_expr.h +++ b/paddle/pir/dialect/shape/utils/dim_expr.h @@ -223,6 +223,8 @@ class ShapeOrData { public: explicit ShapeOrData(const std::vector& shape) : shape_(shape), data_(std::nullopt) {} + explicit ShapeOrData(const std::vector& shape, const std::vector& data) + : shape_(shape), data_(data) {} ShapeOrData() = default; ShapeOrData(const ShapeOrData&) = default; ShapeOrData(ShapeOrData&&) = default; @@ -238,11 +240,9 @@ class ShapeOrData { const std::vector& shape() const { return shape_; } // Specfic for Tensor generated by shape-relevant ops const std::optional>& data() const { return data_; } + void SetData(const std::vector& data) { data_ = data; } private: - explicit ShapeOrData(const std::vector& shape, const std::vector& data) - : shape_(shape), data_(data) {} - std::vector shape_; std::optional> data_; }; diff --git a/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc b/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc index 0d8305c5c934aa..a2e8f4c6ee10a3 100644 --- a/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc +++ b/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc @@ -60,55 +60,6 @@ SymbolicDimMgr::SymbolicDimMgr(ModuleOp m) : m_(m) { symbol_table_ = SymbolTable(func); } -bool SymbolicDimMgr::Load() { - auto func_op = symbol_table_.getOp()->dyn_cast(); - IR_ENFORCE(func_op); - for (auto& op : *(func_op.block())) { - symbol_table_.insert(&op); - if (SymbolicDimOp sym_dim_op = op.dyn_cast()) { - symbol_dim_union_set_[sym_dim_op] = sym_dim_op; - symbol_name_set_.insert(sym_dim_op.GetSymName()); - } - } - return LoadShapeConstraintGraph(); -} - -bool SymbolicDimMgr::LoadShapeConstraintGraph() { - // TODO(zhangbopd): add more constraint function. currently, only support - // tie_product_equal. - auto constraint_vec = - symbol_table_.Lookup("tie_product_equal"); - - if (!constraint_vec.size()) return true; - - auto build_sym_product = [&](std::vector range, - SymbolicDimProduct& product) { - for (Value v : range) { - auto defining_op = v.dyn_cast().owner(); - if (auto constOp = defining_op->dyn_cast()) { - product.factor *= constOp.value().dyn_cast().data(); - continue; - } else if (auto dim_op = defining_op->dyn_cast()) { - auto sym = symbol_table_.Lookup(dim_op.GetName()); - if (!sym) return false; - product.symbols.push_back(sym); - continue; - } - return false; - } - return true; - }; - - for (auto op : constraint_vec) { - SymbolicDimProduct lhs, rhs; - if (!build_sym_product(op.lhs(), lhs) || - !build_sym_product(op.rhs(), rhs) || - !MapSymbolicDimProductEqual(lhs, rhs)) - return false; - } - return true; -} - bool SymbolicDimMgr::MapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs) { SymbolicDimProduct new_lhs, new_rhs; @@ -457,167 +408,4 @@ bool SymbolicDimMgr::IsSymbolicDimProductEqual(const SymbolicDimProduct& lhs, return IsMultipleOfKnownSymbolicDimProductEqualPair(new_lhs, new_rhs); } -bool SymbolicDimMgr::Save() { - using Name2SymbolFn = std::function; - auto update_attrs = [&](ArrayAttribute attrs, Name2SymbolFn fn) { - std::vector new_attrs; - for (Attribute attr : attrs.AsVector()) { - auto sym = fn(attr.dyn_cast().AsString()); - IR_ENFORCE(sym); - SymbolicDimOp root = GetRootSymbolicDim(sym); - Attribute root_symbol = - StrAttribute::get(m_->ir_context(), root.GetSymName()); - new_attrs.push_back(root_symbol); - } - return ArrayAttribute::get(m_->ir_context(), new_attrs); - }; - - // TODO(zhangbopd): update attributes attached in DenseTensorType - for (auto& op : m_.block()) { - if (!op.HasAttribute(SymbolicDimOp::GetSymbolicDimAttrName())) continue; - auto attrs = - op.attribute(SymbolicDimOp::GetSymbolicDimAttrName()); - auto symbolic_shape_attr = - update_attrs(attrs, [&](const std::string& name) { - return symbol_table_.Lookup(name); - }); - op.set_attribute(SymbolicDimOp::GetSymbolicDimAttrName(), - symbolic_shape_attr); - } - if (!UpdateProductEqualityMap()) { - return false; - } - std::unordered_set used_symbolic_ops; - std::vector used_symbol_names; - // TODO(zhangbopd): collect uses in value. - auto collect_used_symbols = [&](ArrayAttribute attrs) { - for (Attribute attr : attrs.AsVector()) { - auto sym = symbol_table_.Lookup( - attr.dyn_cast().AsString()); - IR_ENFORCE(sym); - if (used_symbolic_ops.insert(sym).second) - used_symbol_names.push_back(sym.GetSymName()); - } - }; - for (auto& op : m_.block()) { - if (!op.HasAttribute(SymbolicDimOp::GetSymbolicDimAttrName())) continue; - auto attrs = - op.attribute(SymbolicDimOp::GetSymbolicDimAttrName()); - collect_used_symbols(attrs); - } - auto func_op = symbol_table_.getOp()->dyn_cast(); - IR_ENFORCE(func_op); - for (auto& p : symbol_dim_union_set_) { - if (!used_symbolic_ops.count(p.first)) { - func_op.block()->erase(*(p.first.operation())); - } - } - - std::vector candidates; - for (auto& outter : product_equality_map_) { - if (std::any_of(outter.first.symbols.begin(), - outter.first.symbols.end(), - [&](SymbolicDimOp sym) { - return used_symbolic_ops.count(sym) == 0; - })) - candidates.push_back(outter.first); - } - - for (auto& prod : candidates) product_equality_map_.erase(prod); - for (auto& outter : product_equality_map_) { - std::vector candidates; - for (auto& inner : outter.second) { - if (std::any_of(inner.first.symbols.begin(), - inner.first.symbols.end(), - [&](SymbolicDimOp sym) { - return used_symbolic_ops.count(sym) == 0; - })) - candidates.push_back(outter.first); - } - for (auto& prod : candidates) outter.second.erase(prod); - } - - std::sort(used_symbol_names.begin(), - used_symbol_names.end(), - [&](const std::string& lhs, const std::string& rhs) { - return CompareSymbolicDimNames(lhs, rhs); - }); - int non_const_dims_num = 0; - std::unordered_map name_mapping; - for (const auto& name : used_symbol_names) { - if (name.size() > 0 && name[0] == 'C') { - name_mapping[name] = name; - } else { - name_mapping[name] = ("S" + std::to_string(non_const_dims_num++)); - } - } - - std::unordered_map name_to_symbol; - for (SymbolicDimOp op : used_symbolic_ops) { - auto name = op.GetSymName(); - op.SetSymName(name_mapping[name]); - name_to_symbol[name] = op; - } - - for (auto& op : m_.block()) { - if (!op.HasAttribute(SymbolicDimOp::GetSymbolicDimAttrName())) continue; - auto attrs = - op.attribute(SymbolicDimOp::GetSymbolicDimAttrName()); - auto symbolic_shape_attr = update_attrs( - attrs, [&](const std::string& name) { return name_to_symbol[name]; }); - op.set_attribute(SymbolicDimOp::GetSymbolicDimAttrName(), - symbolic_shape_attr); - } - - // TODO(zhangbopd): update attributes attached to values. - - return SaveShapeConstraintGraph(); -} - -bool SymbolicDimMgr::SaveShapeConstraintGraph() { - auto func_op = symbol_table_.getOp()->dyn_cast(); - IR_ENFORCE(func_op); - auto op_it = func_op.block()->rbegin(); - while (op_it != func_op.block()->rend()) { - if ((op_it->isa()) || - (op_it->isa())) - op_it++; - else - op_it = decltype(op_it)(func_op.block()->erase(*op_it)); - } - - // save product equal predicate - Builder builder = Builder(m_->ir_context(), func_op.block()); - auto build_operands = [&](const SymbolicDimProduct& prod) { - std::vector values; - - if (prod.factor != 1) { - values.push_back( - builder - .Build( - Int32Attribute::get(m_->ir_context(), prod.factor), - Int32Type::get(m_->ir_context())) - ->result(0)); - } - for (SymbolicDimOp sym : prod.symbols) { - values.push_back(builder.Build(sym.GetSymName()).out()); - } - return values; - }; - std::vector sorted_product_vec; - for (auto& p : product_equality_map_) sorted_product_vec.push_back(p.first); - std::sort(sorted_product_vec.begin(), - sorted_product_vec.end(), - CompareSymbolicDimProduct); - for (auto& x : sorted_product_vec) { - for (auto& y : sorted_product_vec) { - if (!CompareSymbolicDimProduct(x, y)) continue; - if (!product_equality_map_[x][y]) continue; - auto lhs_operands = build_operands(x); - auto rhs_operands = build_operands(y); - builder.Build(lhs_operands, rhs_operands); - } - } - return true; -} } // namespace pir diff --git a/paddle/pir/dialect/shape/utils/shape_optimization_utils.h b/paddle/pir/dialect/shape/utils/shape_optimization_utils.h index a2a67c27ff713d..7797ab4f2ffb24 100644 --- a/paddle/pir/dialect/shape/utils/shape_optimization_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_optimization_utils.h @@ -65,9 +65,6 @@ class IR_API SymbolicDimMgr { public: explicit SymbolicDimMgr(ModuleOp m); - // Loads pre-defined SymbolicDimOp ops from the module this mgr runs on. - bool Load(); - // Create a new symbolicDim instance owned by this mgr. SymbolicDimOp NewSymbolicDim(const std::string& name = {}); @@ -117,16 +114,11 @@ class IR_API SymbolicDimMgr { bool MapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs); - // Saves the updated shape constraint IR - bool Save(); - // retuns the SymbolTable. SymbolTable& symbolTable() { return symbol_table_; } private: const std::string GetNextName(); - bool SaveShapeConstraintGraph(); - bool LoadShapeConstraintGraph(); bool UpdateProductEqualityMap(); bool IsMultipleOfKnownSymbolicDimProductEqualPair( const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs); diff --git a/paddle/pir/dialect/shape/utils/shape_utils.cc b/paddle/pir/dialect/shape/utils/shape_utils.cc index 4beb53dde4911b..574805c61f020b 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.cc +++ b/paddle/pir/dialect/shape/utils/shape_utils.cc @@ -47,7 +47,6 @@ bool ShapeAnalysis::IsProductEqual( ShapeConstraintIRAnalysis::ShapeConstraintIRAnalysis(ModuleOp m) : m_(m), mgr_(m) { - mgr_.Load(); for (auto& op : m.block()) { auto tie_shape_op = op.dyn_cast(); if (!tie_shape_op) continue; @@ -66,9 +65,7 @@ ShapeConstraintIRAnalysis::ShapeConstraintIRAnalysis(ModuleOp m) } } -ShapeConstraintIRAnalysis::~ShapeConstraintIRAnalysis() { - // mgr_.Save(); -} +ShapeConstraintIRAnalysis::~ShapeConstraintIRAnalysis() {} bool ShapeConstraintIRAnalysis::IsShapeEqual(Value lhs, Value rhs) { if (lhs == rhs) return true; diff --git a/python/paddle/jit/dy2static/pir_partial_program.py b/python/paddle/jit/dy2static/pir_partial_program.py index 88b51f827581c9..574821ab5b3420 100644 --- a/python/paddle/jit/dy2static/pir_partial_program.py +++ b/python/paddle/jit/dy2static/pir_partial_program.py @@ -548,6 +548,14 @@ def _get_scope(self, program_id=None, use_scope_cache=False): @switch_to_static_graph def _create_program(self, is_infer_mode=False): if is_infer_mode: + # TODO(lanxianghit) mv this into pass_fn + def shape_pass_fn(forward_program, backward_program): + pm = paddle.base.libpaddle.pir.PassManager() + paddle.base.libpaddle.pir.infer_symbolic_shape_pass( + pm, forward_program + ) + pm.run(forward_program) + return forward_program, backward_program def pass_fn(forward_program, backward_program): pm = paddle.base.libpaddle.pir.PassManager() @@ -560,6 +568,7 @@ def pass_fn(forward_program, backward_program): infer_program = self.origin_runable_program.clone() if self._hooker: self._hooker.after_infer(infer_program) + infer_program.apply_pir_program_pass(shape_pass_fn) infer_program.apply_pir_program_pass(pass_fn) return infer_program else: diff --git a/test/cpp/pir/cinn/adt/map_expr_test.cc b/test/cpp/pir/cinn/adt/map_expr_test.cc index 578862495d1e4a..14c90fbc80dedd 100644 --- a/test/cpp/pir/cinn/adt/map_expr_test.cc +++ b/test/cpp/pir/cinn/adt/map_expr_test.cc @@ -74,6 +74,11 @@ TEST(MapExpr, ElementWise_Fusion_0) { ::pir::PassManager pass_manager(ctx); auto shape_analysis = std::make_shared(ctx); + + // TODO(@jiahy0825): use CreateShapeOptimizationPass() instead of + // CreateInferSymbolicShapePass() which is a fake pass + + /* pass_manager.AddPass(::pir::CreateInferSymbolicShapePass(shape_analysis)); pass_manager.Run(&program); @@ -112,4 +117,5 @@ MapExprTest(t_var_2, t_var_1) { } )TEST"; ASSERT_EQ(Trim(map_expr_str), Trim(target_str)); + */ } diff --git a/test/cpp/pir/shape_dialect/shape_optimization_test.cc b/test/cpp/pir/shape_dialect/shape_optimization_test.cc index 63621cce181df3..fb32a6f234f15c 100644 --- a/test/cpp/pir/shape_dialect/shape_optimization_test.cc +++ b/test/cpp/pir/shape_dialect/shape_optimization_test.cc @@ -43,10 +43,6 @@ TEST(shape_optimization, shape_optimization_pass) { // 5 ConstantOp + 5 TensorDim + 2 TieShape + op0 + op1 + 1 funcOp == 15 Ops. EXPECT_EQ(program.block()->size(), 2u); - - pir::SymbolicDimMgr mgr(program.module_op()); - EXPECT_TRUE(mgr.Load()); - EXPECT_TRUE(mgr.Save()); } TEST(shape_optimization, expand_shape_of_op_pattern) { @@ -69,10 +65,6 @@ TEST(shape_optimization, expand_shape_of_op_pattern) { pm.EnableIRPrinting(); pm.AddPass(pir::CreateShapeOptimizationPass()); pm.Run(&program); - - pir::SymbolicDimMgr mgr(program.module_op()); - EXPECT_TRUE(mgr.Load()); - EXPECT_TRUE(mgr.Save()); } TEST(shape_optimization, dim_of_shaped_type_op_interface_pattern) { @@ -100,8 +92,4 @@ TEST(shape_optimization, dim_of_shaped_type_op_interface_pattern) { pm.EnableIRPrinting(); pm.AddPass(pir::CreateShapeOptimizationPass()); pm.Run(&program); - - pir::SymbolicDimMgr mgr(program.module_op()); - EXPECT_TRUE(mgr.Load()); - EXPECT_TRUE(mgr.Save()); } diff --git a/test/cpp/pir/shape_dialect/shape_struct_test.cc b/test/cpp/pir/shape_dialect/shape_struct_test.cc index d2ed8a21b4e6c8..12fbb641ba90cc 100644 --- a/test/cpp/pir/shape_dialect/shape_struct_test.cc +++ b/test/cpp/pir/shape_dialect/shape_struct_test.cc @@ -86,427 +86,3 @@ TEST(shape_struct_test, symbolic_dim_mgr_simple) { EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimEqual(sym_dim_s0, sym_dim_s1)); EXPECT_FALSE(sym_dim_mgr.IsSymbolicDimEqual(sym_dim_s0, sym_dim_c10)); } - -TEST(shape_struct_test, symbolic_dim_mgr_complex) { - /***************************************************************/ - /* Mgr with constraintOp, and SymbolicDimProduct related func. */ - /***************************************************************/ - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Program program(ctx); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - - pir::SymbolicDimMgr sym_dim_mgr(program.module_op()); - auto func_op = - sym_dim_mgr.symbolTable().getOp()->dyn_cast(); - - pir::Builder builder = pir::Builder(ctx, func_op.block()); - - pir::shape::SymbolicDimOp sym_dim_s0 = sym_dim_mgr.NewSymbolicDim("S0"); - pir::shape::SymbolicDimOp sym_dim_s1 = sym_dim_mgr.NewSymbolicDim("S1"); - pir::shape::SymbolicDimOp sym_dim_s2 = sym_dim_mgr.NewSymbolicDim("S2"); - pir::shape::SymbolicDimOp sym_dim_s3 = sym_dim_mgr.NewSymbolicDim("S3"); - pir::shape::SymbolicDimOp sym_dim_s4 = sym_dim_mgr.NewSymbolicDim("S4"); - pir::shape::SymbolicDimOp sym_dim_s5 = sym_dim_mgr.NewSymbolicDim("S5"); - pir::shape::SymbolicDimOp sym_dim_s6 = sym_dim_mgr.NewSymbolicDim("S6"); - pir::shape::SymbolicDimOp sym_dim_s7 = sym_dim_mgr.NewSymbolicDim("S7"); - pir::shape::SymbolicDimOp sym_dim_s8 = sym_dim_mgr.NewSymbolicDim("S8"); - pir::shape::SymbolicDimOp sym_dim_s9 = sym_dim_mgr.NewSymbolicDim("S9"); - pir::shape::SymbolicDimOp sym_dim_s10 = sym_dim_mgr.NewSymbolicDim("S10"); - pir::shape::SymbolicDimOp sym_dim_s11 = sym_dim_mgr.NewSymbolicDim("S11"); - pir::shape::SymbolicDimOp sym_dim_s12 = sym_dim_mgr.NewSymbolicDim("S12"); - pir::shape::SymbolicDimOp sym_dim_c10 = - sym_dim_mgr.NewConstantSymbolicDim(10); - pir::shape::SymbolicDimOp sym_dim_c20 = - sym_dim_mgr.NewConstantSymbolicDim(20); - - pir::OpResult dim_op_s0 = builder.Build("S0").out(); - pir::OpResult dim_op_s1 = builder.Build("S1").out(); - pir::OpResult dim_op_s2 = builder.Build("S2").out(); - pir::OpResult dim_op_s3 = builder.Build("S3").out(); - pir::OpResult dim_op_s4 = builder.Build("S4").out(); - pir::OpResult dim_op_s5 = builder.Build("S5").out(); - pir::OpResult dim_op_s6 = builder.Build("S6").out(); - pir::OpResult dim_op_s7 = builder.Build("S7").out(); - pir::OpResult dim_op_s8 = builder.Build("S8").out(); - pir::OpResult dim_op_s9 = builder.Build("S9").out(); - pir::OpResult dim_op_s10 = builder.Build("S10").out(); - pir::OpResult dim_op_s11 = builder.Build("S11").out(); - pir::OpResult dim_op_c10 = builder.Build("C10").out(); - pir::OpResult dim_op_c20 = builder.Build("C20").out(); - pir::OpResult constant = - builder - .Build(pir::Int32Attribute::get(ctx, 2), - pir::Int32Type::get(ctx)) - ->result(0); - - // Mark S1 == S2. - builder.Build( - 2, 2, std::vector{constant, dim_op_s1, dim_op_s2, constant}); - // Mark S0 * S1 == S2 * S3, For check S0 == S3. - builder.Build( - 2, - 2, - std::vector{dim_op_s0, dim_op_s1, dim_op_s2, dim_op_s3}); - // Mark S4 * S0 * S1 == S2 * S3 * S5, For check S4 == S5. - builder.Build( - 3, - 3, - std::vector{ - dim_op_s4, dim_op_s0, dim_op_s1, dim_op_s2, dim_op_s3, dim_op_s5}); - // For check S6 == C10 * C20. - builder.Build( - 1, 2, std::vector{dim_op_s6, dim_op_c10, dim_op_c20}); - // Mark C10 * S0 * S1 == S2 * S3 * S7, for check C10 == S7. - builder.Build( - 3, - 3, - std::vector{ - dim_op_c10, dim_op_s0, dim_op_s1, dim_op_s2, dim_op_s3, dim_op_s7}); - - // For unsimplify product case: S8 * S9 == S10 * S11 - builder.Build( - 2, - 2, - std::vector{dim_op_s8, dim_op_s9, dim_op_s10, dim_op_s11}); - - auto op = test::CreateDenseTensorOp( - ctx, {-1, -1, -1, -1, -1, -1}, {"op0_attr"}, {"op0_name"}); - auto op_ = test::CreateDenseTensorOp( - ctx, {-1, -1, -1, -1, -1, 10, 20}, {"op1_attr"}, {"op1_name"}); - pir::OpResult res = op->result(0); - pir::OpResult res_ = op_->result(0); - - builder.SetInsertionPointToBlockEnd(program.block()); - pir::shape::TieShapeOp tie_shape_op1 = - builder.Build(res); - pir::shape::TieShapeOp tie_shape_op2 = - builder.Build(res_); - - pir::Attribute attr_s0 = pir::StrAttribute::get(ctx, "S0"); - pir::Attribute attr_s1 = pir::StrAttribute::get(ctx, "S1"); - pir::Attribute attr_s2 = pir::StrAttribute::get(ctx, "S2"); - pir::Attribute attr_s3 = pir::StrAttribute::get(ctx, "S3"); - pir::Attribute attr_s4 = pir::StrAttribute::get(ctx, "S4"); - pir::Attribute attr_s5 = pir::StrAttribute::get(ctx, "S5"); - pir::Attribute attr_s6 = pir::StrAttribute::get(ctx, "S6"); - pir::Attribute attr_s7 = pir::StrAttribute::get(ctx, "S7"); - pir::Attribute attr_s8 = pir::StrAttribute::get(ctx, "S8"); - pir::Attribute attr_s9 = pir::StrAttribute::get(ctx, "S9"); - pir::Attribute attr_s10 = pir::StrAttribute::get(ctx, "S10"); - pir::Attribute attr_s11 = pir::StrAttribute::get(ctx, "S11"); - pir::Attribute attr_c10 = pir::StrAttribute::get(ctx, "C10"); - pir::Attribute attr_c20 = pir::StrAttribute::get(ctx, "C20"); - - std::vector new_attrs1 = { - attr_s0, attr_s1, attr_s2, attr_s3, attr_s4, attr_s5}; - std::vector new_attrs2 = {attr_s6, - attr_s7, - attr_s8, - attr_s9, - attr_s10, - attr_s11, - attr_c10, - attr_c20}; - std::vector new_attrs_ref = { - attr_s0, attr_s1, attr_s1, attr_s0, attr_s2, attr_s2}; - - auto array_attr1 = pir::ArrayAttribute::get(ctx, new_attrs1); - auto array_attr2 = pir::ArrayAttribute::get(ctx, new_attrs2); - auto array_attr_ref = pir::ArrayAttribute::get(ctx, new_attrs_ref); - - tie_shape_op1->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), array_attr1); - tie_shape_op2->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), array_attr2); - - EXPECT_TRUE(sym_dim_mgr.Load()); - - // For check indirect equality: S1 * S4 == S2 * S5 - pir::SymbolicDimProduct sym_dim_product_lhs1; - pir::SymbolicDimProduct sym_dim_product_rhs1; - - sym_dim_product_lhs1.symbols.push_back(sym_dim_s1); - sym_dim_product_lhs1.symbols.push_back(sym_dim_s4); - - sym_dim_product_rhs1.symbols.push_back(sym_dim_s2); - sym_dim_product_rhs1.symbols.push_back(sym_dim_s5); - - // For uncompletely simplied product check: S8 * S9 * S12 == S10 * S11 * S12 - pir::SymbolicDimProduct sym_dim_product_lhs2; - pir::SymbolicDimProduct sym_dim_product_rhs2; - - sym_dim_product_lhs2.symbols.push_back(sym_dim_s8); - sym_dim_product_lhs2.symbols.push_back(sym_dim_s9); - sym_dim_product_lhs2.symbols.push_back(sym_dim_s12); - - sym_dim_product_rhs2.symbols.push_back(sym_dim_s10); - sym_dim_product_rhs2.symbols.push_back(sym_dim_s11); - sym_dim_product_rhs2.symbols.push_back(sym_dim_s12); - - // For check SimplifySymbolicDimProduct, {factor = 1, Sym = {S7}} => {factor = - // 10} - pir::SymbolicDimProduct sym_dim_product_s7; - sym_dim_product_s7.symbols.push_back(sym_dim_s7); - pir::SymbolicDimProduct simplified_product_s7 = - sym_dim_mgr.SimplifySymbolicDimProduct(sym_dim_product_s7); - - // For check SimplifySymbolicDimProductPair, X * Y * Y, Y * Y * Z => X, Z - pir::SymbolicDimProduct sym_dim_product_pair_lhs; - pir::SymbolicDimProduct sym_dim_product_pair_rhs; - pir::SymbolicDimProduct new_lhs, new_rhs; - sym_dim_product_pair_lhs.symbols.push_back(sym_dim_s4); - sym_dim_product_pair_lhs.symbols.push_back(sym_dim_s1); - sym_dim_product_pair_lhs.symbols.push_back(sym_dim_s2); - sym_dim_product_pair_rhs.symbols.push_back(sym_dim_s1); - sym_dim_product_pair_rhs.symbols.push_back(sym_dim_s2); - sym_dim_product_pair_rhs.symbols.push_back(sym_dim_s3); - - std::tie(new_lhs, new_rhs) = sym_dim_mgr.SimplifySymbolicDimProductPair( - sym_dim_product_pair_lhs, sym_dim_product_pair_rhs); - - // For check SymbolicDimProductDivide, {S4 * S1 * C20} / {S1 * C10} => {factor - // = 2 Sym = {S4}} - pir::SymbolicDimProduct sym_dim_product_div_lhs; - pir::SymbolicDimProduct sym_dim_product_div_rhs; - sym_dim_product_div_lhs.symbols.push_back(sym_dim_s4); - sym_dim_product_div_lhs.symbols.push_back(sym_dim_s1); - sym_dim_product_div_lhs.symbols.push_back(sym_dim_c20); - sym_dim_product_div_rhs.symbols.push_back(sym_dim_s1); - sym_dim_product_div_rhs.symbols.push_back(sym_dim_c10); - - pir::SymbolicDimProduct *divRes = sym_dim_mgr.SymbolicDimProductDivide( - sym_dim_product_div_lhs, sym_dim_product_div_rhs); - - EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimEqual(sym_dim_s1, sym_dim_s2)); - EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimEqual(sym_dim_s0, sym_dim_s3)); - EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimEqual(sym_dim_s4, sym_dim_s5)); - EXPECT_EQ(sym_dim_s6.GetDimSize(), 200); - EXPECT_EQ(sym_dim_mgr.symbolTable().Lookup("C20"), - sym_dim_c20); - EXPECT_EQ(sym_dim_s7.GetDimSize(), sym_dim_c10.GetDimSize()); - EXPECT_EQ(simplified_product_s7.factor, 10); - EXPECT_EQ(simplified_product_s7.symbols.size(), static_cast(0)); - EXPECT_EQ(new_lhs.symbols.size(), static_cast(1)); - EXPECT_EQ(new_rhs.symbols.size(), static_cast(1)); - EXPECT_EQ(new_lhs.symbols[0], sym_dim_mgr.GetRootSymbolicDim(sym_dim_s4)); - EXPECT_EQ(new_rhs.symbols[0], sym_dim_mgr.GetRootSymbolicDim(sym_dim_s3)); - EXPECT_EQ(divRes->factor, 2); - EXPECT_EQ(divRes->symbols.size(), static_cast(1)); - EXPECT_EQ(divRes->symbols[0], sym_dim_mgr.GetRootSymbolicDim(sym_dim_s4)); - EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimProductEqual(sym_dim_product_lhs1, - sym_dim_product_rhs1)); - EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimProductEqual(sym_dim_product_lhs2, - sym_dim_product_rhs2)); - EXPECT_TRUE(sym_dim_mgr.Save()); - - pir::SymbolicDimMgr sym_dim_mgr_new(program.module_op()); - EXPECT_TRUE(sym_dim_mgr_new.Load()); - - auto attrs = tie_shape_op1.attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName()); - EXPECT_FALSE( - sym_dim_mgr_new.symbolTable().Lookup("S7")); - EXPECT_EQ(sym_dim_mgr_new.symbolTable() - .Lookup("tie_product_equal") - .size(), - static_cast(1)); - - EXPECT_EQ(attrs.AsVector(), array_attr_ref.AsVector()); -} - -TEST(shape_struct_test, shape_analysis) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Program program(ctx); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - ::pir::Builder builder = ::pir::Builder(ctx, program.block()); - pir::shape::FuncOp func_op = builder.Build(); - - phi::DDim dims_D_2 = {-1, 2}; - phi::DDim dims_2_2 = {2, 2}; - phi::DDim dims_D = {-1}; - - // same shape with dynamic: value1 == value2 - auto op1 = - test::CreateDenseTensorOp(ctx, dims_D_2, {"op1_attr"}, {"op1_name"}); - auto op2 = - test::CreateDenseTensorOp(ctx, dims_D_2, {"op2_attr"}, {"op2_name"}); - pir::OpResult value1 = op1->result(0); - pir::OpResult value2 = op2->result(0); - - // same shape with static: value3 == value4 - auto op3 = - test::CreateDenseTensorOp(ctx, dims_2_2, {"op3_attr"}, {"op3_name"}); - auto op4 = - test::CreateDenseTensorOp(ctx, dims_2_2, {"op4_attr"}, {"op4_name"}); - pir::OpResult value3 = op3->result(0); - pir::OpResult value4 = op4->result(0); - - // one dimension with dynamic: value5 != value1 != value3 - auto op5 = test::CreateDenseTensorOp(ctx, dims_D, {"op5_attr"}, {"op5_name"}); - pir::OpResult value5 = op5->result(0); - - pir::shape::TieShapeOp tie_shape_op1 = - builder.Build(value1); - pir::shape::TieShapeOp tie_shape_op2 = - builder.Build(value2); - pir::shape::TieShapeOp tie_shape_op3 = - builder.Build(value3); - pir::shape::TieShapeOp tie_shape_op4 = - builder.Build(value4); - pir::shape::TieShapeOp tie_shape_op5 = - builder.Build(value5); - - builder.SetInsertionPointToBlockEnd(func_op.block()); - builder.Build("C2", 2, true, false, true, true); - pir::shape::SymbolicDimOp sym_dim_s0 = - builder.Build( - "S0", pir::ShapedTypeInterface::kDynamic, false, false, true, true); - pir::shape::SymbolicDimOp sym_dim_s1 = - builder.Build( - "S1", pir::ShapedTypeInterface::kDynamic, false, false, true, true); - pir::shape::SymbolicDimOp sym_dim_s2 = - builder.Build( - "S2", pir::ShapedTypeInterface::kDynamic, false, false, true, true); - - pir::Attribute attr_s0 = pir::StrAttribute::get(ctx, "S0"); - pir::Attribute attr_s1 = pir::StrAttribute::get(ctx, "S1"); - pir::Attribute attr_s2 = pir::StrAttribute::get(ctx, "S2"); - pir::Attribute attr_c2 = pir::StrAttribute::get(ctx, "C2"); - - auto attr_op1 = pir::ArrayAttribute::get(ctx, {attr_s0, attr_c2}); - auto attr_op2 = pir::ArrayAttribute::get(ctx, {attr_s1, attr_c2}); - auto attr_op3 = pir::ArrayAttribute::get(ctx, {attr_c2, attr_c2}); - auto attr_op4 = pir::ArrayAttribute::get(ctx, {attr_c2, attr_c2}); - auto attr_op5 = pir::ArrayAttribute::get(ctx, {attr_s2}); - - tie_shape_op1->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op1); - tie_shape_op2->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op2); - tie_shape_op3->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op3); - tie_shape_op4->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op4); - tie_shape_op5->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op5); - - pir::ShapeConstraintIRAnalysis shape_analysis(program.module_op()); - EXPECT_TRUE(shape_analysis.IsShapeEqual(value3, value4)); - EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value2)); - EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value3)); - EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value5)); - EXPECT_FALSE(shape_analysis.IsShapeEqual(value3, value5)); - EXPECT_TRUE(shape_analysis.IsProductEqual(value1, {1}, value3, {0})); - EXPECT_TRUE(shape_analysis.IsSameNumElements(value4, value3)); - - shape_analysis.symbolicDimMgr().MapSymbolicDimEqual(sym_dim_s0, sym_dim_s1); - shape_analysis.symbolicDimMgr().MapSymbolicDimEqual(sym_dim_s0, sym_dim_s2); - - const auto &val_sym_dim1 = - shape_analysis.GetOrCreateSymbolicDimsForRankedValue(value1); - const auto &val_sym_dim2 = - shape_analysis.GetOrCreateSymbolicDimsForRankedValue(value2); - EXPECT_TRUE(shape_analysis.symbolicDimMgr().IsSymbolicDimEqual( - val_sym_dim1[0], val_sym_dim2[0])); - - EXPECT_TRUE(shape_analysis.IsShapeEqual(value1, value2)); - EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value5)); -} - -TEST(shape_struct_test, shape_analysis_manager) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Program program(ctx); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - ::pir::Builder builder = ::pir::Builder(ctx, program.block()); - pir::shape::FuncOp func_op = builder.Build(); - - phi::DDim dims_D_2 = {-1, 2}; - phi::DDim dims_2_2 = {2, 2}; - phi::DDim dims_D = {-1}; - - // same shape with dynamic: value1 == value2 - auto op1 = - test::CreateDenseTensorOp(ctx, dims_D_2, {"op1_attr"}, {"op1_name"}); - auto op2 = - test::CreateDenseTensorOp(ctx, dims_D_2, {"op2_attr"}, {"op2_name"}); - pir::OpResult value1 = op1->result(0); - pir::OpResult value2 = op2->result(0); - - // same shape with static: value3 == value4 - auto op3 = - test::CreateDenseTensorOp(ctx, dims_2_2, {"op3_attr"}, {"op3_name"}); - auto op4 = - test::CreateDenseTensorOp(ctx, dims_2_2, {"op4_attr"}, {"op4_name"}); - pir::OpResult value3 = op3->result(0); - pir::OpResult value4 = op4->result(0); - - // one dimension with dynamic: value5 != value1 != value3 - auto op5 = test::CreateDenseTensorOp(ctx, dims_D, {"op5_attr"}, {"op5_name"}); - pir::OpResult value5 = op5->result(0); - - pir::shape::TieShapeOp tie_shape_op1 = - builder.Build(value1); - pir::shape::TieShapeOp tie_shape_op2 = - builder.Build(value2); - pir::shape::TieShapeOp tie_shape_op3 = - builder.Build(value3); - pir::shape::TieShapeOp tie_shape_op4 = - builder.Build(value4); - pir::shape::TieShapeOp tie_shape_op5 = - builder.Build(value5); - - builder.SetInsertionPointToBlockEnd(func_op.block()); - builder.Build("C2", 2, true, false, true, true); - pir::shape::SymbolicDimOp sym_dim_s0 = - builder.Build( - "S0", pir::ShapedTypeInterface::kDynamic, false, false, true, true); - pir::shape::SymbolicDimOp sym_dim_s1 = - builder.Build( - "S1", pir::ShapedTypeInterface::kDynamic, false, false, true, true); - pir::shape::SymbolicDimOp sym_dim_s2 = - builder.Build( - "S2", pir::ShapedTypeInterface::kDynamic, false, false, true, true); - - pir::Attribute attr_s0 = pir::StrAttribute::get(ctx, "S0"); - pir::Attribute attr_s1 = pir::StrAttribute::get(ctx, "S1"); - pir::Attribute attr_s2 = pir::StrAttribute::get(ctx, "S2"); - pir::Attribute attr_c2 = pir::StrAttribute::get(ctx, "C2"); - - auto attr_op1 = pir::ArrayAttribute::get(ctx, {attr_s0, attr_c2}); - auto attr_op2 = pir::ArrayAttribute::get(ctx, {attr_s1, attr_c2}); - auto attr_op3 = pir::ArrayAttribute::get(ctx, {attr_c2, attr_c2}); - auto attr_op4 = pir::ArrayAttribute::get(ctx, {attr_c2, attr_c2}); - auto attr_op5 = pir::ArrayAttribute::get(ctx, {attr_s2}); - - tie_shape_op1->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op1); - tie_shape_op2->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op2); - tie_shape_op3->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op3); - tie_shape_op4->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op4); - tie_shape_op5->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op5); - - auto shape_analysis_mgr = pir::ShapeAnalysisManager::Instance(); - pir::ShapeConstraintIRAnalysis &shape_analysis = - shape_analysis_mgr.Get(&program); - - EXPECT_TRUE(shape_analysis.IsShapeEqual(value3, value4)); - EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value2)); - EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value3)); - EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value5)); - EXPECT_FALSE(shape_analysis.IsShapeEqual(value3, value5)); - EXPECT_TRUE(shape_analysis.IsProductEqual(value1, {1}, value3, {0})); - EXPECT_TRUE(shape_analysis.IsSameNumElements(value4, value3)); - - shape_analysis.symbolicDimMgr().MapSymbolicDimEqual(sym_dim_s0, sym_dim_s1); - shape_analysis.symbolicDimMgr().MapSymbolicDimEqual(sym_dim_s0, sym_dim_s2); - - EXPECT_TRUE(shape_analysis.IsShapeEqual(value1, value2)); - EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value5)); -} From 5e4f4994e3cd417bb16d3d8a278d09e26a162591 Mon Sep 17 00:00:00 2001 From: co63oc Date: Mon, 8 Jan 2024 10:19:05 +0800 Subject: [PATCH 133/142] Fix words (#60603) --- test/ir/inference/auto_scan_test.py | 14 +++++++------- test/ir/inference/program_config.py | 2 +- .../test_conv_elementwise_add2_act_fuse_pass.py | 2 +- test/ir/inference/test_trt_convert_pad.py | 2 +- test/ir/inference/test_trt_convert_reshape.py | 16 ++++++++-------- test/ir/inference/test_trt_convert_rnn.py | 2 +- test/ir/pass_test.py | 10 +++++----- test/ir/pir/test_if_api.py | 8 ++++---- test/ir/pir/test_while_api.py | 4 ++-- test/ir/test_fuse_resnet_unit.py | 2 +- 10 files changed, 31 insertions(+), 31 deletions(-) diff --git a/test/ir/inference/auto_scan_test.py b/test/ir/inference/auto_scan_test.py index 7075c0db0d0a8d..813468c3f73a55 100755 --- a/test/ir/inference/auto_scan_test.py +++ b/test/ir/inference/auto_scan_test.py @@ -278,9 +278,9 @@ def run_test(self, quant=False, *args, **kwargs): model, params, prog_config, base_config, feed_data ) ) - self.success_log(f"basline program_config: {prog_config}") + self.success_log(f"baseline program_config: {prog_config}") self.success_log( - f"basline predictor_config: {self.inference_config_str(base_config)}" + f"baseline predictor_config: {self.inference_config_str(base_config)}" ) for pred_config, (atol, rtol) in self.sample_predictor_configs( @@ -561,11 +561,11 @@ def inference_config_str(self, config) -> str: dic["passes"] = self.passes enable_trt = config.tensorrt_engine_enabled() - trt_precison = config.tensorrt_precision_mode() + trt_precision = config.tensorrt_precision_mode() trt_dynamic_shape = config.tensorrt_dynamic_shape_enabled() if enable_trt: dic["use_trt"] = True - dic["trt_precision"] = trt_precison + dic["trt_precision"] = trt_precision dic["use_dynamic_shape"] = trt_dynamic_shape else: dic["use_trt"] = False @@ -713,11 +713,11 @@ def assert_op_size(self, trt_engine_num, paddle_op_num): def inference_config_str(self, config: paddle_infer.Config) -> str: dic = {} enable_trt = config.tensorrt_engine_enabled() - trt_precison = config.tensorrt_precision_mode() + trt_precision = config.tensorrt_precision_mode() trt_dynamic_shape = config.tensorrt_dynamic_shape_enabled() if enable_trt: dic["use_trt"] = True - dic["trt_precision"] = trt_precison + dic["trt_precision"] = trt_precision dic["use_dynamic_shape"] = trt_dynamic_shape else: dic["use_trt"] = False @@ -755,7 +755,7 @@ def random_to_skip(): gpu_config, prog_config.get_feed_data(), ) - self.success_log(f"basline program_config: {prog_config}") + self.success_log(f"baseline program_config: {prog_config}") for ( pred_config, diff --git a/test/ir/inference/program_config.py b/test/ir/inference/program_config.py index a2f36617de8e6e..7458376de86c8b 100644 --- a/test/ir/inference/program_config.py +++ b/test/ir/inference/program_config.py @@ -260,7 +260,7 @@ def __init__( no_cast_list: Optional[List[str]] = None, ): self.ops = ops - # if no weight need to save, we create a place_holder to help seriazlie params. + # if no weight need to save, we create a place_holder to help serialize params. if not weights: def generate_weight(): diff --git a/test/ir/inference/test_conv_elementwise_add2_act_fuse_pass.py b/test/ir/inference/test_conv_elementwise_add2_act_fuse_pass.py index 6d4076683c19f7..1221c56b331bcf 100755 --- a/test/ir/inference/test_conv_elementwise_add2_act_fuse_pass.py +++ b/test/ir/inference/test_conv_elementwise_add2_act_fuse_pass.py @@ -190,7 +190,7 @@ def sample_program_config(self, draw): ) ) - # 9. Generate legal elemntwise_add: X of conv2d + # 9. Generate legal elementwise_add: X of conv2d bias_2_dict = {} bias_2_dict[1] = [ x_shape[0], diff --git a/test/ir/inference/test_trt_convert_pad.py b/test/ir/inference/test_trt_convert_pad.py index f20b915c2a467b..b34e56aebf9e52 100644 --- a/test/ir/inference/test_trt_convert_pad.py +++ b/test/ir/inference/test_trt_convert_pad.py @@ -137,7 +137,7 @@ def teller1(program_config, predictor_config): self.add_skip_case( teller1, SkipReasons.TRT_NOT_IMPLEMENTED, - "NOT Implemented: we need to add support pad not only inplement on h or w, such as paddings = [0, 0, 1, 1, 1, 1, 1, 1]", + "NOT Implemented: we need to add support pad not only implement on h or w, such as paddings = [0, 0, 1, 1, 1, 1, 1, 1]", ) def test(self): diff --git a/test/ir/inference/test_trt_convert_reshape.py b/test/ir/inference/test_trt_convert_reshape.py index 5762d51c3480f7..43600cfe1c3b47 100644 --- a/test/ir/inference/test_trt_convert_reshape.py +++ b/test/ir/inference/test_trt_convert_reshape.py @@ -88,12 +88,12 @@ def generate_shapeT2_data(attrs: List[Dict[str, Any]]): }, ] self.dims = dims - dics_intput = [{"X": ["reshape_input"]}] + dics_input = [{"X": ["reshape_input"]}] ops_config = [ { "op_type": "reshape", - "op_inputs": dics_intput[0], + "op_inputs": dics_input[0], "op_outputs": {"Out": ["reshape_out"]}, "op_attrs": dics[0], } @@ -228,7 +228,7 @@ def generate_input1(attrs: List[Dict[str, Any]]): {}, ] self.dims = dims - dics_intput = [ + dics_input = [ { "X": ["reshape_input"], "ShapeTensor": ["shapeT1_data", "shapeT2_data"], @@ -257,7 +257,7 @@ def generate_input1(attrs: List[Dict[str, Any]]): }, { "op_type": "reshape", - "op_inputs": dics_intput[0], + "op_inputs": dics_input[0], "op_outputs": {"Out": ["reshape_out"]}, "op_attrs": dics[0], }, @@ -351,7 +351,7 @@ def generate_input1(attrs: List[Dict[str, Any]]): {}, ] self.dims = dims - dics_intput = [ + dics_input = [ { "X": ["reshape_input"], "shape_data": ["shape_data"], @@ -370,7 +370,7 @@ def generate_input1(attrs: List[Dict[str, Any]]): }, { "op_type": "reshape", - "op_inputs": dics_intput[0], + "op_inputs": dics_input[0], "op_outputs": {"Out": ["reshape_out"]}, "op_attrs": dics[0], }, @@ -463,12 +463,12 @@ def generate_input1(attrs: List[Dict[str, Any]]): }, ] self.dims = dims - dics_intput = [{"X": ["reshape_input"]}] + dics_input = [{"X": ["reshape_input"]}] ops_config = [ { "op_type": "reshape", - "op_inputs": dics_intput[0], + "op_inputs": dics_input[0], "op_outputs": {"Out": ["reshape_out"]}, "op_attrs": dics[0], } diff --git a/test/ir/inference/test_trt_convert_rnn.py b/test/ir/inference/test_trt_convert_rnn.py index 296da4db148bc6..d93a1e59bfae5e 100644 --- a/test/ir/inference/test_trt_convert_rnn.py +++ b/test/ir/inference/test_trt_convert_rnn.py @@ -46,7 +46,7 @@ def sample_program_configs(self): "is_bidirec": is_bidirec, "is_test": True, "dropout_prob": 0.0, - # for my convience + # for my convenience "batch": batch, "seq_len": seq_len, } diff --git a/test/ir/pass_test.py b/test/ir/pass_test.py index 2d809e2f5e9bb8..7d892b74590bad 100644 --- a/test/ir/pass_test.py +++ b/test/ir/pass_test.py @@ -187,16 +187,16 @@ def _check_fused_ops(self, program): if program is None or program == self.main_program: program = self._apply_ir_passes() - acctual_num_fused_ops = 0 - # Ir passes can only be applyed to block 0. + actual_num_fused_ops = 0 + # Ir passes can only be applied to block 0. for op in program.block(0).ops: if op.type == self.fused_op_type: - acctual_num_fused_ops += 1 + actual_num_fused_ops += 1 self.assertTrue( - self.num_fused_ops == acctual_num_fused_ops, + self.num_fused_ops == actual_num_fused_ops, "Checking of the number of fused operator < {} > failed. " "Expected: {}, Received: {}".format( - self.fused_op_type, self.num_fused_ops, acctual_num_fused_ops + self.fused_op_type, self.num_fused_ops, actual_num_fused_ops ), ) diff --git a/test/ir/pir/test_if_api.py b/test/ir/pir/test_if_api.py index ff82ff80e16b6f..d33335f7644ae2 100644 --- a/test/ir/pir/test_if_api.py +++ b/test/ir/pir/test_if_api.py @@ -68,7 +68,7 @@ def test_if_with_multiple_output(self): self.assertEqual(last_op.name(), "pd_op.if") self.assertEqual(len(out), 2) - # check Operaion::as_if_op interface + # check Operation::as_if_op interface if_op = last_op.as_if_op() true_block = if_op.true_block() self.assertEqual(len(true_block), 3) @@ -77,7 +77,7 @@ def test_if_with_multiple_output(self): build_pipe_for_block(true_block) self.assertEqual(len(true_block), 4) - # check Operaion::blocks interface + # check Operation::blocks interface block_list = [] for block in out[0].get_defining_op().blocks(): block_list.append(block) @@ -94,7 +94,7 @@ def test_if_op_vjp_interface(self): out_grad = paddle.full(shape=[6, 1], dtype='float32', fill_value=3) # check vjp interface for if_op if_input = [[input] for input in get_used_external_value(if_op)] - if_input_stop_graditents = [[True], [False], [False], [True]] + if_input_stop_gradients = [[True], [False], [False], [True]] if_output = [if_op.results()] if_output_grad = [[out_grad]] self.assertEqual(has_vjp(if_op), True) @@ -103,7 +103,7 @@ def test_if_op_vjp_interface(self): if_input, if_output, if_output_grad, - if_input_stop_graditents, + if_input_stop_gradients, ) self.assertEqual(grad_outs[0][0], None) diff --git a/test/ir/pir/test_while_api.py b/test/ir/pir/test_while_api.py index 1a5ee3186d692a..44bb4541dc3749 100644 --- a/test/ir/pir/test_while_api.py +++ b/test/ir/pir/test_while_api.py @@ -104,7 +104,7 @@ def test_while_op_vjp_interface(self): [input] for input in get_used_external_value(body_block) ] self.assertEqual(len(while_input), 4) - while_input_stop_graditents = [[True], [False], [True], [True]] + while_input_stop_gradients = [[True], [False], [True], [True]] while_output = [[value] for value in while_op.results()] while_output_grad = [[out_grad], [out_grad], [out_grad]] self.assertEqual(has_vjp(while_op), True) @@ -113,7 +113,7 @@ def test_while_op_vjp_interface(self): while_input, while_output, while_output_grad, - while_input_stop_graditents, + while_input_stop_gradients, ) self.assertEqual(grad_outs[0][0], None) diff --git a/test/ir/test_fuse_resnet_unit.py b/test/ir/test_fuse_resnet_unit.py index f9599386f8bfcf..7e5885116e087f 100644 --- a/test/ir/test_fuse_resnet_unit.py +++ b/test/ir/test_fuse_resnet_unit.py @@ -33,7 +33,7 @@ "and device's compute capability is at least 7.0 and less than 9.0", ) class TestFuseResNetUnit(unittest.TestCase): - def test_fuse_resenet_unit(self): + def test_fuse_resnet_unit(self): place = paddle.CUDAPlace(0) program = paddle.static.Program() startup_program = paddle.static.Program() From 0cb83681b3e649f98b37bdaa4b60ce9dac49448b Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Mon, 8 Jan 2024 10:38:33 +0800 Subject: [PATCH 134/142] =?UTF-8?q?=E3=80=90auto=20parallel=E3=80=91custom?= =?UTF-8?q?=20op=20use=20spmd=20rule=20(#60571)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * custom op use smpd rule * custom op use smpd rule --- cmake/inference_lib.cmake | 6 + .../custom_operator/custom_operator_utils.cc | 342 +++++++++++++----- .../distributed/auto_parallel/device_mesh.h | 1 - .../distributed/auto_parallel/dist_attr.h | 1 + .../distributed/auto_parallel/dist_tensor.cc | 15 + .../distributed/auto_parallel/dist_tensor.h | 7 + .../distributed/auto_parallel/process_mesh.cc | 2 + .../distributed/auto_parallel/process_mesh.h | 1 - test/auto_parallel/CMakeLists.txt | 1 + test/auto_parallel/custom_op/CMakeLists.txt | 16 + .../auto_parallel/custom_op/custom_relu_op.cc | 134 +++++++ .../auto_parallel/custom_op/custom_relu_op.cu | 82 +++++ .../custom_op/custom_relu_setup.py | 31 ++ .../semi_auto_parallel_for_custom_op.py | 89 +++++ .../test_semi_auto_parallel_custom_op.py | 52 +++ test/auto_parallel/custom_op/utils.py | 47 +++ 16 files changed, 729 insertions(+), 98 deletions(-) create mode 100644 test/auto_parallel/custom_op/CMakeLists.txt create mode 100644 test/auto_parallel/custom_op/custom_relu_op.cc create mode 100644 test/auto_parallel/custom_op/custom_relu_op.cu create mode 100644 test/auto_parallel/custom_op/custom_relu_setup.py create mode 100644 test/auto_parallel/custom_op/semi_auto_parallel_for_custom_op.py create mode 100644 test/auto_parallel/custom_op/test_semi_auto_parallel_custom_op.py create mode 100644 test/auto_parallel/custom_op/utils.py diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index d0a055d0f2e64c..9f1268ce36c41f 100755 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -335,6 +335,12 @@ copy( DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/distributed/ ) +copy( + inference_lib_dist + SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/distributed/auto_parallel/*.h + DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/distributed/auto_parallel/ +) + copy( inference_lib_dist SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/init_phi.h diff --git a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc index 8894a06267b514..b28357672c046e 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc @@ -453,14 +453,176 @@ paddle::Tensor BuildEmptyDistPaddleTensor( #endif #ifdef PADDLE_WITH_DISTRIBUTE -std::tuple PrepareCtxForAutoParallel( + +phi::distributed::SpmdInfo RunInferSpmdFn( + const paddle::OpMetaInfo& op_info, + const std::vector& inputs, + const std::vector& outputs, + paddle::CustomOpKernelContext& ctx) { // NOLINT + auto& infer_spmd_func = paddle::OpMetaInfoHelper::GetInferSpmdFn(op_info); + if (infer_spmd_func == nullptr) { + // default rule + std::vector meta_dist_inputs; + auto all_inputs = ctx.AllMutableInput(); + for (auto& t : *all_inputs) { + phi::distributed::DistMetaTensor meta_dist_input; + if (t.impl().get()) { + meta_dist_input = + paddle::experimental::MakeDistMetaTensor(*(t.impl().get())); + } + meta_dist_inputs.push_back(meta_dist_input); + } + auto spmd_info_tmp = + phi::distributed::VariadicReplicatedInferSpmdDynamic(meta_dist_inputs); + // flatten input + phi::distributed::SpmdInfo spmd_info; + auto dist_attrs = PADDLE_GET(std::vector, + spmd_info_tmp.first[0]); + for (auto& e : dist_attrs) { + spmd_info.first.push_back(std::move(e)); + } + return spmd_info; + } + std::vector tensor_inputs; + size_t input_size = inputs.size(); + for (size_t i = 0; i < input_size; ++i) { + const auto& in_name = inputs[i]; + if (paddle::framework::detail::IsDuplicableVar(in_name)) { + std::vector meta_tensors; + auto& range = ctx.InputRangeAt(i); + for (size_t j = range.first; j < range.second; ++j) { + auto& t = ctx.InputAt(j); + phi::distributed::DistMetaTensor meta_tensor; + if (t.impl().get()) { + meta_tensor = + paddle::experimental::MakeDistMetaTensor(*(t.impl().get())); + } + meta_tensors.emplace_back(std::move(meta_tensor)); + } + tensor_inputs.emplace_back(std::move(meta_tensors)); + } else { + auto& range = ctx.InputRangeAt(i); + auto& t = ctx.InputAt(range.first); + phi::distributed::DistMetaTensor meta_tensor; + if (t.impl().get()) { + meta_tensor = + paddle::experimental::MakeDistMetaTensor(*(t.impl().get())); + } + tensor_inputs.emplace_back(std::move(meta_tensor)); + } + } + const std::vector& attrs = ctx.Attrs(); + auto spmd_info_tmp = infer_spmd_func(tensor_inputs, attrs); + // flatten input + phi::distributed::SpmdInfo spmd_info; + for (auto& e : spmd_info_tmp.first) { + if (paddle::holds_alternative(e)) { + spmd_info.first.push_back( + std::move(PADDLE_GET(phi::distributed::TensorDistAttr, e))); + } else { + for (auto& ee : + PADDLE_GET(std::vector, e)) { + spmd_info.first.push_back(std::move(ee)); + } + } + } + + // flatten output + for (auto& e : spmd_info_tmp.second) { + if (paddle::holds_alternative(e)) { + spmd_info.second.push_back( + std::move(PADDLE_GET(phi::distributed::TensorDistAttr, e))); + } else { + for (auto& ee : + PADDLE_GET(std::vector, e)) { + spmd_info.second.push_back(std::move(ee)); + } + } + } + + return spmd_info; +} + +std::vector> RunInferShapeFn( const paddle::OpMetaInfo& op_info, bool is_forward, bool is_double_grad, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map, paddle::CustomOpKernelContext& ctx) { // NOLINT + auto& infer_shape_func = paddle::OpMetaInfoHelper::GetInferShapeFn(op_info); + + std::vector> out_dims; + if (infer_shape_func) { + out_dims = + RunInferShapeFunc(ctx, infer_shape_func, inputs, outputs, inplace_map); + } else { + if (is_forward) { + out_dims = RunDefaultInferShapeFunc(ctx, inputs, outputs, inplace_map); + } else { + out_dims = + RunDefaultGradInferShapeFunc(ctx, inputs, outputs, is_double_grad); + } + } + + PADDLE_ENFORCE_EQ( + out_dims.size(), + ctx.OutputRange().size(), + phi::errors::InvalidArgument( + "Custome op infer_shape return size should be %d, but got %d.", + ctx.OutputRange().size(), + out_dims.size())); + + return out_dims; +} + +std::vector> RunInferDtypeFn( + const paddle::OpMetaInfo& op_info, + bool is_forward, + bool is_double_grad, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map, + paddle::CustomOpKernelContext& ctx) { // NOLINT + + auto& infer_dtype_func = paddle::OpMetaInfoHelper::GetInferDtypeFn(op_info); + std::vector> out_dtypes; + if (infer_dtype_func) { + out_dtypes = + RunInferDtypeFunc(ctx, infer_dtype_func, inputs, outputs, inplace_map); + } else { + if (is_forward) { + out_dtypes = RunDefaultInferDtypeFunc(ctx, inputs, outputs, inplace_map); + } else { + out_dtypes = + RunDefaultGradInferDtypeFunc(ctx, inputs, outputs, is_double_grad); + } + } + PADDLE_ENFORCE_EQ( + out_dtypes.size(), + ctx.OutputRange().size(), + phi::errors::InvalidArgument( + "Custome op infer_dtype return size should be %d, but got %d.", + ctx.OutputRange().size(), + out_dtypes.size())); + return out_dtypes; +} + +std:: + tuple + PrepareCtxForAutoParallel( + const paddle::OpMetaInfo& op_info, + bool is_forward, + bool is_double_grad, + paddle::CustomOpKernelContext& ctx, // NOLINT + std::vector>& + dist_inputs, // NOLINT + std::vector& output_dims) { // NOLINT bool run_auto_parallel = false; bool rank_is_in_current_mesh = true; phi::distributed::ProcessMesh current_process_mesh; + phi::distributed::SpmdInfo spmd_info; const auto& inputs = paddle::OpMetaInfoHelper::GetInputs(op_info); const auto& outputs = paddle::OpMetaInfoHelper::GetOutputs(op_info); @@ -483,115 +645,73 @@ std::tuple PrepareCtxForAutoParallel( .process_mesh(); rank_is_in_current_mesh = phi::distributed::IsCurRankInMesh(mesh); - std::vector input_x(x.size()); - for (size_t i = 0; i < input_x.size(); ++i) { - input_x[i] = x.at(i).impl().get(); - } + spmd_info = RunInferSpmdFn(op_info, inputs, outputs, ctx); - auto meta_dist_input_x = paddle::experimental::MakeDistMetaTensor(input_x); - auto spmd_info = - phi::distributed::VariadicReplicatedInferSpmdDynamic(meta_dist_input_x); current_process_mesh = paddle::holds_alternative( spmd_info.first[0]) ? paddle::get<0>(spmd_info.first[0]).process_mesh() : paddle::get<1>(spmd_info.first[0]).at(0).process_mesh(); + std::vector> out_dims = RunInferShapeFn( + op_info, is_forward, is_double_grad, inputs, outputs, inplace_map, ctx); + + std::vector> out_dtypes = RunInferDtypeFn( + op_info, is_forward, is_double_grad, inputs, outputs, inplace_map, ctx); + if (rank_is_in_current_mesh) { auto* dev_ctx = phi::DeviceContextPool::Instance().Get(x.at(0).place()); - auto dist_input_x = paddle::experimental::ReshardApiInputToKernelInput( - dev_ctx, x, spmd_info.first[0]); for (size_t i = 0; i < x.size(); ++i) { + auto dist_input_i = paddle::experimental::ReshardApiInputToKernelInput( + dev_ctx, x[i], spmd_info.first[i]); all_inputs->at(i).set_impl( - std::make_shared(dist_input_x[i]->value())); - } - } else { - auto& infer_shape_func = - paddle::OpMetaInfoHelper::GetInferShapeFn(op_info); - auto& infer_dtype_func = - paddle::OpMetaInfoHelper::GetInferDtypeFn(op_info); - - std::vector> out_dims; - if (infer_shape_func) { - out_dims = RunInferShapeFunc( - ctx, infer_shape_func, inputs, outputs, inplace_map); - } else { - if (is_forward) { - out_dims = - RunDefaultInferShapeFunc(ctx, inputs, outputs, inplace_map); - } else { - out_dims = RunDefaultGradInferShapeFunc( - ctx, inputs, outputs, is_double_grad); - } - } - - std::vector> out_dtypes; - if (infer_dtype_func) { - out_dtypes = RunInferDtypeFunc( - ctx, infer_dtype_func, inputs, outputs, inplace_map); - } else { - if (is_forward) { - out_dtypes = - RunDefaultInferDtypeFunc(ctx, inputs, outputs, inplace_map); - } else { - out_dtypes = RunDefaultGradInferDtypeFunc( - ctx, inputs, outputs, is_double_grad); - } + std::make_shared(dist_input_i->value())); + dist_inputs.emplace_back(dist_input_i); } + } + for (size_t i = 0; i < out_dims.size(); ++i) { + const auto& out_dim = out_dims.at(i); + const auto& out_dtype = out_dtypes.at(i); + const auto& pair = ctx.OutputRangeAt(i); PADDLE_ENFORCE_EQ( - out_dims.size(), - ctx.OutputRange().size(), - phi::errors::InvalidArgument( - "Custome op infer_shape return size should be %d, but got %d.", - ctx.OutputRange().size(), - out_dims.size())); - + out_dim.size(), + pair.second - pair.first, + phi::errors::InvalidArgument("custome op infer_shape result[%d]'s " + "size should be %d, but got %d.", + i, + pair.second - pair.first, + out_dim.size())); PADDLE_ENFORCE_EQ( - out_dtypes.size(), - ctx.OutputRange().size(), - phi::errors::InvalidArgument( - "Custome op infer_dtype return size should be %d, but got %d.", - ctx.OutputRange().size(), - out_dtypes.size())); - - for (size_t i = 0; i < out_dims.size(); ++i) { - const auto& out_dim = out_dims.at(i); - const auto& out_dtype = out_dtypes.at(i); - const auto& pair = ctx.OutputRangeAt(i); - PADDLE_ENFORCE_EQ( - out_dim.size(), - pair.second - pair.first, - phi::errors::InvalidArgument("custome op infer_shape result[%d]'s " - "size should be %d, but got %d.", - i, - pair.second - pair.first, - out_dim.size())); - PADDLE_ENFORCE_EQ( - out_dtype.size(), - pair.second - pair.first, - phi::errors::InvalidArgument("custome op infer_shape result[%d]'s " - "size should be %d, but got %d.", - i, - pair.second - pair.first, - out_dtype.size())); - - if (out_dim.size() == 1) { + out_dtype.size(), + pair.second - pair.first, + phi::errors::InvalidArgument("custome op infer_shape result[%d]'s " + "size should be %d, but got %d.", + i, + pair.second - pair.first, + out_dtype.size())); + + if (out_dim.size() == 1) { + output_dims.emplace_back(out_dim[0]); + if (!rank_is_in_current_mesh) { *(ctx.MutableOutputAt(pair.first)) = BuildEmptyDistPaddleTensor( current_process_mesh, out_dim[0], out_dtype[0]); - } else { - for (size_t j = pair.first; j < pair.second; j++) { + } + } else { + for (size_t j = pair.first; j < pair.second; j++) { + output_dims.emplace_back(out_dim[j]); + if (!rank_is_in_current_mesh) { *(ctx.MutableOutputAt(j)) = BuildEmptyDistPaddleTensor( current_process_mesh, out_dim[j], out_dtype[j]); } } } - return std::tuple( - run_auto_parallel, rank_is_in_current_mesh, current_process_mesh); } } - return std::tuple( - run_auto_parallel, rank_is_in_current_mesh, current_process_mesh); + return {run_auto_parallel, + rank_is_in_current_mesh, + current_process_mesh, + spmd_info}; } #endif @@ -599,27 +719,48 @@ std::tuple PrepareCtxForAutoParallel( void TransCtxTensorsToDistTensors( paddle::CustomOpKernelContext& ctx, // NOLINT bool run_auto_parallel, - const phi::distributed::ProcessMesh& current_process_mesh) { + const phi::distributed::ProcessMesh& current_process_mesh, + const phi::distributed::SpmdInfo& spmd_info, + std::vector>& + dist_inputs, // NOLINT + std::vector& output_dims) { // NOLINT if (run_auto_parallel) { std::vector* output_all = ctx.AllMutableOutput(); for (size_t i = 0; i < output_all->size(); ++i) { auto& tensor = output_all->at(i); - phi::distributed::TensorDistAttr dist_attr = - phi::distributed::TensorDistAttr(common::vectorize(tensor.dims())); - dist_attr.set_process_mesh(current_process_mesh); + phi::distributed::TensorDistAttr dist_attr; + if (!spmd_info.second.empty()) { + dist_attr = PADDLE_GET_CONST(phi::distributed::TensorDistAttr, + spmd_info.second[i]); + } else { + std::vector shape = common::vectorize(output_dims[i]); + dist_attr.set_default_dims_mapping(shape); + dist_attr.set_process_mesh(current_process_mesh); + } auto dist_t = std::make_shared( std::dynamic_pointer_cast(tensor.impl()), + output_dims[i], dist_attr); tensor.set_impl(dist_t); } std::vector* input_all = ctx.AllMutableInput(); for (size_t i = 0; i < input_all->size(); ++i) { auto& tensor = input_all->at(i); - phi::distributed::TensorDistAttr dist_attr = - phi::distributed::TensorDistAttr(common::vectorize(tensor.dims())); - dist_attr.set_process_mesh(current_process_mesh); + phi::distributed::TensorDistAttr dist_attr; + phi::DDim global_dims; + + if (i < dist_inputs.size()) { + auto& dist_input = dist_inputs.at(i); + global_dims = dist_input->dims(); + dist_attr = dist_input->dist_attr(); + } else { + dist_attr = PADDLE_GET_CONST(phi::distributed::TensorDistAttr, + spmd_info.first[i]); + global_dims = tensor.dims(); + } auto dist_t = std::make_shared( std::dynamic_pointer_cast(tensor.impl()), + global_dims, dist_attr); tensor.set_impl(dist_t); } @@ -637,11 +778,15 @@ void run_custom_op_impl(const paddle::OpMetaInfo& op_info, ctx.ConstructInplaceIndex(inputs, outputs, inplace_map); #ifdef PADDLE_WITH_DISTRIBUTE - auto result = - PrepareCtxForAutoParallel(op_info, is_forward, is_double_grad, ctx); + // for output + std::vector> dist_inputs; + std::vector output_dims; + auto result = PrepareCtxForAutoParallel( + op_info, is_forward, is_double_grad, ctx, dist_inputs, output_dims); bool run_auto_parallel = std::get<0>(result); bool rank_is_in_current_mesh = std::get<1>(result); phi::distributed::ProcessMesh current_process_mesh = std::get<2>(result); + auto& spmd_info = std::get<3>(result); if (!rank_is_in_current_mesh) { return; } @@ -667,7 +812,12 @@ void run_custom_op_impl(const paddle::OpMetaInfo& op_info, ctx.AssignInplaceOutputs(); #ifdef PADDLE_WITH_DISTRIBUTE - TransCtxTensorsToDistTensors(ctx, run_auto_parallel, current_process_mesh); + TransCtxTensorsToDistTensors(ctx, + run_auto_parallel, + current_process_mesh, + spmd_info, + dist_inputs, + output_dims); #endif } diff --git a/paddle/phi/core/distributed/auto_parallel/device_mesh.h b/paddle/phi/core/distributed/auto_parallel/device_mesh.h index 8cfdc6ed242f0e..0741e03fe94c0f 100644 --- a/paddle/phi/core/distributed/auto_parallel/device_mesh.h +++ b/paddle/phi/core/distributed/auto_parallel/device_mesh.h @@ -23,7 +23,6 @@ limitations under the License. */ #include #include -#include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h" #include "paddle/phi/core/distributed/auto_parallel/utils.h" #include "paddle/phi/core/enforce.h" diff --git a/paddle/phi/core/distributed/auto_parallel/dist_attr.h b/paddle/phi/core/distributed/auto_parallel/dist_attr.h index d158fc848c8d40..e4016b9f65cdc8 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_attr.h +++ b/paddle/phi/core/distributed/auto_parallel/dist_attr.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include #include +#include #include #include diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc index c41effe6c85220..885797b7386e4f 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc @@ -162,6 +162,21 @@ DistTensor::DistTensor(const std::shared_ptr& local_value, } } +DistTensor::DistTensor(const std::shared_ptr& local_value, + const DDim& global_dims, + const TensorDistAttr& dist_attr) + : global_dims_(global_dims), dist_attr_(dist_attr) { + process_mesh_ = dist_attr_.process_mesh(); + placements_ = ToPlacements(dist_attr); + if (IsCurRankInMesh(process_mesh_)) { + value_ = local_value; + } else { + value_ = std::make_shared( + std::make_shared(nullptr, 0, local_value->place()), + phi::DenseTensorMeta(local_value->dtype(), global_dims_)); + } +} + DistTensor::DistTensor(const std::shared_ptr& global_value, const ProcessMesh& process_mesh, const Placements& placements) diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.h b/paddle/phi/core/distributed/auto_parallel/dist_tensor.h index 55b35ffe5c25a5..5ad10c76b25087 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.h +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.h @@ -58,6 +58,13 @@ class DistTensor final const ProcessMesh& process_mesh, const Placements& placements); + /// \brief Construct a dist tensor based local dense tensor. + /// \param global_dims The global dim of the dist tensor. + /// \param dist_attr The distributed attributes of the current tensor. + DistTensor(const std::shared_ptr& local_value, + const DDim& global_dims, + const TensorDistAttr& dist_attr); + /// \brief Construct a dist tensor based local dense tensor. /// \param global_dims The global dim of the dist tensor. /// \param process_mesh The process mesh of the current tensor. diff --git a/paddle/phi/core/distributed/auto_parallel/process_mesh.cc b/paddle/phi/core/distributed/auto_parallel/process_mesh.cc index a1b60e27c27e67..983725880f3525 100644 --- a/paddle/phi/core/distributed/auto_parallel/process_mesh.cc +++ b/paddle/phi/core/distributed/auto_parallel/process_mesh.cc @@ -16,6 +16,8 @@ limitations under the License. */ #include #include + +#include "paddle/phi/core/distributed/auto_parallel/proto_helper.h" #include "paddle/phi/core/distributed/auto_parallel/utils.h" namespace phi { diff --git a/paddle/phi/core/distributed/auto_parallel/process_mesh.h b/paddle/phi/core/distributed/auto_parallel/process_mesh.h index 792d5e38f5318b..1b76dec23c2a05 100644 --- a/paddle/phi/core/distributed/auto_parallel/process_mesh.h +++ b/paddle/phi/core/distributed/auto_parallel/process_mesh.h @@ -20,7 +20,6 @@ limitations under the License. */ #include #include -#include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h" #include "paddle/phi/core/distributed/auto_parallel/device_mesh.h" #include "paddle/phi/core/distributed/auto_parallel/utils.h" #include "paddle/phi/core/enforce.h" diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index a735762cce6581..ab2b09680c5ad4 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(spmd_rules) add_subdirectory(hybrid_strategy) +add_subdirectory(custom_op) if(WITH_DISTRIBUTE AND WITH_GPU) diff --git a/test/auto_parallel/custom_op/CMakeLists.txt b/test/auto_parallel/custom_op/CMakeLists.txt new file mode 100644 index 00000000000000..b3537bc09c4e04 --- /dev/null +++ b/test/auto_parallel/custom_op/CMakeLists.txt @@ -0,0 +1,16 @@ +set(LOCAL_ALL_ARCH ON) +set(LOCAL_ALL_PLAT ON) +if(WITH_DISTRIBUTE + AND WITH_GPU + AND (LINUX)) + py_test_modules( + test_semi_auto_parallel_custom_op + MODULES + test_semi_auto_parallel_custom_op + ENVS + "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python;PADDLE_SOURCE_DIR=${PROJECT_SOURCE_DIR};WITH_MKLDNN=${WITH_MKLDNN};MKLDNN_INSTALL_DIR=${MKLDNN_INSTALL_DIR};WITH_MKLDNN=${WITH_MKLDNN};WITH_GPU=${WITH_GPU};WITH_ROCM=${WITH_ROCM};externalError_INCLUDE_DIR=${externalError_INCLUDE_DIR};PYBIND_INCLUDE_DIR=${PYBIND_INCLUDE_DIR}" + ) + set_tests_properties(test_semi_auto_parallel_custom_op + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) + +endif() diff --git a/test/auto_parallel/custom_op/custom_relu_op.cc b/test/auto_parallel/custom_op/custom_relu_op.cc new file mode 100644 index 00000000000000..7f76ab92cb2d1a --- /dev/null +++ b/test/auto_parallel/custom_op/custom_relu_op.cc @@ -0,0 +1,134 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/extension.h" +#include "paddle/phi/api/ext/spmd_infer.h" +#include "paddle/phi/infermeta/spmd_rules/rules.h" + +#define CHECK_CPU_INPUT(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.") + +template +void relu_cpu_forward_kernel(const data_t* x_data, + data_t* out_data, + int64_t x_numel) { + PD_CHECK(x_data != nullptr, "x_data is nullptr."); + PD_CHECK(out_data != nullptr, "out_data is nullptr."); + for (int64_t i = 0; i < x_numel; ++i) { + out_data[i] = std::max(static_cast(0.), x_data[i]); + } +} + +template +void relu_cpu_backward_kernel(const data_t* grad_out_data, + const data_t* out_data, + data_t* grad_x_data, + int64_t out_numel) { + for (int64_t i = 0; i < out_numel; ++i) { + grad_x_data[i] = + grad_out_data[i] * (out_data[i] > static_cast(0) ? 1. : 0.); + } +} + +std::vector relu_cpu_forward(const paddle::Tensor& x) { + auto out = paddle::empty_like(x); + + PD_DISPATCH_FLOATING_TYPES( + x.type(), "relu_cpu_forward", ([&] { + relu_cpu_forward_kernel( + x.data(), out.data(), x.numel()); + })); + + return {out}; +} + +std::vector relu_cpu_backward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out) { + auto grad_x = paddle::empty_like(x); + + PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] { + relu_cpu_backward_kernel( + grad_out.data(), + out.data(), + grad_x.data(), + out.size()); + })); + + return {grad_x}; +} + +std::vector relu_cuda_forward(const paddle::Tensor& x); +std::vector relu_cuda_backward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out); + +std::vector ReluForward(const paddle::Tensor& x) { + if (x.is_cpu()) { + return relu_cpu_forward(x); + } else if (x.is_gpu()) { + return relu_cuda_forward(x); + } else { + PD_THROW("Not implemented."); + } +} + +std::vector ReluBackward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out) { + if (x.is_cpu()) { + return relu_cpu_backward(x, out, grad_out); + } else if (x.is_gpu()) { + return relu_cuda_backward(x, out, grad_out); + } else { + PD_THROW("Not implemented."); + } +} + +phi::distributed::SpmdInfo ReluGradInferSpmd( + const phi::distributed::DistMetaTensor& x, + const phi::distributed::DistMetaTensor& out, + const phi::distributed::DistMetaTensor& out_grad) { + return phi::distributed::ElementwiseUnaryGradInferSpmd(x, out, out_grad); +} + +PD_BUILD_OP(custom_relu) + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(ReluForward)) + .SetInferSpmdFn( + PD_INFER_SPMD_RULE(phi::distributed::ElementwiseUnaryInferSpmd)); + +PD_BUILD_GRAD_OP(custom_relu) + .Inputs({"X", "Out", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(ReluBackward)) + .SetInferSpmdFn(PD_INFER_SPMD_RULE(ReluGradInferSpmd)); + +PD_BUILD_OP(custom_relu_no_spmd) + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(ReluForward)); + +PD_BUILD_GRAD_OP(custom_relu_no_spmd) + .Inputs({"X", "Out", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(ReluBackward)); + +PD_REGISTER_SPMD_RULE( + custom_relu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); diff --git a/test/auto_parallel/custom_op/custom_relu_op.cu b/test/auto_parallel/custom_op/custom_relu_op.cu new file mode 100644 index 00000000000000..810ff75be5578f --- /dev/null +++ b/test/auto_parallel/custom_op/custom_relu_op.cu @@ -0,0 +1,82 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" + +#define CHECK_GPU_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.") + +template +__global__ void relu_cuda_forward_kernel(const data_t* x, + data_t* y, + int64_t num) { + int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) { + y[i] = x[i] > static_cast(0.) ? x[i] : static_cast(0.); + } +} + +template +__global__ void relu_cuda_backward_kernel(const data_t* dy, + const data_t* y, + data_t* dx, + int64_t num) { + int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) { + dx[i] = dy[i] * (y[i] > static_cast(0.) ? static_cast(1.) + : static_cast(0.)); + } +} + +std::vector relu_cuda_forward(const paddle::Tensor& x) { + CHECK_GPU_INPUT(x); + auto out = paddle::empty_like(x); + + PD_CHECK(x.place() == paddle::DefaultGPUPlace()); + + int64_t numel = x.numel(); + int64_t block = 512; + int64_t grid = (numel + block - 1) / block; + PD_DISPATCH_FLOATING_AND_HALF_TYPES( + x.type(), "relu_cuda_forward_kernel", ([&] { + relu_cuda_forward_kernel<<>>( + x.data(), out.data(), numel); + })); + + return {out}; +} + +std::vector relu_cuda_backward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out) { + CHECK_GPU_INPUT(x); + CHECK_GPU_INPUT(out); + CHECK_GPU_INPUT(grad_out); + auto grad_x = paddle::empty_like(x); + + PD_CHECK(x.place() == paddle::DefaultGPUPlace()); + + int64_t numel = out.numel(); + int64_t block = 512; + int64_t grid = (numel + block - 1) / block; + PD_DISPATCH_FLOATING_AND_HALF_TYPES( + out.type(), "relu_cuda_backward_kernel", ([&] { + relu_cuda_backward_kernel<<>>( + grad_out.data(), + out.data(), + grad_x.mutable_data(x.place()), + numel); + })); + + return {grad_x}; +} diff --git a/test/auto_parallel/custom_op/custom_relu_setup.py b/test/auto_parallel/custom_op/custom_relu_setup.py new file mode 100644 index 00000000000000..567e7ac65d1e36 --- /dev/null +++ b/test/auto_parallel/custom_op/custom_relu_setup.py @@ -0,0 +1,31 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from utils import extra_compile_args, paddle_includes + +from paddle.utils.cpp_extension import CUDAExtension, setup + +# Mac-CI don't support GPU +Extension = CUDAExtension +sources = ['custom_relu_op.cc', 'custom_relu_op.cu'] + +setup( + name='custom_relu', + ext_modules=Extension( + sources=sources, + include_dirs=paddle_includes, + extra_compile_args=extra_compile_args, + verbose=True, + ), +) diff --git a/test/auto_parallel/custom_op/semi_auto_parallel_for_custom_op.py b/test/auto_parallel/custom_op/semi_auto_parallel_for_custom_op.py new file mode 100644 index 00000000000000..1de800abe1b99c --- /dev/null +++ b/test/auto_parallel/custom_op/semi_auto_parallel_for_custom_op.py @@ -0,0 +1,89 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + +from semi_auto_parallel_util import SemiAutoParallelTestBase + +import paddle +import paddle.distributed as dist +from paddle.framework import core + +import custom_relu # noqa: F401 # pylint: disable=unused-import # isort:skip + +assert core.contains_spmd_rule("custom_relu") + + +class TestCusomOpSemiAutoParallel(SemiAutoParallelTestBase): + def __init__(self): + super().__init__() + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + + def check_placements(self, output, expected_placements): + assert ( + output.placements == expected_placements + ), f"{output.placements} vs {expected_placements}" + + def test_custom_relu(self): + shapes = [16, 4, 4] + specs = ['x', None, None] + inputs, outputs = self.runfunc_and_check( + inputs_shape=shapes, + inputs_specs=specs, + op_func=custom_relu.custom_relu, + with_backward=True, + ) + self.check_placements(outputs, [dist.Shard(0)]) + + def test_custom_relu_no_spmd(self): + shapes = [16, 4, 4] + specs = ['x', None, None] + inputs, outputs = self.runfunc_and_check( + inputs_shape=shapes, + inputs_specs=specs, + op_func=custom_relu.custom_relu_no_spmd, + with_backward=True, + ) + self.check_placements(outputs, [dist.Replicate()]) + + def test_custom_relu_no_shard(self): + shapes = [16, 4, 4] + specs = [None, None, None] + inputs, outputs = self.runfunc_and_check( + inputs_shape=shapes, + inputs_specs=specs, + op_func=custom_relu.custom_relu, + with_backward=True, + ) + self.check_placements(outputs, [dist.Replicate()]) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + elif self._backend == "gpu": + paddle.set_device("gpu:" + str(dist.get_rank())) + else: + raise ValueError("Only support cpu or gpu backend.") + self.test_custom_relu_no_shard() + self.test_custom_relu() + self.test_custom_relu_no_spmd() + + +if __name__ == '__main__': + TestCusomOpSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/custom_op/test_semi_auto_parallel_custom_op.py b/test/auto_parallel/custom_op/test_semi_auto_parallel_custom_op.py new file mode 100644 index 00000000000000..a8014a81c2548f --- /dev/null +++ b/test/auto_parallel/custom_op/test_semi_auto_parallel_custom_op.py @@ -0,0 +1,52 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import unittest + +import collective.test_communication_api_base as test_base + +from paddle.utils.cpp_extension.extension_utils import run_cmd + + +class TestCusomOp(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=2, timeout=200, nnode=1) + self._default_envs = {"dtype": "float32", "seed": "2023"} + self._changeable_envs = {"backend": ["cpu", "gpu"]} + cur_dir = os.path.dirname(os.path.abspath(__file__)) + # compile, install the custom op egg into site-packages under background + if os.name == 'nt': + cmd = f'cd /d {cur_dir} && python custom_relu_setup.py install' + else: + cmd = ( + f'cd {cur_dir} && {sys.executable} custom_relu_setup.py install' + ) + run_cmd(cmd) + + # test dynamic auto parallel run + def test_dynamic_auto_parallel(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_for_custom_op.py", + user_defined_envs=envs, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/custom_op/utils.py b/test/auto_parallel/custom_op/utils.py new file mode 100644 index 00000000000000..07f08648c0a62e --- /dev/null +++ b/test/auto_parallel/custom_op/utils.py @@ -0,0 +1,47 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from site import getsitepackages + +# Test for extra compile args +extra_cc_args = ['-w', '-g'] +extra_nvcc_args = ['-O3'] +extra_compile_args = {'cc': extra_cc_args, 'nvcc': extra_nvcc_args} + + +def get_paddle_includes(): + env_dict = os.environ + paddle_includes = [] + paddle_includes.append(f"{env_dict.get('PADDLE_SOURCE_DIR')}") + + # mkldnn + if env_dict.get("WITH_MKLDNN") == 'ON': + paddle_includes.append(f"{env_dict.get('MKLDNN_INSTALL_DIR')}/include") + if env_dict.get("WITH_GPU") == 'ON' or env_dict.get("WITH_ROCM") == 'ON': + paddle_includes.append(f"{env_dict.get('externalError_INCLUDE_DIR')}") + paddle_includes.append(f"{env_dict.get('PYBIND_INCLUDE_DIR')}") + + for site_packages_path in getsitepackages(): + paddle_includes.append( + os.path.join(site_packages_path, 'paddle', 'include') + ) + paddle_includes.append( + os.path.join(site_packages_path, 'paddle', 'include', 'third_party') + ) + + return paddle_includes + + +paddle_includes = get_paddle_includes() From be98374a4e3a9a02010e5bc6a522b5bc53035ac9 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Mon, 8 Jan 2024 10:53:56 +0800 Subject: [PATCH 135/142] [auto parallel] add lazy init ut to llama (#60585) --- python/paddle/distributed/auto_parallel/api.py | 3 +++ .../hybrid_strategy/semi_auto_llama.py | 13 ++++++++++++- .../test_semi_auto_parallel_llama_model.py | 17 +++++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index efec4383022ed4..428c1878f6381a 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -195,6 +195,9 @@ def lazy_init_hook(param, origin_hook): # lazy init hook with randomness controlling def _init_func(var, block): + if dist.get_rank() not in param.process_mesh.process_ids: + # None calc rank, just return no init. + return # get the unique rng name rng_name = determinate_rng( dist.get_rank(), diff --git a/test/auto_parallel/hybrid_strategy/semi_auto_llama.py b/test/auto_parallel/hybrid_strategy/semi_auto_llama.py index 129824fe4d8b6a..c22c9ce2b6760f 100644 --- a/test/auto_parallel/hybrid_strategy/semi_auto_llama.py +++ b/test/auto_parallel/hybrid_strategy/semi_auto_llama.py @@ -24,6 +24,7 @@ import paddle import paddle.distributed as dist +from paddle import LazyGuard from paddle.io import BatchSampler, DataLoader, Dataset @@ -44,6 +45,7 @@ class Config: rope = True recompute = False recompute_granularity = None + use_lazy_init = False class RandomDataset(Dataset): @@ -104,6 +106,8 @@ def __init__(self): if os.getenv("recompute") == "true": self.config.recompute = True self.config.recompute_granularity = os.getenv("recompute_granularity") + if os.getenv("use_lazy_init") == "true": + self.config.use_lazy_init = True self.gradient_accumulation_steps = int(os.getenv("acc_step")) self.init_dist_env() @@ -126,7 +130,14 @@ def init_dist_env(self): set_global_mesh(global_mesh) def run_llama(self, to_static=0): - model = LlamaForCausalLMAuto(self.config) + if self.config.use_lazy_init: + with LazyGuard(): + model = LlamaForCausalLMAuto(self.config) + for param in model.parameters(): + assert not param._is_initialized() + param.initialize() + else: + model = LlamaForCausalLMAuto(self.config) criterion = LlamaPretrainingCriterionAuto(self.config) lr_scheduler = paddle.optimizer.lr.LinearWarmup( diff --git a/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_llama_model.py b/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_llama_model.py index 3ace2754c7123c..a12f511e88afe3 100644 --- a/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_llama_model.py +++ b/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_llama_model.py @@ -117,5 +117,22 @@ def test_simple_net_hybrid_strategy_acc(self): ) +class TestSemiAutoParallelLlamaLazyInit(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=8, timeout=200, nnode=1) + self._default_envs = {"dp": "2", "mp": "2", "pp": "2", "acc_step": "2"} + self._changeable_envs = {"backend": ["gpu"], "use_lazy_init": ["true"]} + + def test_simple_net_hybrid_strategy(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_llama.py", + user_defined_envs=envs, + ) + + if __name__ == "__main__": unittest.main() From 1646a83bc07f80901c081799063891c0cbf69859 Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Mon, 8 Jan 2024 10:55:39 +0800 Subject: [PATCH 136/142] =?UTF-8?q?=E3=80=90pir=E3=80=91=20modify=20array?= =?UTF-8?q?=5Fwrite=20and=20array=5Fread=20vjp=20,=20add=20a=20simple=20wh?= =?UTF-8?q?ile=20with=20array=5Fwrite=20(#60575)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * optimize backward * [PIR] add vjp interface for while op * [PIR] fix ci error. * modify while stopgradient * merge * modify while grad bug * modify while grad op * modify * increment vp * [PIR] add get_used_external_value interface for block. * while case * delete print * delete print * Update python/paddle/autograd/ir_backward.py * [PIR] add unit_test for get_used_external_value * modify while_loop * code_style * modofy ci bug * modify while api * modify ci * modify array * Update python/paddle/autograd/ir_backward.py * Update test/legacy_test/test_cond.py * update * modify array_write grad info * merge * add_n and createarraylike * conflict * modify array_write vjp * modify array_write vjp * Update paddle/fluid/pybind/manual_static_op_function.h * modify array_write vjp * modify ci bug * modify * modify * Update test/legacy_test/test_while_loop_op.py * modify inplace array_read * Update test/legacy_test/test_while_op.py * Update test/ir/pir/test_while_api.py --------- Co-authored-by: winter-wang <1030748926@qq.com> --- .../pir/dialect/operator/ir/manual_op.cc | 2 +- .../pir/dialect/operator/ir/manual_op_vjp.cc | 40 ++-- python/paddle/autograd/ir_backward.py | 177 +++++++++++++----- python/paddle/tensor/array.py | 2 +- test/ir/pir/test_while_api.py | 4 +- test/legacy_test/test_array_read_write_op.py | 7 +- test/legacy_test/test_while_loop_op.py | 78 +++++++- test/legacy_test/test_while_op.py | 4 +- 8 files changed, 242 insertions(+), 72 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index ca720fdb26ee35..490330f0fe8ace 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -1501,7 +1501,7 @@ OpInfoTuple ArrayReadOp::GetOpInfo() { false, false, false, - true), + false), OpInputInfo( "i", "paddle::dialect::ScalarAttribute", false, false, true, false)}; diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc index 2ce536aa3d1d71..ece8b2fff8b8a9 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc @@ -196,7 +196,7 @@ std::vector> ArrayWrite_Op::Vjp( pir::Operation* op, const std::vector>& inputs_, const std::vector>& outputs, - const std::vector>& out_grads, + const std::vector>& in_grads, const std::vector>& stop_gradients) { PADDLE_ENFORCE_EQ( inputs_.size(), @@ -212,19 +212,21 @@ std::vector> ArrayWrite_Op::Vjp( outputs.size())); PADDLE_ENFORCE_EQ( - out_grads.size(), + in_grads.size(), 1, platform::errors::InvalidArgument( "ArrayWrite_ op's outputs size should be 1, but now is %d.", outputs.size())); VLOG(6) << "Vjp prepare call ArrayWrite_'s vjp inteface"; - pir::OpResult tensor_res = - paddle::dialect::array_read(out_grads[0][0], inputs_[2][0]); - - std::vector> res{{tensor_res}}; - if (stop_gradients[0][0]) { - res = {{}}; + pir::OpResult x_grad = + paddle::dialect::array_read(in_grads[0][0], inputs_[2][0]); + pir::OpResult zero = paddle::dialect::zeros_like(inputs_[1][0]); + paddle::dialect::array_write_(in_grads[0][0], zero, inputs_[2][0]); + std::vector> res(1); + res[0].resize(1); + if (!stop_gradients[0][0]) { + res[0][0] = x_grad; } return res; } @@ -247,27 +249,25 @@ std::vector> ArrayReadOp::Vjp( platform::errors::InvalidArgument( "Array_read op's outputs size should be 1, but now is %d.", outputs.size())); - + // x = array_read(input, i) + // out_grads[0][0] is x_grad + // out_grads[1][0] is input_array_grad PADDLE_ENFORCE_EQ( out_grads.size(), - 1, + 2, platform::errors::InvalidArgument( "Array_read op's outputs size should be 1, but now is %d.", outputs.size())); VLOG(6) << "Vjp prepare call Array_read's vjp inteface"; - paddle::dialect::DenseTensorType outgrad_type = - out_grads[0][0].type().dyn_cast(); - pir::Value new_array = paddle::dialect::create_array( - paddle::dialect::TransToPhiDataType(outgrad_type.dtype())); - pir::OpResult tensor_res = - paddle::dialect::array_write_(new_array, out_grads[0][0], inputs_[1][0]); + pir::Value array_grad_i_origin = + paddle::dialect::array_read(out_grads[1][0], inputs_[1][0]); + pir::Value array_grad_i = + paddle::dialect::add(array_grad_i_origin, out_grads[0][0]); + paddle::dialect::array_write_(out_grads[1][0], array_grad_i, inputs_[1][0]); - std::vector> res{{tensor_res}}; - if (stop_gradients[0][0]) { - res = {{}}; - } + std::vector> res; return res; } } // namespace dialect diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index eed96992a1d528..aa26c86ce31a42 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -63,17 +63,26 @@ def check_all_puts(block, inputs, outputs): def append_full_like(float_value, copy_value, value, state, backward_ops): - value_grad = paddle.full_like( - copy_value, - float_value, - dtype=copy_value.dtype, - ) - full_like_op = value_grad.get_defining_op() - full_op = full_like_op.operand_source(1).get_defining_op() + if copy_value.is_tensorarray(): + value_grad = paddle._pir_ops.create_array_like( + copy_value, + float_value, + ) + full_like_op = value_grad.get_defining_op() + backward_ops_ = [full_like_op] + else: + value_grad = paddle.full_like( + copy_value, + float_value, + dtype=copy_value.dtype, + ) + full_like_op = value_grad.get_defining_op() + full_op = full_like_op.operand_source(1).get_defining_op() + backward_ops_ = [full_like_op, full_op] update_bwdop_structure( backward_ops, state.op_to_opgrad[value.get_defining_op()], - [full_like_op, full_op], + backward_ops_, ) state.value_to_valuegrad[value] = [[value_grad]] return value_grad @@ -367,6 +376,16 @@ def inverse_sort_op(ops): return sorted_list +def inplace_net(op_list): + op_name_list = [op.name() for op in op_list] + if ( + "pd_op.array_write_" in op_name_list + or "pd_op.array_read" in op_name_list + ): + return True + return False + + def append_backward_ops( base_op, base_inputs, @@ -425,13 +444,28 @@ def return_map_value(value, map): output = map[output] return output + def return_map_value_list(grad_value, map): + output = [] + for i in range(len(grad_value)): + if grad_value[i] in map: + output.append(map[grad_value[i]]) + else: + output.append(grad_value[i]) + return output + def append_add_n(value): # value is input of more than one fwd_op, # so more than one bwd_op create input_grad, # need add sum op to accumulate gradient - add_n_value = paddle.add_n( - [item[0] for item in state.value_to_valuegrad[value]] - ) + if value.is_tensorarray(): + add_n_value = paddle._pir_ops.add_n_array( + [item[0] for item in state.value_to_valuegrad[value]] + ) + else: + add_n_value = paddle.add_n( + [item[0] for item in state.value_to_valuegrad[value]] + ) + add_n_op = add_n_value.get_defining_op() combine_op = add_n_op.operand_source(0).get_defining_op() update_bwdop_structure( @@ -446,7 +480,12 @@ def make_output_with_output_grad(op): zero_flag = [False] * op.num_results() outputs = [] output_grads = [] - for i, value in enumerate(op.results()): + if op.name() == "pd_op.array_write_": + output_list = [op.operand_source(0)] + else: + output_list = op.results() + + for i, value in enumerate(output_list): new_value = [ return_map_value(value, control_flow_value_to_copyvalue_map) ] @@ -496,9 +535,39 @@ def make_output_with_output_grad(op): outputs.append(new_value) grad_value = state.value_to_valuegrad[value][0] output_grads.append( - [bwd_value_to_block_argument_map[grad_value[0]]] - if grad_value[0] in bwd_value_to_block_argument_map - else grad_value + return_map_value_list( + grad_value, bwd_value_to_block_argument_map + ) + ) + + if op.name() == "pd_op.array_read": + value = op.operand_source(0) + while value in state.inside_value_to_outside_value_map: + value = state.inside_value_to_outside_value_map[value] + + if value in state.value_to_valuegrad: + if len(state.value_to_valuegrad[value]) > 1: + append_add_n(value) + + if ( + value not in state.value_to_valuegrad + or state.value_to_valuegrad[value] == [] + ): + append_full_like( + 0.0, + return_map_value( + value, control_flow_value_to_copyvalue_map + ), + value, + state, + backward_ops, + ) + + grad_value = state.value_to_valuegrad[value][0] + output_grads.append( + return_map_value_list( + grad_value, bwd_value_to_block_argument_map + ) ) return zero_flag, outputs, output_grads @@ -692,7 +761,11 @@ def argument_to_value(while_op): else: forward_ops = effective_forward_ops - inverse_effective_forward_ops = inverse_sort_op(forward_ops) + if inplace_net(forward_ops): + inverse_effective_forward_ops = reversed(forward_ops) + else: + inverse_effective_forward_ops = inverse_sort_op(forward_ops) + clear_effective_forward_ops = [] for op in inverse_effective_forward_ops: if op.name() != "builtin.combine" and op.name() != "builtin.split": @@ -743,7 +816,9 @@ def argument_to_value(while_op): else: # all(zero_flag) support this op has no contribution for grad # should be delete (prune sub_graph) - if len(output_grads) == 0 or all(zero_flag): + if ( + len(output_grads) == 0 or all(zero_flag) + ) and op.name() != "pd_op.while": continue if op.name() == "pd_op.if": @@ -787,17 +862,28 @@ def argument_to_value(while_op): for i, input in enumerate( get_used_external_value(while_block) ): - append_full_like( - 0.0, input, input, sub_state, backward_ops - ) + if input in sub_state.value_to_valuegrad: + if len(sub_state.value_to_valuegrad[input]) > 1: + append_add_n(input) + + if ( + input not in sub_state.value_to_valuegrad + or sub_state.value_to_valuegrad[input] == [] + ): + append_full_like( + 0.0, input, input, sub_state, backward_ops + ) + grad_value = sub_state.value_to_valuegrad[input][0] + for tmp in state.value_to_valuegrad[input]: + state.value_to_sumvaluegrad[input].append(tmp) + state.value_to_valuegrad[input] = [] output_grads.append( - [bwd_value_to_block_argument_map[grad_value[0]]] - if grad_value[0] - in bwd_value_to_block_argument_map - else grad_value + return_map_value_list( + grad_value, + bwd_value_to_block_argument_map, + ) ) - build_pipe_for_block(while_block) with dynamic_shape_prim_vjp_guard(op, inputs): input_grads = paddle.framework.core.call_vjp( @@ -953,6 +1039,11 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): block = outputs[0].get_defining_op().get_parent_block() state = State(block) + total_ops = [] + if block.parent_block is not None: + total_ops += block.parent_block.ops + total_ops += block.ops + # update no_grad_set if some value stop_gradient=True update_no_grad_set_by_stopgradient(block, no_grad_set) complete_outputs, _, backward_ops = prepare_grad_outputs( @@ -961,14 +1052,14 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): inputs_set = ValueSet(inputs) outputs_set = ValueSet(complete_outputs) - total_ops = [] - if block.parent_block is not None: - total_ops += block.parent_block.ops - total_ops += block.ops - effective_forward_ops, _ = prune_ops( - total_ops, inputs_set, outputs_set, no_grad_set - ) + if inplace_net(total_ops): + effective_forward_ops = total_ops + else: + effective_forward_ops, _ = prune_ops( + total_ops, inputs_set, outputs_set, no_grad_set + ) + update_no_grad_set_after_prune( total_ops, effective_forward_ops, no_grad_set, inputs, complete_outputs ) @@ -993,18 +1084,18 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): outputs_set, inputs_set, no_gradvar_set = create_backward_prune_set( outputs_fwd_set, inputs_fwd_set, no_grad_set, state ) + if not inplace_net(backward_ops): + _, remove_ops = prune_ops( + backward_ops, inputs_set, outputs_set, no_gradvar_set + ) - _, remove_ops = prune_ops( - backward_ops, inputs_set, outputs_set, no_gradvar_set - ) - - state.turn_map() - for bwd_op in inverse_sort_op(remove_ops): - if bwd_op.result(0) in ValueSet(grad_outputs): - continue - if bwd_op.result(0).use_empty(): - remove_op(block, bwd_op, state) - state.turn_map() + state.turn_map() + for bwd_op in inverse_sort_op(remove_ops): + if bwd_op.result(0) in ValueSet(grad_outputs): + continue + if bwd_op.result(0).use_empty(): + remove_op(block, bwd_op, state) + state.turn_map() input_grad_map = state.value_to_valuegrad diff --git a/python/paddle/tensor/array.py b/python/paddle/tensor/array.py index b618c7a0f85c62..b5cc0912650bc7 100644 --- a/python/paddle/tensor/array.py +++ b/python/paddle/tensor/array.py @@ -230,7 +230,7 @@ def array_write(x, i, array=None): if array is None: array = paddle._pir_ops.create_array(x.dtype) - array = paddle._pir_ops.array_write_(array, x, i) + paddle._pir_ops.array_write_(array, x, i) return array else: check_variable_and_dtype(i, 'i', ['int64'], 'array_write') diff --git a/test/ir/pir/test_while_api.py b/test/ir/pir/test_while_api.py index 44bb4541dc3749..337cf31970c1f6 100644 --- a/test/ir/pir/test_while_api.py +++ b/test/ir/pir/test_while_api.py @@ -177,7 +177,7 @@ def test_backward(self): ) self.assertEqual( main_program.global_block() - .ops[-3] + .ops[-1] .as_while_op() .body() .ops[-4] @@ -187,7 +187,7 @@ def test_backward(self): self.assertEqual( main_program.global_block() - .ops[-3] + .ops[-1] .as_while_op() .body() .ops[-5] diff --git a/test/legacy_test/test_array_read_write_op.py b/test/legacy_test/test_array_read_write_op.py index dbdcb7707c3939..05452a9690e2c0 100644 --- a/test/legacy_test/test_array_read_write_op.py +++ b/test/legacy_test/test_array_read_write_op.py @@ -223,18 +223,23 @@ def test_array_backward(self): dd0 = g if p.is_same(mem_array): dmem_array = g + dmem0 = paddle.tensor.array_read( + dmem_array, paddle.zeros(shape=[1], dtype='int64') + ) res = exe.run( main_program, feed={'d0': d}, - fetch_list=[mean, dd0], # dmem_array + fetch_list=[mean, dd0, dmem0], # dmem_array ) # pir not support fetch tensorarray + np.testing.assert_allclose(res[2], [0.0] * 10, rtol=1e-05) else: res = exe.run( main_program, feed={'d0': d}, fetch_list=[mean.name, d0.grad_name, mem_array.grad_name], ) + # this ans is wrong array is empty at begining ,so it no grad. np.testing.assert_allclose(res[2], [[0.1] * 10], rtol=1e-05) mean = 0.6097253 diff --git a/test/legacy_test/test_while_loop_op.py b/test/legacy_test/test_while_loop_op.py index 44ee6383fa6abc..cda75f2bd0ef75 100644 --- a/test/legacy_test/test_while_loop_op.py +++ b/test/legacy_test/test_while_loop_op.py @@ -388,6 +388,7 @@ def internal_body(j, x, mem_array): inner_sum_1 = paddle.add(x=x, y=inner_sum_0) j = paddle.increment(x=j) paddle.tensor.array_write(inner_sum_1, i=j, array=mem_array) + return [j, x, mem_array] outer_data = paddle.tensor.array_read(array=data_array, i=i) @@ -409,6 +410,7 @@ def internal_body(j, x, mem_array): d2 = paddle.static.data(name='d2', shape=[10], dtype='float32') x = paddle.static.data(name='x', shape=[10], dtype='float32') d0.persistable = True + d0.stop_gradient = False d1.persistable = True d2.persistable = True x.stop_gradient = False @@ -417,9 +419,11 @@ def internal_body(j, x, mem_array): i.stop_gradient = True i.persistable = True init = paddle.zeros(shape=[10], dtype='float32') + init.stop_gradient = False mem_array = paddle.tensor.array_write(x=init, i=i) data_array = paddle.tensor.array_write(x=d0, i=i) mem_array.stop_gradient = False + data_array.stop_gradient = False mem_array.persistable = True i = paddle.increment(i) paddle.tensor.array_write(d1, i, array=data_array) @@ -443,7 +447,6 @@ def internal_body(j, x, mem_array): sum_result = paddle.tensor.array_read(array=out[3], i=j) mean = paddle.mean(sum_result) grad_list = append_backward(mean) - place = ( base.CUDAPlace(0) if core.is_compiled_with_cuda() @@ -475,6 +478,79 @@ def internal_body(j, x, mem_array): np.testing.assert_allclose(res[0], data_sum, rtol=1e-05) np.testing.assert_allclose(res[1], x_grad, rtol=1e-05) + def _test_while_with_inplace(self): + with paddle.pir_utils.IrGuard(): + + def internal_cond(i, x, mem_array): + return paddle.less_than(i, array_len) + + def internal_body(i, x, mem_array): + t0 = paddle.tensor.array_read(array=mem_array, i=i) + t1 = paddle.add(t0, x) + i = paddle.increment(i) + paddle.tensor.array_write(t1, i, array=mem_array) + return [i, x, mem_array] + + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + i = paddle.zeros(shape=[1], dtype='int64') + x = paddle.static.data(name='x', shape=[10], dtype='float32') + x.stop_gradient = False + init = paddle.zeros(shape=[10], dtype='float32') + mem_array = paddle.tensor.array_write(init, i) + mem_array.stop_gradient = False + array_len = paddle.tensor.fill_constant( + shape=[1], dtype='int64', value=3 + ) + + i, x, mem_array = paddle.static.nn.while_loop( + internal_cond, internal_body, [i, x, mem_array] + ) + + out = paddle.tensor.array_read(mem_array, i) + mean_out = paddle.mean(out) + dx, dmem_array = paddle.static.gradients( + mean_out, [x, mem_array] + ) + + j = paddle.zeros(shape=[1], dtype='int64') + dmem0 = paddle.tensor.array_read(dmem_array, j) + j = paddle.increment(j) + dmem1 = paddle.tensor.array_read(dmem_array, j) + j = paddle.increment(j) + dmem2 = paddle.tensor.array_read(mem_array, j) + j = paddle.increment(j) + dmem3 = paddle.tensor.array_read(mem_array, j) + place = ( + base.CUDAPlace(0) + if core.is_compiled_with_cuda() + else base.CPUPlace() + ) + exe = base.Executor(place) + + feed_x = np.ones(10).astype('float32') + + if paddle.framework.in_pir_mode(): + res = exe.run( + main_program, + feed={"x": feed_x}, + fetch_list=[out, dx], # dmem0, dmem1, dmem2, dmem3], + ) + else: + res = exe.run( + main_program, + feed={"x": feed_x}, + fetch_list=[out, dx], # dmem0, dmem1, dmem2, dmem3], + ) + + # print("out = ", res[0], [3] * 10) + # print("dx = ", res[1], [0.3] * 10) + # print("dmem0 = ", res[2], [0.0] * 10) + # print("dmem1 = ", res[3], [0.0] * 10) + # print("dmem2 = ", res[4], [0.0] * 10) + # print("dmem3 = ", res[5], [0.0] * 10) + class TestApiWhileLoopWithSwitchCase(unittest.TestCase): @compare_legacy_with_pt diff --git a/test/legacy_test/test_while_op.py b/test/legacy_test/test_while_op.py index a8d79af8a93b6c..1d053fbc18eff2 100644 --- a/test/legacy_test/test_while_op.py +++ b/test/legacy_test/test_while_op.py @@ -20,7 +20,6 @@ import paddle from paddle import base from paddle.base import core -from paddle.base.backward import append_backward from paddle.base.executor import Executor from paddle.base.framework import in_pir_mode from paddle.incubate.layers.nn import shuffle_batch @@ -87,8 +86,7 @@ def test_simple_net(self): startup_program = base.Program() with base.program_guard(main_program, startup_program): loss, sum_result = self.simple_net() - - append_backward(loss) + # append_backward(loss) cpu = core.CPUPlace() exe = Executor(cpu) From 385ec43b96e82b8cf2576eb9263e7c3684bf778b Mon Sep 17 00:00:00 2001 From: kevin Date: Mon, 8 Jan 2024 11:01:01 +0800 Subject: [PATCH 137/142] [Prim][PIR] add leaky_relu, sigmoid, instance_norm op forward prim (#60564) * hardswish op prim sink * hardswish op prim * add composite * add leaky_relu, sigmoid op forward prim * remove hardswish op forward * add instance_norm op forward prim --- .../decomp_interface_gen_op_list.py | 6 + .../pir/dialect/operator/ir/ops_backward.yaml | 1 + paddle/fluid/primitive/composite/composite.h | 109 ++++++++++++++++++ test/legacy_test/test_activation_op.py | 65 ++++++++--- test/legacy_test/test_instance_norm_op.py | 2 +- test/legacy_test/test_instance_norm_op_v2.py | 22 +++- 6 files changed, 183 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py index ce4c0a1cd094fd..5a1e8f361bd626 100644 --- a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py @@ -24,11 +24,14 @@ "dropout", "full_like", "gelu", + "instance_norm", "layer_norm", + "leaky_relu", "mean", "pow", "relu", "rsqrt", + "sigmoid", "silu", "softmax", "sqrt", @@ -44,11 +47,14 @@ "dropout", "full_like", "gelu", + "instance_norm", "layer_norm", + "leaky_relu", "mean", "pow", "relu", "rsqrt", + "sigmoid", "silu", "softmax", "sqrt", diff --git a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml index 90221982ebbddf..71124cf5593960 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml @@ -344,6 +344,7 @@ kernel : func : hardswish_grad inplace : (out_grad -> x_grad) + composite : hardswish_grad(x, out_grad, x_grad) - backward_op : hsigmoid_loss_grad forward : hsigmoid_loss (Tensor x, Tensor label, Tensor w, Tensor bias, Tensor path, Tensor code, int num_classes, bool is_sparse) -> Tensor(out), Tensor(pre_out), Tensor(w_out) diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index 1ab1f33f4f5f68..cb3366229a92f4 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -528,6 +528,115 @@ Tensor gelu_decomp(const Tensor& x, bool approximate) { } } +template +Tensor sigmoid_decomp(const Tensor& x) { + auto org_dtype = x.dtype(); + Tensor x_cast = x; + + bool need_cast = is_half_dtype(org_dtype); + if (need_cast) { + x_cast = cast(x, phi::DataType::FLOAT32); + } + + // res = 1 / (1 + exp(-x)) + auto one = full(common::vectorize(x_cast.dims()), 1, x_cast.dtype()); + auto exp_tmp = exp( + full(common::vectorize(x_cast.dims()), -1, x_cast.dtype()) * x_cast); + auto res = one / (one + exp_tmp); + if (need_cast) { + return cast(res, org_dtype); + } else { + return res; + } +} + +template +Tensor leaky_relu_decomp(const Tensor& x, float negative_slope) { + auto multiply_tmp = + full(phi::vectorize(x.dims()), negative_slope, x.dtype()) * x; + if (negative_slope < 1.0) { + return maximum(x, multiply_tmp); + } else { + return minimum(x, multiply_tmp); + } +} + +template +std::tuple instance_norm_decomp( + const Tensor& x, + const paddle::optional& scale, + const paddle::optional& bias, + float epsilon) { + auto org_dtype = x.dtype(); + Tensor x_cast = x; + + bool need_cast = is_half_dtype(org_dtype); + if (need_cast) { + x_cast = cast(x, phi::DataType::FLOAT32); + } + + std::vector axis; + auto x_dim = common::vectorize(x.dims()); + for (size_t i = 2; i < x_dim.size(); i++) { + axis.push_back(static_cast(i)); + } + + // out = (x - mean(x)) / sqrt(var + epsilon)) + // var = mean((x-mean(x))^2) + auto mean_ = mean_decomp(x_cast, IntArray(axis), true); + auto difference = x_cast - mean_; + auto var_tmp1 = difference * difference; + auto variance = mean_decomp(var_tmp1, IntArray(axis), true); + auto var_tmp3 = variance + epsilon; + auto rsqrt_var = elementwise_pow( + var_tmp3, + full(common::vectorize(var_tmp3.dims()), 0.5, var_tmp3.dtype())); + auto out = difference / rsqrt_var; + + auto scale_ptr = scale.get_ptr(); + auto bias_ptr = bias.get_ptr(); + std::vector slice_shape(x_dim.size(), 1); + slice_shape[1] = x_dim[1]; + + Tensor scale_cast; + if (scale_ptr) { + if (slice_shape != scale_ptr->shape()) { + scale_cast = reshape(*scale_ptr, slice_shape); + } else { + scale_cast = *scale_ptr; + } + if (need_cast) { + scale_cast = cast(scale_cast, phi::DataType::FLOAT32); + } + out = out * scale_cast; + } + Tensor bias_cast; + if (bias_ptr) { + if (slice_shape != bias_ptr->shape()) { + bias_cast = reshape(*bias_ptr, slice_shape); + } else { + bias_cast = *bias_ptr; + } + if (need_cast) { + bias_cast = cast(bias_cast, phi::DataType::FLOAT32); + } + out = out + bias_cast; + } + + std::vector res_shape(1, -1); + auto mean_out = reshape(mean_, res_shape); + auto variance_out = reshape(1 / rsqrt_var, res_shape); + + Tensor res; + if (need_cast) { + res = cast(out, org_dtype); + } else { + res = out; + } + + return std::make_tuple(res, mean_out, variance_out); +} + } // namespace details } // namespace primitive diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 8a9379f528c1eb..7b2b70d7b515f6 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -390,7 +390,7 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_pir=True) + self.check_output(check_pir=True, check_prim_pir=True) def test_check_grad(self): if self.dtype == np.float16: @@ -411,7 +411,7 @@ def init_dtype(self): def test_check_output(self): with paddle.static.scope_guard(paddle.static.Scope()): - self.check_output(check_prim=False) + self.check_output(check_prim=False, check_prim_pir=False) def test_check_grad(self): self.check_grad( @@ -420,6 +420,7 @@ def test_check_grad(self): max_relative_error=0.006, check_prim=False, check_pir=True, + check_prim_pir=False, ) @@ -428,7 +429,9 @@ def init_dtype(self): self.dtype = np.complex128 def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=False, check_pir=True) + self.check_grad( + ['X'], 'Out', check_prim=False, check_pir=True, check_prim_pir=False + ) class TestSigmoid_ZeroDim(TestSigmoid): @@ -469,7 +472,9 @@ def if_enable_cinn(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_prim=True, check_pir=True) + self.check_output_with_place( + place, check_prim=True, check_pir=True, check_prim_pir=True + ) def test_check_grad(self): place = core.CUDAPlace(0) @@ -2555,7 +2560,7 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_prim=True, check_pir=True) + self.check_output(check_prim=True, check_pir=True, check_prim_pir=True) def test_check_grad(self): if self.dtype == np.float16: @@ -3038,7 +3043,9 @@ def test_check_grad(self): else False, only_check_prim=self.if_only_check_prim(), check_pir=True, - check_prim_pir=True, + check_prim_pir=True + if self.dtype not in [np.complex64, np.complex128] + else False, ) def test_check_output(self): @@ -4832,7 +4839,11 @@ def test_check_grad(self): ) create_test_act_fp16_class(TestExpm1) create_test_act_fp16_class( - TestSigmoid, check_prim=True, enable_cinn=True, check_pir=True + TestSigmoid, + check_prim=True, + enable_cinn=True, + check_pir=True, + check_prim_pir=True, ) create_test_act_fp16_class( TestSilu, check_prim=True, enable_cinn=True, check_prim_pir=True @@ -4929,18 +4940,24 @@ def test_check_grad(self): create_test_act_fp16_class(TestHardSwish, check_prim=True, check_pir=True) create_test_act_fp16_class(TestMish, check_pir=True) create_test_act_fp16_class( - TestLeakyRelu, check_prim=True, enable_cinn=True, check_pir=True + TestLeakyRelu, + check_prim=True, + enable_cinn=True, + check_pir=True, + check_prim_pir=True, +) +create_test_act_fp16_class( + TestLeakyReluAlpha1, check_prim=True, enable_cinn=True, check_prim_pir=True ) create_test_act_fp16_class( - TestLeakyReluAlpha1, check_prim=True, enable_cinn=True + TestLeakyReluAlpha2, check_prim=True, enable_cinn=True, check_prim_pir=True ) create_test_act_fp16_class( - TestLeakyReluAlpha2, check_prim=True, enable_cinn=True + TestLeakyReluAlpha3, check_prim=True, enable_cinn=True, check_prim_pir=True ) create_test_act_fp16_class( - TestLeakyReluAlpha3, check_prim=True, enable_cinn=True + TestLeakyRelu_ZeroDim, check_prim=True, check_prim_pir=True ) -create_test_act_fp16_class(TestLeakyRelu_ZeroDim, check_prim=True) create_test_act_fp16_class( TestRsqrt, check_prim=True, @@ -5017,7 +5034,9 @@ def test_check_grad(self): TestExpFp32_Prim, check_prim=True, check_prim_pir=True ) create_test_act_bf16_class(TestExpm1) -create_test_act_bf16_class(TestSigmoid, check_prim=True, check_pir=True) +create_test_act_bf16_class( + TestSigmoid, check_prim=True, check_pir=True, check_prim_pir=True +) create_test_act_bf16_class(TestSilu, check_prim=True, check_prim_pir=True) create_test_act_bf16_class(TestLogSigmoid) create_test_act_bf16_class(TestTanh, check_prim=True, check_prim_pir=True) @@ -5089,11 +5108,21 @@ def test_check_grad(self): create_test_act_bf16_class(TestSwish) create_test_act_bf16_class(TestHardSwish, check_prim=True, check_pir=True) create_test_act_bf16_class(TestMish, check_pir=True) -create_test_act_bf16_class(TestLeakyRelu, check_prim=True, check_pir=True) -create_test_act_bf16_class(TestLeakyReluAlpha1, check_prim=True) -create_test_act_bf16_class(TestLeakyReluAlpha2, check_prim=True) -create_test_act_bf16_class(TestLeakyReluAlpha3, check_prim=True) -create_test_act_bf16_class(TestLeakyRelu_ZeroDim, check_prim=True) +create_test_act_bf16_class( + TestLeakyRelu, check_prim=True, check_pir=True, check_prim_pir=True +) +create_test_act_bf16_class( + TestLeakyReluAlpha1, check_prim=True, check_prim_pir=True +) +create_test_act_bf16_class( + TestLeakyReluAlpha2, check_prim=True, check_prim_pir=True +) +create_test_act_bf16_class( + TestLeakyReluAlpha3, check_prim=True, check_prim_pir=True +) +create_test_act_bf16_class( + TestLeakyRelu_ZeroDim, check_prim=True, check_prim_pir=True +) create_test_act_bf16_class( TestRsqrt, check_prim=True, check_pir=True, check_prim_pir=True ) diff --git a/test/legacy_test/test_instance_norm_op.py b/test/legacy_test/test_instance_norm_op.py index c5fd7af6b48799..3ac10a9547d5c8 100644 --- a/test/legacy_test/test_instance_norm_op.py +++ b/test/legacy_test/test_instance_norm_op.py @@ -130,7 +130,7 @@ def setUp(self): } def test_check_output(self): - self.check_output(check_prim=True, check_pir=True) + self.check_output(check_prim=True, check_pir=True, check_prim_pir=True) def test_check_grad(self): self.check_grad( diff --git a/test/legacy_test/test_instance_norm_op_v2.py b/test/legacy_test/test_instance_norm_op_v2.py index fe8e26aaec7839..90641cc20ef8df 100644 --- a/test/legacy_test/test_instance_norm_op_v2.py +++ b/test/legacy_test/test_instance_norm_op_v2.py @@ -220,7 +220,12 @@ def setUp(self): def test_check_output(self): self.check_output( - atol=self.atol, check_prim=self.check_prim, check_pir=True + atol=self.atol, + check_prim=self.check_prim, + check_pir=True, + check_prim_pir=False + if os.getenv("FLAGS_enable_pir_in_executor") + else True, ) def test_check_grad(self): @@ -275,7 +280,13 @@ def set_err_thre(self): def test_check_output(self): place = core.CUDAPlace(0) self.check_output_with_place( - place, atol=self.atol, check_prim=self.check_prim, check_pir=True + place, + atol=self.atol, + check_prim=self.check_prim, + check_pir=True, + check_prim_pir=False + if os.getenv("FLAGS_enable_pir_in_executor") + else True, ) def test_check_grad(self): @@ -350,7 +361,12 @@ def init_shape(self): def test_check_output(self): place = core.CUDAPlace(0) self.check_output_with_place( - place, check_prim=self.check_prim, check_pir=True + place, + check_prim=self.check_prim, + check_pir=True, + check_prim_pir=False + if os.getenv("FLAGS_enable_pir_in_executor") + else True, ) def test_check_grad(self): From e2b42473b2173f77c6ad05ee12b91e6ac771641b Mon Sep 17 00:00:00 2001 From: BiynXu <62832681+BiynXu@users.noreply.github.com> Date: Mon, 8 Jan 2024 11:11:40 +0800 Subject: [PATCH 138/142] [CINN]Add bucket context (#60549) * [CINN] Add tile tactic * [CINN] Add bind cuda tactic * [CINN] Add bucket contexts * fix group output args bug --- paddle/cinn/hlir/framework/op_lowering.h | 10 +- paddle/cinn/hlir/framework/op_lowering_impl.h | 9 +- .../hlir/framework/op_lowering_impl_base.h | 24 ++-- .../hlir/framework/pir/compilation_task.cc | 16 ++- .../hlir/framework/pir/compilation_task.h | 6 +- .../hlir/framework/pir/op_lowering_impl.cc | 85 +++++++------ .../hlir/framework/pir/op_lowering_impl.h | 13 +- .../dy_shape_group_scheduler.cc | 118 ++++++++++++------ .../group_schedule/dy_shape_group_scheduler.h | 18 ++- .../group_schedule/tactic/schedule_tactic.h | 3 +- .../ir/group_schedule/tactic/tile_tactic.cc | 33 +++++ paddle/cinn/ir/module.cc | 2 +- paddle/cinn/ir/module.h | 2 +- .../instruction/cinn_jit_instruction.cc | 2 - 14 files changed, 213 insertions(+), 128 deletions(-) diff --git a/paddle/cinn/hlir/framework/op_lowering.h b/paddle/cinn/hlir/framework/op_lowering.h index d4b4a78e9cd3fa..f1f1554870663f 100644 --- a/paddle/cinn/hlir/framework/op_lowering.h +++ b/paddle/cinn/hlir/framework/op_lowering.h @@ -47,12 +47,10 @@ class OpLowerer { group, apply_op_schedule, apply_group_schedule, apply_pass); } - std::vector< - std::pair> - BucketLower(const T& group, - bool apply_op_schedule = false, - bool apply_group_schedule = true, - bool apply_pass = true) { + BucketLoweredFuncsWrapper BucketLower(const T& group, + bool apply_op_schedule = false, + bool apply_group_schedule = true, + bool apply_pass = true) { return impl_->BucketLower( group, apply_op_schedule, apply_group_schedule, apply_pass); } diff --git a/paddle/cinn/hlir/framework/op_lowering_impl.h b/paddle/cinn/hlir/framework/op_lowering_impl.h index d48cbbeb7e9b4a..5e57c607c93e1e 100644 --- a/paddle/cinn/hlir/framework/op_lowering_impl.h +++ b/paddle/cinn/hlir/framework/op_lowering_impl.h @@ -60,11 +60,10 @@ class OpLowererImpl : public OpLowererImplBase { bool apply_group_schedule = true, bool apply_pass = true); - std::vector> BucketLower( - const GroupPtr& group, - bool apply_op_schedule = false, - bool apply_group_schedule = true, - bool apply_pass = true) { + BucketLoweredFuncsWrapper BucketLower(const GroupPtr& group, + bool apply_op_schedule = false, + bool apply_group_schedule = true, + bool apply_pass = true) { CINN_NOT_IMPLEMENTED; } diff --git a/paddle/cinn/hlir/framework/op_lowering_impl_base.h b/paddle/cinn/hlir/framework/op_lowering_impl_base.h index 32bda3ca50f675..b67deedbbb7c58 100644 --- a/paddle/cinn/hlir/framework/op_lowering_impl_base.h +++ b/paddle/cinn/hlir/framework/op_lowering_impl_base.h @@ -27,16 +27,15 @@ namespace cinn { namespace hlir { namespace framework { +struct BucketLoweredFuncsWrapper { + std::vector> + predicate2funcs; + ir::LoweredFunc infer_shape_func; +}; + template class OpLowererImplBase { public: - struct WrapLoweredFunc { - ir::LoweredFunc kernel_func; - ir::LoweredFunc infer_shape_func; - WrapLoweredFunc(ir::LoweredFunc kernel_func, - ir::LoweredFunc infer_shape_func = ir::LoweredFunc()) - : infer_shape_func(infer_shape_func), kernel_func(kernel_func) {} - }; OpLowererImplBase() = default; ~OpLowererImplBase() = default; @@ -45,11 +44,12 @@ class OpLowererImplBase { bool apply_group_schedule = true, bool apply_pass = true) = 0; - virtual std::vector> - BucketLower(const T& group, - bool apply_op_schedule = false, - bool apply_group_schedule = true, - bool apply_pass = true) = 0; + virtual BucketLoweredFuncsWrapper BucketLower( + const T& group, + bool apply_op_schedule = false, + bool apply_group_schedule = true, + bool apply_pass = true) = 0; + virtual void InsertNameGeneToScope(std::shared_ptr scope) = 0; }; diff --git a/paddle/cinn/hlir/framework/pir/compilation_task.cc b/paddle/cinn/hlir/framework/pir/compilation_task.cc index c6d3412102c302..6be8e616005857 100644 --- a/paddle/cinn/hlir/framework/pir/compilation_task.cc +++ b/paddle/cinn/hlir/framework/pir/compilation_task.cc @@ -24,16 +24,14 @@ namespace hlir { namespace framework { void GroupCompilationContext::SetLoweredFuncs( - std::vector>&& funcs) { - for (std::pair& - predicate2func : funcs) { - predicates_.push_back(predicate2func.first); - lowered_funcs_.push_back(predicate2func.second.kernel_func); - infer_shape_lowered_funcs_.push_back( - predicate2func.second.infer_shape_func); + BucketLoweredFuncsWrapper&& funcs) { + for (std::pair& predicate2func : + funcs.predicate2funcs) { + predicates_.push_back(std::move(predicate2func.first)); + lowered_funcs_.push_back(std::move(predicate2func.second)); ++func_size_; } + infer_shape_lowered_func_ = std::move(funcs.infer_shape_func); } std::string GroupCompilationContext::PrintPredicate2Funcs() const { @@ -77,7 +75,7 @@ void CompilationTask::CodegenAndJit() { for (const ir::LoweredFunc& func : context_->lowered_funcs_) { builder.AddFunction(func); } - builder.AddInferShapeFunc(context_->infer_shape_lowered_funcs_[0]); + builder.SetInferShapeFunc(context_->infer_shape_lowered_func_); ir::Module ir_module = builder.Build(); context_->backend_compiler_ = backends::Compiler::Create(context_->target_); diff --git a/paddle/cinn/hlir/framework/pir/compilation_task.h b/paddle/cinn/hlir/framework/pir/compilation_task.h index 9e96c64694527e..e76f93d2060962 100644 --- a/paddle/cinn/hlir/framework/pir/compilation_task.h +++ b/paddle/cinn/hlir/framework/pir/compilation_task.h @@ -31,9 +31,7 @@ class GroupCompilationContext { std::shared_ptr scope) : target_(target), group_(group), scope_(scope) {} - void SetLoweredFuncs( - std::vector>&& funcs); + void SetLoweredFuncs(BucketLoweredFuncsWrapper&& funcs); std::string PrintPredicate2Funcs() const; void* FuncPtr(); std::shared_ptr BackendCompiler(); @@ -48,7 +46,7 @@ class GroupCompilationContext { size_t func_size_ = 0; std::vector predicates_; std::vector lowered_funcs_; - std::vector infer_shape_lowered_funcs_; + ir::LoweredFunc infer_shape_lowered_func_; std::string host_func_name_; std::string host_code_; std::vector device_code_; diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc index 062e5db1cc1f8c..1802a1404da0ac 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc @@ -99,17 +99,14 @@ std::vector OpLowererImpl::Lower(const GroupPtr& group, LOG(FATAL) << "Group Pattern Kind Is Unknown!"; } } - -std::vector> -OpLowererImpl::BucketLower(const GroupPtr& group, - bool apply_op_schedule, - bool apply_group_schedule, - bool apply_pass) { +BucketLoweredFuncsWrapper OpLowererImpl::BucketLower(const GroupPtr& group, + bool apply_op_schedule, + bool apply_group_schedule, + bool apply_pass) { // 1.Do compute, lower and schedule for each op. auto& ops = group->ops; if (ops.size() == 1 && ops[0]->name() == "custom_call") { - return {{ir::Expr(1), - pir::OpLowererImpl::WrapLoweredFunc(LowerCustomCall(group)[0])}}; + return {{{ir::Expr(1), LowerCustomCall(group)[0]}}, ir::LoweredFunc()}; } std::vector group_func_arg_tensors; std::unordered_map<::pir::Value, ir::Tensor> tensor_map; @@ -152,24 +149,29 @@ OpLowererImpl::BucketLower(const GroupPtr& group, // 3.Do post-processing, // including preparing function args and temporary variables, // applying low-level optimization passes, etc. - std::vector> cond2funcs; + std::vector scheduled_func_bodies; for (std::pair& cond2body : cond2func_bodies) { - std::vector group_func_arg_tensors_copy = - group_func_arg_tensors; - std::vector group_func_args; - std::vector funcs = - PostProcess(group, - tensor_map, - apply_op_schedule, - cond2body.second, - &group_func_arg_tensors_copy, - &group_func_args); - ir::LoweredFunc infer_shape_func = GenerateInferShapeFunc( - group, group_func_arg_tensors_copy, group_func_args); - cond2funcs.push_back({cond2body.first, {funcs[0], infer_shape_func}}); + scheduled_func_bodies.push_back(cond2body.second); + } + std::vector group_func_arg_tensors_copy = group_func_arg_tensors; + std::vector group_func_args; + std::vector funcs = PostProcess(group, + tensor_map, + apply_op_schedule, + {scheduled_func_bodies}, + &group_func_arg_tensors_copy, + &group_func_args); + CHECK_EQ(funcs.size(), cond2func_bodies.size()); + BucketLoweredFuncsWrapper funcs_wrapper; + for (int i = 0; i < funcs.size(); ++i) { + funcs_wrapper.predicate2funcs.emplace_back(cond2func_bodies[i].first, + funcs[i]); } - return cond2funcs; + funcs_wrapper.infer_shape_func = GenerateInferShapeFunc( + group, group_func_arg_tensors_copy, group_func_args); + + return funcs_wrapper; } void OpLowererImpl::InsertNameGeneToScope(std::shared_ptr scope) { @@ -300,7 +302,7 @@ std::vector OpLowererImpl::LowerMapExpr( return PostProcess(group, *tensor_map, apply_op_schedule, - ir_sch.GetModule().GetExprs()[0], + {ir_sch.GetModule().GetExprs()[0]}, group_func_arg_tensors, &group_func_args); } @@ -355,7 +357,7 @@ std::vector OpLowererImpl::LowerGroup( return PostProcess(group, tensor_map, do_op_schedule, - ir_sch.GetModule().GetExprs().at(0), + {ir_sch.GetModule().GetExprs().at(0)}, &group_func_arg_tensors, &group_func_args); } @@ -410,7 +412,7 @@ std::vector OpLowererImpl::PostProcess( const GroupPtr& group, const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, bool done_op_schedule, - ir::Expr func_body, + std::vector func_bodies, std::vector* group_func_arg_tensors, std::vector* group_func_args) { // 1.Prepare function args @@ -501,23 +503,28 @@ std::vector OpLowererImpl::PostProcess( } } + std::vector lowered_funcs; + for (ir::Expr func_body : func_bodies) { #ifdef CINN_WITH_CUDA - optim::OptimizeExprGPU(&(func_body)); + optim::OptimizeExprGPU(&(func_body)); #endif - // 2.Prepare temp buffers - poly::StageMap stages; - auto temp_buffers = - lang::GetTempBuffers(*group_func_arg_tensors, stages, func_body); - // 3.Building LoweredFunc - auto func = ir::_LoweredFunc_::Make( - group->FuncName(), *group_func_args, func_body, temp_buffers); - if (!done_op_schedule) { - func->PrepareBufferCastExprs(); + // 2.Prepare temp buffers + poly::StageMap stages; + auto temp_buffers = + lang::GetTempBuffers(*group_func_arg_tensors, stages, func_body); + // 3.Building LoweredFunc + auto func = ir::_LoweredFunc_::Make( + group->FuncName(), *group_func_args, func_body, temp_buffers); + if (!done_op_schedule) { + func->PrepareBufferCastExprs(); + } + // 4.Apply low level pass + func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref(); + lowered_funcs.push_back(std::move(func)); } - // 4.Apply low level pass - func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref(); - return {func}; + + return lowered_funcs; } std::vector OpLowererImpl::LowerOps( diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_impl.h b/paddle/cinn/hlir/framework/pir/op_lowering_impl.h index 0a9f4d4b33820a..f1ab9730a2df9b 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_impl.h +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.h @@ -70,11 +70,10 @@ class OpLowererImpl : public OpLowererImplBase { * @param apply_group_schedule Whether to schedule at group level. * @return The lowered funcs. */ - std::vector> - BucketLower(const GroupPtr& group, - bool apply_op_schedule = false, - bool apply_group_schedule = true, - bool apply_pass = true); + BucketLoweredFuncsWrapper BucketLower(const GroupPtr& group, + bool apply_op_schedule = false, + bool apply_group_schedule = true, + bool apply_pass = true); void InsertNameGeneToScope(std::shared_ptr scope); @@ -108,7 +107,7 @@ class OpLowererImpl : public OpLowererImplBase { * @param tensor_map All tensors used for calculating the group. * @param done_op_schedule Mark whether the Op level schedule has been * applied. - * @param func_body The scheduled func body of group. + * @param func_bodies The scheduled func bodies of group. * @param group_func_arg_tensors Tensors used as the group function arguments. * @param group_func_args Arguments used as the group function arguments. * @return The lowered funcs after the post processing. @@ -117,7 +116,7 @@ class OpLowererImpl : public OpLowererImplBase { const GroupPtr& group, const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, bool done_op_schedule, - ir::Expr func_body, + std::vector func_bodies, std::vector* group_func_arg_tensors, std::vector* group_func_args); diff --git a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc index 9f7a52d97fb178..657742e37ab421 100644 --- a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc +++ b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc @@ -25,16 +25,7 @@ namespace cinn { namespace ir { void DynamicShapeGroupScheduler::Init() { - // Only 1 bucket for test now. - schedule_context_.target = target_; - schedule_context_.output_names = OutputTensorNames(); - schedule_context_.global_master = FindGlobalMasterNode(); - schedule_context_.iter_space_info = - ConstructIterSpaceInfo(schedule_context_.global_master); - schedule_context_.bucket_info = {/* sp_lower_bound = */ 1024, - /* sp_upper_bound = */ INT_MAX, - /* rb_lower_bound = */ 64, - /* rb_upper_bound = */ INT_MAX}; + InitBuckets(); tactics_.emplace_back(new AlignIterSpaceTactic()); tactics_.emplace_back(new TileTactic()); tactics_.emplace_back(new ComputeInlineTactic()); @@ -42,43 +33,99 @@ void DynamicShapeGroupScheduler::Init() { tactics_.emplace_back(new ArrangeStorageTactic()); } +void DynamicShapeGroupScheduler::InitBuckets() { + std::unordered_set output_names = OutputTensorNames(); + ir::Expr fake_predicate = ir::LE::Make(Expr(1023), Expr(1024)); + auto InitBucket = [&](BucketInfo&& bucket_info) { + std::unique_ptr ir_sch = + std::make_unique(*ir_sch_); + std::unique_ptr schedule_block_graph = + std::make_unique(*ir_sch); + ir::ScheduleBlockNode* global_master = + FindGlobalMasterNode(schedule_block_graph); + IterativeSpaceInfo iter_space_info = ConstructIterSpaceInfo(global_master); + SymbolicPredicate sp_lower_bound_predicate = ir::GE::Make( + iter_space_info.total_sp_extent, ir::Expr(bucket_info.sp_lower_bound)); + SymbolicPredicate sp_upper_bound_predicate = ir::LT::Make( + iter_space_info.total_sp_extent, ir::Expr(bucket_info.sp_upper_bound)); + SymbolicPredicate rb_lower_bound_predicate = ir::GE::Make( + iter_space_info.total_rb_extent, ir::Expr(bucket_info.rb_lower_bound)); + SymbolicPredicate rb_upper_bound_predicate = ir::LT::Make( + iter_space_info.total_rb_extent, ir::Expr(bucket_info.rb_upper_bound)); + SymbolicPredicate sp_predicate = + ir::And::Make(sp_lower_bound_predicate, sp_upper_bound_predicate); + SymbolicPredicate rb_predicate = + ir::And::Make(rb_lower_bound_predicate, rb_upper_bound_predicate); + SymbolicPredicate predicate = ir::And::Make(sp_predicate, rb_predicate); + ScheduleContext schedule_context{output_names, + target_, + std::move(iter_space_info), + std::move(bucket_info)}; + BucketContext bucket_context{std::move(predicate), + std::move(ir_sch), + std::move(schedule_block_graph), + std::move(schedule_context)}; + bucket_contexts_.emplace_back(std::move(bucket_context)); + }; + // naive buckets + // 1. {sp_extent[1 - 1024], rb_extent[1 - 256]} + InitBucket({/* sp_lower_bound = */ 1, + /* sp_upper_bound = */ 1024, + /* rb_lower_bound = */ 1, + /* rb_upper_bound = */ 256}); + // 2. {sp_extent[1024 - +oo], rb_extent[1 - 256]} + InitBucket({/* sp_lower_bound = */ 1024, + /* sp_upper_bound = */ INT_MAX, + /* rb_lower_bound = */ 1, + /* rb_upper_bound = */ 256}); + // 3. {sp_extent[1 - 1024], rb_extent[256 - +oo]} + InitBucket({/* sp_lower_bound = */ 1, + /* sp_upper_bound = */ 1024, + /* rb_lower_bound = */ 256, + /* rb_upper_bound = */ INT_MAX}); + // 4. {sp_extent[1024 - +oo], rb_extent[256 - +oo]} + InitBucket({/* sp_lower_bound = */ 1024, + /* sp_upper_bound = */ INT_MAX, + /* rb_lower_bound = */ 256, + /* rb_upper_bound = */ INT_MAX}); +} + void DynamicShapeGroupScheduler::Schedule() { - ApplyTactics(); - // Fake bucket for test - ir::Expr predicate1 = ir::LE::Make(Expr(1023), Expr(1024)); - std::unique_ptr new_ir_sch1 = - std::make_unique(*ir_sch_); - ir_schs_.emplace_back(predicate1, std::move(new_ir_sch1)); + for (BucketContext& bucket_context : bucket_contexts_) { + VLOG(4) << "===========================Apply tactics on Bucket [" + << bucket_context.predicate << "]=========================="; + ApplyTactics(&bucket_context); + } } -void DynamicShapeGroupScheduler::ApplyTactics() { - schedule_block_graph_->Update(*ir_sch_); +void DynamicShapeGroupScheduler::ApplyTactics(BucketContext* bucket_context) { + bucket_context->schedule_block_graph->Update(*(bucket_context->ir_sch)); for (const auto& tactic : tactics_) { VLOG(5) << "[Start " << tactic->TacticName() << "] func body:\n" - << ir_sch_->GetModule().GetExprs().front(); + << bucket_context->ir_sch->GetModule().GetExprs().front(); auto ApplyTacticFunc = [&](ir::ScheduleBlockNode* node) { VLOG(6) << "before applying [" << tactic->TacticName() << "] on ScheduleBlockNode [" << node->id() << "] func body:\n" - << ir_sch_->GetModule().GetExprs().front(); - tactic->Apply(ir_sch_, node->id()); + << bucket_context->ir_sch->GetModule().GetExprs().front(); + tactic->Apply(bucket_context->ir_sch.get(), node->id()); VLOG(6) << "after applying [" << tactic->TacticName() << "] on ScheduleBlockNode [" << node->id() << "] func body:\n" - << ir_sch_->GetModule().GetExprs().front(); + << bucket_context->ir_sch->GetModule().GetExprs().front(); }; - tactic->Init(&schedule_context_); - schedule_block_graph_->DFSTopoWalk(ApplyTacticFunc); - schedule_block_graph_->Update(*ir_sch_); - VLOG(5) << "[End " << tactic->TacticName() - << "] func body: " << ir_sch_->GetModule().GetExprs().front(); + tactic->Init(&(bucket_context->schedule_context)); + bucket_context->schedule_block_graph->DFSTopoWalk(ApplyTacticFunc); + bucket_context->schedule_block_graph->Update(*(bucket_context->ir_sch)); + VLOG(5) << "[End " << tactic->TacticName() << "] func body: " + << bucket_context->ir_sch->GetModule().GetExprs().front(); } } std::vector> DynamicShapeGroupScheduler::GetIRs() { std::vector> irs; - for (auto& sch_pair : ir_schs_) { - irs.emplace_back(sch_pair.first, - sch_pair.second->GetModule().GetExprs()[0]); + for (BucketContext& context : bucket_contexts_) { + irs.emplace_back(context.predicate, + context.ir_sch->GetModule().GetExprs()[0]); } return irs; } @@ -95,7 +142,7 @@ IterativeSpaceInfo DynamicShapeGroupScheduler::ConstructIterSpaceInfo( std::vector iter_vars = block.As() ->schedule_block.As() ->iter_vars; - std::vector loops = ir_sch_->GetLoops(block); + std::vector loops = node->GetLoops(); std::unordered_set reduce_iter_vars = analyzer::GetReduceIterVars(block); std::unordered_map iter_var2value = @@ -184,7 +231,8 @@ IterativeSpaceInfo DynamicShapeGroupScheduler::ConstructIterSpaceInfo( return info; } -ir::ScheduleBlockNode* DynamicShapeGroupScheduler::FindGlobalMasterNode() { +ir::ScheduleBlockNode* DynamicShapeGroupScheduler::FindGlobalMasterNode( + const std::unique_ptr& schedule_block_graph) { ir::ScheduleBlockNode* master = nullptr; // 1. reduce auto FindReduce = [&](ir::ScheduleBlockNode* node) { @@ -192,7 +240,7 @@ ir::ScheduleBlockNode* DynamicShapeGroupScheduler::FindGlobalMasterNode() { master = node; } }; - schedule_block_graph_->NodesWalk(FindReduce); + schedule_block_graph->NodesWalk(FindReduce); if (master != nullptr) { VLOG(6) << "Find the global master node: " << master->id(); return master; @@ -203,13 +251,13 @@ ir::ScheduleBlockNode* DynamicShapeGroupScheduler::FindGlobalMasterNode() { master = node; } }; - schedule_block_graph_->NodesWalk(FindBroadcast); + schedule_block_graph->NodesWalk(FindBroadcast); if (master != nullptr) { VLOG(6) << "Find the global master node: " << master->id(); return master; } // 3. end point - master = schedule_block_graph_->EndPoints().back(); + master = schedule_block_graph->EndPoints().back(); VLOG(6) << "Find the global master node: " << master->id(); return master; } diff --git a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h index 896fe86bec852d..e226059011b633 100644 --- a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h +++ b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h @@ -37,20 +37,28 @@ class DynamicShapeGroupScheduler : public GroupScheduler { std::vector> GetIRs() override; + struct BucketContext { + SymbolicPredicate predicate; + std::unique_ptr ir_sch; + std::unique_ptr schedule_block_graph; + ScheduleContext schedule_context; + }; + private: void Init(); - void ApplyTactics(); + void InitBuckets(); + + void ApplyTactics(BucketContext* bucket_context); - ir::ScheduleBlockNode* FindGlobalMasterNode(); + ir::ScheduleBlockNode* FindGlobalMasterNode( + const std::unique_ptr& schedule_block_graph); IterativeSpaceInfo ConstructIterSpaceInfo(ScheduleBlockNode* node); private: - std::vector>> - ir_schs_; + std::vector bucket_contexts_; std::vector> tactics_; - ScheduleContext schedule_context_; }; } // namespace ir diff --git a/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h b/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h index 05c258b82c47ce..87c387c65d817e 100644 --- a/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h +++ b/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h @@ -58,9 +58,8 @@ struct BucketInfo { struct ScheduleContext { std::unordered_set output_names; - ScheduleBlockNode* global_master; - IterativeSpaceInfo iter_space_info; Target target; + IterativeSpaceInfo iter_space_info; BucketInfo bucket_info; }; diff --git a/paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc index 3cace2636f2d39..9586568f51f73e 100644 --- a/paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc +++ b/paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc @@ -22,6 +22,7 @@ void TileTactic::Init(ScheduleContext* context) { context_ = context; // fake strategy auto GetFirstFactor = [](int num) { + if (num == 1) return 1; int factor = 1; for (int i = num - 1; i >= 1; --i) { if (num % i == 0) { @@ -32,6 +33,8 @@ void TileTactic::Init(ScheduleContext* context) { bool has_rb_iter = !context_->iter_space_info.rb_space.empty(); bool has_sp_iter = !context_->iter_space_info.sp_space.empty(); + VLOG(6) << "has_sp_iter = " << has_sp_iter + << ", has_rb_iter = " << has_rb_iter; context_->iter_space_info.rb_space.clear(); context_->iter_space_info.sp_space.clear(); @@ -40,20 +43,50 @@ void TileTactic::Init(ScheduleContext* context) { context_->iter_space_info.sp_space.emplace_back( ir::Expr(context_->bucket_info.sp_lower_bound / sp_factor), IterativeSpaceInfo::AxisType::kCudaBlockX); + VLOG(6) << "sp_space: <" + << std::get<0>(context_->iter_space_info.sp_space.back()) + << ", AxisType[" + << static_cast( + std::get<1>(context_->iter_space_info.sp_space.back())) + << "]>"; context_->iter_space_info.sp_space.emplace_back( ir::Expr(sp_factor), has_rb_iter ? IterativeSpaceInfo::AxisType::kCudaThreadY : IterativeSpaceInfo::AxisType::kCudaThreadX); + VLOG(6) << "sp_space: <" + << std::get<0>(context_->iter_space_info.sp_space.back()) + << ", AxisType[" + << static_cast( + std::get<1>(context_->iter_space_info.sp_space.back())) + << "]>"; context_->iter_space_info.sp_space.emplace_back( ir::Expr(-1), IterativeSpaceInfo::AxisType::kSerial); + VLOG(6) << "sp_space: <" + << std::get<0>(context_->iter_space_info.sp_space.back()) + << ", AxisType[" + << static_cast( + std::get<1>(context_->iter_space_info.sp_space.back())) + << "]>"; } if (has_rb_iter) { context_->iter_space_info.rb_space.emplace_back( ir::Expr(context_->bucket_info.rb_lower_bound), IterativeSpaceInfo::AxisType::kCudaThreadX); + VLOG(6) << "rb_space: <" + << std::get<0>(context_->iter_space_info.rb_space.back()) + << ", AxisType[" + << static_cast( + std::get<1>(context_->iter_space_info.rb_space.back())) + << "]>"; context_->iter_space_info.rb_space.emplace_back( ir::Expr(-1), IterativeSpaceInfo::AxisType::kSerial); + VLOG(6) << "rb_space: <" + << std::get<0>(context_->iter_space_info.rb_space.back()) + << ", AxisType[" + << static_cast( + std::get<1>(context_->iter_space_info.rb_space.back())) + << "]>"; } } diff --git a/paddle/cinn/ir/module.cc b/paddle/cinn/ir/module.cc index fc58e44956fe76..20298e32920fb5 100644 --- a/paddle/cinn/ir/module.cc +++ b/paddle/cinn/ir/module.cc @@ -53,7 +53,7 @@ void Module::Builder::AddPredicate(ir::Expr predicate) { module_->predicates.push_back(predicate); } -void Module::Builder::AddInferShapeFunc(ir::Expr infer_shape_func) { +void Module::Builder::SetInferShapeFunc(ir::Expr infer_shape_func) { module_->infer_shape_func = infer_shape_func; } diff --git a/paddle/cinn/ir/module.h b/paddle/cinn/ir/module.h index 9910caab42b503..160d0087a0e545 100644 --- a/paddle/cinn/ir/module.h +++ b/paddle/cinn/ir/module.h @@ -45,7 +45,7 @@ class Module : public ir::IrNodeRef { void AddFunctionWithoutOptim(const ir::LoweredFunc& func); void AddBuffer(ir::Buffer buffer); void AddPredicate(ir::Expr predicate); - void AddInferShapeFunc(ir::Expr infer_shape_func); + void SetInferShapeFunc(ir::Expr infer_shape_func); void Clear(); Target::Arch GetTargetArch(); diff --git a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc index 180eb4f478fa6b..d8fd3db290b331 100644 --- a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc @@ -50,8 +50,6 @@ class CinnJitInstruction::FnPtrImpl { } // 2. Convert arg's data about shape of Tensor to cinn_pod_value_t for (const auto& int_arg_mp : cinn_kernel_info_.int_args_map) { - func_args_.emplace_back(kernel_args[int_arg_mp.second.arg_idx]->dims().at( - int_arg_mp.second.dim_idx)); func_args_.emplace_back(static_cast( kernel_args[int_arg_mp.second.arg_idx]->dims().at( int_arg_mp.second.dim_idx))); From 41679e4c1c6395a5d9f21e654882c8261befb72a Mon Sep 17 00:00:00 2001 From: Tian Zheng Date: Mon, 8 Jan 2024 11:26:56 +0800 Subject: [PATCH 139/142] Add CUDNNv8 max pooling (#59413) * Add CUDNNv8 version of pool2d * Minor fix * Fix build failure * Remove dygraph API * Fix CI failure * Fix CI failure * Fix timeout * Fix timeout * Add comments * Minor fix --- .../pir/dialect/op_generator/ops_api_gen.py | 1 + .../op_generator/vjp_interface_black_list.py | 1 + paddle/fluid/pybind/pybind.cc | 9 + paddle/phi/api/yaml/fused_backward.yaml | 11 + paddle/phi/api/yaml/fused_ops.yaml | 15 ++ paddle/phi/infermeta/unary.cc | 31 +++ paddle/phi/infermeta/unary.h | 11 + paddle/phi/kernels/CMakeLists.txt | 4 +- paddle/phi/kernels/autotune/cache.h | 4 +- .../fusion/gpu/max_pool2d_v2_grad_kernel.cu | 255 ++++++++++++++++++ .../fusion/gpu/max_pool2d_v2_kernel.cu | 236 ++++++++++++++++ paddle/phi/kernels/gpudnn/pool_gpudnn.h | 15 ++ test/legacy_test/CMakeLists.txt | 1 + test/legacy_test/test_pool_max_op.py | 236 ++++++++++++++++ test/white_list/no_check_set_white_list.py | 1 + test/white_list/op_accuracy_white_list.py | 1 + tools/gpups_test.sh | 1 + 17 files changed, 831 insertions(+), 2 deletions(-) create mode 100644 paddle/phi/kernels/fusion/gpu/max_pool2d_v2_grad_kernel.cu create mode 100644 paddle/phi/kernels/fusion/gpu/max_pool2d_v2_kernel.cu diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index af224cb5be8ab3..314f4b343d481e 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -130,6 +130,7 @@ 'fused_dot_product_attention', 'nce', 'lars_momentum', + 'max_pool2d_v2', 'recv_v2', 'rnn_', 'row_conv', diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py index 9551bfc425ebc6..4a4c4707ac9a8c 100644 --- a/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py +++ b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py @@ -29,4 +29,5 @@ 'fused_rotary_position_embedding', 'fused_bias_dropout_residual_layer_norm', 'fused_dot_product_attention', + 'max_pool2d_v2', ] diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 53df4c25034abd..32e9ffd3a5c639 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -254,6 +254,14 @@ bool IsCompiledWithCUDA() { #endif } +bool IsCompiledWithCudnnFrontend() { +#ifndef PADDLE_WITH_CUDNN_FRONTEND + return false; +#else + return true; +#endif +} + bool IsCompiledWithDISTRIBUTE() { #if !defined(PADDLE_WITH_DISTRIBUTE) return false; @@ -2124,6 +2132,7 @@ All parameter, weight, gradient are variables in Paddle. }); m.def("is_compiled_with_avx", IsCompiledWithAVX); m.def("is_compiled_with_cuda", IsCompiledWithCUDA); + m.def("is_compiled_with_cudnn_frontend", IsCompiledWithCudnnFrontend); m.def("is_compiled_with_rocm", IsCompiledWithROCM); m.def("is_compiled_with_custom_device", IsCompiledWithCustomDevice); m.def("is_compiled_with_ipu", IsCompiledWithIPU); diff --git a/paddle/phi/api/yaml/fused_backward.yaml b/paddle/phi/api/yaml/fused_backward.yaml index 649e427b25a34d..8a2a9786a837a0 100644 --- a/paddle/phi/api/yaml/fused_backward.yaml +++ b/paddle/phi/api/yaml/fused_backward.yaml @@ -51,3 +51,14 @@ func : fused_rotary_position_embedding_grad data_type : out_q_grad support_dygraph_mode : true + +- backward_op : max_pool2d_v2_grad + forward : max_pool2d_v2(Tensor x, int[] kernel_size, int[] strides= {1, 1}, int[] paddings = {0, 0}, str data_format = "NCHW", bool global_pooling = false, bool adaptive = false) -> Tensor(out), Tensor(saved_idx) + args : (Tensor x, Tensor out, Tensor saved_idx, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, str data_format, bool global_pooling, bool adaptive) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : max_pool2d_v2_grad + param: [x, out, saved_idx, out_grad, kernel_size, strides, paddings, data_format, global_pooling, adaptive] diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 1b429fc958de7e..235ddaaacc6948 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -383,6 +383,21 @@ func : layer_norm_act_xpu data_type : x +# This op is implemented using CUDNN Frontend API, which serves as a supplement to +# legacy max pooling implementation. It shows better performance with NHWC layout and +# half precision. +- op : max_pool2d_v2 + args : (Tensor x, int[] kernel_size, int[] strides= {1, 1}, int[] paddings = {0, 0}, str data_format = "NCHW", bool global_pooling = false, bool adaptive = false) + output : Tensor(out), Tensor(saved_idx) + infer_meta : + func : MaxPoolV2InferMeta + param : [x, kernel_size, strides, paddings, data_format, global_pooling, adaptive] + kernel : + func : max_pool2d_v2 + param : [x, kernel_size, strides, paddings, data_format, global_pooling, adaptive] + intermediate: saved_idx + backward : max_pool2d_v2_grad + - op : multi_encoder_xpu args : (Tensor x, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor mask, Tensor seq_lod, Tensor max_seq_len, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx) output : Tensor(out), Tensor(x_fp16), Tensor(out_fp16) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index b1b06fdbfed715..39cec09e3db86f 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2349,6 +2349,37 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x, mask->set_dtype(phi::CppTypeToDataType::Type()); } +void MaxPoolV2InferMeta(const MetaTensor& x, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::string& data_format, + bool global_pooling, + bool adaptive, + MetaTensor* out, + MetaTensor* saved_idx, + MetaConfig config) { + PADDLE_ENFORCE_EQ(adaptive, + false, + phi::errors::InvalidArgument( + "max_pool2d_v2 op does not support adaptive.")); + Pool2DInferMeta(x, + kernel_size, + strides, + paddings, + false, + false, + data_format, + "max", + global_pooling, + adaptive, + "EXPLICIT", + out, + config); + saved_idx->set_dims(out->dims()); + saved_idx->set_dtype(phi::CppTypeToDataType::Type()); +} + void MeanAllInferMeta(const MetaTensor& x, MetaTensor* out) { out->set_dims(common::make_ddim({})); out->set_dtype(x.dtype()); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 0126b76754fef2..aaa85da1f85242 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -350,6 +350,17 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x, MetaTensor* mask, MetaConfig config = MetaConfig()); +void MaxPoolV2InferMeta(const MetaTensor& x, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::string& data_format, + bool global_pooling, + bool adaptive, + MetaTensor* out, + MetaTensor* saved_idx, + MetaConfig config = MetaConfig()); + void MeanAllInferMeta(const MetaTensor& x, MetaTensor* out); void ModeInferMeta(const MetaTensor& x, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index f6ed266577bac5..c0ef08cb0b5ef2 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -222,7 +222,9 @@ if(NOT WITH_CUDNN_FRONTEND) "fusion/gpu/fused_scale_bias_relu_conv_bn_kernel.cu" "fusion/gpu/fused_scale_bias_add_relu_kernel.cu" "fusion/gpu/fused_dconv_drelu_dbn_kernel.cu" - "fusion/gpu/fused_dot_product_attention_op.cu") + "fusion/gpu/fused_dot_product_attention_op.cu" + "fusion/gpu/max_pool2d_v2_grad_kernel.cu" + "fusion/gpu/max_pool2d_v2_kernel.cu") endif() set(cc_search_pattern diff --git a/paddle/phi/kernels/autotune/cache.h b/paddle/phi/kernels/autotune/cache.h index fcb9058cd0a760..0554ab526d5ee0 100644 --- a/paddle/phi/kernels/autotune/cache.h +++ b/paddle/phi/kernels/autotune/cache.h @@ -61,7 +61,9 @@ enum class AlgorithmType { kDgradDreluBnBwdWeight = 16, kDbnApply = 17, kBnActWgrad = 18, - kAlgorithmCount = 19 + kPoolingForwardV8 = 19, + kPoolingBackwardV8 = 20, + kAlgorithmCount = 21 #endif }; diff --git a/paddle/phi/kernels/fusion/gpu/max_pool2d_v2_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/max_pool2d_v2_grad_kernel.cu new file mode 100644 index 00000000000000..9cd45357f0bfbe --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/max_pool2d_v2_grad_kernel.cu @@ -0,0 +1,255 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/phi/backends/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/flags.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/autotune/cache.h" +#include "paddle/phi/kernels/funcs/pooling.h" +#include "paddle/phi/kernels/gpudnn/conv_cudnn_frontend.h" +#include "paddle/phi/kernels/gpudnn/pool_gpudnn.h" + +PHI_DECLARE_bool(cudnn_exhaustive_search); + +namespace phi { + +template +void MaxPoolV2GradCUDNNKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& saved_idx, + const DenseTensor& dout, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::string& data_format, + bool global_pooling, + bool adaptive, + DenseTensor* dx) { + PADDLE_ENFORCE_GE(ctx.GetComputeCapability(), + 80, + phi::errors::PreconditionNotMet( + "This op only supports Ampere and later devices, " + "but got compute capability: %d.", + ctx.GetComputeCapability())); + // Additional options + bool exhaustive_search = FLAGS_cudnn_exhaustive_search; + bool deterministic = FLAGS_cudnn_deterministic; + PADDLE_ENFORCE_EQ(exhaustive_search && deterministic, + false, + phi::errors::InvalidArgument( + "Can't set exhaustive_search True and " + "FLAGS_cudnn_deterministic True at same time.")); + // Allocate output tensors + ctx.template Alloc(dx); + // Update paddings + std::vector paddings_ = paddings; + std::vector kernel_size_ = kernel_size; + const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); + PADDLE_ENFORCE_EQ( + channel_last, + true, + phi::errors::InvalidArgument( + "NCHW layout is currently not supported for max pooling bwd.")); + const std::string padding_algorithm = "EXPLICIT"; + + auto x_dims = x.dims(); + DDim data_dims; + if (channel_last) { + data_dims = slice_ddim(x_dims, 1, x_dims.size() - 1); + } else { + data_dims = slice_ddim(x_dims, 2, x_dims.size()); + } + funcs::UpdatePadding(&paddings_, + global_pooling, + adaptive, + padding_algorithm, + data_dims, + strides, + kernel_size_); + + const auto data_dim = data_dims.size(); + std::vector pre_padding(data_dim, 0); + std::vector post_padding(data_dim, 0); + for (size_t i = 0; i < data_dim; ++i) { + pre_padding[i] = static_cast(paddings_[2 * i]); + post_padding[i] = static_cast(paddings_[2 * i + 1]); + } + + if (global_pooling) { + funcs::UpdateKernelSize(&kernel_size_, data_dims); + } + + using helper = CudnnFrontendConvHelper; + auto kernel_size_int64 = helper::GetInt64Array(kernel_size_); + auto strides_int64 = helper::GetInt64Array(strides); + + // Create tensor descriptors + auto& plan_cache = phi::autotune::AutoTuneCache::Instance().GetConvV8( + phi::autotune::AlgorithmType::kPoolingBackwardV8); + + T2* saved_idx_data = const_cast(saved_idx.data()); + T1* dout_data = const_cast(dout.data()); + T1* dx_data = dx->data(); + + auto uid = [](std::string name) { + const std::map _uid = { + {"saved_idx", 0}, {"dout", 1}, {"dx", 2}}; + PADDLE_ENFORCE_GT(_uid.count(name), + 0, + phi::errors::InvalidArgument( + "The tensor name %s is unknown. " + "Should be in one of [saved_idx, dout, dx].", + name)); + return _uid.at(name); + }; + + cudnnHandle_t handle = const_cast(ctx.cudnn_handle()); + auto workspace_handle = ctx.cudnn_workspace_handle(); + + auto layout = GetLayoutFromStr(data_format); + auto layout_format = phi::backends::gpu::GetCudnnTensorFormat(layout); + auto input_dtype = phi::backends::gpu::CudnnDataType::type; + auto saved_idx_dtype = CudnnIndexType::type; + + // Create plan and execute + std::vector data_ptrs({saved_idx_data, dout_data, dx_data}); + std::vector uids({uid("saved_idx"), uid("dout"), uid("dx")}); + + // Create feature vector for plan caching + cudnn_frontend::feature_vector_t feature_vector; + auto dim_x = phi::vectorize(x.dims()); + phi::autotune::BuildFeatureVector(&feature_vector, + dim_x, + kernel_size_int64, + strides_int64, + pre_padding, + post_padding, + data_format, + input_dtype, + saved_idx_dtype); + + if (plan_cache.FindPlan(feature_vector, handle)) { + const cudnn_frontend::ExecutionPlan* cached_plan = nullptr; + int64_t workspace_size = 0; + plan_cache.GetPlanAndWorkspaceSize( + feature_vector, &cached_plan, &workspace_size, handle); + helper::ExecutePlan(handle, + &workspace_handle, + &data_ptrs, + &uids, + cached_plan->get_raw_desc(), + workspace_size); + return; + } + + auto saved_idx_desc = + helper::GetTensorDescriptor(&saved_idx, uid("saved_idx"), layout_format); + auto dout_desc = + helper::GetTensorDescriptor(&dout, uid("dout"), layout_format); + auto dx_desc = helper::GetTensorDescriptor(dx, uid("dx"), layout_format); + + // Create maxpooling descriptor + auto const nan_opt = CUDNN_NOT_PROPAGATE_NAN; + auto const mode = cudnn_frontend::cudnnResampleMode_t::CUDNN_RESAMPLE_MAXPOOL; + auto const padding_mode = + cudnn_frontend::cudnnPaddingMode_t::CUDNN_NEG_INF_PAD; + auto pool_desc = cudnn_frontend::ResampleDescBuilder_v8() + .setComputeType(CUDNN_DATA_FLOAT) + .setNanPropagation(nan_opt) + .setResampleMode(mode) + .setPaddingMode(padding_mode) + .setSpatialDim(data_dim, kernel_size_int64.data()) + .setSpatialStride(data_dim, strides_int64.data()) + .setPrePadding(data_dim, pre_padding.data()) + .setPostPadding(data_dim, post_padding.data()) + .build(); + + // Create maxpooling bwd op + auto pool_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_RESAMPLE_BWD_DESCRIPTOR) + .setdxDesc(dx_desc) + .setdyDesc(dout_desc) + .setidxDesc(saved_idx_desc) + .setResampleDesc(pool_desc) + .build(); + + // Create op graph + std::array ops = {&pool_op}; + auto op_graph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle) + .setOperationGraph(ops.size(), ops.data()) + .build(); + + auto plans = helper::FindExecutionPlans(&op_graph, + exhaustive_search, + deterministic, + &data_ptrs, + &uids, + handle, + &workspace_handle); + + helper::ExecutePlansAndCache(handle, + &workspace_handle, + &data_ptrs, + &uids, + &plans, + exhaustive_search, + feature_vector, + &plan_cache); +} + +template +void MaxPool2dV2GradCUDNNKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& saved_idx, + const DenseTensor& dout, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::string& data_format, + bool global_pooling, + bool adaptive, + DenseTensor* dx) { + MaxPoolV2GradCUDNNKernel(ctx, + x, + out, + saved_idx, + dout, + kernel_size, + strides, + paddings, + data_format, + global_pooling, + adaptive, + dx); +} + +} // namespace phi + +using phi::dtype::float16; + +PD_REGISTER_KERNEL(max_pool2d_v2_grad, // cuda_only + GPU, + ALL_LAYOUT, + phi::MaxPool2dV2GradCUDNNKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->InputAt(2).SetDataType(phi::CppTypeToDataType::Type()); +} diff --git a/paddle/phi/kernels/fusion/gpu/max_pool2d_v2_kernel.cu b/paddle/phi/kernels/fusion/gpu/max_pool2d_v2_kernel.cu new file mode 100644 index 00000000000000..46cabfe8b2d857 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/max_pool2d_v2_kernel.cu @@ -0,0 +1,236 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/phi/backends/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/flags.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/autotune/cache.h" +#include "paddle/phi/kernels/funcs/pooling.h" +#include "paddle/phi/kernels/gpudnn/conv_cudnn_frontend.h" +#include "paddle/phi/kernels/gpudnn/pool_gpudnn.h" + +PHI_DECLARE_bool(cudnn_exhaustive_search); + +namespace phi { + +template +void MaxPoolV2CUDNNKernel(const Context& ctx, + const DenseTensor& x, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::string& data_format, + bool global_pooling, + bool adaptive, + DenseTensor* out, + DenseTensor* saved_idx) { + PADDLE_ENFORCE_GE(ctx.GetComputeCapability(), + 80, + phi::errors::PreconditionNotMet( + "This op only supports Ampere and later devices, " + "but got compute capability: %d.", + ctx.GetComputeCapability())); + // Additional options + bool exhaustive_search = FLAGS_cudnn_exhaustive_search; + bool deterministic = FLAGS_cudnn_deterministic; + PADDLE_ENFORCE_EQ(exhaustive_search && deterministic, + false, + phi::errors::InvalidArgument( + "Cann't set exhaustive_search True and " + "FLAGS_cudnn_deterministic True at same time.")); + // Allocate output tensors + ctx.template Alloc(out); + ctx.template Alloc(saved_idx); + // Update paddings + std::vector paddings_ = paddings; + std::vector kernel_size_ = kernel_size; + const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); + const std::string padding_algorithm = "EXPLICIT"; + + auto x_dims = x.dims(); + DDim data_dims; + if (channel_last) { + data_dims = slice_ddim(x_dims, 1, x_dims.size() - 1); + } else { + data_dims = slice_ddim(x_dims, 2, x_dims.size()); + } + funcs::UpdatePadding(&paddings_, + global_pooling, + adaptive, + padding_algorithm, + data_dims, + strides, + kernel_size_); + + const auto data_dim = data_dims.size(); + std::vector pre_padding(data_dim, 0); + std::vector post_padding(data_dim, 0); + for (size_t i = 0; i < data_dim; ++i) { + pre_padding[i] = static_cast(paddings_[2 * i]); + post_padding[i] = static_cast(paddings_[2 * i + 1]); + } + + if (global_pooling) { + funcs::UpdateKernelSize(&kernel_size_, data_dims); + } + + using helper = CudnnFrontendConvHelper; + auto kernel_size_int64 = helper::GetInt64Array(kernel_size_); + auto strides_int64 = helper::GetInt64Array(strides); + + // Prepare for execution + auto& plan_cache = phi::autotune::AutoTuneCache::Instance().GetConvV8( + phi::autotune::AlgorithmType::kPoolingForwardV8); + + T1* input_data = const_cast(x.data()); + T1* output_data = out->data(); + T2* saved_idx_data = saved_idx->data(); + + cudnnHandle_t handle = const_cast(ctx.cudnn_handle()); + auto workspace_handle = ctx.cudnn_workspace_handle(); + + auto layout = GetLayoutFromStr(data_format); + auto layout_format = phi::backends::gpu::GetCudnnTensorFormat(layout); + auto input_dtype = phi::backends::gpu::CudnnDataType::type; + auto saved_idx_dtype = CudnnIndexType::type; + + // Create plan and execute + std::vector data_ptrs({input_data, output_data, saved_idx_data}); + std::vector uids({'x', 'o', 's'}); + + // Create feature vector for plan caching + cudnn_frontend::feature_vector_t feature_vector; + auto dim_x = phi::vectorize(x.dims()); + + phi::autotune::BuildFeatureVector(&feature_vector, + dim_x, + kernel_size_int64, + strides_int64, + pre_padding, + post_padding, + data_format, + input_dtype, + saved_idx_dtype); + + // Query cache and execute + if (plan_cache.FindPlan(feature_vector, handle)) { + const cudnn_frontend::ExecutionPlan* cached_plan = nullptr; + int64_t workspace_size = 0; + plan_cache.GetPlanAndWorkspaceSize( + feature_vector, &cached_plan, &workspace_size, handle); + helper::ExecutePlan(handle, + &workspace_handle, + &data_ptrs, + &uids, + cached_plan->get_raw_desc(), + workspace_size); + return; + } + + // Create tensor descriptors + auto x_desc = helper::GetTensorDescriptor(&x, 'x', layout_format); + auto out_desc = helper::GetTensorDescriptor(out, 'o', layout_format); + auto saved_idx_desc = + helper::GetTensorDescriptor(saved_idx, 's', layout_format); + + // Create maxpooling descriptor + auto const nan_opt = CUDNN_NOT_PROPAGATE_NAN; + auto const mode = cudnn_frontend::cudnnResampleMode_t::CUDNN_RESAMPLE_MAXPOOL; + auto const padding_mode = + cudnn_frontend::cudnnPaddingMode_t::CUDNN_NEG_INF_PAD; + auto pool_desc = cudnn_frontend::ResampleDescBuilder_v8() + .setComputeType(CUDNN_DATA_FLOAT) + .setNanPropagation(nan_opt) + .setResampleMode(mode) + .setPaddingMode(padding_mode) + .setSpatialDim(data_dim, kernel_size_int64.data()) + .setSpatialStride(data_dim, strides_int64.data()) + .setPrePadding(data_dim, pre_padding.data()) + .setPostPadding(data_dim, post_padding.data()) + .build(); + + // Create maxpooling op + auto pool_op = cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_RESAMPLE_FWD_DESCRIPTOR) + .setxDesc(x_desc) + .setyDesc(out_desc) + .setidxDesc(saved_idx_desc) + .setResampleDesc(pool_desc) + .build(); + + // Create op graph + std::array ops = {&pool_op}; + auto op_graph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle) + .setOperationGraph(ops.size(), ops.data()) + .build(); + + auto plans = helper::FindExecutionPlans(&op_graph, + exhaustive_search, + deterministic, + &data_ptrs, + &uids, + handle, + &workspace_handle); + + helper::ExecutePlansAndCache(handle, + &workspace_handle, + &data_ptrs, + &uids, + &plans, + exhaustive_search, + feature_vector, + &plan_cache); +} + +template +void MaxPool2dV2CUDNNKernel(const Context& ctx, + const DenseTensor& x, + const std::vector& kernel_size, + const std::vector& strides, + const std::vector& paddings, + const std::string& data_format, + bool global_pooling, + bool adaptive, + DenseTensor* out, + DenseTensor* saved_idx) { + // TODO(tizheng): support int8 mask + MaxPoolV2CUDNNKernel(ctx, + x, + kernel_size, + strides, + paddings, + data_format, + global_pooling, + adaptive, + out, + saved_idx); +} + +} // namespace phi + +using phi::dtype::float16; + +PD_REGISTER_KERNEL(max_pool2d_v2, // cuda_only + GPU, + ALL_LAYOUT, + phi::MaxPool2dV2CUDNNKernel, + float, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->OutputAt(1).SetDataType(phi::CppTypeToDataType::Type()); +} diff --git a/paddle/phi/kernels/gpudnn/pool_gpudnn.h b/paddle/phi/kernels/gpudnn/pool_gpudnn.h index d830aad6b4f4f3..cd2758109f28ce 100644 --- a/paddle/phi/kernels/gpudnn/pool_gpudnn.h +++ b/paddle/phi/kernels/gpudnn/pool_gpudnn.h @@ -29,6 +29,21 @@ template using ScalingParamType = typename phi::backends::gpu::CudnnDataType::ScalingParamType; +template +class CudnnIndexType; + +template <> +class CudnnIndexType { + public: + static const cudnnDataType_t type = CUDNN_DATA_INT32; +}; + +template <> +class CudnnIndexType { + public: + static const cudnnDataType_t type = CUDNN_DATA_INT8; +}; + inline GPUDNNDataLayout GetLayoutFromStr(std::string data_format) { if (data_format == "NHWC") { return GPUDNNDataLayout::kNHWC; diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index ed0f40f982d23c..2d3116d5ad69b1 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -1082,6 +1082,7 @@ set_tests_properties( test_buffer_shared_memory_reuse_pass_and_fuse_optimization_op_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_conv_nn_grad PROPERTIES TIMEOUT 220) +set_tests_properties(test_pool_max_op PROPERTIES TIMEOUT 500) set_tests_properties(test_program_prune_backward PROPERTIES TIMEOUT 120) set_tests_properties(test_group_norm_op PROPERTIES TIMEOUT 1000) set_tests_properties(test_imperative_optimizer PROPERTIES TIMEOUT 250) diff --git a/test/legacy_test/test_pool_max_op.py b/test/legacy_test/test_pool_max_op.py index 23740d39b8ef31..f2186dee7c3399 100644 --- a/test/legacy_test/test_pool_max_op.py +++ b/test/legacy_test/test_pool_max_op.py @@ -469,5 +469,241 @@ def test_check_grad(self): create_test_bf16_class(TestCastAdaptive2d) +def skip_unit_test(): + return ( + not core.is_compiled_with_cuda() + or not core.is_compiled_with_cudnn_frontend() + or paddle.device.cuda.get_device_capability()[0] < 8 + ) + + +@unittest.skipIf( + skip_unit_test(), + "Only support Ampere or later devices; " + "Paddle should be built with WITH_CUDNN_FRONTEND=ON.", +) +class TestMaxPool2dV2Op(OpTest): + def setUp(self): + self.init_layout() + self.init_test_case() + self.init_global() + self.init_adaptive() + self.init_dtype() + + if self.is_bfloat16_op(): + input = np.random.random(self.shape).astype(np.float32) + input = convert_uint16_to_float( + convert_float_to_uint16(np.round(input * 100.0, 2)) + ) + + else: + input = np.random.random(self.shape).astype(self.dtype) + input = np.round(input * 100.0, 2) + + output, _ = self.pool_forward_naive( + input, + self.ksize, + self.strides, + self.paddings, + self.global_pool, + self.adaptive, + ) + if self.is_bfloat16_op(): + output = output.astype(np.float32) + else: + output = output.astype(self.dtype) + + self.attrs = { + 'strides': self.strides, + 'paddings': self.paddings, + 'kernel_size': self.ksize, + 'data_format': self.data_format, + 'global_pooling': self.global_pool, + 'adaptive': self.adaptive, + } + + if self.data_format == 'NHWC': + input = input.transpose((0, 2, 3, 1)) + output = output.transpose((0, 2, 3, 1)) + + saved_idx = np.zeros(shape=output.shape, dtype=np.int32) + + if self.is_bfloat16_op(): + self.inputs = { + 'x': convert_float_to_uint16( + input, data_format=self.data_format + ) + } + self.outputs = { + 'out': convert_float_to_uint16( + output, data_format=self.data_format + ), + 'saved_idx': saved_idx, + } + self.inputs_fp32 = {'x': input} + + else: + self.inputs = {'x': input} + self.outputs = {'out': output, 'saved_idx': saved_idx} + + def init_layout(self): + self.data_format = "NHWC" + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_output_with_place( + place, no_check_set=['saved_idx'], check_dygraph=False + ) + + def test_check_grad(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, + {'x'}, + ['out'], + max_relative_error=0.05, + check_dygraph=False, + ) + + def init_test_case(self): + self.op_type = "max_pool2d_v2" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1] + + def init_global(self): + self.global_pool = True + + def init_adaptive(self): + self.adaptive = False + + +class TestCase8(TestMaxPool2dV2Op): + def init_global(self): + self.global_pool = False + + def test_check_grad(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, + {'x'}, + ['out'], + max_relative_error=0.5, + check_dygraph=False, + ) + + +class TestCase9(TestMaxPool2dV2Op): + def init_test_case(self): + self.op_type = "max_pool2d_v2" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [2, 2] + self.paddings = [0, 0] + + def init_global(self): + self.global_pool = True + + +class TestCase10(TestCase9): + def init_global(self): + self.global_pool = False + + +def create_test_fp16_class(parent): + class TestMaxPool2dV2FP16(parent): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place( + place, no_check_set=['saved_idx'], check_dygraph=False + ) + + def test_check_grad(self): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_grad_with_place( + place, {'x'}, ['out'], check_dygraph=False + ) + + cls_name = "{}_{}".format(parent.__name__, "FP16OP") + TestMaxPool2dV2FP16.__name__ = cls_name + globals()[cls_name] = TestMaxPool2dV2FP16 + + +create_test_fp16_class(TestMaxPool2dV2Op) +create_test_fp16_class(TestCase8) +create_test_fp16_class(TestCase9) +create_test_fp16_class(TestCase10) + + +def create_test_bf16_class(parent): + @unittest.skipIf( + skip_unit_test() or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support bfloat16", + ) + class TestMaxPool2dV2BF16(parent): + def init_dtype(self): + self.dtype = np.uint16 + + def get_numeric_grad(self, place, check_name): + scope = core.Scope() + self._check_grad_helper() + op = create_op( + scope, self.op_type, self.inputs, self.outputs, self.attrs + ) + return get_numeric_gradient( + place, + scope, + op, + self.inputs_fp32, + check_name, + ['out'], + delta=0.005, + ) + + def test_check_output(self): + place = core.CUDAPlace(0) + if core.is_bfloat16_supported(place): + self.check_output_with_place( + place, no_check_set=['saved_idx'], check_dygraph=False + ) + + def test_check_grad(self): + place = core.CUDAPlace(0) + numeric_grads = self.get_numeric_grad(place, 'x') + if core.is_bfloat16_supported(place): + self.check_grad_with_place( + place, + {'x'}, + ['out'], + user_defined_grads=[numeric_grads], + check_dygraph=False, + ) + + cls_name = "{}_{}".format(parent.__name__, "BF16OP") + TestMaxPool2dV2BF16.__name__ = cls_name + globals()[cls_name] = TestMaxPool2dV2BF16 + + +create_test_bf16_class(TestMaxPool2dV2Op) +create_test_bf16_class(TestCase8) +create_test_bf16_class(TestCase9) +create_test_bf16_class(TestCase10) + + if __name__ == '__main__': unittest.main() diff --git a/test/white_list/no_check_set_white_list.py b/test/white_list/no_check_set_white_list.py index 806b0891ea92e9..16bf755eecf6ef 100644 --- a/test/white_list/no_check_set_white_list.py +++ b/test/white_list/no_check_set_white_list.py @@ -39,4 +39,5 @@ 'rmsprop', 'rrelu', 'layer_norm', + 'max_pool2d_v2', ] diff --git a/test/white_list/op_accuracy_white_list.py b/test/white_list/op_accuracy_white_list.py index 5ad871e071ba46..3027f4960c050d 100644 --- a/test/white_list/op_accuracy_white_list.py +++ b/test/white_list/op_accuracy_white_list.py @@ -41,6 +41,7 @@ 'lrn', 'match_matrix_tensor', 'matmul', + 'max_pool2d_v2', 'max_pool2d_with_index', 'max_pool3d_with_index', 'minus', diff --git a/tools/gpups_test.sh b/tools/gpups_test.sh index a482de9074eac9..36675f650578a7 100644 --- a/tools/gpups_test.sh +++ b/tools/gpups_test.sh @@ -110,6 +110,7 @@ parallel_list="^init_phi_test$|\ ^test_gather_nd_op$|\ ^test_index_select_op$|\ ^test_pass_base_list$|\ +^test_pool_max_op$|\ ^test_roll_op$|\ ^test_switch_autotune$|\ ^test_tcp_store$|\ From fa1f901d9e24c03d973648583754cd6c629ad384 Mon Sep 17 00:00:00 2001 From: lijialin03 <124568209+lijialin03@users.noreply.github.com> Date: Mon, 8 Jan 2024 15:44:29 +0800 Subject: [PATCH 140/142] update lbfgs to avoid the randomness caused by paddle.dot() temporarily (#60591) * update lbfgs to avoid the randomness caused by paddle.dot() temporarily * add note --- python/paddle/optimizer/lbfgs.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/python/paddle/optimizer/lbfgs.py b/python/paddle/optimizer/lbfgs.py index 936b71b232d4d9..50065086359287 100644 --- a/python/paddle/optimizer/lbfgs.py +++ b/python/paddle/optimizer/lbfgs.py @@ -23,6 +23,14 @@ __all__ = [] +def dot(x, y): + r""" + NOTE: This is a temporary workaround for unstable result computed by `paddle.dot`, + which will be reverted when the problem is fixed." + """ + return (x * y).sum(axis=-1) + + def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None): r"""Cubic interpolation between (x1, f1, g1) and (x2, f2, g2). Use two points and their gradient to determine a cubic function and get the minimum point @@ -152,7 +160,7 @@ def _strong_wolfe( # evaluate objective and gradient using initial step loss_new, grad_new = obj_func(xk, alpha, d) ls_func_evals = 1 - gtd_new = paddle.dot(grad_new, d) + gtd_new = dot(grad_new, d) # bracket an interval containing a point satisfying the Wolfe criteria t_prev, f_prev, g_prev, gtd_prev = (0, loss, grad, gtd) @@ -205,7 +213,7 @@ def _strong_wolfe( loss_new, grad_new = obj_func(xk, alpha, d) ls_func_evals += 1 - gtd_new = grad_new.dot(d) + gtd_new = dot(grad_new, d) ls_iter += 1 # reached max number of iterations? @@ -265,7 +273,7 @@ def _strong_wolfe( # Evaluate new point loss_new, grad_new = obj_func(xk, alpha, d) ls_func_evals += 1 - gtd_new = grad_new.dot(d) + gtd_new = dot(grad_new, d) ls_iter += 1 if ( @@ -644,7 +652,7 @@ def step(self, closure): # do lbfgs update (update memory) y = flat_grad.subtract(prev_flat_grad) s = d.multiply(paddle.to_tensor(alpha, dtype=d.dtype)) - ys = y.dot(s) + ys = dot(y, s) if ys > 1e-10: # updating memory if len(old_yk) == history_size: @@ -659,7 +667,7 @@ def step(self, closure): ro.append(1.0 / ys) # update scale of initial Hessian approximation - H_diag = ys / y.dot(y) # (y*y) + H_diag = ys / dot(y, y) # (y*y) # compute the approximate (L-BFGS) inverse Hessian # multiplied by the gradient @@ -672,14 +680,14 @@ def step(self, closure): # iteration in L-BFGS loop collapsed to use just one buffer q = flat_grad.neg() for i in range(num_old - 1, -1, -1): - al[i] = old_sk[i].dot(q) * ro[i] + al[i] = dot(old_sk[i], q) * ro[i] paddle.assign(q.add(old_yk[i] * (-al[i])), q) # multiply by initial Hessian # r/d is the final direction d = r = paddle.multiply(q, H_diag) for i in range(num_old): - be_i = old_yk[i].dot(r) * ro[i] + be_i = dot(old_yk[i], r) * ro[i] paddle.assign(r.add(old_sk[i] * (al[i] - be_i)), r) if prev_flat_grad is None: @@ -700,7 +708,7 @@ def step(self, closure): alpha = learning_rate # directional derivative - gtd = flat_grad.dot(d) + gtd = dot(flat_grad, d) # directional derivative is below tolerance if gtd > -tolerance_change: From 5df9cdfdd9b8b3a3a65ff17e0db32e15435d3f88 Mon Sep 17 00:00:00 2001 From: xingmingyyj <135400902+xingmingyyj@users.noreply.github.com> Date: Mon, 8 Jan 2024 16:11:18 +0800 Subject: [PATCH 141/142] set_pir_tests_properties for some tests (#60401) * fix * Update CMakeLists.txt * Update pir_op_test_white_list * Update pir_op_test_white_list * Update pir_op_test_white_list --- test/auto_parallel/CMakeLists.txt | 2 ++ test/mkldnn/CMakeLists.txt | 2 ++ test/white_list/pir_op_test_white_list | 40 -------------------------- 3 files changed, 4 insertions(+), 40 deletions(-) diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index ab2b09680c5ad4..b102057f9973c0 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -297,3 +297,5 @@ endif() py_test_modules(test_job_schedule_profiler_range MODULES test_job_schedule_profiler_range) + +set_pit_tests_properties() diff --git a/test/mkldnn/CMakeLists.txt b/test/mkldnn/CMakeLists.txt index 16030200222e83..4f40752c69c873 100644 --- a/test/mkldnn/CMakeLists.txt +++ b/test/mkldnn/CMakeLists.txt @@ -31,3 +31,5 @@ if(WITH_MKLDNN AND NOT WIN32) PROPERTIES TIMEOUT 300) endif() # set_tests_properties(test_flags_mkldnn_ops_on_off PROPERTIES TIMEOUT 120) + +set_pit_tests_properties() diff --git a/test/white_list/pir_op_test_white_list b/test/white_list/pir_op_test_white_list index d696ab19863a76..d25cbcce075fe2 100644 --- a/test/white_list/pir_op_test_white_list +++ b/test/white_list/pir_op_test_white_list @@ -1,6 +1,5 @@ test_accuracy_op test_activation_bf16_mkldnn_op -test_activation_mkldnn_op test_activation_op test_adadelta_op test_adagrad_op @@ -23,7 +22,6 @@ test_auc_single_pred_op test_bce_loss test_bernoulli_op test_bicubic_interp_v2_op -test_bilinear_interp_v2_mkldnn_op test_bilinear_interp_v2_op test_bilinear_tensor_product_op test_bincount_op @@ -34,7 +32,6 @@ test_box_coder_op test_broadcast_error test_broadcast_tensors_op test_c_embedding_op -test_cast_mkldnn_op test_cast_op test_channel_shuffle test_cholesky_op @@ -48,20 +45,13 @@ test_compare_reduce_op test_complex_abs test_complex_op test_complex_view_op -test_concat_bf16_mkldnn_op test_concat_int8_mkldnn_op test_concat_mkldnn_op test_concat_op test_conj_op -test_conv2d_bf16_mkldnn_op -test_conv2d_int8_mkldnn_op -test_conv2d_mkldnn_op test_conv2d_op -test_conv2d_transpose_bf16_mkldnn_op -test_conv2d_transpose_mkldnn_op test_conv2d_transpose_op test_conv2d_transpose_op_depthwise_conv -test_conv3d_mkldnn_op test_conv3d_op test_conv3d_transpose_part2_op test_crop_tensor_op @@ -72,7 +62,6 @@ test_cumprod_op test_cumsum_op test_decayed_adagrad_op test_deformable_conv_op -test_dequantize_mkldnn_op test_determinant_op test_diag_embed test_diag_v2 @@ -91,15 +80,12 @@ test_eigh_op_static_build test_eigvals_op test_eigvalsh_op test_einsum_op -test_elementwise_add_bf16_mkldnn_op test_elementwise_div_op test_elementwise_floordiv_op test_elementwise_heaviside_op test_elementwise_max_op test_elementwise_min_op test_elementwise_mod_op -test_elementwise_mul_bf16_mkldnn_op -test_elementwise_mul_onednn_op test_elementwise_mul_op test_elementwise_pow_op test_erf_op @@ -108,11 +94,8 @@ test_expand_as_v2_op test_expand_v2_op test_exponential_op test_eye_op -test_fc_bf16_mkldnn_op -test_fc_mkldnn_op test_fill_any_like_op test_fill_any_op -test_fill_constant_batch_size_like test_fill_constant_op test_fill_diagonal_tensor_op test_flatten_contiguous_range_op @@ -130,12 +113,6 @@ test_fused_fc_elementwise_layernorm_op test_fused_feedforward_op test_fused_gate_attention_op test_fused_multihead_matmul_op -test_fusion_gru_bf16_mkldnn_op -test_fusion_gru_int8_mkldnn_op -test_fusion_gru_mkldnn_op -test_fusion_lstm_bf16_mkldnn_op -test_fusion_lstm_int8_mkldnn_op -test_fusion_lstm_mkldnn_op test_fusion_seqexpand_concat_fc_op test_fusion_transpose_flatten_concat_op test_gather_nd_op @@ -176,7 +153,6 @@ test_linear_interp_v2_op test_linspace test_log_loss_op test_log_softmax -test_log_softmax_mkldnn_op test_logcumsumexp_op test_logit_op test_logspace @@ -184,13 +160,10 @@ test_logsumexp test_lookup_table_v2_bf16_op test_lookup_table_v2_op test_lookup_table_v2_op_static_build -test_lrn_mkldnn_op test_lu_op test_lu_unpack_op test_margin_cross_entropy_op test_masked_select_op -test_matmul_bf16_mkldnn_op -test_matmul_mkldnn_op test_matmul_v2_op test_matmul_v2_op_static_build test_matrix_nms_op @@ -202,17 +175,14 @@ test_memcpy_op test_meshgrid_op test_mode_op test_momentum_op -test_mul_int8_mkldnn_op test_mul_op test_multi_dot_op test_multi_forward -test_multi_gru_mkldnn_op test_multinomial_op test_multiplex_op test_mv_op test_nanmedian test_nce -test_nearest_interp_mkldnn_op test_nearest_interp_v2_op test_nextafter_op test_nll_loss @@ -230,8 +200,6 @@ test_pass_quantization test_pixel_shuffle_op test_poisson_op test_polygamma_op -test_pool2d_int8_mkldnn_op -test_pool2d_mkldnn_op test_pool2d_op test_pool_max_op test_prelu_mkldnn_op @@ -240,17 +208,13 @@ test_prior_box_op test_psroi_pool_op test_put_along_axis_op test_qr_op -test_quantize_mkldnn_op test_randint_op test_randperm_op test_range -test_reduce_mkldnn_op test_reduce_op test_reduce_op_static_build test_repeat_interleave_op -test_requantize_mkldnn_op test_reshape_bf16_op -test_reshape_mkldnn_op test_reshape_op test_reverse_op test_roi_align_op @@ -258,7 +222,6 @@ test_roi_pool_op test_roll_op test_row_conv_op test_rrelu_op -test_scale_mkldnn_op test_scale_op test_scatter_nd_op test_scatter_op @@ -276,9 +239,7 @@ test_sigmoid_cross_entropy_with_logits_op test_sign_op test_size_op test_slice_op -test_softmax_bf16_mkldnn_op test_softmax_mask_fuse_upper_triangle_op -test_softmax_mkldnn_op test_softmax_op test_solve_op test_sparse_momentum_op @@ -288,7 +249,6 @@ test_split_mkldnn_op test_split_op test_squared_l2_norm_op test_squeeze2_op -test_sum_bf16_mkldnn_op test_sum_mkldnn_op test_svd_op test_take_along_axis_op From 5bb661dd04a8d6ce2e3ce29a08e882e2f2e96c63 Mon Sep 17 00:00:00 2001 From: xingmingyyj <135400902+xingmingyyj@users.noreply.github.com> Date: Mon, 8 Jan 2024 16:11:39 +0800 Subject: [PATCH 142/142] Add tests to whitelist (#60522) * fix * add --- test/white_list/pir_op_test_white_list | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/white_list/pir_op_test_white_list b/test/white_list/pir_op_test_white_list index d25cbcce075fe2..7e652790bdece2 100644 --- a/test/white_list/pir_op_test_white_list +++ b/test/white_list/pir_op_test_white_list @@ -6,6 +6,7 @@ test_adagrad_op test_adagrad_op_static_build test_adamax_op test_addmm_op +test_affine_grid_op test_allclose_op test_amp_check_finite_and_scale_op test_angle_op @@ -201,6 +202,7 @@ test_pixel_shuffle_op test_poisson_op test_polygamma_op test_pool2d_op +test_pool3d_op test_pool_max_op test_prelu_mkldnn_op test_prelu_op