Skip to content

Commit

Permalink
unittests mock, cleanup (#111)
Browse files Browse the repository at this point in the history
* cleanup, fix issue involved after remove is_bootstrap parameter

* misc

* clean

* add unittests
  • Loading branch information
chenqin authored and CodingCat committed Oct 1, 2019
1 parent ddcc2d8 commit af7281a
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 35 deletions.
16 changes: 15 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
option(DMLC_ROOT "Specify root of external dmlc core.")

add_library(allreduce_base "")
add_library(allreduce_mock "")

target_sources(
allreduce_base
Expand All @@ -9,9 +10,22 @@ target_sources(
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/allreduce_base.h
)
target_sources(
allreduce_mock
PRIVATE
allreduce_robust.cc
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/allreduce_mock.h
)

target_include_directories(
allreduce_base
PUBLIC
${DMLC_ROOT}/include
${DMLC_ROOT}/include
${CMAKE_CURRENT_LIST_DIR}/../../include)

target_include_directories(
allreduce_mock
PUBLIC
${DMLC_ROOT}/include
${CMAKE_CURRENT_LIST_DIR}/../../include)
19 changes: 13 additions & 6 deletions src/allreduce_mock.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,23 @@ class AllreduceMock : public AllreduceRobust {
size_t count,
ReduceFunction reducer,
PreprocFunction prepare_fun,
void *prepare_arg) {
void *prepare_arg,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
double tstart = utils::GetTime();
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
count, reducer, prepare_fun, prepare_arg);
count, reducer, prepare_fun, prepare_arg,
_file, _line, _caller);
tsum_allreduce += utils::GetTime() - tstart;
}
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root,
const char* _file = _FILE,
const int _line = _LINE,
const char* _caller = _CALLER) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root);
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root, _file, _line, _caller);
}
virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model) {
Expand Down Expand Up @@ -168,8 +175,8 @@ class AllreduceMock : public AllreduceRobust {
inline void Verify(const MockKey &key, const char *name) {
if (mock_map.count(key) != 0) {
num_trial += 1;
fprintf(stderr, "[%d]@@@Hit Mock Error:%s\n", rank, name);
exit(-2);
// data processing frameworks runs on shared process
utils::Error("[%d]@@@Hit Mock Error:%s\n", rank, name);
}
}
};
Expand Down
23 changes: 8 additions & 15 deletions src/allreduce_robust.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ int AllreduceRobust::SetBootstrapCache(const std::string &key, const void *buf,
}

int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf,
const size_t type_nbytes, const size_t count, const bool byref) {
const size_t type_nbytes, const size_t count) {
// as requester sync with rest of nodes on latest cache content
if (!RecoverExec(NULL, 0, ActionSummary::kLoadBootstrapCache,
seq_counter, cur_cache_seq)) return -1;
Expand All @@ -136,14 +136,7 @@ int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf,
utils::Assert(cur_cache_seq > index, "cur_cache_seq is smaller than lookup cache seq index");
utils::Assert(siz == type_nbytes*count, "cache size stored expected to be same as requested");
utils::Assert(siz > 0, "cache size should be greater than 0");

// immutable cache, save copy time by pointer manipulation
if (byref) {
buf = temp;
} else {
std::memcpy(buf, temp, type_nbytes*count);
}

std::memcpy(buf, temp, type_nbytes*count);
return 0;
}

Expand Down Expand Up @@ -184,7 +177,7 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_,

// try fetch bootstrap allreduce results from cache
if (!checkpoint_loaded && rabit_bootstrap_cache &&
GetBootstrapCache(key, sendrecvbuf_, type_nbytes, count, true) != -1) return;
GetBootstrapCache(key, sendrecvbuf_, type_nbytes, count) != -1) return;

double start = utils::GetTime();
bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter, cur_cache_seq);
Expand Down Expand Up @@ -244,8 +237,7 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root,
+ std::string(_caller) + "#" +std::to_string(total_size) + "@" + std::to_string(root);
// try fetch bootstrap allreduce results from cache
if (!checkpoint_loaded && rabit_bootstrap_cache &&
GetBootstrapCache(key, sendrecvbuf_, total_size, 1, true) != -1) return;

GetBootstrapCache(key, sendrecvbuf_, total_size, 1) != -1) return;
double start = utils::GetTime();
bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter, cur_cache_seq);
// now we are free to remove the last result, if any
Expand Down Expand Up @@ -1171,9 +1163,10 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno,
// if all nodes are requester in load cache, skip
if (act.load_cache(SeqType::kCache)) return false;

