|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +import torch |
| 5 | +import triton |
| 6 | +import triton.language as tl |
| 7 | + |
| 8 | +from vllm.platforms import current_platform |
| 9 | + |
| 10 | + |
| 11 | +@triton.jit |
| 12 | +def reshape_and_cache_kernel_flash( |
| 13 | + key_ptr, # [num_tokens, num_heads, head_size] |
| 14 | + value_ptr, # [num_tokens, num_heads, head_size] |
| 15 | + key_cache_ptr, # [num_blocks, block_size, num_heads, head_size] |
| 16 | + value_cache_ptr, # [num_blocks, block_size, num_heads, head_size] |
| 17 | + slot_mapping_ptr, # [num_tokens] |
| 18 | + k_scale, # float32 |
| 19 | + v_scale, # float32 |
| 20 | + # strides |
| 21 | + key_stride: tl.int64, |
| 22 | + value_stride: tl.int64, |
| 23 | + block_stride: tl.int64, |
| 24 | + page_stride: tl.int64, |
| 25 | + num_heads: tl.constexpr, |
| 26 | + head_size: tl.constexpr, |
| 27 | + block_size: tl.constexpr, |
| 28 | + # FP8 flags |
| 29 | + FP8_KV_CACHE: tl.constexpr, |
| 30 | + # tune parameters |
| 31 | + TILE_SIZE: tl.constexpr, |
| 32 | +): |
| 33 | + |
| 34 | + token_idx = tl.program_id(axis=0) |
| 35 | + slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64) |
| 36 | + if slot_idx < 0: |
| 37 | + # Padding token that should be ignored. |
| 38 | + return |
| 39 | + |
| 40 | + tile_i = tl.program_id(axis=1) |
| 41 | + tile_offs = tl.arange(0, TILE_SIZE) |
| 42 | + tile_pos = tile_i * TILE_SIZE + tile_offs |
| 43 | + |
| 44 | + block_idx = slot_idx // block_size |
| 45 | + block_offset = slot_idx % block_size |
| 46 | + |
| 47 | + src_key_idx = token_idx * key_stride |
| 48 | + src_value_idx = token_idx * value_stride |
| 49 | + |
| 50 | + tgt_idx = block_idx * block_stride + block_offset * page_stride |
| 51 | + |
| 52 | + # [TILE_SIZE] |
| 53 | + key_load = tl.load(key_ptr + src_key_idx + tile_pos, |
| 54 | + mask=tile_pos < (num_heads * head_size)) |
| 55 | + if FP8_KV_CACHE: |
| 56 | + if key_load.dtype.is_fp8(): |
| 57 | + key_tile = key_load |
| 58 | + else: |
| 59 | + # tl.store will do the correct implicit cast to fp8, |
| 60 | + # based on the key_cache_ptr.dtype.element_ty |
| 61 | + key_tile = key_load / tl.load(k_scale) |
| 62 | + else: |
| 63 | + key_tile = key_load |
| 64 | + |
| 65 | + # [TILE_SIZE] |
| 66 | + value_load = tl.load(value_ptr + src_value_idx + tile_pos, |
| 67 | + mask=tile_pos < (num_heads * head_size)) |
| 68 | + if FP8_KV_CACHE: |
| 69 | + if value_load.dtype.is_fp8(): |
| 70 | + value_tile = value_load |
| 71 | + else: |
| 72 | + # tl.store will do the correct implicit cast to fp8, |
| 73 | + # based on the value_cache_ptr.dtype.element_ty |
| 74 | + value_tile = value_load / tl.load(v_scale) |
| 75 | + else: |
| 76 | + value_tile = value_load |
| 77 | + |
| 78 | + tl.store( |
| 79 | + key_cache_ptr + tgt_idx + tile_pos, |
| 80 | + key_tile, |
| 81 | + mask=tile_pos < (num_heads * head_size), |
| 82 | + ) |
| 83 | + tl.store( |
| 84 | + value_cache_ptr + tgt_idx + tile_pos, |
| 85 | + value_tile, |
| 86 | + mask=tile_pos < (num_heads * head_size), |
| 87 | + ) |
| 88 | + return |
| 89 | + |
| 90 | + |
| 91 | +def triton_reshape_and_cache_flash( |
| 92 | + key: torch.Tensor, # [num_tokens, num_heads, head_size] |
| 93 | + value: torch.Tensor, # [num_tokens, num_heads, head_size] |
| 94 | + # [num_blocks, block_size, num_heads, head_size] |
| 95 | + key_cache: torch.Tensor, |
| 96 | + # [num_blocks, block_size, num_heads, head_size] |
| 97 | + value_cache: torch.Tensor, |
| 98 | + slot_mapping: torch.Tensor, # [num_tokens] |
| 99 | + kv_cache_dtype: str, # "auto", "fp8" |
| 100 | + k_scale: torch.Tensor, # float32 |
| 101 | + v_scale: torch.Tensor, # float32 |
| 102 | +): |
| 103 | + num_tokens = key.shape[0] |
| 104 | + num_heads = key.shape[1] |
| 105 | + head_size = key.shape[2] |
| 106 | + block_size = key_cache.shape[1] |
| 107 | + n = num_heads * head_size |
| 108 | + |
| 109 | + key_stride = key.stride()[0] |
| 110 | + value_stride = value.stride()[0] |
| 111 | + block_stride = key_cache.stride()[0] |
| 112 | + page_stride = key_cache.stride()[1] |
| 113 | + |
| 114 | + head_stride = key_cache.stride()[2] |
| 115 | + assert head_stride == head_size, "only continous heads are supported" |
| 116 | + |
| 117 | + assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), \ |
| 118 | + f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}." |
| 119 | + kv_cache_torch_dtype = current_platform.fp8_dtype() if \ |
| 120 | + kv_cache_dtype.startswith("fp8") else key_cache.dtype |
| 121 | + |
| 122 | + if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith( |
| 123 | + "fp8"): |
| 124 | + # to avoid erounous implicit cast in triton kernel (tl.store to uint8) |
| 125 | + # (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4) |
| 126 | + key_cache = key_cache.view(kv_cache_torch_dtype) |
| 127 | + value_cache = value_cache.view(kv_cache_torch_dtype) |
| 128 | + assert kv_cache_dtype != torch.uint8, "explicit fp8 cast and store to "\ |
| 129 | + "uint8 is not supported by triton reshape_and_cache_flash" |
| 130 | + |
| 131 | + FP8_KV_CACHE = kv_cache_dtype.startswith("fp8") |
| 132 | + assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [ |
| 133 | + torch.float8_e4m3fn, torch.float8_e5m2, torch.uint8, |
| 134 | + torch.float8_e4m3fnuz], \ |
| 135 | + "unsupported dtype of KV cache tensor, got "\ |
| 136 | + "{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, " \ |
| 137 | + "fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz." |
| 138 | + |
| 139 | + # heuristics instead of autotuning |
| 140 | + TILE_SIZE = min(2048, triton.next_power_of_2(n)) |
| 141 | + if torch.version.hip: |
| 142 | + num_stages = 4 |
| 143 | + num_warps = 8 |
| 144 | + else: # cuda |
| 145 | + num_stages = 10 |
| 146 | + num_warps = 16 |
| 147 | + if torch.cuda.get_device_capability(key.device)[0] < 9: |
| 148 | + TILE_SIZE = min(512, TILE_SIZE) |
| 149 | + |
| 150 | + # TODO(ngl): maybe replace with static launch grid to avoid overhead if |
| 151 | + # using cudagraphs |
| 152 | + grid = lambda meta: (int(num_tokens), triton.cdiv(n, meta["TILE_SIZE"])) |
| 153 | + |
| 154 | + reshape_and_cache_kernel_flash[grid]( |
| 155 | + key_ptr=key, |
| 156 | + value_ptr=value, |
| 157 | + key_cache_ptr=key_cache, |
| 158 | + value_cache_ptr=value_cache, |
| 159 | + slot_mapping_ptr=slot_mapping, |
| 160 | + k_scale=k_scale, |
| 161 | + v_scale=v_scale, |
| 162 | + # strides |
| 163 | + key_stride=key_stride, |
| 164 | + value_stride=value_stride, |
| 165 | + block_stride=block_stride, |
| 166 | + page_stride=page_stride, |
| 167 | + num_heads=num_heads, |
| 168 | + head_size=head_size, |
| 169 | + block_size=block_size, |
| 170 | + # FP8 flags |
| 171 | + FP8_KV_CACHE=FP8_KV_CACHE, |
| 172 | + # autotune parameters |
| 173 | + TILE_SIZE=TILE_SIZE, |
| 174 | + num_warps=num_warps, |
| 175 | + num_stages=num_stages, |
| 176 | + ) |
0 commit comments