Skip to content

Commit 858d41d

Browse files
committed
fix: Add temporary workaround for precisions
- torch compile precisions are currently not being reflected due to recent API changes. This update honors specified precisions
1 parent b397eb6 commit 858d41d

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

py/torch_tensorrt/dynamo/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
from typing import Any, Callable, Dict, Optional, Sequence
44

55
import torch
6+
import torch_tensorrt
67
from torch_tensorrt._Device import Device
78
from torch_tensorrt._Input import Input
89
from torch_tensorrt.dynamo import CompilationSettings
10+
from torch_tensorrt.dynamo._defaults import PRECISION
911

1012
from packaging import version
1113

@@ -159,6 +161,28 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
159161
if settings.debug:
160162
logger.setLevel(logging.DEBUG)
161163

164+
# TODO: Remove once Dynamo precisions refactoring is complete
165+
if "enabled_precisions" in kwargs:
166+
enabled_precisions = kwargs["enabled_precisions"]
167+
168+
if (
169+
torch.float16 in enabled_precisions
170+
or torch_tensorrt.dtype.half in enabled_precisions
171+
):
172+
settings.precision = torch.float16
173+
elif (
174+
torch.float32 in enabled_precisions
175+
or torch_tensorrt.dtype.float in enabled_precisions
176+
):
177+
settings.precision = torch.float32
178+
elif len(enabled_precisions) == 0:
179+
logger.info(f"No precision specified, defaulting to {PRECISION}")
180+
settings.precision = PRECISION
181+
else:
182+
raise ValueError(
183+
f"Precision {enabled_precisions} not supported in the Dynamo Path"
184+
)
185+
162186
# Parse input runtime specification
163187
settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime)
164188

0 commit comments

Comments
 (0)