Skip to content

Commit 95e81f6

Browse files
committed
add kernelagent generated triton kernels
1 parent 7a7a841 commit 95e81f6

File tree

249 files changed

+28659
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

249 files changed

+28659
-0
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import torch
2+
import torch.nn as nn
3+
import triton
4+
import triton.language as tl
5+
6+
@triton.jit
7+
def hinge_loss_kernel(
8+
predictions_ptr,
9+
targets_ptr,
10+
output_ptr,
11+
n_elements,
12+
BLOCK_SIZE: tl.constexpr,
13+
):
14+
pid = tl.program_id(0)
15+
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
16+
mask = offsets < n_elements
17+
18+
p = tl.load(predictions_ptr + offsets, mask=mask)
19+
t = tl.load(targets_ptr + offsets, mask=mask)
20+
21+
element = 1.0 - p * t
22+
clamped = tl.where(element > 0, element, 0.0)
23+
24+
total = tl.sum(clamped, axis=0)
25+
if pid == 0:
26+
mean_val = total / n_elements
27+
tl.store(output_ptr, mean_val)
28+
29+
class ModelNew(nn.Module):
30+
def __init__(self):
31+
super(ModelNew, self).__init__()
32+
33+
def forward(self, predictions, targets):
34+
total_elements = predictions.numel()
35+
if total_elements == 0:
36+
return torch.tensor(0.0, device=predictions.device)
37+
38+
predictions_flat = predictions.view(-1)
39+
targets_flat = targets.view(-1)
40+
output = torch.empty(1, device=predictions.device)
41+
42+
grid = (1,)
43+
BLOCK_SIZE = triton.next_power_of_2(total_elements)
44+
hinge_loss_kernel[grid](
45+
predictions_flat,
46+
targets_flat,
47+
output,
48+
total_elements,
49+
BLOCK_SIZE=BLOCK_SIZE,
50+
)
51+
return output.squeeze(0)
52+
53+
batch_size = 128
54+
input_shape = (1,)
55+
dim = 1
56+
57+
def get_inputs():
58+
return [torch.randn(batch_size, *input_shape), torch.randint(0, 2, (batch_size, 1)).float() * 2 - 1]
59+
60+
def get_init_inputs():
61+
return []
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import torch
2+
import torch.nn as nn
3+
import triton
4+
import triton.language as tl
5+
6+
@triton.autotune(
7+
configs=[
8+
triton.Config({'BLOCK_M': 64, 'BLOCK_L': 64, 'BLOCK_K': 64}, num_warps=4, num_stages=4),
9+
triton.Config({'BLOCK_M': 64, 'BLOCK_L': 128, 'BLOCK_K': 64}, num_warps=4, num_stages=3),
10+
triton.Config({'BLOCK_M': 128, 'BLOCK_L': 64, 'BLOCK_K': 64}, num_warps=4, num_stages=3),
11+
triton.Config({'BLOCK_M': 128, 'BLOCK_L': 128, 'BLOCK_K': 64}, num_warps=8, num_stages=3),
12+
triton.Config({'BLOCK_M': 64, 'BLOCK_L': 64, 'BLOCK_K': 128}, num_warps=4, num_stages=3),
13+
],
14+
key=['M', 'K', 'L'],
15+
)
16+
@triton.jit
17+
def _matmul_kernel(
18+
A_ptr, B_ptr, C_ptr,
19+
N, M, K, L,
20+
stride_An, stride_Am, stride_Ak,
21+
stride_Bk, stride_Bl,
22+
stride_Cn, stride_Cm, stride_Cl,
23+
BLOCK_M: tl.constexpr, BLOCK_L: tl.constexpr, BLOCK_K: tl.constexpr,
24+
):
25+
pid_n = tl.program_id(0)
26+
pid_m = tl.program_id(1)
27+
pid_l = tl.program_id(2)
28+
29+
if pid_n >= N:
30+
return
31+
32+
# Create block offsets with proper masking
33+
m_offsets = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
34+
l_offsets = pid_l * BLOCK_L + tl.arange(0, BLOCK_L)
35+
k_offsets = tl.arange(0, BLOCK_K)
36+
37+
# Initialize accumulator
38+
acc = tl.zeros((BLOCK_M, BLOCK_L), dtype=tl.float32)
39+
40+
# Compute pointer bases
41+
a_base = A_ptr + pid_n * stride_An
42+
b_base = B_ptr
43+
c_base = C_ptr + pid_n * stride_Cn
44+
45+
# Blocked matrix multiplication
46+
for k in range(0, tl.cdiv(K, BLOCK_K)):
47+
# Compute current K block
48+
k_start = k * BLOCK_K
49+
50+
# Load A block with coalesced access
51+
a_ptrs = a_base + m_offsets[:, None] * stride_Am + (k_start + k_offsets[None, :]) * stride_Ak
52+
a_mask = (m_offsets[:, None] < M) & ((k_start + k_offsets[None, :]) < K)
53+
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
54+
55+
# Load B block with coalesced access
56+
b_ptrs = b_base + (k_start + k_offsets[:, None]) * stride_Bk + l_offsets[None, :] * stride_Bl
57+
b_mask = ((k_start + k_offsets[:, None]) < K) & (l_offsets[None, :] < L)
58+
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
59+
60+
# Accumulate matrix product with full FP32 precision
61+
acc += tl.dot(a, b, allow_tf32=False)
62+
63+
# Store result with masking
64+
c_ptrs = c_base + m_offsets[:, None] * stride_Cm + l_offsets[None, :] * stride_Cl
65+
c_mask = (m_offsets[:, None] < M) & (l_offsets[None, :] < L)
66+
tl.store(c_ptrs, acc, mask=c_mask)
67+
68+
class ModelNew(nn.Module):
69+
def __init__(self):
70+
super(ModelNew, self).__init__()
71+
72+
def forward(self, A, B):
73+
N, M, K = A.shape
74+
L = B.shape[1]
75+
A = A.contiguous()
76+
B = B.contiguous()
77+
C = torch.empty((N, M, L), device=A.device, dtype=A.dtype)
78+
79+
# Dynamic grid using autotuner meta-parameters
80+
grid = lambda meta: (
81+
N,
82+
triton.cdiv(M, meta['BLOCK_M']),
83+
triton.cdiv(L, meta['BLOCK_L'])
84+
)
85+
86+
# Launch kernel without overriding autotuned parameters
87+
_matmul_kernel[grid](
88+
A, B, C,
89+
N, M, K, L,
90+
A.stride(0), A.stride(1), A.stride(2),
91+
B.stride(0), B.stride(1),
92+
C.stride(0), C.stride(1), C.stride(2),
93+
)
94+
return C
95+
96+
N = 16
97+
M = 1024
98+
K = 2048
99+
L = 768
100+
101+
def get_inputs():
102+
A = torch.randn(N, M, K, device='cuda')
103+
B = torch.randn(K, L, device='cuda')
104+
return [A, B]
105+
106+
def get_init_inputs():
107+
return []
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import torch
2+
import torch.nn as nn
3+
import triton
4+
import triton.language as tl
5+
6+
@triton.autotune(
7+
configs=[
8+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=4),
9+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=4),
10+
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_warps=8, num_stages=3),
11+
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=4),
12+
],
13+
key=['M', 'N', 'K'],
14+
)
15+
@triton.jit
16+
def _triton_matmul(
17+
a_ptr, b_ptr, c_ptr,
18+
M, N, K,
19+
stride_am, stride_ak,
20+
stride_bk, stride_bn,
21+
stride_cm, stride_cn,
22+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
23+
):
24+
pid_m = tl.program_id(0)
25+
pid_n = tl.program_id(1)
26+
27+
rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
28+
rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
29+
30+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
31+
32+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
33+
rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
34+
a_mask = (rm[:, None] < M) & (rk[None, :] < K)
35+
b_mask = (rk[:, None] < K) & (rn[None, :] < N)
36+
37+
a = tl.load(a_ptr + rm[:, None] * stride_am + rk[None, :] * stride_ak,
38+
mask=a_mask, other=0.0)
39+
b = tl.load(b_ptr + rk[:, None] * stride_bk + rn[None, :] * stride_bn,
40+
mask=b_mask, other=0.0)
41+
acc += tl.dot(a, b, allow_tf32=False, out_dtype=tl.float32)
42+
43+
c_ptrs = c_ptr + rm[:, None] * stride_cm + rn[None, :] * stride_cn
44+
c_mask = (rm[:, None] < M) & (rn[None, :] < N)
45+
tl.store(c_ptrs, acc, mask=c_mask)
46+
47+
class ModelNew(nn.Module):
48+
def __init__(self):
49+
super(ModelNew, self).__init__()
50+
51+
def forward(self, A, B):
52+
b, i, j, l = A.shape
53+
k = B.shape[1]
54+
55+
# Ensure contiguous memory layout
56+
A_flat = A.reshape(-1, l).contiguous()
57+
B = B.contiguous()
58+
59+
M, K = A_flat.shape
60+
N = k
61+
62+
C_flat = torch.empty((M, N), device=A.device, dtype=A.dtype)
63+
64+
grid = lambda META: (
65+
triton.cdiv(M, META['BLOCK_SIZE_M']),
66+
triton.cdiv(N, META['BLOCK_SIZE_N']),
67+
)
68+
69+
_triton_matmul[grid](
70+
A_flat, B, C_flat,
71+
M, N, K,
72+
A_flat.stride(0), A_flat.stride(1),
73+
B.stride(0), B.stride(1),
74+
C_flat.stride(0), C_flat.stride(1),
75+
)
76+
77+
return C_flat.reshape(b, i, j, k)
78+
79+
# Test code
80+
b = 16
81+
i = 256
82+
j = 512
83+
l = 256
84+
k = 768
85+
86+
def get_inputs():
87+
A = torch.randn(b, i, j, l)
88+
B = torch.randn(l, k)
89+
return [A, B]
90+
91+
def get_init_inputs():
92+
return [] # No special initialization inputs needed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import torch
2+
import torch.nn as nn
3+
import triton
4+
import triton.language as tl
5+
6+
@triton.jit
7+
def diag_matmul_kernel(
8+
A_ptr,
9+
B_ptr,
10+
C_ptr,
11+
N, M,
12+
stride_B0, stride_B1,
13+
stride_C0, stride_C1,
14+
BLOCK_SIZE: tl.constexpr,
15+
):
16+
pid = tl.program_id(0)
17+
if pid >= N:
18+
return
19+
20+
a_val = tl.load(A_ptr + pid)
21+
row_start_B = pid * stride_B0
22+
row_start_C = pid * stride_C0
23+
24+
for col_block in range(0, tl.cdiv(M, BLOCK_SIZE)):
25+
col_offset = col_block * BLOCK_SIZE
26+
col_indices = col_offset + tl.arange(0, BLOCK_SIZE)
27+
mask = col_indices < M
28+
29+
b_vals = tl.load(
30+
B_ptr + row_start_B + col_indices * stride_B1,
31+
mask=mask,
32+
other=0.0
33+
)
34+
c_vals = a_val * b_vals
35+
tl.store(
36+
C_ptr + row_start_C + col_indices * stride_C1,
37+
c_vals,
38+
mask=mask
39+
)
40+
41+
class ModelNew(nn.Module):
42+
def __init__(self):
43+
super(ModelNew, self).__init__()
44+
45+
def forward(self, A, B):
46+
N, M = B.shape
47+
C = torch.empty_like(B)
48+
49+
if B.numel() == 0:
50+
return C
51+
52+
BLOCK_SIZE = 1024
53+
grid = (N,)
54+
diag_matmul_kernel[grid](
55+
A, B, C,
56+
N, M,
57+
B.stride(0), B.stride(1),
58+
C.stride(0), C.stride(1),
59+
BLOCK_SIZE=BLOCK_SIZE
60+
)
61+
return C
62+
63+
M = 4096
64+
N = 4096
65+
66+
def get_inputs():
67+
A = torch.randn(N)
68+
B = torch.randn(N, M)
69+
return [A, B]
70+
71+
def get_init_inputs():
72+
return [] # No special initialization inputs needed

0 commit comments

Comments
 (0)