Skip to content

Commit

Permalink
Fix the custom pass with empty type (#54065)
Browse files Browse the repository at this point in the history
  • Loading branch information
ronny1996 authored May 25, 2023
1 parent 23baa8c commit 43d6bdc
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 10 deletions.
8 changes: 6 additions & 2 deletions paddle/fluid/framework/ir/generate_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,13 +420,17 @@ GraphPatternDetector::handle_t GetGenerateRewrite(
return handler;
}

GeneratePass::GeneratePass(const std::string& binary_str) {
GeneratePass::GeneratePass(const std::string& binary_str,
const std::string& pass_type) {
RegisterType(pass_type);
multi_pass_desc_.ParseFromString(binary_str);
VerifyDesc();
}

GeneratePass::GeneratePass(const proto::MultiPassDesc& multi_pass_desc)
GeneratePass::GeneratePass(const proto::MultiPassDesc& multi_pass_desc,
const std::string& pass_type)
: multi_pass_desc_(multi_pass_desc) {
RegisterType(pass_type);
VerifyDesc();
}

Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/framework/ir/generate_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ namespace ir {
class GeneratePass : public Pass {
public:
// from binary_str
explicit GeneratePass(const std::string& binary_str);
explicit GeneratePass(const std::string& binary_str,
const std::string& pass_type = "");
// from PassDesc/MultiPassDesc
explicit GeneratePass(const proto::MultiPassDesc& multi_pass_desc);
explicit GeneratePass(const proto::MultiPassDesc& multi_pass_desc,
const std::string& pass_type = "");

protected:
void ApplyImpl(Graph* graph) const override;
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/framework/ir/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ class Pass {
// Pass must be placed after this Pass.
virtual void CheckPrevPass() const {}

protected:
void RegisterType(const std::string &type) { type_ = type; }

private:
template <typename PassType>
friend struct PassRegistrar;
Expand All @@ -207,8 +210,6 @@ class Pass {
attrs_.insert(default_attr_values.begin(), default_attr_values.end());
}

void RegisterType(const std::string &type) { type_ = type; }

mutable bool applied_{false};
std::string type_;
std::unordered_set<std::string> required_pass_attrs_;
Expand Down
25 changes: 23 additions & 2 deletions paddle/fluid/operators/custom_device_common_op_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class CConcatOpCustomDeviceKernel : public framework::OpKernel<T> {
int nranks = ctx.Attr<int>("nranks");
int rank = ctx.Attr<int>("rank");
int rid = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();

PADDLE_ENFORCE_GE(rank,
0,
platform::errors::PreconditionNotMet(
Expand Down Expand Up @@ -98,8 +100,27 @@ class CConcatOpCustomDeviceKernel : public framework::OpKernel<T> {
auto task = pg->AllGather(in_tensor, out_tensor);
task->Wait();
} else {
PADDLE_THROW(phi::errors::Unavailable(
"CustomDevice c_concat only support ProcessGroup"));
auto comm = platform::XCCLCommContext::Instance(place.GetDeviceType())
.Get(rid, place);
PADDLE_ENFORCE_EQ(
nranks,
comm->nranks(),
platform::errors::InvalidArgument(
"nranks: %s should equal to %s", nranks, comm->nranks()));

int64_t send_numel = x->numel();
const T* send_buff = x->data<T>();
T* recv_buff = temp_out.data<T>();
// should ExecutionContext for calc stream.
auto& stream = *dev_ctx.GetStream();
phi::DeviceManager::CCLAllGather(
place.GetDeviceType(),
reinterpret_cast<void*>(const_cast<T*>(send_buff)),
recv_buff,
send_numel,
phi::ccl::ToCCLDataType(x->dtype()),
comm->comm(),
stream);
}
std::vector<phi::DenseTensor> inputs;
int axis = x->dims().size() - 1;
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2275,8 +2275,8 @@ All parameter, weight, gradient are variables in Paddle.
pass_type, [pass_type, callable]() {
py::gil_scoped_acquire guard;
std::unique_ptr<framework::ir::Pass> pass(
new framework::ir::GeneratePass(
py::cast<std::string>(callable())));
new framework::ir::GeneratePass(py::cast<std::string>(callable()),
pass_type));
return pass;
});
});
Expand Down

0 comments on commit 43d6bdc

Please sign in to comment.