Skip to content
This repository has been archived by the owner on Jun 24, 2022. It is now read-only.

Commit

Permalink
Merge pull request #28 from SciML/s/fix-rebase
Browse files Browse the repository at this point in the history
Cherry-pick into Refactor plus Fix a promote bug
  • Loading branch information
shashi authored Apr 15, 2020
2 parents 69c416a + e687ae4 commit a3843f7
Show file tree
Hide file tree
Showing 23 changed files with 948 additions and 989 deletions.
15 changes: 8 additions & 7 deletions src/SparsityDetection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@ using Cassette: tag, untag, Tagged, metadata, hasmetadata, istagged, canrecurse
using Cassette: tagged_new_tuple, ContextTagged, BindingMeta, DisableHooks, nametype
using Core: SSAValue

export Sparsity, hsparsity, sparsity!
export Sparsity, jacobian_sparsity, hessian_sparsity, hsparsity, sparsity!

include("program_sparsity.jl")
include("sparsity_tracker.jl")
include("path.jl")
include("take_all_branches.jl")
include("terms.jl")
include("util.jl")
include("controlflow.jl")
include("propagate_tags.jl")
include("linearity.jl")
include("jacobian.jl")
include("hessian.jl")
include("blas.jl")
include("linearity_special.jl")

sparsity!(args...; kwargs...) = jacobian_sparsity(args...; kwargs...)
hsparsity(args...; kwargs...) = hessian_sparsity(args...; kwargs...)

end
39 changes: 6 additions & 33 deletions src/blas.jl
Original file line number Diff line number Diff line change
@@ -1,39 +1,12 @@
# generic implementations

_name(x::Symbol) = x
_name(x::Expr) = (@assert x.head == :(::); x.args[1])
macro reroute(f, g)
fname = f.args[1]
fargs = f.args[2:end]
gname = g.args[1]
gargs = g.args[2:end]
quote
@inline function Cassette.overdub(ctx::SparsityContext,
f::typeof($(esc(fname))),
$(fargs...))
Cassette.recurse(
ctx,
invoke,
$(esc(gname)),
$(esc(:(Tuple{$(gargs...)}))),
$(map(_name, fargs)...))
end
# Forward BLAS calls to generic implementation
#
using LinearAlgebra
import LinearAlgebra.BLAS

@inline function Cassette.overdub(ctx::HessianSparsityContext,
f::typeof($(esc(fname))),
$(fargs...))
Cassette.recurse(
ctx,
invoke,
$(esc(gname)),
$(esc(:(Tuple{$(gargs...)}))),
$(map(_name, fargs)...))
end
end
end
# generic implementations

