Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove ActorMsg::user_data_ #9762

Merged
merged 12 commits into from
Jan 28, 2023
28 changes: 12 additions & 16 deletions oneflow/core/comm_network/ibverbs/ibverbs_comm_network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,30 +79,26 @@ IBVerbsCommNet::~IBVerbsCommNet() {
}

void IBVerbsCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) {
ActorMsg new_msg = msg;
IBVerbsActorMsgWrapper msg_wrapper;
msg_wrapper.msg = msg;
if (msg.IsDataRegstMsgToConsumer()) {
CHECK_EQ(msg.user_data_size(), 0);
auto* mem_desc = reinterpret_cast<IBVerbsMemDesc*>(msg.regst()->comm_net_token());
CHECK(mem_desc != nullptr);
IBVerbsCommNetRMADesc rma_desc{};
rma_desc.mem_ptr = reinterpret_cast<uint64_t>(mem_desc->mem_ptr());
rma_desc.mem_size = mem_desc->mem_size();
rma_desc.mr_rkey = mem_desc->mr()->rkey;
static_assert(sizeof(IBVerbsCommNetRMADesc) <= kActorMsgUserDataMaxSize, "");
new_msg.AddUserData(sizeof(IBVerbsCommNetRMADesc), &rma_desc);
msg_wrapper.rma_desc.mem_ptr = reinterpret_cast<uint64_t>(mem_desc->mem_ptr());
msg_wrapper.rma_desc.mem_size = mem_desc->mem_size();
msg_wrapper.rma_desc.mr_rkey = mem_desc->mr()->rkey;
}
qp_vec_.at(dst_machine_id)->PostSendRequest(new_msg);
qp_vec_.at(dst_machine_id)->PostSendRequest(msg_wrapper);
}

