From 55188f8f64a47ff515d5b8281d68ceafae5de655 Mon Sep 17 00:00:00 2001 From: ronny1996 Date: Tue, 23 May 2023 11:38:12 +0000 Subject: [PATCH 1/4] Fix custom pass with empty type --- paddle/fluid/framework/ir/generate_pass.cc | 8 ++++++-- paddle/fluid/framework/ir/generate_pass.h | 6 ++++-- paddle/fluid/framework/ir/pass.h | 5 +++-- paddle/fluid/pybind/pybind.cc | 2 +- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/framework/ir/generate_pass.cc b/paddle/fluid/framework/ir/generate_pass.cc index 61c6ce5757aa1..5ba8c0a5ab175 100644 --- a/paddle/fluid/framework/ir/generate_pass.cc +++ b/paddle/fluid/framework/ir/generate_pass.cc @@ -420,13 +420,17 @@ GraphPatternDetector::handle_t GetGenerateRewrite( return handler; } -GeneratePass::GeneratePass(const std::string& binary_str) { +GeneratePass::GeneratePass(const std::string& pass_type, + const std::string& binary_str) { + RegisterType(pass_type); multi_pass_desc_.ParseFromString(binary_str); VerifyDesc(); } -GeneratePass::GeneratePass(const proto::MultiPassDesc& multi_pass_desc) +GeneratePass::GeneratePass(const std::string& pass_type, + const proto::MultiPassDesc& multi_pass_desc) : multi_pass_desc_(multi_pass_desc) { + RegisterType(pass_type); VerifyDesc(); } diff --git a/paddle/fluid/framework/ir/generate_pass.h b/paddle/fluid/framework/ir/generate_pass.h index 192c963cfddcb..c20f51c6f0792 100644 --- a/paddle/fluid/framework/ir/generate_pass.h +++ b/paddle/fluid/framework/ir/generate_pass.h @@ -24,9 +24,11 @@ namespace ir { class GeneratePass : public Pass { public: // from binary_str - explicit GeneratePass(const std::string& binary_str); + explicit GeneratePass(const std::string& pass_type, + const std::string& binary_str); // from PassDesc/MultiPassDesc - explicit GeneratePass(const proto::MultiPassDesc& multi_pass_desc); + explicit GeneratePass(const std::string& pass_type, + const proto::MultiPassDesc& multi_pass_desc); protected: void ApplyImpl(Graph* graph) const override; diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index c7f76cb0bfdd5..1f59466e1cd80 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -185,6 +185,9 @@ class Pass { // Pass must be placed after this Pass. virtual void CheckPrevPass() const {} + protected: + void RegisterType(const std::string &type) { type_ = type; } + private: template friend struct PassRegistrar; @@ -207,8 +210,6 @@ class Pass { attrs_.insert(default_attr_values.begin(), default_attr_values.end()); } - void RegisterType(const std::string &type) { type_ = type; } - mutable bool applied_{false}; std::string type_; std::unordered_set required_pass_attrs_; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 59052a40cbc5a..8c5d32e8f845d 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -2276,7 +2276,7 @@ All parameter, weight, gradient are variables in Paddle. py::gil_scoped_acquire guard; std::unique_ptr pass( new framework::ir::GeneratePass( - py::cast(callable()))); + pass_type, py::cast(callable()))); return pass; }); }); From c0f0ee5eea55ad1f116d994d6d9c70de2d67a961 Mon Sep 17 00:00:00 2001 From: ronny1996 Date: Wed, 24 May 2023 02:22:41 +0000 Subject: [PATCH 2/4] update --- paddle/fluid/framework/ir/generate_pass.cc | 8 ++++---- paddle/fluid/framework/ir/generate_pass.h | 8 ++++---- paddle/fluid/pybind/pybind.cc | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/framework/ir/generate_pass.cc b/paddle/fluid/framework/ir/generate_pass.cc index 5ba8c0a5ab175..0088312c7b98d 100644 --- a/paddle/fluid/framework/ir/generate_pass.cc +++ b/paddle/fluid/framework/ir/generate_pass.cc @@ -420,15 +420,15 @@ GraphPatternDetector::handle_t GetGenerateRewrite( return handler; } -GeneratePass::GeneratePass(const std::string& pass_type, - const std::string& binary_str) { +GeneratePass::GeneratePass(const std::string& binary_str, + const std::string& pass_type) { RegisterType(pass_type); multi_pass_desc_.ParseFromString(binary_str); VerifyDesc(); } -GeneratePass::GeneratePass(const std::string& pass_type, - const proto::MultiPassDesc& multi_pass_desc) +GeneratePass::GeneratePass(const proto::MultiPassDesc& multi_pass_desc, + const std::string& pass_type) : multi_pass_desc_(multi_pass_desc) { RegisterType(pass_type); VerifyDesc(); diff --git a/paddle/fluid/framework/ir/generate_pass.h b/paddle/fluid/framework/ir/generate_pass.h index c20f51c6f0792..3a9d0f1efa71e 100644 --- a/paddle/fluid/framework/ir/generate_pass.h +++ b/paddle/fluid/framework/ir/generate_pass.h @@ -24,11 +24,11 @@ namespace ir { class GeneratePass : public Pass { public: // from binary_str - explicit GeneratePass(const std::string& pass_type, - const std::string& binary_str); + explicit GeneratePass(const std::string& binary_str, + const std::string& pass_type = ""); // from PassDesc/MultiPassDesc - explicit GeneratePass(const std::string& pass_type, - const proto::MultiPassDesc& multi_pass_desc); + explicit GeneratePass(const proto::MultiPassDesc& multi_pass_desc, + const std::string& pass_type = ""); protected: void ApplyImpl(Graph* graph) const override; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 8c5d32e8f845d..7e09266271ca7 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -2275,8 +2275,8 @@ All parameter, weight, gradient are variables in Paddle. pass_type, [pass_type, callable]() { py::gil_scoped_acquire guard; std::unique_ptr pass( - new framework::ir::GeneratePass( - pass_type, py::cast(callable()))); + new framework::ir::GeneratePass(py::cast(callable()), + pass_type)); return pass; }); }); From b9de760ce431d326cb7d426165aa5d2e50878fcd Mon Sep 17 00:00:00 2001 From: ronny1996 Date: Wed, 24 May 2023 02:43:22 +0000 Subject: [PATCH 3/4] fix pg_custom --- .../custom_device_common_op_registry.cc | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/custom_device_common_op_registry.cc b/paddle/fluid/operators/custom_device_common_op_registry.cc index a75106231e00e..3aa208628f366 100644 --- a/paddle/fluid/operators/custom_device_common_op_registry.cc +++ b/paddle/fluid/operators/custom_device_common_op_registry.cc @@ -98,8 +98,28 @@ class CConcatOpCustomDeviceKernel : public framework::OpKernel { auto task = pg->AllGather(in_tensor, out_tensor); task->Wait(); } else { - PADDLE_THROW(phi::errors::Unavailable( - "CustomDevice c_concat only support ProcessGroup")); + auto comm = platform::XCCLCommContext::Instance(place.GetDeviceType()) + .Get(rid, place); + PADDLE_ENFORCE_EQ( + nranks, + comm->nranks(), + platform::errors::InvalidArgument( + "nranks: %s should equal to %s", nranks, comm->nranks())); + + int64_t send_numel = x->numel(); + const T* send_buff = x->data(); + T* recv_buff = temp_out.data(); + // should ExecutionContext for calc stream. + auto& stream = *reinterpret_cast( + ctx.device_context()) + .GetStream(); + phi::DeviceManager::CCLAllGather(place.GetDeviceType(), + send_buff, + recv_buff, + send_numel, + phi::ccl::ToCCLDataType(x.dtype()), + comm->comm(), + stream); } std::vector inputs; int axis = x->dims().size() - 1; From e48fa3d2069051b83520af9b1cc62fa17da7d625 Mon Sep 17 00:00:00 2001 From: ronny1996 Date: Wed, 24 May 2023 03:03:04 +0000 Subject: [PATCH 4/4] update --- .../custom_device_common_op_registry.cc | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/custom_device_common_op_registry.cc b/paddle/fluid/operators/custom_device_common_op_registry.cc index 3aa208628f366..d5ae2f84b4ed1 100644 --- a/paddle/fluid/operators/custom_device_common_op_registry.cc +++ b/paddle/fluid/operators/custom_device_common_op_registry.cc @@ -60,6 +60,8 @@ class CConcatOpCustomDeviceKernel : public framework::OpKernel { int nranks = ctx.Attr("nranks"); int rank = ctx.Attr("rank"); int rid = ctx.Attr("ring_id"); + auto place = ctx.GetPlace(); + PADDLE_ENFORCE_GE(rank, 0, platform::errors::PreconditionNotMet( @@ -110,16 +112,15 @@ class CConcatOpCustomDeviceKernel : public framework::OpKernel { const T* send_buff = x->data(); T* recv_buff = temp_out.data(); // should ExecutionContext for calc stream. - auto& stream = *reinterpret_cast( - ctx.device_context()) - .GetStream(); - phi::DeviceManager::CCLAllGather(place.GetDeviceType(), - send_buff, - recv_buff, - send_numel, - phi::ccl::ToCCLDataType(x.dtype()), - comm->comm(), - stream); + auto& stream = *dev_ctx.GetStream(); + phi::DeviceManager::CCLAllGather( + place.GetDeviceType(), + reinterpret_cast(const_cast(send_buff)), + recv_buff, + send_numel, + phi::ccl::ToCCLDataType(x->dtype()), + comm->comm(), + stream); } std::vector inputs; int axis = x->dims().size() - 1;