Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] InternalError: Check failed: (it != slot_map_.end()) is false: Var mis not defined in the function but is referenced by m * n during VM Shape Lowering #17493

Open
Thrsu opened this issue Oct 27, 2024 · 0 comments
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@Thrsu
Copy link
Contributor

Thrsu commented Oct 27, 2024

After applying LiftTransformParams transformation, during relax VM transformation, particularly in the VM Shape Lowering phase, the following error occurs:

File "/software/tvm/src/relax/backend/vm/vm_shape_lower.cc", line 310
InternalError: Check failed: (it != slot_map_.end()) is false: Var mis not defined in the function but is referenced by m * n

The error seems to indicate an issue with variable scope, where m is used but not recognized in the expected scope during shape transformation. The variable m is defined within main, and its shape is referenced correctly in tir_vars, yet it still causes a shape resolution failure.

Steps to reproduce

import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(var_weight: T.handle, var_T_add: T.handle, m: T.int64, n: T.int64):
        T.func_attr({"tir.noalias": T.bool(True)})
        weight = T.match_buffer(var_weight, (m * n,))
        T_add = T.match_buffer(var_T_add, (m * n,))
        for ax0 in range(m * n):
            with T.block("T_add"):
                v_ax0 = T.axis.spatial(m * n, ax0)
                T.reads(weight[v_ax0])
                T.writes(T_add[v_ax0])
                T_add[v_ax0] = weight[v_ax0] + T.float32(1)

    @R.function
    def main(x: R.Tensor(("m", "n"), dtype="float32"), weight: R.Tensor(("m * n",), dtype="float32")) -> R.Tensor(("m * n", 1, 1, 1), dtype="float32"):
        m = T.int64()
        n = T.int64()
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            gv = R.call_tir(cls.add, (weight,), out_sinfo=R.Tensor((m * n,), dtype="float32"), tir_vars=R.shape([m, n]))
            R.output(gv)
        return gv

mod = Module
with tvm.transform.PassContext(disabled_pass=["RemoveUnusedParameters"]):
    mod = relax.transform.FuseTIR()(mod)
mod = tvm.relax.transform.LegalizeOps()(mod)
mod = relax.transform.LambdaLift()(mod)
mod = relax.transform.LiftTransformParams()(mod)
ex = relax.build(mod, target='llvm')
@Thrsu Thrsu added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Oct 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

1 participant