Skip to content

Commit

Permalink
[Bugfix/Core] Flashinfer k_scale and v_scale (#9861)
Browse files Browse the repository at this point in the history
  • Loading branch information
pavanimajety authored Nov 1, 2024
1 parent aff1fd8 commit 598b6d7
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 12 deletions.
21 changes: 14 additions & 7 deletions tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,19 +258,20 @@ def test_reshape_and_cache_flash(
del key_caches
del value_caches

k_scale = key.amax().item() / 256
v_scale = value.amax().item() / 256

# Clone the KV caches.
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(cloned_key_cache, key_cache)
ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype)
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(cloned_value_cache, value_cache)
ops.convert_fp8(cloned_value_cache, value_cache, v_scale,
kv_cache_dtype)
else:
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()

# Using default kv_scale
k_scale = v_scale = 1.0

# Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
Expand All @@ -281,9 +282,15 @@ def test_reshape_and_cache_flash(

if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(result_key_cache, key_cache)
ops.convert_fp8(result_key_cache,
key_cache,
k_scale,
kv_dtype=kv_cache_dtype)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(result_value_cache, value_cache)
ops.convert_fp8(result_value_cache,
value_cache,
v_scale,
kv_dtype=kv_cache_dtype)

# Run the reference implementation.
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
Expand Down
9 changes: 6 additions & 3 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,8 +759,6 @@ def forward(
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashInfer.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
Expand Down Expand Up @@ -874,7 +872,12 @@ def unified_flash_infer(
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
prefill_output = prefill_meta.prefill_wrapper.forward(
query, kv_cache, logits_soft_cap=logits_soft_cap, causal=True)
query,
kv_cache,
logits_soft_cap=logits_soft_cap,
causal=True,
k_scale=k_scale,
v_scale=v_scale)
if decode_meta := attn_metadata.decode_metadata:
assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None
Expand Down
7 changes: 5 additions & 2 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,11 @@ def create_weights(
layer.register_parameter("input_scale", scale)

def process_weights_after_loading(self, layer: Module) -> None:
max_w_scale, weight = requantize_with_max_scale(
layer.weight, layer.weight_scale, layer.logical_widths)
weight = layer.weight
max_w_scale = layer.weight_scale.max()
if not (layer.weight_scale == layer.weight_scale[0]).all():
max_w_scale, weight = requantize_with_max_scale(
layer.weight, layer.weight_scale, layer.logical_widths)
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
layer.input_scale = Parameter(layer.input_scale.max(),
Expand Down

0 comments on commit 598b6d7

Please sign in to comment.