Skip to content

fix: Add support for truncate_long_and_double in Dynamo [8 / x] #1983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
VERSION_COMPATIBLE = False
OPTIMIZATION_LEVEL = None
USE_PYTHON_RUNTIME = None
TRUNCATE_LONG_AND_DOUBLE = False
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
VERSION_COMPATIBLE,
OPTIMIZATION_LEVEL,
USE_PYTHON_RUNTIME,
TRUNCATE_LONG_AND_DOUBLE,
)


Expand All @@ -26,3 +27,4 @@ class CompilationSettings:
version_compatible: bool = VERSION_COMPATIBLE
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
11 changes: 10 additions & 1 deletion py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
get_submod_inputs,
)
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
from torch_tensorrt.dynamo.conversion import convert_module
from torch_tensorrt.dynamo.conversion import (
convert_module,
repair_long_or_double_inputs,
)

from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler

Expand Down Expand Up @@ -135,6 +138,12 @@ def _compile_module(
partitioned_module, submodule, sample_inputs
)

# Handle long/double inputs if requested by the user
if settings.truncate_long_and_double:
submodule_inputs = repair_long_or_double_inputs(
partitioned_module, submodule, submodule_inputs, name
)

# Create TRT Module from submodule
trt_mod = convert_module(
submodule,
Expand Down
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
VERSION_COMPATIBLE,
OPTIMIZATION_LEVEL,
USE_PYTHON_RUNTIME,
TRUNCATE_LONG_AND_DOUBLE,
)


Expand All @@ -53,7 +54,7 @@ def compile(
dla_local_dram_size=1073741824,
dla_global_dram_size=536870912,
calibrator=None,
truncate_long_and_double=False,
truncate_long_and_double=TRUNCATE_LONG_AND_DOUBLE,
require_full_compilation=False,
min_block_size=MIN_BLOCK_SIZE,
torch_executed_ops=[],
Expand Down Expand Up @@ -109,6 +110,7 @@ def compile(
"version_compatible": version_compatible,
"optimization_level": optimization_level,
"use_python_runtime": use_python_runtime,
"truncate_long_and_double": truncate_long_and_double,
}

settings = CompilationSettings(**compilation_options)
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .trt_interpreter import *
from .conversion import *
from .truncate_long_and_double import repair_long_or_double_inputs
207 changes: 207 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import torch
from torch.fx.node import _get_qualified_name
from typing import Optional, Sequence, Union


def _extract_downstream_get_nodes(
module_node: torch.fx.Node, output_indices: Sequence[int]
) -> Sequence[torch.fx.Node]:
"""Extracts downstream users of a node which get the item at a particular index

Certain module-type nodes have multiple outputs (tuple of outputs). This function
returns downstream nodes which call the _operator.getitem function, which extracts
the element at a particular index in the tuple

Args:
module_node: FX module-type node to analyze
output_index: Indices in the module node output to search for
Returns:
List of nodes which get the item at the specified index in the module node output
"""
get_nodes = []

# Iterate over all downstream users of the node object
for user in module_node.users:
# If the user is a "get" node accessing the specified index, store it
if _get_qualified_name(user.target) == "_operator.getitem" and (
user.args[1] in output_indices
):
get_nodes.append(user)

return get_nodes


def _repair_64bit_input(
gm: torch.fx.GraphModule,
position: int,
submodule_name: str,
submodule_outputs: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]],
dtype: torch.dtype,
):
"""Fixes a single Long/Double input to a TRT-accelerated subgraph

In-Place modifies the provided graph

Inserts a cast to the 32-bit equivalent type for TRT, then if necessary,
inserts an upcast back to the 64-bit type for subsequent Torch operations

Args:
gm: FX GraphModule enclosing the TRT subgraph
position: Index in the submodule inputs at which the long or double input is found
submodule_name: Name of TRT-accelerated subgraph module in FX graph
submodule_outputs: Output tensor(s) of TRT-accelerated subgraph (used for dtypes/structure)
dtype: Data type of tensor at position in submodule (double/long)
"""
assert dtype in (
torch.int64,
torch.float64,
), f"dtype argument must be torch.int64 or torch.float64, got {dtype}"

# Determine target data type in 32 and 64 bit forms
dtype_64bit = dtype
dtype_32bit = torch.int32 if (dtype == torch.int64) else torch.float32

# Find the node representing the submodule in the graph
module_node = None

# Iterate over all nodes in the graph, seeking target module name match
for n in gm.graph.nodes:
if n.op == "call_module" and str(n.target) == submodule_name:
module_node = n
break

if module_node is None:
raise AssertionError(
f"Sought module node {submodule_name}, could not find in graph:\n{gm.graph}"
)

# Extract the 64-bit node of the input
node_64bit = module_node.all_input_nodes[position]

# Prior to the module, insert a cast to the 32-bit equivalent node
with gm.graph.inserting_before(module_node):
node_32bit = gm.graph.call_function(
torch.ops.aten._to_copy.default,
args=(node_64bit,),
kwargs={"dtype": dtype_32bit},
)

# Replace 64-bit input to TRT module with new 32-bit cast node
module_node.replace_input_with(node_64bit, node_32bit)

output_positions_64bit = set()
outputs_list = (
[submodule_outputs]
if isinstance(submodule_outputs, torch.Tensor)
else submodule_outputs
)

# Determine if any outputs of the model are 64-bit type and store their indices
if submodule_outputs is not None:
for output_position, output in enumerate(outputs_list):
if output.dtype == dtype_64bit:
output_positions_64bit.add(output_position)

