Skip to content

Commit f013943

Browse files
committed
feat: support Base.stack
1 parent 9bf471b commit f013943

File tree

4 files changed

+38
-2
lines changed

4 files changed

+38
-2
lines changed

src/Overlay.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ for (cT, aT, bT) in (
127127
@reactant_overlay @noinline function LinearAlgebra.mul!(
128128
C::$cT, A::$aT, B::$bT, α::Number, β::Number
129129
)
130-
if any(Base.Fix2(isa, TracedRArray) ancestor, (C, A, B))
130+
if use_overlayed_version((C, A, B))
131131
TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β)
132132
else
133133
LinearAlgebra._mul!(C, A, B, α, β)
@@ -142,3 +142,12 @@ for (cT, aT, bT) in (
142142
end
143143
end
144144
end
145+
146+
# Base overloads
147+
@reactant_overlay @noinline function Base._stack(dims::Union{Integer,Colon}, iter)
148+
if use_overlayed_version(iter)
149+
return TracedRArrayOverrides.overloaded_stack(dims, iter)
150+
else
151+
return Base._stack(dims, Base.IteratorSize(iter), iter)
152+
end
153+
end

src/Reactant.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,20 @@ mutable struct TracedRNG <: Random.AbstractRNG
185185
const algorithm::String
186186
end
187187

188+
use_overlayed_version(iter) = any(use_overlayed_version, iter)
189+
190+
use_overlayed_version(::TracedRArray) = true
191+
use_overlayed_version(::TracedRNumber) = true
192+
use_overlayed_version(::Number) = false
193+
use_overlayed_version(::MissingTracedValue) = true
194+
use_overlayed_version(::TracedRNG) = true
195+
196+
function use_overlayed_version(x::AbstractArray)
197+
a = ancestor(x)
198+
a === x && return false
199+
return use_overlayed_version(a)
200+
end
201+
188202
# StdLib Overloads
189203
include("stdlibs/LinearAlgebra.jl")
190204
include("stdlibs/Random.jl")

src/TracedRArray.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,4 +605,17 @@ function Base._RepeatInnerOuter.repeat_inner(
605605
return materialize_traced_array(reshape(x_broadcasted, final_size...))
606606
end
607607

608+
# stack
609+
function overloaded_stack(dims::Union{Integer,Colon}, xs)
610+
@assert allequal(ndims, xs) "All arrays must have the same number of dimensions..."
611+
dims = dims isa Colon ? ndims(first(xs)) + 1 : dims
612+
res = map(xs) do x
613+
new_shape = ntuple(
614+
i -> i == dims ? 1 : (i < dims ? size(x, i) : size(x, i - 1)), ndims(x) + 1
615+
)
616+
return materialize_traced_array(reshape(x, new_shape))
617+
end
618+
return cat(res...; dims)
619+
end
620+
608621
end

src/stdlibs/LinearAlgebra.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T}
269269
# <unknown>:0: note: see current operation: %0 = "tensor.empty"() : () -> tensor<0xf64>
270270
length(indices) 0 && return TracedUtils.promote_to(TracedRArray{T,1}, T[])
271271

272-
return Ops.gather_getindex(x, TracedUtils.promote_to(TracedRArray{Int,2}, indices))
272+
return Ops.gather_getindex(y, TracedUtils.promote_to(TracedRArray{Int,2}, indices))
273273
end
274274

275275
function LinearAlgebra._diagm(

0 commit comments

Comments
 (0)