Skip to content

Commit

Permalink
within_tracing, approach taken from Enzyme.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
jumerckx committed Dec 17, 2024
1 parent 55d1527 commit 9b5cd96
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 6 deletions.
5 changes: 1 addition & 4 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import ..Reactant:
TracedType,
Cached
using ScopedValues
import ReactantCore: enable_tracing

@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field)
@inline traced_getfield(
Expand Down Expand Up @@ -297,9 +296,7 @@ function compile_mlir!(mod, f, args, callcache; optimize::Union{Bool,Symbol}=tru
linear_results = MLIR.IR.mmodule!(mod) do
MLIR.IR.block!(MLIR.IR.body(mod)) do
callcache!(callcache) do
with(enable_tracing=>true) do
return Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
end
return Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
end
end
end
Expand Down
24 changes: 24 additions & 0 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,30 @@ function set_reactant_abi(
)
(; fargs, argtypes) = arginfo

if f === ReactantCore.within_tracing
if length(argtypes) != 1
@static if VERSION < v"1.11.0-"
return CallMeta(Union{}, Effects(), NoCallInfo())
else
return CallMeta(Union{}, Union{}, Effects(), NoCallInfo())
end
end
@static if VERSION < v"1.11.0-"
return CallMeta(
Core.Const(true),
Core.Compiler.EFFECTS_TOTAL,
MethodResultPure(),
)
else
return CallMeta(
Core.Const(true),
Union{},
Core.Compiler.EFFECTS_TOTAL,
MethodResultPure(),
)
end
end

# Improve inference by considering call_with_reactant as having the same results as
# the original call
if f === Reactant.call_with_reactant
Expand Down
4 changes: 2 additions & 2 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module Reactant

using ReactantCore: ReactantCore, @trace, MissingTracedValue
using ReactantCore: ReactantCore, @trace, within_tracing, MissingTracedValue

using LinearAlgebra: LinearAlgebra
using Adapt: Adapt, WrappedArray
Expand Down Expand Up @@ -145,7 +145,7 @@ function Enzyme.make_zero(
end

using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile
export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace
export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace, within_tracing

const registry = Ref{MLIR.IR.DialectRegistry}()
function __init__()
Expand Down

0 comments on commit 9b5cd96

Please sign in to comment.