-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Codegen][CUDA] Fix: cuda codegen vectorize cast #7561
Conversation
0ecc450
to
4a9aeec
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
} | ||
if (!fail) { | ||
return; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing break
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
@@ -511,8 +511,8 @@ def check(t0, t1): | |||
|
|||
# schedule | |||
s = tvm.te.create_schedule(C.op) | |||
ob, ib = s[C].split(s[C].op.axis[0], nparts=32) | |||
_, iib = s[C].split(ib, factor=4) | |||
ob, ib = s[C].split(s[C].op.axis[0], nparts=n // factor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also directly say factor=factor
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
like this?
ob, ib = s[C].split(s[C].op.axis[0], factor=factor)
# _, iib = s[C].split(ib, factor=factor)
s[C].vectorize(ib)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
* fix: cuda codegen vectorize cast * style: fix python coding style * fix: missing break * refactor: directly split by factor Co-authored-by: jiangchengquan <jiangchengquan@bytedance.com>
* fix: cuda codegen vectorize cast * style: fix python coding style * fix: missing break * refactor: directly split by factor Co-authored-by: jiangchengquan <jiangchengquan@bytedance.com>
Data types such as float32x8 and int32x8 are not supported in CUDA, which will result in errors like "Cannot convert type float32x8 to CUDA type!" in code generation. I tried to fix this by storing 2 32-bits values in 1 64-bits value.
Could you please help review this fix? @Laurawly