Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev attention[SiliconFlow] #236

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
48 changes: 48 additions & 0 deletions benchmark/test_attention_perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Generator

import torch

from .performance_utils import Benchmark


class AttentionBenchmark(Benchmark):
"""
benchmark for attention
"""

def __init__(self, *args, input_fn, **kwargs):
super().__init__(*args, **kwargs)
self.input_fn = input_fn

def get_input_iter(self, cur_dtype) -> Generator:
for seq_len in [1024, 2048, 3072, 4096]:
yield from self.input_fn(cur_dtype, seq_len)


def test_perf_scaled_dot_product_attention():
def scaled_dot_product_attention_kwargs(dtype, seq_len):
num_heads = 8
head_size = 128
batch = 4

query = torch.randn(
(batch, num_heads, seq_len, head_size), device="cuda", dtype=dtype
)
key = torch.randn(
(batch, num_heads, seq_len, head_size), device="cuda", dtype=dtype
)
value = torch.randn(
(batch, num_heads, seq_len, head_size), device="cuda", dtype=dtype
)
yield query, key, value, None, 0.0, True

bench = AttentionBenchmark(
op_name="scaled_dot_product_attention",
input_fn=scaled_dot_product_attention_kwargs,
torch_op=torch.nn.functional.scaled_dot_product_attention,
dtypes=[
# torch.float32,
torch.float16,
],
)
bench.run()
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def enable(lib=aten_lib):
lib.impl("repeat_interleave.self_int", repeat_interleave_self_int, "CUDA")
lib.impl("vstack", vstack, "CUDA")
lib.impl("repeat_interleave.Tensor", repeat_interleave_tensor, "CUDA")
lib.impl("scaled_dot_product_attention", scaled_dot_product_attention, "CUDA")
lib.impl("repeat_interleave.self_Tensor", repeat_interleave_self_tensor, "CUDA")


Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .any import any, any_dim, any_dims
from .arange import arange, arange_start
from .argmax import argmax
from .attention import scaled_dot_product_attention
from .bitwise_and import (
bitwise_and_scalar,
bitwise_and_scalar_tensor,
Expand Down Expand Up @@ -247,5 +248,6 @@
"repeat_interleave_self_int",
"vstack",
"repeat_interleave_tensor",
"scaled_dot_product_attention",
"repeat_interleave_self_tensor",
]
Loading
Loading