Skip to content

Commit 1d5b9d0

Browse files
committed
update getting orig graph
1 parent 30aaa8c commit 1d5b9d0

File tree

1 file changed

+6
-28
lines changed

1 file changed

+6
-28
lines changed

py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
77
clean_up_graph_after_modifications,
8-
get_tensor_placeholders,
98
)
109

1110
logger = logging.getLogger(__name__)
@@ -36,34 +35,13 @@ def efficient_attention_replacement() -> (
3635
):
3736
"""Constructs the original and replacement functions for efficient attention"""
3837

39-
# Empty boilerplate function taking in three Tensors and returning one
40-
def boilerplate(
41-
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
42-
) -> torch.Tensor:
43-
...
44-
45-
# Trace boilerplate function and extract placeholder and output nodes
46-
orig = torch.fx.symbolic_trace(boilerplate)
47-
q, k, v = get_tensor_placeholders(orig)
48-
output = [node for node in orig.graph.nodes if node.op == "output"][0]
49-
50-
# Graph types to replace are those which use the _scaled_dot_product_efficient_attention
51-
# function and extract only the first element
52-
with orig.graph.inserting_before(output):
53-
att = orig.graph.call_function(
54-
torch.ops.aten._scaled_dot_product_efficient_attention.default,
55-
args=(q, k, v, None, False),
38+
# Original graph
39+
def orig(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
40+
outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
41+
q, k, v, None, False
5642
)
57-
out = orig.graph.call_function(
58-
operator.getitem,
59-
args=(att, 0),
60-
)
61-
62-
# Assign the output of the graph to be the single getitem output
63-
output.args = (out,)
64-
65-
orig.graph.lint()
66-
orig.recompile()
43+
out = operator.getitem(outputs, 0)
44+
return out
6745

6846
# Replacement graph consists of the functional version of scaled_dot_product_attention
6947
def replacement(

0 commit comments

Comments
 (0)