Skip to content

Commit 5395887

Browse files
Avik Palavik-pal
authored andcommitted
fix: propagate track_numbers correctly
1 parent 5afd2e9 commit 5395887

File tree

1 file changed

+91
-48
lines changed

1 file changed

+91
-48
lines changed

src/Tracing.jl

Lines changed: 91 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,32 @@
66
TracedSetPath = 5
77
end
88

9-
for T in (
10-
DataType,
11-
Module,
12-
Nothing,
13-
Symbol,
14-
AbstractChar,
15-
AbstractFloat,
16-
Integer,
17-
AbstractString,
18-
RArray,
19-
RNumber,
20-
)
21-
@eval function traced_type(::Type{T}, seen, mode) where {T<:$T}
9+
for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, RArray, RNumber)
10+
@eval function traced_type(::Type{T}, seen, mode, track_numbers) where {T<:$T}
2211
return T
2312
end
2413
end
2514

26-
function traced_type(::Type{C}, seen::ST, mode::Val{Mode}) where {T,C<:Complex{T},ST,Mode}
15+
function traced_type(::Type{T}, seen, mode::Val{Mode}, track_numbers) where {T<:Number,Mode}
16+
if Mode == ArrayToConcrete && any(Base.Fix1(<:, T), track_numbers)
17+
return ConcreteRNumber{T}
18+
end
19+
return T
20+
end
21+
22+
function traced_type(
23+
::Type{C}, seen::ST, mode::Val{Mode}, track_numbers::TN
24+
) where {T,C<:Complex{T},ST,Mode,TN}
2725
if !(C isa UnionAll)
28-
return Complex{traced_type(T, seen, mode)}
26+
return Complex{traced_type(T, seen, mode, track_numbers)}
2927
else
30-
return @invoke traced_type(C::Type{Any}, seen::ST, mode::Val{Mode})
28+
return @invoke traced_type(
29+
C::Type{Any}, seen::ST, mode::Val{Mode}, track_numbers::TN
30+
)
3131
end
3232
end
3333

34-
function traced_type(::Type{T}, seen, mode) where {T<:Function}
34+
function traced_type(::Type{T}, seen, mode, track_numbers) where {T<:Function}
3535
# functions are directly returned
3636
if sizeof(T) == 0
3737
return T
@@ -41,7 +41,7 @@ function traced_type(::Type{T}, seen, mode) where {T<:Function}
4141
N = fieldcount(T)
4242
changed = false
4343
traced_fieldtypes = ntuple(Val(N)) do i
44-
next = traced_type(fieldtype(T, i), seen, mode)
44+
next = traced_type(fieldtype(T, i), seen, mode, track_numbers)
4545
changed |= next != fieldtype(T, i)
4646
next
4747
end
@@ -57,31 +57,34 @@ end
5757
@inline is_concrete_tuple(x::T2) where {T2} =
5858
(x <: Tuple) && !(x === Tuple) && !(x isa UnionAll)
5959

60-
function traced_type(::Type{T}, seen, mode) where {T<:Tuple}
60+
function traced_type(::Type{T}, seen, mode, track_numbers) where {T<:Tuple}
6161
if !Base.isconcretetype(T) || !is_concrete_tuple(T) || T isa UnionAll
6262
throw(AssertionError("Type $T is not concrete type or concrete tuple"))
6363
elseif is_concrete_tuple(T) && any(T2 isa Core.TypeofVararg for T2 in T.parameters)
6464
# Tuple{((T2 isa Core.TypeofVararg ? Any : T2) for T2 in T.parameters)...}
6565
throw(AssertionError("Type tuple of vararg $T is not supported"))
6666
end
67-
TT = [traced_type(T.parameters[i], seen, mode) for i in 1:length(T.parameters)]
67+
TT = [
68+
traced_type(T.parameters[i], seen, mode, track_numbers) for
69+
i in 1:length(T.parameters)
70+
]
6871
return Tuple{TT...}
6972
end
7073

71-
function traced_type(::Type{T}, seen, mode) where {N,V,T<:NamedTuple{N,V}}
72-
return NamedTuple{N,traced_type(V, seen, mode)}
74+
function traced_type(::Type{T}, seen, mode, track_numbers) where {N,V,T<:NamedTuple{N,V}}
75+
return NamedTuple{N,traced_type(V, seen, mode, track_numbers)}
7376
end
7477

75-
function traced_type(::Type{T}, seen, mode) where {K,V,T<:AbstractDict{K,V}}
78+
function traced_type(::Type{T}, seen, mode, track_numbers) where {K,V,T<:AbstractDict{K,V}}
7679
dictty = T.name.wrapper
77-
return dictty{K,traced_type(V, seen, mode)}
80+
return dictty{K,traced_type(V, seen, mode, track_numbers)}
7881
end
7982

