Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interpolate more in rule helpers and fix escaping of @non_differentiable #325

Merged
merged 9 commits into from
Mar 31, 2021
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.34"
version = "0.9.35"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
53 changes: 29 additions & 24 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials)
if n_outputs > 1
# For forward-mode we return a Composite if output actually a tuple.
pushforward_returns = Expr(
:call, :(ChainRulesCore.Composite{typeof($(esc(:Ω)))}), pushforward_returns...
:call, :(Composite{typeof($(esc(:Ω)))}), pushforward_returns...
)
else
pushforward_returns = first(pushforward_returns)
Expand Down Expand Up @@ -330,53 +330,58 @@ macro non_differentiable(sig_expr)
end

"changes `f(x,y)` into `f(x,y; kwargs....)`"
function _with_kwargs_expr(call_expr::Expr)
function _with_kwargs_expr(call_expr::Expr, kwargs)
@assert isexpr(call_expr, :call)
return Expr(
:call, call_expr.args[1], Expr(:parameters, :(kwargs...)), call_expr.args[2:end]...
:call, call_expr.args[1], Expr(:parameters, :($(kwargs)...)), call_expr.args[2:end]...
)
end

function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke)
return esc(@strip_linenos :(
function ChainRulesCore.frule($(gensym(:_)), $(primal_sig_parts...); kwargs...)
@gensym kwargs
# `::Any` instead of `_`: https://github.com/JuliaLang/julia/issues/32727
return @strip_linenos quote
function ChainRulesCore.frule(::Any, $(map(esc, primal_sig_parts)...); $(esc(kwargs))...)
$(__source__)
# Julia functions always only have 1 output, so return a single DoesNotExist()
return ($(_with_kwargs_expr(primal_invoke)), DoesNotExist())
return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), DoesNotExist())
end
))
end
end

function tuple_expression(primal_sig_parts)
has_vararg = _isvararg(primal_sig_parts[end])
return if !has_vararg
num_primal_inputs = length(primal_sig_parts) - 1 # - primal
Expr(:tuple, ntuple(_->DoesNotExist(), num_primal_inputs)...)
num_primal_inputs = length(primal_sig_parts)
Expr(:tuple, ntuple(_ -> DoesNotExist(), num_primal_inputs)...)
else
num_primal_inputs = length(primal_sig_parts) - 2 # - primal and vararg
length_expr = :($(num_primal_inputs) + length($(_unconstrain(primal_sig_parts[end]))))
Expr(:call, :ntuple, Expr(:(->), :_, DoesNotExist()), length_expr)
num_primal_inputs = length(primal_sig_parts) - 1 # - vararg
length_expr = :($num_primal_inputs + length($(esc(_unconstrain(primal_sig_parts[end])))))
:(ntuple(::Any -> DoesNotExist(), $length_expr))
end
end

function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke)
esc_primal_sig_parts = map(esc, primal_sig_parts)
tup_expr = tuple_expression(primal_sig_parts)
primal_name = first(primal_invoke.args)
pullback_expr = Expr(
:function,
Expr(:call, propagator_name(primal_name, :pullback), :_),
Expr(:tuple, DoesNotExist(), Expr(:(...), tup_expr))
)
return esc(@strip_linenos quote
pullback_expr = @strip_linenos quote
function $(esc(propagator_name(primal_name, :pullback)))(::Any)
return $(tup_expr)
end
end

@gensym kwargs
return @strip_linenos quote
# Manually defined kw version to save compiler work. See explanation in rules.jl
function (::Core.kwftype(typeof(ChainRulesCore.rrule)))(kwargs::Any, rrule::typeof(ChainRulesCore.rrule), $(primal_sig_parts...))
return ($(_with_kwargs_expr(primal_invoke)), $pullback_expr)
function (::Core.kwftype(typeof(rrule)))($(esc(kwargs))::Any, ::typeof(rrule), $(esc_primal_sig_parts...))
return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), $pullback_expr)
end
function ChainRulesCore.rrule($(primal_sig_parts...))
function ChainRulesCore.rrule($(esc_primal_sig_parts...))
$(__source__)
return ($primal_invoke, $pullback_expr)
return ($(esc(primal_invoke)), $pullback_expr)
end
end)
end
end


Expand Down Expand Up @@ -434,7 +439,7 @@ function _split_primal_name(primal_name)
if primal_name isa Symbol || Meta.isexpr(primal_name, :(.)) ||
Meta.isexpr(primal_name, :curly)

primal_name_sig = :(::Core.Typeof($primal_name))
primal_name_sig = :(::$Core.Typeof($primal_name))
return primal_name_sig, primal_name
# e.g. (::T)(x, y)
elseif Meta.isexpr(primal_name, :(::))
Expand Down
65 changes: 51 additions & 14 deletions test/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,21 +254,58 @@ end


module IsolatedModuleForTestingScoping
using Test
# need to make sure macros work in something that hasn't imported all exports
# all that matters is that the following don't error, since they will resolve at
# parse time
using ChainRulesCore: ChainRulesCore
# check that rules can be defined by macros without any additional imports
using ChainRulesCore: @scalar_rule, @non_differentiable

# ensure that functions, types etc. in module `ChainRulesCore` can't be resolved
const ChainRulesCore = nothing

# this is
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/317
fixed(x) = :abc
@non_differentiable fixed(x)

# check name collision between a primal input called `kwargs` and the actual keyword
# arguments
fixed_kwargs(x; kwargs...) = :abc
@non_differentiable fixed_kwargs(kwargs)

my_id(x) = x
@scalar_rule(my_id(x), 1.0)

module IsolatedSubmodule
# check that rules defined in isolated module without imports can be called
# without errors
using ChainRulesCore: frule, rrule, Zero, DoesNotExist
using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id
using Test

@testset "@non_differentiable" begin
for f in (fixed, fixed_kwargs)
y, ẏ = frule((Zero(), randn()), f, randn())
@test y === :abc
@test ẏ === DoesNotExist()

y, f_pullback = rrule(f, randn())
@test y === :abc
@test f_pullback(randn()) === (DoesNotExist(), DoesNotExist())
end

@testset "@non_differentiable" begin
# this is
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/317
fixed(x) = :abc
ChainRulesCore.@non_differentiable fixed(x)
end
y, f_pullback = rrule(fixed_kwargs, randn(); keyword=randn())
@test y === :abc
@test f_pullback(randn()) === (DoesNotExist(), DoesNotExist())
end

@testset "@scalar_rule" begin
my_id(x) = x
ChainRulesCore.@scalar_rule(my_id(x), 1.0)
@testset "@scalar_rule" begin
x, ẋ = randn(2)
y, ẏ = frule((Zero(), ẋ), my_id, x)
@test y == x
@test ẏ == ẋ

Δy = randn()
y, f_pullback = rrule(my_id, x)
@test y == x
@test f_pullback(Δy) == (Zero(), Δy)
end
end
end