Skip to content

Commit e06c8fd

Browse files
committedOct 1, 2024·
Support Enzyme
1 parent 7ff795b commit e06c8fd

File tree

7 files changed

+72
-49
lines changed

7 files changed

+72
-49
lines changed
 

‎Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TaylorDiff"
22
uuid = "b36ab563-344f-407b-a36a-4f200bebf99c"
33
authors = ["Songchen Tan <i@tansongchen.com>"]
4-
version = "0.2.4"
4+
version = "0.2.5"
55

66
[deps]
77
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"

‎benchmark/groups/pinn.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
using Lux, Random, Zygote
1+
using Lux, Zygote
22

33
const input = 2
44
const hidden = 16
55

6-
model = Chain(Dense(input => hidden, exp),
7-
Dense(hidden => hidden, exp),
6+
model = Chain(Dense(input => hidden, Lux.relu),
7+
Dense(hidden => hidden, Lux.relu),
88
Dense(hidden => 1),
99
first)
1010

‎src/chainrules.jl

+10
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,13 @@ for f in (
8686
@eval @opt_out rrule(::typeof($f), x::$tlhs, y::$trhs)
8787
end
8888
end
89+
90+
# Multi-argument functions
91+
92+
@opt_out frule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar)
93+
@opt_out rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar)
94+
95+
@opt_out frule(
96+
::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar, more::TaylorScalar...)
97+
@opt_out rrule(
98+
::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar, more::TaylorScalar...)

‎test/Project.toml

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
[deps]
2+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
3+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
4+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
25
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
6+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
7+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
38
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
49
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5-
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
11+
[compat]
12+
Enzyme = "0.13"

‎test/downstream.jl

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
using LinearAlgebra
2+
import DifferentiationInterface
3+
using DifferentiationInterface: AutoZygote, AutoEnzyme
4+
import Zygote, Enzyme
5+
using FiniteDiff: finite_difference_derivative
6+
7+
DI = DifferentiationInterface
8+
backend = AutoZygote()
9+
# backend = AutoEnzyme(; mode = Enzyme.Reverse, function_annotation = Enzyme.Const)
10+
11+
@testset "Zygote-over-TaylorDiff on same variable" begin
12+
# Scalar functions
13+
some_number = 0.7
14+
some_numbers = [0.3, 0.4, 0.1]
15+
for f in (exp, log, sqrt, sin, asin, sinh, asinh, x -> x^3)
16+
@test DI.derivative(x -> derivative(f, x, 2), backend, some_number)
17+
derivative(f, some_number, 3)
18+
@test DI.jacobian(x -> derivative.(f, x, 2), backend, some_numbers)
19+
diagm(derivative.(f, some_numbers, 3))
20+
end
21+
22+
# Vector functions
23+
g(x) = x[1] * x[1] + x[2] * x[2]
24+
@test DI.gradient(x -> derivative(g, x, [1.0, 0.0], 1), backend, [1.0, 2.0])
25+
[2.0, 0.0]
26+
27+
# Matrix functions
28+
some_matrix = [0.7 0.1; 0.4 0.2]
29+
f(x) = sum(exp.(x), dims = 1)
30+
dfdx1(x) = derivative(f, x, [1.0, 0.0], 1)
31+
dfdx2(x) = derivative(f, x, [0.0, 1.0], 1)
32+
res(x) = sum(dfdx1(x) .+ 2 * dfdx2(x))
33+
grad = DI.gradient(res, backend, some_matrix)
34+
@test grad [1 0; 0 2] * exp.(some_matrix)
35+
end
36+
37+
@testset "Zygote-over-TaylorDiff on different variable" begin
38+
linear_model(x, p, b) = exp.(b + p * x + b)[1]
39+
loss_taylor(x, p, b, v) = derivative(x -> linear_model(x, p, b), x, v, 1)
40+
ε = cbrt(eps(Float64))
41+
loss_finite(x, p, b, v) = (linear_model(x + ε * v, p, b) -
42+
linear_model(x - ε * v, p, b)) / (2 * ε)
43+
let some_x = [0.58, 0.36], some_v = [0.23, 0.11], some_p = [0.49 0.96], some_b = [0.88]
44+
@test DI.gradient(
45+
p -> loss_taylor(some_x, p, some_b, some_v), backend, some_p)
46+
DI.gradient(
47+
p -> loss_finite(some_x, p, some_b, some_v), backend, some_p)
48+
end
49+
end

‎test/runtests.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ using Test
33

44
include("primitive.jl")
55
include("derivative.jl")
6-
include("zygote.jl")
6+
include("downstream.jl")
77
# include("lux.jl")

‎test/zygote.jl

-43
This file was deleted.

2 commit comments

Comments
 (2)

tansongchen commented on Oct 1, 2024

@tansongchen
MemberAuthor

JuliaRegistrator commented on Oct 1, 2024

@JuliaRegistrator

Registration pull request created: JuliaRegistries/General/116427

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.5 -m "<description of version>" e06c8fd365385e9cdf90528eca2b307d1fe789ad
git push origin v0.2.5
Please sign in to comment.