Skip to content

feat: Refactor FX APIs under dynamo namespace for parity with TS APIs #1807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 48 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
f1098f2
feat: Add sample torch.compile backend for tensorrt aten path
gs-olive Mar 20, 2023
243bf9b
Add decompositions to aot call
gs-olive Mar 21, 2023
76fd3c8
Mark FX2TRT converter as fake tensor unsupported
gs-olive Mar 27, 2023
6a8102c
Minor naming bugfix
gs-olive Mar 29, 2023
5dd1a50
feat: Initial refactoring of fx2trt in dynamo namespace
peri044 Apr 5, 2023
9ce5746
chore: Initial refactoring of FX tests in dynamo namespace
peri044 Apr 5, 2023
80828fb
chore: Initial refactoring of TS changes to unify with FX backend in …
peri044 Apr 5, 2023
b6338a0
chore: Add dynamo backend tests to CircleCI
peri044 Apr 5, 2023
c4a03ff
chore: Linter fixes
peri044 Apr 5, 2023
f8ad31a
chore: add missing ts_input.py file
peri044 Apr 5, 2023
3cbf72c
chore: refactoring
peri044 Apr 6, 2023
516248b
fix: Revamp implementation and replace inplace ops
gs-olive Apr 6, 2023
152bf43
Undo changes to aten tracer
gs-olive Apr 6, 2023
cb6e946
chore: Fix tests and remove implicit batch dim support
peri044 Apr 6, 2023
6153d21
chore: fix tests
peri044 Apr 6, 2023
5ad35e3
chore: remove softmax test using implicit dim
peri044 Apr 6, 2023
f76d2b6
fix: Remove unmodified files from FX, add device support
peri044 Apr 6, 2023
0c5befd
fix: Refactor backend, add sample args
gs-olive Apr 7, 2023
a4047d2
chore: refactoring
peri044 Apr 7, 2023
cd4660d
chore: refactoring
peri044 Apr 7, 2023
b647b5d
chore: Linter fixes
peri044 Apr 7, 2023
685bba1
chore: add workspace_size, disable_tf32, sparse_weights settings
peri044 Apr 7, 2023
5cbf46d
chore: Linter fixes
peri044 Apr 7, 2023
2479300
feat: Add new `convert_module` function
gs-olive Apr 7, 2023
eea3884
fix: Improve `torch_tensorrt` Dynamo path
gs-olive Apr 8, 2023
aa0dda8
fix: Move key functions, fix bugs
gs-olive Apr 8, 2023
a6d3a64
chore: refactor device related code
peri044 Apr 10, 2023
5390259
chore: Linter fixes
peri044 Apr 10, 2023
87e4c77
chore: linter fixes
peri044 Apr 10, 2023
a12141c
chore: Fix dynamo tests
peri044 Apr 10, 2023
6d2e01a
fix: Add test cases and improve backend
gs-olive Apr 11, 2023
48618f4
chore: Fix device dict
peri044 Apr 11, 2023
d890b7d
chore: Nest it to dynamo/fx_ts_compat
peri044 Apr 11, 2023
8b15cdc
chore: Update setup.py
peri044 Apr 11, 2023
8f42a18
chore: Rename ir to fx_ts_compat
peri044 Apr 12, 2023
d49b46c
chore: Fix import
peri044 Apr 12, 2023
cf5bb20
chore: Linter fixes
peri044 Apr 12, 2023
226cc79
fix: Reorganize, add tests
gs-olive Apr 12, 2023
b9433ec
Merge remote-tracking branch 'dynamo_changes' into sample_backend
gs-olive Apr 12, 2023
dc7dd04
Update conftest.py
gs-olive Apr 12, 2023
33255de
Merge pull request #1751 from gs-olive/sample_backend
peri044 Apr 13, 2023
0d68a47
chore: Confine the testing to input and core level
peri044 Apr 14, 2023
a9b0711
chore: Modify dynamo fx_ts_compat tests
peri044 Apr 14, 2023
b387dd5
chore: Modify circle ci to reduce tests
peri044 Apr 14, 2023
13caff0
chore: Refactor code
peri044 Apr 19, 2023
2addf5e
chore: Linter fixes
peri044 Apr 19, 2023
fb41cf7
fix: Add test suite for torch.compile backend (#1849)
gs-olive Apr 26, 2023
b8cd7c3
Improve warning wording
gs-olive May 2, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ commands:
- store_artifacts:
path: /tmp/testlogs

# =================== FX tests start ======================== #
test-fx_core:
description: "Test the fx core"
steps:
Expand Down Expand Up @@ -707,6 +708,61 @@ commands:
- store_artifacts:
path: /tmp/testlogs

# =================== FX tests end ======================== #

# =================== Dynamo tests start ======================== #
test-dynamo-fx_ts:
description: "Test the Dynamo fx_ts_compat path"
steps:
- run:
name: Run Dynamo fx_ts_compat core tests
command: |
cd py/torch_tensorrt/dynamo/fx_ts_compat/test
pushd core/
pytest --junitxml=/tmp/artifacts/test_results/dynamo/fx_ts_compat/test_results.xml
popd

- store_test_results:
path: /tmp/artifacts
- store_artifacts:
path: /tmp/testlogs

test-dynamo-torch_compile-core:
description: "Test the Dynamo torch_compile path"
steps:
- run:
name: Run Dynamo torch_compile core tests
command: |
cd py/torch_tensorrt/dynamo/torch_compile
pushd test/
pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml
popd

- store_test_results:
path: /tmp/artifacts
- store_artifacts:
path: /tmp/testlogs

test-dynamo-torch_compile:
description: "Test the Dynamo torch_compile path"
steps:
- run:
name: Run Dynamo torch_compile E2E tests
command: |
cd py/torch_tensorrt/dynamo/
pushd test/
pip3 install timm
pip3 install transformers
pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml --ir torch_compile
popd

- store_test_results:
path: /tmp/artifacts
- store_artifacts:
path: /tmp/testlogs

# =================== Dynamo tests end ======================== #

# Define a job to be invoked later in a workflow.
# See: https://circleci.com/docs/2.0/configuration-reference/#jobs
jobs:
Expand Down Expand Up @@ -883,6 +939,39 @@ jobs:
- dump-test-env
- test-fx-no-aten

test-py-dynamo-x86_64-linux:
parameters:
torch-build:
type: string
torch-build-index:
type: string
trt-version-long:
type: string
machine:
image: ubuntu-2004-cuda-11.4:202110-01
resource_class: gpu.nvidia.large
steps:
- checkout
- attach_workspace:
at: /tmp/dist/
- install-torch-from-index:
torch-build: << parameters.torch-build >>
torch-build-index: << parameters.torch-build-index >>
- create-py-env:
trt-version-long: << parameters.trt-version-long >>
- install-cudnn
# - run:
# name: "Set LD_LIBRARY_PATH path to include the installed CUDNN"
# command: export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH
- run:
name: "Install torch-tensorrt"
command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl
# We install torch after torch-trt because pip automatically enforces the version constraint otherwise
- dump-test-env
- test-dynamo-torch_compile
- test-dynamo-torch_compile-core
- test-dynamo-fx_ts

package-x86_64-linux:
parameters:
enabled:
Expand Down Expand Up @@ -1261,6 +1350,13 @@ workflows:
requires:
- build-x86_64-linux

- test-py-dynamo-x86_64-linux:
torch-build: << pipeline.parameters.torch-build >>
torch-build-index: << pipeline.parameters.torch-build-index >>
trt-version-long: << pipeline.parameters.trt-version-long >>
requires:
- build-x86_64-linux

- build-x86_64-linux:
name: build-x86_64-linux-legacy
torch-build: << pipeline.parameters.torch-build-legacy >>
Expand Down Expand Up @@ -1328,6 +1424,13 @@ workflows:
requires:
- package-x86_64-linux

- test-py-dynamo-x86_64-linux:
torch-build: << pipeline.parameters.torch-build >>
torch-build-index: << pipeline.parameters.torch-build-index >>
trt-version-long: << pipeline.parameters.trt-version-long >>
requires:
- package-x86_64-linux

on-push:
jobs:
- build-x86_64-linux:
Expand Down Expand Up @@ -1357,6 +1460,13 @@ workflows:
requires:
- build-x86_64-linux

- test-py-dynamo-x86_64-linux:
torch-build: << pipeline.parameters.torch-build >>
torch-build-index: << pipeline.parameters.torch-build-index >>
trt-version-long: << pipeline.parameters.trt-version-long >>
requires:
- build-x86_64-linux

- build-x86_64-linux-cmake:
torch-build: << pipeline.parameters.torch-build >>
torch-build-index: << pipeline.parameters.torch-build-index >>
Expand Down
67 changes: 41 additions & 26 deletions py/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,10 @@ def run(self):
"torch_tensorrt.fx.tools",
"torch_tensorrt.fx.tracer.acc_tracer",
"torch_tensorrt.fx.tracer.dispatch_tracer",
"torch_tensorrt.dynamo",
"torch_tensorrt.dynamo.fx_ts_compat",
"torch_tensorrt.dynamo.fx_ts_compat.passes",
"torch_tensorrt.dynamo.fx_ts_compat.tools",
]
package_dir = {
"torch_tensorrt.fx": "torch_tensorrt/fx",
Expand All @@ -364,11 +368,47 @@ def run(self):
"torch_tensorrt.fx.tools": "torch_tensorrt/fx/tools",
"torch_tensorrt.fx.tracer.acc_tracer": "torch_tensorrt/fx/tracer/acc_tracer",
"torch_tensorrt.fx.tracer.dispatch_tracer": "torch_tensorrt/fx/tracer/dispatch_tracer",
"torch_tensorrt.dynamo": "torch_tensorrt/dynamo",
"torch_tensorrt.dynamo.fx_ts_compat": "torch_tensorrt/dynamo/fx_ts_compat",
"torch_tensorrt.dynamo.fx_ts_compat.passes": "torch_tensorrt/dynamo/fx_ts_compat/passes",
"torch_tensorrt.dynamo.fx_ts_compat.tools": "torch_tensorrt/dynamo/fx_ts_compat/tools",
}

