diff --git a/Project.toml b/Project.toml index cd04fd278..210d77792 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.7" +version = "0.7.8" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 82609236a..cc2481f09 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -138,7 +138,7 @@ end ), (!(islow | ishigh), islow, ishigh), ) -@scalar_rule x \ y (-((y / x) / x), inv(x)) +@scalar_rule x \ y (-(Ω / x), one(y) / x) function frule((_, ẏ), ::typeof(identity), x) return (x, ẏ) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 04bb97bb2..0089ccbc0 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -133,7 +133,7 @@ let @scalar_rule x + y (One(), One()) @scalar_rule x - y (One(), -1) - @scalar_rule x / y (inv(y), -((x / y) / y)) + @scalar_rule x / y (one(x) / y, -(Ω / y)) #log(complex(x)) is required so it gives correct complex answer for x<0 @scalar_rule(x ^ y, (ifelse(iszero(x), zero(Ω), y * Ω / x), Ω * log(complex(x))), @@ -141,7 +141,7 @@ let # x^y for x < 0 errors when y is not an integer, but then derivative wrt y # is undefined, so we adopt subgradient convention and set derivative to 0. @scalar_rule(x::Real ^ y::Real, - (ifelse(iszero(x), zero(Ω), y * Ω / x), Ω * log(ifelse(x ≤ 0, one(x), x))), + (ifelse(iszero(x), zero(Ω), y * Ω / x), Ω * log(oftype(Ω, ifelse(x ≤ 0, one(x), x)))), ) @scalar_rule( rem(x, y), diff --git a/test/rulesets/Base/fastmath_able.jl b/test/rulesets/Base/fastmath_able.jl index 8ccdbe67d..49e30d55f 100644 --- a/test/rulesets/Base/fastmath_able.jl +++ b/test/rulesets/Base/fastmath_able.jl @@ -137,7 +137,29 @@ const FASTABLE_AST = quote Δz = randn(typeof(f(x, y))) frule_test(f, (x, Δx), (y, Δy)) - rrule_test(f, Δz, (x, x̄), (y, ȳ)) + rrule_test(f, Δz, (x, x̄), (y, ȳ)) + end + + @testset "$f(x::$T, y::$T) type check" for f in (/, +, -,\, hypot, ^), T in (Float32, Float64) + x, Δx, x̄ = 10rand(T, 3) + y, Δy, ȳ = rand(T, 3) + @assert T == typeof(f(x, y)) + Δz = randn(typeof(f(x, y))) + + @test frule((Zero(), Δx, Δy), f, x, y) isa Tuple{T, T} + _, ∂x, ∂y = rrule(f, x, y)[2](Δz) + @test extern.((∂x, ∂y)) isa Tuple{T, T} + + if f != hypot + # Issue #233 + @test frule((Zero(), Δx, Δy), f, x, 2) isa Tuple{T, T} + _, ∂x, ∂y = rrule(f, x, 2)[2](Δz) + @test extern.((∂x, ∂y)) isa Tuple{T, T} + + @test frule((Zero(), Δx, Δy), f, 2, y) isa Tuple{T, T} + _, ∂x, ∂y = rrule(f, 2, y)[2](Δz) + @test extern.((∂x, ∂y)) isa Tuple{T, T} + end end @testset "^(x::$T, n::$T)" for T in (Float64, ComplexF64)