Skip to content

Commit 26da2c6

Browse files
bringleincyang49yewentao256
committed
[V1][Kernel] Add triton implementation for reshape_and_cache_flash (#24503)
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent 0081c69 commit 26da2c6

File tree

4 files changed

+276
-20
lines changed

4 files changed

+276
-20
lines changed

benchmarks/kernels/benchmark_reshape_and_cache_flash.py

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from tabulate import tabulate
1010

1111
from vllm import _custom_ops as ops
12+
from vllm.attention.ops.triton_reshape_and_cache_flash import (
13+
triton_reshape_and_cache_flash,
14+
)
1215
from vllm.logger import init_logger
1316
from vllm.platforms import current_platform
1417
from vllm.utils import (
@@ -31,13 +34,23 @@ def run_benchmark(
3134
kv_cache_dtype: str,
3235
kv_cache_layout: str,
3336
num_iters: int,
37+
implementation: str,
38+
benchmark_mode: str,
3439
device: str = "cuda",
3540
) -> float:
3641
"""Return latency (seconds) for given num_tokens."""
3742

3843
if kv_cache_dtype == "fp8" and head_size % 16:
3944
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")
4045

46+
if implementation not in ("cuda", "triton"):
47+
raise ValueError(
48+
f"Unsupported implementation: {implementation}. "
49+
"Only 'cuda' and 'triton' are supported."
50+
)
51+
if implementation == "triton" and kv_cache_layout == "HND":
52+
return float("nan") # Triton does not support HND layout yet.
53+
4154
current_platform.seed_everything(42)
4255
torch.set_default_device(device)
4356

@@ -65,27 +78,49 @@ def run_benchmark(
6578
cache_layout=kv_cache_layout,
6679
)
6780
key_cache, value_cache = key_caches[0], value_caches[0]
81+
# to free unused memory
82+
del key_caches, value_caches
6883

6984
# compute per-kernel scaling factors for fp8 conversion (if used).
7085
k_scale = (key.amax() / 64.0).to(torch.float32)
7186
v_scale = (value.amax() / 64.0).to(torch.float32)
7287

88+
if implementation == "cuda":
89+
function_under_test = lambda: ops.reshape_and_cache_flash(
90+
key, # noqa: F821
91+
value, # noqa: F821
92+
key_cache, # noqa: F821
93+
value_cache, # noqa: F821
94+
slot_mapping, # noqa: F821
95+
kv_cache_dtype,
96+
k_scale,
97+
v_scale,
98+
)
99+
else:
100+
function_under_test = lambda: triton_reshape_and_cache_flash(
101+
key, # noqa: F821
102+
value, # noqa: F821
103+
key_cache, # noqa: F821
104+
value_cache, # noqa: F821
105+
slot_mapping, # noqa: F821
106+
kv_cache_dtype,
107+
k_scale,
108+
v_scale,
109+
)
110+
if benchmark_mode == "cudagraph":
111+
g = torch.cuda.CUDAGraph()
112+
with torch.cuda.graph(g):
113+
function_under_test()
114+
torch.cuda.synchronize()
115+
function_under_test = lambda: g.replay()
116+
73117
def run_cuda_benchmark(n_iters: int) -> float:
74118
nonlocal key, value, key_cache, value_cache, slot_mapping
75119
torch.cuda.synchronize()
76120
start = time.perf_counter()
77121
for _ in range(n_iters):
78-
ops.reshape_and_cache_flash(
79-
key,
80-
value,
81-
key_cache,
82-
value_cache,
83-
slot_mapping,
84-
kv_cache_dtype,
85-
k_scale,
86-
v_scale,
87-
)
88-
torch.cuda.synchronize()
122+
function_under_test()
123+
torch.cuda.synchronize()
89124
end = time.perf_counter()
90125
return (end - start) / n_iters
91126

@@ -116,10 +151,16 @@ def main(args):
116151
kv_cache_dtype=args.kv_cache_dtype,
117152
kv_cache_layout=layout,
118153
num_iters=args.iters,
154+
implementation=args.implementation,
155+
benchmark_mode=args.mode,
119156
device="cuda",
120157
)
121158
rows.append([n_tok, layout, f"{lat * 1e6:.3f}"])
122159

160+
print(
161+
f"Benchmark results for implementation {args.implementation}"
162+
f" (measuring with {args.mode}):"
163+
)
123164
print(tabulate(rows, headers=["num_tokens", "layout", "latency (µs)"]))
124165

125166

@@ -151,6 +192,21 @@ def main(args):
151192
)
152193

153194
parser.add_argument("--iters", type=int, default=100)
195+
196+
parser.add_argument(
197+
"--implementation",
198+
type=str,
199+
choices=["cuda", "triton"],
200+
default="cuda",
201+
)
202+
203+
parser.add_argument(
204+
"--mode",
205+
type=str,
206+
choices=["cudagraph", "no_graph"],
207+
default="cudagraph",
208+
)
209+
154210
args = parser.parse_args()
155211

156212
main(args)

tests/kernels/attention/test_cache.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
# We assume fp8 is always enabled for testing.
4040
KV_CACHE_DTYPE = ["auto", "fp8"]
4141

42+
RESHAPE_FLASH_IMPLEMENTATIONS = ["cuda", "triton"]
43+
4244

4345
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
4446
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@@ -223,6 +225,7 @@ def test_reshape_and_cache(
223225
@pytest.mark.parametrize("device", CUDA_DEVICES)
224226
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
225227
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
228+
@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS)
226229
@torch.inference_mode()
227230
def test_reshape_and_cache_flash(
228231
kv_cache_factory_flashinfer,
@@ -236,9 +239,13 @@ def test_reshape_and_cache_flash(
236239
device: str,
237240
kv_cache_dtype: str,
238241
kv_cache_layout: str,
242+
implementation: str,
239243
) -> None:
240244
current_platform.seed_everything(seed)
241245
torch.set_default_device(device)
246+
assert implementation in ["cuda", "triton"]
247+
if implementation == "triton" and kv_cache_layout == "HND":
248+
pytest.skip("Triton implementation only supports NHD layout.")
242249

243250
# fp8 conversion requires continugous memory buffer. Reduce the number of
244251
# blocks and tokens to consume less memory.
@@ -298,12 +305,20 @@ def permute_and_compact(x):
298305
cloned_key_cache = key_cache_compact.clone()
299306
cloned_value_cache = value_cache_compact.clone()
300307
# Call the reshape_and_cache kernel.
301-
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
302-
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
303-
k_scale, v_scale),
304-
cond=(head_size == HEAD_SIZES[0]))
305-
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
306-
slot_mapping, kv_cache_dtype, k_scale, v_scale)
308+
if implementation == "cuda":
309+
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
310+
(key, value, key_cache, value_cache, slot_mapping,
311+
kv_cache_dtype, k_scale, v_scale),
312+
cond=(head_size == HEAD_SIZES[0]))
313+
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
314+
slot_mapping, kv_cache_dtype, k_scale,
315+
v_scale)
316+
elif implementation == "triton":
317+
from vllm.attention.ops.triton_reshape_and_cache_flash import (
318+
triton_reshape_and_cache_flash)
319+
triton_reshape_and_cache_flash(key, value, key_cache, value_cache,
320+
slot_mapping, kv_cache_dtype, k_scale,
321+
v_scale)
307322
key_cache_compact = permute_and_compact(key_cache)
308323
value_cache_compact = permute_and_compact(value_cache)
309324

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import torch
5+
import triton
6+
import triton.language as tl
7+
8+
from vllm.platforms import current_platform
9+
10+
11+
@triton.jit
12+
def reshape_and_cache_kernel_flash(
13+
key_ptr, # [num_tokens, num_heads, head_size]
14+
value_ptr, # [num_tokens, num_heads, head_size]
15+
key_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
16+
value_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
17+
slot_mapping_ptr, # [num_tokens]
18+
k_scale, # float32
19+
v_scale, # float32
20+
# strides
21+
key_stride: tl.int64,
22+
value_stride: tl.int64,
23+
block_stride: tl.int64,
24+
page_stride: tl.int64,
25+
num_heads: tl.constexpr,
26+
head_size: tl.constexpr,
27+
block_size: tl.constexpr,
28+
# FP8 flags
29+
FP8_KV_CACHE: tl.constexpr,
30+
# tune parameters
31+
TILE_SIZE: tl.constexpr,
32+
):
33+
34+
token_idx = tl.program_id(axis=0)
35+
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64)
36+
if slot_idx < 0:
37+
# Padding token that should be ignored.
38+
return
39+
40+
tile_i = tl.program_id(axis=1)
41+
tile_offs = tl.arange(0, TILE_SIZE)
42+
tile_pos = tile_i * TILE_SIZE + tile_offs
43+
44+
block_idx = slot_idx // block_size
45+
block_offset = slot_idx % block_size
46+
47+
src_key_idx = token_idx * key_stride
48+
src_value_idx = token_idx * value_stride
49+
50+
tgt_idx = block_idx * block_stride + block_offset * page_stride
51+
52+
# [TILE_SIZE]
53+
key_load = tl.load(key_ptr + src_key_idx + tile_pos,
54+
mask=tile_pos < (num_heads * head_size))
55+
if FP8_KV_CACHE:
56+
if key_load.dtype.is_fp8():
57+
key_tile = key_load
58+
else:
59+
# tl.store will do the correct implicit cast to fp8,
60+
# based on the key_cache_ptr.dtype.element_ty
61+
key_tile = key_load / tl.load(k_scale)
62+
else:
63+
key_tile = key_load
64+
65+
# [TILE_SIZE]
66+
value_load = tl.load(value_ptr + src_value_idx + tile_pos,
67+
mask=tile_pos < (num_heads * head_size))
68+
if FP8_KV_CACHE:
69+
if value_load.dtype.is_fp8():
70+
value_tile = value_load
71+
else:
72+
# tl.store will do the correct implicit cast to fp8,
73+
# based on the value_cache_ptr.dtype.element_ty
74+
value_tile = value_load / tl.load(v_scale)
75+
else:
76+
value_tile = value_load
77+
78+
tl.store(
79+
key_cache_ptr + tgt_idx + tile_pos,
80+
key_tile,
81+
mask=tile_pos < (num_heads * head_size),
82+
)
83+
tl.store(
84+
value_cache_ptr + tgt_idx + tile_pos,
85+
value_tile,
86+
mask=tile_pos < (num_heads * head_size),
87+
)
88+
return
89+
90+
91+
def triton_reshape_and_cache_flash(
92+
key: torch.Tensor, # [num_tokens, num_heads, head_size]
93+
value: torch.Tensor, # [num_tokens, num_heads, head_size]
94+
# [num_blocks, block_size, num_heads, head_size]
95+
key_cache: torch.Tensor,
96+
# [num_blocks, block_size, num_heads, head_size]
97+
value_cache: torch.Tensor,
98+
slot_mapping: torch.Tensor, # [num_tokens]
99+
kv_cache_dtype: str, # "auto", "fp8"
100+
k_scale: torch.Tensor, # float32
101+
v_scale: torch.Tensor, # float32
102+
):
103+
num_tokens = key.shape[0]
104+
num_heads = key.shape[1]
105+
head_size = key.shape[2]
106+
block_size = key_cache.shape[1]
107+
n = num_heads * head_size
108+
109+
key_stride = key.stride()[0]
110+
value_stride = value.stride()[0]
111+
block_stride = key_cache.stride()[0]
112+
page_stride = key_cache.stride()[1]
113+
114+
head_stride = key_cache.stride()[2]
115+
assert head_stride == head_size, "only continous heads are supported"
116+
117+
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), \
118+
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
119+
kv_cache_torch_dtype = current_platform.fp8_dtype() if \
120+
kv_cache_dtype.startswith("fp8") else key_cache.dtype
121+
122+
if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith(
123+
"fp8"):
124+
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
125+
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
126+
key_cache = key_cache.view(kv_cache_torch_dtype)
127+
value_cache = value_cache.view(kv_cache_torch_dtype)
128+
assert kv_cache_dtype != torch.uint8, "explicit fp8 cast and store to "\
129+
"uint8 is not supported by triton reshape_and_cache_flash"
130+
131+
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
132+
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
133+
torch.float8_e4m3fn, torch.float8_e5m2, torch.uint8,
134+
torch.float8_e4m3fnuz], \
135+
"unsupported dtype of KV cache tensor, got "\
136+
"{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, " \
137+
"fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz."
138+
139+
# heuristics instead of autotuning
140+
TILE_SIZE = min(2048, triton.next_power_of_2(n))
141+
if torch.version.hip:
142+
num_stages = 4
143+
num_warps = 8
144+
else: # cuda
145+
num_stages = 10
146+
num_warps = 16
147+
if torch.cuda.get_device_capability(key.device)[0] < 9:
148+
TILE_SIZE = min(512, TILE_SIZE)
149+
150+
# TODO(ngl): maybe replace with static launch grid to avoid overhead if
151+
# using cudagraphs
152+
grid = lambda meta: (int(num_tokens), triton.cdiv(n, meta["TILE_SIZE"]))
153+
154+
reshape_and_cache_kernel_flash[grid](
155+
key_ptr=key,
156+
value_ptr=value,
157+
key_cache_ptr=key_cache,
158+
value_cache_ptr=value_cache,
159+
slot_mapping_ptr=slot_mapping,
160+
k_scale=k_scale,
161+
v_scale=v_scale,
162+
# strides
163+
key_stride=key_stride,
164+
value_stride=value_stride,
165+
block_stride=block_stride,
166+
page_stride=page_stride,
167+
num_heads=num_heads,
168+
head_size=head_size,
169+
block_size=block_size,
170+
# FP8 flags
171+
FP8_KV_CACHE=FP8_KV_CACHE,
172+
# autotune parameters
173+
TILE_SIZE=TILE_SIZE,
174+
num_warps=num_warps,
175+
num_stages=num_stages,
176+
)

0 commit comments

Comments
 (0)