Skip to content

Commit

Permalink
changes to make Llama example work
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Dec 23, 2024
1 parent bba4153 commit fc74cff
Show file tree
Hide file tree
Showing 7 changed files with 282 additions and 9 deletions.
30 changes: 27 additions & 3 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
from torch_tensorrt.dynamo._compiler import compile_module
from torch_tensorrt.dynamo.lowering import (
get_decompositions,
modify_complex_nodes,
post_lowering,
remove_detach,
remove_sym_nodes,
repair_input_aliasing,
replace_complex_placeholder_to_tuple,
)
from torch_tensorrt.dynamo.utils import (
find_complex_nodes,
parse_dynamo_kwargs,
prepare_inputs,
set_log_level,
Expand Down Expand Up @@ -61,9 +64,15 @@ def aot_torch_tensorrt_aten_backend(
settings_aot_autograd["decompostions"] = get_decompositions(
settings.enable_experimental_decompositions
)
return aot_autograd(fw_compiler=_pretraced_backend_autograd)(
gm, sample_inputs, **settings_aot_autograd
)
# This is added since detach lowering leads to alias nodes
# Error - View operation returned a tensor that is the same as the input base tensor
# torch nop_decompositions in torch/_decomp/decompositions.py
if aten.detach in settings_aot_autograd["decompositions"]:
del settings_aot_autograd["decompositions"][aten.detach]
return aot_autograd(
fw_compiler=_pretraced_backend_autograd,
decompositions=get_decompositions(settings.enable_experimental_decompositions),
)(gm, sample_inputs)


def _pretraced_backend(
Expand Down Expand Up @@ -103,6 +112,16 @@ def _pretraced_backend(
# Remove detach nodes
remove_detach(gm, settings)

complexInputIndices = []
for i, torch_input in enumerate(torch_inputs):
if torch_inputs[i].dtype == torch.complex64:
complexInputIndices.append(i)
torch_input_real = torch_inputs[i].real
torch_input_imaginary = torch_inputs[i].imag
torch_inputs[i] = torch.stack(
(torch_input_real, torch_input_imaginary), dim=-1
)

# Invoke AOTAutograd to translate operators to aten
if settings.use_aot_joint_export:
gm = aot_export_joint_simple(
Expand All @@ -120,6 +139,11 @@ def _pretraced_backend(

logger.debug("Lowered Input graph:\n " + str(gm.graph))

complex_nodes = find_complex_nodes(gm)
if complex_nodes:
replace_complex_placeholder_to_tuple(gm, complexInputIndices)
modify_complex_nodes(gm, complex_nodes)

torchtrt_inputs = prepare_inputs(
torch_inputs, disable_memory_format_check=True
)
Expand Down
7 changes: 1 addition & 6 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
has_static_shapes_in_args,
)
from torch_tensorrt.dynamo.conversion.converter_utils import (
args_bounds_check,
enforce_tensor_types,
get_positive_dim,
is_only_operator_on_placeholder,
Expand All @@ -25,12 +26,6 @@
_LOGGER: logging.Logger = logging.getLogger(__name__)


def args_bounds_check(
args: Tuple[Argument, ...], i: int, replacement: Optional[Any] = None
) -> Any:
return args[i] if len(args) > i and args[i] is not None else replacement


def get_ir(target: Target) -> SourceIR:
target_module = getattr(target, "__module__", "None")
if any(
Expand Down
6 changes: 6 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,3 +913,9 @@ def set_layer_name(
else f"{source_ir}_ops.{target.__name__}"
)
layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]"


def args_bounds_check(
args: Tuple[Argument, ...], i: int, replacement: Optional[Any] = None
) -> Any:
return args[i] if len(args) > i and args[i] is not None else replacement
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from ._aten_lowering_pass import *
from ._modify_complex_nodes import modify_complex_nodes
from ._replace_complex_placeholder_to_tuple import replace_complex_placeholder_to_tuple
from .remove_sym_nodes import remove_sym_nodes
from .repair_input_aliasing import repair_input_aliasing
95 changes: 95 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/_modify_complex_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import logging

import torch

logger = logging.getLogger(__name__)

from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)


def tensorrt_complex_mul(args0, args1):
args0_real, args0_imag = torch.ops.aten.split.Tensor(args0, 1, -1)
args1_real, args1_imag = torch.ops.aten.split.Tensor(args1, 1, -1)

args0_real = torch.ops.aten.squeeze.dim(args0_real, -1)
args0_imag = torch.ops.aten.squeeze.dim(args0_imag, -1)
args1_real = torch.ops.aten.squeeze.dim(args1_real, -1)
args1_imag = torch.ops.aten.squeeze.dim(args1_imag, -1)

complex_mul_real = torch.ops.aten.sub(
torch.ops.aten.mul(args0_real, args1_real),
torch.ops.aten.mul(args0_imag, args1_imag),
)
complex_mul_imag = torch.ops.aten.add(
torch.ops.aten.mul(args0_real, args1_imag),
torch.ops.aten.mul(args0_imag, args1_real),
)

return torch.ops.aten.stack((complex_mul_real, complex_mul_imag), -1)


def remove_complex_real_view_nodes(gm: torch.fx.GraphModule):
modified_graph = False
nodes_to_remove = []
for node in gm.graph.nodes:
if "view_as_complex" in node.name or "view_as_real" in node.name:
nodes_to_remove.append(node)

for node in nodes_to_remove:
input_node = node.args[0] if node.args else None

for other_node in gm.graph.nodes:
new_args = tuple(
input_node if arg is node else arg for arg in other_node.args
)
other_node.args = new_args

gm.graph.erase_node(node)
modified_graph = True

if modified_graph:
gm = clean_up_graph_after_modifications(gm)
logger.debug(
f"Graph after removing view_as_complex nodes and view_as_real nodes:\n{gm.graph}"
)


def modify_reshape_nodes(gm: torch.fx.GraphModule, complex_nodes):
for node in gm.graph.nodes:
if node in complex_nodes:
# slice and transpose will remain same
if "reshape" in node.name:
new_shape = list(node.args[1]) + [2]
node.args = (node.args[0], tuple(new_shape))


def modify_mul_nodes(gm: torch.fx.GraphModule, complex_nodes):
modified_graph = False
for node in gm.graph.nodes:
if node in complex_nodes:
if "mul" in node.name:
complex_mul_args = (node.args[0], node.args[1])
with gm.graph.inserting_after(node):
replacement_node = gm.graph.create_node(
op="call_function",
target=tensorrt_complex_mul,
args=complex_mul_args,
)
node.replace_all_uses_with(replacement_node)
replacement_node.meta.update(node.meta)
modified_graph = True
gm.graph.erase_node(node)

if modified_graph:
gm = clean_up_graph_after_modifications(gm)
logger.debug(
f"Graph after custom complex mul nodes is applied to the graph:\n{gm.graph}"
)


def modify_complex_nodes(gm: torch.fx.GraphModule, complex_nodes):
modify_reshape_nodes(gm, complex_nodes)
remove_complex_real_view_nodes(gm)
modify_mul_nodes(gm, complex_nodes)
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import logging

import torch
from torch.fx.node import _get_qualified_name
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion.converter_utils import args_bounds_check

# dead-code elimination, linting, and recompilation for graph, in-place
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)

# for now creating this node, but mostly will want to modify this in input


def replace_complex_placeholder_to_tuple(
gm: torch.fx.GraphModule, inputListindices
) -> torch.fx.GraphModule:
modified_graph = False
input_arg_list = [f"arg{inputListIndex}_1" for inputListIndex in inputListindices]
for node in gm.graph.nodes:
if node.op == "placeholder" and node.target in input_arg_list:
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode

node_shape = node.meta["val"].size()
new_node_shape = node_shape + (2,)
new_node_dtype = None
if node.meta["val"].dtype == torch.complex64:
new_node_dtype = torch.float32
else:
new_node_dtype = torch.float64
fake_mode = FakeTensorMode()

real_tensor = torch.empty(new_node_shape, dtype=new_node_dtype)
with FakeTensorMode() as fake_mode:
new_placeholder_tuple = fake_mode.from_tensor(real_tensor)
node.meta["val"] = new_placeholder_tuple
modified_graph = True
# propagate the meta data change for the downstream ops
# TODO:to check if this is required in all cases
propogate_shape_change(gm, node, fake_mode)

# If graph was modified, clean it up
if modified_graph:
gm = clean_up_graph_after_modifications(gm)
logger.debug(
f"Graph after fusing wait_tensor and distributed op tensor:\n{gm.graph}"
)

return gm


def infer_slice_shape(node):
input_shape = node.args[0].meta["val"].shape
slice_args = node.args
dim = slice_args[1]
start = slice_args[2]
end = slice_args[3]
step = args_bounds_check(slice_args, 4, replacement=1)
new_shape = list(input_shape)
new_shape[dim] = (end - start + step - 1) // step
return tuple(new_shape)


def infer_reshape_shape(node):
return node.args[1]


shape_inference_funcs = {
"torch.ops.aten.slice.Tensor": infer_slice_shape,
"torch.ops.aten.reshape.default": infer_reshape_shape,
}

shape_inference_funcs = {
"torch.ops.aten.slice.Tensor": infer_slice_shape,
"torch.ops.aten.reshape.default": infer_reshape_shape,
}


def propogate_shape_change(node, start_node, fake_mode):
visited_nodes = set()
stack = [start_node]
while stack:
node = stack.pop()
if node in visited_nodes:
continue
visited_nodes.add(node)
update_node_meta(node, fake_mode)
for user in node.users:
if (
user.op == "call_function"
and _get_qualified_name(user.target) == "torch.ops.aten.mul.Tensor"
):
continue
stack.append(user)


def update_node_meta(node, fake_mode):
op_name = node.name
op_target = node.target

if node.op == "call_function":
op_target = _get_qualified_name(node.target)

if op_target in shape_inference_funcs:
new_shape = shape_inference_funcs[op_target](node)
real_tensor = torch.empty(new_shape, dtype=node.meta["val"].dtype)
node.meta["val"] = fake_mode.from_tensor(real_tensor)
else:
print("No shape for the inference function", {op_name})
39 changes: 39 additions & 0 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,3 +780,42 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]
f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple/list of torch.fx.node.Node"
)
return output_dtypes


def find_complex_nodes(gm: torch.fx.GraphModule):
complex_nodes = []
complexNodes = {}
for node in gm.graph.nodes:
if is_node_complex(node, complexNodes):
complex_nodes.append(node)
return complex_nodes


def is_node_complex(node: torch.fx.Node, complexNodes):
if not isinstance(node, torch.fx.Node):
return False
if node.name in complexNodes:
return True
if node.op == "call_function" and node.args is not None:
for arg in node.args:
if isinstance(arg, int):
continue
elif isinstance(arg, (list, tuple)):
for eachNode in arg:
if is_node_complex(eachNode, complexNodes):
complexNodes[node.name] = True
return True

elif hasattr(arg, "meta") and "val" in arg.meta:
if isinstance(arg.meta["val"], (list, tuple)):
for eachFakeTensorMeta in arg.meta["val"]:
if eachFakeTensorMeta.dtype in (
torch.complex64,
torch.complex128,
):
complexNodes[node.name] = True
return True
elif arg.meta["val"].dtype in (torch.complex64, torch.complex128):
complexNodes[node.name] = True
return True
return False

0 comments on commit fc74cff

Please sign in to comment.