@reroute LinearAlgebra.BLAS.dot(x,y) LinearAlgebra.dot(Any, Any)
@reroute LinearAlgebra.BLAS.axpy!(x, y) LinearAlgebra.axpy!(Any,
@reroute LinearAlgebra.BLAS.axpy!(a, x, y) LinearAlgebra.axpy!(Any,
AbstractArray,
AbstractArray)

Expand Down
229 changes: 229 additions & 0 deletions src/controlflow.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
#### Path

# First just do it for the case where there we assume
# tainted gotoifnots do not go in a loop!
# TODO: write a thing to detect this! (overdub predicates only in tainted ifs)
# implement snapshotting function state as an optimization for branch exploration
mutable struct Path
path::BitVector
cursor::Int
end

Path() = Path([], 1)

function increment!(bitvec)
for i=1:length(bitvec)
if bitvec[i] === true
bitvec[i] = false
else
bitvec[i] = true
break
end
end
end

function reset!(p::Path)
p.cursor=1
increment!(p.path)
nothing
end

function alldone(p::Path) # must be called at the end of the function!
all(identity, p.path)
end

function current_predicate!(p::Path)
if p.cursor > length(p.path)
push!(p.path, false)
else
p.path[p.cursor]
end
val = p.path[p.cursor]
p.cursor+=1
val
end

alldone(c) = alldone(c.metadata[2])
reset!(c) = reset!(c.metadata[2])
current_predicate!(c) = current_predicate!(c.metadata[2])

#=
julia> p=Path()
Path(Bool[], 1)
julia> alldone(p) # must be called at the end of a full run
true
julia> current_predicate!(p)
false
julia> alldone(p) # must be called at the end of a full run
false
julia> current_predicate!(p)
false
julia> p
Path(Bool[false, false], 3)
julia> alldone(p)
false
julia> reset!(p)
julia> p
Path(Bool[true, false], 1)
julia> current_predicate!(p)
true
julia> current_predicate!(p)
false
julia> alldone(p)
false
julia> reset!(p)
julia> p
Path(Bool[false, true], 1)
julia> current_predicate!(p)
false
julia> current_predicate!(p)
true
julia> reset!(p)
julia> current_predicate!(p)
true
julia> current_predicate!(p)
true
julia> alldone(p)
true
=#

"""
`abstract_run(g, ctx, overdubbed_fn, args...)`
First rewrites every if statement
```julia
if <expr>
...
end
as
```julia
cond = <expr>
if istainted(ctx, cond) ? @amb(true, false) : cond
...
end
```
Then runs `g(Cassette.overdub(ctx, overdubbed_fn, args...)`
as many times as there are available paths. i.e. `2^n` ways
where `n` is the number of tainted branch conditions.
# Examples:
```
meta = Any[]
abstract_run(ctx, f. args...) do result
push!(meta, metadata(result, ctx))
end
# do something to merge metadata from all the runs
```
"""
function abstract_run(acc, ctx::Cassette.Context, overdub_fn, args...; verbose=true)
path = Path()
pass_ctx = Cassette.similarcontext(ctx, metadata=(ctx.metadata, path), pass=AbsintPass)

while true
acc(Cassette.recurse(pass_ctx, ()->overdub_fn(args...)))

verbose && println("Explored path: ", path)
alldone(path) && break
reset!(path)
end
end

"""
`istainted(ctx, cond)`
Does `cond` have any metadata?
"""
function istainted(ctx, cond)
error("Method needed: istainted(::$(typeof(ctx)), ::Bool)." *
" See docs for `istainted`.")
end

# Must return 7 exprs
function rewrite_branch(ctx, stmt, extraslot, i)
# turn
# gotoifnot %p #g
# into
# %t = istainted(%p)
# gotoifnot %t #orig
# %rec = @amb true false
# gotoifnot %rec #orig+1 (the next statement after gotoifnot)

exprs = Any[]
cond = stmt.args[1] # already an SSAValue

# insert a check to see if SSAValue(i) isa Tainted
istainted_ssa = Core.SSAValue(i)
push!(exprs, :($(Expr(:nooverdub, istainted))($(Expr(:contextslot)),
$cond)))

# not tainted? jump to the penultimate statement
push!(exprs, Expr(:gotoifnot, istainted_ssa, i+5))

# tainted? then use current_predicate!(SSAValue(1))
current_pred = i+2
push!(exprs, :($(Expr(:nooverdub, current_predicate!))($(Expr(:contextslot)))))

# Store the interpreter-provided predicate in the slot
push!(exprs, Expr(:(=), extraslot, SSAValue(i+2)))

push!(exprs, Core.GotoNode(i+6))

push!(exprs, Expr(:(=), extraslot, cond))

# here we put in the original code
stmt1 = copy(stmt)
stmt.args[1] = extraslot
push!(exprs, stmt)

exprs
end

function rewrite_ir(ctx, ref)
# turn
# <val> ? t : f
# into
# istainted(<val>) ? current_predicate!(p) : <val> ? t : f

ir = ref.code_info
ir = copy(ir)

extraslot = gensym("tmp")
push!(ir.slotnames, extraslot)
push!(ir.slotflags, 0x00)
extraslot = Core.SlotNumber(length(ir.slotnames))

Cassette.insert_statements!(ir.code, ir.codelocs,
(stmt, i) -> Base.Meta.isexpr(stmt, :gotoifnot) ? 7 : nothing,
(stmt, i) -> rewrite_branch(ctx, stmt, extraslot, i))

ir.ssavaluetypes = length(ir.code)
# Core.Compiler.validate_code(ir)
#@show ref.method
#@show ir
return ir
end

const AbsintPass = Cassette.@pass rewrite_ir
Loading

0 comments on commit a3843f7

Please sign in to comment.