From ab41f43d8a81445cd91289f95beeee6327ef7478 Mon Sep 17 00:00:00 2001 From: lxning Date: Sun, 4 Feb 2024 21:14:11 -0800 Subject: [PATCH 01/42] fix compile error on mac x86 --- cpp/CMakeLists.txt | 2 +- cpp/build.sh | 55 +++++++++++++++++++++++++++++---- cpp/src/backends/CMakeLists.txt | 2 +- cpp/src/examples/CMakeLists.txt | 6 ++-- cpp/src/utils/CMakeLists.txt | 2 +- 5 files changed, 55 insertions(+), 12 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 76ca7b29e5..4a7ad7eec8 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.14 FATAL_ERROR) +cmake_minimum_required(VERSION 3.18 FATAL_ERROR) project(torchserve_cpp VERSION 0.1) set(CMAKE_CXX_STANDARD 17) diff --git a/cpp/build.sh b/cpp/build.sh index 165cf17cbb..ef4fc8b118 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -28,7 +28,7 @@ function install_folly() { echo -e "${COLOR_GREEN}[ INFO ] Cloning folly repo ${COLOR_OFF}" git clone https://github.com/facebook/folly.git "$FOLLY_SRC_DIR" cd $FOLLY_SRC_DIR - git checkout tags/v2022.06.27.00 + git checkout tags/v2024.01.29.00 fi if [ ! -d "$FOLLY_BUILD_DIR" ] ; then @@ -75,7 +75,18 @@ function install_kineto() { function install_libtorch() { if [ "$PLATFORM" = "Mac" ]; then - echo -e "${COLOR_GREEN}[ INFO ] Skip install libtorch on Mac ${COLOR_OFF}" + cd "$DEPS_DIR" || exit + if [[ $(uname -m) == 'x86_64' ]]; then + echo -e "${COLOR_GREEN}[ INFO ] Install libtorch on Mac x86_64 ${COLOR_OFF}" + wget https://download.pytorch.org/libtorch/cpu/libtorch-macos-x86_64-2.2.0.zip + unzip libtorch-macos-x86_64-2.2.0.zip + rm libtorch-macos-x86_64-2.2.0.zip + else + echo -e "${COLOR_GREEN}[ INFO ] Install libtorch on Mac arm64 ${COLOR_OFF}" + wget https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.2.0.zip + unzip libtorch-macos-arm64-2.2.0.zip + rm libtorch-macos-arm64-2.2.0.zip + fi elif [ ! -d "$DEPS_DIR/libtorch" ] ; then cd "$DEPS_DIR" || exit if [ "$PLATFORM" = "Linux" ]; then @@ -113,7 +124,7 @@ function install_yaml_cpp() { echo -e "${COLOR_GREEN}[ INFO ] Cloning yaml-cpp repo ${COLOR_OFF}" git clone https://github.com/jbeder/yaml-cpp.git "$YAML_CPP_SRC_DIR" cd $YAML_CPP_SRC_DIR - git checkout tags/yaml-cpp-0.7.0 + git checkout tags/0.8.0 fi if [ ! -d "$YAML_CPP_BUILD_DIR" ] ; then @@ -136,11 +147,42 @@ function install_yaml_cpp() { cd "$BWD" || exit } +function install_sentencepiece() { + SENTENCEPIECE_SRC_DIR=$BASE_DIR/third-party/sentencepiece + SENTENCEPIECE_BUILD_DIR=$DEPS_DIR/sentencepiece-build + + if [ ! -d "$SENTENCEPIECE_SRC_DIR" ] ; then + echo -e "${COLOR_GREEN}[ INFO ] Cloning sentencepiece repo ${COLOR_OFF}" + git clone https://github.com/google/sentencepiece.git "$SENTENCEPIECE_SRC_DIR" + cd $SENTENCEPIECE_SRC_DIR + git checkout tags/v0.1.99 + fi + + if [ ! -d "$SENTENCEPIECE_BUILD_DIR" ] ; then + echo -e "${COLOR_GREEN}[ INFO ] Building sentencepiece ${COLOR_OFF}" + + mkdir $SENTENCEPIECE_BUILD_DIR + cd $SENTENCEPIECE_BUILD_DIR + cmake $SENTENCEPIECE_SRC_DIR + make -i $(nproc) + if [ "$PLATFORM" = "Linux" ]; then + sudo make install + sudo ldconfig -v + elif [ "$PLATFORM" = "Mac" ]; then + make install + fi + + echo -e "${COLOR_GREEN}[ INFO ] sentencepiece is installed ${COLOR_OFF}" + fi + + cd "$BWD" || exit +} + function build_llama_cpp() { BWD=$(pwd) LLAMA_CPP_SRC_DIR=$BASE_DIR/third-party/llama.cpp cd "${LLAMA_CPP_SRC_DIR}" - make + make LLAMA_METAL=OFF cd "$BWD" || exit } @@ -191,7 +233,7 @@ function build() { fi elif [ "$PLATFORM" = "Mac" ]; then cmake \ - -DCMAKE_PREFIX_PATH="$(python -c 'import torch; print(torch.utils.cmake_prefix_path)');$DEPS_DIR;$FOLLY_CMAKE_DIR;$YAML_CPP_CMAKE_DIR" \ + -DCMAKE_PREFIX_PATH="$DEPS_DIR;$FOLLY_CMAKE_DIR;$YAML_CPP_CMAKE_DIR;$DEPS_DIR/libtorch" \ -DCMAKE_INSTALL_PREFIX="$PREFIX" \ "$MAYBE_BUILD_QUIC" \ "$MAYBE_BUILD_TESTS" \ @@ -225,7 +267,7 @@ function build() { function symlink_torch_libs() { if [ "$PLATFORM" = "Linux" ]; then - ln -sf ${DEPS_DIR}/libtorch/lib/*.so* ${BUILD_DIR}/libs/ + ln -sf ${DEPS_DIR}/libtorch/lib/*.so* ${LIBS_DIR} fi } @@ -315,6 +357,7 @@ install_folly install_kineto install_libtorch install_yaml_cpp +install_sentencepiece build_llama_cpp build symlink_torch_libs diff --git a/cpp/src/backends/CMakeLists.txt b/cpp/src/backends/CMakeLists.txt index 9824d41f62..8f17339ef8 100644 --- a/cpp/src/backends/CMakeLists.txt +++ b/cpp/src/backends/CMakeLists.txt @@ -22,7 +22,7 @@ list(APPEND BACKEND_SOURCE_FILES ${TS_BACKENDS_SRC_DIR}/handler/base_handler.cc) list(APPEND BACKEND_SOURCE_FILES ${TS_BACKENDS_SRC_DIR}/handler/torch_scripted_handler.cc) add_library(ts_backends_core SHARED ${BACKEND_SOURCE_FILES}) target_include_directories(ts_backends_core PUBLIC ${TS_BACKENDS_CORE_SRC_DIR}) -target_link_libraries(ts_backends_core PUBLIC ts_utils ts_backends_protocol ${FOLLY_LIBRARIES}) +target_link_libraries(ts_backends_core PUBLIC ts_utils ts_backends_protocol ${FOLLY_LIBRARIES} ${TORCH_LIBRARIES}) install(TARGETS ts_backends_core DESTINATION ${torchserve_cpp_SOURCE_DIR}/_build/libs) # build exe model_worker_socket diff --git a/cpp/src/examples/CMakeLists.txt b/cpp/src/examples/CMakeLists.txt index a313616270..09e9f710e2 100644 --- a/cpp/src/examples/CMakeLists.txt +++ b/cpp/src/examples/CMakeLists.txt @@ -1,6 +1,6 @@ -add_subdirectory("../../../examples/cpp/babyllama/" "../../../test/resources/examples/babyllama/babyllama_handler/") +add_subdirectory("../../../examples/cpp/babyllama/" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/babyllama/babyllama_handler/") -add_subdirectory("../../../examples/cpp/llamacpp/" "../../../test/resources/examples/llamacpp/llamacpp_handler/") +add_subdirectory("../../../examples/cpp/llamacpp/" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/llamacpp/llamacpp_handler/") -add_subdirectory("../../../examples/cpp/mnist/" "../../../test/resources/examples/mnist/mnist_handler/") +add_subdirectory("../../../examples/cpp/mnist/" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/mnist/mnist_handler/") diff --git a/cpp/src/utils/CMakeLists.txt b/cpp/src/utils/CMakeLists.txt index ab7940fab4..666b26a5eb 100644 --- a/cpp/src/utils/CMakeLists.txt +++ b/cpp/src/utils/CMakeLists.txt @@ -12,7 +12,7 @@ list(APPEND TS_UTILS_SOURCE_FILES ${TS_UTILS_SRC_DIR}/metrics/registry.cc) add_library(ts_utils SHARED ${TS_UTILS_SOURCE_FILES}) target_include_directories(ts_utils PUBLIC ${TS_UTILS_SRC_DIR}) target_include_directories(ts_utils PRIVATE ${Boost_INCLUDE_DIRS}) -target_link_libraries(ts_utils ${FOLLY_LIBRARIES} ${CMAKE_DL_LIBS} ${Boost_LIBRARIES} yaml-cpp) +target_link_libraries(ts_utils ${FOLLY_LIBRARIES} ${CMAKE_DL_LIBS} ${Boost_LIBRARIES} yaml-cpp::yaml-cpp) install(TARGETS ts_utils DESTINATION ${torchserve_cpp_SOURCE_DIR}/_build/libs) list(APPEND FOO_SOURCE_FILES ${TS_UTILS_SRC_DIR}/ifoo.hh) From 9f85b38ce4ea91c06beb5e2b109ca9731db629f3 Mon Sep 17 00:00:00 2001 From: lxning Date: Mon, 5 Feb 2024 10:48:31 -0800 Subject: [PATCH 02/42] update install libtorch --- cpp/build.sh | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/cpp/build.sh b/cpp/build.sh index ef4fc8b118..e9f4a4d3d2 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -74,22 +74,21 @@ function install_kineto() { } function install_libtorch() { - if [ "$PLATFORM" = "Mac" ]; then + if [ ! -d "$DEPS_DIR/libtorch" ] ; then cd "$DEPS_DIR" || exit - if [[ $(uname -m) == 'x86_64' ]]; then - echo -e "${COLOR_GREEN}[ INFO ] Install libtorch on Mac x86_64 ${COLOR_OFF}" - wget https://download.pytorch.org/libtorch/cpu/libtorch-macos-x86_64-2.2.0.zip - unzip libtorch-macos-x86_64-2.2.0.zip - rm libtorch-macos-x86_64-2.2.0.zip - else - echo -e "${COLOR_GREEN}[ INFO ] Install libtorch on Mac arm64 ${COLOR_OFF}" - wget https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.2.0.zip - unzip libtorch-macos-arm64-2.2.0.zip - rm libtorch-macos-arm64-2.2.0.zip - fi - elif [ ! -d "$DEPS_DIR/libtorch" ] ; then - cd "$DEPS_DIR" || exit - if [ "$PLATFORM" = "Linux" ]; then + if [ "$PLATFORM" = "Mac" ]; then + if [[ $(uname -m) == 'x86_64' ]]; then + echo -e "${COLOR_GREEN}[ INFO ] Install libtorch on Mac x86_64 ${COLOR_OFF}" + wget https://download.pytorch.org/libtorch/cpu/libtorch-macos-x86_64-2.2.0.zip + unzip libtorch-macos-x86_64-2.2.0.zip + rm libtorch-macos-x86_64-2.2.0.zip + else + echo -e "${COLOR_GREEN}[ INFO ] Install libtorch on Mac arm64 ${COLOR_OFF}" + wget https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.2.0.zip + unzip libtorch-macos-arm64-2.2.0.zip + rm libtorch-macos-arm64-2.2.0.zip + fi + elif [ "$PLATFORM" = "Linux" ]; then echo -e "${COLOR_GREEN}[ INFO ] Install libtorch on Linux ${COLOR_OFF}" if [ "$CUDA" = "cu118" ]; then wget https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.1%2Bcu118.zip From b1fadcaedf3121239aa9b252835214da506b0665 Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 7 Feb 2024 10:30:31 -0800 Subject: [PATCH 03/42] fmt --- cpp/src/examples/CMakeLists.txt | 2 + examples/cpp/aot_inductor/bert/CMakeLists.txt | 3 + .../cpp/aot_inductor/bert/src/bert_handler.cc | 221 ++++++++++++++++++ .../cpp/aot_inductor/bert/src/bert_handler.hh | 57 +++++ 4 files changed, 283 insertions(+) create mode 100644 examples/cpp/aot_inductor/bert/CMakeLists.txt create mode 100644 examples/cpp/aot_inductor/bert/src/bert_handler.cc create mode 100644 examples/cpp/aot_inductor/bert/src/bert_handler.hh diff --git a/cpp/src/examples/CMakeLists.txt b/cpp/src/examples/CMakeLists.txt index 09e9f710e2..13122976f0 100644 --- a/cpp/src/examples/CMakeLists.txt +++ b/cpp/src/examples/CMakeLists.txt @@ -4,3 +4,5 @@ add_subdirectory("../../../examples/cpp/babyllama/" "${CMAKE_CURRENT_BINARY_DIR} add_subdirectory("../../../examples/cpp/llamacpp/" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/llamacpp/llamacpp_handler/") add_subdirectory("../../../examples/cpp/mnist/" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/mnist/mnist_handler/") + +add_subdirectory("../../../examples/cpp/aot_inductor/bert" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/aot_inductor/bert/") diff --git a/examples/cpp/aot_inductor/bert/CMakeLists.txt b/examples/cpp/aot_inductor/bert/CMakeLists.txt new file mode 100644 index 0000000000..b1e9f8147c --- /dev/null +++ b/examples/cpp/aot_inductor/bert/CMakeLists.txt @@ -0,0 +1,3 @@ +find_library(SENTENCEPIECE sentencepiece) +add_library(bert_handler SHARED src/bert_handler.cc) +target_link_libraries(bert_handler PRIVATE ts_backends_core ts_utils ${TORCH_LIBRARIES} ${SENTENCEPIECE}) diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.cc b/examples/cpp/aot_inductor/bert/src/bert_handler.cc new file mode 100644 index 0000000000..844e13a782 --- /dev/null +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.cc @@ -0,0 +1,221 @@ +#include "bert_handler.hh" + +#include + +namespace bert { +std::unique_ptr BertCppHandler::LoadJsonFile(const std::string& file_path) { + std::string content; + if (!folly::readFile(file_path.c_str(), content)) { + TS_LOGF(ERROR, "{}} not found", file_path); + throw; + } + return std::make_unique(folly::parseJson(content)); +} + +const folly::dynamic& BertCppHandler::GetJsonValue(std::unique_ptr& json, const std::string& key) { + if (json->find(key) != json->items().end()) { + return (*json)[key]; + } else { + TS_LOG(ERROR, "Required field {} not found in JSON.", key); + throw ; + } +} + +std::pair, std::shared_ptr> +BertCppHandler::LoadModel( + std::shared_ptr& load_model_request) { + try { + auto device = GetTorchDevice(load_model_request); + + const std::string mapFilePath = + fmt::format("{}/{}", load_model_request->model_dir, "index_to_name.json"); + mapping_json_ = LoadJsonFile(mapFilePath); + + const std::string configFilePath = + fmt::format("{}/{}", load_model_request->model_dir, "config.json"); + config_json_ = LoadJsonFile(configFilePath); + max_length_ = static_cast(GetJsonValue(config_json_, "max_length").asInt()); + + bool lower_case = GetJsonValue(config_json_, "do_lower_case").asBool(); + std::string tokenizer_path = GetJsonValue(config_json_, "tokenizer_path").asString(); + auto status = sentence_piece_.Load(tokenizer_path); + if (!status.ok()) { + throw std::runtime_error(fmt::format( + "loading tokenizer: {}, error: {}", tokenizer_path, status.ToString() + )); + } + /* + if (lower_case) { + sentence_piece_.SetNormalizer( + std::make_unique( + sentencepiece::SentencePieceTrainer::GetNormalizerSpec("nmt_nfkc_cf"))); + } + */ + + std::string model_so_path = folly::parseJson("model_so_path").asString();; + c10::InferenceMode mode; + + if (device->is_cuda()) { + return std::make_pair( + std::make_shared(model_so_path.c_str(), 1, device->str().c_str()), + device); + } else { + return std::make_pair( + std::make_shared(model_so_path.c_str()), + device); + } + } catch (const c10::Error& e) { + TS_LOGF(ERROR, "loading the model: {}, device id: {}, error: {}", + load_model_request->model_name, load_model_request->gpu_id, + e.msg()); + throw e; + } catch (const std::runtime_error& e) { + TS_LOGF(ERROR, "loading the model: {}, device id: {}, error: {}", + load_model_request->model_name, load_model_request->gpu_id, + e.what()); + throw e; + } +} + +c10::IValue BertCppHandler::Preprocess( + std::shared_ptr &device, + std::pair &> &idx_to_req_id, + std::shared_ptr &request_batch, + std::shared_ptr &response_batch) { + auto batch_ivalue = c10::impl::GenericList(torch::TensorType::get()); + uint8_t idx = 0; + for (auto& request : *request_batch) { + try { + (*response_batch)[request.request_id] = + std::make_shared(request.request_id); + idx_to_req_id.first += idx_to_req_id.first.empty() + ? request.request_id + : "," + request.request_id; + + auto data_it = request.parameters.find( + torchserve::PayloadType::kPARAMETER_NAME_DATA); + auto dtype_it = + request.headers.find(torchserve::PayloadType::kHEADER_NAME_DATA_TYPE); + if (data_it == request.parameters.end()) { + data_it = request.parameters.find( + torchserve::PayloadType::kPARAMETER_NAME_BODY); + dtype_it = request.headers.find( + torchserve::PayloadType::kHEADER_NAME_BODY_TYPE); + } + + if (data_it == request.parameters.end() || + dtype_it == request.headers.end()) { + TS_LOGF(ERROR, "Empty payload for request id: {}", request.request_id); + (*response_batch)[request.request_id]->SetResponse( + 500, "data_type", torchserve::PayloadType::kCONTENT_TYPE_TEXT, + "Empty payload"); + continue; + } + + std::string msg = torchserve::Converter::VectorToStr(data_it->second); + + // tokenization + std::vector token_ids; + sentence_piece_.Encode(msg, &token_ids); + int cur_token_ids_length = (int)token_ids.size(); + + if (cur_token_ids_length > max_length_) { + TS_LOGF(ERROR, "prompt too long ({} tokens, max {})", cur_token_ids_length, max_length_); + } else if (cur_token_ids_length < max_length_) { + // padding token ids + token_ids.insert(token_ids.end(), max_length_ - cur_token_ids_length, sentence_piece_.pad_id()); + } + auto options = torch::TensorOptions().dtype(torch::kInt64); + batch_ivalue.emplace_back(torch::from_blob(token_ids.data(), max_length_, options)); + idx_to_req_id.second[idx++] = request.request_id; + } catch (const std::runtime_error& e) { + TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}", + request.request_id, e.what()); + auto response = (*response_batch)[request.request_id]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "runtime_error, failed to load tensor"); + } catch (const c10::Error& e) { + TS_LOGF(ERROR, "Failed to load tensor for request id: {}, c10 error: {}", + request.request_id, e.msg()); + auto response = (*response_batch)[request.request_id]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "c10 error, failed to load tensor"); + } + } + + return batch_ivalue; +} + +c10::IValue BertCppHandler::Inference( + std::shared_ptr model, c10::IValue &inputs, + std::shared_ptr &device, + std::pair &> &idx_to_req_id, + std::shared_ptr &response_batch) { + c10::InferenceMode mode; + try { + std::shared_ptr runner; + if (device->is_cuda()) { + runner = std::static_pointer_cast(model); + } else { + runner = std::static_pointer_cast(model); + } + + auto batch_output_tensor_vector = runner->run(inputs.toTensorVector()); + return c10::IValue(batch_output_tensor_vector[0]); + } catch (std::runtime_error& e) { + TS_LOG(ERROR, e.what()); + } catch (const c10::Error& e) { + TS_LOGF(ERROR, "Failed to apply inference on input, c10 error:{}", e.msg()); + } +} + +void BertCppHandler::Postprocess( + c10::IValue &inputs, + std::pair &> &idx_to_req_id, + std::shared_ptr &response_batch) { + auto& data = inputs.toTensor(); + for (const auto &kv : idx_to_req_id.second) { + try { + auto out = data[kv.first].unsqueeze(0); + auto y_hat = torch::argmax(out, 1).item(); + auto predicted_idx = std::to_string(y_hat); + auto response = (*response_batch)[kv.second]; + + response->SetResponse(200, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + (*mapping_json_)[predicted_idx].asString()); + } catch (const std::runtime_error &e) { + TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}", + kv.second, e.what()); + auto response = (*response_batch)[kv.second]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "runtime_error, failed to postprocess tensor"); + } catch (const c10::Error &e) { + TS_LOGF(ERROR, + "Failed to postprocess tensor for request id: {}, error: {}", + kv.second, e.msg()); + auto response = (*response_batch)[kv.second]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "c10 error, failed to postprocess tensor"); + } + } +} +} // namespace bert + +#if defined(__linux__) || defined(__APPLE__) +extern "C" { +torchserve::BaseHandler *allocatorBertCppHandler() { + return new bert::BertCppHandler(); +} + +void deleterBertCppHandler(torchserve::BaseHandler *p) { + if (p != nullptr) { + delete static_cast(p); + } +} +} +#endif diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.hh b/examples/cpp/aot_inductor/bert/src/bert_handler.hh new file mode 100644 index 0000000000..b3c3f605e1 --- /dev/null +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.hh @@ -0,0 +1,57 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/backends/handler/base_handler.hh" + +namespace bert { +class BertCppHandler : public torchserve::BaseHandler { + public: + // NOLINTBEGIN(bugprone-exception-escape) + BertCppHandler() = default; + // NOLINTEND(bugprone-exception-escape) + ~BertCppHandler() noexcept = default; + + std::pair, std::shared_ptr> LoadModel( + std::shared_ptr& load_model_request) + override; + + c10::IValue Preprocess( + std::shared_ptr& device, + std::pair&>& idx_to_req_id, + std::shared_ptr& request_batch, + std::shared_ptr& response_batch) + override; + + c10::IValue Inference( + std::shared_ptr model, c10::IValue& inputs, + std::shared_ptr& device, + std::pair&>& idx_to_req_id, + std::shared_ptr& response_batch) + override; + + void Postprocess( + c10::IValue& data, + std::pair&>& idx_to_req_id, + std::shared_ptr& response_batch) + override; + +private: + std::unique_ptr LoadJsonFile(const std::string& file_path); + const folly::dynamic& GetJsonValue(std::unique_ptr& json, const std::string& key); + + std::unique_ptr config_json_; + std::unique_ptr mapping_json_; + sentencepiece::SentencePieceProcessor sentence_piece_; + int max_length_; +}; +} // namespace bert From 0addfffa451b3243f8f36324d258d7e9daf064c6 Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 7 Feb 2024 17:24:48 -0800 Subject: [PATCH 04/42] fmt --- .../aot_inductor/bert/aot_compile_export.py | 111 ++++++++++++++++++ .../cpp/aot_inductor/bert/setup_config.json | 12 ++ 2 files changed, 123 insertions(+) create mode 100644 examples/cpp/aot_inductor/bert/aot_compile_export.py create mode 100644 examples/cpp/aot_inductor/bert/setup_config.json diff --git a/examples/cpp/aot_inductor/bert/aot_compile_export.py b/examples/cpp/aot_inductor/bert/aot_compile_export.py new file mode 100644 index 0000000000..bc57c1f7cb --- /dev/null +++ b/examples/cpp/aot_inductor/bert/aot_compile_export.py @@ -0,0 +1,111 @@ +import json +import os +import sys + +import torch +from transformers import ( + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, + set_seed, +) + +set_seed(1) + + +def transformers_model_dowloader( + mode, + pretrained_model_name, + num_labels, + do_lower_case, + max_length, + batch_size, +): + print("Download model and tokenizer", pretrained_model_name) + # loading pre-trained model and tokenizer + if mode == "sequence_classification": + config = AutoConfig.from_pretrained( + pretrained_model_name, num_labels=num_labels, torchscript=torchscript + ) + model = AutoModelForSequenceClassification.from_pretrained( + pretrained_model_name, config=config + ) + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name, do_lower_case=do_lower_case + ) + else: + sys.exit(f"mode={mode} has not been implemented in this cpp example yet.") + + NEW_DIR = "./Transformer_model" + try: + os.mkdir(NEW_DIR) + except OSError: + print("Creation of directory %s failed" % NEW_DIR) + else: + print("Successfully created directory %s " % NEW_DIR) + + print( + "Save model and tokenizer model based on the setting from setup_config", + pretrained_model_name, + "in directory", + NEW_DIR, + ) + + model.save_pretrained(NEW_DIR) + tokenizer.save_pretrained(NEW_DIR) + + with torch.no_grad(): + model.eval() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + model = model.to(device=device) + dummy_input = "This is a dummy input for torch jit trace" + inputs = tokenizer.encode_plus( + dummy_input, + max_length=int(max_length), + pad_to_max_length=True, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids = torch.cat([inputs["input_ids"]] * batch_size, 0).to(device) + attention_mask = torch.cat([inputs["attention_mask"]] * batch_size, 0).to( + device + ) + batch_dim = torch.export.Dim("batch", min=1, max=batch_size) + torch._C._GLIBCXX_USE_CXX11_ABI = True + model_so_path = torch._export.aot_compile( + model, + (inputs, attention_mask), + dynamic_shapes={"x": {0: batch_dim}}, + options={ + "aot_inductor.output_path": os.path.join(os.getcwd(), "bert-seq.so"), + "max_autotune": True, + }, + ) + + return + + +if __name__ == "__main__": + dirname = os.path.dirname(__file__) + if len(sys.argv) > 1: + filename = os.path.join(dirname, sys.argv[1]) + else: + filename = os.path.join(dirname, "setup_config.json") + f = open(filename) + settings = json.load(f) + mode = settings["mode"] + model_name = settings["model_name"] + num_labels = int(settings["num_labels"]) + do_lower_case = settings["do_lower_case"] + max_length = settings["max_length"] + batch_size = int(settings.get("batch_size", "1")) + + transformers_model_dowloader( + mode, + model_name, + num_labels, + do_lower_case, + max_length, + batch_size, + ) diff --git a/examples/cpp/aot_inductor/bert/setup_config.json b/examples/cpp/aot_inductor/bert/setup_config.json new file mode 100644 index 0000000000..fe09860dce --- /dev/null +++ b/examples/cpp/aot_inductor/bert/setup_config.json @@ -0,0 +1,12 @@ +{ + "model_name":"bert-base-uncased", + "mode":"sequence_classification", + "do_lower_case":true, + "num_labels":"2", + "max_length":"150", + "captum_explanation":false, + "embedding_name": "bert", + "FasterTransformer":false, + "BetterTransformer":false, + "model_parallel":false +} From e64ef7d608da6dd65b69ff9ded43cc392887b288 Mon Sep 17 00:00:00 2001 From: lxning Date: Thu, 8 Feb 2024 00:18:36 -0800 Subject: [PATCH 05/42] fmt --- examples/cpp/aot_inductor/bert/aot_compile_export.py | 11 ++++++----- examples/cpp/aot_inductor/bert/setup_config.json | 3 ++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/cpp/aot_inductor/bert/aot_compile_export.py b/examples/cpp/aot_inductor/bert/aot_compile_export.py index bc57c1f7cb..73baee28ae 100644 --- a/examples/cpp/aot_inductor/bert/aot_compile_export.py +++ b/examples/cpp/aot_inductor/bert/aot_compile_export.py @@ -25,7 +25,7 @@ def transformers_model_dowloader( # loading pre-trained model and tokenizer if mode == "sequence_classification": config = AutoConfig.from_pretrained( - pretrained_model_name, num_labels=num_labels, torchscript=torchscript + pretrained_model_name, num_labels=num_labels, torchscript=False ) model = AutoModelForSequenceClassification.from_pretrained( pretrained_model_name, config=config @@ -63,7 +63,7 @@ def transformers_model_dowloader( inputs = tokenizer.encode_plus( dummy_input, max_length=int(max_length), - pad_to_max_length=True, + padding=True, add_special_tokens=True, return_tensors="pt", ) @@ -71,12 +71,13 @@ def transformers_model_dowloader( attention_mask = torch.cat([inputs["attention_mask"]] * batch_size, 0).to( device ) - batch_dim = torch.export.Dim("batch", min=1, max=batch_size) + batch_dim = torch.export.Dim("batch", min=2, max=8) torch._C._GLIBCXX_USE_CXX11_ABI = True model_so_path = torch._export.aot_compile( model, - (inputs, attention_mask), - dynamic_shapes={"x": {0: batch_dim}}, + (input_ids, attention_mask), + # dynamic_shapes={"input_ids": {0, batch_dim}, "attention_mask": {0, batch_dim}}, + constraints=[batch_dim], options={ "aot_inductor.output_path": os.path.join(os.getcwd(), "bert-seq.so"), "max_autotune": True, diff --git a/examples/cpp/aot_inductor/bert/setup_config.json b/examples/cpp/aot_inductor/bert/setup_config.json index fe09860dce..cf8078a167 100644 --- a/examples/cpp/aot_inductor/bert/setup_config.json +++ b/examples/cpp/aot_inductor/bert/setup_config.json @@ -8,5 +8,6 @@ "embedding_name": "bert", "FasterTransformer":false, "BetterTransformer":false, - "model_parallel":false + "model_parallel":false, + "batch_size": "4" } From ce0e65a910780356af9eb58e573f07f8c6c62928 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Fri, 9 Feb 2024 19:53:19 +0000 Subject: [PATCH 06/42] Set return type of bert model and dynamic shapes --- .../aot_inductor/bert/aot_compile_export.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/examples/cpp/aot_inductor/bert/aot_compile_export.py b/examples/cpp/aot_inductor/bert/aot_compile_export.py index 73baee28ae..47ababf092 100644 --- a/examples/cpp/aot_inductor/bert/aot_compile_export.py +++ b/examples/cpp/aot_inductor/bert/aot_compile_export.py @@ -25,7 +25,10 @@ def transformers_model_dowloader( # loading pre-trained model and tokenizer if mode == "sequence_classification": config = AutoConfig.from_pretrained( - pretrained_model_name, num_labels=num_labels, torchscript=False + pretrained_model_name, + num_labels=num_labels, + torchscript=False, + return_dict=False, ) model = AutoModelForSequenceClassification.from_pretrained( pretrained_model_name, config=config @@ -62,7 +65,7 @@ def transformers_model_dowloader( dummy_input = "This is a dummy input for torch jit trace" inputs = tokenizer.encode_plus( dummy_input, - max_length=int(max_length), + max_length=max_length, padding=True, add_special_tokens=True, return_tensors="pt", @@ -71,13 +74,16 @@ def transformers_model_dowloader( attention_mask = torch.cat([inputs["attention_mask"]] * batch_size, 0).to( device ) - batch_dim = torch.export.Dim("batch", min=2, max=8) + batch_dim = torch.export.Dim("batch", min=1, max=8) + seq_len_dim = torch.export.Dim("seq_len", min=1, max=max_length) torch._C._GLIBCXX_USE_CXX11_ABI = True model_so_path = torch._export.aot_compile( model, (input_ids, attention_mask), - # dynamic_shapes={"input_ids": {0, batch_dim}, "attention_mask": {0, batch_dim}}, - constraints=[batch_dim], + dynamic_shapes={ + "input_ids": (batch_dim, seq_len_dim), + "attention_mask": (batch_dim, seq_len_dim), + }, options={ "aot_inductor.output_path": os.path.join(os.getcwd(), "bert-seq.so"), "max_autotune": True, @@ -99,7 +105,7 @@ def transformers_model_dowloader( model_name = settings["model_name"] num_labels = int(settings["num_labels"]) do_lower_case = settings["do_lower_case"] - max_length = settings["max_length"] + max_length = int(settings["max_length"]) batch_size = int(settings.get("batch_size", "1")) transformers_model_dowloader( From e827f2f69c95929187cd8d7628b6e244cf77d9ca Mon Sep 17 00:00:00 2001 From: lxning Date: Fri, 9 Feb 2024 13:23:43 -0800 Subject: [PATCH 07/42] fix json value --- examples/cpp/aot_inductor/bert/src/bert_handler.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.cc b/examples/cpp/aot_inductor/bert/src/bert_handler.cc index 844e13a782..1fb80b35b3 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.cc +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.cc @@ -52,7 +52,7 @@ BertCppHandler::LoadModel( } */ - std::string model_so_path = folly::parseJson("model_so_path").asString();; + std::string model_so_path = GetJsonValue(config_json_, "model_so_path").asString();; c10::InferenceMode mode; if (device->is_cuda()) { @@ -185,7 +185,7 @@ void BertCppHandler::Postprocess( response->SetResponse(200, "data_type", torchserve::PayloadType::kDATA_TYPE_STRING, - (*mapping_json_)[predicted_idx].asString()); + GetJsonValue(mapping_json_, predicted_idx).asString()); } catch (const std::runtime_error &e) { TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}", kv.second, e.what()); From d0315a61385f51da275b4e0a83645343422d3054 Mon Sep 17 00:00:00 2001 From: lxning Date: Fri, 9 Feb 2024 13:30:36 -0800 Subject: [PATCH 08/42] fix build on linux --- cpp/CMakeLists.txt | 2 +- cpp/build.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 4a7ad7eec8..5176245a14 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.18 FATAL_ERROR) +cmake_minimum_required(VERSION 3.16 FATAL_ERROR) project(torchserve_cpp VERSION 0.1) set(CMAKE_CXX_STANDARD 17) diff --git a/cpp/build.sh b/cpp/build.sh index 6f6dbf81e9..b6a368ddaa 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -171,7 +171,7 @@ function install_sentencepiece() { mkdir $SENTENCEPIECE_BUILD_DIR cd $SENTENCEPIECE_BUILD_DIR - cmake $SENTENCEPIECE_SRC_DIR + cmake -DSPM_ENABLE_TCMALLOC=OFF $SENTENCEPIECE_SRC_DIR make -i $(nproc) if [ "$PLATFORM" = "Linux" ]; then sudo make install From 2692dd2f9e18ed8d7d371b35cc4bf9459ea26715 Mon Sep 17 00:00:00 2001 From: lxning Date: Fri, 9 Feb 2024 21:12:35 -0800 Subject: [PATCH 09/42] add linux dependency --- examples/cpp/aot_inductor/bert/setup_config.json | 4 +++- examples/cpp/aot_inductor/bert/src/bert_handler.cc | 6 +++--- ts_scripts/install_dependencies.py | 2 ++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/cpp/aot_inductor/bert/setup_config.json b/examples/cpp/aot_inductor/bert/setup_config.json index cf8078a167..e7acb6c955 100644 --- a/examples/cpp/aot_inductor/bert/setup_config.json +++ b/examples/cpp/aot_inductor/bert/setup_config.json @@ -9,5 +9,7 @@ "FasterTransformer":false, "BetterTransformer":false, "model_parallel":false, - "batch_size": "4" + "batch_size": "4", + "tokenizer_path": "Transformer_model/model.safetensors", + "model_so_path": "bert-seq.so" } diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.cc b/examples/cpp/aot_inductor/bert/src/bert_handler.cc index 1fb80b35b3..749b892e17 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.cc +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.cc @@ -37,8 +37,8 @@ BertCppHandler::LoadModel( max_length_ = static_cast(GetJsonValue(config_json_, "max_length").asInt()); bool lower_case = GetJsonValue(config_json_, "do_lower_case").asBool(); - std::string tokenizer_path = GetJsonValue(config_json_, "tokenizer_path").asString(); - auto status = sentence_piece_.Load(tokenizer_path); + std::string tokenizer_path = fmt::format("{}/{}", load_model_request->model_dir, GetJsonValue(config_json_, "tokenizer_path").asString()); + auto status = sentence_piece_.LoadFromSerializedProto(tokenizer_path); if (!status.ok()) { throw std::runtime_error(fmt::format( "loading tokenizer: {}, error: {}", tokenizer_path, status.ToString() @@ -52,7 +52,7 @@ BertCppHandler::LoadModel( } */ - std::string model_so_path = GetJsonValue(config_json_, "model_so_path").asString();; + std::string model_so_path = fmt::format("{}/{}", load_model_request->model_dir, GetJsonValue(config_json_, "model_so_path").asString()); c10::InferenceMode mode; if (device->is_cuda()) { diff --git a/ts_scripts/install_dependencies.py b/ts_scripts/install_dependencies.py index 7b31664474..4bc6a367be 100644 --- a/ts_scripts/install_dependencies.py +++ b/ts_scripts/install_dependencies.py @@ -45,6 +45,8 @@ "ninja-build", "clang-tidy", "clang-format", + "build-essential", + "libgoogle-perftools-dev", ) CPP_DARWIN_DEPENDENCIES = ( From 92a62386e7b33fd8b74c8aa1f73edb8388453c12 Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 10 Feb 2024 09:19:13 -0800 Subject: [PATCH 10/42] replace sentenepice with tokenizers-cpp --- cpp/build.sh | 32 ++++--------------- cpp/src/utils/CMakeLists.txt | 7 +++- examples/cpp/aot_inductor/bert/CMakeLists.txt | 5 +-- .../cpp/aot_inductor/bert/src/bert_handler.cc | 19 +++-------- .../cpp/aot_inductor/bert/src/bert_handler.hh | 5 ++- ts_scripts/install_dependencies.py | 1 + 6 files changed, 23 insertions(+), 46 deletions(-) diff --git a/cpp/build.sh b/cpp/build.sh index b6a368ddaa..9ac9c0f2b7 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -155,32 +155,14 @@ function install_yaml_cpp() { cd "$BWD" || exit } -function install_sentencepiece() { - SENTENCEPIECE_SRC_DIR=$BASE_DIR/third-party/sentencepiece - SENTENCEPIECE_BUILD_DIR=$DEPS_DIR/sentencepiece-build +function install_tokenizer_cpp() { + TOKENIZERS_CPP_SRC_DIR=$BASE_DIR/third-party/tokenizers-cpp - if [ ! -d "$SENTENCEPIECE_SRC_DIR" ] ; then + if [ ! -d "$TOKENIZERS_CPP_SRC_DIR" ] ; then echo -e "${COLOR_GREEN}[ INFO ] Cloning sentencepiece repo ${COLOR_OFF}" - git clone https://github.com/google/sentencepiece.git "$SENTENCEPIECE_SRC_DIR" - cd $SENTENCEPIECE_SRC_DIR - git checkout tags/v0.1.99 - fi - - if [ ! -d "$SENTENCEPIECE_BUILD_DIR" ] ; then - echo -e "${COLOR_GREEN}[ INFO ] Building sentencepiece ${COLOR_OFF}" - - mkdir $SENTENCEPIECE_BUILD_DIR - cd $SENTENCEPIECE_BUILD_DIR - cmake -DSPM_ENABLE_TCMALLOC=OFF $SENTENCEPIECE_SRC_DIR - make -i $(nproc) - if [ "$PLATFORM" = "Linux" ]; then - sudo make install - sudo ldconfig -v - elif [ "$PLATFORM" = "Mac" ]; then - make install - fi - - echo -e "${COLOR_GREEN}[ INFO ] sentencepiece is installed ${COLOR_OFF}" + git clone https://github.com/mlc-ai/tokenizers-cpp.git "$TOKENIZERS_CPP_SRC_DIR" + cd $TOKENIZERS_CPP_SRC_DIR + git checkout tags/v0.1.0 fi cd "$BWD" || exit @@ -401,7 +383,7 @@ install_folly install_kineto install_libtorch install_yaml_cpp -install_sentencepiece +install_tokenizer_cpp build_llama_cpp prepare_test_files build diff --git a/cpp/src/utils/CMakeLists.txt b/cpp/src/utils/CMakeLists.txt index 666b26a5eb..89c18772cd 100644 --- a/cpp/src/utils/CMakeLists.txt +++ b/cpp/src/utils/CMakeLists.txt @@ -12,7 +12,12 @@ list(APPEND TS_UTILS_SOURCE_FILES ${TS_UTILS_SRC_DIR}/metrics/registry.cc) add_library(ts_utils SHARED ${TS_UTILS_SOURCE_FILES}) target_include_directories(ts_utils PUBLIC ${TS_UTILS_SRC_DIR}) target_include_directories(ts_utils PRIVATE ${Boost_INCLUDE_DIRS}) -target_link_libraries(ts_utils ${FOLLY_LIBRARIES} ${CMAKE_DL_LIBS} ${Boost_LIBRARIES} yaml-cpp::yaml-cpp) +if(CMAKE_SYSTEM_NAME MATCHES "Darwin") + target_link_libraries(ts_utils ${FOLLY_LIBRARIES} ${CMAKE_DL_LIBS} ${Boost_LIBRARIES} yaml-cpp::yaml-cpp) +else() + target_link_libraries(ts_utils ${FOLLY_LIBRARIES} ${CMAKE_DL_LIBS} ${Boost_LIBRARIES} yaml-cpp) +endif() + install(TARGETS ts_utils DESTINATION ${torchserve_cpp_SOURCE_DIR}/_build/libs) list(APPEND FOO_SOURCE_FILES ${TS_UTILS_SRC_DIR}/ifoo.hh) diff --git a/examples/cpp/aot_inductor/bert/CMakeLists.txt b/examples/cpp/aot_inductor/bert/CMakeLists.txt index b1e9f8147c..e006fcceb1 100644 --- a/examples/cpp/aot_inductor/bert/CMakeLists.txt +++ b/examples/cpp/aot_inductor/bert/CMakeLists.txt @@ -1,3 +1,4 @@ -find_library(SENTENCEPIECE sentencepiece) +set(TOKENZIER_CPP_PATH $CMAKE_INSTALL_PREFIX/../third-party/tokenizers-cpp) add_library(bert_handler SHARED src/bert_handler.cc) -target_link_libraries(bert_handler PRIVATE ts_backends_core ts_utils ${TORCH_LIBRARIES} ${SENTENCEPIECE}) +target_include_directories(bert_handler PRIVATE ${TOKENZIER_CPP_PATH}/include) +target_link_libraries(bert_handler PRIVATE ts_backends_core ts_utils ${TORCH_LIBRARIES} tokenizers_cpp) diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.cc b/examples/cpp/aot_inductor/bert/src/bert_handler.cc index 749b892e17..bb95711b61 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.cc +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.cc @@ -38,19 +38,8 @@ BertCppHandler::LoadModel( bool lower_case = GetJsonValue(config_json_, "do_lower_case").asBool(); std::string tokenizer_path = fmt::format("{}/{}", load_model_request->model_dir, GetJsonValue(config_json_, "tokenizer_path").asString()); - auto status = sentence_piece_.LoadFromSerializedProto(tokenizer_path); - if (!status.ok()) { - throw std::runtime_error(fmt::format( - "loading tokenizer: {}, error: {}", tokenizer_path, status.ToString() - )); - } - /* - if (lower_case) { - sentence_piece_.SetNormalizer( - std::make_unique( - sentencepiece::SentencePieceTrainer::GetNormalizerSpec("nmt_nfkc_cf"))); - } - */ + auto tokenizer_blob = LoadJsonFile(tokenizer_path)->asString(); + tokenizer_ = tokenizers::Tokenizer.FromBlobJSON(tokenizer_blob); std::string model_so_path = fmt::format("{}/{}", load_model_request->model_dir, GetJsonValue(config_json_, "model_so_path").asString()); c10::InferenceMode mode; @@ -116,14 +105,14 @@ c10::IValue BertCppHandler::Preprocess( // tokenization std::vector token_ids; - sentence_piece_.Encode(msg, &token_ids); + tokenizer_.Encode(msg, &token_ids); int cur_token_ids_length = (int)token_ids.size(); if (cur_token_ids_length > max_length_) { TS_LOGF(ERROR, "prompt too long ({} tokens, max {})", cur_token_ids_length, max_length_); } else if (cur_token_ids_length < max_length_) { // padding token ids - token_ids.insert(token_ids.end(), max_length_ - cur_token_ids_length, sentence_piece_.pad_id()); + token_ids.insert(token_ids.end(), max_length_ - cur_token_ids_length, tokenizer_.TokenToId("")); } auto options = torch::TensorOptions().dtype(torch::kInt64); batch_ivalue.emplace_back(torch::from_blob(token_ids.data(), max_length_, options)); diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.hh b/examples/cpp/aot_inductor/bert/src/bert_handler.hh index b3c3f605e1..91b6352c89 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.hh +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.hh @@ -5,8 +5,7 @@ #include #include #include -#include -#include +#include #include #include #include @@ -51,7 +50,7 @@ private: std::unique_ptr config_json_; std::unique_ptr mapping_json_; - sentencepiece::SentencePieceProcessor sentence_piece_; + tokenizers::Tokenizer tokenizer_; int max_length_; }; } // namespace bert diff --git a/ts_scripts/install_dependencies.py b/ts_scripts/install_dependencies.py index 4bc6a367be..cd0488100c 100644 --- a/ts_scripts/install_dependencies.py +++ b/ts_scripts/install_dependencies.py @@ -47,6 +47,7 @@ "clang-format", "build-essential", "libgoogle-perftools-dev", + "rustc", ) CPP_DARWIN_DEPENDENCIES = ( From 94d4309568853557c0da489d64aadb3595d2e058 Mon Sep 17 00:00:00 2001 From: lxning Date: Sun, 11 Feb 2024 00:36:28 -0800 Subject: [PATCH 11/42] update dependency --- cpp/README.md | 1 + cpp/build.sh | 2 +- ts_scripts/install_dependencies.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/README.md b/cpp/README.md index 70b96339b9..dae97087eb 100644 --- a/cpp/README.md +++ b/cpp/README.md @@ -2,6 +2,7 @@ ## Requirements * C++17 * GCC version: gcc-9 +* cmake version: 3.18+ ## Installation and Running TorchServe CPP ### Install dependencies diff --git a/cpp/build.sh b/cpp/build.sh index 9ac9c0f2b7..aa28f3f5c5 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -159,7 +159,7 @@ function install_tokenizer_cpp() { TOKENIZERS_CPP_SRC_DIR=$BASE_DIR/third-party/tokenizers-cpp if [ ! -d "$TOKENIZERS_CPP_SRC_DIR" ] ; then - echo -e "${COLOR_GREEN}[ INFO ] Cloning sentencepiece repo ${COLOR_OFF}" + echo -e "${COLOR_GREEN}[ INFO ] Cloning tokenizers-cpp repo ${COLOR_OFF}" git clone https://github.com/mlc-ai/tokenizers-cpp.git "$TOKENIZERS_CPP_SRC_DIR" cd $TOKENIZERS_CPP_SRC_DIR git checkout tags/v0.1.0 diff --git a/ts_scripts/install_dependencies.py b/ts_scripts/install_dependencies.py index cd0488100c..1b3d1c0b35 100644 --- a/ts_scripts/install_dependencies.py +++ b/ts_scripts/install_dependencies.py @@ -48,6 +48,7 @@ "build-essential", "libgoogle-perftools-dev", "rustc", + "cargo", ) CPP_DARWIN_DEPENDENCIES = ( From fd5e1453af953c290e2f347c512b42da4a2fd1e1 Mon Sep 17 00:00:00 2001 From: lxning Date: Sun, 11 Feb 2024 20:43:20 -0800 Subject: [PATCH 12/42] add attention mask --- examples/cpp/aot_inductor/bert/CMakeLists.txt | 3 ++- examples/cpp/aot_inductor/bert/src/bert_handler.cc | 10 ++++++---- examples/cpp/aot_inductor/bert/src/bert_handler.hh | 7 +------ 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/examples/cpp/aot_inductor/bert/CMakeLists.txt b/examples/cpp/aot_inductor/bert/CMakeLists.txt index e006fcceb1..a4f48301fc 100644 --- a/examples/cpp/aot_inductor/bert/CMakeLists.txt +++ b/examples/cpp/aot_inductor/bert/CMakeLists.txt @@ -1,4 +1,5 @@ -set(TOKENZIER_CPP_PATH $CMAKE_INSTALL_PREFIX/../third-party/tokenizers-cpp) +set(TOKENZIER_CPP_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../../../cpp/third-party/tokenizers-cpp) +add_subdirectory(${TOKENZIER_CPP_PATH} tokenizers EXCLUDE_FROM_ALL) add_library(bert_handler SHARED src/bert_handler.cc) target_include_directories(bert_handler PRIVATE ${TOKENZIER_CPP_PATH}/include) target_link_libraries(bert_handler PRIVATE ts_backends_core ts_utils ${TORCH_LIBRARIES} tokenizers_cpp) diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.cc b/examples/cpp/aot_inductor/bert/src/bert_handler.cc index bb95711b61..2fbc300ff2 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.cc +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.cc @@ -36,7 +36,6 @@ BertCppHandler::LoadModel( config_json_ = LoadJsonFile(configFilePath); max_length_ = static_cast(GetJsonValue(config_json_, "max_length").asInt()); - bool lower_case = GetJsonValue(config_json_, "do_lower_case").asBool(); std::string tokenizer_path = fmt::format("{}/{}", load_model_request->model_dir, GetJsonValue(config_json_, "tokenizer_path").asString()); auto tokenizer_blob = LoadJsonFile(tokenizer_path)->asString(); tokenizer_ = tokenizers::Tokenizer.FromBlobJSON(tokenizer_blob); @@ -72,6 +71,8 @@ c10::IValue BertCppHandler::Preprocess( std::shared_ptr &request_batch, std::shared_ptr &response_batch) { auto batch_ivalue = c10::impl::GenericList(torch::TensorType::get()); + auto tokens_ivalue = c10::impl::GenericList(torch::TensorType::get()); + auto attention_mask = torch::ones({request_batch->size(), max_length_}, torch::kI32); uint8_t idx = 0; for (auto& request : *request_batch) { try { @@ -104,8 +105,7 @@ c10::IValue BertCppHandler::Preprocess( std::string msg = torchserve::Converter::VectorToStr(data_it->second); // tokenization - std::vector token_ids; - tokenizer_.Encode(msg, &token_ids); + std::vector token_ids = tokenizer_.Encode(msg);; int cur_token_ids_length = (int)token_ids.size(); if (cur_token_ids_length > max_length_) { @@ -115,7 +115,7 @@ c10::IValue BertCppHandler::Preprocess( token_ids.insert(token_ids.end(), max_length_ - cur_token_ids_length, tokenizer_.TokenToId("")); } auto options = torch::TensorOptions().dtype(torch::kInt64); - batch_ivalue.emplace_back(torch::from_blob(token_ids.data(), max_length_, options)); + tokens_ivalue.emplace_back(torch::from_blob(token_ids.data(), max_length_, options)); idx_to_req_id.second[idx++] = request.request_id; } catch (const std::runtime_error& e) { TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}", @@ -133,6 +133,8 @@ c10::IValue BertCppHandler::Preprocess( "c10 error, failed to load tensor"); } } + batch_ivalue.emplace_back(torch::from_blob(tokens_ivalue.toTensorVec(), {request_batch->size(), max_length_})); + batch_ivalue.emplace_back(attention_mask); return batch_ivalue; } diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.hh b/examples/cpp/aot_inductor/bert/src/bert_handler.hh index 91b6352c89..d1cdda2f09 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.hh +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.hh @@ -15,11 +15,6 @@ namespace bert { class BertCppHandler : public torchserve::BaseHandler { public: - // NOLINTBEGIN(bugprone-exception-escape) - BertCppHandler() = default; - // NOLINTEND(bugprone-exception-escape) - ~BertCppHandler() noexcept = default; - std::pair, std::shared_ptr> LoadModel( std::shared_ptr& load_model_request) override; @@ -50,7 +45,7 @@ private: std::unique_ptr config_json_; std::unique_ptr mapping_json_; - tokenizers::Tokenizer tokenizer_; + std::unique_ptr tokenizer_; int max_length_; }; } // namespace bert From 203f19eac82d26569c0c9ffec6e5e02e13e97470 Mon Sep 17 00:00:00 2001 From: lxning Date: Sun, 11 Feb 2024 21:13:46 -0800 Subject: [PATCH 13/42] fix compile error --- cpp/CMakeLists.txt | 2 +- cpp/build.sh | 1 - examples/cpp/aot_inductor/bert/src/bert_handler.cc | 8 ++++---- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 5176245a14..e81a272894 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.16 FATAL_ERROR) +cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR) project(torchserve_cpp VERSION 0.1) set(CMAKE_CXX_STANDARD 17) diff --git a/cpp/build.sh b/cpp/build.sh index aa28f3f5c5..32e00b2f40 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -162,7 +162,6 @@ function install_tokenizer_cpp() { echo -e "${COLOR_GREEN}[ INFO ] Cloning tokenizers-cpp repo ${COLOR_OFF}" git clone https://github.com/mlc-ai/tokenizers-cpp.git "$TOKENIZERS_CPP_SRC_DIR" cd $TOKENIZERS_CPP_SRC_DIR - git checkout tags/v0.1.0 fi cd "$BWD" || exit diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.cc b/examples/cpp/aot_inductor/bert/src/bert_handler.cc index 2fbc300ff2..9c2d8e51e5 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.cc +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.cc @@ -72,7 +72,7 @@ c10::IValue BertCppHandler::Preprocess( std::shared_ptr &response_batch) { auto batch_ivalue = c10::impl::GenericList(torch::TensorType::get()); auto tokens_ivalue = c10::impl::GenericList(torch::TensorType::get()); - auto attention_mask = torch::ones({request_batch->size(), max_length_}, torch::kI32); + auto attention_mask = torch::ones({static_cast(request_batch->size()), max_length_}, torch::kI32); uint8_t idx = 0; for (auto& request : *request_batch) { try { @@ -105,14 +105,14 @@ c10::IValue BertCppHandler::Preprocess( std::string msg = torchserve::Converter::VectorToStr(data_it->second); // tokenization - std::vector token_ids = tokenizer_.Encode(msg);; + std::vector token_ids = tokenizer_->Encode(msg);; int cur_token_ids_length = (int)token_ids.size(); if (cur_token_ids_length > max_length_) { TS_LOGF(ERROR, "prompt too long ({} tokens, max {})", cur_token_ids_length, max_length_); } else if (cur_token_ids_length < max_length_) { // padding token ids - token_ids.insert(token_ids.end(), max_length_ - cur_token_ids_length, tokenizer_.TokenToId("")); + token_ids.insert(token_ids.end(), max_length_ - cur_token_ids_length, tokenizer_->TokenToId("")); } auto options = torch::TensorOptions().dtype(torch::kInt64); tokens_ivalue.emplace_back(torch::from_blob(token_ids.data(), max_length_, options)); @@ -133,7 +133,7 @@ c10::IValue BertCppHandler::Preprocess( "c10 error, failed to load tensor"); } } - batch_ivalue.emplace_back(torch::from_blob(tokens_ivalue.toTensorVec(), {request_batch->size(), max_length_})); + batch_ivalue.emplace_back(torch::from_blob(tokens_ivalue.toTensorVector(), {static_cast(request_batch->size()), max_length_})); batch_ivalue.emplace_back(attention_mask); return batch_ivalue; From 0d8f50584201e1cff0923141f4e5f5d609bccb84 Mon Sep 17 00:00:00 2001 From: lxning Date: Sun, 11 Feb 2024 21:39:59 -0800 Subject: [PATCH 14/42] fix compile error --- examples/cpp/aot_inductor/bert/src/bert_handler.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.cc b/examples/cpp/aot_inductor/bert/src/bert_handler.cc index 9c2d8e51e5..86930dfe5c 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.cc +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.cc @@ -38,7 +38,7 @@ BertCppHandler::LoadModel( std::string tokenizer_path = fmt::format("{}/{}", load_model_request->model_dir, GetJsonValue(config_json_, "tokenizer_path").asString()); auto tokenizer_blob = LoadJsonFile(tokenizer_path)->asString(); - tokenizer_ = tokenizers::Tokenizer.FromBlobJSON(tokenizer_blob); + tokenizer_ = tokenizers::Tokenizer::FromBlobJSON(tokenizer_blob); std::string model_so_path = fmt::format("{}/{}", load_model_request->model_dir, GetJsonValue(config_json_, "model_so_path").asString()); c10::InferenceMode mode; @@ -71,7 +71,7 @@ c10::IValue BertCppHandler::Preprocess( std::shared_ptr &request_batch, std::shared_ptr &response_batch) { auto batch_ivalue = c10::impl::GenericList(torch::TensorType::get()); - auto tokens_ivalue = c10::impl::GenericList(torch::TensorType::get()); + std::vector tokens_ivalue; auto attention_mask = torch::ones({static_cast(request_batch->size()), max_length_}, torch::kI32); uint8_t idx = 0; for (auto& request : *request_batch) { @@ -133,7 +133,7 @@ c10::IValue BertCppHandler::Preprocess( "c10 error, failed to load tensor"); } } - batch_ivalue.emplace_back(torch::from_blob(tokens_ivalue.toTensorVector(), {static_cast(request_batch->size()), max_length_})); + batch_ivalue.emplace_back(torch::from_blob(tokens_ivalue, {static_cast(request_batch->size()), max_length_})); batch_ivalue.emplace_back(attention_mask); return batch_ivalue; From 45ae6b29dd2aa8ee1288c91e8d4cb2555d3c3598 Mon Sep 17 00:00:00 2001 From: lxning Date: Mon, 12 Feb 2024 11:56:12 -0800 Subject: [PATCH 15/42] fmt --- examples/cpp/aot_inductor/bert/src/bert_handler.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.cc b/examples/cpp/aot_inductor/bert/src/bert_handler.cc index 86930dfe5c..58c1aeed00 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.cc +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.cc @@ -70,8 +70,8 @@ c10::IValue BertCppHandler::Preprocess( std::pair &> &idx_to_req_id, std::shared_ptr &request_batch, std::shared_ptr &response_batch) { - auto batch_ivalue = c10::impl::GenericList(torch::TensorType::get()); - std::vector tokens_ivalue; + auto options = torch::TensorOptions().dtype(torch::kInt64); + std::vector batch_tokens; auto attention_mask = torch::ones({static_cast(request_batch->size()), max_length_}, torch::kI32); uint8_t idx = 0; for (auto& request : *request_batch) { @@ -114,8 +114,7 @@ c10::IValue BertCppHandler::Preprocess( // padding token ids token_ids.insert(token_ids.end(), max_length_ - cur_token_ids_length, tokenizer_->TokenToId("")); } - auto options = torch::TensorOptions().dtype(torch::kInt64); - tokens_ivalue.emplace_back(torch::from_blob(token_ids.data(), max_length_, options)); + batch_tokens.emplace_back(torch::from_blob(token_ids.data(), max_length_, options)); idx_to_req_id.second[idx++] = request.request_id; } catch (const std::runtime_error& e) { TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}", @@ -133,7 +132,8 @@ c10::IValue BertCppHandler::Preprocess( "c10 error, failed to load tensor"); } } - batch_ivalue.emplace_back(torch::from_blob(tokens_ivalue, {static_cast(request_batch->size()), max_length_})); + auto batch_ivalue = c10::impl::GenericList(torch::TensorType::get()); + batch_ivalue.emplace_back(torch::from_blob(batch_tokens, {static_cast(request_batch->size()), max_length_}, options)); batch_ivalue.emplace_back(attention_mask); return batch_ivalue; From 558df11e045986ce166d39398d6bdeb51daa711c Mon Sep 17 00:00:00 2001 From: lxning Date: Mon, 12 Feb 2024 15:06:41 -0800 Subject: [PATCH 16/42] Fmt --- cpp/build.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/build.sh b/cpp/build.sh index 32e00b2f40..bda4383caf 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -133,6 +133,7 @@ function install_yaml_cpp() { git clone https://github.com/jbeder/yaml-cpp.git "$YAML_CPP_SRC_DIR" cd $YAML_CPP_SRC_DIR git checkout tags/0.8.0 + git submodule update --init --recursive fi if [ ! -d "$YAML_CPP_BUILD_DIR" ] ; then From 9c2cdf368fb49a7298f47b2a43dfec0e8222ecbf Mon Sep 17 00:00:00 2001 From: lxning Date: Mon, 12 Feb 2024 15:18:16 -0800 Subject: [PATCH 17/42] tockenizer-cpp git submodule --- cpp/build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/build.sh b/cpp/build.sh index bda4383caf..73338df2ab 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -133,7 +133,6 @@ function install_yaml_cpp() { git clone https://github.com/jbeder/yaml-cpp.git "$YAML_CPP_SRC_DIR" cd $YAML_CPP_SRC_DIR git checkout tags/0.8.0 - git submodule update --init --recursive fi if [ ! -d "$YAML_CPP_BUILD_DIR" ] ; then @@ -163,6 +162,7 @@ function install_tokenizer_cpp() { echo -e "${COLOR_GREEN}[ INFO ] Cloning tokenizers-cpp repo ${COLOR_OFF}" git clone https://github.com/mlc-ai/tokenizers-cpp.git "$TOKENIZERS_CPP_SRC_DIR" cd $TOKENIZERS_CPP_SRC_DIR + git submodule update --init --recursive fi cd "$BWD" || exit From a7a551f135e3c937df51da22e542185a0df3cd88 Mon Sep 17 00:00:00 2001 From: lxning Date: Mon, 12 Feb 2024 15:49:32 -0800 Subject: [PATCH 18/42] update handler --- examples/cpp/aot_inductor/bert/src/bert_handler.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.cc b/examples/cpp/aot_inductor/bert/src/bert_handler.cc index 58c1aeed00..e2b4e95bd4 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.cc +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.cc @@ -133,7 +133,7 @@ c10::IValue BertCppHandler::Preprocess( } } auto batch_ivalue = c10::impl::GenericList(torch::TensorType::get()); - batch_ivalue.emplace_back(torch::from_blob(batch_tokens, {static_cast(request_batch->size()), max_length_}, options)); + batch_ivalue.emplace_back(torch::from_blob(batch_tokens.data(), {static_cast(request_batch->size()), max_length_}, options)); batch_ivalue.emplace_back(attention_mask); return batch_ivalue; From 748b73464244348c1e61da73ba60619d35cc12b0 Mon Sep 17 00:00:00 2001 From: lxning Date: Tue, 13 Feb 2024 13:20:31 -0800 Subject: [PATCH 19/42] fmt --- .../cpp/aot_inductor/bert/src/bert_handler.cc | 55 +++++++++++++++++-- .../cpp/aot_inductor/bert/src/bert_handler.hh | 6 ++ 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.cc b/examples/cpp/aot_inductor/bert/src/bert_handler.cc index e2b4e95bd4..8d77e4b39f 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.cc +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.cc @@ -3,10 +3,26 @@ #include namespace bert { + +std::string BertCppHandler::LoadBytesFromFile(const std::string& path) { + std::ifstream fs(path, std::ios::in | std::ios::binary); + if (fs.fail()) { + TS_LOGF(ERROR, "Cannot open tokenizer file {}", path); + throw; + } + std::string data; + fs.seekg(0, std::ios::end); + size_t size = static_cast(fs.tellg()); + fs.seekg(0, std::ios::beg); + data.resize(size); + fs.read(data.data(), size); + return data; +} + std::unique_ptr BertCppHandler::LoadJsonFile(const std::string& file_path) { std::string content; if (!folly::readFile(file_path.c_str(), content)) { - TS_LOGF(ERROR, "{}} not found", file_path); + TS_LOGF(ERROR, "{} not found", file_path); throw; } return std::make_unique(folly::parseJson(content)); @@ -25,22 +41,31 @@ std::pair, std::shared_ptr> BertCppHandler::LoadModel( std::shared_ptr& load_model_request) { try { + TS_LOG(INFO, "start LoadModel"); auto device = GetTorchDevice(load_model_request); + TS_LOG(INFO, "Found device id"); const std::string mapFilePath = fmt::format("{}/{}", load_model_request->model_dir, "index_to_name.json"); mapping_json_ = LoadJsonFile(mapFilePath); + TS_LOG(INFO, "Load index_to_name.json"); const std::string configFilePath = fmt::format("{}/{}", load_model_request->model_dir, "config.json"); config_json_ = LoadJsonFile(configFilePath); + TS_LOG(INFO, "Load config.json"); max_length_ = static_cast(GetJsonValue(config_json_, "max_length").asInt()); + TS_LOG(INFO, "Get max_length"); std::string tokenizer_path = fmt::format("{}/{}", load_model_request->model_dir, GetJsonValue(config_json_, "tokenizer_path").asString()); - auto tokenizer_blob = LoadJsonFile(tokenizer_path)->asString(); + auto tokenizer_blob = LoadBytesFromFile(tokenizer_path); + TS_LOG(INFO, "Load tokenizer"); + tokenizer_ = tokenizers::Tokenizer::FromBlobJSON(tokenizer_blob); + std::string model_so_path = fmt::format("{}/{}", load_model_request->model_dir, GetJsonValue(config_json_, "model_so_path").asString()); + TS_LOGF(INFO, "Get model_so_path {}", model_so_path); c10::InferenceMode mode; if (device->is_cuda()) { @@ -73,6 +98,7 @@ c10::IValue BertCppHandler::Preprocess( auto options = torch::TensorOptions().dtype(torch::kInt64); std::vector batch_tokens; auto attention_mask = torch::ones({static_cast(request_batch->size()), max_length_}, torch::kI32); + TS_LOG(INFO, "start Preprocess"); uint8_t idx = 0; for (auto& request : *request_batch) { try { @@ -84,8 +110,10 @@ c10::IValue BertCppHandler::Preprocess( auto data_it = request.parameters.find( torchserve::PayloadType::kPARAMETER_NAME_DATA); + TS_LOG(INFO, "get data_it "); auto dtype_it = request.headers.find(torchserve::PayloadType::kHEADER_NAME_DATA_TYPE); + TS_LOG(INFO, "get data_it "); if (data_it == request.parameters.end()) { data_it = request.parameters.find( torchserve::PayloadType::kPARAMETER_NAME_BODY); @@ -103,10 +131,12 @@ c10::IValue BertCppHandler::Preprocess( } std::string msg = torchserve::Converter::VectorToStr(data_it->second); + TS_LOGF(INFO, "receive msg {}", msg); // tokenization std::vector token_ids = tokenizer_->Encode(msg);; int cur_token_ids_length = (int)token_ids.size(); + TS_LOGF(INFO, "cur_token_ids_length {}", cur_token_ids_length); if (cur_token_ids_length > max_length_) { TS_LOGF(ERROR, "prompt too long ({} tokens, max {})", cur_token_ids_length, max_length_); @@ -114,7 +144,10 @@ c10::IValue BertCppHandler::Preprocess( // padding token ids token_ids.insert(token_ids.end(), max_length_ - cur_token_ids_length, tokenizer_->TokenToId("")); } + TS_LOG(INFO, "pad token_ids"); batch_tokens.emplace_back(torch::from_blob(token_ids.data(), max_length_, options)); + TS_LOG(INFO, "add token_ids to batch_tokens"); + idx_to_req_id.second[idx++] = request.request_id; } catch (const std::runtime_error& e) { TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}", @@ -133,8 +166,10 @@ c10::IValue BertCppHandler::Preprocess( } } auto batch_ivalue = c10::impl::GenericList(torch::TensorType::get()); - batch_ivalue.emplace_back(torch::from_blob(batch_tokens.data(), {static_cast(request_batch->size()), max_length_}, options)); - batch_ivalue.emplace_back(attention_mask); + batch_ivalue.emplace_back(torch::from_blob(batch_tokens.data(), {static_cast(request_batch->size()), max_length_}, options).to(*device)); + TS_LOG(INFO, "add batch tokens to batch_ivalue"); + batch_ivalue.emplace_back(attention_mask.to(*device)); + TS_LOG(INFO, "add batch mask to batch_ivalue"); return batch_ivalue; } @@ -146,14 +181,22 @@ c10::IValue BertCppHandler::Inference( std::shared_ptr &response_batch) { c10::InferenceMode mode; try { + TS_LOG(INFO, "start Inference"); std::shared_ptr runner; if (device->is_cuda()) { runner = std::static_pointer_cast(model); } else { runner = std::static_pointer_cast(model); } - - auto batch_output_tensor_vector = runner->run(inputs.toTensorVector()); + TS_LOG(INFO, "cast model to runner"); + auto vec = inputs.toTensorVector(); + for (ulong i=0; i < vec.size(); i++) { + std::cout << "item " << i << ", tensor:" << vec[i] << std::endl; + } + TS_LOG(INFO, "convert ivalue to TensorVector"); + //auto batch_output_tensor_vector = runner->run(inputs.toTensorVector()); + auto batch_output_tensor_vector = runner->run(vec); + TS_LOG(INFO, "get batch_output_tensor_vector"); return c10::IValue(batch_output_tensor_vector[0]); } catch (std::runtime_error& e) { TS_LOG(ERROR, e.what()); diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.hh b/examples/cpp/aot_inductor/bert/src/bert_handler.hh index d1cdda2f09..604bf56696 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.hh +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.hh @@ -15,6 +15,11 @@ namespace bert { class BertCppHandler : public torchserve::BaseHandler { public: + // NOLINTBEGIN(bugprone-exception-escape) + BertCppHandler() = default; + // NOLINTEND(bugprone-exception-escape) + ~BertCppHandler() noexcept = default; + std::pair, std::shared_ptr> LoadModel( std::shared_ptr& load_model_request) override; @@ -40,6 +45,7 @@ class BertCppHandler : public torchserve::BaseHandler { override; private: + std::string LoadBytesFromFile(const std::string& path); std::unique_ptr LoadJsonFile(const std::string& file_path); const folly::dynamic& GetJsonValue(std::unique_ptr& json, const std::string& key); From 0bbfc18abc11b92b37687916970ca5f9d3864f29 Mon Sep 17 00:00:00 2001 From: lxning Date: Tue, 13 Feb 2024 17:38:18 -0800 Subject: [PATCH 20/42] fmt --- .../cpp/aot_inductor/bert/src/bert_handler.cc | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.cc b/examples/cpp/aot_inductor/bert/src/bert_handler.cc index 8d77e4b39f..35c18a76ef 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.cc +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.cc @@ -95,9 +95,9 @@ c10::IValue BertCppHandler::Preprocess( std::pair &> &idx_to_req_id, std::shared_ptr &request_batch, std::shared_ptr &response_batch) { - auto options = torch::TensorOptions().dtype(torch::kInt64); - std::vector batch_tokens; - auto attention_mask = torch::ones({static_cast(request_batch->size()), max_length_}, torch::kI32); + auto options = torch::TensorOptions().dtype(torch::kInt32); + std::vector batch_tokens; + auto attention_mask = torch::zeros({static_cast(request_batch->size()), max_length_}, torch::kInt32); TS_LOG(INFO, "start Preprocess"); uint8_t idx = 0; for (auto& request : *request_batch) { @@ -134,8 +134,12 @@ c10::IValue BertCppHandler::Preprocess( TS_LOGF(INFO, "receive msg {}", msg); // tokenization - std::vector token_ids = tokenizer_->Encode(msg);; + std::vector token_ids = tokenizer_->Encode(msg);; int cur_token_ids_length = (int)token_ids.size(); + for (int i = 0; i < cur_token_ids_length; i++) { + TS_LOGF(INFO, "token: {}, id: {}", i, token_ids[i]); + attention_mask[idx][i] = 1; + } TS_LOGF(INFO, "cur_token_ids_length {}", cur_token_ids_length); if (cur_token_ids_length > max_length_) { @@ -145,7 +149,7 @@ c10::IValue BertCppHandler::Preprocess( token_ids.insert(token_ids.end(), max_length_ - cur_token_ids_length, tokenizer_->TokenToId("")); } TS_LOG(INFO, "pad token_ids"); - batch_tokens.emplace_back(torch::from_blob(token_ids.data(), max_length_, options)); + batch_tokens.insert(batch_tokens.end(), token_ids.begin(), token_ids.end()); TS_LOG(INFO, "add token_ids to batch_tokens"); idx_to_req_id.second[idx++] = request.request_id; @@ -166,8 +170,10 @@ c10::IValue BertCppHandler::Preprocess( } } auto batch_ivalue = c10::impl::GenericList(torch::TensorType::get()); + std::cout << "batch_tokens.data blob" << torch::from_blob(batch_tokens.data(), {static_cast(request_batch->size()), max_length_}) << std::endl; batch_ivalue.emplace_back(torch::from_blob(batch_tokens.data(), {static_cast(request_batch->size()), max_length_}, options).to(*device)); TS_LOG(INFO, "add batch tokens to batch_ivalue"); + std::cout << "mask: " << attention_mask << std::endl; batch_ivalue.emplace_back(attention_mask.to(*device)); TS_LOG(INFO, "add batch mask to batch_ivalue"); From 4b2a1cee92685c2fa38def7e547c04ec47ead74f Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 14 Feb 2024 15:22:41 -0800 Subject: [PATCH 21/42] fmt --- cpp/build.sh | 24 ++++++++++++++----- cpp/src/examples/CMakeLists.txt | 9 ++++--- cpp/test/examples/examples_test.cc | 23 ++++++++++++++++++ .../bert_handler/MAR-INF/MANIFEST.json | 10 ++++++++ .../aot_inductor/bert_handler/config.json | 2 +- .../bert_handler/index_to_name.json | 4 ++++ .../aot_inductor/bert_handler/sample_text.txt | 1 + .../aot_inductor/bert/aot_compile_export.py | 2 +- examples/cpp/aot_inductor/bert/config.json | 15 ++++++++++++ .../cpp/aot_inductor/bert/src/bert_handler.cc | 4 ++-- 10 files changed, 81 insertions(+), 13 deletions(-) create mode 100644 cpp/test/resources/examples/aot_inductor/bert_handler/MAR-INF/MANIFEST.json rename examples/cpp/aot_inductor/bert/setup_config.json => cpp/test/resources/examples/aot_inductor/bert_handler/config.json (84%) create mode 100644 cpp/test/resources/examples/aot_inductor/bert_handler/index_to_name.json create mode 100644 cpp/test/resources/examples/aot_inductor/bert_handler/sample_text.txt create mode 100644 examples/cpp/aot_inductor/bert/config.json diff --git a/cpp/build.sh b/cpp/build.sh index 73338df2ab..8cd067b335 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -190,14 +190,26 @@ function prepare_test_files() { if [ ! -f "${EX_DIR}/babyllama/babyllama_handler/stories15M.bin" ]; then wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin -O "${EX_DIR}/babyllama/babyllama_handler/stories15M.bin" fi - if [ ! -f "${EX_DIR}/aot_inductor/llama_handler/stories15M.so" ]; then - local HANDLER_DIR=${EX_DIR}/aot_inductor/llama_handler/ - if [ ! -f "${HANDLER_DIR}/stories15M.pt" ]; then - wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt?download=true -O "${HANDLER_DIR}/stories15M.pt" + # PT2.2 torch.expport does not support Mac + if [ "$PLATFORM" = "Linux" ]; then + if [ ! -f "${EX_DIR}/aot_inductor/llama_handler/stories15M.so" ]; then + local HANDLER_DIR=${EX_DIR}/aot_inductor/llama_handler/ + if [ ! -f "${HANDLER_DIR}/stories15M.pt" ]; then + wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt?download=true -O "${HANDLER_DIR}/stories15M.pt" + fi + local LLAMA_SO_DIR=${BASE_DIR}/third-party/llama2.so/ + PYTHONPATH=${LLAMA_SO_DIR}:${PYTHONPATH} python ${BASE_DIR}/../examples/cpp/aot_inductor/llama2/compile.py --checkpoint ${HANDLER_DIR}/stories15M.pt ${HANDLER_DIR}/stories15M.so + fi + if [ ! -f "${EX_DIR}/aot_inductor/bert_handler/bert-seq.so" ]; then + local HANDLER_DIR=${EX_DIR}/aot_inductor/bert_handler/ + export TOKENIZERS_PARALLELISM=false + cd ${BASE_DIR}/../examples/cpp/aot_inductor/bert/ + python aot_compile_export.py + mv bert-seq.so ${HANDLER_DIR}/bert-seq.so + mv Transformer_model ${HANDLER_DIR}/Transformer_model fi - local LLAMA_SO_DIR=${BASE_DIR}/third-party/llama2.so/ - PYTHONPATH=${LLAMA_SO_DIR}:${PYTHONPATH} python ${BASE_DIR}/../examples/cpp/aot_inductor/llama2/compile.py --checkpoint ${HANDLER_DIR}/stories15M.pt ${HANDLER_DIR}/stories15M.so fi + cd "$BWD" || exit } function build() { diff --git a/cpp/src/examples/CMakeLists.txt b/cpp/src/examples/CMakeLists.txt index b2897486ff..aab31e8c78 100644 --- a/cpp/src/examples/CMakeLists.txt +++ b/cpp/src/examples/CMakeLists.txt @@ -1,10 +1,13 @@ add_subdirectory("../../../examples/cpp/babyllama/" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/babyllama/babyllama_handler/") -add_subdirectory("../../../examples/cpp/aot_inductor/llama2/" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/aot_inductor/llama_handler/") - add_subdirectory("../../../examples/cpp/llamacpp/" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/llamacpp/llamacpp_handler/") add_subdirectory("../../../examples/cpp/mnist/" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/mnist/mnist_handler/") -add_subdirectory("../../../examples/cpp/aot_inductor/bert" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/aot_inductor/bert/") +# PT2.2 torch.expport does not support Mac +if(CMAKE_SYSTEM_NAME MATCHES "Linux") + add_subdirectory("../../../examples/cpp/aot_inductor/llama2/" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/aot_inductor/llama_handler/") + + add_subdirectory("../../../examples/cpp/aot_inductor/resnet" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/aot_inductor/bert_handler/") +endif() diff --git a/cpp/test/examples/examples_test.cc b/cpp/test/examples/examples_test.cc index 00e5135715..f55ec0e994 100644 --- a/cpp/test/examples/examples_test.cc +++ b/cpp/test/examples/examples_test.cc @@ -59,3 +59,26 @@ TEST_F(ModelPredictTest, TestLoadPredictLlamaCppHandler) { base_dir + "llamacpp_handler", "llamacpp", -1, "", "", 1, false), base_dir + "llamacpp_handler", base_dir + "prompt.txt", "llm_ts", 200); } + +TEST_F(ModelPredictTest, TestLoadPredictAotInductorBertHandler) { + std::string base_dir = "_build/test/resources/examples/aot_inductor/"; + std::string file1 = base_dir + "bert_handler/bert-seq.so"; + std::string file2 = base_dir + "bert_handler/Transformer_model/tokenizer.json"; + + std::ifstream f1(file1); + std::ifstream f2(file2); + + if (!f1.good() || !f2.good()) + GTEST_SKIP() << "Skipping TestLoadPredictAotInductorBertHandler because " + "of missing files: " + << file1 << " or " << file2; + + this->LoadPredict( + std::make_shared( + base_dir + "bert_handler", "bert_aot", + torch::cuda::is_available() ? 0 : -1, "", "", 1, false), + base_dir + "bert_handler", + base_dir + "bert_handler/sample_text.txt", + "bert_ts", + 200); +} diff --git a/cpp/test/resources/examples/aot_inductor/bert_handler/MAR-INF/MANIFEST.json b/cpp/test/resources/examples/aot_inductor/bert_handler/MAR-INF/MANIFEST.json new file mode 100644 index 0000000000..c9c205603a --- /dev/null +++ b/cpp/test/resources/examples/aot_inductor/bert_handler/MAR-INF/MANIFEST.json @@ -0,0 +1,10 @@ +{ + "createdOn": "12/02/2024 21:09:26", + "runtime": "LSP", + "model": { + "modelName": "bertcppaot", + "handler": "libbert_handler:BertCppHandler", + "modelVersion": "1.0" + }, + "archiverVersion": "0.9.0" +} diff --git a/examples/cpp/aot_inductor/bert/setup_config.json b/cpp/test/resources/examples/aot_inductor/bert_handler/config.json similarity index 84% rename from examples/cpp/aot_inductor/bert/setup_config.json rename to cpp/test/resources/examples/aot_inductor/bert_handler/config.json index e7acb6c955..f593f2a05b 100644 --- a/examples/cpp/aot_inductor/bert/setup_config.json +++ b/cpp/test/resources/examples/aot_inductor/bert_handler/config.json @@ -10,6 +10,6 @@ "BetterTransformer":false, "model_parallel":false, "batch_size": "4", - "tokenizer_path": "Transformer_model/model.safetensors", + "tokenizer_path": "Transformer_model/tokenizer.json", "model_so_path": "bert-seq.so" } diff --git a/cpp/test/resources/examples/aot_inductor/bert_handler/index_to_name.json b/cpp/test/resources/examples/aot_inductor/bert_handler/index_to_name.json new file mode 100644 index 0000000000..9ccff719f6 --- /dev/null +++ b/cpp/test/resources/examples/aot_inductor/bert_handler/index_to_name.json @@ -0,0 +1,4 @@ +{ + "0":"Not Accepted", + "1":"Accepted" +} diff --git a/cpp/test/resources/examples/aot_inductor/bert_handler/sample_text.txt b/cpp/test/resources/examples/aot_inductor/bert_handler/sample_text.txt new file mode 100644 index 0000000000..4c15a88ad2 --- /dev/null +++ b/cpp/test/resources/examples/aot_inductor/bert_handler/sample_text.txt @@ -0,0 +1 @@ +Bloomberg has decided to publish a new report on the global economy. diff --git a/examples/cpp/aot_inductor/bert/aot_compile_export.py b/examples/cpp/aot_inductor/bert/aot_compile_export.py index 47ababf092..49fd4058ce 100644 --- a/examples/cpp/aot_inductor/bert/aot_compile_export.py +++ b/examples/cpp/aot_inductor/bert/aot_compile_export.py @@ -98,7 +98,7 @@ def transformers_model_dowloader( if len(sys.argv) > 1: filename = os.path.join(dirname, sys.argv[1]) else: - filename = os.path.join(dirname, "setup_config.json") + filename = os.path.join(dirname, "config.json") f = open(filename) settings = json.load(f) mode = settings["mode"] diff --git a/examples/cpp/aot_inductor/bert/config.json b/examples/cpp/aot_inductor/bert/config.json new file mode 100644 index 0000000000..f593f2a05b --- /dev/null +++ b/examples/cpp/aot_inductor/bert/config.json @@ -0,0 +1,15 @@ +{ + "model_name":"bert-base-uncased", + "mode":"sequence_classification", + "do_lower_case":true, + "num_labels":"2", + "max_length":"150", + "captum_explanation":false, + "embedding_name": "bert", + "FasterTransformer":false, + "BetterTransformer":false, + "model_parallel":false, + "batch_size": "4", + "tokenizer_path": "Transformer_model/tokenizer.json", + "model_so_path": "bert-seq.so" +} diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.cc b/examples/cpp/aot_inductor/bert/src/bert_handler.cc index 35c18a76ef..ce3bba6e6d 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.cc +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.cc @@ -137,8 +137,8 @@ c10::IValue BertCppHandler::Preprocess( std::vector token_ids = tokenizer_->Encode(msg);; int cur_token_ids_length = (int)token_ids.size(); for (int i = 0; i < cur_token_ids_length; i++) { - TS_LOGF(INFO, "token: {}, id: {}", i, token_ids[i]); - attention_mask[idx][i] = 1; + TS_LOGF(INFO, "token: {}, id: {}", i, token_ids[i]); + attention_mask[idx][i] = 1; } TS_LOGF(INFO, "cur_token_ids_length {}", cur_token_ids_length); From 060453376e3355a8fd0698b0843694293b4a7e9e Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 14 Feb 2024 15:59:40 -0800 Subject: [PATCH 22/42] unset env --- cpp/build.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/build.sh b/cpp/build.sh index 8cd067b335..f455b61b0c 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -207,6 +207,7 @@ function prepare_test_files() { python aot_compile_export.py mv bert-seq.so ${HANDLER_DIR}/bert-seq.so mv Transformer_model ${HANDLER_DIR}/Transformer_model + export TOKENIZERS_PARALLELISM="" fi fi cd "$BWD" || exit From 9922a9997aed213f8ebc3ac2cf63b8435faa0eb9 Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 14 Feb 2024 16:08:45 -0800 Subject: [PATCH 23/42] fix path --- cpp/src/examples/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/examples/CMakeLists.txt b/cpp/src/examples/CMakeLists.txt index aab31e8c78..51126210bd 100644 --- a/cpp/src/examples/CMakeLists.txt +++ b/cpp/src/examples/CMakeLists.txt @@ -9,5 +9,5 @@ add_subdirectory("../../../examples/cpp/mnist/" "${CMAKE_CURRENT_BINARY_DIR}/../ if(CMAKE_SYSTEM_NAME MATCHES "Linux") add_subdirectory("../../../examples/cpp/aot_inductor/llama2/" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/aot_inductor/llama_handler/") - add_subdirectory("../../../examples/cpp/aot_inductor/resnet" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/aot_inductor/bert_handler/") + add_subdirectory("../../../examples/cpp/aot_inductor/bert" "${CMAKE_CURRENT_BINARY_DIR}/../../test/resources/examples/aot_inductor/bert_handler/") endif() From 0e81e4f22f449ac2b1ff47c18e08178d000179de Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Thu, 15 Feb 2024 19:47:32 +0000 Subject: [PATCH 24/42] Fix type error in bert aot example --- .../cpp/aot_inductor/bert/src/bert_handler.cc | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.cc b/examples/cpp/aot_inductor/bert/src/bert_handler.cc index ce3bba6e6d..09f457a150 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.cc +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.cc @@ -95,9 +95,9 @@ c10::IValue BertCppHandler::Preprocess( std::pair &> &idx_to_req_id, std::shared_ptr &request_batch, std::shared_ptr &response_batch) { - auto options = torch::TensorOptions().dtype(torch::kInt32); - std::vector batch_tokens; - auto attention_mask = torch::zeros({static_cast(request_batch->size()), max_length_}, torch::kInt32); + auto options = torch::TensorOptions().dtype(torch::kLong); + auto attention_mask = torch::zeros({static_cast(request_batch->size()), max_length_}, torch::kLong); + auto batch_tokens = torch::full({static_cast(request_batch->size()), max_length_}, tokenizer_->TokenToId(""), torch::kLong); TS_LOG(INFO, "start Preprocess"); uint8_t idx = 0; for (auto& request : *request_batch) { @@ -139,17 +139,14 @@ c10::IValue BertCppHandler::Preprocess( for (int i = 0; i < cur_token_ids_length; i++) { TS_LOGF(INFO, "token: {}, id: {}", i, token_ids[i]); attention_mask[idx][i] = 1; + batch_tokens[idx][i] = token_ids[i]; } TS_LOGF(INFO, "cur_token_ids_length {}", cur_token_ids_length); if (cur_token_ids_length > max_length_) { TS_LOGF(ERROR, "prompt too long ({} tokens, max {})", cur_token_ids_length, max_length_); - } else if (cur_token_ids_length < max_length_) { - // padding token ids - token_ids.insert(token_ids.end(), max_length_ - cur_token_ids_length, tokenizer_->TokenToId("")); } TS_LOG(INFO, "pad token_ids"); - batch_tokens.insert(batch_tokens.end(), token_ids.begin(), token_ids.end()); TS_LOG(INFO, "add token_ids to batch_tokens"); idx_to_req_id.second[idx++] = request.request_id; @@ -170,8 +167,8 @@ c10::IValue BertCppHandler::Preprocess( } } auto batch_ivalue = c10::impl::GenericList(torch::TensorType::get()); - std::cout << "batch_tokens.data blob" << torch::from_blob(batch_tokens.data(), {static_cast(request_batch->size()), max_length_}) << std::endl; - batch_ivalue.emplace_back(torch::from_blob(batch_tokens.data(), {static_cast(request_batch->size()), max_length_}, options).to(*device)); + batch_ivalue.emplace_back(batch_tokens.to(*device)); + std::cout << "input_ids: " << batch_tokens << std::endl; TS_LOG(INFO, "add batch tokens to batch_ivalue"); std::cout << "mask: " << attention_mask << std::endl; batch_ivalue.emplace_back(attention_mask.to(*device)); From ac08078b826e1fda321bfaa6e83a6428e0f0ec50 Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 17 Feb 2024 13:58:26 -0800 Subject: [PATCH 25/42] fmt --- cpp/CMakeLists.txt | 2 +- cpp/build.sh | 3 +- cpp/test/examples/examples_test.cc | 2 +- .../bert_handler/MAR-INF/MANIFEST.json | 3 +- .../aot_inductor/bert_handler/config.json | 15 ----- .../bert_handler/model-config.yaml | 13 ++++ examples/cpp/aot_inductor/bert/README.md | 66 +++++++++++++++++++ .../aot_inductor/bert/aot_compile_export.py | 19 +++--- examples/cpp/aot_inductor/bert/config.json | 15 ----- .../cpp/aot_inductor/bert/index_to_name.json | 4 ++ .../cpp/aot_inductor/bert/model-config.yaml | 13 ++++ .../cpp/aot_inductor/bert/src/bert_handler.cc | 63 ++++++------------ .../cpp/aot_inductor/bert/src/bert_handler.hh | 3 +- 13 files changed, 135 insertions(+), 86 deletions(-) delete mode 100644 cpp/test/resources/examples/aot_inductor/bert_handler/config.json create mode 100644 cpp/test/resources/examples/aot_inductor/bert_handler/model-config.yaml create mode 100644 examples/cpp/aot_inductor/bert/README.md delete mode 100644 examples/cpp/aot_inductor/bert/config.json create mode 100644 examples/cpp/aot_inductor/bert/index_to_name.json create mode 100644 examples/cpp/aot_inductor/bert/model-config.yaml diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index e81a272894..f466ee6a6b 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -30,7 +30,7 @@ find_package(folly REQUIRED) find_package(fmt REQUIRED) find_package(gflags REQUIRED) find_package(Torch REQUIRED) -find_package(yaml-cpp REQUIRED NO_CMAKE_PATH) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") include_directories(${TORCH_INCLUDE_DIRS}) diff --git a/cpp/build.sh b/cpp/build.sh index f455b61b0c..b4f9ae84fe 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -201,12 +201,13 @@ function prepare_test_files() { PYTHONPATH=${LLAMA_SO_DIR}:${PYTHONPATH} python ${BASE_DIR}/../examples/cpp/aot_inductor/llama2/compile.py --checkpoint ${HANDLER_DIR}/stories15M.pt ${HANDLER_DIR}/stories15M.so fi if [ ! -f "${EX_DIR}/aot_inductor/bert_handler/bert-seq.so" ]; then + pip install transformers local HANDLER_DIR=${EX_DIR}/aot_inductor/bert_handler/ export TOKENIZERS_PARALLELISM=false cd ${BASE_DIR}/../examples/cpp/aot_inductor/bert/ python aot_compile_export.py mv bert-seq.so ${HANDLER_DIR}/bert-seq.so - mv Transformer_model ${HANDLER_DIR}/Transformer_model + mv Transformer_model/tokenizer_config.json ${HANDLER_DIR}/tokenizer_config.json export TOKENIZERS_PARALLELISM="" fi fi diff --git a/cpp/test/examples/examples_test.cc b/cpp/test/examples/examples_test.cc index f55ec0e994..afce1fb271 100644 --- a/cpp/test/examples/examples_test.cc +++ b/cpp/test/examples/examples_test.cc @@ -63,7 +63,7 @@ TEST_F(ModelPredictTest, TestLoadPredictLlamaCppHandler) { TEST_F(ModelPredictTest, TestLoadPredictAotInductorBertHandler) { std::string base_dir = "_build/test/resources/examples/aot_inductor/"; std::string file1 = base_dir + "bert_handler/bert-seq.so"; - std::string file2 = base_dir + "bert_handler/Transformer_model/tokenizer.json"; + std::string file2 = base_dir + "bert_handler/tokenizer.json"; std::ifstream f1(file1); std::ifstream f2(file2); diff --git a/cpp/test/resources/examples/aot_inductor/bert_handler/MAR-INF/MANIFEST.json b/cpp/test/resources/examples/aot_inductor/bert_handler/MAR-INF/MANIFEST.json index c9c205603a..c5f5d519f3 100644 --- a/cpp/test/resources/examples/aot_inductor/bert_handler/MAR-INF/MANIFEST.json +++ b/cpp/test/resources/examples/aot_inductor/bert_handler/MAR-INF/MANIFEST.json @@ -4,7 +4,8 @@ "model": { "modelName": "bertcppaot", "handler": "libbert_handler:BertCppHandler", - "modelVersion": "1.0" + "modelVersion": "1.0", + "configFile": "model-config.yaml" }, "archiverVersion": "0.9.0" } diff --git a/cpp/test/resources/examples/aot_inductor/bert_handler/config.json b/cpp/test/resources/examples/aot_inductor/bert_handler/config.json deleted file mode 100644 index f593f2a05b..0000000000 --- a/cpp/test/resources/examples/aot_inductor/bert_handler/config.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "model_name":"bert-base-uncased", - "mode":"sequence_classification", - "do_lower_case":true, - "num_labels":"2", - "max_length":"150", - "captum_explanation":false, - "embedding_name": "bert", - "FasterTransformer":false, - "BetterTransformer":false, - "model_parallel":false, - "batch_size": "4", - "tokenizer_path": "Transformer_model/tokenizer.json", - "model_so_path": "bert-seq.so" -} diff --git a/cpp/test/resources/examples/aot_inductor/bert_handler/model-config.yaml b/cpp/test/resources/examples/aot_inductor/bert_handler/model-config.yaml new file mode 100644 index 0000000000..f44839848c --- /dev/null +++ b/cpp/test/resources/examples/aot_inductor/bert_handler/model-config.yaml @@ -0,0 +1,13 @@ +minWorkers: 1 +maxWorkers: 1 +batchSize: 2 + +handler: + model_so_path: "bert-seq.so" + tokenizer_path: "tokenizer.json" + mapping: "index_to_name.json" + model_name: "bert-base-uncased" + mode: "sequence_classification" + do_lower_case: true + num_labels: 2 + max_length: 150 diff --git a/examples/cpp/aot_inductor/bert/README.md b/examples/cpp/aot_inductor/bert/README.md new file mode 100644 index 0000000000..8a54ade2ce --- /dev/null +++ b/examples/cpp/aot_inductor/bert/README.md @@ -0,0 +1,66 @@ +This example uses AOTInductor to compile the [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) into an so file which is then executed using libtorch. +The handler C++ source code for this examples can be found [here](src). + +### Setup +1. Follow the instructions in [README.md](../../../../cpp/README.md) to build the TorchServe C++ backend. + +``` +cd serve/cpp +./builld.sh +``` + +The build script will create the necessary artifact for this example. +To recreate these by hand you can follow the prepare_test_files function of the [build.sh](../../../../cpp/build.sh) script. +We will need the handler .so file as well as the bert-seq.so and tokenizer.json. + +2. Create a [model-config.yaml](model-config.yaml) + +```yaml +minWorkers: 1 +maxWorkers: 1 +batchSize: 2 + +handler: + model_so_path: "bert-seq.so" + tokenizer_path: "tokenizer.json" + mapping: "index_to_name.json" + model_name: "bert-base-uncased" + mode: "sequence_classification" + do_lower_case: true + num_labels: 2 + max_length: 150 +``` + +### Generate Model Artifact Folder + +```bash +torch-model-archiver --model-name bertcppaot --version 1.0 --handler ../../../../cpp/_build/test/resources/examples/aot_inductor/bert/libbert_handler:BertCppHandler --runtime LSP --extra-files index_to_name.json,../../../../cpp/_build/test/resources/examples/aot_inductor/bert_handler/bert-seq.so,../../../../cpp/_build/test/resources/examples/aot_inductor/bert_handler/tokenizer.json --config-file model-config.yaml --archive-format no-archive +``` + +Create model store directory and move the folder `bertcppaot` + +``` +mkdir model_store +mv bertcppaot model_store/ +``` + +### Inference + +Start torchserve using the following command + +``` +torchserve --ncs --model-store model_store/ --models bertcppaot +``` + +Infer the model using the following command + +``` +curl http://localhost:8080/predictions/bertcppaot -T ../../../../cpp/test/resources/examples/aot_inductor/bert_handler/sample_text.txt +{ + "lens_cap": 0.0022578993812203407, + "lynx": 0.0032067005522549152, + "Egyptian_cat": 0.046274684369564056, + "tiger_cat": 0.13740436732769012, + "tabby": 0.2724998891353607 +} +``` diff --git a/examples/cpp/aot_inductor/bert/aot_compile_export.py b/examples/cpp/aot_inductor/bert/aot_compile_export.py index 49fd4058ce..dd6b078ed2 100644 --- a/examples/cpp/aot_inductor/bert/aot_compile_export.py +++ b/examples/cpp/aot_inductor/bert/aot_compile_export.py @@ -1,8 +1,8 @@ -import json import os import sys import torch +import yaml from transformers import ( AutoConfig, AutoModelForSequenceClassification, @@ -98,14 +98,15 @@ def transformers_model_dowloader( if len(sys.argv) > 1: filename = os.path.join(dirname, sys.argv[1]) else: - filename = os.path.join(dirname, "config.json") - f = open(filename) - settings = json.load(f) - mode = settings["mode"] - model_name = settings["model_name"] - num_labels = int(settings["num_labels"]) - do_lower_case = settings["do_lower_case"] - max_length = int(settings["max_length"]) + filename = os.path.join(dirname, "model-config.yaml") + with open(filename, "r") as f: + settings = yaml.load(f) + + mode = settings["handler"]["mode"] + model_name = settings["handler"]["model_name"] + num_labels = int(settings["handler"]["num_labels"]) + do_lower_case = settings["handler"]["do_lower_case"] + max_length = int(settings["handler"]["max_length"]) batch_size = int(settings.get("batch_size", "1")) transformers_model_dowloader( diff --git a/examples/cpp/aot_inductor/bert/config.json b/examples/cpp/aot_inductor/bert/config.json deleted file mode 100644 index f593f2a05b..0000000000 --- a/examples/cpp/aot_inductor/bert/config.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "model_name":"bert-base-uncased", - "mode":"sequence_classification", - "do_lower_case":true, - "num_labels":"2", - "max_length":"150", - "captum_explanation":false, - "embedding_name": "bert", - "FasterTransformer":false, - "BetterTransformer":false, - "model_parallel":false, - "batch_size": "4", - "tokenizer_path": "Transformer_model/tokenizer.json", - "model_so_path": "bert-seq.so" -} diff --git a/examples/cpp/aot_inductor/bert/index_to_name.json b/examples/cpp/aot_inductor/bert/index_to_name.json new file mode 100644 index 0000000000..9ccff719f6 --- /dev/null +++ b/examples/cpp/aot_inductor/bert/index_to_name.json @@ -0,0 +1,4 @@ +{ + "0":"Not Accepted", + "1":"Accepted" +} diff --git a/examples/cpp/aot_inductor/bert/model-config.yaml b/examples/cpp/aot_inductor/bert/model-config.yaml new file mode 100644 index 0000000000..f44839848c --- /dev/null +++ b/examples/cpp/aot_inductor/bert/model-config.yaml @@ -0,0 +1,13 @@ +minWorkers: 1 +maxWorkers: 1 +batchSize: 2 + +handler: + model_so_path: "bert-seq.so" + tokenizer_path: "tokenizer.json" + mapping: "index_to_name.json" + model_name: "bert-base-uncased" + mode: "sequence_classification" + do_lower_case: true + num_labels: 2 + max_length: 150 diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.cc b/examples/cpp/aot_inductor/bert/src/bert_handler.cc index 09f457a150..2ba7a1b8e5 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.cc +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.cc @@ -43,29 +43,28 @@ BertCppHandler::LoadModel( try { TS_LOG(INFO, "start LoadModel"); auto device = GetTorchDevice(load_model_request); - TS_LOG(INFO, "Found device id"); + + const std::string modelConfigYamlFilePath = + fmt::format("{}/{}", load_model_request->model_dir, "model-config.yaml"); + model_config_yaml_ = std::make_unique(YAML::LoadFile(modelConfigYamlFilePath)); const std::string mapFilePath = - fmt::format("{}/{}", load_model_request->model_dir, "index_to_name.json"); + fmt::format("{}/{}", load_model_request->model_dir, + (*model_config_yaml_)["handler"]["mapping"].as()); mapping_json_ = LoadJsonFile(mapFilePath); - TS_LOG(INFO, "Load index_to_name.json"); - const std::string configFilePath = - fmt::format("{}/{}", load_model_request->model_dir, "config.json"); - config_json_ = LoadJsonFile(configFilePath); - TS_LOG(INFO, "Load config.json"); - max_length_ = static_cast(GetJsonValue(config_json_, "max_length").asInt()); - TS_LOG(INFO, "Get max_length"); + max_length_ = (*model_config_yaml_)["handler"]["max_length"].as(); - std::string tokenizer_path = fmt::format("{}/{}", load_model_request->model_dir, GetJsonValue(config_json_, "tokenizer_path").asString()); + std::string tokenizer_path = + fmt::format("{}/{}", load_model_request->model_dir, + (*model_config_yaml_)["handler"]["tokenizer_path"].as()); auto tokenizer_blob = LoadBytesFromFile(tokenizer_path); - TS_LOG(INFO, "Load tokenizer"); - tokenizer_ = tokenizers::Tokenizer::FromBlobJSON(tokenizer_blob); + std::string model_so_path = + fmt::format("{}/{}", load_model_request->model_dir, + (*model_config_yaml_)["handler"]["model_so_path"].as()); - std::string model_so_path = fmt::format("{}/{}", load_model_request->model_dir, GetJsonValue(config_json_, "model_so_path").asString()); - TS_LOGF(INFO, "Get model_so_path {}", model_so_path); c10::InferenceMode mode; if (device->is_cuda()) { @@ -98,7 +97,7 @@ c10::IValue BertCppHandler::Preprocess( auto options = torch::TensorOptions().dtype(torch::kLong); auto attention_mask = torch::zeros({static_cast(request_batch->size()), max_length_}, torch::kLong); auto batch_tokens = torch::full({static_cast(request_batch->size()), max_length_}, tokenizer_->TokenToId(""), torch::kLong); - TS_LOG(INFO, "start Preprocess"); + uint8_t idx = 0; for (auto& request : *request_batch) { try { @@ -110,10 +109,8 @@ c10::IValue BertCppHandler::Preprocess( auto data_it = request.parameters.find( torchserve::PayloadType::kPARAMETER_NAME_DATA); - TS_LOG(INFO, "get data_it "); auto dtype_it = request.headers.find(torchserve::PayloadType::kHEADER_NAME_DATA_TYPE); - TS_LOG(INFO, "get data_it "); if (data_it == request.parameters.end()) { data_it = request.parameters.find( torchserve::PayloadType::kPARAMETER_NAME_BODY); @@ -123,7 +120,6 @@ c10::IValue BertCppHandler::Preprocess( if (data_it == request.parameters.end() || dtype_it == request.headers.end()) { - TS_LOGF(ERROR, "Empty payload for request id: {}", request.request_id); (*response_batch)[request.request_id]->SetResponse( 500, "data_type", torchserve::PayloadType::kCONTENT_TYPE_TEXT, "Empty payload"); @@ -131,23 +127,16 @@ c10::IValue BertCppHandler::Preprocess( } std::string msg = torchserve::Converter::VectorToStr(data_it->second); - TS_LOGF(INFO, "receive msg {}", msg); - // tokenization std::vector token_ids = tokenizer_->Encode(msg);; int cur_token_ids_length = (int)token_ids.size(); - for (int i = 0; i < cur_token_ids_length; i++) { - TS_LOGF(INFO, "token: {}, id: {}", i, token_ids[i]); - attention_mask[idx][i] = 1; - batch_tokens[idx][i] = token_ids[i]; - } - TS_LOGF(INFO, "cur_token_ids_length {}", cur_token_ids_length); - if (cur_token_ids_length > max_length_) { TS_LOGF(ERROR, "prompt too long ({} tokens, max {})", cur_token_ids_length, max_length_); } - TS_LOG(INFO, "pad token_ids"); - TS_LOG(INFO, "add token_ids to batch_tokens"); + for (int i = 0; i < std::min(cur_token_ids_length, max_length_); i++) { + attention_mask[idx][i] = 1; + batch_tokens[idx][i] = token_ids[i]; + } idx_to_req_id.second[idx++] = request.request_id; } catch (const std::runtime_error& e) { @@ -168,11 +157,7 @@ c10::IValue BertCppHandler::Preprocess( } auto batch_ivalue = c10::impl::GenericList(torch::TensorType::get()); batch_ivalue.emplace_back(batch_tokens.to(*device)); - std::cout << "input_ids: " << batch_tokens << std::endl; - TS_LOG(INFO, "add batch tokens to batch_ivalue"); - std::cout << "mask: " << attention_mask << std::endl; batch_ivalue.emplace_back(attention_mask.to(*device)); - TS_LOG(INFO, "add batch mask to batch_ivalue"); return batch_ivalue; } @@ -184,22 +169,16 @@ c10::IValue BertCppHandler::Inference( std::shared_ptr &response_batch) { c10::InferenceMode mode; try { - TS_LOG(INFO, "start Inference"); std::shared_ptr runner; if (device->is_cuda()) { runner = std::static_pointer_cast(model); } else { runner = std::static_pointer_cast(model); } - TS_LOG(INFO, "cast model to runner"); - auto vec = inputs.toTensorVector(); - for (ulong i=0; i < vec.size(); i++) { - std::cout << "item " << i << ", tensor:" << vec[i] << std::endl; - } - TS_LOG(INFO, "convert ivalue to TensorVector"); - //auto batch_output_tensor_vector = runner->run(inputs.toTensorVector()); + + auto batch_output_tensor_vector = runner->run(inputs.toTensorVector()); auto batch_output_tensor_vector = runner->run(vec); - TS_LOG(INFO, "get batch_output_tensor_vector"); + return c10::IValue(batch_output_tensor_vector[0]); } catch (std::runtime_error& e) { TS_LOG(ERROR, e.what()); diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.hh b/examples/cpp/aot_inductor/bert/src/bert_handler.hh index 604bf56696..e8fbce62e0 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.hh +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.hh @@ -9,6 +9,7 @@ #include #include #include +#include #include "src/backends/handler/base_handler.hh" @@ -49,9 +50,9 @@ private: std::unique_ptr LoadJsonFile(const std::string& file_path); const folly::dynamic& GetJsonValue(std::unique_ptr& json, const std::string& key); - std::unique_ptr config_json_; std::unique_ptr mapping_json_; std::unique_ptr tokenizer_; + std::unique_ptr model_config_yaml_; int max_length_; }; } // namespace bert From 869b9a3e709eafc99a9ca3183bfd01ae9bc2e70d Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 17 Feb 2024 19:03:01 -0800 Subject: [PATCH 26/42] fmt --- cpp/build.sh | 2 +- examples/cpp/aot_inductor/bert/README.md | 10 ++-------- examples/cpp/aot_inductor/bert/aot_compile_export.py | 8 ++++---- examples/cpp/aot_inductor/bert/src/bert_handler.cc | 2 -- 4 files changed, 7 insertions(+), 15 deletions(-) diff --git a/cpp/build.sh b/cpp/build.sh index b4f9ae84fe..86c3269cdb 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -207,7 +207,7 @@ function prepare_test_files() { cd ${BASE_DIR}/../examples/cpp/aot_inductor/bert/ python aot_compile_export.py mv bert-seq.so ${HANDLER_DIR}/bert-seq.so - mv Transformer_model/tokenizer_config.json ${HANDLER_DIR}/tokenizer_config.json + mv Transformer_model/tokenizer.json ${HANDLER_DIR}/tokenizer.json export TOKENIZERS_PARALLELISM="" fi fi diff --git a/examples/cpp/aot_inductor/bert/README.md b/examples/cpp/aot_inductor/bert/README.md index 8a54ade2ce..f4734ff3f1 100644 --- a/examples/cpp/aot_inductor/bert/README.md +++ b/examples/cpp/aot_inductor/bert/README.md @@ -34,7 +34,7 @@ handler: ### Generate Model Artifact Folder ```bash -torch-model-archiver --model-name bertcppaot --version 1.0 --handler ../../../../cpp/_build/test/resources/examples/aot_inductor/bert/libbert_handler:BertCppHandler --runtime LSP --extra-files index_to_name.json,../../../../cpp/_build/test/resources/examples/aot_inductor/bert_handler/bert-seq.so,../../../../cpp/_build/test/resources/examples/aot_inductor/bert_handler/tokenizer.json --config-file model-config.yaml --archive-format no-archive +torch-model-archiver --model-name bertcppaot --version 1.0 --handler ../../../../cpp/_build/test/resources/examples/aot_inductor/bert_handler/libbert_handler:BertCppHandler --runtime LSP --extra-files index_to_name.json,../../../../cpp/_build/test/resources/examples/aot_inductor/bert_handler/bert-seq.so,../../../../cpp/_build/test/resources/examples/aot_inductor/bert_handler/tokenizer.json --config-file model-config.yaml --archive-format no-archive ``` Create model store directory and move the folder `bertcppaot` @@ -56,11 +56,5 @@ Infer the model using the following command ``` curl http://localhost:8080/predictions/bertcppaot -T ../../../../cpp/test/resources/examples/aot_inductor/bert_handler/sample_text.txt -{ - "lens_cap": 0.0022578993812203407, - "lynx": 0.0032067005522549152, - "Egyptian_cat": 0.046274684369564056, - "tiger_cat": 0.13740436732769012, - "tabby": 0.2724998891353607 -} +Not Accepted ``` diff --git a/examples/cpp/aot_inductor/bert/aot_compile_export.py b/examples/cpp/aot_inductor/bert/aot_compile_export.py index dd6b078ed2..b697503751 100644 --- a/examples/cpp/aot_inductor/bert/aot_compile_export.py +++ b/examples/cpp/aot_inductor/bert/aot_compile_export.py @@ -11,6 +11,7 @@ ) set_seed(1) +MAX_BATCH_SIZE = 15 def transformers_model_dowloader( @@ -100,15 +101,14 @@ def transformers_model_dowloader( else: filename = os.path.join(dirname, "model-config.yaml") with open(filename, "r") as f: - settings = yaml.load(f) + settings = yaml.safe_load(f) mode = settings["handler"]["mode"] model_name = settings["handler"]["model_name"] num_labels = int(settings["handler"]["num_labels"]) - do_lower_case = settings["handler"]["do_lower_case"] + do_lower_case = bool(settings["handler"]["do_lower_case"]) max_length = int(settings["handler"]["max_length"]) - batch_size = int(settings.get("batch_size", "1")) - + batch_size = int(settings["batchSize"]) transformers_model_dowloader( mode, model_name, diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.cc b/examples/cpp/aot_inductor/bert/src/bert_handler.cc index 2ba7a1b8e5..4e522a8185 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.cc +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.cc @@ -177,8 +177,6 @@ c10::IValue BertCppHandler::Inference( } auto batch_output_tensor_vector = runner->run(inputs.toTensorVector()); - auto batch_output_tensor_vector = runner->run(vec); - return c10::IValue(batch_output_tensor_vector[0]); } catch (std::runtime_error& e) { TS_LOG(ERROR, e.what()); From d6fa80813450a974f2eb0b9730a1bdab1ec57780 Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 17 Feb 2024 19:25:50 -0800 Subject: [PATCH 27/42] update max setting --- examples/cpp/aot_inductor/bert/aot_compile_export.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/cpp/aot_inductor/bert/aot_compile_export.py b/examples/cpp/aot_inductor/bert/aot_compile_export.py index b697503751..1bb2ff0104 100644 --- a/examples/cpp/aot_inductor/bert/aot_compile_export.py +++ b/examples/cpp/aot_inductor/bert/aot_compile_export.py @@ -12,6 +12,7 @@ set_seed(1) MAX_BATCH_SIZE = 15 +MAX_LENGTH = 1024 def transformers_model_dowloader( @@ -75,8 +76,8 @@ def transformers_model_dowloader( attention_mask = torch.cat([inputs["attention_mask"]] * batch_size, 0).to( device ) - batch_dim = torch.export.Dim("batch", min=1, max=8) - seq_len_dim = torch.export.Dim("seq_len", min=1, max=max_length) + batch_dim = torch.export.Dim("batch", min=1, max=MAX_BATCH_SIZE) + seq_len_dim = torch.export.Dim("seq_len", min=1, max=MAX_LENGTH) torch._C._GLIBCXX_USE_CXX11_ABI = True model_so_path = torch._export.aot_compile( model, From c33f66c3602e91c41185bf09383ecafb1899c370 Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 17 Feb 2024 21:32:16 -0800 Subject: [PATCH 28/42] fix lint --- ts_scripts/spellcheck_conf/wordlist.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index 3cfa9e6840..600ddb39df 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1187,3 +1187,4 @@ FxGraphCache TorchInductor fx locustapache +bertcppaot From caa5042135c6bf1cde88a2cadee3281c4c1a1110 Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 17 Feb 2024 21:49:49 -0800 Subject: [PATCH 29/42] add limitation --- examples/cpp/aot_inductor/bert/aot_compile_export.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/cpp/aot_inductor/bert/aot_compile_export.py b/examples/cpp/aot_inductor/bert/aot_compile_export.py index 1bb2ff0104..cd54ce8b27 100644 --- a/examples/cpp/aot_inductor/bert/aot_compile_export.py +++ b/examples/cpp/aot_inductor/bert/aot_compile_export.py @@ -11,8 +11,9 @@ ) set_seed(1) +# PT2.2 has limitation on the max MAX_BATCH_SIZE = 15 -MAX_LENGTH = 1024 +MAX_LENGTH = 511 def transformers_model_dowloader( From d39ba51fe02ce4f6e8f7254beb3ed2de6f7157b1 Mon Sep 17 00:00:00 2001 From: lxning Date: Thu, 22 Feb 2024 20:33:38 -0800 Subject: [PATCH 30/42] pinned folly to v2024.02.19.00 --- .gitmodules | 3 +++ cpp/third-party/folly | 1 + 2 files changed, 4 insertions(+) create mode 160000 cpp/third-party/folly diff --git a/.gitmodules b/.gitmodules index 3125a3b997..6a015e8902 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,6 @@ [submodule "cpp/third-party/llama2.so"] path = cpp/third-party/llama2.so url = https://github.com/mreso/llama2.so.git +[submodule "cpp/third-party/folly"] + path = cpp/third-party/folly + url = https://github.com/facebook/folly.git diff --git a/cpp/third-party/folly b/cpp/third-party/folly new file mode 160000 index 0000000000..323e467e23 --- /dev/null +++ b/cpp/third-party/folly @@ -0,0 +1 @@ +Subproject commit 323e467e2375e535e10bda62faf2569e8f5c9b19 From 0e1d7739d10d69a213220692c307cf5d44211fb9 Mon Sep 17 00:00:00 2001 From: lxning Date: Thu, 22 Feb 2024 21:02:20 -0800 Subject: [PATCH 31/42] pinned yam-cpp with tags/0.8.0 --- .gitmodules | 3 +++ cpp/third-party/yaml-cpp | 1 + 2 files changed, 4 insertions(+) create mode 160000 cpp/third-party/yaml-cpp diff --git a/.gitmodules b/.gitmodules index 6a015e8902..cb90a5f329 100644 --- a/.gitmodules +++ b/.gitmodules @@ -13,3 +13,6 @@ [submodule "cpp/third-party/folly"] path = cpp/third-party/folly url = https://github.com/facebook/folly.git +[submodule "cpp/third-party/yaml-cpp"] + path = cpp/third-party/yaml-cpp + url = https://github.com/jbeder/yaml-cpp.git diff --git a/cpp/third-party/yaml-cpp b/cpp/third-party/yaml-cpp new file mode 160000 index 0000000000..76dc671573 --- /dev/null +++ b/cpp/third-party/yaml-cpp @@ -0,0 +1 @@ +Subproject commit 76dc6715734295ff1866bfc32872ff2278258fc8 From f8c71d4168589ee894bb9d1576784a172e78b99f Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 24 Feb 2024 21:11:38 -0800 Subject: [PATCH 32/42] pinned yaml-cpp 0.8.0 --- .gitmodules | 3 +++ cpp/third-party/tokenizers-cpp | 1 + 2 files changed, 4 insertions(+) create mode 160000 cpp/third-party/tokenizers-cpp diff --git a/.gitmodules b/.gitmodules index cb90a5f329..c19f73f813 100644 --- a/.gitmodules +++ b/.gitmodules @@ -16,3 +16,6 @@ [submodule "cpp/third-party/yaml-cpp"] path = cpp/third-party/yaml-cpp url = https://github.com/jbeder/yaml-cpp.git +[submodule "cpp/third-party/tokenizers-cpp"] + path = cpp/third-party/tokenizers-cpp + url = https://github.com/mlc-ai/tokenizers-cpp.git diff --git a/cpp/third-party/tokenizers-cpp b/cpp/third-party/tokenizers-cpp new file mode 160000 index 0000000000..27dbe17d72 --- /dev/null +++ b/cpp/third-party/tokenizers-cpp @@ -0,0 +1 @@ +Subproject commit 27dbe17d7268801ec720569167af905c88d3db50 From be81439f3fe9082aa7fda4eaaedb447534385b99 Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 24 Feb 2024 21:23:03 -0800 Subject: [PATCH 33/42] update build.sh --- cpp/build.sh | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/cpp/build.sh b/cpp/build.sh index 86c3269cdb..f752e9f889 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -24,12 +24,12 @@ function install_folly() { FOLLY_SRC_DIR=$BASE_DIR/third-party/folly FOLLY_BUILD_DIR=$DEPS_DIR/folly-build - if [ ! -d "$FOLLY_SRC_DIR" ] ; then - echo -e "${COLOR_GREEN}[ INFO ] Cloning folly repo ${COLOR_OFF}" - git clone https://github.com/facebook/folly.git "$FOLLY_SRC_DIR" - cd $FOLLY_SRC_DIR - git checkout tags/v2024.01.29.00 - fi + # if [ ! -d "$FOLLY_SRC_DIR" ] ; then + # echo -e "${COLOR_GREEN}[ INFO ] Cloning folly repo ${COLOR_OFF}" + # git clone https://github.com/facebook/folly.git "$FOLLY_SRC_DIR" + # cd $FOLLY_SRC_DIR + # git checkout tags/v2024.01.29.00 + # fi if [ ! -d "$FOLLY_BUILD_DIR" ] ; then echo -e "${COLOR_GREEN}[ INFO ] Building Folly ${COLOR_OFF}" @@ -128,12 +128,12 @@ function install_yaml_cpp() { YAML_CPP_SRC_DIR=$BASE_DIR/third-party/yaml-cpp YAML_CPP_BUILD_DIR=$DEPS_DIR/yaml-cpp-build - if [ ! -d "$YAML_CPP_SRC_DIR" ] ; then - echo -e "${COLOR_GREEN}[ INFO ] Cloning yaml-cpp repo ${COLOR_OFF}" - git clone https://github.com/jbeder/yaml-cpp.git "$YAML_CPP_SRC_DIR" - cd $YAML_CPP_SRC_DIR - git checkout tags/0.8.0 - fi + # if [ ! -d "$YAML_CPP_SRC_DIR" ] ; then + # echo -e "${COLOR_GREEN}[ INFO ] Cloning yaml-cpp repo ${COLOR_OFF}" + # git clone https://github.com/jbeder/yaml-cpp.git "$YAML_CPP_SRC_DIR" + # cd $YAML_CPP_SRC_DIR + # git checkout tags/0.8.0 + # fi if [ ! -d "$YAML_CPP_BUILD_DIR" ] ; then echo -e "${COLOR_GREEN}[ INFO ] Building yaml-cpp ${COLOR_OFF}" @@ -158,12 +158,12 @@ function install_yaml_cpp() { function install_tokenizer_cpp() { TOKENIZERS_CPP_SRC_DIR=$BASE_DIR/third-party/tokenizers-cpp - if [ ! -d "$TOKENIZERS_CPP_SRC_DIR" ] ; then - echo -e "${COLOR_GREEN}[ INFO ] Cloning tokenizers-cpp repo ${COLOR_OFF}" - git clone https://github.com/mlc-ai/tokenizers-cpp.git "$TOKENIZERS_CPP_SRC_DIR" - cd $TOKENIZERS_CPP_SRC_DIR - git submodule update --init --recursive - fi + # if [ ! -d "$TOKENIZERS_CPP_SRC_DIR" ] ; then + # echo -e "${COLOR_GREEN}[ INFO ] Cloning tokenizers-cpp repo ${COLOR_OFF}" + # git clone https://github.com/mlc-ai/tokenizers-cpp.git "$TOKENIZERS_CPP_SRC_DIR" + # cd $TOKENIZERS_CPP_SRC_DIR + # git submodule update --init --recursive + # fi cd "$BWD" || exit } From 71deb7038191d97b96c9ad3d304cd85b35034470 Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 24 Feb 2024 21:31:50 -0800 Subject: [PATCH 34/42] pinned yaml-cpp v0.8.0 --- cpp/third-party/yaml-cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/third-party/yaml-cpp b/cpp/third-party/yaml-cpp index 76dc671573..f732014112 160000 --- a/cpp/third-party/yaml-cpp +++ b/cpp/third-party/yaml-cpp @@ -1 +1 @@ -Subproject commit 76dc6715734295ff1866bfc32872ff2278258fc8 +Subproject commit f7320141120f720aecc4c32be25586e7da9eb978 From ebbf11993bdff763e9af5463851e990af6318f05 Mon Sep 17 00:00:00 2001 From: lxning Date: Sun, 25 Feb 2024 15:09:15 -0800 Subject: [PATCH 35/42] fmt --- cpp/src/utils/file_system.cc | 39 ++++++++++++++++++- cpp/src/utils/file_system.hh | 6 ++- .../aot_inductor/bert/aot_compile_export.py | 2 +- .../cpp/aot_inductor/bert/src/bert_handler.cc | 13 +++++-- .../cpp/aot_inductor/bert/src/bert_handler.hh | 12 ++---- 5 files changed, 56 insertions(+), 16 deletions(-) diff --git a/cpp/src/utils/file_system.cc b/cpp/src/utils/file_system.cc index 7ba9b13501..945f57a0c8 100644 --- a/cpp/src/utils/file_system.cc +++ b/cpp/src/utils/file_system.cc @@ -1,4 +1,8 @@ #include "src/utils/file_system.hh" +#include "src/utils/logging.hh" + +#include +#include namespace torchserve { std::unique_ptr FileSystem::GetStream( @@ -10,4 +14,37 @@ std::unique_ptr FileSystem::GetStream( } return file_stream; } -} // namespace torchserve \ No newline at end of file + +std::string FileSystem::LoadBytesFromFile(const std::string& path) { + std::ifstream fs(path, std::ios::in | std::ios::binary); + if (fs.fail()) { + TS_LOGF(ERROR, "Cannot open tokenizer file {}", path); + throw; + } + std::string data; + fs.seekg(0, std::ios::end); + size_t size = static_cast(fs.tellg()); + fs.seekg(0, std::ios::beg); + data.resize(size); + fs.read(data.data(), size); + return data; +} + +std::unique_ptr FileSystem::LoadJsonFile(const std::string& file_path) { + std::string content; + if (!folly::readFile(file_path.c_str(), content)) { + TS_LOGF(ERROR, "{} not found", file_path); + throw; + } + return std::make_unique(folly::parseJson(content)); +} + +const folly::dynamic& FileSystem::GetJsonValue(std::unique_ptr& json, const std::string& key) { + if (json->find(key) != json->items().end()) { + return (*json)[key]; + } else { + TS_LOG(ERROR, "Required field {} not found in JSON.", key); + throw ; + } +} +} // namespace torchserve diff --git a/cpp/src/utils/file_system.hh b/cpp/src/utils/file_system.hh index 352ccdcbb8..dd21fcbf7b 100644 --- a/cpp/src/utils/file_system.hh +++ b/cpp/src/utils/file_system.hh @@ -1,8 +1,7 @@ #ifndef TS_CPP_UTILS_FILE_SYSTEM_HH_ #define TS_CPP_UTILS_FILE_SYSTEM_HH_ -#include - +#include #include #include #include @@ -11,6 +10,9 @@ namespace torchserve { class FileSystem { public: static std::unique_ptr GetStream(const std::string& path); + static std::string LoadBytesFromFile(const std::string& path); + static std::unique_ptr LoadJsonFile(const std::string& file_path); + static const folly::dynamic& GetJsonValue(std::unique_ptr& json, const std::string& key); }; } // namespace torchserve #endif // TS_CPP_UTILS_FILE_SYSTEM_HH_ diff --git a/examples/cpp/aot_inductor/bert/aot_compile_export.py b/examples/cpp/aot_inductor/bert/aot_compile_export.py index cd54ce8b27..7a193420a8 100644 --- a/examples/cpp/aot_inductor/bert/aot_compile_export.py +++ b/examples/cpp/aot_inductor/bert/aot_compile_export.py @@ -65,7 +65,7 @@ def transformers_model_dowloader( device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device=device) - dummy_input = "This is a dummy input for torch jit trace" + dummy_input = "This is a dummy input for torch compile export" inputs = tokenizer.encode_plus( dummy_input, max_length=max_length, diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.cc b/examples/cpp/aot_inductor/bert/src/bert_handler.cc index 4e522a8185..51ec18cc39 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.cc +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.cc @@ -1,10 +1,15 @@ #include "bert_handler.hh" +#include "src/utils/file_system.hh" +#include +#include +#include +#include #include namespace bert { -std::string BertCppHandler::LoadBytesFromFile(const std::string& path) { +/* std::string BertCppHandler::LoadBytesFromFile(const std::string& path) { std::ifstream fs(path, std::ios::in | std::ios::binary); if (fs.fail()) { TS_LOGF(ERROR, "Cannot open tokenizer file {}", path); @@ -35,7 +40,7 @@ const folly::dynamic& BertCppHandler::GetJsonValue(std::unique_ptr, std::shared_ptr> BertCppHandler::LoadModel( @@ -51,14 +56,14 @@ BertCppHandler::LoadModel( const std::string mapFilePath = fmt::format("{}/{}", load_model_request->model_dir, (*model_config_yaml_)["handler"]["mapping"].as()); - mapping_json_ = LoadJsonFile(mapFilePath); + mapping_json_ = torchserve::FileSystem::LoadJsonFile(mapFilePath); max_length_ = (*model_config_yaml_)["handler"]["max_length"].as(); std::string tokenizer_path = fmt::format("{}/{}", load_model_request->model_dir, (*model_config_yaml_)["handler"]["tokenizer_path"].as()); - auto tokenizer_blob = LoadBytesFromFile(tokenizer_path); + auto tokenizer_blob = torchserve::FileSystem::LoadBytesFromFile(tokenizer_path); tokenizer_ = tokenizers::Tokenizer::FromBlobJSON(tokenizer_blob); std::string model_so_path = diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.hh b/examples/cpp/aot_inductor/bert/src/bert_handler.hh index e8fbce62e0..1276bad13d 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.hh +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.hh @@ -2,13 +2,9 @@ #include #include -#include -#include -#include + #include #include -#include -#include #include #include "src/backends/handler/base_handler.hh" @@ -46,9 +42,9 @@ class BertCppHandler : public torchserve::BaseHandler { override; private: - std::string LoadBytesFromFile(const std::string& path); - std::unique_ptr LoadJsonFile(const std::string& file_path); - const folly::dynamic& GetJsonValue(std::unique_ptr& json, const std::string& key); + // std::string LoadBytesFromFile(const std::string& path); + // std::unique_ptr LoadJsonFile(const std::string& file_path); + // const folly::dynamic& GetJsonValue(std::unique_ptr& json, const std::string& key); std::unique_ptr mapping_json_; std::unique_ptr tokenizer_; From a0c710b4eb20945f3c7cb682d02b35f8acae271b Mon Sep 17 00:00:00 2001 From: lxning Date: Sun, 25 Feb 2024 15:58:31 -0800 Subject: [PATCH 36/42] fix typo --- examples/cpp/aot_inductor/bert/src/bert_handler.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.cc b/examples/cpp/aot_inductor/bert/src/bert_handler.cc index 51ec18cc39..de9f8518a8 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.cc +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.cc @@ -204,7 +204,7 @@ void BertCppHandler::Postprocess( response->SetResponse(200, "data_type", torchserve::PayloadType::kDATA_TYPE_STRING, - GetJsonValue(mapping_json_, predicted_idx).asString()); + torchserve::FileSystem::GetJsonValue(mapping_json_, predicted_idx).asString()); } catch (const std::runtime_error &e) { TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}", kv.second, e.what()); From 20d879936883c4d8c30f6f1d828fdf0690f8a8c3 Mon Sep 17 00:00:00 2001 From: lxning Date: Sun, 25 Feb 2024 16:59:19 -0800 Subject: [PATCH 37/42] add submodule kineto --- .gitmodules | 3 +++ cpp/third-party/kineto | 1 + 2 files changed, 4 insertions(+) create mode 160000 cpp/third-party/kineto diff --git a/.gitmodules b/.gitmodules index c19f73f813..f24d0431c1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -19,3 +19,6 @@ [submodule "cpp/third-party/tokenizers-cpp"] path = cpp/third-party/tokenizers-cpp url = https://github.com/mlc-ai/tokenizers-cpp.git +[submodule "cpp/third-party/kineto"] + path = cpp/third-party/kineto + url = https://github.com/pytorch/kineto.git diff --git a/cpp/third-party/kineto b/cpp/third-party/kineto new file mode 160000 index 0000000000..594c63c50d --- /dev/null +++ b/cpp/third-party/kineto @@ -0,0 +1 @@ +Subproject commit 594c63c50dd9684a592ad7670ecdef6dd5e36be7 From 6accaf4d44c47c2288e70bad4f443bca52fcde4e Mon Sep 17 00:00:00 2001 From: lxning Date: Sun, 25 Feb 2024 17:14:22 -0800 Subject: [PATCH 38/42] fmt --- cpp/build.sh | 31 +---------------- .../cpp/aot_inductor/bert/src/bert_handler.cc | 34 ------------------- .../cpp/aot_inductor/bert/src/bert_handler.hh | 4 --- .../aot_inductor/resnet/src/resnet_handler.cc | 29 +++++----------- .../aot_inductor/resnet/src/resnet_handler.hh | 7 ---- 5 files changed, 9 insertions(+), 96 deletions(-) diff --git a/cpp/build.sh b/cpp/build.sh index 59b7f66162..5371ecae26 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -24,13 +24,6 @@ function install_folly() { FOLLY_SRC_DIR=$BASE_DIR/third-party/folly FOLLY_BUILD_DIR=$DEPS_DIR/folly-build - # if [ ! -d "$FOLLY_SRC_DIR" ] ; then - # echo -e "${COLOR_GREEN}[ INFO ] Cloning folly repo ${COLOR_OFF}" - # git clone https://github.com/facebook/folly.git "$FOLLY_SRC_DIR" - # cd $FOLLY_SRC_DIR - # git checkout tags/v2024.01.29.00 - # fi - if [ ! -d "$FOLLY_BUILD_DIR" ] ; then echo -e "${COLOR_GREEN}[ INFO ] Building Folly ${COLOR_OFF}" cd $FOLLY_SRC_DIR @@ -60,9 +53,7 @@ function install_kineto() { elif [ "$PLATFORM" = "Mac" ]; then KINETO_SRC_DIR=$BASE_DIR/third-party/kineto - if [ ! -d "$KINETO_SRC_DIR" ] ; then - echo -e "${COLOR_GREEN}[ INFO ] Cloning kineto repo ${COLOR_OFF}" - git clone --recursive https://github.com/pytorch/kineto.git "$KINETO_SRC_DIR" + if [ ! -d "$KINETO_SRC_DIR/libkineto/build" ] ; then cd $KINETO_SRC_DIR/libkineto mkdir build && cd build cmake .. @@ -128,13 +119,6 @@ function install_yaml_cpp() { YAML_CPP_SRC_DIR=$BASE_DIR/third-party/yaml-cpp YAML_CPP_BUILD_DIR=$DEPS_DIR/yaml-cpp-build - # if [ ! -d "$YAML_CPP_SRC_DIR" ] ; then - # echo -e "${COLOR_GREEN}[ INFO ] Cloning yaml-cpp repo ${COLOR_OFF}" - # git clone https://github.com/jbeder/yaml-cpp.git "$YAML_CPP_SRC_DIR" - # cd $YAML_CPP_SRC_DIR - # git checkout tags/0.8.0 - # fi - if [ ! -d "$YAML_CPP_BUILD_DIR" ] ; then echo -e "${COLOR_GREEN}[ INFO ] Building yaml-cpp ${COLOR_OFF}" @@ -155,19 +139,6 @@ function install_yaml_cpp() { cd "$BWD" || exit } -function install_tokenizer_cpp() { - TOKENIZERS_CPP_SRC_DIR=$BASE_DIR/third-party/tokenizers-cpp - - # if [ ! -d "$TOKENIZERS_CPP_SRC_DIR" ] ; then - # echo -e "${COLOR_GREEN}[ INFO ] Cloning tokenizers-cpp repo ${COLOR_OFF}" - # git clone https://github.com/mlc-ai/tokenizers-cpp.git "$TOKENIZERS_CPP_SRC_DIR" - # cd $TOKENIZERS_CPP_SRC_DIR - # git submodule update --init --recursive - # fi - - cd "$BWD" || exit -} - function build_llama_cpp() { BWD=$(pwd) LLAMA_CPP_SRC_DIR=$BASE_DIR/third-party/llama.cpp diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.cc b/examples/cpp/aot_inductor/bert/src/bert_handler.cc index de9f8518a8..3fdf950a1f 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.cc +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.cc @@ -8,40 +8,6 @@ #include namespace bert { - -/* std::string BertCppHandler::LoadBytesFromFile(const std::string& path) { - std::ifstream fs(path, std::ios::in | std::ios::binary); - if (fs.fail()) { - TS_LOGF(ERROR, "Cannot open tokenizer file {}", path); - throw; - } - std::string data; - fs.seekg(0, std::ios::end); - size_t size = static_cast(fs.tellg()); - fs.seekg(0, std::ios::beg); - data.resize(size); - fs.read(data.data(), size); - return data; -} - -std::unique_ptr BertCppHandler::LoadJsonFile(const std::string& file_path) { - std::string content; - if (!folly::readFile(file_path.c_str(), content)) { - TS_LOGF(ERROR, "{} not found", file_path); - throw; - } - return std::make_unique(folly::parseJson(content)); -} - -const folly::dynamic& BertCppHandler::GetJsonValue(std::unique_ptr& json, const std::string& key) { - if (json->find(key) != json->items().end()) { - return (*json)[key]; - } else { - TS_LOG(ERROR, "Required field {} not found in JSON.", key); - throw ; - } -} */ - std::pair, std::shared_ptr> BertCppHandler::LoadModel( std::shared_ptr& load_model_request) { diff --git a/examples/cpp/aot_inductor/bert/src/bert_handler.hh b/examples/cpp/aot_inductor/bert/src/bert_handler.hh index 1276bad13d..80a3d68cb0 100644 --- a/examples/cpp/aot_inductor/bert/src/bert_handler.hh +++ b/examples/cpp/aot_inductor/bert/src/bert_handler.hh @@ -42,10 +42,6 @@ class BertCppHandler : public torchserve::BaseHandler { override; private: - // std::string LoadBytesFromFile(const std::string& path); - // std::unique_ptr LoadJsonFile(const std::string& file_path); - // const folly::dynamic& GetJsonValue(std::unique_ptr& json, const std::string& key); - std::unique_ptr mapping_json_; std::unique_ptr tokenizer_; std::unique_ptr model_config_yaml_; diff --git a/examples/cpp/aot_inductor/resnet/src/resnet_handler.cc b/examples/cpp/aot_inductor/resnet/src/resnet_handler.cc index a7d55381bb..351ac7e216 100644 --- a/examples/cpp/aot_inductor/resnet/src/resnet_handler.cc +++ b/examples/cpp/aot_inductor/resnet/src/resnet_handler.cc @@ -1,30 +1,17 @@ #include "resnet_handler.hh" +#include "src/utils/file_system.hh" +#include +#include +#include +#include #include namespace resnet { -std::unique_ptr ResnetCppHandler::LoadJsonFile(const std::string& file_path) { - std::string content; - if (!folly::readFile(file_path.c_str(), content)) { - TS_LOGF(ERROR, "{}} not found", file_path); - throw; - } - return std::make_unique(folly::parseJson(content)); -} - -const folly::dynamic& ResnetCppHandler::GetJsonValue(std::unique_ptr& json, const std::string& key) { - if (json->find(key) != json->items().end()) { - return (*json)[key]; - } else { - TS_LOG(ERROR, "Required field {} not found in JSON.", key); - throw ; - } -} - std::string ResnetCppHandler::MapClassToLabel(const torch::Tensor& classes, const torch::Tensor& probs) { folly::dynamic map = folly::dynamic::object; for (int i = 0; i < classes.sizes()[0]; i++) { - auto class_value = GetJsonValue(mapping_json_, std::to_string(classes[i].item())); + auto class_value = torchserve::FileSystem::GetJsonValue(mapping_json_, std::to_string(classes[i].item())); map[class_value[1].asString()] = probs[i].item(); } @@ -44,12 +31,12 @@ ResnetCppHandler::LoadModel( const std::string mapFilePath = fmt::format("{}/{}", load_model_request->model_dir, (*model_config_yaml_)["handler"]["mapping"].as()); - mapping_json_ = LoadJsonFile(mapFilePath); + mapping_json_ = torchserve::FileSystem::LoadJsonFile(mapFilePath); std::string model_so_path = fmt::format("{}/{}", load_model_request->model_dir, (*model_config_yaml_)["handler"]["model_so_path"].as()); - mapping_json_ = LoadJsonFile(mapFilePath); + mapping_json_ = torchserve::FileSystem::LoadJsonFile(mapFilePath); c10::InferenceMode mode; if (device->is_cuda()) { diff --git a/examples/cpp/aot_inductor/resnet/src/resnet_handler.hh b/examples/cpp/aot_inductor/resnet/src/resnet_handler.hh index db20b551fe..4e43ea9fad 100644 --- a/examples/cpp/aot_inductor/resnet/src/resnet_handler.hh +++ b/examples/cpp/aot_inductor/resnet/src/resnet_handler.hh @@ -2,12 +2,7 @@ #include #include -#include -#include -#include #include -#include -#include #include #include "src/backends/handler/base_handler.hh" @@ -45,8 +40,6 @@ class ResnetCppHandler : public torchserve::BaseHandler { override; private: - std::unique_ptr LoadJsonFile(const std::string& file_path); - const folly::dynamic& GetJsonValue(std::unique_ptr& json, const std::string& key); std::string MapClassToLabel(const torch::Tensor& classes, const torch::Tensor& probs); std::unique_ptr mapping_json_; From ee74ad9adec125846d9c19cf87fb70cf6e0cf580 Mon Sep 17 00:00:00 2001 From: lxning Date: Sun, 25 Feb 2024 18:20:45 -0800 Subject: [PATCH 39/42] fix workflow --- .github/workflows/ci-cpu-cpp.yml | 2 +- cpp/build.sh | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/ci-cpu-cpp.yml b/.github/workflows/ci-cpu-cpp.yml index ec0a9c88d8..5e5bfc301e 100644 --- a/.github/workflows/ci-cpu-cpp.yml +++ b/.github/workflows/ci-cpu-cpp.yml @@ -31,4 +31,4 @@ jobs: python ts_scripts/install_dependencies.py --environment=dev --cpp - name: Build run: | - cd cpp && ./build.sh --install-dependencies + cd cpp && ./build.sh diff --git a/cpp/build.sh b/cpp/build.sh index 5371ecae26..74d2ddd971 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -373,7 +373,6 @@ install_folly install_kineto install_libtorch install_yaml_cpp -install_tokenizer_cpp build_llama_cpp prepare_test_files build From 9b6736428ba7d03e3ee62ed99fabe87728271a9b Mon Sep 17 00:00:00 2001 From: lxning Date: Sun, 25 Feb 2024 19:07:15 -0800 Subject: [PATCH 40/42] fix workflow --- cpp/build.sh | 2 +- ts_scripts/install_dependencies.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/build.sh b/cpp/build.sh index 74d2ddd971..b986163852 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -370,7 +370,7 @@ cd $BASE_DIR git submodule update --init --recursive install_folly -install_kineto +#install_kineto install_libtorch install_yaml_cpp build_llama_cpp diff --git a/ts_scripts/install_dependencies.py b/ts_scripts/install_dependencies.py index f4fa4572b8..ff9ae4c25e 100644 --- a/ts_scripts/install_dependencies.py +++ b/ts_scripts/install_dependencies.py @@ -49,6 +49,7 @@ "libgoogle-perftools-dev", "rustc", "cargo", + "libunwind-dev", ) CPP_DARWIN_DEPENDENCIES = ( From 9c1a33a7e5538c8f92056cdc21fc5d08329c3ef9 Mon Sep 17 00:00:00 2001 From: lxning Date: Sun, 25 Feb 2024 19:50:23 -0800 Subject: [PATCH 41/42] fix ubuntu version --- .github/workflows/ci-cpu-cpp.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-cpu-cpp.yml b/.github/workflows/ci-cpu-cpp.yml index 5e5bfc301e..0b3a6529c2 100644 --- a/.github/workflows/ci-cpu-cpp.yml +++ b/.github/workflows/ci-cpu-cpp.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, macOS-latest] + os: [ubuntu-20.04, macOS-latest] steps: - name: Checkout TorchServe uses: actions/checkout@v2 From 9ab5336ef31d7b1aab4e159b8bd9bf98c535bc04 Mon Sep 17 00:00:00 2001 From: lxning Date: Tue, 27 Feb 2024 10:42:10 -0800 Subject: [PATCH 42/42] update readme --- examples/cpp/aot_inductor/bert/README.md | 5 +++-- examples/cpp/aot_inductor/bert/aot_compile_export.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/cpp/aot_inductor/bert/README.md b/examples/cpp/aot_inductor/bert/README.md index f4734ff3f1..f10e106f86 100644 --- a/examples/cpp/aot_inductor/bert/README.md +++ b/examples/cpp/aot_inductor/bert/README.md @@ -1,5 +1,6 @@ -This example uses AOTInductor to compile the [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) into an so file which is then executed using libtorch. -The handler C++ source code for this examples can be found [here](src). +This example uses AOTInductor to compile the [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) into an so file (see script [aot_compile_export.py](aot_compile_export.py)). In PyTorch 2.2, the supported `MAX_SEQ_LENGTH` in this script is 511. + +Then, this example loads model and runs prediction using libtorch. The handler C++ source code for this examples can be found [here](src). ### Setup 1. Follow the instructions in [README.md](../../../../cpp/README.md) to build the TorchServe C++ backend. diff --git a/examples/cpp/aot_inductor/bert/aot_compile_export.py b/examples/cpp/aot_inductor/bert/aot_compile_export.py index 7a193420a8..2a01ad1c21 100644 --- a/examples/cpp/aot_inductor/bert/aot_compile_export.py +++ b/examples/cpp/aot_inductor/bert/aot_compile_export.py @@ -13,7 +13,7 @@ set_seed(1) # PT2.2 has limitation on the max MAX_BATCH_SIZE = 15 -MAX_LENGTH = 511 +MAX_SEQ_LENGTH = 511 def transformers_model_dowloader( @@ -78,7 +78,7 @@ def transformers_model_dowloader( device ) batch_dim = torch.export.Dim("batch", min=1, max=MAX_BATCH_SIZE) - seq_len_dim = torch.export.Dim("seq_len", min=1, max=MAX_LENGTH) + seq_len_dim = torch.export.Dim("seq_len", min=1, max=MAX_SEQ_LENGTH) torch._C._GLIBCXX_USE_CXX11_ABI = True model_so_path = torch._export.aot_compile( model,