Skip to content

Commit 7b2bf0d

Browse files
committed
test: use updated API for the tests
1 parent dab1255 commit 7b2bf0d

File tree

3 files changed

+10
-5
lines changed

3 files changed

+10
-5
lines changed

src/Compiler.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,13 @@ import ..Reactant:
1919
@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field)
2020
@inline traced_getfield(@nospecialize(obj::AbstractArray), field) =
2121
Base.getindex(obj, field)
22+
@inline traced_getfield(@nospecialize(obj::RArray), field) = Base.getfield(obj, field)
2223

2324
@inline traced_setfield!(@nospecialize(obj), field, val) = Base.setfield!(obj, val, field)
2425
@inline traced_setfield!(@nospecialize(obj::AbstractArray), field, val) =
2526
Base.setindex!(obj, val, field)
27+
@inline traced_setfield!(@nospecialize(obj::RArray), field, val) =
28+
Base.setfield!(obj, val, field)
2629

2730
function create_result(tocopy::T, path, result_stores) where {T}
2831
if !isstructtype(typeof(tocopy))

src/Tracing.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, RArra
1212
end
1313
end
1414

15-
function traced_type(::Type{T}, seen, mode::Val{Mode}, track_numbers) where {T<:Number,Mode}
15+
function traced_type(
16+
::Type{T}, seen, mode::Val{Mode}, track_numbers
17+
) where {T<:Union{AbstractFloat,Integer},Mode}
1618
if Mode == ArrayToConcrete && any(Base.Fix1(<:, T), track_numbers)
1719
return ConcreteRNumber{T}
1820
end

test/tracing.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ using Test
9090
(Val{:x}, Val{:x}),
9191
]
9292
tracedty = traced_type(
93-
origty, Reactant.OrderedIdDict(), Val(ConcreteToTraced)
93+
origty, Reactant.OrderedIdDict(), Val(ConcreteToTraced), ()
9494
)
9595
@test tracedty == targetty
9696
end
@@ -102,21 +102,21 @@ using Test
102102
TracedRArray{Float64,3},
103103
]
104104
@test_throws Union{ErrorException,String} traced_type(
105-
type, Reactant.OrderedIdDict(), Val(ConcreteToTraced)
105+
type, Reactant.OrderedIdDict(), Val(ConcreteToTraced), ()
106106
)
107107
end
108108
end
109109
@testset "traced_type exceptions" begin
110110
@test_throws TracedTypeError Reactant.traced_type(
111-
Real, Reactant.OrderedIdDict(), Val(Reactant.ArrayToConcrete)
111+
Real, Reactant.OrderedIdDict(), Val(Reactant.ArrayToConcrete), ()
112112
)
113113

114114
struct Node
115115
x::Vector{Float64}
116116
y::Union{Nothing,Node}
117117
end
118118
@test_throws NoFieldMatchError traced_type(
119-
Node, Reactant.OrderedIdDict(), Val(ArrayToConcrete)
119+
Node, Reactant.OrderedIdDict(), Val(ArrayToConcrete), ()
120120
)
121121
end
122122
end

0 commit comments

Comments
 (0)