diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index f95139ed68..6105bca782 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -158,6 +158,21 @@ function Base.getindex(a::TracedRArray{T,N}, indices::CartesianIndex{N}) where { return Ops.gather_getindex(a, indices)[1] end +# Needed to prevent method ambiguity +function Base.getindex(a::TracedRArray{T,1}, indices::CartesianIndex{1}) where {T} + indices = + materialize_traced_array( + reshape( + TracedUtils.promote_to( + TracedRArray{Int,1}, collect(Int64, vcat(Tuple(indices)...)) + ), + 1, + 1, + ), + ) .- 1 + return Ops.gather_getindex(a, indices)[1] +end + function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} indices = TracedUtils.normalize_indices(a, indices...) diff --git a/test/indexing.jl b/test/indexing.jl index c72a18bd11..5d6aec3412 100644 --- a/test/indexing.jl +++ b/test/indexing.jl @@ -326,3 +326,14 @@ end y = Array(y_ra) @test res[:, 1, :] ≈ view(y, :, 1:3) end + +@testset "getindex ambiguity" begin + x = collect(Float32, 1:8) + x_ra = Reactant.to_rarray(x) + + idx = CartesianIndex(1) + + fn(x, idx) = @allowscalar x[idx] + + @test @jit(fn(x_ra, idx)) ≈ fn(x, idx) +end