Skip to content

Commit

Permalink
CUDA kernels take 3 (EnzymeAD#427)
Browse files Browse the repository at this point in the history
* CUDA take 3

* conditional run cuda

* Update test/integration/cuda.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* bump enzymexla

* fix

* fix gpu reg

* Update BUILD

* Update BUILD

* Update Project.toml

* Update ReactantCUDAExt.jl

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fix reactant method blocker

* Update ReactantCUDAExt.jl

* only do compile

* use names in cache

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* cleanup further gc issues

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fix

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
wsmoses and github-actions[bot] authored Dec 29, 2024
1 parent b0a58bd commit d4e7c76
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 85 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ PythonCall = "0.9"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.3"
Reactant_jll = "0.0.32"
Reactant_jll = "0.0.33"
Scratch = "1.2"
Statistics = "1.10"
YaoBlocks = "0.13"
Expand Down
4 changes: 2 additions & 2 deletions deps/ReactantExtra/.bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ build -c opt
build:cuda --repo_env TF_NEED_CUDA=1
build:cuda --repo_env TF_NVCC_CLANG=1
build:cuda --repo_env TF_NCCL_USE_STUB=1
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.6.2"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.6.0"
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
Expand Down
33 changes: 20 additions & 13 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ cc_toolchain_config(
coverage_link_flags = ["--coverage"],
cpu = "k8",
cxx_builtin_include_directories = [
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/13.2.0",
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/13.2.0/x86_64-linux-musl",
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/13.2.0/backward",
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/10.2.0",
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/10.2.0/x86_64-linux-musl",
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/10.2.0/backward",
Expand Down Expand Up @@ -149,14 +152,14 @@ cc_toolchain_config(
abi_libc_version = "local",
abi_version = "local",
cxx_builtin_include_directories = [
"/opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include",
"/opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include-fixed",
"/opt/BB_TARGET/lib/gcc/BB_TARGET/13.2.0/include",
"/opt/BB_TARGET/lib/gcc/BB_TARGET/13.2.0/include-fixed",
"/opt/BB_TARGET/BB_TARGET/include",
"/opt/BB_TARGET/BB_TARGET/sys-root/usr/include",
"/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0",
"/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/BB_TARGET",
"/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/backward",
"/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/parallel"
"/opt/BB_TARGET/BB_TARGET/include/c++/13.2.0",
"/opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/BB_TARGET",
"/opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/backward",
"/opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/parallel"
],
tool_paths = {
"ar": "/opt/bin/BB_FULL_TARGET/ar",
Expand Down Expand Up @@ -193,14 +196,14 @@ cc_toolchain_config(
"-Wno-free-nonheap-object",
"-fno-omit-frame-pointer",
# TODO cxx_builtin_include_directories doesn't seem to be working, so we add the INCLUDE_PATHs manually
"-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include",
"-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include-fixed",
"-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/13.2.0/include",
"-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/13.2.0/include-fixed",
"-isystem /opt/BB_TARGET/BB_TARGET/include",
"-isystem /opt/BB_TARGET/BB_TARGET/sys-root/usr/include",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/BB_TARGET",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/backward",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/parallel",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/13.2.0",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/BB_TARGET",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/backward",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/parallel",
],
opt_compile_flags = [
"-g0",
Expand Down Expand Up @@ -361,6 +364,7 @@ cc_library(

) + [
"@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.cpp",
"@enzyme_ad//src/enzyme_ad/jax:gpu.cc",
# "@com_google_protobuf//:src/google/protobuf/io/coded_stream.cc",
# "@xla//xla:xla.pb.cc",
"@xla//xla:xla_data.pb.cc",
Expand Down Expand Up @@ -429,7 +433,7 @@ cc_library(
"-Wl,-exported_symbol,_ifrt_*",
"-Wl,-exported_symbol,_RegisterCustomCallTarget",
"-Wl,-exported_symbol,_ConvertLLVMToMLIR",
"-Wl,-exported_symbol,_EnzymeGPUCustomCall",
"-Wl,-exported_symbol,_RegisterEnzymeXLAGPUHandler",
"-Wl,-exported_symbol,_ReactantThrowError",
]}),
deps = [
Expand Down Expand Up @@ -469,6 +473,9 @@ cc_library(
"@llvm-project//llvm:X86CodeGen",
"@enzyme_ad//src/enzyme_ad/jax:TransformOps",
"@enzyme_ad//src/enzyme_ad/jax:XLADerivatives",
# "@enzyme_ad//src/enzyme_ad/jax:gpu",
"@xla//xla/ffi/api:ffi",
"@xla//xla/ffi:ffi_api",
"@stablehlo//:chlo_ops",
"@xla//xla/pjrt:pjrt_api",
"@xla//xla/pjrt:pjrt_c_api_client",
Expand Down
2 changes: 1 addition & 1 deletion deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ http_archive(
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
)

ENZYMEXLA_COMMIT = "b6d6563aa3a3050474a4250bf18322f7ebf0b486"
ENZYMEXLA_COMMIT = "74046d05089c02946058f8fd94ed23efd0bf3ccc"
ENZYMEXLA_SHA256 = ""

http_archive(
Expand Down
2 changes: 2 additions & 0 deletions deps/build_local.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ else
run(
Cmd(
`bazel build $(arg) -c $(build_kind) --action_env=JULIA=$(Base.julia_cmd().exec[1])
--repo_env=GCC_HOST_COMPILER_PATH=/usr/bin/gcc
--repo_env=CC=/home/wmoses/llvms/llvm16-r/clang+llvm-16.0.2-x86_64-linux-gnu-ubuntu-22.04/bin/clang
--repo_env HERMETIC_PYTHON_VERSION="3.10"
--check_visibility=false --verbose_failures :libReactantExtra.so`;
dir=source_dir,
Expand Down
52 changes: 26 additions & 26 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,11 @@ function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where
return res
end

const _kernel_instances = Dict{Any,Any}()

# Since we cache these objects we cannot cache data containing MLIR operations (e.g. the entry must be a string
# and not the operation itself).
struct LLVMFunc{F,tt}
f::Union{F,Nothing}
entry::MLIR.IR.Operation
entry::String
end

const GPUCompiler = CUDA.GPUCompiler
Expand Down Expand Up @@ -324,9 +324,9 @@ function compile(job)
)::MLIR.API.MlirOperation

entry = MLIR.IR.Operation(linkRes)

entry
String(Reactant.TracedUtils.get_attribute_by_name(linkRes, "sym_name"))
end

return LLVMFunc{job.source.specTypes.parameters[1],job.source.specTypes}(nothing, entry)
end

Expand Down Expand Up @@ -378,9 +378,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(

output_operand_aliases = MLIR.IR.Attribute(aliases)

fname = Reactant.TracedUtils.get_attribute_by_name(func.entry, "sym_name")
# Force public for now while we don't have real users
# MLIR.IR.rmattr!(func.entry, "sym_visibility")
fname = func.entry

operands = MLIR.IR.Value[]
for idx in
Expand Down Expand Up @@ -460,25 +458,27 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction(
end

function __init__()
handle = Reactant.XLA.Libdl.dlopen(CUDA.CUDA_Driver_jll.libcuda; throw_error=false)
if handle === nothing
handle = C_NULL
end
ptr1 = Reactant.XLA.Libdl.dlsym(handle, "cuLaunchKernel"; throw_error=false)
if ptr1 === nothing
ptr1 = C_NULL
end
ptr2 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleLoadData"; throw_error=false)
if ptr2 === nothing
ptr2 = C_NULL
end
ptr3 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleGetFunction"; throw_error=false)
if ptr3 === nothing
ptr3 = C_NULL
if CUDA.CUDA_Driver_jll.libcuda !== nothing
handle = Reactant.XLA.Libdl.dlopen(CUDA.CUDA_Driver_jll.libcuda; throw_error=false)
if handle === nothing
handle = C_NULL
end
ptr1 = Reactant.XLA.Libdl.dlsym(handle, "cuLaunchKernel"; throw_error=false)
if ptr1 === nothing
ptr1 = C_NULL
end
ptr2 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleLoadData"; throw_error=false)
if ptr2 === nothing
ptr2 = C_NULL
end
ptr3 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleGetFunction"; throw_error=false)
if ptr3 === nothing
ptr3 = C_NULL
end
Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1)
Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2)
Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3)
end
Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1)
Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2)
Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3)
return nothing
end

Expand Down
6 changes: 1 addition & 5 deletions src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,7 @@ function __init__()
end
end

@ccall MLIR.API.mlir_c.RegisterCustomCallTarget(
"enzymexla_gpu"::Cstring,
cglobal((:EnzymeGPUCustomCall, MLIR.API.mlir_c))::Ptr{Cvoid},
"CUDA"::Cstring,
)::Cvoid
@ccall MLIR.API.mlir_c.RegisterEnzymeXLAGPUHandler()::Cvoid

# This wasn't properly exported on macos, we'll remove the try once macOS JLL
# has the fix.
Expand Down
5 changes: 4 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,10 @@ function call_with_reactant_generator(
ir, rt = CC.typeinf_ircode(interp, mi, nothing)
end

ir, any_changed = rewrite_insts!(ir, interp)
if !is_reactant_method(mi::Core.MethodInstance)
ir, any_changed = rewrite_insts!(ir, interp)
end

src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ())
src.slotnames = fill(:none, length(ir.argtypes) + 1)
src.slotflags = fill(zero(UInt8), length(ir.argtypes))
Expand Down
36 changes: 0 additions & 36 deletions test/cuda.jl

This file was deleted.

29 changes: 29 additions & 0 deletions test/integration/cuda.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using Reactant
using Test
using CUDA

function square_kernel!(x, y)
i = threadIdx().x
x[i] *= y[i]
sync_threads()
return nothing
end

# basic squaring on GPU
function square!(x, y)
@cuda blocks = 1 threads = length(x) square_kernel!(x, y)
return nothing
end

@testset "Square Kernel" begin
oA = collect(1:1:64)
A = Reactant.to_rarray(oA)
B = Reactant.to_rarray(100 .* oA)
if CUDA.functional()
@jit square!(A, B)
@test all(Array(A) .≈ (oA .* oA .* 100))
@test all(Array(B) .≈ (oA .* 100))
else
@code_hlo optimize = :before_kernel square!(A, B)
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"
# Temporarily disabled as minutia are debugged
# @safetestset "CUDA" include("integration/cuda.jl")
@safetestset "Linear Algebra" include("integration/linear_algebra.jl")
@safetestset "AbstractFFTs" include("integration/fft.jl")
@safetestset "Random" include("integration/random.jl")
Expand Down

0 comments on commit d4e7c76

Please sign in to comment.