Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vulkan] Support passing 64 bit scalar #7572

Merged
merged 13 commits into from
Mar 5, 2021
Merged

[Vulkan] Support passing 64 bit scalar #7572

merged 13 commits into from
Mar 5, 2021

Conversation

masahi
Copy link
Member

@masahi masahi commented Mar 2, 2021

Implements suggestion in #7457 (comment)

Seems to work on Vulkan. Enabled get_valid_counts and cumsum tests on vulkan, which uses TIR scan and requires passing int64 scalar.

Metal runtime is also updated but not tested at all, since I don't have mac.

@tqchen @tmoreau89 @jwfromm

@jroesch
Copy link
Member

jroesch commented Mar 3, 2021

Does this work for f64? I think that was the error I triggered the other night

src/runtime/pack_args.h Outdated Show resolved Hide resolved
jroesch
jroesch previously requested changes Mar 3, 2021
Copy link
Member

@jroesch jroesch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify why we didn't just replace ArgUnion?

@masahi
Copy link
Member Author

masahi commented Mar 3, 2021

ArgUnion is also used by other runtime (cuda/opencl I think)

TempArray<ArgUnion, N> holder_(num_args);

I was not sure / haven't look if we can replace this with 64. Since ArgUnion only supports 32 bit and I haven't seen any issue with it, I'm not sure if replacing this with 64 bit union is a good idea. Thoughts?

@masahi masahi marked this pull request as ready for review March 3, 2021 09:21
@tqchen
Copy link
Member

tqchen commented Mar 3, 2021

cc @vinx13 @zxy844288792 @jwfromm can you please help if the change works for metal GPU backend? Thank you!

@masahi
Copy link
Member Author

masahi commented Mar 3, 2021

Just found that Metal codegen generates hardcoded __TVMArgUnion union:

decl_stream << "union __TVMArgUnion {\n"
<< " int v_int;\n"
<< "};\n\n";

Do we need to change this to something like this?

  decl_stream << "union __TVMArgUnion64 {\n"
              << " int32_t v_int32[2];\n"
              << " int64_t v_int64;\n"    
              << "};\n\n";

And I don't understand why float32 is not there.

@vinx13
Copy link
Member

vinx13 commented Mar 3, 2021

@masahi Actually I don't think int (and float32) should be there, since 32-bit args are captured here

if (v.dtype().bits() == 32) {
decl_stream << " ";
PrintType(v.dtype(), decl_stream);
decl_stream << " " << vid << ";\n";
vref << varg << "." << vid;
} else {

@masahi
Copy link
Member Author

masahi commented Mar 3, 2021

@tqchen @vinx13 @jroesch @jwfromm I can leave the Metal runtime as an old version and only update vulkan runtime, if that sounds better. Unless someone looks at the Metal issue in detail for me there is nothing I can do.

@tqchen
Copy link
Member

tqchen commented Mar 3, 2021

@vinx13 @jwfromm @junrushao1994 would be great if any of you who have a macbook can help double check the metal impl

I think we only need to change

if (v.dtype().bits() == 32) {
to print specific types.

@masahi I think your change to __TArgUnion is right. Note that we only use ArgUnion to pass in values that are non-32 bits(e.g. int8 that needs to be passed).

We need to update the generation of the argument buffer struct for 32bit and 64bit values in L104

@vinx13
Copy link
Member

vinx13 commented Mar 4, 2021

@tqchen @masahi

length:num_pack_args_ * sizeof(ArgUnion64)
here the encoder assumes every argument is 64 bit, I think we need to remove this block
if (v.dtype().bits() == 32) {
decl_stream << " ";
PrintType(v.dtype(), decl_stream);
decl_stream << " " << vid << ";\n";
vref << varg << "." << vid;
} else {
and use TVMArgUnion64 for argument values

@tqchen
Copy link
Member

tqchen commented Mar 4, 2021

we should still keep the 32 bit path, instead generating a int32[2] type as in our argunion

@tqchen
Copy link
Member

tqchen commented Mar 4, 2021

@masahi please still incorporate the metal codegen change, otherwise the constant path would results in error(since we already pass in ArgUnion64. Perhaps @vinx13 can followup to suggest a change

@masahi
Copy link
Member Author

masahi commented Mar 4, 2021

I can let @vinx13 directly push to my branch

@vinx13
Copy link
Member

vinx13 commented Mar 4, 2021

I can help with it. @tqchen are you suggesting separating 32bit and 64 bit when passing args?

@tqchen
Copy link
Member

tqchen commented Mar 4, 2021

@vinx13 we can follow the same convention, e.g. pass int32 as int32[2] and waste the second element

@masahi
Copy link
Member Author

masahi commented Mar 5, 2021

Ready to merge? @vinx13 @tqchen @jroesch

@tqchen tqchen merged commit c0b9688 into apache:main Mar 5, 2021
@tqchen
Copy link
Member

tqchen commented Mar 5, 2021

Thanks @masahi @vinx13 @jroesch !

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request May 6, 2021

Co-authored-by: Wuwei Lin <wuwei@apache.org>
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request May 11, 2021

Co-authored-by: Wuwei Lin <wuwei@apache.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants