Skip to content

Commit ef1e2a3

Browse files
committed
re-apply
1 parent a517d15 commit ef1e2a3

File tree

5 files changed

+7
-12
lines changed

5 files changed

+7
-12
lines changed

examples/attention_sink/example_gqa_sink_bwd_bhsd.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from tilelang.profiler import do_bench
77
import tilelang.language as T
88
import argparse
9-
from typing import Optional
109

1110

1211
def get_bwd_configs():
@@ -405,7 +404,7 @@ def ref_program(query: torch.Tensor,
405404
key: torch.Tensor,
406405
value: torch.Tensor,
407406
sinks: torch.Tensor,
408-
sliding_window: Optional[int] = None,
407+
sliding_window: int | None = None,
409408
dtype: torch.dtype = torch.float16) -> torch.Tensor:
410409

411410
key = key.transpose(1, 2).contiguous()

examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import triton
1414
import triton.language as tl
1515
from triton.tools.tensor_descriptor import TensorDescriptor
16-
from typing import Optional
1716

1817

1918
def get_configs():
@@ -213,7 +212,7 @@ def ref_program(query: torch.Tensor,
213212
key: torch.Tensor,
214213
value: torch.Tensor,
215214
sinks: torch.Tensor,
216-
sliding_window: Optional[int] = None,
215+
sliding_window: int | None = None,
217216
dtype: torch.dtype = torch.float16) -> torch.Tensor:
218217

219218
key = key.transpose(1, 2).contiguous()

examples/attention_sink/example_mha_sink_bwd_bhsd.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from tilelang.profiler import do_bench
77
import tilelang.language as T
88
import argparse
9-
from typing import Optional
109

1110

1211
def get_bwd_configs():
@@ -401,7 +400,7 @@ def ref_program(query: torch.Tensor,
401400
key: torch.Tensor,
402401
value: torch.Tensor,
403402
sinks: torch.Tensor,
404-
sliding_window: Optional[int] = None,
403+
sliding_window: int | None = None,
405404
dtype: torch.dtype = torch.float16) -> torch.Tensor:
406405

407406
query = query.transpose(1, 2).contiguous().unsqueeze(

examples/attention_sink/example_mha_sink_fwd_bhsd.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from tilelang.layout import make_swizzled_layout
1010
import itertools
1111
import argparse
12-
from typing import Optional
1312

1413

1514
def get_configs():
@@ -193,7 +192,7 @@ def ref_program(query: torch.Tensor,
193192
key: torch.Tensor,
194193
value: torch.Tensor,
195194
sinks: torch.Tensor,
196-
sliding_window: Optional[int] = None,
195+
sliding_window: int | None = None,
197196
dtype: torch.dtype = torch.float16) -> torch.Tensor:
198197

199198
query = query.transpose(1, 2).contiguous().unsqueeze(
@@ -306,8 +305,8 @@ def main(batch: int = 1,
306305

307306
latency = do_bench(
308307
lambda: ref_program(Q, K, V, sinks, window_size, dtype=torch_dtype), warmup=500)
309-
print("Ref: {:.2f} ms".format(latency))
310-
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
308+
print(f"Ref: {latency:.2f} ms")
309+
print(f"Ref: {total_flops / latency * 1e-9:.2f} TFlops")
311310
latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500)
312311
print(f"Tilelang: {latency:.2f} ms")
313312
print(f"Tilelang: {total_flops / latency * 1e-9:.2f} TFlops")

examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import triton
1414
import triton.language as tl
1515
from triton.tools.tensor_descriptor import TensorDescriptor
16-
from typing import Optional
1716

1817

1918
def get_configs():
@@ -206,7 +205,7 @@ def ref_program(query: torch.Tensor,
206205
key: torch.Tensor,
207206
value: torch.Tensor,
208207
sinks: torch.Tensor,
209-
sliding_window: Optional[int] = None,
208+
sliding_window: int | None = None,
210209
dtype: torch.dtype = torch.float16) -> torch.Tensor:
211210

212211
query = query.transpose(1, 2).contiguous().unsqueeze(

0 commit comments

Comments
 (0)