Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
fdd9cbd
cherry pick 09133e9833811778240b3c2cc4de2390fd08e470; and only add AI…
vllmellm Feb 26, 2025
668ec2f
cherry pick acc27ffa94e677b8f6fce0f5b593430ce6acbfe4; and only add AI…
vllmellm Mar 5, 2025
8d49d6e
bug fixes and pass unit tests
tjtanaa Mar 17, 2025
43af6c0
add AITER setup steps in Dockerfile.rocm_base
tjtanaa Mar 17, 2025
0c30ce9
remove AITER setup steps in Dockerfile.rocm
tjtanaa Mar 17, 2025
ab73f97
Merge remote-tracking branch 'origin/main' into aiter-linear
tjtanaa Mar 17, 2025
e952b2d
fix missing property from Platform
tjtanaa Mar 17, 2025
6a632ac
skip AITER in AMD CI
tjtanaa Mar 17, 2025
61c92a9
Merge remote-tracking branch 'origin/main' into aiter-linear
tjtanaa Mar 20, 2025
0224eff
merge with main
tjtanaa Apr 16, 2025
d2ed934
revert run-amd-test.sh; update Dockerfile.rocm_base aiter version, re…
tjtanaa Apr 16, 2025
3fec588
clean up spaces and newline; fix typo
tjtanaa Apr 16, 2025
3558099
clean up spaces and newline;
tjtanaa Apr 16, 2025
2bf7206
fix typo
tjtanaa Apr 16, 2025
1f979fa
untested refactoring
tjtanaa Apr 17, 2025
f13746c
fix bug; validated to work V1 AITER unquantized and quantized
tjtanaa Apr 19, 2025
20139af
relocate the linear helper function into aiter_ops and fix unittest
tjtanaa Apr 19, 2025
700ac73
add test_aiter_ops.py to unit test if the ops are registered correctl…
tjtanaa Apr 19, 2025
7dd2812
fix the test to test fake tensor implementation
tjtanaa Apr 20, 2025
d9f0e7b
use current_platform.fp8_dtype(); update aiter commit
tjtanaa Apr 21, 2025
dde9157
merge with main; fix dispatcher and unit tests
tjtanaa Apr 22, 2025
e34712c
remove is_rocm_aiter_xxxx_enabled flag from _aiter_ops.py
tjtanaa Apr 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tests/model_executor/test_enabled_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm.model_executor.layers.layernorm import (
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
from vllm.platforms import current_platform


Expand Down Expand Up @@ -96,6 +97,27 @@ def test_enabled_ops_invalid(env: str):
RMSNorm(1024).enabled()


@pytest.mark.skipif(not current_platform.is_rocm(),
reason="AITER is a feature exclusive for ROCm")
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_linear", ["0", "1"])
def test_unquantized_linear_dispatch(use_rocm_aiter: str,
use_rocm_aiter_linear: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", use_rocm_aiter_linear)

linear_func = dispatch_unquantized_gemm()
print(f"use_rocm_aiter: {use_rocm_aiter}, " +
f"use_rocm_aiter_linear: {use_rocm_aiter_linear}")
if current_platform.is_rocm() and int(use_rocm_aiter) and int(
use_rocm_aiter_linear):
from vllm._aiter_ops import aiter_ops
assert linear_func == aiter_ops.rocm_aiter_tuned_gemm
else:
from vllm.model_executor.layers.utils import rocm_unquantized_gemm
assert linear_func == rocm_unquantized_gemm


@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
Expand Down
1 change: 0 additions & 1 deletion tests/quantization/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
use_rocm_aiter: bool, monkeypatch) -> None:

if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

Expand Down
145 changes: 145 additions & 0 deletions tests/v1/rocm/test_aiter_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# SPDX-License-Identifier: Apache-2.0
# This is a test for the aiter ops.
# It tests if the aiter ops are
# 1. correctly registered as custom ops
# 2. correctly defined the relationship between
# implementation and fake function
# 3. can be used with torch.compile
# This file will be skipped if aiter is not installed
# and the platform is not ROCm.
#
# NOTE:
# This unit tests is by no means to check the
# correctness of the aiter ops. It only checks if the
# aiter ops are correctly registered and if torch.compile
# can be used with the aiter ops.
# The correctness of the aiter ops is tested in the
# https://github.com/ROCm/aiter

import importlib.util

import pytest
import torch

from vllm._aiter_ops import aiter_ops
from vllm.platforms import current_platform

# Check if aiter package is installed
aiter_available = importlib.util.find_spec("aiter") is not None

pytestmark = pytest.mark.skipif(
not (current_platform.is_rocm() and aiter_available),
reason="AITER ops are only available on ROCm with aiter package installed")


def test_rocm_aiter_tuned_gemm_custom_op_registration():
"""Test that the custom op is correctly registered."""
# Check if the op exists in torch.ops.vllm
assert hasattr(torch.ops.vllm, 'rocm_aiter_tuned_gemm')

# Check if the op is callable
assert callable(torch.ops.vllm.rocm_aiter_tuned_gemm)


def test_rocm_aiter_tuned_gemm_torch_compile_compatibility():
"""Test that the op can be used with torch.compile."""
# Create test tensors
input_tensor = torch.randn(64, 32, dtype=torch.float16, device='cuda')
weight_tensor = torch.randn(16, 32, dtype=torch.float16, device='cuda')

# Define a function that uses the op
def gemm_fn(x, w):
return aiter_ops.rocm_aiter_tuned_gemm(x, w)

# Verify the op's fake implementation
torch.library.opcheck(torch.ops.vllm.rocm_aiter_tuned_gemm,
(input_tensor, weight_tensor),
test_utils=("test_schema", "test_faketensor"))

# Compile the function with appropriate settings based on
# vllm/compilation/wrapper.py
compiled_fn = torch.compile(gemm_fn,
fullgraph=True,
backend="inductor",
mode="reduce-overhead",
dynamic=False)

# Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
result_original = gemm_fn(input_tensor, weight_tensor)
result_compiled = compiled_fn(input_tensor, weight_tensor)

# Verify results match
assert torch.allclose(result_original, result_compiled)


def test_rocm_aiter_tuned_gemm_torch_compile_fp8_compatibility():

input_tensor = torch.randn(64, 32, dtype=torch.float16, device='cuda')
weight_tensor = torch.randn(16, 32, dtype=torch.float16, device='cuda')

input_fp8 = input_tensor.to(current_platform.fp8_dtype())
weight_fp8 = weight_tensor.to(current_platform.fp8_dtype())

scale_a = torch.tensor(10.0, device='cuda')
scale_b = torch.tensor(0.5, device='cuda')

# Define a function that uses the op with FP8 and scales
def gemm_fp8_fn(x, w, scale_a, scale_b):
return aiter_ops.rocm_aiter_tuned_gemm(x,
w,
out_dtype=torch.float16,
scale_a=scale_a,
scale_b=scale_b)

# Verify the op's fake implementation with FP8 inputs
# Disable test_schema as fp8 datatype is not supported by
# torch.library.opcheck
# Related error:
# OpCheckError: opcheck(op, ...): test_schema failed with
# "mul_cuda" not implemented for 'Float8_e4m3fnuz'
torch.library.opcheck(torch.ops.vllm.rocm_aiter_tuned_gemm,
(input_fp8, weight_fp8),
kwargs={
"out_dtype": torch.float16,
"scale_a": scale_a,
"scale_b": scale_b
},
test_utils=("test_faketensor"))

# Compile the function with appropriate settings based on
# vllm/compilation/wrapper.py
compiled_fp8_fn = torch.compile(gemm_fp8_fn,
fullgraph=True,
backend="inductor",
mode="reduce-overhead",
dynamic=False)

# Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
result_original = gemm_fp8_fn(input_fp8, weight_fp8, scale_a, scale_b)
result_compiled = compiled_fp8_fn(input_fp8, weight_fp8, scale_a, scale_b)

# Verify results match and have correct properties
assert torch.allclose(result_original, result_compiled)
assert result_original.dtype == torch.float16
assert result_compiled.dtype == torch.float16
assert result_original.shape == (64, 16)
assert result_compiled.shape == (64, 16)

# Get unscaled result
unscaled_result = aiter_ops.rocm_aiter_tuned_gemm(
input_fp8.to(torch.float16),
weight_fp8.to(torch.float16),
out_dtype=torch.float16)

# Verify that scaling was applied correctly
# The scaled result should be approximately equal to the
# unscaled result multiplied by the scales
expected_scaled = unscaled_result * (scale_a * scale_b)
assert torch.allclose(result_original,
expected_scaled,
rtol=1e-2,
atol=1e-2)

# Verify that scaled and unscaled results are different
assert not torch.allclose(
result_original, unscaled_result, rtol=1e-2, atol=1e-2)
76 changes: 76 additions & 0 deletions vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional

import torch

from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op


def rocm_aiter_tuned_gemm_impl(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
scale_a: Optional[torch.Tensor] = None,
scale_b: Optional[torch.Tensor] = None) -> torch.Tensor:

# This AITER function can be used for
# - BF16 and FP16 matmul
# e.g. vllm/model_executor/layers/linear.py
# - per-tensor activations + per-tensor weights
# e.g. vllm/model_executor/layers/quantization/utils/w8a8_utils.py
from aiter.tuned_gemm import tgemm as aiter_tgemm

return aiter_tgemm.mm(input,
weight,
otype=out_dtype,
scale_a=scale_a,
scale_b=scale_b,
bias=bias)


def rocm_aiter_tuned_gemm_fake(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
scale_a: Optional[torch.Tensor] = None,
scale_b: Optional[torch.Tensor] = None) -> torch.Tensor:

m = input.shape[0]
n = weight.shape[0]
if out_dtype is None:
out_dtype = input.dtype
return torch.empty((m, n), dtype=out_dtype, device=input.device)


if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_tuned_gemm",
op_func=rocm_aiter_tuned_gemm_impl,
mutates_args=[],
fake_impl=rocm_aiter_tuned_gemm_fake,
dispatch_key=current_platform.dispatch_key,
)


