Skip to content

Commit 96f807e

Browse files
committed
typo
1 parent 1d1f48e commit 96f807e

File tree

2 files changed

+1
-3
lines changed

2 files changed

+1
-3
lines changed

examples/attention_sink/example_gqa_sink_bwd_bhsd.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,6 @@ def maybe_contiguous(x):
368368
dq = torch.zeros(q_shape, dtype=torch.float32, device=q.device) # acc for atomicAdd
369369
dk = torch.zeros(kv_shape, dtype=torch.float16, device=q.device)
370370
dv = torch.zeros(kv_shape, dtype=torch.float16, device=q.device)
371-
dsinks = torch.empty([BATCH, H], dtype=torch.float32, device=q.device)
372371
kernel(q, k, v, do, lse, delta, dq, dk, dv)
373372
dq = kernel_post(dq)
374373

examples/attention_sink/example_mha_sink_bwd_bhsd.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,6 @@ def maybe_contiguous(x):
374374
dq = torch.zeros(shape, dtype=torch.float32, device=q.device) # acc for atomicAdd
375375
dk = torch.empty(shape, dtype=torch.float16, device=q.device)
376376
dv = torch.empty(shape, dtype=torch.float16, device=q.device)
377-
dsinks = torch.empty([BATCH, H], dtype=torch.float32, device=q.device)
378377
kernel(q, k, v, do, lse, delta, dq, dk, dv)
379378
dq = kernel_post(dq)
380379

@@ -395,7 +394,7 @@ def ref_program(query: torch.Tensor,
395394
sliding_window: int | None = None) -> torch.Tensor:
396395

397396
query = query.transpose(1, 2).contiguous().unsqueeze(
398-
3) # align with the original function'sinterface
397+
3) # align with the original function's interface
399398
key = key.transpose(1, 2).contiguous()
400399
value = value.transpose(1, 2).contiguous()
401400

0 commit comments

Comments
 (0)