Skip to content

Commit

Permalink
fix #14 (#16)
Browse files Browse the repository at this point in the history
* fix #14

* add FD

* Update test/runtests.jl

Co-authored-by: Mohamed Tarek <mohamed82008@gmail.com>

---------

Co-authored-by: Mohamed Tarek <mohamed82008@gmail.com>
Co-authored-by: ThummeTo <83663542+ThummeTo@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 12, 2023
1 parent 978c246 commit b301b0c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 5 deletions.
27 changes: 22 additions & 5 deletions src/ForwardDiffChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,19 @@ import ForwardDiff
import MacroTools
using DifferentiableFlatten

macro ForwardDiff_frule(sig)
_fd_frule(sig)
"""
@ForwardDiff_frule signature mutating
`mutating` indicates whether or not the function is mutating the input argument.
# Example
To define a rule for `LinearAlgebra.exp!`, which is a mutating funciton, we call the macro like this
```julia
@ForwardDiff_frule LinearAlgebra.exp!(A::AbstractMatrix{<:ForwardDiff.Dual}) true
```
"""
macro ForwardDiff_frule(sig, mutating=false)
_fd_frule(sig; mutating)
end
export @ForwardDiff_frule

Expand All @@ -59,7 +70,7 @@ end

const cfg = ForwardDiffRuleConfig()

function _fd_frule(sig)
function _fd_frule(sig; mutating=false)
if MacroTools.@capture(sig, f_(x__; k__))
nothing
else
Expand All @@ -77,15 +88,21 @@ function _fd_frule(sig)
flat_xpartials = reduce(vcat, transpose.(ForwardDiff.partials.(flatx)))

xprimals = unflattenx(flat_xprimals)
xprimals_copy = $mutating ? copy.(xprimals) : xprimals
xpartials1 = unflattenx(flat_xpartials[:,1])
yprimals, ypartials1 = ChainRulesCore.frule(
cfg, (NoTangent(), xpartials1...), f, xprimals...; ks...,
cfg, (NoTangent(), xpartials1...), f, xprimals_copy...; ks...,
)
flat_yprimals, unflatteny = ForwardDiffChainRules.DifferentiableFlatten.flatten(yprimals)
flat_ypartials1, _ = ForwardDiffChainRules.DifferentiableFlatten.flatten(ypartials1)
flat_ypartials = hcat(reshape(flat_ypartials1, :, 1), ntuple(Val(CS - 1)) do i
if $mutating
for (xpc, xp) in zip(xprimals_copy, xprimals)
xpc .= xp # Update copy
end
end
xpartialsi = unflattenx(flat_xpartials[:, i+1])
_, ypartialsi = ChainRulesCore.frule(cfg, (NoTangent(), xpartialsi...), f, xprimals...; ks...)
_, ypartialsi = ChainRulesCore.frule(cfg, (NoTangent(), xpartialsi...), f, xprimals_copy...; ks...)
return ForwardDiffChainRules.DifferentiableFlatten.flatten(ypartialsi)[1]
end...)

Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
Expand Down
8 changes: 8 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ SOFTWARE.
@ForwardDiff_frule f1(x1::ForwardDiff.Dual, x2::ForwardDiff.Dual)
@ForwardDiff_frule f1(x1::AbstractVector{<:ForwardDiff.Dual}, x2::AbstractVector{<:ForwardDiff.Dual})
@ForwardDiff_frule f1(x1::AbstractMatrix{<:ForwardDiff.Dual}, x2::AbstractMatrix{<:ForwardDiff.Dual})
@ForwardDiff_frule LinearAlgebra.exp!(A::AbstractMatrix{<:ForwardDiff.Dual}) true

f2(x::NamedTuple, y::NamedTuple) = (a = x.a + y.a, b = x.b + y.b)
f2(x::AbstractVector, y::AbstractVector) = f2.(x, y)
Expand Down Expand Up @@ -168,6 +169,13 @@ SOFTWARE.
@test frule_count == 16
@test norm(g - I) < 1e-6
end
@testset "exp!" begin
fexp = x -> sum(LinearAlgebra.exp!(copy(x)))
X = rand(4, 4)
g = ForwardDiff.gradient(fexp, X)
g2 = FiniteDifferences.grad(central_fdm(5, 1), fexp, X)[1]
@test norm(g-g2) < 1e-4
end
@testset "kwargs" begin
fkwarg(x1, x2; a = 2.0) = x1 * x2 * a
frule_count = 0
Expand Down

0 comments on commit b301b0c

Please sign in to comment.