Skip to content

Commit 015a748

Browse files
authored
Despecialize make_tracer (#540)
* Despecialize make_tracer * further despec * even more * even more inference stop * more testinfra * fix * more despec * fixes * fix * fix * fix * fix * fixup * fix * more stuff * fix
1 parent 93f9f07 commit 015a748

File tree

11 files changed

+400
-252
lines changed

11 files changed

+400
-252
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>"]
4-
version = "0.2.18"
4+
version = "0.2.19"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

ext/ReactantCUDAExt.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -722,37 +722,40 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction(
722722
return Core.Typeof(res)(f, res.entry)
723723
end
724724

725-
function Reactant.traced_type(
726-
::Type{A}, seen::ST, ::Val{mode}, track_numbers
727-
) where {A<:CuTracedArray,ST,mode}
725+
Base.@nospecializeinfer function Reactant.traced_type_inner(
726+
@nospecialize(A::Type{<:CuTracedArray}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)
727+
)
728728
return A
729729
end
730730

731-
function Reactant.traced_type(
732-
::Type{A}, seen::ST, ::Val{mode}, track_numbers
733-
) where {T,N,A<:CUDA.CuArray{T,N},ST,mode}
731+
Base.@nospecializeinfer function Reactant.traced_type_inner(
732+
@nospecialize(A::Type{<:CUDA.CuArray}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type)
733+
)
734+
T = eltype(A)
735+
N = ndims(A)
734736
if mode == Reactant.ArrayToConcrete && T <: Reactant.ReactantPrimitive
735737
return Reactant.ConcreteRArray{T,N}
736738
else
737-
TT = Reactant.traced_type(T, seen, Val(mode), track_numbers)
739+
TT = Reactant.traced_type_inner(T, seen, mode, track_numbers)
738740
if TT === T
739741
return A
740742
else
741-
return Array{traced_type(T, seen, Val(mode), track_numbers),N}
743+
return Array{Reactant.traced_type_inner(T, seen, mode, track_numbers),N}
742744
end
743745
end
744746
end
745747

746748
function Reactant.make_tracer(
747-
seen, @nospecialize(prev::RT), @nospecialize(path), mode; track_numbers=(), kwargs...
748-
) where {RT<:CUDA.CuArray}
749+
seen, @nospecialize(prev::CUDA.CuArray), @nospecialize(path), mode; @nospecialize(track_numbers::Type=Union{}), kwargs...
750+
)
751+
RT = Core.Typeof(prev)
749752
if haskey(seen, prev)
750753
return seen[prev]
751754
end
752755
if mode == Reactant.ArrayToConcrete && eltype(RT) <: Reactant.ReactantPrimitive
753756
return seen[prev] = Reactant.ConcreteRArray(Array(prev))
754757
end
755-
TT = Reactant.traced_type(eltype(RT), (), Val(mode), track_numbers)
758+
TT = Reactant.traced_type(eltype(RT), Val(mode), track_numbers)
756759
if TT === eltype(RT)
757760
return prev
758761
end

ext/ReactantOffsetArraysExt.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
module ReactantOffsetArraysExt
22

3+
using OffsetArrays
34
using OffsetArrays: OffsetArray
45
using Reactant: Reactant, MLIR, Ops, TracedRArray
56

6-
function Reactant.traced_type(
7-
::Type{<:OffsetArray{<:Any,N,T}}, seen::ST, ::Val{mode}, track_numbers
8-
) where {T,N,ST,mode}
9-
T2 = Reactant.traced_type(T, seen, Val(mode), track_numbers)
7+
Base.@nospecializeinfer function Reactant.traced_type_inner(
8+
@nospecialize(OA::Type{<:OffsetArray}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type=Union{})
9+
)
10+
N = ndims(OA)
11+
T = OffsetArrays.parenttype(OA)
12+
T2 = Reactant.traced_type_inner(T, seen, mode, track_numbers)
1013
return OffsetArray{eltype(T2),N,T2}
1114
end
1215

src/Reactant.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,23 @@ else
2323
const ReactantFloat = Union{Float16,Float32,Float64}
2424
end
2525

26-
const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64}
26+
@static if isdefined(Core, :BFloat16)
27+
const ReactantComplexFloat = Union{Complex{Float16},Complex{Core.BFloat16},Complex{Float32},Complex{Float64}}
28+
else
29+
const ReactantComplexFloat = Union{Complex{Float16},Complex{Float32},Complex{Float64}}
30+
end
31+
32+
const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128}
33+
34+
const ReactantComplexInt = Union{Complex{Int8},Complex{UInt8},Complex{Int16},Complex{UInt16},Complex{Int32},Complex{UInt32},Complex{Int64},Complex{UInt64},Complex{Int128},Complex{UInt128}}
2735

2836
const ReactantFloatInt = Union{
2937
Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)...
3038
}
3139

3240
const ReactantPrimitive = Union{
33-
Bool,Base.uniontypes(ReactantFloatInt)...,Complex{Float32},Complex{Float64}
41+
Bool,Base.uniontypes(ReactantFloatInt)...,
42+
Base.uniontypes(ReactantComplexInt)...,Base.uniontypes(ReactantComplexFloat)...
3443
}
3544

3645
abstract type RNumber{T<:ReactantPrimitive} <: Number end

src/TracedUtils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ function make_mlir_fn(
131131
(:args, i),
132132
concretein ? Reactant.ConcreteToTraced : Reactant.TracedSetPath;
133133
toscalar,
134-
track_numbers=construct_function_without_args ? (Number,) : (),
134+
track_numbers=construct_function_without_args ? Number : Union{},
135135
)
136136
end
137137

@@ -201,7 +201,7 @@ function make_mlir_fn(
201201
result,
202202
(:result,),
203203
concretein ? Reactant.TracedTrack : Reactant.TracedSetPath;
204-
track_numbers=construct_function_without_args ? (Number,) : (),
204+
track_numbers=construct_function_without_args ? Number : Union{},
205205
)
206206

207207
# marks buffers to be donated

0 commit comments

Comments
 (0)