Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] narrow thread extents to 32 bits for GPU lowering #10969

Closed
wants to merge 4 commits into from

Conversation

altanh
Copy link
Contributor

@altanh altanh commented Apr 11, 2022

Occasionally, int64 constants get piped through lowering and end up as thread extents, which can cause a dtype mismatch with the thread IterVar (which should be int32 on GPU). This PR narrows extents to int32 for GPU lowering to avoid the mismatch.

I added a test case for a small broadcast_to -> sum program that fails to compile before this fix.

cc @Lunderberg @mbrookhart

@@ -836,7 +836,7 @@ def broadcast_to(data, shape):
The resulting tensor.
"""
if isinstance(shape, Constant):
shape = list(shape.data.numpy())
shape = [int(i) for i in shape.data.numpy()]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because tvm.runtime.convert handles all integer types as int32, would this cause issues with arrays larger than 4 GB? I don't think we have many of those in practice, but I think it could then cause a similar issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible, although the default handler for a python list of ints would hit the same problem. In either case, I'll push a proper fix for the reduction schedule on CUDA (forcibly cast the extent to int32, since that's needed for CUDA anyway if I understand correctly).

Do you think I should wrap ints using numpy int64 to avoid this problem? I'm slightly worried it will break assumptions elsewhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess one possible option is to change the behavior tvm.runtime.convert to check for overflow and use int64 when necessary

@altanh altanh changed the title [Bug] hotfix edge case for broadcast_to const shape int64 [Bug] narrow thread extents to 32 bits for GPU lowering Apr 12, 2022
Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

I like the verification that the size will fit in an int32.I could imagine edge cases, such as a user declaring an dynamically-sized input buffer with int64 size, then using a schedule that chooses the number of threads based on that size. However, that feels like enough of an edge case that it isn't worth replacing CanProveLess with !CanProveGreaterEqual.

@masahi
Copy link
Member

masahi commented Apr 12, 2022

There is also a PR (looks related) #10983 touching narrow_datatype.cc. Can you discuss which patch to merge? cc @ganler

@altanh
Copy link
Contributor Author

altanh commented Apr 12, 2022

There is also a PR (looks related) #10983 touching narrow_datatype.cc. Can you discuss which patch to merge? cc @ganler

haha great timing, @ganler do you mind trying the test case I added here on your branch? I suspect your fix might miss the case of mismatched int32 var with int64 extent.

@ganler
Copy link
Contributor

ganler commented Apr 12, 2022

@masahi @altanh I just tried your test case and #10983 can pass it.

@ganler
Copy link
Contributor

ganler commented Apr 12, 2022

OK, IMHO, the difference between this fix and #10983 is that:

  • this one fixes the bug in Visitor by doing a casting for thread extend;
  • [FIX] resolve int64/32 for AttrStmtNode #10983 (theoretically) fixes a bigger scope of similar bugs as it does casting in the Rewrite phase and does not specifically target attr::thread_extent (can be any other incompatible integer attributes).

@altanh
Copy link
Contributor Author

altanh commented Apr 12, 2022

interesting, I'm confused how it does the cast when the IterVar is int32 but the extent is int64 (since then vi_dtype.bits() < var.dtype().bits() is false), there must be some interaction with the DataTypeVisitor that I don't understand. I'll defer to @Lunderberg, I'm fine with either fix

@ganler
Copy link
Contributor

ganler commented Apr 12, 2022

@altanh I just tried your PR, which also fixes the issue from a PyTorch-generated ONNX model in #10983. So it seems the main issue is from the thread extend?

@altanh
Copy link
Contributor Author

altanh commented Apr 12, 2022

Yep I think we're solving the same problem haha

@ganler
Copy link
Contributor

ganler commented Apr 12, 2022

interesting, I'm confused how it does the cast when the IterVar is int32 but the extent is int64 (since then vi_dtype.bits() < var.dtype().bits() is false), there must be some interaction with the DataTypeVisitor that I don't understand. I'll defer to @Lunderberg, I'm fine with either fix

Very great point and this is where I think your PR is better than mine as I simply assume that extend.dtype.bits <= var.dtype.bits.

@ganler
Copy link
Contributor

ganler commented Apr 12, 2022

But I am not sure if CanProveLess is able to consider the integer bits. That said, if the var of IterVal is int64 and extend is int32. They are both symbolic that CanProveLess seems to return false but we can actually cast extend to int64 right? From the tests, it seems results to CanProveLess are often false, causing failures.

@altanh
Copy link
Contributor Author

altanh commented Apr 12, 2022

yeah I see those tests failures now... my operating assumption is that for the thread_extent attr, it's only for GPUs and currently GPUs only support int32 IterVar (e.g. threadIdx.x) and extents. So maybe I should just go ahead and force everything to be int32 here? I also assumed that the extents will be concrete, but maybe that's too strict too?

@ganler
Copy link
Contributor

ganler commented Apr 12, 2022

If that's the case, probably we should just force the thread index to be int32. i.e., removing CanProveLess. Or let's just use int64 and let the compiler complain about larger-than-int32 extends?

@altanh
Copy link
Contributor Author

altanh commented Apr 12, 2022

let me try just removing the CanProveLess and see if it can pass CI

@altanh
Copy link
Contributor Author

altanh commented Apr 12, 2022

@ganler it seems like I'm breaking some assumptions elsewhere with this change- some unit tests seem to want the thread extents to explicitly be int64. Maybe we should just go with your change if you're happy with it

@ganler
Copy link
Contributor

ganler commented Apr 12, 2022

@altanh I see. According to previous fixes (#10172 #9582 #10519 #10571 #10584), it seems to be more compatible if we just change the data type when constructing a new IR node. And #10983 did pass the CI (but I made further edits to adapt your suggestion about vi_dtype.bits() < var.dtype().bits()).

@altanh
Copy link
Contributor Author

altanh commented Apr 12, 2022

I'll close this PR once yours passes CI, thx for tracking down the other relevant PRs!

@masahi masahi closed this Apr 13, 2022
@masahi
Copy link
Member

masahi commented Apr 13, 2022

The other PR #10983 has been merged, thanks @altanh @ganler for discussion

@Lunderberg
Copy link
Contributor

I'm fine with either fix as well, and thank you for tracking it down!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants