Skip to content

Commit

Permalink
Try #767:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored Feb 23, 2023
2 parents 2507a7f + be284b1 commit 9944fd8
Show file tree
Hide file tree
Showing 18 changed files with 1,209 additions and 510 deletions.
9 changes: 9 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,15 @@ steps:
artifact_paths:
- "examples/sphere/output/cg_sphere_shallowwater_rossby_haurwitz_alpha0/*"

- label: ":flower_playing_cards: CUDA Shallow-water 2D sphere"
key: "cuda_shallowwater_2d_cg_sphere"
command:
- "julia --color=yes --project=examples examples/sphere/shallow_water_cuda.jl"
artifact_paths:
- "examples/sphere/output/cuda_shallowwater_2d_cg_sphere/*"
agents:
slurm_gpu: 1

- group: "Examples hybrid sphere"
steps:

Expand Down
70 changes: 70 additions & 0 deletions examples/sphere/shallow_water_cuda.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
using CUDA
using ClimaComms
using DocStringExtensions
using LinearAlgebra
using OrdinaryDiffEq

CUDA.allowscalar(false)

import ClimaCore:
Device,
Expand Down Expand Up @@ -442,6 +446,48 @@ function set_initial_condition(
return Y
end

function rhs!(dYdt, y, parameters, t)
#@nvtx "rhs!" color = colorant"red" begin
(; f, h_s, ghost_buffer, params) = parameters
(; D₄, g) = params

div = Operators.Divergence()
wdiv = Operators.WeakDivergence()
grad = Operators.Gradient()
wgrad = Operators.WeakGradient()
curl = Operators.Curl()
wcurl = Operators.WeakCurl()

# Compute hyperviscosity first
#@nvtx "Hyperviscosity (rhs!)" color = colorant"green" begin
@. dYdt.h = wdiv(grad(y.h))
@. dYdt.u =
wgrad(div(y.u)) -
Geometry.Covariant12Vector(wcurl(Geometry.Covariant3Vector(curl(y.u))))

Spaces.weighted_dss2!(dYdt, ghost_buffer)

@. dYdt.h = -D₄ * wdiv(grad(dYdt.h))
# split to avoid "device kernel image is invalid (code 200, ERROR_INVALID_IMAGE)"
@. dYdt.u =
wgrad(div(dYdt.u)) - Geometry.Covariant12Vector(
wcurl(Geometry.Covariant3Vector(curl(dYdt.u))),
)
@. dYdt.u = -D₄ * dYdt.u
#end
#@nvtx "h and u (rhs!)" color = colorant"blue" begin
# Add in pieces
@. begin
dYdt.h += -wdiv(y.h * y.u)
dYdt.u += -grad(g * (y.h + h_s) + norm(y.u)^2 / 2) #+
dYdt.u += y.u × (f + curl(y.u))
end
Spaces.weighted_dss2!(dYdt, ghost_buffer)
#end
#end
return dYdt
end

function shallow_water_driver_cuda(ARGS, ::Type{FT}) where {FT}
device = Device.device()
context = ClimaComms.SingletonCommsContext(device)
Expand Down Expand Up @@ -478,11 +524,35 @@ function shallow_water_driver_cuda(ARGS, ::Type{FT}) where {FT}
f = set_coriolis_parameter(space, test)
h_s = surface_topography(space, test)
Y = set_initial_condition(space, test)
dYdt = similar(Y)

ghost_buffer = Spaces.create_dss_buffer(Y)
Spaces.weighted_dss_start!(Y, ghost_buffer)
Spaces.weighted_dss_internal!(Y, ghost_buffer)
Spaces.weighted_dss_ghost!(Y, ghost_buffer)

parameters =
(; f = f, h_s = h_s, ghost_buffer = ghost_buffer, params = test.params)
rhs!(dYdt, Y, parameters, 0.0)

#=
# Solve the ODE
dt = 9 * 60
T = 86400 * 2
prob = ODEProblem(rhs!, Y, (0.0, T), parameters)
integrator = OrdinaryDiffEq.init(
prob,
SSPRK33(),
dt = dt,
saveat = dt,
progress = true,
adaptive = false,
progress_message = (dt, u, p, t) -> t,
)
sol = @timev OrdinaryDiffEq.solve!(integrator)
=#
return nothing
end

Expand Down
122 changes: 106 additions & 16 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ indexes the underlying array as `[i,j,k,f,v,h]`
module DataLayouts

import Base: Base, @propagate_inbounds
import StaticArrays: SOneTo, MArray
import StaticArrays: SOneTo, MArray, SArray
import ClimaComms

import ..enable_threading, ..slab, ..slab_args, ..column, ..column_args, ..level
Expand Down Expand Up @@ -296,6 +296,19 @@ function replace_basetype(data::IJFH{S, Nij}, ::Type{T}) where {S, Nij, T}
return IJFH{S′, Nij}(similar(array, T))
end

@inline function Base.getindex(data::IJFH{S}, i, j, _, _, h) where {S}
@inbounds get_struct(parent(data), S, Val(3), CartesianIndex(i, j, 1, h))
end
@inline function Base.setindex!(data::IJFH{S}, val, i, j, _, _, h) where {S}
@inbounds set_struct!(
parent(data),
convert(S, val),
Val(3),
CartesianIndex(i, j, 1, h),
)
end


Base.length(data::IJFH) = size(parent(data), 4)

@generated function _property_view(
Expand Down Expand Up @@ -455,6 +468,18 @@ end
IFH{SS, Ni}(dataview)
end

@inline function Base.getindex(data::IFH{S}, i, _, _, _, h) where {S}
@inbounds get_struct(parent(data), S, Val(2), CartesianIndex(i, 1, h))
end
@inline function Base.setindex!(data::IFH{S}, val, i, _, _, _, h) where {S}
@inbounds set_struct!(
parent(data),
convert(S, val),
Val(3),
CartesianIndex(i, 1, h),
)
end

# ======================
# Data0D DataLayout
# ======================
Expand Down Expand Up @@ -523,15 +548,20 @@ end
end

Base.@propagate_inbounds function Base.getindex(data::DataF{S}) where {S}
@inbounds get_struct(parent(data), S)
@inbounds get_struct(parent(data), S, Val(1), CartesianIndex(1))
end

@propagate_inbounds function Base.getindex(col::Data0D, I::CartesianIndex{5})
@inbounds col[]
end

Base.@propagate_inbounds function Base.setindex!(data::DataF{S}, val) where {S}
@inbounds set_struct!(parent(data), convert(S, val))
@inbounds set_struct!(
parent(data),
convert(S, val),
Val(1),
CartesianIndex(1),
)
end

@propagate_inbounds function Base.setindex!(
Expand Down Expand Up @@ -598,6 +628,14 @@ end

rebuild(data::IJF{S, Nij}, array::A) where {S, Nij, A <: AbstractArray} =
IJF{S, Nij}(array)
function IJF{S, Nij}(::Type{MArray}, ::Type{T}) where {S, Nij, T}
Nf = typesize(T, S)
array = MArray{Tuple{Nij, Nij, Nf}, T, 3, Nij * Nij * Nf}(undef)
IJF{S, Nij}(array)
end
function SArray(ijf::IJF{S, Nij, <:MArray}) where {S, Nij}
IJF{S, Nij}(SArray(parent(ijf)))
end

function replace_basetype(data::IJF{S, Nij}, ::Type{T}) where {S, Nij, T}
array = parent(data)
Expand Down Expand Up @@ -643,11 +681,13 @@ end
data::IJF{S, Nij},
i::Integer,
j::Integer,
k = nothing,
v = nothing,
h = nothing,
) where {S, Nij}
@boundscheck (1 <= i <= Nij && 1 <= j <= Nij) ||
throw(BoundsError(data, (i, j)))
dataview = @inbounds view(parent(data), i, j, :)
@inbounds get_struct(dataview, S)
@inbounds get_struct(parent(data), S, Val(3), CartesianIndex(i, j, 1))
end

@inline function Base.setindex!(
Expand All @@ -658,8 +698,12 @@ end
) where {S, Nij}
@boundscheck (1 <= i <= Nij && 1 <= j <= Nij) ||
throw(BoundsError(data, (i, j)))
dataview = @inbounds view(parent(data), i, j, :)
set_struct!(dataview, convert(S, val))
@inbounds set_struct!(
parent(data),
convert(S, val),
Val(3),
CartesianIndex(i, j, 1),
)
end

@inline function column(data::IJF{S, Nij}, i, j) where {S, Nij}
Expand Down Expand Up @@ -719,6 +763,14 @@ function IF{S, Ni}(array::AbstractArray{T, 2}) where {S, Ni, T}
check_basetype(T, S)
IF{S, Ni, typeof(array)}(array)
end
function IF{S, Ni}(::Type{MArray}, ::Type{T}) where {S, Ni, T}
Nf = typesize(T, S)
array = MArray{Tuple{Ni, Nf}, T, 2, Ni * Nf}(undef)
IF{S, Ni}(array)
end
function SArray(data::IF{S, Ni, <:MArray}) where {S, Ni}
IF{S, Ni}(SArray(parent(data)))
end

function replace_basetype(data::IF{S, Ni}, ::Type{T}) where {S, Ni, T}
array = parent(data)
Expand Down Expand Up @@ -756,16 +808,26 @@ end
IF{SS, Ni}(dataview)
end

@inline function Base.getindex(data::IF{S, Ni}, i::Integer) where {S, Ni}
@inline function Base.getindex(
data::IF{S, Ni},
i::Integer,
j = nothing,
k = nothing,
v = nothing,
h = nothing,
) where {S, Ni}
@boundscheck (1 <= i <= Ni) || throw(BoundsError(data, (i,)))
dataview = @inbounds view(parent(data), i, :)
@inbounds get_struct(dataview, S)
@inbounds get_struct(parent(data), S, Val(2), CartesianIndex(i, 1))
end

@inline function Base.setindex!(data::IF{S, Ni}, val, i::Integer) where {S, Ni}
@boundscheck (1 <= i <= Ni) || throw(BoundsError(data, (i,)))
dataview = @inbounds view(parent(data), i, :)
set_struct!(dataview, convert(S, val))
@inbounds set_struct!(
parent(data),
convert(S, val),
Val(2),
CartesianIndex(i, 1),
)
end

@inline function column(data::IF{S, Ni}, i) where {S, Ni}
Expand Down Expand Up @@ -851,8 +913,7 @@ end
@inline function Base.getindex(data::VF{S}, v::Integer) where {S}
@boundscheck 1 <= v <= size(parent(data), 1) ||
throw(BoundsError(data, (v,)))
dataview = @inbounds view(parent(data), v, :)
@inbounds get_struct(dataview, S)
@inbounds get_struct(parent(data), S, Val(2), CartesianIndex(v, 1))
end

@propagate_inbounds function Base.getindex(
Expand All @@ -873,8 +934,12 @@ end
@inline function Base.setindex!(data::VF{S}, val, v::Integer) where {S}
@boundscheck (1 <= v <= length(parent(data))) ||
throw(BoundsError(data, (v,)))
dataview = @inbounds view(parent(data), v, :)
@inbounds set_struct!(dataview, convert(S, val))
@inbounds set_struct!(
parent(data),
convert(S, val),
Val(2),
CartesianIndex(v, 1),
)
end

@inline function column(data::VF, i, h)
Expand Down Expand Up @@ -1039,6 +1104,19 @@ function gather(
nothing
end
end

@inline function Base.getindex(data::VIJFH{S}, i, j, _, v, h) where {S}
@inbounds get_struct(parent(data), S, Val(4), CartesianIndex(v, i, j, 1, h))
end
@inline function Base.setindex!(data::VIJFH{S}, val, i, j, _, v, h) where {S}
@inbounds set_struct!(
parent(data),
convert(S, val),
Val(4),
CartesianIndex(v, i, j, 1, h),
)
end

# ======================
# Data1DX DataLayout
# ======================
Expand Down Expand Up @@ -1166,6 +1244,18 @@ end
IFH{S, Nij}(dataview)
end

@inline function Base.getindex(data::VIFH{S}, i, _, _, v, h) where {S}
@inbounds get_struct(parent(data), S, Val(3), CartesianIndex(v, i, 1, h))
end
@inline function Base.setindex!(data::VIFH{S}, val, i, _, _, v, h) where {S}
@inbounds set_struct!(
parent(data),
convert(S, val),
Val(3),
CartesianIndex(v, i, 1, h),
)
end

# =========================================
# Special DataLayouts for regular gridding
# =========================================
Expand Down
Loading

0 comments on commit 9944fd8

Please sign in to comment.