@@ -65,6 +65,9 @@ function Base.getindex(
6565end
6666
6767Base. getindex (a:: TracedRArray{T,0} ) where {T} = TracedRNumber {T} ((), a. mlir_data)
68+ function Base. getindex (a:: TracedRArray{T,0} , :: CartesianIndex{0} ) where {T}
69+ return TracedRNumber {T} ((), a. mlir_data)
70+ end
6871
6972function generate_index_list (i1, is... )
7073 list = reshape (i1, :, 1 ) .- 1
@@ -123,7 +126,11 @@ function Base.getindex(a::TracedRArray{T,N}, indices::CartesianIndex{N}) where {
123126 indices =
124127 materialize_traced_array (
125128 reshape (
126- TracedUtils. promote_to (TracedRArray{Int,1 }, vcat (Tuple (indices)... )), 1 , N
129+ TracedUtils. promote_to (
130+ TracedRArray{Int,1 }, collect (Int64, vcat (Tuple (indices)... ))
131+ ),
132+ 1 ,
133+ N,
127134 ),
128135 ) .- 1
129136 return Ops. gather_getindex (a, indices)[1 ]
@@ -157,10 +164,19 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
157164 if any (i -> unwrapped_eltype (i) <: Bool , indices)
158165 error (" Boolean indexing with TracedRArrays isn't fully supported yet." )
159166 end
160- indices_list = map (Base. Fix1 (TracedUtils. promote_to, TracedRArray{Int,1 }), indices)
167+ idxs = map (indices) do i
168+ i isa Number && return fill (i, 1 )
169+ return i
170+ end
171+ indices_list = map (Base. Fix1 (TracedUtils. promote_to, TracedRArray{Int,1 }), idxs)
161172 indices_list = generate_index_list (indices_list... )
162173 res = Ops. gather_getindex (a, indices_list)
163- return Ops. reshape (res, length .(indices)... )
174+ res = Ops. reshape (res, length .(idxs)... )
175+ ddims = findall (indices) do idx
176+ return idx isa Integer || idx isa TracedRNumber{<: Integer }
177+ end
178+ isempty (ddims) || return materialize_traced_array (dropdims (res; dims= Tuple (ddims)))
179+ return res
164180 end
165181
166182 start_indices = map (indices) do i
@@ -172,7 +188,9 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
172188 )
173189
174190 x = TracedRArray {T,N} ((), res, Tuple (length .(indices)))
175- ddims = findall (Base. Fix2 (isa, Integer), indices)
191+ ddims = findall (indices) do idx
192+ return idx isa Integer || idx isa TracedRNumber{<: Integer }
193+ end
176194 isempty (ddims) || return materialize_traced_array (dropdims (x; dims= Tuple (ddims)))
177195 return x
178196end
@@ -306,6 +324,11 @@ TracedUtils.promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} = TracedRArra
306324function TracedUtils. promote_to (:: TracedRArray{T,N} , rhs) where {T,N}
307325 return TracedUtils. promote_to (TracedRArray{T,N}, rhs)
308326end
327+ function TracedUtils. promote_to (
328+ :: Type{TracedRArray{T,0}} , rhs:: TracedRNumber{T2}
329+ ) where {T,T2}
330+ return TracedRArray {T,0} ((), Ops. convert (TracedRNumber{T}, rhs). mlir_data, ())
331+ end
309332
310333for (jlop, hloop, hlocomp, merge) in
311334 ((:(Base.:(== )), :compare , " EQ" , :all ), (:(Base.:(!= )), :compare , " NE" , :any ))
0 commit comments