diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py index b5165c6f2d..e4298600cb 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py @@ -163,6 +163,9 @@ def run( timing_cache=None, profiling_verbosity=None, tactic_sources=None, + max_aux_streams=None, + version_compatible=False, + optimization_level=None, ) -> TRTInterpreterResult: """ Build TensorRT engine with some configs. @@ -227,6 +230,18 @@ def run( if profiling_verbosity else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) + + if trt.__version__ >= "8.6": + if max_aux_streams is not None: + _LOGGER.info(f"Setting max aux streams to {max_aux_streams}") + builder_config.max_aux_streams = max_aux_streams + if version_compatible: + _LOGGER.info(f"Using version compatible") + builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE) + if optimization_level is not None: + _LOGGER.info(f"Using optimization level {optimization_level}") + builder_config.builder_optimization_level = optimization_level + if lower_precision == LowerPrecision.FP16: builder_config.set_flag(trt.BuilderFlag.FP16) @@ -264,6 +279,7 @@ def run( _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) + _LOGGER.info(f"TRT Engine uses: {engine.device_memory_size} bytes of Memory") return TRTInterpreterResult( engine, self._input_names, self._output_names, serialized_cache diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py b/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py index 60ace0f12a..8131edb540 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py @@ -181,6 +181,9 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: if self.lower_setting.verbose_profile else trt.ProfilingVerbosity.LAYER_NAMES_ONLY, tactic_sources=self.lower_setting.tactic_sources, + max_aux_streams=self.lower_setting.max_aux_streams, + version_compatible=self.lower_setting.version_compatible, + optimization_level=self.lower_setting.optimization_level, ) # Update timing cache file if needed diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py b/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py index 9008bbe8e9..64fa1bf267 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py @@ -70,6 +70,9 @@ class LowerSetting(LowerSettingBasic): correctness_atol: absolute tolerance for correctness check correctness_rtol: relative tolerance for correctness check use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++). + max_aux_streams: max number of aux stream to use + version_compatible: enable version compatible feature + optimization_level: builder optimization level """ input_specs: List[InputTensorSpec] = dc.field(default_factory=list) @@ -96,3 +99,6 @@ class LowerSetting(LowerSettingBasic): correctness_atol: float = 0.1 correctness_rtol: float = 0.1 use_experimental_rt: bool = False + max_aux_streams: Optional[int] = None + version_compatible: bool = False + optimization_level: Optional[int] = None diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py index 96fa96cfae..7d3046d617 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py @@ -126,7 +126,10 @@ def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule: # (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall # on pass that failed accuracy check. def validate_inference( - rtol=None, atol=None, device=torch.device(torch.cuda.current_device()) + rtol=None, + atol=None, + device=torch.device(torch.cuda.current_device()), + suppress_accuracy_check_failure=True, ): def _validate_inference(pass_: PassFunc) -> PassFunc: """ @@ -141,48 +144,51 @@ def pass_with_validation( *args, **kwargs, ) -> fx.GraphModule: - input_tensors = extract_example_tensors_from_input(input, device) - res0 = module(*input_tensors) - processed_module = pass_(module, input, *args, **kwargs) - res1 = processed_module(*input_tensors) - tensor_res_0 = _collect_tensors(res0) - tensor_res_1 = _collect_tensors(res1) - relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE - - for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): - kwargs2 = {"equal_nan": True} - if rtol: - kwargs2["rtol"] = rtol - if atol: - kwargs2["atol"] = atol - kwargs2[ - "msg" - ] = ( - lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" - ) - # If tensors are on different devices, make sure to compare - # their copies that are on the same device. - if x.get_device() != y.get_device(): - x = x.cpu() - y = y.cpu() - try: - torch.testing.assert_close(x, y, **kwargs2) - except Exception as e: - if relax_accuracy_check_failure: - _LOGGER.error(f"{e}") - kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER - kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER - new_atol = kwargs2["atol"] - new_rtol = kwargs2["rtol"] - _LOGGER.info( - f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" - ) + if suppress_accuracy_check_failure: + return pass_(module, input, *args, **kwargs) + else: + input_tensors = extract_example_tensors_from_input(input, device) + res0 = module(*input_tensors) + processed_module = pass_(module, input, *args, **kwargs) + res1 = processed_module(*input_tensors) + tensor_res_0 = _collect_tensors(res0) + tensor_res_1 = _collect_tensors(res1) + relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE + + for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): + kwargs2 = {"equal_nan": True} + if rtol: + kwargs2["rtol"] = rtol + if atol: + kwargs2["atol"] = atol + kwargs2[ + "msg" + ] = ( + lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" + ) + # If tensors are on different devices, make sure to compare + # their copies that are on the same device. + if x.get_device() != y.get_device(): + x = x.cpu() + y = y.cpu() + try: torch.testing.assert_close(x, y, **kwargs2) - return processed_module - else: - raise e - - return processed_module + except Exception as e: + if relax_accuracy_check_failure: + _LOGGER.error(f"{e}") + kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER + kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER + new_atol = kwargs2["atol"] + new_rtol = kwargs2["rtol"] + _LOGGER.info( + f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" + ) + torch.testing.assert_close(x, y, **kwargs2) + return processed_module + else: + raise e + + return processed_module return pass_with_validation