Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use fallback for reshape/cat OneHotArray #1459

Merged
merged 5 commits into from
Jan 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 32 additions & 18 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,24 @@ OneHotArray(indices::T, L::Integer) where {T<:Integer} = OneHotArray{T, L, 0, T}
OneHotArray(indices::AbstractArray{T, N}, L::Integer) where {T, N} = OneHotArray{T, L, N, typeof(indices)}(indices)

_indices(x::OneHotArray) = x.indices
_indices(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray}) =
reshape(parent(x).indices, x.dims[2:end])

const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T}
const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I}

OneHotVector(idx, L) = OneHotArray(idx, L)
OneHotMatrix(indices, L) = OneHotArray(indices, L)

# use this type so reshaped arrays hit fast paths
# e.g. argmax
const OneHotLike{T, L, N, var"N+1", I} =
Union{OneHotArray{T, L, N, var"N+1", I},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Man this N+1 is tripping me up, I would say we need to remove this soon. Where is it used exactly?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we could calculate var"N+1" during runtime?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like it either! It can't be done at runtime since N and var"N+1" are used in the type specification. N is used to specify the type of the index array, and var"N+1" is used to inherit from AbstractArray{Bool, var"N+1"}. Neither is evaluated at runtime.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could change it to another variable. I don't have strong feelings, but a part of me says that at least this naming signals the intent of the type parameter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be fair, I did mean we would have to switch it out during construction, because I don't think it's any better for dispatch to have to do checks on ints than types. To me it suggests that it is a preknown quantity so adding it to the type doesn't win us much.

Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L, <:Any, <:Any, I}}}

_isonehot(x::OneHotArray) = true
_isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L = (size(x, 1) == L)
darsnack marked this conversation as resolved.
Show resolved Hide resolved

Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...)

_onehotindex(x, i) = (x == i)
Expand All @@ -28,34 +39,30 @@ Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(x.i
Base.getindex(x::OneHotArray{<:Any, <:Any, <:Any, N}, ::Vararg{Colon, N}) where N = x
Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...]

_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}
_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}

function Base.cat(xs::OneHotArray{<:Any, L}...; dims::Int) where L
if isone(dims)
return throw(ArgumentError("Cannot concat OneHotArray along first dimension. Use collect to convert to Bool array first."))
function Base.cat(xs::OneHotLike{<:Any, L}...; dims::Int) where L
if isone(dims) || any(x -> !_isonehot(x), xs)
return cat(map(x -> convert(_onehot_bool_type(x), x), xs)...; dims = dims)
else
return OneHotArray(cat(_indices.(xs)...; dims = dims - 1), L)
end
end

Base.hcat(xs::OneHotArray...) = cat(xs...; dims = 2)
Base.vcat(xs::OneHotArray...) = cat(xs...; dims = 1)

Base.reshape(x::OneHotArray{<:Any, L}, dims::Dims) where L =
(first(dims) == L) ? OneHotArray(reshape(x.indices, dims[2:end]...), L) :
throw(ArgumentError("Cannot reshape OneHotArray if first(dims) != size(x, 1)"))
Base._reshape(x::OneHotArray, dims::Tuple{Vararg{Int}}) = reshape(x, dims)
Base.hcat(xs::OneHotLike...) = cat(xs...; dims = 2)
Base.vcat(xs::OneHotLike...) = cat(xs...; dims = 1)

batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(_indices.(xs), L)

Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, x.indices), L)
Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L)

Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}}) where N = CUDA.CuArrayStyle{N}()

Base.argmax(x::OneHotArray; dims = Colon()) =
(dims == 1) ? reshape(CartesianIndex.(x.indices, CartesianIndices(x.indices)), 1, size(x.indices)...) :
argmax(convert(_onehot_bool_type(x), x); dims = dims)
Base.argmax(x::OneHotLike; dims = Colon()) =
(_isonehot(x) && dims == 1) ?
reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) :
invoke(argmax, Tuple{AbstractArray}, x; dims = dims)

"""
onehot(l, labels[, unk])
Expand Down Expand Up @@ -135,11 +142,18 @@ function onecold(y::AbstractArray, labels = 1:size(y, 1))
end

_fast_argmax(x::AbstractArray) = dropdims(argmax(x; dims = 1); dims = 1)
_fast_argmax(x::OneHotArray) = x.indices
function _fast_argmax(x::OneHotLike)
if _isonehot(x)
return _indices(x)
else
return _fast_argmax(convert(_onehot_bool_type(x), x))
end
end
darsnack marked this conversation as resolved.
Show resolved Hide resolved

@nograd OneHotArray, onecold, onehot, onehotbatch

function Base.:(*)(A::AbstractMatrix, B::OneHotArray{<:Any, L}) where L
function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L}) where L
_isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B)
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
darsnack marked this conversation as resolved.
Show resolved Hide resolved
return A[:, onecold(B)]
end
32 changes: 22 additions & 10 deletions test/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ end
end

@testset "OneHotArray" begin
using Flux: OneHotArray, OneHotVector, OneHotMatrix
using Flux: OneHotArray, OneHotVector, OneHotMatrix, OneHotLike

ov = OneHotVector(rand(1:10), 10)
om = OneHotMatrix(rand(1:10, 5), 10)
Expand Down Expand Up @@ -74,27 +74,39 @@ end
@testset "Concatenating" begin
# vector cat
@test hcat(ov, ov) == OneHotMatrix(vcat(ov.indices, ov.indices), 10)
@test_throws ArgumentError vcat(ov, ov)
@test vcat(ov, ov) == vcat(collect(ov), collect(ov))
@test cat(ov, ov; dims = 3) == OneHotArray(cat(ov.indices, ov.indices; dims = 2), 10)

# matrix cat
@test hcat(om, om) == OneHotMatrix(vcat(om.indices, om.indices), 10)
@test_throws ArgumentError vcat(om, om)
@test vcat(om, om) == vcat(collect(om), collect(om))
@test cat(om, om; dims = 3) == OneHotArray(cat(om.indices, om.indices; dims = 2), 10)

# array cat
@test cat(oa, oa; dims = 3) == OneHotArray(cat(oa.indices, oa.indices; dims = 2), 10)
@test_throws ArgumentError cat(oa, oa; dims = 1)
@test cat(oa, oa; dims = 1) == cat(collect(oa), collect(oa); dims = 1)
end

@testset "Base.reshape" begin
# reshape test
@test reshape(oa, 10, 25) isa OneHotArray
@test reshape(oa, 10, :) isa OneHotArray
@test reshape(oa, :, 25) isa OneHotArray
@test_throws ArgumentError reshape(oa, 50, :)
@test_throws ArgumentError reshape(oa, 5, 10, 5)
@test reshape(oa, (10, 25)) isa OneHotArray
@test reshape(oa, 10, 25) isa OneHotLike
@test reshape(oa, 10, :) isa OneHotLike
@test reshape(oa, :, 25) isa OneHotLike
@test reshape(oa, 50, :) isa OneHotLike
@test reshape(oa, 5, 10, 5) isa OneHotLike
@test reshape(oa, (10, 25)) isa OneHotLike

@testset "w/ cat" begin
r = reshape(oa, 10, :)
@test hcat(r, r) isa OneHotArray
@test vcat(r, r) isa Array{Bool}
end

@testset "w/ argmax" begin
r = reshape(oa, 10, :)
@test argmax(r) == argmax(OneHotMatrix(reshape(oa.indices, :), 10))
@test Flux._fast_argmax(r) == collect(reshape(oa.indices, :))
end
end

@testset "Base.argmax" begin
Expand Down