From 35cd9761a875a2d299fa6cee028230c86cfaecb2 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Sat, 9 Feb 2019 22:32:02 +0530 Subject: [PATCH 1/8] adding tests --- src/onehot.jl | 24 +++++++++++++++++++++++- test/cuda/cuda.jl | 6 ++++++ test/onehot.jl | 6 ++++++ 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/src/onehot.jl b/src/onehot.jl index cd29f14e9d..79679ff5aa 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -9,6 +9,8 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),) Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix +Base.getindex(xs::OneHotVector, ::Colon) = xs + A::AbstractMatrix * b::OneHotVector = A[:, b.ix] struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool} @@ -22,6 +24,22 @@ Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i] Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) +Base.getindex(xs::Flux.OneHotMatrix, j::Base.UnitRange, i::Int) = xs.data[i][j] + +Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = xs +Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data) + +# handle special case when we want the whole column +function Base.getindex(xs::Flux.OneHotMatrix{T}, ot::Union{Base.Slice, Base.OneTo}, i::Int) where {T<:AbstractArray} + res = similar(xs, size(xs, 1), 1) + if length(ot) == size(xs, 1) + res = xs[:,i] + else + res = xs[1:length(ot),i] + end + res +end + A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)] Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...]) @@ -54,13 +72,17 @@ end onehotbatch(ls, labels, unk...) = OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls]) +Base.argmax(xs::OneHotVector) = xs.ix + onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)] onecold(y::AbstractMatrix, labels...) = dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1) +onecold(y::OneHotMatrix, labels...) = map(x -> onecold(x, labels...), y.data) + function argmax(xs...) - Base.depwarn("`argmax(...) is deprecated, use `onecold(...)` instead.", :argmax) + Base.depwarn("`argmax(...)` is deprecated, use `onecold(...)` instead.", :argmax) return onecold(xs...) end diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index f7a085031b..0704b98f5f 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -38,6 +38,12 @@ Flux.back!(sum(l)) end +@testset "onecold gpu" begin + y = Flux.onehotbatch(ones(3), 1:10) |> gpu; + @test Flux.onecold(y) isa CuArray + @test y[3,:] isa CuArray +end + if CuArrays.libcudnn != nothing @info "Testing Flux/CUDNN" include("cudnn.jl") diff --git a/test/onehot.jl b/test/onehot.jl index b0177f3e96..fa2f4d6031 100644 --- a/test/onehot.jl +++ b/test/onehot.jl @@ -11,3 +11,9 @@ using Test @test onecold(a, labels) == 'C' @test onecold(A, labels) == ['C', 'A', 'D'] end + +@testset "onehotbatch indexing" begin + y = Flux.onehotbatch(ones(3), 1:10) + @test y[:,1] isa Flux.OneHotVector + @test y[:,:] isa Flux.OneHotMatrix +end From 1ada9afe81fb8cdc494c293ac936f37c63339b42 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Sat, 9 Feb 2019 22:38:49 +0530 Subject: [PATCH 2/8] assert no scalar indexing for onecold --- test/cuda/cuda.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index 0704b98f5f..a3ed62b883 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -39,6 +39,7 @@ Flux.back!(sum(l)) end @testset "onecold gpu" begin + CuArrays.allowscalar(false) y = Flux.onehotbatch(ones(3), 1:10) |> gpu; @test Flux.onecold(y) isa CuArray @test y[3,:] isa CuArray From d16ef75b1c5cb165113a7867e9aa847bdaceed6f Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Mon, 11 Feb 2019 20:32:23 +0530 Subject: [PATCH 3/8] remove duplicate allowscalar call --- test/cuda/cuda.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index a3ed62b883..69730975ac 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -39,9 +39,9 @@ Flux.back!(sum(l)) end @testset "onecold gpu" begin - CuArrays.allowscalar(false) y = Flux.onehotbatch(ones(3), 1:10) |> gpu; @test Flux.onecold(y) isa CuArray + @test y[:,:] isa Flux.OneHotMatrix{<:CuArray} @test y[3,:] isa CuArray end From 2ec35861b5a6b1e0ac03cd42a43f3543452c595a Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Mon, 11 Feb 2019 21:22:32 +0530 Subject: [PATCH 4/8] removing non-allocating functions and tests --- src/onehot.jl | 3 --- test/cuda/cuda.jl | 1 - test/onehot.jl | 1 - 3 files changed, 5 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index 79679ff5aa..4294538870 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -24,9 +24,6 @@ Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i] Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) -Base.getindex(xs::Flux.OneHotMatrix, j::Base.UnitRange, i::Int) = xs.data[i][j] - -Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = xs Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data) # handle special case when we want the whole column diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index 69730975ac..0704b98f5f 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -41,7 +41,6 @@ end @testset "onecold gpu" begin y = Flux.onehotbatch(ones(3), 1:10) |> gpu; @test Flux.onecold(y) isa CuArray - @test y[:,:] isa Flux.OneHotMatrix{<:CuArray} @test y[3,:] isa CuArray end diff --git a/test/onehot.jl b/test/onehot.jl index fa2f4d6031..6e0057ebbc 100644 --- a/test/onehot.jl +++ b/test/onehot.jl @@ -15,5 +15,4 @@ end @testset "onehotbatch indexing" begin y = Flux.onehotbatch(ones(3), 1:10) @test y[:,1] isa Flux.OneHotVector - @test y[:,:] isa Flux.OneHotMatrix end From 6825639f793ec83decf8d43bebc3aedb3876df4e Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 28 Feb 2019 09:17:18 +0530 Subject: [PATCH 5/8] mapreduce for onehotmatrix --- src/onehot.jl | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/onehot.jl b/src/onehot.jl index 4294538870..ef326650d4 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -20,23 +20,12 @@ end Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) -Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i] +Base.getindex(xs::OneHotMatrix, i::Union{Integer, AbstractVector}, j::Integer) = xs.data[j][i] Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data) -# handle special case when we want the whole column -function Base.getindex(xs::Flux.OneHotMatrix{T}, ot::Union{Base.Slice, Base.OneTo}, i::Int) where {T<:AbstractArray} - res = similar(xs, size(xs, 1), 1) - if length(ot) == size(xs, 1) - res = xs[:,i] - else - res = xs[1:length(ot),i] - end - res -end - A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)] Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...]) @@ -76,7 +65,8 @@ onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)] onecold(y::AbstractMatrix, labels...) = dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1) -onecold(y::OneHotMatrix, labels...) = map(x -> onecold(x, labels...), y.data) +onecold(y::OneHotMatrix, labels...) = + mapreduce(x -> Flux.onecold(x, labels...), |, y.data, dims = 2, init = 0) function argmax(xs...) Base.depwarn("`argmax(...)` is deprecated, use `onecold(...)` instead.", :argmax) From 4f1336905fa2c7dd3f178f302c02e34a92f8bd45 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 4 Apr 2019 19:16:14 +0530 Subject: [PATCH 6/8] fix colon indexing --- src/onehot.jl | 1 + test/onehot.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/src/onehot.jl b/src/onehot.jl index ef326650d4..488167e2dd 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -23,6 +23,7 @@ Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) Base.getindex(xs::OneHotMatrix, i::Union{Integer, AbstractVector}, j::Integer) = xs.data[j][i] Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) +# Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy(xs.data)) Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data) diff --git a/test/onehot.jl b/test/onehot.jl index 6e0057ebbc..fa2f4d6031 100644 --- a/test/onehot.jl +++ b/test/onehot.jl @@ -15,4 +15,5 @@ end @testset "onehotbatch indexing" begin y = Flux.onehotbatch(ones(3), 1:10) @test y[:,1] isa Flux.OneHotVector + @test y[:,:] isa Flux.OneHotMatrix end From 5b9c53439b9f791c5dae8a4c28d029e7e9d9536c Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 4 Apr 2019 19:19:47 +0530 Subject: [PATCH 7/8] recreate OHV --- src/onehot.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onehot.jl b/src/onehot.jl index 488167e2dd..12a77ecd66 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -9,7 +9,7 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),) Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix -Base.getindex(xs::OneHotVector, ::Colon) = xs +Base.getindex(xs::OneHotVector, ::Colon) = OneHotVector(xs.ix, xs.of) A::AbstractMatrix * b::OneHotVector = A[:, b.ix] From 2952bcdab16b196bf79c7c71aebc2178b02feaaa Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 4 Apr 2019 19:28:40 +0530 Subject: [PATCH 8/8] fixes --- src/onehot.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onehot.jl b/src/onehot.jl index 12a77ecd66..0cd145f867 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -23,7 +23,7 @@ Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data)) Base.getindex(xs::OneHotMatrix, i::Union{Integer, AbstractVector}, j::Integer) = xs.data[j][i] Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i] Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i]) -# Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy(xs.data)) +Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy(xs.data)) Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data)