Skip to content

Commit

Permalink
Don't use fix1 for enzyme (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Dec 24, 2022
1 parent 7c3bb0a commit da73c64
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/AD_Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ function logdensity_and_gradient(∇ℓ::EnzymeGradientLogDensity{<:Any,<:Enzyme
x::AbstractVector)
@unpack ℓ, mode, shadow = ∇ℓ
_shadow = shadow === nothing ? Enzyme.onehot(x) : shadow
y, ∂ℓ_∂x = Enzyme.autodiff(mode, Base.Fix1(logdensity, ℓ), Enzyme.BatchDuplicated,
y, ∂ℓ_∂x = Enzyme.autodiff(mode, logdensity, Enzyme.BatchDuplicated,
Enzyme.Const(ℓ),
Enzyme.BatchDuplicated(x, _shadow))
return y, collect(∂ℓ_∂x)
end
Expand All @@ -55,7 +56,8 @@ function logdensity_and_gradient(∇ℓ::EnzymeGradientLogDensity{<:Any,<:Enzyme
# Ref: https://github.com/EnzymeAD/Enzyme.jl/issues/107
y = logdensity(ℓ, x)
∂ℓ_∂x = zero(x)
Enzyme.autodiff(mode, Base.Fix1(logdensity, ℓ), Enzyme.Active,
Enzyme.autodiff(mode, logdensity, Enzyme.Active,
Enzyme.Const(ℓ),
Enzyme.Duplicated(x, ∂ℓ_∂x))
y, ∂ℓ_∂x
end

0 comments on commit da73c64

Please sign in to comment.