diff --git a/src/DataLayouts/struct.jl b/src/DataLayouts/struct.jl index fb6d390a8b..a143afbacd 100644 --- a/src/DataLayouts/struct.jl +++ b/src/DataLayouts/struct.jl @@ -179,6 +179,13 @@ Base.@propagate_inbounds @generated function get_struct( ::Val{D}, start_index::CartesianIndex, ) where {T, S, D} + # recursion base case: hit array type is the same as the struct leaf type + if T === S # Use Union-splitting for better latency + return quote + Base.@_propagate_inbounds_meta + @inbounds return array[start_index] + end + end tup = :(()) for i in 1:fieldcount(S) push!( @@ -201,16 +208,6 @@ Base.@propagate_inbounds @generated function get_struct( end end -# recursion base case: hit array type is the same as the struct leaf type -Base.@propagate_inbounds function get_struct( - array::AbstractArray{S}, - ::Type{S}, - ::Val{D}, - start_index::CartesianIndex, -) where {S, D} - @inbounds return array[start_index] -end - """ set_struct!(array, val::S, Val(D), start_index) diff --git a/src/Operators/finitedifference.jl b/src/Operators/finitedifference.jl index 5f4d6a73ea..d367fc8f49 100644 --- a/src/Operators/finitedifference.jl +++ b/src/Operators/finitedifference.jl @@ -3759,25 +3759,37 @@ end Base.@propagate_inbounds function getidx( parent_space, - bc::StencilBroadcasted, - loc::Interior, - idx, - hidx, -) - space = reconstruct_placeholder_space(axes(bc), parent_space) - stencil_interior(bc.op, loc, space, idx, hidx, bc.args...) -end - -Base.@propagate_inbounds function getidx( - parent_space, - bc::StencilBroadcasted, - loc::LeftBoundaryWindow, + bc::Union{StencilBroadcasted, Base.Broadcast.Broadcasted}, + loc::Location, idx, hidx, ) + # Use Union-splitting here (x isa X) instead of dispatch + # for improved latency. space = reconstruct_placeholder_space(axes(bc), parent_space) + if bc isa Base.Broadcast.Broadcasted + # Manually call bc.f for small tuples (improved latency) + (; args) = bc + N = length(bc.args) + if N == 1 + return bc.f(getidx(space, args[1], loc, idx, hidx)) + elseif N == 2 + return bc.f( + getidx(space, args[1], loc, idx, hidx), + getidx(space, args[2], loc, idx, hidx), + ) + elseif N == 3 + return bc.f( + getidx(space, args[1], loc, idx, hidx), + getidx(space, args[2], loc, idx, hidx), + getidx(space, args[3], loc, idx, hidx), + ) + end + return call_bc_f(bc.f, space, loc, idx, hidx, args...) + end op = bc.op - if should_call_left_boundary(idx, space, bc, loc) + if loc isa LeftBoundaryWindow && + should_call_left_boundary(idx, space, bc, loc) stencil_left_boundary( op, get_boundary(op, loc), @@ -3787,22 +3799,8 @@ Base.@propagate_inbounds function getidx( hidx, bc.args..., ) - else - # fallback to interior stencil - stencil_interior(op, loc, space, idx, hidx, bc.args...) - end -end - -Base.@propagate_inbounds function getidx( - parent_space, - bc::StencilBroadcasted, - loc::RightBoundaryWindow, - idx, - hidx, -) - op = bc.op - space = reconstruct_placeholder_space(axes(bc), parent_space) - if should_call_right_boundary(idx, space, bc, loc) + elseif loc isa RightBoundaryWindow && + should_call_right_boundary(idx, space, bc, loc) stencil_right_boundary( op, get_boundary(op, loc), @@ -3813,11 +3811,11 @@ Base.@propagate_inbounds function getidx( bc.args..., ) else - # fallback to interior stencil - stencil_interior(op, loc, space, idx, hidx, bc.args...) + stencil_interior(bc.op, loc, space, idx, hidx, bc.args...) end end + # broadcasting a StencilStyle gives a CompositeStencilStyle Base.Broadcast.BroadcastStyle( ::Type{<:StencilBroadcasted{Style}}, @@ -3902,48 +3900,25 @@ end @noinline inferred_getidx_error(idx_type::Type, space_type::Type) = error("Invalid index type `$idx_type` for field on space `$space_type`") - # recursively unwrap getidx broadcast arguments in a way that is statically reducible by the optimizer -Base.@propagate_inbounds getidx_args( - space, - args::Tuple, - loc::Location, - idx, - hidx, -) = ( - getidx(space, args[1], loc, idx, hidx), - getidx_args(space, Base.tail(args), loc, idx, hidx)..., -) -Base.@propagate_inbounds getidx_args( +@generated function call_bc_f( + f::F, space, - arg::Tuple{Any}, - loc::Location, - idx, - hidx, -) = (getidx(space, arg[1], loc, idx, hidx),) -Base.@propagate_inbounds getidx_args( - space, - ::Tuple{}, - loc::Location, - idx, - hidx, -) = () - -Base.@propagate_inbounds function getidx( - parent_space, - bc::Base.Broadcast.Broadcasted, loc::Location, idx, hidx, -) - space = reconstruct_placeholder_space(axes(bc), parent_space) - _args = getidx_args(space, bc.args, loc, idx, hidx) - bc.f(_args...) + args..., +) where {F} + N = length(args) + return quote + Base.@_propagate_inbounds_meta + Base.Cartesian.@ncall $N f i -> getidx(space, args[i], loc, idx, hidx) + end end if hasfield(Method, :recursion_relation) dont_limit = (args...) -> true - for m in methods(getidx_args) + for m in methods(call_bc_f) m.recursion_relation = dont_limit end for m in methods(getidx) @@ -4123,7 +4098,6 @@ function window_bounds(space, bc) return (li, lw, rw, ri) end - Base.@propagate_inbounds function apply_stencil!( space, field_out, @@ -4135,36 +4109,21 @@ Base.@propagate_inbounds function apply_stencil!( # left window lbw = LeftBoundaryWindow{Spaces.left_boundary_name(space)}() @inbounds for idx in li:(lw - 1) - setidx!( - space, - field_out, - idx, - hidx, - getidx(space, bc, lbw, idx, hidx), - ) + val = getidx(space, bc, lbw, idx, hidx) + setidx!(space, field_out, idx, hidx, val) end end # interior @inbounds for idx in lw:rw - setidx!( - space, - field_out, - idx, - hidx, - getidx(space, bc, Interior(), idx, hidx), - ) + val = getidx(space, bc, Interior(), idx, hidx) + setidx!(space, field_out, idx, hidx, val) end if !Topologies.isperiodic(Spaces.vertical_topology(space)) # right window rbw = RightBoundaryWindow{Spaces.right_boundary_name(space)}() @inbounds for idx in (rw + 1):ri - setidx!( - space, - field_out, - idx, - hidx, - getidx(space, bc, rbw, idx, hidx), - ) + val = getidx(space, bc, rbw, idx, hidx) + setidx!(space, field_out, idx, hidx, val) end end return field_out