with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()

if FX_ONLY:
package_data_list = [
"_Input.py",
]
else:
package_data_list = [
"lib/*",
"include/torch_tensorrt/*.h",
"include/torch_tensorrt/core/*.h",
"include/torch_tensorrt/core/conversion/*.h",
"include/torch_tensorrt/core/conversion/conversionctx/*.h",
"include/torch_tensorrt/core/conversion/converters/*.h",
"include/torch_tensorrt/core/conversion/evaluators/*.h",
"include/torch_tensorrt/core/conversion/tensorcontainer/*.h",
"include/torch_tensorrt/core/conversion/var/*.h",
"include/torch_tensorrt/core/ir/*.h",
"include/torch_tensorrt/core/lowering/*.h",
"include/torch_tensorrt/core/lowering/passes/*.h",
"include/torch_tensorrt/core/partitioning/*.h",
"include/torch_tensorrt/core/partitioning/segmentedblock/*.h",
"include/torch_tensorrt/core/partitioning/partitioninginfo/*.h",
"include/torch_tensorrt/core/partitioning/partitioningctx/*.h",
"include/torch_tensorrt/core/plugins/*.h",
"include/torch_tensorrt/core/plugins/impl/*.h",
"include/torch_tensorrt/core/runtime/*.h",
"include/torch_tensorrt/core/util/*.h",
"include/torch_tensorrt/core/util/logging/*.h",
"bin/*",
"BUILD",
"WORKSPACE",
]

setup(
name="torch_tensorrt",
version=__version__,
Expand Down Expand Up @@ -412,32 +452,7 @@ def run(self):
python_requires=">=3.7",
include_package_data=True,
package_data={
"torch_tensorrt": [
"lib/*",
"include/torch_tensorrt/*.h",
"include/torch_tensorrt/core/*.h",
"include/torch_tensorrt/core/conversion/*.h",
"include/torch_tensorrt/core/conversion/conversionctx/*.h",
"include/torch_tensorrt/core/conversion/converters/*.h",
"include/torch_tensorrt/core/conversion/evaluators/*.h",
"include/torch_tensorrt/core/conversion/tensorcontainer/*.h",
"include/torch_tensorrt/core/conversion/var/*.h",
"include/torch_tensorrt/core/ir/*.h",
"include/torch_tensorrt/core/lowering/*.h",
"include/torch_tensorrt/core/lowering/passes/*.h",
"include/torch_tensorrt/core/partitioning/*.h",
"include/torch_tensorrt/core/partitioning/segmentedblock/*.h",
"include/torch_tensorrt/core/partitioning/partitioninginfo/*.h",
"include/torch_tensorrt/core/partitioning/partitioningctx/*.h",
"include/torch_tensorrt/core/plugins/*.h",
"include/torch_tensorrt/core/plugins/impl/*.h",
"include/torch_tensorrt/core/runtime/*.h",
"include/torch_tensorrt/core/util/*.h",
"include/torch_tensorrt/core/util/logging/*.h",
"bin/*",
"BUILD",
"WORKSPACE",
],
"torch_tensorrt": package_data_list,
},
exclude_package_data={
"": ["*.cpp"],
Expand Down
34 changes: 24 additions & 10 deletions py/torch_tensorrt/_Device.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import torch

from torch_tensorrt import _enums
# from torch_tensorrt import _enums
import tensorrt as trt
from torch_tensorrt import logging
from torch_tensorrt import _C

import warnings

try:
from torch_tensorrt import _C
except:
warnings.warn(
"Unable to import torchscript frontend core and torch-tensorrt runtime. Some dependent features may be unavailable."
)


class Device(object):
"""
Expand Down Expand Up @@ -51,7 +57,7 @@ def __init__(self, *args, **kwargs):
)
else:
(self.device_type, id) = Device._parse_device_str(args[0])
if self.device_type == _enums.DeviceType.GPU:
if self.device_type == trt.DeviceType.GPU:
self.gpu_id = id
else:
self.dla_core = id
Expand All @@ -64,7 +70,7 @@ def __init__(self, *args, **kwargs):
elif len(args) == 0:
if "gpu_id" in kwargs or "dla_core" in kwargs:
if "dla_core" in kwargs:
self.device_type = _enums.DeviceType.DLA
self.device_type = trt.DeviceType.DLA
self.dla_core = kwargs["dla_core"]
if "gpu_id" in kwargs:
self.gpu_id = kwargs["gpu_id"]
Expand All @@ -76,7 +82,7 @@ def __init__(self, *args, **kwargs):
)
else:
self.gpu_id = kwargs["gpu_id"]
self.device_type = _enums.DeviceType.GPU
self.device_type = trt.DeviceType.GPU
else:
raise ValueError(
"Either gpu_id or dla_core or both must be defined if no string with device specs is provided as an arg"
Expand All @@ -97,15 +103,23 @@ def __init__(self, *args, **kwargs):
def __str__(self) -> str:
return (
"Device(type={}, gpu_id={}".format(self.device_type, self.gpu_id) + ")"
if self.device_type == _enums.DeviceType.GPU
if self.device_type == trt.DeviceType.GPU
else ", dla_core={}, allow_gpu_fallback={}".format(
self.dla_core, self.allow_gpu_fallback
)
)

def _to_internal(self) -> _C.Device:
internal_dev = _C.Device()
internal_dev.device_type = self.device_type
if self.device_type == trt.DeviceType.GPU:
internal_dev.device_type = _C.DeviceType.GPU
elif self.device_type == trt.DeviceType.DLA:
internal_dev.device_type = _C.DeviceType.DLA
else:
raise ValueError(
"Invalid DeviceType detected while parsing the Device class"
)

internal_dev.gpu_id = self.gpu_id
internal_dev.dla_core = self.dla_core
internal_dev.allow_gpu_fallback = self.allow_gpu_fallback
Expand Down Expand Up @@ -136,6 +150,6 @@ def _parse_device_str(s):
s = s.lower()
spec = s.split(":")
if spec[0] == "gpu" or spec[0] == "cuda":
return (_enums.DeviceType.GPU, int(spec[1]))
return (trt.DeviceType.GPU, int(spec[1]))
elif spec[0] == "dla":
return (_enums.DeviceType.DLA, int(spec[1]))
return (trt.DeviceType.DLA, int(spec[1]))
Loading