From 88abd9606d10e70cc67bbaffebf84a615b639585 Mon Sep 17 00:00:00 2001 From: yaofengchen Date: Wed, 28 Aug 2024 07:01:58 +0000 Subject: [PATCH] fix: fix rms_norm params --- lmdeploy/pytorch/kernels/ascend/rms_norm.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/lmdeploy/pytorch/kernels/ascend/rms_norm.py b/lmdeploy/pytorch/kernels/ascend/rms_norm.py index 5949bfb116..d3c6630fc8 100644 --- a/lmdeploy/pytorch/kernels/ascend/rms_norm.py +++ b/lmdeploy/pytorch/kernels/ascend/rms_norm.py @@ -3,5 +3,13 @@ from torch import Tensor -def rms_norm(hidden_states: Tensor, weight: Tensor, epsilon: float = 1e-6): - return ext_ops.rms_norm(hidden_states, weight, epsilon) +def rms_norm(hidden_states: Tensor, + weight: Tensor, + eps: float = 1e-6, + out: Tensor = None): + rms_norm_out = ext_ops.rms_norm(hidden_states, weight, eps) + if out is None: + out = rms_norm_out + else: + out.copy_(rms_norm_out) + return rms_norm_out