Skip to content

Commit

Permalink
fix(//cpp/ptq): fixing bad accuracy in just the example code
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@naredasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Aug 5, 2021
1 parent b4a2dd6 commit 7efa11d
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 33 deletions.
53 changes: 23 additions & 30 deletions cpp/ptq/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,24 @@ torch::jit::Module compile_int8_model(const std::string& data_dir, torch::jit::M

auto calibrator = trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, true);

std::vector<std::vector<int64_t>> input_shape = {{32, 3, 32, 32}};
std::vector<trtorch::CompileSpec::Input> inputs = {
trtorch::CompileSpec::Input(std::vector<int64_t>({32, 3, 32, 32}), trtorch::CompileSpec::DataType::kFloat)};
/// Configure settings for compilation
auto compile_spec = trtorch::CompileSpec({input_shape});
auto compile_spec = trtorch::CompileSpec(inputs);
/// Set operating precision to INT8
compile_spec.enable_precisions.insert(torch::kI8);
compile_spec.enabled_precisions.insert(torch::kF16);
compile_spec.enabled_precisions.insert(torch::kI8);
/// Use the TensorRT Entropy Calibrator
compile_spec.ptq_calibrator = calibrator;
/// Set max batch size for the engine
compile_spec.max_batch_size = 32;
/// Set a larger workspace
compile_spec.workspace_size = 1 << 28;

mod.eval();

#ifdef SAVE_ENGINE
std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", compile_spec);
std::ofstream out("/tmp/engine_converted_from_jit.trt");
std::ofstream out("/tmp/int8_engine_converted_from_jit.trt");
out << engine;
out.close();
#endif
Expand All @@ -86,60 +86,53 @@ int main(int argc, const char* argv[]) {
return -1;
}

mod.eval();

/// Create the calibration dataset
const std::string data_dir = std::string(argv[2]);
auto trt_mod = compile_int8_model(data_dir, mod);

/// Dataloader moved into calibrator so need another for inference
auto eval_dataset = datasets::CIFAR10(data_dir, datasets::CIFAR10::Mode::kTest)
.use_subset(3200)
.map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465}, {0.2023, 0.1994, 0.2010}))
.map(torch::data::transforms::Stack<>());
auto eval_dataloader = torch::data::make_data_loader(
std::move(eval_dataset), torch::data::DataLoaderOptions().batch_size(32).workers(2));

/// Check the FP32 accuracy in JIT
float correct = 0.0, total = 0.0;
torch::Tensor jit_correct = torch::zeros({1}, {torch::kCUDA}), jit_total = torch::zeros({1}, {torch::kCUDA});
for (auto batch : *eval_dataloader) {
auto images = batch.data.to(torch::kCUDA);
auto targets = batch.target.to(torch::kCUDA);

auto outputs = mod.forward({images});
auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false));

total += targets.sizes()[0];
correct += torch::sum(torch::eq(predictions, targets)).item().toFloat();
jit_total += targets.sizes()[0];
jit_correct += torch::sum(torch::eq(predictions, targets));
}
std::cout << "Accuracy of JIT model on test set: " << 100 * (correct / total) << "%" << std::endl;
torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100;

/// Compile Graph
auto trt_mod = compile_int8_model(data_dir, mod);

/// Check the INT8 accuracy in TRT
correct = 0.0;
total = 0.0;
torch::Tensor trt_correct = torch::zeros({1}, {torch::kCUDA}), trt_total = torch::zeros({1}, {torch::kCUDA});
for (auto batch : *eval_dataloader) {
auto images = batch.data.to(torch::kCUDA);
auto targets = batch.target.to(torch::kCUDA);

if (images.sizes()[0] < 32) {
/// To handle smaller batches util Optimization profiles work with Int8
auto diff = 32 - images.sizes()[0];
auto img_padding = torch::zeros({diff, 3, 32, 32}, {torch::kCUDA});
auto target_padding = torch::zeros({diff}, {torch::kCUDA});
images = torch::cat({images, img_padding}, 0);
targets = torch::cat({targets, target_padding}, 0);
}

auto outputs = trt_mod.forward({images});
auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false));
predictions = predictions.reshape(predictions.sizes()[0]);

