diff --git a/ext/InterpolationsRegridderExt.jl b/ext/InterpolationsRegridderExt.jl index f64447aa..9b90dd1d 100644 --- a/ext/InterpolationsRegridderExt.jl +++ b/ext/InterpolationsRegridderExt.jl @@ -12,6 +12,7 @@ struct InterpolationsRegridder{ SPACE <: ClimaCore.Spaces.AbstractSpace, FIELD <: ClimaCore.Fields.Field, BC, + GITP, } <: Regridders.AbstractRegridder """ClimaCore.Space where the output Field will be defined""" @@ -22,6 +23,14 @@ struct InterpolationsRegridder{ """Tuple of extrapolation conditions as accepted by Interpolations.jl""" extrapolation_bc::BC + + # This is needed because Adapt moves from CPU to GPU and allocates new memory. + """Dictionary of preallocated areas of memory where to store the GPU interpolant (if + needed). Every time new data/dimensions are used in regrid, a new entry in the + dictionary is created. The keys of the dictionary a tuple of tuple + `(size(dimensions), size(data))`, with `dimensions` and `data` defined in `regrid`. + """ + _gpuitps::GITP end # Note, we swap Lat and Long! This is because according to the CF conventions longitude @@ -58,6 +67,8 @@ function Regridders.InterpolationsRegridder( ) coordinates = ClimaCore.Fields.coordinate_field(target_space) + num_dimensions = length(propertynames(coordinates)) + if isnothing(extrapolation_bc) extrapolation_bc = () if eltype(coordinates) <: ClimaCore.Geometry.LatLongPoint @@ -69,9 +80,42 @@ function Regridders.InterpolationsRegridder( end end - return InterpolationsRegridder(target_space, coordinates, extrapolation_bc) + num_dimensions == length(extrapolation_bc) || error( + "Number of boundary conditions does not match the number of dimensions", + ) + + # Let's figure out the type of _gpuitps by creating a simple spline + FT = ClimaCore.Spaces.undertype(target_space) + dimensions = ntuple(_ -> [zero(FT), one(FT)], num_dimensions) + data = zeros(FT, ntuple(_ -> 2, num_dimensions)) + itp = _create_linear_spline(FT, data, dimensions, extrapolation_bc) + fake_gpuitp = Adapt.adapt(ClimaComms.array_type(target_space), itp) + gpuitps = Dict((size.(dimensions), size(data)) => fake_gpuitp) + + return InterpolationsRegridder( + target_space, + coordinates, + extrapolation_bc, + gpuitps, + ) end +""" + _create_linear_spline(regridder::InterpolationsRegridder, data, dimensions) + +Create a linear spline for the given data on the given dimension (on the CPU). +""" +function _create_linear_spline(FT, data, dimensions, extrapolation_bc) + dimensions_FT = map(d -> FT.(d), dimensions) + + # Make a linear spline + return Intp.extrapolate( + Intp.interpolate(dimensions_FT, FT.(data), Intp.Gridded(Intp.Linear())), + extrapolation_bc, + ) +end + + """ regrid(regridder::InterpolationsRegridder, data, dimensions)::Field @@ -81,16 +125,31 @@ This function is allocating. """ function Regridders.regrid(regridder::InterpolationsRegridder, data, dimensions) FT = ClimaCore.Spaces.undertype(regridder.target_space) - dimensions_FT = map(d -> FT.(d), dimensions) - - # Make a linear spline - itp = Intp.extrapolate( - Intp.interpolate(dimensions_FT, FT.(data), Intp.Gridded(Intp.Linear())), - regridder.extrapolation_bc, - ) + itp = + _create_linear_spline(FT, data, dimensions, regridder.extrapolation_bc) + + key = (size.(dimensions), size(data)) + + if haskey(regridder._gpuitps, key) + for (k, k_new) in zip( + regridder._gpuitps[key].itp.knots, + Adapt.adapt( + ClimaComms.array_type(regridder.target_space), + itp.itp.knots, + ), + ) + k .= k_new + end + regridder._gpuitps[key].itp.coefs .= Adapt.adapt( + ClimaComms.array_type(regridder.target_space), + itp.itp.coefs, + ) + else + regridder._gpuitps[key] = + Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp) + end - # Move it to GPU (if needed) - gpuitp = Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp) + gpuitp = regridder._gpuitps[key] return map(regridder.coordinates) do coord gpuitp(totuple(coord)...) diff --git a/test/TestTools.jl b/test/TestTools.jl index 00490e39..d6a1284d 100644 --- a/test/TestTools.jl +++ b/test/TestTools.jl @@ -15,7 +15,15 @@ function make_spherical_space(FT; context = ClimaComms.context()) boundary_names = (:bottom, :top), ) vertmesh = ClimaCore.Meshes.IntervalMesh(vertdomain, nelems = zelem) - vert_center_space = ClimaCore.Spaces.CenterFiniteDifferenceSpace(vertmesh) + if pkgversion(ClimaCore) >= v"0.14.10" + vert_center_space = ClimaCore.Spaces.CenterFiniteDifferenceSpace( + ClimaComms.device(context), + vertmesh, + ) + else + vert_center_space = + ClimaCore.Spaces.CenterFiniteDifferenceSpace(vertmesh) + end horzdomain = ClimaCore.Domains.SphereDomain(radius) horzmesh = ClimaCore.Meshes.EquiangularCubedSphere(horzdomain, helem) diff --git a/test/data_handling.jl b/test/data_handling.jl index fd0f0f64..89b8019f 100644 --- a/test/data_handling.jl +++ b/test/data_handling.jl @@ -87,13 +87,10 @@ ClimaComms.init(context) target_space; regridder_type = :InterpolationsRegridder, file_reader_kwargs = (; preprocess_func = (data) -> 0.0 * data), - regridder_kwargs = (; - extrapolation_bc = (Intp.Flat(), Intp.Flat(), Intp.Flat()) - ), + regridder_kwargs = (; extrapolation_bc = (Intp.Flat(), Intp.Flat())), ) - @test data_handler.regridder.extrapolation_bc == - (Intp.Flat(), Intp.Flat(), Intp.Flat()) + @test data_handler.regridder.extrapolation_bc == (Intp.Flat(), Intp.Flat()) field = DataHandling.regridded_snapshot(data_handler) @test extrema(field) == (0.0, 0.0) end diff --git a/test/regridders.jl b/test/regridders.jl index 142a2416..1526e865 100644 --- a/test/regridders.jl +++ b/test/regridders.jl @@ -97,6 +97,17 @@ end extrapolation_bc, ) + # Test num_dimensions != length(extrapolation_bc) + @test_throws ErrorException Regridders.InterpolationsRegridder( + hv_center_space; + extrapolation_bc = ( + Interpolations.Periodic(), + Interpolations.Flat(), + Interpolations.Flat(), + Interpolations.Flat(), + ), + ) + regridded_lat = Regridders.regrid(reg_hv, data_lat3D, dimensions3D) regridded_lon = Regridders.regrid(reg_hv, data_lon3D, dimensions3D) regridded_z = Regridders.regrid(reg_hv, data_z3D, dimensions3D)