diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 5834cf14d200..a38b31c9bb00 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1813,7 +1813,7 @@ def rms_norm( .. math:: - out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight + bias + out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight Parameters ---------- @@ -1823,9 +1823,6 @@ def rms_norm( weight : relax.Expr The scale factor. - bias : relax.Expr - The offset factor. - axes : Union[int, List[int]] The axes that along which the normalization is applied. diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 344c9bc7a347..3597b16a5bcc 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -848,13 +848,12 @@ InferLayoutOutput InferLayoutRMSNorm(const Call& call, LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); ObjectPtr new_attrs = make_object(*attrs); - std::vector new_axis; + std::vector new_axes; for (const auto& axis : attrs->axes) { - new_axis.push_back(FindAxis(layout->layout, axis->value)); + new_axes.push_back(FindAxis(layout->layout, axis->value)); } - new_attrs->axes = std::move(new_axis); - return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]}, {layout}, - Attrs(new_attrs)); + new_attrs->axes = std::move(new_axes); + return InferLayoutOutput({layout, initial_layouts[1]}, {layout}, Attrs(new_attrs)); } TVM_REGISTER_OP("relax.nn.rms_norm")