-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes OneHotMatrix/Vector GPU Performance #612
Conversation
test/cuda/cuda.jl
Outdated
@@ -38,6 +38,13 @@ Flux.back!(sum(l)) | |||
|
|||
end | |||
|
|||
@testset "onecold gpu" begin | |||
CuArrays.allowscalar(false) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be necessary as part of the test (we call it at the beginning of this file I think).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I'll remove it
src/onehot.jl
Outdated
function Base.getindex(xs::Flux.OneHotMatrix{T}, ot::Union{Base.Slice, Base.OneTo}, i::Int) where {T<:AbstractArray} | ||
res = similar(xs, size(xs, 1), 1) | ||
if length(ot) == size(xs, 1) | ||
res = xs[:,i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this branch? Are there any cases where they aren't equivalent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
W/O this branch - 136.165 ms (50001 allocations: 1.99 MiB)
julia> A = Flux.onehotbatch(1:300, 1:10000) |> gpu;
julia> d = 1
1
julia> a = Base.Slice(axes(A, d))
Base.Slice(Base.OneTo(10000))
julia> A[a, 5]
10000-element Array{Bool,1}:
false
false
false
...
With - 15.930 μs (7 allocations: 10.16 KiB)
julia> A = Flux.onehotbatch(1:300, 1:10000) |> gpu;
julia> a = Base.Slice(axes(A, d))
Base.Slice(Base.OneTo(10000))
julia> A[a, 5]
10000-element Flux.OneHotVector:
false
false
false
...
Performance and avoiding the allocation of the vector, basically.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that also true on CPU? Is the slowdown due to scalar indexing? It seems like this might need to be something that's fixed at the CuArrays level rather than being special cased here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can remove it, if that behaviour is expected and should be maintained.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Scalar indexing happens when we try to get a column out of the .data
field from OneHotMatrix
currently, which does affect performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not so much about whether the behaviour is expected as where the bug should be filed. If it can be fixed in CuArrays instead then it should be. It's still not clear to me whether or not that's the case, but I'll take a closer look at the code.
src/onehot.jl
Outdated
@@ -22,6 +24,22 @@ Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i] | |||
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] | |||
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) | |||
|
|||
Base.getindex(xs::Flux.OneHotMatrix, j::Base.UnitRange, i::Int) = xs.data[i][j] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps just change the above definition to
Base.getindex(xs::OneHotMatrix, i::Union{Integer,AbstractVector}, j::Integer)
src/onehot.jl
Outdated
@@ -22,6 +24,22 @@ Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i] | |||
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] | |||
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) | |||
|
|||
Base.getindex(xs::Flux.OneHotMatrix, j::Base.UnitRange, i::Int) = xs.data[i][j] | |||
|
|||
Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = xs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this definition is already handled above. Probably best to leave it; even though it's faster to avoid the allocation, it's more correct to return an independent array.
test/onehot.jl
Outdated
@@ -15,5 +15,4 @@ end | |||
@testset "onehotbatch indexing" begin | |||
y = Flux.onehotbatch(ones(3), 1:10) | |||
@test y[:,1] isa Flux.OneHotVector | |||
@test y[:,:] isa Flux.OneHotMatrix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests should both still pass, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This goes into scalar indexing and allocates the whole array. I had that function to avoid it..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worth it to have something like getindex(xs::OneHotMatrix, ::Colon, ::Colon) = deepcopy(xs)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Best to avoid deepcopy
-- perhaps copy(xs.data)
and reconstruct.
My old comment got deleted from the diff, so just to reiterate; there's a line Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i] I think if |
I am trying to figure out a nice way to do the reduction in the case of an |
src/onehot.jl
Outdated
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] | ||
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) | ||
# Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy(xs.data)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this meant to be commented out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I found that the second I pushed, fixed with 2952bcd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Mike!
tryBuild succeeded |
bors r+ |
612: Fixes OneHotMatrix/Vector GPU Performance r=MikeInnes a=dhairyagandhi96 Added tests in conjunction with changes made to the behaviour of OneHotVector/Matrix cc @MikeInnes @KristofferC Co-authored-by: Dhairya Gandhi <dhairya@juliacopmuting.com>
Hmm.. bors seems to have given up, but the tests have finished https://gitlab.com/JuliaGPU/Flux.jl/pipelines/58592993 |
I guess the internal error is just the usual spurious one? If so feel free to merge. |
Those messages indeed are the ones CuArrays has been showing. Will fix the merge conflicts and merge. |
Added tests in conjunction with changes made to the behaviour of OneHotVector/Matrix
cc @MikeInnes @KristofferC