Skip to content

Commit

Permalink
fix division return type (#233)
Browse files Browse the repository at this point in the history
* fix division

* add tests

* more tests; relax hypot types

* fixes

* fix tests

* use oftype; use one(x)
  • Loading branch information
CarloLucibello authored Jul 10, 2020
1 parent dc3db4a commit d3cd83e
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.7"
version = "0.7.8"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
2 changes: 1 addition & 1 deletion src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ end
),
(!(islow | ishigh), islow, ishigh),
)
@scalar_rule x \ y (-((y / x) / x), inv(x))
@scalar_rule x \ y (-(Ω / x), one(y) / x)

function frule((_, ẏ), ::typeof(identity), x)
return (x, ẏ)
Expand Down
4 changes: 2 additions & 2 deletions src/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,15 @@ let

@scalar_rule x + y (One(), One())
@scalar_rule x - y (One(), -1)
@scalar_rule x / y (inv(y), -((x / y) / y))
@scalar_rule x / y (one(x) / y, - / y))
#log(complex(x)) is required so it gives correct complex answer for x<0
@scalar_rule(x ^ y,
(ifelse(iszero(x), zero(Ω), y * Ω / x), Ω * log(complex(x))),
)
# x^y for x < 0 errors when y is not an integer, but then derivative wrt y
# is undefined, so we adopt subgradient convention and set derivative to 0.
@scalar_rule(x::Real ^ y::Real,
(ifelse(iszero(x), zero(Ω), y * Ω / x), Ω * log(ifelse(x 0, one(x), x))),
(ifelse(iszero(x), zero(Ω), y * Ω / x), Ω * log(oftype(Ω, ifelse(x 0, one(x), x)))),
)
@scalar_rule(
rem(x, y),
Expand Down
24 changes: 23 additions & 1 deletion test/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,29 @@ const FASTABLE_AST = quote
Δz = randn(typeof(f(x, y)))

frule_test(f, (x, Δx), (y, Δy))
rrule_test(f, Δz, (x, x̄), (y, ȳ))
rrule_test(f, Δz, (x, x̄), (y, ȳ))
end

@testset "$f(x::$T, y::$T) type check" for f in (/, +, -,\, hypot, ^), T in (Float32, Float64)
x, Δx, x̄ = 10rand(T, 3)
y, Δy, ȳ = rand(T, 3)
@assert T == typeof(f(x, y))
Δz = randn(typeof(f(x, y)))

@test frule((Zero(), Δx, Δy), f, x, y) isa Tuple{T, T}
_, ∂x, ∂y = rrule(f, x, y)[2](Δz)
@test extern.((∂x, ∂y)) isa Tuple{T, T}

if f != hypot
# Issue #233
@test frule((Zero(), Δx, Δy), f, x, 2) isa Tuple{T, T}
_, ∂x, ∂y = rrule(f, x, 2)[2](Δz)
@test extern.((∂x, ∂y)) isa Tuple{T, T}

@test frule((Zero(), Δx, Δy), f, 2, y) isa Tuple{T, T}
_, ∂x, ∂y = rrule(f, 2, y)[2](Δz)
@test extern.((∂x, ∂y)) isa Tuple{T, T}
end
end

@testset "^(x::$T, n::$T)" for T in (Float64, ComplexF64)
Expand Down

2 comments on commit d3cd83e

@mattBrzezinski
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/17753

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.8 -m "<description of version>" d3cd83e5d202475fbc18de8fadea69fd9042f66b
git push origin v0.7.8

Please sign in to comment.