Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Fix the parser to avoid treating a list as an integer #17497

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

jikechao
Copy link
Contributor

@jikechao jikechao commented Oct 30, 2024

When the output tensor is in the form of (10), TVM will crash unexpectedly as follows.
This PR adds a rule to convert the (10) to (10,) to avoid such a crash!

Traceback (most recent call last):
  File "/share_container/optfuzz/res/ut_ut_test/res_executions/14832_test.py", line 9, in <module>
    class Module:
  File "/share_container/optfuzz/res/ut_ut_test/res_executions/14832_test.py", line 11, in Module
    def main(x: R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((10,), dtype="float32"))) -> R.Tensor((10), dtype="float32"):
                                                                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/script/parser/relax/entry.py", line 266, in Tensor
    if shape is not None and not isinstance(shape, Var) and len(shape) == 0:
                                                            ^^^^^^^^^^
TypeError: object of type 'int' has no len()
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(x: R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((10), dtype="float32"))) -> R.Tensor((10), dtype="float32"):
        cls = Module
        with R.dataflow():
            lv: R.Tensor((10,), dtype="float32") = x[0]
            R.output(lv)
        return lv
mod = Module
mod.show()

cc @tqchen @Hzfengsy @Lunderberg @yongwww

@jikechao jikechao changed the title [Relax] Fix the parser to avoid treating a list as an integer. [Relax] Fix the parser to avoid treating a list as an integer Oct 30, 2024
@Hzfengsy
Copy link
Member

Thanks for the contribution!

First, it's not a typical bug. (10) is just an integer rather than a list based on the Python syntax, while [10] and (10, ) are list and tuples. Here, we request a list as input, so the error is expected (based on the current implementation)

Second, it should be a valuable sugar for end users. But please:

  1. consider supporting not only integers but also prim expr, symbolic shapes etc.
  2. please add tests in the parser test file

@jikechao jikechao marked this pull request as draft October 30, 2024 15:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants