66 TracedSetPath = 5
77end
88
9- for T in (
10- DataType,
11- Module,
12- Nothing,
13- Symbol,
14- AbstractChar,
15- AbstractFloat,
16- Integer,
17- AbstractString,
18- RArray,
19- RNumber,
20- )
21- @eval function traced_type (:: Type{T} , seen, mode) where {T<: $T }
9+ for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, RArray, RNumber)
10+ @eval function traced_type (:: Type{T} , seen, mode, track_numbers) where {T<: $T }
2211 return T
2312 end
2413end
2514
26- function traced_type (:: Type{C} , seen:: ST , mode:: Val{Mode} ) where {T,C<: Complex{T} ,ST,Mode}
15+ function traced_type (:: Type{T} , seen, mode:: Val{Mode} , track_numbers) where {T<: Number ,Mode}
16+ if Mode == ArrayToConcrete && any (Base. Fix1 (< :, T), track_numbers)
17+ return ConcreteRNumber{T}
18+ end
19+ return T
20+ end
21+
22+ function traced_type (
23+ :: Type{C} , seen:: ST , mode:: Val{Mode} , track_numbers:: TN
24+ ) where {T,C<: Complex{T} ,ST,Mode,TN}
2725 if ! (C isa UnionAll)
28- return Complex{traced_type (T, seen, mode)}
26+ return Complex{traced_type (T, seen, mode, track_numbers )}
2927 else
30- return @invoke traced_type (C:: Type{Any} , seen:: ST , mode:: Val{Mode} )
28+ return @invoke traced_type (
29+ C:: Type{Any} , seen:: ST , mode:: Val{Mode} , track_numbers:: TN
30+ )
3131 end
3232end
3333
34- function traced_type (:: Type{T} , seen, mode) where {T<: Function }
34+ function traced_type (:: Type{T} , seen, mode, track_numbers ) where {T<: Function }
3535 # functions are directly returned
3636 if sizeof (T) == 0
3737 return T
@@ -41,7 +41,7 @@ function traced_type(::Type{T}, seen, mode) where {T<:Function}
4141 N = fieldcount (T)
4242 changed = false
4343 traced_fieldtypes = ntuple (Val (N)) do i
44- next = traced_type (fieldtype (T, i), seen, mode)
44+ next = traced_type (fieldtype (T, i), seen, mode, track_numbers )
4545 changed |= next != fieldtype (T, i)
4646 next
4747 end
5757@inline is_concrete_tuple (x:: T2 ) where {T2} =
5858 (x <: Tuple ) && ! (x === Tuple) && ! (x isa UnionAll)
5959
60- function traced_type (:: Type{T} , seen, mode) where {T<: Tuple }
60+ function traced_type (:: Type{T} , seen, mode, track_numbers ) where {T<: Tuple }
6161 if ! Base. isconcretetype (T) || ! is_concrete_tuple (T) || T isa UnionAll
6262 throw (AssertionError (" Type $T is not concrete type or concrete tuple" ))
6363 elseif is_concrete_tuple (T) && any (T2 isa Core. TypeofVararg for T2 in T. parameters)
6464 # Tuple{((T2 isa Core.TypeofVararg ? Any : T2) for T2 in T.parameters)...}
6565 throw (AssertionError (" Type tuple of vararg $T is not supported" ))
6666 end
67- TT = [traced_type (T. parameters[i], seen, mode) for i in 1 : length (T. parameters)]
67+ TT = [
68+ traced_type (T. parameters[i], seen, mode, track_numbers) for
69+ i in 1 : length (T. parameters)
70+ ]
6871 return Tuple{TT... }
6972end
7073
71- function traced_type (:: Type{T} , seen, mode) where {N,V,T<: NamedTuple{N,V} }
72- return NamedTuple{N,traced_type (V, seen, mode)}
74+ function traced_type (:: Type{T} , seen, mode, track_numbers ) where {N,V,T<: NamedTuple{N,V} }
75+ return NamedTuple{N,traced_type (V, seen, mode, track_numbers )}
7376end
7477
75- function traced_type (:: Type{T} , seen, mode) where {K,V,T<: AbstractDict{K,V} }
78+ function traced_type (:: Type{T} , seen, mode, track_numbers ) where {K,V,T<: AbstractDict{K,V} }
7679 dictty = T. name. wrapper
77- return dictty{K,traced_type (V, seen, mode)}
80+ return dictty{K,traced_type (V, seen, mode, track_numbers )}
7881end
7982
8083@inline getmap (:: Val{T} ) where {T} = nothing
8184@inline getmap (:: Val{T} , a, b, args... ) where {T} = getmap (Val (T), args... )
8285@inline getmap (:: Val{T} , :: Val{T} , :: Val{T2} , args... ) where {T,T2} = T2
8386
84- function traced_type (:: Type{T} , seen, mode) where {T}
87+ function traced_type (:: Type{T} , seen, mode, track_numbers ) where {T}
8588 if T === Any
8689 return T
8790 end
@@ -110,7 +113,10 @@ function traced_type(::Type{T}, seen, mode) where {T}
110113 end
111114
112115 if T isa Union
113- return Union{traced_type (T. a, seen, mode),traced_type (T. b, seen, mode)}
116+ return Union{
117+ traced_type (T. a, seen, mode, track_numbers),
118+ traced_type (T. b, seen, mode, track_numbers),
119+ }
114120 end
115121
116122 # if abstract it must be by reference
@@ -133,7 +139,7 @@ function traced_type(::Type{T}, seen, mode) where {T}
133139 subTys = Type[]
134140 for f in 1 : fieldcount (T)
135141 subT = fieldtype (T, f)
136- subTT = traced_type (subT, seen2, mode)
142+ subTT = traced_type (subT, seen2, mode, track_numbers )
137143 changed |= subT != subTT
138144 push! (subTys, subTT)
139145 end
@@ -145,7 +151,7 @@ function traced_type(::Type{T}, seen, mode) where {T}
145151 subParms = []
146152 for SST in T. parameters
147153 if SST isa Type
148- TrT = traced_type (SST, seen, mode)
154+ TrT = traced_type (SST, seen, mode, track_numbers )
149155 push! (subParms, TrT)
150156 else
151157 push! (subParms, SST)
@@ -163,7 +169,7 @@ function traced_type(::Type{T}, seen, mode) where {T}
163169 for f in 1 : fieldcount (T)
164170 subT = fieldtype (T, f)
165171 subT2 = fieldtype (TT2, f)
166- subTT = traced_type (subT, seen3, mode)
172+ subTT = traced_type (subT, seen3, mode, track_numbers )
167173 if subT2 != subTT
168174 legal = false
169175 break
@@ -178,7 +184,9 @@ function traced_type(::Type{T}, seen, mode) where {T}
178184 throw (NoFieldMatchError (T, TT2))
179185end
180186
181- function traced_type (:: Type{<:ConcreteRNumber{T}} , seen, :: Val{mode} ) where {T,mode}
187+ function traced_type (
188+ :: Type{<:ConcreteRNumber{T}} , seen, :: Val{mode} , track_numbers
189+ ) where {T,mode}
182190 if mode == ConcreteToTraced
183191 return TracedRNumber{T}
184192 elseif mode == TracedToConcrete
@@ -188,7 +196,9 @@ function traced_type(::Type{<:ConcreteRNumber{T}}, seen, ::Val{mode}) where {T,m
188196 end
189197end
190198
191- function traced_type (:: Type{T} , seen, :: Val{mode} ) where {T<: ConcreteRArray ,mode}
199+ function traced_type (
200+ :: Type{T} , seen, :: Val{mode} , track_numbers
201+ ) where {T<: ConcreteRArray ,mode}
192202 if mode == ConcreteToTraced
193203 @inline base_typet (TV:: TT ) where {TT<: UnionAll } =
194204 UnionAll (TV. var, base_typet (TV. body))
@@ -201,7 +211,9 @@ function traced_type(::Type{T}, seen, ::Val{mode}) where {T<:ConcreteRArray,mode
201211 end
202212end
203213
204- function traced_type (:: Type{T} , seen:: ST , :: Val{mode} ) where {ST,T<: TracedType ,mode}
214+ function traced_type (
215+ :: Type{T} , seen:: ST , :: Val{mode} , track_numbers
216+ ) where {ST,T<: TracedType ,mode}
205217 T <: MissingTracedValue && error (" TODO" )
206218 if mode == ConcreteToTraced
207219 throw (" TracedRArray $T cannot be traced" )
@@ -218,26 +230,28 @@ function traced_type(::Type{T}, seen::ST, ::Val{mode}) where {ST,T<:TracedType,m
218230 end
219231end
220232
221- function traced_type (:: Type{T} , seen, mode) where {T<: XLAArray }
233+ function traced_type (:: Type{T} , seen, mode, track_numbers ) where {T<: XLAArray }
222234 throw (" XLA $T array cannot be traced" )
223235end
224236
225- function traced_type (:: Type{A} , seen:: ST , :: Val{mode} ) where {T,N,A<: Array{T,N} ,ST,mode}
237+ function traced_type (
238+ :: Type{A} , seen:: ST , :: Val{mode} , track_numbers
239+ ) where {T,N,A<: Array{T,N} ,ST,mode}
226240 if mode == ArrayToConcrete && T <: ReactantPrimitive
227241 return ConcreteRArray{T,N}
228242 else
229- return Array{traced_type (T, seen, Val (mode)),N}
243+ return Array{traced_type (T, seen, Val (mode), track_numbers ),N}
230244 end
231245end
232246
233247for P in (Ptr, Core. LLVMPtr, Base. RefValue)
234- @eval function traced_type (:: Type{P} , seen, mode) where {T,P<: $P{T} }
235- return $ P{traced_type (T, seen, mode)}
248+ @eval function traced_type (:: Type{P} , seen, mode, track_numbers ) where {T,P<: $P{T} }
249+ return $ P{traced_type (T, seen, mode, track_numbers )}
236250 end
237251end
238252
239- function traced_type (:: Type{Val{T}} , seen, mode) where {T}
240- if traced_type (typeof (T), seen, mode) == typeof (T)
253+ function traced_type (:: Type{Val{T}} , seen, mode, track_numbers ) where {T}
254+ if traced_type (typeof (T), seen, mode, track_numbers ) == typeof (T)
241255 return Val{T}
242256 end
243257 throw (" Val type $(Val{T}) cannot be traced" )
@@ -274,12 +288,13 @@ function make_tracer(
274288 mode;
275289 toscalar= false ,
276290 tobatch= nothing ,
291+ track_numbers= (),
277292 kwargs... ,
278293) where {RT}
279294 if haskey (seen, prev)
280295 return seen[prev]
281296 end
282- TT = traced_type (RT, (), Val (mode))
297+ TT = traced_type (RT, (), Val (mode), track_numbers )
283298 @assert ! Base. isabstracttype (RT)
284299 @assert Base. isconcretetype (RT)
285300 nf = fieldcount (RT)
@@ -295,7 +310,16 @@ function make_tracer(
295310 for i in 1 : nf
296311 if isdefined (prev, i)
297312 xi = Base. getfield (prev, i)
298- xi2 = make_tracer (seen, xi, append_path (path, i), mode; toscalar, tobatch)
313+ xi2 = make_tracer (
314+ seen,
315+ xi,
316+ append_path (path, i),
317+ mode;
318+ toscalar,
319+ tobatch,
320+ track_numbers,
321+ kwargs... ,
322+ )
299323 if xi != = xi2
300324 changed = true
301325 end
@@ -318,7 +342,16 @@ function make_tracer(
318342 for i in 1 : nf
319343 if isdefined (prev, i)
320344 xi = Base. getfield (prev, i)
321- xi2 = make_tracer (seen, xi, append_path (path, i), mode; toscalar, tobatch)
345+ xi2 = make_tracer (
346+ seen,
347+ xi,
348+ append_path (path, i),
349+ mode;
350+ toscalar,
351+ tobatch,
352+ track_numbers,
353+ kwargs... ,
354+ )
322355 if xi != = xi2
323356 changed = true
324357 end
@@ -543,22 +576,22 @@ function make_tracer(
543576end
544577
545578function make_tracer (
546- seen, @nospecialize (prev:: RT ), @nospecialize (path), mode; kwargs...
579+ seen, @nospecialize (prev:: RT ), @nospecialize (path), mode; track_numbers = (), kwargs...
547580) where {RT<: Array }
548581 if haskey (seen, prev)
549582 return seen[prev]
550583 end
551584 if mode == ArrayToConcrete && eltype (RT) <: ReactantPrimitive
552585 return seen[prev] = ConcreteRArray (prev)
553586 end
554- TT = traced_type (eltype (RT), (), Val (mode))
587+ TT = traced_type (eltype (RT), (), Val (mode), track_numbers )
555588 newa = Array {TT,ndims(RT)} (undef, size (prev))
556589 seen[prev] = newa
557590 same = true
558591 for I in eachindex (prev)
559592 if isassigned (prev, I)
560593 pv = prev[I]
561- nv = make_tracer (seen, pv, append_path (path, I), mode; kwargs... )
594+ nv = make_tracer (seen, pv, append_path (path, I), mode; track_numbers, kwargs... )
562595 if pv != = nv
563596 same = false
564597 end
@@ -584,12 +617,22 @@ function make_tracer(
584617end
585618
586619function make_tracer (
587- seen, @nospecialize (prev:: NamedTuple{A,RT} ), @nospecialize (path), mode; kwargs...
620+ seen,
621+ @nospecialize (prev:: NamedTuple{A,RT} ),
622+ @nospecialize (path),
623+ mode;
624+ track_numbers= (),
625+ kwargs... ,
588626) where {A,RT}
589- return NamedTuple {A,traced_type(RT, (), Val(mode))} ((
627+ return NamedTuple {A,traced_type(RT, (), Val(mode), track_numbers )} ((
590628 (
591629 make_tracer (
592- seen, Base. getfield (prev, i), append_path (path, i), mode; kwargs...
630+ seen,
631+ Base. getfield (prev, i),
632+ append_path (path, i),
633+ mode;
634+ track_numbers,
635+ kwargs... ,
593636 ) for i in 1 : length (A)
594637 ). .. ,
595638 ))
0 commit comments