diff --git a/docs/src/api/api.md b/docs/src/api/api.md index b08afef91..b7befc9df 100644 --- a/docs/src/api/api.md +++ b/docs/src/api/api.md @@ -13,6 +13,10 @@ Reactant.@jit ## ReactantCore API +```@docs +within_compile +``` + ```@docs @trace ``` diff --git a/lib/ReactantCore/src/ReactantCore.jl b/lib/ReactantCore/src/ReactantCore.jl index 7279ba585..4b02f42bd 100644 --- a/lib/ReactantCore/src/ReactantCore.jl +++ b/lib/ReactantCore/src/ReactantCore.jl @@ -3,7 +3,7 @@ module ReactantCore using ExpressionExplorer: ExpressionExplorer using MacroTools: MacroTools -export @trace, MissingTracedValue +export @trace, within_compile, MissingTracedValue # Traits is_traced(x) = false @@ -21,6 +21,13 @@ const SPECIAL_SYMBOLS = [ :(:), :nothing, :missing, :Inf, :Inf16, :Inf32, :Inf64, :Base, :Core ] +""" + within_compile() + +Returns true if this function is executed in a Reactant compilation context, otherwise false. +""" +@inline within_compile() = false # behavior is overwritten in Interpreter.jl + # Code generation """ @trace @@ -117,6 +124,13 @@ macro trace(expr) return esc(trace_if_with_returns(__module__, expr)) end end + Meta.isexpr(expr, :call) && return esc(trace_call(__module__, expr)) + if Meta.isexpr(expr, :(.), 2) && Meta.isexpr(expr.args[2], :tuple) + fname = :($(Base.Broadcast.BroadcastFunction)($(expr.args[1]))) + args = only(expr.args[2:end]).args + call = Expr(:call, fname, args...) + return esc(trace_call(__module__, call)) + end Meta.isexpr(expr, :if) && return esc(trace_if(__module__, expr)) Meta.isexpr(expr, :for) && return (esc(trace_for(__module__, expr))) return error("Only `if-elseif-else` blocks are currently supported by `@trace`") @@ -196,7 +210,9 @@ function trace_for(mod, expr) end return quote - if any($(is_traced), $(Expr(:tuple, cond_val.(all_syms.args[(begin + 1):end])...))) + if $(within_compile)() && $(any)( + $(is_traced), $(Expr(:tuple, cond_val.(all_syms.args[(begin + 1):end])...)) + ) $(reactant_code_block) else $(expr) @@ -210,7 +226,7 @@ function trace_if_with_returns(mod, expr) mod, expr.args[2]; store_last_line=expr.args[1], depth=1 ) return quote - if any($(is_traced), ($(all_check_vars...),)) + if $(within_compile)() && $(any)($(is_traced), ($(all_check_vars...),)) $(new_expr) else $(expr) @@ -356,7 +372,7 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0) ) return quote - if any($(is_traced), ($(all_check_vars...),)) + if $(within_compile)() && $(any)($(is_traced), ($(all_check_vars...),)) $(reactant_code_block) else $(original_expr) @@ -364,6 +380,33 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0) end end +function correct_maybe_bcast_call(fname) + startswith(string(fname), '.') || return false, fname, fname + return true, Symbol(string(fname)[2:end]), fname +end + +function trace_call(mod, call) + bcast, fname, fname_full = correct_maybe_bcast_call(call.args[1]) + f = if bcast + quote + if isdefined(mod, $(Meta.quot(fname_full))) + $(fname_full) + else + Base.Broadcast.BroadcastFunction($(fname)) + end + end + else + :($(fname)) + end + return quote + if $(within_compile)() + $(traced_call)($f, $(call.args[2:end]...)) + else + $(call) + end + end +end + function remove_shortcircuiting(expr) return MacroTools.prewalk(expr) do x if MacroTools.@capture(x, a_ && b_) @@ -382,6 +425,8 @@ end function traced_while end # defined inside Reactant.jl +traced_call(f, args...; kwargs...) = f(args...; kwargs...) + function cleanup_expr_to_avoid_boxing(expr, prepend::Symbol, all_vars) return MacroTools.postwalk(expr) do x if Meta.isexpr(x, :kw) # undo lhs rewriting diff --git a/src/Compiler.jl b/src/Compiler.jl index b38308faf..9d98ec2fe 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -20,6 +20,8 @@ import ..Reactant: ancestor, TracedType +import ..ReactantCore: correct_maybe_bcast_call + @inline function traced_getfield(@nospecialize(obj), field) return Base.getfield(obj, field) end @@ -440,18 +442,34 @@ const DEBUG_KERNEL = Ref{Bool}(false) const DUMP_LLVMIR = Ref{Bool}(false) function compile_mlir!( - mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false, backend="gpu" + mod, + f, + args, + callcache=Dict{ + Vector, + @NamedTuple{ + f_name::String, + mlir_result_types::Vector{MLIR.IR.Type}, + traced_result::Any, + mutated::Vector{Int}, + } + }(); + optimize::Union{Bool,Symbol}=true, + no_nan::Bool=false, + backend="gpu", ) # Explicitly don't use block! to avoid creating a closure, which creates # both compile-time and relocatability issues MLIR.IR.activate!(mod) MLIR.IR.activate!(MLIR.IR.body(mod)) + activate_callcache!(callcache) fnwrapped, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = try Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true) finally + deactivate_callcache!(callcache) MLIR.IR.deactivate!(MLIR.IR.body(mod)) MLIR.IR.deactivate!(mod) end @@ -716,11 +734,6 @@ function compile_call_expr(mod, compiler, options, args...) (; compiled=compiled_symbol, args=args_symbol) end -function correct_maybe_bcast_call(fname) - startswith(string(fname), '.') || return false, fname, fname - return true, Symbol(string(fname)[2:end]), fname -end - """ codegen_flatten! @@ -1167,4 +1180,40 @@ function register_thunk( return Thunk{Core.Typeof(f),tag,argtys,isclosure}(f) end +function activate_callcache!(callcache) + stack = get!(task_local_storage(), :callcache) do + return [] + end + push!(stack, callcache) + return nothing +end + +function deactivate_callcache!(callcache) + callcache === last(task_local_storage(:callcache)) || + error("Deactivating wrong callcache") + return pop!(task_local_storage(:callcache)) +end + +function _has_callcache() + return haskey(task_local_storage(), :callcache) && + !Base.isempty(task_local_storage(:callcache)) +end + +function callcache(; throw_error::Bool=true) + if !_has_callcache() + throw_error && error("No callcache is active") + return nothing + end + return last(task_local_storage(:callcache)) +end + +function callcache!(f, callcache) + activate_callcache!(callcache) + try + return f() + finally + deactivate_callcache!(callcache) + end +end + end diff --git a/src/ControlFlow.jl b/src/ControlFlow.jl index a5bb723ba..d3271de59 100644 --- a/src/ControlFlow.jl +++ b/src/ControlFlow.jl @@ -4,6 +4,10 @@ function ReactantCore.traced_if( return Ops.if_condition(cond, true_fn, false_fn, args...) end +function ReactantCore.traced_call(f::Function, args...) + return Ops.call(f, args...) +end + function ReactantCore.traced_while(cond_fn::CFn, body_fn::BFn, args) where {CFn,BFn} return Ops.while_loop(cond_fn, body_fn, args...) end diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 449f78de3..638c1b535 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -39,6 +39,25 @@ function set_reactant_abi( ) (; fargs, argtypes) = arginfo + if f === ReactantCore.within_compile + if length(argtypes) != 1 + @static if VERSION < v"1.11.0-" + return CallMeta(Union{}, Effects(), NoCallInfo()) + else + return CallMeta(Union{}, Union{}, Effects(), NoCallInfo()) + end + end + @static if VERSION < v"1.11.0-" + return CallMeta( + Core.Const(true), Core.Compiler.EFFECTS_TOTAL, MethodResultPure() + ) + else + return CallMeta( + Core.Const(true), Union{}, Core.Compiler.EFFECTS_TOTAL, MethodResultPure() + ) + end + end + # Improve inference by considering call_with_reactant as having the same results as # the original call if f === Reactant.call_with_reactant @@ -236,7 +255,7 @@ function overload_autodiff( primf = f.val primargs = ((v.val for v in args)...,) - fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = TracedUtils.make_mlir_fn( + fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results, _ = TracedUtils.make_mlir_fn( primf, primargs, (), string(f) * "_autodiff", false ) diff --git a/src/Ops.jl b/src/Ops.jl index 31d7f64a6..d9e0d9547 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1070,7 +1070,7 @@ end (sample_inputs...,), (), "comparator"; - no_args_in_result=true, + args_in_result=:none, return_dialect=:stablehlo, )[2] @assert MLIR.IR.nregions(func) == 1 @@ -1679,7 +1679,7 @@ end string(gensym("cond_fn")), false; return_dialect=:stablehlo, - no_args_in_result=true, + args_in_result=:none, do_transpose=false, ) @@ -1690,7 +1690,7 @@ end string(gensym("body_fn")), false; return_dialect=:stablehlo, - no_args_in_result=true, + args_in_result=:none, do_transpose=false, ) @@ -2060,4 +2060,75 @@ end return corrected_traced_results end +@noinline function call(f, args...) + seen_cache = Reactant.OrderedIdDict() + Reactant.make_tracer( + seen_cache, + args, + (), # we have to insert something here, but we remove it immediately below. + Reactant.TracedTrack; + toscalar=false, + ) + linear_args = [] + mlir_caller_args = Reactant.MLIR.IR.Value[] + for (k, v) in seen_cache + v isa Reactant.TracedType || continue + push!(linear_args, v) + push!(mlir_caller_args, v.mlir_data) + # make tracer inserted `()` into the path, here we remove it: + v.paths = v.paths[1:(end - 1)] + end + + seen = Dict() + cache_key = [] + Reactant.make_tracer(seen, (f, args...), cache_key, Reactant.TracedToTypes) + cache = Reactant.Compiler.callcache() + if haskey(cache, cache_key) + # cache lookup: + (; f_name, mlir_result_types, traced_result, mutated) = cache[cache_key] + else + f_name = String(gensym(Symbol(f))) + temp = Reactant.TracedUtils.make_mlir_fn( + f, args, (), f_name, false; args_in_result=:mutated, do_transpose=false + ) + traced_result, ret, mutated = temp[[3, 6, 10]] + mlir_result_types = [ + MLIR.IR.type(MLIR.IR.operand(ret, i)) for i in 1:MLIR.IR.noperands(ret) + ] + cache[cache_key] = (; f_name, mlir_result_types, traced_result, mutated) + end + + call_op = MLIR.Dialects.func.call( + mlir_caller_args; + result_0=mlir_result_types, + callee=MLIR.IR.FlatSymbolRefAttribute(f_name), + ) + + seen_results = Reactant.OrderedIdDict() + traced_result = Reactant.make_tracer( + seen_results, + traced_result, + (), # we have to insert something here, but we remove it immediately below. + Reactant.TracedSetPath; + toscalar=false, + ) + i = 1 + for (k, v) in seen_results + v isa Reactant.TracedType || continue + # this mutates `traced_result`, which is what we want: + v.mlir_data = MLIR.IR.result(call_op, i) + # make tracer inserted `()` into the path, here we remove it: + v.paths = v.paths[1:(end - 1)] + i += 1 + end + nres = MLIR.IR.nresults(call_op) + # mutated args are included as the last ones in the call op results + for (result_i, arg_i) in zip((nres - length(mutated)):nres, mutated) + Reactant.TracedUtils.set_mlir_data!( + linear_args[arg_i], MLIR.IR.result(call_op, result_i + 1) + ) + end + return traced_result +end + end # module Ops diff --git a/src/Reactant.jl b/src/Reactant.jl index 377e33e46..458082aee 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -1,6 +1,6 @@ module Reactant -using ReactantCore: ReactantCore, @trace, MissingTracedValue +using ReactantCore: ReactantCore, @trace, within_compile, MissingTracedValue using LinearAlgebra: LinearAlgebra using Random: Random, AbstractRNG @@ -231,7 +231,7 @@ function Enzyme.make_zero( end using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile -export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace +export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace, within_compile const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}() diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 08d962a58..90206c02a 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -134,8 +134,9 @@ function make_mlir_fn( concretein=true; toscalar=false, return_dialect=:func, + args_in_result::Symbol=:all, + construct_function_without_args::Bool=false, do_transpose=true, - no_args_in_result=false, ) if sizeof(typeof(f)) != 0 || f isa Base.BroadcastFunction return ( @@ -148,8 +149,9 @@ function make_mlir_fn( concretein; toscalar, return_dialect, + args_in_result, + construct_function_without_args, do_transpose, - no_args_in_result, )[2:end]..., ) end @@ -218,6 +220,17 @@ function make_mlir_fn( MLIR.IR.deactivate!(fnbody) end + # check which arguments have been mutated + mutated_args = Int[] + if !construct_function_without_args + for (i, arg) in enumerate(linear_args) + if get_mlir_data(arg) != MLIR.IR.argument(fnbody, i) + # mutation occured! + push!(mutated_args, i) + end + end + end + seen_results = OrderedIdDict() traced_result = Reactant.make_tracer( @@ -240,11 +253,18 @@ function make_mlir_fn( linear_results = Reactant.TracedType[] for (k, v) in seen_results v isa Reactant.TracedType || continue - (no_args_in_result && has_argidx(v)) && continue + (args_in_result != :all && has_argidx(v)) && continue push!(linear_results, v) end + if args_in_result == :mutated + append!(linear_results, linear_args[mutated_args]) + end - out_tys = [transpose_ty(Ops.mlir_type(arg)) for arg in linear_results] + out_tys = if do_transpose + [transpose_ty(Ops.mlir_type(arg)) for arg in linear_results] + else + [Ops.mlir_type(arg) for arg in linear_results] + end MLIR.IR.activate!(fnbody) ret = try @@ -259,7 +279,7 @@ function make_mlir_fn( end push!(vals, col_maj) end - !no_args_in_result && @assert length(vals) == length(linear_results) + args_in_result == :all && @assert length(vals) == length(linear_results) dialect = getfield(MLIR.Dialects, return_dialect) dialect.return_(vals) @@ -289,6 +309,7 @@ function make_mlir_fn( linear_args, in_tys, linear_results, + mutated_args, ) end @@ -401,7 +422,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} return f(scalar_args...) end - fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn( + fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results, _ = make_mlir_fn( f, args, (), string(f) * "_broadcast_scalar", false; toscalar=true ) diff --git a/src/Tracing.jl b/src/Tracing.jl index 9ffa30aaf..d08381d76 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -4,7 +4,12 @@ TracedToConcrete = 3 ArrayToConcrete = 4 TracedSetPath = 5 - NoStopTracedTrack = 6 + TracedToTypes = 6 + NoStopTracedTrack = 7 +end + +struct VisitedObject + id::Int end function traced_type_inner end @@ -659,16 +664,29 @@ function make_tracer( @nospecialize(track_numbers::Type = Union{}), kwargs..., ) - if mode != NoStopTracedTrack && haskey(seen, prev) - return seen[prev] - end RT = Core.Typeof(prev) + if haskey(seen, prev) + if mode == TracedToTypes + id = seen[prev] + push!(path, id) + return nothing + elseif mode != NoStopTracedTrack && haskey(seen, prev) + return seen[prev] + end + elseif mode == TracedToTypes + push!(path, RT) + seen[prev] = VisitedObject(length(seen) + 1) + end TT = traced_type(RT, Val(mode), track_numbers) @assert !Base.isabstracttype(RT) @assert Base.isconcretetype(RT) nf = fieldcount(RT) if TT === Module || TT === String + if mode == TracedToTypes + push!(path, prev) + return nothing + end return prev end @@ -678,16 +696,10 @@ function make_tracer( changed = false for i in 1:nf if isdefined(prev, i) + newpath = mode == TracedToTypes ? path : append_path(path, i) xi = Base.getfield(prev, i) xi2 = make_tracer( - seen, - xi, - append_path(path, i), - mode; - toscalar, - tobatch, - track_numbers, - kwargs..., + seen, xi, newpath, mode; toscalar, tobatch, track_numbers, kwargs... ) if xi !== xi2 changed = true @@ -703,6 +715,10 @@ function make_tracer( end if nf == 0 + if mode == TracedToTypes + push!(path, prev) + return nothing + end return prev end @@ -710,16 +726,10 @@ function make_tracer( changed = false for i in 1:nf if isdefined(prev, i) + newpath = mode == TracedToTypes ? path : append_path(path, i) xi = Base.getfield(prev, i) xi2 = make_tracer( - seen, - xi, - append_path(path, i), - mode; - toscalar, - tobatch, - track_numbers, - kwargs..., + seen, xi, newpath, mode; toscalar, tobatch, track_numbers, kwargs... ) if xi !== xi2 changed = true @@ -730,6 +740,9 @@ function make_tracer( break end end + if mode == TracedToTypes + return nothing + end if !changed seen[prev] = prev return prev @@ -742,6 +755,9 @@ end function make_tracer( seen, @nospecialize(prev::ConcreteRArray{T,N}), @nospecialize(path), mode; kwargs... ) where {T,N} + if mode == TracedToTypes + throw("Cannot have ConcreteRArray as function call argument.") + end if mode == ArrayToConcrete return prev end @@ -760,6 +776,9 @@ end function make_tracer( seen, prev::ConcreteRNumber{T}, @nospecialize(path), mode; kwargs... ) where {T} + if mode == TracedToTypes + throw("Cannot have ConcreteRNumber as function call argument.") + end if mode == ArrayToConcrete return prev end @@ -786,6 +805,10 @@ function make_tracer( if mode == ConcreteToTraced throw("Cannot trace existing trace type") end + if mode == TracedToTypes + push!(path, MLIR.IR.type(prev.mlir_data)) + return nothing + end if mode == TracedTrack TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path)) if !haskey(seen, prev) @@ -839,6 +862,10 @@ function make_tracer( if mode == ConcreteToTraced throw("Cannot trace existing trace type") end + if mode == TracedToTypes + push!(path, MLIR.IR.type(prev.mlir_data)) + return nothing + end if mode == TracedTrack TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path)) if !haskey(seen, prev) @@ -886,6 +913,9 @@ function make_tracer( if mode == ConcreteToTraced throw("Cannot trace existing trace type") end + if mode == TracedToTypes + throw("Cannot have MissingTracedValue as function call argument.") + end if mode == TracedTrack TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path)) if !haskey(seen, prev) @@ -921,6 +951,10 @@ function make_tracer( @nospecialize(track_numbers::Type = Union{}), kwargs..., ) + if mode == TracedToTypes + push!(path, prev) + return nothing + end RT = Core.Typeof(prev) if RT <: track_numbers if mode == ArrayToConcrete @@ -949,8 +983,20 @@ function make_tracer( return prev end -make_tracer(seen, @nospecialize(prev::Type), @nospecialize(path), mode; kwargs...) = prev -make_tracer(seen, prev::Symbol, @nospecialize(path), mode; kwargs...) = prev +function make_tracer(seen, @nospecialize(prev::Type), @nospecialize(path), mode; kwargs...) + if mode == TracedToTypes + push!(path, prev) + return nothing + end + return prev +end +function make_tracer(seen, prev::Symbol, @nospecialize(path), mode; kwargs...) + if mode == TracedToTypes + push!(path, prev) + return nothing + end + return prev +end function make_tracer( seen, @@ -961,6 +1007,12 @@ function make_tracer( tobatch=nothing, kwargs..., ) + if mode == TracedToTypes + push!(path, Core.Typeof(prev)) + make_tracer(seen, prev.re, path, mode; toscalar, tobatch, kwargs...) + make_tracer(seen, prev.im, path, mode; toscalar, tobatch, kwargs...) + return nothing + end return Complex( make_tracer( seen, prev.re, append_path(path, :re), mode; toscalar, tobatch, kwargs... @@ -976,15 +1028,35 @@ function make_tracer( @nospecialize(prev::Array), @nospecialize(path), mode; - @nospecialize(track_numbers::Type = Union{}), + track_numbers::Type=Union{}, kwargs..., ) RT = Core.Typeof(prev) if mode != NoStopTracedTrack && haskey(seen, prev) + if mode == TracedToTypes + visited = seen[prev] + push!(path, visited) + return nothing + end return seen[prev] end - if mode == ArrayToConcrete && eltype(RT) <: Reactant.ReactantPrimitive - return seen[prev] = ConcreteRArray(prev) + if eltype(RT) <: ReactantPrimitive + if mode == ArrayToConcrete && return seen[prev] = ConcreteRArray(prev) + elseif mode == TracedToTypes + # Original array can get mutated so we store a copy: + push!(path, copy(prev)) + seen[prev] = VisitedObject(length(seen) + 1) + return nothing + end + elseif mode == TracedToTypes + push!(path, RT) + for I in eachindex(prev) + if isassigned(prev, I) + pv = prev[I] + make_tracer(seen, pv, path, mode; track_numbers, kwargs...) + end + end + return nothing end TT = traced_type(eltype(RT), Val(mode), track_numbers) newa = Array{TT,ndims(RT)}(undef, size(prev)) @@ -1008,6 +1080,14 @@ function make_tracer( end function make_tracer(seen, @nospecialize(prev::Tuple), @nospecialize(path), mode; kwargs...) + RT = Core.Typeof(prev) + if mode == TracedToTypes + push!(path, RT) + for v in prev + make_tracer(seen, v, path, mode; kwargs...) + end + return nothing + end return ( ( make_tracer(seen, v, append_path(path, i), mode; kwargs...) for @@ -1027,6 +1107,14 @@ function make_tracer( NT = Core.Typeof(prev) A = NT.parameters[1] RT = NT.parameters[2] + + if mode == TracedToTypes + push!(path, NT) + for i in 1:length(A) + make_tracer(seen, Base.getfield(prev, i), path, mode; track_numbers, kwargs...) + end + return nothing + end return NamedTuple{A,traced_type(RT, Val(mode), track_numbers)}(( ( make_tracer( @@ -1042,6 +1130,10 @@ function make_tracer( end function make_tracer(seen, prev::Core.Box, @nospecialize(path), mode; kwargs...) + if mode == TracedToTypes + push!(path, Core.Box) + return make_tracer(seen, prev.contents, path, mode; kwargs...) + end if mode != NoStopTracedTrack && haskey(seen, prev) return seen[prev] end diff --git a/src/utils.jl b/src/utils.jl index 4b909757b..36c4ca7fa 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -151,7 +151,10 @@ function should_rewrite_call(@nospecialize(ft)) if ft <: Type{<:TracedRArray} || ft <: Type{<:TracedRNumber} || ft === Type{MLIR.IR.Location} || - ft === Type{MLIR.IR.Block} + ft === Type{MLIR.IR.Block} || + # TODO: perhaps problematic calls in `traced_call` + # should be moved to TracedUtils.jl: + ft <: typeof(Reactant.ReactantCore.traced_call) return false end diff --git a/test/control_flow.jl b/test/control_flow.jl index dae8aeb8a..e59162875 100644 --- a/test/control_flow.jl +++ b/test/control_flow.jl @@ -626,3 +626,113 @@ end @test @jit(for_with_named_tuple(x_ra)) ≈ for_with_named_tuple(x) end + +_call1(a, b) = a +function call1(a, b) + x = @trace _call1(a, b) + y = @trace _call1(a, b) + return @trace _call1(x, y) +end + +@testset "call: basic" begin + a = rand(2, 3) + b = rand(2, 3) + a_ra = Reactant.to_rarray(a) + b_ra = Reactant.to_rarray(b) + + @test @jit(call1(a_ra, b_ra)) ≈ call1(a, b) + + # check whether the func for _call1 was only generated once: + ir = @code_hlo optimize = false call1(a_ra, b_ra) + ops = [op for op in Reactant.MLIR.IR.OperationIterator(Reactant.MLIR.IR.body(ir))] + @test length(ops) == 2 # call1, _call1 + + # With different operand sizes, different functions need to be generated: + c = rand(4, 5) + c_ra = Reactant.to_rarray(c) + + @test @jit(call1(a_ra, c_ra)) ≈ call1(a, c) + ir = @code_hlo optimize = false call1(a_ra, c_ra) + ops = [op for op in Reactant.MLIR.IR.OperationIterator(Reactant.MLIR.IR.body(ir))] + @test length(ops) == 3 +end + +_call2(a) = a + a +function call2(a) + return @trace _call2(a) +end + +@testset "call: rnumber" begin + a = 10 + a_rn = Reactant.ConcreteRNumber(a) + + @test @jit(call2(a_rn)) == call2(a) +end + +function _call3(x::Int, y) + if x > 10 + return y .+ y + else + return y .* y + end +end + +function call3(y) + z = @trace _call3(1, y) + @trace _call3(1, z) # doesn't generate new function because y.shape == z.shape + @trace _call3(11, y) # new function because x changed. +end + +@testset "call: caching for Julia operands" begin + y = rand(3) + y_ra = Reactant.to_rarray(y) + + ir = @code_hlo optimize = false call3(y_ra) + ops = [op for op in Reactant.MLIR.IR.OperationIterator(Reactant.MLIR.IR.body(ir))] + @test length(ops) == 5 # call3, .+, .*, _call3 (2X) +end + +struct Foo + x +end +struct Bar + x +end + +_call4(foobar::Union{Foo,Bar}) = foobar.x +function call4(foo, foo2, bar) + @trace _call4(foo) + @trace _call4(foo2) + @trace _call4(bar) +end + +@testset "call: Caching struct arguments" begin + a = rand(10) + b = rand(10) + foo = Foo(Reactant.to_rarray(a)) + foo2 = Foo(Reactant.to_rarray(b)) + bar = Foo(Bar(Reactant.to_rarray(b))) # typeof(foo) == typeof(bar), but these don't match! + ir = @code_hlo optimize = false call4(foo, foo2, bar) + ops = [op for op in Reactant.MLIR.IR.OperationIterator(Reactant.MLIR.IR.body(ir))] + @test length(ops) == 3 # call4, _call4 for {foo, foo2}, and _call4 for bar +end + +function _call5!(a, b) + @allowscalar a[1] = zero(eltype(a)) + return b +end + +function call5!(a, b) + @trace _call5!(a, b) + return a +end + +@testset "call: argument mutation" begin + a = ones(3) + b = ones(3) + a_ra = Reactant.to_rarray(a) + b_ra = Reactant.to_rarray(b) + @jit call5!(a_ra, b_ra) + call5!(a, b) + @test a_ra == a +end