From de2d6cd86649b6b27d2ce3bab914b7c2f5b0268f Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 20 May 2024 05:46:31 +0530 Subject: [PATCH] chore: don't return structural tangent --- ext/SciMLBaseZygoteExt.jl | 20 ++++++++++++-------- test/downstream/observables_autodiff.jl | 6 +++--- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 242500983..cc09cc9b1 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -6,9 +6,10 @@ import Zygote: literal_getproperty using SciMLBase using SciMLBase: ODESolution, remake, getobserved, build_solution, EnsembleSolution, - NonlinearSolution, AbstractTimeseriesSolution + NonlinearSolution, AbstractTimeseriesSolution, + SciMLStructures using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed, - observed, parameter_values + observed, parameter_values, state_values, current_time using RecursiveArrayTools # This method resolves the ambiguity with the pullback defined in @@ -111,10 +112,13 @@ end function ODESolution_getindex_pullback(Δ) i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym if is_observed(VA, sym) - y, back = Zygote.pullback(VA) do sol - f = observed(sol, sym) - p = parameter_values(sol) - f.(sol.u, Ref(p), sol.t) + f = observed(VA, sym) + p = parameter_values(VA) + tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p) + u = state_values(VA) + t = current_time(VA) + y, back = Zygote.pullback(u, tunables) do u, tunables + f.(u, Ref(tunables), t) end gs = back(Δ) (gs[1], nothing) @@ -154,8 +158,7 @@ function not_obs_grads(VA::ODESolution{T}, sym, not_obss_idx, i, Δ) where {T} end end - nt = Zygote.nt_nothing(VA) - Zygote.accum(nt, (u = Δ′,)) + Δ′ end @adjoint function Base.getindex( @@ -171,6 +174,7 @@ end gs_not_obs = not_obs_grads(VA, sym, not_obs_idx, i, Δ) a = Zygote.accum(gs_obs[1], gs_not_obs) + (a, nothing) end VA[sym], ODESolution_getindex_pullback diff --git a/test/downstream/observables_autodiff.jl b/test/downstream/observables_autodiff.jl index ecf38f480..e03cb7fc6 100644 --- a/test/downstream/observables_autodiff.jl +++ b/test/downstream/observables_autodiff.jl @@ -35,7 +35,7 @@ sol = solve(prob, Tsit5()) end du_ = [0.0, 1.0, 1.0, 1.0] du = [du_ for _ in sol.u] - @test du == gs.u + @test du == gs # Observable in a vector gs, = gradient(sol) do sol @@ -43,7 +43,7 @@ sol = solve(prob, Tsit5()) end du_ = [0.0, 1.0, 1.0, 2.0] du = [du_ for _ in sol.u] - @test du == gs.u + @test du == gs end # DAE @@ -84,7 +84,7 @@ end end du_ = [0.2, 1.0] du = [du_ for _ in sol.u] - @test gs.u == du + @test gs == du end # @testset "Adjoints with DAE" begin