Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Inconsistent module structure and InternalError: Check failed: (!require_value_computed) is false: PrimExpr m is not computed #17494

Open
Thrsu opened this issue Oct 27, 2024 · 0 comments
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@Thrsu
Copy link
Contributor

Thrsu commented Oct 27, 2024

Applying the transformations LiftTransformParams(), there is an inconsistency in the model structure between the sequential transformation (mod_seq) and the individual transformations (mod). And build the module after transformation, it will crash.

The error may relate to how m is handled as a dynamic shape or a required computed value, which may not be properly resolved during the transformation and build processes.

Actual behavior

  File "/software/tvm/src/relax/backend/vm/vm_shape_lower.cc", line 463
InternalError: Check failed: (!require_value_computed) is false: PrimExpr m is not computed

Steps to reproduce

import tvm
from tvm import relax
import numpy as np

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def tir_acos(var_x: T.handle, var_compute: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        m = T.int64()
        x = T.match_buffer(var_x, (T.int64(16), m, T.int64(3), T.int64(3)))
        compute = T.match_buffer(var_compute, (T.int64(16), m, T.int64(3), T.int64(3)))
        # with T.block("root"):
        for i0, i1, i2, i3 in T.grid(T.int64(16), m, T.int64(3), T.int64(3)):
            with T.block("compute"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(x[v_i0, v_i1, v_i2, v_i3])
                T.writes(compute[v_i0, v_i1, v_i2, v_i3])
                compute[v_i0, v_i1, v_i2, v_i3] = T.acos(x[v_i0, v_i1, v_i2, v_i3])

    @R.function
    def main(x: R.Tensor((1, 16, 224, "n"), dtype="float32"), w1: R.Tensor((16, "m", 3, 3), dtype="float32"), w2: R.Tensor((16, "m", 3, 3), dtype="float32")) -> R.Tensor((16, "m", 3, 3), dtype="float32"):
        m = T.int64()
        n = T.int64()
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            gv = R.call_tir(cls.tir_acos, (w1,), out_sinfo=R.Tensor((16, m, 3, 3), dtype="float32"))
            R.output(gv)
        return gv

mod = Module
mod_seq = tvm.transform.Sequential([relax.transform.LiftTransformParams(), ])(mod)
mod = relax.transform.LiftTransformParams()(mod)
ex = relax.build(mod, target='llvm')
tvm.ir.assert_structural_equal(mod_seq, mod)
@Thrsu Thrsu added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Oct 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

1 participant