@@ -234,15 +234,36 @@ function Base.setindex!(
234234 return a
235235end
236236
237+ # Avoid ambiguity
238+ function Base. setindex! (
239+ a:: TracedRArray{T,1} , v, indices:: Union{Int,TracedRNumber{Int}}
240+ ) where {T}
241+ GPUArraysCore. assertscalar (
242+ " setindex!(::TracedRArray, v, ::Union{Int, TracedRNumber{Int}})"
243+ )
244+ if indices isa Int
245+ indices = TracedUtils. promote_to (TracedRNumber{Int}, indices)
246+ end
247+ indices = scalar_index_to_cartesian (
248+ TracedUtils. broadcast_to_size (indices, (1 ,)), size (a)
249+ )
250+ v = v isa Number ? v : vec (v)
251+ res = Ops. scatter_setindex (a, indices, TracedUtils. broadcast_to_size (v, (1 ,)))
252+ set_mlir_data! (a, get_mlir_data (res))
253+ return a
254+ end
255+
237256function Base. setindex! (a:: TracedRArray{T,N} , v, indices) where {T,N}
238257 if ! (indices isa TracedRArray)
239258 indices = collect (indices)
240259 eltype (indices) <: CartesianIndex && (indices = LinearIndices (size (a))[indices])
241260 indices = TracedUtils. promote_to (TracedRArray{Int,ndims (indices)}, indices)
242261 end
243- v = v isa Number ? v : vec (v)
244- v = materialize_traced_array (TracedUtils. broadcast_to_size (v, size (a)))
245- res = Ops. scatter_setindex (a, scalar_index_to_cartesian (vec (indices), size (a)), v)
262+ res = Ops. scatter_setindex (
263+ a,
264+ scalar_index_to_cartesian (vec (indices), size (a)),
265+ materialize_traced_array (vec (v)),
266+ )
246267 set_mlir_data! (a, get_mlir_data (res))
247268 return a
248269end
@@ -299,7 +320,7 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
299320 indices_list = map (Base. Fix1 (TracedUtils. promote_to, TracedRArray{Int,1 }), indices)
300321 indices_list = generate_index_list (indices_list... )
301322 res = Ops. scatter_setindex (a, indices_list, Ops. reshape (v, length (v)))
302- a . mlir_data = res. mlir_data
323+ set_mlir_data! (a, get_mlir_data ( res))
303324 return v
304325 end
305326
@@ -330,7 +351,7 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
330351 ),
331352 1 ,
332353 )
333- a . mlir_data = res
354+ set_mlir_data! (a, res)
334355 return v
335356end
336357
0 commit comments