Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@trace function calls #366

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
533a19d
wip trace function call
jumerckx Dec 11, 2024
04690e8
multiple return values
jumerckx Dec 12, 2024
180604f
halfway there with caching: new tracemode `CallCache` transforms trac…
jumerckx Dec 12, 2024
caf0e0e
add `enable_tracing` scopedvalue
jumerckx Dec 13, 2024
dbf1dac
wip
jumerckx Dec 13, 2024
d914f04
caching!
jumerckx Dec 13, 2024
670a8c5
add ScopedValues as a dependency
jumerckx Dec 13, 2024
48be2e9
actually use ScopedValues...
jumerckx Dec 13, 2024
ae8701d
properly handle arg and result tracedvalues in `traced_call`
jumerckx Dec 13, 2024
c40e585
don't tranpose results when `do_tranpose=false`
jumerckx Dec 13, 2024
531d1b2
a few tests
jumerckx Dec 14, 2024
f0723e8
CallCache mode: support TracedRNumber
jumerckx Dec 14, 2024
5ead000
one more test
jumerckx Dec 14, 2024
1f05bc9
add Cached object with custom equality and hash for use a dict key
jumerckx Dec 15, 2024
e1852c5
caching test
jumerckx Dec 15, 2024
0cea97a
fix
jumerckx Dec 16, 2024
38e098c
callcache
jumerckx Dec 16, 2024
236babf
add deepcopy to Cached constructor
jumerckx Dec 17, 2024
2c87ddc
`within_tracing`, approach taken from Enzyme.jl
jumerckx Dec 17, 2024
ccb585f
make callcache a default arg
jumerckx Dec 17, 2024
469a2ff
`within_tracing` --> `within_compile`
Dec 22, 2024
84dd0d1
TracedToTypes tracing mode.
jumerckx Jan 1, 2025
39b07c9
remove ScopedValues from Project.toml
jumerckx Jan 2, 2025
af8ac2d
Merge branch 'main' into jm/funccall
jumerckx Jan 10, 2025
739977c
Merge branch 'main' into jm/funccall
jumerckx Jan 13, 2025
a34a228
support broadcasted calls in `@trace`
jumerckx Jan 22, 2025
8d0669c
handle mutation
jumerckx Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"

[sources.ReactantCore]
path = "lib/ReactantCore"
[sources]
ReactantCore = {path = "lib/ReactantCore"}

[extensions]
ReactantAbstractFFTsExt = "AbstractFFTs"
Expand Down
52 changes: 48 additions & 4 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module ReactantCore
using ExpressionExplorer: ExpressionExplorer
using MacroTools: MacroTools

export @trace, MissingTracedValue
export @trace, within_compile, MissingTracedValue

# Traits
is_traced(x) = false
Expand All @@ -19,6 +19,13 @@ const SPECIAL_SYMBOLS = [
:(:), :nothing, :missing, :Inf, :Inf16, :Inf32, :Inf64, :Base, :Core
]

"""
within_compile()

Returns true if this function is executed in a Reactant compilation context, otherwise false.
"""
@inline within_compile() = false # behavior is overwritten in Interpreter.jl

