Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix onehot gpu #1441

Closed
wants to merge 4 commits into from
Closed

fix onehot gpu #1441

wants to merge 4 commits into from

Conversation

CarloLucibello
Copy link
Member

@CarloLucibello CarloLucibello commented Dec 27, 2020

Fix #556, fix #582, using a GPU friendly reduction.
Benchmarked with the following script (onecold2 is the new version)

using CUDA
using Flux
using Flux: onehotbatch, onecold, OneHotMatrix
using BenchmarkTools
CUDA.allowscalar(false)

onecold2(y::AbstractMatrix, labels=1:size(y,1)) =
  vec(map(x -> labels[x[1]], argmax(y; dims=1)))

onecold2(y::OneHotMatrix, labels...) = 
    map(x -> Flux.onecold(x, labels...), y.data)

function accuracy_v1a(oh, ŷ)
    mean(onecold(oh) .== onecold(ŷ))
end

function accuracy_v1b(oh, ŷ)
    mean(onecold(cpu(oh)) .== onecold(cpu(ŷ)))
end

function accuracy_v1c(y, ŷ)
    mean(cpu(y) .== onecold(cpu(ŷ)))
end

function accuracy_v2a(oh, ŷ)
    mean(onecold2(oh) .== onecold2(ŷ))
end

function accuracy_v2b(oh, ŷ)
    mean(onecold2(cpu(oh)) .== onecold2(cpu(ŷ)))
end

function accuracy_v2c(y, ŷ)
    mean(y .== onecold2(ŷ))
end

function accuracy_v3(y, ŷ)
    mean(y .== mapslices(argmax, ŷ, dims=1))
end

function accuracy_v4(oh, ŷ)
    mean(maximum(oh .* ŷ, dims=1) .== maximum(ŷ, dims=1))
end= rand(Float32, 100, 1000)
y = rand(1:100, 1000)
oh = onehotbatch(y, 1:100)
ŷg, yg, ohg = gpu.([ŷ, y, oh]) 

println("V1A")
@btime accuracy_v1a(oh, ŷ) # 755.393 μs (9516 allocations: 248.97 KiB)
# @btime CUDA.@sync accuracy_v1a(ohg, ŷg) # Error scalar indexing

println("\nV1B")
@btime accuracy_v1b(oh, ŷ)  #   728.873 μs (9524 allocations: 249.81 KiB)
@btime CUDA.@sync accuracy_v1b(ohg, ŷg) #   1.027 ms (9542 allocations: 648.70 KiB)

println("\nV1C")
@btime accuracy_v1c(y, ŷ) # 771.765 μs (9519 allocations: 241.75 KiB)
@btime CUDA.@sync accuracy_v1c(yg, ŷg) # 1.022 ms (9537 allocations: 640.64 KiB)

println("\nV2A")
@btime accuracy_v2a(oh, ŷ) #   511.169 μs (10 allocations: 40.20 KiB)
@btime CUDA.@sync accuracy_v2a(ohg, ŷg) #   72.631 μs (261 allocations: 8.03 KiB)

println("\nV2B")
@btime accuracy_v2b(oh, ŷ) #  513.488 μs (20 allocations: 41.11 KiB)
@btime CUDA.@sync accuracy_v2b(ohg, ŷg) #   792.764 μs (38 allocations: 440.00 KiB)


println("\nV2C")
@btime accuracy_v2c(y, ŷ) # 524.095 μs (9 allocations: 32.27 KiB)
@btime CUDA.@sync accuracy_v2c(yg, ŷg)  #  64.536 μs (236 allocations: 7.38 KiB)

println("\nV3")
@btime accuracy_v3(y, ŷ)    # 1.802 ms (9510 allocations: 362.91 KiB)
# @btime CUDA.@sync accuracy_v3(yg, ŷg) # Error scalar indexing

println("\nV4")
@btime accuracy_v4(oh, ŷ) #   612.915 μs (14 allocations: 403.50 KiB)
@btime CUDA.@sync accuracy_v4(ohg, ŷg) # 85.776 μs (154 allocations: 4.23 KiB)

src/onehot.jl Outdated Show resolved Hide resolved

onecold(y::AbstractMatrix, labels...) =
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
onecold(y::AbstractMatrix, labels = 1:size(y,1)) =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove the splatted version?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because previously the default arg was handled by the specialization on AbstractVector, now it has to be handled here

@DhairyaLGandhi
Copy link
Member

#764 fixed the performance IIRC, that issue should be closed

@CarloLucibello
Copy link
Member Author

it didn't, as you can see from the benchmark I posted

@CarloLucibello
Copy link
Member Author

closing in favor of #1447

@CarloLucibello CarloLucibello deleted the cl/onecold branch January 7, 2021 08:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Typical accuracy function using onecold with a OneHotMatrix fails to compile on GPU onecold is very slow
2 participants