8083
@inline getmap(::Val{T}) where {T} = nothing
8184
@inline getmap(::Val{T}, a, b, args...) where {T} = getmap(Val(T), args...)
8285
@inline getmap(::Val{T}, ::Val{T}, ::Val{T2}, args...) where {T,T2} = T2
8386

84-
function traced_type(::Type{T}, seen, mode) where {T}
87+
function traced_type(::Type{T}, seen, mode, track_numbers) where {T}
8588
if T === Any
8689
return T
8790
end
@@ -110,7 +113,10 @@ function traced_type(::Type{T}, seen, mode) where {T}
110113
end
111114

112115
if T isa Union
113-
return Union{traced_type(T.a, seen, mode),traced_type(T.b, seen, mode)}
116+
return Union{
117+
traced_type(T.a, seen, mode, track_numbers),
118+
traced_type(T.b, seen, mode, track_numbers),
119+
}
114120
end
115121

116122
# if abstract it must be by reference
@@ -133,7 +139,7 @@ function traced_type(::Type{T}, seen, mode) where {T}
133139
subTys = Type[]
134140
for f in 1:fieldcount(T)
135141
subT = fieldtype(T, f)
136-
subTT = traced_type(subT, seen2, mode)
142+
subTT = traced_type(subT, seen2, mode, track_numbers)
137143
changed |= subT != subTT
138144
push!(subTys, subTT)
139145
end
@@ -145,7 +151,7 @@ function traced_type(::Type{T}, seen, mode) where {T}
145151
subParms = []
146152
for SST in T.parameters
147153
if SST isa Type
148-
TrT = traced_type(SST, seen, mode)
154+
TrT = traced_type(SST, seen, mode, track_numbers)
149155
push!(subParms, TrT)
150156
else
151157
push!(subParms, SST)
@@ -163,7 +169,7 @@ function traced_type(::Type{T}, seen, mode) where {T}
163169
for f in 1:fieldcount(T)
164170
subT = fieldtype(T, f)
165171
subT2 = fieldtype(TT2, f)
166-
subTT = traced_type(subT, seen3, mode)
172+
subTT = traced_type(subT, seen3, mode, track_numbers)
167173
if subT2 != subTT
168174
legal = false
169175
break
@@ -178,7 +184,9 @@ function traced_type(::Type{T}, seen, mode) where {T}
178184
throw(NoFieldMatchError(T, TT2))
179185
end
180186

