Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tests/models/language/generation/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"Qwen/Qwen-7B-Chat",
"Qwen/Qwen2.5-0.5B-Instruct",
"TitanML/tiny-mixtral",
"Qwen/Qwen3-8B",
]


Expand Down Expand Up @@ -78,6 +79,9 @@
"Qwen/Qwen2.5-0.5B-Instruct", # qwen2
marks=[pytest.mark.core_model],
),
pytest.param(
"Qwen/Qwen3-8B", # qwen (text-only)
),
pytest.param("stabilityai/stablelm-3b-4e1t"), # stablelm
pytest.param("bigcode/starcoder2-3b"), # starcoder2
pytest.param(
Expand Down
15 changes: 11 additions & 4 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:

import aiter as rocm_aiter
if x.dim() > 2:
x_original_shape = x.shape
x = x.reshape(-1, x_original_shape[-1])
x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
return x.reshape(x_original_shape)

return rocm_aiter.rms_norm(x, weight, variance_epsilon)


Expand All @@ -55,16 +61,17 @@ def rocm_aiter_fused_add_rms_norm(

import aiter as rocm_aiter

# Assuming the correct signature for rmsnorm2d_fwd_with_add
residual_out = torch.empty_like(residual)
output = torch.empty_like(x)
Comment on lines +64 to +65
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why wasn't this copy necessary before? Would making the copy conditional improve performance in cases where a copy isn't required?

Copy link
Contributor Author

@tjtanaa tjtanaa May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337 It is a bugfix of using the AITER commit: 5a77249 . The usage of the ops has changed at some point. However, it hasn't been throwing error in any of the unit tests. To ensure the correctness, we have followed the usage of the kernel based on this commit 5a77249 specifically.

rocm_aiter.rmsnorm2d_fwd_with_add(
x, # output
output, # output
x, # input
residual, # residual input
residual, # residual output
residual_out, # residual output
weight,
variance_epsilon,
)
return x, residual
return output, residual_out


def dispatch_cuda_rmsnorm_func(add_residual: bool):
Expand Down