Skip to content

Commit

Permalink
Fix GPU inference (2065)
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Nov 1, 2024
1 parent 3e95814 commit 908b717
Showing 2 changed files with 16 additions and 14 deletions.
17 changes: 15 additions & 2 deletions src/DataLayouts/non_extruded_broadcasted.jl
Original file line number Diff line number Diff line change
@@ -140,8 +140,21 @@ Base.@propagate_inbounds function _broadcast_getindex(
end
@inline _broadcast_getindex_evalf(f::Tf, args::Vararg{Any, N}) where {Tf, N} =
f(args...) # not propagate_inbounds
Base.@propagate_inbounds _getindex(args::Tuple, I) =
(_broadcast_getindex(args[1], I), _getindex(Base.tail(args), I)...)


# To fix https://github.com/CliMA/ClimaCore.jl/issues/2065:
# Base.@propagate_inbounds _getindex(args::Tuple, I) =
# (_broadcast_getindex(args[1], I), _getindex(Base.tail(args), I)...)
tuple_length(::Type{T}) where {T <: Tuple} = length(T.parameters)
@generated function _getindex(args::T, I) where {T}
quote
Base.Cartesian.@ntuple $(tuple_length(T)) ξ -> begin
_broadcast_getindex(args[ξ], I)
end
end
end


Base.@propagate_inbounds _getindex(args::Tuple{Any}, I) =
(_broadcast_getindex(args[1], I),)
Base.@propagate_inbounds _getindex(args::Tuple{}, I) = ()
13 changes: 1 addition & 12 deletions test/Fields/inference_repro.jl
Original file line number Diff line number Diff line change
@@ -64,16 +64,5 @@ end

using Test
@testset "GPU inference failure" begin
if ClimaComms.device() isa ClimaComms.CUDADevice
@test_broken try
main(Float64)
true
catch e
@assert occursin("GPUCompiler.InvalidIRError", string(e))
@assert occursin("dynamic function invocation", e.errors[1][1])
false
end
else
main(Float64)
end
main(Float64)
end

0 comments on commit 908b717

Please sign in to comment.