Skip to content

Commit

Permalink
fix: scoping
Browse files Browse the repository at this point in the history
  • Loading branch information
thofma committed Aug 8, 2024
1 parent 3f7d845 commit 794d3a9
Show file tree
Hide file tree
Showing 6 changed files with 271 additions and 39 deletions.
3 changes: 2 additions & 1 deletion docs/src/manual.md
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ DocTestSetup = quote
@resumable function arrays_of_tuples()
for u in [[(1,2),(3,4)], [(5,6),(7,8)]]
for i in 1:2
local val
let i=i
val = [a[i] for a in u]
end
Expand Down Expand Up @@ -413,4 +414,4 @@ DocTestSetup = nothing
- In a `try` block only top level `@yield` statements are allowed.
- In a `catch` or a `finally` block a `@yield` statement is not allowed.
- An anonymous function can not contain a `@yield` statement.
- If a `FiniteStateMachineIterator` object is used in more than one `for` loop, only the `state` variable is reinitialised. A `@resumable function` that alters its arguments will use the modified values as initial parameters.
- If a `FiniteStateMachineIterator` object is used in more than one `for` loop, only the `state` variable is reinitialised. A `@resumable function` that alters its arguments will use the modified values as initial parameters.
34 changes: 30 additions & 4 deletions src/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ macro resumable(expr::Expr)

# The function that executes a step of the finite state machine
func_def = splitdef(expr)
@debug func_def[:body]
rtype = :rtype in keys(func_def) ? func_def[:rtype] : Any
args, kwargs, arg_dict = get_args(func_def)
params = ((get_param_name(param) for param in func_def[:whereparams])...,)
Expand All @@ -43,7 +44,35 @@ macro resumable(expr::Expr)
func_def[:body] = postwalk(transform_arg_yieldfrom, func_def[:body])
func_def[:body] = postwalk(transform_yieldfrom, func_def[:body])
func_def[:body] = postwalk(x->transform_for(x, ui8), func_def[:body])
@debug func_def[:body]|>MacroTools.striplines
#func_def[:body] = postwalk(x->transform_macro(x), func_def[:body])
#@debug func_def[:body]|>MacroTools.striplines
#func_def[:body] = postwalk(x->transform_macro_undo(x), func_def[:body])
#@debug func_def[:body]|>MacroTools.striplines
#func_def[:body] = postwalk(x->transform_let(x), func_def[:body])
#@info func_def[:body]|>MacroTools.striplines
#func_def[:body] = postwalk(x->transform_local(x), func_def[:body])
#@info func_def[:body]|>MacroTools.striplines
# Scoping fixes

# :name is :(fA::A) if it is an overloading call function (fA::A)(...)
# ...
if func_def[:name] isa Expr
@assert func_def[:name].head == :(::)
_name = func_def[:name].args[1]
else
_name = func_def[:name]
end

scope = ScopeTracker(0, __module__, [Dict(i =>i for i in vcat(args, kwargs, [_name], params...))])
#@info func_def[:body]|>MacroTools.striplines
func_def[:body] = scoping(copy(func_def[:body]), scope)
#@info func_def[:body]|>MacroTools.striplines
func_def[:body] = postwalk(x->transform_remove_local(x), func_def[:body])
#@info func_def[:body]|>MacroTools.striplines

inferfn, slots = get_slots(copy(func_def), arg_dict, __module__)
@debug slots

# check if the resumable function is a callable struct instance (a functional) that is referencing itself
isfunctional = @capture(func_def[:name], functional_::T_) && inexpr(func_def[:body], functional)
Expand Down Expand Up @@ -74,7 +103,6 @@ macro resumable(expr::Expr)
fsmi._state = 0x00
fsmi
end

# the bare/fallback version of the constructor supplies default slot type parameters
# we only need to define this if there there are actually slot defaults to be filled
if !isempty(slot_T)
Expand All @@ -100,7 +128,6 @@ macro resumable(expr::Expr)
end
)
@debug type_expr|>MacroTools.striplines

# The "original" function that now is simply a wrapper around the construction of the finite state machine
call_def = copy(func_def)
call_def[:rtype] = nothing
Expand All @@ -119,7 +146,7 @@ macro resumable(expr::Expr)
end
call_expr = combinedef(call_def) |> flatten
@debug call_expr|>MacroTools.striplines