class aiter_ops:

@staticmethod
def rocm_aiter_tuned_gemm(
input: torch.Tensor, # [M, K]
weight: torch.Tensor, # [N, K]
bias: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
scale_a: Optional[torch.Tensor] = None,
scale_b: Optional[torch.Tensor] = None) -> torch.Tensor:

return torch.ops.vllm.rocm_aiter_tuned_gemm(
input,
weight,
bias=bias,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b,
)
10 changes: 4 additions & 6 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,8 +528,7 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
"VLLM_USE_V1":
lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))),

# Disable aiter ops unless specifically enabled.
# Acts as a parent switch to enable the rest of the other operations.
# use aiter ops unless specifically disabled
"VLLM_ROCM_USE_AITER":
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
("true", "1")),
Expand All @@ -541,11 +540,10 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
("true", "1")),

# use aiter linear op if aiter ops are enabled
# The following list of related ops
# - scaled_mm (per-tensor / rowwise)
"VLLM_ROCM_USE_AITER_LINEAR":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True").lower() in
("true", "1")),
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
("true", "1") and os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True"
).lower() in ("true", "1")),

# Whether to use aiter moe ops.
# By default is enabled.
Expand Down
6 changes: 5 additions & 1 deletion vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ def apply(self,
class UnquantizedLinearMethod(LinearMethodBase):
"""Linear method without quantization."""

def __init__(self):
super().__init__()
self._gemm_func = dispatch_unquantized_gemm()

def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int], input_size: int,
Expand All @@ -199,7 +203,7 @@ def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

return dispatch_unquantized_gemm()(x, layer.weight, bias)
return self._gemm_func(x, layer.weight, bias)


class LinearBase(torch.nn.Module):
Expand Down
Loading