diff --git a/challenges/hard/59_sliding_window_attn/challenge.html b/challenges/hard/59_sliding_window_attn/challenge.html new file mode 100644 index 0000000..05c2fd0 --- /dev/null +++ b/challenges/hard/59_sliding_window_attn/challenge.html @@ -0,0 +1,148 @@ +

+ Implement Sliding Window Self-Attention for a given set of matrices. + Before introducing the sliding window version, let's first recall standard Self-Attention. +

+ +

1. Standard Softmax Attention

+

+ Given query matrix Q, key matrix K, and value matrix V, each position i attends to all positions j using a softmax-weighted sum: +

+ +

+ \( \text{score}_{i,j} = \frac{Q_i \cdot K_j}{\sqrt{d}} \) +

+ +

+ \( \text{output}_i = \sum_{j=1}^{M} \text{softmax}(\text{score}_{i,*})_j \cdot V_j \) +

+ +

+ In other words, each query computes similarity with all keys, applies a softmax to get attention weights, and then computes a weighted sum of values. +

+ +

2. Sliding Window Self-Attention

+

+ Sliding Window Attention modifies standard attention by restricting each query to attend only to a local window around its position. +

+ + + +

+ \( \text{score}_{i,j} = \frac{Q_i \cdot K_j}{\sqrt{d}} \) +

+ + + +

+ \( \text{output}_i = \sum_{j \in [i-\text{window_size}, \, i+\text{window_size}]} \text{softmax}(\text{score}_{i,*})_j \cdot V_j \) +

+ +

+ In short, each query only attends to its nearby neighbors. +

+ + +

Implementation Requirements

+ +

Example 1:

+

+Input:
+Q (2×4): +\[ +\begin{bmatrix} +1.0 & 0.0 & 0.0 & 0.0 \\ +0.0 & 1.0 & 0.0 & 0.0 +\end{bmatrix} +\] +K (2×4): +\[ +\begin{bmatrix} +1.0 & 0.0 & 0.0 & 0.0 \\ +0.0 & 1.0 & 0.0 & 0.0 +\end{bmatrix} +\] +V (2×4): +\[ +\begin{bmatrix} +1.0 & 2.0 & 3.0 & 4.0 \\ +5.0 & 6.0 & 7.0 & 8.0 +\end{bmatrix} +\] +window_size: 1 +

+ +

+Output:
+output (2×4): +\[ +\begin{bmatrix} +2.5101628 & 3.5101628 & 4.510163 & 5.510163 \\ +3.4898374 & 4.4898376 & 5.4898376 & 6.489837 +\end{bmatrix} +\] +

+ + +

Example 2:

+

+ Input:
+ Q (2×3): + \[ + \begin{bmatrix} + 0.0 & 0.0 & 0.0 \\ + 0.0 & 1.0 & 0.0 + \end{bmatrix} + \] + K (2×3): + \[ + \begin{bmatrix} + 1.0 & 0.0 & 0.0 \\ + 0.0 & 1.0 & 0.0 + \end{bmatrix} + \] + V (2×3): + \[ + \begin{bmatrix} + 1.0 & 2.0 & 3.0 \\ + 5.0 & 6.0 & 7.0 + \end{bmatrix} + \] + window_size: 1 +

+ +

+ Output:
+ output (2×3): + \[ + \begin{bmatrix} + 3.0 & 4.0 & 5.0 \\ + 3.5618298 & 4.56183 & 5.5618296 + \end{bmatrix} + \] +

+ + + +

Constraints

