Skip to content

Commit

Permalink
Merge pull request #514 from jlk9/main
Browse files Browse the repository at this point in the history
Valid index check for gpu in EnzymeExt
  • Loading branch information
wsmoses authored Aug 26, 2024
2 parents 91ada95 + 80b2996 commit 05a3c05
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions ext/EnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import KernelAbstractions:
__index_Global_Linear,
__groupsize,
__groupindex,
__validindex,
Backend,
synchronize

Expand Down Expand Up @@ -219,8 +220,10 @@ function gpu_aug_fwd(

# On the GPU: F is a per thread function
# On the GPU: subtape::Vector
I = __index_Global_Linear(ctx)
subtape[I] = forward(Const(f), Const(ctx), args...)[1]
if __validindex(ctx)
I = __index_Global_Linear(ctx)
subtape[I] = forward(Const(f), Const(ctx), args...)[1]
end
return nothing
end

Expand All @@ -241,9 +244,11 @@ function gpu_rev(
Const{Core.Typeof(ctx)},
map(Core.Typeof, args)...,
)
I = __index_Global_Linear(ctx)
tp = subtape[I]
reverse(Const(f), Const(ctx), args..., tp)
if __validindex(ctx)
I = __index_Global_Linear(ctx)
tp = subtape[I]
reverse(Const(f), Const(ctx), args..., tp)
end
return nothing
end

Expand Down

0 comments on commit 05a3c05

Please sign in to comment.