Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Commit

Permalink
Add line information to more generated code (#144)
Browse files Browse the repository at this point in the history
In some places we are willfully omitting line information by passing
`nothing` to certain `Expr`s, and in others we're generating expressions
that are syntactically valid but are not what Julia itself produces.
Part of that is the omission of line information.

This changes the function definition expressions from things like
```julia
f(x::Node{Int}) = Branch(f, (x,), getfield(x, :tape))
```
to
```julia
function f(x::Node{Int})
    #= line info here =#
    Branch(f, (x,), getfield(x, :tape))
end
```
"What's wrong with the former?" you might ask. Nothing in principle,
though if you were to write that in your own code, Julia parses it as
```julia
f(x::Node{Int}) = begin
    #= line info here =#
    Branch(f, (x,), getfield(x, :tape))
end
```
For ease of visual parsing, we'll change it to use `:function`
expression heads, i.e. long-form function definitions, as it's
functionally equivalent.

The benefit of having line information is for backtraces. If something
goes wrong, the user should now be able to get a better idea of where a
problem is, rather than Julia providing a backtrace that just says
something along the lines of "in Nabla."
  • Loading branch information
ararslan authored Mar 25, 2019
1 parent 782fc96 commit 7ee3613
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 77 deletions.
24 changes: 12 additions & 12 deletions src/code_transformation/differentiable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ get_quote_body(code::QuoteNode) = code.value
Unionise the code inside a call to `eval`, such that when the `eval` call actually occurs
the code inside will be unionised.
"""
function unionise_eval(code::Expr)
body = Expr(:macrocall, Symbol("@unionise"), nothing, deepcopy(get_quote_body(code.args[end])))
function unionise_eval(code::Expr, linfo::LineNumberNode=LineNumberNode(0))
body = Expr(:macrocall, Symbol("@unionise"), linfo, deepcopy(get_quote_body(code.args[end])))
return length(code.args) == 3 ?
Expr(:call, :eval, deepcopy(code.args[2]), quot(body)) :
Expr(:call, :eval, quot(body))
Expand All @@ -62,11 +62,11 @@ end
Unionise the code in a call to @eval, such that when the `eval` call actually occurs, the
code inside will be unionised.
"""
function unionise_macro_eval(code::Expr)
body = Expr(:macrocall, Symbol("@unionise"), nothing, deepcopy(code.args[end]))
function unionise_macro_eval(code::Expr, linfo::LineNumberNode=LineNumberNode(0))
body = Expr(:macrocall, Symbol("@unionise"), linfo, deepcopy(code.args[end]))
return length(code.args) == 4 ?
Expr(:macrocall, Symbol("@eval"), nothing, deepcopy(code.args[3]), body) :
Expr(:macrocall, Symbol("@eval"), nothing, body)
Expr(:macrocall, Symbol("@eval"), linfo, deepcopy(code.args[3]), body) :
Expr(:macrocall, Symbol("@eval"), linfo, body)
end

"""
Expand Down Expand Up @@ -125,24 +125,24 @@ arguments. This should not affect the existing functionality of the code.
function unionise end

# If we get a symbol then we cannot have found a function definition, so ignore it.
unionise(code) = code
unionise(code, linfo::LineNumberNode=LineNumberNode(0)) = code

# Recurse through an expression, bottoming out if we find a function definition or a
# quoted expression to be `eval`-ed.
function unionise(code::Expr)
function unionise(code::Expr, linfo::LineNumberNode=LineNumberNode(0))
if code.head in (:function, Symbol("->"))
return Expr(code.head, unionise_sig(code.args[1]), code.args[2])
elseif code.head == Symbol("=") && !isa(code.args[1], Symbol) &&
(get_body(code.args[1]).head == :tuple || get_body(code.args[1]).head isa Symbol)
return Expr(code.head, unionise_sig(code.args[1]), code.args[2])
elseif code.head == :call && code.args[1] == :eval
return unionise_eval(code)
return unionise_eval(code, linfo)
elseif code.head == :macrocall && code.args[1] == Symbol("@eval")
return unionise_macro_eval(code)
return unionise_macro_eval(code, linfo)
elseif code.head == :struct
return unionise_struct(code)
else
return Expr(code.head, [unionise(arg) for arg in code.args]...)
return Expr(code.head, [unionise(arg, linfo) for arg in code.args]...)
end
end

Expand All @@ -153,5 +153,5 @@ Transform code such that each function definition accepts `Node` objects as argu
without effecting dispatch in other ways.
"""
macro unionise(code)
return esc(unionise(code))
return esc(unionise(code, __source__))
end
19 changes: 10 additions & 9 deletions src/sensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ macro explicit_intercepts(
end
insert!(oldcall.args, 2, params)
# The actual function definition
def = Expr(:function, newcall, oldcall)
def = Expr(:function, newcall, Expr(:block, __source__, oldcall))
end
# NOTE: If kws is nonempty, explicit_intercepts will add methods to both f and _f
# See boxed_method
ex = explicit_intercepts(f, type_tuple, isnode; kws...)
ex = explicit_intercepts(f, type_tuple, isnode, __source__; kws...)
# The result contains all method definitions generated for f (and _f if applicable)
return esc(Expr(:block, def, ex))
end
Expand All @@ -92,10 +92,10 @@ Return a `:block` expression which evaluates to declare all of the combinations
that could be required to catch if a `Node` is ever passed to the function specified in
`expr`.
"""
function explicit_intercepts(f::SymOrExpr, types::Expr, is_node::Vector{Bool}; kwargs...)
function explicit_intercepts(f::SymOrExpr, types::Expr, is_node::Vector{Bool}, linfo::LineNumberNode; kwargs...)
function explicit_intercepts_(states::Vector{Bool})
if length(states) == length(is_node)
return any(states) ? boxed_method(f, types, states; kwargs...) : []
return any(states) ? boxed_method(f, types, states, linfo; kwargs...) : []
else
return vcat(
explicit_intercepts_(vcat(states, false)),
Expand Down Expand Up @@ -145,7 +145,8 @@ function boxed_method(
f::SymOrExpr,
type_tuple::Expr,
is_node::Vector{Bool},
arg_names::Vector{Symbol};
arg_names::Vector{Symbol}=[gensym() for _ in is_node],
linfo::LineNumberNode=LineNumberNode(0);
kwargs...
)
# Get the argument types and create the function call.
Expand All @@ -161,7 +162,7 @@ function boxed_method(
body = Expr(:call, :Branch, f, tuple_expr, tape_expr)

# Combine call signature with the body to create a new function.
return Expr(:(=), call, body)
return Expr(:function, call, Expr(:block, linfo, body))
else
_type_tuple = copy(type_tuple)
_is_node = copy(is_node)
Expand All @@ -171,15 +172,15 @@ function boxed_method(
push!(_is_node, false)
push!(_arg_names, k)
end
kw_def = Expr(:function, call, Expr(:call, kwfname(f), _arg_names...))
kw_def = Expr(:function, call, Expr(:block, linfo, Expr(:call, kwfname(f), _arg_names...)))

# Recurse on the internal function to get a Branch call
branch_def = boxed_method(kwfname(f), _type_tuple, _is_node, _arg_names)
branch_def = boxed_method(kwfname(f), _type_tuple, _is_node, _arg_names, linfo)

return Expr(:block, kw_def, branch_def)
end
end
boxed_method(f, t, n; kwargs...) = boxed_method(f, t, n, [gensym() for _ in n]; kwargs...)
boxed_method(f, t, n, l; kwargs...) = boxed_method(f, t, n, [gensym() for _ in n], l; kwargs...)

"""
get_sig(f::SymOrExpr, arg_names::Vector{Symbol}, types::Vector; kwargs...)
Expand Down
108 changes: 52 additions & 56 deletions test/sensitivity.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
@testset "sensitivity" begin
using Base.Meta
using Nabla: boxed_method

import Base.Meta.quot
function expected_func(sig::Expr, body::Expr)
return Expr(:function, sig, Expr(:block, LineNumberNode(0), body))
end

@testset "sensitivity" begin
# # "Test" `Nabla.get_body`. (Not currently unit testing this as it is awkward. Will
# # change this at some point in the future to be more unit-testable.)
# let
Expand All @@ -23,60 +27,52 @@
# println(full_expr)
# end

# Test `Nabla.boxed_method`.
import Nabla.Nabla.boxed_method
let
from_func = boxed_method(:foo, :(Tuple{Any}), [true], [:x1])
expected = Expr(Symbol("="),
:(foo(x1::Node{<:Any})),
:(Branch(foo, (x1,), getfield(x1, $(quot(:tape))))))
@test from_func == expected
end
let
from_func = boxed_method(:foo, :(Tuple{T{V}}), [true], [:x1])
expected = Expr(Symbol("="),
:(foo(x1::Node{<:T{V}})),
:(Branch(foo, (x1,), getfield(x1, $(quot(:tape))))))
@test from_func == expected
end
let
from_func = boxed_method(:foo, :(Tuple{Any, Any}), [true, false], [:x1, :x2])
expected = Expr(Symbol("="),
:(foo(x1::Node{<:Any}, x2::Any)),
:(Branch(foo, (x1, x2), getfield(x1, $(quot(:tape))))))
@test from_func == expected
end
let
from_func = boxed_method(:foo, :(Tuple{Any, Any}), [true, true], [:x1, :x2])
expected = Expr(Symbol("="),
:(foo(x1::Node{<:Any}, x2::Node{<:Any})),
:(Branch(foo, (x1, x2), getfield(x1, $(quot(:tape))))))
@test from_func == expected
end
let
from_func = boxed_method(:foo, :(Tuple{Any, Any}), [false, true], [:x1, :x2])
expected = Expr(Symbol("="),
:(foo(x1::Any, x2::Node{<:Any})),
:(Branch(foo, (x1, x2), getfield(x2, $(quot(:tape))))))
@test from_func == expected
end
let
from_func = boxed_method(:foo, :(Tuple{T} where T), [true], [:x1])
expected = Expr(Symbol("="),
:(foo(x1::Node{<:T}) where T),
:(Branch(foo, (x1,), getfield(x1, $(quot(:tape))))))
@test from_func == expected
end
let
from_func = boxed_method(:foo, :(Tuple{Any, Any}), [false, true], [:x1, :x2]; a=1, b=2)
expected = Expr(:block,
Expr(:function,
:(foo(x1::Any, x2::Node{<:Any}; a=1, b=2)),
:(_foo(x1, x2, a, b))),
Expr(:(=),
:(_foo(x1::Any, x2::Node{<:Any}, a::Any, b::Any)),
:(Branch(_foo, (x1, x2, a, b), getfield(x2, $(quot(:tape)))))))
@test from_func == expected
@testset "boxed_method" begin
let
from_func = boxed_method(:foo, :(Tuple{Any}), [true], [:x1])
expected = expected_func(:(foo(x1::Node{<:Any})),
:(Branch(foo, (x1,), getfield(x1, $(quot(:tape))))))
@test from_func == expected
end
let
from_func = boxed_method(:foo, :(Tuple{T{V}}), [true], [:x1])
expected = expected_func(:(foo(x1::Node{<:T{V}})),
:(Branch(foo, (x1,), getfield(x1, $(quot(:tape))))))
@test from_func == expected
end
let
from_func = boxed_method(:foo, :(Tuple{Any, Any}), [true, false], [:x1, :x2])
expected = expected_func(:(foo(x1::Node{<:Any}, x2::Any)),
:(Branch(foo, (x1, x2), getfield(x1, $(quot(:tape))))))
@test from_func == expected
end
let
from_func = boxed_method(:foo, :(Tuple{Any, Any}), [true, true], [:x1, :x2])
expected = expected_func(:(foo(x1::Node{<:Any}, x2::Node{<:Any})),
:(Branch(foo, (x1, x2), getfield(x1, $(quot(:tape))))))
@test from_func == expected
end
let
from_func = boxed_method(:foo, :(Tuple{Any, Any}), [false, true], [:x1, :x2])
expected = expected_func(:(foo(x1::Any, x2::Node{<:Any})),
:(Branch(foo, (x1, x2), getfield(x2, $(quot(:tape))))))
@test from_func == expected
end
let
from_func = boxed_method(:foo, :(Tuple{T} where T), [true], [:x1])
expected = expected_func(:(foo(x1::Node{<:T}) where T),
:(Branch(foo, (x1,), getfield(x1, $(quot(:tape))))))
@test from_func == expected
end
let
from_func = boxed_method(:foo, :(Tuple{Any, Any}), [false, true], [:x1, :x2]; a=1, b=2)
expected = Expr(:block,
expected_func(:(foo(x1::Any, x2::Node{<:Any}; a=1, b=2)),
:(_foo(x1, x2, a, b))),
expected_func(:(_foo(x1::Any, x2::Node{<:Any}, a::Any, b::Any)),
:(Branch(_foo, (x1, x2, a, b), getfield(x2, $(quot(:tape)))))))
@test from_func == expected
end
end

# Test `Nabla.branch_expr`.
Expand Down

0 comments on commit 7ee3613

Please sign in to comment.