Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 147 additions & 2 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ function ka_with_reactant(ndrange, workgroupsize, obj, args...)

# figure out the optimal workgroupsize automatically
if KA.workgroupsize(obj) <: KA.DynamicSize && workgroupsize === nothing
if !Reactant.Compiler.PartitionKA[]
if !Reactant.Compiler.PartitionKA[] || Reactant.Compiler.Raise[]
threads = prod(ndrange)
else
config = CUDA.launch_configuration(kernel.fun; max_threads=prod(ndrange))
Expand Down Expand Up @@ -459,6 +459,145 @@ function vendored_optimize_module!(
end
end

function vendored_buildEarlyOptimizerPipeline(mpm, @nospecialize(job), opt_level; instcombine=false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function vendored_buildEarlyOptimizerPipeline(mpm, @nospecialize(job), opt_level; instcombine=false)
function vendored_buildEarlyOptimizerPipeline(
mpm, @nospecialize(job), opt_level; instcombine=false
)

LLVM.add!(mpm, LLVM.NewPMCGSCCPassManager()) do cgpm
# TODO invokeCGSCCCallbacks
LLVM.add!(cgpm, LLVM.NewPMFunctionPassManager()) do fpm
LLVM.add!(fpm, LLVM.Interop.AllocOptPass())
LLVM.add!(fpm, LLVM.Float2IntPass())
LLVM.add!(fpm, LLVM.LowerConstantIntrinsicsPass())
end
end
LLVM.add!(mpm, GPULowerCPUFeaturesPass())
if opt_level >= 1
LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm
if opt_level >= 2
LLVM.add!(fpm, LLVM.SROAPass())
if instcombine
LLVM.add!(fpm, LLVM.InstCombinePass())
else
LLVM.add!(fpm, LLVM.InstSimplifyPass())
end
LLVM.add!(fpm, LLVM.JumpThreadingPass())
LLVM.add!(fpm, LLVM.CorrelatedValuePropagationPass())
LLVM.add!(fpm, LLVM.ReassociatePass())
LLVM.add!(fpm, LLVM.EarlyCSEPass())
LLVM.add!(fpm, LLVM.Interop.AllocOptPass())
else
if instcombine
LLVM.add!(fpm, LLVM.InstCombinePass())
else
LLVM.add!(fpm, LLVM.InstSimplifyPass())
end
LLVM.add!(fpm, LLVM.EarlyCSEPass())
end
end
# TODO invokePeepholeCallbacks
end
end

function vendored_buildIntrinsicLoweringPipeline(mpm, @nospecialize(job), opt_level; instcombine::Bool=false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function vendored_buildIntrinsicLoweringPipeline(mpm, @nospecialize(job), opt_level; instcombine::Bool=false)
function vendored_buildIntrinsicLoweringPipeline(
mpm, @nospecialize(job), opt_level; instcombine::Bool=false
)

GPUCompiler.add!(mpm, LLVM.Interop.RemoveNIPass())

# lower GC intrinsics
if !GPUCompiler.uses_julia_runtime(job)
LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm
LLVM.add!(fpm, GPULowerGCFramePass())
end
end

# lower kernel state intrinsics
# NOTE: we can only do so here, as GC lowering can introduce calls to the runtime,
# and thus additional uses of the kernel state intrinsics.
if job.config.kernel
# TODO: now that all kernel state-related passes are being run here, merge some?
LLVM.add!(mpm, AddKernelStatePass())
LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm
LLVM.add!(fpm, LowerKernelStatePass())
end
LLVM.add!(mpm, CleanupKernelStatePass())
end

if !GPUCompiler.uses_julia_runtime(job)
# remove dead uses of ptls
LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm
LLVM.add!(fpm, LLVM.ADCEPass())
end
LLVM.add!(mpm, GPULowerPTLSPass())
end

LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm
# lower exception handling
if GPUCompiler.uses_julia_runtime(job)
LLVM.add!(fpm, LLVM.Interop.LowerExcHandlersPass())
end
LLVM.add!(fpm, GPUCompiler.GCInvariantVerifierPass())
LLVM.add!(fpm, LLVM.Interop.LateLowerGCPass())
if GPUCompiler.uses_julia_runtime(job) && VERSION >= v"1.11.0-DEV.208"
LLVM.add!(fpm, LLVM.Interop.FinalLowerGCPass())
end
end
if GPUCompiler.uses_julia_runtime(job) && VERSION < v"1.11.0-DEV.208"
LLVM.add!(mpm, LLVM.Interop.FinalLowerGCPass())
end

if opt_level >= 2
LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm
LLVM.add!(fpm, LLVM.GVNPass())
LLVM.add!(fpm, LLVM.SCCPPass())
LLVM.add!(fpm, LLVM.DCEPass())
end
end

# lower PTLS intrinsics
if GPUCompiler.uses_julia_runtime(job)
LLVM.add!(mpm, LLVM.Interop.LowerPTLSPass())
end

if opt_level >= 1
LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm
if instcombine
LLVM.add!(fpm, LLVM.InstCombinePass())
else
LLVM.add!(fpm, LLVM.InstSimplifyPass())
end
LLVM.add!(fpm, LLVM.SimplifyCFGPass(; GPUCompiler.AggressiveSimplifyCFGOptions...))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
LLVM.add!(fpm, LLVM.SimplifyCFGPass(; GPUCompiler.AggressiveSimplifyCFGOptions...))
LLVM.add!(
fpm, LLVM.SimplifyCFGPass(; GPUCompiler.AggressiveSimplifyCFGOptions...)
)

end
end

# remove Julia address spaces
LLVM.add!(mpm, LLVM.Interop.RemoveJuliaAddrspacesPass())

# Julia's operand bundles confuse the inliner, so repeat here now they are gone.
# FIXME: we should fix the inliner so that inlined code gets optimized early-on
LLVM.add!(mpm, LLVM.AlwaysInlinerPass())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
LLVM.add!(mpm, LLVM.AlwaysInlinerPass())
return LLVM.add!(mpm, LLVM.AlwaysInlinerPass())

end

function vendored_buildNewPMPipeline!(mpm, @nospecialize(job), opt_level)
# Doesn't call instcombine
GPUCompiler.buildEarlySimplificationPipeline(mpm, job, opt_level)
LLVM.add!(mpm, LLVM.AlwaysInlinerPass())
vendored_buildEarlyOptimizerPipeline(mpm, job, opt_level)
LLVM.add!(mpm, LLVM.NewPMFunctionPassManager()) do fpm
# Doesn't call instcombine
GPUCompiler.buildLoopOptimizerPipeline(fpm, job, opt_level)
# Doesn't call instcombine
GPUCompiler.buildScalarOptimizerPipeline(fpm, job, opt_level)
if GPUCompiler.uses_julia_runtime(job) && opt_level >= 2
# XXX: we disable vectorization, as this generally isn't useful for GPU targets
# and actually causes issues with some back-end compilers (like Metal).
# TODO: Make this not dependent on `uses_julia_runtime` (likely CPU), but it's own control
# Doesn't call instcombine
GPUCompiler.buildVectorPipeline(fpm, job, opt_level)
end
# if isdebug(:optim)
# add!(fpm, WarnMissedTransformationsPass())
# end
end
vendored_buildIntrinsicLoweringPipeline(mpm, job, opt_level)
GPUCompiler.buildCleanupPipeline(mpm, job, opt_level)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
GPUCompiler.buildCleanupPipeline(mpm, job, opt_level)
return GPUCompiler.buildCleanupPipeline(mpm, job, opt_level)

end

# compile to executable machine code
function compile(job)
# lower to PTX
Expand Down Expand Up @@ -495,11 +634,17 @@ function compile(job)
LLVM.register!(pb, CleanupKernelStatePass())

LLVM.add!(pb, LLVM.NewPMModulePassManager()) do mpm
GPUCompiler.buildNewPMPipeline!(mpm, job, opt_level)
vendored_buildNewPMPipeline!(mpm, job, opt_level)
end
LLVM.run!(pb, mod, tm)
end
if Reactant.Compiler.DUMP_LLVMIR[]
println("cuda.jl pre vendor IR\n", string(mod))
end
vendored_optimize_module!(job, mod)
if Reactant.Compiler.DUMP_LLVMIR[]
println("cuda.jl post vendor IR\n", string(mod))
end
LLVM.run!(CUDA.GPUCompiler.DeadArgumentEliminationPass(), mod, tm)

for fname in ("gpu_report_exception", "gpu_signal_exception")
Expand Down
80 changes: 74 additions & 6 deletions src/Compiler.jl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Reactant.jl/src/Compiler.jl

Lines 780 to 782 in 4e39a42

run_pass_pipeline!(
mod, "canonicalize"
)

Original file line number Diff line number Diff line change
Expand Up @@ -416,10 +416,12 @@ function optimization_passes(; no_nan::Bool=false, sroa::Bool=false, inline::Boo
if sroa
push!(passes, "propagate-constant-bounds")
if DUMP_LLVMIR[]
push!(passes, "sroa-wrappers{dump_prellvm=true dump_postllvm=true}")
push!(passes, "sroa-wrappers{dump_prellvm=true dump_postllvm=true instcombine=false instsimplify=true}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
push!(passes, "sroa-wrappers{dump_prellvm=true dump_postllvm=true instcombine=false instsimplify=true}")
push!(
passes,
"sroa-wrappers{dump_prellvm=true dump_postllvm=true instcombine=false instsimplify=true}",
)

else
push!(passes, "sroa-wrappers")
push!(passes, "sroa-wrappers{instcombine=false instsimplify=true}")
end
push!(passes, "canonicalize")
push!(passes, "sroa-wrappers{instcombine=false instsimplify=true}")
push!(passes, "libdevice-funcs-raise")
push!(passes, "canonicalize")
push!(passes, "remove-duplicate-func-def")
Expand Down Expand Up @@ -556,6 +558,9 @@ end
const DEBUG_KERNEL = Ref{Bool}(false)
const DUMP_LLVMIR = Ref{Bool}(false)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

const Raise = Ref{Bool}(false)

function compile_mlir!(
mod,
f,
Expand Down Expand Up @@ -605,16 +610,33 @@ function compile_mlir!(
end

if backend == "cpu"
kern = "lower-kernel{backend=cpu},canonicalize,lower-jit{openmp=true backend=cpu},symbol-dce"
kern = "lower-kernel{backend=cpu},canonicalize"
jit = "lower-jit{openmp=true backend=cpu},symbol-dce"
elseif DEBUG_KERNEL[]
curesulthandler = dlsym(
Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult"
)
@assert curesulthandler !== nothing
curesulthandler = Base.reinterpret(UInt, curesulthandler)
kern = "lower-kernel,canonicalize,lower-jit{debug=true cuResultHandlerPtr=$curesulthandler cuOptLevel=$(cuOptLevel[]) cubinFormat=$(cubinFormat[]) indexBitWidth=$(cuindexBitWidth[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce"
kern = if Raise[]
"lower-kernel{backend=cpu},canonicalize"
else
"lower-kernel,canonicalize"
end
jit = "lower-jit{debug=true cuResultHandlerPtr=$curesulthandler cuOptLevel=$(cuOptLevel[]) cubinFormat=$(cubinFormat[]) indexBitWidth=$(cuindexBitWidth[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce"
else
kern = "lower-kernel,canonicalize,lower-jit{cuOptLevel=$(cuOptLevel[]) indexBitWidth=$(cuindexBitWidth[]) cubinFormat=$(cubinFormat[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce"
kern = if Raise[]
"lower-kernel{backend=cpu},canonicalize"
else
"lower-kernel,canonicalize"
end
jit = "lower-jit{cuOptLevel=$(cuOptLevel[]) indexBitWidth=$(cuindexBitWidth[]) cubinFormat=$(cubinFormat[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce"
end

raise = if Raise[]
"convert-llvm-to-cf,canonicalize,enzyme-lift-cf-to-scf,llvm-to-affine-access,canonicalize"
else
"canonicalize"
end

opt_passes = optimization_passes(; no_nan, sroa=true)
Expand All @@ -634,6 +656,8 @@ function compile_mlir!(
"enzyme-simplify-math",
opt_passes2,
kern,
raise,
jit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
jit
jit,

],
',',
),
Expand All @@ -655,6 +679,43 @@ function compile_mlir!(
',',
),
)
elseif optimize === :before_jit
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
run_pass_pipeline!(
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
)
run_pass_pipeline!(
mod,
join(
[
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
opt_passes2,
kern,
raise,
],
',',
),
)
elseif optimize === :before_raise
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
run_pass_pipeline!(
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
)
run_pass_pipeline!(
mod,
join(
[
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
opt_passes2,
kern
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
kern
kern,

],
',',
),
)
elseif optimize === :no_enzyme
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
run_pass_pipeline!(mod, "arith-raise{stablehlo=true}"; enable_verifier=false)
Expand Down Expand Up @@ -696,6 +757,8 @@ function compile_mlir!(
"enzyme-simplify-math",
opt_passes2,
kern,
raise,
jit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
jit
jit,

],
',',
),
Expand All @@ -706,7 +769,12 @@ function compile_mlir!(
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
)
run_pass_pipeline!(
mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math," * kern
mod, join([
"canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math",
kern,
raise,
jit
], ',')
Comment on lines +772 to +777
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
mod, join([
"canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math",
kern,
raise,
jit
], ',')
mod,
join(
[
"canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math",
kern,
raise,
jit,
],
',',
),

)
elseif optimize === :canonicalize
run_pass_pipeline!(
Expand Down
Loading