Skip to content

Commit afbf199

Browse files
committed
test: use updated API for the tests
1 parent dab1255 commit afbf199

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

src/Compiler.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import ..Reactant:
1010
ConcreteRNumber,
1111
TracedRArray,
1212
TracedRNumber,
13+
RArray,
1314
OrderedIdDict,
1415
make_tracer,
1516
TracedToConcrete,
@@ -19,10 +20,13 @@ import ..Reactant:
1920
@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field)
2021
@inline traced_getfield(@nospecialize(obj::AbstractArray), field) =
2122
Base.getindex(obj, field)
23+
@inline traced_getfield(@nospecialize(obj::RArray), field) = Base.getfield(obj, field)
2224

23-
@inline traced_setfield!(@nospecialize(obj), field, val) = Base.setfield!(obj, val, field)
25+
@inline traced_setfield!(@nospecialize(obj), field, val) = Base.setfield!(obj, field, val)
2426
@inline traced_setfield!(@nospecialize(obj::AbstractArray), field, val) =
2527
Base.setindex!(obj, val, field)
28+
@inline traced_setfield!(@nospecialize(obj::RArray), field, val) =
29+
Base.setfield!(obj, field, val)
2630

2731
function create_result(tocopy::T, path, result_stores) where {T}
2832
if !isstructtype(typeof(tocopy))

src/Tracing.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, RArra
1212
end
1313
end
1414

15-
function traced_type(::Type{T}, seen, mode::Val{Mode}, track_numbers) where {T<:Number,Mode}
15+
function traced_type(
16+
::Type{T}, seen, mode::Val{Mode}, track_numbers
17+
) where {T<:Union{AbstractFloat,Integer},Mode}
1618
if Mode == ArrayToConcrete && any(Base.Fix1(<:, T), track_numbers)
1719
return ConcreteRNumber{T}
1820
end

test/tracing.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ using Test
9090
(Val{:x}, Val{:x}),
9191
]
9292
tracedty = traced_type(
93-
origty, Reactant.OrderedIdDict(), Val(ConcreteToTraced)
93+
origty, Reactant.OrderedIdDict(), Val(ConcreteToTraced), ()
9494
)
9595
@test tracedty == targetty
9696
end
@@ -102,21 +102,21 @@ using Test
102102
TracedRArray{Float64,3},
103103
]
104104
@test_throws Union{ErrorException,String} traced_type(
105-
type, Reactant.OrderedIdDict(), Val(ConcreteToTraced)
105+
type, Reactant.OrderedIdDict(), Val(ConcreteToTraced), ()
106106
)
107107
end
108108
end
109109
@testset "traced_type exceptions" begin
110110
@test_throws TracedTypeError Reactant.traced_type(
111-
Real, Reactant.OrderedIdDict(), Val(Reactant.ArrayToConcrete)
111+
Real, Reactant.OrderedIdDict(), Val(Reactant.ArrayToConcrete), ()
112112
)
113113

114114
struct Node
115115
x::Vector{Float64}
116116
y::Union{Nothing,Node}
117117
end
118118
@test_throws NoFieldMatchError traced_type(
119-
Node, Reactant.OrderedIdDict(), Val(ArrayToConcrete)
119+
Node, Reactant.OrderedIdDict(), Val(ArrayToConcrete), ()
120120
)
121121
end
122122
end

0 commit comments

Comments
 (0)