Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI][SPIRV] Cast to float32 not float64 before log2 in sort/scan #7669

Merged
merged 8 commits into from
Apr 17, 2021

Conversation

masahi
Copy link
Member

@masahi masahi commented Mar 16, 2021

This fixes TIR scan + dynamic input shape on VK / SPIRV. I debugged this problem by doing scan on 2 elements, so that up/downsweep run only one iteration.

I found that when scan_axis_size is dynamic whose runtime value is 2, the value of lim is 0 instead of expected 1. Surprisingly this issue was fixed by casting scan_axis_size to float32 instead of 64. I realized that generally GPUs (especially low ends) don't have great support for fp64, so I think this is better.

Now dynamic cumsum, argwhere etc are working with VK.

@mbrookhart

Copy link
Contributor

@mbrookhart mbrookhart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm kind of stunned by this. I'm not sure how it's possible that ceil(log2(x)) == 0 for x > 1 in any datatype. That feels like a kind of fundamental issue with the intrinsic...

I worry ever so slightly about rounding issues taking a large input size that's just over a power of 2 and casting it to a large number just under a power of 2 in float32 (due to the lack of precision), and then this will return the wrong number.

I wrote a little script to test this:

import numpy as np
for i in range(30):
    n = np.array(2**i + 1).astype("int64")
    f = n.astype("float32")
    n2 = f.astype("int64")
    print(i, n, n2)
    assert n2 >= n

and it asserts at n = 2**24 + 1 = 16,777,217

That's larger than anything we can currently fit in GPU memory, so I don't think it's an issue at the moment, but it's a little uncomfortably close for my tastes.

Maybe we should add a warning/assert if the input size is too big?

@tqchen
Copy link
Member

tqchen commented Mar 16, 2021

Perhaps we should think about other alternatives for such an intrinsics.

see

@masahi
Copy link
Member Author

masahi commented Mar 16, 2021

Ok updated to cast to float32 only in the problematic case, which is VK + dynamic input on TIR scan. I think this is an acceptable solution for now. Of course, the best solution is to implement TIR level CSE, since the host is doing the same compute anyway and there is no point computing log2 etc in device.

Interestingly, TIR mergepath kernel used in sort, which is also littered with glsl log2 and ceil, doesn't cast to float64 before log2 in the GPU IR. If you see the IR dump https://gist.github.com/masahi/c0979c61907af15f9924b3b3d72fe6a7, there is no float64 anywhere. But for TIR scan downsweep kernel, there is a cast to float64. So I removed cast to float32 in TIR sort.

It could also be the case that our SPIRV codegen for int64 to float64 cast is busted, but I haven't checked. Another weird thing is that glsl log2 on fp64 works correctly if the input size is static.

Copy link
Contributor

@mbrookhart mbrookhart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@masahi masahi marked this pull request as draft March 16, 2021 20:06
@masahi
Copy link
Member Author

masahi commented Mar 16, 2021

Made it a draft while I am reading about clz bit hacks

@masahi
Copy link
Member Author

masahi commented Mar 16, 2021

@mbrookhart @tqchen The SPIRV spec says their log2 intrinsics only support 16 or 32 bit floating point https://www.khronos.org/registry/spir-v/specs/1.0/GLSL.std.450.html

The operand x must be a scalar or vector whose component type is 16-bit or 32-bit floating-point.

@masahi
Copy link
Member Author

masahi commented Mar 16, 2021

Looks like one reasonable way to implement ceil(log2(x) is 32 - clz(x) + (x & (x-1) ? 1 : 0) for 32 bit integers. We need to be careful with 32 bit vs 64 bit and signed vs unsigned.

We need to add intrinsic lowering of tvm.tir.clz for llvm and spirv. I'll do that next week.

tmoreau89 added a commit to tmoreau89/tvm that referenced this pull request Mar 16, 2021
Co-authored-by: Masahiro Masuda <masahi129@gmail.com>
@masahi masahi marked this pull request as ready for review April 16, 2021 13:31
@masahi
Copy link
Member Author

masahi commented Apr 16, 2021

@mbrookhart I'm finally back with this, we can now do integer ceil(log2(x)) without cast to float for vulkan.

Copy link
Contributor

@mbrookhart mbrookhart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Masa! This looks great. I'll merge when it passes CI

@masahi masahi merged commit e082ef5 into apache:main Apr 17, 2021
@masahi
Copy link
Member Author

masahi commented Apr 17, 2021

Thanks @mbrookhart @tqchen

mehrdadh pushed a commit to mehrdadh/tvm that referenced this pull request Apr 22, 2021
…pache#7669)

* [TOPI] Cast to float32 before log2 in sort/scan

* revert sort change since this seems unnecessary

* only does cast to float32 on vk + dynamic input case

* check against IntImm instead of Var

* revert change

* use clz for ceil_log2 when compiling for vk

* add doc on ceil_log2

* fix pylint

Co-authored-by: Masahiro Masuda <masahi@129@gmail.com>
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request May 6, 2021
…pache#7669)

* [TOPI] Cast to float32 before log2 in sort/scan

* revert sort change since this seems unnecessary

* only does cast to float32 on vk + dynamic input case

* check against IntImm instead of Var

* revert change

* use clz for ceil_log2 when compiling for vk

* add doc on ceil_log2

* fix pylint

Co-authored-by: Masahiro Masuda <masahi@129@gmail.com>
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request May 6, 2021
…pache#7669)

* [TOPI] Cast to float32 before log2 in sort/scan

* revert sort change since this seems unnecessary

* only does cast to float32 on vk + dynamic input case

* check against IntImm instead of Var

* revert change

* use clz for ceil_log2 when compiling for vk

* add doc on ceil_log2

* fix pylint

Co-authored-by: Masahiro Masuda <masahi@129@gmail.com>
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request May 6, 2021
…pache#7669)

* [TOPI] Cast to float32 before log2 in sort/scan

* revert sort change since this seems unnecessary

* only does cast to float32 on vk + dynamic input case

* check against IntImm instead of Var

* revert change

* use clz for ceil_log2 when compiling for vk

* add doc on ceil_log2

* fix pylint

Co-authored-by: Masahiro Masuda <masahi@129@gmail.com>
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request May 11, 2021
…pache#7669)

* [TOPI] Cast to float32 before log2 in sort/scan

* revert sort change since this seems unnecessary

* only does cast to float32 on vk + dynamic input case

* check against IntImm instead of Var

* revert change

* use clz for ceil_log2 when compiling for vk

* add doc on ceil_log2

* fix pylint

Co-authored-by: Masahiro Masuda <masahi@129@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants