Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates for StaticArrayInterface #469

Merged
merged 1 commit into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LoopVectorization"
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
authors = ["Chris Elrod <elrodc@gmail.com>"]
version = "0.12.150"
version = "0.12.151"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -15,8 +15,6 @@ SpecialFunctionsExt = "SpecialFunctions"
[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
ArrayInterfaceOffsetArrays = "015c0d05-e682-4f19-8f0a-679ce4c54826"
ArrayInterfaceStaticArrays = "b0d46f97-bff5-4637-a19a-dd75974142cd"
CPUSummary = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CloseOpenIntervals = "fb6a15b2-703c-40df-9091-08a04967cfa9"
Expand All @@ -33,15 +31,14 @@ SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
ThreadingUtilities = "8290d209-cae3-49c0-8002-c8c24d57dab5"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"

[compat]
ArrayInterface = "6"
ArrayInterface = "7"
ArrayInterfaceCore = "0.1.5"
ArrayInterfaceOffsetArrays = "0.1.2"
ArrayInterfaceStaticArrays = "0.1.2"
CPUSummary = "0.1.3 - 0.1.8, 0.1.11, 0.2.1"
ChainRulesCore = "1"
CloseOpenIntervals = "0.1.10"
Expand All @@ -56,7 +53,8 @@ SIMDTypes = "0.1"
SLEEFPirates = "0.6.23"
SnoopPrecompile = "1"
SpecialFunctions = "1, 2"
Static = "0.7, 0.8"
Static = "0.8.4"
StaticArrayInterface = "1"
ThreadingUtilities = "0.5"
UnPack = "1"
VectorizationBase = "0.21.53"
Expand Down
16 changes: 6 additions & 10 deletions src/LoopVectorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@ end
using ArrayInterfaceCore: UpTri, LoTri
using Static: StaticInt, gt, static, Zero, One, reduce_tup
using VectorizationBase,
SLEEFPirates,
UnPack,
OffsetArrays,
ArrayInterfaceOffsetArrays,
ArrayInterfaceStaticArrays
SLEEFPirates, UnPack, OffsetArrays, StaticArrayInterface
const ArrayInterface = StaticArrayInterface
using LayoutPointers:
AbstractStridedPointer,
StridedPointer,
Expand Down Expand Up @@ -155,18 +152,17 @@ using SLEEFPirates:
sincos_fast,
tan_fast

using ArrayInterface
using ArrayInterface:
using StaticArrayInterface:
OptionallyStaticUnitRange,
OptionallyStaticRange,
StaticBool,
True,
False,
indices,
strides,
static_strides,
offsets,
size,
axes,
static_size,
static_axes,
StrideIndex
using CloseOpenIntervals: AbstractCloseOpen, CloseOpen#, SafeCloseOpen
# @static if VERSION ≥ v"1.6.0-rc1" #TODO: delete `else` when dropping 1.5 support
Expand Down
24 changes: 13 additions & 11 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ end
@inline ArrayInterface.parent_type(
::Type{LowDimArray{D,T,N,A}}
) where {T,D,N,A} = A
@inline Base.strides(A::LowDimArray) = map(Int, strides(A))
@inline Base.strides(A::LowDimArray) = map(Int, static_strides(A))
@inline ArrayInterface.device(::LowDimArray) = ArrayInterface.CPUPointer()
@generated function ArrayInterface.size(A::LowDimArray{D,T,N}) where {D,T,N}
@generated function ArrayInterface.static_size(
A::LowDimArray{D,T,N}
) where {D,T,N}
t = Expr(:tuple)
for n ∈ 1:N
if n > length(D) || D[n]
Expand Down Expand Up @@ -105,11 +107,13 @@ end
@inline forbroadcast(A) = A
# @inline forbroadcast(A::Adjoint) = forbroadcast(parent(A))
# @inline forbroadcast(A::Transpose) = forbroadcast(parent(A))
@inline function ArrayInterface.strides(A::Union{LowDimArray,ForBroadcast})
@inline function ArrayInterface.static_strides(
A::Union{LowDimArray,ForBroadcast}
)
B = parent(A)
_strides(
size(A),
strides(B),
static_size(A),
static_strides(B),
VectorizationBase.val_stride_rank(B),
VectorizationBase.val_dense_dims(B)
)
Expand Down Expand Up @@ -145,10 +149,10 @@ end
) where {D,T,N,A}
_lowdimfilter(Val(D), ArrayInterface.dense_dims(A))
end
@inline function ArrayInterface.strides(
@inline function ArrayInterface.static_strides(
fb::LowDimArrayForBroadcast{D}
) where {D}
_lowdimfilter(Val(D), strides(parent(fb)))
_lowdimfilter(Val(D), static_strides(parent(fb)))
end
@inline function ArrayInterface.offsets(
fb::LowDimArrayForBroadcast{D}
Expand Down Expand Up @@ -225,11 +229,9 @@ function _strides_expr(
sₙ_value::Int = 0
for n ∈ Nrange
xₙ_type = x[n]
# xₙ_type = typeof(x).parameters[n]
xₙ_static = xₙ_type <: StaticInt
xₙ_value::Int = xₙ_static ? (xₙ_type.parameters[1])::Int : 0
s_type = s[n]
# s_type = typeof(s).parameters[n]
sₙ_static = s_type <: StaticInt
if sₙ_static
sₙ_value = s_type.parameters[1]
Expand Down Expand Up @@ -365,7 +367,7 @@ function add_broadcast!(
pushprepreamble!(ls, Expr(:(=), mB, Expr(:(.), bcname, QuoteNode(:b))))
pushprepreamble!(
ls,
Expr(:(=), Klen, Expr(:call, getfield, Expr(:call, :size, mB), 1))
Expr(:(=), Klen, Expr(:call, getfield, Expr(:call, :static_size, mB), 1))
)
pushpreamble!(ls, Expr(:(=), Krange, Expr(:call, :(:), staticexpr(1), Klen)))
k = gensym!(ls, "k")
Expand Down Expand Up @@ -587,7 +589,7 @@ function add_broadcast_loops!(
destsym::Symbol
)
axes_tuple = Expr(:tuple)
pushpreamble!(ls, Expr(:(=), axes_tuple, Expr(:call, :axes, destsym)))
pushpreamble!(ls, Expr(:(=), axes_tuple, Expr(:call, :static_axes, destsym)))
for itersym ∈ loopsyms
Nrange = gensym!(ls, "N")
Nlower = gensym!(ls, "N")
Expand Down
6 changes: 3 additions & 3 deletions src/condense_loopset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -397,11 +397,11 @@ val(x) = Expr(:call, Expr(:curly, :Val, x))
p, li = VectorizationBase.tdot(
x,
(vsub_nsw(getfield(i, 1), one($I)),),
strides(x)
static_strides(x)
)
ptr = gep(p, li)
si = ArrayInterface.StrideIndex{1,$(R[ri],),$(C === 1 ? 1 : 0)}(
(getfield(strides(x), $ri),),
(getfield(static_strides(x), $ri),),
(Zero(),)
)
stridedpointer(ptr, si, StaticInt{$(B === 1 ? 1 : 0)}())
Expand All @@ -415,7 +415,7 @@ end
quote
$(Expr(:meta, :inline))
si = ArrayInterface.StrideIndex{1,$(R[ri],),$(C === 1 ? 1 : 0)}(
(getfield(strides(x), $ri),),
(getfield(static_strides(x), $ri),),
(getfield(offsets(x), $ri),)
)
stridedpointer(pointer(x), si, StaticInt{$(B == 1 ? 1 : 0)}())
Expand Down
16 changes: 9 additions & 7 deletions src/modeling/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -995,10 +995,12 @@ function makestatic!(expr)
if ex isa Int
expr.args[i] = staticexpr(ex)
elseif ex isa Symbol
if ex === :length
expr.args[i] = GlobalRef(ArrayInterface, :static_length)
elseif Base.sym_in(ex, (:axes, :size))
expr.args[i] = GlobalRef(ArrayInterface, ex)
j = findfirst(==(ex), (:axes, :size, :length))
if j !== nothing
expr.args[i] = GlobalRef(
ArrayInterface,
(:static_axes, :static_size, :static_length)[j]
)
end
elseif ex isa Expr
makestatic!(ex)
Expand Down Expand Up @@ -1215,7 +1217,7 @@ function indices_loop!(ls::LoopSet, r::Expr, itersym::Symbol)::Loop
axsym,
Expr(
:call,
GlobalRef(ArrayInterface, :axes),
GlobalRef(ArrayInterface, :static_axes),
a_s,
staticexpr(dims::Int)
)
Expand Down Expand Up @@ -1280,7 +1282,7 @@ function indices_loop!(ls::LoopSet, r::Expr, itersym::Symbol)::Loop
axsym,
Expr(
:call,
GlobalRef(ArrayInterface, :axes),
GlobalRef(ArrayInterface, :static_axes),
a_s,
staticexpr(mdim)
)
Expand Down Expand Up @@ -1351,7 +1353,7 @@ function register_single_loop!(ls::LoopSet, looprange::Expr)
)
indices_loop!(ls, r, itersym)
else
(f === :axes) && (r.args[1] = lv(:axes))
(f === :axes) && (r.args[1] = lv(:static_axes))
misc_loop!(ls, r, itersym, (f === :eachindex) | (f === :axes))
end
elseif isa(r, Symbol)
Expand Down
5 changes: 4 additions & 1 deletion src/reconstruct_loopset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,10 @@ function _add_mref!(
offsets = gensym(:offsets)
strides = gensym(:strides)
pushpreamble!(ls, Expr(:(=), offsets, Expr(:call, lv(:offsets), tmpsp)))
pushpreamble!(ls, Expr(:(=), strides, Expr(:call, lv(:strides), tmpsp)))
pushpreamble!(
ls,
Expr(:(=), strides, Expr(:call, lv(:static_strides), tmpsp))
)
for (i, p) ∈ enumerate(sp)
push!(strd_tup.args, Expr(:call, gf, strides, p, false))
push!(offsets_tup.args, Expr(:call, gf, offsets, p, false))
Expand Down
4 changes: 2 additions & 2 deletions src/simdfunctionals/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,11 @@ for (op, init) in zip((:+, :max, :min), (:zero, :typemin, :typemax))
Base.Cartesian.@nif 5 d -> (d <= ndims(arg) && dims == d) d -> begin
Rpre = CartesianIndices(ntuple(i -> axes_arg[i], d - 1))
Rpost = CartesianIndices(ntuple(i -> axes_arg[i+d], ndims(arg) - d))
_vreduce_dims!(out, $op, Rpre, 1:size(arg, dims), Rpost, arg)
_vreduce_dims!(out, $op, Rpre, static_axes(arg, dims), Rpost, arg)
end d -> begin
Rpre = CartesianIndices(axes_arg[1:dims-1])
Rpost = CartesianIndices(axes_arg[dims+1:end])
_vreduce_dims!(out, $op, Rpre, 1:size(arg, dims), Rpost, arg)
_vreduce_dims!(out, $op, Rpre, static_axes(arg, dims), Rpost, arg)
end
end

Expand Down
6 changes: 3 additions & 3 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function test_broadcast(::Type{T}) where {T}
b = rand(R, 99, 99, 1)
bl = LowDimArray{(true, true, false)}(b)
@test size(bl) == size(b)
@test LoopVectorization.ArrayInterface.size(bl) ===
@test LoopVectorization.static_size(bl) ===
(size(b, 1), size(b, 2), LoopVectorization.StaticInt(1))

br = reshape(b, (99, 99))
Expand All @@ -29,7 +29,7 @@ function test_broadcast(::Type{T}) where {T}
br = reshape(b, (99, 1, 99))
bl = LowDimArray{(true, false, true)}(br)
@test size(bl) == size(br)
@test LoopVectorization.ArrayInterface.size(bl) ===
@test LoopVectorization.static_size(bl) ===
(size(br, 1), LoopVectorization.StaticInt(1), size(br, 3))
@. c1 = a + br
fill!(c2, 99999)
Expand All @@ -41,7 +41,7 @@ function test_broadcast(::Type{T}) where {T}
br = reshape(b, (1, 99, 99))
bl = LowDimArray{(false,)}(br)
@test size(bl) == size(br)
@test LoopVectorization.ArrayInterface.size(bl) ===
@test LoopVectorization.static_size(bl) ===
(LoopVectorization.StaticInt(1), size(br, 2), size(br, 3))
@. c1 = a + br
fill!(c2, 99999)
Expand Down
18 changes: 4 additions & 14 deletions test/gemm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@
return C
end
function dense!(f::F, C, A, B) where {F}
Kp1 = LoopVectorization.size(A, LoopVectorization.StaticInt(2))
Kp1 = LoopVectorization.static_size(A, LoopVectorization.StaticInt(2))
K = Kp1 - LoopVectorization.StaticInt(1)
@turbo for n ∈ indices((B, C), 2), m ∈ indices((A, C), 1)
Cmn = zero(eltype(C))
Expand Down Expand Up @@ -733,7 +733,7 @@
Base.@propagate_inbounds Base.setindex!(A::TestSizedMatrix, v, i::Int, j::Int) =
setindex!(parent(A), v, i + 1, j + 1)
Base.size(::TestSizedMatrix{M,N}) where {M,N} = (M, N)
LoopVectorization.ArrayInterface.size(::TestSizedMatrix{M,N}) where {M,N} =
LoopVectorization.static_size(::TestSizedMatrix{M,N}) where {M,N} =
(LoopVectorization.StaticInt{M}(), LoopVectorization.StaticInt{N}())
function Base.axes(::TestSizedMatrix{M,N}) where {M,N}
(
Expand All @@ -757,7 +757,7 @@
end
Base.unsafe_convert(::Type{Ptr{T}}, A::TestSizedMatrix{M,N,T}) where {M,N,T} =
pointer(A.data)
LoopVectorization.ArrayInterface.strides(::TestSizedMatrix{M}) where {M} =
LoopVectorization.static_strides(::TestSizedMatrix{M}) where {M} =
(LoopVectorization.StaticInt{1}(), LoopVectorization.StaticInt{M}())
LoopVectorization.ArrayInterface.contiguous_axis(::Type{<:TestSizedMatrix}) =
LoopVectorization.One()
Expand All @@ -771,17 +771,7 @@
LoopVectorization.ArrayInterface.dense_dims(
::Type{TestSizedMatrix{M,N,T}},
) where {M,N,T} = LoopVectorization.ArrayInterface.dense_dims(Matrix{T})
# struct ZeroInitializedArray{T,N,A<:DenseArray{T,N}} <: DenseArray{T,N}
# data::A
# end
# Base.size(A::ZeroInitializedArray) = size(A.data)
# Base.length(A::ZeroInitializedArray) = length(A.data)
# Base.axes(A::ZeroInitializedArray, i) = axes(A.data, i)
# @inline Base.getindex(A::ZeroInitializedArray{T}) where {T} = zero(T)
# Base.@propagate_inbounds Base.setindex!(A::ZeroInitializedArray, v, i...) = setindex!(A.data, v, i...)
# function LoopVectorization.VectorizationBase.stridedpointer(A::ZeroInitializedArray)
# LoopVectorization.VectorizationBase.ZeroInitializedStridedPointer(LoopVectorization.VectorizationBase.stridedpointer(A.data))
# end


@testset "Matmuls" begin
for T ∈ (Float32, Float64, Int32, Int64)
Expand Down
9 changes: 6 additions & 3 deletions test/offsetarrays.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LoopVectorization, ArrayInterface, OffsetArrays, Test
using LoopVectorization, OffsetArrays, Test
using LoopVectorization: ArrayInterface
using LoopVectorization: StaticInt
# T = Float64; r = -1:1;
# T = Float32; r = -1:1;
Expand Down Expand Up @@ -109,10 +110,12 @@ using LoopVectorization: StaticInt
ArrayInterface.contiguous_batch_size(::Type{<:SizedOffsetMatrix}) = ArrayInterface.Zero()
ArrayInterface.stride_rank(::Type{<:SizedOffsetMatrix}) =
(ArrayInterface.StaticInt(1), ArrayInterface.StaticInt(2))
function ArrayInterface.strides(A::SizedOffsetMatrix{T,LR,UR,LC,UC}) where {T,LR,UR,LC,UC}
function LoopVectorization.static_strides(
::SizedOffsetMatrix{T,LR,UR,LC,UC},
) where {T,LR,UR,LC,UC}
(StaticInt{1}(), (StaticInt{UR}() - StaticInt{LR}() + StaticInt{1}()))
end
ArrayInterface.offsets(A::SizedOffsetMatrix{T,LR,UR,LC,UC}) where {T,LR,UR,LC,UC} =
ArrayInterface.offsets(::SizedOffsetMatrix{T,LR,UR,LC,UC}) where {T,LR,UR,LC,UC} =
(StaticInt{LR}(), StaticInt{LC}())
ArrayInterface.dense_dims(::Type{<:SizedOffsetMatrix{T}}) where {T} =
ArrayInterface.dense_dims(Matrix{T})
Expand Down
3 changes: 2 additions & 1 deletion test/parsing_inputs.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LoopVectorization, Test, ArrayInterface
using LoopVectorization, Test
using LoopVectorization: ArrayInterface
using LoopVectorization: check_inputs!

# macros for generate loops whose body is not a block
Expand Down