diff --git a/src/onehot.jl b/src/onehot.jl index b879e5cda2..e30a0bb321 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -5,8 +5,8 @@ struct OneHotArray{T<:Integer, L, N, var"N+1", I<:Union{T, AbstractArray{T, N}}} indices::I end OneHotArray{T, L, N, I}(indices) where {T, L, N, I} = OneHotArray{T, L, N, N+1, I}(indices) -OneHotArray(indices::T, L::Integer) where {T<:Integer} = OneHotArray{T, L, 0, T}(indices) -OneHotArray(indices::AbstractArray{T, N}, L::Integer) where {T, N} = OneHotArray{T, L, N, typeof(indices)}(indices) +OneHotArray(indices::T, L::Integer) where {T<:Integer} = OneHotArray{T, L, 0, 1, T}(indices) +OneHotArray(indices::I, L::Integer) where {T, N, I<:AbstractArray{T, N}} = OneHotArray{T, L, N, N+1, I}(indices) _indices(x::OneHotArray) = x.indices _indices(x::Base.ReshapedArray{<: Any, <: Any, <: OneHotArray}) = @@ -75,6 +75,12 @@ end Base.hcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 2) Base.vcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 1) +# optimized concatenation for matrices and vectors of same parameters +Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 2}} = + OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), L) +Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 1}} = + OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), L) + batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(_indices.(xs), L) Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L) diff --git a/test/onehot.jl b/test/onehot.jl index ce30534ec9..5991d7c521 100644 --- a/test/onehot.jl +++ b/test/onehot.jl @@ -38,7 +38,9 @@ end using Flux: OneHotArray, OneHotVector, OneHotMatrix, OneHotLike ov = OneHotVector(rand(1:10), 10) + ov2 = OneHotVector(rand(1:11), 11) om = OneHotMatrix(rand(1:10, 5), 10) + om2 = OneHotMatrix(rand(1:11, 5), 11) oa = OneHotArray(rand(1:10, 5, 5), 10) # sizes @@ -74,17 +76,24 @@ end @testset "Concatenating" begin # vector cat @test hcat(ov, ov) == OneHotMatrix(vcat(ov.indices, ov.indices), 10) + @test hcat(ov, ov) isa OneHotMatrix @test vcat(ov, ov) == vcat(collect(ov), collect(ov)) @test cat(ov, ov; dims = 3) == OneHotArray(cat(ov.indices, ov.indices; dims = 2), 10) # matrix cat @test hcat(om, om) == OneHotMatrix(vcat(om.indices, om.indices), 10) + @test hcat(om, om) isa OneHotMatrix @test vcat(om, om) == vcat(collect(om), collect(om)) @test cat(om, om; dims = 3) == OneHotArray(cat(om.indices, om.indices; dims = 2), 10) # array cat @test cat(oa, oa; dims = 3) == OneHotArray(cat(oa.indices, oa.indices; dims = 2), 10) + @test cat(oa, oa; dims = 3) isa OneHotArray @test cat(oa, oa; dims = 1) == cat(collect(oa), collect(oa); dims = 1) + + # proper error handling of inconsistent sizes + @test_throws DimensionMismatch hcat(ov, ov2) + @test_throws DimensionMismatch hcat(om, om2) end @testset "Base.reshape" begin