Skip to content

Commit b2dc5ca

Browse files
committed
fmt + lint
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
1 parent 742ae79 commit b2dc5ca

File tree

9 files changed

+1313
-429
lines changed

9 files changed

+1313
-429
lines changed

tests/models/decoder_only/language/test_bamba.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba.
22
3-
This actually is really indentical to test_mamba, so maybe we can reuse
3+
This actually is really identical to test_mamba, so maybe we can reuse
44
55
Run `pytest tests/models/decoder_only/language/test_bamba.py`.
66
"""
@@ -97,6 +97,7 @@ def test_batching(
9797
name_1="batched_vllm",
9898
)
9999

100+
100101
@pytest.mark.skip("bamba does not support chunked prefill yet")
101102
@pytest.mark.parametrize("model", MODELS)
102103
@pytest.mark.parametrize("dtype", ["float"])
@@ -122,6 +123,7 @@ def test_chunked_prefill_with_parallel_sampling(vllm_runner, example_prompts,
122123
) as vllm_model:
123124
vllm_model.generate(example_prompts, sampling_params)
124125

126+
125127
@pytest.mark.skip("bamba does not support chunked prefill yet")
126128
@pytest.mark.parametrize("model", MODELS)
127129
@pytest.mark.parametrize("dtype", ["float"])
@@ -205,7 +207,8 @@ def test_mamba_cache_cg_padding(
205207
# This test is for verifying that mamba cache is padded to CG captured
206208
# batch size. If it's not, a torch RuntimeError will be raised because
207209
# tensor dimensions aren't compatible
208-
while len(example_prompts) == VllmConfig.get_graph_batch_size(len(example_prompts)):
210+
while len(example_prompts) == VllmConfig.get_graph_batch_size(
211+
len(example_prompts)):
209212
example_prompts.append(example_prompts[0])
210213

211214
try:

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 127 additions & 106 deletions
Large diffs are not rendered by default.
Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
1+
# Copyright (c) 2024, Tri Dao, Albert Gu.
2+
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/softplus.py
3+
4+
# ruff: noqa: E501
5+
16
import triton
27
import triton.language as tl
38
from packaging import version
49

510
TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
611

7-
812
if TRITON3:
13+
914
@triton.jit
1015
def softplus(dt):
1116
return tl.math.log(tl.math.exp(dt) + 1)
1217
else:
18+
1319
@triton.jit
1420
def softplus(dt):
15-
return tl.math.log1p(tl.exp(dt))
21+
return tl.math.log1p(tl.exp(dt))

vllm/model_executor/layers/mamba/ops/ssd_bmm.py

Lines changed: 171 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,134 @@
11
# Copyright (c) 2024, Tri Dao, Albert Gu.
2+
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/ssd_bmm.py
23

4+
# ruff: noqa: E501,SIM102
35
"""We want triton==2.1.0 or 2.2.0 for this
46
"""
57

68
import math
7-
import torch
8-
import torch.nn.functional as F
99

10+
import torch
1011
import triton
1112
import triton.language as tl
1213

13-
from einops import rearrange, repeat
14-
1514

1615
def init_to_zero(names):
17-
return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
16+
return lambda nargs: [
17+
nargs[name].zero_() for name in names if nargs[name] is not None
18+
]
1819

1920

2021
@triton.autotune(
2122
configs=[
22-
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
23-
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
24-
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
25-
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
26-
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
27-
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
28-
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
29-
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
30-
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
23+
triton.Config(
24+
{
25+
'BLOCK_SIZE_M': 128,
26+
'BLOCK_SIZE_N': 256,
27+
'BLOCK_SIZE_K': 64
28+
},
29+
num_stages=3,
30+
num_warps=8),
31+
triton.Config(
32+
{
33+
'BLOCK_SIZE_M': 64,
34+
'BLOCK_SIZE_N': 256,
35+
'BLOCK_SIZE_K': 32
36+
},
37+
num_stages=4,
38+
num_warps=4),
39+
triton.Config(
40+
{
41+
'BLOCK_SIZE_M': 128,
42+
'BLOCK_SIZE_N': 128,
43+
'BLOCK_SIZE_K': 32
44+
},
45+
num_stages=4,
46+
num_warps=4),
47+
triton.Config(
48+
{
49+
'BLOCK_SIZE_M': 128,
50+
'BLOCK_SIZE_N': 64,
51+
'BLOCK_SIZE_K': 32
52+
},
53+
num_stages=4,
54+
num_warps=4),
55+
triton.Config(
56+
{
57+
'BLOCK_SIZE_M': 64,
58+
'BLOCK_SIZE_N': 128,
59+
'BLOCK_SIZE_K': 32
60+
},
61+
num_stages=4,
62+
num_warps=4),
63+
triton.Config(
64+
{
65+
'BLOCK_SIZE_M': 128,
66+
'BLOCK_SIZE_N': 32,
67+
'BLOCK_SIZE_K': 32
68+
},
69+
num_stages=4,
70+
num_warps=4),
71+
triton.Config(
72+
{
73+
'BLOCK_SIZE_M': 64,
74+
'BLOCK_SIZE_N': 32,
75+
'BLOCK_SIZE_K': 32
76+
},
77+
num_stages=5,
78+
num_warps=2),
79+
triton.Config(
80+
{
81+
'BLOCK_SIZE_M': 32,
82+
'BLOCK_SIZE_N': 64,
83+
'BLOCK_SIZE_K': 32
84+
},
85+
num_stages=5,
86+
num_warps=2),
87+
triton.Config(
88+
{
89+
'BLOCK_SIZE_M': 64,
90+
'BLOCK_SIZE_N': 64,
91+
'BLOCK_SIZE_K': 32
92+
},
93+
num_stages=4,
94+
num_warps=2),
3195
],
3296
key=['chunk_size', 'K', 'IS_CAUSAL'],
3397
)
3498
@triton.jit
3599
def _bmm_chunk_fwd_kernel(
36100
# Pointers to matrices
37-
a_ptr, b_ptr, out_ptr, seq_idx_ptr,
101+
a_ptr,
102+
b_ptr,
103+
out_ptr,
104+
seq_idx_ptr,
38105
# Matrix dimensions
39-
seqlen, chunk_size, K, ngroups,
40-
stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
41-
stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,
42-
stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,
43-
stride_seq_idx_batch, stride_seq_idx_seqlen,
106+
seqlen,
107+
chunk_size,
108+
K,
109+
ngroups,
110+
stride_a_batch,
111+
stride_a_seqlen,
112+
stride_a_head,
113+
stride_ak,
114+
stride_b_batch,
115+
stride_b_seqlen,
116+
stride_b_head,
117+
stride_bk,
118+
stride_out_batch,
119+
stride_out_chunk,
120+
stride_out_head,
121+
stride_outm,
122+
stride_outn,
123+
stride_seq_idx_batch,
124+
stride_seq_idx_seqlen,
44125
# Meta-parameters
45126
IS_CAUSAL: tl.constexpr,
46127
dot_dtype: tl.constexpr,
47128
HAS_SEQ_IDX: tl.constexpr,
48-
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
129+
BLOCK_SIZE_M: tl.constexpr,
130+
BLOCK_SIZE_N: tl.constexpr,
131+
BLOCK_SIZE_K: tl.constexpr,
49132
):
50133
pid_b = tl.program_id(axis=1)
51134
pid_ch = tl.program_id(axis=2).to(tl.int64)
@@ -65,14 +148,22 @@ def _bmm_chunk_fwd_kernel(
65148
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
66149
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
67150
offs_k = tl.arange(0, BLOCK_SIZE_K)
68-
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
69-
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
151+
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen +
152+
offs_k[None, :] * stride_ak)
153+
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk +
154+
offs_n[None, :] * stride_b_seqlen)
70155
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
71156

72157
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
73158
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
74-
a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype)
75-
b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype)
159+
a = tl.load(a_ptrs,
160+
mask=(offs_m[:, None] < chunk_size_limit) &
161+
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
162+
other=0.0).to(dot_dtype)
163+
b = tl.load(b_ptrs,
164+
mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) &
165+
(offs_n[None, :] < chunk_size_limit),
166+
other=0.0).to(dot_dtype)
76167
acc += tl.dot(a, b)
77168
a_ptrs += BLOCK_SIZE_K * stride_ak
78169
b_ptrs += BLOCK_SIZE_K * stride_bk
@@ -81,16 +172,30 @@ def _bmm_chunk_fwd_kernel(
81172
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
82173
if HAS_SEQ_IDX:
83174
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
84-
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
85-
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)
175+
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
176+
mask=offs_m < chunk_size_limit,
177+
other=-1)
178+
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen,
179+
mask=offs_n < chunk_size_limit,
180+
other=-2)
86181
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
87182
out = acc.to(out_ptr.dtype.element_ty)
88183

89184
out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
90-
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
91-
tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))
185+
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] +
186+
offs_n[None, :] * stride_outn)
187+
tl.store(out_ptrs,
188+
out,
189+
mask=(offs_m[:, None] < chunk_size) &
190+
(offs_n[None, :] < chunk_size))
191+
92192

93-
def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):
193+
def _bmm_chunk_fwd(a,
194+
b,
195+
chunk_size,
196+
seq_idx=None,
197+
causal=False,
198+
output_dtype=None):
94199
"""
95200
Argument:
96201
a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
@@ -117,20 +222,44 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=No
117222
nchunks = math.ceil(seqlen / chunk_size)
118223
# Allocates output.
119224
out_dtype = a.dtype if output_dtype is None else output_dtype
120-
out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),
121-
device=a.device, dtype=out_dtype)
122-
dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
123-
(tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))
124-
grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),
125-
batch, nchunks if not has_groups else nchunks * ngroups)
225+
out = torch.empty(
226+
(batch, nchunks, chunk_size, chunk_size) if not has_groups else
227+
(batch, nchunks, ngroups, chunk_size, chunk_size),
228+
device=a.device,
229+
dtype=out_dtype)
230+
dot_dtype = (tl.bfloat16
231+
if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
232+
(tl.float16 if a.dtype == torch.float16
233+
or b.dtype == torch.float16 else tl.float32))
234+
grid = lambda META: (triton.cdiv(
235+
chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
236+
chunk_size, META['BLOCK_SIZE_N']), batch, nchunks
237+
if not has_groups else nchunks * ngroups)
126238
with torch.cuda.device(a.device.index):
127239
_bmm_chunk_fwd_kernel[grid](
128-
a, b, out, seq_idx,
129-
seqlen, chunk_size, k, ngroups if has_groups else 1,
130-
a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
131-
b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),
132-
out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),
133-
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
240+
a,
241+
b,
242+
out,
243+
seq_idx,
244+
seqlen,
245+
chunk_size,
246+
k,
247+
ngroups if has_groups else 1,
248+
a.stride(0),
249+
a.stride(1),
250+
0 if not has_groups else a.stride(2),
251+
a.stride(-1),
252+
b.stride(0),
253+
b.stride(1),
254+
0 if not has_groups else b.stride(2),
255+
b.stride(-1),
256+
out.stride(0),
257+
out.stride(1),
258+
0 if not has_groups else out.stride(2),
259+
out.stride(-2),
260+
out.stride(-1),
261+
*((seq_idx.stride(0),
262+
seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
134263
causal,
135264
dot_dtype,
136265
HAS_SEQ_IDX=seq_idx is not None,

0 commit comments

Comments
 (0)