Skip to content

Commit 3943257

Browse files
Restore original torch.Parameter behavior in RMSNorm
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
1 parent a3ebf0a commit 3943257

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

vllm/model_executor/layers/layernorm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ def __init__(
170170
)
171171
weight_dtype = dtype or torch.get_default_dtype()
172172
self.has_weight = has_weight
173-
self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype))
173+
self.weight = torch.ones(hidden_size, dtype=weight_dtype)
174+
if self.has_weight:
175+
self.weight = nn.Parameter(self.weight)
174176

175177
if current_platform.is_rocm():
176178
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(

0 commit comments

Comments
 (0)