Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Commit

Permalink
Add sensitivities for float (#143)
Browse files Browse the repository at this point in the history
`float` acts on arrays and scalars, turning integers into floats of the
corresponding size (e.g. `float(::Int32)` -> `Float32`) and is a no-op
for floats. This same behavior is applied when differentiating.
  • Loading branch information
ararslan authored Mar 22, 2019
1 parent e3640fc commit 782fc96
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/sensitivities/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,7 @@ end

# Add method to resolve exponentiation ambiguity.
^(n::Node{<:Real}, p::Integer) = invoke(^, Tuple{Node{<:Real}, Real}, n, p)

import Base: float
@explicit_intercepts float Tuple{∇ArrayOrScalar}
(::typeof(float), ::Type{Arg{1}}, p, y, ȳ, x) = float(ȳ)
20 changes: 20 additions & 0 deletions test/sensitivities/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,24 @@ end
# Test whether the exponentiation amibiguity is resolved.
@test (x -> x^2)(1) == (2.0,)
end

@testset "float" begin
# Scalars
x_ = 4
x = Leaf(Tape(), x_)
y = float(x)
@test y isa Branch{Float64}
@test unbox(y) == 4.0

# Arrays
X_ = [1,2,3,4]
X = Leaf(Tape(), X_)
Y = float(X)
@test Y isa Branch{Vector{Float64}}
@test unbox(Y) == Float64[1,2,3,4]

# In expressions
@test (x->2x)(1) === (2,)
@test (x->2*float(x))(1) === (2.0,)
end
end

0 comments on commit 782fc96

Please sign in to comment.