Skip to content

Commit e82c420

Browse files
authored
fix: prevent method ambiguity for CartesianIndex{1} (#730)
1 parent 4ca5147 commit e82c420

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

src/TracedRArray.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,21 @@ function Base.getindex(a::TracedRArray{T,N}, indices::CartesianIndex{N}) where {
158158
return Ops.gather_getindex(a, indices)[1]
159159
end
160160

161+
# Needed to prevent method ambiguity
162+
function Base.getindex(a::TracedRArray{T,1}, indices::CartesianIndex{1}) where {T}
163+
indices =
164+
materialize_traced_array(
165+
reshape(
166+
TracedUtils.promote_to(
167+
TracedRArray{Int,1}, collect(Int64, vcat(Tuple(indices)...))
168+
),
169+
1,
170+
1,
171+
),
172+
) .- 1
173+
return Ops.gather_getindex(a, indices)[1]
174+
end
175+
161176
function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
162177
indices = TracedUtils.normalize_indices(a, indices...)
163178

test/indexing.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,14 @@ end
326326
y = Array(y_ra)
327327
@test res[:, 1, :] view(y, :, 1:3)
328328
end
329+
330+
@testset "getindex ambiguity" begin
331+
x = collect(Float32, 1:8)
332+
x_ra = Reactant.to_rarray(x)
333+
334+
idx = CartesianIndex(1)
335+
336+
fn(x, idx) = @allowscalar x[idx]
337+
338+
@test @jit(fn(x_ra, idx)) fn(x, idx)
339+
end

0 commit comments

Comments
 (0)