Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reducing reducndancy for primitive functions #250

Open
cscherrer opened this issue Mar 21, 2021 · 0 comments
Open

Reducing reducndancy for primitive functions #250

cscherrer opened this issue Mar 21, 2021 · 0 comments
Assignees

Comments

@cscherrer
Copy link
Owner

"Primitive" here is a term I've been using for functions that use GeneralizedGenerated.jl to generated a function based on a Model and usually some other values. For each of these, there's a source____ function that builds the AST, for example sourceLogdensity and sourceRand.

For example, logdensity is built from

function sourceLogdensity()
    function(_m::Model)
        proc(_m, st :: Assign)     = :($(st.x) = $(st.rhs))
        proc(_m, st :: Return)     = nothing
        proc(_m, st :: LineNumber) = nothing
        function proc(_m, st :: Sample)
            x = st.x
            rhs = st.rhs
            @q begin
                _ℓ += logdensity($rhs, $x)
                $x = Soss.predict($rhs, $x)
            end
        end

        wrap(kernel) = @q begin
            _ℓ = 0.0
            $kernel
            return _ℓ
        end

        buildSource(_m, proc, wrap) |> MacroTools.flatten
    end
end

and rand is built from

function sourceRand() 
    function(_m::Model)
        proc(_m, st::Assign)  = :($(st.x) = $(st.rhs))
        proc(_m, st::Sample)  = :($(st.x) = rand(_rng, $(st.rhs)))
        proc(_m, st::Return)  = :(return $(st.rhs))
        proc(_m, st::LineNumber) = nothing

        vals = map(x -> Expr(:(=), x,x),parameters(_m)) 

        wrap(kernel) = @q begin
            _rng -> begin
                $kernel
                $(Expr(:tuple, vals...))
            end
        end

        buildSource(_m, proc, wrap) |> MacroTools.flatten
    end
end

There's clearly a lot of commonality between these, and also between the many calls to @gg:

chad@albatross ~/g/Soss.jl (dev)> rg @gg
src/importance.jl
167:@gg M function _importanceSample(_::Type{M}, p::Model, _pargs, q::Model, _qargs, _data) where M <: TypeLevel{Module}

src/simulate.jl
124:@gg M function _simulate(_::Type{M}, _m::Model, _args, trace_assignments::Val{V}) where {V, M <: TypeLevel{Module}}
131:@gg M function _simulate(_::Type{M}, _m::Model, _args::NamedTuple{()}, trace_assignments::Val{V}) where {V, M <: TypeLevel{Module}}

src/particles.jl
150:@gg M function _particles(_::Type{M}, _m::Model, _args, _n::Val{_N}) where {M <: TypeLevel{Module},_N}
156:@gg M function _particles(_::Type{M}, _m::Model, _args::NamedTuple{()}, _n::Val{_N}) where {M <: TypeLevel{Module},_N}

src/primitives/likelihood-weighting.jl
38:@gg M function _weightedSample(_::Type{M}, _m::Model, _args, _data) where M <: TypeLevel{Module}

src/primitives/rand.jl
59:@gg M function _rand(_::Type{M}, _m::Model, _args) where M <: TypeLevel{Module}
65:@gg M function _rand(_::Type{M}, _m::Model, _args::NamedTuple{()}) where M <: TypeLevel{Module}

src/primitives/logdensity.jl
43:@gg M function _logdensity(_::Type{M}, _m::Model, _args, _data, _pars) where M <: TypeLevel{Module}

src/primitives/xform.jl
148:@gg M function _xform(_::Type{M}, _m::Model{Asub,B}, _args::A, _data) where {M <: TypeLevel{Module}, Asub, A,B}

src/primitives/entropy.jl
55:@gg M function _entropy(_::Type{M}, _m::Model, _args, _n::Val{_N}) where {M <: TypeLevel{Module},_N}
61:@gg M function _entropy(_::Type{M}, _m::Model, _args::NamedTuple{()}, _n::Val{_N}) where {M <: TypeLevel{Module},_N}

src/symbolic/symbolic.jl
143:@gg M function _symlogdensity(_::Type{M}, _m::Model, ::Type{T}) where {T, M <: TypeLevel{Module}}

src/primitives/basemeasure.jl
40:@gg M function _basemeasure(_::Type{M}, _m::Model, _args, _data, _pars) where M <: TypeLevel{Module}
chad@albatross ~/g/Soss.jl (dev)> 

This makes me wonder, can we put all of this under a common higher-order function? Maybe something like

@gg M function makeprimitive(::Type{M}, _m::Model, f, post, args...)

where f takes the place of proc (since that name's not so descriptive anyway), and args... can hold whatever other arguments are passed. post is a function Expr -> Expr, which in many cases might just add some surrounding context.

Some challenges:

  • The way args is used can change a lot across functions
  • In the past I've found it very tricky to manage what exactly is known at what time. In some cases we need to know values at AST generation time, in other cases just types.

If it can become easier to build new primitives, this will encourage people to use this functionality. I think there's a really great potential if we can do this. Things do get tricky at this degree of abstraction, so we nede to be sure we can completely represent what we have already without losing performance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants