-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix][TIR]fix mul dtype mismatch #16010
Conversation
CC @Lunderberg @wrongtest-intellif |
It would probably be good to add a unit test. The best way to do so would be to define a TIR function that (before the change) would trigger the mismatched dtype bug when passed through In this case, the buggy output would contain an extra multiplication step, due to the explicit construction of a class TestMultiplicationNodesAreInligned(tvm.testing.CompareBeforeAfter):
transform = tvm.tir.transform.InjectPTXAsyncCopy()
def before(A: T.Buffer((32, 128), "float16"), B: T.Buffer((32, 128), "float16")):
tx = T.launch_thread("threadIdx.x", 32)
A_flattened = T.Buffer((4096,), "float16", data=A.data)
A_shared = T.decl_buffer([4096], "float16", scope="shared")
T.attr("default", "async_scope", 1)
for i in range(16):
A_shared[tx * 128 + i * 8 : tx * 128 + i * 8 + 8] = A_flattened[
tx * 128 + i * 8 : tx * 128 + i * 8 + 8
]
T.ptx_commit_group()
T.ptx_wait_group(0)
def expected(A: T.Buffer((32, 128), "float16"), B: T.Buffer((32, 128), "float16")):
tx = T.launch_thread("threadIdx.x", 32)
A_shared = T.decl_buffer([4096], "float16", scope="shared")
for i in range(16):
T.ptx_cp_async(
"float16",
A_shared.data,
T.Mul(tx * 128 + i * 8, 1),
A.data,
tx * 128 + i * 8,
16,
)
T.ptx_commit_group()
T.ptx_wait_group(0) With your fix applied, I suspect that this would fail due to the |
Thank you for your suggestion. There is something confuse me. |
Ah, I didn't realize that was a necessary step of reproducing the bug. The way to use int64 for the integer literals is to wrap each of them with |
Ooh, that's an interesting failure mode, and good debugging on catching it. Looks like the root cause of that extra step is here, where the |
Thanks for the review. Can this PR be merged then? |
Looks like there were a couple of CI steps that required approval to start. I've started them, and after they finish, the PR can be merged. These are compile-only tests, so they should be done relatively quickly. I'll keep an eye out for when they finish, but feel free to ping me if you notice them finish before I do. Also, I wanted to say thank you for the extra work in splitting out the separate PRs and adding the unit tests. It can be a bit tedious, but it is very much appreciated in maintaining a testable code base with history suitable for |
On closer inspection, the |
Yeah i see, it confuse me quite a lot before. Thanks for the explanation |
Another bug occurs in PASS InjectPTXAsyncCopy .
that is dst_offset.dtype could be int64, the dtype of PrimExpr(index_factor) would be set to default to int32.
cause dtype inconsistent when calling tir::Mul.
To reproduce the problem in InjectPTXAsyncCopy, see script here