diff --git a/src/Tracing.jl b/src/Tracing.jl index 0c3cec5088..d0668913c9 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -55,7 +55,7 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(track_numbers::Type), @nospecialize(sharding) ) - if Mode == ArrayToConcrete && T <: track_numbers + if mode == ArrayToConcrete && T <: track_numbers return ConcretePJRTNumber{ T,Sharding.ndevices(sharding),Sharding.shard_type(typeof(sharding), 0) } diff --git a/test/tracing.jl b/test/tracing.jl index 59c5eeb30b..2990b4f06c 100644 --- a/test/tracing.jl +++ b/test/tracing.jl @@ -15,6 +15,17 @@ struct Wrapper{A,B} b::B end +struct Descent{T} + eta::T +end + +struct RMSProp{Teta,Trho,Teps,C<:Bool} + eta::Teta + rho::Trho + epsilon::Teps + centred::C +end + @testset "Tracing" begin @testset "trace_type" begin @testset "mode = ConcreteToTraced" begin @@ -242,4 +253,17 @@ end st_traced = Reactant.to_rarray(st; track_numbers=Number) @test st_traced.training isa Val{true} end + + @testset "to_rarray(::AbstractRule)" begin + opt = Descent(0.1) + opt_traced = Reactant.to_rarray(opt; track_numbers=AbstractFloat) + @test opt_traced.eta isa ConcreteRNumber{Float64} + + opt = RMSProp(0.1, 0.9, 1e-8, true) + opt_traced = Reactant.to_rarray(opt; track_numbers=AbstractFloat) + @test opt_traced.eta isa ConcreteRNumber{Float64} + @test opt_traced.rho isa ConcreteRNumber{Float64} + @test opt_traced.epsilon isa ConcreteRNumber{Float64} + @test opt_traced.centred isa Bool + end end