Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge logging facilities for rabit and xgboost. #6101

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,25 @@ if (GOOGLE_TEST)
PASS_REGULAR_EXPRESSION ".*test-rmse:0.087.*")
endif (GOOGLE_TEST)

if (GOOGLE_TEST AND (NOT WIN32))
# rabit mock based integration tests
set(tests lazy_recover local_recover model_recover)

foreach(test ${tests})
add_executable(${test} rabit/test/${test}.cc)
set_output_directory(${test} ${xgboost_BINARY_DIR})
target_link_libraries(${test} rabit_mock_static xgboost)
set_target_properties(${test} PROPERTIES CXX_STANDARD 14 CXX_STANDARD_REQUIRED ON)
add_test(NAME ${test} COMMAND ${test} WORKING_DIRECTORY ${xgboost_BINARY_DIR})
endforeach()

if(RABIT_BUILD_MPI)
add_executable(speed_test_mpi test/speed_test.cc)
target_link_libraries(speed_test_mpi rabit_mpi)
add_test(NAME speed_test_mpi COMMAND speed_test_mpi WORKING_DIRECTORY ${xgboost_BINARY_DIR})
endif(RABIT_BUILD_MPI)
endif (GOOGLE_TEST AND (NOT WIN32))

# For MSVC: Call msvc_use_static_runtime() once again to completely
# replace /MD with /MT. See https://github.com/dmlc/xgboost/issues/4462
# for issues caused by mixing of /MD and /MT flags
Expand Down
24 changes: 2 additions & 22 deletions rabit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,30 +49,10 @@ endif((CMAKE_CONFIGURATION_TYPES STREQUAL "Debug") AND (CMAKE_CXX_COMPILER_ID MA
foreach(lib ${rabit_libs})
target_include_directories(${lib} PUBLIC
"$<BUILD_INTERFACE:${xgboost_SOURCE_DIR}/rabit/include>"
"$<BUILD_INTERFACE:${xgboost_SOURCE_DIR}/dmlc-core/include>")
"$<BUILD_INTERFACE:${xgboost_SOURCE_DIR}/dmlc-core/include>"
"$<BUILD_INTERFACE:${xgboost_SOURCE_DIR}/include>")
endforeach()

if (GOOGLE_TEST AND (NOT WIN32))
enable_testing()

# rabit mock based integration tests
list(REMOVE_ITEM rabit_libs "rabit_mock_static") # remove here to avoid installing it
set(tests lazy_recover local_recover model_recover)

foreach(test ${tests})
add_executable(${test} test/${test}.cc)
target_link_libraries(${test} rabit_mock_static)
set_target_properties(${test} PROPERTIES CXX_STANDARD 14 CXX_STANDARD_REQUIRED ON)
add_test(NAME ${test} COMMAND ${test} WORKING_DIRECTORY ${xgboost_BINARY_DIR})
endforeach()

if(RABIT_BUILD_MPI)
add_executable(speed_test_mpi test/speed_test.cc)
target_link_libraries(speed_test_mpi rabit_mpi)
add_test(NAME speed_test_mpi COMMAND speed_test_mpi WORKING_DIRECTORY ${xgboost_BINARY_DIR})
endif(RABIT_BUILD_MPI)
endif (GOOGLE_TEST AND (NOT WIN32))

# Headers:
set(include_install_dir "include")
install(
Expand Down
13 changes: 8 additions & 5 deletions rabit/include/rabit/internal/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,19 @@
*/
#ifndef RABIT_INTERNAL_IO_H_
#define RABIT_INTERNAL_IO_H_
#include <dmlc/io.h>

#include <cstdio>
#include <vector>
#include <cstring>
#include <string>
#include <algorithm>
#include <numeric>
#include <limits>
#include "rabit/internal/utils.h"

#include "rabit/serializable.h"
#include "rabit/internal/utils.h"
#include "xgboost/logging.h"

namespace rabit {
namespace utils {
Expand All @@ -41,8 +45,7 @@ struct MemoryFixSizeBuffer : public SeekStream {
}
void Write(const void *ptr, size_t size) override {
if (size == 0) return;
utils::Assert(curr_ptr_ + size <= buffer_size_,
"write position exceed fixed buffer size");
CHECK_LE(curr_ptr_ + size, buffer_size_) << "write position exceed fixed buffer size";
std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
curr_ptr_ += size;
}
Expand Down Expand Up @@ -78,8 +81,8 @@ struct MemoryBufferStream : public SeekStream {
}
~MemoryBufferStream() override = default;
size_t Read(void *ptr, size_t size) override {
utils::Assert(curr_ptr_ <= p_buffer_->length(),
"read can not have position excceed buffer length");
CHECK_LE(curr_ptr_, p_buffer_->length())
<< "read can not have position excceed buffer length";
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
curr_ptr_ += nread;
Expand Down
7 changes: 4 additions & 3 deletions rabit/include/rabit/internal/rabit-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// use engine for implementation
#include <vector>
#include <string>
#include "xgboost/logging.h"
#include "rabit/internal/io.h"
#include "rabit/internal/utils.h"
#include "rabit/rabit.h"
Expand Down Expand Up @@ -330,7 +331,7 @@ struct SerializeReduceClosure {
inline void Run() {
if (prepare_fun != nullptr) prepare_fun(prepare_arg);
for (size_t i = 0; i < count; ++i) {
utils::MemoryFixSizeBuffer fs(BeginPtr(*p_buffer) + i * max_nbyte, max_nbyte);
utils::MemoryFixSizeBuffer fs(dmlc::BeginPtr(*p_buffer) + i * max_nbyte, max_nbyte);
sendrecvobj[i].Save(fs);
}
}
Expand All @@ -352,11 +353,11 @@ inline void SerializeReducer<DType>::Allreduce(DType *sendrecvobj,
c.sendrecvobj = sendrecvobj; c.max_nbyte = max_nbyte; c.count = count;
c.prepare_fun = prepare_fun; c.prepare_arg = prepare_arg; c.p_buffer = &buffer_;
// invoke here
handle_.Allreduce(BeginPtr(buffer_), max_nbyte, count,
handle_.Allreduce(dmlc::BeginPtr(buffer_), max_nbyte, count,
SerializeReduceClosure<DType>::Invoke, &c,
_file, _line, _caller);
for (size_t i = 0; i < count; ++i) {
utils::MemoryFixSizeBuffer fs(BeginPtr(buffer_) + i * max_nbyte, max_nbyte);
utils::MemoryFixSizeBuffer fs(dmlc::BeginPtr(buffer_) + i * max_nbyte, max_nbyte);
sendrecvobj[i].Load(fs);
}
}
Expand Down
11 changes: 6 additions & 5 deletions rabit/include/rabit/internal/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <cstring>
#include <vector>
#include <unordered_map>
#include "xgboost/logging.h"
#include "utils.h"

#if defined(_WIN32) || defined(__MINGW32__)
Expand Down Expand Up @@ -71,8 +72,8 @@ struct SockAddr {
hints.ai_protocol = SOCK_STREAM;
addrinfo *res = nullptr;
int sig = getaddrinfo(host, nullptr, &hints, &res);
Check(sig == 0 && res != nullptr, "cannot obtain address of %s", host);
Check(res->ai_family == AF_INET, "Does not support IPv6");
CHECK(sig == 0 && res != nullptr) << "cannot obtain address of: " << host;
CHECK(res->ai_family == AF_INET) << "Does not support IPv6";
memcpy(&addr, res->ai_addr, res->ai_addrlen);
addr.sin_port = htons(port);
freeaddrinfo(res);
Expand All @@ -91,7 +92,7 @@ struct SockAddr {
const char *s = inet_ntop(AF_INET, &addr.sin_addr,
&buf[0], buf.length());
#endif // _WIN32
Assert(s != nullptr, "cannot decode address");
CHECK(s) << "cannot decode address";
return std::string(s);
}
};
Expand Down Expand Up @@ -251,9 +252,9 @@ class Socket {
inline static void Error(const char *msg) {
int errsv = GetLastError();
#ifdef _WIN32
utils::Error("Socket %s Error:WSAError-code=%d", msg, errsv);
LOG(FATAL) << "Socket " << msg << " Error:WSAError-code=" << errsv;
#else
utils::Error("Socket %s Error:%s", msg, strerror(errsv));
LOG(FATAL) << "Socket Error:" << msg << " " << strerror(errsv);
#endif
}

Expand Down
111 changes: 4 additions & 107 deletions rabit/include/rabit/internal/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,14 @@

#include <rabit/base.h>
#include <cstring>
#include <cstdarg>
#include <cstdio>
#include <string>
#include <cstdlib>
#include <stdexcept>
#include <vector>
#include "dmlc/io.h"

#ifndef RABIT_STRICT_CXX98_
#include <cstdarg>
#endif // RABIT_STRICT_CXX98_

#if !defined(__GNUC__) || defined(__FreeBSD__)
#define fopen64 std::fopen
#endif // !defined(__GNUC__) || defined(__FreeBSD__)
Expand Down Expand Up @@ -71,71 +68,6 @@ inline bool StringToBool(const char* s) {
return CompareStringsCaseInsensitive(s, "true") == 0 || atoi(s) != 0;
}

#ifndef RABIT_CUSTOMIZE_MSG_
/*!
* \brief handling of Assert error, caused by inappropriate input
* \param msg error message
*/
inline void HandleAssertError(const char *msg) {
fprintf(stderr,
"AssertError:%s, rabit is configured to keep process running\n", msg);
throw dmlc::Error(msg);
}
/*!
* \brief handling of Check error, caused by inappropriate input
* \param msg error message
*/
inline void HandleCheckError(const char *msg) {
fprintf(stderr, "%s, rabit is configured to keep process running\n", msg);
throw dmlc::Error(msg);
}
inline void HandlePrint(const char *msg) {
printf("%s", msg);
}

inline void HandleLogInfo(const char *fmt, ...) {
std::string msg(kPrintBuffer, '\0');
va_list args;
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
fprintf(stdout, "%s", msg.c_str());
fflush(stdout);
}
#else
#ifndef RABIT_STRICT_CXX98_
// include declarations, some one must implement this
void HandleAssertError(const char *msg);
void HandleCheckError(const char *msg);
void HandlePrint(const char *msg);
#endif // RABIT_STRICT_CXX98_
#endif // RABIT_CUSTOMIZE_MSG_
#ifdef RABIT_STRICT_CXX98_
// these function pointers are to be assigned
extern "C" void (*Printf)(const char *fmt, ...);
extern "C" int (*SPrintf)(char *buf, size_t size, const char *fmt, ...);
extern "C" void (*Assert)(int exp, const char *fmt, ...);
extern "C" void (*Check)(int exp, const char *fmt, ...);
extern "C" void (*Error)(const char *fmt, ...);
#else
/*! \brief printf, prints messages to the console */
inline void Printf(const char *fmt, ...) {
std::string msg(kPrintBuffer, '\0');
va_list args;
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
HandlePrint(msg.c_str());
}
/*! \brief portable version of snprintf */
inline int SPrintf(char *buf, size_t size, const char *fmt, ...) {
va_list args;
va_start(args, fmt);
int ret = vsnprintf(buf, size, fmt, args);
va_end(args);
return ret;
}

/*! \brief assert a condition is true, use this to handle debug information */
inline void Assert(bool exp, const char *fmt, ...) {
if (!exp) {
Expand All @@ -144,7 +76,7 @@ inline void Assert(bool exp, const char *fmt, ...) {
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
HandleAssertError(msg.c_str());
LOG(FATAL) << msg;
}
}

Expand All @@ -156,7 +88,7 @@ inline void Check(bool exp, const char *fmt, ...) {
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
HandleCheckError(msg.c_str());
LOG(FATAL) << msg;
}
}

Expand All @@ -168,44 +100,9 @@ inline void Error(const char *fmt, ...) {
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
HandleCheckError(msg.c_str());
LOG(FATAL) << msg;
}
}
#endif // RABIT_STRICT_CXX98_

/*! \brief replace fopen, report error when the file open fails */
inline std::FILE *FopenCheck(const char *fname, const char *flag) {
std::FILE *fp = fopen64(fname, flag);
Check(fp != nullptr, "can not open file \"%s\"\n", fname);
return fp;
}
} // namespace utils
// easy utils that can be directly accessed in xgboost
/*! \brief get the beginning address of a vector */
template<typename T>
inline T *BeginPtr(std::vector<T> &vec) { // NOLINT(*)
if (vec.size() == 0) {
return nullptr;
} else {
return &vec[0];
}
}
/*! \brief get the beginning address of a vector */
template<typename T>
inline const T *BeginPtr(const std::vector<T> &vec) { // NOLINT(*)
if (vec.size() == 0) {
return nullptr;
} else {
return &vec[0];
}
}
inline char* BeginPtr(std::string &str) { // NOLINT(*)
if (str.length() == 0) return nullptr;
return &str[0];
}
inline const char* BeginPtr(const std::string &str) {
if (str.length() == 0) return nullptr;
return &str[0];
}
} // namespace rabit
#endif // RABIT_INTERNAL_UTILS_H_
Loading