Skip to content

Commit b0bdce7

Browse files
committed
feat: add zero and fill! for ConcreteRArray
1 parent 6571d54 commit b0bdce7

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

src/ConcreteRArray.jl

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N}
242242
return convert(Array, a)[args...]
243243
end
244244

245-
function mysetindex!(a, v, args::Vararg{Int,N}) where {N}
245+
function mysetindex!(a, v, args::Vararg{Any,N}) where {N}
246246
setindex!(a, v, args...)
247247
return nothing
248248
end
@@ -353,3 +353,32 @@ end
353353
function Ops.constant(x::ConcreteRNumber{T}; kwargs...) where {T}
354354
return Ops.constant(Base.convert(T, x); kwargs...)
355355
end
356+
357+
Base.zero(x::ConcreteRArray{T,N}) where {T,N} = ConcreteRArray(zeros(T, size(x)...))
358+
359+
function Base.fill!(a::ConcreteRArray{T,N}, val) where {T,N}
360+
if a.data == XLA.AsyncEmptyBuffer
361+
throw("Cannot setindex! to empty buffer")
362+
end
363+
364+
XLA.await(a.data)
365+
if buffer_on_cpu(a)
366+
buf = a.data.buffer
367+
GC.@preserve buf begin
368+
ptr = Base.unsafe_convert(Ptr{T}, XLA.UnsafeBufferPointer(buf))
369+
start = 0
370+
for i in 1:N
371+
start *= size(a, N - i + 1)
372+
start += (args[N - i + 1] - 1)
373+
end
374+
start += 1
375+
unsafe_store!(ptr, val, start)
376+
end
377+
return a
378+
end
379+
380+
idxs = ntuple(Returns(Colon()), N)
381+
fn = compile(mysetindex!, (a, val, idxs...,))
382+
fn(a, val, idxs...)
383+
return a
384+
end

0 commit comments

Comments
 (0)