From 3a266c9bb3abaae6ec33ad596e1b8cfe33dd4993 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 7 Sep 2025 13:14:07 -0400 Subject: [PATCH 1/6] feat: better support for Base.Generators --- test/basic.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/test/basic.jl b/test/basic.jl index 03778fa6c0..a63bda0765 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -1598,3 +1598,20 @@ end x_ra = Reactant.to_rarray(x) @test @jit(clamp!(x_ra, 0.5, Inf32)) ≈ clamp!(x, 0.5, Inf32) end + +@testset "Base.Generator" begin end + +using Reactant + +points = [rand(Float32, 2), rand(Float32, 2)] +params = rand(Float32, 4, 2) +points_ra = Reactant.to_rarray(points) +params_ra = Reactant.to_rarray(params) + +function f_generator(points, params) + gen = (params * point for point in points) + @show typeof(gen) + return sum(gen) +end + +@code_hlo f_generator(points_ra, params_ra) From d76936e2e2963ea7b978bc945982a6aa4e4df66a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 7 Sep 2025 16:00:35 -0400 Subject: [PATCH 2/6] feat: use traced_call when unrolling iterators and generators --- src/Overlay.jl | 5 ++++- src/Reactant.jl | 19 ++++++++++++++++--- src/TracedRArray.jl | 6 +++++- test/basic.jl | 22 +++++++++------------- 4 files changed, 34 insertions(+), 18 deletions(-) diff --git a/src/Overlay.jl b/src/Overlay.jl index ec92063dd7..951ad3c7c1 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -156,7 +156,10 @@ end end @reactant_overlay @noinline function Base.mapreduce( - f, op, A::Union{AbstractArray,Base.Iterators.Zip,Base.Iterators.Enumerate}; kwargs... + f, + op, + A::Union{AbstractArray,Base.Iterators.Zip,Base.Iterators.Enumerate,Base.Generator}; + kwargs..., ) if use_overlayed_version(A) return TracedRArrayOverrides.overloaded_mapreduce(f, op, A; kwargs...) diff --git a/src/Reactant.jl b/src/Reactant.jl index 25d62f6464..a66c550f6a 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -179,10 +179,23 @@ include("TracedRArray.jl") include("ConcreteRArray.jl") use_overlayed_version(x) = false -use_overlayed_version(x::Base.Iterators.Zip) = any(use_overlayed_version, x.is) +function use_overlayed_version(x::F) where {F<:Function} + return Base.inferencebarrier(any)( + use_overlayed_version, getfield.(Ref(x), fieldnames(F)) + ) +end +function use_overlayed_version(x::Base.Generator) + return use_overlayed_version(x.f) || use_overlayed_version(x.iter) +end +function use_overlayed_version(x::Base.Iterators.Zip) + return Base.inferencebarrier(any)(use_overlayed_version, x.is) +end use_overlayed_version(x::Base.Iterators.Enumerate) = use_overlayed_version(x.itr) -use_overlayed_version(iter::Tuple) = any(use_overlayed_version, iter) -use_overlayed_version(iter::NamedTuple) = any(use_overlayed_version, values(iter)) +use_overlayed_version(x::Vector) = Base.inferencebarrier(any)(use_overlayed_version, x) +use_overlayed_version(iter::Tuple) = Base.inferencebarrier(any)(use_overlayed_version, iter) +function use_overlayed_version(iter::NamedTuple) + return Base.inferencebarrier(any)(use_overlayed_version, values(iter)) +end use_overlayed_version(::TracedRArray) = true use_overlayed_version(::TracedRNumber) = true use_overlayed_version(::Number) = false diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 033329e6df..111dc04fd9 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -686,7 +686,7 @@ function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle}) elseif ElType == Any ElType = eltype(fn(map(first_scalar, bc.args)...)) end - @assert ElType != Any && ElType != Union{} + @assert ElType != Union{} && ElType != Any sim = similar(bc, ElType) return copyto!(sim, bc) end @@ -1527,6 +1527,10 @@ function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate) where {F} end end +function unwrapped_broadcast(f::F, x::Base.Generator) where {F} + return unwrapped_broadcast_with_iterate(f, Base.Generator(TracedCall(x.f), x.iter)) +end + unwrapped_broadcast(f::F, xs) where {F} = unwrapped_broadcast_with_iterate(f, xs) function unwrapped_broadcast_with_iterate(f::F, itr) where {F} diff --git a/test/basic.jl b/test/basic.jl index a63bda0765..54068a9eec 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -1599,19 +1599,15 @@ end @test @jit(clamp!(x_ra, 0.5, Inf32)) ≈ clamp!(x, 0.5, Inf32) end -@testset "Base.Generator" begin end - -using Reactant +@testset "Base.Generator" begin + points = [rand(Float32, 2) for _ in 1:5] + params = rand(Float32, 4, 2) + points_ra = Reactant.to_rarray(points) + params_ra = Reactant.to_rarray(params) -points = [rand(Float32, 2), rand(Float32, 2)] -params = rand(Float32, 4, 2) -points_ra = Reactant.to_rarray(points) -params_ra = Reactant.to_rarray(params) + function f_generator(points, params) + return sum(params * point for point in points) + end -function f_generator(points, params) - gen = (params * point for point in points) - @show typeof(gen) - return sum(gen) + @code_hlo f_generator(points_ra, params_ra) end - -@code_hlo f_generator(points_ra, params_ra) From 8d4f416fd7c17afa49cf290e0a4faa674f3293a8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 7 Sep 2025 19:59:42 -0400 Subject: [PATCH 3/6] fix: closure with call working --- src/Reactant.jl | 21 +++++++++++++-------- test/basic.jl | 4 ++-- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index a66c550f6a..4b942ccf37 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -178,22 +178,27 @@ include("TracedRArray.jl") include("ConcreteRArray.jl") -use_overlayed_version(x) = false -function use_overlayed_version(x::F) where {F<:Function} +Base.@nospecializeinfer use_overlayed_version(x) = false +Base.@nospecializeinfer function use_overlayed_version( + @nospecialize(x::F) +) where {F<:Function} return Base.inferencebarrier(any)( use_overlayed_version, getfield.(Ref(x), fieldnames(F)) ) end -function use_overlayed_version(x::Base.Generator) +Base.@nospecializeinfer function use_overlayed_version(@nospecialize(x::Base.Generator)) return use_overlayed_version(x.f) || use_overlayed_version(x.iter) end -function use_overlayed_version(x::Base.Iterators.Zip) +Base.@nospecializeinfer function use_overlayed_version(@nospecialize(x::Base.Iterators.Zip)) return Base.inferencebarrier(any)(use_overlayed_version, x.is) end -use_overlayed_version(x::Base.Iterators.Enumerate) = use_overlayed_version(x.itr) -use_overlayed_version(x::Vector) = Base.inferencebarrier(any)(use_overlayed_version, x) -use_overlayed_version(iter::Tuple) = Base.inferencebarrier(any)(use_overlayed_version, iter) -function use_overlayed_version(iter::NamedTuple) +Base.@nospecializeinfer use_overlayed_version(@nospecialize(x::Base.Iterators.Enumerate)) = + use_overlayed_version(x.itr) +Base.@nospecializeinfer use_overlayed_version(@nospecialize(x::Vector)) = + Base.inferencebarrier(any)(use_overlayed_version, x) +Base.@nospecializeinfer use_overlayed_version(@nospecialize(iter::Tuple)) = + Base.inferencebarrier(any)(use_overlayed_version, iter) +Base.@nospecializeinfer function use_overlayed_version(@nospecialize(iter::NamedTuple)) return Base.inferencebarrier(any)(use_overlayed_version, values(iter)) end use_overlayed_version(::TracedRArray) = true diff --git a/test/basic.jl b/test/basic.jl index 54068a9eec..c27e91d88b 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -1600,7 +1600,7 @@ end end @testset "Base.Generator" begin - points = [rand(Float32, 2) for _ in 1:5] + points = eachcol(rand(Float32, 2, 6)) params = rand(Float32, 4, 2) points_ra = Reactant.to_rarray(points) params_ra = Reactant.to_rarray(params) @@ -1609,5 +1609,5 @@ end return sum(params * point for point in points) end - @code_hlo f_generator(points_ra, params_ra) + @test @jit(f_generator(points_ra, params_ra)) ≈ f_generator(points, params) end From 9cd90127055fa5f15f39954b937bafda20eb528f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Sep 2025 16:26:45 -0400 Subject: [PATCH 4/6] fix: try removing nospecialize --- src/Reactant.jl | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index 4b942ccf37..d9b03e4ff6 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -178,27 +178,28 @@ include("TracedRArray.jl") include("ConcreteRArray.jl") -Base.@nospecializeinfer use_overlayed_version(x) = false -Base.@nospecializeinfer function use_overlayed_version( - @nospecialize(x::F) -) where {F<:Function} +use_overlayed_version(x) = false +function use_overlayed_version(x::F) where {F<:Function} return Base.inferencebarrier(any)( use_overlayed_version, getfield.(Ref(x), fieldnames(F)) ) end -Base.@nospecializeinfer function use_overlayed_version(@nospecialize(x::Base.Generator)) +function use_overlayed_version(x::Base.Generator) return use_overlayed_version(x.f) || use_overlayed_version(x.iter) end -Base.@nospecializeinfer function use_overlayed_version(@nospecialize(x::Base.Iterators.Zip)) +function use_overlayed_version(x::Base.Iterators.Zip) return Base.inferencebarrier(any)(use_overlayed_version, x.is) end -Base.@nospecializeinfer use_overlayed_version(@nospecialize(x::Base.Iterators.Enumerate)) = - use_overlayed_version(x.itr) -Base.@nospecializeinfer use_overlayed_version(@nospecialize(x::Vector)) = - Base.inferencebarrier(any)(use_overlayed_version, x) -Base.@nospecializeinfer use_overlayed_version(@nospecialize(iter::Tuple)) = - Base.inferencebarrier(any)(use_overlayed_version, iter) -Base.@nospecializeinfer function use_overlayed_version(@nospecialize(iter::NamedTuple)) +function use_overlayed_version(x::Base.Iterators.Enumerate) + return use_overlayed_version(x.itr) +end +function use_overlayed_version(x::Vector) + return Base.inferencebarrier(any)(use_overlayed_version, x) +end +function use_overlayed_version(iter::Tuple) + return Base.inferencebarrier(any)(use_overlayed_version, iter) +end +function use_overlayed_version(iter::NamedTuple) return Base.inferencebarrier(any)(use_overlayed_version, values(iter)) end use_overlayed_version(::TracedRArray) = true From 8e700b69329720898065a46853b5b5f3a9b32144 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Sep 2025 22:48:03 -0400 Subject: [PATCH 5/6] fix: use a looped version of any to avoid inference issues --- src/Reactant.jl | 38 ++++++++++++++++---------------------- src/TracedRArray.jl | 2 +- 2 files changed, 17 insertions(+), 23 deletions(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index d9b03e4ff6..9f373f1a05 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -180,33 +180,18 @@ include("ConcreteRArray.jl") use_overlayed_version(x) = false function use_overlayed_version(x::F) where {F<:Function} - return Base.inferencebarrier(any)( - use_overlayed_version, getfield.(Ref(x), fieldnames(F)) - ) -end -function use_overlayed_version(x::Base.Generator) - return use_overlayed_version(x.f) || use_overlayed_version(x.iter) -end -function use_overlayed_version(x::Base.Iterators.Zip) - return Base.inferencebarrier(any)(use_overlayed_version, x.is) -end -function use_overlayed_version(x::Base.Iterators.Enumerate) - return use_overlayed_version(x.itr) -end -function use_overlayed_version(x::Vector) - return Base.inferencebarrier(any)(use_overlayed_version, x) -end -function use_overlayed_version(iter::Tuple) - return Base.inferencebarrier(any)(use_overlayed_version, iter) -end -function use_overlayed_version(iter::NamedTuple) - return Base.inferencebarrier(any)(use_overlayed_version, values(iter)) + return use_overlayed_version(getfield.(Ref(x), fieldnames(F))) end +use_overlayed_version(x::Base.Generator) = use_overlayed_version((x.f, x.iter)) +use_overlayed_version(x::Base.Iterators.Zip) = use_overlayed_version(x.is) +use_overlayed_version(x::Base.Iterators.Enumerate) = use_overlayed_version(x.itr) +use_overlayed_version(x::Vector) = looped_any(use_overlayed_version, x) +use_overlayed_version(iter::Tuple) = looped_any(use_overlayed_version, iter) +use_overlayed_version(iter::NamedTuple) = looped_any(use_overlayed_version, values(iter)) use_overlayed_version(::TracedRArray) = true use_overlayed_version(::TracedRNumber) = true use_overlayed_version(::Number) = false use_overlayed_version(::MissingTracedValue) = true -use_overlayed_version(::Vector{<:AnyTracedRArray}) = true use_overlayed_version(::AbstractArray{<:TracedRNumber}) = true use_overlayed_version(rng::ReactantRNG) = use_overlayed_version(rng.seed) function use_overlayed_version(x::AbstractArray) @@ -215,7 +200,16 @@ function use_overlayed_version(x::AbstractArray) return use_overlayed_version(a) end +## We avoid calling into `any` to avoid triggering the `any` overlay +function looped_any(f::F, itr) where {F} + @inbounds for x in itr + f(x) && return true + end + return false +end + # StdLib Overloads + include("stdlibs/LinearAlgebra.jl") include("stdlibs/Random.jl") include("stdlibs/Base.jl") diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 111dc04fd9..519115f05d 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -686,7 +686,7 @@ function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle}) elseif ElType == Any ElType = eltype(fn(map(first_scalar, bc.args)...)) end - @assert ElType != Union{} && ElType != Any + @assert ElType != Any && ElType != Union{} sim = similar(bc, ElType) return copyto!(sim, bc) end From 4f1524b202d2e5dcaf7c8c3040f590a1e4d1c78b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 13 Sep 2025 00:11:12 -0400 Subject: [PATCH 6/6] fix: dont overlay inside compiler call --- src/Reactant.jl | 2 +- src/TracedUtils.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index 9f373f1a05..da18116d15 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -60,6 +60,7 @@ function _parent end _parent_type(::Type{Array}) = Array _parent_type(::Type{Array{T}}) where {T} = Array{T} _parent_type(::Type{Array{T,N}}) where {T,N} = Array{T,N} +_parent_type(::Type{<:Slices{P}}) where {P} = P include("accelerators/Accelerators.jl") @@ -209,7 +210,6 @@ function looped_any(f::F, itr) where {F} end # StdLib Overloads - include("stdlibs/LinearAlgebra.jl") include("stdlibs/Random.jl") include("stdlibs/Base.jl") diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 554791af8f..98b83ea006 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -660,7 +660,7 @@ function finalize_mlir_fn( skipped_results = Reactant.TracedType[] for (k, v) in seen_results v isa Reactant.TracedType || continue - if any(Base.Fix1(===, k), skipped_args) + if Reactant.looped_any(Base.Fix1(===, k), skipped_args) push!(skipped_results, v) _, argpath = get_argidx(v, argprefix)