Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Merged by Bors] - Sibling PR of introduction of Setfield.jl in AbstractPPL.jl #295

Closed
wants to merge 41 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
1388502
show full timings for evaluation rather than just min
torfjelde Jul 14, 2021
678ef1d
initial work on allowing more than just real and array variables
torfjelde Jul 27, 2021
ddf761c
Merge branch 'master' into tor/allow-non-array-variables
torfjelde Jul 28, 2021
b867d00
ensure that varname uses concretize
torfjelde Jul 29, 2021
6ad5d95
update PointwiseLikelihoodContext
torfjelde Jul 29, 2021
4bf663f
update unwrap_right_left_vns and fix
torfjelde Jul 29, 2021
e4922c9
formatting
torfjelde Jul 29, 2021
ab4b384
fixed doctest
torfjelde Jul 30, 2021
9dadb3a
Merge branch 'tor/allow-non-array-variables' of github.com:TuringLang…
torfjelde Jul 30, 2021
26216e3
forgot to remove escaping some places
torfjelde Jul 31, 2021
405d52c
removed usage of Setfield.set for .= and some other niceties
torfjelde Jul 31, 2021
505d690
fixed a doctests that will inevitably fail on Julia 1.3
torfjelde Jul 31, 2021
9f8c47b
updated a comment
torfjelde Jul 31, 2021
0a47953
added deprecations of the tildes
torfjelde Jul 31, 2021
0c51329
Update src/DynamicPPL.jl
torfjelde Jul 31, 2021
54f8c89
Merge branch 'master' into tor/benchmark-improvements
torfjelde Jul 31, 2021
d3fe07c
Merge branch 'tor/benchmark-improvements' into tor/allow-non-array-va…
torfjelde Jul 31, 2021
8134a1a
use impl of get for VarName instead of the hacky stuff we currently have
torfjelde Aug 1, 2021
9d3c1dd
uncomment commented out test suite
torfjelde Aug 1, 2021
5c4bf0e
Merge branch 'master' into tor/allow-non-array-variables
torfjelde Aug 1, 2021
4121b0e
use BangBang to prefer mutation when using set
torfjelde Aug 1, 2021
a38168a
remove redundant and outdated tests for VarInfo in integration tests
torfjelde Aug 1, 2021
51e7426
formatting
torfjelde Aug 1, 2021
0a3655d
no longer need the custom make_set method after Setfield v0.7.1
torfjelde Aug 1, 2021
f82579e
formatting
torfjelde Aug 1, 2021
9aa7298
drop concretize argument to varname
torfjelde Aug 1, 2021
b130db9
added a couple of additional benchmarks
torfjelde Aug 1, 2021
1ae04b3
Merge branch 'master' into tor/allow-non-array-variables
torfjelde Aug 17, 2021
475da88
fixed tests
torfjelde Aug 17, 2021
4af7e30
formatting
torfjelde Aug 17, 2021
e03ef4e
Update src/context_implementations.jl
torfjelde Aug 17, 2021
9b79b61
Merge branch 'master' into tor/allow-non-array-variables
torfjelde Sep 8, 2021
6dd6de9
bumped APPL compat bound
torfjelde Sep 8, 2021
4c7e882
some bugfixes
torfjelde Sep 9, 2021
4c325c3
updated tests
torfjelde Sep 9, 2021
fa228d8
Merge branch 'master' into tor/allow-non-array-variables
torfjelde Sep 9, 2021
edad225
Merge branch 'master' into tor/allow-non-array-variables
yebai Sep 9, 2021
553ae9b
drop testing begin indexing since incomp with Julia 1.3
torfjelde Sep 9, 2021
00d8411
fixed tests I think
torfjelde Sep 9, 2021
a99a2b1
we try again
torfjelde Sep 9, 2021
472629d
fixed test
torfjelde Sep 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Expand Down
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ using ChainRulesCore: ChainRulesCore
using MacroTools: MacroTools
using ZygoteRules: ZygoteRules

using Setfield: Setfield

using Random: Random

import Base:
Expand Down
182 changes: 137 additions & 45 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function isassumption(expr::Union{Symbol,Expr})
vn = gensym(:vn)

return quote
let $vn = $(varname(expr))
let $vn = $(varname(expr, true))
# This branch should compile nicely in all cases except for partial missing data
# For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
if !$(DynamicPPL.inargnames)($vn, __model__) ||
Expand All @@ -38,7 +38,7 @@ isassumption(expr) = :(false)

