Skip to content

Commit

Permalink
Merge pull request #23 from JuliaGPU/tb/wrappers
Browse files Browse the repository at this point in the history
Rework wrapper type
  • Loading branch information
maleadt authored Jun 2, 2020
2 parents db56c24 + 3f39d53 commit a62a256
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 78 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Adapt"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "1.1.0"
version = "2.0.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
42 changes: 1 addition & 41 deletions src/Adapt.jl
Original file line number Diff line number Diff line change
@@ -1,46 +1,5 @@
module Adapt

using LinearAlgebra


export WrappedArray

# database of array wrappers
#
# LHS entries are a symbolic type with AT for the array type
#
# RHS entries consist of a closure to reconstruct the wrapper, with as arguments
# a wrapper instance and mutator function to apply to the inner array
const wrappers = (
:(SubArray{<:Any,<:Any,AT}) => (A,mut)->SubArray(mut(parent(A)), mut(parentindices(A))),
:(PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,AT})=> (A,mut)->PermutedDimsArray(mut(parent(A)), permutation(A)),
:(Base.ReshapedArray{<:Any,<:Any,AT,<:Any}) => (A,mut)->Base.reshape(mut(parent(A)), size(A)),
:(LinearAlgebra.Adjoint{<:Any,AT}) => (A,mut)->LinearAlgebra.adjoint(mut(parent(A))),
:(LinearAlgebra.Transpose{<:Any,AT}) => (A,mut)->LinearAlgebra.transpose(mut(parent(A))),
:(LinearAlgebra.LowerTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.LowerTriangular(mut(parent(A))),
:(LinearAlgebra.UnitLowerTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.UnitLowerTriangular(mut(parent(A))),
:(LinearAlgebra.UpperTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.UpperTriangular(mut(parent(A))),
:(LinearAlgebra.UnitUpperTriangular{<:Any,AT}) => (A,mut)->LinearAlgebra.UnitUpperTriangular(mut(parent(A))),
:(LinearAlgebra.Diagonal{<:Any,AT}) => (A,mut)->LinearAlgebra.Diagonal(mut(parent(A))),
:(LinearAlgebra.Tridiagonal{<:Any,AT}) => (A,mut)->LinearAlgebra.Tridiagonal(mut(A.dl), mut(A.d), mut(A.du)),
)

"""
WrappedArray{AT}
Union-type that encodes all array wrappers known by Adapt.jl.
Only use this type for dispatch purposes. To convert instances of an array wrapper, use
[`adapt`](@ref).
"""
const WrappedArray{AT} = @eval Union{$([W for (W,ctor) in Adapt.wrappers]...)} where AT

# XXX: this Unions is a hack, and only works with one level of wrray wrappers. ideally, Base
# would have `Transpose <: WrappedArray <: AbstractArray` and we could define methods
# in terms of `Union{SomeArray, WrappedArray{<:Any, <:SomeArray}}`.
# https://github.com/JuliaLang/julia/pull/31563


export adapt

"""
Expand Down Expand Up @@ -84,5 +43,6 @@ adapt_structure(to, x) = adapt_storage(to, x)
adapt_storage(to, x) = x

include("base.jl")
include("wrappers.jl")

end # module
12 changes: 0 additions & 12 deletions src/base.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,8 @@
# predefined adaptors for working with types from the Julia standard library

## Base

adapt_structure(to, xs::Union{Tuple,NamedTuple}) = map(x->adapt(to,x), xs)


## Array wrappers

permutation(::PermutedDimsArray{T,N,perm}) where {T,N,perm} = perm

for (W, ctor) in wrappers
mut = :(A -> adapt(to, A))
@eval adapt_structure(to, wrapper::$W where {AT <: Any}) = $ctor(wrapper, $mut)
end


## Broadcast

import Base.Broadcast: Broadcasted, Extruded
Expand Down
65 changes: 65 additions & 0 deletions src/wrappers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# adaptors and type aliases for working with array wrappers

using LinearAlgebra

permutation(::PermutedDimsArray{T,N,perm}) where {T,N,perm} = perm


export WrappedArray

# database of array wrappers
const _wrappers = (
:(SubArray{T,N,<:Src}) => (A,mut)->SubArray(mut(parent(A)), mut(parentindices(A))),
:(PermutedDimsArray{T,N,<:Any,<:Any,<:Src}) => (A,mut)->PermutedDimsArray(mut(parent(A)), permutation(A)),
:(Base.ReshapedArray{T,N,<:Src}) => (A,mut)->Base.reshape(mut(parent(A)), size(A)),
:(Base.ReinterpretArray{T,N,<:Src}) => (A,mut)->Base.reinterpret(eltype(A), mut(parent(A))),
:(LinearAlgebra.Adjoint{T,<:Dst}) => (A,mut)->LinearAlgebra.adjoint(mut(parent(A))),
:(LinearAlgebra.Transpose{T,<:Dst}) => (A,mut)->LinearAlgebra.transpose(mut(parent(A))),
:(LinearAlgebra.LowerTriangular{T,<:Dst}) => (A,mut)->LinearAlgebra.LowerTriangular(mut(parent(A))),
:(LinearAlgebra.UnitLowerTriangular{T,<:Dst}) => (A,mut)->LinearAlgebra.UnitLowerTriangular(mut(parent(A))),
:(LinearAlgebra.UpperTriangular{T,<:Dst}) => (A,mut)->LinearAlgebra.UpperTriangular(mut(parent(A))),
:(LinearAlgebra.UnitUpperTriangular{T,<:Dst}) => (A,mut)->LinearAlgebra.UnitUpperTriangular(mut(parent(A))),
:(LinearAlgebra.Diagonal{T,<:Dst}) => (A,mut)->LinearAlgebra.Diagonal(mut(parent(A))),
:(LinearAlgebra.Tridiagonal{T,<:Dst}) => (A,mut)->LinearAlgebra.Tridiagonal(mut(A.dl), mut(A.d), mut(A.du)),
)

for (W, ctor) in _wrappers
mut = :(A -> adapt(to, A))
@eval adapt_structure(to, wrapper::$W where {T,N,Src,Dst}) = $ctor(wrapper, $mut)
end

"""
WrappedArray{T,N,Src,Dst}
Union-type that encodes all array wrappers known by Adapt.jl. Typevars `T` and `N` encode
the type and dimensionality of the resulting container.
Two additional typevars are used to encode the parent array type: `Src` when the wrapper
uses the parent array as a source, but changes its properties (e.g.
`SubArray{T,1,Array{T,2}` changes `N`), and `Dst` when those properties are copied and thus
are identical to the destination wrapper's properties (e.g. `Transpose{T,Array{T,N}}` has
the same dimensionality as the inner array). When creating an alias for this type, e.g.
`WrappedSomeArray{T,N} = WrappedArray{T,N,...}` the `Dst` typevar should typically be set to
`SomeArray{T,N}` while `Src` should be more lenient, e.g., `SomeArray`.
Only use this type for dispatch purposes. To convert instances of an array wrapper, use
[`adapt`](@ref).
"""
const WrappedArray{T,N,Src,Dst} = @eval Union{$([W for (W,ctor) in Adapt._wrappers]...)} where {T,N,Src,Dst}

# XXX: this Union is a hack:
# - only works with one level of wrappi ng
# - duplication of Src and Dst typevars (without it, we get `WrappedGPUArray{T,N,AT{T,N}}`
# not matching `SubArray{T,1,AT{T,2}}`, and leaving out `{T,N}` makes it impossible to
# match e.g. `Diagonal{T,AT}` and get `N` out of that). alternatively, computed types
# would make it possible to do `SubArray{T,N,<:AT.name.wrapper}` or `Diagonal{T,AT{T,N}}`.
#
# ideally, Base would have, e.g., `Transpose <: WrappedArray`, and we could use
# `Union{SomeArray, WrappedArray{<:Any, <:SomeArray}}` for dispatch.
# https://github.com/JuliaLang/julia/pull/31563

# accessors for extracting information about the wrapper type
Base.ndims(::Type{<:WrappedArray{T,N,Src,Dst}}) where {T,N,Src,Dst} = @isdefined(N) ? N : ndims(Dst)
Base.eltype(::Type{<:WrappedArray{T,N,Src,Dst}}) where {T,N,Src,Dst} = @isdefined(T) ? T : ndims(Dst)
Base.parent(W::Type{<:WrappedArray{T,N,Src,Dst}}) where {T,N,Src,Dst} = @isdefined(Dst) ? Dst.name.wrapper : Src.name.wrapper

58 changes: 34 additions & 24 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,67 +5,77 @@ using Test
# custom array type

struct CustomArray{T,N} <: AbstractArray{T,N}
arr::AbstractArray
arr::Array
end

CustomArray(x::AbstractArray{T,N}) where {T,N} = CustomArray{T,N}(x)
Adapt.adapt_storage(::Type{<:CustomArray}, xs::AbstractArray) = CustomArray(xs)
CustomArray(x::Array{T,N}) where {T,N} = CustomArray{T,N}(x)
Adapt.adapt_storage(::Type{<:CustomArray}, xs::Array) = CustomArray(xs)

Base.size(x::CustomArray, y...) = size(x.arr, y...)
Base.getindex(x::CustomArray, y...) = getindex(x.arr, y...)


const val = CustomArray{Float64,2}(rand(2,2))
const mat = CustomArray{Float64,2}(rand(2,2))
const vec = CustomArray{Float64,1}(rand(2))

macro test_adapt(to, src, dst)
quote
@test adapt($to, $src) == $dst
@test typeof(adapt($to, $src)) == typeof($dst)
end
end


# basic adaption
@test adapt(CustomArray, val.arr) == val
@test adapt(CustomArray, val.arr) isa CustomArray
@test_adapt CustomArray mat.arr mat

# idempotency
@test adapt(CustomArray, val) == val
@test adapt(CustomArray, val) isa CustomArray
@test_adapt CustomArray mat mat

# custom wrapper
struct Wrapper{T}
arr::T
end
Wrapper(x::T) where T = Wrapper{T}(x)
Adapt.adapt_structure(to, xs::Wrapper) = Wrapper(adapt(to, xs.arr))
@test adapt(CustomArray, Wrapper(val.arr)) == Wrapper(val)
@test adapt(CustomArray, Wrapper(val.arr)) isa Wrapper{<:CustomArray}
@test_adapt CustomArray Wrapper(mat.arr) Wrapper(mat)


## base wrappers

@test @inferred(adapt(nothing, NamedTuple())) == NamedTuple()
@test adapt(CustomArray, (val.arr,)) == (val,)
@test_adapt CustomArray (mat.arr,) (mat,)
@test @allocated(adapt(nothing, ())) == 0
@test @allocated(adapt(nothing, (1,))) == 0
@test @allocated(adapt(nothing, (1,2,3,4,5,6,7,8,9,10))) == 0

@test adapt(CustomArray, (a=val.arr,)) == (a=val,)
@test_adapt CustomArray (a=mat.arr,) (a=mat,)

@test adapt(CustomArray, view(val.arr,:,:)) == view(val,:,:)
@test_adapt CustomArray view(mat.arr,:,:) view(mat,:,:)
const inds = CustomArray{Int,1}([1,2])
@test adapt(CustomArray, view(val.arr,inds.arr,:)) == view(val,inds,:)
@test_adapt CustomArray view(mat.arr,inds.arr,:) view(mat,inds,:)

# NOTE: manual creation of PermutedDimsArray because permutedims collects
@test adapt(CustomArray, PermutedDimsArray(val.arr,(2,1))) == PermutedDimsArray(val,(2,1))
@test_adapt CustomArray PermutedDimsArray(mat.arr,(2,1)) PermutedDimsArray(mat,(2,1))

# NOTE: manual creation of ReshapedArray because Base.Array has an optimized `reshape`
@test adapt(CustomArray, Base.ReshapedArray(val.arr,(2,2),())) == reshape(val,(2,2))
@test_adapt CustomArray Base.ReshapedArray(mat.arr,(2,2),()) reshape(mat,(2,2))


using LinearAlgebra

@test adapt(CustomArray, val.arr') == val'
@test_adapt CustomArray mat.arr' mat'

@test_adapt CustomArray transpose(mat.arr) transpose(mat)

@test adapt(CustomArray, transpose(val.arr)) == transpose(val)
@test_adapt CustomArray LowerTriangular(mat.arr) LowerTriangular(mat)
@test_adapt CustomArray UnitLowerTriangular(mat.arr) UnitLowerTriangular(mat)
@test_adapt CustomArray UpperTriangular(mat.arr) UpperTriangular(mat)
@test_adapt CustomArray UnitUpperTriangular(mat.arr) UnitUpperTriangular(mat)

@test adapt(CustomArray, LowerTriangular(val.arr)) == LowerTriangular(val)
@test adapt(CustomArray, UnitLowerTriangular(val.arr)) == UnitLowerTriangular(val)
@test adapt(CustomArray, UpperTriangular(val.arr)) == UpperTriangular(val)
@test adapt(CustomArray, UnitUpperTriangular(val.arr)) == UnitUpperTriangular(val)
@test_adapt CustomArray Diagonal(vec.arr) Diagonal(vec)

@test adapt(CustomArray, Diagonal(val.arr)) == Diagonal(val)
@test adapt(CustomArray, Tridiagonal(val.arr)) == Tridiagonal(val)
const dl = CustomArray{Float64,1}(rand(2))
const du = CustomArray{Float64,1}(rand(2))
const d = CustomArray{Float64,1}(rand(3))
@test_adapt CustomArray Tridiagonal(dl.arr, d.arr, du.arr) Tridiagonal(dl, d, du)

2 comments on commit a62a256

@maleadt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/16140

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v2.0.0 -m "<description of version>" a62a2568f1199d0e7b154eb6001afa629d31038e
git push origin v2.0.0

Please sign in to comment.