diff --git a/include/tvm/topi/nn/layer_norm.h b/include/tvm/topi/nn/layer_norm.h index ee0cba74dd3b..f1b0e4ac9eaa 100644 --- a/include/tvm/topi/nn/layer_norm.h +++ b/include/tvm/topi/nn/layer_norm.h @@ -65,7 +65,7 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& auto real_axis = GetRealAxis(static_cast(ndim), axis); auto reduce_axes = MakeReduceAxes(real_axis, data); auto target_shape = - MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, /*atleast1d=*/true); + MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, /*atleast1d=*/false); auto func = MakeTupleSumReducer(); auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func, diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 52986feef377..575f4a0fb0d9 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -2599,6 +2599,60 @@ def layer_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.in tvm.ir.assert_structural_equal(mod, Expected) +def test_layer_norm_1d(): + # fmt: off + @I.ir_module + class LayerNorm_1D: + @R.function + def forward(x: R.Tensor((3,), dtype="float32"), layer_norm_weight: R.Tensor((3,), dtype="float32"), layer_norm_bias: R.Tensor((3,), dtype="float32")) -> R.Tensor((3,), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + layer_norm: R.Tensor((3,), dtype="float32") = R.nn.layer_norm(x, layer_norm_weight, layer_norm_bias, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) + gv: R.Tensor((3,), dtype="float32") = layer_norm + R.output(gv) + return gv + + @I.ir_module + class LayerNorm_1D_Expected: + @T.prim_func(private=True) + def layer_norm(x: T.Buffer((T.int64(3),), "float32"), layer_norm_weight: T.Buffer((T.int64(3),), "float32"), layer_norm_bias: T.Buffer((T.int64(3),), "float32"), T_layer_norm: T.Buffer((T.int64(3),), "float32")): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + x_red_temp_v0 = T.alloc_buffer(()) + x_red_temp_v1 = T.alloc_buffer(()) + for k0 in range(T.int64(3)): + with T.block("x_red_temp"): + v_k0 = T.axis.reduce(T.int64(3), k0) + T.reads(x[v_k0]) + T.writes(x_red_temp_v0[()], x_red_temp_v1[()]) + with T.init(): + x_red_temp_v0[()] = T.float32(0.0) + x_red_temp_v1[()] = T.float32(0.0) + v_x_red_temp_v0: T.float32 = x_red_temp_v0[()] + x[v_k0] + v_x_red_temp_v1: T.float32 = x_red_temp_v1[()] + x[v_k0] * x[v_k0] + x_red_temp_v0[()] = v_x_red_temp_v0 + x_red_temp_v1[()] = v_x_red_temp_v1 + for ax0 in range(T.int64(3)): + with T.block("T_layer_norm"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(x[v_ax0], x_red_temp_v0[()], x_red_temp_v1[()], layer_norm_weight[v_ax0], layer_norm_bias[v_ax0]) + T.writes(T_layer_norm[v_ax0]) + T_layer_norm[v_ax0] = (x[v_ax0] - x_red_temp_v0[()] * T.float32(0.33333333333333331)) * T.rsqrt(x_red_temp_v1[()] * T.float32(0.33333333333333331) - x_red_temp_v0[()] * T.float32(0.33333333333333331) * (x_red_temp_v0[()] * T.float32(0.33333333333333331)) + T.float32(1.0000000000000001e-05)) * layer_norm_weight[v_ax0] + layer_norm_bias[v_ax0] + + @R.function + def forward(x: R.Tensor((3,), dtype="float32"), layer_norm_weight: R.Tensor((3,), dtype="float32"), layer_norm_bias: R.Tensor((3,), dtype="float32")) -> R.Tensor((3,), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = LayerNorm_1D_Expected + with R.dataflow(): + layer_norm = R.call_tir(cls.layer_norm, (x, layer_norm_weight, layer_norm_bias), out_sinfo=R.Tensor((3,), dtype="float32")) + gv: R.Tensor((3,), dtype="float32") = layer_norm + R.output(gv) + return gv + # fmt: on + mod = LegalizeOps()(LayerNorm_1D) + tvm.ir.assert_structural_equal(mod, LayerNorm_1D_Expected) + + def test_layer_norm_fp16(): # fmt: off @tvm.script.ir_module