Skip to content

Commit 38916f5

Browse files
authored
feat: add zero and fill! for ConcreteRArray (#420)
* feat: add zero and fill! for ConcreteRArray * test: add tests
1 parent 6571d54 commit 38916f5

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

src/ConcreteRArray.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N}
242242
return convert(Array, a)[args...]
243243
end
244244

245-
function mysetindex!(a, v, args::Vararg{Int,N}) where {N}
245+
function mysetindex!(a, v, args::Vararg{Any,N}) where {N}
246246
setindex!(a, v, args...)
247247
return nothing
248248
end
@@ -353,3 +353,28 @@ end
353353
function Ops.constant(x::ConcreteRNumber{T}; kwargs...) where {T}
354354
return Ops.constant(Base.convert(T, x); kwargs...)
355355
end
356+
357+
Base.zero(x::ConcreteRArray{T,N}) where {T,N} = ConcreteRArray(zeros(T, size(x)...))
358+
359+
function Base.fill!(a::ConcreteRArray{T,N}, val) where {T,N}
360+
if a.data == XLA.AsyncEmptyBuffer
361+
throw("Cannot setindex! to empty buffer")
362+
end
363+
364+
XLA.await(a.data)
365+
if buffer_on_cpu(a)
366+
buf = a.data.buffer
367+
GC.@preserve buf begin
368+
ptr = Base.unsafe_convert(Ptr{T}, XLA.UnsafeBufferPointer(buf))
369+
for start in 1:length(a)
370+
unsafe_store!(ptr, val, start)
371+
end
372+
end
373+
return a
374+
end
375+
376+
idxs = ntuple(Returns(Colon()), N)
377+
fn = compile(mysetindex!, (a, val, idxs...))
378+
fn(a, val, idxs...)
379+
return a
380+
end

test/basic.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,3 +664,15 @@ end
664664
ifelse(ConcreteRNumber(false), ConcreteRNumber(1.0f0), ConcreteRNumber(0.0f0))
665665
) isa ConcreteRNumber{Float32}
666666
end
667+
668+
@testset "fill! and zero on ConcreteRArray" begin
669+
x_ra = Reactant.to_rarray(rand(3, 4))
670+
671+
z = zero(x_ra)
672+
@test z isa ConcreteRArray
673+
@test size(z) == size(x_ra)
674+
@test all(iszero, Array(z))
675+
676+
fill!(z, 1.0)
677+
@test all(==(1.0), Array(z))
678+
end

0 commit comments

Comments
 (0)