diff --git a/docs/_downloads/0e30a6276601af7e5fc4d5166e2e3d37/torch_compile_advanced_usage.py b/docs/_downloads/0e30a6276601af7e5fc4d5166e2e3d37/torch_compile_advanced_usage.py index 8ebedab111..af7d4b212d 100644 --- a/docs/_downloads/0e30a6276601af7e5fc4d5166e2e3d37/torch_compile_advanced_usage.py +++ b/docs/_downloads/0e30a6276601af7e5fc4d5166e2e3d37/torch_compile_advanced_usage.py @@ -4,7 +4,8 @@ Torch Compile Advanced Usage ====================================================== -This interactive script is intended as an overview of the process by which `torch_tensorrt.compile(..., ir="torch_compile", ...)` works, and how it integrates with the `torch.compile` API.""" +This interactive script is intended as an overview of the process by which `torch_tensorrt.compile(..., ir="torch_compile", ...)` works, and how it integrates with the `torch.compile` API. +""" # %% # Imports and Model Definition diff --git a/docs/_downloads/2a9ac10f2667047a7f398d1593b7ca33/torch_export_gpt2.py b/docs/_downloads/2a9ac10f2667047a7f398d1593b7ca33/torch_export_gpt2.py index cea0f3adf2..4d34c58de4 100644 --- a/docs/_downloads/2a9ac10f2667047a7f398d1593b7ca33/torch_export_gpt2.py +++ b/docs/_downloads/2a9ac10f2667047a7f398d1593b7ca33/torch_export_gpt2.py @@ -4,7 +4,8 @@ Compiling GPT2 using the dynamo backend ========================================================== -This script illustrates Torch-TensorRT workflow with dynamo backend on popular GPT2 model.""" +This script illustrates Torch-TensorRT workflow with dynamo backend on popular GPT2 model. +""" # %% # Imports and Model Definition diff --git a/docs/_downloads/3d4d74f6636d986f33167154f6553961/torch_export_cudagraphs.py b/docs/_downloads/3d4d74f6636d986f33167154f6553961/torch_export_cudagraphs.py index 1671c7783d..fb31766b7c 100644 --- a/docs/_downloads/3d4d74f6636d986f33167154f6553961/torch_export_cudagraphs.py +++ b/docs/_downloads/3d4d74f6636d986f33167154f6553961/torch_export_cudagraphs.py @@ -4,7 +4,8 @@ Torch Export with Cudagraphs ====================================================== -This interactive script is intended as an overview of the process by which the Torch-TensorRT Cudagraphs integration can be used in the `ir="dynamo"` path. The functionality works similarly in the `torch.compile` path as well.""" +This interactive script is intended as an overview of the process by which the Torch-TensorRT Cudagraphs integration can be used in the `ir="dynamo"` path. The functionality works similarly in the `torch.compile` path as well. +""" # %% # Imports and Model Definition diff --git a/docs/_downloads/418941399c146271a7b7728ba3059960/dynamo_compile_resnet_example.py b/docs/_downloads/418941399c146271a7b7728ba3059960/dynamo_compile_resnet_example.py index 797e41f5fd..5826e28d1e 100644 --- a/docs/_downloads/418941399c146271a7b7728ba3059960/dynamo_compile_resnet_example.py +++ b/docs/_downloads/418941399c146271a7b7728ba3059960/dynamo_compile_resnet_example.py @@ -4,7 +4,8 @@ Compiling ResNet using the Torch-TensorRT Dyanmo Frontend ========================================================== -This interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a ResNet model.""" +This interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a ResNet model. +""" # %% # Imports and Model Definition diff --git a/docs/_downloads/7b7004dc2ea6f839be532665e16e0426/torch_export_llama2.py b/docs/_downloads/7b7004dc2ea6f839be532665e16e0426/torch_export_llama2.py index 5cfd1ed61c..2f3e3cba43 100644 --- a/docs/_downloads/7b7004dc2ea6f839be532665e16e0426/torch_export_llama2.py +++ b/docs/_downloads/7b7004dc2ea6f839be532665e16e0426/torch_export_llama2.py @@ -4,7 +4,8 @@ Compiling Llama2 using the dynamo backend ========================================================== -This script illustrates Torch-TensorRT workflow with dynamo backend on popular Llama2 model.""" +This script illustrates Torch-TensorRT workflow with dynamo backend on popular Llama2 model. +""" # %% # Imports and Model Definition diff --git a/docs/_downloads/d6e1bb6ec5f884994554d9d12e37a0f6/torch_compile_resnet_example.py b/docs/_downloads/d6e1bb6ec5f884994554d9d12e37a0f6/torch_compile_resnet_example.py index f852d60158..fb75986099 100644 --- a/docs/_downloads/d6e1bb6ec5f884994554d9d12e37a0f6/torch_compile_resnet_example.py +++ b/docs/_downloads/d6e1bb6ec5f884994554d9d12e37a0f6/torch_compile_resnet_example.py @@ -4,7 +4,8 @@ Compiling ResNet with dynamic shapes using the `torch.compile` backend ========================================================== -This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a ResNet model.""" +This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a ResNet model. +""" # %% # Imports and Model Definition diff --git a/docs/_downloads/dfa60e8f9850fd7761f3e7da81304d32/torch_compile_transformers_example.py b/docs/_downloads/dfa60e8f9850fd7761f3e7da81304d32/torch_compile_transformers_example.py index 221ecd4fd1..17cf46e8a3 100644 --- a/docs/_downloads/dfa60e8f9850fd7761f3e7da81304d32/torch_compile_transformers_example.py +++ b/docs/_downloads/dfa60e8f9850fd7761f3e7da81304d32/torch_compile_transformers_example.py @@ -4,7 +4,8 @@ Compiling BERT using the `torch.compile` backend ============================================================== -This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a BERT model.""" +This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a BERT model. +""" # %% # Imports and Model Definition diff --git a/docs/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py b/docs/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py index f73bd1e780..3fb63e8a32 100644 --- a/docs/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py +++ b/docs/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py @@ -4,7 +4,8 @@ Dynamo Compile Advanced Usage ====================================================== -This interactive script is intended as an overview of the process by which `torch_tensorrt.dynamo.compile` works, and how it integrates with the new `torch.compile` API.""" +This interactive script is intended as an overview of the process by which `torch_tensorrt.dynamo.compile` works, and how it integrates with the new `torch.compile` API. +""" # %% # Imports and Model Definition diff --git a/docs/_downloads/e550c5f53cc43e11aa6da8cfb79b54df/dynamo_compile_transformers_example.py b/docs/_downloads/e550c5f53cc43e11aa6da8cfb79b54df/dynamo_compile_transformers_example.py index dd7fe2e07a..59319078a4 100644 --- a/docs/_downloads/e550c5f53cc43e11aa6da8cfb79b54df/dynamo_compile_transformers_example.py +++ b/docs/_downloads/e550c5f53cc43e11aa6da8cfb79b54df/dynamo_compile_transformers_example.py @@ -4,7 +4,8 @@ Compiling a Transformer using torch.compile and TensorRT ============================================================== -This interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a transformer-based model.""" +This interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a transformer-based model. +""" # %% # Imports and Model Definition diff --git a/docs/v1.4.0/_downloads/418941399c146271a7b7728ba3059960/dynamo_compile_resnet_example.py b/docs/v1.4.0/_downloads/418941399c146271a7b7728ba3059960/dynamo_compile_resnet_example.py index 797e41f5fd..5826e28d1e 100644 --- a/docs/v1.4.0/_downloads/418941399c146271a7b7728ba3059960/dynamo_compile_resnet_example.py +++ b/docs/v1.4.0/_downloads/418941399c146271a7b7728ba3059960/dynamo_compile_resnet_example.py @@ -4,7 +4,8 @@ Compiling ResNet using the Torch-TensorRT Dyanmo Frontend ========================================================== -This interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a ResNet model.""" +This interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a ResNet model. +""" # %% # Imports and Model Definition diff --git a/docs/v1.4.0/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py b/docs/v1.4.0/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py index f73bd1e780..3fb63e8a32 100644 --- a/docs/v1.4.0/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py +++ b/docs/v1.4.0/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py @@ -4,7 +4,8 @@ Dynamo Compile Advanced Usage ====================================================== -This interactive script is intended as an overview of the process by which `torch_tensorrt.dynamo.compile` works, and how it integrates with the new `torch.compile` API.""" +This interactive script is intended as an overview of the process by which `torch_tensorrt.dynamo.compile` works, and how it integrates with the new `torch.compile` API. +""" # %% # Imports and Model Definition diff --git a/docs/v1.4.0/_downloads/e550c5f53cc43e11aa6da8cfb79b54df/dynamo_compile_transformers_example.py b/docs/v1.4.0/_downloads/e550c5f53cc43e11aa6da8cfb79b54df/dynamo_compile_transformers_example.py index dd7fe2e07a..59319078a4 100644 --- a/docs/v1.4.0/_downloads/e550c5f53cc43e11aa6da8cfb79b54df/dynamo_compile_transformers_example.py +++ b/docs/v1.4.0/_downloads/e550c5f53cc43e11aa6da8cfb79b54df/dynamo_compile_transformers_example.py @@ -4,7 +4,8 @@ Compiling a Transformer using torch.compile and TensorRT ============================================================== -This interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a transformer-based model.""" +This interactive script is intended as a sample of the `torch_tensorrt.dynamo.compile` workflow on a transformer-based model. +""" # %% # Imports and Model Definition diff --git a/examples/dynamo/torch_compile_advanced_usage.py b/examples/dynamo/torch_compile_advanced_usage.py index 8ebedab111..af7d4b212d 100644 --- a/examples/dynamo/torch_compile_advanced_usage.py +++ b/examples/dynamo/torch_compile_advanced_usage.py @@ -4,7 +4,8 @@ Torch Compile Advanced Usage ====================================================== -This interactive script is intended as an overview of the process by which `torch_tensorrt.compile(..., ir="torch_compile", ...)` works, and how it integrates with the `torch.compile` API.""" +This interactive script is intended as an overview of the process by which `torch_tensorrt.compile(..., ir="torch_compile", ...)` works, and how it integrates with the `torch.compile` API. +""" # %% # Imports and Model Definition diff --git a/examples/dynamo/torch_compile_resnet_example.py b/examples/dynamo/torch_compile_resnet_example.py index f852d60158..fb75986099 100644 --- a/examples/dynamo/torch_compile_resnet_example.py +++ b/examples/dynamo/torch_compile_resnet_example.py @@ -4,7 +4,8 @@ Compiling ResNet with dynamic shapes using the `torch.compile` backend ========================================================== -This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a ResNet model.""" +This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a ResNet model. +""" # %% # Imports and Model Definition diff --git a/examples/dynamo/torch_compile_transformers_example.py b/examples/dynamo/torch_compile_transformers_example.py index 221ecd4fd1..17cf46e8a3 100644 --- a/examples/dynamo/torch_compile_transformers_example.py +++ b/examples/dynamo/torch_compile_transformers_example.py @@ -4,7 +4,8 @@ Compiling BERT using the `torch.compile` backend ============================================================== -This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a BERT model.""" +This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a BERT model. +""" # %% # Imports and Model Definition diff --git a/examples/dynamo/torch_export_cudagraphs.py b/examples/dynamo/torch_export_cudagraphs.py index 1671c7783d..fb31766b7c 100644 --- a/examples/dynamo/torch_export_cudagraphs.py +++ b/examples/dynamo/torch_export_cudagraphs.py @@ -4,7 +4,8 @@ Torch Export with Cudagraphs ====================================================== -This interactive script is intended as an overview of the process by which the Torch-TensorRT Cudagraphs integration can be used in the `ir="dynamo"` path. The functionality works similarly in the `torch.compile` path as well.""" +This interactive script is intended as an overview of the process by which the Torch-TensorRT Cudagraphs integration can be used in the `ir="dynamo"` path. The functionality works similarly in the `torch.compile` path as well. +""" # %% # Imports and Model Definition diff --git a/examples/dynamo/torch_export_gpt2.py b/examples/dynamo/torch_export_gpt2.py index cea0f3adf2..4d34c58de4 100644 --- a/examples/dynamo/torch_export_gpt2.py +++ b/examples/dynamo/torch_export_gpt2.py @@ -4,7 +4,8 @@ Compiling GPT2 using the dynamo backend ========================================================== -This script illustrates Torch-TensorRT workflow with dynamo backend on popular GPT2 model.""" +This script illustrates Torch-TensorRT workflow with dynamo backend on popular GPT2 model. +""" # %% # Imports and Model Definition diff --git a/examples/dynamo/torch_export_llama2.py b/examples/dynamo/torch_export_llama2.py index 5cfd1ed61c..2f3e3cba43 100644 --- a/examples/dynamo/torch_export_llama2.py +++ b/examples/dynamo/torch_export_llama2.py @@ -4,7 +4,8 @@ Compiling Llama2 using the dynamo backend ========================================================== -This script illustrates Torch-TensorRT workflow with dynamo backend on popular Llama2 model.""" +This script illustrates Torch-TensorRT workflow with dynamo backend on popular Llama2 model. +""" # %% # Imports and Model Definition diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 126219ee8a..2f953094ca 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -261,7 +261,7 @@ def _supported_input_size_type(input_size: Any) -> bool: @staticmethod def _parse_tensor_domain( - domain: Optional[Tuple[float, float]] + domain: Optional[Tuple[float, float]], ) -> Tuple[float, float]: """ Produce a tuple of integers which specifies a tensor domain in the interval format: [lo, hi) diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index eaefb68ce5..c706c345d6 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -1200,7 +1200,7 @@ def _from( @classmethod def try_from( - c: Union[trt.EngineCapability, EngineCapability] + c: Union[trt.EngineCapability, EngineCapability], ) -> Optional[EngineCapability]: """Create a Torch-TensorRT engine capability enum from a TensorRT engine capability enum. diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTBuilderMonitor.py b/py/torch_tensorrt/dynamo/conversion/_TRTBuilderMonitor.py index 9a1189e44a..9b2755f4c7 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTBuilderMonitor.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTBuilderMonitor.py @@ -53,13 +53,13 @@ def _redraw(self, *, blank_lines: int = 0) -> None: if self._render: def clear_line() -> None: - print("\x1B[2K", end="") + print("\x1b[2K", end="") def move_to_start_of_line() -> None: - print("\x1B[0G", end="") + print("\x1b[0G", end="") def move_cursor_up(lines: int) -> None: - print("\x1B[{}A".format(lines), end="") + print("\x1b[{}A".format(lines), end="") def progress_bar(steps: int, num_steps: int) -> str: INNER_WIDTH = 10 diff --git a/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py index a563118526..eb981f2031 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py @@ -247,7 +247,7 @@ def hard_sigmoid( operation_type = trt.ActivationType.HARD_SIGMOID def hard_sigmoid_dyn_range_fn( - dyn_range: Tuple[float, float] + dyn_range: Tuple[float, float], ) -> Tuple[float, float]: def hard_sigmoid_fn(x: float) -> float: return max(0, min(1, alpha * x + beta)) @@ -310,7 +310,7 @@ def thresholded_relu( operation_type = trt.ActivationType.THRESHOLDED_RELU def thresholded_relu_dyn_range_fn( - dyn_range: Tuple[float, float] + dyn_range: Tuple[float, float], ) -> Tuple[float, float]: def thresholded_relu_fn(x: float) -> float: return x if x > alpha else 0 diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 5d6807f33a..6e5bece418 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -465,7 +465,7 @@ def to_torch_device(device: Optional[Union[Device, torch.device, str]]) -> torch def to_torch_tensorrt_device( - device: Optional[Union[Device, torch.device, str]] + device: Optional[Union[Device, torch.device, str]], ) -> Device: """Cast a device-type to torch_tensorrt.Device diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_where.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_where.py index 72fea70265..1e14b50305 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_where.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_where.py @@ -101,7 +101,7 @@ def __init__(self, x_shape, y_shape): def forward(self, condition): return torch.where(condition, self.x, self.y) - inputs = [(torch.randn(condition_shape) > 0)] + inputs = [torch.randn(condition_shape) > 0] self.run_test( Where(x_shape, y_shape), inputs, diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py index 9d5576bd63..c8db1b62ef 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py @@ -10,7 +10,6 @@ from typing import ( Any, Callable, - cast, Dict, Iterable, Optional, @@ -19,6 +18,7 @@ Tuple, Type, Union, + cast, ) import torch @@ -32,7 +32,6 @@ from . import acc_normalizer, acc_ops, acc_shape_prop, acc_utils # noqa: F401 - _LOGGER = logging.getLogger(__name__) @@ -517,7 +516,7 @@ def _replace_transpose_last_dims_impl( changed = False def _calculate_dim( - transpose_dim: Union[torch.fx.Node, int] + transpose_dim: Union[torch.fx.Node, int], ) -> Union[torch.fx.Node, int]: nonlocal transpose_input_node nonlocal changed diff --git a/tests/py/dynamo/partitioning/test_flaky_global_partitioning.py b/tests/py/dynamo/partitioning/test_flaky_global_partitioning.py new file mode 100644 index 0000000000..2e2013d5e6 --- /dev/null +++ b/tests/py/dynamo/partitioning/test_flaky_global_partitioning.py @@ -0,0 +1,108 @@ +from copy import deepcopy + +import numpy as np +import pytest +import torch +import torch.nn.functional as F +import torch_tensorrt +from parameterized import parameterized +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt.dynamo import partitioning + +from ..testing_utilities import lower_graph_testing + +# Note: the following tests were a part of test_global_partitioning.py and were flaky when +# we ran all the tests. So, the following test cases were separated out in this test_flaky_global_partitioning.py +# The partitioned graphs were different when you ran the graph as a part of test_global_partitioning.py vs when you +# run these tests independently. pytest by default doesn't use parallel execution, so we are not sure why this behavior occurs +# currently. When you run these tests independently, the partitioned graph is structurally correct and is similar to fast partitioning. + + +class TestGlobalPartitioning(TestCase): + def test_partition_partially_supported_multi_op(self): + class PartiallySupportedMultiOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, y): + sum_1 = torch.ops.aten.add.Tensor(x, y) + sum_2 = torch.ops.aten.add.Tensor(x, sum_1) + sum_ = np.sum(sum_1) + np.sum(sum_2) + relu_ = torch.ops.aten.relu.default(sum_) + pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2) + return pow_ + + fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) + partitioned_graph, _ = partitioning.global_partition( + deepcopy(fx_graph), min_block_size=2 + ) + # breakpoint() + self.assertEqual( + len(list(partitioned_graph.named_children())), + 2, + "Unsupported operators interleave supported ones, expected 2 segments", + ) + + def test_partition_partially_supported_with_torch_executed_ops(self): + class PartiallySupportedMultiOp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, y): + sum_1 = torch.ops.aten.add.Tensor(x, y) + sum_2 = torch.ops.aten.add.Tensor(x, sum_1) + sum_ = torch.ops.aten.add.Tensor(sum_1, sum_2) + relu_ = torch.ops.aten.relu.default(sum_) + pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2) + return pow_ + + unexpected_ops = {torch.ops.aten.add.Tensor} + + inputs = [ + torch.randint( + 1, + 10, + (5,), + ), + torch.randint( + 1, + 10, + (5,), + ), + ] + + fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) + ( + unexpected_ops_seen, + _, + partitioned_graphs, + ) = lower_graph_testing( + fx_graph, + inputs, + unexpected_ops=unexpected_ops, + min_block_size=2, + torch_executed_ops={"torch.ops.aten.add.Tensor"}, + testing_partitioning=True, + use_fast_partitioner=False, + ) + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEqual( + len(partitioned_graphs), + 1, + "Without control flow breaks, there should only be a single graph", + ) + self.assertEqual( + len(list(partitioned_graphs[0].named_children())), + 1, + "Certain operators are set to run in Torch, expected 1 segment", + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/partitioning/test_global_partitioning.py b/tests/py/dynamo/partitioning/test_global_partitioning.py index 80b6716d20..887fa35659 100644 --- a/tests/py/dynamo/partitioning/test_global_partitioning.py +++ b/tests/py/dynamo/partitioning/test_global_partitioning.py @@ -117,89 +117,6 @@ def forward(self, x, y): "All operators are supported, there should be one segment", ) - def test_partition_partially_supported_multi_op(self): - class PartiallySupportedMultiOp(torch.nn.Module): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - def forward(self, x, y): - sum_1 = torch.ops.aten.add.Tensor(x, y) - sum_2 = torch.ops.aten.add.Tensor(x, sum_1) - sum_ = np.sum(sum_1) + np.sum(sum_2) - relu_ = torch.ops.aten.relu.default(sum_) - pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2) - return pow_ - - fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) - partitioned_graph, _ = partitioning.global_partition( - deepcopy(fx_graph), min_block_size=2 - ) - self.assertEqual( - len(list(partitioned_graph.named_children())), - 2, - "Unsupported operators interleave supported ones, expected 2 segments", - ) - - def test_partition_partially_supported_with_torch_executed_ops(self): - class PartiallySupportedMultiOp(torch.nn.Module): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - def forward(self, x, y): - sum_1 = torch.ops.aten.add.Tensor(x, y) - sum_2 = torch.ops.aten.add.Tensor(x, sum_1) - sum_ = torch.ops.aten.add.Tensor(sum_1, sum_2) - relu_ = torch.ops.aten.relu.default(sum_) - pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2) - return pow_ - - unexpected_ops = {torch.ops.aten.add.Tensor} - - inputs = [ - torch.randint( - 1, - 10, - (5,), - ), - torch.randint( - 1, - 10, - (5,), - ), - ] - - fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) - ( - unexpected_ops_seen, - _, - partitioned_graphs, - ) = lower_graph_testing( - fx_graph, - inputs, - unexpected_ops=unexpected_ops, - min_block_size=2, - torch_executed_ops={"torch.ops.aten.add.Tensor"}, - testing_partitioning=True, - use_fast_partitioner=False, - ) - - self.assertEqual( - len(unexpected_ops_seen), - 0, - f"The following unexpected ops were encountered: {unexpected_ops_seen}", - ) - - self.assertEqual( - len(partitioned_graphs), - 1, - "Without control flow breaks, there should only be a single graph", - ) - self.assertEqual( - len(list(partitioned_graphs[0].named_children())), - 1, - "Certain operators are set to run in Torch, expected 1 segment", - ) - if __name__ == "__main__": run_tests()