Skip to content

Commit 31f4392

Browse files
CuArray tracing (#475)
* CuArray tracing * Update ext/ReactantCUDAExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent adee0e0 commit 31f4392

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,56 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction(
457457
return res
458458
end
459459

460+
function Reactant.traced_type(
461+
::Type{A}, seen::ST, ::Val{mode}, track_numbers
462+
) where {T,N,A<:CUDA.CuArray{T,N},ST,mode}
463+
if mode == Reactant.ArrayToConcrete && T <: Reactant.ReactantPrimitive
464+
return Reactant.ConcreteRArray{T,N}
465+
else
466+
TT = Reactant.traced_type(T, seen, Val(mode), track_numbers)
467+
if TT === T
468+
return A
469+
else
470+
return Array{traced_type(T, seen, Val(mode), track_numbers),N}
471+
end
472+
end
473+
end
474+
475+
function Reactant.make_tracer(
476+
seen, @nospecialize(prev::RT), @nospecialize(path), mode; track_numbers=(), kwargs...
477+
) where {RT<:CUDA.CuArray}
478+
if haskey(seen, prev)
479+
return seen[prev]
480+
end
481+
if mode == Reactant.ArrayToConcrete && eltype(RT) <: Reactant.ReactantPrimitive
482+
return seen[prev] = Reactant.ConcreteRArray(Array(prev))
483+
end
484+
TT = Reactant.traced_type(eltype(RT), (), Val(mode), track_numbers)
485+
if TT === eltype(RT)
486+
return prev
487+
end
488+
newa = Array{TT,ndims(RT)}(undef, size(prev))
489+
seen[prev] = newa
490+
same = true
491+
for I in eachindex(prev)
492+
if isassigned(prev, I)
493+
pv = prev[I]
494+
nv = Reactant.make_tracer(
495+
seen, pv, append_path(path, I), mode; track_numbers, kwargs...
496+
)
497+
if pv !== nv
498+
same = false
499+
end
500+
@inbounds newa[I] = nv
501+
end
502+
end
503+
if same
504+
seen[prev] = prev
505+
return prev
506+
end
507+
return newa
508+
end
509+
460510
function __init__()
461511
if isdefined(CUDA.CUDA_Driver_jll, :libcuda) && CUDA.CUDA_Driver_jll.libcuda !== nothing
462512
handle = Reactant.XLA.Libdl.dlopen(CUDA.CUDA_Driver_jll.libcuda; throw_error=false)

0 commit comments

Comments
 (0)