Skip to content

Commit

Permalink
Merge #612
Browse files Browse the repository at this point in the history
612: Fixes OneHotMatrix/Vector GPU Performance r=MikeInnes a=dhairyagandhi96

Added tests in conjunction with changes made to the behaviour of OneHotVector/Matrix
cc @MikeInnes @KristofferC 

Co-authored-by: Dhairya Gandhi <dhairya@juliacopmuting.com>
  • Loading branch information
bors[bot] and Dhairya Gandhi committed Apr 26, 2019
2 parents 13cfcb5 + 2952bcd commit 6e14bf9
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
14 changes: 12 additions & 2 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),)

Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix

Base.getindex(xs::OneHotVector, ::Colon) = OneHotVector(xs.ix, xs.of)

A::AbstractMatrix * b::OneHotVector = A[:, b.ix]

struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
Expand All @@ -18,9 +20,12 @@ end

Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))

Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, i::Union{Integer, AbstractVector}, j::Integer) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy(xs.data))

Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data)

A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]

Expand Down Expand Up @@ -54,13 +59,18 @@ end
onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])

Base.argmax(xs::OneHotVector) = xs.ix

onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]

onecold(y::AbstractMatrix, labels...) =
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)

onecold(y::OneHotMatrix, labels...) =
mapreduce(x -> Flux.onecold(x, labels...), |, y.data, dims = 2, init = 0)

function argmax(xs...)
Base.depwarn("`argmax(...) is deprecated, use `onecold(...)` instead.", :argmax)
Base.depwarn("`argmax(...)` is deprecated, use `onecold(...)` instead.", :argmax)
return onecold(xs...)
end

Expand Down
6 changes: 6 additions & 0 deletions test/cuda/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ Flux.back!(sum(l))

end

@testset "onecold gpu" begin
y = Flux.onehotbatch(ones(3), 1:10) |> gpu;
@test Flux.onecold(y) isa CuArray
@test y[3,:] isa CuArray
end

if CuArrays.libcudnn != nothing
@info "Testing Flux/CUDNN"
include("cudnn.jl")
Expand Down
6 changes: 6 additions & 0 deletions test/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,9 @@ using Test
@test onecold(a, labels) == 'C'
@test onecold(A, labels) == ['C', 'A', 'D']
end

@testset "onehotbatch indexing" begin
y = Flux.onehotbatch(ones(3), 1:10)
@test y[:,1] isa Flux.OneHotVector
@test y[:,:] isa Flux.OneHotMatrix
end

0 comments on commit 6e14bf9

Please sign in to comment.