Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions src/DataLayouts/struct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,13 @@ Base.@propagate_inbounds @generated function get_struct(
::Val{D},
start_index::CartesianIndex,
) where {T, S, D}
# recursion base case: hit array type is the same as the struct leaf type
if T === S # Use Union-splitting for better latency
return quote
Base.@_propagate_inbounds_meta
@inbounds return array[start_index]
end
end
tup = :(())
for i in 1:fieldcount(S)
push!(
Expand All @@ -201,16 +208,6 @@ Base.@propagate_inbounds @generated function get_struct(
end
end

# recursion base case: hit array type is the same as the struct leaf type
Base.@propagate_inbounds function get_struct(
array::AbstractArray{S},
::Type{S},
::Val{D},
start_index::CartesianIndex,
) where {S, D}
@inbounds return array[start_index]
end

"""
set_struct!(array, val::S, Val(D), start_index)

Expand Down
133 changes: 46 additions & 87 deletions src/Operators/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3759,25 +3759,37 @@ end

Base.@propagate_inbounds function getidx(
parent_space,
bc::StencilBroadcasted,
loc::Interior,
idx,
hidx,
)
space = reconstruct_placeholder_space(axes(bc), parent_space)
stencil_interior(bc.op, loc, space, idx, hidx, bc.args...)
end

Base.@propagate_inbounds function getidx(
parent_space,
bc::StencilBroadcasted,
loc::LeftBoundaryWindow,
bc::Union{StencilBroadcasted, Base.Broadcast.Broadcasted},
loc::Location,
idx,
hidx,
)
# Use Union-splitting here (x isa X) instead of dispatch
# for improved latency.
space = reconstruct_placeholder_space(axes(bc), parent_space)
if bc isa Base.Broadcast.Broadcasted
# Manually call bc.f for small tuples (improved latency)
(; args) = bc
N = length(bc.args)
if N == 1
return bc.f(getidx(space, args[1], loc, idx, hidx))
elseif N == 2
return bc.f(
getidx(space, args[1], loc, idx, hidx),
getidx(space, args[2], loc, idx, hidx),
)
elseif N == 3
return bc.f(
getidx(space, args[1], loc, idx, hidx),
getidx(space, args[2], loc, idx, hidx),
getidx(space, args[3], loc, idx, hidx),
)
end
return call_bc_f(bc.f, space, loc, idx, hidx, args...)
end
op = bc.op
if should_call_left_boundary(idx, space, bc, loc)
if loc isa LeftBoundaryWindow &&
should_call_left_boundary(idx, space, bc, loc)
stencil_left_boundary(
op,
get_boundary(op, loc),
Expand All @@ -3787,22 +3799,8 @@ Base.@propagate_inbounds function getidx(
hidx,
bc.args...,
)
else
# fallback to interior stencil
stencil_interior(op, loc, space, idx, hidx, bc.args...)
end
end

Base.@propagate_inbounds function getidx(
parent_space,
bc::StencilBroadcasted,
loc::RightBoundaryWindow,
idx,
hidx,
)
op = bc.op
space = reconstruct_placeholder_space(axes(bc), parent_space)
if should_call_right_boundary(idx, space, bc, loc)
elseif loc isa RightBoundaryWindow &&
should_call_right_boundary(idx, space, bc, loc)
stencil_right_boundary(
op,
get_boundary(op, loc),
Expand All @@ -3813,11 +3811,11 @@ Base.@propagate_inbounds function getidx(
bc.args...,
)
else
# fallback to interior stencil
stencil_interior(op, loc, space, idx, hidx, bc.args...)
stencil_interior(bc.op, loc, space, idx, hidx, bc.args...)
end
end


# broadcasting a StencilStyle gives a CompositeStencilStyle
Base.Broadcast.BroadcastStyle(
::Type{<:StencilBroadcasted{Style}},
Expand Down Expand Up @@ -3902,48 +3900,25 @@ end
@noinline inferred_getidx_error(idx_type::Type, space_type::Type) =
error("Invalid index type `$idx_type` for field on space `$space_type`")


# recursively unwrap getidx broadcast arguments in a way that is statically reducible by the optimizer
Base.@propagate_inbounds getidx_args(
space,
args::Tuple,
loc::Location,
idx,
hidx,
) = (
getidx(space, args[1], loc, idx, hidx),
getidx_args(space, Base.tail(args), loc, idx, hidx)...,
)
Base.@propagate_inbounds getidx_args(
@generated function call_bc_f(
f::F,
space,
arg::Tuple{Any},
loc::Location,
idx,
hidx,
) = (getidx(space, arg[1], loc, idx, hidx),)
Base.@propagate_inbounds getidx_args(
space,
::Tuple{},
loc::Location,
idx,
hidx,
) = ()

Base.@propagate_inbounds function getidx(
parent_space,
bc::Base.Broadcast.Broadcasted,
loc::Location,
idx,
hidx,
)
space = reconstruct_placeholder_space(axes(bc), parent_space)
_args = getidx_args(space, bc.args, loc, idx, hidx)
bc.f(_args...)
args...,
) where {F}
N = length(args)
return quote
Base.@_propagate_inbounds_meta
Base.Cartesian.@ncall $N f i -> getidx(space, args[i], loc, idx, hidx)
end
end

if hasfield(Method, :recursion_relation)
dont_limit = (args...) -> true
for m in methods(getidx_args)
for m in methods(call_bc_f)
m.recursion_relation = dont_limit
end
for m in methods(getidx)
Expand Down Expand Up @@ -4123,7 +4098,6 @@ function window_bounds(space, bc)
return (li, lw, rw, ri)
end


Base.@propagate_inbounds function apply_stencil!(
space,
field_out,
Expand All @@ -4135,36 +4109,21 @@ Base.@propagate_inbounds function apply_stencil!(
# left window
lbw = LeftBoundaryWindow{Spaces.left_boundary_name(space)}()
@inbounds for idx in li:(lw - 1)
setidx!(
space,
field_out,
idx,
hidx,
getidx(space, bc, lbw, idx, hidx),
)
val = getidx(space, bc, lbw, idx, hidx)
setidx!(space, field_out, idx, hidx, val)
end
end
# interior
@inbounds for idx in lw:rw
setidx!(
space,
field_out,
idx,
hidx,
getidx(space, bc, Interior(), idx, hidx),
)
val = getidx(space, bc, Interior(), idx, hidx)
setidx!(space, field_out, idx, hidx, val)
end
if !Topologies.isperiodic(Spaces.vertical_topology(space))
# right window
rbw = RightBoundaryWindow{Spaces.right_boundary_name(space)}()
@inbounds for idx in (rw + 1):ri
setidx!(
space,
field_out,
idx,
hidx,
getidx(space, bc, rbw, idx, hidx),
)
val = getidx(space, bc, rbw, idx, hidx)
setidx!(space, field_out, idx, hidx, val)
end
end
return field_out
Expand Down
Loading