Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tensorrt quantization #172

Merged
merged 13 commits into from
May 9, 2022
12 changes: 8 additions & 4 deletions alonet/torch2trt/TRTEngineBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ def get_engine(self):
Returns
-------
trt.ICudaEngine
Engine created from ONNX graph
Engine created from ONNX graph.

Raises
------
NotImplementedError
INT8 flag not implemented yet
RuntimeError
INT8 not supported by the platform.
Exception
TRT export engine error. It was not possible to export the engine.
"""
Expand All @@ -148,7 +148,11 @@ def get_engine(self):
config.set_flag(trt.BuilderFlag.FP16)
# INT8
if self.INT8_allowed:
raise NotImplementedError()
if not builder.platform_has_fast_int8:
raise RuntimeError('INT8 not supported on this platform')
config.set_quantization_flag(trt.QuantizationFlag.CALIBRATE_BEFORE_FUSION)
config.set_flag(trt.BuilderFlag.INT8)
config.int8_calibrator = self.calibrator
if self.strict_type:
config.set_flag(trt.BuilderFlag.STRICT_TYPES)
# Add optimization profile (used for dynamic shapes)
Expand Down
47 changes: 23 additions & 24 deletions alonet/torch2trt/base_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from alonet.torch2trt.onnx_hack import scope_name_workaround, get_scope_names, rename_tensors_
from alonet.torch2trt import TRTEngineBuilder, TRTExecutor, utils
from alonet.torch2trt.utils import get_nodes_by_op, rename_nodes_
from alonet.torch2trt.calibrator import BaseCalibrator

from pytorch_quantization import nn as quant_nn


class BaseTRTExporter:
Expand Down Expand Up @@ -54,7 +56,7 @@ def __init__(
operator_export_type=None,
dynamic_axes: Union[Dict[str, Dict[int, str]], Dict[str, List[int]]] = None,
opt_profiles: Dict[str, Tuple[List[int]]] = None,
skip_adapt_graph=False,
calibrator: BaseCalibrator = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -110,7 +112,6 @@ def __init__(
self.custom_opset = None # to be redefine in child class if needed
self.use_scope_names = use_scope_names
self.operator_export_type = operator_export_type
self.skip_adapt_graph = skip_adapt_graph
if dynamic_axes is not None:
assert opt_profiles is not None, "If dynamic_axes are to be used, opt_profiles must be provided"
assert isinstance(dynamic_axes, dict)
Expand All @@ -121,22 +122,23 @@ def __init__(
onnx_file_name = os.path.split(onnx_path)[1]
model_name = onnx_file_name.split(".")[0]

if not self.skip_adapt_graph:
self.adapted_onnx_path = os.path.join(onnx_dir, "trt_" + onnx_file_name)
else:
self.adapted_onnx_path = os.path.join(onnx_dir, onnx_file_name)

self.engine_path = os.path.join(onnx_dir, model_name + f"_{precision.lower()}.engine")

if self.verbose:
trt_logger = trt.Logger(trt.Logger.VERBOSE)
else:
trt_logger = trt.Logger(trt.Logger.WARNING)

self.engine_builder = TRTEngineBuilder(self.adapted_onnx_path, logger=trt_logger, opt_profiles=opt_profiles)
self.engine_builder = TRTEngineBuilder(self.onnx_path, logger=trt_logger, opt_profiles=opt_profiles, calibrator=calibrator)

if precision.lower() == "fp32":
pass
elif precision.lower() == "int8":
## set fake quantization to True before torch2onnx
quant_nn.TensorQuantizer.use_fb_fake_quant = True
self.engine_builder.INT8_allowed = True
self.engine_builder.FP16_allowed = True
self.engine_builder.strict_type = True
elif precision.lower() == "fp16":
self.engine_builder.FP16_allowed = True
self.engine_builder.strict_type = True
Expand All @@ -156,7 +158,6 @@ def build_torch_model(self):
pass
raise Exception("Child class should implement this method")


def adapt_graph(self, graph):
"""Modify ONNX graph to ensure compability between ONNX and TensorRT

Expand All @@ -165,8 +166,8 @@ def adapt_graph(self, graph):
graph: onnx_graphsurgeon.Graph
"""
return graph

def _adapt_graph(self, graph):
def _adapt_graph(self, graph, **kwargs):
"""Modify ONNX graph to ensure compability between ONNX and TensorRT

Returns
Expand Down Expand Up @@ -258,11 +259,9 @@ def _torch2onnx(self):
else:
with torch.no_grad():
m_outputs = self.model(*inputs, **kwargs)

# Prepare inputs for torch.export.onnx and sanity check
np_inputs = tuple(np.array(i.cpu()) for i in inputs)
inputs = (*inputs, kwargs)

onames = m_outputs._fields if hasattr(m_outputs, "_fields") else [f"out_{i}" for i in range(len(m_outputs))]
np_m_outputs = {key: val.cpu().numpy() for key, val in zip(onames, m_outputs) if isinstance(val, torch.Tensor)}
# print("Model input shapes:", [val.shape for val in np_inputs])
Expand All @@ -275,6 +274,7 @@ def _torch2onnx(self):
buffer = stack.enter_context(io.StringIO())
stack.enter_context(redirect_stdout(buffer))
stack.enter_context(scope_name_workaround())

torch.onnx.export(
self.model, # model being run
inputs, # model input (or a tuple for multiple inputs)
Expand All @@ -293,6 +293,16 @@ def _torch2onnx(self):

if self.use_scope_names:
onnx_export_log = buffer.getvalue()


graph = gs.import_onnx(onnx.load(self.onnx_path))
graph.toposort()

# === Modify ONNX graph for TensorRT compability
graph = self._adapt_graph(graph, **kwargs)
utils.print_graph_io(graph)
# === Export adapted onnx for TRT engine
onnx.save(gs.export_onnx(graph), self.onnx_path)

# rewrite onnx graph with new scope names
if self.use_scope_names:
Expand All @@ -319,16 +329,6 @@ def _onnx2engine(self, **kwargs):
if prod_package_error is not None:
raise prod_package_error

if not self.skip_adapt_graph:
graph = gs.import_onnx(onnx.load(self.onnx_path))
graph.toposort()

# === Modify ONNX graph for TensorRT compability
graph = self._adapt_graph(graph, **kwargs)
utils.print_graph_io(graph)
# === Export adapted onnx for TRT engine
onnx.save(gs.export_onnx(graph), self.adapted_onnx_path)

# === Build engine
self.engine_builder.export_engine(self.engine_path)
return self.engine_builder.engine
Expand Down Expand Up @@ -387,7 +387,6 @@ def add_argparse_args(parent_parser):
default=None,
help="/path/onnx/will/be/exported, by default set as ~/.aloception/weights/MODEL/MODEL.onnx",
)
parser.add_argument("--skip_adapt_graph", action="store_true", help="Skip the adapt graph")
parser.add_argument("--batch_size", type=int, default=1, help="Engine batch size, default = 1")
parser.add_argument("--precision", type=str, default="fp32", help="fp32/fp16/mix, default FP32")
parser.add_argument("--verbose", action="store_true", help="Helpful when debugging")
Expand Down
Loading