578578const DEBUG_KERNEL = Ref {Bool} (false )
579579const DUMP_LLVMIR = Ref {Bool} (false )
580580
581- const Raise = Ref {Bool} (false )
581+ function activate_raising! (is_raising:: Bool )
582+ stack = get! (task_local_storage (), :reactant_is_raising ) do
583+ Bool[]
584+ end
585+ push! (stack, is_raising)
586+ return nothing
587+ end
588+
589+ function deactivate_raising! (is_raising:: Bool )
590+ key = :reactant_is_raising
591+ is_raising === last (task_local_storage (key)) ||
592+ error (" Deactivating wrong Reactant raising context" )
593+ return pop! (task_local_storage (key))
594+ end
595+
596+ function raising (; throw_error:: Bool = true )
597+ key = :reactant_is_raising
598+ if ! (haskey (task_local_storage (), key) && ! Base. isempty (task_local_storage (key)))
599+ throw_error && error (" No Reactant raising context" )
600+ end
601+ return last (task_local_storage (key):: Vector{Bool} )
602+ end
603+
604+ function raising! (f, is_raising:: Bool )
605+ activate_raising! (is_raising)
606+ try
607+ return f ()
608+ finally
609+ deactivate_raising! (is_raising)
610+ end
611+ end
582612
583613function compile_mlir! (
584614 mod,
@@ -605,6 +635,7 @@ function compile_mlir!(
605635 no_nan:: Bool = false ,
606636 backend= " gpu" ,
607637 fn_kwargs= (),
638+ raise:: Union{Bool,String} = false ,
608639)
609640 # Explicitly don't use block! to avoid creating a closure, which creates
610641 # both compile-time and relocatability issues
@@ -614,9 +645,16 @@ function compile_mlir!(
614645 activate_callcache! (callcache)
615646 activate_sdycache! (sdycache)
616647
648+ # Save in the TLS whether we are raising. We identify that condition by
649+ # checking whether the user set an explicit list of passes, or chose
650+ # `raise=true` to use the default passes.
651+ is_raising = raise isa String || raise
652+ activate_raising! (is_raising)
653+
617654 mlir_fn_res = try
618655 Reactant. TracedUtils. make_mlir_fn (f, args, fn_kwargs, " main" , true )
619656 finally
657+ deactivate_raising! (is_raising)
620658 deactivate_sdycache! (sdycache)
621659 deactivate_callcache! (callcache)
622660 MLIR. IR. deactivate! (MLIR. IR. body (mod))
@@ -648,14 +686,14 @@ function compile_mlir!(
648686 )
649687 @assert curesulthandler != = nothing
650688 curesulthandler = Base. reinterpret (UInt, curesulthandler)
651- kern = if Raise[]
689+ kern = if is_raising
652690 " lower-kernel{backend=cpu},symbol-dce,canonicalize"
653691 else
654692 " lower-kernel,canonicalize"
655693 end
656694 jit = " lower-jit{debug=true cuResultHandlerPtr=$curesulthandler cuOptLevel=$(cuOptLevel[]) cubinFormat=$(cubinFormat[]) indexBitWidth=$(cuindexBitWidth[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures ()) run_init=true toolkitPath=$toolkit },symbol-dce"
657695 else
658- kern = if Raise[]
696+ kern = if is_raising
659697 " lower-kernel{backend=cpu},symbol-dce,canonicalize"
660698 else
661699 " lower-kernel,canonicalize"
@@ -666,8 +704,13 @@ function compile_mlir!(
666704 opt_passes = optimization_passes (; no_nan, sroa= true )
667705 opt_passes2 = optimization_passes (; no_nan, sroa= false )
668706
669- raise = if Raise[]
670- " canonicalize,llvm-to-memref-access,canonicalize,convert-llvm-to-cf,canonicalize,enzyme-lift-cf-to-scf,canonicalize,func.func(canonicalize-loops),canonicalize-scf-for,canonicalize,affine-cfg,canonicalize,func.func(canonicalize-loops),canonicalize,llvm-to-affine-access,canonicalize,delinearize-indexing,canonicalize,raise-affine-to-stablehlo,arith-raise{stablehlo=true}," * opt_passes2
707+ raise_passes = if raise isa String
708+ # Raising passes were specified
709+ raise
710+ elseif raise
711+ # Raise enabled but use default passes
712+ " canonicalize,llvm-to-memref-access,canonicalize,convert-llvm-to-cf,canonicalize,enzyme-lift-cf-to-scf,canonicalize,func.func(canonicalize-loops),canonicalize-scf-for,canonicalize,affine-cfg,canonicalize,func.func(canonicalize-loops),canonicalize,llvm-to-affine-access,canonicalize,delinearize-indexing,canonicalize,raise-affine-to-stablehlo,arith-raise{stablehlo=true}," *
713+ opt_passes2
671714 else
672715 " canonicalize"
673716 end
@@ -686,7 +729,7 @@ function compile_mlir!(
686729 " enzyme-simplify-math" ,
687730 opt_passes2,
688731 kern,
689- raise ,
732+ raise_passes ,
690733 jit,
691734 ],
692735 ' ,' ,
@@ -723,7 +766,7 @@ function compile_mlir!(
723766 " enzyme-simplify-math" ,
724767 opt_passes2,
725768 kern,
726- raise ,
769+ raise_passes ,
727770 ],
728771 ' ,' ,
729772 ),
@@ -787,7 +830,7 @@ function compile_mlir!(
787830 " enzyme-simplify-math" ,
788831 opt_passes2,
789832 kern,
790- raise ,
833+ raise_passes ,
791834 jit,
792835 ],
793836 ' ,' ,
@@ -804,7 +847,7 @@ function compile_mlir!(
804847 [
805848 " canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math" ,
806849 kern,
807- raise ,
850+ raise_passes ,
808851 jit,
809852 ],
810853 ' ,' ,
@@ -891,7 +934,7 @@ See also [`@code_xla`](@ref), [`@code_mhlo`](@ref).
891934"""
892935macro code_hlo (args... )
893936 default_options = Dict {Symbol,Any} (
894- :optimize => true , :no_nan => false , :client => nothing
937+ :optimize => true , :no_nan => false , :client => nothing , :raise => false
895938 )
896939 compile_expr, (; compiled) = compile_call_expr (
897940 __module__, compile_mlir, default_options, args...
@@ -915,7 +958,7 @@ See also [`@code_xla`](@ref), [`@code_hlo`](@ref).
915958"""
916959macro code_mhlo (args... )
917960 default_options = Dict {Symbol,Any} (
918- :optimize => true , :no_nan => false , :client => nothing
961+ :optimize => true , :no_nan => false , :client => nothing , :raise => false
919962 )
920963 compile_expr, (; compiled) = compile_call_expr (
921964 __module__, compile_xla, default_options, args...
@@ -939,7 +982,7 @@ See also [`@code_mhlo`](@ref), [`@code_hlo`](@ref).
939982"""
940983macro code_xla (args... )
941984 default_options = Dict {Symbol,Any} (
942- :optimize => true , :no_nan => false , :client => nothing
985+ :optimize => true , :no_nan => false , :client => nothing , :raise => false
943986 )
944987 compile_expr, (; compiled) = compile_call_expr (
945988 __module__, compile_xla, default_options, args...
@@ -961,7 +1004,11 @@ end
9611004"""
9621005macro compile (args... )
9631006 default_options = Dict {Symbol,Any} (
964- :optimize => true , :sync => false , :no_nan => false , :client => nothing
1007+ :optimize => true ,
1008+ :sync => false ,
1009+ :no_nan => false ,
1010+ :client => nothing ,
1011+ :raise => false ,
9651012 )
9661013 return esc (first (compile_call_expr (__module__, compile, default_options, args... )))
9671014end
@@ -973,7 +1020,11 @@ Run @compile f(args..) then immediately execute it
9731020"""
9741021macro jit (args... )
9751022 default_options = Dict {Symbol,Any} (
976- :optimize => true , :sync => false , :no_nan => false , :client => nothing
1023+ :optimize => true ,
1024+ :sync => false ,
1025+ :no_nan => false ,
1026+ :client => nothing ,
1027+ :raise => false ,
9771028 )
9781029 compile_expr, (; compiled, args) = compile_call_expr (
9791030 __module__, compile, default_options, args...
@@ -988,14 +1039,14 @@ macro jit(args...)
9881039 # ! format: on
9891040end
9901041
991- function compile_call_expr (mod, compiler, options, args... )
1042+ function compile_call_expr (mod, compiler, options:: Dict , args... )
9921043 while length (args) > 1
9931044 option, args = args[1 ], args[2 : end ]
9941045 if ! Meta. isexpr (option, :(= ))
9951046 error (" Invalid option $(option) " )
9961047 else
9971048 option_name = option. args[1 ]
998- @assert haskey (options, option_name) " Invalid option $(option_name) "
1049+ @assert haskey (options, option_name) " Invalid option name ' $(option_name) '. Valid options are $( join ( keys (options), " , " ) )"
9991050 options[option_name] = option. args[2 ]
10001051 end
10011052 end
0 commit comments