if (predictions.sizes()[0] != targets.sizes()[0]) {
/// To handle smaller batches util Optimization profiles work with Int8
predictions = predictions.slice(0, 0, targets.sizes()[0]);
}

total += targets.sizes()[0];
correct += torch::sum(torch::eq(predictions, targets)).item().toFloat();
trt_total += targets.sizes()[0];
trt_correct += torch::sum(torch::eq(predictions, targets)).item().toFloat();
}
std::cout << "Accuracy of quantized model on test set: " << 100 * (correct / total) << "%" << std::endl;
torch::Tensor trt_accuracy = (trt_correct / trt_total) * 100;

std::cout << "Accuracy of JIT model on test set: " << jit_accuracy.item().toFloat() << "%" << std::endl;
std::cout << "Accuracy of quantized model on test set: " << trt_accuracy.item().toFloat() << "%" << std::endl;

/// Time execution in JIT-FP32 and TRT-INT8
std::vector<std::vector<int64_t>> dims = {{32, 3, 32, 32}};
Expand Down
10 changes: 10 additions & 0 deletions tests/accuracy/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ filegroup(
srcs = glob(["**/*.jit.pt"]),
)

filegroup(
name = "data",
srcs = glob(["datasets/**/*"])
)

test_suite(
name = "aarch64_accuracy_tests",
tests = [
Expand All @@ -28,6 +33,7 @@ cc_test(
srcs = ["test_int8_accuracy.cpp"],
data = [
":jit_models",
":data"
],
deps = [
":accuracy_test",
Expand All @@ -40,6 +46,7 @@ cc_test(
srcs = ["test_fp16_accuracy.cpp"],
data = [
":jit_models",
":data"
],
deps = [
":accuracy_test",
Expand All @@ -52,6 +59,7 @@ cc_test(
srcs = ["test_fp32_accuracy.cpp"],
data = [
":jit_models",
":data"
],
deps = [
":accuracy_test",
Expand All @@ -64,6 +72,7 @@ cc_test(
srcs = ["test_dla_int8_accuracy.cpp"],
data = [
":jit_models",
":data"
],
deps = [
":accuracy_test",
Expand All @@ -76,6 +85,7 @@ cc_test(
srcs = ["test_dla_fp16_accuracy.cpp"],
data = [
":jit_models",
":data"
],
deps = [
":accuracy_test",
Expand Down
9 changes: 6 additions & 3 deletions tests/accuracy/test_int8_accuracy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ TEST_P(AccuracyTests, INT8AccuracyIsClose) {

std::string calibration_cache_file = "/tmp/vgg16_TRT_ptq_calibration.cache";

auto calibrator = trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, true);
auto calibrator =
trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, false);
// auto calibrator = trtorch::ptq::make_int8_cache_calibrator(calibration_cache_file);

std::vector<std::vector<int64_t>> input_shape = {{32, 3, 32, 32}};
std::vector<trtorch::CompileSpec::Input> inputs = {
trtorch::CompileSpec::Input(std::vector<int64_t>({32, 3, 32, 32}), trtorch::CompileSpec::DataType::kFloat)};
// Configure settings for compilation
auto compile_spec = trtorch::CompileSpec({input_shape});
auto compile_spec = trtorch::CompileSpec(inputs);
// Set operating precision to INT8
compile_spec.enabled_precisions.insert(torch::kF16);
compile_spec.enabled_precisions.insert(torch::kI8);
// Use the TensorRT Entropy Calibrator
compile_spec.ptq_calibrator = calibrator;
Expand Down

0 comments on commit 7efa11d

Please sign in to comment.