@@ -242,7 +242,7 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N}
242242 return convert (Array, a)[args... ]
243243end
244244
245- function mysetindex! (a, v, args:: Vararg{Int ,N} ) where {N}
245+ function mysetindex! (a, v, args:: Vararg{Any ,N} ) where {N}
246246 setindex! (a, v, args... )
247247 return nothing
248248end
353353function Ops. constant (x:: ConcreteRNumber{T} ; kwargs... ) where {T}
354354 return Ops. constant (Base. convert (T, x); kwargs... )
355355end
356+
357+ Base. zero (x:: ConcreteRArray{T,N} ) where {T,N} = ConcreteRArray (zeros (T, size (x)... ))
358+
359+ function Base. fill! (a:: ConcreteRArray{T,N} , val) where {T,N}
360+ if a. data == XLA. AsyncEmptyBuffer
361+ throw (" Cannot setindex! to empty buffer" )
362+ end
363+
364+ XLA. await (a. data)
365+ if buffer_on_cpu (a)
366+ buf = a. data. buffer
367+ GC. @preserve buf begin
368+ ptr = Base. unsafe_convert (Ptr{T}, XLA. UnsafeBufferPointer (buf))
369+ start = 0
370+ for i in 1 : N
371+ start *= size (a, N - i + 1 )
372+ start += (args[N - i + 1 ] - 1 )
373+ end
374+ start += 1
375+ unsafe_store! (ptr, val, start)
376+ end
377+ return a
378+ end
379+
380+ idxs = ntuple (Returns (Colon ()), N)
381+ fn = compile (mysetindex!, (a, val, idxs... ,))
382+ fn (a, val, idxs... )
383+ return a
384+ end
0 commit comments