@@ -722,37 +722,40 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction(
722722 return Core. Typeof (res)(f, res. entry)
723723end
724724
725- function Reactant. traced_type (
726- :: Type{A} , seen:: ST , :: Val{mode} , track_numbers
727- ) where {A <: CuTracedArray ,ST,mode}
725+ Base . @nospecializeinfer function Reactant. traced_type_inner (
726+ @nospecialize (A :: Type{<:CuTracedArray} ) , seen, mode :: Reactant.TraceMode , @nospecialize ( track_numbers:: Type )
727+ )
728728 return A
729729end
730730
731- function Reactant. traced_type (
732- :: Type{A} , seen:: ST , :: Val{mode} , track_numbers
733- ) where {T,N,A<: CUDA.CuArray{T,N} ,ST,mode}
731+ Base. @nospecializeinfer function Reactant. traced_type_inner (
732+ @nospecialize (A:: Type{<:CUDA.CuArray} ), seen, mode:: Reactant.TraceMode , @nospecialize (track_numbers:: Type )
733+ )
734+ T = eltype (A)
735+ N = ndims (A)
734736 if mode == Reactant. ArrayToConcrete && T <: Reactant.ReactantPrimitive
735737 return Reactant. ConcreteRArray{T,N}
736738 else
737- TT = Reactant. traced_type (T, seen, Val ( mode) , track_numbers)
739+ TT = Reactant. traced_type_inner (T, seen, mode, track_numbers)
738740 if TT === T
739741 return A
740742 else
741- return Array{traced_type (T, seen, Val ( mode) , track_numbers),N}
743+ return Array{Reactant . traced_type_inner (T, seen, mode, track_numbers),N}
742744 end
743745 end
744746end
745747
746748function Reactant. make_tracer (
747- seen, @nospecialize (prev:: RT ), @nospecialize (path), mode; track_numbers= (), kwargs...
748- ) where {RT<: CUDA.CuArray }
749+ seen, @nospecialize (prev:: CUDA.CuArray ), @nospecialize (path), mode; @nospecialize (track_numbers:: Type = Union{}), kwargs...
750+ )
751+ RT = Core. Typeof (prev)
749752 if haskey (seen, prev)
750753 return seen[prev]
751754 end
752755 if mode == Reactant. ArrayToConcrete && eltype (RT) <: Reactant.ReactantPrimitive
753756 return seen[prev] = Reactant. ConcreteRArray (Array (prev))
754757 end
755- TT = Reactant. traced_type (eltype (RT), (), Val (mode), track_numbers)
758+ TT = Reactant. traced_type (eltype (RT), Val (mode), track_numbers)
756759 if TT === eltype (RT)
757760 return prev
758761 end
0 commit comments