Skip to content

Commit 53644c9

Browse files
authored
fix: inconsistent return dims (#558)
* fix: inconsistent return dims * test: inconsistent indexing * fix: inconsistent dimensions inside gather getindex
1 parent 532fd73 commit 53644c9

File tree

3 files changed

+41
-5
lines changed

3 files changed

+41
-5
lines changed

src/TracedRArray.jl

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ function Base.getindex(
6565
end
6666

6767
Base.getindex(a::TracedRArray{T,0}) where {T} = TracedRNumber{T}((), a.mlir_data)
68+
function Base.getindex(a::TracedRArray{T,0}, ::CartesianIndex{0}) where {T}
69+
return TracedRNumber{T}((), a.mlir_data)
70+
end
6871

6972
function generate_index_list(i1, is...)
7073
list = reshape(i1, :, 1) .- 1
@@ -123,7 +126,11 @@ function Base.getindex(a::TracedRArray{T,N}, indices::CartesianIndex{N}) where {
123126
indices =
124127
materialize_traced_array(
125128
reshape(
126-
TracedUtils.promote_to(TracedRArray{Int,1}, vcat(Tuple(indices)...)), 1, N
129+
TracedUtils.promote_to(
130+
TracedRArray{Int,1}, collect(Int64, vcat(Tuple(indices)...))
131+
),
132+
1,
133+
N,
127134
),
128135
) .- 1
129136
return Ops.gather_getindex(a, indices)[1]
@@ -157,10 +164,19 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
157164
if any(i -> unwrapped_eltype(i) <: Bool, indices)
158165
error("Boolean indexing with TracedRArrays isn't fully supported yet.")
159166
end
160-
indices_list = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{Int,1}), indices)
167+
idxs = map(indices) do i
168+
i isa Number && return fill(i, 1)
169+
return i
170+
end
171+
indices_list = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{Int,1}), idxs)
161172
indices_list = generate_index_list(indices_list...)
162173
res = Ops.gather_getindex(a, indices_list)
163-
return Ops.reshape(res, length.(indices)...)
174+
res = Ops.reshape(res, length.(idxs)...)
175+
ddims = findall(indices) do idx
176+
return idx isa Integer || idx isa TracedRNumber{<:Integer}
177+
end
178+
isempty(ddims) || return materialize_traced_array(dropdims(res; dims=Tuple(ddims)))
179+
return res
164180
end
165181

166182
start_indices = map(indices) do i
@@ -172,7 +188,9 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
172188
)
173189

174190
x = TracedRArray{T,N}((), res, Tuple(length.(indices)))
175-
ddims = findall(Base.Fix2(isa, Integer), indices)
191+
ddims = findall(indices) do idx
192+
return idx isa Integer || idx isa TracedRNumber{<:Integer}
193+
end
176194
isempty(ddims) || return materialize_traced_array(dropdims(x; dims=Tuple(ddims)))
177195
return x
178196
end
@@ -306,6 +324,11 @@ TracedUtils.promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} = TracedRArra
306324
function TracedUtils.promote_to(::TracedRArray{T,N}, rhs) where {T,N}
307325
return TracedUtils.promote_to(TracedRArray{T,N}, rhs)
308326
end
327+
function TracedUtils.promote_to(
328+
::Type{TracedRArray{T,0}}, rhs::TracedRNumber{T2}
329+
) where {T,T2}
330+
return TracedRArray{T,0}((), Ops.convert(TracedRNumber{T}, rhs).mlir_data, ())
331+
end
309332

310333
for (jlop, hloop, hlocomp, merge) in
311334
((:(Base.:(==)), :compare, "EQ", :all), (:(Base.:(!=)), :compare, "NE", :any))

src/TracedUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ end
374374
function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
375375
if all(iszero ndims, args)
376376
scalar_args = map(args) do arg
377-
return promote_to(TracedRNumber{eltype(arg)}, arg)
377+
return promote_to(TracedRNumber{Reactant.unwrapped_eltype(arg)}, arg)
378378
end
379379
return f(scalar_args...)
380380
end

test/basic.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,3 +1051,16 @@ end
10511051
@test isfinite(Reactant.to_rarray(0.0; track_numbers=Number))
10521052
@test !isfinite(Reactant.to_rarray(Inf; track_numbers=Number))
10531053
end
1054+
1055+
@testset "inconsistent indexing" begin
1056+
x_ra = Reactant.to_rarray(rand(3, 4, 3))
1057+
idx_ra = Reactant.to_rarray(1; track_numbers=Number)
1058+
1059+
fn1(x) = x[:, :, 1]
1060+
fn2(x, idx) = x[:, :, idx]
1061+
fn3(x, idx) = x[idx, :, 1]
1062+
1063+
@test ndims(@jit(fn1(x_ra))) == 2
1064+
@test ndims(@jit(fn2(x_ra, idx_ra))) == 2
1065+
@test ndims(@jit(fn3(x_ra, idx_ra))) == 1
1066+
end

0 commit comments

Comments
 (0)