void IBVerbsCommNet::RecvActorMsg(const ActorMsg& msg) {
ActorMsg new_msg = msg;
if (msg.IsDataRegstMsgToConsumer()) {
void IBVerbsCommNet::RecvActorMsg(const IBVerbsActorMsgWrapper& msg_wrapper) {
ActorMsg new_msg = msg_wrapper.msg;
if (msg_wrapper.msg.IsDataRegstMsgToConsumer()) {
std::lock_guard<std::mutex> lock(remote_regst2rma_desc_mutex_);
auto& desc = remote_regst2rma_desc_[std::make_pair(msg.src_actor_id(),
reinterpret_cast<uint64_t>(msg.regst()))];
auto& desc = remote_regst2rma_desc_[std::make_pair(
msg_wrapper.msg.src_actor_id(), reinterpret_cast<uint64_t>(msg_wrapper.msg.regst()))];
if (!desc) { desc.reset(new IBVerbsCommNetRMADesc); }
CHECK_EQ(msg.user_data_size(), sizeof(IBVerbsCommNetRMADesc));
std::memcpy(desc.get(), msg.user_data(), sizeof(IBVerbsCommNetRMADesc));
*desc = msg_wrapper.rma_desc;
new_msg.set_comm_net_token(desc.get());
}
Singleton<ActorMsgBus>::Get()->SendMsgWithoutCommNet(new_msg);
Expand Down
8 changes: 1 addition & 7 deletions oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,13 @@ limitations under the License.

namespace oneflow {

struct IBVerbsCommNetRMADesc {
uint64_t mem_ptr;
uint64_t mem_size;
uint32_t mr_rkey;
};

class IBVerbsCommNet final : public CommNetIf<IBVerbsMemDesc> {
public:
OF_DISALLOW_COPY_AND_MOVE(IBVerbsCommNet);
~IBVerbsCommNet();

void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) override;
void RecvActorMsg(const ActorMsg& msg);
void RecvActorMsg(const IBVerbsActorMsgWrapper& msg_wrapper);

private:
friend class Singleton<IBVerbsCommNet>;
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/comm_network/ibverbs/ibverbs_qp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ void IBVerbsQP::PostReadRequest(const IBVerbsCommNetRMADesc& remote_mem,
}
}

void IBVerbsQP::PostSendRequest(const ActorMsg& msg) {
void IBVerbsQP::PostSendRequest(const IBVerbsActorMsgWrapper& msg_wrapper) {
ActorMsgMR* msg_mr = GetOneSendMsgMRFromBuf();
msg_mr->set_msg(msg);
msg_mr->set_msg(msg_wrapper);
WorkRequestId* wr_id = NewWorkRequestId();
wr_id->msg_mr = msg_mr;
ibv_send_wr wr{};
Expand Down
19 changes: 15 additions & 4 deletions oneflow/core/comm_network/ibverbs/ibverbs_qp.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,30 @@ limitations under the License.

namespace oneflow {

struct IBVerbsCommNetRMADesc {
uint64_t mem_ptr;
uint64_t mem_size;
uint32_t mr_rkey;
};

struct IBVerbsActorMsgWrapper final {
ActorMsg msg;
IBVerbsCommNetRMADesc rma_desc;
};

class ActorMsgMR final {
public:
OF_DISALLOW_COPY_AND_MOVE(ActorMsgMR);
ActorMsgMR() = delete;
ActorMsgMR(ibv_pd* pd) { mem_desc_.reset(new IBVerbsMemDesc(pd, &msg_, sizeof(msg_))); }
~ActorMsgMR() { mem_desc_.reset(); }

const ActorMsg& msg() const { return msg_; }
void set_msg(const ActorMsg& val) { msg_ = val; }
const IBVerbsActorMsgWrapper& msg() const { return msg_; }
void set_msg(const IBVerbsActorMsgWrapper& val) { msg_ = val; }
const IBVerbsMemDesc& mem_desc() const { return *mem_desc_; }

private:
ActorMsg msg_;
IBVerbsActorMsgWrapper msg_;
std::unique_ptr<IBVerbsMemDesc> mem_desc_;
};

Expand Down Expand Up @@ -64,7 +75,7 @@ class IBVerbsQP final {

void PostReadRequest(const IBVerbsCommNetRMADesc& remote_mem, const IBVerbsMemDesc& local_mem,
void* read_id);
void PostSendRequest(const ActorMsg& msg);
void PostSendRequest(const IBVerbsActorMsgWrapper& msg_wrapper);

void ReadDone(WorkRequestId*);
void SendDone(WorkRequestId*);
Expand Down
11 changes: 0 additions & 11 deletions oneflow/core/lazy/actor/actor_message.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,6 @@ int64_t ActorMsg::eord_regst_desc_id() const {
return eord_regst_desc_id_;
}

void ActorMsg::AddUserData(uint8_t size, const void* data) {
CHECK_EQ(user_data_size_, 0);
CHECK_LE(size, kActorMsgUserDataMaxSize);
user_data_size_ = size;
std::memcpy(user_data_, data, size);
}

uint8_t ActorMsg::user_data_size() const { return user_data_size_; }

const void* ActorMsg::user_data() const { return user_data_; }

bool ActorMsg::IsDataRegstMsgToConsumer() const {
return msg_type_ == ActorMsgType::kRegstMsg && regst_wrapper_.is_data_regst_to_consumer;
}
Expand Down
7 changes: 0 additions & 7 deletions oneflow/core/lazy/actor/actor_message.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ enum class ActorCmd {

enum class ActorMsgType : int8_t { kRegstMsg = 0, kEordMsg, kCmdMsg };

constexpr uint8_t kActorMsgUserDataMaxSize = 32;

class ActorMsg final {
public:
ActorMsg() = default;
Expand All @@ -54,9 +52,6 @@ class ActorMsg final {
void set_comm_net_token(void* token);
bool has_sole_empty_blob() const;
int64_t eord_regst_desc_id() const;
void AddUserData(uint8_t size, const void* data);
uint8_t user_data_size() const;
const void* user_data() const;
bool IsDataRegstMsgToConsumer() const;
int64_t comm_net_sequence_number() const;
void set_comm_net_sequence_number(int64_t sequence_number);
Expand Down Expand Up @@ -91,8 +86,6 @@ class ActorMsg final {
int64_t eord_regst_desc_id_;
};
ActorMsgType msg_type_;
uint8_t user_data_size_;
unsigned char user_data_[kActorMsgUserDataMaxSize];
};

template<typename StreamT>
Expand Down