Skip to content

cherry-pick: Attention converter and linting fixes #2641

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ repos:
hooks:
- id: ruff
- repo: https://github.com/psf/black
rev: 23.7.0
rev: 24.1.1
hooks:
- id: black
exclude: ^examples/custom_converters/elu_converter/setup.py|^docs
Expand Down
9 changes: 8 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2243,7 +2243,14 @@ def tensorrt_scaled_dot_product_attention(
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.attention.scaled_dot_product_attention(
ctx, target, SourceIR.TORCHTRT_LOWERED, name, args[0], args[1], args[2]
ctx,
target,
SourceIR.TORCHTRT_LOWERED,
name,
args[0],
args[1],
args[2],
kwargs.get("scale", None),
)


Expand Down
29 changes: 20 additions & 9 deletions py/torch_tensorrt/dynamo/conversion/impl/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def scaled_dot_product_attention(
query: TRTTensor,
key: TRTTensor,
value: TRTTensor,
scale: Optional[float],
) -> TRTTensor:
mm = impl.matmul.matrix_multiply(
ctx,
Expand All @@ -27,16 +28,26 @@ def scaled_dot_product_attention(
key,
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
)
div = impl.elementwise.div(
ctx,
target,
source_ir,
name + "_scale",
mm,
math.sqrt(query.shape[-1]),
)
if scale is None:
scaled = impl.elementwise.div(
ctx,
target,
source_ir,
name + "_scale",
mm,
math.sqrt(query.shape[-1]),
)
else:
scaled = impl.elementwise.mul(
ctx,
target,
source_ir,
name + "_scale",
mm,
scale,
)
softmax = impl.normalization.softmax(
ctx, target, source_ir, name + "_softmax", div, -1
ctx, target, source_ir, name + "_softmax", scaled, -1
)
out = impl.matmul.matrix_multiply(
ctx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from .constant_folding import constant_fold
from .fuse_prims_broadcast import fuse_prims_broadcast
from .lower_efficient_attention import lower_efficient_attention
from .lower_linear import lower_linear
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
from .pass_manager import DynamoPassManager
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .repair_input_as_output import repair_input_as_output
Expand All @@ -18,7 +18,7 @@
remove_input_alias_fixing_clones,
constant_fold,
repair_input_as_output,
lower_efficient_attention,
lower_scaled_dot_product_attention,
lower_linear,
fuse_prims_broadcast,
replace_max_pool_with_indices,
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import logging
import operator
from typing import Callable, Sequence, Tuple

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)
REPLACEABLE_ATEN_OPS = {
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
}


def lower_scaled_dot_product_attention(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
"""Replace specific versions of scaled_dot_product_attention with an equivalent
implementation which can be easily converted to TRT
"""
original_fns, replacement = scaled_dot_product_attention_replacement()
replaced_nodes = []

# For each original function, search for it in the graph and replace
for original in original_fns:
replaced_nodes += torch.fx.subgraph_rewriter.replace_pattern_with_filters(
gm,
original,
replacement,
ignore_literals=True,
)

if replaced_nodes:
# Repair instances which use the kwargs field (specifically the "scale" kwarg)
for match in replaced_nodes:
attention_node_replaced = None
# Seek the attention operator being replaced
for node in match.nodes_map:
if node.target in REPLACEABLE_ATEN_OPS:
attention_node_replaced = match.nodes_map[node]
break

assert attention_node_replaced is not None

# If the attention operator had keyword-args, copy them to the new node
if attention_node_replaced.kwargs:
assert len(match.replacements) == 1
new_attention_node = match.replacements[0]
assert (
new_attention_node.target
== torch.nn.functional.scaled_dot_product_attention
)
new_attention_node.kwargs = {**attention_node_replaced.kwargs}

gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}")

return gm


def scaled_dot_product_attention_replacement() -> Tuple[
Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
]:
"""Constructs the original and replacement functions for efficient attention"""

# Efficient Attention original graph
def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
q,
k,
v,
None,
False,
)
out = operator.getitem(outputs, 0)
return out

# Flash Attention original graph
def flash(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
outputs = torch.ops.aten._scaled_dot_product_flash_attention.default(
q,
k,
v,
)
out = operator.getitem(outputs, 0)
return out

# Efficient Attention w/Scale original graph
def efficient_scale(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
q,
k,
v,
None,
False,
scale=1.0,
)
out = operator.getitem(outputs, 0)
return out

# Flash Attention w/Scale original graph
def flash_scale(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
outputs = torch.ops.aten._scaled_dot_product_flash_attention.default(
q,
k,
v,
scale=1.0,
)
out = operator.getitem(outputs, 0)
return out

# Replacement graph consists of the functional version of scaled_dot_product_attention
def replacement(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
return torch.nn.functional.scaled_dot_product_attention(query, key, value)

return (efficient, flash, efficient_scale, flash_scale), replacement
117 changes: 117 additions & 0 deletions tests/py/dynamo/lowering/test_aten_lowering_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,123 @@ def forward(self, q, k, v):
torch._dynamo.reset()


class TestLowerFlashAttention(TestCase):
def test_lower_flash_attention(self):
class FlashAttention(torch.nn.Module):
def forward(self, q, k, v):
attn = torch.ops.aten._scaled_dot_product_flash_attention.default(
q,
k,
v,
scale=0.15,
)
return attn[0]

inputs = [
torch.rand(8, 4, 16, 8).half().cuda(),
torch.rand(8, 4, 16, 8).half().cuda(),
torch.rand(8, 4, 16, 8).half().cuda(),
]

fx_graph = torch.fx.symbolic_trace(FlashAttention())
expected_ops = {torch.nn.functional.scaled_dot_product_attention}
unexpected_ops = {torch.ops.aten._scaled_dot_product_flash_attention.default}

unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEquals(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEquals(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)
torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
)
optimized_model_results = torch.cat(
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
)
torch_model_results = torch.cat(
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
)

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
# Remove 1 decimal from the requirement for FP16
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT - 1,
msg=f"FlashAttention TRT outputs don't match with the original model.",
)
torch._dynamo.reset()

def test_flash_attention_converter(self):
class FlashAttention(torch.nn.Module):
def forward(self, q, k, v):
attn = torch.ops.aten._scaled_dot_product_flash_attention.default(
q,
k,
v,
scale=0.25,
)
return attn[0]

inputs = [
torch.rand(1, 3, 6, 8).half().cuda(),
torch.rand(1, 3, 2, 8).half().cuda(),
torch.rand(1, 3, 2, 8).half().cuda(),
]

fx_graph = torch.fx.symbolic_trace(FlashAttention())

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
)
optimized_model_results = torch.cat(
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
)
torch_model_results = torch.cat(
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
)

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
# Remove 1 decimal from the requirement for FP16
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT - 1,
msg=f"FlashAttention TRT outputs don't match with the original model.",
)
torch._dynamo.reset()


class TestLowerLinear(TestCase):
def test_lower_linear(self):
class Linear(torch.nn.Module):
Expand Down