Skip to content

[Bug] MetaSchedule Causes Precision Loss in LayerNorm Operator #17977

@Cookiee235

Description

@Cookiee235

Actual behavior

Traceback (most recent call last):
  File "/data/qshenaf/remote_pc/TirFuzz/bugs/topi.nn.layer_norm_3.py", line 33, in <module>
    np.testing.assert_allclose(
  File "/data/qshenaf/miniconda3/envs/tvm/lib/python3.12/site-packages/numpy/testing/_private/utils.py", line 1715, in assert_allclose
    assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
  File "/data/qshenaf/miniconda3/envs/tvm/lib/python3.12/site-packages/numpy/testing/_private/utils.py", line 921, in assert_array_compare
    raise AssertionError(msg)
AssertionError:
Not equal to tolerance rtol=1e-05, atol=1e-05
An Inconsistency bug detected.
Mismatched elements: 33 / 16384 (0.201%)
Max absolute difference among violations: 0.000977
Max relative difference among violations: 0.001172
 ACTUAL: array([[[[-3.0884e-01,  1.3924e-02,  3.9844e-01, ..., -1.5515e-01,
          -1.9397e-01,  3.7646e-01],
         [ 2.8882e-01,  7.9163e-02, -5.2979e-01, ...,  7.5732e-01,...
 DESIRED: array([[[[-3.0884e-01,  1.3924e-02,  3.9844e-01, ..., -1.5515e-01,
          -1.9397e-01,  3.7646e-01],
         [ 2.8882e-01,  7.9163e-02, -5.2979e-01, ...,  7.5732e-01,...

Environment

tvm-0.21.dev0

Steps to reproduce

import tvm
from tvm import te, topi, tir
from tvm import meta_schedule as ms
import numpy as np


def compile_mod(mod, np_input_list, output_shape, output_type, opt_level=3):
    with tvm.transform.PassContext(opt_level):
        ref_mod = tvm.build(mod, target='llvm')
    mod_output = tvm.nd.empty(output_shape, dtype=output_type, device=tvm.cpu(0))

    tvm_inputs = [tvm.nd.array(x) for x in np_input_list]
    ref_mod(*tvm_inputs, mod_output)
    return mod_output


data = te.placeholder((4, 8, 16, 32), dtype='float16', name='data')
gamma = te.placeholder((8, 16, 32), dtype='float16', name='gamma')
op_output = topi.nn.layer_norm(data, gamma, beta=None, axis=[1,2,3])
np_inputs = [np.random.uniform(-1, 1, size=(4, 8, 16, 32)).astype('float16'),np.random.uniform(-1, 1, size=(8, 16, 32)).astype('float16')]

sch = tir.Schedule(te.create_prim_func([data, gamma, op_output]).with_attr('target', tvm.target.Target('llvm')))
ref_output = compile_mod(sch.mod, np_inputs, op_output.shape, op_output.dtype, opt_level=0)


database = ms.tir_integration.tune_tir(mod=sch.mod, target='llvm --num-cores=16', work_dir='./tune_tmp', max_trials_global=1, num_trials_per_iter=1)
sch = ms.tir_integration.compile_tir(database, sch.mod, 'llvm --num-cores=16')
opt_mod_output = compile_mod(sch.mod, np_inputs, op_output.shape, op_output.dtype, opt_level=0)

np.testing.assert_allclose(
    ref_output.numpy(), opt_mod_output.numpy(), rtol=1e-5, atol=1e-5, err_msg=f"An Inconsistency bug detected."
)

Triage

  • needs-triage
  • tune:meta_schedule

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions