Skip to content

Commit

Permalink
update Enzyme version
Browse files Browse the repository at this point in the history
  • Loading branch information
jgreener64 committed Oct 7, 2024
1 parent c7dcee3 commit 33e06a5
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Combinatorics = "1"
DataStructures = "0.18"
Distances = "0.10"
Distributions = "0.23, 0.24, 0.25"
Enzyme = "0.12"
Enzyme = "0.13"
EzXML = "1"
FLoops = "0.2"
GLMakie = "0.8, 0.9, 0.10"
Expand Down
21 changes: 11 additions & 10 deletions test/gradients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,20 +191,20 @@ end
if forward
grad_enzyme = (
autodiff(
Forward, loss, Duplicated, Duplicated(σ, one(T)), Const(r0),
Duplicated(copy(coords), zero(coords)),
set_runtime_activity(Forward), loss, Duplicated,
Duplicated(σ, one(T)), Const(r0), Duplicated(copy(coords), zero(coords)),
Duplicated(copy(velocities), zero(velocities)), const_args...,
)[2],
)[1],
autodiff(
Forward, loss, Duplicated, Const(σ), Duplicated(r0, one(T)),
Duplicated(copy(coords), zero(coords)),
set_runtime_activity(Forward), loss, Duplicated,
Const(σ), Duplicated(r0, one(T)), Duplicated(copy(coords), zero(coords)),
Duplicated(copy(velocities), zero(velocities)), const_args...,
)[2],
)[1],
)
else
grad_enzyme = autodiff(
Reverse, loss, Active, Active(σ), Active(r0),
Duplicated(copy(coords), zero(coords)),
set_runtime_activity(Reverse), loss, Active,
Active(σ), Active(r0), Duplicated(copy(coords), zero(coords)),
Duplicated(copy(velocities), zero(velocities)), const_args...,
)[1][1:2]
end
Expand Down Expand Up @@ -426,8 +426,9 @@ end
n_threads = parallel ? Threads.nthreads() : 1
grads_enzyme = Dict(k => 0.0 for k in keys(params_dic))
autodiff(
Reverse, test_fn, Active, Duplicated(params_dic, grads_enzyme),
Const(sys_ref), Duplicated(copy(sys_ref.coords), zero(sys_ref.coords)),
set_runtime_activity(Reverse), test_fn, Active,
Duplicated(params_dic, grads_enzyme), Const(sys_ref),
Duplicated(copy(sys_ref.coords), zero(sys_ref.coords)),
Duplicated(sys_ref.neighbor_finder, sys_ref.neighbor_finder),
Const(n_threads),
)
Expand Down
3 changes: 0 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ const openmm_dir = joinpath(data_dir, "openmm_6mrr")
const temp_fp_pdb = tempname(cleanup=true) * ".pdb"
const temp_fp_viz = tempname(cleanup=true) * ".mp4"

# Required for gradient tests
Enzyme.API.runtimeActivity!(true)

if GROUP in ("All", "NotGradients")
# Some failures due to dependencies but there is an unbound args error
Aqua.test_all(
Expand Down

0 comments on commit 33e06a5

Please sign in to comment.