diff --git a/.gitignore b/.gitignore index 3841d7a..1fe9bf4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ _* Manifest.toml benchmark/*.json benchmark/Manifest.toml +.vscode/ \ No newline at end of file diff --git a/Project.toml b/Project.toml index 4b58513..9f684b9 100644 --- a/Project.toml +++ b/Project.toml @@ -15,6 +15,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [compat] CUDA = "3" diff --git a/src/compile.jl b/src/compile.jl index b17c307..e9952d2 100644 --- a/src/compile.jl +++ b/src/compile.jl @@ -1,16 +1,72 @@ -make_name(id::Int) = Symbol("x$id") -make_name(op::AbstractOp) = Symbol("x$(op.id)") +const NEXT_UNIQUE_ID = Ref{Int}(0) +next_unique_id() = (NEXT_UNIQUE_ID[] += 1; NEXT_UNIQUE_ID[]) -arg2expr(v::Variable) = make_name(v.id) -arg2expr(s::Symbol) = QuoteNode(s) -arg2expr(c) = c +make_name(id::Int, prefix="") = Symbol("$(prefix)x$id") +make_name(op::AbstractOp, prefix="") = Symbol("$(prefix)x$(op.id)") +make_name(name::String, prefix="") = Symbol("$(prefix)$(name)") -function to_expr(op::Call) - call = Expr(:call, map(arg2expr, (op.fn, op.args...))...) - return Expr(:(=), make_name(op.id), call) +arg2expr(v::Variable, prefix="") = make_name(v.id, prefix) +arg2expr(s::Symbol, prefix="") = QuoteNode(s) +arg2expr(c, prefix="") = c + +function to_expr(op::Call, prefix="") + call = Expr(:call, [arg2expr(v, prefix) for v in (op.fn, op.args...)]...) + return Expr(:(=), make_name(op.id, prefix), call) +end + +to_expr(op::Constant, prefix="") = :($(make_name(op.id, prefix)) = $(op.val)) + + +function loop_exit_tuple_expr_at_point(op::Loop, id::Int, prefix::String, loop_prefix::String) + exit_name = make_name(op.id, prefix) + arg_vars = loop_exit_vars_at_point(op, id) + arg_names = [make_name(v.id, loop_prefix) for v in arg_vars] + return Expr(:(=), exit_name, Expr(:call, tuple, arg_names...)) end -to_expr(op::Constant) = :($(make_name(op.id)) = $(op.val)) + +function to_expr(op::Loop, prefix="") + loop_prefix = "l$(next_unique_id())" + exprs = [] + # map parent input ids to continue ids + init_var_names = [] + for (inp, parent) in zip(inputs(op.subtape), op.parent_inputs) + init_var_name = make_name(inp.id, loop_prefix) + push!(init_var_names, init_var_name) + ex = Expr(:(=), init_var_name, make_name(parent.id, prefix)) + push!(exprs, ex) + end + # add exit tuple which will be used in case of zero trip count + init_exit_tuple_ex = loop_exit_tuple_expr_at_point(op, 0, prefix, loop_prefix) + push!(exprs, init_exit_tuple_ex) + loop_ex = :(while true end) + body = loop_ex.args[2] + for (id, subop) in enumerate(op.subtape) + if !isa(subop, Input) + subex = to_expr(subop, loop_prefix) + if subex isa Vector + push!(body.args, subex...) + else + push!(body.args, subex) + end + if id == op.condition.id + exit_expr = :(if !$(make_name(op.condition.id, loop_prefix)) end) + exit_body = exit_expr.args[2] + # update exit tuple + exit_tuple_ex = loop_exit_tuple_expr_at_point(op, id, prefix, loop_prefix) + push!(exit_body.args, exit_tuple_ex) + push!(exit_body.args, Expr(:break)) + push!(body.args, exit_expr) + end + end + end + # map continue vars to inputs + for (inp, cont) in zip(inputs(op.subtape), op.cont_vars) + ex = Expr(:(=), make_name(inp.id, loop_prefix), make_name(cont.id, loop_prefix)) + push!(body.args, ex) + end + push!(exprs, loop_ex) +end function to_expr(tape::Tape) @@ -23,7 +79,12 @@ function to_expr(tape::Tape) body = Expr(:block) for op in tape op isa Input && continue - push!(body.args, to_expr(op)) + ex = to_expr(op) + if ex isa Vector + push!(body.args, ex...) + else + push!(body.args, ex) + end end push!(body.args, Expr(:return, make_name(tape.result.id))) fn_ex = Expr(:function, header, body) diff --git a/src/tape.jl b/src/tape.jl index 1ba454c..7f2b98c 100644 --- a/src/tape.jl +++ b/src/tape.jl @@ -173,15 +173,15 @@ mutable struct Tape{C} ops::Vector{<:AbstractOp} # result variable result::Variable + # for subtapes - parent tape + parent::Union{Tape,Nothing} + # tape metadata (depends on the context) + meta::Dict # application-specific context - c::C - # # derivs[var] == grad_var - # derivs::LittleDict{Variable, Variable} - # # pb_derivs[var] == pullback_var - # pullbacks::LittleDict{Variable, Variable} +c::C end -Tape(c::C) where C = Tape(AbstractOp[], Variable(0), c) +Tape(c::C) where C = Tape(AbstractOp[], Variable(0), nothing, Dict(), c) # by default context is just a Dict{Any, Any} Tape() = Tape(Dict{Any,Any}()) @@ -205,9 +205,19 @@ end inputs(tape::Tape) = [V(op) for op in tape.ops if op isa Input] function inputs!(tape::Tape, vals...) - @assert length(tape) == 0 "Can only set inputs to an empty tape" - for val in vals - push!(tape, Input(val)) + @assert(isempty(tape) || length(inputs(tape)) == length(vals), + "This tape contains $(length(inputs(tape))) inputs, but " * + "$(length(vals)) value(s) were provided") + if isempty(tape) + # initialize inputs + for val in vals + push!(tape, Input(val)) + end + else + # rewrite input values + for (i, val) in enumerate(vals) + tape[V(i)].val = val + end end return [V(op) for op in tape.ops[1:length(vals)]] end @@ -292,6 +302,27 @@ function Base.replace!(tape::Tape, idx_ops::Pair{<:Integer,<:Union{Tuple,Vector} end +######################################################################## +# SPECIAL OPERATIONS # +######################################################################## + +## Loop + +mutable struct Loop <: AbstractOp + id::Int + parent_inputs::Vector{Variable} + condition::Variable + cont_vars::Vector{Variable} + exit_vars::Vector{Variable} + subtape::Tape + val::Any +end + +function Base.show(io::IO, loop::Loop) + input_str = join(map(string, loop.parent_inputs), ", ") + print(io, "%$(loop.id) = Loop($input_str)") +end + ############################################################################### # REBIND # ############################################################################### @@ -334,9 +365,9 @@ function rebind!(tape::Tape, op::Call, st::Dict) return op end + """ rebind_context!(tape::Tape, st::Dict) - Rebind variables in the tape's context according to substitution table. By default does nothing, but can be overwitten for specific Tape{C} """ @@ -367,6 +398,64 @@ function exec!(tape::Tape, op::Call) end +""" +Collect variables which will be used at loop exit if it happens +at this point on tape. +""" +function loop_exit_vars_at_point(op::Loop, id::Int) + input_vars = inputs(op.subtape) + exit_idxs = findall(v -> v in op.exit_vars, op.cont_vars) + vars = Vector{Variable}(undef, length(exit_idxs)) + for (i, idx) in enumerate(exit_idxs) + if id > op.cont_vars[idx].id + # if condition is checked after this continue var is changed, + # use continue var + vars[i] = op.cont_vars[idx] + else + # otherwise use input var + vars[i] = input_vars[idx] + end + end + return vars +end + + +function exec!(tape::Tape, op::Loop) + subtape = op.subtape + # initialize inputs + inputs!(subtape, [tape[v].val for v in op.parent_inputs]...) + # run the loop strictly while continue condition is true + # note that subtape execution may finish before the full + # iteration is done + cond_var = op.condition + vi0 = length(op.parent_inputs) + 1 + vi = vi0 + while true + # @show vi + # @show subtape[V(1)].val + # @show subtape[V(2)].val + # @show subtape[V(7)].val + # sleep(1) + exec!(subtape, subtape[V(vi)]) + if vi == cond_var.id && subtape[V(vi)].val == false + actual_exit_vars = loop_exit_vars_at_point(op, vi) + op.val = ([v._op.val for v in actual_exit_vars]...,) + break + end + vi += 1 + if vi > length(subtape) + vi = vi0 + inputs!(subtape, [subtape[v].val for v in op.cont_vars]...) + end + end + # # exit_var is special - it's a tuple combining all the exit variables + # # since it doesn't exist in the original code, it may be not executed + # # by loop logic at the last iteration; hence, we execute it manually + # exec!(subtape, subtape[op.exit_var]) + # op.val = subtape[op.exit_var].val +end + + function play!(tape::Tape, args...; debug=false) for (i, val) in enumerate(args) @assert(tape[V(i)] isa Input, "More arguments than the original function had") @@ -389,4 +478,4 @@ end function call_signature(tape::Tape, op::Call) farg_vals = map_vars(v -> tape[v].val, [op.fn, op.args...]) return Tuple{map(typeof, farg_vals)...} -end \ No newline at end of file +end diff --git a/src/trace.jl b/src/trace.jl index 934daf9..6f6ddef 100644 --- a/src/trace.jl +++ b/src/trace.jl @@ -1,5 +1,6 @@ import IRTools import IRTools: IR, @dynamo, self, insertafter! +import UUIDs: UUID, uuid1 function module_functions(modl) @@ -63,6 +64,34 @@ function Base.show(io::IO, fr::Frame) end +################################################################################ +# Tracer Options # +################################################################################ + +const TRACING_OPTIONS = Ref(Dict()) + +""" + should_trace_loops!(val=false) + +Turn on/off loop tracing. Without parameters, resets the flag to the default value +""" +should_trace_loops!(val::Bool=false) = (TRACING_OPTIONS[][:trace_loops] = val) +should_trace_loops() = get(TRACING_OPTIONS[], :trace_loops, false) + + +""" +Tracer options. Configured globally via the following methods: + +- should_trace_loops!() +""" +struct TracerOptions + trace_loops::Bool +end + + +TracerOptions() = TracerOptions(should_trace_loops()) + + ################################################################################ # IRTracer (defs and utils) # ################################################################################ @@ -71,11 +100,12 @@ mutable struct IRTracer is_primitive::Function tape::Tape frames::Vector{Frame} + options::TracerOptions end function IRTracer(; ctx=Dict(), is_primitive=is_chainrules_primitive) tape = Tape(ctx) - return IRTracer(is_primitive, tape, []) + return IRTracer(is_primitive, tape, [], TracerOptions()) end Base.show(io::IO, t::IRTracer) = print(io, "IRTracer($(length(t.tape)))") @@ -142,11 +172,240 @@ end """Set return variable for the current frame""" function set_return!(t::IRTracer, arg_sid_ref) + @assert(arg_sid_ref[] !== nothing, + "Cannot set return value to nothing. Does this function actually return a value?") tape_var = get_tape_vars(t, [arg_sid_ref[]])[1] t.frames[end].result = tape_var end +################################################################################ +# Loop utils # +################################################################################ + + +function is_loop(block::IRTools.Block) + for br in IRTools.branches(block) + # if a branch refers to an earlier block and is not return + # then it must be a loop + if br.block <= block.id && br.block != 0 + return true + end + end + return false +end + + +"""Get SSA IDs of a block's inputs""" +function block_input_ir_ids(block::IRTools.Block) + result = [] + for (stmt_id, (block_id, arg_id)) in enumerate(block.ir.defs) + if block_id == block.id && arg_id < 0 + push!(result, stmt_id) + end + end + return result +end + + +""" +Get SSA IDs of arguments which are not part of this block +(e.g. come from outside of a loop) +""" +function block_outsider_ir_ids(block::IRTools.Block) + result = [] + min_id = minimum(block_input_ir_ids(block)) + for (_, stmt) in block + ex = stmt.expr + @assert Meta.isexpr(ex, :call) + for arg in ex.args[2:end] + if arg isa IRTools.Variable && arg.id < min_id + push!(result, arg.id) + end + end + end + return result +end + + +function loop_exit_branch(block::IRTools.Block) + branches = IRTools.branches(block) + target_branch_id = findfirst([br.block > block.id for br in branches]) + return target_branch_id !== nothing ? branches[target_branch_id] : nothing +end + + +function loop_continue_branch(block::IRTools.Block) + branches = IRTools.branches(block) + target_branch_id = findfirst([br.block <= block.id for br in branches]) + return return target_branch_id !== nothing ? branches[target_branch_id] : nothing +end + + +function loop_condition_ir_id(block::IRTools.Block) + br = loop_exit_branch(block) + return br !== nothing ? br.condition.id : nothing +end + + +"""Get SSA IDs of exit arguments of a loop block""" +function loop_exit_ir_ids(block::IRTools.Block) + br = loop_exit_branch(block) + return br !== nothing ? [arg.id for arg in br.args] : nothing +end + + +function loop_continue_ir_ids(block::IRTools.Block) + br = loop_continue_branch(block) + return br !== nothing ? [arg.id for arg in br.args] : nothing +end + + +"""Pseudo op to designate loop end. Removed after Loop op is created""" +mutable struct _LoopEnd <: AbstractOp + id::Int +end +_LoopEnd() = _LoopEnd(0) + + +const LOOP_EXIT_TAPE_IDS = "loop_exit_tape_ids" +const LOOP_COND_ID = "loop_cond_id" +const LOOP_CONTINUE_TAPE_IDS = "loop_continue_tape_ids" + + +is_loop_traced(t::IRTracer, loop_id::UUID) = haskey(t.tape.meta, "loop_already_traced_$loop_id") +loop_traced!(t::IRTracer, loop_id::UUID) = (t.tape.meta["loop_already_traced_$loop_id"] = true) + + +""" +Trigger loop start operations. + +Arguments: +---------- + + * t :: IRTracer + Current tracer + * loop_id :: Int + Unique ID of a loop being entered + * loop_input_ir_ids :: Vector{Int} + IR IDs of variables which will be used as loop inputs. Includes + loop block inputs and any outside IDs + +This function is added to the very beginning of the loop block(s). +During the first iteration we initialize a subtape which will be used +later to create the Loop operation on the parent tape. Since all iterations +of the loop produce identical operations, we only need to trace it once. +However, it turns out to be easier to record all iterations (seprated by +a special _LoopEnd op) and then prune unused iterations. + +Another important detail is that Loop's subtape is a completely valid +and independent tape with its own call frame and inputs which include +all explicit and implicit inputs to the loop's block in the original IR. +""" +function enter_loop!(t::IRTracer, loop_id, loop_input_ir_ids::Vector) + t.options.trace_loops || return + # skip if it's not the first iteration + is_loop_traced(t, loop_id) && return + # create subtape, with the current tape as parent + C = typeof(t.tape.c) + subtape = Tape(C()) + subtape.parent = t.tape + # create a new frame and push to the list + frame = Frame(Dict(), V(0), ()) + push!(t.frames, frame) + # record inputs to the subtape & populate the new frame's ir2tape + for ir_id in loop_input_ir_ids + parent_tape_var = t.frames[end - 1].ir2tape[ir_id] + val = subtape.parent[parent_tape_var].val + tape_var = push!(subtape, Input(val)) + t.frames[end].ir2tape[ir_id] = tape_var + end + # replace IRTracer.tape with subtape + t.tape = subtape +end + + +""" +Set flags designating the end of the first run of the loop. +""" +function stop_loop_tracing!(t::IRTracer, + loop_id::Any, + input_ir_ids::Vector, + cond_ir_id::Any, + cont_ir_ids::Vector, + exit_ir_ids::Vector, + exit_target_ir_ids::Vector) + t.options.trace_loops || return + if !is_loop_traced(t, loop_id) + # record exit tape IDs as of first iteration + # we will use them later + t.tape.meta[LOOP_EXIT_TAPE_IDS] = + [t.frames[end].ir2tape[ir_id] for ir_id in exit_ir_ids] + t.tape.meta[LOOP_COND_ID] = t.frames[end].ir2tape[cond_ir_id] + t.tape.meta[LOOP_CONTINUE_TAPE_IDS] = + [t.frames[end].ir2tape[ir_id] for ir_id in cont_ir_ids] + # set flag to stop creatnig new subtapes + loop_traced!(t, loop_id) + end + # record a special op to designate the end of the loop code + # tracer will continue to record ops, but later we truncate + # the tape to get only ops before _LoopEnd + push!(t.tape, _LoopEnd()) +end + + +""" +Trigget loop end operations. + +This function is added just before the end of the loop block. +Since we record all iterations of the loop, we must remember tape IDs +of continuation condition and exit variables during the first run. +""" +function exit_loop!(t::IRTracer, + input_ir_ids::Vector, + cond_ir_id::Any, + cont_ir_ids::Vector, + exit_ir_ids::Vector, + exit_target_ir_ids::Vector) + t.options.trace_loops || return + # loop subtape already contains a variable designating condition + # of loop continuation; if this condition is false, + # we are ready to exit the loop and record Loop operation + cond_op = t.tape[t.frames[end].ir2tape[cond_ir_id]] + if !cond_op.val + # swap active tape back + pop!(t.frames) + parent_ir2tape = t.frames[end].ir2tape + subtape = t.tape + t.tape = t.tape.parent + # remove repeating blocks + first_loop_end = findfirst(op -> isa(op, _LoopEnd), subtape.ops) + subtape.ops = subtape.ops[1:first_loop_end-1] + # record output tuple + exit_tape_vars = subtape.meta[LOOP_EXIT_TAPE_IDS] + exit_val = tuple([subtape[id].val for id in exit_tape_vars]...) + # exit_var = push!(subtape, mkcall(tuple, exit_tape_vars...)) + # exit_val = subtape[exit_var].val + subtape.result = bound(t.tape, V(length(t.tape))) # dummy, not used in practice + # record the loop operation + # global STATE = (t, input_ir_ids, exit_target_ir_ids) + parent_input_vars = [parent_ir2tape[ir_id] for ir_id in input_ir_ids] + condition = subtape.meta[LOOP_COND_ID] + cont_vars = subtape.meta[LOOP_CONTINUE_TAPE_IDS] + loop_id = push!(t.tape, Loop(0, parent_input_vars, + condition, cont_vars, + exit_tape_vars, subtape, exit_val)) + # destructure loop return values to separate vars on the main tape + # and map branch arguments to these vars + for i=1:length(exit_val) + res_id = push!(t.tape, mkcall(getfield, loop_id, i)) + parent_ir2tape[exit_target_ir_ids[i]] = res_id + end + end +end + + + ################################################################################ # IRTracer (body) + irtrace() # ################################################################################ @@ -172,8 +431,6 @@ Params: """ function record_or_recurse!(t::IRTracer, res_sid::Int, farg_irvars, fargs...) fn, args = fargs[1], fargs[2:end] - # global STATE = (t, res_sid, farg_irvars, fargs) - # fn == Core.kwfunc(sum) && error() if t.is_primitive(Tuple{map(typeof, fargs)...}) tape_vars = get_tape_vars(t, farg_irvars) # record corresponding op to the tape @@ -199,6 +456,28 @@ function record_const!(t::IRTracer, res_sid, val) end +""" +Get SSA IDs of the branch target parameters. +For example, given code like this: + + + 2: (%9, %10, %11) + ... + br 3 (%14, %15) unless %18 + br 2 (%16, %14, %15) + 3: (%19, %20) + ... + +This function will return: + + branch_target_params(ir,
) ==> [19, 20] + +""" +function branch_target_params(ir:: IR, branch::IRTools.Branch) + return [v.id for v in ir.blocks[branch.block].args] +end + + function trace_branches!(ir::IR) # if a block ends with a branch, we map its parameters to tape IDs # which currently correspond to argument SSA IDs @@ -217,10 +496,61 @@ function trace_branches!(ir::IR) end +function loop_start_end_blocks(ir::IR, block::IRTools.Block) + start_block = IRTools.block(ir, loop_continue_branch(block).block) + exit_branch = loop_exit_branch(block) + # check if this block contains the exit branch + if exit_branch !== nothing && exit_branch.block > block.id + end_block = block + else + exit_branch = loop_exit_branch(start_block) + if exit_branch !== nothing && exit_branch.block > block.id + end_block = start_block + else + error("Cannot find end block of a loop") + end + end + return start_block, end_block +end + + +function trace_loops!(ir::IR) + for block in IRTools.blocks(ir) + if is_loop(block) + loop_id = uuid1() # unique ID of this loop + start_block, end_block = loop_start_end_blocks(ir, block) + # loop start - the first block of the loop + loop_input_ir_ids = vcat( + block_input_ir_ids(start_block), + block_outsider_ir_ids(start_block), + ) + pushfirst!(start_block, Expr(:call, enter_loop!, self, loop_id, loop_input_ir_ids)) + # loop tracing border - at this point all operations of the loop + # have been executed at least once, even if continuation condition + # is in another block + push!(block, Expr(:call, stop_loop_tracing!, self, + loop_id, + loop_input_ir_ids, + loop_condition_ir_id(end_block), + loop_continue_ir_ids(block), + loop_exit_ir_ids(end_block), + branch_target_params(ir, loop_exit_branch(end_block)))) + # loop end - continuation condition is checked here + push!(end_block, Expr(:call, exit_loop!, self, + loop_input_ir_ids, + loop_condition_ir_id(end_block), + loop_continue_ir_ids(block), + loop_exit_ir_ids(end_block), + branch_target_params(ir, loop_exit_branch(end_block)))) + end + end +end + + @dynamo function (t::IRTracer)(fargs...) ir = IR(fargs...) ir === nothing && return # intrinsic functions - # TODO (for loops): IRTools.expand!(ir) + IRTools.expand!(ir) rewrite_special_cases!(ir) for (v, st) in ir ex = st.expr @@ -240,6 +570,7 @@ end end end trace_branches!(ir) + trace_loops!(ir) return ir end @@ -273,4 +604,4 @@ function trace(f, args...; is_primitive=is_primitive, primitives=nothing, ctx=Di t.tape.result = t.frames[1].result tape = t.tape return val, tape -end \ No newline at end of file +end diff --git a/test/runtests.jl b/test/runtests.jl index 0ba1b4a..a084553 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using Test using Yota using Yota: gradtape, gradcheck, update_chainrules_primitives!, compile +using Yota: Loop, should_trace_loops, should_trace_loops! using CUDA import ChainRulesCore: Tangent, ZeroTangent diff --git a/test/test_trace.jl b/test/test_trace.jl index 17f6425..da98430 100644 --- a/test/test_trace.jl +++ b/test/test_trace.jl @@ -79,9 +79,7 @@ function loop1(a, n) a = 2a for i in 1:n a = a * n - n = n + 1 end - a = a + n return a end @@ -106,23 +104,68 @@ function loop3(a, b) end -function cond1(a, b) - if b > 0 - a = 2a +function loop4(x, n, m) + for i in 1:n + for j in 1:m + x = 2x + end end - return a + return x end -@testset "trace: loops" begin - # smoke tests, will be replaced with loop testing when it's ready - val, tape = trace(loop1, 8.0, 2) - @test val == loop1(8.0, 2) - - val, tape = trace(loop2, 5, 10) - @test val == loop2(5, 10) +function loop5(a, n) + for i=1:3 + a = loop1(a, n) + end + return a +end - val, tape = trace(loop3, 1, 3) - @test val == loop3(1, 3) +@testset "trace: loops" begin + should_trace_loops!(false) + + _, tape = trace(loop1, 1.0, 3) + @test findfirst(op -> op isa Loop, tape.ops) === nothing + # same number of iteration + @test play!(tape, loop1, 2.0, 3) == loop1(2.0, 3) + @test compile(tape)(loop1, 2.0, 3) == loop1(2.0, 3) + # different number of iteration - with loop tracing off, should be incorrect + @test play!(tape, loop1, 2.0, 4) != loop1(2.0, 4) + @test compile(tape)(loop1, 2.0, 4) != loop1(2.0, 4) + + should_trace_loops!(true) + + _, tape = trace(loop1, 1.0, 3) + @test play!(tape, loop1, 2.0, 4) == loop1(2.0, 4) + @test compile(tape)(loop1, 2.0, 4) == loop1(2.0, 4) + @test findfirst(op -> op isa Loop, tape.ops) !== nothing + + _, tape = trace(loop2, 1.0, 3) + @test play!(tape, loop2, 2.0, 4) == loop2(2.0, 4) + @test compile(tape)(loop2, 2.0, 4) == loop2(2.0, 4) + @test findfirst(op -> op isa Loop, tape.ops) !== nothing + + _, tape = trace(loop3, 1.0, 3) + @test play!(tape, loop3, 2.0, 4) == loop3(2.0, 4) + @test compile(tape)(loop3, 2.0, 4) == loop3(2.0, 4) + @test findfirst(op -> op isa Loop, tape.ops) !== nothing + + _, tape = trace(loop4, 1.0, 2, 3) + @test play!(tape, loop4, 2.0, 3, 4) == loop4(2.0, 3, 4) + @test compile(tape)(loop4, 2.0, 3, 4) == loop4(2.0, 3, 4) + loop_idx = findfirst(op -> op isa Loop, tape.ops) + @test loop_idx !== nothing + subtape = tape[V(loop_idx)].subtape + @test findfirst(op -> op isa Loop, subtape.ops) !== nothing + + _, tape = trace(loop5, 1.0, 3) + @test play!(tape, loop5, 2.0, 4) == loop5(2.0, 4) + @test compile(tape)(loop5, 2.0, 4) == loop5(2.0, 4) + loop_idx = findfirst(op -> op isa Loop, tape.ops) + @test loop_idx !== nothing + subtape = tape[V(loop_idx)].subtape + @test findfirst(op -> op isa Loop, subtape.ops) !== nothing + + should_trace_loops!() end \ No newline at end of file