Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions tests/compile/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,36 @@ def check_after_ops(self, ops: Sequence[OpOverload]):
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"

def check_before_fused_auto_custom_ops(
self, ops: Sequence[tuple[OpOverload, bool]], fully_replaced=True
):
# currently only used for aiter custom ops that are
# registered with mutable scheme directly on vllm namespace
# while they are fused with auto_functionalized ops.

for op, target_op_only in ops:
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass, target_op_only)))
num_post = len(
list(find_op_nodes(op, self.graph_post_pass, target_op_only))
)
assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
assert num_pre > num_post, f"All nodes remain for op {op.name()}"
if fully_replaced:
assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph"

def check_after_fused_auto_custom_ops(self, ops: Sequence[tuple[OpOverload, bool]]):
# currently only used for aiter custom ops that
# are registered with mutable scheme directly on vllm namespace
# while they are fused with auto_functionalized ops.

for op, target_op_only in ops:
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass, target_op_only)))
num_post = len(
list(find_op_nodes(op, self.graph_post_pass, target_op_only))
)
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"

def op_count(self, op: OpOverload, before=False) -> int:
graph = self.graph_pre_pass if before else self.graph_post_pass
return len(list(find_op_nodes(op, graph)))
154 changes: 154 additions & 0 deletions tests/compile/test_rocm_aiter_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence

import pytest
import torch
from torch._ops import OpOverload

import vllm.plugins
from vllm.compilation.fusion import (
QUANT_OPS,
FusedRMSQuantKey,
)
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.compilation.rocm_aiter_rmsnorm_fusion import (
ROCM_AITER_FUSED_OPS,
RMSNormAiterQuantFusionPass,
)
from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
ScaleDesc,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
maybe_create_device_identity,
)
from vllm.platforms import current_platform

from .backend import TestBackend

FP8_DTYPE = current_platform.fp8_dtype()


class TestModel(torch.nn.Module):
def __init__(
self,
hidden_size: int,
eps: float,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
group_shape = GroupShape.PER_TOKEN
# AITER RMSNorm fusion pass does not support static quantization at the moment.
self.wscale = [
torch.rand(size=(hidden_size, 1), dtype=torch.float32) for _ in range(2)
]
quant_scale = ScaleDesc(torch.float32, static=False, group_shape=group_shape)
self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)

self.scale = [None for _ in range(2)]
self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(2)
]

self.fp8_linear = Fp8LinearOp(
act_quant_static=False,
act_quant_group_shape=group_shape,
)

def forward(self, x):
resid = torch.sqrt(x)
y = self.norm[0](x)

x2 = self.fp8_linear.apply(
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
)
# make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid)

x3 = self.fp8_linear.apply(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
)
y3, resid = self.norm[2](x3, resid) # use resid here
return y3

def ops_in_model_before(self) -> Sequence[tuple[OpOverload, bool]]:
# find fp8 quant ops in the model before fusion using
# its funcationalized version (without directly targeting the function).
return [(QUANT_OPS[self.key], False)]

def ops_in_model_after(self) -> Sequence[tuple[OpOverload, bool]]:
# find aiter rmsnorm fused ops in the model
# after fusion by directly targeting the function.

return [
(ROCM_AITER_FUSED_OPS[FusedRMSQuantKey(self.key, False)], True),
(ROCM_AITER_FUSED_OPS[FusedRMSQuantKey(self.key, True)], True),
]


@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [2048])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.skipif(not current_platform.is_rocm(), reason="Only test on ROCm")
def test_fusion_rmsnorm_quant(
dtype: torch.dtype,
hidden_size: int,
num_tokens: int,
eps: float,
monkeypatch: pytest.MonkeyPatch,
):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths

vllm_config = VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
custom_ops=["+rms_norm", "+quant_fp8"],
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
)
)
with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
m.setenv("VLLM_ROCM_USE_AITER", "1")
m.setenv("VLLM_ROCM_USE_AITER_LINEAR", "0")
m.setenv("VLLM_ROCM_USE_AITER_RMSNORM", "1")

# Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = RMSNormAiterQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)

backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
model = TestModel(hidden_size, eps)

# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)

result = model(x)

model2 = torch.compile(model, backend=backend)
result2 = model2(x)

ATOL, RTOL = (1e-2, 1e-2)

torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)

assert fusion_pass.matched_count == 2

# In pre-nodes, fp8 quant should be there and fused kernels should not
backend.check_before_fused_auto_custom_ops(model.ops_in_model_before())

# In post-nodes, fused kernels should be there and fp8 quant should not
backend.check_after_fused_auto_custom_ops(model.ops_in_model_after())
25 changes: 23 additions & 2 deletions vllm/compilation/fx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,29 @@ def find_getitem(node: fx.Node, idx: int) -> fx.Node:


# An auto-functionalization-aware utility for finding nodes with a specific op
def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]:
if not op._schema.is_mutable:
def find_op_nodes(
op: OpOverload, graph: fx.Graph, target_op_only: bool = False
) -> Iterator[fx.Node]:
"""
Yields all nodes in the graph that call the given op.
op (OpOverload):
The operator overload to match within the FX graph.
graph (fx.Graph):
The FX graph to search for nodes calling the specified operator.
target_op_only (bool):
If True, only yields nodes that directly call the specified operator.
If False, also yields nodes that call
the operator via auto_functionalized.
This is useful when `op`
is a mutable or custom-registered operator
that does not have an auto-functionalized version.
"""

# op can be mutable by default, not using auto_functionalized.
# op like aiter_rmsnorm_fused_dynamic_quant has mutable schema
# by default directly registered on vllm namespace.
# it is not auto functionalized.
if not op._schema.is_mutable or target_op_only:
yield from graph.find_nodes(op="call_function", target=op)

for n in graph.find_nodes(op="call_function", target=auto_functionalized):
Expand Down
9 changes: 9 additions & 0 deletions vllm/compilation/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
from .fusion import RMSNormQuantFusionPass
from .fusion_attn import AttnFusionPass

if current_platform.is_rocm():
from .rocm_aiter_rmsnorm_fusion import (
RMSNormAiterQuantFusionPass,
is_rocm_aiter_rmsnorm_enabled,
)

if current_platform.is_cuda():
from .collective_fusion import AllReduceFusionPass, AsyncTPPass

Expand Down Expand Up @@ -98,6 +104,9 @@ def configure(self, config: VllmConfig):
self.passes += [AllReduceFusionPass(config)]

if self.pass_config.enable_fusion:
if is_rocm_aiter_rmsnorm_enabled():
self.passes += [RMSNormAiterQuantFusionPass(config)]

self.passes += [RMSNormQuantFusionPass(config)]
self.passes += [ActivationQuantFusionPass(config)]

Expand Down
Loading
Loading