Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] RemoveUnusedOutputs give unexpected results #17247

Closed
Cookiee235 opened this issue Aug 6, 2024 · 5 comments · Fixed by #17249 or #17253
Closed

[Bug] RemoveUnusedOutputs give unexpected results #17247

Cookiee235 opened this issue Aug 6, 2024 · 5 comments · Fixed by #17249 or #17253
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@Cookiee235
Copy link
Contributor

Cookiee235 commented Aug 6, 2024

Hi all, The pass RemoveUnusedOutputs seems to give an unexpected optimized result. Due to the lack of detailed documentation about this API (e.g., relax.transform.RemoveUnusedOutputs), I cannot confirm if the optimization result is wrong.

In addition, another bug is about the API tvm.ir.assert_structural_equal, for the totally same mod, this API judge the structure of them as unequal. It was triggered by IRs with the string "nan".

Actual behavior

## Output IRs after the RemoveUnusedOutputs
@I.ir_module
class Module:
    @R.function
    def main(v0_0: R.Tensor((1,), dtype="int32"), v1_0: R.Tensor((42,), dtype="int32")) -> R.Tuple(R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan"))):
        R.func_attr({"num_input": 2})
        with R.dataflow():
            res: R.Tuple(R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan"))) = R.prim_value(T.float64("nan")), R.prim_value(T.float64("nan")), R.prim_value(T.float64("nan"))
            R.output(res)
        return res
----------------------------------------------------------------------------------------------------------------------------------
Traceback (most recent call last):
  File "/share_container/optfuzz/res/bugs/assert_structure.py", line 66, in <module>
    tvm.ir.assert_structural_equal(mod, mod)
  File "/software/tvm-lunder/python/tvm/ir/base.py", line 256, in assert_structural_equal
    _ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars)  # type: ignore # pylint: disable=no-member
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line 240, in __call__
    raise_last_ffi_error()
  File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
ValueError: Traceback (most recent call last):
  5: _ZN3tvm7runtime13PackedFuncObj
  4: tvm::runtime::TypedPackedFunc<bool (tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)>::AssignTypedLambda<tvm::{lambda(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)#3}>(tvm::{lambda(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)#3}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  3: tvm::SEqualHandlerDefault::Impl::Equal(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool)
  2: tvm::SEqualHandlerDefault::Impl::RunTasks()
  1: tvm::SEqualHandlerDefault::DispatchSEqualReduce(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, tvm::runtime::Optional<tvm::ObjectPathPair> const&)
  0: tvm::SEqualHandlerDefault::Impl::CheckResult(bool, tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, tvm::runtime::Optional<tvm::ObjectPathPair> const&)
  File "/software/tvm-lunder/src/node/structural_equal.cc", line 392
ValueError: StructuralEqual check failed, caused by lhs at <root>.functions[I.GlobalVar("main")].body.blocks[0].bindings[0].value.fields[0].value.value:
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(v0_0: R.Tensor((1,), dtype="int32"), v1_0: R.Tensor((42,), dtype="int32")) -> R.Tuple(R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan"))):
        R.func_attr({"num_input": 2})
        with R.dataflow():
            res: R.Tuple(R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan"))) = R.prim_value(T.float64("nan")), R.prim_value(T.float64("nan")), R.prim_value(T.float64("nan"))
                                                                                                                                                  ^^^^^
            R.output(res)
        return res
and rhs at <root>.functions[I.GlobalVar("main")].body.blocks[0].bindings[0].value.fields[0].value.value:
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(v0_0: R.Tensor((1,), dtype="int32"), v1_0: R.Tensor((42,), dtype="int32")) -> R.Tuple(R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan"))):
        R.func_attr({"num_input": 2})
        with R.dataflow():
            res: R.Tuple(R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan"))) = R.prim_value(T.float64("nan")), R.prim_value(T.float64("nan")), R.prim_value(T.float64("nan"))
                                                                                                                                                  ^^^^^
            R.output(res)
        return res

Steps to reproduce

import tvm
from tvm import relax

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def ones(T_full: T.Buffer((T.int64(16), T.int64(16)), "int32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(16), T.int64(16)):
            with T.block("T_full"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads()
                T.writes(T_full[v_ax0, v_ax1])
                T_full[v_ax0, v_ax1] = 1

    @T.prim_func(private=True)
    def zeros(T_full: T.Buffer((T.int64(16), T.int64(16)), "int32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(16), T.int64(16)):
            with T.block("T_full"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads()
                T.writes(T_full[v_ax0, v_ax1])
                T_full[v_ax0, v_ax1] = 0

    @T.prim_func(private=True)
    def zeros1(T_full: T.Buffer((T.int64(32), T.int64(32)), "int32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(32), T.int64(32)):
            with T.block("T_full"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads()
                T.writes(T_full[v_ax0, v_ax1])
                T_full[v_ax0, v_ax1] = 0

    @R.function(private=True)
    def func() -> R.Tuple(R.Tensor((16, 16), dtype="int32"), R.Tensor((16, 16), dtype="int32"), R.Tensor((32, 32), dtype="int32")):
        cls = Module
        A = R.call_tir(cls.zeros, R.tuple(), out_sinfo=R.Tensor((16, 16), dtype="int32"))
        B = R.call_tir(cls.ones, R.tuple(), out_sinfo=R.Tensor((16, 16), dtype="int32"))
        C = R.call_tir(cls.zeros1, R.tuple(), out_sinfo=R.Tensor((32, 32), dtype="int32"))
        return (A, B, C)

    @R.function
    def main(v0_0: R.Tensor((1,), dtype="int32"), v1_0: R.Tensor((42,), dtype="int32")) -> R.Tuple(R.Tensor((16, 16), dtype="int32"), R.Tensor((16, 16), dtype="int32"), R.Tensor((32, 32), dtype="int32")):
        R.func_attr({"num_input": 2})
        cls = Module
        with R.dataflow():
            res: R.Tuple(R.Tensor((16, 16), dtype="int32"), R.Tensor((16, 16), dtype="int32"), R.Tensor((32, 32), dtype="int32")) = cls.func()
            R.output(res)
        return res


mod = Module
mod.show()

mod = relax.transform.RemoveUnusedOutputs()(mod)
mod.show()  # is this irs correct?
tvm.ir.assert_structural_equal(mod, mod)  # not equal! why?

cc @Lunderberg @junrushao

@Cookiee235 Cookiee235 added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Aug 6, 2024
Lunderberg added a commit to Lunderberg/tvm that referenced this issue Aug 6, 2024
Prior to this commit, `NaN` values did not have any special handling
in either `StructuralEqual` or `StructuralHash`.

`StructuralEqual` checked whether the LHS and RHS were within some
tolerance of each other.  If the LHS and RHS are both `NaN`, this
would evaluate to false.  The updated `StructuralEqual` now checks for
this case, and returns true if both sides are `NaN`.

`StructuralHash` used the bit-pattern of a floating-point number to
compute the hash.  A `NaN` value may have any non-zero value in its
mantissa, and so this could produce distinct hashes for ASTs that
differ only by the choice of non-zero value.  The updated
`StructuralHash` uses the same
`std::numeric_limits<double::quiet_NaN()` value for all `NaN` values.

With these changes, `StructuralEqual` and `StructuralHash` can now
compare two IR functions, even if they contain `NaN`.

Closes apache#17247
@Lunderberg
Copy link
Contributor

Looks like the test case can be made even simpler, and isn't limited to RemoveUnusedOutputs. The root cause is that StructuralEqual compared TIR floats by checking abs(lhs-rhs) < 1e9, which always evaluates to false for NaN values.

@T.prim_func(private=True)
def func_1():
    return T.float32("nan")

@T.prim_func(private=True)
def func_2():
    return T.float32("nan")

tvm.ir.assert_structural_equal(func_1, func_2)

I've implemented #17249 which should fix this issue, by having StructuralEqual and StructuralHash have special handling to compare NaN values.

@Cookiee235
Copy link
Contributor Author

Cookiee235 commented Aug 7, 2024

@Lunderberg Fixing for fixing the wrong implementation about assert_structural_equal.
I have another question. For the given IRs in my script, an odd IRs was obtained after using the RemoveUnusedOutputs optimization. It seems function func should not be removed. Can you help me check if this is a bug?


Output IRs after the RemoveUnusedOutputs

@I.ir_module
class Module:
    @R.function
    def main(v0_0: R.Tensor((1,), dtype="int32"), v1_0: R.Tensor((42,), dtype="int32")) -> R.Tuple(R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan"))):
        R.func_attr({"num_input": 2})
        with R.dataflow():
            res: R.Tuple(R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan"))) = R.prim_value(T.float64("nan")), R.prim_value(T.float64("nan")), R.prim_value(T.float64("nan"))
            R.output(res)
        return res```

@Lunderberg
Copy link
Contributor

Ooh, I had missed that part, and thought there was a nan inside the original model. Thank you for calling my attention to it.

Lunderberg added a commit to Lunderberg/tvm that referenced this issue Aug 7, 2024
Prior to this commit, the `relax.transform.RemoveUnusedOutputs` pass
only marked a tuple element as used if it occurred in a `TupleGetItem`
node.  This ignored use cases where a tuple is used as an aggregate
object, such as returning a tuple from a function.  This would collect
incorrect results for a Relax function that calls a subroutine,
receives a tuple as the return value of the subroutine, then returns
that tuple.

This commit updates `RemoveUnusedOutputs` to look for usage of a tuple
object, not just for usage in `TupleGetItem`.

Closes apache#17247
@Lunderberg
Copy link
Contributor

The introduction of nan is a bug in RemoveUnusedOutputs. When determining which outputs of a callee are used, it only collected usages in TupleGetItem(out_tuple, index). If the tuple is used in a context that doesn't access a specific element, such as returning from a function, then the usage is skipped.

The nan values are intended as placeholders, as a dummy value for indexing. If a callee produces (A,B,C), but B is never used, then the callee would be updated to produce (A,C), and callsites would be updated to replace res = callee() with new_output = callee(); res = (new_output[0], NaN, new_output[1]). The intermediate tuple would then be deconstructed with CanonicalizeBindings. Since nothing ever accessed res[1], the NaN value at that location would drop out altogether.

So, if a function called a subroutine that produces a tuple, then immediately returned that tuple, the usage would fail to be collected, and the tuple elements would be erroneously replaced with NaN. This should be resolved with #17253.

@Lunderberg
Copy link
Contributor

Re-opening, as the auto-close from #17249 wasn't correct. This issue still requires #17253 to land in order to be resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
2 participants