Skip to content

Commit

Permalink
Finish testing and cleaning code on non_differentiable macro
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Aug 26, 2020
1 parent fe95249 commit 29ccf55
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 31 deletions.
50 changes: 24 additions & 26 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,9 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
@assert Meta.isexpr(call, :call)

# Annotate all arguments in the signature as scalars
inputs = _constrain_and_name.(call.args[2:end], :Number)

inputs = esc.(_constrain_and_name.(call.args[2:end], :Number))
# Remove annotations and escape names for the call
call.args = _unconstrain.(call.args)
call.args[2:end] .= _unconstrain.(call.args[2:end])
call.args = esc.(call.args)

# For consistency in code that follows we make all partials tuple expressions
Expand Down Expand Up @@ -188,7 +187,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)

# Δs is the input to the propagator rule
# because this is a pull-back there is one per output of function
Δs = [Symbol(string(, i)) for i in 1:n_outputs]
Δs = [Symbol(, i) for i in 1:n_outputs]

# 1 partial derivative per input
pullback_returns = map(1:n_inputs) do input_i
Expand All @@ -199,7 +198,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
# Multi-output functions have pullbacks with a tuple input that will be destructured
pullback_input = n_outputs == 1 ? first(Δs) : Expr(:tuple, Δs...)
pullback = quote
function $(propagator_name(f, :pullback))($pullback_input)
function $(esc(propagator_name(f, :pullback)))($pullback_input)
return (NO_FIELDS, $(pullback_returns...))
end
end
Expand All @@ -225,16 +224,14 @@ function propagation_expr(Δs, ∂s, _conj = false)
∂s = map(esc, ∂s)
n∂s = length(∂s)

# Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression
# literals.
# Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression literals.
∂_mul_Δs = if _conj
ntuple(i->:(conj($(∂s[i])) * $(Δs[i])), n∂s)
else
ntuple(i->:($(∂s[i]) * $(Δs[i])), n∂s)
end

# Avoiding the extra `+` operation, it is potentially expensive for vector
# mode AD.
# Avoiding the extra `+` operation, it is potentially expensive for vector mode AD.
sumed_∂_mul_Δs = if n∂s > 1
# we use `@.` to broadcast `*` and `+`
:(@. +($(∂_mul_Δs...)))
Expand Down Expand Up @@ -275,37 +272,38 @@ macro non_differentiable(call_expr)
primal_name, orig_args = Iterators.peel(call_expr.args)

constrained_args = _constrain_and_name.(orig_args, :Any)
primal_sig_parts = [:(::typeof($primal_name)), constrained_args...]

unconstrained_args = _unconstrain.(constrained_args)
primal_invoke = Expr(:call, esc(primal_name), esc.(unconstrained_args)...)


primal_sig_parts = [:(::typeof($primal_name)), constrained_args...]

quote
$(_nondiff_frule_expr(primal_sig_parts, primal_invoke))
$(_nondiff_rrule_expr(primal_sig_parts, primal_invoke))
end
end

# TODO Move to frule helper
frule_defn = Expr(
function _nondiff_frule_expr(primal_sig_parts, primal_invoke)
return Expr(
:(=),
Expr(:call, :(ChainRulesCore.frule), esc(:_), esc.(primal_sig_parts)...),
# How many outputs we have it doesn't matter: `DoesNotExist()` is a iterator that
# returns `DoesNotExist()` for every position.
# Julia functions always only have 1 output, so just return a single DoesNotExist()
Expr(:tuple, primal_invoke, DoesNotExist())
)
end

# TODO Move to rrule helper

function _nondiff_rrule_expr(primal_sig_parts, primal_invoke)
num_primal_inputs = length(primal_sig_parts) - 1
primal_name = first(primal_invoke.args)
pullback_expr = Expr(
:function,
Expr(:call, esc(propagator_name(primal_name, :pullback)), esc(:_)),
Expr(:tuple, NO_FIELDS, (DoesNotExist() for _ in constrained_args)...)
Expr(:tuple, NO_FIELDS, ntuple(_->DoesNotExist(), num_primal_inputs)...)
)
rrule_defn = Expr(
:(=),
Expr(:call, :(ChainRulesCore.rrule), esc.(primal_sig_parts)...),
Expr(:tuple, primal_invoke, pullback_expr),
)

quote
$frule_defn
$rrule_defn
end
end

return rrule_defn
end
39 changes: 34 additions & 5 deletions test/rule_definition_tools.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,40 @@
@testset "rule_definition_tools.jl" begin

@testset "@nondifferentiable" begin
@testset "@non_differentiable" begin
@testset "nondiff_2_1" begin
nondiff_2_1(x, y) = fill(7.5, 100)[x + y]
@non_differentiable nondiff_2_1(::Any, ::Any)
@test frule((Zero(), 1.2, 2.3), nondiff_2_1, 3, 2) == (7.5, DoesNotExist())
res, pullback = rrule(nondiff_2_1, 3, 2)
@test res == 7.5
@test pullback(4.5) == (NO_FIELDS, DoesNotExist(), DoesNotExist())
end

end
end
@testset "nondiff_1_2" begin
nondiff_1_2(x) = (5.0, 3.0)
@non_differentiable nondiff_1_2(::Any)
@test frule((Zero(), 1.2), nondiff_1_2, 3.1) == ((5.0, 3.0), DoesNotExist())
res, pullback = rrule(nondiff_1_2, 3.1)
@test res == (5.0, 3.0)
@test isequal(
pullback(Composite{Tuple{Float64, Float64}}(1.2, 3.2)),
(NO_FIELDS, DoesNotExist()),
)
end

@testset "specific signature" begin
nonembed_identity(x) = x
@non_differentiable nonembed_identity(::Integer)

@test frule((Zero(), 1.2), nonembed_identity, 2) == (2, DoesNotExist())
@test frule((Zero(), 1.2), nonembed_identity, 2.0) == nothing

Base.remove_linenums!(@macroexpand @non_differentiable println(io::IO))
res, pullback = rrule(nonembed_identity, 2)
@test res == 2
@test pullback(1.2) == (NO_FIELDS, DoesNotExist())

@test rrule(nonembed_identity, 2.0) == nothing
end
end
end

@non_differentiable println(io::IO)

0 comments on commit 29ccf55

Please sign in to comment.