-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix for onecold broadcast bug #764
Conversation
bors try |
tryBuild succeeded |
What was the bug? Could the existing test be tweaked or added to in order to catch it? |
I think we already have the test that is supposed to check for this Line 43 in e991228
Maybe we should add an extra one that checks for the broadcasting just like in the Edit: I should've written the description properly, and I shall change that to reflect what changes are made and why. |
The diff looks odd to me -- if this can be done with a simple If there's no behaviour change then this is certainly good to go. |
We needed to hit an optimised kernel with the GPU at the time (iirc, because it's been a while since this has been up), but things have moved on massively since then and the regular way works fine now. No change in behaviour |
Bump |
I'm not totally clear on the testing situation; if there's a test for this already, why isn't it failing currently? Can we add the test that would fail without this patch? |
A test that would fail right now is the following identity: using Flux: onehotbatch, onecold
using Test
data = [:b, :a, :c]
labels = [:a, :b, :c]
hot = onehotbatch(data, labels)
cold = onecold(hot, labels)
@test cold == data This is something that works with the new However, I don't know why the the |
Perhaps another relevant data point: When the |
I think the original reason was for some CUDA optimization. But this PR really should get some attention. |
@DhairyaLGandhi could you add the test from #764 (comment) and get this merged? |
bors r+ |
764: Fix for onecold broadcast bug r=DhairyaLGandhi a=DhairyaLGandhi Using `mapreduce` was returning an `Array` when used with `OneHotMatrix{CuArray}`. This causes issues down the chain. `map` avoids that, and also avoids the reduction giving slightly better speed (~40x). Co-authored-by: Dhairya Gandhi <dhairya@juliacopmuting.com>
Build failed: |
bors r+ |
Build succeeded: |
Using
mapreduce
was returning anArray
when used withOneHotMatrix{CuArray}
. This causes issues down the chain.map
avoids that, and also avoids the reduction giving slightly better speed (~40x).