Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA kernels take 3 #427

Merged
merged 19 commits into from
Dec 29, 2024
Merged
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ PythonCall = "0.9"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.3"
Reactant_jll = "0.0.32"
Reactant_jll = "0.0.33"
Scratch = "1.2"
Statistics = "1.10"
YaoBlocks = "0.13"
Expand Down
4 changes: 2 additions & 2 deletions deps/ReactantExtra/.bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ build -c opt
build:cuda --repo_env TF_NEED_CUDA=1
build:cuda --repo_env TF_NVCC_CLANG=1
build:cuda --repo_env TF_NCCL_USE_STUB=1
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.6.2"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.6.0"
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
Expand Down
33 changes: 20 additions & 13 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ cc_toolchain_config(
coverage_link_flags = ["--coverage"],
cpu = "k8",
cxx_builtin_include_directories = [
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/13.2.0",
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/13.2.0/x86_64-linux-musl",
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/13.2.0/backward",
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/10.2.0",
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/10.2.0/x86_64-linux-musl",
"/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/10.2.0/backward",
Expand Down Expand Up @@ -149,14 +152,14 @@ cc_toolchain_config(
abi_libc_version = "local",
abi_version = "local",
cxx_builtin_include_directories = [
"/opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include",
"/opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include-fixed",
"/opt/BB_TARGET/lib/gcc/BB_TARGET/13.2.0/include",
"/opt/BB_TARGET/lib/gcc/BB_TARGET/13.2.0/include-fixed",
"/opt/BB_TARGET/BB_TARGET/include",
"/opt/BB_TARGET/BB_TARGET/sys-root/usr/include",
"/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0",
"/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/BB_TARGET",
"/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/backward",
"/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/parallel"
"/opt/BB_TARGET/BB_TARGET/include/c++/13.2.0",
"/opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/BB_TARGET",
"/opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/backward",
"/opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/parallel"
],
tool_paths = {
"ar": "/opt/bin/BB_FULL_TARGET/ar",
Expand Down Expand Up @@ -193,14 +196,14 @@ cc_toolchain_config(
"-Wno-free-nonheap-object",
"-fno-omit-frame-pointer",
# TODO cxx_builtin_include_directories doesn't seem to be working, so we add the INCLUDE_PATHs manually
"-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include",
"-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include-fixed",
"-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/13.2.0/include",
"-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/13.2.0/include-fixed",
"-isystem /opt/BB_TARGET/BB_TARGET/include",
"-isystem /opt/BB_TARGET/BB_TARGET/sys-root/usr/include",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/BB_TARGET",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/backward",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/parallel",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/13.2.0",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/BB_TARGET",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/backward",
"-isystem /opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/parallel",
],
opt_compile_flags = [
"-g0",
Expand Down Expand Up @@ -361,6 +364,7 @@ cc_library(

) + [
"@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.cpp",
"@enzyme_ad//src/enzyme_ad/jax:gpu.cc",
# "@com_google_protobuf//:src/google/protobuf/io/coded_stream.cc",
# "@xla//xla:xla.pb.cc",
"@xla//xla:xla_data.pb.cc",
Expand Down Expand Up @@ -429,7 +433,7 @@ cc_library(
"-Wl,-exported_symbol,_ifrt_*",
"-Wl,-exported_symbol,_RegisterCustomCallTarget",
"-Wl,-exported_symbol,_ConvertLLVMToMLIR",
"-Wl,-exported_symbol,_EnzymeGPUCustomCall",
"-Wl,-exported_symbol,_RegisterEnzymeXLAGPUHandler",
"-Wl,-exported_symbol,_ReactantThrowError",
]}),
deps = [
Expand Down Expand Up @@ -469,6 +473,9 @@ cc_library(
"@llvm-project//llvm:X86CodeGen",
"@enzyme_ad//src/enzyme_ad/jax:TransformOps",
"@enzyme_ad//src/enzyme_ad/jax:XLADerivatives",
# "@enzyme_ad//src/enzyme_ad/jax:gpu",
"@xla//xla/ffi/api:ffi",
"@xla//xla/ffi:ffi_api",
"@stablehlo//:chlo_ops",
"@xla//xla/pjrt:pjrt_api",
"@xla//xla/pjrt:pjrt_c_api_client",
Expand Down
2 changes: 1 addition & 1 deletion deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ http_archive(
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
)

ENZYMEXLA_COMMIT = "b6d6563aa3a3050474a4250bf18322f7ebf0b486"
ENZYMEXLA_COMMIT = "74046d05089c02946058f8fd94ed23efd0bf3ccc"
ENZYMEXLA_SHA256 = ""

http_archive(
Expand Down
2 changes: 2 additions & 0 deletions deps/build_local.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ else
run(
Cmd(
`bazel build $(arg) -c $(build_kind) --action_env=JULIA=$(Base.julia_cmd().exec[1])
--repo_env=GCC_HOST_COMPILER_PATH=/usr/bin/gcc
--repo_env=CC=/home/wmoses/llvms/llvm16-r/clang+llvm-16.0.2-x86_64-linux-gnu-ubuntu-22.04/bin/clang
--repo_env HERMETIC_PYTHON_VERSION="3.10"
--check_visibility=false --verbose_failures :libReactantExtra.so`;
dir=source_dir,
Expand Down
18 changes: 18 additions & 0 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction(
end

function __init__()
if CUDA.CUDA_Driver_jll.libcuda !== nothing
handle = Reactant.XLA.Libdl.dlopen(CUDA.CUDA_Driver_jll.libcuda; throw_error=false)
if handle === nothing
handle = C_NULL
Expand All @@ -479,6 +480,23 @@ function __init__()
Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1)
Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2)
Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3)
end
ptr1 = Reactant.XLA.Libdl.dlsym(handle, "cuLaunchKernel"; throw_error=false)
if ptr1 === nothing
ptr1 = C_NULL
end
ptr2 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleLoadData"; throw_error=false)
if ptr2 === nothing
ptr2 = C_NULL
end
ptr3 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleGetFunction"; throw_error=false)
if ptr3 === nothing
ptr3 = C_NULL
end
Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1)
Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2)
Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3)
end
return nothing
end

Expand Down
6 changes: 1 addition & 5 deletions src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,7 @@ function __init__()
end
end

@ccall MLIR.API.mlir_c.RegisterCustomCallTarget(
"enzymexla_gpu"::Cstring,
cglobal((:EnzymeGPUCustomCall, MLIR.API.mlir_c))::Ptr{Cvoid},
"CUDA"::Cstring,
)::Cvoid
@ccall MLIR.API.mlir_c.RegisterEnzymeXLAGPUHandler()::Cvoid

# This wasn't properly exported on macos, we'll remove the try once macOS JLL
# has the fix.
Expand Down
5 changes: 4 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,10 @@ function call_with_reactant_generator(
ir, rt = CC.typeinf_ircode(interp, mi, nothing)
end

ir, any_changed = rewrite_insts!(ir, interp)
if !is_reactant_method(mi::Core.MethodInstance)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jumerckx as a note. Something that I missed earlier, we previously had a catch all like this that we'll only do rewrite insts if the top level function is not a reactant method, but I think this was accidentally removed by your earlier PR. In any case restored here (and once CUDA lands will actually be tested it is done)

ir, any_changed = rewrite_insts!(ir, interp)
end

src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ())
src.slotnames = fill(:none, length(ir.argtypes) + 1)
src.slotflags = fill(zero(UInt8), length(ir.argtypes))
Expand Down
36 changes: 0 additions & 36 deletions test/cuda.jl

This file was deleted.

29 changes: 29 additions & 0 deletions test/integration/cuda.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using Reactant
using Test
using CUDA

function square_kernel!(x, y)
i = threadIdx().x
x[i] *= y[i]
sync_threads()
return nothing
end

# basic squaring on GPU
function square!(x, y)
@cuda blocks = 1 threads = length(x) square_kernel!(x, y)
return nothing
end

@testset "Square Kernel" begin
oA = collect(1:1:64)
A = Reactant.to_rarray(oA)
B = Reactant.to_rarray(100 .* oA)
if CUDA.functional()
@jit square!(A, B)
@test all(Array(A) .≈ (oA .* oA .* 100))
@test all(Array(B) .≈ (oA .* 100))
else
@compile optimize = :before_kernel square!(A, B)
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"
@safetestset "CUDA" include("integration/cuda.jl")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conditionally run with a CUDA.functional()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we have that in the test itself, but it’s actually important that we can compile cuda code even without a cuda gpu (since we want to take existing cuda kernels and rewrite them to whatever hardware).

@safetestset "Linear Algebra" include("integration/linear_algebra.jl")
@safetestset "AbstractFFTs" include("integration/fft.jl")
@safetestset "Random" include("integration/random.jl")
Expand Down
Loading