Skip to content

Commit b382a7f

Browse files
authored
[BugFix] Make FP8 Linear compatible with torch.compile (#13918)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent 4cb6fa0 commit b382a7f

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,12 +369,9 @@ def apply(self,
369369
size_k=layer.input_size_per_partition,
370370
bias=bias)
371371

372-
# Note: lazy import to avoid triton import error.
373-
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
374-
apply_w8a8_block_fp8_linear)
375372
if self.block_quant:
376373
assert self.quant_config.weight_block_size is not None
377-
return apply_w8a8_block_fp8_linear(
374+
return torch.ops.vllm.apply_w8a8_block_fp8_linear(
378375
input=x,
379376
weight=layer.weight,
380377
block_size=self.quant_config.weight_block_size,

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
1818
CUTLASS_BLOCK_FP8_SUPPORTED, CUTLASS_FP8_SUPPORTED, apply_fp8_linear)
1919
from vllm.platforms import current_platform
20+
from vllm.utils import direct_register_custom_op
2021

2122
logger = init_logger(__name__)
2223

@@ -81,6 +82,25 @@ def apply_w8a8_block_fp8_linear(
8182
return output.to(dtype=input.dtype).view(*output_shape)
8283

8384

85+
def apply_w8a8_block_fp8_linear_fake(
86+
input: torch.Tensor,
87+
weight: torch.Tensor,
88+
block_size: List[int],
89+
weight_scale: torch.Tensor,
90+
input_scale: Optional[torch.Tensor] = None,
91+
) -> torch.Tensor:
92+
output_shape = [*input.shape[:-1], weight.shape[0]]
93+
return torch.empty(output_shape, dtype=input.dtype, device=input.device)
94+
95+
96+
direct_register_custom_op(
97+
op_name="apply_w8a8_block_fp8_linear",
98+
op_func=apply_w8a8_block_fp8_linear,
99+
mutates_args=[],
100+
fake_impl=apply_w8a8_block_fp8_linear_fake,
101+
)
102+
103+
84104
# Unify the interface between `apply_w8a8_block_fp8_linear` and
85105
# `apply_fp8_linear`
86106
# NOTE(lucas): this is quite messy, we should think through this more formally

0 commit comments

Comments
 (0)