11# C++ API  
22
3- Targets in module create the user facing C++ library for the TRTorch core.  
3+ Targets in module create the user facing C++ library for the TRTorch core.
44
55## Building libtrtorch.so  
66
7- ### Debug build    
7+ ### Debug build  
88```  shell 
99bazel build //cpp/api:libtrtorch.so --compilation_mode=dbg
1010``` 
@@ -26,12 +26,19 @@ bazel build //cpp/api:libtrtorch.so --cxxopt="-DNDEBUG"
2626>  Temporary, will get real documentation soon
2727
2828``` c++ 
29+ namespace  trtorch  {
2930/** 
3031 *  Settings data structure for TRTorch compilation
3132 * 
3233 * /
3334struct TRTORCH_API ExtraInfo {
34-     //struct TRTORCH_API InputRangesArray {
35+     /** 
36+      * @brief   A struct to hold an input range (used by TensorRT Optimization profile)
37+      *
38+      * This struct can either hold a single vector representing an input shape, signifying a
39+      * static input shape or a set of three input shapes representing the min, optiminal and max
40+      * input shapes allowed for the engine.
41+      * /
3542    struct TRTORCH_API InputRange {
3643        std::vector<int64_t> min;
3744        std::vector<int64_t> opt;
@@ -46,7 +53,7 @@ struct TRTORCH_API ExtraInfo {
4653     * Supported Data Types that can be used with TensorRT engines 
4754     * 
4855     * This class is compatable with c10::DataTypes (but will check for TRT support) 
49-      * so there should not be a reason that you need to use this type explictly.   
56+      * so there should not be a reason that you need to use this type explictly. 
5057     */ 
5158    class DataType { 
5259    public: 
@@ -59,14 +66,14 @@ struct TRTORCH_API ExtraInfo {
5966         * ex. trtorch::DataType type = DataType::kFloat; 
6067         */ 
6168        enum Value : int8_t { 
62-             /// FP32    
69+             /// FP32 
6370            kFloat, 
6471            /// FP16 
6572            kHalf, 
6673            /// INT8 
67-             /* kChar, char or int8? */  
74+             kChar, 
6875        }; 
69-           
76+ 
7077        DataType() = default; 
7178        constexpr DataType(Value t) : value(t) {} 
7279        DataType (c10::ScalarType t);
@@ -83,7 +90,7 @@ struct TRTORCH_API ExtraInfo {
8390     * 
8491     * This class is compatable with c10::DeviceTypes (but will check for TRT support) 
8592     * but the only applicable value is at::kCUDA, which maps to DeviceType::kGPU 
86-      *   
93+      * 
8794     * To use the DataType class itself, interface using the enum vs. normal instatination 
8895     * 
8996     * ex. trtorch::DeviceType type = DeviceType::kGPU; 
@@ -117,7 +124,7 @@ struct TRTORCH_API ExtraInfo {
117124    }; 
118125
119126    /** 
120-      * Emum for selecting engine capability   
127+      * Emum for selecting engine capability 
121128     */ 
122129    enum class EngineCapability : int8_t { 
123130        kDEFAULT, 
@@ -129,24 +136,24 @@ struct TRTORCH_API ExtraInfo {
129136        : input_ranges(std::move(input_ranges)) {} 
130137    ExtraInfo(std::vector<std::vector<int64_t>> fixed_sizes); 
131138    ExtraInfo(std::vector<c10::ArrayRef<int64_t>> fixed_sizes); 
132-           
139+ 
133140    // Defaults should reflect TensorRT defaults for BuilderConfig 
134141
135-     /**   
142+     /** 
136143     * Sizes for inputs to engine, can either be a single size or a range 
137-      * defined by Min, Optimal, Max sizes   
138-      *   
139-      * Order is should match call order   
144+      * defined by Min, Optimal, Max sizes 
145+      * 
146+      * Order is should match call order 
140147     */ 
141148    std::vector<InputRange> input_ranges; 
142149
143150    /** 
144-      * Default operating precision for the engine   
151+      * Default operating precision for the engine 
145152     */ 
146153    DataType op_precision = DataType::kFloat; 
147-      
154+ 
148155    /** 
149-      * Build a refitable engine   
156+      * Build a refitable engine 
150157     */ 
151158    bool refit = false; 
152159
@@ -158,10 +165,10 @@ struct TRTORCH_API ExtraInfo {
158165    /** 
159166     * Restrict operating type to only set default operation precision (op_precision) 
160167     */ 
161-     bool strict_type  = false; 
168+     bool strict_types  = false; 
162169
163170    /** 
164-      * (Only used when targeting DLA (device))   
171+      * (Only used when targeting DLA (device)) 
165172     * Lets engine run layers on GPU if they are not supported on DLA 
166173     */ 
167174    bool allow_gpu_fallback = true; 
@@ -189,6 +196,16 @@ struct TRTORCH_API ExtraInfo {
189196     * Maximum size of workspace given to TensorRT 
190197     */ 
191198    uint64_t workspace_size = 0; 
199+ 
200+     /** 
201+      * Maximum batch size (must be =< 1 to be set, 0 means not set) 
202+      */ 
203+     uint64_t max_batch_size = 0; 
204+ 
205+     /** 
206+      * Calibration dataloaders for each input for post training quantizatiom 
207+      */ 
208+     nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr; 
192209};
193210
194211/** 
@@ -198,37 +215,89 @@ TRTORCH_API std::string get_build_info();
198215
199216/** 
200217 *  Dump the version information for TRTorch including base libtorch and TensorRT versions
201-  *  to stdout  
218+  *  to stdout
202219 * /
203220TRTORCH_API void dump_build_info();
204221
222+ /** 
223+  *  @brief   Check to see if a module is fully supported by the compiler
224+  * 
225+  *  @param   module: torch::jit::script::Module - Existing TorchScript module
226+  *  @param   method_name: std::string - Name of method to compile
227+  * 
228+  *  Takes a module and a method name and checks if the method graph contains purely
229+  *  convertable operators
230+  * 
231+  *  Will print out a list of unsupported operators if the graph is unsupported
232+  * /
233+ TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, std::string method_name);
234+ 
205235/** 
206236 *  @brief   Compile a TorchScript module for NVIDIA GPUs using TensorRT
207237 * 
208-  *  @param   module: torch::jit::script::Module - Existing TorchScript module  
209-  *  @param   info: trtorch::ExtraInfo - Compilation settings  
238+  *  @param   module: torch::jit::script::Module - Existing TorchScript module
239+  *  @param   info: trtorch::ExtraInfo - Compilation settings
210240 * 
211241 *  Takes a existing TorchScript module and a set of settings to configure the compiler
212242 *  and will convert methods to JIT Graphs which call equivalent TensorRT engines
213243 * 
214-  *  Converts specifically the forward method of a TorchScript Module  
215-  * /  
244+  *  Converts specifically the forward method of a TorchScript Module
245+  * /
216246TRTORCH_API torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo info);
217247
218248/** 
219249 *  @brief   Compile a TorchScript method for NVIDIA GPUs using TensorRT
220250 * 
221-  *  @param   module: torch::jit::script::Module - Existing TorchScript module  
251+  *  @param   module: torch::jit::script::Module - Existing TorchScript module
222252 *  @param   method_name: std::string - Name of method to compile
223-  *  @param   info: trtorch::ExtraInfo - Compilation settings  
253+  *  @param   info: trtorch::ExtraInfo - Compilation settings
224254 * 
225255 *  Takes a existing TorchScript module and a set of settings to configure the compiler
226256 *  and will convert selected method to a serialized TensorRT engine which can be run with
227257 *  TensorRT
228258 * /
229- TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, ExtraInfo info);
259+ TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module, std::string method_name, ExtraInfo info);
260+ 
261+ namespace ptq {
262+ /** 
263+  *  @brief   A factory to build a post training quantization calibrator from a torch dataloader
264+  * 
265+  *  Creates a calibrator to use for post training quantization
266+  *  If there are multiple inputs, the dataset should produce a example which is a vector (or similar container) of tensors vs a single tensor
267+  * 
268+  *  By default the returned calibrator uses TensorRT Entropy v2 algorithm to perform calibration. This is recommended for feed forward networks
269+  *  You can override the algorithm selection (such as to use the MinMax Calibrator recomended for NLP tasks) by calling make_int8_calibrator with
270+  *  the calibrator class as a template parameter.
271+  * 
272+  *  e.g. trtorch::ptq::make_int8_calibrator< nvinfer1::IInt8MinMaxCalibrator > (std::move(calibration_dataloader), calibration_cache_file, use_cache);
273+  * /
274+ template<typename  Algorithm  = nvinfer1::IInt8EntropyCalibrator2,  typename  DataLoader >
275+ TRTORCH_API inline Int8Calibrator<Algorithm, DataLoader> make_int8_calibrator(DataLoader dataloader, const std::string& cache_file_path, bool use_cache) {
276+     return Int8Calibrator<Algorithm, DataLoader>(std::move(dataloader), cache_file_path, use_cache);
277+ }
278+ 
279+ /** 
280+  *  @brief   A factory to build a post training quantization calibrator from a torch dataloader that only uses the calibration cache
281+  * 
282+  *  Creates a calibrator to use for post training quantization which reads from a previously created calibration cache, therefore
283+  *  you can have a calibration cache generating program that requires a dataloader and a dataset, then save the cache to use later
284+  *  in a different program that needs to calibrate from scratch and not have the dataset dependency. However, the network should also
285+  *   be recalibrated if its structure changes, or the input data set changes, and it is the responsibility of the application to ensure this.
286+  * 
287+  *  By default the returned calibrator uses TensorRT Entropy v2 algorithm to perform calibration. This is recommended for feed forward networks
288+  *  You can override the algorithm selection (such as to use the MinMax Calibrator recomended for NLP tasks) by calling make_int8_calibrator with
289+  *  the calibrator class as a template parameter.
290+  * 
291+  *  e.g. trtorch::ptq::make_int8_cache_calibrator< nvinfer1::IInt8MinMaxCalibrator > (calibration_cache_file);
292+  * /
293+ template<typename  Algorithm  = nvinfer1::IInt8EntropyCalibrator2 >
294+ TRTORCH_API inline Int8CacheCalibrator<Algorithm > make_int8_cache_calibrator(const std::string& cache_file_path) {
295+     return Int8CacheCalibrator<Algorithm >(cache_file_path);
296+ }
297+ } // namespace ptq
230298} // namespace trtorch
231299
300+ 
232301``` 
233302
234303
0 commit comments