From fb2e2d459b1cfea3d270c44794bc80ba3f892506 Mon Sep 17 00:00:00 2001 From: Louis J Date: Wed, 24 Jul 2019 18:07:57 +0200 Subject: [PATCH] parent 7eb644338ea3770b6d46bb7f30da59d1f5b277c9 author Louis J 1563984477 +0200 committer Guillaume Infantes 1576060297 +0100 parent 7eb644338ea3770b6d46bb7f30da59d1f5b277c9 author Louis J 1563984477 +0200 committer Guillaume Infantes 1576059845 +0100 LOUISJ'S COMMITS: Move dataset management and model building in separate classes Add train and test The fix on txtinputconnector is temporary, vocab generation should be fixed a more robust way BERT finetuning with custom number of classes Add self supervised Masked LM learning Save solver checkpoint along with model Ensure label is of correct dimension Fix masked_lm, add more explicit error message Add script to trace huggingface models Add classfication on hidden states to be able to use masked lm model for classif Better API, more features, less memory usage and fix bugs Add unit tests for training Move training parameters to solver and net Add comments Download tar from deepdetect.com torch 1.3.1 alone working with caffe patch correction: add pcaffe/logging.h force -j8 when building libtorch (default is -j nproc) points to model traced for torch 131 GUILLAUME COMMITS: changes for torch 131 Move dataset management and model building in separate classes Add train and test The fix on txtinputconnector is temporary, vocab generation should be fixed a more robust way BERT finetuning with custom number of classes Add self supervised Masked LM learning Save solver checkpoint along with model Ensure label is of correct dimension Better API, more features, less memory usage and fix bugs Move training parameters to solver and net Add comments Add inference support for GPT2 Make lower case optional Add gpt2 training Add gpt2 demo rebase all glitches in merge update to last transformers from hugginface gpt2 inference ok sanitize width vs sequence remove comment in cmakelist --- CMakeLists.txt | 12 +- demo/gpt2/dd_client.py | 1 + demo/gpt2/run_gpt2.py | 75 ++ patches/pytorch/pytorch_compile.patch | 12 + patches/pytorch/pytorch_logging.patch | 818 ++++++++++++++++++++++ src/backends/torch/torchinputconns.cc | 262 ++++++- src/backends/torch/torchinputconns.h | 148 +++- src/backends/torch/torchlib.cc | 716 +++++++++++++++++-- src/backends/torch/torchlib.h | 50 +- src/backends/torch/torchmodel.cc | 42 +- src/backends/torch/torchmodel.h | 6 +- src/txtinputfileconn.cc | 25 +- src/txtinputfileconn.h | 17 +- tests/CMakeLists.txt | 9 + tests/ut-torchapi.cc | 70 +- tools/torch/README.md | 8 + tools/torch/trace_pytorch_transformers.py | 160 +++++ 17 files changed, 2288 insertions(+), 143 deletions(-) create mode 120000 demo/gpt2/dd_client.py create mode 100644 demo/gpt2/run_gpt2.py create mode 100644 patches/pytorch/pytorch_compile.patch create mode 100644 patches/pytorch/pytorch_logging.patch create mode 100755 tools/torch/trace_pytorch_transformers.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 63fcc6d15..bea4911d5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -706,7 +706,8 @@ if (USE_TORCH) set(PYTORCH_PATCHES_PATH ${CMAKE_BINARY_DIR}/patches/pytorch) set(PYTORCH_PATCHES - ${PYTORCH_PATCHES_PATH}/pytorch_log.patch + ${PYTORCH_PATCHES_PATH}/pytorch_compile.patch + ${PYTORCH_PATCHES_PATH}/pytorch_logging.patch ) include_directories("${PROTOBUF_INCLUDE_DIR}") @@ -715,7 +716,8 @@ if (USE_TORCH) add_definitions(-DUSE_TORCH) if (NOT TORCH_LOCATION) - set(PYTORCH_COMMIT 0b868b19063645afed59d6d49aff1e43d1665b88) + # below version 1.3.1 + set(PYTORCH_COMMIT ee77ccbb6da4e2efd83673e798acf7081bc03564) set(PYTORCH_COMPLETE ${CMAKE_BINARY_DIR}/CMakeFiles/pytorch-complete) if(NOT USE_CPU_ONLY AND CUDA_FOUND) @@ -734,7 +736,7 @@ if (USE_TORCH) PATCH_COMMAND test -f ${PYTORCH_COMPLETE} && echo Skipping || git apply ${PYTORCH_PATCHES} && echo Applying ${PYTORCH_PATCHES} CONFIGURE_COMMAND "" BUILD_COMMAND "" - COMMAND test -f ${PYTORCH_COMPLETE} && echo Skipping || GLIBCXX_USE_CXX11_ABI=1 BUILD_TEST=0 USE_CUDA=${PYTORCH_USE_CUDA} python3 ../pytorch/tools/build_libtorch.py + COMMAND test -f ${PYTORCH_COMPLETE} && echo Skipping || GLIBCXX_USE_CXX11_ABI=1 BUILD_TEST=0 USE_CUDA=${PYTORCH_USE_CUDA} CAFFE2_LINK_LOCAL_PROTOBUF=0 MAX_JOBS=8 python3 ../pytorch/tools/build_libtorch.py INSTALL_COMMAND "" ) @@ -747,10 +749,10 @@ if (USE_TORCH) # ) #message(STATUS "Libraries are: ${TORCH_LIBRARIES}") - set(TORCH_LIB_DEPS torch caffe2 ${TORCH_LOCATION}/lib/libc10.so) + set(TORCH_LIB_DEPS torch ${TORCH_LOCATION}/lib/libc10.so) if (NOT USE_CPU_ONLY AND CUDA_FOUND) - list(APPEND TORCH_LIB_DEPS caffe2_gpu ${TORCH_LOCATION}/lib/libc10_cuda.so iomp5) + list(APPEND TORCH_LIB_DEPS ${TORCH_LOCATION}/lib/libc10_cuda.so iomp5) else() list(APPEND TORCH_LIB_DEPS iomp5) endif() diff --git a/demo/gpt2/dd_client.py b/demo/gpt2/dd_client.py new file mode 120000 index 000000000..86059b6f0 --- /dev/null +++ b/demo/gpt2/dd_client.py @@ -0,0 +1 @@ +../../clients/python/dd_client.py \ No newline at end of file diff --git a/demo/gpt2/run_gpt2.py b/demo/gpt2/run_gpt2.py new file mode 100644 index 000000000..e1c61ce62 --- /dev/null +++ b/demo/gpt2/run_gpt2.py @@ -0,0 +1,75 @@ +import random +import sys +import argparse +from dd_client import DD + +parser = argparse.ArgumentParser(description="Use DeepDetect and GPT-2 to generate text") +parser.add_argument("-r", "--repository", required=True, help="Model repository") +parser.add_argument("--host", type=str, default="localhost") +parser.add_argument("--port", type=int, default=8080) +parser.add_argument("--cpu", action='store_true', help="Force model to run on CPU") +parser.add_argument("--input-size", type=int, default=512) +parser.add_argument("--topk", type=int, default=5, help="How many top predictions should be considered to chose the next token.") +parser.add_argument("--temperature", type=float, default=1, help="Temperature of the predictions. The higher, the 'randomer'.") + +args = parser.parse_args() + +# dd global variables +sname = 'gpt-2' +description = 'Inference with GPT-2' +mllib = 'torch' + +dd = DD(args.host, args.port) +dd.set_return_format(dd.RETURN_PYTHON) + +# setting up the ML service +model = {'repository':args.repository} +parameters_input = { + 'connector':'txt', + 'ordered_words': True, + 'wordpiece_tokens': True, + 'punctuation_tokens': True, + 'lower_case': False, + 'width': args.input_size +} +parameters_mllib = {'template':'gpt2', 'gpu':True} +parameters_output = {} +dd.put_service(sname,model,description,mllib, + parameters_input,parameters_mllib,parameters_output) + +# generating text +prompt = input("Enter beggining of sentence >>> ") + +for i in range(0, 256): + data = [prompt] + parameters_input = {'word_start': "Ġ", 'suffix_start': ""} + parameters_mllib = {} + parameters_output = {'best':args.topk} + result = dd.post_predict(sname, data, parameters_input,parameters_mllib,parameters_output) + + # Select result from the returned tokens + word_probs = list() + total_probs = 0 + + for cls in result['body']['predictions'][0]['classes']: + word = cls['cat'].replace("Ġ", " ") + # dede does not support \n character well, so we don't select tokens containing a new line + if 'Ċ' in word: + continue + + prob = pow(cls['prob'], args.temperature) + total_probs += prob + word_probs.append((word, prob)) + + selector = random.uniform(0, total_probs) + total_probs = 0 + + for word, prob in word_probs: + total_probs += prob + if total_probs > selector: + selected_word = word + break + + print(selected_word, sep='', end='') + sys.stdout.flush() + prompt += selected_word diff --git a/patches/pytorch/pytorch_compile.patch b/patches/pytorch/pytorch_compile.patch new file mode 100644 index 000000000..19998a0d4 --- /dev/null +++ b/patches/pytorch/pytorch_compile.patch @@ -0,0 +1,12 @@ +diff --git a/tools/setup_helpers/cmake.py b/tools/setup_helpers/cmake.py +index 894559ed43..7887147a28 100644 +--- a/tools/setup_helpers/cmake.py ++++ b/tools/setup_helpers/cmake.py +@@ -229,6 +229,7 @@ class CMake: + 'CUDA_NVCC_EXECUTABLE', + 'CUDNN_LIBRARY', + 'CUDNN_INCLUDE_DIR', ++ 'CAFFE2_LINK_LOCAL_PROTOBUF', + 'EXPERIMENTAL_SINGLE_THREAD_POOL', + 'INSTALL_TEST', + 'MKL_THREADING', diff --git a/patches/pytorch/pytorch_logging.patch b/patches/pytorch/pytorch_logging.patch new file mode 100644 index 000000000..137d7dade --- /dev/null +++ b/patches/pytorch/pytorch_logging.patch @@ -0,0 +1,818 @@ +diff --git a/c10/util/Logging.cpp b/c10/util/Logging.cpp +index 1bcb938cdc..f250828058 100644 +--- a/c10/util/Logging.cpp ++++ b/c10/util/Logging.cpp +@@ -93,204 +93,18 @@ bool LogAPIUsageFakeReturn(const std::string& event) { + + } // namespace c10 + +-#if defined(C10_USE_GFLAGS) && defined(C10_USE_GLOG) +-// When GLOG depends on GFLAGS, these variables are being defined in GLOG +-// directly via the GFLAGS definition, so we will use DECLARE_* to declare +-// them, and use them in Caffe2. +-// GLOG's minloglevel +-DECLARE_int32(minloglevel); +-// GLOG's verbose log value. +-DECLARE_int32(v); +-// GLOG's logtostderr value +-DECLARE_bool(logtostderr); +-#endif // defined(C10_USE_GFLAGS) && defined(C10_USE_GLOG) + +-#if !defined(C10_USE_GLOG) +-// This backward compatibility flags are in order to deal with cases where +-// Caffe2 are not built with glog, but some init flags still pass in these +-// flags. They may go away in the future. +-C10_DEFINE_int32(minloglevel, 0, "Equivalent to glog minloglevel"); +-C10_DEFINE_int32(v, 0, "Equivalent to glog verbose"); +-C10_DEFINE_bool(logtostderr, false, "Equivalent to glog logtostderr"); +-#endif // !defined(c10_USE_GLOG) + +-#ifdef C10_USE_GLOG +- +-// Provide easy access to the above variables, regardless whether GLOG is +-// dependent on GFLAGS or not. Note that the namespace (fLI, fLB) is actually +-// consistent between GLOG and GFLAGS, so we can do the below declaration +-// consistently. +-namespace c10 { +-using fLB::FLAGS_logtostderr; +-using fLI::FLAGS_minloglevel; +-using fLI::FLAGS_v; +-} // namespace c10 +- +-C10_DEFINE_int( +- caffe2_log_level, +- google::GLOG_ERROR, +- "The minimum log level that caffe2 will output."); +- +-// Google glog's api does not have an external function that allows one to check +-// if glog is initialized or not. It does have an internal function - so we are +-// declaring it here. This is a hack but has been used by a bunch of others too +-// (e.g. Torch). +-namespace google { +-namespace glog_internal_namespace_ { +-bool IsGoogleLoggingInitialized(); +-} // namespace glog_internal_namespace_ +-} // namespace google + + namespace c10 { + bool InitCaffeLogging(int* argc, char** argv) { +- if (*argc == 0) +- return true; +-#if !defined(_MSC_VER) +- // This trick can only be used on UNIX platforms +- if (!::google::glog_internal_namespace_::IsGoogleLoggingInitialized()) +-#endif +- { +- ::google::InitGoogleLogging(argv[0]); +-#if !defined(_MSC_VER) +- // This is never defined on Windows +- ::google::InstallFailureSignalHandler(); +-#endif +- } +- UpdateLoggingLevelsFromFlags(); + return true; + } + + void UpdateLoggingLevelsFromFlags() { +- // If caffe2_log_level is set and is lower than the min log level by glog, +- // we will transfer the caffe2_log_level setting to glog to override that. +- FLAGS_minloglevel = std::min(FLAGS_caffe2_log_level, FLAGS_minloglevel); +- // If caffe2_log_level is explicitly set, let's also turn on logtostderr. +- if (FLAGS_caffe2_log_level < google::GLOG_ERROR) { +- FLAGS_logtostderr = 1; +- } +- // Also, transfer the caffe2_log_level verbose setting to glog. +- if (FLAGS_caffe2_log_level < 0) { +- FLAGS_v = std::min(FLAGS_v, -FLAGS_caffe2_log_level); +- } +-} +- +-void ShowLogInfoToStderr() { +- FLAGS_logtostderr = 1; +- FLAGS_minloglevel = std::min(FLAGS_minloglevel, google::GLOG_INFO); +-} +-} // namespace c10 +- +-#else // !C10_USE_GLOG +- +-#ifdef ANDROID +-#include +-#endif // ANDROID +- +-C10_DEFINE_int( +- caffe2_log_level, +- ERROR, +- "The minimum log level that caffe2 will output."); +- +-namespace c10 { +-bool InitCaffeLogging(int* argc, char** argv) { +- // When doing InitCaffeLogging, we will assume that caffe's flag paser has +- // already finished. +- if (*argc == 0) +- return true; +- if (!c10::CommandLineFlagsHasBeenParsed()) { +- std::cerr << "InitCaffeLogging() has to be called after " +- "c10::ParseCommandLineFlags. Modify your program to make sure " +- "of this." +- << std::endl; +- return false; +- } +- if (FLAGS_caffe2_log_level > FATAL) { +- std::cerr << "The log level of Caffe2 has to be no larger than FATAL(" +- << FATAL << "). Capping it to FATAL." << std::endl; +- FLAGS_caffe2_log_level = FATAL; +- } +- return true; + } + +-void UpdateLoggingLevelsFromFlags() {} +- + void ShowLogInfoToStderr() { +- FLAGS_caffe2_log_level = INFO; +-} +- +-MessageLogger::MessageLogger(const char* file, int line, int severity) +- : severity_(severity) { +- if (severity_ < FLAGS_caffe2_log_level) { +- // Nothing needs to be logged. +- return; +- } +-#ifdef ANDROID +- tag_ = "native"; +-#else // !ANDROID +- tag_ = ""; +-#endif // ANDROID +- /* +- time_t rawtime; +- struct tm * timeinfo; +- time(&rawtime); +- timeinfo = localtime(&rawtime); +- std::chrono::nanoseconds ns = +- std::chrono::duration_cast( +- std::chrono::high_resolution_clock::now().time_since_epoch()); +- */ +- stream_ << "[" +- << CAFFE2_SEVERITY_PREFIX[std::min(4, FATAL - severity_)] +- //<< (timeinfo->tm_mon + 1) * 100 + timeinfo->tm_mday +- //<< std::setfill('0') +- //<< " " << std::setw(2) << timeinfo->tm_hour +- //<< ":" << std::setw(2) << timeinfo->tm_min +- //<< ":" << std::setw(2) << timeinfo->tm_sec +- //<< "." << std::setw(9) << ns.count() % 1000000000 +- << " " << c10::detail::StripBasename(std::string(file)) << ":" << line +- << "] "; + } +- +-// Output the contents of the stream to the proper channel on destruction. +-MessageLogger::~MessageLogger() { +- if (severity_ < FLAGS_caffe2_log_level) { +- // Nothing needs to be logged. +- return; +- } +- stream_ << "\n"; +-#ifdef ANDROID +- static const int android_log_levels[] = { +- ANDROID_LOG_FATAL, // LOG_FATAL +- ANDROID_LOG_ERROR, // LOG_ERROR +- ANDROID_LOG_WARN, // LOG_WARNING +- ANDROID_LOG_INFO, // LOG_INFO +- ANDROID_LOG_DEBUG, // VLOG(1) +- ANDROID_LOG_VERBOSE, // VLOG(2) .. VLOG(N) +- }; +- int android_level_index = FATAL - std::min(FATAL, severity_); +- int level = android_log_levels[std::min(android_level_index, 5)]; +- // Output the log string the Android log at the appropriate level. +- __android_log_print(level, tag_, "%s", stream_.str().c_str()); +- // Indicate termination if needed. +- if (severity_ == FATAL) { +- __android_log_print(ANDROID_LOG_FATAL, tag_, "terminating.\n"); +- } +-#else // !ANDROID +- if (severity_ >= FLAGS_caffe2_log_level) { +- // If not building on Android, log all output to std::cerr. +- std::cerr << stream_.str(); +- // Simulating the glog default behavior: if the severity is above INFO, +- // we flush the stream so that the output appears immediately on std::cerr. +- // This is expected in some of our tests. +- if (severity_ > INFO) { +- std::cerr << std::flush; +- } +- } +-#endif // ANDROID +- if (severity_ == FATAL) { +- DealWithFatal(); +- } +-} +- + } // namespace c10 + +-#endif // !C10_USE_GLOG +diff --git a/c10/util/Logging.h b/c10/util/Logging.h +index 610dc33a9a..b4b644defa 100644 +--- a/c10/util/Logging.h ++++ b/c10/util/Logging.h +@@ -21,14 +21,8 @@ + #define CAFFE2_LOG_THRESHOLD INT_MIN + #endif // CAFFE2_LOG_THRESHOLD + +-// Below are different implementations for glog and non-glog cases. +-#ifdef C10_USE_GLOG +-#include "c10/util/logging_is_google_glog.h" +-#else // !C10_USE_GLOG + #include "c10/util/logging_is_not_google_glog.h" +-#endif // C10_USE_GLOG + +-C10_DECLARE_int(caffe2_log_level); + C10_DECLARE_bool(caffe2_use_fatal_for_enforce); + + // Some versions of GLOG support less-spammy version of LOG_EVERY_MS. If it's +@@ -63,11 +57,7 @@ C10_API C10_NORETURN void ThrowEnforceNotMet( + const void* caller = nullptr); + + constexpr bool IsUsingGoogleLogging() { +-#ifdef C10_USE_GLOG +- return true; +-#else + return false; +-#endif + } + + /** +diff --git a/c10/util/logging_is_not_google_glog.h b/c10/util/logging_is_not_google_glog.h +index 50b0c0d007..b5888afc59 100644 +--- a/c10/util/logging_is_not_google_glog.h ++++ b/c10/util/logging_is_not_google_glog.h +@@ -1,240 +1,93 @@ + #ifndef C10_UTIL_LOGGING_IS_NOT_GOOGLE_GLOG_H_ + #define C10_UTIL_LOGGING_IS_NOT_GOOGLE_GLOG_H_ + +-#include +-#include +-#include +-#include + #include +-#include +-#include +-#include +-#include +-#include ++#include ++#include "caffe/llogging.h" + + #include "c10/util/Flags.h" ++#include "c10/util/ArrayRef.h" ++#include "c10/util/typeid.h" ++#include "c10/core/DeviceType.h" ++ ++ ++#ifndef VLOG ++ ++#define DEBUG "none" ++static const std::string VLogLevels[4] = {DEBUG, DEBUG, DEBUG, DEBUG}; ++#define VLOG(n) LOG(VLogLevels[n]) + +-// Log severity level constants. +-const int FATAL = 3; +-#if !defined(_MSC_VER) || !defined(ERROR) +-// Windows defines the ERROR macro already, and as a result we will +-// simply use that one. The downside is that one will now mix LOG(INFO) +-// and LOG(ERROR) because ERROR is defined to be zero. Anyway, the +-// recommended way is to use glog so fixing this is a low-pri item. +-const int ERROR = 2; + #endif +-const int WARNING = 1; +-const int INFO = 0; +-const char CAFFE2_SEVERITY_PREFIX[] = "FEWIV"; +- +-namespace c10 { +-class C10_API MessageLogger { +- public: +- MessageLogger(const char* file, int line, int severity); +- ~MessageLogger(); +- // Return the stream associated with the logger object. +- std::stringstream& stream() { +- return stream_; +- } + +- private: +- // When there is a fatal log, we simply abort. +- void DealWithFatal() { +- abort(); +- } ++#define LOG_AS_STRING(l, o) \ ++ std::stringstream ss; \ ++ ss << o; \ ++ return l << ss.str(); + +- const char* tag_; +- std::stringstream stream_; +- int severity_; +-}; +- +-// This class is used to explicitly ignore values in the conditional +-// logging macros. This avoids compiler warnings like "value computed +-// is not used" and "statement has no effect". +-class C10_API LoggerVoidify { +- public: +- LoggerVoidify() {} +- // This has to be an operator with a precedence lower than << but +- // higher than ?: +- void operator&(const std::ostream& s) {} +-}; +- +-// Log a message and terminate. +-template +-void LogMessageFatal(const char* file, int line, const T& message) { +- MessageLogger(file, line, FATAL).stream() << message; +-} + +-// Helpers for CHECK_NOTNULL(). Two are necessary to support both raw pointers +-// and smart pointers. +-template +-T& CheckNotNullCommon(const char* file, int line, const char* names, T& t) { +- if (t == nullptr) { +- LogMessageFatal(file, line, std::string(names)); +- } +- return t; +-} + +-template +-T* CheckNotNull(const char* file, int line, const char* names, T* t) { +- return CheckNotNullCommon(file, line, names, t); ++inline CaffeLogger &operator<<(CaffeLogger &out, const std::_Setprecision &o) { ++ LOG_AS_STRING(out, o); + } +- +-template +-T& CheckNotNull(const char* file, int line, const char* names, T& t) { +- return CheckNotNullCommon(file, line, names, t); ++inline CaffeLogger &operator<<(CaffeLogger &out, std::ostream&(*o)(std::ostream&)) { ++ LOG_AS_STRING(out, o); ++} ++inline CaffeLogger &operator<<(const CaffeLogger &out, const void *o) { ++ LOG_AS_STRING(out, o); ++} ++inline CaffeLogger &operator<<(CaffeLogger &out, const c10::DeviceType &o) { ++ LOG_AS_STRING(out, o); ++} ++inline CaffeLogger &operator<<(CaffeLogger &out, c10::IntArrayRef &o) { ++ LOG_AS_STRING(out, o); ++} ++inline CaffeLogger &operator<<(CaffeLogger &out, const caffe2::TypeMeta &o) { ++ LOG_AS_STRING(out, o); ++} ++inline CaffeLogger &operator<<(CaffeLogger &out, const google::protobuf::Message &o) { ++ return out << o.SerializeAsString(); + } +-} // namespace c10 +- +-// ---------------------- Logging Macro definitions -------------------------- +- +-static_assert( +- CAFFE2_LOG_THRESHOLD <= FATAL, +- "CAFFE2_LOG_THRESHOLD should at most be FATAL."); +-// If n is under the compile time caffe log threshold, The _CAFFE_LOG(n) +-// should not generate anything in optimized code. +-#define LOG(n) \ +- if (n >= CAFFE2_LOG_THRESHOLD) \ +- ::c10::MessageLogger((char*)__FILE__, __LINE__, n).stream() +-#define VLOG(n) LOG((-n)) +- +-#define LOG_IF(n, condition) \ +- if (n >= CAFFE2_LOG_THRESHOLD && (condition)) \ +- ::c10::MessageLogger((char*)__FILE__, __LINE__, n).stream() +-#define VLOG_IF(n, condition) LOG_IF((-n), (condition)) +- +-#define VLOG_IS_ON(verboselevel) (CAFFE2_LOG_THRESHOLD <= -(verboselevel)) +- +-// Log only if condition is met. Otherwise evaluates to void. +-#define FATAL_IF(condition) \ +- condition ? (void)0 \ +- : ::c10::LoggerVoidify() & \ +- ::c10::MessageLogger((char*)__FILE__, __LINE__, FATAL).stream() +- +-// Check for a given boolean condition. +-#define CHECK(condition) FATAL_IF(condition) << "Check failed: " #condition " " +- +-#ifndef NDEBUG +-// Debug only version of CHECK +-#define DCHECK(condition) FATAL_IF(condition) << "Check failed: " #condition " " +-#else +-// Optimized version - generates no code. +-#define DCHECK(condition) \ +- while (false) \ +- CHECK(condition) +-#endif // NDEBUG +- +-#define CHECK_OP(val1, val2, op) \ +- FATAL_IF((val1 op val2)) << "Check failed: " #val1 " " #op " " #val2 " " +- +-// Check_op macro definitions +-#define CHECK_EQ(val1, val2) CHECK_OP(val1, val2, ==) +-#define CHECK_NE(val1, val2) CHECK_OP(val1, val2, !=) +-#define CHECK_LE(val1, val2) CHECK_OP(val1, val2, <=) +-#define CHECK_LT(val1, val2) CHECK_OP(val1, val2, <) +-#define CHECK_GE(val1, val2) CHECK_OP(val1, val2, >=) +-#define CHECK_GT(val1, val2) CHECK_OP(val1, val2, >) +- +-#ifndef NDEBUG +-// Debug only versions of CHECK_OP macros. +-#define DCHECK_EQ(val1, val2) CHECK_OP(val1, val2, ==) +-#define DCHECK_NE(val1, val2) CHECK_OP(val1, val2, !=) +-#define DCHECK_LE(val1, val2) CHECK_OP(val1, val2, <=) +-#define DCHECK_LT(val1, val2) CHECK_OP(val1, val2, <) +-#define DCHECK_GE(val1, val2) CHECK_OP(val1, val2, >=) +-#define DCHECK_GT(val1, val2) CHECK_OP(val1, val2, >) +-#else // !NDEBUG +-// These versions generate no code in optimized mode. +-#define DCHECK_EQ(val1, val2) \ +- while (false) \ +- CHECK_OP(val1, val2, ==) +-#define DCHECK_NE(val1, val2) \ +- while (false) \ +- CHECK_OP(val1, val2, !=) +-#define DCHECK_LE(val1, val2) \ +- while (false) \ +- CHECK_OP(val1, val2, <=) +-#define DCHECK_LT(val1, val2) \ +- while (false) \ +- CHECK_OP(val1, val2, <) +-#define DCHECK_GE(val1, val2) \ +- while (false) \ +- CHECK_OP(val1, val2, >=) +-#define DCHECK_GT(val1, val2) \ +- while (false) \ +- CHECK_OP(val1, val2, >) +-#endif // NDEBUG +- +-// Check that a pointer is not null. +-#define CHECK_NOTNULL(val) \ +- ::c10::CheckNotNull( \ +- __FILE__, __LINE__, "Check failed: '" #val "' Must be non NULL", (val)) +- +-#ifndef NDEBUG +-// Debug only version of CHECK_NOTNULL +-#define DCHECK_NOTNULL(val) \ +- ::c10::CheckNotNull( \ +- __FILE__, __LINE__, "Check failed: '" #val "' Must be non NULL", (val)) +-#else // !NDEBUG +-// Optimized version - generates no code. +-#define DCHECK_NOTNULL(val) \ +- while (false) \ +- CHECK_NOTNULL(val) +-#endif // NDEBUG +- +-// ---------------------- Support for std objects -------------------------- +-// These are adapted from glog to support a limited set of logging capability +-// for STL objects. + +-namespace std { +-// Forward declare these two, and define them after all the container streams +-// operators so that we can recurse from pair -> container -> container -> pair +-// properly. +-template +-std::ostream& operator<<(std::ostream& out, const std::pair& p); +-} // namespace std + +-namespace c10 { +-template +-void PrintSequence(std::ostream& ss, Iter begin, Iter end); +-} // namespace c10 + + namespace std { +-#define INSTANTIATE_FOR_CONTAINER(container) \ +- template \ +- std::ostream& operator<<( \ +- std::ostream& out, const container& seq) { \ +- c10::PrintSequence(out, seq.begin(), seq.end()); \ +- return out; \ ++ ++#define LOG_PAIR(s) \ ++ template \ ++ inline s& operator<<(s& out, const std::pair& p) { \ ++ return out << '(' << p.first << ", " << p.second << ')'; \ + } + +-INSTANTIATE_FOR_CONTAINER(std::vector) +-INSTANTIATE_FOR_CONTAINER(std::map) +-INSTANTIATE_FOR_CONTAINER(std::set) +-#undef INSTANTIATE_FOR_CONTAINER +- +-template +-inline std::ostream& operator<<( +- std::ostream& out, +- const std::pair& p) { +- out << '(' << p.first << ", " << p.second << ')'; +- return out; +-} ++LOG_PAIR(std::ostream) ++LOG_PAIR(CaffeLogger) ++ ++#define LOG_CONTAINER(s, c) \ ++ template \ ++ s& operator<<(s& out, const c& seq) { \ ++ int i = 0; \ ++ for (auto it = seq.begin(); it != seq.end(); ++it) { \ ++ if (i++) { \ ++ out << ' '; \ ++ } \ ++ if (i > 100) { \ ++ return out << "..."; \ ++ } \ ++ out << *it; \ ++ } \ ++ return out; \ ++ } ++ ++ ++ LOG_CONTAINER(std::ostream, std::vector) ++ LOG_CONTAINER(std::ostream, std::map) ++ LOG_CONTAINER(std::ostream, std::set) ++ ++ LOG_CONTAINER(CaffeLogger, std::vector) ++ LOG_CONTAINER(CaffeLogger, std::map) ++ LOG_CONTAINER(CaffeLogger, std::set) ++ + } // namespace std + +-namespace c10 { +-template +-inline void PrintSequence(std::ostream& out, Iter begin, Iter end) { +- // Output at most 100 elements -- appropriate if used for logging. +- for (int i = 0; begin != end && i < 100; ++i, ++begin) { +- if (i > 0) +- out << ' '; +- out << *begin; +- } +- if (begin != end) { +- out << " ..."; +- } +-} +-} // namespace c10 ++#define VLOG_IS_ON(n) true + + #endif // C10_UTIL_LOGGING_IS_NOT_GOOGLE_GLOG_H_ +diff --git a/caffe/llogging.h b/caffe/llogging.h +new file mode 100644 +index 0000000000..f92198a9cf +--- /dev/null ++++ b/caffe/llogging.h +@@ -0,0 +1,258 @@ ++/** ++ * Author: Emmanuel Benazera ++ */ ++ ++#ifndef LLOGGING_H ++#define LLOGGING_H ++ ++#include ++#include ++#include ++ ++class DateLogger { ++ public: ++ DateLogger() { ++#if defined(_MSC_VER) ++ _tzset(); ++#endif ++ } ++ const char* HumanDate() { ++#if defined(_MSC_VER) ++ _strtime_s(buffer_, sizeof(buffer_)); ++#else ++ time_t time_value = time(NULL); ++ struct tm *pnow; ++#if !defined(_WIN32) ++ struct tm now; ++ pnow = localtime_r(&time_value, &now); ++#else ++ pnow = localtime(&time_value); // NOLINT(*) ++#endif ++ snprintf(buffer_, sizeof(buffer_), "%02d:%02d:%02d", ++ pnow->tm_hour, pnow->tm_min, pnow->tm_sec); ++#endif ++ return buffer_; ++ } ++ ++ private: ++ char buffer_[9]; ++}; ++ ++// avoid fatal checks from glog ++#define CAFFE_THROW_ON_ERROR ++ ++// make sure we erase definitions by glog if any ++#undef LOG ++#undef LOG_IF ++#undef CHECK ++#undef CHECK_OP_LOG ++#undef CHECK_EQ ++#undef CHECK_LT ++#undef CHECK_GT ++#undef CHECK_LE ++#undef CHECK_GE ++#undef CHECK_EQ ++#undef CHECK_NE ++#undef CHECK_OP_LOG ++#undef CHECK_NOTNULL ++#undef DCHECK ++#undef DCHECK_LT ++#undef DCHECK_GT ++#undef DCHECK_LE ++#undef DCHECK_GE ++#undef DCHECK_EQ ++#undef DCHECK_NE ++#undef DLOG ++#undef DFATAL ++#undef LOG_DFATAL ++#undef LOG_EVERY_N ++ ++#ifdef CAFFE_THROW_ON_ERROR ++#include ++#define SSTR( x ) dynamic_cast< std::ostringstream & >( \ ++ ( std::ostringstream() << std::dec << x ) ).str() ++class CaffeErrorException : public std::exception ++{ ++public: ++ CaffeErrorException(const std::string &s):_s(s) {} ++ ~CaffeErrorException() throw() {} ++ const char* what() const throw() { return _s.c_str(); } ++ std::string _s; ++}; ++ ++static std::string INFO="INFO"; ++static std::string WARNING="WARNING"; ++static std::string ERROR="ERROR"; ++static std::string FATAL="FATAL"; ++ ++#define GLOG_NO_ABBREVIATED_SEVERITIES ++ ++#define INFO INFO ++#define WARNING WARNING ++#define ERROR ERROR ++#define FATAL FATAL ++ ++static std::ostream nullstream(0); ++ ++#define CHECK(condition) \ ++ if (!(condition)) \ ++ throw CaffeErrorException(std::string(__FILE__) + ":" + SSTR(__LINE__) + " / Check failed (custom): " #condition ""); \ ++ nullstream \ ++ << "Check failed (custom): " #condition " " ++ ++#define CHECK_LT(x, y) CHECK((x) < (y)) ++#define CHECK_GT(x, y) CHECK((x) > (y)) ++#define CHECK_LE(x, y) CHECK((x) <= (y)) ++#define CHECK_GE(x, y) CHECK((x) >= (y)) ++#define CHECK_EQ(x, y) CHECK((x) == (y)) ++#define CHECK_NE(x, y) CHECK((x) != (y)) ++ ++#define CHECK_OP_LOG(name, op, val1, val2, log) CHECK((val1) op (val2)) ++/* #ifdef DEBUG */ ++/* #define CHECK_EQ(val1,val2) if (0) std::cerr */ ++/* #endif */ ++#endif ++ ++#define CHECK_NOTNULL(x) \ ++ ((x) == NULL ? LOG(FATAL) << "Check notnull: " #x << ' ', (x) : (x)) // NOLINT(*) ++ ++#ifdef NDEBUG ++#define DCHECK(x) \ ++ while (false) CHECK(x) ++#define DCHECK_LT(x, y) \ ++ while (false) CHECK((x) < (y)) ++#define DCHECK_GT(x, y) \ ++ while (false) CHECK((x) > (y)) ++#define DCHECK_LE(x, y) \ ++ while (false) CHECK((x) <= (y)) ++#define DCHECK_GE(x, y) \ ++ while (false) CHECK((x) >= (y)) ++#define DCHECK_EQ(x, y) \ ++ while (false) CHECK((x) == (y)) ++#define DCHECK_NE(x, y) \ ++ while (false) CHECK((x) != (y)) ++#else ++#define DCHECK(x) CHECK(x) ++#define DCHECK_LT(x, y) CHECK((x) < (y)) ++#define DCHECK_GT(x, y) CHECK((x) > (y)) ++#define DCHECK_LE(x, y) CHECK((x) <= (y)) ++#define DCHECK_GE(x, y) CHECK((x) >= (y)) ++#define DCHECK_EQ(x, y) CHECK((x) == (y)) ++#define DCHECK_NE(x, y) CHECK((x) != (y)) ++#endif // NDEBUG ++ ++class CaffeLogger ++{ ++ public: ++ CaffeLogger(const std::string &severity) ++ :_severity(severity) ++ { ++ _console = spdlog::get("caffe"); ++ if (!_console) ++#ifdef USE_SYSLOG ++ _console = spdlog::syslog_logger("caffe"); ++#else ++ _console = spdlog::stdout_logger_mt("caffe"); ++#endif ++ } ++ ++ ~CaffeLogger() ++ { ++ if (_severity == "none" || _str.empty()) // ignore ++ {} ++ else if (_severity == INFO) ++ _console->info(_str); ++ else if (_severity == WARNING) ++ _console->warn(_str); ++ else if (_severity == ERROR) ++ _console->error(_str); ++ } ++ ++ friend CaffeLogger& operator<<(const CaffeLogger &cl, const std::string &rstr) ++ { ++ std::string str = rstr; ++ const_cast(cl)._str += str; ++ return const_cast(cl); ++ } ++ ++ friend CaffeLogger& operator<<(const CaffeLogger &cl, const double &d) ++ { ++ std::string str = std::to_string(d); ++ boost::trim_right_if(str,boost::is_any_of("\n")); ++ const_cast(cl)._str += str; ++ return const_cast(cl); ++ } ++ ++ friend CaffeLogger& operator<<(const CaffeLogger &cl, const std::ostream &out) ++ { ++ std::stringstream sstr; ++ sstr << out.rdbuf(); ++ const_cast(cl)._str += sstr.str(); ++ return const_cast(cl); ++ } ++ ++ std::string _severity = INFO; ++ std::shared_ptr _console; ++ std::string _str; ++}; ++ ++inline CaffeLogger LOG(const std::string &severity) ++{ ++ if (severity != FATAL) ++ { ++ return CaffeLogger(severity); ++ } ++ else ++ { ++ throw CaffeErrorException(std::string(__FILE__) + ":" + SSTR(__LINE__) + " / Fatal Caffe error"); // XXX: cannot report the exact location of the trigger... ++ } ++} ++ ++inline CaffeLogger LOG_IF(const std::string &severity,const bool &condition) ++{ ++ if (condition) ++ return LOG(severity); ++ else return CaffeLogger("none"); ++} ++ ++#ifdef NDEBUG ++inline CaffeLogger DFATAL(const std::string &severity) ++{ ++ (void)severity; ++ return CaffeLogger("none"); ++} ++inline CaffeLogger LOG_DFATAL(const std::string &severity) ++{ ++ (void)severity; ++ return CaffeLogger("none"); ++} ++inline CaffeLogger DLOG(const std::string &severity) ++{ ++ (void)severity; ++ return CaffeLogger("none"); ++} ++#else ++inline CaffeLogger DFATAL(const std::string &severity) ++{ ++ (void)severity; ++ return LOG(FATAL); ++} ++inline CaffeLogger LOG_DFATAL(const std::string &severity) ++{ ++ (void)severity; ++ return LOG(FATAL); ++} ++inline CaffeLogger DLOG(const std::string &severity) ++{ ++ return LOG(severity); ++} ++#endif ++ ++// Poor man's version... ++inline CaffeLogger LOG_EVERY_N(const std::string &severity, const int &n) ++{ ++ (void)n; ++ return LOG(severity); ++} ++ ++#endif diff --git a/src/backends/torch/torchinputconns.cc b/src/backends/torch/torchinputconns.cc index b5950f5ac..31d351085 100644 --- a/src/backends/torch/torchinputconns.cc +++ b/src/backends/torch/torchinputconns.cc @@ -2,7 +2,107 @@ namespace dd { +using namespace torch; + +// ===== TorchDataset + +void TorchDataset::add_batch(std::vector data, std::vector target) +{ + _batches.push_back(TorchBatch(data, target)); +} + +void TorchDataset::reset() +{ + _indices.clear(); + + for (int64_t i = 0; i < _batches.size(); ++i) { + _indices.push_back(i); + } + + if (_shuffle) + { + auto seed = _seed == -1 ? static_cast(time(NULL)) : _seed; + std::shuffle(_indices.begin(), _indices.end(), std::mt19937(seed)); + } +} + +// `request` holds the size of the batch +// Data selection and batch construction are done in this method +c10::optional TorchDataset::get_batch(BatchRequestType request) +{ + size_t count = request[0]; + count = count < _indices.size() ? count : _indices.size(); + + if (count == 0) { + return torch::nullopt; + } + + std::vector> data, target; + + while(count != 0) { + auto id = _indices.back(); + auto entry = _batches[id]; + + for (int i = 0; i < entry.data.size(); ++i) + { + while (i >= data.size()) + data.emplace_back(); + data[i].push_back(entry.data.at(i)); + } + for (int i = 0; i < entry.target.size(); ++i) + { + while (i >= target.size()) + target.emplace_back(); + target[i].push_back(entry.target.at(i)); + } + + _indices.pop_back(); + count--; + } + + std::vector data_tensors; + for (auto vec : data) + data_tensors.push_back(torch::stack(vec)); + + std::vector target_tensors; + for (auto vec : target) + target_tensors.push_back(torch::stack(vec)); + + return TorchBatch{ data_tensors, target_tensors }; +} + +TorchBatch TorchDataset::get_cached() { + reset(); + auto batch = get_batch({cache_size()}); + if (!batch) + throw InputConnectorInternalException("No data provided"); + return batch.value(); +} + +TorchDataset TorchDataset::split(double start, double stop) +{ + auto datasize = _batches.size(); + auto start_it = _batches.begin() + static_cast(datasize * start); + auto stop_it = _batches.end() - static_cast(datasize * (1 - stop)); + + TorchDataset new_dataset; + new_dataset._batches.insert(new_dataset._batches.end(), start_it, stop_it); + return new_dataset; +} + + +// ===== TxtTorchInputFileConn + +void TxtTorchInputFileConn::fillup_parameters(const APIData &ad_input) +{ + TxtInputFileConn::fillup_parameters(ad_input); +} + void TxtTorchInputFileConn::transform(const APIData &ad) { + // if (_finetuning) + // XXX: Generating vocab from scratch is not currently + _generate_vocab = false; + try { TxtInputFileConn::transform(ad); @@ -12,68 +112,174 @@ void TxtTorchInputFileConn::transform(const APIData &ad) { throw; } + if (!_ordered_words || _characters) + throw InputConnectorBadParamException("Need ordered_words = true with backend torch"); + + // XXX: move in txtinputconn? + make_inv_vocab(); + if (ad.has("parameters") && ad.getobj("parameters").has("input")) { APIData ad_input = ad.getobj("parameters").getobj("input"); - if (ad_input.has("width")) - _width = ad_input.get("width").get(); + fillup_parameters(ad_input); + } + + if (_input_format == "bert") + { + _cls_pos = _vocab.at("[CLS]")._pos; + _sep_pos = _vocab.at("[SEP]")._pos; + _unk_pos = _vocab.at("[UNK]")._pos; + _mask_id = _vocab.at("[MASK]")._pos; + } + else if (_input_format == "gpt2") + { + _eot_pos = _vocab.at("<|endoftext|>")._pos; } - if (!_ordered_words || _characters) - throw InputConnectorBadParamException("Need ordered_words = true with backend torch"); + fill_dataset(_dataset, _txt); + if (!_test_txt.empty()) + fill_dataset(_test_dataset, _test_txt); +} - int cls_pos = _vocab.at("[CLS]")._pos; - int sep_pos = _vocab.at("[SEP]")._pos; - int unk_pos = _vocab.at("[UNK]")._pos; +TorchBatch TxtTorchInputFileConn::generate_masked_lm_batch(const TorchBatch &example) +{ + std::uniform_real_distribution uniform(0, 1); + std::uniform_int_distribution vocab_distrib(0, vocab_size() - 1); + Tensor input_ids = example.data.at(0).clone(); + // lm_labels: n_batch * sequence_length + // equals to input_ids where tokens are masked, and -1 otherwise + Tensor lm_labels = torch::ones_like(input_ids, TensorOptions(kLong)) * -1; - std::vector vids; - std::vector vmask; + // mask random tokens + auto input_acc = input_ids.accessor(); + auto att_mask_acc = example.data.at(2).accessor(); + auto labels_acc = lm_labels.accessor(); + for (int i = 0; i < input_ids.size(0); ++i) + { + int j = 1; // skip [CLS] token + while (j < input_ids.size(1) && att_mask_acc[i][j] != 0) + { + double rand_num = uniform(_rng); + if (rand_num < _lm_params._change_prob && input_acc[i][j] != _sep_pos) + { + labels_acc[i][j] = input_acc[i][j]; - for (auto *te : _txt) + rand_num = uniform(_rng); + if (rand_num < _lm_params._mask_prob) + { + input_acc[i][j] = mask_id(); + } + else if (rand_num < _lm_params._mask_prob + _lm_params._rand_prob) + { + input_acc[i][j] = vocab_distrib(_rng); + } + } + ++j; + } + } + + TorchBatch output; + output.target.push_back(lm_labels); + output.data.push_back(input_ids); + for (int i = 1; i < example.data.size(); ++i) + { + output.data.push_back(example.data[i]); + } + return output; +} + +void TxtTorchInputFileConn::fill_dataset(TorchDataset &dataset, + const std::vector*> &entries) +{ + for (auto *te : entries) { TxtOrderedWordsEntry *tow = static_cast(te); tow->reset(); - + std::string word; + double val; std::vector ids; - ids.push_back(cls_pos); while(tow->has_elt()) { - std::string word; - double val; + if (ids.size() >= _width) + break; + tow->get_next_elt(word, val); std::unordered_map::iterator it; - + if ((it = _vocab.find(word)) != _vocab.end()) { ids.push_back(it->second._pos); } - else + else if (_input_format == "bert") { - ids.push_back(unk_pos); + ids.push_back(_unk_pos); } } - ids.push_back(sep_pos); + // Extract last token (needed by gpt2) + int64_t last_token = 0; + if (tow->has_elt()) + { + tow->get_next_elt(word, val); + std::unordered_map::iterator it; + + if ((it = _vocab.find(word)) != _vocab.end()) + last_token = it->second._pos; + } + + // Post-processing for each model + if (_input_format == "bert") + { + // make room for cls and sep token + while (ids.size() > _width - 2) + ids.pop_back(); + + ids.insert(ids.begin(), _cls_pos); + ids.push_back(_sep_pos); + } + else if (_input_format == "gpt2") + { + if (ids.size() < _width) + { + ids.push_back(_eot_pos); + } + } at::Tensor ids_tensor = toLongTensor(ids); at::Tensor mask_tensor = torch::ones_like(ids_tensor); - // at::Tensor token_type_ids_tensor = torch::zeros_like(ids_tensor); + at::Tensor token_type_ids_tensor = torch::zeros_like(ids_tensor); - int64_t padding_size = _width - ids_tensor.sizes().back(); + int64_t seq_len = ids_tensor.sizes().back(); + int64_t padding_size = _width - seq_len; + _lengths.push_back(seq_len); ids_tensor = torch::constant_pad_nd( ids_tensor, at::IntList{0, padding_size}, 0); mask_tensor = torch::constant_pad_nd( mask_tensor, at::IntList{0, padding_size}, 0); - // token_type_ids_tensor = torch::constant_pad_nd( - // token_type_ids_tensor, at::IntList{0, padding_size}, 0); + token_type_ids_tensor = torch::constant_pad_nd( + token_type_ids_tensor, at::IntList{0, padding_size}, 0); + at::Tensor position_ids = torch::arange(_width, at::kLong); - vids.push_back(ids_tensor); - vmask.push_back(mask_tensor); - } + std::vector target_vec; + int target_val = static_cast(tow->_target); + + if (target_val != -1) + { + Tensor target_tensor = torch::full(1, target_val, torch::kLong); + target_vec.push_back(target_tensor); + } - _in = torch::stack(vids, 0); - _attention_mask = torch::stack(vmask, 0); + if (_input_format == "bert") + dataset.add_batch({ids_tensor, token_type_ids_tensor, mask_tensor}, std::move(target_vec)); + else if (_input_format == "gpt2") + { + std::vector out_vec { ids_tensor.slice(0, 1) }; + out_vec.push_back(torch::full(1, last_token, torch::kLong)); + target_vec.insert(target_vec.begin(), torch::cat(out_vec, 0)); + dataset.add_batch({ids_tensor, position_ids}, std::move(target_vec)); + } + } } -} \ No newline at end of file +} diff --git a/src/backends/torch/torchinputconns.h b/src/backends/torch/torchinputconns.h index 964dac69a..524673c73 100644 --- a/src/backends/torch/torchinputconns.h +++ b/src/backends/torch/torchinputconns.h @@ -31,22 +31,88 @@ namespace dd { + typedef torch::data::Example, std::vector> TorchBatch; + + class TorchDataset : public torch::data::BatchDataset + > + { + private: + bool _shuffle = false; + long _seed = -1; + std::vector _indices; + + public: + /// Vector containing the whole dataset (the "cached data"). + std::vector _batches; + + + TorchDataset() {} + + void add_batch(std::vector data, std::vector target = {}); + + void reset(); + + /// Size of data loaded in memory + size_t cache_size() const { return _batches.size(); } + + c10::optional size() const override { + return cache_size(); + } + + bool empty() const { return cache_size() == 0; } + + c10::optional get_batch(BatchRequestType request) override; + + /// Returns a batch containing all the cached data + TorchBatch get_cached(); + + /// Split a percentage of this dataset + TorchDataset split(double start, double stop); + }; + + + struct MaskedLMParams + { + double _change_prob = 0.15; /**< When masked LM learning, probability of changing a token (mask/randomize/keep). */ + double _mask_prob = 0.8; /**< When masked LM learning, probability of masking a token. */ + double _rand_prob = 0.1; /**< When masked LM learning, probability of randomizing a token. */ + }; + + class TorchInputInterface { public: TorchInputInterface() {} TorchInputInterface(const TorchInputInterface &i) - : _in(i._in), _attention_mask(i._attention_mask) {} + : _finetuning(i._finetuning), + _lm_params(i._lm_params), + _dataset(i._dataset), + _test_dataset(i._test_dataset), + _input_format(i._input_format) { } ~TorchInputInterface() {} torch::Tensor toLongTensor(std::vector &values) { int64_t val_size = values.size(); - return torch::from_blob(&values[0], at::IntList{val_size}, at::kLong); + return torch::from_blob(&values[0], at::IntList{val_size}, at::kLong).clone(); } - at::Tensor _in; - at::Tensor _attention_mask; + TorchBatch generate_masked_lm_batch(const TorchBatch &example) { return {}; } + + int64_t mask_id() const { return 0; } + int64_t vocab_size() const { return 0; } + std::string get_word(int64_t id) const { return ""; } + + + TorchDataset _dataset; + TorchDataset _test_dataset; + + MaskedLMParams _lm_params; + bool _finetuning; + /** Tell which inputs should be provided to the models. + * see*/ + std::string _input_format; + std::vector _lengths;/**< length of each sentence with txt connector. */ }; class ImgTorchInputFileConn : public ImgInputFileConn, public TorchInputInterface @@ -74,7 +140,7 @@ namespace dd { ImgInputFileConn::init(ad); } - + void transform(const APIData &ad) { try @@ -85,29 +151,28 @@ namespace dd { throw; } - + std::vector tensors; std::vector sizes{ _height, _width, 3 }; at::TensorOptions options(at::ScalarType::Byte); for (const cv::Mat &bgr : this->_images) { - at::Tensor imgt = torch::from_blob(bgr.data, at::IntList(sizes), options); - imgt = imgt.toType(at::kFloat).permute({2, 0, 1}); - size_t nchannels = imgt.size(0); - if (_scale != 1.0) - imgt = imgt.mul(_scale); - if (!_mean.empty() && _mean.size() != nchannels) - throw InputConnectorBadParamException("mean vector be of size the number of channels (" + std::to_string(nchannels) + ")"); - for (size_t m=0;m<_mean.size();m++) - imgt[0][m] = imgt[0][m].sub_(_mean.at(m)); - if (!_std.empty() && _std.size() != nchannels) - throw InputConnectorBadParamException("std vector be of size the number of channels (" + std::to_string(nchannels) + ")"); - for (size_t s=0;s<_std.size();s++) - imgt[0][s] = imgt[0][s].div_(_std.at(s)); - tensors.push_back(imgt); + at::Tensor imgt = torch::from_blob(bgr.data, at::IntList(sizes), options); + imgt = imgt.toType(at::kFloat).permute({2, 0, 1}); + size_t nchannels = imgt.size(0); + if (_scale != 1.0) + imgt = imgt.mul(_scale); + if (!_mean.empty() && _mean.size() != nchannels) + throw InputConnectorBadParamException("mean vector be of size the number of channels (" + std::to_string(nchannels) + ")"); + for (size_t m=0;m<_mean.size();m++) + imgt[0][m] = imgt[0][m].sub_(_mean.at(m)); + if (!_std.empty() && _std.size() != nchannels) + throw InputConnectorBadParamException("std vector be of size the number of channels (" + std::to_string(nchannels) + ")"); + for (size_t s=0;s<_std.size();s++) + imgt[0][s] = imgt[0][s].div_(_std.at(s)); + tensors.push_back(imgt); + _dataset.add_batch({imgt}); } - - _in = torch::stack(tensors, 0); } public: @@ -127,6 +192,14 @@ namespace dd _width(i._width), _height(i._height) {} ~TxtTorchInputFileConn() {} + void init(const APIData &ad) + { + TxtInputFileConn::init(ad); + fillup_parameters(ad); + } + + void fillup_parameters(const APIData &ad_input); + // for API info only int width() const { @@ -139,11 +212,42 @@ namespace dd return _height; } + int64_t mask_id() const { return _mask_id; } + + int64_t vocab_size() const { return _vocab.size(); } + + std::string get_word(int64_t id) const { + return _inv_vocab.at(id); + } + void transform(const APIData &ad); + TorchBatch generate_masked_lm_batch(const TorchBatch &example); + + void fill_dataset(TorchDataset &dataset, const std::vector*> &entries); public: + /** width of the input tensor */ int _width = 512; int _height = 0; + std::mt19937 _rng; + /// token id to vocabulary word + std::map _inv_vocab; + + int64_t _mask_id = -1; /**< ID of mask token in the vocabulary. */ + int64_t _cls_pos = -1; + int64_t _sep_pos = -1; + int64_t _unk_pos = -1; + int64_t _eot_pos = -1; /**< end of text */ + + + void make_inv_vocab() { + _inv_vocab.clear(); + + for (auto &entry : _vocab) + { + _inv_vocab[entry.second._pos] = entry.first; + } + } }; } // namespace dd diff --git a/src/backends/torch/torchlib.cc b/src/backends/torch/torchlib.cc index 3e9099892..65a1f0eb7 100644 --- a/src/backends/torch/torchlib.cc +++ b/src/backends/torch/torchlib.cc @@ -22,6 +22,9 @@ #include "torchlib.h" #include +#if !defined(CPU_ONLY) +#include +#endif #include "outputconnectorstrategy.h" @@ -29,9 +32,157 @@ using namespace torch; namespace dd { + inline void empty_cuda_cache() { + #if !defined(CPU_ONLY) + c10::cuda::CUDACachingAllocator::emptyCache(); + #endif + } + + void add_parameters(std::shared_ptr module, std::vector ¶ms, bool requires_grad = true) { + for (const auto &slot : module->get_parameters()) { + Tensor tensor = slot.value().toTensor(); + if (tensor.requires_grad() && requires_grad) + params.push_back(tensor); + } + for (auto child : module->get_modules()) { + add_parameters(std::make_shared(child), params); + } + } + + /// Convert IValue to Tensor and throw an exception if the IValue is not a Tensor. + Tensor to_tensor_safe(const IValue &value) { + if (!value.isTensor()) + throw MLLibInternalException("Expected Tensor, found " + value.tagKind()); + return value.toTensor(); + } + + /// Convert id Tensor to one_hot Tensor + void fill_one_hot(Tensor &one_hot, Tensor ids, int nclasses) + { + one_hot.zero_(); + for (int i = 0; i < ids.size(0); ++i) + { + one_hot[i][ids[i].item()] = 1; + } + } + + Tensor to_one_hot(Tensor ids, int nclasses) + { + Tensor one_hot = torch::zeros(IntList{ids.size(0), nclasses}); + for (int i = 0; i < ids.size(0); ++i) + { + one_hot[i][ids[i].item()] = 1; + } + return one_hot; + } + + // ======= TORCH MODULE + + + TorchModule::TorchModule() : _device{"cpu"} {} + + c10::IValue TorchModule::forward(std::vector source) + { + if (_traced) + { + auto output = _traced->forward(source); + if (output.isTensorList()) { + auto elems = output.toTensorList(); + source = std::vector(elems.begin(), elems.end()); + } + else if (output.isTuple()) { + auto &elems = output.toTuple()->elements(); + source = std::vector(elems.begin(), elems.end()); + } + else { + source = { output }; + } + } + c10::IValue out_val = source.at(_classif_in); + if (_hidden_states) + { + // out_val is a tuple containing tensors of dimension n_batch * sequence_lenght * n_features + // We want a tensor of size n_batch * n_features from the last hidden state + auto &elems = out_val.toTuple()->elements(); + out_val = elems.back().toTensor().slice(1, 0, 1).squeeze(1); + } + if (_classif) + { + out_val = _classif->forward(to_tensor_safe(out_val)); + } + return out_val; + } + + void TorchModule::freeze_traced(bool freeze) + { + if (freeze != _freeze_traced) + { + _freeze_traced = freeze; + std::vector params; + add_parameters(_traced, params, false); + for (auto ¶m : params) + { + param.set_requires_grad(!freeze); + } + } + } + + std::vector TorchModule::parameters() + { + std::vector params; + if (_traced) + add_parameters(_traced, params); + if (_classif) + { + auto classif_params = _classif->parameters(); + params.insert(params.end(), classif_params.begin(), classif_params.end()); + } + return params; + } + + void TorchModule::save_checkpoint(TorchModel &model, const std::string &name) + { + if (_traced) + _traced->save(model._repo + "/checkpoint-" + name + ".pt"); + if (_classif) + torch::save(_classif, model._repo + "/checkpoint-" + name + ".ptw"); + } + + void TorchModule::load(TorchModel &model) + { + if (!model._traced.empty()) + _traced = std::make_shared + (torch::jit::load(model._traced, _device)); + if (!model._weights.empty() && _classif) + torch::load(_classif, model._weights); + } + + void TorchModule::eval() { + if (_traced) + _traced->eval(); + if (_classif) + _classif->eval(); + } + + void TorchModule::train() { + if (_traced) + _traced->train(); + if (_classif) + _classif->train(); + } + + void TorchModule::free() + { + _traced = nullptr; + _classif = nullptr; + } + + + // ======= TORCHLIB + template TorchLib::TorchLib(const TorchModel &tmodel) - : MLLib(tmodel) + : MLLib(tmodel) { this->_libname = "torch"; } @@ -41,62 +192,376 @@ namespace dd : MLLib(std::move(tl)) { this->_libname = "torch"; - _traced = std::move(tl._traced); + _module = std::move(tl._module); + _template = tl._template; _nclasses = tl._nclasses; _device = tl._device; - _attention = tl._attention; + _masked_lm = tl._masked_lm; + _seq_training = tl._seq_training; + _finetuning = tl._finetuning; } template - TorchLib::~TorchLib() + TorchLib::~TorchLib() { - + _module.free(); + empty_cuda_cache(); } /*- from mllib -*/ template void TorchLib::init_mllib(const APIData &lib_ad) { + bool classification = false; bool gpu = false; int gpuid = -1; + bool freeze_traced = false; + int embedding_size = 768; + std::string self_supervised = ""; - if (lib_ad.has("gpu")) { + if (lib_ad.has("template")) + _template = lib_ad.get("template").get(); + if (lib_ad.has("gpu")) gpu = lib_ad.get("gpu").get() && torch::cuda::is_available(); - } if (lib_ad.has("gpuid")) gpuid = lib_ad.get("gpuid").get(); - if (lib_ad.has("nclasses")) { + if (lib_ad.has("nclasses")) + { + classification = true; _nclasses = lib_ad.get("nclasses").get(); } + if (lib_ad.has("self_supervised")) + self_supervised = lib_ad.get("self_supervised").get(); + if (lib_ad.has("embedding_size")) + embedding_size = lib_ad.get("embedding_size").get(); + if (lib_ad.has("finetuning")) + _finetuning = lib_ad.get("finetuning").get(); + if (lib_ad.has("freeze_traced")) + freeze_traced = lib_ad.get("freeze_traced").get(); _device = gpu ? torch::Device(DeviceType::CUDA, gpuid) : torch::Device(DeviceType::CPU); + _module._device = _device; - if (typeid(TInputConnectorStrategy) == typeid(TxtTorchInputFileConn)) { - _attention = true; + // Create the model + if (this->_mlmodel._traced.empty()) + throw MLLibInternalException("Use of libtorch backend without traced net is not supported yet"); + + this->_inputc._input_format = "bert"; + if (_template == "bert") + { + if (classification) + { + _module._classif = nn::Linear(embedding_size, _nclasses); + _module._classif->to(_device); + _module._hidden_states = true; + _module._classif_in = 1; + } + else if (!self_supervised.empty()) + { + if (self_supervised != "mask") + { + throw MLLibBadParamException("self_supervised"); + } + this->_logger->info("Masked Language model"); + _masked_lm = true; + _seq_training = true; + } + else + { + throw MLLibBadParamException("BERT only supports self-supervised or classification"); + } } + else if (_template == "gpt2") + { + this->_inputc._input_format = "gpt2"; + _seq_training = true; + } + else if (!_template.empty()) + { + throw MLLibBadParamException("template"); + } + + this->_logger->info("Loading ml model from file {}.", this->_mlmodel._traced); + if (!this->_mlmodel._weights.empty()) + this->_logger->info("Loading weights from file {}.", this->_mlmodel._weights); + _module.load(this->_mlmodel); + _module.freeze_traced(freeze_traced); - try - { - _traced = torch::jit::load(this->_mlmodel._model_file, _device); - } - catch (std::exception&) - { - throw MLLibBadParamException("failed loading torch model file " + this->_mlmodel._model_file); - } - - _traced->eval(); + this->_mltype = "classification"; } template - void TorchLib::clear_mllib(const APIData &ad) + void TorchLib::clear_mllib(const APIData &ad) { - (void)ad; + std::vector extensions{".json", ".pt", ".ptw"}; + fileops::remove_directory_files(this->_mlmodel._repo, extensions); + this->_logger->info("Torchlib service cleared"); } template - int TorchLib::train(const APIData &ad, APIData &out) + int TorchLib::train(const APIData &ad, APIData &out) { - (void)out; + this->_tjob_running.store(true); + + TInputConnectorStrategy inputc(this->_inputc); + inputc._train = true; + inputc._finetuning = _finetuning; + + try + { + inputc.transform(ad); + } + catch (...) + { + throw; + } + + APIData ad_mllib = ad.getobj("parameters").getobj("mllib"); + + // solver params + int64_t iterations = 1; + std::string solver_type = "SGD"; + double base_lr = 0.0001; + int64_t batch_size = 1; + int64_t iter_size = 1; + int64_t test_batch_size = 1; + int64_t test_interval = 1; + int64_t save_period = 0; + + // logging parameters + int64_t log_batch_period = 20; + + if (ad_mllib.has("solver")) + { + APIData ad_solver = ad_mllib.getobj("solver"); + if (ad_solver.has("iterations")) + iterations = ad_solver.get("iterations").get(); + if (ad_solver.has("solver_type")) + solver_type = ad_solver.get("solver_type").get(); + if (ad_solver.has("base_lr")) + base_lr = ad_solver.get("base_lr").get(); + if (ad_solver.has("test_interval")) + test_interval = ad_solver.get("test_interval").get(); + if (ad_solver.has("iter_size")) + iter_size = ad_solver.get("iter_size").get(); + if (ad_solver.has("snapshot")) + save_period = ad_solver.get("snapshot").get(); + } + + if (ad_mllib.has("net")) + { + APIData ad_net = ad_mllib.getobj("net"); + if (ad_net.has("batch_size")) + batch_size = ad_net.get("batch_size").get(); + if (ad_net.has("test_batch_size")) + test_batch_size = ad_net.get("test_batch_size").get(); + } + + if (iter_size <= 0) + iter_size = 1; + + // create dataset for evaluation during training + TorchDataset eval_dataset; + if (!inputc._test_dataset.empty()) + { + eval_dataset = inputc._test_dataset; //.split(0, 0.1); + } + + // create solver + std::unique_ptr optimizer; + + if (solver_type == "ADAM") + optimizer = std::unique_ptr( + new optim::Adam(_module.parameters(), optim::AdamOptions(base_lr))); + else if (solver_type == "RMSPROP") + optimizer = std::unique_ptr( + new optim::RMSprop(_module.parameters(), optim::RMSpropOptions(base_lr))); + else if (solver_type == "ADAGRAD") + optimizer = std::unique_ptr( + new optim::Adagrad(_module.parameters(), optim::AdagradOptions(base_lr))); + else + { + if (solver_type != "SGD") + this->_logger->warn("Solver type {} not found, using SGD", solver_type); + optimizer = std::unique_ptr( + new optim::SGD(_module.parameters(), optim::SGDOptions(base_lr))); + } + // reload solver + if (!this->_mlmodel._sstate.empty()) + { + this->_logger->info("Reload solver from {}", this->_mlmodel._sstate); + torch::load(*optimizer, this->_mlmodel._sstate); + } + optimizer->zero_grad(); + _module.train(); + + // create dataloader + auto dataloader = torch::data::make_data_loader( + std::move(inputc._dataset), + data::DataLoaderOptions(batch_size) + ); + + this->_logger->info("Training for {} iterations", iterations); + int it = 0; + int batch_id = 0; + using namespace std::chrono; + + // it is the iteration count (not epoch) + while (it < iterations) + { + if (!this->_tjob_running.load()) + { + break; + } + + double train_loss = 0; + double avg_it_time = 0; + + for (TorchBatch batch : *dataloader) + { + auto tstart = system_clock::now(); + if (_masked_lm) + { + batch = inputc.generate_masked_lm_batch(batch); + } + std::vector in_vals; + for (Tensor tensor : batch.data) + in_vals.push_back(tensor.to(_device)); + Tensor y = batch.target.at(0).to(_device); + + Tensor y_pred; + try + { + y_pred = to_tensor_safe(_module.forward(in_vals)); + } + catch (std::exception &e) + { + throw MLLibInternalException(std::string("Libtorch error:") + e.what()); + } + + // As CrossEntropy is not available (Libtorch 1.1) we use nllloss + log_softmax + Tensor loss; + if (_seq_training) + { + // Convert [n_batch, sequence_length, vocab_size] to [n_batch * sequence_length, vocab_size] + // + ignore non-masked tokens (== -1 in target) + loss = torch::nll_loss( + torch::log_softmax(y_pred.view(IntList{-1, y_pred.size(2)}), 1), + y.view(IntList{-1}), + {}, Reduction::Mean, -1 + ); + } + else + { + loss = torch::nll_loss(torch::log_softmax(y_pred, 1), y.view(IntList{-1})); + } + if (iter_size > 1) + loss /= iter_size; + + double loss_val = loss.item(); + train_loss += loss_val; + loss.backward(); + auto tstop = system_clock::now(); + avg_it_time += duration_cast(tstop - tstart).count(); + + if ((batch_id + 1) % iter_size == 0) + { + if (!this->_tjob_running.load()) + { + break; + } + optimizer->step(); + optimizer->zero_grad(); + avg_it_time /= iter_size; + this->add_meas("learning_rate", base_lr); + this->add_meas("iteration", it); + this->add_meas("iter_time", avg_it_time); + this->add_meas("remain_time", avg_it_time * iter_size * (iterations - it) / 1000.0); + this->add_meas("train_loss", train_loss); + this->add_meas_per_iter("learning_rate", base_lr); + this->add_meas_per_iter("train_loss", train_loss); + int64_t elapsed_it = it + 1; + if (log_batch_period != 0 && elapsed_it % log_batch_period == 0) + { + this->_logger->info("Iteration {}/{}: loss is {}", elapsed_it, iterations, train_loss); + } + avg_it_time = 0; + train_loss = 0; + + if (elapsed_it % test_interval == 0 && elapsed_it != iterations && !eval_dataset.empty()) + { + // Free memory + loss = torch::empty(1); + y_pred = torch::empty(1); + y = torch::empty(1); + in_vals.clear(); + + APIData meas_out; + this->_logger->info("Start test"); + test(ad, inputc, eval_dataset, test_batch_size, meas_out); + APIData meas_obj = meas_out.getobj("measure"); + std::vector meas_names = meas_obj.list_keys(); + + for (auto name : meas_names) + { + if (name != "cmdiag" && name != "cmfull" && name != "labels") + { + double mval = meas_obj.get(name).get(); + this->_logger->info("{}={}", name, mval); + this->add_meas(name, mval); + this->add_meas_per_iter(name, mval); + } + else if (name == "cmdiag") + { + std::vector mdiag = meas_obj.get(name).get>(); + std::vector cnames; + std::string mdiag_str; + for (size_t i=0; i_mlmodel.get_hcorresp(i) + ":" + std::to_string(mdiag.at(i)) + " "; + this->add_meas_per_iter(name+'_'+this->_mlmodel.get_hcorresp(i), mdiag.at(i)); + cnames.push_back(this->_mlmodel.get_hcorresp(i)); + } + this->_logger->info("{}=[{}]", name, mdiag_str); + this->add_meas(name, mdiag, cnames); + } + } + } + + if ((save_period != 0 && elapsed_it % save_period == 0) || elapsed_it == iterations) + { + this->_logger->info("Saving checkpoint after {} iterations", elapsed_it); + _module.save_checkpoint(this->_mlmodel, std::to_string(elapsed_it)); + // Save optimizer + torch::save(*optimizer, this->_mlmodel._repo + "/solver-" + std::to_string(elapsed_it) + ".pt"); + } + ++it; + + if (it >= iterations) + break; + } + + ++batch_id; + } + } + + if (!this->_tjob_running.load()) + { + this->_logger->info("Training job interrupted at iteration {}", it); + empty_cuda_cache(); + return -1; + } + + test(ad, inputc, inputc._test_dataset, test_batch_size, out); + empty_cuda_cache(); + + // Update model after training + this->_mlmodel.read_from_repository(this->_logger); + this->_mlmodel.read_corresp_file(); + + inputc.response_params(out); + this->_logger->info("Training done."); + return 0; } template @@ -114,68 +579,187 @@ namespace dd } torch::Device cpu("cpu"); - std::vector in_vals; - in_vals.push_back(inputc._in.to(_device)); + _module.eval(); - if (_attention) { - // token_type_ids - in_vals.push_back(torch::zeros_like(inputc._in, at::kLong).to(_device)); - in_vals.push_back(inputc._attention_mask.to(_device)); + if (output_params.has("measure")) + { + APIData meas_out; + test(ad, inputc, inputc._dataset, 1, meas_out); + meas_out.erase("iteration"); + meas_out.erase("train_loss"); + out.add("measure", meas_out.getobj("measure")); + empty_cuda_cache(); + return 0; } + inputc._dataset.reset(); + TorchBatch batch = inputc._dataset.get_cached(); + + std::vector in_vals; + for (Tensor tensor : batch.data) + in_vals.push_back(tensor.to(_device)); Tensor output; try { - c10::IValue out_val = _traced->forward(in_vals); - if (out_val.isTuple()) { - out_val = out_val.toTuple()->elements()[0]; + output = to_tensor_safe(_module.forward(in_vals)); + if (_template == "gpt2") + { + // Keep only the prediction for the last token + Tensor input_ids = batch.data[0]; + std::vector outputs; + for (int i = 0; i < input_ids.size(0); ++i) + { + // output is (n_batch * sequence_length * vocab_size) + // With gpt2, last token is endoftext so we need to take the previous output. + outputs.push_back(output[i][inputc._lengths.at(i) - 2]); + } + output = torch::stack(outputs); } - if (!out_val.isTensor()) { - throw MLLibInternalException("Model returned an invalid output. Please check your model."); - } - output = out_val.toTensor().to(at::kFloat); + output = torch::softmax(output, 1).to(cpu); } catch (std::exception &e) { throw MLLibInternalException(std::string("Libtorch error:") + e.what()); } - output = torch::softmax(output, 1).to(cpu); - + // Output std::vector results_ads; - // classification - std::tuple sorted_output = output.sort(1, true); - auto probs_acc = std::get<0>(sorted_output).accessor(); - auto indices_acc = std::get<1>(sorted_output).accessor(); - - for (int i = 0; i < output.size(0); ++i) - { - APIData results_ad; - std::vector probs; - std::vector cats; - - for (int j = 0; j < probs_acc.size(1); ++j) - { - probs.push_back(probs_acc[i][j]); - int index = indices_acc[i][j]; - cats.push_back(this->_mlmodel.get_hcorresp(index)); - } - - results_ad.add("uri", inputc._uris.at(results_ads.size())); - results_ad.add("loss", 0.0); - results_ad.add("cats", cats); - results_ad.add("probs", probs); - - results_ads.push_back(results_ad); - } - - out.add("nclasses",static_cast(probs_acc.size(1))); + if (output_params.has("best")) + { + const int best_count = output_params.get("best").get(); + std::tuple sorted_output = output.sort(1, true); + auto probs_acc = std::get<0>(sorted_output).accessor(); + auto indices_acc = std::get<1>(sorted_output).accessor(); + + for (int i = 0; i < output.size(0); ++i) + { + APIData results_ad; + std::vector probs; + std::vector cats; + + for (int j = 0; j < best_count; ++j) + { + probs.push_back(probs_acc[i][j]); + int index = indices_acc[i][j]; + if (_seq_training) + { + cats.push_back(inputc.get_word(index)); + } + else + { + cats.push_back(this->_mlmodel.get_hcorresp(index)); + } + } + + results_ad.add("uri", inputc._uris.at(results_ads.size())); + results_ad.add("loss", 0.0); + results_ad.add("cats", cats); + results_ad.add("probs", probs); + results_ad.add("nclasses", _nclasses); + + results_ads.push_back(results_ad); + } + } + outputc.add_results(results_ads); outputc.finalize(output_params, out, static_cast(&this->_mlmodel)); out.add("status", 0); + return 0; + } + + template + int TorchLib::test(const APIData &ad, + TInputConnectorStrategy &inputc, + TorchDataset &dataset, + int batch_size, + APIData &out) + { + APIData ad_res; + APIData ad_out = ad.getobj("parameters").getobj("output"); + int nclasses = _masked_lm ? inputc.vocab_size() : _nclasses; + + // confusion matrix is irrelevant to masked_lm training + if (_masked_lm && ad_out.has("measure")) + { + auto meas = ad_out.get("measure").get>(); + std::vector::iterator it; + if ((it = std::find(meas.begin(), meas.end(), "cmfull")) != meas.end()) + meas.erase(it); + if ((it = std::find(meas.begin(), meas.end(), "cmdiag")) != meas.end()) + meas.erase(it); + ad_out.add("measure", meas); + } + + auto dataloader = torch::data::make_data_loader( + dataset, + data::DataLoaderOptions(batch_size) + ); + torch::Device cpu("cpu"); + + _module.eval(); + int entry_id = 0; + for (TorchBatch batch : *dataloader) + { + if (_masked_lm) + { + batch = inputc.generate_masked_lm_batch(batch); + } + std::vector in_vals; + for (Tensor tensor : batch.data) + in_vals.push_back(tensor.to(_device)); + + Tensor output; + try + { + output = to_tensor_safe(_module.forward(in_vals)); + } + catch (std::exception &e) + { + throw MLLibInternalException(std::string("Libtorch error:") + e.what()); + } + + if (batch.target.empty()) + throw MLLibBadParamException("Missing label on data while testing"); + Tensor labels = batch.target[0].view(IntList{-1}); + if (_masked_lm) + { + // Convert [n_batch, sequence_length, vocab_size] to [n_batch * sequence_length, vocab_size] + output = output.view(IntList{-1, output.size(2)}); + } + output = torch::softmax(output, 1).to(cpu); + auto output_acc = output.accessor(); + auto labels_acc = labels.accessor(); + + for (int j = 0; j < labels.size(0); ++j) + { + if (_masked_lm && labels_acc[j] == -1) + continue; + + APIData bad; + std::vector predictions; + for (int c = 0; c < nclasses; c++) + { + predictions.push_back(output_acc[j][c]); + } + bad.add("target", static_cast(labels_acc[j])); + bad.add("pred", predictions); + ad_res.add(std::to_string(entry_id), bad); + ++entry_id; + } + // this->_logger->info("Testing: {}/{} entries processed", entry_id, test_size); + } + ad_res.add("iteration",this->get_meas("iteration")); + ad_res.add("train_loss",this->get_meas("train_loss")); + std::vector clnames; + for (int i=0;i< nclasses;i++) + clnames.push_back(this->_mlmodel.get_hcorresp(i)); + ad_res.add("clnames", clnames); + ad_res.add("nclasses", nclasses); + ad_res.add("batch_size", entry_id); // here batch_size = tested entries count + SupervisedOutput::measure(ad_res, ad_out, out); return 0; } diff --git a/src/backends/torch/torchlib.h b/src/backends/torch/torchlib.h index ddb6809b5..441cd069a 100644 --- a/src/backends/torch/torchlib.h +++ b/src/backends/torch/torchlib.h @@ -22,6 +22,8 @@ #ifndef TORCHLIB_H #define TORCHLIB_H +#include + #include #include "apidata.h" @@ -32,6 +34,41 @@ namespace dd { + // TODO: Make TorchModule inherit torch::nn::Module ? And use the TORCH_MODULE macro + class TorchModule { + public: + TorchModule(); + + c10::IValue forward(std::vector source); + + void freeze_traced(bool freeze); + + std::vector parameters(); + + /** Save traced module to checkpoint-[name].pt, and custom parts weights + * to checkpoint-[name].ptw */ + // (Actually only _classif is saved in the .ptw) + void save_checkpoint(TorchModel &model, const std::string &name); + + /** Load traced module from .pt and custom parts weights from .ptw */ + void load(TorchModel &model); + + void eval(); + void train(); + + void free(); + public: + std::shared_ptr _traced; + torch::nn::Linear _classif = nullptr; + + torch::Device _device; + int _classif_in = 0; /**