From 777d2c512916f4736c952be8c5c16cdf0fe60634 Mon Sep 17 00:00:00 2001 From: Wei Date: Fri, 9 Sep 2022 16:34:36 -0700 Subject: [PATCH 1/3] enable direct call to fx.compile() --- py/torch_tensorrt/fx/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py/torch_tensorrt/fx/__init__.py b/py/torch_tensorrt/fx/__init__.py index c1c42c446f..03eb7174b5 100644 --- a/py/torch_tensorrt/fx/__init__.py +++ b/py/torch_tensorrt/fx/__init__.py @@ -11,5 +11,6 @@ from .input_tensor_spec import generate_input_specs, InputTensorSpec # noqa from .lower_setting import LowerSetting # noqa from .trt_module import TRTModule # noqa +from .lower import compile # usort: skip #noqa logging.basicConfig(level=logging.INFO) From 5a32c295bea521ccdb2138713b53c15aa881a796 Mon Sep 17 00:00:00 2001 From: Wei Date: Fri, 9 Sep 2022 16:35:39 -0700 Subject: [PATCH 2/3] Update lower_example.py --- examples/fx/lower_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fx/lower_example.py b/examples/fx/lower_example.py index 7f3b374f44..cd9215712b 100644 --- a/examples/fx/lower_example.py +++ b/examples/fx/lower_example.py @@ -4,7 +4,7 @@ import torch import torchvision -from torch_tensorrt.fx.lower import compile +from torch_tensorrt.fx import compile from torch_tensorrt.fx.utils import LowerPrecision From 69a3ba1db6f9cd71dc42edef1ba887cd59a018de Mon Sep 17 00:00:00 2001 From: Wei Date: Fri, 9 Sep 2022 16:36:35 -0700 Subject: [PATCH 3/3] Update _compile.py --- py/torch_tensorrt/_compile.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 8b5f235531..18b9901c56 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -7,7 +7,6 @@ from enum import Enum import torch_tensorrt.fx -import torch_tensorrt.fx.lower from torch_tensorrt.fx.utils import LowerPrecision @@ -140,7 +139,7 @@ def compile( else: raise ValueError(f"Precision {enabled_precisions} not supported on FX") - return torch_tensorrt.fx.lower.compile( + return torch_tensorrt.fx.compile( module, inputs, lower_precision=lower_precision,