Skip to content

Commit

Permalink
Fix reinstantiation of spectral broadcasted
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jan 27, 2025
1 parent bbad885 commit e2693e7
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/Operators/spectralelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ function Base.Broadcast.instantiate(sbc::SpectralBroadcasted)
Base.Broadcast.check_broadcast_axes(axes, args...)
end
end
op = typeof(op)(axes)
# If we've already instantiated, then we need to strip the type parameters,
# for example, `Divergence{()}(axes)`.
op = unionall_type(typeof(op)){()}(axes)
Style = AbstractSpectralStyle(ClimaComms.device(axes))
return SpectralBroadcasted{Style}(op, args, axes)
end
Expand Down Expand Up @@ -1323,6 +1325,7 @@ struct Interpolate{I, S} <: TensorOperator
space::S
end
Interpolate(space) = Interpolate{operator_axes(space), typeof(space)}(space)
Interpolate{()}(space) = Interpolate{operator_axes(space), typeof(space)}(space)

function apply_operator(op::Interpolate{(1,)}, space_out, slabidx, arg)
FT = Spaces.undertype(space_out)
Expand Down Expand Up @@ -1412,6 +1415,7 @@ struct Restrict{I, S} <: TensorOperator
space::S
end
Restrict(space) = Restrict{operator_axes(space), typeof(space)}(space)
Restrict{()}(space) = Restrict{operator_axes(space), typeof(space)}(space)

function apply_operator(op::Restrict{(1,)}, space_out, slabidx, arg)
FT = Spaces.undertype(space_out)
Expand Down
53 changes: 53 additions & 0 deletions test/Operators/unit_reinstantiate_bc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#=
julia --project=.buildkite
using Revise; include("test/Operators/unit_reinstantiate_bc.jl")
=#

# TODO: make this unit test more low-level
using ClimaComms
ClimaComms.@import_required_backends
using ClimaCore.CommonSpaces
using ClimaCore: Spaces, Fields, Geometry, ClimaCore, Operators
using LazyBroadcast: lazy
using Test
using Base.Broadcast: materialize

const divₕ = Operators.Divergence()
const wgradₕ = Operators.WeakGradient()
const curlₕ = Operators.Curl()
const wcurlₕ = Operators.WeakCurl()

using ClimaCore.CommonSpaces

function foo_tendency_uₕ(ᶜuₕ, zmax)
return lazy.(
@. (
wgradₕ(divₕ(ᶜuₕ)) - Geometry.project(
Geometry.Covariant12Axis(),
wcurlₕ(Geometry.project(Geometry.Covariant3Axis(), curlₕ(ᶜuₕ))),
)
)
)
end

@testset "Reinstantiation of SpectralBroadcasted" begin
FT = Float64
ᶜspace = ExtrudedCubedSphereSpace(
FT;
z_elem = 10,
z_min = 0,
z_max = 1,
radius = 10,
h_elem = 10,
n_quad_points = 4,
staggering = CellCenter(),
)
ᶠspace = Spaces.face_space(ᶜspace)
ᶠz = Fields.coordinate_field(ᶠspace).z
ᶜz = Fields.coordinate_field(ᶜspace).z
ᶜuₕ = map(z -> zero(Geometry.Covariant12Vector{eltype(z)}), ᶜz)
zmax = Spaces.z_max(axes(ᶠz))
vst_uₕ = foo_tendency_uₕ(ᶜuₕ, zmax)
ᶜuₕₜ = zero(ᶜuₕ)
@. ᶜuₕₜ += vst_uₕ
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ UnitTest("Spaces - DSS cubed sphere" ,"Spaces/ddss1_cs.jl"),
UnitTest("Sphere spaces" ,"Spaces/sphere.jl"),
# UnitTest("Terrain warp" ,"Spaces/terrain_warp.jl"), # appears to hang on GHA
UnitTest("Fields" ,"Fields/unit_field.jl"), # has benchmarks
UnitTest("Reinstantiate broadcasted" ,"test/Operators/unit_reinstantiate_bc.jl"),
UnitTest("Spectral elem - rectilinear" ,"Operators/spectralelement/rectilinear.jl"),
UnitTest("Spectral elem - opt" ,"Operators/spectralelement/opt.jl"),
UnitTest("Spectral elem - gradient tensor" ,"Operators/spectralelement/covar_deriv_ops.jl"),
Expand Down

0 comments on commit e2693e7

Please sign in to comment.