11# Copyright (c) 2024, Tri Dao, Albert Gu.
2+ # Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_bmm.py
23
4+ # ruff: noqa: E501,SIM102
35"""We want triton==2.1.0 or 2.2.0 for this
46"""
57
68import math
7- import torch
8- import torch .nn .functional as F
99
10+ import torch
1011import triton
1112import triton .language as tl
1213
13- from einops import rearrange , repeat
14-
1514
1615def init_to_zero (names ):
17- return lambda nargs : [nargs [name ].zero_ () for name in names if nargs [name ] is not None ]
16+ return lambda nargs : [
17+ nargs [name ].zero_ () for name in names if nargs [name ] is not None
18+ ]
1819
1920
2021@triton .autotune (
2122 configs = [
22- triton .Config ({'BLOCK_SIZE_M' : 128 , 'BLOCK_SIZE_N' : 256 , 'BLOCK_SIZE_K' : 64 }, num_stages = 3 , num_warps = 8 ),
23- triton .Config ({'BLOCK_SIZE_M' : 64 , 'BLOCK_SIZE_N' : 256 , 'BLOCK_SIZE_K' : 32 }, num_stages = 4 , num_warps = 4 ),
24- triton .Config ({'BLOCK_SIZE_M' : 128 , 'BLOCK_SIZE_N' : 128 , 'BLOCK_SIZE_K' : 32 }, num_stages = 4 , num_warps = 4 ),
25- triton .Config ({'BLOCK_SIZE_M' : 128 , 'BLOCK_SIZE_N' : 64 , 'BLOCK_SIZE_K' : 32 }, num_stages = 4 , num_warps = 4 ),
26- triton .Config ({'BLOCK_SIZE_M' : 64 , 'BLOCK_SIZE_N' : 128 , 'BLOCK_SIZE_K' : 32 }, num_stages = 4 , num_warps = 4 ),
27- triton .Config ({'BLOCK_SIZE_M' : 128 , 'BLOCK_SIZE_N' : 32 , 'BLOCK_SIZE_K' : 32 }, num_stages = 4 , num_warps = 4 ),
28- triton .Config ({'BLOCK_SIZE_M' : 64 , 'BLOCK_SIZE_N' : 32 , 'BLOCK_SIZE_K' : 32 }, num_stages = 5 , num_warps = 2 ),
29- triton .Config ({'BLOCK_SIZE_M' : 32 , 'BLOCK_SIZE_N' : 64 , 'BLOCK_SIZE_K' : 32 }, num_stages = 5 , num_warps = 2 ),
30- triton .Config ({'BLOCK_SIZE_M' : 64 , 'BLOCK_SIZE_N' : 64 , 'BLOCK_SIZE_K' : 32 }, num_stages = 4 , num_warps = 2 ),
23+ triton .Config (
24+ {
25+ 'BLOCK_SIZE_M' : 128 ,
26+ 'BLOCK_SIZE_N' : 256 ,
27+ 'BLOCK_SIZE_K' : 64
28+ },
29+ num_stages = 3 ,
30+ num_warps = 8 ),
31+ triton .Config (
32+ {
33+ 'BLOCK_SIZE_M' : 64 ,
34+ 'BLOCK_SIZE_N' : 256 ,
35+ 'BLOCK_SIZE_K' : 32
36+ },
37+ num_stages = 4 ,
38+ num_warps = 4 ),
39+ triton .Config (
40+ {
41+ 'BLOCK_SIZE_M' : 128 ,
42+ 'BLOCK_SIZE_N' : 128 ,
43+ 'BLOCK_SIZE_K' : 32
44+ },
45+ num_stages = 4 ,
46+ num_warps = 4 ),
47+ triton .Config (
48+ {
49+ 'BLOCK_SIZE_M' : 128 ,
50+ 'BLOCK_SIZE_N' : 64 ,
51+ 'BLOCK_SIZE_K' : 32
52+ },
53+ num_stages = 4 ,
54+ num_warps = 4 ),
55+ triton .Config (
56+ {
57+ 'BLOCK_SIZE_M' : 64 ,
58+ 'BLOCK_SIZE_N' : 128 ,
59+ 'BLOCK_SIZE_K' : 32
60+ },
61+ num_stages = 4 ,
62+ num_warps = 4 ),
63+ triton .Config (
64+ {
65+ 'BLOCK_SIZE_M' : 128 ,
66+ 'BLOCK_SIZE_N' : 32 ,
67+ 'BLOCK_SIZE_K' : 32
68+ },
69+ num_stages = 4 ,
70+ num_warps = 4 ),
71+ triton .Config (
72+ {
73+ 'BLOCK_SIZE_M' : 64 ,
74+ 'BLOCK_SIZE_N' : 32 ,
75+ 'BLOCK_SIZE_K' : 32
76+ },
77+ num_stages = 5 ,
78+ num_warps = 2 ),
79+ triton .Config (
80+ {
81+ 'BLOCK_SIZE_M' : 32 ,
82+ 'BLOCK_SIZE_N' : 64 ,
83+ 'BLOCK_SIZE_K' : 32
84+ },
85+ num_stages = 5 ,
86+ num_warps = 2 ),
87+ triton .Config (
88+ {
89+ 'BLOCK_SIZE_M' : 64 ,
90+ 'BLOCK_SIZE_N' : 64 ,
91+ 'BLOCK_SIZE_K' : 32
92+ },
93+ num_stages = 4 ,
94+ num_warps = 2 ),
3195 ],
3296 key = ['chunk_size' , 'K' , 'IS_CAUSAL' ],
3397)
3498@triton .jit
3599def _bmm_chunk_fwd_kernel (
36100 # Pointers to matrices
37- a_ptr , b_ptr , out_ptr , seq_idx_ptr ,
101+ a_ptr ,
102+ b_ptr ,
103+ out_ptr ,
104+ seq_idx_ptr ,
38105 # Matrix dimensions
39- seqlen , chunk_size , K , ngroups ,
40- stride_a_batch , stride_a_seqlen , stride_a_head , stride_ak ,
41- stride_b_batch , stride_b_seqlen , stride_b_head , stride_bk ,
42- stride_out_batch , stride_out_chunk , stride_out_head , stride_outm , stride_outn ,
43- stride_seq_idx_batch , stride_seq_idx_seqlen ,
106+ seqlen ,
107+ chunk_size ,
108+ K ,
109+ ngroups ,
110+ stride_a_batch ,
111+ stride_a_seqlen ,
112+ stride_a_head ,
113+ stride_ak ,
114+ stride_b_batch ,
115+ stride_b_seqlen ,
116+ stride_b_head ,
117+ stride_bk ,
118+ stride_out_batch ,
119+ stride_out_chunk ,
120+ stride_out_head ,
121+ stride_outm ,
122+ stride_outn ,
123+ stride_seq_idx_batch ,
124+ stride_seq_idx_seqlen ,
44125 # Meta-parameters
45126 IS_CAUSAL : tl .constexpr ,
46127 dot_dtype : tl .constexpr ,
47128 HAS_SEQ_IDX : tl .constexpr ,
48- BLOCK_SIZE_M : tl .constexpr , BLOCK_SIZE_N : tl .constexpr , BLOCK_SIZE_K : tl .constexpr ,
129+ BLOCK_SIZE_M : tl .constexpr ,
130+ BLOCK_SIZE_N : tl .constexpr ,
131+ BLOCK_SIZE_K : tl .constexpr ,
49132):
50133 pid_b = tl .program_id (axis = 1 )
51134 pid_ch = tl .program_id (axis = 2 ).to (tl .int64 )
@@ -65,14 +148,22 @@ def _bmm_chunk_fwd_kernel(
65148 offs_m = pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
66149 offs_n = pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )
67150 offs_k = tl .arange (0 , BLOCK_SIZE_K )
68- a_ptrs = a_ptr + (offs_m [:, None ] * stride_a_seqlen + offs_k [None , :] * stride_ak )
69- b_ptrs = b_ptr + (offs_k [:, None ] * stride_bk + offs_n [None , :] * stride_b_seqlen )
151+ a_ptrs = a_ptr + (offs_m [:, None ] * stride_a_seqlen +
152+ offs_k [None , :] * stride_ak )
153+ b_ptrs = b_ptr + (offs_k [:, None ] * stride_bk +
154+ offs_n [None , :] * stride_b_seqlen )
70155 chunk_size_limit = min (chunk_size , seqlen - pid_c * chunk_size )
71156
72157 acc = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .float32 )
73158 for k in range (0 , tl .cdiv (K , BLOCK_SIZE_K )):
74- a = tl .load (a_ptrs , mask = (offs_m [:, None ] < chunk_size_limit ) & (offs_k [None , :] < K - k * BLOCK_SIZE_K ), other = 0.0 ).to (dot_dtype )
75- b = tl .load (b_ptrs , mask = (offs_k [:, None ] < K - k * BLOCK_SIZE_K ) & (offs_n [None , :] < chunk_size_limit ), other = 0.0 ).to (dot_dtype )
159+ a = tl .load (a_ptrs ,
160+ mask = (offs_m [:, None ] < chunk_size_limit ) &
161+ (offs_k [None , :] < K - k * BLOCK_SIZE_K ),
162+ other = 0.0 ).to (dot_dtype )
163+ b = tl .load (b_ptrs ,
164+ mask = (offs_k [:, None ] < K - k * BLOCK_SIZE_K ) &
165+ (offs_n [None , :] < chunk_size_limit ),
166+ other = 0.0 ).to (dot_dtype )
76167 acc += tl .dot (a , b )
77168 a_ptrs += BLOCK_SIZE_K * stride_ak
78169 b_ptrs += BLOCK_SIZE_K * stride_bk
@@ -81,16 +172,30 @@ def _bmm_chunk_fwd_kernel(
81172 offs_n = pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )
82173 if HAS_SEQ_IDX :
83174 chunk_size_limit = min (chunk_size , seqlen - pid_c * chunk_size )
84- seq_idx_m = tl .load (seq_idx_ptr + offs_m * stride_seq_idx_seqlen , mask = offs_m < chunk_size_limit , other = - 1 )
85- seq_idx_n = tl .load (seq_idx_ptr + offs_n * stride_seq_idx_seqlen , mask = offs_n < chunk_size_limit , other = - 2 )
175+ seq_idx_m = tl .load (seq_idx_ptr + offs_m * stride_seq_idx_seqlen ,
176+ mask = offs_m < chunk_size_limit ,
177+ other = - 1 )
178+ seq_idx_n = tl .load (seq_idx_ptr + offs_n * stride_seq_idx_seqlen ,
179+ mask = offs_n < chunk_size_limit ,
180+ other = - 2 )
86181 acc = tl .where (seq_idx_m [:, None ] == seq_idx_n [None , :], acc , 0.0 )
87182 out = acc .to (out_ptr .dtype .element_ty )
88183
89184 out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
90- out_ptrs = out_ptr + (stride_outm * offs_m [:, None ] + offs_n [None , :] * stride_outn )
91- tl .store (out_ptrs , out , mask = (offs_m [:, None ] < chunk_size ) & (offs_n [None , :] < chunk_size ))
185+ out_ptrs = out_ptr + (stride_outm * offs_m [:, None ] +
186+ offs_n [None , :] * stride_outn )
187+ tl .store (out_ptrs ,
188+ out ,
189+ mask = (offs_m [:, None ] < chunk_size ) &
190+ (offs_n [None , :] < chunk_size ))
191+
92192
93- def _bmm_chunk_fwd (a , b , chunk_size , seq_idx = None , causal = False , output_dtype = None ):
193+ def _bmm_chunk_fwd (a ,
194+ b ,
195+ chunk_size ,
196+ seq_idx = None ,
197+ causal = False ,
198+ output_dtype = None ):
94199 """
95200 Argument:
96201 a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
@@ -117,20 +222,44 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=No
117222 nchunks = math .ceil (seqlen / chunk_size )
118223 # Allocates output.
119224 out_dtype = a .dtype if output_dtype is None else output_dtype
120- out = torch .empty ((batch , nchunks , chunk_size , chunk_size ) if not has_groups else (batch , nchunks , ngroups , chunk_size , chunk_size ),
121- device = a .device , dtype = out_dtype )
122- dot_dtype = (tl .bfloat16 if a .dtype == torch .bfloat16 or b .dtype == torch .bfloat16 else
123- (tl .float16 if a .dtype == torch .float16 or b .dtype == torch .float16 else tl .float32 ))
124- grid = lambda META : (triton .cdiv (chunk_size , META ['BLOCK_SIZE_M' ]) * triton .cdiv (chunk_size , META ['BLOCK_SIZE_N' ]),
125- batch , nchunks if not has_groups else nchunks * ngroups )
225+ out = torch .empty (
226+ (batch , nchunks , chunk_size , chunk_size ) if not has_groups else
227+ (batch , nchunks , ngroups , chunk_size , chunk_size ),
228+ device = a .device ,
229+ dtype = out_dtype )
230+ dot_dtype = (tl .bfloat16
231+ if a .dtype == torch .bfloat16 or b .dtype == torch .bfloat16 else
232+ (tl .float16 if a .dtype == torch .float16
233+ or b .dtype == torch .float16 else tl .float32 ))
234+ grid = lambda META : (triton .cdiv (
235+ chunk_size , META ['BLOCK_SIZE_M' ]) * triton .cdiv (
236+ chunk_size , META ['BLOCK_SIZE_N' ]), batch , nchunks
237+ if not has_groups else nchunks * ngroups )
126238 with torch .cuda .device (a .device .index ):
127239 _bmm_chunk_fwd_kernel [grid ](
128- a , b , out , seq_idx ,
129- seqlen , chunk_size , k , ngroups if has_groups else 1 ,
130- a .stride (0 ), a .stride (1 ), 0 if not has_groups else a .stride (2 ), a .stride (- 1 ),
131- b .stride (0 ), b .stride (1 ), 0 if not has_groups else b .stride (2 ), b .stride (- 1 ),
132- out .stride (0 ), out .stride (1 ), 0 if not has_groups else out .stride (2 ), out .stride (- 2 ), out .stride (- 1 ),
133- * ((seq_idx .stride (0 ), seq_idx .stride (1 )) if seq_idx is not None else (0 , 0 )),
240+ a ,
241+ b ,
242+ out ,
243+ seq_idx ,
244+ seqlen ,
245+ chunk_size ,
246+ k ,
247+ ngroups if has_groups else 1 ,
248+ a .stride (0 ),
249+ a .stride (1 ),
250+ 0 if not has_groups else a .stride (2 ),
251+ a .stride (- 1 ),
252+ b .stride (0 ),
253+ b .stride (1 ),
254+ 0 if not has_groups else b .stride (2 ),
255+ b .stride (- 1 ),
256+ out .stride (0 ),
257+ out .stride (1 ),
258+ 0 if not has_groups else out .stride (2 ),
259+ out .stride (- 2 ),
260+ out .stride (- 1 ),
261+ * ((seq_idx .stride (0 ),
262+ seq_idx .stride (1 )) if seq_idx is not None else (0 , 0 )),
134263 causal ,
135264 dot_dtype ,
136265 HAS_SEQ_IDX = seq_idx is not None ,
0 commit comments