diff --git a/Project.toml b/Project.toml index 82fb9a26..3545861d 100644 --- a/Project.toml +++ b/Project.toml @@ -56,7 +56,7 @@ Combinatorics = "1" DataStructures = "0.18" Distances = "0.10" Distributions = "0.23, 0.24, 0.25" -Enzyme = "0.12" +Enzyme = "0.13" EzXML = "1" FLoops = "0.2" GLMakie = "0.8, 0.9, 0.10" diff --git a/test/gradients.jl b/test/gradients.jl index 25e3322c..f70f01e2 100644 --- a/test/gradients.jl +++ b/test/gradients.jl @@ -191,20 +191,20 @@ end if forward grad_enzyme = ( autodiff( - Forward, loss, Duplicated, Duplicated(σ, one(T)), Const(r0), - Duplicated(copy(coords), zero(coords)), + set_runtime_activity(Forward), loss, Duplicated, + Duplicated(σ, one(T)), Const(r0), Duplicated(copy(coords), zero(coords)), Duplicated(copy(velocities), zero(velocities)), const_args..., - )[2], + )[1], autodiff( - Forward, loss, Duplicated, Const(σ), Duplicated(r0, one(T)), - Duplicated(copy(coords), zero(coords)), + set_runtime_activity(Forward), loss, Duplicated, + Const(σ), Duplicated(r0, one(T)), Duplicated(copy(coords), zero(coords)), Duplicated(copy(velocities), zero(velocities)), const_args..., - )[2], + )[1], ) else grad_enzyme = autodiff( - Reverse, loss, Active, Active(σ), Active(r0), - Duplicated(copy(coords), zero(coords)), + set_runtime_activity(Reverse), loss, Active, + Active(σ), Active(r0), Duplicated(copy(coords), zero(coords)), Duplicated(copy(velocities), zero(velocities)), const_args..., )[1][1:2] end @@ -426,8 +426,9 @@ end n_threads = parallel ? Threads.nthreads() : 1 grads_enzyme = Dict(k => 0.0 for k in keys(params_dic)) autodiff( - Reverse, test_fn, Active, Duplicated(params_dic, grads_enzyme), - Const(sys_ref), Duplicated(copy(sys_ref.coords), zero(sys_ref.coords)), + set_runtime_activity(Reverse), test_fn, Active, + Duplicated(params_dic, grads_enzyme), Const(sys_ref), + Duplicated(copy(sys_ref.coords), zero(sys_ref.coords)), Duplicated(sys_ref.neighbor_finder, sys_ref.neighbor_finder), Const(n_threads), ) diff --git a/test/runtests.jl b/test/runtests.jl index e7aa5223..732bcbe5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -66,9 +66,6 @@ const openmm_dir = joinpath(data_dir, "openmm_6mrr") const temp_fp_pdb = tempname(cleanup=true) * ".pdb" const temp_fp_viz = tempname(cleanup=true) * ".mp4" -# Required for gradient tests -Enzyme.API.runtimeActivity!(true) - if GROUP in ("All", "NotGradients") # Some failures due to dependencies but there is an unbound args error Aqua.test_all(