Skip to content

Commit a57d94f

Browse files
committed
feat: split out non-generator changes from #1642
1 parent 1149cb0 commit a57d94f

File tree

5 files changed

+62
-13
lines changed

5 files changed

+62
-13
lines changed

src/Reactant.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ use_overlayed_version(::TracedRArray) = true
187187
use_overlayed_version(::TracedRNumber) = true
188188
use_overlayed_version(::Number) = false
189189
use_overlayed_version(::MissingTracedValue) = true
190+
use_overlayed_version(::Vector{<:AnyTracedRArray}) = true
190191
use_overlayed_version(::AbstractArray{<:TracedRNumber}) = true
191192
use_overlayed_version(rng::ReactantRNG) = use_overlayed_version(rng.seed)
192193
function use_overlayed_version(x::AbstractArray)

src/TracedRArray.jl

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,29 +1500,53 @@ end
15001500

15011501
struct BroadcastIterator{F}
15021502
f::F
1503+
1504+
BroadcastIterator{F}(f::F) where {F} = new{F}(f)
1505+
BroadcastIterator(f::F) where {F} = new{F}(f)
15031506
end
15041507

1505-
(fn::BroadcastIterator)(args...) = Reactant.call_with_reactant(fn.f, (args...,))
1508+
(fn::BroadcastIterator)(args...) = fn.f((args...,))
15061509

15071510
function unwrapped_broadcast(f::F, x::Base.Iterators.Zip) where {F}
15081511
min_length = Base.inferencebarrier(minimum)(length, x.is)
15091512
itrs = [length(itr) > min_length ? itr[1:min_length] : itr for itr in x.is]
15101513
if any(Base.Fix2(isa, AnyTracedRArray), itrs)
1511-
return (BroadcastIterator(f)).(itrs...)
1514+
return broadcast(BroadcastIterator(f), itrs...)
15121515
else
1513-
fn = BroadcastIterator(f)
1514-
return [fn(Base.Fix2(getindex, i).(itrs)...) for i in 1:min_length]
1516+
return unwrapped_broadcast_with_iterate(f, x)
15151517
end
15161518
end
15171519

15181520
function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate) where {F}
15191521
if x.itr isa AnyTracedRArray
1520-
return (BroadcastIterator(f)).(1:length(x.itr), x.itr)
1522+
return broadcast(
1523+
BroadcastIterator(f), Reactant.promote_to(TracedRArray, 1:length(x.itr)), x.itr
1524+
)
15211525
else
1522-
return [f((i, x.itr[i])) for i in 1:length(x.itr)]
1526+
return unwrapped_broadcast_with_iterate(f, x)
15231527
end
15241528
end
15251529

1526-
unwrapped_broadcast(f::F, xs::Vector) where {F} = [f(x) for x in xs]
1530+
unwrapped_broadcast(f::F, xs) where {F} = unwrapped_broadcast_with_iterate(f, xs)
1531+
1532+
function unwrapped_broadcast_with_iterate(f::F, itr) where {F}
1533+
y = Reactant.call_with_reactant(iterate, itr)
1534+
y === nothing && return []
1535+
1536+
first, state = y
1537+
res_first = @opcall call(f, first)
1538+
result = [res_first]
1539+
1540+
while true
1541+
y = Reactant.call_with_reactant(iterate, itr, state)
1542+
y === nothing && break
1543+
1544+
val, state = y
1545+
res = @opcall call(f, val)
1546+
push!(result, res)
1547+
end
1548+
1549+
return result
1550+
end
15271551

15281552
end

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
struct CallWithReactant{F}
1+
struct CallWithReactant{F} <: Function
22
f::F
33
end
44

test/autodiff.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ end
132132

