Skip to content

Commit 7088245

Browse files
committed
feat(//py): Inital introduction of the Python API
Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent 83e0ed6 commit 7088245

File tree

8 files changed

+523
-0
lines changed

8 files changed

+523
-0
lines changed

py/BUILD

Whitespace-only changes.

py/requirements.txt

Whitespace-only changes.

py/setup.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from setuptools import setup, Extension, find_packages
2+
from setuptools.command.build_ext import build_ext
3+
import sys
4+
import setuptools
5+
import os
6+
from torch.utils import cpp_extension
7+
from shutil import copyfile
8+
9+
dir_path = os.path.dirname(os.path.realpath(__file__))
10+
11+
__version__ = '0.0.1'
12+
13+
def gen_version_file():
14+
if not os.path.exists(dir_path + '/trtorch/version.py'):
15+
os.mknod(dir_path + '/trtorch/version.py')
16+
17+
with open(dir_path + '/trtorch/version.py', 'w') as f:
18+
print("creating version file")
19+
f.write("__version__ = \"" + __version__ + '\"')
20+
21+
def copy_libtrtorch():
22+
if not os.path.exists(dir_path + '/trtorch/lib'):
23+
os.makedirs(dir_path + '/trtorch/lib')
24+
25+
print("copying library into module")
26+
copyfile(dir_path + "/../bazel-bin/cpp/api/lib/libtrtorch.so", dir_path + '/trtorch/lib/libtrtorch.so')
27+
28+
class DevelopCommand(develop):
29+
description = "Builds the package and symlinks it into the PYTHONPATH"
30+
user_options = develop.user_options + plugins_user_options
31+
32+
def initialize_options(self):
33+
develop.initialize_options(self)
34+
35+
def finalize_options(self):
36+
develop.finalize_options(self)
37+
38+
def run(self):
39+
gen_version_file()
40+
copy_libtrtorch()
41+
develop.run(self)
42+
43+
44+
class InstallCommand(install):
45+
description = "Builds the package"
46+
user_options = install.user_options + plugins_user_options
47+
48+
def initialize_options(self):
49+
install.initialize_options(self)
50+
51+
def finalize_options(self):
52+
install.finalize_options(self)
53+
54+
def run(self):
55+
gen_version_file()
56+
copy_libtrtorch()
57+
install.run(self)
58+
59+
class CleanCommand(Command):
60+
"""Custom clean command to tidy up the project root."""
61+
PY_CLEAN_FILES = ['./build', './dist', './trtorch/__pycache__', './*.pyc', './*.tgz', './*.egg-info']
62+
description = "Command to tidy up the project root"
63+
user_options = []
64+
65+
def initialize_options(self):
66+
pass
67+
68+
def finalize_options(self):
69+
pass
70+
71+
def run(self):
72+
for path_spec in self.PY_CLEAN_FILES:
73+
# Make paths absolute and relative to this path
74+
abs_paths = glob.glob(os.path.normpath(os.path.join(dir_path, path_spec)))
75+
for path in [str(p) for p in abs_paths]:
76+
if not path.startswith(root_dir):
77+
# Die if path in CLEAN_FILES is absolute + outside this directory
78+
raise ValueError("%s is not a path inside %s" % (path, root_dir))
79+
print('Removing %s' % os.path.relpath(path))
80+
shutil.rmtree(path)
81+
82+
ext_modules = [
83+
cpp_extension.CUDAExtension('trtorch._C',
84+
['trtorch/csrc/trtorch_py.cpp'],
85+
library_dirs=[
86+
dir_path + '/trtorch/lib/libtrtorch.so',
87+
dir_path + '/trtorch/lib/'
88+
],
89+
libraries=[
90+
"trtorch"
91+
],
92+
include_dirs=[
93+
dir_path + "/../",
94+
],
95+
extra_compile_args=[
96+
"-D_GLIBCXX_USE_CXX11_ABI=0"
97+
],
98+
extra_link_args=[
99+
"-D_GLIBCXX_USE_CXX11_ABI=0"
100+
"-Wl,--no-as-needed",
101+
"-ltrtorch"
102+
],
103+
undef_macros=[ "NDEBUG" ]
104+
)
105+
]
106+
107+
setup(
108+
name='trtorch',
109+
version=__version__,
110+
author='NVIDIA Corporation.',
111+
author_email='narens@nvidia.com',
112+
url='https://github.com/nvidia/trtorch',
113+
description='A compiler backend for PyTorch JIT targeting NVIDIA GPUs',
114+
long_description='',
115+
ext_modules=ext_modules,
116+
install_requires=['pybind11>=2.4'],
117+
setup_requires=['pybind11>=2.4'],
118+
cmdclass={
119+
'install': InstallCommand,
120+
'clean': CleanCommand,
121+
'develop': DevelopCommand,
122+
'build_ext': cpp_extension.BuildExtension
123+
},
124+
zip_safe=False,
125+
license="BSD-3",
126+
packages=find_packages(),
127+
classifiers=["Intended Audience :: Developers",
128+
"Intended Audience :: Science/Research",
129+
"Operating System :: POSIX :: Linux",
130+
"Programming Language :: C++",
131+
"Programming Language :: Python",
132+
"Programming Language :: Python :: Implementation :: CPython",
133+
"Topic :: Scientific/Engineering",
134+
"Topic :: Scientific/Engineering :: Artifical Intelligence",
135+
"Topic :: Software Development",
136+
"Topic :: Software Developement :: Libraries"],
137+
138+
)

py/trtorch/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import os
2+
import sys
3+
4+
if sys.version_info < (3,):
5+
raise Exception("Python 2 has reached end-of-life and is not supported by TRTorch")
6+
7+
import ctypes
8+
import torch
9+
10+
def _load_trtorch_lib():
11+
lib_name = 'libtrtorch.so'
12+
here = os.path.abspath(__file__)
13+
lib_path = os.path.join(os.path.dirname(here), 'lib', lib_name)
14+
ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
15+
16+
_load_trtorch_lib()
17+
18+
from .version import __version__
19+
#from trtorch import _C
20+
from trtorch.compiler import *
21+
from trtorch.types import *
22+
23+
def test(mod, data):
24+
_C._test(mod._c, data)

py/trtorch/compiler.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
from typing import List, Dict, Any
2+
import torch
3+
import tensorrt as trt
4+
import trtorch._C
5+
from trtorch import types
6+
from .version import __version__
7+
8+
def _supported_input_size_type(input_size: Any) -> bool:
9+
if isinstance(input_size, torch.Size):
10+
return True
11+
elif isinstance(input_size, tuple):
12+
return True
13+
elif isinstance(input_size, list):
14+
return True
15+
else:
16+
raise TypeError("Input sizes for inputs are required to be a List, tuple or torch.Size or a Dict of three sizes (min, opt, max), found type: " + str(type(input_size)))
17+
18+
def _parse_input_sizes(input_sizes: List) -> List:
19+
20+
if any (not isinstance(i, dict) and not _supported_input_size_type(i) for i in input_sizes):
21+
raise KeyError("An input size must either be a static size or a range of three sizes (min, opt, max) as Dict")
22+
23+
parsed_input_sizes = []
24+
for i in input_sizes:
25+
if isinstance(i, dict):
26+
if all (k in i for k in ["min", "opt", "min"]):
27+
in_range = trtorch._C.InputRange()
28+
in_range.min = i["min"]
29+
in_range.opt = i["opt"]
30+
in_range.max = i["max"]
31+
32+
parsed_input_sizes.append(in_range.to_internal_input_range())
33+
34+
elif "opt" in i:
35+
in_range = trtorch._C.InputRange()
36+
in_range.min = i["opt"]
37+
in_range.opt = i["opt"]
38+
in_range.max = i["opt"]
39+
40+
parsed_input_sizes.append(in_range.to_internal_input_range())
41+
42+
else:
43+
raise KeyError("An input size must either be a static size or a range of three sizes (min, opt, max) as Dict")
44+
45+
elif isinstance(i, list):
46+
in_range = trtorch._C.InputRange()
47+
in_range.min = i
48+
in_range.opt = i
49+
in_range.max = i
50+
51+
parsed_input_sizes.append(in_range.to_internal_input_range())
52+
53+
return parsed_input_sizes
54+
55+
def _parse_op_precision(precision: Any) -> types.dtype:
56+
if isinstance(precision, torch.dtype):
57+
if precision == torch.int8:
58+
return types.dtype.int8
59+
elif precision == torch.half:
60+
return types.dtype.half
61+
elif precision == torch.float:
62+
return types.dtype.float
63+
else:
64+
raise TypeError("Provided an unsupported dtype as operating precision (support: int8, half, float), got: " + str(precision))
65+
66+
elif isinstance(precision, types.DataTypes):
67+
return precision
68+
69+
else:
70+
raise TypeError("Op precision type needs to be specified with a torch.dtype or a trtorch.dtype, got: " + str(type(precision)))
71+
72+
def _parse_device_type(device: Any) -> types.DeviceType:
73+
if isinstance(device, torch.device):
74+
if torch.device.type == 'cuda':
75+
return types.DeviceType.gpu
76+
else:
77+
raise TypeError("Valid device choices are GPU (and DLA if on Jetson platforms) however got device type" + str(device.type))
78+
79+
elif isinstance(device, types.DeviceType):
80+
return device
81+
82+
else:
83+
raise TypeError("Device specification must be of type torch.device or trtorch.DeviceType, but got: " + str(type(device)))
84+
85+
def _parse_extra_info(extra_info: Dict[str, Any]) -> trtorch._C._ExtraInfo:
86+
info = trtorch._C._ExtraInfo()
87+
if "input_shapes" not in extra_info and not isinstance(extra_info["input_shapes"], list):
88+
raise KeyError("Input shapes for inputs are required as a List, provided as either a static sizes or a range of three sizes (min, opt, max) as Dict")
89+
90+
info.input_ranges = _parse_input_sizes(extra_info["input_shapes"])
91+
92+
if "op_precision" in extra_info:
93+
info.op_precision = _parse_op_precision(extra_info["op_precision"])
94+
95+
if "refit" in extra_info:
96+
assert isinstance(extra_info["refit"], bool)
97+
info.refit = extra_info["refit"]
98+
99+
if "debug" in extra_info:
100+
assert isinstance(extra_info["debug"], bool)
101+
info.debug = extra_info["debug"]
102+
103+
if "strict_types" in extra_info:
104+
assert isinstance(extra_info["strict_types"], bool)
105+
info.strict_types = extra_info["strict_types"]
106+
107+
if "allow_gpu_fallback" in extra_info:
108+
assert isinstance(extra_info["allow_gpu_fallback"], bool)
109+
info.allow_gpu_fallback = extra_info["allow_gpu_fallback"]
110+
111+
if "device" in extra_info:
112+
info.device = _parse_device_type(extra_info["device"])
113+
114+
if "capability" in extra_info:
115+
assert isinstance(extra_info["capability"], type.EngineCapability)
116+
info.capability = extra_info["capability"]
117+
118+
119+
if "num_min_timing_iters" in extra_info:
120+
assert type(extra_info["num_min_timing_iters"]) is int
121+
info.num_min_timing_iters = extra_info["num_min_timing_iters"]
122+
123+
if "num_avg_timing_iters" in extra_info:
124+
assert type(extra_info["num_avg_timing_iters"]) is int
125+
info.num_avg_timing_iters = extra_info["num_avg_timing_iters"]
126+
127+
if "workspace_size" in extra_info:
128+
assert type(extra_info["workspace_size"]) is int
129+
info.workspace_size = extra_info["workspace_size"]
130+
131+
if "max_batch_size" in extra_info:
132+
assert type(extra_info["max_batch_size"]) is int
133+
info.max_batch_size = extra_info["max_batch_size"]
134+
135+
return info
136+
137+
def compile_module(module: torch.jit.ScriptModule, extra_info: Any) -> torch.jit.ScriptModule:
138+
return module
139+
140+
def convert_graph_to_trt_engine(module: torch.jit.ScriptModule, method_name: str, extra_info: Any) -> str:
141+
return trtorch._C._convert_graph_to_trt_engine(module._c, method_name, _parse_extra_info(extra_info))
142+
143+
def check_method_op_support(module: torch.jit.ScriptModule, method_name: str) -> bool:
144+
return trtorch._C._check_method_op_support(module._c, method_name)
145+
146+
def dump_build_info():
147+
print(get_build_info())
148+
149+
def get_build_info() -> str:
150+
build_info = trtorch._C._get_build_info()
151+
build_info = "TRTorch Version: " + str(__version__) + '\n' + build_info
152+
return build_info
153+

0 commit comments

Comments
 (0)