Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion Dockerfile.rocm_base
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="b7d29fb"
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
ARG AITER_BRANCH="21d47a9"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"

FROM ${BASE_IMAGE} AS base

Expand Down Expand Up @@ -129,8 +131,18 @@ RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
pip install /install/*.whl

ARG AITER_REPO
ARG AITER_BRANCH
RUN git clone --recursive ${AITER_REPO}
RUN cd aiter \
&& git checkout ${AITER_BRANCH} \
&& git submodule update --init --recursive \
&& pip install -r requirements.txt \
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter

ARG BASE_IMAGE
ARG HIPBLASLT_BRANCH
ARG HIPBLAS_COMMON_BRANCH
ARG LEGACY_HIPBLASLT_OPTION
ARG RCCL_BRANCH
ARG RCCL_REPO
Expand All @@ -155,4 +167,6 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
&& echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \
&& echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt
29 changes: 28 additions & 1 deletion tests/model_executor/test_enabled_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation,
SiluAndMul)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import (
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
from vllm.platforms import current_platform


# Registered subclass for test
Expand Down Expand Up @@ -87,3 +90,27 @@ def test_enabled_ops_invalid(env: str):
custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
RMSNorm(1024).enabled()


@pytest.mark.parametrize("add_residual", [True, False])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="AITER is a feature exclusive for ROCm")
def test_rms_norm_dispatch(add_residual: bool, use_rocm_aiter: str,
use_rocm_aiter_norm: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm)
rms_norm_func = dispatch_cuda_rmsnorm_func(add_residual)

if not add_residual:
if current_platform.is_rocm() and int(use_rocm_aiter) and int(
use_rocm_aiter_norm):
assert rms_norm_func == rocm_aiter_rms_norm
else:
assert rms_norm_func == rms_norm
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
use_rocm_aiter_norm):
assert rms_norm_func == rocm_aiter_fused_add_rms_norm
else:
assert rms_norm_func == fused_add_rms_norm
50 changes: 40 additions & 10 deletions tests/models/decoder_only/language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

Run `pytest tests/models/test_models.py`.
"""

import pytest
import torch

from vllm.platforms import current_platform

from ...utils import check_logprobs_close

Expand All @@ -13,7 +17,21 @@
# https://github.com/vllm-project/vllm/issues/14524
REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"]

# This list contains the model that are using AITER kernel.
# Skip model that are not using AITER tests.
# When more AITER kernels are added, this list will not be
# needed as all the models will be calling AITER kernels
# in parts of the operators
AITER_MODEL_LIST = [
"meta-llama/Llama-3.2-1B-Instruct",
"openbmb/MiniCPM3-4B",
"Qwen/Qwen-7B",
"Qwen/Qwen2.5-0.5B-Instruct",
"ehristoforu/Falcon3-MoE-2x7B-Insruct",
]


# @maybe_test_rocm_aiter
@pytest.mark.parametrize(
"model",
[
Expand Down Expand Up @@ -69,19 +87,24 @@
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
monkeypatch,
) -> None:
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_models(hf_runner, vllm_runner, example_prompts, model: str,
dtype: str, max_tokens: int, num_logprobs: int,
use_rocm_aiter: bool, monkeypatch) -> None:

if model in REQUIRES_V0:
monkeypatch.setenv("VLLM_USE_V1", "0")

if use_rocm_aiter and (model in AITER_MODEL_LIST):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
elif use_rocm_aiter and model not in AITER_MODEL_LIST:
# Skip model that are not using AITER tests.
# When more AITER kernels are added, this list will not be
# needed as all the models will be calling AITER kernels
# in parts of the operators
pytest.skip(f"Skipping '{model}' model test with AITER kernel.")

with hf_runner(model, dtype=dtype) as hf_model:
if model.startswith("THUDM/chatglm3"):
hf_model.model.get_output_embeddings = lambda: \
Expand All @@ -100,3 +123,10 @@ def test_models(
name_0="hf",
name_1="vllm",
)
if use_rocm_aiter:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something we should generally be doing for ROCm or just when AITER is enabled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something we should generally be doing for ROCm or just when AITER is enabled?

Currently, it seems to be just when AITER enabled that this situation could occur.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore
We have made the description clearer

        # this is to ensure that vllm engine
        # has deallocated the memory before running the next
+        # unit tests. On ROCm, when using AITER
+        # the memory might not be deallocated completely
+        # before running the next test case
        torch.cuda.synchronize()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to know. Thanks!

# this is to ensure that vllm engine
# has deallocated the memory before running the next
# unit tests. On ROCm, when using AITER
# the memory might not be deallocated completely
# before running the next test case
torch.cuda.synchronize()
13 changes: 13 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_DISABLED_KERNELS: list[str] = []
VLLM_USE_V1: bool = True
VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
Expand Down Expand Up @@ -521,6 +523,17 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
"VLLM_USE_V1":
lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))),

# Disable aiter ops unless specifically enabled.
# Acts as a parent switch to enable the rest of the other operations.
"VLLM_ROCM_USE_AITER":
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
("true", "1")),

# use aiter rms norm op if aiter ops are enabled.
"VLLM_ROCM_USE_AITER_RMSNORM":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
("true", "1")),

# Pad the fp8 weights to 256 bytes for ROCm
"VLLM_ROCM_FP8_PADDING":
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
Expand Down
94 changes: 77 additions & 17 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,77 @@
import torch
import torch.nn as nn

import vllm.envs as envs
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform


def is_rocm_aiter_rmsnorm_enabled() -> bool:
return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER_RMSNORM \
and envs.VLLM_ROCM_USE_AITER


def rms_norm(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
from vllm import _custom_ops as ops
out = torch.empty_like(x)
ops.rms_norm(
out,
x,
weight,
variance_epsilon,
)
return out


def fused_add_rms_norm(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
ops.fused_add_rms_norm(
x,
residual,
weight,
variance_epsilon,
)
return x, residual


def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:

import aiter as rocm_aiter
return rocm_aiter.rms_norm(x, weight, variance_epsilon)


def rocm_aiter_fused_add_rms_norm(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:

import aiter as rocm_aiter
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that AITER isn't published on pypy yet, meaning users will either have to use the docker container or build from source, I'd like to have a nicer error message when users try to enable aiter without it being installed. There are a number of ways we can do this. I like the following but am open to other solutions.

def dispatch_cuda_rmsnorm_func(
    add_residual: bool
) -> Callable[..., Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
    if not add_residual:
        return rms_norm
    if current_platform.is_rocm_aiter_rmsnorm_enabled():
        try:
            import aiter as rocm_aiter
            return rocm_aiter_rmsnorm2d_fwd_with_add
        except ImportError:
            logger.warn_once("AITER RMS Norm kernel is enabled, but AITER is not installed. Falling back to the default RMS Norm kernel")
            return fused_add_rms_norm
    return fused_add_rms_norm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore

  1. import aiter is conflicting with built-in function in python and try catching import aiter does not show whether aiter is installed. Unless we try to import a kernel function from aiter that if that kernel is not prebuild it would start building kernel JIT.
  2. having a fallback makes it difficult to actually debug and ping pong performance differences. In addition, just a warning might be missed by users and complain about the performance, as user expect when AITER flag is set, AITER kernels are used.

So, we will avoid having a fallback here.


# Assuming the correct signature for rmsnorm2d_fwd_with_add
rocm_aiter.rmsnorm2d_fwd_with_add(
x, # output
x, # input
residual, # residual input
residual, # residual output
weight,
variance_epsilon,
)
return x, residual


def dispatch_cuda_rmsnorm_func(add_residual: bool):
if add_residual:
if is_rocm_aiter_rmsnorm_enabled():
return rocm_aiter_fused_add_rms_norm
return fused_add_rms_norm

if is_rocm_aiter_rmsnorm_enabled():
return rocm_aiter_rms_norm
return rms_norm


@CustomOp.register("rms_norm")
Expand Down Expand Up @@ -81,24 +151,14 @@ def forward_cuda(
if self.variance_size_override is not None:
return self.forward_native(x, residual)

from vllm import _custom_ops as ops
add_residual = residual is not None
norm_func = dispatch_cuda_rmsnorm_func(add_residual)

if residual is not None:
ops.fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
return x, residual
out = torch.empty_like(x)
ops.rms_norm(
out,
x,
self.weight.data,
self.variance_epsilon,
)
return out
if add_residual:
return norm_func(x, residual, self.weight.data,
self.variance_epsilon)
else:
return norm_func(x, self.weight.data, self.variance_epsilon)

def forward_hpu(
self,
Expand Down