From 02ae6b54852e438d528452f3944ea91d5e68a3c9 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 11 Sep 2025 12:01:05 +0200 Subject: [PATCH] fix: handle constant derivatives with runtime activity for Enzyme --- DifferentiationInterface/CHANGELOG.md | 6 +++ .../DifferentiationInterfaceEnzymeExt.jl | 5 +- .../forward_onearg.jl | 26 +++++++--- .../reverse_onearg.jl | 2 + .../utils.jl | 36 ++++++++++++++ .../test/Back/Enzyme/test.jl | 16 ++++++ .../src/scenarios/modify.jl | 49 +++++++++++++++++++ DifferentiationInterfaceTest/test/weird.jl | 7 +++ 8 files changed, 139 insertions(+), 8 deletions(-) diff --git a/DifferentiationInterface/CHANGELOG.md b/DifferentiationInterface/CHANGELOG.md index c60369666..7fce52a7f 100644 --- a/DifferentiationInterface/CHANGELOG.md +++ b/DifferentiationInterface/CHANGELOG.md @@ -7,8 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.7...main) +### Fixed + + - Handle constant derivatives with runtime activity for Enzyme + ## [0.7.7](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.6...DifferentiationInterface-v0.7.7) +### Fixed + - Improve support for empty inputs (still not guaranteed) ([#835](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/835)) ## [0.7.6](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.5...DifferentiationInterface-v0.7.6) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl index dbf1e4f5c..407b15835 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl @@ -1,7 +1,7 @@ module DifferentiationInterfaceEnzymeExt using ADTypes: ADTypes, AutoEnzyme -using Base: Fix1 +using Base: Fix1, datatype_pointerfree import DifferentiationInterface as DI using EnzymeCore: Active, @@ -42,7 +42,8 @@ using Enzyme: jacobian, make_zero, make_zero!, - onehot + onehot, + runtime_activity DI.check_available(::AutoEnzyme) = true diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index 3dd70bbc7..d5c424380 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -37,6 +37,7 @@ function DI.value_and_pushforward( x_and_dx = Duplicated(x, dx) annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1)) dy, y = autodiff(mode, f_and_df, x_and_dx, annotated_contexts...) + dy = runtime_activity_safeguard(backend, y, dy) return y, (dy,) end @@ -54,8 +55,10 @@ function DI.value_and_pushforward( f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B)) x_and_tx = BatchDuplicated(x, tx) annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B)) - ty, y = autodiff(mode, f_and_df, x_and_tx, annotated_contexts...) - return y, values(ty) + ty_nt, y = autodiff(mode, f_and_df, x_and_tx, annotated_contexts...) + ty = values(ty_nt) + ty = runtime_activity_safeguard(backend, y, ty) + return y, ty end function DI.pushforward( @@ -66,6 +69,9 @@ function DI.pushforward( tx::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} + if has_runtime_activity(backend) + return DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)[2] + end DI.check_prep(f, prep, backend, x, tx, contexts...) (; df, context_shadows) = prep mode = forward_noprimal(backend) @@ -85,14 +91,18 @@ function DI.pushforward( tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} + if has_runtime_activity(backend) + return DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)[2] + end DI.check_prep(f, prep, backend, x, tx, contexts...) (; df, context_shadows) = prep mode = forward_noprimal(backend) f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B)) x_and_tx = BatchDuplicated(x, tx) annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B)) - ty = only(autodiff(mode, f_and_df, x_and_tx, annotated_contexts...)) - return values(ty) + ty_nt = only(autodiff(mode, f_and_df, x_and_tx, annotated_contexts...)) + ty = values(ty_nt) + return ty end function DI.value_and_pushforward!( @@ -168,7 +178,9 @@ function DI.gradient( derivs = gradient( mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=basis_shadows ) - return first(derivs) + deriv = first(derivs) + deriv = runtime_activity_safeguard(backend, x, deriv) + return deriv end function DI.value_and_gradient( @@ -186,7 +198,9 @@ function DI.value_and_gradient( (; derivs, val) = gradient( mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=basis_shadows ) - return val, first(derivs) + deriv = first(derivs) + deriv = runtime_activity_safeguard(backend, x, deriv) + return val, deriv end function DI.gradient!( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index 67b3989f0..3f24eb51b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -7,6 +7,7 @@ function seeded_autodiff_thunk( ) where {ReturnPrimal,FA<:Annotation,RA<:Annotation,N} forward, reverse = autodiff_thunk(rmode, FA, RA, typeof.(args)...) tape, result, shadow_result = forward(f, args...) + shadow_result = runtime_activity_safeguard(rmode, result, shadow_result) if RA <: Active dinputs = only(reverse(f, args..., dresult, tape)) else @@ -30,6 +31,7 @@ function batch_seeded_autodiff_thunk( rmode_rightwidth = ReverseSplitWidth(rmode, Val(B)) forward, reverse = autodiff_thunk(rmode_rightwidth, FA, RA, typeof.(args)...) tape, result, shadow_results = forward(f, args...) + shadow_results = runtime_activity_safeguard(rmode_rightwidth, result, shadow_results) if RA <: Active dinputs = only(reverse(f, args..., dresults, tape)) else diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index 991796bb1..b6bc9a9e2 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -193,3 +193,39 @@ end batchify_activity(::Type{Active{T}}, ::Val{B}) where {T,B} = Active{T} batchify_activity(::Type{Duplicated{T}}, ::Val{B}) where {T,B} = BatchDuplicated{T,B} + +has_runtime_activity(mode::Mode) = runtime_activity(mode) +has_runtime_activity(::AutoEnzyme{Nothing}) = false +has_runtime_activity(backend::AutoEnzyme{<:Mode}) = has_runtime_activity(backend.mode) + +function runtime_activity_safeguard( + backend_or_mode::Union{<:AutoEnzyme,<:Mode}, primal::T, shadow::T +) where {T} + # TODO: improve datatype_pointerfree to take Ptr into account + if has_runtime_activity(backend_or_mode) && + !datatype_pointerfree(T) && + pointer(primal) === pointer(shadow) # TODO: doesn't work beyond arrays + return make_zero(shadow) + else + return shadow + end +end + +function runtime_activity_safeguard( + backend_or_mode::Union{<:AutoEnzyme,<:Mode}, + primal::T, + shadow::Union{NTuple{N,T},NamedTuple}, +) where {T,N} + # TODO: improve datatype_pointerfree to take Ptr into account + if has_runtime_activity(backend_or_mode) && + !datatype_pointerfree(T) && + pointer(primal) === pointer(shadow[1]) # TODO: doesn't work beyond arrays + return make_zero(shadow) + else + return shadow + end +end + +function runtime_activity_safeguard(::Union{<:AutoEnzyme,<:Mode}, primal, shadow::Nothing) + return nothing +end diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index 9f799cbb9..0c90d56ae 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -196,3 +196,19 @@ end excluded=[:jacobian], ) end; + +@testset "Runtime activity" begin + # TODO: higher-level operators not tested + test_differentiation( + AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Forward)), + DIT.unknown_activity(default_scenarios()); + excluded=vcat(SECOND_ORDER, :jacobian, :gradient, :derivative, :pullback), + logging=LOGGING, + ) + test_differentiation( + AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse)), + DIT.unknown_activity(default_scenarios()); + excluded=vcat(SECOND_ORDER, :jacobian, :gradient, :derivative, :pushforward), + logging=LOGGING, + ) +end diff --git a/DifferentiationInterfaceTest/src/scenarios/modify.jl b/DifferentiationInterfaceTest/src/scenarios/modify.jl index 76429253c..60d842c16 100644 --- a/DifferentiationInterfaceTest/src/scenarios/modify.jl +++ b/DifferentiationInterfaceTest/src/scenarios/modify.jl @@ -163,6 +163,54 @@ function closurify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} ) end +struct UnknownActivityReturn{pl_fun,F} + f::F +end + +function Base.show(io::IO, f::UnknownActivityReturn) + return print(io, "UnknownActivityReturn($(f.f))") +end + +function (f::UnknownActivityReturn{:out})(x, yc, return_constant::Bool) + if return_constant + return copy(yc) + else + return f.f(x) + end +end + +function (f::UnknownActivityReturn{:in})(y, x, yc, return_constant::Bool) + if return_constant + copyto!(y, copy(yc)) + else + f.f(y, x) + end + return nothing +end + +""" + unknown_activity(scen::Scenario) + +Return a new scenario identical to `scen` except that the function now takes an additional constant argument which is the theoretical output, and a constant boolean condition stating whether or not that output should be recomputed. +""" +function unknown_activity(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} + (; f) = deepcopy(scen) + zero_scen = deepcopy(zero(scen)) + @assert isempty(scen.contexts) + unknown_f = UnknownActivityReturn{pl_fun,typeof(f)}(f) + return Scenario{op,pl_op,pl_fun}(; + f=unknown_f, + x=scen.x, + y=scen.y, + t=scen.t, + contexts=(Constant(scen.y), Constant(true)), + res1=zero_scen.res1, + res2=zero_scen.res2, + prep_args=(; scen.prep_args..., contexts=(Constant(scen.y), Constant(true))), + name=isnothing(scen.name) ? nothing : scen.name * " [unknown activity]", + ) +end + struct MultiplyByConstant{pl_fun,F} <: FunctionModifier f::F end @@ -366,6 +414,7 @@ closurify(scens::AbstractVector{<:Scenario}) = closurify.(scens) constantify(scens::AbstractVector{<:Scenario}) = constantify.(scens) cachify(scens::AbstractVector{<:Scenario}; use_tuples) = cachify.(scens; use_tuples) constantorcachify(scens::AbstractVector{<:Scenario}) = constantorcachify.(scens) +unknown_activity(scens::AbstractVector{<:Scenario}) = unknown_activity.(scens) ## Compute results with backend diff --git a/DifferentiationInterfaceTest/test/weird.jl b/DifferentiationInterfaceTest/test/weird.jl index f5e88e1ff..898b5a067 100644 --- a/DifferentiationInterfaceTest/test/weird.jl +++ b/DifferentiationInterfaceTest/test/weird.jl @@ -65,6 +65,13 @@ test_differentiation( logging=LOGGING, ); +test_differentiation( + AutoFiniteDiff(), + unknown_activity(default_scenarios); + excluded=SECOND_ORDER, + logging=LOGGING, +); + ## Neural nets test_differentiation(