|
5 | 5 | import torch
|
6 | 6 | from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
|
7 | 7 | clean_up_graph_after_modifications,
|
8 |
| - get_tensor_placeholders, |
9 | 8 | )
|
10 | 9 |
|
11 | 10 | logger = logging.getLogger(__name__)
|
@@ -36,34 +35,13 @@ def efficient_attention_replacement() -> (
|
36 | 35 | ):
|
37 | 36 | """Constructs the original and replacement functions for efficient attention"""
|
38 | 37 |
|
39 |
| - # Empty boilerplate function taking in three Tensors and returning one |
40 |
| - def boilerplate( |
41 |
| - query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
42 |
| - ) -> torch.Tensor: |
43 |
| - ... |
44 |
| - |
45 |
| - # Trace boilerplate function and extract placeholder and output nodes |
46 |
| - orig = torch.fx.symbolic_trace(boilerplate) |
47 |
| - q, k, v = get_tensor_placeholders(orig) |
48 |
| - output = [node for node in orig.graph.nodes if node.op == "output"][0] |
49 |
| - |
50 |
| - # Graph types to replace are those which use the _scaled_dot_product_efficient_attention |
51 |
| - # function and extract only the first element |
52 |
| - with orig.graph.inserting_before(output): |
53 |
| - att = orig.graph.call_function( |
54 |
| - torch.ops.aten._scaled_dot_product_efficient_attention.default, |
55 |
| - args=(q, k, v, None, False), |
| 38 | + # Original graph |
| 39 | + def orig(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: |
| 40 | + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( |
| 41 | + q, k, v, None, False |
56 | 42 | )
|
57 |
| - out = orig.graph.call_function( |
58 |
| - operator.getitem, |
59 |
| - args=(att, 0), |
60 |
| - ) |
61 |
| - |
62 |
| - # Assign the output of the graph to be the single getitem output |
63 |
| - output.args = (out,) |
64 |
| - |
65 |
| - orig.graph.lint() |
66 |
| - orig.recompile() |
| 43 | + out = operator.getitem(outputs, 0) |
| 44 | + return out |
67 | 45 |
|
68 | 46 | # Replacement graph consists of the functional version of scaled_dot_product_attention
|
69 | 47 | def replacement(
|
|
0 commit comments