Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interpolate more in rule helpers and fix escaping of @non_differentiable #325

Merged
merged 9 commits into from
Mar 31, 2021

Conversation

devmotion
Copy link
Member

@devmotion devmotion commented Mar 31, 2021

The PR interpolates all types and functions that are not provided by the user more types and functions in the AST in the macros @scalar_rule and @non_differentiable. This fixes the following bug:

julia> using ChainRulesCore: @non_differentiable

julia> @non_differentiable length(x::AbstractVector)
ERROR: UndefVarError: ChainRulesCore not defined
Stacktrace:
 [1] top-level scope
   @ ~/.julia/packages/ChainRulesCore/ASgvC/src/rule_definition_tools.jl:327

Additionally, not the whole output of @non_differentiable is escaped anymore, fixing a possible name collision for arguments of name kwargs:

julia> using ChainRulesCore: ChainRulesCore, @non_differentiable

julia> @non_differentiable length(kwargs::AbstractVector)
ERROR: syntax: function argument name not unique: "kwargs" around /home/david/.julia/packages/ChainRulesCore/ASgvC/src/rule_definition_tools.jl:327
Stacktrace:
 [1] top-level scope
   @ REPL[17]:1

Macro output on the master branch:

julia> Base.remove_linenums!(@macroexpand @scalar_rule sincos(x) cos(x) -sin(x))
quote
    if !(sincos isa ChainRulesCore.Type) && ChainRulesCore.fieldcount(ChainRulesCore.typeof(sincos)) > 0
        ChainRulesCore.throw(ChainRulesCore.ArgumentError("@scalar_rule cannot be used on closures/functors (such as $(sincos))"))
    end
    begin
        function (ChainRulesCore.ChainRulesCore).frule((ChainRulesCore._, var"##Δ1#257"), ::ChainRulesCore.typeof(sincos), x::Number)
            Ω = sincos(x)
            nothing
            return (Ω, (ChainRulesCore.ChainRulesCore).Composite{ChainRulesCore.typeof(Ω)}(cos(x) * var"##Δ1#257", -(sin(x)) * var"##Δ1#257"))
        end
    end
    begin
        function (ChainRulesCore.ChainRulesCore).rrule(::ChainRulesCore.typeof(sincos), x::Number)
            Ω = sincos(x)
            nothing
            return (Ω, begin
                        function sincos_pullback((var"##Δ1#258", var"##Δ2#259"))
                            $(Expr(:meta, :inline))
                            return (ChainRulesCore.NO_FIELDS, ChainRulesCore.muladd.(ChainRulesCore.conj(-(sin(x))), var"##Δ2#259", ChainRulesCore.:*.(ChainRulesCore.conj(cos(x)), var"##Δ1#258")))
                        end
                    end)
        end
    end
end

julia> Base.remove_linenums!(@macroexpand @non_differentiable length(x::Vector))
quote
    function ChainRulesCore.frule(var"##_#260", ::Core.Typeof(length), x::Vector; kwargs...)
        return (length(x; kwargs...), DoesNotExist())
    end
    begin
        function (::Core.kwftype(typeof(ChainRulesCore.rrule)))(kwargs::Any, rrule::typeof(ChainRulesCore.rrule), ::Core.Typeof(length), x::Vector)
            return (length(x; kwargs...), function length_pullback(_)
                        (ChainRulesCore.DoesNotExist(), (ChainRulesCore.DoesNotExist(),)...)
                    end)
        end
        function ChainRulesCore.rrule(::Core.Typeof(length), x::Vector)
            return (length(x), function length_pullback(_)
                        (ChainRulesCore.DoesNotExist(), (ChainRulesCore.DoesNotExist(),)...)
                    end)
        end
    end
end

julia> Base.remove_linenums!(@macroexpand @non_differentiable length(x::Vector...))
quote
    function ChainRulesCore.frule(var"##_#261", ::Core.Typeof(length), x::Vector...; kwargs...)
        return (length(x...; kwargs...), DoesNotExist())
    end
    begin
        function (::Core.kwftype(typeof(ChainRulesCore.rrule)))(kwargs::Any, rrule::typeof(ChainRulesCore.rrule), ::Core.Typeof(length), x::Vector...)
            return (length(x...; kwargs...), function length_pullback(_)
                        (ChainRulesCore.DoesNotExist(), ntuple((_->ChainRulesCore.DoesNotExist()), 0 + length(x))...)
                    end)
        end
        function ChainRulesCore.rrule(::Core.Typeof(length), x::Vector...)
            return (length(x...), function length_pullback(_)
                        (ChainRulesCore.DoesNotExist(), ntuple((_->ChainRulesCore.DoesNotExist()), 0 + length(x))...)
                    end)
        end
    end
end

