From da73c64ebb429c2494f2f7cfa0fe51873ad45655 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 24 Dec 2022 17:03:22 -0500 Subject: [PATCH] Don't use fix1 for enzyme (#4) --- src/AD_Enzyme.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/AD_Enzyme.jl b/src/AD_Enzyme.jl index 376e532..b7e8cfa 100644 --- a/src/AD_Enzyme.jl +++ b/src/AD_Enzyme.jl @@ -43,7 +43,8 @@ function logdensity_and_gradient(∇ℓ::EnzymeGradientLogDensity{<:Any,<:Enzyme x::AbstractVector) @unpack ℓ, mode, shadow = ∇ℓ _shadow = shadow === nothing ? Enzyme.onehot(x) : shadow - y, ∂ℓ_∂x = Enzyme.autodiff(mode, Base.Fix1(logdensity, ℓ), Enzyme.BatchDuplicated, + y, ∂ℓ_∂x = Enzyme.autodiff(mode, logdensity, Enzyme.BatchDuplicated, + Enzyme.Const(ℓ), Enzyme.BatchDuplicated(x, _shadow)) return y, collect(∂ℓ_∂x) end @@ -55,7 +56,8 @@ function logdensity_and_gradient(∇ℓ::EnzymeGradientLogDensity{<:Any,<:Enzyme # Ref: https://github.com/EnzymeAD/Enzyme.jl/issues/107 y = logdensity(ℓ, x) ∂ℓ_∂x = zero(x) - Enzyme.autodiff(mode, Base.Fix1(logdensity, ℓ), Enzyme.Active, + Enzyme.autodiff(mode, logdensity, Enzyme.Active, + Enzyme.Const(ℓ), Enzyme.Duplicated(x, ∂ℓ_∂x)) y, ∂ℓ_∂x end