From 87985be8fbe0164a3f526af929f024fe88c9b491 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 13 Aug 2020 06:14:50 +0800 Subject: [PATCH 1/2] Use shared ptr for net link. --- src/allreduce_base.cc | 72 ++++++++++++++++------------- src/allreduce_base.h | 8 ++-- src/allreduce_robust.cc | 100 ++++++++++++++++++++-------------------- src/allreduce_robust.h | 2 +- 4 files changed, 94 insertions(+), 88 deletions(-) diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index a5b199df..82bfb8f5 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -118,7 +118,7 @@ bool AllreduceBase::Init(int argc, char* argv[]) { bool AllreduceBase::Shutdown(void) { try { for (size_t i = 0; i < all_links.size(); ++i) { - all_links[i].sock.Close(); + all_links[i]->sock.Close(); } all_links.clear(); tree_links.plinks.clear(); @@ -321,10 +321,10 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) { // send over good links std::vector good_link; for (size_t i = 0; i < all_links.size(); ++i) { - if (!all_links[i].sock.BadSocket()) { - good_link.push_back(static_cast(all_links[i].rank)); + if (!all_links[i]->sock.BadSocket()) { + good_link.push_back(static_cast(all_links[i]->rank)); } else { - if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close(); + if (!all_links[i]->sock.IsClosed()) all_links[i]->sock.Close(); } } int ngood = static_cast(good_link.size()); @@ -340,7 +340,7 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) { sizeof(num_accept), "ReConnectLink failure 8"); num_error = 0; for (int i = 0; i < num_conn; ++i) { - LinkRecord r; + auto r = std::make_shared(); int hport, hrank; std::string hname; tracker.RecvStr(&hname); @@ -349,24 +349,24 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) { Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10"); - r.sock.Create(); - if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) { + r->sock.Create(); + if (!r->sock.Connect(utils::SockAddr(hname.c_str(), hport))) { num_error += 1; - r.sock.Close(); + r->sock.Close(); continue; } - Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), + Assert(r->sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 12"); - Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), + Assert(r->sock.RecvAll(&r->rank, sizeof(r->rank)) == sizeof(r->rank), "ReConnectLink failure 13"); - utils::Check(hrank == r.rank, + utils::Check(hrank == r->rank, "ReConnectLink failure, link rank inconsistent"); bool match = false; for (size_t i = 0; i < all_links.size(); ++i) { - if (all_links[i].rank == hrank) { - Assert(all_links[i].sock.IsClosed(), + if (all_links[i]->rank == hrank) { + Assert(all_links[i]->sock.IsClosed(), "Override a link that is active"); - all_links[i].sock = r.sock; + all_links[i]->sock = r->sock; match = true; break; } @@ -383,23 +383,25 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) { tracker.Close(); // listen to incoming links for (int i = 0; i < num_accept; ++i) { - LinkRecord r; - r.sock = sock_listen.Accept(); - Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), + auto r = std::make_shared(); + r->sock = sock_listen.Accept(); + Assert(r->sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 15"); - Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), + Assert(r->sock.RecvAll(&r->rank, sizeof(r->rank)) == sizeof(r->rank), "ReConnectLink failure 15"); bool match = false; for (size_t i = 0; i < all_links.size(); ++i) { - if (all_links[i].rank == r.rank) { - utils::Assert(all_links[i].sock.IsClosed(), + if (all_links[i]->rank == r->rank) { + utils::Assert(all_links[i]->sock.IsClosed(), "Override a link that is active"); - all_links[i].sock = r.sock; + all_links[i]->sock = r->sock; match = true; break; } } - if (!match) all_links.push_back(r); + if (!match) { + all_links.push_back(r); + } } sock_listen.Close(); this->parent_index = -1; @@ -407,28 +409,32 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) { tree_links.plinks.clear(); int tcpNoDelay = 1; for (size_t i = 0; i < all_links.size(); ++i) { - utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket"); + utils::Assert(!all_links[i]->sock.BadSocket(), "ReConnectLink: bad socket"); // set the socket to non-blocking mode, enable TCP keepalive - all_links[i].sock.SetNonBlock(true); - all_links[i].sock.SetKeepAlive(true); + all_links[i]->sock.SetNonBlock(true); + all_links[i]->sock.SetKeepAlive(true); if (rabit_enable_tcp_no_delay) { - setsockopt(all_links[i].sock, IPPROTO_TCP, + setsockopt(all_links[i]->sock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&tcpNoDelay), sizeof(tcpNoDelay)); } - if (tree_neighbors.count(all_links[i].rank) != 0) { - if (all_links[i].rank == parent_rank) { + if (tree_neighbors.count(all_links[i]->rank) != 0) { + if (all_links[i]->rank == parent_rank) { parent_index = static_cast(tree_links.plinks.size()); } - tree_links.plinks.push_back(&all_links[i]); + tree_links.plinks.push_back(all_links[i]); + } + if (all_links[i]->rank == prev_rank) { + ring_prev = all_links[i]; + } + if (all_links[i]->rank == next_rank) { + ring_next = all_links[i]; } - if (all_links[i].rank == prev_rank) ring_prev = &all_links[i]; - if (all_links[i].rank == next_rank) ring_next = &all_links[i]; } Assert(parent_rank == -1 || parent_index != -1, "cannot find parent in the link"); - Assert(prev_rank == -1 || ring_prev != NULL, + Assert(prev_rank == -1 || ring_prev != nullptr, "cannot find prev ring in the link"); - Assert(next_rank == -1 || ring_next != NULL, + Assert(next_rank == -1 || ring_next != nullptr, "cannot find next ring in the link"); return true; } catch (const std::exception& e) { diff --git a/src/allreduce_base.h b/src/allreduce_base.h index c7ef638f..3ce7de3a 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -401,11 +401,11 @@ class AllreduceBase : public IEngine { * but takes reference instead of space */ struct RefLinkVector { - std::vector plinks; + std::vector> plinks; inline LinkRecord &operator[](size_t i) { return *plinks[i]; } - inline size_t size(void) const { + inline size_t size() const { return plinks.size(); } }; @@ -536,13 +536,13 @@ class AllreduceBase : public IEngine { // rank of parent node, can be -1 int parent_rank; // sockets of all links this connects to - std::vector all_links; + std::vector> all_links; // used to record the link where things goes wrong LinkRecord *err_link; // all the links in the reduction tree connection RefLinkVector tree_links; // pointer to links in the ring - LinkRecord *ring_prev, *ring_next; + std::shared_ptr ring_prev, ring_next; //----- meta information----- // list of enviroment variables that are of possible interest std::vector env_vars; diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index de962055..b83fbbc5 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -568,30 +568,30 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { // number of links const int nlink = static_cast(all_links.size()); for (int i = 0; i < nlink; ++i) { - all_links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size); - all_links[i].ResetSize(); + all_links[i]->InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size); + all_links[i]->ResetSize(); } // read and discard data from all channels until pass mark while (true) { for (int i = 0; i < nlink; ++i) { - if (all_links[i].sock.BadSocket()) continue; - if (all_links[i].size_write == 0) { + if (all_links[i]->sock.BadSocket()) continue; + if (all_links[i]->size_write == 0) { char sig = kOOBReset; - ssize_t len = all_links[i].sock.Send(&sig, sizeof(sig), MSG_OOB); + ssize_t len = all_links[i]->sock.Send(&sig, sizeof(sig), MSG_OOB); // error will be filtered in next loop - if (len == sizeof(sig)) all_links[i].size_write = 1; + if (len == sizeof(sig)) all_links[i]->size_write = 1; } - if (all_links[i].size_write == 1) { + if (all_links[i]->size_write == 1) { char sig = kResetMark; - ssize_t len = all_links[i].sock.Send(&sig, sizeof(sig)); - if (len == sizeof(sig)) all_links[i].size_write = 2; + ssize_t len = all_links[i]->sock.Send(&sig, sizeof(sig)); + if (len == sizeof(sig)) all_links[i]->size_write = 2; } } utils::PollHelper rsel; bool finished = true; for (int i = 0; i < nlink; ++i) { - if (all_links[i].size_write != 2 && !all_links[i].sock.BadSocket()) { - rsel.WatchWrite(all_links[i].sock); finished = false; + if (all_links[i]->size_write != 2 && !all_links[i]->sock.BadSocket()) { + rsel.WatchWrite(all_links[i]->sock); finished = false; } } if (finished) break; @@ -599,56 +599,56 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { rsel.Poll(); } for (int i = 0; i < nlink; ++i) { - if (!all_links[i].sock.BadSocket()) { - utils::PollHelper::WaitExcept(all_links[i].sock); + if (!all_links[i]->sock.BadSocket()) { + utils::PollHelper::WaitExcept(all_links[i]->sock); } } while (true) { utils::PollHelper rsel; bool finished = true; for (int i = 0; i < nlink; ++i) { - if (all_links[i].size_read == 0 && !all_links[i].sock.BadSocket()) { - rsel.WatchRead(all_links[i].sock); finished = false; + if (all_links[i]->size_read == 0 && !all_links[i]->sock.BadSocket()) { + rsel.WatchRead(all_links[i]->sock); finished = false; } } if (finished) break; rsel.Poll(); for (int i = 0; i < nlink; ++i) { - if (all_links[i].sock.BadSocket()) continue; - if (all_links[i].size_read == 0) { - int atmark = all_links[i].sock.AtMark(); + if (all_links[i]->sock.BadSocket()) continue; + if (all_links[i]->size_read == 0) { + int atmark = all_links[i]->sock.AtMark(); if (atmark < 0) { - _assert(all_links[i].sock.BadSocket(), "must already gone bad"); + _assert(all_links[i]->sock.BadSocket(), "must already gone bad"); } else if (atmark > 0) { - all_links[i].size_read = 1; + all_links[i]->size_read = 1; } else { // no at mark, read and discard data - ssize_t len = all_links[i].sock.Recv(all_links[i].buffer_head, all_links[i].buffer_size); - if (all_links[i].sock.AtMark()) all_links[i].size_read = 1; + ssize_t len = all_links[i]->sock.Recv(all_links[i]->buffer_head, all_links[i]->buffer_size); + if (all_links[i]->sock.AtMark()) all_links[i]->size_read = 1; // zero length, remote closed the connection, close socket - if (len == 0) all_links[i].sock.Close(); + if (len == 0) all_links[i]->sock.Close(); } } } } // start synchronization, use blocking I/O to avoid select for (int i = 0; i < nlink; ++i) { - if (!all_links[i].sock.BadSocket()) { + if (!all_links[i]->sock.BadSocket()) { char oob_mark; - all_links[i].sock.SetNonBlock(false); - ssize_t len = all_links[i].sock.Recv(&oob_mark, sizeof(oob_mark), MSG_WAITALL); + all_links[i]->sock.SetNonBlock(false); + ssize_t len = all_links[i]->sock.Recv(&oob_mark, sizeof(oob_mark), MSG_WAITALL); if (len == 0) { - all_links[i].sock.Close(); continue; + all_links[i]->sock.Close(); continue; } else if (len > 0) { _assert(oob_mark == kResetMark, "wrong oob msg"); - _assert(all_links[i].sock.AtMark() != 1, "should already read past mark"); + _assert(all_links[i]->sock.AtMark() != 1, "should already read past mark"); } else { _assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG"); } // send out ack char ack = kResetAck; while (true) { - len = all_links[i].sock.Send(&ack, sizeof(ack)); + len = all_links[i]->sock.Send(&ack, sizeof(ack)); if (len == sizeof(ack)) break; if (len == -1) { if (errno != EAGAIN && errno != EWOULDBLOCK) break; @@ -658,22 +658,22 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { } // wait all ack for (int i = 0; i < nlink; ++i) { - if (!all_links[i].sock.BadSocket()) { + if (!all_links[i]->sock.BadSocket()) { char ack; - ssize_t len = all_links[i].sock.Recv(&ack, sizeof(ack), MSG_WAITALL); + ssize_t len = all_links[i]->sock.Recv(&ack, sizeof(ack), MSG_WAITALL); if (len == 0) { - all_links[i].sock.Close(); continue; + all_links[i]->sock.Close(); continue; } else if (len > 0) { _assert(ack == kResetAck, "wrong Ack MSG"); } else { _assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG"); } // set back to nonblock mode - all_links[i].sock.SetNonBlock(true); + all_links[i]->sock.SetNonBlock(true); } } for (int i = 0; i < nlink; ++i) { - if (all_links[i].sock.BadSocket()) return kSockError; + if (all_links[i]->sock.BadSocket()) return kSockError; } return kSuccess; } @@ -716,7 +716,7 @@ bool AllreduceRobust::CheckAndRecover(ReturnType err_type) { } // simple way, shutdown all links for (size_t i = 0; i < all_links.size(); ++i) { - if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close(); + if (!all_links[i]->sock.BadSocket()) all_links[i]->sock.Close(); } // smooth out traffic to tracker std::this_thread::sleep_for(std::chrono::milliseconds(10*rank)); @@ -1067,12 +1067,12 @@ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) { "LoadCheckPoint: too many nodes fails, cannot recover local state"); } // do call save model if the checkpoint was lazy - if (role == kHaveData && global_lazycheck != NULL) { + if (role == kHaveData && global_lazycheck != nullptr) { global_checkpoint.resize(0); utils::MemoryBufferStream fs(&global_checkpoint); fs.Write(&version_number, sizeof(version_number)); global_lazycheck->Save(&fs); - global_lazycheck = NULL; + global_lazycheck = nullptr; } // recover global checkpoint size_t size = this->global_checkpoint.length(); @@ -1119,7 +1119,7 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re RecoverType role; if (!requester) { sendrecvbuf = resbuf.Query(seqno, &size); - role = sendrecvbuf != NULL ? kHaveData : kPassData; + role = sendrecvbuf != nullptr ? kHaveData : kPassData; } else { role = kRequestData; } @@ -1226,7 +1226,7 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno, if (!requester) { _assert(req.check_point(), "checkpoint node should be KHaveData role"); buf = resbuf.Query(act.seqno(), &size); - _assert(buf != NULL, "buf should have data from resbuf"); + _assert(buf != nullptr, "buf should have data from resbuf"); _assert(size > 0, "buf size should be greater than 0"); } if (!CheckAndRecover(TryGetResult(buf, size, act.seqno(), requester))) continue; @@ -1348,14 +1348,14 @@ AllreduceRobust::TryRecoverLocalState(std::vector *p_local_rptr, succ = RingPassing(BeginPtr(msg_back), 1 * sizeof(int), (n+1) * sizeof(int), 0 * sizeof(int), n * sizeof(int), - ring_next, ring_prev); + ring_next.get(), ring_prev.get()); if (succ != kSuccess) return succ; int msg_forward[2]; msg_forward[0] = nlocal; succ = RingPassing(msg_forward, 1 * sizeof(int), 2 * sizeof(int), 0 * sizeof(int), 1 * sizeof(int), - ring_prev, ring_next); + ring_prev.get(), ring_next.get()); if (succ != kSuccess) return succ; // calculate the number of things we can read from next link int nread_end = nlocal; @@ -1375,7 +1375,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector *p_local_rptr, nread_end * sizeof(size_t), nwrite_start * sizeof(size_t), nread_end * sizeof(size_t), - ring_next, ring_prev); + ring_next.get(), ring_prev.get()); if (succ != kSuccess) return succ; // update rptr rptr.resize(nread_end + 1); @@ -1386,7 +1386,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector *p_local_rptr, // pass data through the link succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end], rptr[nwrite_start], rptr[nread_end], - ring_next, ring_prev); + ring_next.get(), ring_prev.get()); if (succ != kSuccess) { rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ; } @@ -1402,14 +1402,14 @@ AllreduceRobust::TryRecoverLocalState(std::vector *p_local_rptr, succ = RingPassing(BeginPtr(msg_forward), 1 * sizeof(int), (n+1) * sizeof(int), 0 * sizeof(int), n * sizeof(int), - ring_prev, ring_next); + ring_prev.get(), ring_next.get()); if (succ != kSuccess) return succ; int msg_back[2]; msg_back[0] = nlocal; succ = RingPassing(msg_back, 1 * sizeof(int), 2 * sizeof(int), 0 * sizeof(int), 1 * sizeof(int), - ring_next, ring_prev); + ring_next.get(), ring_prev.get()); if (succ != kSuccess) return succ; // calculate the number of things we can read from next link int nread_end = nlocal, nwrite_end = 1; @@ -1439,7 +1439,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector *p_local_rptr, nread_end * sizeof(size_t), nwrite_start * sizeof(size_t), nwrite_end * sizeof(size_t), - ring_prev, ring_next); + ring_prev.get(), ring_next.get()); if (succ != kSuccess) return succ; // update rptr rptr.resize(nread_end + 1); @@ -1450,7 +1450,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector *p_local_rptr, // pass data through the link succ = RingPassing(BeginPtr(chkpt), rptr[nlocal], rptr[nread_end], rptr[nwrite_start], rptr[nwrite_end], - ring_prev, ring_next); + ring_prev.get(), ring_next.get()); if (succ != kSuccess) { rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ; } @@ -1491,7 +1491,7 @@ AllreduceRobust::TryCheckinLocalState(std::vector *p_local_rptr, (n + 1) * sizeof(size_t), 0 * sizeof(size_t), n * sizeof(size_t), - ring_prev, ring_next); + ring_prev.get(), ring_next.get()); if (succ != kSuccess) return succ; // update rptr rptr.resize(n + 2); @@ -1503,7 +1503,7 @@ AllreduceRobust::TryCheckinLocalState(std::vector *p_local_rptr, succ = RingPassing(BeginPtr(chkpt), rptr[1], rptr[n + 1], rptr[0], rptr[n], - ring_prev, ring_next); + ring_prev.get(), ring_next.get()); if (succ != kSuccess) { rptr.resize(2); chkpt.resize(rptr.back()); return succ; } @@ -1534,7 +1534,7 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_, size_t write_end, LinkRecord *read_link, LinkRecord *write_link) { - if (read_link == NULL || write_link == NULL || read_end == 0) return kSuccess; + if (read_link == nullptr || write_link == nullptr || read_end == 0) return kSuccess; _assert(write_end <= read_end, "RingPassing: boundary check1"); _assert(read_ptr <= read_end, "RingPassing: boundary check2"); diff --git a/src/allreduce_robust.h b/src/allreduce_robust.h index a4bee7c5..a6dfd9d6 100644 --- a/src/allreduce_robust.h +++ b/src/allreduce_robust.h @@ -187,7 +187,7 @@ class AllreduceRobust : public AllreduceBase { virtual void InitAfterException(void) { // simple way, shutdown all links for (size_t i = 0; i < all_links.size(); ++i) { - if (!all_links[i].sock.BadSocket()) all_links[i].sock.Close(); + if (!all_links[i]->sock.BadSocket()) all_links[i]->sock.Close(); } ReConnectLinks("recover"); } From d58f31e3d49d6fc3df2d53f5460d88b68acab475 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 14 Aug 2020 15:26:13 +0800 Subject: [PATCH 2/2] Lint. --- src/allreduce_base.cc | 2 ++ src/allreduce_base.h | 3 ++- src/allreduce_robust.cc | 3 ++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index 82bfb8f5..532420c9 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -7,9 +7,11 @@ */ #define NOMINMAX #include "allreduce_base.h" + #include #include #include +#include #include namespace rabit { diff --git a/src/allreduce_base.h b/src/allreduce_base.h index 3ce7de3a..fed08d6d 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -15,6 +15,8 @@ #include #include #include +#include + #include "rabit/internal/utils.h" #include "rabit/internal/engine.h" #include "rabit/internal/socket.h" @@ -24,7 +26,6 @@ #define protected public #endif // RABIT_CXXTESTDEFS_H - namespace MPI { // MPI data type to be compatible with existing MPI interface class Datatype { diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index b83fbbc5..a49f6a49 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -623,7 +623,8 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { all_links[i]->size_read = 1; } else { // no at mark, read and discard data - ssize_t len = all_links[i]->sock.Recv(all_links[i]->buffer_head, all_links[i]->buffer_size); + ssize_t len = all_links[i]->sock.Recv(all_links[i]->buffer_head, + all_links[i]->buffer_size); if (all_links[i]->sock.AtMark()) all_links[i]->size_read = 1; // zero length, remote closed the connection, close socket if (len == 0) all_links[i]->sock.Close();