+ \ No newline at end of file diff --git a/challenges/hard/59_sliding_window_attn/challenge.py b/challenges/hard/59_sliding_window_attn/challenge.py new file mode 100644 index 0000000..ff79956 --- /dev/null +++ b/challenges/hard/59_sliding_window_attn/challenge.py @@ -0,0 +1,107 @@ +import ctypes +from typing import Any, List, Dict +import torch +from core.challenge_base import ChallengeBase + +class Challenge(ChallengeBase): + def __init__(self): + super().__init__( + name="Sliding Window Self-Attention", + atol=1e-05, + rtol=1e-05, + num_gpus=1, + access_tier="free" + ) + + def reference_impl(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, output: torch.Tensor, M: int, d: int, window_size: int): + assert Q.shape == K.shape == V.shape == output.shape == (M,d) + + scores = (Q @ K.T) / (d ** 0.5) + + idxs = torch.arange(M) + mask = (idxs[None, :] - idxs[:, None]).abs() > window_size + mask = mask.to(Q.device) + scores.masked_fill_(mask, float('-inf')) + attn = torch.softmax(scores, dim=1) + + torch.matmul(attn, V, out=output) + + def get_solve_signature(self) -> Dict[str, Any]: + return { + "Q": (ctypes.POINTER(ctypes.c_float), "in"), + "K": (ctypes.POINTER(ctypes.c_float), "in"), + "V": (ctypes.POINTER(ctypes.c_float), "in"), + "output": (ctypes.POINTER(ctypes.c_float), "out"), + "M": (ctypes.c_int, "in"), + "d": (ctypes.c_int, "in"), + "window_size": (ctypes.c_int, "in"), + } + + def generate_example_test(self) -> Dict[str, Any]: + dtype = torch.float32 + Q = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], device="cuda", dtype=dtype) + K= torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], device="cuda", dtype=dtype) + V= torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], device="cuda", dtype=dtype) + output = torch.empty(2, 4, device="cuda", dtype=dtype) + return {"Q": Q, "K": K, "V": V, "output": output, "M": 2, "d": 4, "window_size": 1} + + def generate_functional_test(self) -> List[Dict[str, Any]]: + dtype = torch.float32 + tests = [] + + # basic_example + tests.append({ + "Q": torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], device="cuda", dtype=dtype), + "K": torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], device="cuda", dtype=dtype), + "V": torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], device="cuda", dtype=dtype), + "output": torch.empty(2, 4, device="cuda", dtype=dtype), + "M": 2, "d": 4, "window_size" : 1 + }) + + # basic_example + tests.append({ + "Q": torch.tensor([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device="cuda", dtype=dtype), + "K": torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device="cuda", dtype=dtype), + "V": torch.tensor([[1.0, 2.0, 3.0], [5.0, 6.0, 7.0]], device="cuda", dtype=dtype), + "output": torch.empty(2, 3, device="cuda", dtype=dtype), + "M": 2, "d": 3, "window_size" : 1 + }) + + + # zero_matrices + tests.append({ + "Q": torch.zeros((3, 5), device="cuda", dtype=dtype), + "K": torch.zeros((3, 5), device="cuda", dtype=dtype), + "V": torch.zeros((3, 5), device="cuda", dtype=dtype), + "output": torch.empty(3, 5, device="cuda", dtype=dtype), + "M": 3, "d": 5, "window_size" : 2 + }) + + # mixed_values + tests.append({ + "Q": torch.tensor([[-1.0, 2.0, -3.0], [4.0, -5.0, 6.0], [-7.0, 8.0, -9.0], [10.0, -11.0, 12.0]], device="cuda", dtype=dtype), + "K": torch.tensor([[2.0, -1.0, 3.0], [-4.0, 5.0, -6.0], [7.0, -8.0, 9.0], [-10.0, 11.0, -12.0]], device="cuda", dtype=dtype), + "V": torch.tensor([[1.0, 0.5, -0.5], [-1.0, 2.0, 3.0], [4.0, -2.0, 1.0], [0.0, 1.0, -1.0]], device="cuda", dtype=dtype), + "output": torch.empty(4, 3, device="cuda", dtype=dtype), + "M": 4, "d": 3, "window_size" : 2 + }) + + # large_matrices + tests.append({ + "Q": torch.empty((128, 32), device="cuda", dtype=dtype).uniform_(-0.1, 0.1), + "K": torch.empty((128, 32), device="cuda", dtype=dtype).uniform_(-0.1, 0.1), + "V": torch.empty((128, 32), device="cuda", dtype=dtype).uniform_(-0.1, 0.1), + "output": torch.empty(128, 32, device="cuda", dtype=dtype), + "M": 128, "d": 32, "window_size" : 8 + }) + + return tests + + def generate_performance_test(self) -> Dict[str, Any]: + dtype = torch.float32 + M, d, window_size = 5000, 64, 16 + Q = torch.empty((M, d), device="cuda", dtype=dtype).uniform_(-100, 100) + K = torch.empty((M, d), device="cuda", dtype=dtype).uniform_(-100, 100) + V = torch.empty((M, d), device="cuda", dtype=dtype).uniform_(-100, 100) + output = torch.empty(M, d, device="cuda", dtype=dtype) + return {"Q": Q, "K": K, "V": V, "output": output, "M": M, "d": d, "window_size" : window_size} \ No newline at end of file diff --git a/challenges/hard/59_sliding_window_attn/starter/starter.cu b/challenges/hard/59_sliding_window_attn/starter/starter.cu new file mode 100644 index 0000000..1568bd6 --- /dev/null +++ b/challenges/hard/59_sliding_window_attn/starter/starter.cu @@ -0,0 +1,6 @@ +#include + +// Q, K, V, output are device pointers +extern "C" void solve(const float* Q, const float* K, const float* V, float* output, int M, int d, int window_size) { + +} diff --git a/challenges/hard/59_sliding_window_attn/starter/starter.cute.py b/challenges/hard/59_sliding_window_attn/starter/starter.cute.py new file mode 100644 index 0000000..73c256f --- /dev/null +++ b/challenges/hard/59_sliding_window_attn/starter/starter.cute.py @@ -0,0 +1,8 @@ +import cutlass +import cutlass.cute as cute + +# Q, K, V, output are tensors on the GPU +@cute.jit +def solve(Q: cute.Tensor, K: cute.Tensor, V: cute.Tensor, output: cute.Tensor, M: int, d: int, window_size: int): + pass + \ No newline at end of file diff --git a/challenges/hard/59_sliding_window_attn/starter/starter.mojo b/challenges/hard/59_sliding_window_attn/starter/starter.mojo new file mode 100644 index 0000000..87114ae --- /dev/null +++ b/challenges/hard/59_sliding_window_attn/starter/starter.mojo @@ -0,0 +1,10 @@ +from gpu.host import DeviceContext +from gpu.id import block_dim, block_idx, thread_idx +from memory import UnsafePointer +from math import ceildiv + +# Q, K, V, output are device pointers (i.e. pointers to memory on the GPU) +@export +def solve(Q: UnsafePointer[Float32], K: UnsafePointer[Float32], V: UnsafePointer[Float32], + output: UnsafePointer[Float32], M: Int32, d: Int32, window_size: Int32): + pass \ No newline at end of file diff --git a/challenges/hard/59_sliding_window_attn/starter/starter.pytorch.py b/challenges/hard/59_sliding_window_attn/starter/starter.pytorch.py new file mode 100644 index 0000000..89b4620 --- /dev/null +++ b/challenges/hard/59_sliding_window_attn/starter/starter.pytorch.py @@ -0,0 +1,6 @@ +import torch + +# Q, K, V, output are tensors on the GPU +def solve(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, output: torch.Tensor, + M: int, d: int, window_size: int): + pass \ No newline at end of file diff --git a/challenges/hard/59_sliding_window_attn/starter/starter.triton.py b/challenges/hard/59_sliding_window_attn/starter/starter.triton.py new file mode 100644 index 0000000..a8d05ee --- /dev/null +++ b/challenges/hard/59_sliding_window_attn/starter/starter.triton.py @@ -0,0 +1,7 @@ +import torch +import triton +import triton.language as tl + +# Q, K, V, output are tensors on the GPU +def solve(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, output: torch.Tensor, M: int, d: int, window_size: int): + pass \ No newline at end of file