Macro output with this PR:

julia> Base.remove_linenums!(@macroexpand @scalar_rule sincos(x) cos(x) -sin(x))
quote
    if !(sincos isa ChainRulesCore.Type) && ChainRulesCore.fieldcount(ChainRulesCore.typeof(sincos)) > 0
        ChainRulesCore.throw(ChainRulesCore.ArgumentError("@scalar_rule cannot be used on closures/functors (such as $(sincos))"))
    end
    begin
        function (ChainRulesCore.ChainRulesCore).frule((ChainRulesCore._, var"##Δ1#277"), ::ChainRulesCore.typeof(sincos), x::Number)
            Ω = sincos(x)
            nothing
            return (Ω, ChainRulesCore.Composite{ChainRulesCore.typeof(Ω)}(cos(x) * var"##Δ1#277", -(sin(x)) * var"##Δ1#277"))
        end
    end
    begin
        function (ChainRulesCore.ChainRulesCore).rrule(::ChainRulesCore.typeof(sincos), x::Number)
            Ω = sincos(x)
            nothing
            return (Ω, begin
                        function sincos_pullback((var"##Δ1#278", var"##Δ2#279"))
                            $(Expr(:meta, :inline))
                            return (ChainRulesCore.NO_FIELDS, ChainRulesCore.muladd.(ChainRulesCore.conj(-(sin(x))), var"##Δ2#279", ChainRulesCore.:*.(ChainRulesCore.conj(cos(x)), var"##Δ1#278")))
                        end
                    end)
        end
    end
end

julia> Base.remove_linenums!(@macroexpand @non_differentiable length(x::Vector))
quote
    begin
        function (ChainRulesCore.ChainRulesCore).frule(::ChainRulesCore.Any, ::(Core).Typeof(length), x::Vector; var"##kwargs#280"...)
            return (length(x; var"##kwargs#280"...), ChainRulesCore.DoesNotExist())
        end
    end
    begin
        function (::(ChainRulesCore.Core).kwftype(ChainRulesCore.typeof(ChainRulesCore.rrule)))(var"##kwargs#281"::ChainRulesCore.Any, ::ChainRulesCore.typeof(ChainRulesCore.rrule
), ::(Core).Typeof(length), x::Vector)
            return (length(x; var"##kwargs#281"...), begin
                        function length_pullback(::ChainRulesCore.Any)
                            return (ChainRulesCore.DoesNotExist(), ChainRulesCore.DoesNotExist())
                        end
                    end)
        end
        function (ChainRulesCore.ChainRulesCore).rrule(::(Core).Typeof(length), x::Vector)
            return (length(x), begin
                        function length_pullback(::ChainRulesCore.Any)
                            return (ChainRulesCore.DoesNotExist(), ChainRulesCore.DoesNotExist())
                        end
                    end)
        end
    end
end

julia> Base.remove_linenums!(@macroexpand @non_differentiable length(x::Vector...))
quote
    begin
        function (ChainRulesCore.ChainRulesCore).frule(::ChainRulesCore.Any, ::(Core).Typeof(length), x::Vector...; var"##kwargs#282"...)
            return (length(x...; var"##kwargs#282"...), ChainRulesCore.DoesNotExist())
        end
    end
    begin
        function (::(ChainRulesCore.Core).kwftype(ChainRulesCore.typeof(ChainRulesCore.rrule)))(var"##kwargs#283"::ChainRulesCore.Any, ::ChainRulesCore.typeof(ChainRulesCore.rrule
), ::(Core).Typeof(length), x::Vector...)
            return (length(x...; var"##kwargs#283"...), begin
                        function length_pullback(::ChainRulesCore.Any)
                            return ChainRulesCore.ntuple(((::ChainRulesCore.Any,)->begin
                                            ChainRulesCore.DoesNotExist()
                                        end), 1 + ChainRulesCore.length(x))
                        end
                    end)
        end
        function (ChainRulesCore.ChainRulesCore).rrule(::(Core).Typeof(length), x::Vector...)
            return (length(x...), begin
                        function length_pullback(::ChainRulesCore.Any)
                            return ChainRulesCore.ntuple(((::ChainRulesCore.Any,)->begin
                                            ChainRulesCore.DoesNotExist()
                                        end), 1 + ChainRulesCore.length(x))
                        end
                    end)
        end
    end
end

This fixes #320.

Edit: Updated after the suggestions by the reviewers were incorporated.

@codecov-io
Copy link

codecov-io commented Mar 31, 2021

Codecov Report

Merging #325 (1a64a05) into master (d675be3) will increase coverage by 0.06%.
The diff coverage is 94.11%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #325      +/-   ##
==========================================
+ Coverage   89.76%   89.83%   +0.06%     
==========================================
  Files          13       13              
  Lines         469      472       +3     
