-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOPI][SPIRV] Cast to float32 not float64 before log2 in sort/scan #7669
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm kind of stunned by this. I'm not sure how it's possible that ceil(log2(x)) == 0
for x > 1 in any datatype. That feels like a kind of fundamental issue with the intrinsic...
I worry ever so slightly about rounding issues taking a large input size that's just over a power of 2 and casting it to a large number just under a power of 2 in float32 (due to the lack of precision), and then this will return the wrong number.
I wrote a little script to test this:
import numpy as np
for i in range(30):
n = np.array(2**i + 1).astype("int64")
f = n.astype("float32")
n2 = f.astype("int64")
print(i, n, n2)
assert n2 >= n
and it asserts at n = 2**24 + 1 = 16,777,217
That's larger than anything we can currently fit in GPU memory, so I don't think it's an issue at the moment, but it's a little uncomfortably close for my tastes.
Maybe we should add a warning/assert if the input size is too big?
Perhaps we should think about other alternatives for such an intrinsics. see |
Ok updated to cast to float32 only in the problematic case, which is VK + dynamic input on TIR scan. I think this is an acceptable solution for now. Of course, the best solution is to implement TIR level CSE, since the host is doing the same compute anyway and there is no point computing log2 etc in device. Interestingly, TIR mergepath kernel used in sort, which is also littered with glsl log2 and ceil, doesn't cast to float64 before log2 in the GPU IR. If you see the IR dump https://gist.github.com/masahi/c0979c61907af15f9924b3b3d72fe6a7, there is no It could also be the case that our SPIRV codegen for int64 to float64 cast is busted, but I haven't checked. Another weird thing is that glsl log2 on fp64 works correctly if the input size is static. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Made it a draft while I am reading about clz bit hacks |
@mbrookhart @tqchen The SPIRV spec says their log2 intrinsics only support 16 or 32 bit floating point https://www.khronos.org/registry/spir-v/specs/1.0/GLSL.std.450.html
|
Looks like one reasonable way to implement We need to add intrinsic lowering of |
Co-authored-by: Masahiro Masuda <masahi129@gmail.com>
@mbrookhart I'm finally back with this, we can now do integer ceil(log2(x)) without cast to float for vulkan. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Masa! This looks great. I'll merge when it passes CI
Thanks @mbrookhart @tqchen |
…pache#7669) * [TOPI] Cast to float32 before log2 in sort/scan * revert sort change since this seems unnecessary * only does cast to float32 on vk + dynamic input case * check against IntImm instead of Var * revert change * use clz for ceil_log2 when compiling for vk * add doc on ceil_log2 * fix pylint Co-authored-by: Masahiro Masuda <masahi@129@gmail.com>
…pache#7669) * [TOPI] Cast to float32 before log2 in sort/scan * revert sort change since this seems unnecessary * only does cast to float32 on vk + dynamic input case * check against IntImm instead of Var * revert change * use clz for ceil_log2 when compiling for vk * add doc on ceil_log2 * fix pylint Co-authored-by: Masahiro Masuda <masahi@129@gmail.com>
…pache#7669) * [TOPI] Cast to float32 before log2 in sort/scan * revert sort change since this seems unnecessary * only does cast to float32 on vk + dynamic input case * check against IntImm instead of Var * revert change * use clz for ceil_log2 when compiling for vk * add doc on ceil_log2 * fix pylint Co-authored-by: Masahiro Masuda <masahi@129@gmail.com>
…pache#7669) * [TOPI] Cast to float32 before log2 in sort/scan * revert sort change since this seems unnecessary * only does cast to float32 on vk + dynamic input case * check against IntImm instead of Var * revert change * use clz for ceil_log2 when compiling for vk * add doc on ceil_log2 * fix pylint Co-authored-by: Masahiro Masuda <masahi@129@gmail.com>
…pache#7669) * [TOPI] Cast to float32 before log2 in sort/scan * revert sort change since this seems unnecessary * only does cast to float32 on vk + dynamic input case * check against IntImm instead of Var * revert change * use clz for ceil_log2 when compiling for vk * add doc on ceil_log2 * fix pylint Co-authored-by: Masahiro Masuda <masahi@129@gmail.com>
This fixes TIR scan + dynamic input shape on VK / SPIRV. I debugged this problem by doing scan on 2 elements, so that up/downsweep run only one iteration.
I found that when
scan_axis_size
is dynamic whose runtime value is 2, the value oflim
is 0 instead of expected 1. Surprisingly this issue was fixed by castingscan_axis_size
to float32 instead of 64. I realized that generally GPUs (especially low ends) don't have great support for fp64, so I think this is better.Now dynamic cumsum, argwhere etc are working with VK.
@mbrookhart