@@ -42,24 +42,24 @@ torch::jit::Module compile_int8_model(const std::string& data_dir, torch::jit::M
4242
4343 auto calibrator = trtorch::ptq::make_int8_calibrator (std::move (calibration_dataloader), calibration_cache_file, true );
4444
45- std::vector<std::vector<int64_t >> input_shape = {{32 , 3 , 32 , 32 }};
45+ std::vector<trtorch::CompileSpec::Input> inputs = {
46+ trtorch::CompileSpec::Input (std::vector<int64_t >({32 , 3 , 32 , 32 }), trtorch::CompileSpec::DataType::kFloat )};
4647 // / Configure settings for compilation
47- auto compile_spec = trtorch::CompileSpec ({input_shape} );
48+ auto compile_spec = trtorch::CompileSpec (inputs );
4849 // / Set operating precision to INT8
49- compile_spec.enable_precisions .insert (torch::kI8 );
50+ compile_spec.enabled_precisions .insert (torch::kF16 );
51+ compile_spec.enabled_precisions .insert (torch::kI8 );
5052 // / Use the TensorRT Entropy Calibrator
5153 compile_spec.ptq_calibrator = calibrator;
5254 // / Set max batch size for the engine
5355 compile_spec.max_batch_size = 32 ;
5456 // / Set a larger workspace
5557 compile_spec.workspace_size = 1 << 28 ;
5658
57- mod.eval ();
58-
5959#ifdef SAVE_ENGINE
6060 std::cout << " Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
6161 auto engine = trtorch::ConvertGraphToTRTEngine (mod, " forward" , compile_spec);
62- std::ofstream out (" /tmp/engine_converted_from_jit .trt" );
62+ std::ofstream out (" /tmp/int8_engine_converted_from_jit .trt" );
6363 out << engine;
6464 out.close ();
6565#endif
@@ -86,60 +86,53 @@ int main(int argc, const char* argv[]) {
8686 return -1 ;
8787 }
8888
89+ mod.eval ();
90+
8991 // / Create the calibration dataset
9092 const std::string data_dir = std::string (argv[2 ]);
91- auto trt_mod = compile_int8_model (data_dir, mod);
9293
9394 // / Dataloader moved into calibrator so need another for inference
9495 auto eval_dataset = datasets::CIFAR10 (data_dir, datasets::CIFAR10::Mode::kTest )
96+ .use_subset (3200 )
9597 .map (torch::data::transforms::Normalize<>({0.4914 , 0.4822 , 0.4465 }, {0.2023 , 0.1994 , 0.2010 }))
9698 .map (torch::data::transforms::Stack<>());
9799 auto eval_dataloader = torch::data::make_data_loader (
98100 std::move (eval_dataset), torch::data::DataLoaderOptions ().batch_size (32 ).workers (2 ));
99101
100102 // / Check the FP32 accuracy in JIT
101- float correct = 0.0 , total = 0.0 ;
103+ torch::Tensor jit_correct = torch::zeros ({ 1 }, {torch:: kCUDA }), jit_total = torch::zeros ({ 1 }, {torch:: kCUDA }) ;
102104 for (auto batch : *eval_dataloader) {
103105 auto images = batch.data .to (torch::kCUDA );
104106 auto targets = batch.target .to (torch::kCUDA );
105107
106108 auto outputs = mod.forward ({images});
107109 auto predictions = std::get<1 >(torch::max (outputs.toTensor (), 1 , false ));
108110
109- total += targets.sizes ()[0 ];
110- correct += torch::sum (torch::eq (predictions, targets)). item (). toFloat ( );
111+ jit_total += targets.sizes ()[0 ];
112+ jit_correct += torch::sum (torch::eq (predictions, targets));
111113 }
112- std::cout << " Accuracy of JIT model on test set: " << 100 * (correct / total) << " %" << std::endl;
114+ torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100 ;
115+
116+ // / Compile Graph
117+ auto trt_mod = compile_int8_model (data_dir, mod);
113118
114119 // / Check the INT8 accuracy in TRT
115- correct = 0.0 ;
116- total = 0.0 ;
120+ torch::Tensor trt_correct = torch::zeros ({1 }, {torch::kCUDA }), trt_total = torch::zeros ({1 }, {torch::kCUDA });
117121 for (auto batch : *eval_dataloader) {
118122 auto images = batch.data .to (torch::kCUDA );
119123 auto targets = batch.target .to (torch::kCUDA );
120124
121- if (images.sizes ()[0 ] < 32 ) {
122- // / To handle smaller batches util Optimization profiles work with Int8
123- auto diff = 32 - images.sizes ()[0 ];
124- auto img_padding = torch::zeros ({diff, 3 , 32 , 32 }, {torch::kCUDA });
125- auto target_padding = torch::zeros ({diff}, {torch::kCUDA });
126- images = torch::cat ({images, img_padding}, 0 );
127- targets = torch::cat ({targets, target_padding}, 0 );
128- }
129-
130125 auto outputs = trt_mod.forward ({images});
131126 auto predictions = std::get<1 >(torch::max (outputs.toTensor (), 1 , false ));
132127 predictions = predictions.reshape (predictions.sizes ()[0 ]);
133128
134- if (predictions.sizes ()[0 ] != targets.sizes ()[0 ]) {
135- // / To handle smaller batches util Optimization profiles work with Int8
136- predictions = predictions.slice (0 , 0 , targets.sizes ()[0 ]);
137- }
138-
139- total += targets.sizes ()[0 ];
140- correct += torch::sum (torch::eq (predictions, targets)).item ().toFloat ();
129+ trt_total += targets.sizes ()[0 ];
130+ trt_correct += torch::sum (torch::eq (predictions, targets)).item ().toFloat ();
141131 }
142- std::cout << " Accuracy of quantized model on test set: " << 100 * (correct / total) << " %" << std::endl;
132+ torch::Tensor trt_accuracy = (trt_correct / trt_total) * 100 ;
133+
134+ std::cout << " Accuracy of JIT model on test set: " << jit_accuracy.item ().toFloat () << " %" << std::endl;
135+ std::cout << " Accuracy of quantized model on test set: " << trt_accuracy.item ().toFloat () << " %" << std::endl;
143136
144137 // / Time execution in JIT-FP32 and TRT-INT8
145138 std::vector<std::vector<int64_t >> dims = {{32 , 3 , 32 , 32 }};
0 commit comments