==========================================
+ Hits          421      424       +3     
  Misses         48       48              
Impacted Files Coverage Δ
src/rule_definition_tools.jl 96.12% <94.11%> (+0.09%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update d675be3...1a64a05. Read the comment docs.

@mzgubic
Copy link
Member

mzgubic commented Mar 31, 2021

Wow, that's great, thanks for looking into this! I will take a detailed look later today.

For now just a tip (I think from @oxinabox) I've found very useful when debugging macros: Base.remove_linenums!(@macroexpand @non_differentiable length(x::AbstractVector)) removes line numbers from the output and makes it much easier to read.

Comment on lines 379 to 382
function ($ChainRulesCore.rrule)($(esc_primal_sig_parts...))
$(__source__)
return ($primal_invoke, $pullback_expr)
return ($(esc(primal_invoke)), $(pullback_expr))
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is now escaped correctly, there shouldn't be a need to interpolate something like ChainRules anymore. That was just my proposed quick fix, since we didn't handle escaping correctly before.

@@ -165,7 +165,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials)
return @strip_linenos quote
# _ is the input derivative w.r.t. function internals. since we do not
# allow closures/functors with @scalar_rule, it is always ignored
function ChainRulesCore.frule((_, $(Δs...)), ::typeof($f), $(inputs...))
function ($ChainRulesCore.frule)(($(esc(:_)), $(Δs...)), ::$(typeof)($f), $(inputs...))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to escape _

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in https://github.com/JuliaDiff/ChainRulesCore.jl/pull/325/files#r604881501, the main intention was to obtain a (somewhat) cleaner output. I don't mind changing it though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like _ getting mangled by macro hygiene is something that we should fix in Base. It definitely shouldn't get turned into a GlobalRef. Mind opening an issue?

Copy link
Member

@simeonschaub simeonschaub Mar 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, JuliaLang/julia#40280 should fix this, so I think we don't need to escape it here.

Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, basically looks good too me.
Once you have addressed the comments to your satisfaction merge and tag a release

