Skip to content

Commit 998db88

Browse files
committed
Patch release for fixing division bug
1 parent 29b1312 commit 998db88

File tree

5 files changed

+12
-53
lines changed

5 files changed

+12
-53
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TaylorDiff"
22
uuid = "b36ab563-344f-407b-a36a-4f200bebf99c"
33
authors = ["Songchen Tan <i@tansongchen.com>"]
4-
version = "0.3.0"
4+
version = "0.3.1"
55

66
[deps]
77
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"

examples/ode.jl

-45
This file was deleted.

src/chainrules.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ function rrule(::typeof(partials), t::TaylorScalar{T, N}) where {N, T}
3030
end
3131

3232
function rrule(::typeof(partials), t::TaylorArray{T, N, A, P}) where {N, T, A, P}
33-
partials_pullback(v̄::NTuple{P, A}) = NoTangent(), TaylorArray(broadcast(zero, v̄[1]), v̄)
33+
function partials_pullback(v̄::NTuple{P, A})
34+
NoTangent(), TaylorArray(broadcast(zero, v̄[1]), v̄)
35+
end
3436
return partials(t), partials_pullback
3537
end
3638

src/utils.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,10 @@ function process(d, expr)
9191
@match x begin
9292
a_[idx_] => a in magic_names ? Symbol(a, idx) : :($a[begin + $idx])
9393
(a_ = b_) => (push!(known_names, a); :($a = $b))
94-
(a_ += b_) => a in known_names ? :($a += $b) : (push!(known_names, a); :($a = $b))
95-
(a_ -= b_) => a in known_names ? :($a -= $b) : (push!(known_names, a); :($a = -$b))
94+
(a_ += b_) => a in known_names ? :($a += $b) :
95+
(push!(known_names, a); :($a = $b))
96+
(a_ -= b_) => a in known_names ? :($a -= $b) :
97+
(push!(known_names, a); :($a = -$b))
9698
TaylorScalar(v_) => :(TaylorScalar(tuple($([Symbol(v, idx) for idx in 0:d[:P]]...))))
9799
_ => x
98100
end

test/primitive.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ end
5050
# end
5151
end
5252

53-
@testset "Multi-argument functions" begin
54-
@test derivative(x -> 1 + 1/x, 1.0, Val(1))-1.0 rtol=1e-6
55-
@test derivative(x -> (x+1)/x, 1.0, Val(1))-1.0 rtol=1e-6
56-
@test derivative(x -> x/x, 1.0, Val(1)) 0.0 rtol=1e-6
53+
@testset "Multi-argument functions" begin
54+
@test derivative(x -> 1 + 1 / x, 1.0, Val(1))-1.0 rtol=1e-6
55+
@test derivative(x -> (x + 1) / x, 1.0, Val(1))-1.0 rtol=1e-6
56+
@test derivative(x -> x / x, 1.0, Val(1))0.0 rtol=1e-6
5757
end
5858

5959
@testset "Corner cases" begin

0 commit comments

Comments
 (0)