Skip to content

Clean @pure issues in "dimensions.jl" and "static.jl" #119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ArrayInterface"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "3.0.2"
version = "3.1"

[deps]
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
Expand Down
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ If unknown, it returns `nothing`.

## contiguous_axis_indicator(::Type{T})

Returns a tuple of boolean `Val`s indicating whether that axis is contiguous.
Returns a tuple of boolean `StaticBool`s indicating whether that axis is contiguous.

## contiguous_batch_size(::Type{T})

Expand All @@ -167,7 +167,7 @@ Returns the rank of each stride.

## is_column_major(A)

Returns a `Val{true}()` if `A` is column major, and a `Val{false}()` otherwise.`
Returns a `True` if `A` is column major, and a `True/False` otherwise.

## dense_dims(::Type{T})
Returns a tuple of indicators for whether each axis is dense.
Expand Down Expand Up @@ -208,6 +208,10 @@ For example, if `A isa Base.Matrix`, `offsets(A) === (StaticInt(1), StaticInt(1)

Is the function `f` whitelisted for `LoopVectorization.@avx`?

## static(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we at the point yet where we should just create a Static.jl and depend on it?

We have static Ints, Bools, and now Symbols.
Octavian also has staticfloats.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking the same thing. There is already StaticNumbers.jl but I think there is a lot of utility in a very restricted set of static types. I think there has been a tendency with projects that deal with "staticness" to just solve everything with a new static type, which creates another problem.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do think a package like Static.jl should go?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just created SciML/Static.jl and added you as a contributor.

I'm not sure about ArrayInterface's functionality vs StaticNumbers.jl, but ArrayInterface's has definitely been growing.
I should've added static ranges to the list of things we support here.

But I do agree on the idea of having a narrower scope / focusing on being a lightweight dependency that other packages that want such functionality (including of course ArrayInterface.jl) can cheaply depend upon.

Returns a static form of `x`. If `x` is already in a static form then `x` is returned. If
there is no static alternative for `x` then an error is thrown.

## StaticInt(N::Int)

Creates a static integer with value known at compile time. It is a number,
Expand Down
11 changes: 10 additions & 1 deletion src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,17 @@ using IfElse
using Requires
using LinearAlgebra
using SparseArrays
using Base.Cartesian

using Base: @pure, @propagate_inbounds, tail, OneTo, LogicalIndex, Slice, ReinterpretArray
using Base: @propagate_inbounds, tail, OneTo, LogicalIndex, Slice, ReinterpretArray

@static if VERSION >= v"1.7.0-DEV.421"
using Base: @aggressive_constprop
else
macro aggressive_constprop(ex)
ex
end
end

Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
parameterless_type(x) = parameterless_type(typeof(x))
Expand Down
195 changes: 93 additions & 102 deletions src/dimensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(A
end
out
end
function from_parent_dims(::Type{<:PermutedDimsArray{T,N,<:Any,I}}) where {T,N,I}
return _val_to_static(Val(I))
end
from_parent_dims(::Type{<:PermutedDimsArray{T,N,<:Any,I}}) where {T,N,I} = static(Val(I))

"""
to_parent_dims(::Type{T}) -> Bool
Expand All @@ -51,7 +49,7 @@ Returns the mapping from child dimensions to parent dimensions.
to_parent_dims(x) = to_parent_dims(typeof(x))
to_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T)))
to_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One())
to_parent_dims(::Type{<:PermutedDimsArray{T,N,I}}) where {T,N,I} = _val_to_static(Val(I))
to_parent_dims(::Type{<:PermutedDimsArray{T,N,I}}) where {T,N,I} = static(Val(I))
to_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _to_sub_dims(A, I)
@generated function _to_sub_dims(::Type{A}, ::Type{I}) where {A,N,I<:Tuple{Vararg{Any,N}}}
out = Expr(:tuple)
Expand Down Expand Up @@ -79,37 +77,44 @@ function has_dimnames(::Type{T}) where {T}
end
end

# this takes the place of dimension names that aren't defined
const SUnderscore = StaticSymbol(:_)

"""
dimnames(::Type{T}) -> Tuple{Vararg{Symbol}}
dimnames(::Type{T}, d) -> Symbol
dimnames(::Type{T}) -> Tuple{Vararg{StaticSymbol}}
dimnames(::Type{T}, dim) -> StaticSymbol

Return the names of the dimensions for `x`.
"""
@inline dimnames(x) = dimnames(typeof(x))
@inline dimnames(x, i::Integer) = dimnames(typeof(x), i)
@inline dimnames(::Type{T}, d::Integer) where {T} = getfield(dimnames(T), to_dims(T, d))
@inline function dimnames(::Type{T}) where {T}
if parent_type(T) <: T
return ntuple(i -> :_, Val(ndims(T)))
@inline dimnames(x, dim::Int) = dimnames(typeof(x), dim)
@inline dimnames(x, dim::StaticInt) = dimnames(typeof(x), dim)
@inline function dimnames(::Type{T}, ::StaticInt{dim}) where {T,dim}
if ndims(T) < dim
return SUnderscore
else
return dimnames(parent_type(T))
return getfield(dimnames(T), dim)
end
end
@inline function dimnames(::Type{T}) where {T<:Union{Transpose,Adjoint}}
return _transpose_dimnames(Val(dimnames(parent_type(T))))
@inline function dimnames(::Type{T}, dim::Int) where {T}
if ndims(T) < dim
return SUnderscore
else
return getfield(dimnames(T), dim)
end
end
# inserting the Val here seems to help inferability; I got a test failure without it.
function _transpose_dimnames(::Val{S}) where {S}
if length(S) == 1
(:_, first(S))
elseif length(S) == 2
(last(S), first(S))
@inline function dimnames(::Type{T}) where {T}
if parent_type(T) <: T
return ntuple(_ -> SUnderscore, Val(ndims(T)))
else
throw("Can't transpose $S of dim $(length(S)).")
return dimnames(parent_type(T))
end
end
@inline _transpose_dimnames(x::Tuple{Symbol,Symbol}) = (last(x), first(x))
@inline _transpose_dimnames(x::Tuple{Symbol}) = (:_, first(x))
@inline function dimnames(::Type{T}) where {T<:Union{Adjoint,Transpose}}
_transpose_dimnames(dimnames(parent_type(T)))
end
@inline _transpose_dimnames(x::Tuple{Any,Any}) = (last(x), first(x))
@inline _transpose_dimnames(x::Tuple{Any}) = (SUnderscore, first(x))

@inline function dimnames(::Type{T}) where {I,T<:PermutedDimsArray{<:Any,<:Any,I}}
return map(i -> dimnames(parent_type(T), i), I)
Expand All @@ -123,7 +128,7 @@ end
for i in 1:length(I)
if I[i] > 0
if nl < i
push!(e.args, QuoteNode(:_))
push!(e.args, :(ArrayInterface.SUnderscore))
else
push!(e.args, QuoteNode(L[i]))
end
Expand All @@ -132,83 +137,79 @@ end
return e
end

"""
to_dims(x[, d])
_to_int(x::Integer) = Int(x)
_to_int(x::StaticInt) = x

This returns the dimension(s) of `x` corresponding to `d`.
"""
to_dims(x, d) = to_dims(dimnames(x), d)
to_dims(x::Tuple{Vararg{Symbol}}, d::Integer) = Int(d)
to_dims(x::Tuple{Vararg{Symbol}}, d::Colon) = d # `:` is the default for most methods that take `dims`
@inline to_dims(x::Tuple{Vararg{Symbol}}, d::Tuple) = map(i -> to_dims(x, i), d)
@inline function to_dims(x::Tuple{Vararg{Symbol}}, d::Symbol)::Int
i = _sym_to_dim(x, d)
if i === 0
throw(ArgumentError("Specified name ($(repr(d))) does not match any dimension name ($(x))"))
end
return i
end
Base.@pure function _sym_to_dim(x::Tuple{Vararg{Symbol,N}}, sym::Symbol) where {N}
for i in 1:N
getfield(x, i) === sym && return i
end
return 0
function no_dimname_error(@nospecialize(x), @nospecialize(dim))
throw(ArgumentError("($(repr(dim))) does not correspond to any dimension of ($(x))"))
end

"""
tuple_issubset
to_dims(::Type{T}, dim) -> Integer

A version of `issubset` sepecifically for `Tuple`s of `Symbol`s, that is `@pure`.
This helps it get optimised out of existance. It is less of an abuse of `@pure` than
most of the stuff for making `NamedTuples` work.
This returns the dimension(s) of `x` corresponding to `d`.
"""
Base.@pure function tuple_issubset(
lhs::Tuple{Vararg{Symbol,N}}, rhs::Tuple{Vararg{Symbol,M}}
) where {N,M}
N <= M || return false
for a in lhs
found = false
for b in rhs
found |= a === b
end
found || return false
end
return true
to_dims(x, dim) = to_dims(typeof(x), dim)
to_dims(::Type{T}, dim::Integer) where {T} = _to_int(dim)
to_dims(::Type{T}, dim::Colon) where {T} = dim
function to_dims(::Type{T}, dim::StaticSymbol) where {T}
i = find_first_eq(dim, dimnames(T))
i === nothing && no_dimname_error(T, dim)
return i
end
@inline function to_dims(::Type{T}, dim::Symbol) where {T}
i = find_first_eq(dim, Symbol.(dimnames(T)))
i === nothing && no_dimname_error(T, dim)
return i
end
to_dims(::Type{T}, dims::Tuple) where {T} = map(i -> to_dims(T, i), dims)

"""
order_named_inds(Val(names); kwargs...)
order_named_inds(Val(names), namedtuple)
#=
order_named_inds(names, namedtuple)
order_named_inds(names, subnames, inds)

Returns the tuple of index values for an array with `names`, when indexed by keywords.
Any dimensions not fixed are given as `:`, to make a slice.
An error is thrown if any keywords are used which do not occur in `nda`'s names.
"""
@inline function order_named_inds(val::Val{L}; kwargs...) where {L}
if isempty(kwargs)
return ()


1. parse into static dimnension names and key words.
2. find each dimnames in key words
3. if nothing is found use Colon()
4. if (ndims - ncolon) === nkwargs then all were found, else error
=#
order_named_inds(x::Tuple, ::NamedTuple{(),Tuple{}}) = ()
function order_named_inds(x::Tuple, nd::NamedTuple{L}) where {L}
return order_named_inds(x, static(Val(L)), Tuple(nd))
end
@aggressive_constprop function order_named_inds(
x::Tuple{Vararg{Any,N}},
nd::Tuple,
inds::Tuple
) where {N}

out = eachop(((x, nd, inds), i) -> order_named_inds(x, nd, inds, i), (x, nd, inds), nstatic(Val(N)))
_order_named_inds_check(out, length(nd))
return out
end
function order_named_inds(x::Tuple, nd::Tuple, inds::Tuple, ::StaticInt{dim}) where {dim}
index = find_first_eq(getfield(x, dim), nd)
if index === nothing
return Colon()
else
return order_named_inds(val, kwargs.data)
end
end
@generated function order_named_inds(val::Val{L}, ni::NamedTuple{K}) where {L,K}
tuple_issubset(K, L) || throw(DimensionMismatch("Expected subset of $L, got $K"))
exs = map(L) do n
if Base.sym_in(n, K)
qn = QuoteNode(n)
:(getfield(ni, $qn))
else
:(Colon())
end
return @inbounds(inds[index])
end
return Expr(:tuple, exs...)
end
@generated function _perm_tuple(::Type{T}, ::Val{P}) where {T,P}
out = Expr(:curly, :Tuple)
for p in P
push!(out.args, T.parameters[p])

ncolon(x::Tuple{Colon,Vararg}, n::Int) = ncolon(tail(x), n + 1)
ncolon(x::Tuple{Any,Vararg}, n::Int) = ncolon(tail(x), n)
ncolon(x::Tuple{Colon}, n::Int) = n + 1
ncolon(x::Tuple{Any}, n::Int) = n
function _order_named_inds_check(inds::Tuple{Vararg{Any,N}}, nkwargs::Int) where {N}
if (N - ncolon(inds, 0)) !== nkwargs
error("Not all keywords matched dimension names.")
end
Expr(:block, Expr(:meta, :inline), out)
return nothing
end

"""
Expand All @@ -226,14 +227,11 @@ function axes_types(::Type{T}) where {T}
return axes_types(parent_type(T))
end
end
function axes_types(::Type{T}) where {T<:Adjoint}
return _perm_tuple(axes_types(parent_type(T)), Val((2, 1)))
function axes_types(::Type{T}) where {T<:MatAdjTrans}
return eachop_tuple(_get_tuple, axes_types(parent_type(T)), to_parent_dims(T))
end
function axes_types(::Type{T}) where {T<:Transpose}
return _perm_tuple(axes_types(parent_type(T)), Val((2, 1)))
end
function axes_types(::Type{T}) where {I1,T<:PermutedDimsArray{<:Any,<:Any,I1}}
return _perm_tuple(axes_types(parent_type(T)), Val(I1))
function axes_types(::Type{T}) where {T<:PermutedDimsArray}
return eachop_tuple(_get_tuple, axes_types(parent_type(T)), to_parent_dims(T))
end
function axes_types(::Type{T}) where {T<:AbstractRange}
if known_length(T) === nothing
Expand Down Expand Up @@ -311,8 +309,6 @@ end
Expr(:block, Expr(:meta, :inline), out)
end



"""
size(A)

Expand All @@ -330,12 +326,7 @@ julia> ArrayInterface.size(A)
@inline size(A) = Base.size(A)
@inline size(A, d::Integer) = size(A)[Int(d)]
@inline size(A, d) = Base.size(A, to_dims(A, d))
@inline function size(x::LinearAlgebra.Adjoint{T,V}) where {T,V<:AbstractVector{T}}
return (One(), static_length(x))
end
@inline function size(x::LinearAlgebra.Transpose{T,V}) where {T,V<:AbstractVector{T}}
return (One(), static_length(x))
end
@inline size(x::VecAdjTrans) = (One(), static_length(x))

function size(B::S) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}}
return _size(size(parent(B)), B.indices, map(static_length, B.indices))
Expand All @@ -357,9 +348,9 @@ end
Expr(:block, Expr(:meta, :inline), t)
end
@inline size(v::AbstractVector) = (static_length(v),)
@inline size(B::MatAdjTrans) = permute(size(parent(B)), Val{(2, 1)}())
@inline size(B::MatAdjTrans) = permute(size(parent(B)), to_parent_dims(B))
@inline function size(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A}
return permute(size(parent(B)), Val{I1}())
return permute(size(parent(B)), to_parent_dims(B))
end
@inline size(A::AbstractArray, ::StaticInt{N}) where {N} = size(A)[N]
@inline size(A::AbstractArray, ::Val{N}) where {N} = size(A)[N]
Expand Down
5 changes: 3 additions & 2 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ Changing indexing based on a given argument from `args` should be done through
@propagate_inbounds getindex(A, args...) = unsafe_getindex(A, to_indices(A, args))
@propagate_inbounds function getindex(A; kwargs...)
if has_dimnames(A)
return A[order_named_inds(Val(dimnames(A)); kwargs...)...]
return A[order_named_inds(dimnames(A), kwargs.data)...]
else
return unsafe_getindex(A, to_indices(A, ()); kwargs...)
end
Expand Down Expand Up @@ -548,7 +548,7 @@ Store the given values at the given key or index within a collection.
end
@propagate_inbounds function setindex!(A, val; kwargs...)
if has_dimnames(A)
A[order_named_inds(Val(dimnames(A)); kwargs...)...] = val
A[order_named_inds(dimnames(A), kwargs.data)...] = val
else
return unsafe_setindex!(A, val, to_indices(A, ()); kwargs...)
end
Expand Down Expand Up @@ -662,3 +662,4 @@ end
) where {N}
return _generate_unsafe_setindex!_body(N)
end

1 change: 1 addition & 0 deletions src/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -607,3 +607,4 @@ function Base.show(io::IO, r::OptionallyStaticRange)
end
print(io, last(r))
end

Loading