@@ -88,8 +88,8 @@ macro scalar_rule(call, maybe_setup, partials...)
############################################################################
# Final return: building the expression to insert in the place of this macro
code = quote
if !($f isa Type) && fieldcount(typeof($f)) > 0
throw(ArgumentError(
if !($f isa $Type) && $(fieldcount)($(typeof)($f)) > 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you feel about not using brackets that are redudent?
On the one hand less visually noisy, on the other hand less consistent with other symbols interpolated in

Suggested change
if !($f isa $Type) && $(fieldcount)($(typeof)($f)) > 0
if !($f isa $Type) && $fieldcount($typeof($f)) > 0

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got rid of all the interpolations again: #325 (comment)

src/rule_definition_tools.jl Outdated Show resolved Hide resolved
if !($f isa Type) && fieldcount(typeof($f)) > 0
throw(ArgumentError(
if !($f isa $Type) && $(fieldcount)($(typeof)($f)) > 0
$(throw)($(ArgumentError)(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL throw can be overwritten. It isn't a language keyword, just a built in function.
Though there seems like some inconsistency in what we are interpolating in.
In that we are interpolating in throw, but not ! or >, or isa.
All of which i think would be at least as common to shadow as throw (and still incredibly rare)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, now macro hygiene takes care of all these things (the only downside being the messier output of @macroexpand).

# all that matters is that the following don't error, since they will resolve at
# parse time
using ChainRulesCore: ChainRulesCore
using ChainRulesCore: @scalar_rule, @non_differentiable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BlueStyle doesn't have a rule on this, but I think that it should.
And that submodules within another file should be indented.
JuliaDiff/BlueStyle#85

Espeically in this case when there are so many of them nested, that i am having trouble following exactly what is going on.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed the indentation.

fixed(x) = :abc
@non_differentiable fixed(x)

# check name collision
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# check name collision
# check name collision between a primal input called `kwargs` and the actual keyword args

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added your suggestion (unfortunately I missed it initially and couldn't commit the suggestion on Github anymore then).

# Manually defined kw version to save compiler work. See explanation in rules.jl
function (::Core.kwftype(typeof(ChainRulesCore.rrule)))(kwargs::Any, rrule::typeof(ChainRulesCore.rrule), $(primal_sig_parts...))
return ($(_with_kwargs_expr(primal_invoke)), $pullback_expr)
function (::$(Core.kwftype)($(typeof)($(rrule))))($(esc(kwargs))::$(Any), ::$(typeof)($(rrule)), $(esc_primal_sig_parts...))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure you can actually shadow Core

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again this should be handled by macro hygiene.

my_id(x) = x
@scalar_rule(my_id(x), 1.0)

module IsolatedSubmodule
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is being tested here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That all frule and rrules can be called without errors - this would detect e.g. if the macro output would try to look up something in ChainRulesCore in the module the rules were defined in (i.e., in IsolatedModuleForTestingScoping.ChainRulesCore). To avoid having to import anything from ChainRulesCore in addition to @scalar_rule and @non_differentiable, these checks are performed in a separate submodule.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we indicate this in the code with a comment?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done 👍

)
return esc(@strip_linenos quote
pullback_expr = @strip_linenos quote
function $(esc(propagator_name(primal_name, :pullback)))($(esc(:_)))
Copy link
Member

@mzgubic mzgubic Mar 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for changing this as well, it's much clearer now. Similarly here I don't think we need to escape :_

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, escaping :_ ends up with (arguably) cleaner output but makes the implementation a bit annoying to read.

Copy link
Member

@simeonschaub simeonschaub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tried this without all the interpolation, just with the changes to how esc is used? I would prefer being careful with how we use esc instead of just interpolating everything, since that's really how macro hygiene is intended to be used and I find the interpolations make the source code harder to read. I also wouldn't be too worried about the output of @macroexpand, as long as stacktraces are good.

@devmotion
Copy link
Member Author

Have you tried this without all the interpolation, just with the changes to how esc is used? I would prefer being careful with how we use esc instead of just interpolating everything, since that's really how macro hygiene is intended to be used and I find the interpolations make the source code harder to read. I also wouldn't be too worried about the output of @macroexpand, as long as stacktraces are good.

I agree, I already started to reduce the numbers interpolations in my local branch since I had the same feeling. I was quite happy about the (arguably) clean output of @macroexpand but I don't think it's worth it anymore.

@@ -356,7 +357,7 @@ function tuple_expression(primal_sig_parts)
else
num_primal_inputs = length(primal_sig_parts) - 1 # - vararg
length_expr = :($num_primal_inputs + length($(esc(_unconstrain(primal_sig_parts[end])))))
Expr(:call, :ntuple, Expr(:(->), :_, DoesNotExist()), length_expr)
Expr(:call, :ntuple, Expr(:(->), :($(esc(:_))), DoesNotExist()), length_expr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not rely on this as a fix, since this behavior will likely change soon.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you suggest? Using ::Any instead? Or just :_? I also noticed that if esc is removed both in the pullback function signature and this anonymous function, then the macro hygiene will replace both occurrences of _ with the same variable name (something like `var"#249#_").

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, in that case I'd go with ::Any or just something like _tmp.

@devmotion devmotion changed the title Interpolate everything in rule helpers Interpolate ~~everything~~ more in rule helpers and fix escaping of @non_differentiable Mar 31, 2021
@devmotion devmotion changed the title Interpolate ~~everything~~ more in rule helpers and fix escaping of @non_differentiable Interpolate more in rule helpers and fix escaping of @non_differentiable Mar 31, 2021
Copy link
Member

@simeonschaub simeonschaub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just this small nit plus removing all remaining uses of esc(:_) and this looks good to go. Sorry for being so picky here, but I am now very happy with this. Macro hygiene is always a bit difficult to get right, since mistakes can show up in subtle ways or only in edgecases. Thanks for sticking with me!

# Manually defined kw version to save compiler work. See explanation in rules.jl
function (::Core.kwftype(typeof(ChainRulesCore.rrule)))(kwargs::Any, rrule::typeof(ChainRulesCore.rrule), $(primal_sig_parts...))
return ($(_with_kwargs_expr(primal_invoke)), $pullback_expr)
function (::Core.kwftype(typeof($rrule)))($(esc(kwargs))::Any, ::typeof($rrule), $(esc_primal_sig_parts...))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interpolation of rrule can be removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done 👍

Copy link
Member

@simeonschaub simeonschaub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks!

@devmotion
Copy link
Member Author

Would it be useful to add @nospecialize to the discarded argument of type Any in the frule and in the pullback of rrule generated by @nondifferentiable (I guess it doesn't matter for the ntuple expression and can't be applied to the tuple destructuring in the frule of @scalar_rule)?

@simeonschaub
Copy link
Member

Yes, I think that would make sense, but probably better as a separate PR.

@devmotion devmotion requested a review from mzgubic March 31, 2021 17:58
@mzgubic
Copy link
Member

mzgubic commented Mar 31, 2021

all good from my side, @simeonschaub and @oxinabox are far more knowledgeable than me anyway, I wanted to review to learn really

@devmotion devmotion merged commit fb12855 into JuliaDiff:master Mar 31, 2021
@devmotion devmotion deleted the dw/macro_hygiene branch April 1, 2021 07:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Make macros very careful about what they expect to be in scope
5 participants