Skip to content

Commit 9af8d87

Browse files
giordanowsmoses
andauthored
[Compiler] Make raise a keyword argument (#797)
* [Compiler] Make `raise` a boolean keyword argument * [Compiler] Allow `raise` to be a string too * Use `task_local_storage` to set whether we're raising or not * Add super basic tests for `raise` keyword argument * [Compiler] Make a stack of raising contexts --------- Co-authored-by: William Moses <wmoses@google.com>
1 parent 2e138de commit 9af8d87

File tree

3 files changed

+80
-17
lines changed

3 files changed

+80
-17
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module ReactantCUDAExt
33
using CUDA
44
using Reactant:
55
Reactant, TracedRArray, AnyTracedRArray, AnyConcretePJRTArray, MLIR, TracedRNumber
6+
using Reactant.Compiler: raising
67
using ReactantCore: @trace
78
using GPUCompiler: GPUCompiler
89
using KernelAbstractions: KernelAbstractions
@@ -294,7 +295,7 @@ function ka_with_reactant(ndrange, workgroupsize, obj, args...)
294295

295296
# figure out the optimal workgroupsize automatically
296297
if KA.workgroupsize(obj) <: KA.DynamicSize && workgroupsize === nothing
297-
if !Reactant.Compiler.PartitionKA[] || Reactant.Compiler.Raise[]
298+
if !Reactant.Compiler.PartitionKA[] || raising()
298299
threads = prod(ndrange)
299300
else
300301
config = CUDA.launch_configuration(kernel.fun; max_threads=prod(ndrange))

src/Compiler.jl

Lines changed: 67 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,37 @@ end
578578
const DEBUG_KERNEL = Ref{Bool}(false)
579579
const DUMP_LLVMIR = Ref{Bool}(false)
580580

581-
const Raise = Ref{Bool}(false)
581+
function activate_raising!(is_raising::Bool)
582+
stack = get!(task_local_storage(), :reactant_is_raising) do
583+
Bool[]
584+
end
585+
push!(stack, is_raising)
586+
return nothing
587+
end
588+
589+
function deactivate_raising!(is_raising::Bool)
590+
key = :reactant_is_raising
591+
is_raising === last(task_local_storage(key)) ||
592+
error("Deactivating wrong Reactant raising context")
593+
return pop!(task_local_storage(key))
594+
end
595+
596+
function raising(; throw_error::Bool=true)
597+
key = :reactant_is_raising
598+
if !(haskey(task_local_storage(), key) && !Base.isempty(task_local_storage(key)))
599+
throw_error && error("No Reactant raising context")
600+
end
601+
return last(task_local_storage(key)::Vector{Bool})
602+
end
603+
604+
function raising!(f, is_raising::Bool)
605+
activate_raising!(is_raising)
606+
try
607+
return f()
608+
finally
609+
deactivate_raising!(is_raising)
610+
end
611+
end
582612

583613
function compile_mlir!(
584614
mod,
@@ -605,6 +635,7 @@ function compile_mlir!(
605635
no_nan::Bool=false,
606636
backend="gpu",
607637
fn_kwargs=(),
638+
raise::Union{Bool,String}=false,
608639
)
609640
# Explicitly don't use block! to avoid creating a closure, which creates
610641
# both compile-time and relocatability issues
@@ -614,9 +645,16 @@ function compile_mlir!(
614645
activate_callcache!(callcache)
615646
activate_sdycache!(sdycache)
616647

648+
# Save in the TLS whether we are raising. We identify that condition by
649+
# checking whether the user set an explicit list of passes, or chose
650+
# `raise=true` to use the default passes.
651+
is_raising = raise isa String || raise
652+
activate_raising!(is_raising)
653+
617654
mlir_fn_res = try
618655
Reactant.TracedUtils.make_mlir_fn(f, args, fn_kwargs, "main", true)
619656
finally
657+
deactivate_raising!(is_raising)
620658
deactivate_sdycache!(sdycache)
621659
deactivate_callcache!(callcache)
622660
MLIR.IR.deactivate!(MLIR.IR.body(mod))
@@ -648,14 +686,14 @@ function compile_mlir!(
648686
)
649687
@assert curesulthandler !== nothing
650688
curesulthandler = Base.reinterpret(UInt, curesulthandler)
651-
kern = if Raise[]
689+
kern = if is_raising
652690
"lower-kernel{backend=cpu},symbol-dce,canonicalize"
653691
else
654692
"lower-kernel,canonicalize"
655693
end
656694
jit = "lower-jit{debug=true cuResultHandlerPtr=$curesulthandler cuOptLevel=$(cuOptLevel[]) cubinFormat=$(cubinFormat[]) indexBitWidth=$(cuindexBitWidth[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce"
657695
else
658-
kern = if Raise[]
696+
kern = if is_raising
659697
"lower-kernel{backend=cpu},symbol-dce,canonicalize"
660698
else
661699
"lower-kernel,canonicalize"
@@ -666,8 +704,13 @@ function compile_mlir!(
666704
opt_passes = optimization_passes(; no_nan, sroa=true)
667705
opt_passes2 = optimization_passes(; no_nan, sroa=false)
668706

669-
raise = if Raise[]
670-
"canonicalize,llvm-to-memref-access,canonicalize,convert-llvm-to-cf,canonicalize,enzyme-lift-cf-to-scf,canonicalize,func.func(canonicalize-loops),canonicalize-scf-for,canonicalize,affine-cfg,canonicalize,func.func(canonicalize-loops),canonicalize,llvm-to-affine-access,canonicalize,delinearize-indexing,canonicalize,raise-affine-to-stablehlo,arith-raise{stablehlo=true}," * opt_passes2
707+
raise_passes = if raise isa String
708+
# Raising passes were specified
709+
raise
710+
elseif raise
711+
# Raise enabled but use default passes
712+
"canonicalize,llvm-to-memref-access,canonicalize,convert-llvm-to-cf,canonicalize,enzyme-lift-cf-to-scf,canonicalize,func.func(canonicalize-loops),canonicalize-scf-for,canonicalize,affine-cfg,canonicalize,func.func(canonicalize-loops),canonicalize,llvm-to-affine-access,canonicalize,delinearize-indexing,canonicalize,raise-affine-to-stablehlo,arith-raise{stablehlo=true}," *
713+
opt_passes2
671714
else
672715
"canonicalize"
673716
end
@@ -686,7 +729,7 @@ function compile_mlir!(
686729
"enzyme-simplify-math",
687730
opt_passes2,
688731
kern,
689-
raise,
732+
raise_passes,
690733
jit,
691734
],
692735
',',
@@ -723,7 +766,7 @@ function compile_mlir!(
723766
"enzyme-simplify-math",
724767
opt_passes2,
725768
kern,
726-
raise,
769+
raise_passes,
727770
],
728771
',',
729772
),
@@ -787,7 +830,7 @@ function compile_mlir!(
787830
"enzyme-simplify-math",
788831
opt_passes2,
789832
kern,
790-
raise,
833+
raise_passes,
791834
jit,
792835
],
793836
',',
@@ -804,7 +847,7 @@ function compile_mlir!(
804847
[
805848
"canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math",
806849
kern,
807-
raise,
850+
raise_passes,
808851
jit,
809852
],
810853
',',
@@ -891,7 +934,7 @@ See also [`@code_xla`](@ref), [`@code_mhlo`](@ref).
891934
"""
892935
macro code_hlo(args...)
893936
default_options = Dict{Symbol,Any}(
894-
:optimize => true, :no_nan => false, :client => nothing
937+
:optimize => true, :no_nan => false, :client => nothing, :raise => false
895938
)
896939
compile_expr, (; compiled) = compile_call_expr(
897940
__module__, compile_mlir, default_options, args...
@@ -915,7 +958,7 @@ See also [`@code_xla`](@ref), [`@code_hlo`](@ref).
915958
"""
916959
macro code_mhlo(args...)
917960
default_options = Dict{Symbol,Any}(
918-
:optimize => true, :no_nan => false, :client => nothing
961+
:optimize => true, :no_nan => false, :client => nothing, :raise => false
919962
)
920963
compile_expr, (; compiled) = compile_call_expr(
921964
__module__, compile_xla, default_options, args...
@@ -939,7 +982,7 @@ See also [`@code_mhlo`](@ref), [`@code_hlo`](@ref).
939982
"""
940983
macro code_xla(args...)
941984
default_options = Dict{Symbol,Any}(
942-
:optimize => true, :no_nan => false, :client => nothing
985+
:optimize => true, :no_nan => false, :client => nothing, :raise => false
943986
)
944987
compile_expr, (; compiled) = compile_call_expr(
945988
__module__, compile_xla, default_options, args...
@@ -961,7 +1004,11 @@ end
9611004
"""
9621005
macro compile(args...)
9631006
default_options = Dict{Symbol,Any}(
964-
:optimize => true, :sync => false, :no_nan => false, :client => nothing
1007+
:optimize => true,
1008+
:sync => false,
1009+
:no_nan => false,
1010+
:client => nothing,
1011+
:raise => false,
9651012
)
9661013
return esc(first(compile_call_expr(__module__, compile, default_options, args...)))
9671014
end
@@ -973,7 +1020,11 @@ Run @compile f(args..) then immediately execute it
9731020
"""
9741021
macro jit(args...)
9751022
default_options = Dict{Symbol,Any}(
976-
:optimize => true, :sync => false, :no_nan => false, :client => nothing
1023+
:optimize => true,
1024+
:sync => false,
1025+
:no_nan => false,
1026+
:client => nothing,
1027+
:raise => false,
9771028
)
9781029
compile_expr, (; compiled, args) = compile_call_expr(
9791030
__module__, compile, default_options, args...
@@ -988,14 +1039,14 @@ macro jit(args...)
9881039
#! format: on
9891040
end
9901041

991-
function compile_call_expr(mod, compiler, options, args...)
1042+
function compile_call_expr(mod, compiler, options::Dict, args...)
9921043
while length(args) > 1
9931044
option, args = args[1], args[2:end]
9941045
if !Meta.isexpr(option, :(=))
9951046
error("Invalid option $(option)")
9961047
else
9971048
option_name = option.args[1]
998-
@assert haskey(options, option_name) "Invalid option $(option_name)"
1049+
@assert haskey(options, option_name) "Invalid option name '$(option_name)'. Valid options are $(join(keys(options), ", "))"
9991050
options[option_name] = option.args[2]
10001051
end
10011052
end

test/basic.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -920,3 +920,14 @@ end
920920
@test contains(hlo, "HloModule")
921921
@test contains(hlo, "sine")
922922
end
923+
924+
@testset "Raise keyword" begin
925+
v = randn(Float32, 16)
926+
rv = Reactant.to_rarray(v)
927+
@test sin.(v) @jit raise = true sin.(rv)
928+
@test cos.(v) @jit raise = false cos.(rv)
929+
@test exp.(v) @jit raise = "canonicalize" exp.(rv)
930+
@test_throws Reactant.MLIR.IR.AddPipelineException @jit raise = "this_pass-does_not_ExisT" exp.(
931+
rv
932+
)
933+
end

0 commit comments

Comments
 (0)