Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

variable axis on which the "one hot" property holds #35

Open
nomadbl opened this issue May 23, 2023 · 3 comments
Open

variable axis on which the "one hot" property holds #35

nomadbl opened this issue May 23, 2023 · 3 comments

Comments

@nomadbl
Copy link

nomadbl commented May 23, 2023

Motivation and description

I am working on a layer that produces one hot outputs, so I am looking into using OneHotArrays.jl.
My gripe is that currently the datatype only supports the one hot vectors to extend on the first axis.

I thought I'd write my thoughts and possible implementations of the variable axis, to get some feedback and context from other maintainers and users here (I am very new to Julia and Flux, coming from working with python).

Possible Implementation

Implementation path 1 (WIP), change the constructors, size and getindex:

struct OneHotArray{T<:Integer,N,var"N+1",I<:Union{T,AbstractArray{T,N}}} <: AbstractArray{Bool,var"N+1"}
  indices::I
  nlabels::Int
  axis::Int
end
OneHotArray{T,N,I}(indices, L::Int, axis::Int=1) where {T,N,I} = OneHotArray{T,N,N + 1,I}(indices, L, axis)
OneHotArray(indices::T, L::Int, axis::Int=1) where {T<:Integer} = OneHotArray{T,0,1,T}(indices, L, axis)
OneHotArray(indices::I, L::Int, axis::Int=1) where {T,N,I<:AbstractArray{T,N}} = OneHotArray{T,N,N + 1,I}(indices, L, axis)

Base.size(x::OneHotArray) = Tuple(insert!(collect(size(x.indices)), x.axis, x.nlabels))

function Base.getindex(x::OneHotArray, I::Vararg{Int,N}) where {N}
  length(I) == length(size(x)) || throw(DimensionMismatch("dimensions of OneHotArray $(length(size(x))) and dimensions of indices $(length(I)) do not match."))
  @boundscheck all(1 .<= I .<= size(x)) || throw(BoundsError(x, I))
  Ip = Tuple(popat!(collect(I), x.axis))
  return some_appropriate_checks_here
end

The idea with this is to maintain the sparse nature of the representation for later optimized multiplications, backptop etc.

While working on this I also hit upon path 2, to reuse all the original code, but use the new axis parameter to do appropriate permutations of the underlying (1,...) dimensional object before computations.

I expect to do a PR of this soon, but I'd love to hear your thoughts: do you think the first approach is better (more memory and compute efficient?)? But also it is probably harder to maintain and test.

@mcabbott
Copy link
Member

Maybe the first question should be: What comes after this layer?

In this package, I think the efficient methods are that *(::Matrix, ::OneHotMatrix) is just indexing, used by Flux.Embedding, and argmax(:: OneHotMatrix) just reads the indices. Common uses like Flux.crossentropy(::Matrix, ::OneHotMatrix) don't do anything special, they use broadcasting which uses getindex. They could specialise but they are never the bottleneck.

A variant of path 2 is also just to wrap PermutedDimsArray(OneHotArray(.... If the next operation is broadcasting, this should be about as good. Downstream operations could also specialise on e.g. PermutedDimsArray{..., (2,1,3), ..., <:OneHotArray} to target particular dims.

@nomadbl
Copy link
Author

nomadbl commented May 23, 2023

I expect this layer to be followed by *(...) primarily. I'm not sure exactly what you mean by broadcasting, could you link to a function definition or explain?

I'll try both ways since I'm close to done on path 1, and path 2 seems simple at first glance.

In terms of tests for correctness I'm assuming to use the Flux and OneHotArrays tests as a first step, and add tests as necessary.
Do you have some kind of benchmark that would be useful to compare path 1 vs 2 in terms of speed?

Thanks for the help :)

@mcabbott
Copy link
Member

mcabbott commented May 24, 2023

expect this layer to be followed by *(...) primarily.

Since * is only for matrices & vectors, perhaps you just want transpose(onehotbatch([1,1,2], 1:4))? That doesn't specialise but it could:

julia> @which rand(3,4) * onehotbatch([1,1,2], 1:4)
*(A::AbstractMatrix, B::Union{OneHotArray{var"#s13", 1, var"N+1", I}, Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{var"#s13", <:Any, <:Any, I}}} where {var"#s13", var"N+1", I})
     @ OneHotArrays ~/.julia/packages/OneHotArrays/T3yiq/src/linalg.jl:7

julia> @which transpose(onehotbatch([1,1,2], 1:4)) * rand(4,3)
*(A::AbstractMatrix, B::AbstractMatrix)
     @ LinearAlgebra ~/.julia/dev/julia/usr/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:108

I'm not sure exactly what you mean by broadcasting

Operations like .* are broadcasting, see e.g. this or the manual.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants