Skip to content

Features: update the fx model for output dtype match #2040

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions py/torch_tensorrt/fx/fx2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import warnings
from datetime import datetime
from packaging import version
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence

import numpy
Expand Down Expand Up @@ -41,6 +42,7 @@ def __init__(
explicit_batch_dimension: bool = False,
explicit_precision: bool = False,
logger_level=None,
output_dtypes=None,
):
super().__init__(module)

Expand Down Expand Up @@ -78,6 +80,8 @@ def __init__(
self._itensor_to_tensor_meta: Dict[
trt.tensorrt.ITensor, TensorMetadata
] = dict()
# Data types for TRT Module output Tensors
self.output_dtypes = output_dtypes

def validate_input_specs(self):
for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
Expand Down Expand Up @@ -183,7 +187,7 @@ def run(
TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module)

# For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and
# force_fp32_output=False.
# force_fp32_output=False. Overriden by specifying output_dtypes
self.output_fp16 = (
not force_fp32_output and lower_precision == LowerPrecision.FP16
)
Expand Down Expand Up @@ -225,12 +229,13 @@ def run(
cache = builder_config.create_timing_cache(b"")
builder_config.set_timing_cache(cache, False)

if trt.__version__ >= "8.2":
if version.parse(trt.__version__) >= version.parse("8.2"):
builder_config.profiling_verbosity = (
profiling_verbosity
if profiling_verbosity
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
)

if lower_precision == LowerPrecision.FP16:
builder_config.set_flag(trt.BuilderFlag.FP16)

Expand Down Expand Up @@ -357,6 +362,11 @@ def output(self, target, args, kwargs):
if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs):
raise RuntimeError("TensorRT requires all outputs to be Tensor!")

if self.output_dtypes is not None and len(self.output_dtypes) != len(outputs):
raise RuntimeError(
f"Specified output dtypes ({len(self.output_dtypes)}) differ from number of outputs ({len(outputs)})"
)

for i, output in enumerate(outputs):
if any(
op_name in output.name.split("_")
Expand All @@ -381,6 +391,8 @@ def output(self, target, args, kwargs):
self.network.mark_output(output)
if output_bool:
output.dtype = trt.bool
elif self.output_dtypes is not None:
output.dtype = torch_dtype_to_trt(self.output_dtypes[i])
elif self.output_fp16 and output.dtype == trt.float32:
output.dtype = trt.float16
self._output_names.append(name)
6 changes: 6 additions & 0 deletions py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
logger.warning(f"Cannot load timing cache for {split_name}: {str(e)}")
cache_data = None

module_outputs = mod(*input)
if not isinstance(module_outputs, (list, tuple)):
module_outputs = [module_outputs]
output_dtypes = list(output.dtype for output in module_outputs)

interpreter = TRTInterpreter(
mod,
input_specs=self.lower_setting.input_specs,
Expand All @@ -129,6 +134,7 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
logger_level=trt.Logger.VERBOSE
if self.lower_setting.verbose_log
else trt.Logger.WARNING,
output_dtypes=output_dtypes,
)

interp_result: TRTInterpreterResult = interpreter.run(
Expand Down