133133
@testset "Forward Gradient" begin
134134
x = Reactant.to_rarray(3.1 * ones(2, 2))
135-
res = @test_warn r"`Adapt.parent_type` is not implemented for" @jit gw(x)
135+
res = @jit gw(x)
136136
# TODO we should probably override https://github.com/EnzymeAD/Enzyme.jl/blob/5e6a82dd08e74666822b9d7b2b46c36b075668ca/src/Enzyme.jl#L2132
137137
# to make sure this gets merged as a tracedrarray
138138
@test res isa Tuple{<:Enzyme.TupleArray{<:ConcreteRNumber{Float64},(2, 2),4,2}}

test/basic.jl

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -926,7 +926,7 @@ end
926926

927927
ra = Reactant.to_rarray(x)
928928
@jit dip!(ra)
929-
ra[:a] (2.7 * 2) * ones(4)
929+
@test ra[:a] (2.7 * 3.1) * ones(4)
930930
end
931931

932932
@testset "@code_xla" begin
@@ -1429,7 +1429,10 @@ end
14291429
end
14301430

14311431
zip_iterator(a, b) = mapreduce(splat(*), +, zip(a, b))
1432+
zip_iterator2(a, b) = mapreduce(splat(.-), +, zip(a, b))
14321433
enumerate_iterator(a) = mapreduce(splat(*), +, enumerate(a))
1434+
enumerate_iterator2(a) = mapreduce(splat(.-), +, enumerate(a))
1435+
mapreduce_vector(a) = mapreduce(-, +, a)
14331436

14341437
function nested_mapreduce_zip(x, y)
14351438
return mapreduce(+, zip(eachcol(x), eachcol(y)); init=0.0f0) do (x, y)
@@ -1448,18 +1451,30 @@ end
14481451
@testset "Base.Iterators" begin
14491452
@testset "zip" begin
14501453
N = 10
1451-
a = range(1.0, 5.0; length=N)
1452-
x = range(10.0, 15.0; length=N + 2)
1454+
a = collect(range(1.0, 5.0; length=N))
1455+
x = collect(range(10.0, 15.0; length=N + 2))
14531456
x_ra = Reactant.to_rarray(x)
14541457

14551458
@test @jit(zip_iterator(a, x_ra)) zip_iterator(a, x)
1459+
1460+
a = [rand(Float32, 2, 3) for _ in 1:10]
1461+
x = [rand(Float32, 2, 3) for _ in 1:10]
1462+
a_ra = Reactant.to_rarray(a)
1463+
x_ra = Reactant.to_rarray(x)
1464+
1465+
@test @jit(zip_iterator2(a_ra, x_ra)) zip_iterator2(a, x)
14561466
end
14571467

14581468
@testset "enumerate" begin
1459-
x = range(1.0, 5.0; length=10)
1469+
x = collect(range(1.0, 5.0; length=10))
14601470
x_ra = Reactant.to_rarray(x)
14611471

14621472
@test @jit(enumerate_iterator(x_ra)) enumerate_iterator(x)
1473+
1474+
x = [rand(Float32, 2, 3) for _ in 1:10]
1475+
x_ra = Reactant.to_rarray(x)
1476+
1477+
@test @jit(enumerate_iterator2(x_ra)) enumerate_iterator2(x)
14631478
end
14641479

14651480
@testset "nested mapreduce" begin
@@ -1481,6 +1496,15 @@ end
14811496

14821497
@test @jit(nested_mapreduce_hcat(x_ra, y_ra)) nested_mapreduce_hcat(x, y)
14831498
end
1499+
1500+
@testset "mapreduce vector" begin
1501+
x = [rand(Float32, 2, 3) for _ in 1:10]
1502+
x_ra = Reactant.to_rarray(x)
1503+
1504+
@test @jit(mapreduce_vector(x_ra)) mapreduce_vector(x)
1505+
hlo = repr(@code_hlo optimize = false mapreduce_vector(x_ra))
1506+
@test contains(hlo, "call")
1507+
end
14841508
end
14851509

14861510
@testset "compilation cache" begin

0 commit comments

Comments
 (0)