# If we're working with, say, a `Symbol`, then we're not going to `view`.
maybe_view(x) = x
maybe_view(x::Expr) = :($(DynamicPPL.maybe_unwrap_view)(@view($x)))
maybe_view(x::Expr) = :($(DynamicPPL.maybe_unwrap_view)(@views($x)))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@views is now needed since we can have property-access, etc. in x.


# If the result of a `view` is a zero-dim array then it's just a
# single element. Likely the rest is expecting type `eltype(x)`, hence
Expand Down Expand Up @@ -90,6 +90,28 @@ left-hand side of a `.~` expression such as `x .~ Normal()`.

This is used mainly to unwrap `NamedDist` distributions and adjust the indices of the
variables.

# Examples
```jldoctest
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(1, 1.0), randn(1, 2), @varname(x)); vns
2-element Vector{VarName{:x, Setfield.IndexLens{Tuple{Colon, Int64}}}}:
x[:,1]
x[:,2]

julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x)); vns
1Γ—2 Matrix{VarName{:x, Setfield.IndexLens{Tuple{Int64, Int64}}}}:
x[1,1] x[1,2]

julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); vns
1Γ—2 Matrix{VarName{:x, Setfield.ComposedLens{Setfield.IndexLens{Tuple{Colon}}, Setfield.IndexLens{Tuple{Int64, Int64}}}}}:
x[:][1,1] x[:][1,2]

julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1])); vns
3-element Vector{VarName{:x, Setfield.ComposedLens{Setfield.IndexLens{Tuple{Int64}}, Setfield.IndexLens{Tuple{Int64}}}}}:
x[1][1]
x[1][2]
x[1][3]
```
"""
unwrap_right_left_vns(right, left, vns) = right, left, vns
function unwrap_right_left_vns(right::NamedDist, left, vns)
Expand All @@ -103,7 +125,7 @@ function unwrap_right_left_vns(
# for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`,
# and we therefore add the `Colon()` below.
vns = map(axes(left, 2)) do i
return VarName(vn, (vn.indexing..., Colon(), Tuple(i)))
return vn ∘ Setfield.IndexLens((Colon(), i))
phipsgabler marked this conversation as resolved.
Show resolved Hide resolved
end
return unwrap_right_left_vns(right, left, vns)
end
Expand All @@ -113,7 +135,7 @@ function unwrap_right_left_vns(
vn::VarName,
)
vns = map(CartesianIndices(left)) do i
return VarName(vn, (vn.indexing..., Tuple(i)))
return vn ∘ Setfield.IndexLens(Tuple(i))
end
return unwrap_right_left_vns(right, left, vns)
end
Expand Down Expand Up @@ -271,6 +293,10 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
# Do not touch interpolated expressions
expr.head === :$ && return expr.args[1]

# Do we don't want escaped expressions because we unfortunately
# escape the entire body afterwards.
Meta.isexpr(expr, :escape) && return generate_mainbody(mod, found, expr.args[1], warn)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we do the wrong thing and escape the entire body of the method in @model, we cannot handle nested escaping. And since Setfield.lensmacro does the right thing, only escaping what it needs to, we run into issues.

This is a hack to essentially ensure that any escaping will be removed. Note that this doesn't break anything because before we couldn't even use escaped expressions within @model, and so I think it's fine for this PR alone. BUT we should really address this properly, i.e. rewrite @model to only escape what it needs to.


# If it's a macro, we expand it
if Meta.isexpr(expr, :macrocall)
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn)
Expand Down Expand Up @@ -303,95 +329,161 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...)
end

function generate_tilde_literal(left, right)
# If the LHS is a literal, it is always an observation
return quote
$(DynamicPPL.tilde_observe!)(
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
)
end
end

"""
generate_tilde(left, right)