# Finalizing the function stepping through the finite state machine
if isempty(params)
func_def[:name] = :((_fsmi::$type_name))
Expand Down Expand Up @@ -153,7 +180,6 @@ macro resumable(expr::Expr)
call_expr = postwalk(x->x==:(ResumableFunctions.typed_fsmi) ? :(ResumableFunctions.typed_fsmi_fallback) : x, call_expr)
end
@debug func_expr|>MacroTools.striplines

# The final expression:
# - the finite state machine struct
# - the function stepping through the states
Expand Down
83 changes: 65 additions & 18 deletions src/transforms.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
function transform_remove_local(ex)
ex isa Expr && ex.head === :local && return Expr(:block)
return ex
end

function transform_macro(ex)
ex isa Expr || return ex
ex.head !== :macrocall && return ex
return Expr(:call, :__secret__, ex.args)
end

function transform_macro_undo(ex)
ex isa Expr || return ex
(ex.head !== :call || ex.args[1] !== :__secret__) && return ex
return Expr(:macrocall, ex.args[2]...)
end

"""
Function that replaces a variable
"""
Expand Down Expand Up @@ -70,22 +87,26 @@ Function that replaces a `for` loop by a corresponding `while` loop saving expli
"""
function transform_for(expr, ui8::BoxedUInt8)
@capture(expr, for element_ in iterator_ body_ end) || return expr
#@info element
localelement = Expr(:local, element)
ui8.n += one(UInt8)
next = Symbol("_iteratornext_", ui8.n)
state = Symbol("_iterstate_", ui8.n)
iterator_value = Symbol("_iterator_", ui8.n)
label = Symbol("_iteratorlabel_", ui8.n)
body = postwalk(x->transform_continue(x, label), :(begin $(body) end))
quote
res = quote
$iterator_value = $iterator
@nosave $next = iterate($iterator_value)
$next = iterate($iterator_value)
while $next !== nothing
$localelement
($element, $state) = $next
$body
@label $label
$next = iterate($iterator_value, $state)
end
end
res
end


Expand All @@ -102,7 +123,7 @@ Function that replaces a variable `x` in an expression by `_fsmi.x` where `x` is
"""
function transform_slots(expr, symbols)
expr isa Expr || return expr
expr.head === :let && return transform_slots_let(expr, symbols)
#expr.head === :let && return transform_slots_let(expr, symbols)
for i in 1:length(expr.args)
expr.head === :kw && i === 1 && continue
expr.head === Symbol("quote") && continue
Expand All @@ -111,27 +132,53 @@ function transform_slots(expr, symbols)
expr
end

"""
Function that handles `let` block
"""
function transform_slots_let(expr::Expr, symbols)
@capture(expr, let vars_; body_ end)
locals = Set{Symbol}()
(isa(vars, Expr) && vars.head==:(=)) || error("@resumable currently supports only single variable declarations in let blocks, i.e. only let blocks exactly of the form `let i=j; ...; end`. If you need multiple variables, please submit an issue on the issue tracker and consider contributing a patch.")
sym = vars.args[1].args[2].value
push!(locals, sym)
vars.args[1] = sym
body = postwalk(x->transform_let(x, locals), :(begin $(body) end))
:(let $vars; $body end)
#"""
#Function that handles `let` block
#"""
#function transform_slots_let(expr::Expr, symbols)
# @capture(expr, let vars_; body_ end)
# locals = Set{Symbol}()
# (isa(vars, Expr) && vars.head==:(=)) || error("@resumable currently supports only single variable declarations in let blocks, i.e. only let blocks exactly of the form `let i=j; ...; end`. If you need multiple variables, please submit an issue on the issue tracker and consider contributing a patch.")
# sym = vars.args[1].args[2].value
# push!(locals, sym)
# vars.args[1] = sym
# body = postwalk(x->transform_let(x, locals), :(begin $(body) end))
# :(let $vars; $body end)
#end

function transform_let(expr)
expr isa Expr || return expr
expr.head === :block && return expr
#@info "inside transform let"
@capture(expr, let arg_; body_; end) || return expr
#@info "captured let"
#arg |> dump
#@info expr
#@info arg
#error("ASds")
res = quote
let
local $arg
$body
end
end
#@info "emitting $res"
res
#expr.head === :. || return expr
#expr = expr.args[2].value in symbols ? :($(expr.args[2].value)) : expr
end