# Code generation
"""
@trace <expr>
Expand Down Expand Up @@ -115,6 +122,13 @@ macro trace(expr)
return esc(trace_if_with_returns(__module__, expr))
end
end
Meta.isexpr(expr, :call) && return esc(trace_call(__module__, expr))
if Meta.isexpr(expr, :(.), 2) && Meta.isexpr(expr.args[2], :tuple)
fname = :($(Base.Broadcast.BroadcastFunction)($(expr.args[1])))
args = only(expr.args[2:end]).args
call = Expr(:call, fname, args...)
return esc(trace_call(__module__, call))
end
Meta.isexpr(expr, :if) && return esc(trace_if(__module__, expr))
Meta.isexpr(expr, :for) && return (esc(trace_for(__module__, expr)))
return error("Only `if-elseif-else` blocks are currently supported by `@trace`")
Expand Down Expand Up @@ -182,7 +196,8 @@ function trace_for(mod, expr)
end

return quote
if any($(is_traced), $(Expr(:tuple, all_syms.args[(begin + 1):end]...)))
if $(within_compile)() &&
$(any)($(is_traced), $(Expr(:tuple, all_syms.args[(begin + 1):end]...)))
$(reactant_code_block)
else
$(expr)
Expand All @@ -196,7 +211,7 @@ function trace_if_with_returns(mod, expr)
mod, expr.args[2]; store_last_line=expr.args[1], depth=1
)
return quote
if any($(is_traced), ($(all_check_vars...),))
if $(within_compile)() && $(any)($(is_traced), ($(all_check_vars...),))
$(new_expr)
else
$(expr)
Expand Down Expand Up @@ -342,14 +357,41 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0)
)

return quote
if any($(is_traced), ($(all_check_vars...),))
if $(within_compile)() && $(any)($(is_traced), ($(all_check_vars...),))
$(reactant_code_block)
else
$(original_expr)
end
end
end

function correct_maybe_bcast_call(fname)
startswith(string(fname), '.') || return false, fname, fname
return true, Symbol(string(fname)[2:end]), fname
end

function trace_call(mod, call)
bcast, fname, fname_full = correct_maybe_bcast_call(call.args[1])
f = if bcast
quote
if isdefined(mod, $(Meta.quot(fname_full)))
$(fname_full)
else
Base.Broadcast.BroadcastFunction($(fname))
end
end
else
:($(fname))
end
return quote
if $(within_compile)()
$(traced_call)($f, $(call.args[2:end]...))
else
$(call)
end
end
end

function remove_shortcircuiting(expr)
return MacroTools.prewalk(expr) do x
if MacroTools.@capture(x, a_ && b_)
Expand All @@ -373,6 +415,8 @@ function traced_while(cond_fn, body_fn, args)
return args
end

traced_call(f, args...; kwargs...) = f(args...; kwargs...)

function cleanup_expr_to_avoid_boxing(expr, prepend::Symbol, all_vars)
return MacroTools.postwalk(expr) do x
if x isa Symbol && x ∈ all_vars
Expand Down
49 changes: 43 additions & 6 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import ..Reactant:
ancestor,
TracedType

import ..ReactantCore: correct_maybe_bcast_call

@inline function traced_getfield(@nospecialize(obj), field)
return Base.getfield(obj, field)
end
Expand Down Expand Up @@ -352,7 +354,7 @@ const cuLaunch = Ref{UInt}(0)
const cuFunc = Ref{UInt}(0)
const cuModule = Ref{UInt}(0)

function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false)
function compile_mlir!(mod, f, args, callcache=Dict{Vector, @NamedTuple{f_name::String, mlir_result_types::Vector{MLIR.IR.Type}, traced_result::Any, mutated::Vector{Int}}}(); optimize::Union{Bool,Symbol}=true, no_nan::Bool=false)
# Explicitly don't use block! to avoid creating a closure, which creates
# both compile-time and relocatability issues

Expand All @@ -361,7 +363,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
fnwrapped,
func2, traced_result, result, seen_args, ret, linear_args, in_tys,
linear_results = try
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
callcache!(callcache) do # TODO: don't create a closure here either.
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
end
finally
MLIR.IR.deactivate!(MLIR.IR.body(mod))
MLIR.IR.deactivate!(mod)
Expand Down Expand Up @@ -600,10 +604,6 @@ function compile_call_expr(mod, compiler, options, args...)
(; compiled=compiled_symbol, args=args_symbol)
end

function correct_maybe_bcast_call(fname)
startswith(string(fname), '.') || return false, fname, fname
return true, Symbol(string(fname)[2:end]), fname
end

"""
codegen_flatten!
Expand Down Expand Up @@ -1014,4 +1014,41 @@ function register_thunk(
return Thunk{Core.Typeof(f),tag,argtys,isclosure}(f)
end

function activate_callcache!(callcache)
stack = get!(task_local_storage(), :callcache) do
return []
end
push!(stack, callcache)
return nothing
end

function deactivate_callcache!(callcache)
callcache === last(task_local_storage(:callcache)) ||
error("Deactivating wrong callcache")
return pop!(task_local_storage(:callcache))
end

function _has_callcache()
return haskey(task_local_storage(), :callcache) &&
!Base.isempty(task_local_storage(:callcache))
end

function callcache(; throw_error::Bool=true)
if !_has_callcache()
throw_error && error("No callcache is active")
return nothing
end
return last(task_local_storage(:callcache))
end

function callcache!(f, callcache)
activate_callcache!(callcache)
try
return f()
finally
deactivate_callcache!(callcache)
end
end


end
97 changes: 89 additions & 8 deletions src/ControlFlow.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
function ReactantCore.traced_if(
cond::TracedRNumber{Bool}, true_fn::TFn, false_fn::FFn, args
) where {TFn,FFn}
(_, true_branch_compiled, true_branch_results, _, _, _, _, _, true_linear_results) = Reactant.TracedUtils.make_mlir_fn(
(_, true_branch_compiled, true_branch_results, _, _, _, _, _, true_linear_results, _) = Reactant.TracedUtils.make_mlir_fn(
true_fn,
args,
(),
string(gensym("true_branch")),
false;
return_dialect=:stablehlo,
no_args_in_result=true,
args_in_result=:none,
construct_function_without_args=true,
)

(_, false_branch_compiled, false_branch_results, _, _, _, _, _, false_linear_results) = Reactant.TracedUtils.make_mlir_fn(
(_, false_branch_compiled, false_branch_results, _, _, _, _, _, false_linear_results, _) = Reactant.TracedUtils.make_mlir_fn(
false_fn,
args,
(),
string(gensym("false_branch")),
false;
return_dialect=:stablehlo,
no_args_in_result=true,
args_in_result=:none,
construct_function_without_args=true,
)

Expand Down Expand Up @@ -88,24 +88,24 @@ function ReactantCore.traced_while(
end for v in args
]

(_, cond_fn_compiled, cond_fn_results, _, _, _, _, in_tys, cond_fn_linear_results) = Reactant.TracedUtils.make_mlir_fn(
(_, cond_fn_compiled, cond_fn_results, _, _, _, _, in_tys, cond_fn_linear_results, _) = Reactant.TracedUtils.make_mlir_fn(
cond_fn,
traced_args,
(),
string(gensym("cond_fn")),
false;
no_args_in_result=true,
args_in_result=:none,
return_dialect=:stablehlo,
do_transpose=false,
)

(_, body_fn_compiled, body_fn_results, _, _, _, _, _, body_fn_linear_results) = Reactant.TracedUtils.make_mlir_fn(
(_, body_fn_compiled, body_fn_results, _, _, _, _, _, body_fn_linear_results, _) = Reactant.TracedUtils.make_mlir_fn(
body_fn,
traced_args,
(),
string(gensym("body_fn")),
false;
no_args_in_result=true,
args_in_result=:none,
return_dialect=:stablehlo,
do_transpose=false,
)
Expand All @@ -130,6 +130,87 @@ function ReactantCore.traced_while(
end
end

function ReactantCore.traced_call(f::Function, args...)
seen_cache = Reactant.OrderedIdDict()
make_tracer(
seen_cache,
args,
(), # we have to insert something here, but we remove it immediately below.
TracedTrack;
toscalar=false,
track_numbers=(), # TODO: track_numbers?
)
linear_args = []
mlir_caller_args = Reactant.MLIR.IR.Value[]
for (k, v) in seen_cache
v isa TracedType || continue
push!(linear_args, v)
push!(mlir_caller_args, v.mlir_data)
# make tracer inserted `()` into the path, here we remove it:
v.paths = v.paths[1:end-1]
end

seen = Dict()
cache_key = []
make_tracer(seen, (f, args...), cache_key, TracedToTypes)
cache = Reactant.Compiler.callcache()
if haskey(cache, cache_key)
# cache lookup:
(; f_name, mlir_result_types, traced_result, mutated) = cache[cache_key]
else
f_name = String(gensym(Symbol(f)))
temp = Reactant.TracedUtils.make_mlir_fn(
f,
args,
(),
f_name,
false;
args_in_result=:mutated,
do_transpose=false,
)
traced_result, ret, mutated = temp[[3, 6, 10]]
mlir_result_types = [MLIR.IR.type(MLIR.IR.operand(ret, i)) for i in 1:MLIR.IR.noperands(ret)]
cache[cache_key] = (; f_name, mlir_result_types, traced_result, mutated)
end

call_op = MLIR.Dialects.func.call(
mlir_caller_args;
result_0=mlir_result_types,
callee=MLIR.IR.FlatSymbolRefAttribute(f_name),
)

seen_results = Reactant.OrderedIdDict()
jumerckx marked this conversation as resolved.
Show resolved Hide resolved
traced_result = make_tracer(
seen_results,
traced_result,
(), # we have to insert something here, but we remove it immediately below.
TracedSetPath;
toscalar=false,
track_numbers=(),
)
i = 1
for (k, v) in seen_results
v isa TracedType || continue
# this mutates `traced_result`, which is what we want:
v.mlir_data = MLIR.IR.result(call_op, i)
# make tracer inserted `()` into the path, here we remove it:
v.paths = v.paths[1:end-1]
i += 1
end
nres = MLIR.IR.nresults(call_op)
# mutated args are included as the last ones in the call op results
for (result_i, arg_i) in zip(nres-length(mutated):nres, mutated)
TracedUtils.set_mlir_data!(linear_args[arg_i], MLIR.IR.result(call_op, result_i+1))
paths = TracedUtils.get_paths(linear_args[arg_i])
if length(paths) > 0 && length(paths[1]) == 2 && paths[1][1] == :args
# we remove arg from path to make sure it is definitely returned (since it changed)
TracedUtils.set_paths!(linear_args[arg_i], paths[2:end])
end
end

return traced_result
end

function take_region(compiled_fn)
region = MLIR.IR.Region()
MLIR.API.mlirRegionTakeBody(region, MLIR.API.mlirOperationGetRegion(compiled_fn, 0))
Expand Down
21 changes: 20 additions & 1 deletion src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,25 @@ function set_reactant_abi(
)
(; fargs, argtypes) = arginfo

if f === ReactantCore.within_compile
if length(argtypes) != 1
@static if VERSION < v"1.11.0-"
return CallMeta(Union{}, Effects(), NoCallInfo())
else
return CallMeta(Union{}, Union{}, Effects(), NoCallInfo())
end
end
@static if VERSION < v"1.11.0-"
return CallMeta(
Core.Const(true), Core.Compiler.EFFECTS_TOTAL, MethodResultPure()
)
else
return CallMeta(
Core.Const(true), Union{}, Core.Compiler.EFFECTS_TOTAL, MethodResultPure()
)
end
end

# Improve inference by considering call_with_reactant as having the same results as
# the original call
if f === Reactant.call_with_reactant
Expand Down Expand Up @@ -236,7 +255,7 @@ function overload_autodiff(
primf = f.val
primargs = ((v.val for v in args)...,)

fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = TracedUtils.make_mlir_fn(
fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results, _ = TracedUtils.make_mlir_fn(
primf, primargs, (), string(f) * "_autodiff", false
)

Expand Down
2 changes: 1 addition & 1 deletion src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,7 @@ end
(a, b),
(),
"comparator";
no_args_in_result=true,
args_in_result=:none,
return_dialect=:stablehlo,
)[2]
@assert MLIR.IR.nregions(func) == 1
Expand Down
Loading