From 59073602575669aff3dc1b3a687421fd78e8100e Mon Sep 17 00:00:00 2001 From: yrx <421626388@qq.com> Date: Sun, 25 May 2025 19:33:45 +0800 Subject: [PATCH 1/3] [TOPI][NN][Layer_Norm] Fix layer_norm error with reduce-only axes --- include/tvm/topi/nn/layer_norm.h | 2 +- .../relax/test_transform_legalize_ops_nn.py | 54 +++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/include/tvm/topi/nn/layer_norm.h b/include/tvm/topi/nn/layer_norm.h index ee0cba74dd3b..f1b0e4ac9eaa 100644 --- a/include/tvm/topi/nn/layer_norm.h +++ b/include/tvm/topi/nn/layer_norm.h @@ -65,7 +65,7 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& auto real_axis = GetRealAxis(static_cast(ndim), axis); auto reduce_axes = MakeReduceAxes(real_axis, data); auto target_shape = - MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, /*atleast1d=*/true); + MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, /*atleast1d=*/false); auto func = MakeTupleSumReducer(); auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func, diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 52986feef377..4ba6a680fb3f 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -2598,6 +2598,60 @@ def layer_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.in mod = LegalizeOps()(LayerNorm) tvm.ir.assert_structural_equal(mod, Expected) +def test_layer_norm_1d(): + # fmt: off + @I.ir_module + class LayerNorm_1D: + @R.function + def forward(x: R.Tensor((3,), dtype="float32"), layer_norm_weight: R.Tensor((3,), dtype="float32"), layer_norm_bias: R.Tensor((3,), dtype="float32")) -> R.Tensor((3,), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + layer_norm: R.Tensor((3,), dtype="float32") = R.nn.layer_norm(x, layer_norm_weight, layer_norm_bias, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) + gv: R.Tensor((3,), dtype="float32") = layer_norm + R.output(gv) + return gv + + @I.ir_module + class LayerNorm_1D_Expected: + @T.prim_func(private=True) + def layer_norm(x: T.Buffer((T.int64(3),), "float32"), layer_norm_weight: T.Buffer((T.int64(3),), "float32"), layer_norm_bias: T.Buffer((T.int64(3),), "float32"), T_layer_norm: T.Buffer((T.int64(3),), "float32")): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + x_red_temp_v0 = T.alloc_buffer(()) + x_red_temp_v1 = T.alloc_buffer(()) + for k0 in range(T.int64(3)): + with T.block("x_red_temp"): + v_k0 = T.axis.reduce(T.int64(3), k0) + T.reads(x[v_k0]) + T.writes(x_red_temp_v0[()], x_red_temp_v1[()]) + with T.init(): + x_red_temp_v0[()] = T.float32(0.0) + x_red_temp_v1[()] = T.float32(0.0) + v_x_red_temp_v0: T.float32 = x_red_temp_v0[()] + x[v_k0] + v_x_red_temp_v1: T.float32 = x_red_temp_v1[()] + x[v_k0] * x[v_k0] + x_red_temp_v0[()] = v_x_red_temp_v0 + x_red_temp_v1[()] = v_x_red_temp_v1 + for ax0 in range(T.int64(3)): + with T.block("T_layer_norm"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(x[v_ax0], x_red_temp_v0[()], x_red_temp_v1[()], layer_norm_weight[v_ax0], layer_norm_bias[v_ax0]) + T.writes(T_layer_norm[v_ax0]) + T_layer_norm[v_ax0] = (x[v_ax0] - x_red_temp_v0[()] * T.float32(0.33333333333333331)) * T.rsqrt(x_red_temp_v1[()] * T.float32(0.33333333333333331) - x_red_temp_v0[()] * T.float32(0.33333333333333331) * (x_red_temp_v0[()] * T.float32(0.33333333333333331)) + T.float32(1.0000000000000001e-05)) * layer_norm_weight[v_ax0] + layer_norm_bias[v_ax0] + + @R.function + def forward(x: R.Tensor((3,), dtype="float32"), layer_norm_weight: R.Tensor((3,), dtype="float32"), layer_norm_bias: R.Tensor((3,), dtype="float32")) -> R.Tensor((3,), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = LayerNorm_1D_Expected + with R.dataflow(): + layer_norm = R.call_tir(cls.layer_norm, (x, layer_norm_weight, layer_norm_bias), out_sinfo=R.Tensor((3,), dtype="float32")) + gv: R.Tensor((3,), dtype="float32") = layer_norm + R.output(gv) + return gv + + # fmt: on + mod = LegalizeOps()(LayerNorm_1D) + tvm.ir.assert_structural_equal(mod, LayerNorm_1D_Expected) + def test_layer_norm_fp16(): # fmt: off From d7cbe880bf776b168236250d8fcfce787d982c3b Mon Sep 17 00:00:00 2001 From: yrx <421626388@qq.com> Date: Sun, 25 May 2025 20:19:58 +0800 Subject: [PATCH 2/3] change code style --- tests/python/relax/test_transform_legalize_ops_nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 4ba6a680fb3f..f3065e07a6d1 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -2598,6 +2598,7 @@ def layer_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.in mod = LegalizeOps()(LayerNorm) tvm.ir.assert_structural_equal(mod, Expected) + def test_layer_norm_1d(): # fmt: off @I.ir_module @@ -2647,7 +2648,6 @@ def forward(x: R.Tensor((3,), dtype="float32"), layer_norm_weight: R.Tensor((3,) gv: R.Tensor((3,), dtype="float32") = layer_norm R.output(gv) return gv - # fmt: on mod = LegalizeOps()(LayerNorm_1D) tvm.ir.assert_structural_equal(mod, LayerNorm_1D_Expected) From e9634ad69289551875437b7c998d12eefb2369d0 Mon Sep 17 00:00:00 2001 From: yrx <421626388@qq.com> Date: Sun, 25 May 2025 20:29:58 +0800 Subject: [PATCH 3/3] rechange code style --- tests/python/relax/test_transform_legalize_ops_nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index f3065e07a6d1..575f4a0fb0d9 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -2611,7 +2611,7 @@ def forward(x: R.Tensor((3,), dtype="float32"), layer_norm_weight: R.Tensor((3,) gv: R.Tensor((3,), dtype="float32") = layer_norm R.output(gv) return gv - + @I.ir_module class LayerNorm_1D_Expected: @T.prim_func(private=True)