"""
Function that replaces a variable `_fsmi.x` in an expression by `x` where `x` is a variable declared in a `let` block.
"""
function transform_let(expr, symbols::Set{Symbol})
function transform_local(expr)
expr isa Expr || return expr
expr.head === :. || return expr
expr = expr.args[2].value in symbols ? :($(expr.args[2].value)) : expr
@capture(expr, local arg_ = ex_) || return expr
res = quote
local $arg
$arg = $ex
end
res
end

"""
Expand Down
150 changes: 150 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,153 @@ end
function typed_fsmi_fallback(fsmi::Type{T}, fargs...)::T where T
return T()
end

mutable struct ScopeTracker
i::Int
mod::Module
scope_stack::Vector
end

function lookup!(s::Symbol, S::ScopeTracker; new = false)
if isdefined(S.mod, s)
return s
end
if !new
for D in Iterators.reverse(S.scope_stack)
if haskey(D, s)
return D[s]
end
end
end
D = last(S.scope_stack)
new = Symbol(s, Symbol("_$(S.i)"))
S.i += 1
D[s] = new
return new
end

scoping(e::LineNumberNode, scope) = e

scoping(e::Int, scope) = e

scoping(e::String, scope) = e
scoping(e::typeof(ResumableFunctions.generate), scope) = e
scoping(e::typeof(ResumableFunctions.IteratorReturn), scope) = e
scoping(e::QuoteNode, scope) = e
scoping(e::Bool, scope) = e
scoping(e::Nothing, scope) = e

function scoping(s::Symbol, scope; new = false)
#@info "scoping $s, $new"
return lookup!(s, scope; new = new)
end

function scoping(expr::Expr, scope)
if expr.head === :macrocall
for i in 2:length(expr.args)
expr.args[i] = scoping(expr.args[i], scope)
end
return expr
end
new_stack = false
if expr.head === :let
# Replace
# let i, k = 2, j = 1
# [...]
# end
#
# by
#
# let
# local i_new
# local k_new = 2
# local j_new = 1
# end
#
# Caveat:
# let i = i, j = i
#
# must be
# new_i = old_i
# new_j = new_i
#
# :(

# defer adding a new scope after the right hand side have been renamed
@capture(expr, let arg_; body_ end) || return expr
@capture(arg, begin x__ end)
replace_rhs = []
for i in 1:length(x)
y = x[i]
fl = @capture(y, k_ = v_)
if fl
push!(replace_rhs, scoping(v, scope))
else
# there was no right side
push!(replace_rhs, nothing)
end
end
new_stack = true
push!(scope.scope_stack, Dict())
replace_lhs = []
rep = []
for i in 1:length(x)
y = x[i]
fl = @capture(y, k_ = v_)
if fl
push!(replace_lhs, scoping(k, scope, new = true))
push!(rep, quote local $(replace_lhs[i]); $(replace_lhs[i]) = $(replace_rhs[i]) end)
else
@assert y isa Symbol
push!(replace_lhs, scoping(y, scope, new = true))
push!(rep, quote local $(replace_lhs[i]) end)
end
end
rep = quote
$(rep...)
end
rep = MacroTools.flatten(rep)
expr.args[1] = Expr(:block)
pushfirst!(expr.args[2].args, rep)

# Now continue recursively
# but skip the local/dance, since we already replaced them
for i in 2:length(expr.args[2].args)
a = expr.args[2].args[i]
expr.args[2].args[i] = scoping(a, scope)
end
pop!(scope.scope_stack)
return expr
end

if expr.head === :while || expr.head === :let
push!(scope.scope_stack, Dict())
new_stack = true
end
if expr.head === :local
# this is my local dance
# explain and rewrite using @capture
if length(expr.args) == 1 && expr.args[1] isa Symbol
expr.args[1] = scoping(expr.args[1], scope, new = true)
elseif length(expr.args) == 1 && expr.args[1].head === :tuple
for i in 1:length(expr.args[1].args)
a = expr.args[1].args[i]
expr.args[1].args[i] = scoping(a, scope, new = true)
end
else
for i in 1:length(expr.args)
a = expr.args[i]
expr.args[i] = scoping(a, scope, new = true)
end
end
else
for i in 1:length(expr.args)
a = expr.args[i]
expr.args[i] = scoping(a, scope)
end
end
if new_stack
pop!(scope.scope_stack)
end
return expr
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ end

macro doset(descr)
quote
@info "====================================="
@info $descr
if doset($descr)
@safetestset $descr begin
include("test_"*$descr*".jl")
Expand Down
Loading

0 comments on commit 794d3a9

Please sign in to comment.