From 6dac9a757c7ca88ed083ac5d947fb58c3edd78d3 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 31 Mar 2021 12:14:17 +0200 Subject: [PATCH 1/9] Interpolate everything in rule helpers --- src/rule_definition_tools.jl | 66 +++++++++++++++++++---------------- test/rule_definition_tools.jl | 61 ++++++++++++++++++++++++-------- 2 files changed, 82 insertions(+), 45 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 78e6d8c63..8bc8b5ea6 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -88,8 +88,8 @@ macro scalar_rule(call, maybe_setup, partials...) ############################################################################ # Final return: building the expression to insert in the place of this macro code = quote - if !($f isa Type) && fieldcount(typeof($f)) > 0 - throw(ArgumentError( + if !($f isa $Type) && $(fieldcount)($(typeof)($f)) > 0 + $(throw)($(ArgumentError)( "@scalar_rule cannot be used on closures/functors (such as $($f))" )) end @@ -156,7 +156,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) if n_outputs > 1 # For forward-mode we return a Composite if output actually a tuple. pushforward_returns = Expr( - :call, :(ChainRulesCore.Composite{typeof($(esc(:Ω)))}), pushforward_returns... + :call, :($(Composite){$(typeof)($(esc(:Ω)))}), pushforward_returns... ) else pushforward_returns = first(pushforward_returns) @@ -165,7 +165,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) return @strip_linenos quote # _ is the input derivative w.r.t. function internals. since we do not # allow closures/functors with @scalar_rule, it is always ignored - function ChainRulesCore.frule((_, $(Δs...)), ::typeof($f), $(inputs...)) + function ($ChainRulesCore.frule)(($(esc(:_)), $(Δs...)), ::$(typeof)($f), $(inputs...)) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) @@ -193,12 +193,12 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) pullback = @strip_linenos quote @inline function $(esc(propagator_name(f, :pullback)))($pullback_input) $(__source__) - return (NO_FIELDS, $(pullback_returns...)) + return ($NO_FIELDS, $(pullback_returns...)) end end return @strip_linenos quote - function ChainRulesCore.rrule(::typeof($f), $(inputs...)) + function ($ChainRulesCore.rrule)(::$(typeof)($f), $(inputs...)) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) @@ -223,7 +223,7 @@ function propagation_expr(Δs, ∂s, _conj = false) # This is basically Δs ⋅ ∂s _∂s = map(∂s) do ∂s_i if _conj - :(conj($(esc(∂s_i)))) + :($(conj)($(esc(∂s_i)))) else esc(∂s_i) end @@ -233,11 +233,11 @@ function propagation_expr(Δs, ∂s, _conj = false) summed_∂_mul_Δs = if n∂s > 1 # Explicit multiplication is only performed for the first pair # of partial and gradient. - init_expr = :((*).($(_∂s[1]), $(Δs[1]))) + init_expr = :(($(*)).($(_∂s[1]), $(Δs[1]))) # Apply `muladd` iteratively. foldl(Iterators.drop(zip(_∂s, Δs), 1); init=init_expr) do ex, (∂s_i, Δs_i) - :((muladd).($∂s_i, $Δs_i, $ex)) + :(($(muladd)).($∂s_i, $Δs_i, $ex)) end else # Note: we don't want to do broadcasting with only 1 multiply (no `+`), @@ -330,53 +330,57 @@ macro non_differentiable(sig_expr) end "changes `f(x,y)` into `f(x,y; kwargs....)`" -function _with_kwargs_expr(call_expr::Expr) +function _with_kwargs_expr(call_expr::Expr, kwargs) @assert isexpr(call_expr, :call) return Expr( - :call, call_expr.args[1], Expr(:parameters, :(kwargs...)), call_expr.args[2:end]... + :call, call_expr.args[1], Expr(:parameters, :($(kwargs)...)), call_expr.args[2:end]... ) end function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) - return esc(@strip_linenos :( - function ChainRulesCore.frule($(gensym(:_)), $(primal_sig_parts...); kwargs...) + @gensym kwargs + return @strip_linenos quote + function ($ChainRulesCore.frule)(_, $(map(esc, primal_sig_parts)...); $(esc(kwargs))...) $(__source__) # Julia functions always only have 1 output, so return a single DoesNotExist() - return ($(_with_kwargs_expr(primal_invoke)), DoesNotExist()) + return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), $(DoesNotExist)()) end - )) + end end function tuple_expression(primal_sig_parts) has_vararg = _isvararg(primal_sig_parts[end]) return if !has_vararg - num_primal_inputs = length(primal_sig_parts) - 1 # - primal - Expr(:tuple, ntuple(_->DoesNotExist(), num_primal_inputs)...) + num_primal_inputs = length(primal_sig_parts) + Expr(:tuple, ntuple(_ -> :($(DoesNotExist)()), num_primal_inputs)...) else - num_primal_inputs = length(primal_sig_parts) - 2 # - primal and vararg - length_expr = :($(num_primal_inputs) + length($(_unconstrain(primal_sig_parts[end])))) - Expr(:call, :ntuple, Expr(:(->), :_, DoesNotExist()), length_expr) + num_primal_inputs = length(primal_sig_parts) - 1 # - vararg + length_expr = :($(num_primal_inputs) + $(length)($(esc(_unconstrain(primal_sig_parts[end]))))) + Expr(:call, :ntuple, Expr(:(->), :($(esc(:_))), :($(DoesNotExist)())), length_expr) end end function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke) + esc_primal_sig_parts = map(esc, primal_sig_parts) tup_expr = tuple_expression(primal_sig_parts) primal_name = first(primal_invoke.args) - pullback_expr = Expr( - :function, - Expr(:call, propagator_name(primal_name, :pullback), :_), - Expr(:tuple, DoesNotExist(), Expr(:(...), tup_expr)) - ) - return esc(@strip_linenos quote + pullback_expr = @strip_linenos quote + function $(esc(propagator_name(primal_name, :pullback)))($(esc(:_))) + return $(tup_expr) + end + end + + @gensym kwargs + return @strip_linenos quote # Manually defined kw version to save compiler work. See explanation in rules.jl - function (::Core.kwftype(typeof(ChainRulesCore.rrule)))(kwargs::Any, rrule::typeof(ChainRulesCore.rrule), $(primal_sig_parts...)) - return ($(_with_kwargs_expr(primal_invoke)), $pullback_expr) + function (::$(Core.kwftype)($(typeof)($(rrule))))($(esc(kwargs))::$(Any), ::$(typeof)($(rrule)), $(esc_primal_sig_parts...)) + return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), $(pullback_expr)) end - function ChainRulesCore.rrule($(primal_sig_parts...)) + function ($ChainRulesCore.rrule)($(esc_primal_sig_parts...)) $(__source__) - return ($primal_invoke, $pullback_expr) + return ($(esc(primal_invoke)), $(pullback_expr)) end - end) + end end diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index c563786ea..80dae9112 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -254,21 +254,54 @@ end module IsolatedModuleForTestingScoping - using Test - # need to make sure macros work in something that hasn't imported all exports - # all that matters is that the following don't error, since they will resolve at - # parse time - using ChainRulesCore: ChainRulesCore +using ChainRulesCore: @scalar_rule, @non_differentiable - @testset "@non_differentiable" begin - # this is - # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/317 - fixed(x) = :abc - ChainRulesCore.@non_differentiable fixed(x) - end +# ensure that functions, types etc. in module `ChainRulesCore` can't be resolved +const ChainRulesCore = nothing - @testset "@scalar_rule" begin - my_id(x) = x - ChainRulesCore.@scalar_rule(my_id(x), 1.0) +# this is +# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/317 +fixed(x) = :abc +@non_differentiable fixed(x) + +# check name collision +fixed_kwargs(x; kwargs...) = :abc +@non_differentiable fixed_kwargs(kwargs) + +my_id(x) = x +@scalar_rule(my_id(x), 1.0) + +module IsolatedSubmodule +using Test +using ChainRulesCore: frule, rrule, Zero, DoesNotExist +using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id + +@testset "@non_differentiable" begin + for f in (fixed, fixed_kwargs) + y, ẏ = frule((Zero(), randn()), f, randn()) + @test y === :abc + @test ẏ === DoesNotExist() + + y, f_pullback = rrule(f, randn()) + @test y === :abc + @test f_pullback(randn()) === (DoesNotExist(), DoesNotExist()) end + + y, f_pullback = rrule(fixed_kwargs, randn(); keyword=randn()) + @test y === :abc + @test f_pullback(randn()) === (DoesNotExist(), DoesNotExist()) +end + +@testset "@scalar_rule" begin + x, ẋ = randn(2) + y, ẏ = frule((Zero(), ẋ), my_id, x) + @test y == x + @test ẏ == ẋ + + Δy = randn() + y, f_pullback = rrule(my_id, x) + @test y == x + @test f_pullback(Δy) == (Zero(), Δy) +end +end end From 1a64a05e6af29365d27d09763ab2f0a8983bbbba Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 31 Mar 2021 12:15:15 +0200 Subject: [PATCH 2/9] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e662b1c17..92234b0f6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.34" +version = "0.9.35" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" From 4571cf4e223976fe743ec5c6ff29c44d0dae59ce Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 31 Mar 2021 16:34:51 +0200 Subject: [PATCH 3/9] Less interpolation, simpler code, messier `@macroexpand` output... --- src/rule_definition_tools.jl | 42 ++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 8bc8b5ea6..32d3b034a 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -88,8 +88,8 @@ macro scalar_rule(call, maybe_setup, partials...) ############################################################################ # Final return: building the expression to insert in the place of this macro code = quote - if !($f isa $Type) && $(fieldcount)($(typeof)($f)) > 0 - $(throw)($(ArgumentError)( + if !($f isa Type) && fieldcount(typeof($f)) > 0 + throw(ArgumentError( "@scalar_rule cannot be used on closures/functors (such as $($f))" )) end @@ -156,7 +156,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) if n_outputs > 1 # For forward-mode we return a Composite if output actually a tuple. pushforward_returns = Expr( - :call, :($(Composite){$(typeof)($(esc(:Ω)))}), pushforward_returns... + :call, :(Composite{typeof($(esc(:Ω)))}), pushforward_returns... ) else pushforward_returns = first(pushforward_returns) @@ -165,7 +165,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) return @strip_linenos quote # _ is the input derivative w.r.t. function internals. since we do not # allow closures/functors with @scalar_rule, it is always ignored - function ($ChainRulesCore.frule)(($(esc(:_)), $(Δs...)), ::$(typeof)($f), $(inputs...)) + function ChainRulesCore.frule(($(esc(:_)), $(Δs...)), ::typeof($f), $(inputs...)) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) @@ -193,12 +193,12 @@ function scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) pullback = @strip_linenos quote @inline function $(esc(propagator_name(f, :pullback)))($pullback_input) $(__source__) - return ($NO_FIELDS, $(pullback_returns...)) + return (NO_FIELDS, $(pullback_returns...)) end end return @strip_linenos quote - function ($ChainRulesCore.rrule)(::$(typeof)($f), $(inputs...)) + function ChainRulesCore.rrule(::typeof($f), $(inputs...)) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) @@ -223,7 +223,7 @@ function propagation_expr(Δs, ∂s, _conj = false) # This is basically Δs ⋅ ∂s _∂s = map(∂s) do ∂s_i if _conj - :($(conj)($(esc(∂s_i)))) + :(conj($(esc(∂s_i)))) else esc(∂s_i) end @@ -233,11 +233,11 @@ function propagation_expr(Δs, ∂s, _conj = false) summed_∂_mul_Δs = if n∂s > 1 # Explicit multiplication is only performed for the first pair # of partial and gradient. - init_expr = :(($(*)).($(_∂s[1]), $(Δs[1]))) + init_expr = :((*).($(_∂s[1]), $(Δs[1]))) # Apply `muladd` iteratively. foldl(Iterators.drop(zip(_∂s, Δs), 1); init=init_expr) do ex, (∂s_i, Δs_i) - :(($(muladd)).($∂s_i, $Δs_i, $ex)) + :((muladd).($∂s_i, $Δs_i, $ex)) end else # Note: we don't want to do broadcasting with only 1 multiply (no `+`), @@ -340,10 +340,10 @@ end function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) @gensym kwargs return @strip_linenos quote - function ($ChainRulesCore.frule)(_, $(map(esc, primal_sig_parts)...); $(esc(kwargs))...) + function ChainRulesCore.frule(_, $(map(esc, primal_sig_parts)...); $(esc(kwargs))...) $(__source__) # Julia functions always only have 1 output, so return a single DoesNotExist() - return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), $(DoesNotExist)()) + return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), DoesNotExist()) end end end @@ -352,11 +352,11 @@ function tuple_expression(primal_sig_parts) has_vararg = _isvararg(primal_sig_parts[end]) return if !has_vararg num_primal_inputs = length(primal_sig_parts) - Expr(:tuple, ntuple(_ -> :($(DoesNotExist)()), num_primal_inputs)...) + Expr(:tuple, ntuple(_ -> DoesNotExist(), num_primal_inputs)...) else num_primal_inputs = length(primal_sig_parts) - 1 # - vararg - length_expr = :($(num_primal_inputs) + $(length)($(esc(_unconstrain(primal_sig_parts[end]))))) - Expr(:call, :ntuple, Expr(:(->), :($(esc(:_))), :($(DoesNotExist)())), length_expr) + length_expr = :($num_primal_inputs + length($(esc(_unconstrain(primal_sig_parts[end]))))) + Expr(:call, :ntuple, Expr(:(->), :_, DoesNotExist()), length_expr) end end @@ -364,21 +364,21 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke) esc_primal_sig_parts = map(esc, primal_sig_parts) tup_expr = tuple_expression(primal_sig_parts) primal_name = first(primal_invoke.args) - pullback_expr = @strip_linenos quote + pullback_expr = @strip_linenos :( function $(esc(propagator_name(primal_name, :pullback)))($(esc(:_))) return $(tup_expr) end - end + ) @gensym kwargs return @strip_linenos quote # Manually defined kw version to save compiler work. See explanation in rules.jl - function (::$(Core.kwftype)($(typeof)($(rrule))))($(esc(kwargs))::$(Any), ::$(typeof)($(rrule)), $(esc_primal_sig_parts...)) - return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), $(pullback_expr)) + function (::Core.kwftype(typeof($rrule)))($(esc(kwargs))::Any, ::typeof($rrule), $(esc_primal_sig_parts...)) + return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), $pullback_expr) end - function ($ChainRulesCore.rrule)($(esc_primal_sig_parts...)) + function ChainRulesCore.rrule($(esc_primal_sig_parts...)) $(__source__) - return ($(esc(primal_invoke)), $(pullback_expr)) + return ($(esc(primal_invoke)), $pullback_expr) end end end @@ -438,7 +438,7 @@ function _split_primal_name(primal_name) if primal_name isa Symbol || Meta.isexpr(primal_name, :(.)) || Meta.isexpr(primal_name, :curly) - primal_name_sig = :(::Core.Typeof($primal_name)) + primal_name_sig = :(::$Core.Typeof($primal_name)) return primal_name_sig, primal_name # e.g. (::T)(x, y) elseif Meta.isexpr(primal_name, :(::)) From cf6d283dc8002bb6b0ccb779aaff7baece32999a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 31 Mar 2021 16:35:35 +0200 Subject: [PATCH 4/9] Change indentation of submodules in tests --- test/rule_definition_tools.jl | 80 +++++++++++++++++------------------ 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 80dae9112..a20ca5faf 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -254,54 +254,54 @@ end module IsolatedModuleForTestingScoping -using ChainRulesCore: @scalar_rule, @non_differentiable + using ChainRulesCore: @scalar_rule, @non_differentiable -# ensure that functions, types etc. in module `ChainRulesCore` can't be resolved -const ChainRulesCore = nothing + # ensure that functions, types etc. in module `ChainRulesCore` can't be resolved + const ChainRulesCore = nothing -# this is -# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/317 -fixed(x) = :abc -@non_differentiable fixed(x) + # this is + # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/317 + fixed(x) = :abc + @non_differentiable fixed(x) -# check name collision -fixed_kwargs(x; kwargs...) = :abc -@non_differentiable fixed_kwargs(kwargs) + # check name collision + fixed_kwargs(x; kwargs...) = :abc + @non_differentiable fixed_kwargs(kwargs) -my_id(x) = x -@scalar_rule(my_id(x), 1.0) + my_id(x) = x + @scalar_rule(my_id(x), 1.0) -module IsolatedSubmodule -using Test -using ChainRulesCore: frule, rrule, Zero, DoesNotExist -using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id + module IsolatedSubmodule + using Test + using ChainRulesCore: frule, rrule, Zero, DoesNotExist + using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id -@testset "@non_differentiable" begin - for f in (fixed, fixed_kwargs) - y, ẏ = frule((Zero(), randn()), f, randn()) - @test y === :abc - @test ẏ === DoesNotExist() + @testset "@non_differentiable" begin + for f in (fixed, fixed_kwargs) + y, ẏ = frule((Zero(), randn()), f, randn()) + @test y === :abc + @test ẏ === DoesNotExist() - y, f_pullback = rrule(f, randn()) - @test y === :abc - @test f_pullback(randn()) === (DoesNotExist(), DoesNotExist()) - end + y, f_pullback = rrule(f, randn()) + @test y === :abc + @test f_pullback(randn()) === (DoesNotExist(), DoesNotExist()) + end - y, f_pullback = rrule(fixed_kwargs, randn(); keyword=randn()) - @test y === :abc - @test f_pullback(randn()) === (DoesNotExist(), DoesNotExist()) -end + y, f_pullback = rrule(fixed_kwargs, randn(); keyword=randn()) + @test y === :abc + @test f_pullback(randn()) === (DoesNotExist(), DoesNotExist()) + end -@testset "@scalar_rule" begin - x, ẋ = randn(2) - y, ẏ = frule((Zero(), ẋ), my_id, x) - @test y == x - @test ẏ == ẋ + @testset "@scalar_rule" begin + x, ẋ = randn(2) + y, ẏ = frule((Zero(), ẋ), my_id, x) + @test y == x + @test ẏ == ẋ - Δy = randn() - y, f_pullback = rrule(my_id, x) - @test y == x - @test f_pullback(Δy) == (Zero(), Δy) -end -end + Δy = randn() + y, f_pullback = rrule(my_id, x) + @test y == x + @test f_pullback(Δy) == (Zero(), Δy) + end + end end From c687250a52801776d283c250c41a3946c7af8052 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 31 Mar 2021 16:53:53 +0200 Subject: [PATCH 5/9] Work around JuliaLang/julia/issues/32727 --- src/rule_definition_tools.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 32d3b034a..394fe2413 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -339,8 +339,9 @@ end function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) @gensym kwargs + # `::Any` instead of `_`: https://github.com/JuliaLang/julia/issues/32727 return @strip_linenos quote - function ChainRulesCore.frule(_, $(map(esc, primal_sig_parts)...); $(esc(kwargs))...) + function ChainRulesCore.frule(::Any, $(map(esc, primal_sig_parts)...); $(esc(kwargs))...) $(__source__) # Julia functions always only have 1 output, so return a single DoesNotExist() return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), DoesNotExist()) @@ -356,7 +357,7 @@ function tuple_expression(primal_sig_parts) else num_primal_inputs = length(primal_sig_parts) - 1 # - vararg length_expr = :($num_primal_inputs + length($(esc(_unconstrain(primal_sig_parts[end]))))) - Expr(:call, :ntuple, Expr(:(->), :_, DoesNotExist()), length_expr) + Expr(:call, :ntuple, Expr(:(->), :($(esc(:_))), DoesNotExist()), length_expr) end end From f4fcdf2430b8058bcf3641ffe5424cf07aaa322b Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 31 Mar 2021 17:17:29 +0200 Subject: [PATCH 6/9] Add additional comment --- test/rule_definition_tools.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index a20ca5faf..8021cb3b5 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -264,7 +264,8 @@ module IsolatedModuleForTestingScoping fixed(x) = :abc @non_differentiable fixed(x) - # check name collision + # check name collision between a primal input called `kwargs` and the actual keyword + # arguments fixed_kwargs(x; kwargs...) = :abc @non_differentiable fixed_kwargs(kwargs) From 3f474eecd0823711b0d0698c0133b1676bf6977b Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 31 Mar 2021 19:20:12 +0200 Subject: [PATCH 7/9] Remove remaining `esc(:_)` --- src/rule_definition_tools.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 394fe2413..dc0e063d3 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -165,7 +165,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) return @strip_linenos quote # _ is the input derivative w.r.t. function internals. since we do not # allow closures/functors with @scalar_rule, it is always ignored - function ChainRulesCore.frule(($(esc(:_)), $(Δs...)), ::typeof($f), $(inputs...)) + function ChainRulesCore.frule((_, $(Δs...)), ::typeof($f), $(inputs...)) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) @@ -357,7 +357,7 @@ function tuple_expression(primal_sig_parts) else num_primal_inputs = length(primal_sig_parts) - 1 # - vararg length_expr = :($num_primal_inputs + length($(esc(_unconstrain(primal_sig_parts[end]))))) - Expr(:call, :ntuple, Expr(:(->), :($(esc(:_))), DoesNotExist()), length_expr) + :(ntuple(::Any -> DoesNotExist(), $length_expr)) end end @@ -365,11 +365,11 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke) esc_primal_sig_parts = map(esc, primal_sig_parts) tup_expr = tuple_expression(primal_sig_parts) primal_name = first(primal_invoke.args) - pullback_expr = @strip_linenos :( - function $(esc(propagator_name(primal_name, :pullback)))($(esc(:_))) + pullback_expr = @strip_linenos quote + function $(esc(propagator_name(primal_name, :pullback)))(::Any) return $(tup_expr) end - ) + end @gensym kwargs return @strip_linenos quote From 28ae5e63c9f9ef8abb04f18fcf2843ddf005fcc1 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 31 Mar 2021 19:21:55 +0200 Subject: [PATCH 8/9] Remove unneeded interpolation of `rrule` --- src/rule_definition_tools.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index dc0e063d3..de226ede4 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -374,7 +374,7 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke) @gensym kwargs return @strip_linenos quote # Manually defined kw version to save compiler work. See explanation in rules.jl - function (::Core.kwftype(typeof($rrule)))($(esc(kwargs))::Any, ::typeof($rrule), $(esc_primal_sig_parts...)) + function (::Core.kwftype(typeof(rrule)))($(esc(kwargs))::Any, ::typeof(rrule), $(esc_primal_sig_parts...)) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), $pullback_expr) end function ChainRulesCore.rrule($(esc_primal_sig_parts...)) From c0832783ff0f3be080e84cf84ff66df6400f4164 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 31 Mar 2021 19:26:23 +0200 Subject: [PATCH 9/9] Add explanations to tests --- test/rule_definition_tools.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 8021cb3b5..1f058c4da 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -254,6 +254,7 @@ end module IsolatedModuleForTestingScoping + # check that rules can be defined by macros without any additional imports using ChainRulesCore: @scalar_rule, @non_differentiable # ensure that functions, types etc. in module `ChainRulesCore` can't be resolved @@ -273,9 +274,11 @@ module IsolatedModuleForTestingScoping @scalar_rule(my_id(x), 1.0) module IsolatedSubmodule - using Test + # check that rules defined in isolated module without imports can be called + # without errors using ChainRulesCore: frule, rrule, Zero, DoesNotExist using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id + using Test @testset "@non_differentiable" begin for f in (fixed, fixed_kwargs)