Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use shared ptr for net link. #144

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 39 additions & 33 deletions src/allreduce_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ bool AllreduceBase::Init(int argc, char* argv[]) {
bool AllreduceBase::Shutdown(void) {
try {
for (size_t i = 0; i < all_links.size(); ++i) {
all_links[i].sock.Close();
all_links[i]->sock.Close();
}
all_links.clear();
tree_links.plinks.clear();
Expand Down Expand Up @@ -321,10 +321,10 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
// send over good links
std::vector<int> good_link;
for (size_t i = 0; i < all_links.size(); ++i) {
if (!all_links[i].sock.BadSocket()) {
good_link.push_back(static_cast<int>(all_links[i].rank));
if (!all_links[i]->sock.BadSocket()) {
good_link.push_back(static_cast<int>(all_links[i]->rank));
} else {
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
if (!all_links[i]->sock.IsClosed()) all_links[i]->sock.Close();
}
}
int ngood = static_cast<int>(good_link.size());
Expand All @@ -340,7 +340,7 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
sizeof(num_accept), "ReConnectLink failure 8");
num_error = 0;
for (int i = 0; i < num_conn; ++i) {
LinkRecord r;
auto r = std::make_shared<LinkRecord>();
int hport, hrank;
std::string hname;
tracker.RecvStr(&hname);
Expand All @@ -349,24 +349,24 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank),
"ReConnectLink failure 10");

r.sock.Create();
if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
r->sock.Create();
if (!r->sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
num_error += 1;
r.sock.Close();
r->sock.Close();
continue;
}
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
Assert(r->sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
"ReConnectLink failure 12");
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
Assert(r->sock.RecvAll(&r->rank, sizeof(r->rank)) == sizeof(r->rank),
"ReConnectLink failure 13");
utils::Check(hrank == r.rank,
utils::Check(hrank == r->rank,
"ReConnectLink failure, link rank inconsistent");
bool match = false;
for (size_t i = 0; i < all_links.size(); ++i) {
if (all_links[i].rank == hrank) {
Assert(all_links[i].sock.IsClosed(),
if (all_links[i]->rank == hrank) {
Assert(all_links[i]->sock.IsClosed(),
"Override a link that is active");
all_links[i].sock = r.sock;
all_links[i]->sock = r->sock;
match = true;
break;
}
Expand All @@ -383,52 +383,58 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
tracker.Close();
// listen to incoming links
for (int i = 0; i < num_accept; ++i) {
LinkRecord r;
r.sock = sock_listen.Accept();
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
auto r = std::make_shared<LinkRecord>();
r->sock = sock_listen.Accept();
Assert(r->sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
"ReConnectLink failure 15");
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
Assert(r->sock.RecvAll(&r->rank, sizeof(r->rank)) == sizeof(r->rank),
"ReConnectLink failure 15");
bool match = false;
for (size_t i = 0; i < all_links.size(); ++i) {
if (all_links[i].rank == r.rank) {
utils::Assert(all_links[i].sock.IsClosed(),
if (all_links[i]->rank == r->rank) {
utils::Assert(all_links[i]->sock.IsClosed(),
"Override a link that is active");
all_links[i].sock = r.sock;
all_links[i]->sock = r->sock;
match = true;
break;
}
}
if (!match) all_links.push_back(r);
if (!match) {
all_links.push_back(r);
}
}
sock_listen.Close();
this->parent_index = -1;
// setup tree links and ring structure
tree_links.plinks.clear();
int tcpNoDelay = 1;
for (size_t i = 0; i < all_links.size(); ++i) {
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
utils::Assert(!all_links[i]->sock.BadSocket(), "ReConnectLink: bad socket");
// set the socket to non-blocking mode, enable TCP keepalive
all_links[i].sock.SetNonBlock(true);
all_links[i].sock.SetKeepAlive(true);
all_links[i]->sock.SetNonBlock(true);
all_links[i]->sock.SetKeepAlive(true);
if (rabit_enable_tcp_no_delay) {
setsockopt(all_links[i].sock, IPPROTO_TCP,
setsockopt(all_links[i]->sock, IPPROTO_TCP,
TCP_NODELAY, reinterpret_cast<void *>(&tcpNoDelay), sizeof(tcpNoDelay));
}
if (tree_neighbors.count(all_links[i].rank) != 0) {
if (all_links[i].rank == parent_rank) {
if (tree_neighbors.count(all_links[i]->rank) != 0) {
if (all_links[i]->rank == parent_rank) {
parent_index = static_cast<int>(tree_links.plinks.size());
}
tree_links.plinks.push_back(&all_links[i]);
tree_links.plinks.push_back(all_links[i]);
}
if (all_links[i]->rank == prev_rank) {
ring_prev = all_links[i];
}
if (all_links[i]->rank == next_rank) {
ring_next = all_links[i];
}
if (all_links[i].rank == prev_rank) ring_prev = &all_links[i];
if (all_links[i].rank == next_rank) ring_next = &all_links[i];
}
Assert(parent_rank == -1 || parent_index != -1,
"cannot find parent in the link");
Assert(prev_rank == -1 || ring_prev != NULL,
Assert(prev_rank == -1 || ring_prev != nullptr,
"cannot find prev ring in the link");
Assert(next_rank == -1 || ring_next != NULL,
Assert(next_rank == -1 || ring_next != nullptr,
"cannot find next ring in the link");
return true;
} catch (const std::exception& e) {
Expand Down
8 changes: 4 additions & 4 deletions src/allreduce_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,11 +401,11 @@ class AllreduceBase : public IEngine {
* but takes reference instead of space
*/
struct RefLinkVector {
std::vector<LinkRecord*> plinks;
std::vector<std::shared_ptr<LinkRecord>> plinks;
inline LinkRecord &operator[](size_t i) {
return *plinks[i];
}
inline size_t size(void) const {
inline size_t size() const {
return plinks.size();
}
};
Expand Down Expand Up @@ -536,13 +536,13 @@ class AllreduceBase : public IEngine {
// rank of parent node, can be -1
int parent_rank;
// sockets of all links this connects to
std::vector<LinkRecord> all_links;
std::vector<std::shared_ptr<LinkRecord>> all_links;
// used to record the link where things goes wrong
LinkRecord *err_link;
// all the links in the reduction tree connection
RefLinkVector tree_links;
// pointer to links in the ring
LinkRecord *ring_prev, *ring_next;
std::shared_ptr<LinkRecord> ring_prev, ring_next;
//----- meta information-----
// list of enviroment variables that are of possible interest
std::vector<std::string> env_vars;
Expand Down
Loading