Skip to content

Commit

Permalink
Merge pull request #87 from dfdx/loop-op-2
Browse files Browse the repository at this point in the history
Loop operation
  • Loading branch information
dfdx authored Jun 11, 2021
2 parents 1fae1a0 + 2ca76a4 commit 3c4bc41
Show file tree
Hide file tree
Showing 7 changed files with 568 additions and 41 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ _*
Manifest.toml
benchmark/*.json
benchmark/Manifest.toml
.vscode/
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[compat]
CUDA = "3"
Expand Down
81 changes: 71 additions & 10 deletions src/compile.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,72 @@
make_name(id::Int) = Symbol("x$id")
make_name(op::AbstractOp) = Symbol("x$(op.id)")
const NEXT_UNIQUE_ID = Ref{Int}(0)
next_unique_id() = (NEXT_UNIQUE_ID[] += 1; NEXT_UNIQUE_ID[])

arg2expr(v::Variable) = make_name(v.id)
arg2expr(s::Symbol) = QuoteNode(s)
arg2expr(c) = c
make_name(id::Int, prefix="") = Symbol("$(prefix)x$id")
make_name(op::AbstractOp, prefix="") = Symbol("$(prefix)x$(op.id)")
make_name(name::String, prefix="") = Symbol("$(prefix)$(name)")

function to_expr(op::Call)
call = Expr(:call, map(arg2expr, (op.fn, op.args...))...)
return Expr(:(=), make_name(op.id), call)
arg2expr(v::Variable, prefix="") = make_name(v.id, prefix)
arg2expr(s::Symbol, prefix="") = QuoteNode(s)
arg2expr(c, prefix="") = c

function to_expr(op::Call, prefix="")
call = Expr(:call, [arg2expr(v, prefix) for v in (op.fn, op.args...)]...)
return Expr(:(=), make_name(op.id, prefix), call)
end

to_expr(op::Constant, prefix="") = :($(make_name(op.id, prefix)) = $(op.val))


function loop_exit_tuple_expr_at_point(op::Loop, id::Int, prefix::String, loop_prefix::String)
exit_name = make_name(op.id, prefix)
arg_vars = loop_exit_vars_at_point(op, id)
arg_names = [make_name(v.id, loop_prefix) for v in arg_vars]
return Expr(:(=), exit_name, Expr(:call, tuple, arg_names...))
end

to_expr(op::Constant) = :($(make_name(op.id)) = $(op.val))

function to_expr(op::Loop, prefix="")
loop_prefix = "l$(next_unique_id())"
exprs = []
# map parent input ids to continue ids
init_var_names = []
for (inp, parent) in zip(inputs(op.subtape), op.parent_inputs)
init_var_name = make_name(inp.id, loop_prefix)
push!(init_var_names, init_var_name)
ex = Expr(:(=), init_var_name, make_name(parent.id, prefix))
push!(exprs, ex)
end
# add exit tuple which will be used in case of zero trip count
init_exit_tuple_ex = loop_exit_tuple_expr_at_point(op, 0, prefix, loop_prefix)
push!(exprs, init_exit_tuple_ex)
loop_ex = :(while true end)
body = loop_ex.args[2]
for (id, subop) in enumerate(op.subtape)
if !isa(subop, Input)
subex = to_expr(subop, loop_prefix)
if subex isa Vector
push!(body.args, subex...)
else
push!(body.args, subex)
end
if id == op.condition.id
exit_expr = :(if !$(make_name(op.condition.id, loop_prefix)) end)
exit_body = exit_expr.args[2]
# update exit tuple
exit_tuple_ex = loop_exit_tuple_expr_at_point(op, id, prefix, loop_prefix)
push!(exit_body.args, exit_tuple_ex)
push!(exit_body.args, Expr(:break))
push!(body.args, exit_expr)
end
end
end
# map continue vars to inputs
for (inp, cont) in zip(inputs(op.subtape), op.cont_vars)
ex = Expr(:(=), make_name(inp.id, loop_prefix), make_name(cont.id, loop_prefix))
push!(body.args, ex)
end
push!(exprs, loop_ex)
end


function to_expr(tape::Tape)
Expand All @@ -23,7 +79,12 @@ function to_expr(tape::Tape)
body = Expr(:block)
for op in tape
op isa Input && continue
push!(body.args, to_expr(op))
ex = to_expr(op)
if ex isa Vector
push!(body.args, ex...)
else
push!(body.args, ex)
end
end
push!(body.args, Expr(:return, make_name(tape.result.id)))
fn_ex = Expr(:function, header, body)
Expand Down
111 changes: 100 additions & 11 deletions src/tape.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,15 @@ mutable struct Tape{C}
ops::Vector{<:AbstractOp}
# result variable
result::Variable
# for subtapes - parent tape
parent::Union{Tape,Nothing}
# tape metadata (depends on the context)
meta::Dict
# application-specific context
c::C
# # derivs[var] == grad_var
# derivs::LittleDict{Variable, Variable}
# # pb_derivs[var] == pullback_var
# pullbacks::LittleDict{Variable, Variable}
c::C
end

Tape(c::C) where C = Tape(AbstractOp[], Variable(0), c)
Tape(c::C) where C = Tape(AbstractOp[], Variable(0), nothing, Dict(), c)
# by default context is just a Dict{Any, Any}
Tape() = Tape(Dict{Any,Any}())

Expand All @@ -205,9 +205,19 @@ end

inputs(tape::Tape) = [V(op) for op in tape.ops if op isa Input]
function inputs!(tape::Tape, vals...)
@assert length(tape) == 0 "Can only set inputs to an empty tape"
for val in vals
push!(tape, Input(val))
@assert(isempty(tape) || length(inputs(tape)) == length(vals),
"This tape contains $(length(inputs(tape))) inputs, but " *
"$(length(vals)) value(s) were provided")
if isempty(tape)
# initialize inputs
for val in vals
push!(tape, Input(val))
end
else
# rewrite input values
for (i, val) in enumerate(vals)
tape[V(i)].val = val
end
end
return [V(op) for op in tape.ops[1:length(vals)]]
end
Expand Down Expand Up @@ -292,6 +302,27 @@ function Base.replace!(tape::Tape, idx_ops::Pair{<:Integer,<:Union{Tuple,Vector}
end


########################################################################
# SPECIAL OPERATIONS #
########################################################################

## Loop

mutable struct Loop <: AbstractOp
id::Int
parent_inputs::Vector{Variable}
condition::Variable
cont_vars::Vector{Variable}
exit_vars::Vector{Variable}
subtape::Tape
val::Any
end

function Base.show(io::IO, loop::Loop)
input_str = join(map(string, loop.parent_inputs), ", ")
print(io, "%$(loop.id) = Loop($input_str)")
end

###############################################################################
# REBIND #
###############################################################################
Expand Down Expand Up @@ -334,9 +365,9 @@ function rebind!(tape::Tape, op::Call, st::Dict)
return op
end


"""
rebind_context!(tape::Tape, st::Dict)
Rebind variables in the tape's context according to substitution table.
By default does nothing, but can be overwitten for specific Tape{C}
"""
Expand Down Expand Up @@ -367,6 +398,64 @@ function exec!(tape::Tape, op::Call)
end


"""
Collect variables which will be used at loop exit if it happens
at this point on tape.
"""
function loop_exit_vars_at_point(op::Loop, id::Int)
input_vars = inputs(op.subtape)
exit_idxs = findall(v -> v in op.exit_vars, op.cont_vars)
vars = Vector{Variable}(undef, length(exit_idxs))
for (i, idx) in enumerate(exit_idxs)
if id > op.cont_vars[idx].id
# if condition is checked after this continue var is changed,
# use continue var
vars[i] = op.cont_vars[idx]
else
# otherwise use input var
vars[i] = input_vars[idx]
end
end
return vars
end


function exec!(tape::Tape, op::Loop)
subtape = op.subtape
# initialize inputs
inputs!(subtape, [tape[v].val for v in op.parent_inputs]...)
# run the loop strictly while continue condition is true
# note that subtape execution may finish before the full
# iteration is done
cond_var = op.condition
vi0 = length(op.parent_inputs) + 1
vi = vi0
while true
# @show vi
# @show subtape[V(1)].val
# @show subtape[V(2)].val
# @show subtape[V(7)].val
# sleep(1)
exec!(subtape, subtape[V(vi)])
if vi == cond_var.id && subtape[V(vi)].val == false
actual_exit_vars = loop_exit_vars_at_point(op, vi)
op.val = ([v._op.val for v in actual_exit_vars]...,)
break
end
vi += 1
if vi > length(subtape)
vi = vi0
inputs!(subtape, [subtape[v].val for v in op.cont_vars]...)
end
end
# # exit_var is special - it's a tuple combining all the exit variables
# # since it doesn't exist in the original code, it may be not executed
# # by loop logic at the last iteration; hence, we execute it manually
# exec!(subtape, subtape[op.exit_var])
# op.val = subtape[op.exit_var].val
end


function play!(tape::Tape, args...; debug=false)
for (i, val) in enumerate(args)
@assert(tape[V(i)] isa Input, "More arguments than the original function had")
Expand All @@ -389,4 +478,4 @@ end
function call_signature(tape::Tape, op::Call)
farg_vals = map_vars(v -> tape[v].val, [op.fn, op.args...])
return Tuple{map(typeof, farg_vals)...}
end
end
Loading

0 comments on commit 3c4bc41

Please sign in to comment.