diff --git a/py/torch_tensorrt/dynamo/backend/__init__.py b/py/torch_tensorrt/dynamo/backend/__init__.py index 3743b263db..037294965c 100644 --- a/py/torch_tensorrt/dynamo/backend/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/__init__.py @@ -14,7 +14,7 @@ from torch_tensorrt.dynamo.backend._defaults import ( PRECISION, DEBUG, - MAX_WORKSPACE_SIZE, + WORKSPACE_SIZE, MIN_BLOCK_SIZE, PASS_THROUGH_BUILD_FAILURES, ) @@ -35,7 +35,7 @@ def compile( debug=DEBUG, capability=EngineCapability.default, num_avg_timing_iters=1, - workspace_size=MAX_WORKSPACE_SIZE, + workspace_size=WORKSPACE_SIZE, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, @@ -105,7 +105,7 @@ def compile( def create_backend( precision: LowerPrecision = PRECISION, debug: bool = DEBUG, - workspace_size: int = MAX_WORKSPACE_SIZE, + workspace_size: int = WORKSPACE_SIZE, min_block_size: int = MIN_BLOCK_SIZE, torch_executed_ops: Sequence[str] = set(), pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES, @@ -114,10 +114,12 @@ def create_backend( """Create torch.compile backend given specified arguments Args: - precision: - debug: Whether to print out verbose debugging information - workspace_size: Maximum workspace TRT is allowed to use for the module precision: Model Layer precision + debug: Whether to print out verbose debugging information + workspace_size: Workspace TRT is allowed to use for the module (0 is default) + min_block_size: Minimum number of operators per TRT-Engine Block + torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage + pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False) Returns: Backend for torch.compile """ diff --git a/py/torch_tensorrt/dynamo/backend/_defaults.py b/py/torch_tensorrt/dynamo/backend/_defaults.py index fe7b5f6b4f..bb34f2dcac 100644 --- a/py/torch_tensorrt/dynamo/backend/_defaults.py +++ b/py/torch_tensorrt/dynamo/backend/_defaults.py @@ -3,6 +3,6 @@ PRECISION = LowerPrecision.FP32 DEBUG = False -MAX_WORKSPACE_SIZE = 20 << 30 +WORKSPACE_SIZE = 0 MIN_BLOCK_SIZE = 5 PASS_THROUGH_BUILD_FAILURES = False diff --git a/py/torch_tensorrt/dynamo/backend/_settings.py b/py/torch_tensorrt/dynamo/backend/_settings.py index df3212f54a..73bc08a419 100644 --- a/py/torch_tensorrt/dynamo/backend/_settings.py +++ b/py/torch_tensorrt/dynamo/backend/_settings.py @@ -5,7 +5,7 @@ from torch_tensorrt.dynamo.backend._defaults import ( PRECISION, DEBUG, - MAX_WORKSPACE_SIZE, + WORKSPACE_SIZE, MIN_BLOCK_SIZE, PASS_THROUGH_BUILD_FAILURES, ) @@ -15,7 +15,7 @@ class CompilationSettings: precision: LowerPrecision = PRECISION debug: bool = DEBUG - workspace_size: int = MAX_WORKSPACE_SIZE + workspace_size: int = WORKSPACE_SIZE min_block_size: int = MIN_BLOCK_SIZE torch_executed_ops: Sequence[str] = field(default_factory=set) pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES diff --git a/py/torch_tensorrt/dynamo/backend/conversion.py b/py/torch_tensorrt/dynamo/backend/conversion.py index 1644dea547..f359020bfb 100644 --- a/py/torch_tensorrt/dynamo/backend/conversion.py +++ b/py/torch_tensorrt/dynamo/backend/conversion.py @@ -3,7 +3,7 @@ from torch_tensorrt.fx.trt_module import TRTModule from torch_tensorrt import TRTModuleNext from torch_tensorrt.dynamo.backend._settings import CompilationSettings -from torch_tensorrt.fx.fx2trt import ( +from torch_tensorrt.dynamo.fx_ts_compat.fx2trt import ( InputTensorSpec, TRTInterpreter, ) @@ -24,15 +24,15 @@ def convert_module( Returns: TRTModule or TRTModuleNext """ - interp = TRTInterpreter( + interpreter = TRTInterpreter( module, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True, logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), ) - r = interp.run( - max_workspace_size=settings.workspace_size, + interpreter_result = interpreter.run( + workspace_size=settings.workspace_size, lower_precision=settings.precision, profiling_verbosity=( trt.ProfilingVerbosity.VERBOSE @@ -41,4 +41,8 @@ def convert_module( ), ) - return TRTModule(*r) + return TRTModule( + engine=interpreter_result.engine, + input_names=interpreter_result.input_names, + output_names=interpreter_result.output_names, + )