181-
function traced_type(::Type{<:ConcreteRNumber{T}}, seen, ::Val{mode}) where {T,mode}
187+
function traced_type(
188+
::Type{<:ConcreteRNumber{T}}, seen, ::Val{mode}, track_numbers
189+
) where {T,mode}
182190
if mode == ConcreteToTraced
183191
return TracedRNumber{T}
184192
elseif mode == TracedToConcrete
@@ -188,7 +196,9 @@ function traced_type(::Type{<:ConcreteRNumber{T}}, seen, ::Val{mode}) where {T,m
188196
end
189197
end
190198

191-
function traced_type(::Type{T}, seen, ::Val{mode}) where {T<:ConcreteRArray,mode}
199+
function traced_type(
200+
::Type{T}, seen, ::Val{mode}, track_numbers
201+
) where {T<:ConcreteRArray,mode}
192202
if mode == ConcreteToTraced
193203
@inline base_typet(TV::TT) where {TT<:UnionAll} =
194204
UnionAll(TV.var, base_typet(TV.body))
@@ -201,7 +211,9 @@ function traced_type(::Type{T}, seen, ::Val{mode}) where {T<:ConcreteRArray,mode
201211
end
202212
end
203213

204-
function traced_type(::Type{T}, seen::ST, ::Val{mode}) where {ST,T<:TracedType,mode}
214+
function traced_type(
215+
::Type{T}, seen::ST, ::Val{mode}, track_numbers
216+
) where {ST,T<:TracedType,mode}
205217
T <: MissingTracedValue && error("TODO")
206218
if mode == ConcreteToTraced
207219
throw("TracedRArray $T cannot be traced")
@@ -218,26 +230,28 @@ function traced_type(::Type{T}, seen::ST, ::Val{mode}) where {ST,T<:TracedType,m
218230
end
219231
end
220232

221-
function traced_type(::Type{T}, seen, mode) where {T<:XLAArray}
233+
function traced_type(::Type{T}, seen, mode, track_numbers) where {T<:XLAArray}
222234
throw("XLA $T array cannot be traced")
223235
end
224236

225-
function traced_type(::Type{A}, seen::ST, ::Val{mode}) where {T,N,A<:Array{T,N},ST,mode}
237+
function traced_type(
238+
::Type{A}, seen::ST, ::Val{mode}, track_numbers
239+
) where {T,N,A<:Array{T,N},ST,mode}
226240
if mode == ArrayToConcrete && T <: ReactantPrimitive
227241
return ConcreteRArray{T,N}
228242
else
229-
return Array{traced_type(T, seen, Val(mode)),N}
243+
return Array{traced_type(T, seen, Val(mode), track_numbers),N}
230244
end
231245
end
232246

233247
for P in (Ptr, Core.LLVMPtr, Base.RefValue)
234-
@eval function traced_type(::Type{P}, seen, mode) where {T,P<:$P{T}}
235-
return $P{traced_type(T, seen, mode)}
248+
@eval function traced_type(::Type{P}, seen, mode, track_numbers) where {T,P<:$P{T}}
249+
return $P{traced_type(T, seen, mode, track_numbers)}
236250
end
237251
end
238252

239-
function traced_type(::Type{Val{T}}, seen, mode) where {T}
240-
if traced_type(typeof(T), seen, mode) == typeof(T)
253+
function traced_type(::Type{Val{T}}, seen, mode, track_numbers) where {T}
254+
if traced_type(typeof(T), seen, mode, track_numbers) == typeof(T)
241255
return Val{T}
242256
end
243257
throw("Val type $(Val{T}) cannot be traced")
@@ -274,12 +288,13 @@ function make_tracer(
274288
mode;
275289
toscalar=false,
276290
tobatch=nothing,
291+
track_numbers=(),
277292
kwargs...,
278293
) where {RT}
279294
if haskey(seen, prev)
280295
return seen[prev]
281296
end
282-
TT = traced_type(RT, (), Val(mode))
297+
TT = traced_type(RT, (), Val(mode), track_numbers)
283298
@assert !Base.isabstracttype(RT)
284299
@assert Base.isconcretetype(RT)
285300
nf = fieldcount(RT)
@@ -295,7 +310,16 @@ function make_tracer(
295310
for i in 1:nf
296311
if isdefined(prev, i)
297312
xi = Base.getfield(prev, i)
298-
xi2 = make_tracer(seen, xi, append_path(path, i), mode; toscalar, tobatch)
313+
xi2 = make_tracer(
314+
seen,
315+
xi,
316+
append_path(path, i),
317+
mode;
318+
toscalar,
319+
tobatch,
320+
track_numbers,
321+
kwargs...,
322+
)
299323
if xi !== xi2
300324
changed = true
301325
end
@@ -318,7 +342,16 @@ function make_tracer(
318342
for i in 1:nf
319343
if isdefined(prev, i)
320344
xi = Base.getfield(prev, i)
321-
xi2 = make_tracer(seen, xi, append_path(path, i), mode; toscalar, tobatch)
345+
xi2 = make_tracer(
346+
seen,
347+
xi,
348+
append_path(path, i),
349+
mode;
350+
toscalar,
351+
tobatch,
352+
track_numbers,
353+
kwargs...,
354+
)
322355
if xi !== xi2
323356
changed = true
324357
end
@@ -543,22 +576,22 @@ function make_tracer(
543576
end
544577

545578
function make_tracer(
546-
seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs...
579+
seen, @nospecialize(prev::RT), @nospecialize(path), mode; track_numbers=(), kwargs...
547580
) where {RT<:Array}
548581
if haskey(seen, prev)
549582
return seen[prev]
550583
end
551584
if mode == ArrayToConcrete && eltype(RT) <: ReactantPrimitive
552585
return seen[prev] = ConcreteRArray(prev)
553586
end
554-
TT = traced_type(eltype(RT), (), Val(mode))
587+
TT = traced_type(eltype(RT), (), Val(mode), track_numbers)
555588
newa = Array{TT,ndims(RT)}(undef, size(prev))
556589
seen[prev] = newa
557590
same = true
558591
for I in eachindex(prev)
559592
if isassigned(prev, I)
560593
pv = prev[I]
561-
nv = make_tracer(seen, pv, append_path(path, I), mode; kwargs...)
594+
nv = make_tracer(seen, pv, append_path(path, I), mode; track_numbers, kwargs...)
562595
if pv !== nv
563596
same = false
564597
end
@@ -584,12 +617,22 @@ function make_tracer(
584617
end
585618

586619
function make_tracer(
587-
seen, @nospecialize(prev::NamedTuple{A,RT}), @nospecialize(path), mode; kwargs...
620+
seen,
621+
@nospecialize(prev::NamedTuple{A,RT}),
622+
@nospecialize(path),
623+
mode;
624+
track_numbers=(),
625+
kwargs...,
588626
) where {A,RT}
589-
return NamedTuple{A,traced_type(RT, (), Val(mode))}((
627+
return NamedTuple{A,traced_type(RT, (), Val(mode), track_numbers)}((
590628
(
591629
make_tracer(
592-
seen, Base.getfield(prev, i), append_path(path, i), mode; kwargs...
630+
seen,
631+
Base.getfield(prev, i),
632+
append_path(path, i),
633+
mode;
634+
track_numbers,
635+
kwargs...,
593636
) for i in 1:length(A)
594637
)...,
595638
))

0 commit comments

Comments
 (0)