Skip to content

Commit

Permalink
Merge pull request #172 from Visual-Behavior/dev
Browse files Browse the repository at this point in the history
add tensorrt quantization
  • Loading branch information
thibo73800 authored May 9, 2022
2 parents 33b49e2 + 0f36f91 commit 4caa0d2
Show file tree
Hide file tree
Showing 4 changed files with 436 additions and 28 deletions.
12 changes: 8 additions & 4 deletions alonet/torch2trt/TRTEngineBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ def get_engine(self):
Returns
-------
trt.ICudaEngine
Engine created from ONNX graph
Engine created from ONNX graph.
Raises
------
NotImplementedError
INT8 flag not implemented yet
RuntimeError
INT8 not supported by the platform.
Exception
TRT export engine error. It was not possible to export the engine.
"""
Expand All @@ -148,7 +148,11 @@ def get_engine(self):
config.set_flag(trt.BuilderFlag.FP16)
# INT8
if self.INT8_allowed:
raise NotImplementedError()
if not builder.platform_has_fast_int8:
raise RuntimeError('INT8 not supported on this platform')
config.set_quantization_flag(trt.QuantizationFlag.CALIBRATE_BEFORE_FUSION)
config.set_flag(trt.BuilderFlag.INT8)
config.int8_calibrator = self.calibrator
if self.strict_type:
config.set_flag(trt.BuilderFlag.STRICT_TYPES)
# Add optimization profile (used for dynamic shapes)
Expand Down
47 changes: 23 additions & 24 deletions alonet/torch2trt/base_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from alonet.torch2trt.onnx_hack import scope_name_workaround, get_scope_names, rename_tensors_
from alonet.torch2trt import TRTEngineBuilder, TRTExecutor, utils
from alonet.torch2trt.utils import get_nodes_by_op, rename_nodes_
from alonet.torch2trt.calibrator import BaseCalibrator

from pytorch_quantization import nn as quant_nn


class BaseTRTExporter:
Expand Down Expand Up @@ -54,7 +56,7 @@ def __init__(
operator_export_type=None,
dynamic_axes: Union[Dict[str, Dict[int, str]], Dict[str, List[int]]] = None,
opt_profiles: Dict[str, Tuple[List[int]]] = None,
skip_adapt_graph=False,
calibrator: BaseCalibrator = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -110,7 +112,6 @@ def __init__(
self.custom_opset = None # to be redefine in child class if needed
self.use_scope_names = use_scope_names
self.operator_export_type = operator_export_type
self.skip_adapt_graph = skip_adapt_graph
if dynamic_axes is not None:
assert opt_profiles is not None, "If dynamic_axes are to be used, opt_profiles must be provided"
assert isinstance(dynamic_axes, dict)
Expand All @@ -121,22 +122,23 @@ def __init__(
onnx_file_name = os.path.split(onnx_path)[1]
model_name = onnx_file_name.split(".")[0]

if not self.skip_adapt_graph:
self.adapted_onnx_path = os.path.join(onnx_dir, "trt_" + onnx_file_name)
else:
self.adapted_onnx_path = os.path.join(onnx_dir, onnx_file_name)

self.engine_path = os.path.join(onnx_dir, model_name + f"_{precision.lower()}.engine")

if self.verbose:
trt_logger = trt.Logger(trt.Logger.VERBOSE)
else:
trt_logger = trt.Logger(trt.Logger.WARNING)

self.engine_builder = TRTEngineBuilder(self.adapted_onnx_path, logger=trt_logger, opt_profiles=opt_profiles)
self.engine_builder = TRTEngineBuilder(self.onnx_path, logger=trt_logger, opt_profiles=opt_profiles, calibrator=calibrator)

if precision.lower() == "fp32":
pass
elif precision.lower() == "int8":
## set fake quantization to True before torch2onnx
quant_nn.TensorQuantizer.use_fb_fake_quant = True
self.engine_builder.INT8_allowed = True
self.engine_builder.FP16_allowed = True
self.engine_builder.strict_type = True
elif precision.lower() == "fp16":
self.engine_builder.FP16_allowed = True
self.engine_builder.strict_type = True
Expand All @@ -156,7 +158,6 @@ def build_torch_model(self):
pass
raise Exception("Child class should implement this method")


def adapt_graph(self, graph):
"""Modify ONNX graph to ensure compability between ONNX and TensorRT
Expand All @@ -165,8 +166,8 @@ def adapt_graph(self, graph):
graph: onnx_graphsurgeon.Graph
"""
return graph

def _adapt_graph(self, graph):
def _adapt_graph(self, graph, **kwargs):
"""Modify ONNX graph to ensure compability between ONNX and TensorRT
Returns
Expand Down Expand Up @@ -258,11 +259,9 @@ def _torch2onnx(self):
else:
with torch.no_grad():
m_outputs = self.model(*inputs, **kwargs)

# Prepare inputs for torch.export.onnx and sanity check
np_inputs = tuple(np.array(i.cpu()) for i in inputs)
inputs = (*inputs, kwargs)

onames = m_outputs._fields if hasattr(m_outputs, "_fields") else [f"out_{i}" for i in range(len(m_outputs))]
np_m_outputs = {key: val.cpu().numpy() for key, val in zip(onames, m_outputs) if isinstance(val, torch.Tensor)}
# print("Model input shapes:", [val.shape for val in np_inputs])
Expand All @@ -275,6 +274,7 @@ def _torch2onnx(self):
buffer = stack.enter_context(io.StringIO())
stack.enter_context(redirect_stdout(buffer))
stack.enter_context(scope_name_workaround())

torch.onnx.export(
self.model, # model being run
inputs, # model input (or a tuple for multiple inputs)
Expand All @@ -293,6 +293,16 @@ def _torch2onnx(self):

if self.use_scope_names:
onnx_export_log = buffer.getvalue()


graph = gs.import_onnx(onnx.load(self.onnx_path))
graph.toposort()

# === Modify ONNX graph for TensorRT compability
graph = self._adapt_graph(graph, **kwargs)
utils.print_graph_io(graph)
# === Export adapted onnx for TRT engine
onnx.save(gs.export_onnx(graph), self.onnx_path)

# rewrite onnx graph with new scope names
if self.use_scope_names:
Expand All @@ -319,16 +329,6 @@ def _onnx2engine(self, **kwargs):
if prod_package_error is not None:
raise prod_package_error

if not self.skip_adapt_graph:
graph = gs.import_onnx(onnx.load(self.onnx_path))
graph.toposort()

# === Modify ONNX graph for TensorRT compability
graph = self._adapt_graph(graph, **kwargs)
utils.print_graph_io(graph)
# === Export adapted onnx for TRT engine
onnx.save(gs.export_onnx(graph), self.adapted_onnx_path)

# === Build engine
self.engine_builder.export_engine(self.engine_path)
return self.engine_builder.engine
Expand Down Expand Up @@ -387,7 +387,6 @@ def add_argparse_args(parent_parser):
default=None,
help="/path/onnx/will/be/exported, by default set as ~/.aloception/weights/MODEL/MODEL.onnx",
)
parser.add_argument("--skip_adapt_graph", action="store_true", help="Skip the adapt graph")
parser.add_argument("--batch_size", type=int, default=1, help="Engine batch size, default = 1")
parser.add_argument("--precision", type=str, default="fp32", help="fp32/fp16/mix, default FP32")
parser.add_argument("--verbose", action="store_true", help="Helpful when debugging")
Expand Down
Loading

0 comments on commit 4caa0d2

Please sign in to comment.