From 832f227b92f0d0772ad70b28f951f3ae9b17e4b9 Mon Sep 17 00:00:00 2001 From: tinyinl Date: Tue, 20 Jun 2023 22:54:14 +0000 Subject: [PATCH 01/12] change fix --- py/torch_tensorrt/fx/fx2trt.py | 71 +++++++++++++++++- py/torch_tensorrt/fx/lower.py | 11 +++ py/torch_tensorrt/fx/lower_setting.py | 10 +++ py/torch_tensorrt/fx/passes/pass_utils.py | 87 ++++++++++++----------- 4 files changed, 134 insertions(+), 45 deletions(-) diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index 96f1f1cadd..6b9640669d 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -1,6 +1,7 @@ import logging import warnings from datetime import datetime +from packaging import version from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence import numpy @@ -40,6 +41,7 @@ def __init__( explicit_batch_dimension: bool = False, explicit_precision: bool = False, logger_level=None, + output_dtypes=None, ): super().__init__(module) @@ -77,7 +79,8 @@ def __init__( self._itensor_to_tensor_meta: Dict[ trt.tensorrt.ITensor, TensorMetadata ] = dict() - + # Data types for TRT Module output Tensors + self.output_dtypes = output_dtypes def validate_input_specs(self): for shape, _, _, shape_ranges, has_batch_dim in self.input_specs: if not self.network.has_implicit_batch_dimension: @@ -163,6 +166,11 @@ def run( timing_cache=None, profiling_verbosity=None, tactic_sources=None, + max_aux_streams=None, + version_compatible=False, + tactic_heuristic=False, + optimization_level=None, + faster_dynamic_shapes=None, ) -> TRTInterpreterResult: """ Build TensorRT engine with some configs. @@ -182,7 +190,7 @@ def run( TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module) # For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and - # force_fp32_output=False. + # force_fp32_output=False. Overriden by specifying output_dtypes self.output_fp16 = ( not force_fp32_output and lower_precision == LowerPrecision.FP16 ) @@ -219,12 +227,31 @@ def run( cache = builder_config.create_timing_cache(b"") builder_config.set_timing_cache(cache, False) - if trt.__version__ >= "8.2": + if version.parse(trt.__version__) >= version.parse("8.2"): builder_config.profiling_verbosity = ( profiling_verbosity if profiling_verbosity else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) + + if version.parse(trt.__version__) >= version.parse("8.6"): + if max_aux_streams is not None: + _LOGGER.info(f"Setting max aux streams to {max_aux_streams}") + builder_config.max_aux_streams = max_aux_streams + if version_compatible: + _LOGGER.info(f"Using version compatible") + builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE) + if optimization_level is not None: + _LOGGER.info(f"Using optimization level {optimization_level}") + builder_config.builder_optimization_level = optimization_level + if faster_dynamic_shapes is not None: + builder_config.set_preview_feature(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805, faster_dynamic_shapes); + + if version.parse(trt.__version__) >= version.parse("8.5"): + if tactic_heuristic: + _LOGGER.info(f"Setting builder flag ENABLE_TACTIC_HEURISTIC") + builder_config.set_flag(trt.BuilderFlag.ENABLE_TACTIC_HEURISTIC) + if lower_precision == LowerPrecision.FP16: builder_config.set_flag(trt.BuilderFlag.FP16) @@ -251,6 +278,34 @@ def run( engine = self.builder.build_engine(self.network, builder_config) assert engine + import os + def get_file_name(org): + file_name = org + i = 0 + while os.path.exists(os.path.abspath(file_name)): + i += 1 + file_name = org + str(i) + return file_name + + engine_file = os.environ.get('TORCH_FX_DUMP_ENGINE') + if engine_file: + dump_file = get_file_name(engine_file) + print(f'Dumping engine to {dump_file}') + s = engine.serialize() + with open(dump_file, 'wb') as f: + f.write(s) + engine_info_file = os.environ.get('TORCH_FX_DUMP_ENGINE_INFO') + if engine_info_file: + inspector = engine.create_engine_inspector() + engine_info = inspector.get_engine_information(trt.LayerInformationFormat.JSON) + if engine_info is None or len(engine_info) == 0: + raise Exception('Engine info is empty') + else: + dump_file = get_file_name(engine_info_file) + print(f'Dumping engine info to {dump_file}') + with open(dump_file, 'w') as f: + f.write(engine_info) + serialized_cache = ( bytearray(cache.serialize()) if builder_config.get_timing_cache() @@ -259,6 +314,9 @@ def run( _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) + _LOGGER.info( + f"TRT Engine uses: {engine.device_memory_size} Memory" + ) return TRTInterpreterResult( engine, self._input_names, self._output_names, serialized_cache @@ -346,6 +404,11 @@ def output(self, target, args, kwargs): if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs): raise RuntimeError("TensorRT requires all outputs to be Tensor!") + if self.output_dtypes is not None and len(self.output_dtypes) != len(outputs): + raise RuntimeError( + f"Specified output dtypes ({len(self.output_dtypes)}) differ from number of outputs ({len(outputs)})" + ) + for i, output in enumerate(outputs): if any( op_name in output.name.split("_") @@ -370,6 +433,8 @@ def output(self, target, args, kwargs): self.network.mark_output(output) if output_bool: output.dtype = trt.bool + elif self.output_dtypes is not None: + output.dtype = torch_dtype_to_trt(self.output_dtypes[i]) elif self.output_fp16 and output.dtype == trt.float32: output.dtype = trt.float16 self._output_names.append(name) diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 61bd232421..c5de82f63e 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -119,6 +119,11 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: logger.warning(f"Cannot load timing cache for {split_name}: {str(e)}") cache_data = None + module_outputs = mod(*input) + if not isinstance(module_outputs, (list, tuple)): + module_outputs = [module_outputs] + output_dtypes = list(output.dtype for output in module_outputs) + interpreter = TRTInterpreter( mod, input_specs=self.lower_setting.input_specs, @@ -127,6 +132,7 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: logger_level=trt.Logger.VERBOSE if self.lower_setting.verbose_log else trt.Logger.WARNING, + output_dtypes=output_dtypes, ) interp_result: TRTInterpreterResult = interpreter.run( @@ -140,6 +146,11 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: if self.lower_setting.verbose_profile else trt.ProfilingVerbosity.LAYER_NAMES_ONLY, tactic_sources=self.lower_setting.tactic_sources, + max_aux_streams=self.lower_setting.max_aux_streams, + version_compatible=self.lower_setting.version_compatible, + tactic_heuristic=self.lower_setting.tactic_heuristic, + optimization_level=self.lower_setting.optimization_level, + faster_dynamic_shapes=self.lower_setting.faster_dynamic_shapes, ) # Update timing cache file if needed diff --git a/py/torch_tensorrt/fx/lower_setting.py b/py/torch_tensorrt/fx/lower_setting.py index a47f8c77c5..bf9357cbd5 100644 --- a/py/torch_tensorrt/fx/lower_setting.py +++ b/py/torch_tensorrt/fx/lower_setting.py @@ -74,6 +74,11 @@ class LowerSetting(LowerSettingBasic): correctness_atol: absolute tolerance for correctness check correctness_rtol: relative tolerance for correctness check use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++). + max_aux_streams: max number of aux stream to use + version_compatible: enable version compatible feature + tactic_heuristic: enable tactic heuristic + optimization_level: builder optimization level + faster_dynamic_shapes: enable/disable faster dynamic shapes """ input_specs: List[InputTensorSpec] = dc.field(default_factory=list) @@ -97,3 +102,8 @@ class LowerSetting(LowerSettingBasic): correctness_atol: float = 0.1 correctness_rtol: float = 0.1 use_experimental_rt: bool = False + max_aux_streams: Optional[int] = None + version_compatible: bool = False + tactic_heuristic: bool = False + optimization_level: Optional[int] = None + faster_dynamic_shapes: Optional[bool] = None diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index 78e9ec1b22..2cc0426747 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -99,7 +99,7 @@ def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule: # (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall # on pass that failed accuracy check. -def validate_inference(rtol=None, atol=None): +def validate_inference(rtol=None, atol=None, suppress_accuracy_check_failure=True): def _validate_inference(pass_: PassFunc) -> PassFunc: """ Wraps a pass function to validate that its inference results before and @@ -113,48 +113,51 @@ def pass_with_validation( *args, **kwargs, ) -> fx.GraphModule: - res0 = module(*input) - processed_module = pass_(module, input, *args, **kwargs) - res1 = processed_module(*input) - - tensor_res_0 = _collect_tensors(res0) - tensor_res_1 = _collect_tensors(res1) - relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE - - for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): - kwargs2 = {"equal_nan": True} - if rtol: - kwargs2["rtol"] = rtol - if atol: - kwargs2["atol"] = atol - kwargs2[ - "msg" - ] = ( - lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" - ) - # If tensors are on different devices, make sure to compare - # their copies that are on the same device. - if x.get_device() != y.get_device(): - x = x.cpu() - y = y.cpu() - try: - torch.testing.assert_close(x, y, **kwargs2) - except Exception as e: - if relax_accuracy_check_failure: - _LOGGER.error(f"{e}") - kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER - kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER - new_atol = kwargs2["atol"] - new_rtol = kwargs2["rtol"] - _LOGGER.info( - f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" - ) + if suppress_accuracy_check_failure: + return pass_(module, input, *args, **kwargs) + else: + res0 = module(*input) + processed_module = pass_(module, input, *args, **kwargs) + res1 = processed_module(*input) + + tensor_res_0 = _collect_tensors(res0) + tensor_res_1 = _collect_tensors(res1) + relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE + + for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): + kwargs2 = {"equal_nan": True} + if rtol: + kwargs2["rtol"] = rtol + if atol: + kwargs2["atol"] = atol + kwargs2[ + "msg" + ] = ( + lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" + ) + # If tensors are on different devices, make sure to compare + # their copies that are on the same device. + if x.get_device() != y.get_device(): + x = x.cpu() + y = y.cpu() + try: torch.testing.assert_close(x, y, **kwargs2) - return processed_module - else: - raise e - - return processed_module + except Exception as e: + if relax_accuracy_check_failure: + _LOGGER.error(f"{e}") + kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER + kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER + new_atol = kwargs2["atol"] + new_rtol = kwargs2["rtol"] + _LOGGER.info( + f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" + ) + torch.testing.assert_close(x, y, **kwargs2) + return processed_module + else: + raise e + + return processed_module return pass_with_validation From d8759849102fbb17bde61100355a0a3a34b6c277 Mon Sep 17 00:00:00 2001 From: tinyinl Date: Tue, 20 Jun 2023 23:14:27 +0000 Subject: [PATCH 02/12] revert --- py/torch_tensorrt/fx/fx2trt.py | 54 --------------------------- py/torch_tensorrt/fx/lower.py | 5 --- py/torch_tensorrt/fx/lower_setting.py | 10 ----- 3 files changed, 69 deletions(-) diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index ab8ac2fc8f..95dc784cde 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -167,11 +167,6 @@ def run( timing_cache=None, profiling_verbosity=None, tactic_sources=None, - max_aux_streams=None, - version_compatible=False, - tactic_heuristic=False, - optimization_level=None, - faster_dynamic_shapes=None, ) -> TRTInterpreterResult: """ Build TensorRT engine with some configs. @@ -240,24 +235,6 @@ def run( else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) - if version.parse(trt.__version__) >= version.parse("8.6"): - if max_aux_streams is not None: - _LOGGER.info(f"Setting max aux streams to {max_aux_streams}") - builder_config.max_aux_streams = max_aux_streams - if version_compatible: - _LOGGER.info(f"Using version compatible") - builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE) - if optimization_level is not None: - _LOGGER.info(f"Using optimization level {optimization_level}") - builder_config.builder_optimization_level = optimization_level - if faster_dynamic_shapes is not None: - builder_config.set_preview_feature(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805, faster_dynamic_shapes); - - if version.parse(trt.__version__) >= version.parse("8.5"): - if tactic_heuristic: - _LOGGER.info(f"Setting builder flag ENABLE_TACTIC_HEURISTIC") - builder_config.set_flag(trt.BuilderFlag.ENABLE_TACTIC_HEURISTIC) - if lower_precision == LowerPrecision.FP16: builder_config.set_flag(trt.BuilderFlag.FP16) @@ -284,34 +261,6 @@ def run( engine = self.builder.build_engine(self.network, builder_config) assert engine - import os - def get_file_name(org): - file_name = org - i = 0 - while os.path.exists(os.path.abspath(file_name)): - i += 1 - file_name = org + str(i) - return file_name - - engine_file = os.environ.get('TORCH_FX_DUMP_ENGINE') - if engine_file: - dump_file = get_file_name(engine_file) - print(f'Dumping engine to {dump_file}') - s = engine.serialize() - with open(dump_file, 'wb') as f: - f.write(s) - engine_info_file = os.environ.get('TORCH_FX_DUMP_ENGINE_INFO') - if engine_info_file: - inspector = engine.create_engine_inspector() - engine_info = inspector.get_engine_information(trt.LayerInformationFormat.JSON) - if engine_info is None or len(engine_info) == 0: - raise Exception('Engine info is empty') - else: - dump_file = get_file_name(engine_info_file) - print(f'Dumping engine info to {dump_file}') - with open(dump_file, 'w') as f: - f.write(engine_info) - serialized_cache = ( bytearray(cache.serialize()) if builder_config.get_timing_cache() @@ -320,9 +269,6 @@ def get_file_name(org): _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) - _LOGGER.info( - f"TRT Engine uses: {engine.device_memory_size} Memory" - ) return TRTInterpreterResult( engine, self._input_names, self._output_names, serialized_cache diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 388030b93f..f7c0d92c2e 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -148,11 +148,6 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: if self.lower_setting.verbose_profile else trt.ProfilingVerbosity.LAYER_NAMES_ONLY, tactic_sources=self.lower_setting.tactic_sources, - max_aux_streams=self.lower_setting.max_aux_streams, - version_compatible=self.lower_setting.version_compatible, - tactic_heuristic=self.lower_setting.tactic_heuristic, - optimization_level=self.lower_setting.optimization_level, - faster_dynamic_shapes=self.lower_setting.faster_dynamic_shapes, ) # Update timing cache file if needed diff --git a/py/torch_tensorrt/fx/lower_setting.py b/py/torch_tensorrt/fx/lower_setting.py index 6ed77b10af..07e7bf0dac 100644 --- a/py/torch_tensorrt/fx/lower_setting.py +++ b/py/torch_tensorrt/fx/lower_setting.py @@ -74,11 +74,6 @@ class LowerSetting(LowerSettingBasic): correctness_atol: absolute tolerance for correctness check correctness_rtol: relative tolerance for correctness check use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++). - max_aux_streams: max number of aux stream to use - version_compatible: enable version compatible feature - tactic_heuristic: enable tactic heuristic - optimization_level: builder optimization level - faster_dynamic_shapes: enable/disable faster dynamic shapes """ input_specs: List[InputTensorSpec] = dc.field(default_factory=list) @@ -106,8 +101,3 @@ class LowerSetting(LowerSettingBasic): correctness_atol: float = 0.1 correctness_rtol: float = 0.1 use_experimental_rt: bool = False - max_aux_streams: Optional[int] = None - version_compatible: bool = False - tactic_heuristic: bool = False - optimization_level: Optional[int] = None - faster_dynamic_shapes: Optional[bool] = None From e31cf0e3070f6f45c88769d990e0bcab08dda211 Mon Sep 17 00:00:00 2001 From: tinyinl Date: Tue, 20 Jun 2023 23:21:03 +0000 Subject: [PATCH 03/12] revert --- py/torch_tensorrt/fx/passes/pass_utils.py | 87 +++++++++++------------ 1 file changed, 42 insertions(+), 45 deletions(-) diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index 4cba4d5fab..d9fa24c2c6 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -181,51 +181,48 @@ def pass_with_validation( *args, **kwargs, ) -> fx.GraphModule: - if suppress_accuracy_check_failure: - return pass_(module, input, *args, **kwargs) - else: - res0 = module(*input) - processed_module = pass_(module, input, *args, **kwargs) - res1 = processed_module(*input) - - tensor_res_0 = _collect_tensors(res0) - tensor_res_1 = _collect_tensors(res1) - relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE - - for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): - kwargs2 = {"equal_nan": True} - if rtol: - kwargs2["rtol"] = rtol - if atol: - kwargs2["atol"] = atol - kwargs2[ - "msg" - ] = ( - lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" - ) - # If tensors are on different devices, make sure to compare - # their copies that are on the same device. - if x.get_device() != y.get_device(): - x = x.cpu() - y = y.cpu() - try: - torch.testing.assert_close(x, y, **kwargs2) - except Exception as e: - if relax_accuracy_check_failure: - _LOGGER.error(f"{e}") - kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER - kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER - new_atol = kwargs2["atol"] - new_rtol = kwargs2["rtol"] - _LOGGER.info( - f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" - ) - torch.testing.assert_close(x, y, **kwargs2) - return processed_module - else: - raise e - - return processed_module + res0 = module(*input) + processed_module = pass_(module, input, *args, **kwargs) + res1 = processed_module(*input) + + tensor_res_0 = _collect_tensors(res0) + tensor_res_1 = _collect_tensors(res1) + relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE + + for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): + kwargs2 = {"equal_nan": True} + if rtol: + kwargs2["rtol"] = rtol + if atol: + kwargs2["atol"] = atol + kwargs2[ + "msg" + ] = ( + lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" + ) + # If tensors are on different devices, make sure to compare + # their copies that are on the same device. + if x.get_device() != y.get_device(): + x = x.cpu() + y = y.cpu() + try: + torch.testing.assert_close(x, y, **kwargs2) + except Exception as e: + if relax_accuracy_check_failure: + _LOGGER.error(f"{e}") + kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER + kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER + new_atol = kwargs2["atol"] + new_rtol = kwargs2["rtol"] + _LOGGER.info( + f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" + ) + torch.testing.assert_close(x, y, **kwargs2) + return processed_module + else: + raise e + + return processed_module return pass_with_validation From 711f2d9953c0c74714aaf75aab69b568ee43559b Mon Sep 17 00:00:00 2001 From: tinyinl Date: Tue, 20 Jun 2023 23:21:48 +0000 Subject: [PATCH 04/12] revert --- py/torch_tensorrt/fx/passes/pass_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index d9fa24c2c6..6f9ce8b34b 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -181,7 +181,7 @@ def pass_with_validation( *args, **kwargs, ) -> fx.GraphModule: - res0 = module(*input) + res0 = module(*input) processed_module = pass_(module, input, *args, **kwargs) res1 = processed_module(*input) From b41af829ea1150807155e1118914ef54e1f0c017 Mon Sep 17 00:00:00 2001 From: tinyinl Date: Tue, 20 Jun 2023 23:24:53 +0000 Subject: [PATCH 05/12] revert --- py/torch_tensorrt/fx/passes/pass_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index 6f9ce8b34b..0b8578ffba 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -217,7 +217,7 @@ def pass_with_validation( _LOGGER.info( f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" ) - torch.testing.assert_close(x, y, **kwargs2) + torch.testing.assert_close(x, y, **kwargs2) return processed_module else: raise e From 744621c1d683c7b326aff788258901bcee908dba Mon Sep 17 00:00:00 2001 From: tinyinl Date: Thu, 22 Jun 2023 00:38:25 +0000 Subject: [PATCH 06/12] format --- py/torch_tensorrt/fx/fx2trt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index 95dc784cde..bb20be5b64 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -82,6 +82,7 @@ def __init__( ] = dict() # Data types for TRT Module output Tensors self.output_dtypes = output_dtypes + def validate_input_specs(self): for shape, _, _, shape_ranges, has_batch_dim in self.input_specs: if not self.network.has_implicit_batch_dimension: From 85f741a1ee91ac5dce40b1c067ec0797350b3a6d Mon Sep 17 00:00:00 2001 From: tinyinl Date: Tue, 20 Jun 2023 22:54:14 +0000 Subject: [PATCH 07/12] change fix --- py/torch_tensorrt/fx/fx2trt.py | 71 ++++++++++++++++++- py/torch_tensorrt/fx/lower.py | 11 +++ py/torch_tensorrt/fx/lower_setting.py | 10 +++ py/torch_tensorrt/fx/passes/pass_utils.py | 85 ++++++++++++----------- 4 files changed, 133 insertions(+), 44 deletions(-) diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index 846c90bdd5..ab8ac2fc8f 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -2,6 +2,7 @@ import os import warnings from datetime import datetime +from packaging import version from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence import numpy @@ -41,6 +42,7 @@ def __init__( explicit_batch_dimension: bool = False, explicit_precision: bool = False, logger_level=None, + output_dtypes=None, ): super().__init__(module) @@ -78,7 +80,8 @@ def __init__( self._itensor_to_tensor_meta: Dict[ trt.tensorrt.ITensor, TensorMetadata ] = dict() - + # Data types for TRT Module output Tensors + self.output_dtypes = output_dtypes def validate_input_specs(self): for shape, _, _, shape_ranges, has_batch_dim in self.input_specs: if not self.network.has_implicit_batch_dimension: @@ -164,6 +167,11 @@ def run( timing_cache=None, profiling_verbosity=None, tactic_sources=None, + max_aux_streams=None, + version_compatible=False, + tactic_heuristic=False, + optimization_level=None, + faster_dynamic_shapes=None, ) -> TRTInterpreterResult: """ Build TensorRT engine with some configs. @@ -183,7 +191,7 @@ def run( TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module) # For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and - # force_fp32_output=False. + # force_fp32_output=False. Overriden by specifying output_dtypes self.output_fp16 = ( not force_fp32_output and lower_precision == LowerPrecision.FP16 ) @@ -225,12 +233,31 @@ def run( cache = builder_config.create_timing_cache(b"") builder_config.set_timing_cache(cache, False) - if trt.__version__ >= "8.2": + if version.parse(trt.__version__) >= version.parse("8.2"): builder_config.profiling_verbosity = ( profiling_verbosity if profiling_verbosity else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) + + if version.parse(trt.__version__) >= version.parse("8.6"): + if max_aux_streams is not None: + _LOGGER.info(f"Setting max aux streams to {max_aux_streams}") + builder_config.max_aux_streams = max_aux_streams + if version_compatible: + _LOGGER.info(f"Using version compatible") + builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE) + if optimization_level is not None: + _LOGGER.info(f"Using optimization level {optimization_level}") + builder_config.builder_optimization_level = optimization_level + if faster_dynamic_shapes is not None: + builder_config.set_preview_feature(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805, faster_dynamic_shapes); + + if version.parse(trt.__version__) >= version.parse("8.5"): + if tactic_heuristic: + _LOGGER.info(f"Setting builder flag ENABLE_TACTIC_HEURISTIC") + builder_config.set_flag(trt.BuilderFlag.ENABLE_TACTIC_HEURISTIC) + if lower_precision == LowerPrecision.FP16: builder_config.set_flag(trt.BuilderFlag.FP16) @@ -257,6 +284,34 @@ def run( engine = self.builder.build_engine(self.network, builder_config) assert engine + import os + def get_file_name(org): + file_name = org + i = 0 + while os.path.exists(os.path.abspath(file_name)): + i += 1 + file_name = org + str(i) + return file_name + + engine_file = os.environ.get('TORCH_FX_DUMP_ENGINE') + if engine_file: + dump_file = get_file_name(engine_file) + print(f'Dumping engine to {dump_file}') + s = engine.serialize() + with open(dump_file, 'wb') as f: + f.write(s) + engine_info_file = os.environ.get('TORCH_FX_DUMP_ENGINE_INFO') + if engine_info_file: + inspector = engine.create_engine_inspector() + engine_info = inspector.get_engine_information(trt.LayerInformationFormat.JSON) + if engine_info is None or len(engine_info) == 0: + raise Exception('Engine info is empty') + else: + dump_file = get_file_name(engine_info_file) + print(f'Dumping engine info to {dump_file}') + with open(dump_file, 'w') as f: + f.write(engine_info) + serialized_cache = ( bytearray(cache.serialize()) if builder_config.get_timing_cache() @@ -265,6 +320,9 @@ def run( _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) + _LOGGER.info( + f"TRT Engine uses: {engine.device_memory_size} Memory" + ) return TRTInterpreterResult( engine, self._input_names, self._output_names, serialized_cache @@ -357,6 +415,11 @@ def output(self, target, args, kwargs): if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs): raise RuntimeError("TensorRT requires all outputs to be Tensor!") + if self.output_dtypes is not None and len(self.output_dtypes) != len(outputs): + raise RuntimeError( + f"Specified output dtypes ({len(self.output_dtypes)}) differ from number of outputs ({len(outputs)})" + ) + for i, output in enumerate(outputs): if any( op_name in output.name.split("_") @@ -381,6 +444,8 @@ def output(self, target, args, kwargs): self.network.mark_output(output) if output_bool: output.dtype = trt.bool + elif self.output_dtypes is not None: + output.dtype = torch_dtype_to_trt(self.output_dtypes[i]) elif self.output_fp16 and output.dtype == trt.float32: output.dtype = trt.float16 self._output_names.append(name) diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 6572fe9588..388030b93f 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -121,6 +121,11 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: logger.warning(f"Cannot load timing cache for {split_name}: {str(e)}") cache_data = None + module_outputs = mod(*input) + if not isinstance(module_outputs, (list, tuple)): + module_outputs = [module_outputs] + output_dtypes = list(output.dtype for output in module_outputs) + interpreter = TRTInterpreter( mod, input_specs=self.lower_setting.input_specs, @@ -129,6 +134,7 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: logger_level=trt.Logger.VERBOSE if self.lower_setting.verbose_log else trt.Logger.WARNING, + output_dtypes=output_dtypes, ) interp_result: TRTInterpreterResult = interpreter.run( @@ -142,6 +148,11 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: if self.lower_setting.verbose_profile else trt.ProfilingVerbosity.LAYER_NAMES_ONLY, tactic_sources=self.lower_setting.tactic_sources, + max_aux_streams=self.lower_setting.max_aux_streams, + version_compatible=self.lower_setting.version_compatible, + tactic_heuristic=self.lower_setting.tactic_heuristic, + optimization_level=self.lower_setting.optimization_level, + faster_dynamic_shapes=self.lower_setting.faster_dynamic_shapes, ) # Update timing cache file if needed diff --git a/py/torch_tensorrt/fx/lower_setting.py b/py/torch_tensorrt/fx/lower_setting.py index 07e7bf0dac..6ed77b10af 100644 --- a/py/torch_tensorrt/fx/lower_setting.py +++ b/py/torch_tensorrt/fx/lower_setting.py @@ -74,6 +74,11 @@ class LowerSetting(LowerSettingBasic): correctness_atol: absolute tolerance for correctness check correctness_rtol: relative tolerance for correctness check use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++). + max_aux_streams: max number of aux stream to use + version_compatible: enable version compatible feature + tactic_heuristic: enable tactic heuristic + optimization_level: builder optimization level + faster_dynamic_shapes: enable/disable faster dynamic shapes """ input_specs: List[InputTensorSpec] = dc.field(default_factory=list) @@ -101,3 +106,8 @@ class LowerSetting(LowerSettingBasic): correctness_atol: float = 0.1 correctness_rtol: float = 0.1 use_experimental_rt: bool = False + max_aux_streams: Optional[int] = None + version_compatible: bool = False + tactic_heuristic: bool = False + optimization_level: Optional[int] = None + faster_dynamic_shapes: Optional[bool] = None diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index 0b8578ffba..4cba4d5fab 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -181,48 +181,51 @@ def pass_with_validation( *args, **kwargs, ) -> fx.GraphModule: - res0 = module(*input) - processed_module = pass_(module, input, *args, **kwargs) - res1 = processed_module(*input) - - tensor_res_0 = _collect_tensors(res0) - tensor_res_1 = _collect_tensors(res1) - relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE - - for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): - kwargs2 = {"equal_nan": True} - if rtol: - kwargs2["rtol"] = rtol - if atol: - kwargs2["atol"] = atol - kwargs2[ - "msg" - ] = ( - lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" - ) - # If tensors are on different devices, make sure to compare - # their copies that are on the same device. - if x.get_device() != y.get_device(): - x = x.cpu() - y = y.cpu() - try: - torch.testing.assert_close(x, y, **kwargs2) - except Exception as e: - if relax_accuracy_check_failure: - _LOGGER.error(f"{e}") - kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER - kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER - new_atol = kwargs2["atol"] - new_rtol = kwargs2["rtol"] - _LOGGER.info( - f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" - ) + if suppress_accuracy_check_failure: + return pass_(module, input, *args, **kwargs) + else: + res0 = module(*input) + processed_module = pass_(module, input, *args, **kwargs) + res1 = processed_module(*input) + + tensor_res_0 = _collect_tensors(res0) + tensor_res_1 = _collect_tensors(res1) + relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE + + for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): + kwargs2 = {"equal_nan": True} + if rtol: + kwargs2["rtol"] = rtol + if atol: + kwargs2["atol"] = atol + kwargs2[ + "msg" + ] = ( + lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" + ) + # If tensors are on different devices, make sure to compare + # their copies that are on the same device. + if x.get_device() != y.get_device(): + x = x.cpu() + y = y.cpu() + try: torch.testing.assert_close(x, y, **kwargs2) - return processed_module - else: - raise e - - return processed_module + except Exception as e: + if relax_accuracy_check_failure: + _LOGGER.error(f"{e}") + kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER + kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER + new_atol = kwargs2["atol"] + new_rtol = kwargs2["rtol"] + _LOGGER.info( + f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" + ) + torch.testing.assert_close(x, y, **kwargs2) + return processed_module + else: + raise e + + return processed_module return pass_with_validation From c4d42c4be713ad967e26cbe9891973e8f81d665a Mon Sep 17 00:00:00 2001 From: tinyinl Date: Tue, 20 Jun 2023 23:14:27 +0000 Subject: [PATCH 08/12] revert --- py/torch_tensorrt/fx/fx2trt.py | 54 --------------------------- py/torch_tensorrt/fx/lower.py | 5 --- py/torch_tensorrt/fx/lower_setting.py | 10 ----- 3 files changed, 69 deletions(-) diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index ab8ac2fc8f..95dc784cde 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -167,11 +167,6 @@ def run( timing_cache=None, profiling_verbosity=None, tactic_sources=None, - max_aux_streams=None, - version_compatible=False, - tactic_heuristic=False, - optimization_level=None, - faster_dynamic_shapes=None, ) -> TRTInterpreterResult: """ Build TensorRT engine with some configs. @@ -240,24 +235,6 @@ def run( else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) - if version.parse(trt.__version__) >= version.parse("8.6"): - if max_aux_streams is not None: - _LOGGER.info(f"Setting max aux streams to {max_aux_streams}") - builder_config.max_aux_streams = max_aux_streams - if version_compatible: - _LOGGER.info(f"Using version compatible") - builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE) - if optimization_level is not None: - _LOGGER.info(f"Using optimization level {optimization_level}") - builder_config.builder_optimization_level = optimization_level - if faster_dynamic_shapes is not None: - builder_config.set_preview_feature(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805, faster_dynamic_shapes); - - if version.parse(trt.__version__) >= version.parse("8.5"): - if tactic_heuristic: - _LOGGER.info(f"Setting builder flag ENABLE_TACTIC_HEURISTIC") - builder_config.set_flag(trt.BuilderFlag.ENABLE_TACTIC_HEURISTIC) - if lower_precision == LowerPrecision.FP16: builder_config.set_flag(trt.BuilderFlag.FP16) @@ -284,34 +261,6 @@ def run( engine = self.builder.build_engine(self.network, builder_config) assert engine - import os - def get_file_name(org): - file_name = org - i = 0 - while os.path.exists(os.path.abspath(file_name)): - i += 1 - file_name = org + str(i) - return file_name - - engine_file = os.environ.get('TORCH_FX_DUMP_ENGINE') - if engine_file: - dump_file = get_file_name(engine_file) - print(f'Dumping engine to {dump_file}') - s = engine.serialize() - with open(dump_file, 'wb') as f: - f.write(s) - engine_info_file = os.environ.get('TORCH_FX_DUMP_ENGINE_INFO') - if engine_info_file: - inspector = engine.create_engine_inspector() - engine_info = inspector.get_engine_information(trt.LayerInformationFormat.JSON) - if engine_info is None or len(engine_info) == 0: - raise Exception('Engine info is empty') - else: - dump_file = get_file_name(engine_info_file) - print(f'Dumping engine info to {dump_file}') - with open(dump_file, 'w') as f: - f.write(engine_info) - serialized_cache = ( bytearray(cache.serialize()) if builder_config.get_timing_cache() @@ -320,9 +269,6 @@ def get_file_name(org): _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) - _LOGGER.info( - f"TRT Engine uses: {engine.device_memory_size} Memory" - ) return TRTInterpreterResult( engine, self._input_names, self._output_names, serialized_cache diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 388030b93f..f7c0d92c2e 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -148,11 +148,6 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: if self.lower_setting.verbose_profile else trt.ProfilingVerbosity.LAYER_NAMES_ONLY, tactic_sources=self.lower_setting.tactic_sources, - max_aux_streams=self.lower_setting.max_aux_streams, - version_compatible=self.lower_setting.version_compatible, - tactic_heuristic=self.lower_setting.tactic_heuristic, - optimization_level=self.lower_setting.optimization_level, - faster_dynamic_shapes=self.lower_setting.faster_dynamic_shapes, ) # Update timing cache file if needed diff --git a/py/torch_tensorrt/fx/lower_setting.py b/py/torch_tensorrt/fx/lower_setting.py index 6ed77b10af..07e7bf0dac 100644 --- a/py/torch_tensorrt/fx/lower_setting.py +++ b/py/torch_tensorrt/fx/lower_setting.py @@ -74,11 +74,6 @@ class LowerSetting(LowerSettingBasic): correctness_atol: absolute tolerance for correctness check correctness_rtol: relative tolerance for correctness check use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++). - max_aux_streams: max number of aux stream to use - version_compatible: enable version compatible feature - tactic_heuristic: enable tactic heuristic - optimization_level: builder optimization level - faster_dynamic_shapes: enable/disable faster dynamic shapes """ input_specs: List[InputTensorSpec] = dc.field(default_factory=list) @@ -106,8 +101,3 @@ class LowerSetting(LowerSettingBasic): correctness_atol: float = 0.1 correctness_rtol: float = 0.1 use_experimental_rt: bool = False - max_aux_streams: Optional[int] = None - version_compatible: bool = False - tactic_heuristic: bool = False - optimization_level: Optional[int] = None - faster_dynamic_shapes: Optional[bool] = None From bb07a7636f8ecbf39abb53efad66bb8b5955080c Mon Sep 17 00:00:00 2001 From: tinyinl Date: Tue, 20 Jun 2023 23:21:03 +0000 Subject: [PATCH 09/12] revert --- py/torch_tensorrt/fx/passes/pass_utils.py | 87 +++++++++++------------ 1 file changed, 42 insertions(+), 45 deletions(-) diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index 4cba4d5fab..d9fa24c2c6 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -181,51 +181,48 @@ def pass_with_validation( *args, **kwargs, ) -> fx.GraphModule: - if suppress_accuracy_check_failure: - return pass_(module, input, *args, **kwargs) - else: - res0 = module(*input) - processed_module = pass_(module, input, *args, **kwargs) - res1 = processed_module(*input) - - tensor_res_0 = _collect_tensors(res0) - tensor_res_1 = _collect_tensors(res1) - relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE - - for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): - kwargs2 = {"equal_nan": True} - if rtol: - kwargs2["rtol"] = rtol - if atol: - kwargs2["atol"] = atol - kwargs2[ - "msg" - ] = ( - lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" - ) - # If tensors are on different devices, make sure to compare - # their copies that are on the same device. - if x.get_device() != y.get_device(): - x = x.cpu() - y = y.cpu() - try: - torch.testing.assert_close(x, y, **kwargs2) - except Exception as e: - if relax_accuracy_check_failure: - _LOGGER.error(f"{e}") - kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER - kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER - new_atol = kwargs2["atol"] - new_rtol = kwargs2["rtol"] - _LOGGER.info( - f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" - ) - torch.testing.assert_close(x, y, **kwargs2) - return processed_module - else: - raise e - - return processed_module + res0 = module(*input) + processed_module = pass_(module, input, *args, **kwargs) + res1 = processed_module(*input) + + tensor_res_0 = _collect_tensors(res0) + tensor_res_1 = _collect_tensors(res1) + relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE + + for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): + kwargs2 = {"equal_nan": True} + if rtol: + kwargs2["rtol"] = rtol + if atol: + kwargs2["atol"] = atol + kwargs2[ + "msg" + ] = ( + lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" + ) + # If tensors are on different devices, make sure to compare + # their copies that are on the same device. + if x.get_device() != y.get_device(): + x = x.cpu() + y = y.cpu() + try: + torch.testing.assert_close(x, y, **kwargs2) + except Exception as e: + if relax_accuracy_check_failure: + _LOGGER.error(f"{e}") + kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER + kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER + new_atol = kwargs2["atol"] + new_rtol = kwargs2["rtol"] + _LOGGER.info( + f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" + ) + torch.testing.assert_close(x, y, **kwargs2) + return processed_module + else: + raise e + + return processed_module return pass_with_validation From d115197b79fc9690fe5da0bdcaa3923e2756658f Mon Sep 17 00:00:00 2001 From: tinyinl Date: Tue, 20 Jun 2023 23:21:48 +0000 Subject: [PATCH 10/12] revert --- py/torch_tensorrt/fx/passes/pass_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index d9fa24c2c6..6f9ce8b34b 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -181,7 +181,7 @@ def pass_with_validation( *args, **kwargs, ) -> fx.GraphModule: - res0 = module(*input) + res0 = module(*input) processed_module = pass_(module, input, *args, **kwargs) res1 = processed_module(*input) From d7770ea6741be1b956c0560d6351fc9bf34c7026 Mon Sep 17 00:00:00 2001 From: tinyinl Date: Tue, 20 Jun 2023 23:24:53 +0000 Subject: [PATCH 11/12] revert --- py/torch_tensorrt/fx/passes/pass_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index 6f9ce8b34b..0b8578ffba 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -217,7 +217,7 @@ def pass_with_validation( _LOGGER.info( f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" ) - torch.testing.assert_close(x, y, **kwargs2) + torch.testing.assert_close(x, y, **kwargs2) return processed_module else: raise e From 806783c687aec00026d4b6ca0f965a3ba8216b81 Mon Sep 17 00:00:00 2001 From: tinyinl Date: Thu, 22 Jun 2023 00:38:25 +0000 Subject: [PATCH 12/12] format --- py/torch_tensorrt/fx/fx2trt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index 95dc784cde..bb20be5b64 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -82,6 +82,7 @@ def __init__( ] = dict() # Data types for TRT Module output Tensors self.output_dtypes = output_dtypes + def validate_input_specs(self): for shape, _, _, shape_ranges, has_batch_dim in self.input_specs: if not self.network.has_implicit_batch_dimension: