Skip to content

Commit

Permalink
try simplifying multiply_matrix_at_index
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jun 28, 2024
1 parent f292195 commit 92a638b
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 30 deletions.
3 changes: 2 additions & 1 deletion ext/cuda/operators_finite_difference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ function copyto_stencil_kernel!(
Operators.getidx(space, bc, lwindow, idx, hidx),
)
elseif idx > rw
rwindow = RightBoundaryWindow{Spaces.right_boundary_name(space)}()
rwindow =
RightBoundaryWindow{Spaces.right_boundary_name(space)}()
setidx!(
space,
out,
Expand Down
41 changes: 26 additions & 15 deletions src/MatrixFields/matrix_multiplication.jl
Original file line number Diff line number Diff line change
Expand Up @@ -407,12 +407,11 @@ Base.@propagate_inbounds function multiply_matrix_at_index_mat_mat(
# of as a map from boundary_modified_ld1 to boundary_modified_ud1. For
# simplicity, use zero padding for rows that are outside the matrix.
# Wrap the rows in a BandMatrixRow so that they can be easily indexed.
nt_mr = ntuple->ld1+ζ-1, Val(ud1-ld1+1))
nt_mr = ntuple -> ld1 + ζ - 1, Val(ud1 - ld1 + 1))
matrix2_rows = unrolled_map(nt_mr) do d
# TODO: Use @propagate_inbounds_meta instead of @inline_meta.
Base.@_inline_meta
if isnothing(bc) ||
boundary_modified_ld1 <= d <= boundary_modified_ud1
if isnothing(bc) || boundary_modified_ld1 <= d <= boundary_modified_ud1
@inbounds Operators.getidx(space, matrix2, loc, idx + d, hidx)
else
zero(eltype(matrix2)) # This row is outside the matrix.
Expand All @@ -430,20 +429,34 @@ Base.@propagate_inbounds function multiply_matrix_at_index_mat_mat(
# to boundary_modified_prod_ud. For simplicity, use zero padding for
# entries that are outside the matrix. Wrap the entries in a
# BandMatrixRow before returning them.
N = prod_ud-prod_ld+1
nt_pe = ntuple->prod_ld+ξ-1, Val(N))
N = prod_ud - prod_ld + 1
nt_pe = ntuple -> prod_ld + ξ - 1, Val(N))
prod_entries = unrolled_map(nt_pe) do prod_d
# TODO: Use @propagate_inbounds_meta instead of @inline_meta.
Base.@_inline_meta
if isnothing(bc) ||
boundary_modified_prod_ld <= prod_d <= boundary_modified_prod_ud
boundary_modified_prod_ld <=
prod_d <=
boundary_modified_prod_ud
prod_entry = zero_value
min_d = max(boundary_modified_ld1, prod_d - ud2)
max_d = min(boundary_modified_ud1, prod_d - ld2)
min_d = max(
boundary_modified_ld1,
prod_d - ud2,
)
max_d = min(
boundary_modified_ud1,
prod_d - ld2,
)
@inbounds for d in min_d:max_d
value1 = matrix1_row[d]
value2 = matrix2_rows_wrapper[d][prod_d - d]
value2_lg = Geometry.LocalGeometry(space, idx + d, hidx)
value2 =
matrix2_rows_wrapper[d][prod_d - d]
value2_lg =
Geometry.LocalGeometry(
space,
idx + d,
hidx,
)
prod_entry = radd(
prod_entry,
rmul_with_projection(value1, value2, value2_lg),
Expand All @@ -453,7 +466,7 @@ Base.@propagate_inbounds function multiply_matrix_at_index_mat_mat(
else
zero_value # This entry is outside the matrix.
end
end::NTuple{N,eltype(prod_type)}
end::NTuple{N, eltype(prod_type)}
return BandMatrixRow{prod_ld}(prod_entries...)
end

Expand Down Expand Up @@ -484,10 +497,8 @@ Base.@propagate_inbounds function multiply_matrix_at_index_mat_vec(
value1 = matrix1_row[d]
value2 = Operators.getidx(space, vector, loc, idx + d, hidx)
value2_lg = Geometry.LocalGeometry(space, idx + d, hidx)
prod_value = radd(
prod_value,
rmul_with_projection(value1, value2, value2_lg),
)
prod_value =
radd(prod_value, rmul_with_projection(value1, value2, value2_lg))
end # Using a for-loop is currently faster than using mapreduce.
return prod_value
end
Expand Down
22 changes: 16 additions & 6 deletions src/Operators/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ const AllCenterFiniteDifferenceSpace = Union{
left_center_boundary_idx(space)
@inline right_idx(space::AllCenterFiniteDifferenceSpace) =
right_center_boundary_idx(space)
@inline left_idx(space::AllFaceFiniteDifferenceSpace) = left_face_boundary_idx(space)
@inline right_idx(space::AllFaceFiniteDifferenceSpace) = right_face_boundary_idx(space)
@inline left_idx(space::AllFaceFiniteDifferenceSpace) =
left_face_boundary_idx(space)
@inline right_idx(space::AllFaceFiniteDifferenceSpace) =
right_face_boundary_idx(space)

@inline left_center_boundary_idx(space::AllFiniteDifferenceSpace) = 1
@inline right_center_boundary_idx(space::AllFiniteDifferenceSpace) = size(
Expand Down Expand Up @@ -604,8 +606,12 @@ Base.@propagate_inbounds stencil_interior(
arg,
) = right_idx(space)

@inline left_interior_idx(space::AbstractSpace, ::LeftBiasedF2C, ::SetValue, arg) =
left_idx(space) + 1
@inline left_interior_idx(
space::AbstractSpace,
::LeftBiasedF2C,
::SetValue,
arg,
) = left_idx(space) + 1
Base.@propagate_inbounds function stencil_left_boundary(
::LeftBiasedF2C,
bc::SetValue,
Expand Down Expand Up @@ -852,8 +858,12 @@ Base.@propagate_inbounds stencil_interior(
arg,
) = right_idx(space)

@inline right_interior_idx(space::AbstractSpace, ::RightBiasedF2C, ::SetValue, arg) =
right_idx(space) - 1
@inline right_interior_idx(
space::AbstractSpace,
::RightBiasedF2C,
::SetValue,
arg,
) = right_idx(space) - 1
Base.@propagate_inbounds function stencil_right_boundary(
::RightBiasedF2C,
bc::SetValue,
Expand Down
24 changes: 16 additions & 8 deletions src/Operators/operator2stencil.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,23 @@ extrapolation_increases_bandwidth_error(op_type::Type) = throw(
),
)

@inline has_boundary(op::Operator2Stencil, bw::LeftBoundaryWindow{name}) where {name} =
has_boundary(op.op, bw)
@inline has_boundary(op::Operator2Stencil, bw::RightBoundaryWindow{name}) where {name} =
has_boundary(op.op, bw)
@inline has_boundary(
op::Operator2Stencil,
bw::LeftBoundaryWindow{name},
) where {name} = has_boundary(op.op, bw)
@inline has_boundary(
op::Operator2Stencil,
bw::RightBoundaryWindow{name},
) where {name} = has_boundary(op.op, bw)

@inline get_boundary(op::Operator2Stencil, bw::LeftBoundaryWindow{name}) where {name} =
get_boundary(op.op, bw)
@inline get_boundary(op::Operator2Stencil, bw::RightBoundaryWindow{name}) where {name} =
get_boundary(op.op, bw)
@inline get_boundary(
op::Operator2Stencil,
bw::LeftBoundaryWindow{name},
) where {name} = get_boundary(op.op, bw)
@inline get_boundary(
op::Operator2Stencil,
bw::RightBoundaryWindow{name},
) where {name} = get_boundary(op.op, bw)

function return_eltype(op::Operator2Stencil, args...)
lbw, ubw = stencil_interior_width(op.op, args...)[1]
Expand Down

0 comments on commit 92a638b

Please sign in to comment.