diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc index 08ac2ae0ec45..23f7339605df 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc @@ -40,7 +40,7 @@ namespace contrib { TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger, const std::vector& data_entry, size_t max_workspace_size, bool use_implicit_batch, bool use_fp16, - int batch_size) + int batch_size, nvinfer1::IInt8Calibrator* calibrator) : data_entry_(data_entry), max_workspace_size_(max_workspace_size), use_implicit_batch_(use_implicit_batch), @@ -48,6 +48,8 @@ TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger, batch_size_(batch_size) { // Create TRT builder and network. builder_ = nvinfer1::createInferBuilder(*logger); + use_int8_ = false; + #if TRT_VERSION_GE(6, 0, 1) // Use INetworkV2. auto flags = @@ -56,9 +58,15 @@ TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger, flags = 0U; builder_->setMaxBatchSize(batch_size_); } + this->calibrator_ = calibrator; + if (calibrator != nullptr) { + use_int8_ = true; + builder_->setFp16Mode(true); + builder_->setInt8Mode(true); + builder_->setInt8Calibrator(calibrator); + } network_ = builder_->createNetworkV2(flags); #else - // Use INetwork with implicit batch. builder_->setMaxBatchSize(batch_size_); builder_->setMaxWorkspaceSize(max_workspace_size_); builder_->setFp16Mode(use_fp16_); @@ -158,6 +166,13 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() { if (use_fp16_) { config_->setFlag(nvinfer1::BuilderFlag::kFP16); } + + if (use_int8_) { + config_->setFlag(nvinfer1::BuilderFlag::kINT8); + config_->setInt8Calibrator(calibrator_); + LOG(INFO) << "config finishes setting up calibrator as INT8 mode ... "; + } + // Add profiles. if (!use_implicit_batch_) { auto profile = builder_->createOptimizationProfile(); diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.h b/src/runtime/contrib/tensorrt/tensorrt_builder.h index 0b1c3997ec57..bf74630bce7f 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.h +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.h @@ -72,8 +72,8 @@ class TensorRTBuilder { * \param batch_size If use_implicit_batch, */ TensorRTBuilder(TensorRTLogger* logger, const std::vector& data_entry, - size_t max_workspace_size, bool use_implicit_batch, bool use_fp16, - int batch_size); + size_t max_workspace_size, bool use_implicit_batch, bool use_fp16, int batch_size, + nvinfer1::IInt8Calibrator* calibrator = nullptr); /*! * \brief Add TensorRT input(s) for input node in network definition. @@ -153,6 +153,9 @@ class TensorRTBuilder { /*! \brief Whether to automatically convert model to 16-bit floating point precision. */ bool use_fp16_; + /*! \brief whether to automatically convert model to int8 precision */ + bool use_int8_; + /*! \brief Batch size to optimize for. */ int batch_size_; @@ -161,6 +164,10 @@ class TensorRTBuilder { /*! \brief Output names. */ std::vector network_output_names_; + + /*! \brief calibrator pointer to add batch data when using int8 mode */ + /*! \brief pointer will be nullptr when it is fp16 or fp32 precision */ + nvinfer1::IInt8Calibrator* calibrator_; }; } // namespace contrib diff --git a/src/runtime/contrib/tensorrt/tensorrt_calibrator.h b/src/runtime/contrib/tensorrt/tensorrt_calibrator.h new file mode 100755 index 000000000000..1e340d287629 --- /dev/null +++ b/src/runtime/contrib/tensorrt/tensorrt_calibrator.h @@ -0,0 +1,130 @@ +/* * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + + * file runtime/contrib/tensorrt/tensorrt_builder.h + * brief Contains TensorRTBuilder class which can be used to convert a relay + * program into a TRT engine which can be used for inference. +*/ + +#ifndef TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_CALIBRATOR_H_ +#define TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_CALIBRATOR_H_ + +#include +#include + +#include "../../cuda/cuda_common.h" +#include "NvInfer.h" + +namespace tvm { +namespace runtime { + +class TensorRTCalibrator : public nvinfer1::IInt8EntropyCalibrator2 { + public: + TensorRTCalibrator(int batch_size, const std::vector& input_names) + : batch_size_(batch_size), num_batches_calibrated_(0), input_names_(input_names) {} + + ~TensorRTCalibrator() { + // Free calibration data + for (auto& inputs : data_) { + for (size_t i = 0; i < inputs.size(); ++i) { + delete[] inputs[i]; + } + } + // Free buffers + for (size_t i = 0; i < buffers_.size(); ++i) { + CUDA_CALL(cudaFree(buffers_[i])); + } + } + + void AddBatchData(const std::vector& bindings, const std::vector& binding_sizes) { + // Copy data from GPU + std::vector data_host(bindings.size(), nullptr); + for (size_t i = 0; i < bindings.size(); ++i) { + data_host[i] = new float[batch_size_ * binding_sizes[i]]; + CUDA_CALL(cudaMemcpy(static_cast(data_host[i]), bindings[i], + batch_size_ * binding_sizes[i] * sizeof(float), cudaMemcpyDeviceToHost)); + } + data_.push_back(data_host); + data_sizes_.push_back(binding_sizes); + } + + int getBatchSize() const override { return batch_size_; } + + /*! + * \brief TensorRT will call this method to get next batch of data to + * calibrate with. + */ + bool getBatch(void* bindings[], const char* names[], int nbBindings) override { + AllocateBuffersIfNotAllocated(); + CHECK_EQ(input_names_.size(), nbBindings); + for (size_t i = 0; i < input_names_.size(); ++i) { + CHECK_EQ(input_names_[i], names[i]); + CUDA_CALL(cudaMemcpy(buffers_[i], data_[num_batches_calibrated_][i], + batch_size_ * data_sizes_[num_batches_calibrated_][i] * sizeof(float), + cudaMemcpyHostToDevice)); + bindings[i] = buffers_[i]; + } + num_batches_calibrated_++; + // TODO(trevmorr): Free data from previous batch? + return (num_batches_calibrated_ < data_.size()); + } + + const void* readCalibrationCache(size_t& length) override { + if (calibration_cache_.empty()) return nullptr; + length = calibration_cache_.size(); + return calibration_cache_.data(); + } + + void writeCalibrationCache(const void* cache, size_t length) override { + calibration_cache_.assign(static_cast(cache), length); + } + + private: + /*! \brief Batch size. */ + int batch_size_; + /*! \brief Number of batches already fed to calibrator. */ + int num_batches_calibrated_; + /*! \brief Storage for calibration cache. */ + std::string calibration_cache_; + + /*! \brief Data to be used for calibration. */ + std::vector> data_; + /*! \brief Number of elements for data to be used for calibration. */ + std::vector> data_sizes_; + + /*! \brief Device buffers to be used for calibration. */ + std::vector buffers_; + + /*! \brief Names of inputs */ + const std::vector input_names_; + + /*! \brief Allocate device memory buffers. data_sizes_ must already have one + * entry. */ + void AllocateBuffersIfNotAllocated() { + if (!buffers_.empty()) return; + CHECK_GE(data_sizes_.size(), 1); + const int num_inputs = data_sizes_[0].size(); + buffers_.assign(num_inputs, nullptr); + for (int i = 0; i < num_inputs; ++i) { + CUDA_CALL(cudaMalloc(&buffers_[i], data_sizes_[0][i] * sizeof(float))); + } + } +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_CALIBRATOR_H_ diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index 5562f853383c..a5779f739dac 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -27,6 +27,10 @@ #include #include +#include +#include +#include +#include #include "../../file_utils.h" #include "../json/json_node.h" @@ -35,6 +39,8 @@ #ifdef TVM_GRAPH_EXECUTOR_TENSORRT #include "NvInfer.h" #include "tensorrt_builder.h" +#include "tensorrt_calibrator.h" +#include "tensorrt_utils.h" #endif namespace tvm { @@ -66,7 +72,22 @@ class TensorRTRuntime : public JSONRuntimeBase { use_implicit_batch_(true), max_workspace_size_(size_t(1) << 30), max_batch_size_(-1), - multi_engine_mode_(false) {} + multi_engine_mode_(false) { + const bool use_int8 = dmlc::GetEnv("TVM_TENSORRT_USE_INT8", false); + multi_engine_mode_ = dmlc::GetEnv("TVM_TENSORRT_MULTI_ENGINE", false); + num_calibration_batches_remaining_ = dmlc::GetEnv("TENSORRT_NUM_CALI_INT8", 0); + if (use_int8) { + ICHECK(num_calibration_batches_remaining_ != 0) + << "When using INT8 mode, " + << "environment variable TENSORRT_NUM_CALI_INT8" + << "must also be set to specify the number of " + << "calibration times"; + LOG(INFO) << "settiing up " << num_calibration_batches_remaining_ + << " sample data to calibrate data ... "; + ICHECK(multi_engine_mode_ == false) << "When using int8 mode, " + << "multi-engine is not allowed"; + } + } /*! * \brief The type key of the module. @@ -87,7 +108,6 @@ class TensorRTRuntime : public JSONRuntimeBase { LoadGlobalAttributes(); if (GetCachedEnginesFromDisk()) return; SetupConstants(consts); - multi_engine_mode_ = dmlc::GetEnv("TVM_TENSORRT_MULTI_ENGINE", false); } void LoadGlobalAttributes() { @@ -130,7 +150,9 @@ class TensorRTRuntime : public JSONRuntimeBase { if (batch_size == 0) return; auto engine = engine_and_context.engine; auto context = engine_and_context.context; - std::vector bindings(engine->getNbBindings(), nullptr); + const int num_bindings = engine->getNbBindings(); + std::vector bindings(num_bindings, nullptr); + std::vector binding_sizes(num_bindings, 0); // Setup input bindings. for (size_t i = 0; i < input_nodes_.size(); ++i) { auto nid = input_nodes_[i]; @@ -153,9 +175,26 @@ class TensorRTRuntime : public JSONRuntimeBase { device_buffer.CopyFrom(data_entry_[eid]); bindings[binding_index] = device_buffer->data; } + + auto dims = engine->getBindingDimensions(binding_index); + int num_elements = 1; + for (int i = 0; i < dims.nbDims; ++i) num_elements *= dims.d[i]; + binding_sizes[binding_index] = num_elements; } } } + + // add batch data to calibrator + if (num_calibration_batches_remaining_ > 0) { + if (calibrator_ != nullptr) { + LOG(INFO) << "Starting adding last " << num_calibration_batches_remaining_ + << "-th batch data to the calibrator"; + calibrator_->AddBatchData(bindings, binding_sizes); + num_calibration_batches_remaining_--; + } + return; + } + // Setup output bindings. for (size_t i = 0; i < outputs_.size(); ++i) { uint32_t eid = EntryID(outputs_[i]); @@ -225,10 +264,16 @@ class TensorRTRuntime : public JSONRuntimeBase { TensorRTEngineAndContext& GetOrBuildEngine() { int batch_size = GetBatchSize(); int compatible_engine_batch_size = -1; - if (FindCompatibleEngine(batch_size, &compatible_engine_batch_size)) { + bool find_engine_flag = FindCompatibleEngine(batch_size, &compatible_engine_batch_size); + const bool use_int8 = (dmlc::GetEnv("TVM_TENSORRT_USE_INT8", 0) != 0); + const bool int8_calibration_not_used_or_not_complete = + (calibrator_ != nullptr && num_calibration_batches_remaining_ != 0); + if (find_engine_flag && + (!use_int8 || calibrator_ == nullptr || int8_calibration_not_used_or_not_complete)) { // A compatible engine already exists. return trt_engine_cache_.at(std::make_pair(symbol_name_, compatible_engine_batch_size)); } + // For single engine mode, remove previous engine and update max_batch_size. if (!multi_engine_mode_) { DestroyEngines(); @@ -236,11 +281,32 @@ class TensorRTRuntime : public JSONRuntimeBase { } DLOG(INFO) << "Building new TensorRT engine for subgraph " << symbol_name_ << " with batch size " << batch_size; + + // Build engine. + if (calibrator_ != nullptr && num_calibration_batches_remaining_ == 0) { + // Calibration complete and build int8 engine + BuildEngineFromJson(batch_size); + calibrator_.reset(nullptr); + } else { + // Build new engine + BuildEngineFromJson(batch_size); + TensorRTEngineAndContext& engine_and_context = + trt_engine_cache_[std::make_pair(symbol_name_, batch_size)]; + if (use_int8) { + this->CreateInt8Calibrator(engine_and_context); + } + } + + LOG(INFO) << "Finished building TensorRT engine for subgraph " << symbol_name_ + << " with batch size " << batch_size; + CacheEngineToDisk(); + return trt_engine_cache_.at(std::make_pair(symbol_name_, batch_size)); + } + + void BuildEngineFromJson(int batch_size) { const bool use_fp16 = dmlc::GetEnv("TVM_TENSORRT_USE_FP16", false); TensorRTBuilder builder(&logger_, data_entry_, max_workspace_size_, use_implicit_batch_, - use_fp16, batch_size); - - // Add inputs and constants. + use_fp16, batch_size, calibrator_.get()); for (size_t i = 0; i < input_nodes_.size(); ++i) { auto nid = input_nodes_[i]; const auto& node = nodes_[nid]; @@ -266,12 +332,8 @@ class TensorRTRuntime : public JSONRuntimeBase { builder.AddOutput(outputs_[i], EntryID(outputs_[i])); } - // Build engine. - trt_engine_cache_[std::make_pair(symbol_name_, batch_size)] = builder.BuildEngine(); - DLOG(INFO) << "Finished building TensorRT engine for subgraph " << symbol_name_ - << " with batch size " << batch_size; - CacheEngineToDisk(); - return trt_engine_cache_.at(std::make_pair(symbol_name_, batch_size)); + TensorRTEngineAndContext engine_and_context = builder.BuildEngine(); + trt_engine_cache_[std::make_pair(symbol_name_, batch_size)] = engine_and_context; } /*! \brief If TVM_TENSORRT_CACHE_DIR is set, will check that directory for @@ -286,7 +348,7 @@ class TensorRTRuntime : public JSONRuntimeBase { // Check if engine is in the cache. std::ifstream infile(path, std::ios::binary); if (!infile.good()) return false; - DLOG(INFO) << "Loading cached TensorRT engine from " << path; + LOG(INFO) << "Loading cached TensorRT engine from " << path; infile.close(); std::string serialized_engine; LoadBinaryFromFile(path, &serialized_engine); @@ -308,6 +370,7 @@ class TensorRTRuntime : public JSONRuntimeBase { helper.ReadAllFields(&reader); const int batch_size = GetBatchSize(); trt_engine_cache_[std::make_pair(symbol_name_, batch_size)] = engine_and_context; + LOG(INFO) << "finished saving engine and context ... "; return true; } @@ -369,10 +432,24 @@ class TensorRTRuntime : public JSONRuntimeBase { return device_buffers_.at(binding_index); } + void CreateInt8Calibrator(const TensorRTEngineAndContext& engine_and_context) { + // Get input names in binding order. + std::vector input_names; + for (size_t i = 0; i < engine_and_context.inputs.size(); i++) { + std::string ele = engine_and_context.inputs[i]; + input_names.push_back(ele); + } + const int batch_size = GetBatchSize(); + calibrator_.reset(new TensorRTCalibrator(batch_size, input_names)); + } + /*! \brief Map of function name and max batch size to TRT engine if built already. */ std::unordered_map, TensorRTEngineAndContext, PairHash> trt_engine_cache_; + /*! \brief Calibrator for INT8 mode. */ + std::unique_ptr calibrator_; + /*! \brief Map of inding index to GPU buffers for inputs and outputs. Only used when target device * is not "cuda". Since TensorRT execution can only read data from GPU, we need to copy data from * the runtime device to these buffers first. These will be allocated for the highest batch size @@ -402,6 +479,9 @@ class TensorRTRuntime : public JSONRuntimeBase { size_t max_workspace_size_; + /*! \brief Number of calibration batches until we are done. */ + int num_calibration_batches_remaining_; + /*! \brief Highest batch size that an engine has been built for, used in single-engine mode only * (multi_engine_mode=false). */ int max_batch_size_; diff --git a/tests/python/contrib/test_tensorrt_int8_exp.py b/tests/python/contrib/test_tensorrt_int8_exp.py new file mode 100644 index 000000000000..84360e92d33b --- /dev/null +++ b/tests/python/contrib/test_tensorrt_int8_exp.py @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import os +import numpy as np + +import tvm +import tvm.relay.testing +from tvm import relay +from tvm.contrib.download import download_testdata +from tvm.relay.op.contrib.tensorrt import partition_for_tensorrt +from tvm.relay.op.contrib import tensorrt + + +def skip_codegen_test(): + """Skip test if TensorRT and CUDA codegen are not present""" + if not tvm.runtime.enabled("cuda") or not tvm.cuda(0).exist: + print("Skip because CUDA is not enabled.") + return True + if not tvm.get_global_func("relay.ext.tensorrt", True): + print("Skip because TensorRT codegen is not available.") + return True + return False + + +def skip_runtime_test(): + if not tvm.runtime.enabled("cuda") or not tvm.cuda(0).exist: + print("Skip because CUDA is not enabled.") + return True + if not tensorrt.is_tensorrt_runtime_enabled(): + print("Skip because TensorRT runtime is not available.") + return True + return False + + +def test_trt_int8(): + """ + This Function is used to use tensorrt int8 to compile a resnet34 model, + and compare cosine distance between the output of the original model and trt int8 tvm ouput + + """ + if skip_codegen_test() or skip_runtime_test(): + return + + try: + from PIL import Image + from scipy.spatial import distance + except: + print("please install scipy and Image python packages") + return + + try: + import torch + import torchvision + from torchvision import transforms + except: + print("please install pytorch python package") + return + + os.environ["TVM_TENSORRT_USE_INT8"] = "1" + os.environ["TENSORRT_NUM_CALI_INT8"] = "10" + model_name = "resnet34" + model = getattr(torchvision.models, model_name)(pretrained=True) + model = model.eval() + + # We grab the TorchScripted model via tracing + input_shape = [1, 3, 224, 224] + input_data = torch.randn(input_shape) + scripted_model = torch.jit.trace(model, input_data).eval() + + img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true" + img_path = download_testdata(img_url, "cat.png", module="data") + img = Image.open(img_path).resize((224, 224)) + my_preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + img = my_preprocess(img) + img = np.expand_dims(img, 0) + + input_name = "input0" + shape_list = [(input_name, img.shape)] + mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) + + # compile the model + target = "cuda" + dev = tvm.cuda(1) + mod, config = partition_for_tensorrt(mod, params) + with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): + lib = relay.build(mod, target=target, params=params) + + dtype = "float32" + gen_module = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + num_cali_int8 = int(os.environ["TENSORRT_NUM_CALI_INT8"]) + if num_cali_int8 != 0: + print("start calibrating data ... ") + for i in range(num_cali_int8): + tvm_data = tvm.nd.array(img) + gen_module.set_input(input_name, tvm_data) + gen_module.run(data=tvm_data) + print("finished calibrating data ... ") + + # get output of tvm model + print("rebuild engine and test to run ... ") + tvm_data = tvm.nd.array(img) + gen_module.set_input(input_name, tvm_data) + gen_module.run(data=tvm_data) + out = gen_module.get_output(0) + + # check output of tvm and output of pytorch model are equal + torch_data = torch.from_numpy(img) + model = scripted_model.eval() + torch_output = model(torch_data) + + cosine_distance_res = distance.cosine(out.numpy(), torch_output.detach().cpu().numpy()) + assert cosine_distance_res <= 0.01 + + # Evaluate + print("Evaluate inference time cost...") + ftimer = gen_module.module.time_evaluator("run", dev, repeat=10, min_repeat_ms=500) + prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond + message = "Mean inference time (std dev): %.2f ms (%.2f ms)" % ( + np.mean(prof_res), + np.std(prof_res), + ) + print(message) + + +if __name__ == "__main__": + pytest.main([__file__])