Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] [Relax] Check failed: (it != this->var_arg_map_.end()) is false: Var is not defined #17231

Closed
Cookiee235 opened this issue Aug 1, 2024 · 3 comments · May be fixed by #17232
Closed

[Bug] [Relax] Check failed: (it != this->var_arg_map_.end()) is false: Var is not defined #17231

Cookiee235 opened this issue Aug 1, 2024 · 3 comments · May be fixed by #17232
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@Cookiee235
Copy link
Contributor

Actual behavior

Traceback (most recent call last):
  File "/share_container/optfuzz/res/bugs/simple/res_undefined.py", line 49, in <module>
    compiled_after = compile_mod(relax.transform.LiftTransformParams()(mod))
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/share_container/optfuzz/res/bugs/simple/res_undefined.py", line 41, in compile_mod
    ex = relax.build(mod, target="llvm")
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/relax/vm_build.py", line 340, in build
    mod = _vmcodegen(builder, mod, exec_mode)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/relax/vm_build.py", line 176, in _vmcodegen
    return _ffi_api.VMCodeGen(builder, mod)  # type:ignore
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line 240, in __call__
    raise_last_ffi_error()
  File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm.error.InternalError: Traceback (most recent call last):
  7: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::relax::ExecBuilder, tvm::IRModule)>::AssignTypedLambda<tvm::IRModule (*)(tvm::relax::ExecBuilder, tvm::IRModule)>(tvm::IRModule (*)(tvm::relax::ExecBuilder, tvm::IRModule), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  6: tvm::relax::relax_vm::VMCodeGen(tvm::relax::ExecBuilder, tvm::IRModule)
  5: tvm::relax::relax_vm::CodeGenVM::Run(tvm::relax::ExecBuilder, tvm::IRModule)
  4: tvm::relax::relax_vm::CodeGenVM::Codegen(tvm::relax::Function const&)
  3: tvm::relax::ExprFunctor<tvm::runtime::relax_vm::Instruction::Arg (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  2: tvm::relax::relax_vm::CodeGenVM::VisitExpr_(tvm::relax::SeqExprNode const*)
  1: tvm::relax::ExprFunctor<tvm::runtime::relax_vm::Instruction::Arg (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  0: tvm::relax::relax_vm::CodeGenVM::VisitExpr_(tvm::relax::VarNode const*)
  File "/software/tvm-lunder/src/relax/backend/vm/codegen_vm.cc", line 232
InternalError: Check failed: (it != this->var_arg_map_.end()) is false: Var w1_t is not defined

Steps to reproduce

import tvm
from tvm import relax
import numpy as np
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def transpose(w1: T.Buffer((T.int64(256), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(256)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(256), T.int64(256)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(w1[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = w1[v_ax1, v_ax0]

    @R.function(private=False)
    def main(x: R.Tensor((256, 256), dtype="float32"), w1: R.Tensor((256, 256), dtype="float32")) -> R.Tensor((256, 256), dtype="float32"):
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            w1_t = R.call_tir(cls.transpose, (w1,), out_sinfo=R.Tensor((256, 256), dtype="float32"))
            R.output(w1_t)
        return w1_t

mod = Module
mod.show()
mod = tvm.relax.transform.LegalizeOps()(mod)


input_0 = tvm.nd.array(10 * np.random.random([256, 256]).astype('float32'))
input_1 = tvm.nd.array(10 * np.random.random([256, 256]).astype('float32'))

def compile_mod(mod):
    mod = relax.transform.FuseTIR()(mod)
    mod = relax.transform.LambdaLift()(mod)
    ex = relax.build(mod, target="llvm")
    vm = relax.VirtualMachine(ex, tvm.cpu())
    return vm


compiled_before = compile_mod(mod)
before_outputs = compiled_before["main"](input_0, input_1)

compiled_after = compile_mod(relax.transform.LiftTransformParams()(mod))
transformed_weights = compiled_after["main_transform_params"]([input_1])
after_outputs = compiled_after["main"](input_0, *transformed_weights)

cc @Lunderberg @junrushao

@Cookiee235 Cookiee235 added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Aug 1, 2024
Lunderberg added a commit to Lunderberg/tvm that referenced this issue Aug 1, 2024
Prior to this commit, the `relax.transform.LiftTransformParams` pass
inspected the expression in each `relax::Binding` for variables that
were required at runtime, but did not inspect the function's output.
As a result, any value that could be computed at compile-time, and was
either the function output or used in the function's output tuple,
would be undefined in the inference function.

This commit updates `LiftTransformParams` to collect variables from
both the bound value of `relax::Binding`, and the function's output.

While this error only impacted the `shared_transform=False` branch of
`LiftTransformParams`, this commit also adds regression tests the
`shared_transform=True` use case of `LiftTransformParams`.

Closes apache#17231
@Lunderberg
Copy link
Contributor

Looks like a bug in the LiftTransformParams implementation, that it only determines the variables required at runtime based on the contents of VarBinding, and not from the output of a Function. Should be fixed in #17232.

@Cookiee235
Copy link
Contributor Author

@Lunderberg The test case can run correctly now under the given PR (#17232). Thanks for your efforts!

@Lunderberg
Copy link
Contributor

No problem, and thank you for the high-quality bug reports! Running into any of these failure modes in larger use cases can be very difficult to debug. My personal rule of thumb is that every IRModule should either be caught as ill-formed, or should compile without issue. The errors you've been uncovering show that that clearly isn't the current case, but fixing them helps move toward that ideal.

(With some exceptions for uncatchable issues, such as incorrect arguments used for external functions.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants