@@ -457,6 +457,56 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction(
457457 return res
458458end
459459
460+ function Reactant. traced_type (
461+ :: Type{A} , seen:: ST , :: Val{mode} , track_numbers
462+ ) where {T,N,A<: CUDA.CuArray{T,N} ,ST,mode}
463+ if mode == Reactant. ArrayToConcrete && T <: Reactant.ReactantPrimitive
464+ return Reactant. ConcreteRArray{T,N}
465+ else
466+ TT = Reactant. traced_type (T, seen, Val (mode), track_numbers)
467+ if TT === T
468+ return A
469+ else
470+ return Array{traced_type (T, seen, Val (mode), track_numbers),N}
471+ end
472+ end
473+ end
474+
475+ function Reactant. make_tracer (
476+ seen, @nospecialize (prev:: RT ), @nospecialize (path), mode; track_numbers= (), kwargs...
477+ ) where {RT<: CUDA.CuArray }
478+ if haskey (seen, prev)
479+ return seen[prev]
480+ end
481+ if mode == Reactant. ArrayToConcrete && eltype (RT) <: Reactant.ReactantPrimitive
482+ return seen[prev] = Reactant. ConcreteRArray (Array (prev))
483+ end
484+ TT = Reactant. traced_type (eltype (RT), (), Val (mode), track_numbers)
485+ if TT === eltype (RT)
486+ return prev
487+ end
488+ newa = Array {TT,ndims(RT)} (undef, size (prev))
489+ seen[prev] = newa
490+ same = true
491+ for I in eachindex (prev)
492+ if isassigned (prev, I)
493+ pv = prev[I]
494+ nv = Reactant. make_tracer (
495+ seen, pv, append_path (path, I), mode; track_numbers, kwargs...
496+ )
497+ if pv != = nv
498+ same = false
499+ end
500+ @inbounds newa[I] = nv
501+ end
502+ end
503+ if same
504+ seen[prev] = prev
505+ return prev
506+ end
507+ return newa
508+ end
509+
460510function __init__ ()
461511 if isdefined (CUDA. CUDA_Driver_jll, :libcuda ) && CUDA. CUDA_Driver_jll. libcuda != = nothing
462512 handle = Reactant. XLA. Libdl. dlopen (CUDA. CUDA_Driver_jll. libcuda; throw_error= false )
0 commit comments