-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Vulkan] Support passing 64 bit scalar #7572
Conversation
Does this work for f64? I think that was the error I triggered the other night |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you clarify why we didn't just replace ArgUnion?
Line 143 in 1831c17
I was not sure / haven't look if we can replace this with 64. Since |
cc @vinx13 @zxy844288792 @jwfromm can you please help if the change works for metal GPU backend? Thank you! |
Just found that Metal codegen generates hardcoded tvm/src/target/source/codegen_metal.cc Lines 49 to 51 in 053347c
Do we need to change this to something like this?
And I don't understand why float32 is not there. |
@masahi Actually I don't think tvm/src/target/source/codegen_metal.cc Lines 104 to 109 in 053347c
|
@vinx13 @jwfromm @junrushao1994 would be great if any of you who have a macbook can help double check the metal impl I think we only need to change tvm/src/target/source/codegen_metal.cc Line 104 in 053347c
@masahi I think your change to We need to update the generation of the argument buffer struct for 32bit and 64bit values in L104 |
tvm/src/runtime/metal/metal_module.mm Line 200 in 6cd9626
encoder assumes every argument is 64 bit, I think we need to remove this block tvm/src/target/source/codegen_metal.cc Lines 104 to 109 in 053347c
TVMArgUnion64 for argument values
|
we should still keep the 32 bit path, instead generating a int32[2] type as in our argunion |
I can let @vinx13 directly push to my branch |
I can help with it. @tqchen are you suggesting separating 32bit and 64 bit when passing args? |
@vinx13 we can follow the same convention, e.g. pass int32 as int32[2] and waste the second element |
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Implements suggestion in #7457 (comment)
Seems to work on Vulkan. Enabled
get_valid_counts
andcumsum
tests on vulkan, which uses TIR scan and requires passing int64 scalar.Metal runtime is also updated but not tested at all, since I don't have mac.
@tqchen @tmoreau89 @jwfromm