Skip to content

fix: Linter + config fix #2636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ repos:
hooks:
- id: ruff
- repo: https://github.com/psf/black
rev: 23.7.0
rev: 24.1.1
hooks:
- id: black
exclude: ^examples/custom_converters/elu_converter/setup.py|^docs
Expand Down
4 changes: 3 additions & 1 deletion examples/int8/training/vgg16/vgg16.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](
https://arxiv.org/abs/1409.1556) (ICLR 2015)
"""

from functools import reduce

import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce


class VGG(nn.Module):
Expand Down
10 changes: 6 additions & 4 deletions py/torch_tensorrt/_Device.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ class Device(object):
allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
"""

device_type: Optional[
trt.DeviceType
] = None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
device_type: Optional[trt.DeviceType] = (
None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
)
gpu_id: int = -1 #: Device ID for target GPU
dla_core: int = -1 #: Core ID for target DLA core
allow_gpu_fallback: bool = False #: Whether falling back to GPU if DLA cannot support an op should be allowed
allow_gpu_fallback: bool = (
False #: Whether falling back to GPU if DLA cannot support an op should be allowed
)

def __init__(self, *args: Any, **kwargs: Any):
"""__init__ Method for torch_tensorrt.Device
Expand Down
12 changes: 6 additions & 6 deletions py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ class _ShapeMode(Enum):
STATIC = 0
DYNAMIC = 1

shape_mode: Optional[
_ShapeMode
] = None #: Is input statically or dynamically shaped
shape: Optional[
Tuple[int, ...] | Dict[str, Tuple[int, ...]]
] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
shape_mode: Optional[_ShapeMode] = (
None #: Is input statically or dynamically shaped
)
shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (
None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
)
dtype: _enums.dtype = (
_enums.dtype.unknown
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
Expand Down
9 changes: 4 additions & 5 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union

import torch
import torch_tensorrt
from torch.export import ExportedProgram
from torch.fx.node import Target
from torch_tensorrt import _enums
Expand Down Expand Up @@ -66,8 +67,6 @@
to_torch_tensorrt_device,
)

import torch_tensorrt

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -217,9 +216,9 @@ def compile(
"device": device,
"workspace_size": workspace_size,
"min_block_size": min_block_size,
"torch_executed_ops": torch_executed_ops
if torch_executed_ops is not None
else set(),
"torch_executed_ops": (
torch_executed_ops if torch_executed_ops is not None else set()
),
"pass_through_build_failures": pass_through_build_failures,
"max_aux_streams": max_aux_streams,
"version_compatible": version_compatible,
Expand Down
12 changes: 6 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@

_LOGGER: logging.Logger = logging.getLogger(__name__)

TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
Callable[[torch.fx.GraphModule], None]
] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
)


class UnsupportedOperatorException(RuntimeError):
Expand Down Expand Up @@ -92,9 +92,9 @@ def __init__(
self._cur_node: Optional[torch.fx.Node] = None
self._input_names: List[str] = []
self._output_names: List[str] = []
self._itensor_to_tensor_meta: Dict[
trt.tensorrt.ITensor, TensorMetadata
] = dict()
self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
dict()
)
self.compilation_settings = compilation_settings

# Data types for TRT Module output Tensors
Expand Down
6 changes: 2 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,13 +324,11 @@ def get_trt_tensor(


@overload
def get_positive_dim(dim: int, dim_size: int) -> int:
...
def get_positive_dim(dim: int, dim_size: int) -> int: ...


@overload
def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:
...
def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ...


def get_positive_dim(
Expand Down
12 changes: 6 additions & 6 deletions py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

aten = torch.ops.aten

_core_aten_decompositions: Dict[
OpOverload, Callable[[Any], Any]
] = core_aten_decompositions()
_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = (
core_aten_decompositions()
)
torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
aten._adaptive_avg_pool2d_backward,
aten.addcdiv,
Expand Down Expand Up @@ -180,9 +180,9 @@
}


ENABLED_TORCH_DECOMPOSITIONS: Dict[
OpOverload, Callable[[Any], Any]
] = get_torch_decompositions(torch_enabled_decompositions)
ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = (
get_torch_decompositions(torch_enabled_decompositions)
)
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}


Expand Down
10 changes: 4 additions & 6 deletions py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ def lower_linear(
return gm


def linear_replacement() -> (
Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
]
):
def linear_replacement() -> Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
]:
"""Constructs the original and replacement functions for linear"""

# Original graph
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,10 @@ def lower_scaled_dot_product_attention(
return gm


def scaled_dot_product_attention_replacement() -> (
Tuple[
Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
]
):
def scaled_dot_product_attention_replacement() -> Tuple[
Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
]:
"""Constructs the original and replacement functions for efficient attention"""

# Efficient Attention original graph
Expand Down
10 changes: 4 additions & 6 deletions py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ def view_to_reshape(
return gm


def view_replacement() -> (
Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
]
):
def view_replacement() -> Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
]:
"""Constructs the original and replacement functions for view"""

# Original graph
Expand Down
57 changes: 37 additions & 20 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import tensorrt as trt
import torch
import torch_tensorrt
from torch.nn import Module
from torch_tensorrt._Device import Device
from torch_tensorrt.dynamo.runtime.tools import (
Expand All @@ -15,8 +16,6 @@
)
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter

import torch_tensorrt

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -101,9 +100,11 @@ def _initialize(self) -> None:
for idx in self.output_binding_indices_in_order
]
self.output_shapes = [
tuple(self.engine.get_binding_shape(idx))
if self.engine.has_implicit_batch_dimension
else tuple()
(
tuple(self.engine.get_binding_shape(idx))
if self.engine.has_implicit_batch_dimension
else tuple()
)
for idx in self.output_binding_indices_in_order
]
self.hidden_output_dtypes = [
Expand All @@ -113,9 +114,11 @@ def _initialize(self) -> None:
for idx in self.hidden_output_binding_indices_in_order
]
self.hidden_output_shapes = [
tuple(self.engine.get_binding_shape(idx))
if self.engine.has_implicit_batch_dimension
else tuple()
(
tuple(self.engine.get_binding_shape(idx))
if self.engine.has_implicit_batch_dimension
else tuple()
)
for idx in self.hidden_output_binding_indices_in_order
]

Expand Down Expand Up @@ -167,9 +170,11 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
self.context = self.engine.create_execution_context()

def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
with torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:Forward"
) if self.profiling_enabled else nullcontext():
with (
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
if self.profiling_enabled
else nullcontext()
):
self._check_initialized()

# If in safe mode, check at each iteration for for whether a switch is required
Expand Down Expand Up @@ -200,9 +205,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
inputs = tuple([tensor.to(device) for tensor in inputs])
logger.warning(f"Moved all input Tensors to cuda:{device_id}")

with torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:ProcessInputs"
) if self.profiling_enabled else nullcontext():
with (
torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:ProcessInputs"
)
if self.profiling_enabled
else nullcontext()
):
assert len(inputs) == len(
self.input_names
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."
Expand Down Expand Up @@ -239,9 +248,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
idx, tuple(contiguous_inputs[i].shape)
)

with torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:ProcessOutputs"
) if self.profiling_enabled else nullcontext():
with (
torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:ProcessOutputs"
)
if self.profiling_enabled
else nullcontext()
):
# create output tensors
outputs: List[torch.Tensor] = []

Expand All @@ -266,9 +279,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
)
bindings[idx] = output.data_ptr()

with torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:TensorRTRuntime"
) if self.profiling_enabled else nullcontext():
with (
torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:TensorRTRuntime"
)
if self.profiling_enabled
else nullcontext()
):
self.context.execute_async_v2(
bindings, torch.cuda.current_stream().cuda_stream
)
Expand Down
38 changes: 16 additions & 22 deletions py/torch_tensorrt/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,22 @@
import math
import operator
import warnings
from typing import cast, Dict, Optional, Sequence, Tuple, Union
from typing import Dict, Optional, Sequence, Tuple, Union, cast

import numpy as np

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
from torch.fx.immutable_collections import immutable_list
from torch.fx.node import Argument, Target
from torch_tensorrt.fx.converters import acc_ops_converters
from torch_tensorrt.fx.converters.impl import activation, convolution

from ..converter_registry import tensorrt_converter

from ..types import * # noqa: F403
from torch.fx.immutable_collections import immutable_list
from torch.fx.node import Argument, Target

from .converter_utils import * # noqa: F403
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
from torch_tensorrt.fx.converters.impl import activation, convolution

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -317,21 +315,17 @@ def aten_ops_max_poolnd(
kwargs_new = {
"input": args[0],
"kernel_size": args[1],
"stride": args[2]
if len(args) > 2
else (None, None)
if len(args[1]) == 2
else (None, None, None),
"padding": args[3]
if len(args) > 3
else (0, 0)
if len(args[1]) == 2
else (0, 0, 0),
"dilation": args[4]
if len(args) > 4
else (1, 1)
if len(args[1]) == 2
else (1, 1, 1),
"stride": (
args[2]
if len(args) > 2
else (None, None) if len(args[1]) == 2 else (None, None, None)
),
"padding": (
args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0)
),
"dilation": (
args[4] if len(args) > 4 else (1, 1) if len(args[1]) == 2 else (1, 1, 1)
),
"ceil_mode": args[5] if len(args) > 5 else False,
}
return acc_ops_converters.acc_ops_max_poolnd(
Expand Down
Loading