Skip to content

Commit

Permalink
feat(//py): Allowing people using the PyTorch backend to use TRTorch/TRT
Browse files Browse the repository at this point in the history
INT8 calibrators

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Mar 16, 2021
1 parent fe5654f commit 6c3e0ad
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 1 deletion.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ bazel-genfiles
bazel-out
bazel-testlogs
bazel-TRTorch
bazel-trtorch-testing
third_party/pytorch
*.jit
*.jit.pt
Expand Down Expand Up @@ -37,4 +38,6 @@ bdist
py/trtorch/_version.py
py/wheelhouse
py/.eggs
notebooks/.ipynb_checkpoints/
notebooks/.ipynb_checkpoints/
*.cache
tests/py/data
1 change: 1 addition & 0 deletions py/trtorch/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,5 +257,6 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
backend_spec.set_num_avg_timing_iters(parsed_spec.num_avg_timing_iters)
backend_spec.set_workspace_size(parsed_spec.workspace_size)
backend_spec.set_max_batch_size(parsed_spec.max_batch_size)
backend_spec._set_ptq_calibrator(parsed_spec._get_calibrator_handle())

return backend_spec
1 change: 1 addition & 0 deletions py/trtorch/csrc/register_tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void RegisterTRTCompileSpec() {
.def(torch::init<>())
.def("append_input_range", &trtorch::pyapi::CompileSpec::appendInputRange)
.def("set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive)
.def("_set_ptq_calibrator", &trtorch::pyapi::CompileSpec::setPTQCalibratorViaHandle)
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);

ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision);
Expand Down
8 changes: 8 additions & 0 deletions py/trtorch/csrc/tensorrt_classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,18 @@ struct CompileSpec : torch::CustomClassHolder {
input_ranges.push_back(*ir);
}

int64_t getPTQCalibratorHandle() {
return (int64_t)ptq_calibrator;
}

void setDeviceIntrusive(const c10::intrusive_ptr<Device>& d) {
device = *d;
}

void setPTQCalibratorViaHandle(int64_t handle) {
ptq_calibrator = (nvinfer1::IInt8Calibrator*)handle;
}

ADD_ENUM_GET_SET(op_precision, DataType, static_cast<int64_t>(DataType::kChar));
ADD_FIELD_GET_SET(disable_tf32, bool);
ADD_FIELD_GET_SET(refit, bool);
Expand Down
1 change: 1 addition & 0 deletions py/trtorch/csrc/trtorch_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ PYBIND11_MODULE(_C, m) {

py::class_<CompileSpec>(m, "CompileSpec")
.def(py::init<>())
.def("_get_calibrator_handle", &CompileSpec::getPTQCalibratorHandle, "[Internal] gets a handle from a calibrator")
.def_readwrite("input_ranges", &CompileSpec::input_ranges)
.def_readwrite("op_precision", &CompileSpec::op_precision)
.def_readwrite("ptq_calibrator", &CompileSpec::ptq_calibrator)
Expand Down
97 changes: 97 additions & 0 deletions tests/py/test_ptq_to_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import unittest
import trtorch
from trtorch.logging import *
import torch
import torch.nn as nn
from torch.nn import functional as F
import torchvision
import torchvision.transforms as transforms
from model_test_case import ModelTestCase


class TestAccuracy(ModelTestCase):

def setUp(self):
self.input = torch.randn((1, 3, 32, 32)).to("cuda")
self.testing_dataset = torchvision.datasets.CIFAR10(root='./data',
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))
]))

self.testing_dataloader = torch.utils.data.DataLoader(self.testing_dataset,
batch_size=1,
shuffle=False,
num_workers=1)
self.calibrator = trtorch.ptq.DataLoaderCalibrator(self.testing_dataloader,
cache_file='./calibration.cache',
use_cache=False,
algo_type=trtorch.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2,
device=torch.device('cuda:0'))

self.spec = {
"forward":
trtorch.TensorRTCompileSpec({
"input_shapes": [[1, 3, 32, 32]],
"op_precision": torch.int8,
"calibrator": self.calibrator,
"device": {
"device_type": trtorch.DeviceType.GPU,
"gpu_id": 0,
"dla_core": 0,
"allow_gpu_fallback": False,
}
})
}

def compute_accuracy(self, testing_dataloader, model):
total = 0
correct = 0
loss = 0.0
class_probs = []
class_preds = []

with torch.no_grad():
idx = 0
for data, labels in testing_dataloader:
data, labels = data.cuda(), labels.cuda(non_blocking=True)
out = model(data)
preds = torch.max(out, 1)[1]
class_probs.append([F.softmax(i, dim=0) for i in out])
class_preds.append(preds)
total += labels.size(0)
correct += (preds == labels).sum().item()
idx += 1

test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
test_preds = torch.cat(class_preds)
return correct / total

def test_compile_script(self):

fp32_test_acc = self.compute_accuracy(self.testing_dataloader, self.model)
log(Level.Info, "[Pyt FP32] Test Acc: {:.2f}%".format(100 * fp32_test_acc))

trt_mod = torch._C._jit_to_backend("tensorrt", self.model, self.spec)
int8_test_acc = self.compute_accuracy(self.testing_dataloader, trt_mod)
log(Level.Info, "[TRT INT8 Backend] Test Acc: {:.2f}%".format(100 * int8_test_acc))
acc_diff = fp32_test_acc - int8_test_acc
self.assertTrue(abs(acc_diff) < 3)


def test_suite():
suite = unittest.TestSuite()
suite.addTest(TestAccuracy.parametrize(TestAccuracy, model=torch.jit.load('./trained_vgg16.jit.pt')))

return suite


suite = test_suite()

runner = unittest.TextTestRunner()
result = runner.run(suite)

exit(int(not result.wasSuccessful()))

0 comments on commit 6c3e0ad

Please sign in to comment.