Skip to content

Commit

Permalink
Debug dist remap
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jan 14, 2025
1 parent 0ab0602 commit 1b70088
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
2 changes: 0 additions & 2 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,6 @@ steps:
- label: "Unit: distributed remapping with CUDA (1 process)"
key: distributed_remapping_gpu_1proc
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/Remapping/distributed_remapping.jl"
env:
CLIMACOMMS_DEVICE: "CUDA"
env:
CLIMACOMMS_DEVICE: "CUDA"
agents:
Expand Down
12 changes: 6 additions & 6 deletions ext/cuda/remapping_distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,22 @@ function set_interpolated_values_kernel!(
num_vert = length(vert_bounding_indices)
num_fields = length(field_values)

(size(out, 1) == num_horiz) || error("Incorrect input size")
(size(out, 2) == num_vert) || error("Incorrect input size")
(size(out, 3) == num_fields) || error("Incorrect input size")

hindex = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x
vindex = (blockIdx().y - Int32(1)) * blockDim().y + threadIdx().y
findex = (blockIdx().z - Int32(1)) * blockDim().z + threadIdx().z

totalThreadsX = gridDim().x * blockDim().x
totalThreadsY = gridDim().y * blockDim().y
totalThreadsZ = gridDim().z * blockDim().z

_, Nq = size(I1)
CI = CartesianIndex
for i in hindex:totalThreadsX:num_horiz
h = local_horiz_indices[i]
for j in vindex:totalThreadsY:num_vert
for j in 1:num_vert
v_lo, v_hi = vert_bounding_indices[j]
A, B = vert_interpolation_weights[j]
for k in findex:totalThreadsZ:num_fields
for k in 1:num_fields
if i num_horiz && j num_vert && k num_fields
out[i, j, k] = 0
for t in 1:Nq, s in 1:Nq
Expand Down
4 changes: 4 additions & 0 deletions test/Remapping/distributed_remapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import ClimaCore:
Operators,
Spaces,
Quadratures,
DataLayouts,
Topologies,
Remapping,
Hypsography
Expand Down Expand Up @@ -531,6 +532,9 @@ end
[sind(y) for x in longpts, y in latpts, z in zpts] rtol = 0.01
end

z_fv = Fields.field_values(coords.z)
@show z_fv
@show DataLayouts.universal_size(z_fv)
interp_z = Remapping.interpolate(remapper, coords.z)
expected_z = [z for x in longpts, y in latpts, z in zpts]
if ClimaComms.iamroot(context)
Expand Down

0 comments on commit 1b70088

Please sign in to comment.