1313#include < sys/stat.h>
1414
1515int main (int argc, const char * argv[]) {
16- trtorch::logging::set_reportable_log_level (trtorch::logging::kINFO );
16+ trtorch::logging::set_reportable_log_level (trtorch::logging::Level:: kERROR );
1717 if (argc < 3 ) {
1818 std::cerr << " usage: ptq <path-to-module> <path-to-cifar10>\n " ;
1919 return -1 ;
@@ -50,11 +50,13 @@ int main(int argc, const char* argv[]) {
5050 // Configure settings for compilation
5151 auto extra_info = trtorch::ExtraInfo ({input_shape});
5252 // Set operating precision to INT8
53- extra_info.op_precision = torch::kFI8 ;
53+ extra_info.op_precision = torch::kI8 ;
5454 // Use the TensorRT Entropy Calibrator
5555 extra_info.ptq_calibrator = calibrator;
5656 // Set max batch size for the engine
5757 extra_info.max_batch_size = 32 ;
58+ // Set a larger workspace
59+ extra_info.workspace_size = 1 << 28 ;
5860
5961 mod.eval ();
6062
@@ -82,6 +84,7 @@ int main(int argc, const char* argv[]) {
8284 std::cout << " Accuracy of JIT model on test set: " << 100 * (correct / total) << " %" << std::endl;
8385
8486 // Compile Graph
87+ std::cout << " Compiling and quantizing module" << std::endl;
8588 auto trt_mod = trtorch::CompileGraph (mod, extra_info);
8689
8790 // Check the INT8 accuracy in TRT
@@ -91,22 +94,27 @@ int main(int argc, const char* argv[]) {
9194 auto images = batch.data .to (torch::kCUDA );
9295 auto targets = batch.target .to (torch::kCUDA );
9396
97+ if (images.sizes ()[0 ] < 32 ) {
98+ // To handle smaller batches util Optimization profiles work with Int8
99+ auto diff = 32 - images.sizes ()[0 ];
100+ auto img_padding = torch::zeros ({diff, 3 , 32 , 32 }, {torch::kCUDA });
101+ auto target_padding = torch::zeros ({diff}, {torch::kCUDA });
102+ images = torch::cat ({images, img_padding}, 0 );
103+ targets = torch::cat ({targets, target_padding}, 0 );
104+ }
105+
94106 auto outputs = trt_mod.forward ({images});
95107 auto predictions = std::get<1 >(torch::max (outputs.toTensor (), 1 , false ));
96108 predictions = predictions.reshape (predictions.sizes ()[0 ]);
97109
98110 if (predictions.sizes ()[0 ] != targets.sizes ()[0 ]) {
99- // To handle smaller batches util Optimization profiles work
111+ // To handle smaller batches util Optimization profiles work with Int8
100112 predictions = predictions.slice (0 , 0 , targets.sizes ()[0 ]);
101113 }
102114
103- std:: cout << predictions << targets << std::endl;
104-
105115 total += targets.sizes ()[0 ];
106116 correct += torch::sum (torch::eq (predictions, targets)).item ().toFloat ();
107- std::cout << total << " " << correct << std::endl;
108117 }
109- std::cout << total << " " << correct << std::endl;
110118 std::cout << " Accuracy of quantized model on test set: " << 100 * (correct / total) << " %" << std::endl;
111119
112120 // Time execution in INT8
0 commit comments