From 782fc9650cde202a9b26bf32dd9f8892f391f6ce Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Fri, 22 Mar 2019 10:47:24 -0700 Subject: [PATCH] Add sensitivities for float (#143) `float` acts on arrays and scalars, turning integers into floats of the corresponding size (e.g. `float(::Int32)` -> `Float32`) and is a no-op for floats. This same behavior is applied when differentiating. --- src/sensitivities/scalar.jl | 4 ++++ test/sensitivities/scalar.jl | 20 ++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/src/sensitivities/scalar.jl b/src/sensitivities/scalar.jl index 1ea7ae75..531854b6 100644 --- a/src/sensitivities/scalar.jl +++ b/src/sensitivities/scalar.jl @@ -44,3 +44,7 @@ end # Add method to resolve exponentiation ambiguity. ^(n::Node{<:Real}, p::Integer) = invoke(^, Tuple{Node{<:Real}, Real}, n, p) + +import Base: float +@explicit_intercepts float Tuple{∇ArrayOrScalar} +∇(::typeof(float), ::Type{Arg{1}}, p, y, ȳ, x) = float(ȳ) diff --git a/test/sensitivities/scalar.jl b/test/sensitivities/scalar.jl index 214cdd45..b1f6d0a5 100644 --- a/test/sensitivities/scalar.jl +++ b/test/sensitivities/scalar.jl @@ -91,4 +91,24 @@ end # Test whether the exponentiation amibiguity is resolved. @test ∇(x -> x^2)(1) == (2.0,) end + + @testset "float" begin + # Scalars + x_ = 4 + x = Leaf(Tape(), x_) + y = float(x) + @test y isa Branch{Float64} + @test unbox(y) == 4.0 + + # Arrays + X_ = [1,2,3,4] + X = Leaf(Tape(), X_) + Y = float(X) + @test Y isa Branch{Vector{Float64}} + @test unbox(Y) == Float64[1,2,3,4] + + # In expressions + @test ∇(x->2x)(1) === (2,) + @test ∇(x->2*float(x))(1) === (2.0,) + end end