# Only enter this code block if there exists a 64-bit output
# This implies a cast is needed, since TRT cannot output 64-bit tensors
if output_positions_64bit:
# Determine whther the outputs of the module are tuple-type or not
is_collection_output = False
if isinstance(submodule_outputs, tuple):
is_collection_output = True

if not is_collection_output:
# If the output is a single tensor, insert a cast back to int64
with gm.graph.inserting_after(module_node):
cast_node_64bit = gm.graph.call_function(
torch.ops.aten._to_copy.default,
args=(module_node,),
kwargs={"dtype": dtype_64bit},
)

# Replace all uses of the TRT module (except the cast node) with the 64-bit equivalent
module_node.replace_all_uses_with(
cast_node_64bit, delete_user_cb=lambda user: (user != cast_node_64bit)
)

else:
# If the output is a tuple of tensors, extract downstream users for each 64-bit output
get_nodes = _extract_downstream_get_nodes(
module_node, output_positions_64bit
)

# For each downstream user, append a cast node back to the 64-bit precision
for get_node in get_nodes:
with gm.graph.inserting_after(get_node):
cast_node_64bit = gm.graph.call_function(
torch.ops.aten._to_copy.default,
args=(get_node,),
kwargs={"dtype": torch.int64},
)

get_node.replace_all_uses_with(
cast_node_64bit,
delete_user_cb=lambda user: (user != cast_node_64bit),
)

# Clean up graph and ensure invariants are preserved
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()


def repair_long_or_double_inputs(
parent_graph: torch.fx.GraphModule,
submodule: torch.fx.GraphModule,
submodule_inputs: Sequence[torch.Tensor],
submodule_name: Optional[str] = None,
) -> Sequence[torch.Tensor]:
"""Fixes all Long/Double type inputs to a TRT-accelerated subgraph

In-Place modifies the provided graph

Inserts a cast to the 32-bit equivalent type for TRT, then if necessary,
inserts an upcast back to the 64-bit type for subsequent Torch operations

Args:
parent_graph: FX GraphModule enclosing the TRT subgraph
submodule: Child submodule to repair inputs on
submodule_inputs: Input tensor(s) of TRT-accelerated subgraph (used for dtypes/structure)
submodule_name: Optionally specify the name of the submodule target in the parent graph
Returns:
New submodule inputs, updated accordingly with long/double truncation
"""
num_submodule_inputs = len(submodule_inputs)
repaired_outputs_once = False

# For each input to the TRT subgraph, check if its type is long/double
for position in range(num_submodule_inputs):
param = submodule_inputs[position]

# If the data type of the input is long/double, insert necessary
# casts to replace the operation
if param.dtype in (torch.int64, torch.float64):
# Ensure outputs are only repaired once per submodule to avoid
# unnecessary ops showing up in the graph
if not repaired_outputs_once:
submodule_outputs = submodule(*submodule_inputs)

_repair_64bit_input(
parent_graph,
position,
submodule_name if submodule_name is not None else submodule._get_name(),
None if repaired_outputs_once else submodule_outputs,
param.dtype,
)

repaired_outputs_once = True

# Repair submodule inputs in accordance with inserted casts
dtype_32bit = torch.int32 if (param.dtype == torch.int64) else torch.float32
submodule_inputs = (
submodule_inputs[:position]
+ (param.to(dtype_32bit),)
+ submodule_inputs[position + 1 :]
)

return submodule_inputs
113 changes: 113 additions & 0 deletions tests/py/dynamo/backend/test_backend_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,5 +171,118 @@ def forward(self, x, y):
)


class Test64BitInput(TestCase):
def test_float64_input_full_support(self):
class FullySupportedMultiOp(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.mean.dim(
torch.ops.aten.mul.Tensor(torch.ops.aten.add.Tensor(x, y), 2), [0]
)

fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp())
partitioned_graph = partition(deepcopy(fx_graph), min_block_size=3)

self.assertEquals(
len(list(partitioned_graph.named_children())),
1,
"All operators are supported, there should be one segment",
)

inputs = [
torch.randint(-5, 5, (16, 7), dtype=torch.double).cuda(),
torch.randint(-5, 5, (16, 7), dtype=torch.double).cuda(),
]

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
truncate_long_and_double=True,
debug=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"TRT outputs don't match with the original model.",
)

def test_int64_input_partial_support(self):
class PartiallySupportedMultiOp(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.div.Tensor_mode(
x, torch.ops.aten.add.Tensor(y, y), rounding_mode="floor"
)

fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp())
unexpected_ops = {torch.ops.aten.add.Tensor}

inputs = [
torch.randint(-40, 40, (16, 7, 5), dtype=torch.long).cuda(),
torch.randint(1, 40, (16, 7, 5), dtype=torch.long).cuda(),
]

(unexpected_ops_seen, _, partitioned_graphs,) = lower_graph_testing(
fx_graph,
inputs,
unexpected_ops=unexpected_ops,
min_block_size=1,
torch_executed_ops={"torch.ops.aten.add.Tensor"},
testing_partitioning=True,
)

self.assertEquals(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)
self.assertEquals(
len(partitioned_graphs),
1,
"Without control flow breaks, there should only be a single graph",
)
self.assertEquals(
len(list(partitioned_graphs[0].named_children())),
1,
"Certain operators are set to run in Torch, expected 1 segment",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
truncate_long_and_double=True,
debug=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"TRT outputs don't match with the original model.",
)


if __name__ == "__main__":
run_tests()