|
1 |
| -import torch |
2 |
| -import logging |
3 |
| -import collections.abc |
4 |
| -import torch_tensorrt |
5 |
| -from functools import partial |
6 |
| - |
7 |
| -from typing import Any, Optional, Sequence |
8 |
| -from torch_tensorrt import EngineCapability, Device |
9 |
| -from torch_tensorrt.fx.utils import LowerPrecision |
10 |
| - |
11 |
| -from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device |
12 |
| -from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend |
13 |
| -from torch_tensorrt.dynamo._defaults import ( |
14 |
| - PRECISION, |
15 |
| - DEBUG, |
16 |
| - WORKSPACE_SIZE, |
17 |
| - MIN_BLOCK_SIZE, |
18 |
| - PASS_THROUGH_BUILD_FAILURES, |
19 |
| - MAX_AUX_STREAMS, |
20 |
| - VERSION_COMPATIBLE, |
21 |
| - OPTIMIZATION_LEVEL, |
22 |
| - USE_PYTHON_RUNTIME, |
23 |
| -) |
24 |
| - |
25 |
| - |
26 |
| -logger = logging.getLogger(__name__) |
27 |
| - |
28 |
| - |
29 |
| -def compile( |
30 |
| - gm: torch.nn.Module, |
31 |
| - inputs: Any, |
32 |
| - *, |
33 |
| - device=Device._current_device(), |
34 |
| - disable_tf32=False, |
35 |
| - sparse_weights=False, |
36 |
| - enabled_precisions=set(), |
37 |
| - refit=False, |
38 |
| - debug=DEBUG, |
39 |
| - capability=EngineCapability.default, |
40 |
| - num_avg_timing_iters=1, |
41 |
| - workspace_size=WORKSPACE_SIZE, |
42 |
| - dla_sram_size=1048576, |
43 |
| - dla_local_dram_size=1073741824, |
44 |
| - dla_global_dram_size=536870912, |
45 |
| - calibrator=None, |
46 |
| - truncate_long_and_double=False, |
47 |
| - require_full_compilation=False, |
48 |
| - min_block_size=MIN_BLOCK_SIZE, |
49 |
| - torch_executed_ops=[], |
50 |
| - torch_executed_modules=[], |
51 |
| - pass_through_build_failures=PASS_THROUGH_BUILD_FAILURES, |
52 |
| - max_aux_streams=MAX_AUX_STREAMS, |
53 |
| - version_compatible=VERSION_COMPATIBLE, |
54 |
| - optimization_level=OPTIMIZATION_LEVEL, |
55 |
| - use_python_runtime=USE_PYTHON_RUNTIME, |
56 |
| - **kwargs, |
57 |
| -): |
58 |
| - if debug: |
59 |
| - logger.setLevel(logging.DEBUG) |
60 |
| - |
61 |
| - logger.warn( |
62 |
| - "The Dynamo backend is an experimental feature, for which only the " |
63 |
| - + "following arguments are supported: " |
64 |
| - + "{enabled_precisions, debug, workspace_size, min_block_size, " |
65 |
| - + "torch_executed_ops, pass_through_build_failures}" |
66 |
| - ) |
67 |
| - |
68 |
| - if not isinstance(inputs, collections.abc.Sequence): |
69 |
| - inputs = [inputs] |
70 |
| - |
71 |
| - inputs = prepare_inputs(inputs, prepare_device(device)) |
72 |
| - |
73 |
| - if not isinstance(enabled_precisions, collections.abc.Collection): |
74 |
| - enabled_precisions = [enabled_precisions] |
75 |
| - |
76 |
| - # Parse user-specified enabled precisions |
77 |
| - if ( |
78 |
| - torch.float16 in enabled_precisions |
79 |
| - or torch_tensorrt.dtype.half in enabled_precisions |
80 |
| - ): |
81 |
| - lower_precision = LowerPrecision.FP16 |
82 |
| - elif ( |
83 |
| - torch.float32 in enabled_precisions |
84 |
| - or torch_tensorrt.dtype.float in enabled_precisions |
85 |
| - ): |
86 |
| - lower_precision = LowerPrecision.FP32 |
87 |
| - elif len(enabled_precisions) == 0: |
88 |
| - logger.info(f"No precision specified, defaulting to {PRECISION}") |
89 |
| - lower_precision = PRECISION |
90 |
| - else: |
91 |
| - raise ValueError( |
92 |
| - f"Precision {enabled_precisions} not supported in the Dynamo Path" |
93 |
| - ) |
94 |
| - |
95 |
| - custom_backend = create_backend( |
96 |
| - precision=lower_precision, |
97 |
| - debug=debug, |
98 |
| - workspace_size=workspace_size, |
99 |
| - min_block_size=min_block_size, |
100 |
| - torch_executed_ops=torch_executed_ops, |
101 |
| - pass_through_build_failures=pass_through_build_failures, |
102 |
| - max_aux_streams=max_aux_streams, |
103 |
| - version_compatible=version_compatible, |
104 |
| - optimization_level=optimization_level, |
105 |
| - use_python_runtime=use_python_runtime, |
106 |
| - **kwargs, |
107 |
| - ) |
108 |
| - |
109 |
| - model = torch.compile(gm, backend=custom_backend) |
110 |
| - |
111 |
| - # Ensure compilation occurs by calling the function with provided inputs |
112 |
| - model(*inputs) |
113 |
| - |
114 |
| - return model |
115 |
| - |
116 |
| - |
117 |
| -from torch_tensorrt.fx.utils import LowerPrecision |
118 |
| - |
119 |
| -logger = logging.getLogger(__name__) |
120 |
| - |
121 |
| - |
122 |
| -def create_backend( |
123 |
| - precision: LowerPrecision = PRECISION, |
124 |
| - debug: bool = DEBUG, |
125 |
| - workspace_size: int = WORKSPACE_SIZE, |
126 |
| - min_block_size: int = MIN_BLOCK_SIZE, |
127 |
| - torch_executed_ops: Sequence[str] = set(), |
128 |
| - pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES, |
129 |
| - max_aux_streams: Optional[int] = MAX_AUX_STREAMS, |
130 |
| - version_compatible: bool = VERSION_COMPATIBLE, |
131 |
| - optimization_level: Optional[int] = OPTIMIZATION_LEVEL, |
132 |
| - use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME, |
133 |
| - **kwargs, |
134 |
| -): |
135 |
| - """Create torch.compile backend given specified arguments |
136 |
| -
|
137 |
| - Args: |
138 |
| - precision: Model Layer precision |
139 |
| - debug: Whether to print out verbose debugging information |
140 |
| - workspace_size: Workspace TRT is allowed to use for the module (0 is default) |
141 |
| - min_block_size: Minimum number of operators per TRT-Engine Block |
142 |
| - torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage |
143 |
| - pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False) |
144 |
| - max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine |
145 |
| - version_compatible: Provide version forward-compatibility for engine plan files |
146 |
| - optimization_level: Builder optimization 0-5, higher levels imply longer build time, |
147 |
| - searching for more optimization options. TRT defaults to 3 |
148 |
| - use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime |
149 |
| - based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the |
150 |
| - argument as None |
151 |
| - Returns: |
152 |
| - Backend for torch.compile |
153 |
| - """ |
154 |
| - return partial( |
155 |
| - torch_tensorrt_backend, |
156 |
| - debug=debug, |
157 |
| - precision=precision, |
158 |
| - workspace_size=workspace_size, |
159 |
| - min_block_size=min_block_size, |
160 |
| - torch_executed_ops=torch_executed_ops, |
161 |
| - pass_through_build_failures=pass_through_build_failures, |
162 |
| - max_aux_streams=max_aux_streams, |
163 |
| - version_compatible=version_compatible, |
164 |
| - optimization_level=optimization_level, |
165 |
| - use_python_runtime=use_python_runtime, |
166 |
| - **kwargs, |
167 |
| - ) |
| 1 | +from .backends import torch_tensorrt_backend |
| 2 | +from .compile import compile |
0 commit comments