diff --git a/src/onehot.jl b/src/onehot.jl index 0310181854..172591f6f9 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -9,6 +9,8 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),) Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix +Base.getindex(xs::OneHotVector, ::Colon) = OneHotVector(xs.ix, xs.of) + A::AbstractMatrix * b::OneHotVector = A[:, b.ix] struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool} @@ -18,9 +20,12 @@ end Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) -Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i] +Base.getindex(xs::OneHotMatrix, i::Union{Integer, AbstractVector}, j::Integer) = xs.data[j][i] Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) +Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy(xs.data)) + +Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data) A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)] @@ -94,6 +99,8 @@ julia> onehotbatch([:b, :a, :b], [:a, :b, :c]) onehotbatch(ls, labels, unk...) = OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls]) +Base.argmax(xs::OneHotVector) = xs.ix + """ onecold(y[, labels = 1:length(y)]) @@ -114,8 +121,11 @@ onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)] onecold(y::AbstractMatrix, labels...) = dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1) +onecold(y::OneHotMatrix, labels...) = + mapreduce(x -> Flux.onecold(x, labels...), |, y.data, dims = 2, init = 0) + function argmax(xs...) - Base.depwarn("`argmax(...) is deprecated, use `onecold(...)` instead.", :argmax) + Base.depwarn("`argmax(...)` is deprecated, use `onecold(...)` instead.", :argmax) return onecold(xs...) end diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index 1748ed5e07..86e7f2f3b8 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -38,6 +38,12 @@ Flux.back!(sum(l)) end +@testset "onecold gpu" begin + y = Flux.onehotbatch(ones(3), 1:10) |> gpu; + @test Flux.onecold(y) isa CuArray + @test y[3,:] isa CuArray +end + if CuArrays.libcudnn != nothing @info "Testing Flux/CUDNN" include("cudnn.jl") diff --git a/test/onehot.jl b/test/onehot.jl index b0177f3e96..fa2f4d6031 100644 --- a/test/onehot.jl +++ b/test/onehot.jl @@ -11,3 +11,9 @@ using Test @test onecold(a, labels) == 'C' @test onecold(A, labels) == ['C', 'A', 'D'] end + +@testset "onehotbatch indexing" begin + y = Flux.onehotbatch(ones(3), 1:10) + @test y[:,1] isa Flux.OneHotVector + @test y[:,:] isa Flux.OneHotMatrix +end