// only restore when at least one pair of max_seq are different
if (act.diff_seq(SeqType::kCache)) {
// if restore cache failed, retry from what's left
// bootstrap cache always restore before loadcheckpoint
// requester always have seq diff with non requester
if (act.diff_seq()) {
// restore cache failed, retry from what's left
if (TryRestoreCache(req.load_cache(), act.seqno(), act.seqno(SeqType::kCache))
!= kSuccess) continue;
}
Expand Down
17 changes: 6 additions & 11 deletions src/allreduce_robust.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class AllreduceRobust : public AllreduceBase {
* \param buflen total number of bytes
*/
int GetBootstrapCache(const std::string &key, void *buf, const size_t type_nbytes,
const size_t count, const bool byref = false);
const size_t count);
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe
Expand Down Expand Up @@ -255,9 +255,8 @@ class AllreduceRobust : public AllreduceBase {
return (code & kCheckAck) != 0;
}
// whether the operation set contains different sequence number
inline bool diff_seq(SeqType t = SeqType::kSeq) const {
int code = t == SeqType::kSeq ? seqcode : maxseqcode;
return (code & kDiffSeq) != 0;
inline bool diff_seq() const {
return (seqcode & kDiffSeq) != 0;
}
// returns the operation flag of the result
inline int flag(SeqType t = SeqType::kSeq) const {
Expand All @@ -266,11 +265,10 @@ class AllreduceRobust : public AllreduceBase {
}
// print flags in user friendly way
inline void print_flags(int rank, std::string prefix ) {
utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|%d|\n",
utils::HandleLogInfo("[%d] %s - |%lu|%d|%d|%d|%d| - |%lu|%d|\n",
rank, prefix.c_str(),
seqno(), check_point(), check_ack(), load_cache(),
diff_seq(), seqno(SeqType::kCache), load_cache(SeqType::kCache),
diff_seq(SeqType::kCache));
diff_seq(), seqno(SeqType::kCache), load_cache(SeqType::kCache));
}
// reducer for Allreduce, get the result ActionSummary from all nodes
inline static void Reducer(const void *src_, void *dst_,
Expand All @@ -286,12 +284,9 @@ class AllreduceRobust : public AllreduceBase {
int role_flag = src[i].flag(SeqType::kCache) & dst[i].flag(SeqType::kCache);
// if seqno is different in src and destination
int seq_diff_flag = src[i].seqno() != dst[i].seqno() ? kDiffSeq : 0;
// if cache seqno is different in src and destination
int cache_diff_flag =
src[i].seqno(SeqType::kCache) != dst[i].seqno(SeqType::kCache) ? kDiffSeq : 0;
// apply or to both seq diff flag as well as cache seq diff flag
dst[i] = ActionSummary(action_flag | seq_diff_flag,
role_flag | cache_diff_flag, min_seqno, max_seqno);
role_flag, min_seqno, max_seqno);
}
}

Expand Down
3 changes: 2 additions & 1 deletion test/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ find_package(GTest REQUIRED)
add_executable(
unit_tests
allreduce_base_test.cpp
allreduce_mock_test.cpp
test_main.cpp)

target_link_libraries(
unit_tests
GTest::GTest GTest::Main
rabit_base)
rabit_base rabit_mock)

target_include_directories(unit_tests PUBLIC
"$<BUILD_INTERFACE:${rabit_SOURCE_DIR}/include>"
Expand Down
36 changes: 36 additions & 0 deletions test/cpp/allreduce_mock_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#define RABIT_CXXTESTDEFS_H
#include <gtest/gtest.h>

#include <string>
#include <iostream>
#include "../../src/allreduce_mock.h"

TEST(allreduce_mock, mock_allreduce)
{
rabit::engine::AllreduceMock m;

std::string mock_str = "mock=0,0,0,0";
char cmd[mock_str.size()+1];
std::copy(mock_str.begin(), mock_str.end(), cmd);
cmd[mock_str.size()] = '\0';

char* argv[] = {cmd};
m.Init(1, argv);
m.rank = 0;
EXPECT_EXIT(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), ::testing::ExitedWithCode(255), "");
}

TEST(allreduce_mock, mock_broadcast)
{
rabit::engine::AllreduceMock m;
std::string mock_str = "mock=0,1,2,0";
char cmd[mock_str.size()+1];
std::copy(mock_str.begin(), mock_str.end(), cmd);
cmd[mock_str.size()] = '\0';
char* argv[] = {cmd};
m.Init(1, argv);
m.rank = 0;
m.version_number=1;
m.seq_counter=2;
EXPECT_EXIT(m.Broadcast(nullptr,0,0), ::testing::ExitedWithCode(255), "");
}
2 changes: 1 addition & 1 deletion test/model_recover.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ int main(int argc, char *argv[]) {
std::string name = rabit::GetProcessorName();

int max_rank = rank;
rabit::Allreduce<op::Max>(&max_rank, sizeof(int));
rabit::Allreduce<op::Max>(&max_rank, 1);
utils::Check(max_rank == nproc - 1, "max rank is world size-1");

Model model;
Expand Down

0 comments on commit af7281a

Please sign in to comment.