diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 0d73b10936..0b51791381 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -730,3 +730,27 @@ end Base.all(f::Function, x::AnyTracedRArray) = mapreduce(f, &, x) Base.any(f::Function, x::AnyTracedRArray) = mapreduce(f, |, x) + +# outer repeat +function Base.repeat(x::AnyTracedRArray{T,N}, counts::Vararg{Int,M}) where {T,N,M} + P = max(N, M) # potentially padded + + # (d1, d2, ..., dP) -> (d1, 1, d2, 1, ..., dP, 1) + interleaved_size = ones(Int, 2P) + interleaved_size[1:2:2N] .= size(x) + + x_interleaved = reshape(x, interleaved_size...) + + # (d1, 1, d2, 1, ..., dP, 1) -> (d1, r1, d2, r2, ..., dP, rP) + broadcast_target_size = interleaved_size + broadcast_target_size[2:2:2M] .= counts + + x_broadcasted = broadcast_to_size(x_interleaved, broadcast_target_size) + + # (d1, r1, d2, r2, ..., dP, rP) -> (d1*r1, d2*r2, ..., dP*rP) + final_size = vec(prod(reshape(broadcast_target_size, 2, :), dims=1)) + + x_final = reshape(x_broadcasted, final_size...) + + return x_final +end diff --git a/test/basic.jl b/test/basic.jl index 3aef500201..9c246ac818 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -364,6 +364,16 @@ end end end +@testset "repeat" begin + @testset for (size, counts) in Iterators.product( + [(2,), (2,3), (2,3,4), (2,3,4,5)], + [(), (1,), (2,), (2,1), (1,2), (2,2), (2,2,2), (1,1,1,1,1)] + ) + x = rand(size...) + @test (@jit repeat(Reactant.to_rarray(x), counts...)) == repeat(x, counts...) + end +end + function update_on_copy(x) y = x[1:2, 2:4, :] y[1:1, 1:1, :] = ones(1, 1, 3)