Skip to content

Commit

Permalink
switch to using Fields as temporaries
Browse files Browse the repository at this point in the history
  • Loading branch information
simonbyrne committed Jan 25, 2023
1 parent c03f0f5 commit 21bdc33
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 65 deletions.
32 changes: 30 additions & 2 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ indexes the underlying array as `[i,j,k,f,v,h]`
module DataLayouts

import Base: Base, @propagate_inbounds
import StaticArrays: SOneTo, MArray
import StaticArrays: SOneTo, MArray, SArray
import ClimaComms

import ..enable_threading, ..slab, ..slab_args, ..column, ..column_args, ..level
Expand Down Expand Up @@ -626,6 +626,16 @@ function IJF{S, Nij}(array::AbstractArray{T, 3}) where {S, Nij, T}
IJF{S, Nij, typeof(array)}(array)
end

function IJF{S, Nij}(::Type{MArray}, ::Type{T}) where {S, Nij, T}
Nf = typesize(T, S)
array = MArray{Tuple{Nij, Nij, Nf}, T, 3, Nij * Nij * Nf}(undef)
IJF{S, Nij}(array)
end
function SArray(ijf::IJF{S, Nij, <:MArray}) where {S, Nij}
IJF{S, Nij}(SArray(parent(ijf)))
end


function replace_basetype(data::IJF{S, Nij}, ::Type{T}) where {S, Nij, T}
array = parent(data)
S′ = replace_basetype(eltype(array), T, S)
Expand Down Expand Up @@ -670,6 +680,9 @@ end
data::IJF{S, Nij},
i::Integer,
j::Integer,
k = nothing,
v = nothing,
h = nothing,
) where {S, Nij}
@boundscheck (1 <= i <= Nij && 1 <= j <= Nij) ||
throw(BoundsError(data, (i, j)))
Expand Down Expand Up @@ -749,6 +762,14 @@ function IF{S, Ni}(array::AbstractArray{T, 2}) where {S, Ni, T}
check_basetype(T, S)
IF{S, Ni, typeof(array)}(array)
end
function IF{S, Ni}(::Type{MArray}, ::Type{T}) where {S, Ni, T}
Nf = typesize(T, S)
array = MArray{Tuple{Ni, Nf}, T, 2, Ni * Nf}(undef)
IF{S, Ni}(array)
end
function SArray(data::IF{S, Ni, <:MArray}) where {S, Ni}
IF{S, Ni}(SArray(parent(data)))
end

function replace_basetype(data::IF{S, Ni}, ::Type{T}) where {S, Ni, T}
array = parent(data)
Expand Down Expand Up @@ -786,7 +807,14 @@ end
IF{SS, Ni}(dataview)
end

@inline function Base.getindex(data::IF{S, Ni}, i::Integer) where {S, Ni}
@inline function Base.getindex(
data::IF{S, Ni},
i::Integer,
j = nothing,
k = nothing,
v = nothing,
h = nothing,
) where {S, Ni}
@boundscheck (1 <= i <= Ni) || throw(BoundsError(data, (i,)))
@inbounds get_struct(parent(data), S, Val(2), CartesianIndex(i, 1))
end
Expand Down
Loading

0 comments on commit 21bdc33

Please sign in to comment.