Generate an `observe` expression for data variables and `assume` expression for parameter
variables.
"""
function generate_tilde(left, right)
# If the LHS is a literal, it is always an observation
if isliteral(left)
return quote
$(DynamicPPL.tilde_observe!)(
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
)
end
end
isliteral(left) && return generate_tilde_literal(left, right)

# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn inds isassumption
@gensym vn isassumption

return quote
$vn = $(varname(left))
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
$vn = $(remove_escape(varname(left, true)))
$isassumption = $(remove_escape(DynamicPPL.isassumption(left)))
if $isassumption
$left = $(DynamicPPL.tilde_assume!)(
__context__,
$(DynamicPPL.unwrap_right_vn)(
$(DynamicPPL.check_tilde_rhs)($right), $vn
)...,
$inds,
__varinfo__,
)
$(generate_tilde_assume(left, right, vn))
else
$(DynamicPPL.tilde_observe!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($right),
$(maybe_view(left)),
$vn,
$inds,
__varinfo__,
)
end
end
end

function generate_tilde_assume(left::Symbol, right, vn)
return quote
$left = $(DynamicPPL.tilde_assume!)(
__context__,
$(DynamicPPL.unwrap_right_vn)(
$(DynamicPPL.check_tilde_rhs)($right), $vn
)...,
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
__varinfo__,
)
end
end

function generate_tilde_assume(left::Expr, right, vn)
expr = :(
$left = $(DynamicPPL.tilde_assume!)(
__context__,
$(DynamicPPL.unwrap_right_vn)(
$(DynamicPPL.check_tilde_rhs)($right), $vn
)...,
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
__varinfo__,
)
)

return remove_escape(setmacro(identity, expr, overwrite=true))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above. setmacro correctly escapes, we don't, and so we hack.

end

"""
generate_dot_tilde(left, right)

Generate the expression that replaces `left .~ right` in the model body.
"""
function generate_dot_tilde(left, right)
# If the LHS is a literal, it is always an observation
if isliteral(left)
return quote
$(DynamicPPL.dot_tilde_observe!)(
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
)
end
end
isliteral(left) && return generate_tilde_literal(left, right)

# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn inds isassumption
@gensym vn isassumption
return quote
$vn = $(varname(left))
$inds = $(vinds(left))
$vn = $(varname(left, true))
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left .= $(DynamicPPL.dot_tilde_assume!)(
__context__,
$(DynamicPPL.unwrap_right_left_vns)(
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn
)...,
$inds,
__varinfo__,
)
$(generate_dot_tilde_assume(left, right, vn))
else
$(DynamicPPL.dot_tilde_observe!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($right),
$(maybe_view(left)),
$vn,
$inds,
__varinfo__,
)
end
end
end

function generate_dot_tilde_assume(left::Symbol, right, vn)
return :(
$left .= $(DynamicPPL.dot_tilde_assume!)(
__context__,
$(DynamicPPL.unwrap_right_left_vns)(
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn
)...,
__varinfo__,
)
)
end

function generate_dot_tilde_assume(left::Expr, right, vn)
expr = :(
$left .= $(DynamicPPL.dot_tilde_assume!)(
__context__,
$(DynamicPPL.unwrap_right_left_vns)(
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn
)...,
__varinfo__,
)
)

return remove_escape(setmacro(identity, expr, overwrite=true))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end

# HACK: This is unfortunate. It's a consequence of the fact that in
# DynamicPPL we the entire function body. Instead we should be
# more selective with our escape. Until that's the case, we remove them all.
remove_escape(x) = x
function remove_escape(expr::Expr)
Meta.isexpr(expr, :escape) && return remove_escape(expr.args[1])
return Expr(expr.head, map(x -> remove_escape(x), expr.args)...)
end

# TODO: Make PR to Setfield.jl to use `gensym` for the `lens` variable.
# This seems like it should be the case anyways since it allows multiple
# calls to `setmacro` without any cost to the current functionality.
function setmacro(lenstransform, ex::Expr; overwrite::Bool=false)
@assert ex.head isa Symbol
@assert length(ex.args) == 2
ref, val = ex.args
obj, lens = Setfield.parse_obj_lens(ref)
lens_var = gensym("lens")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is essentially copy-paste from Setfields' implementation, but we add this in here.

I believe this will also be sorted out if we fix the "escape EVERYTHING" in @model since then lens should automatically be assigned a unique symbol expanded.

If not we should just make a PR to Setfield.jl. Using gensym seems like it loses nothing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding my above comment about setmacro: if we have copied the function to here anyway, we might as well rename it to something more descriptive.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with that, though my intention is to make a PR to Setfield.jl now that we seem to be going in the direction of using it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dst = overwrite ? obj : gensym("_")
val = esc(val)
ret = if ex.head == :(=)
quote
$lens_var = ($lenstransform)($lens)
$dst = $(Setfield.set)($obj, $lens_var, $val)
end
else
op = get_update_op(ex.head)
f = :($(Setfield._UpdateOp)($op,$val))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
quote
$lens_var = ($lenstransform)($lens)
$dst = $(Setfield.modify)($f, $obj, $lens_var)
end
end
ret
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end

const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}}
hasmissing(T::Type{<:AbstractArray{TA}}) where {TA<:AbstractArray} = hasmissing(TA)
hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true
Expand Down
Loading