From cd250ac9d24a15eba371a295beaa0cdcec7ff358 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 2 Feb 2024 14:57:28 -0800 Subject: [PATCH 1/2] fix: Linter config fix --- examples/int8/training/vgg16/vgg16.py | 4 +- py/torch_tensorrt/dynamo/_compiler.py | 9 ++- .../runtime/_PythonTorchTensorRTModule.py | 57 ++++++++++++------- .../fx/converters/aten_ops_converters.py | 40 ++++++------- py/torch_tensorrt/fx/fx2trt.py | 2 +- py/torch_tensorrt/fx/lower.py | 23 ++++---- .../fx/passes/lower_basic_pass.py | 1 - .../fx/passes/lower_pass_manager_builder.py | 23 ++++---- .../fx/test/converters/acc_op/test_split.py | 16 ++++-- py/torch_tensorrt/fx/tools/common_fx2trt.py | 7 +-- py/torch_tensorrt/fx/trt_module.py | 18 +++--- py/torch_tensorrt/ts/_compile_spec.py | 9 ++- pyproject.toml | 2 +- 13 files changed, 116 insertions(+), 95 deletions(-) diff --git a/examples/int8/training/vgg16/vgg16.py b/examples/int8/training/vgg16/vgg16.py index b371b8e243..379306114b 100644 --- a/examples/int8/training/vgg16/vgg16.py +++ b/examples/int8/training/vgg16/vgg16.py @@ -3,10 +3,12 @@ - [Very Deep Convolutional Networks for Large-Scale Image Recognition]( https://arxiv.org/abs/1409.1556) (ICLR 2015) """ + +from functools import reduce + import torch import torch.nn as nn import torch.nn.functional as F -from functools import reduce class VGG(nn.Module): diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index e705c069d5..be96853e58 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -5,6 +5,7 @@ from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union import torch +import torch_tensorrt from torch.export import ExportedProgram from torch.fx.node import Target from torch_tensorrt import _enums @@ -66,8 +67,6 @@ to_torch_tensorrt_device, ) -import torch_tensorrt - logger = logging.getLogger(__name__) @@ -217,9 +216,9 @@ def compile( "device": device, "workspace_size": workspace_size, "min_block_size": min_block_size, - "torch_executed_ops": torch_executed_ops - if torch_executed_ops is not None - else set(), + "torch_executed_ops": ( + torch_executed_ops if torch_executed_ops is not None else set() + ), "pass_through_build_failures": pass_through_build_failures, "max_aux_streams": max_aux_streams, "version_compatible": version_compatible, diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index db45609123..3a66ed3716 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -6,6 +6,7 @@ import tensorrt as trt import torch +import torch_tensorrt from torch.nn import Module from torch_tensorrt._Device import Device from torch_tensorrt.dynamo.runtime.tools import ( @@ -15,8 +16,6 @@ ) from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter -import torch_tensorrt - logger = logging.getLogger(__name__) @@ -101,9 +100,11 @@ def _initialize(self) -> None: for idx in self.output_binding_indices_in_order ] self.output_shapes = [ - tuple(self.engine.get_binding_shape(idx)) - if self.engine.has_implicit_batch_dimension - else tuple() + ( + tuple(self.engine.get_binding_shape(idx)) + if self.engine.has_implicit_batch_dimension + else tuple() + ) for idx in self.output_binding_indices_in_order ] self.hidden_output_dtypes = [ @@ -113,9 +114,11 @@ def _initialize(self) -> None: for idx in self.hidden_output_binding_indices_in_order ] self.hidden_output_shapes = [ - tuple(self.engine.get_binding_shape(idx)) - if self.engine.has_implicit_batch_dimension - else tuple() + ( + tuple(self.engine.get_binding_shape(idx)) + if self.engine.has_implicit_batch_dimension + else tuple() + ) for idx in self.hidden_output_binding_indices_in_order ] @@ -167,9 +170,11 @@ def __setstate__(self, state: Dict[str, Any]) -> None: self.context = self.engine.create_execution_context() def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: - with torch.autograd.profiler.record_function( - "PythonTorchTensorRTModule:Forward" - ) if self.profiling_enabled else nullcontext(): + with ( + torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward") + if self.profiling_enabled + else nullcontext() + ): self._check_initialized() # If in safe mode, check at each iteration for for whether a switch is required @@ -200,9 +205,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . inputs = tuple([tensor.to(device) for tensor in inputs]) logger.warning(f"Moved all input Tensors to cuda:{device_id}") - with torch.autograd.profiler.record_function( - "PythonTorchTensorRTModule:ProcessInputs" - ) if self.profiling_enabled else nullcontext(): + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:ProcessInputs" + ) + if self.profiling_enabled + else nullcontext() + ): assert len(inputs) == len( self.input_names ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}." @@ -239,9 +248,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . idx, tuple(contiguous_inputs[i].shape) ) - with torch.autograd.profiler.record_function( - "PythonTorchTensorRTModule:ProcessOutputs" - ) if self.profiling_enabled else nullcontext(): + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:ProcessOutputs" + ) + if self.profiling_enabled + else nullcontext() + ): # create output tensors outputs: List[torch.Tensor] = [] @@ -266,9 +279,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . ) bindings[idx] = output.data_ptr() - with torch.autograd.profiler.record_function( - "PythonTorchTensorRTModule:TensorRTRuntime" - ) if self.profiling_enabled else nullcontext(): + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:TensorRTRuntime" + ) + if self.profiling_enabled + else nullcontext() + ): self.context.execute_async_v2( bindings, torch.cuda.current_stream().cuda_stream ) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 17c19eda33..b639ea8ce9 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -3,24 +3,22 @@ import math import operator import warnings -from typing import cast, Dict, Optional, Sequence, Tuple, Union +from typing import Dict, Optional, Sequence, Tuple, Union, cast import numpy as np # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils +from torch.fx.immutable_collections import immutable_list +from torch.fx.node import Argument, Target from torch_tensorrt.fx.converters import acc_ops_converters +from torch_tensorrt.fx.converters.impl import activation, convolution from ..converter_registry import tensorrt_converter - from ..types import * # noqa: F403 -from torch.fx.immutable_collections import immutable_list -from torch.fx.node import Argument, Target - from .converter_utils import * # noqa: F403 -import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils -from torch_tensorrt.fx.converters.impl import activation, convolution _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -317,21 +315,19 @@ def aten_ops_max_poolnd( kwargs_new = { "input": args[0], "kernel_size": args[1], - "stride": args[2] - if len(args) > 2 - else (None, None) - if len(args[1]) == 2 - else (None, None, None), - "padding": args[3] - if len(args) > 3 - else (0, 0) - if len(args[1]) == 2 - else (0, 0, 0), - "dilation": args[4] - if len(args) > 4 - else (1, 1) - if len(args[1]) == 2 - else (1, 1, 1), + "stride": ( + args[2] + if len(args) > 2 + else (None, None) + if len(args[1]) == 2 + else (None, None, None) + ), + "padding": ( + args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0) + ), + "dilation": ( + args[4] if len(args) > 4 else (1, 1) if len(args[1]) == 2 else (1, 1, 1) + ), "ceil_mode": args[5] if len(args) > 5 else False, } return acc_ops_converters.acc_ops_max_poolnd( diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index d7ef976fba..a7df4d10e1 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -17,7 +17,7 @@ from .converter_registry import CONVERTERS from .input_tensor_spec import InputTensorSpec from .observer import Observer -from .utils import get_dynamic_dims, LowerPrecision, unified_dtype_converter, Frameworks +from .utils import Frameworks, LowerPrecision, get_dynamic_dims, unified_dtype_converter _LOGGER: logging.Logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 5f66519e05..fa148ce6cb 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -16,7 +16,6 @@ from .passes.pass_utils import PassFunc, validate_inference from .tools.timing_cache_utils import TimingCacheManager from .tools.trt_splitter import TRTSplitter, TRTSplitterSetting - from .tracer.acc_tracer import acc_tracer from .trt_module import TRTModule from .utils import LowerPrecision @@ -126,9 +125,11 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: input_specs=self.lower_setting.input_specs, explicit_batch_dimension=self.lower_setting.explicit_batch_dimension, explicit_precision=self.lower_setting.explicit_precision, - logger_level=trt.Logger.VERBOSE - if self.lower_setting.verbose_log - else trt.Logger.WARNING, + logger_level=( + trt.Logger.VERBOSE + if self.lower_setting.verbose_log + else trt.Logger.WARNING + ), ) interp_result: TRTInterpreterResult = interpreter.run( @@ -138,9 +139,11 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: strict_type_constraints=self.lower_setting.strict_type_constraints, algorithm_selector=algo_selector, timing_cache=cache_data, - profiling_verbosity=trt.ProfilingVerbosity.DETAILED - if self.lower_setting.verbose_profile - else trt.ProfilingVerbosity.LAYER_NAMES_ONLY, + profiling_verbosity=( + trt.ProfilingVerbosity.DETAILED + if self.lower_setting.verbose_profile + else trt.ProfilingVerbosity.LAYER_NAMES_ONLY + ), tactic_sources=self.lower_setting.tactic_sources, ) @@ -297,10 +300,8 @@ def do_lower(module: nn.Module, inputs: Input) -> nn.Module: # handle inputs with custom types. By default, just handle # tensors and NoneType. if fp16_conversion_fn is None: - conversion_fn = ( - lambda x: x.half() - if x is not None and x.dtype == torch.float32 - else x + conversion_fn = lambda x: ( + x.half() if x is not None and x.dtype == torch.float32 else x ) else: conversion_fn = fp16_conversion_fn diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass.py b/py/torch_tensorrt/fx/passes/lower_basic_pass.py index b203bc82e0..4a55e294a5 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass.py @@ -11,7 +11,6 @@ from torch.fx.experimental.const_fold import split_const_subgraphs from ..observer import observable - from ..tracer.acc_tracer import acc_ops from ..tracer.acc_tracer.acc_utils import get_attr from .pass_utils import log_before_after, validate_inference diff --git a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py index 6e6b40d42f..8f3cc576ec 100644 --- a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py @@ -5,19 +5,17 @@ import torch from torch import nn -from torch.fx.passes.pass_manager import inplace_wrapper, PassManager +from torch.fx.passes.pass_manager import PassManager, inplace_wrapper from torch.fx.passes.shape_prop import ShapeProp -from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult +from torch.fx.passes.splitter_base import SplitResult, generate_inputs_for_submodules from torch_tensorrt.fx.passes.pass_utils import apply_bfloat_float_conversion from torch_tensorrt.fx.utils import LowerPrecision from ..input_tensor_spec import generate_input_specs - from ..lower_setting import LowerSetting from ..observer import Observer from ..passes.remove_duplicate_output_args import remove_duplicate_output_args from .graph_opts import common_subexpression_elimination - from .lower_basic_pass import ( # noqa fix_clamp_numerical_limits_to_fp16, fix_reshape_batch_dim, @@ -26,7 +24,6 @@ run_const_fold, ) - _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -196,9 +193,11 @@ def lower_func(split_result: SplitResult) -> nn.Module: self.lower_setting.input_specs = generate_input_specs( submod_inputs, self.lower_setting, - additional_submodule_inputs[submod_name] - if additional_submodule_inputs - else None, + ( + additional_submodule_inputs[submod_name] + if additional_submodule_inputs + else None + ), ) lowered_module = self._lower_func( submod, submod_inputs, self.lower_setting, submod_name @@ -236,9 +235,11 @@ def lower_func(split_result: SplitResult) -> nn.Module: lowering_start_time = datetime.datetime.now() self.lower_setting.additional_inputs = ( - additional_submodule_inputs[submod_name] - if additional_submodule_inputs - else None, + ( + additional_submodule_inputs[submod_name] + if additional_submodule_inputs + else None + ), ) lowered_module = self._lower_func( diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py index 29d174d9fd..cf49d028ae 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py @@ -23,9 +23,11 @@ def forward(self, x): Split(), inputs, expected_ops={ - acc_ops.split - if isinstance(split_size_or_sections, int) - else acc_ops.slice_tensor + ( + acc_ops.split + if isinstance(split_size_or_sections, int) + else acc_ops.slice_tensor + ) }, test_explicit_batch_dim=False, ) @@ -70,9 +72,11 @@ def forward(self, x): Split(), input_specs, expected_ops={ - acc_ops.split - if isinstance(split_size_or_sections, int) - else acc_ops.slice_tensor + ( + acc_ops.split + if isinstance(split_size_or_sections, int) + else acc_ops.slice_tensor + ) }, ) diff --git a/py/torch_tensorrt/fx/tools/common_fx2trt.py b/py/torch_tensorrt/fx/tools/common_fx2trt.py index 6d883a4f62..2ddd832c2a 100644 --- a/py/torch_tensorrt/fx/tools/common_fx2trt.py +++ b/py/torch_tensorrt/fx/tools/common_fx2trt.py @@ -7,7 +7,6 @@ import tensorrt as trt import torch import torch.fx - import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer from torch.fx.experimental.normalize import NormalizeArgs @@ -154,9 +153,9 @@ def run_test_custom_compare_results( self.assert_has_op(mod, expected_ops) interpreter_result = interpreter.run( - lower_precision=LowerPrecision.FP16 - if fp16_mode - else LowerPrecision.FP32 + lower_precision=( + LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32 + ) ) trt_mod = TRTModule( interpreter_result.engine, diff --git a/py/torch_tensorrt/fx/trt_module.py b/py/torch_tensorrt/fx/trt_module.py index ab2d9ac348..c5bab21353 100644 --- a/py/torch_tensorrt/fx/trt_module.py +++ b/py/torch_tensorrt/fx/trt_module.py @@ -4,7 +4,7 @@ import tensorrt as trt import torch -from .utils import unified_dtype_converter, Frameworks +from .utils import Frameworks, unified_dtype_converter class TRTModule(torch.nn.Module): @@ -69,9 +69,11 @@ def _initialize(self): for idx in self.output_binding_indices_in_order ] self.output_shapes = [ - tuple(self.engine.get_binding_shape(idx)) - if self.engine.has_implicit_batch_dimension - else tuple() + ( + tuple(self.engine.get_binding_shape(idx)) + if self.engine.has_implicit_batch_dimension + else tuple() + ) for idx in self.output_binding_indices_in_order ] self.hidden_output_dtypes: Sequence[torch.dtype] = [ @@ -81,9 +83,11 @@ def _initialize(self): for idx in self.hidden_output_binding_indices_in_order ] self.hidden_output_shapes = [ - tuple(self.engine.get_binding_shape(idx)) - if self.engine.has_implicit_batch_dimension - else tuple() + ( + tuple(self.engine.get_binding_shape(idx)) + if self.engine.has_implicit_batch_dimension + else tuple() + ) for idx in self.hidden_output_binding_indices_in_order ] diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index b9a84152e1..37f5fb79e3 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -3,6 +3,7 @@ from copy import deepcopy from typing import Any, Dict, List, Optional, Set +import tensorrt as trt import torch import torch_tensorrt._C.ts as _ts_C from torch_tensorrt import _C, _enums @@ -11,8 +12,6 @@ from torch_tensorrt.logging import Level, log from torch_tensorrt.ts._Input import TorchScriptInput -import tensorrt as trt - def _internal_input_to_torch_class_input(i: _C.Input) -> torch.classes.tensorrt._Input: clone = torch.classes.tensorrt._Input() @@ -406,9 +405,9 @@ def TensorRTCompileSpec( "device": device, "disable_tf32": disable_tf32, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas "sparse_weights": sparse_weights, # Enable sparsity for convolution and fully connected layers. - "enabled_precisions": enabled_precisions - if enabled_precisions is not None - else set(), # Enabling FP16 kernels + "enabled_precisions": ( + enabled_precisions if enabled_precisions is not None else set() + ), # Enabling FP16 kernels "refit": refit, # enable refit "debug": debug, # enable debuggable engine "capability": capability, # Restrict kernel selection to safe gpu kernels or safe dla kernels diff --git a/pyproject.toml b/pyproject.toml index c987ac1f40..5c42700ef8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -177,7 +177,7 @@ skip = [ [tool.black] #line-length = 120 -target-versions = ["py38", "py39", "py310", "py311", "py312"] +target-version = ["py38", "py39", "py310", "py311", "py312"] force-exclude = """ elu_converter/setup.py """ From 1e36d99a9bbe5bc6f8ef54b078534b41bf95ab33 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 2 Feb 2024 15:48:47 -0800 Subject: [PATCH 2/2] Additional linting fixes --- .pre-commit-config.yaml | 2 +- py/torch_tensorrt/_Device.py | 10 ++++++---- py/torch_tensorrt/_Input.py | 12 ++++++------ .../dynamo/conversion/_TRTInterpreter.py | 12 ++++++------ .../dynamo/conversion/converter_utils.py | 6 ++---- .../dynamo/lowering/_decomposition_groups.py | 12 ++++++------ .../dynamo/lowering/passes/lower_linear.py | 10 ++++------ .../passes/lower_scaled_dot_product_attention.py | 10 ++++------ .../dynamo/lowering/passes/view_to_reshape.py | 10 ++++------ .../fx/converters/aten_ops_converters.py | 4 +--- py/torch_tensorrt/fx/fx2trt.py | 12 ++++++------ py/torch_tensorrt/fx/passes/lower_basic_pass.py | 6 +++--- py/torch_tensorrt/fx/passes/pass_utils.py | 4 +--- 13 files changed, 50 insertions(+), 60 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4738ea80be..f4ac2ab9e2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -47,7 +47,7 @@ repos: hooks: - id: ruff - repo: https://github.com/psf/black - rev: 23.7.0 + rev: 24.1.1 hooks: - id: black exclude: ^examples/custom_converters/elu_converter/setup.py|^docs diff --git a/py/torch_tensorrt/_Device.py b/py/torch_tensorrt/_Device.py index 0f8ce1e392..6f20b6c84c 100644 --- a/py/torch_tensorrt/_Device.py +++ b/py/torch_tensorrt/_Device.py @@ -32,12 +32,14 @@ class Device(object): allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed """ - device_type: Optional[ - trt.DeviceType - ] = None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified. + device_type: Optional[trt.DeviceType] = ( + None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified. + ) gpu_id: int = -1 #: Device ID for target GPU dla_core: int = -1 #: Core ID for target DLA core - allow_gpu_fallback: bool = False #: Whether falling back to GPU if DLA cannot support an op should be allowed + allow_gpu_fallback: bool = ( + False #: Whether falling back to GPU if DLA cannot support an op should be allowed + ) def __init__(self, *args: Any, **kwargs: Any): """__init__ Method for torch_tensorrt.Device diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 9acb073c62..db36678d17 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -28,12 +28,12 @@ class _ShapeMode(Enum): STATIC = 0 DYNAMIC = 1 - shape_mode: Optional[ - _ShapeMode - ] = None #: Is input statically or dynamically shaped - shape: Optional[ - Tuple[int, ...] | Dict[str, Tuple[int, ...]] - ] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }`` + shape_mode: Optional[_ShapeMode] = ( + None #: Is input statically or dynamically shaped + ) + shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = ( + None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }`` + ) dtype: _enums.dtype = ( _enums.dtype.unknown ) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 5db9fc183e..06ae596ed0 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -28,9 +28,9 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[ - Callable[[torch.fx.GraphModule], None] -] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER") +TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = ( + Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER") +) class UnsupportedOperatorException(RuntimeError): @@ -92,9 +92,9 @@ def __init__( self._cur_node: Optional[torch.fx.Node] = None self._input_names: List[str] = [] self._output_names: List[str] = [] - self._itensor_to_tensor_meta: Dict[ - trt.tensorrt.ITensor, TensorMetadata - ] = dict() + self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = ( + dict() + ) self.compilation_settings = compilation_settings # Data types for TRT Module output Tensors diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index f90c869c15..f9d14917f1 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -324,13 +324,11 @@ def get_trt_tensor( @overload -def get_positive_dim(dim: int, dim_size: int) -> int: - ... +def get_positive_dim(dim: int, dim_size: int) -> int: ... @overload -def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: - ... +def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ... def get_positive_dim( diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index af92a9dc50..de791851db 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -7,9 +7,9 @@ aten = torch.ops.aten -_core_aten_decompositions: Dict[ - OpOverload, Callable[[Any], Any] -] = core_aten_decompositions() +_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = ( + core_aten_decompositions() +) torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = { aten._adaptive_avg_pool2d_backward, aten.addcdiv, @@ -180,9 +180,9 @@ } -ENABLED_TORCH_DECOMPOSITIONS: Dict[ - OpOverload, Callable[[Any], Any] -] = get_torch_decompositions(torch_enabled_decompositions) +ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = ( + get_torch_decompositions(torch_enabled_decompositions) +) TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {} diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py index 75ad067a3f..ef2c0531a6 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py @@ -22,12 +22,10 @@ def lower_linear( return gm -def linear_replacement() -> ( - Tuple[ - torch.fx.GraphModule, - Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], - ] -): +def linear_replacement() -> Tuple[ + torch.fx.GraphModule, + Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], +]: """Constructs the original and replacement functions for linear""" # Original graph diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py index 74dee9c0c9..161dbbe9df 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py @@ -60,12 +60,10 @@ def lower_scaled_dot_product_attention( return gm -def scaled_dot_product_attention_replacement() -> ( - Tuple[ - Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]], - Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], - ] -): +def scaled_dot_product_attention_replacement() -> Tuple[ + Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]], + Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], +]: """Constructs the original and replacement functions for efficient attention""" # Efficient Attention original graph diff --git a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py index efc836814f..e2ef051f06 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py @@ -22,12 +22,10 @@ def view_to_reshape( return gm -def view_replacement() -> ( - Tuple[ - torch.fx.GraphModule, - Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor], - ] -): +def view_replacement() -> Tuple[ + torch.fx.GraphModule, + Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor], +]: """Constructs the original and replacement functions for view""" # Original graph diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index b639ea8ce9..f11e40a6db 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -318,9 +318,7 @@ def aten_ops_max_poolnd( "stride": ( args[2] if len(args) > 2 - else (None, None) - if len(args[1]) == 2 - else (None, None, None) + else (None, None) if len(args[1]) == 2 else (None, None, None) ), "padding": ( args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0) diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index a7df4d10e1..6a29932b1b 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -21,9 +21,9 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[ - Callable[[torch.fx.GraphModule], None] -] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER") +TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = ( + Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER") +) class TRTInterpreterResult(NamedTuple): @@ -75,9 +75,9 @@ def __init__( self._cur_node_name: Optional[str] = None self._input_names: List[str] = [] self._output_names: List[str] = [] - self._itensor_to_tensor_meta: Dict[ - trt.tensorrt.ITensor, TensorMetadata - ] = dict() + self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = ( + dict() + ) def validate_input_specs(self): for shape, _, _, shape_ranges, has_batch_dim in self.input_specs: diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass.py b/py/torch_tensorrt/fx/passes/lower_basic_pass.py index 4a55e294a5..fb75a3e3c3 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass.py @@ -537,9 +537,9 @@ def get_reshape_batch_size_inferred_source( ) if not reshape_batch_size: continue - reshape_batch_size_inferred_source: Optional[ - fx.Node - ] = get_reshape_batch_size_inferred_source(reshape_batch_size) + reshape_batch_size_inferred_source: Optional[fx.Node] = ( + get_reshape_batch_size_inferred_source(reshape_batch_size) + ) if not reshape_batch_size_inferred_source: continue diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index 0b8578ffba..2de5c23aaf 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -195,9 +195,7 @@ def pass_with_validation( kwargs2["rtol"] = rtol if atol: kwargs2["atol"] = atol - kwargs2[ - "msg" - ] = ( + kwargs2["msg"] = ( lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" ) # If tensors are on different devices, make sure to compare