Skip to content

feat: Add tensor type enforcement for converters #2324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/_ConversionContext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from dataclasses import dataclass, field

from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.fx.types import TRTNetwork


@dataclass
class ConversionContext:
"""Class representing the context for conversion of a particular network

Args:
net: TensorRT Network being built
compilation_settings: Settings selected by the user for compilation
"""

net: TRTNetwork
compilation_settings: CompilationSettings = field(
default_factory=CompilationSettings
)
63 changes: 42 additions & 21 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
from torch.fx.passes.shape_prop import TensorMetadata
from torch.utils._python_dispatch import _disable_current_modes
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo.conversion.converter_utils import get_node_name
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_registry import CallingConvention
from torch_tensorrt.dynamo.conversion.converter_utils import (
get_node_name,
get_trt_tensor,
)
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter

Expand Down Expand Up @@ -46,6 +52,7 @@ def __init__(
input_specs: List[Input],
logger_level: trt.ILogger.Severity = trt.ILogger.Severity.WARNING,
output_dtypes: Optional[List[torch.dtype]] = None,
compilation_settings: CompilationSettings = CompilationSettings(),
):
super().__init__(module)

Expand All @@ -59,7 +66,9 @@ def __init__(
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
flag |= EXPLICIT_BATCH

self.network = self.builder.create_network(flag)
self.ctx = ConversionContext(
self.builder.create_network(flag), compilation_settings
)

missing_ops = self.validate_conversion()
if missing_ops:
Expand Down Expand Up @@ -95,14 +104,14 @@ def validate_conversion(self) -> Set[str]:
missing_converters: Set[str] = set()

for node in self.module.graph.nodes:
if node.op == "call_function" and not CONVERTERS.get(node):
if node.op == "call_function" and CONVERTERS.get(node) is None:
missing_converters.add(f"{node.op} {_get_qualified_name(node.target)}")
elif node.op == "call_method" and not CONVERTERS.get(node):
elif node.op == "call_method" and CONVERTERS.get(node) is None:
missing_converters.add(f"{node.op} torch.Tensor.{node.target}")
elif node.op == "call_module":
submod = self.fetch_attr(node.target)
submod_type = getattr(submod, "_base_class_origin", type(submod))
if not CONVERTERS.get(node):
if CONVERTERS.get(node) is None:
missing_converters.add(f"{node.op} {torch.typename(submod_type)}")

return missing_converters
Expand Down Expand Up @@ -221,7 +230,7 @@ def run(
if tactic_sources is not None:
builder_config.set_tactic_sources(tactic_sources=tactic_sources)

engine = self.builder.build_engine(self.network, builder_config)
engine = self.builder.build_engine(self.ctx.net, builder_config)
assert engine

serialized_cache = (
Expand Down Expand Up @@ -291,7 +300,7 @@ def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor:
f"Unable to access shape spec for input: {target} (got: {current_input})"
)

return self.network.add_input(
return self.ctx.net.add_input(
name=target,
shape=tuple(shape),
dtype=unified_dtype_converter(current_input.torch_dtype, Frameworks.TRT),
Expand All @@ -303,30 +312,40 @@ def call_module(
assert isinstance(target, str)
submod = self.fetch_attr(target)
submod_type = getattr(submod, "_base_class_origin", type(submod))
converter = CONVERTERS.get(self._cur_node)
converter_packet = CONVERTERS.get(self._cur_node)

if not converter:
if converter_packet is None:
raise UnsupportedOperatorException(
f"Conversion of module of type {submod_type} not currently supported!"
)

converter, calling_convention = converter_packet

assert self._cur_node_name is not None
return converter(self.network, submod, args, kwargs, self._cur_node_name)
if calling_convention is CallingConvention.LEGACY:
return converter(self.ctx.net, submod, args, kwargs, self._cur_node_name)
else:
return converter(self.ctx, submod, args, kwargs, self._cur_node_name)

def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
# TODO: Why is this stateful? We should be able to take in the inputs
converter = CONVERTERS.get(self._cur_node)
if not converter:
converter_packet = CONVERTERS.get(self._cur_node)
if converter_packet is None:
raise UnsupportedOperatorException(
f"Conversion of function {torch.typename(target)} not currently supported!"
)

converter, calling_convention = converter_packet

assert self._cur_node_name is not None
return converter(self.network, target, args, kwargs, self._cur_node_name)
if calling_convention is CallingConvention.LEGACY:
return converter(self.ctx.net, target, args, kwargs, self._cur_node_name)
else:
return converter(self.ctx, target, args, kwargs, self._cur_node_name)

def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:
with _disable_current_modes():
from torch_tensorrt.fx.converters import to_numpy
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy

frozen_attr = self.fetch_attr(target)

Expand All @@ -341,15 +360,19 @@ def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:

def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
assert isinstance(target, str)
converter = CONVERTERS.get(self._cur_node)
converter_packet = CONVERTERS.get(self._cur_node)

if not converter:
if converter_packet is None:
raise UnsupportedOperatorException(
f"Conversion of method {target} not currently supported!"
)
converter, calling_convention = converter_packet

assert self._cur_node_name is not None
return converter(self.network, target, args, kwargs, self._cur_node_name)
if calling_convention is CallingConvention.LEGACY:
return converter(self.ctx.net, target, args, kwargs, self._cur_node_name)
else:
return converter(self.ctx, target, args, kwargs, self._cur_node_name)

def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
assert len(args) == 1
Expand All @@ -361,12 +384,10 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
outputs = (args[0],)

for output_idx in range(len(outputs)):
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor

output = outputs[output_idx]

if not isinstance(output, trt.tensorrt.ITensor):
new_output = get_trt_tensor(self.network, output, target)
new_output = get_trt_tensor(self.ctx, output, target)
outputs = (
outputs[:output_idx] + (new_output,) + outputs[output_idx + 1 :]
)
Expand Down Expand Up @@ -400,7 +421,7 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
output_bool = False
name = f"output{i}"
output.name = name
self.network.mark_output(output)
self.ctx.net.mark_output(output)
if output_bool:
output.dtype = trt.bool
elif self.output_dtypes is not None:
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._ConversionContext import ConversionContext
from ._TRTInterpreter import * # noqa: F403
from .aten_ops_converters import * # noqa: F403
from .conversion import * # noqa: F403
Expand Down
Loading