Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions vllm/compilation/activation_quant_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,32 @@ def silu_mul_replacement_static(result: torch.Tensor,
return at[1]


def silu_mul_mxfp4_gemm_pattern(result: torch.Tensor,
result_silu_mul: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at1 = auto_functionalized(torch.ops._C.silu_and_mul.default,
result=result_silu_mul,
input=input)
at2 = auto_functionalized(torch.ops.vllm.gemm_with_dynamic_quant.default,
result=result,
x=at1[1],
weight=weight,
weight_scale=scale,
x_scales=None)
return at2[1]


def silu_mul_mxfp4_gemm_replacement(result: torch.Tensor,
result_silu_mul: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(torch.ops.vllm.silu_and_mul_mxfp4_gemm.default,
result=result,
x=input,
weight=weight,
weight_scale=scale)
return at[1]


def empty_bf16(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")

Expand All @@ -51,6 +77,10 @@ def empty_fp32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")


def empty_fp4(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.uint8, device="cuda")


class ActivationQuantFusionPass(VllmInductorPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
Expand All @@ -76,6 +106,17 @@ def __init__(self, config: VllmConfig):
register_replacement(silu_mul_pattern_static,
silu_mul_replacement_static, inputs, fwd_only,
self.patterns)

inputs = [
empty_bf16(32, 32), # result
empty_bf16(32, 32), # result_silu_mul
empty_bf16(32, 32), # input
empty_fp4(32, 32), # weight
empty_fp4(32, 1), # scale
]
register_replacement(silu_mul_mxfp4_gemm_pattern,
silu_mul_mxfp4_gemm_replacement, inputs, fwd_only,
self.patterns)

def __call__(self, graph: torch.fx.Graph):
self.begin()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,74 +18,94 @@
from aiter.ops.shuffle import shuffle_weight
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
from aiter.ops.triton.quant import dynamic_mxfp4_quant
from aiter.ops.triton.activation import act_mul_and_mxfp4_quant

from vllm.utils import direct_register_custom_op
if envs.VLLM_TRITON_FP4_GEMM_USE_ASM:
from aiter import gemm_a4w4, per_1x32_f4_quant_hip

def gemm_with_dynamic_quant(
result: torch.Tensor,
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
x_scales: torch.Tensor = None,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
) -> torch.Tensor:
M = x.shape[0]
out_dtype: Optional[torch.dtype] = torch.bfloat16
) -> None:
if envs.VLLM_TRITON_FP4_GEMM_USE_ASM:
M = x.shape[0]
if x_scales is None:
# use hip quant kernel for performance
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
else:
x_q = x
x_s = x_scales

# 32 alignment is enough for dim0 padding of output for
# gemm_a4w4 kernel
y = torch.empty((M + 31) // 32 * 32,
weight.shape[0],
device=x_q.device,
dtype=out_dtype)

gemm_a4w4(x_q,
weight,
x_s,
weight_scale.view(x_s.dtype),
y,
bpreshuffle=True)
return y[:M]
result.copy_(y[:M])
else:
if x_scales is None:
x_q, x_s = dynamic_mxfp4_quant(x)
else:
x_q = x
x_s = x_scales
y = torch.empty(x_q.shape[0],
weight.shape[0],
device=x_q.device,
dtype=out_dtype)

gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
return y
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, result)

def gemm_with_dynamic_quant_fake(
result: torch.Tensor,
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
x_scales: torch.Tensor = None,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
) -> torch.Tensor:
return torch.empty((*x.shape[:-1], weight.shape[0]),
dtype=out_dtype,
device=x.device)
out_dtype: Optional[torch.dtype] = torch.bfloat16
) -> None:
return

direct_register_custom_op(
op_name="gemm_with_dynamic_quant",
op_func=gemm_with_dynamic_quant,
mutates_args=[],
mutates_args=['result'],
fake_impl=gemm_with_dynamic_quant_fake,
dispatch_key=current_platform.dispatch_key,
)

def silu_and_mul_mxfp4_gemm(
result: torch.Tensor,
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: Optional[torch.dtype] = torch.bfloat16
) -> None:
x_fp4, blockscale_e8m0 = act_mul_and_mxfp4_quant(x, 'silu')
gemm_with_dynamic_quant(result, x_fp4, weight, weight_scale, blockscale_e8m0, out_dtype)

def silu_and_mul_mxfp4_gemm_fake(
result: torch.Tensor,
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: Optional[torch.dtype] = torch.bfloat16
) -> None:
return

direct_register_custom_op(
op_name="silu_and_mul_mxfp4_gemm",
op_func=silu_and_mul_mxfp4_gemm,
mutates_args=['result'],
fake_impl=silu_and_mul_mxfp4_gemm_fake,
dispatch_key=current_platform.dispatch_key,
)

except ImportError:
dynamic_mxfp4_quant = gemm_afp4wfp4 = None

Expand Down Expand Up @@ -225,5 +245,7 @@ def apply_weights(self,

return F.linear(x, dq_w, bias)
else:
return torch.ops.vllm.gemm_with_dynamic_quant(
x, layer.weight, layer.weight_scale, x_quant_scales, self.out_dtype)
result = torch.empty((*x.shape[:-1], layer.weight.shape[0]), dtype=self.out_dtype, device=x.device)
torch.ops.vllm.gemm_with_dynamic_quant(
result, x, layer.weight, layer.weight_scale, x_quant_scales, self.out_dtype)
return result
Loading