-
Notifications
You must be signed in to change notification settings - Fork 355
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
282 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
from ._aten_lowering_pass import * | ||
from ._modify_complex_nodes import modify_complex_nodes | ||
from ._replace_complex_placeholder_to_tuple import replace_complex_placeholder_to_tuple | ||
from .remove_sym_nodes import remove_sym_nodes | ||
from .repair_input_aliasing import repair_input_aliasing |
95 changes: 95 additions & 0 deletions
95
py/torch_tensorrt/dynamo/lowering/passes/_modify_complex_nodes.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import logging | ||
|
||
import torch | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( | ||
clean_up_graph_after_modifications, | ||
) | ||
|
||
|
||
def tensorrt_complex_mul(args0, args1): | ||
args0_real, args0_imag = torch.ops.aten.split.Tensor(args0, 1, -1) | ||
args1_real, args1_imag = torch.ops.aten.split.Tensor(args1, 1, -1) | ||
|
||
args0_real = torch.ops.aten.squeeze.dim(args0_real, -1) | ||
args0_imag = torch.ops.aten.squeeze.dim(args0_imag, -1) | ||
args1_real = torch.ops.aten.squeeze.dim(args1_real, -1) | ||
args1_imag = torch.ops.aten.squeeze.dim(args1_imag, -1) | ||
|
||
complex_mul_real = torch.ops.aten.sub( | ||
torch.ops.aten.mul(args0_real, args1_real), | ||
torch.ops.aten.mul(args0_imag, args1_imag), | ||
) | ||
complex_mul_imag = torch.ops.aten.add( | ||
torch.ops.aten.mul(args0_real, args1_imag), | ||
torch.ops.aten.mul(args0_imag, args1_real), | ||
) | ||
|
||
return torch.ops.aten.stack((complex_mul_real, complex_mul_imag), -1) | ||
|
||
|
||
def remove_complex_real_view_nodes(gm: torch.fx.GraphModule): | ||
modified_graph = False | ||
nodes_to_remove = [] | ||
for node in gm.graph.nodes: | ||
if "view_as_complex" in node.name or "view_as_real" in node.name: | ||
nodes_to_remove.append(node) | ||
|
||
for node in nodes_to_remove: | ||
input_node = node.args[0] if node.args else None | ||
|
||
for other_node in gm.graph.nodes: | ||
new_args = tuple( | ||
input_node if arg is node else arg for arg in other_node.args | ||
) | ||
other_node.args = new_args | ||
|
||
gm.graph.erase_node(node) | ||
modified_graph = True | ||
|
||
if modified_graph: | ||
gm = clean_up_graph_after_modifications(gm) | ||
logger.debug( | ||
f"Graph after removing view_as_complex nodes and view_as_real nodes:\n{gm.graph}" | ||
) | ||
|
||
|
||
def modify_reshape_nodes(gm: torch.fx.GraphModule, complex_nodes): | ||
for node in gm.graph.nodes: | ||
if node in complex_nodes: | ||
# slice and transpose will remain same | ||
if "reshape" in node.name: | ||
new_shape = list(node.args[1]) + [2] | ||
node.args = (node.args[0], tuple(new_shape)) | ||
|
||
|
||
def modify_mul_nodes(gm: torch.fx.GraphModule, complex_nodes): | ||
modified_graph = False | ||
for node in gm.graph.nodes: | ||
if node in complex_nodes: | ||
if "mul" in node.name: | ||
complex_mul_args = (node.args[0], node.args[1]) | ||
with gm.graph.inserting_after(node): | ||
replacement_node = gm.graph.create_node( | ||
op="call_function", | ||
target=tensorrt_complex_mul, | ||
args=complex_mul_args, | ||
) | ||
node.replace_all_uses_with(replacement_node) | ||
replacement_node.meta.update(node.meta) | ||
modified_graph = True | ||
gm.graph.erase_node(node) | ||
|
||
if modified_graph: | ||
gm = clean_up_graph_after_modifications(gm) | ||
logger.debug( | ||
f"Graph after custom complex mul nodes is applied to the graph:\n{gm.graph}" | ||
) | ||
|
||
|
||
def modify_complex_nodes(gm: torch.fx.GraphModule, complex_nodes): | ||
modify_reshape_nodes(gm, complex_nodes) | ||
remove_complex_real_view_nodes(gm) | ||
modify_mul_nodes(gm, complex_nodes) |
112 changes: 112 additions & 0 deletions
112
py/torch_tensorrt/dynamo/lowering/passes/_replace_complex_placeholder_to_tuple.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import logging | ||
|
||
import torch | ||
from torch.fx.node import _get_qualified_name | ||
from torch_tensorrt.dynamo._settings import CompilationSettings | ||
from torch_tensorrt.dynamo.conversion.converter_utils import args_bounds_check | ||
|
||
# dead-code elimination, linting, and recompilation for graph, in-place | ||
from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( | ||
clean_up_graph_after_modifications, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# for now creating this node, but mostly will want to modify this in input | ||
|
||
|
||
def replace_complex_placeholder_to_tuple( | ||
gm: torch.fx.GraphModule, inputListindices | ||
) -> torch.fx.GraphModule: | ||
modified_graph = False | ||
input_arg_list = [f"arg{inputListIndex}_1" for inputListIndex in inputListindices] | ||
for node in gm.graph.nodes: | ||
if node.op == "placeholder" and node.target in input_arg_list: | ||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode | ||
|
||
node_shape = node.meta["val"].size() | ||
new_node_shape = node_shape + (2,) | ||
new_node_dtype = None | ||
if node.meta["val"].dtype == torch.complex64: | ||
new_node_dtype = torch.float32 | ||
else: | ||
new_node_dtype = torch.float64 | ||
fake_mode = FakeTensorMode() | ||
|
||
real_tensor = torch.empty(new_node_shape, dtype=new_node_dtype) | ||
with FakeTensorMode() as fake_mode: | ||
new_placeholder_tuple = fake_mode.from_tensor(real_tensor) | ||
node.meta["val"] = new_placeholder_tuple | ||
modified_graph = True | ||
# propagate the meta data change for the downstream ops | ||
# TODO:to check if this is required in all cases | ||
propogate_shape_change(gm, node, fake_mode) | ||
|
||
# If graph was modified, clean it up | ||
if modified_graph: | ||
gm = clean_up_graph_after_modifications(gm) | ||
logger.debug( | ||
f"Graph after fusing wait_tensor and distributed op tensor:\n{gm.graph}" | ||
) | ||
|
||
return gm | ||
|
||
|
||
def infer_slice_shape(node): | ||
input_shape = node.args[0].meta["val"].shape | ||
slice_args = node.args | ||
dim = slice_args[1] | ||
start = slice_args[2] | ||
end = slice_args[3] | ||
step = args_bounds_check(slice_args, 4, replacement=1) | ||
new_shape = list(input_shape) | ||
new_shape[dim] = (end - start + step - 1) // step | ||
return tuple(new_shape) | ||
|
||
|
||
def infer_reshape_shape(node): | ||
return node.args[1] | ||
|
||
|
||
shape_inference_funcs = { | ||
"torch.ops.aten.slice.Tensor": infer_slice_shape, | ||
"torch.ops.aten.reshape.default": infer_reshape_shape, | ||
} | ||
|
||
shape_inference_funcs = { | ||
"torch.ops.aten.slice.Tensor": infer_slice_shape, | ||
"torch.ops.aten.reshape.default": infer_reshape_shape, | ||
} | ||
|
||
|
||
def propogate_shape_change(node, start_node, fake_mode): | ||
visited_nodes = set() | ||
stack = [start_node] | ||
while stack: | ||
node = stack.pop() | ||
if node in visited_nodes: | ||
continue | ||
visited_nodes.add(node) | ||
update_node_meta(node, fake_mode) | ||
for user in node.users: | ||
if ( | ||
user.op == "call_function" | ||
and _get_qualified_name(user.target) == "torch.ops.aten.mul.Tensor" | ||
): | ||
continue | ||
stack.append(user) | ||
|
||
|
||
def update_node_meta(node, fake_mode): | ||
op_name = node.name | ||
op_target = node.target | ||
|
||
if node.op == "call_function": | ||
op_target = _get_qualified_name(node.target) | ||
|
||
if op_target in shape_inference_funcs: | ||
new_shape = shape_inference_funcs[op_target](node) | ||
real_tensor = torch.empty(new_shape, dtype=node.meta["val"].dtype) | ||
node.meta["val"] = fake_mode.from_tensor(real_tensor) | ||
else: | ||
print("No shape for the inference function", {op_name}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters