-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
[FEAT] [ROCm]: Add AITER RMS Norm (Layer Norm) Feature #14959
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e88fd89
4df7ed0
5f653f0
5b670cf
6acb1bb
44b8861
d899eea
d84e654
785f753
8854c03
64cc656
74c2d39
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,11 @@ | |
|
|
||
| Run `pytest tests/models/test_models.py`. | ||
| """ | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm.platforms import current_platform | ||
|
|
||
| from ...utils import check_logprobs_close | ||
|
|
||
|
|
@@ -13,7 +17,21 @@ | |
| # https://github.com/vllm-project/vllm/issues/14524 | ||
| REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"] | ||
|
|
||
| # This list contains the model that are using AITER kernel. | ||
| # Skip model that are not using AITER tests. | ||
| # When more AITER kernels are added, this list will not be | ||
| # needed as all the models will be calling AITER kernels | ||
| # in parts of the operators | ||
| AITER_MODEL_LIST = [ | ||
| "meta-llama/Llama-3.2-1B-Instruct", | ||
| "openbmb/MiniCPM3-4B", | ||
| "Qwen/Qwen-7B", | ||
| "Qwen/Qwen2.5-0.5B-Instruct", | ||
| "ehristoforu/Falcon3-MoE-2x7B-Insruct", | ||
| ] | ||
|
|
||
|
|
||
| # @maybe_test_rocm_aiter | ||
| @pytest.mark.parametrize( | ||
| "model", | ||
| [ | ||
|
|
@@ -69,19 +87,24 @@ | |
| @pytest.mark.parametrize("dtype", ["half"]) | ||
| @pytest.mark.parametrize("max_tokens", [32]) | ||
| @pytest.mark.parametrize("num_logprobs", [5]) | ||
| def test_models( | ||
| hf_runner, | ||
| vllm_runner, | ||
| example_prompts, | ||
| model: str, | ||
| dtype: str, | ||
| max_tokens: int, | ||
| num_logprobs: int, | ||
| monkeypatch, | ||
| ) -> None: | ||
| @pytest.mark.parametrize( | ||
| "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) | ||
| def test_models(hf_runner, vllm_runner, example_prompts, model: str, | ||
| dtype: str, max_tokens: int, num_logprobs: int, | ||
| use_rocm_aiter: bool, monkeypatch) -> None: | ||
|
|
||
| if model in REQUIRES_V0: | ||
| monkeypatch.setenv("VLLM_USE_V1", "0") | ||
|
|
||
| if use_rocm_aiter and (model in AITER_MODEL_LIST): | ||
| monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") | ||
| elif use_rocm_aiter and model not in AITER_MODEL_LIST: | ||
| # Skip model that are not using AITER tests. | ||
| # When more AITER kernels are added, this list will not be | ||
| # needed as all the models will be calling AITER kernels | ||
| # in parts of the operators | ||
| pytest.skip(f"Skipping '{model}' model test with AITER kernel.") | ||
|
|
||
| with hf_runner(model, dtype=dtype) as hf_model: | ||
| if model.startswith("THUDM/chatglm3"): | ||
| hf_model.model.get_output_embeddings = lambda: \ | ||
|
|
@@ -100,3 +123,10 @@ def test_models( | |
| name_0="hf", | ||
| name_1="vllm", | ||
| ) | ||
| if use_rocm_aiter: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this something we should generally be doing for ROCm or just when AITER is enabled? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Currently, it seems to be just when AITER enabled that this situation could occur. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @SageMoore # this is to ensure that vllm engine
# has deallocated the memory before running the next
+ # unit tests. On ROCm, when using AITER
+ # the memory might not be deallocated completely
+ # before running the next test case
torch.cuda.synchronize()There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good to know. Thanks! |
||
| # this is to ensure that vllm engine | ||
| # has deallocated the memory before running the next | ||
| # unit tests. On ROCm, when using AITER | ||
| # the memory might not be deallocated completely | ||
| # before running the next test case | ||
| torch.cuda.synchronize() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,77 @@ | |
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| import vllm.envs as envs | ||
| from vllm.model_executor.custom_op import CustomOp | ||
| from vllm.platforms import current_platform | ||
|
|
||
|
|
||
| def is_rocm_aiter_rmsnorm_enabled() -> bool: | ||
| return current_platform.is_rocm() \ | ||
| and envs.VLLM_ROCM_USE_AITER_RMSNORM \ | ||
| and envs.VLLM_ROCM_USE_AITER | ||
|
|
||
|
|
||
| def rms_norm(x: torch.Tensor, weight: torch.Tensor, | ||
| variance_epsilon: float) -> torch.Tensor: | ||
| from vllm import _custom_ops as ops | ||
| out = torch.empty_like(x) | ||
| ops.rms_norm( | ||
| out, | ||
| x, | ||
| weight, | ||
| variance_epsilon, | ||
| ) | ||
| return out | ||
|
|
||
|
|
||
| def fused_add_rms_norm( | ||
| x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, | ||
| variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| from vllm import _custom_ops as ops | ||
| ops.fused_add_rms_norm( | ||
| x, | ||
| residual, | ||
| weight, | ||
| variance_epsilon, | ||
| ) | ||
| return x, residual | ||
|
|
||
|
|
||
| def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor, | ||
| variance_epsilon: float) -> torch.Tensor: | ||
|
|
||
| import aiter as rocm_aiter | ||
| return rocm_aiter.rms_norm(x, weight, variance_epsilon) | ||
|
|
||
|
|
||
| def rocm_aiter_fused_add_rms_norm( | ||
| x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, | ||
| variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]: | ||
|
|
||
| import aiter as rocm_aiter | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given that AITER isn't published on pypy yet, meaning users will either have to use the docker container or build from source, I'd like to have a nicer error message when users try to enable aiter without it being installed. There are a number of ways we can do this. I like the following but am open to other solutions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
So, we will avoid having a fallback here. |
||
|
|
||
| # Assuming the correct signature for rmsnorm2d_fwd_with_add | ||
| rocm_aiter.rmsnorm2d_fwd_with_add( | ||
| x, # output | ||
| x, # input | ||
| residual, # residual input | ||
| residual, # residual output | ||
| weight, | ||
| variance_epsilon, | ||
| ) | ||
| return x, residual | ||
|
|
||
|
|
||
| def dispatch_cuda_rmsnorm_func(add_residual: bool): | ||
| if add_residual: | ||
| if is_rocm_aiter_rmsnorm_enabled(): | ||
| return rocm_aiter_fused_add_rms_norm | ||
| return fused_add_rms_norm | ||
|
|
||
| if is_rocm_aiter_rmsnorm_enabled(): | ||
| return rocm_aiter_rms_norm | ||
| return rms_norm | ||
|
|
||
|
|
||
| @CustomOp.register("rms_norm") | ||
|
|
@@ -81,24 +151,14 @@ def forward_cuda( | |
| if self.variance_size_override is not None: | ||
| return self.forward_native(x, residual) | ||
|
|
||
| from vllm import _custom_ops as ops | ||
| add_residual = residual is not None | ||
| norm_func = dispatch_cuda_rmsnorm_func(add_residual) | ||
|
|
||
| if residual is not None: | ||
| ops.fused_add_rms_norm( | ||
| x, | ||
| residual, | ||
| self.weight.data, | ||
| self.variance_epsilon, | ||
| ) | ||
| return x, residual | ||
| out = torch.empty_like(x) | ||
| ops.rms_norm( | ||
| out, | ||
| x, | ||
| self.weight.data, | ||
| self.variance_epsilon, | ||
| ) | ||
| return out | ||
| if add_residual: | ||
| return norm_func(x, residual, self.weight.data, | ||
| self.variance_epsilon) | ||
| else: | ||
| return norm_func(x, self.weight.data, self.variance_epsilon) | ||
|
|
||
| def forward_hpu( | ||
| self, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.