diff --git a/Project.toml b/Project.toml index c99a7a2e4..2129ac056 100644 --- a/Project.toml +++ b/Project.toml @@ -60,7 +60,7 @@ PythonCall = "0.9" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.3" -Reactant_jll = "0.0.32" +Reactant_jll = "0.0.33" Scratch = "1.2" Statistics = "1.10" YaoBlocks = "0.13" diff --git a/deps/ReactantExtra/.bazelrc b/deps/ReactantExtra/.bazelrc index a8b84e0f0..ba56a9d61 100644 --- a/deps/ReactantExtra/.bazelrc +++ b/deps/ReactantExtra/.bazelrc @@ -18,8 +18,8 @@ build -c opt build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --repo_env TF_NVCC_CLANG=1 build:cuda --repo_env TF_NCCL_USE_STUB=1 -build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" -build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.6.2" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.6.0" # "sm" means we emit only cubin, which is forward compatible within a GPU generation. # "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations. build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90" diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 0a512067a..781548809 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -54,6 +54,9 @@ cc_toolchain_config( coverage_link_flags = ["--coverage"], cpu = "k8", cxx_builtin_include_directories = [ + "/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/13.2.0", + "/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/13.2.0/x86_64-linux-musl", + "/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/13.2.0/backward", "/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/10.2.0", "/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/10.2.0/x86_64-linux-musl", "/opt/x86_64-linux-musl/x86_64-linux-musl/include/c++/10.2.0/backward", @@ -149,14 +152,14 @@ cc_toolchain_config( abi_libc_version = "local", abi_version = "local", cxx_builtin_include_directories = [ - "/opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include", - "/opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include-fixed", + "/opt/BB_TARGET/lib/gcc/BB_TARGET/13.2.0/include", + "/opt/BB_TARGET/lib/gcc/BB_TARGET/13.2.0/include-fixed", "/opt/BB_TARGET/BB_TARGET/include", "/opt/BB_TARGET/BB_TARGET/sys-root/usr/include", - "/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0", - "/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/BB_TARGET", - "/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/backward", - "/opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/parallel" + "/opt/BB_TARGET/BB_TARGET/include/c++/13.2.0", + "/opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/BB_TARGET", + "/opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/backward", + "/opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/parallel" ], tool_paths = { "ar": "/opt/bin/BB_FULL_TARGET/ar", @@ -193,14 +196,14 @@ cc_toolchain_config( "-Wno-free-nonheap-object", "-fno-omit-frame-pointer", # TODO cxx_builtin_include_directories doesn't seem to be working, so we add the INCLUDE_PATHs manually - "-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include", - "-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/10.2.0/include-fixed", + "-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/13.2.0/include", + "-isystem /opt/BB_TARGET/lib/gcc/BB_TARGET/13.2.0/include-fixed", "-isystem /opt/BB_TARGET/BB_TARGET/include", "-isystem /opt/BB_TARGET/BB_TARGET/sys-root/usr/include", - "-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0", - "-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/BB_TARGET", - "-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/backward", - "-isystem /opt/BB_TARGET/BB_TARGET/include/c++/10.2.0/parallel", + "-isystem /opt/BB_TARGET/BB_TARGET/include/c++/13.2.0", + "-isystem /opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/BB_TARGET", + "-isystem /opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/backward", + "-isystem /opt/BB_TARGET/BB_TARGET/include/c++/13.2.0/parallel", ], opt_compile_flags = [ "-g0", @@ -361,6 +364,7 @@ cc_library( ) + [ "@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.cpp", + "@enzyme_ad//src/enzyme_ad/jax:gpu.cc", # "@com_google_protobuf//:src/google/protobuf/io/coded_stream.cc", # "@xla//xla:xla.pb.cc", "@xla//xla:xla_data.pb.cc", @@ -429,7 +433,7 @@ cc_library( "-Wl,-exported_symbol,_ifrt_*", "-Wl,-exported_symbol,_RegisterCustomCallTarget", "-Wl,-exported_symbol,_ConvertLLVMToMLIR", -"-Wl,-exported_symbol,_EnzymeGPUCustomCall", +"-Wl,-exported_symbol,_RegisterEnzymeXLAGPUHandler", "-Wl,-exported_symbol,_ReactantThrowError", ]}), deps = [ @@ -469,6 +473,9 @@ cc_library( "@llvm-project//llvm:X86CodeGen", "@enzyme_ad//src/enzyme_ad/jax:TransformOps", "@enzyme_ad//src/enzyme_ad/jax:XLADerivatives", + # "@enzyme_ad//src/enzyme_ad/jax:gpu", + "@xla//xla/ffi/api:ffi", + "@xla//xla/ffi:ffi_api", "@stablehlo//:chlo_ops", "@xla//xla/pjrt:pjrt_api", "@xla//xla/pjrt:pjrt_c_api_client", diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index dc72ecba9..95c2b7bb9 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "b6d6563aa3a3050474a4250bf18322f7ebf0b486" +ENZYMEXLA_COMMIT = "74046d05089c02946058f8fd94ed23efd0bf3ccc" ENZYMEXLA_SHA256 = "" http_archive( diff --git a/deps/build_local.jl b/deps/build_local.jl index 4138d2b6c..8a0c03e96 100644 --- a/deps/build_local.jl +++ b/deps/build_local.jl @@ -93,6 +93,8 @@ else run( Cmd( `bazel build $(arg) -c $(build_kind) --action_env=JULIA=$(Base.julia_cmd().exec[1]) + --repo_env=GCC_HOST_COMPILER_PATH=/usr/bin/gcc + --repo_env=CC=/home/wmoses/llvms/llvm16-r/clang+llvm-16.0.2-x86_64-linux-gnu-ubuntu-22.04/bin/clang --repo_env HERMETIC_PYTHON_VERSION="3.10" --check_visibility=false --verbose_failures :libReactantExtra.so`; dir=source_dir, diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index ba0765af5..2d709d05f 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -218,11 +218,11 @@ function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where return res end -const _kernel_instances = Dict{Any,Any}() - +# Since we cache these objects we cannot cache data containing MLIR operations (e.g. the entry must be a string +# and not the operation itself). struct LLVMFunc{F,tt} f::Union{F,Nothing} - entry::MLIR.IR.Operation + entry::String end const GPUCompiler = CUDA.GPUCompiler @@ -324,9 +324,9 @@ function compile(job) )::MLIR.API.MlirOperation entry = MLIR.IR.Operation(linkRes) - - entry + String(Reactant.TracedUtils.get_attribute_by_name(linkRes, "sym_name")) end + return LLVMFunc{job.source.specTypes.parameters[1],job.source.specTypes}(nothing, entry) end @@ -378,9 +378,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( output_operand_aliases = MLIR.IR.Attribute(aliases) - fname = Reactant.TracedUtils.get_attribute_by_name(func.entry, "sym_name") - # Force public for now while we don't have real users - # MLIR.IR.rmattr!(func.entry, "sym_visibility") + fname = func.entry operands = MLIR.IR.Value[] for idx in @@ -460,25 +458,27 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction( end function __init__() - handle = Reactant.XLA.Libdl.dlopen(CUDA.CUDA_Driver_jll.libcuda; throw_error=false) - if handle === nothing - handle = C_NULL - end - ptr1 = Reactant.XLA.Libdl.dlsym(handle, "cuLaunchKernel"; throw_error=false) - if ptr1 === nothing - ptr1 = C_NULL - end - ptr2 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleLoadData"; throw_error=false) - if ptr2 === nothing - ptr2 = C_NULL - end - ptr3 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleGetFunction"; throw_error=false) - if ptr3 === nothing - ptr3 = C_NULL + if CUDA.CUDA_Driver_jll.libcuda !== nothing + handle = Reactant.XLA.Libdl.dlopen(CUDA.CUDA_Driver_jll.libcuda; throw_error=false) + if handle === nothing + handle = C_NULL + end + ptr1 = Reactant.XLA.Libdl.dlsym(handle, "cuLaunchKernel"; throw_error=false) + if ptr1 === nothing + ptr1 = C_NULL + end + ptr2 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleLoadData"; throw_error=false) + if ptr2 === nothing + ptr2 = C_NULL + end + ptr3 = Reactant.XLA.Libdl.dlsym(handle, "cuModuleGetFunction"; throw_error=false) + if ptr3 === nothing + ptr3 = C_NULL + end + Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1) + Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2) + Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3) end - Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1) - Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2) - Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3) return nothing end diff --git a/src/XLA.jl b/src/XLA.jl index 54b45cd00..6255737e4 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -144,11 +144,7 @@ function __init__() end end - @ccall MLIR.API.mlir_c.RegisterCustomCallTarget( - "enzymexla_gpu"::Cstring, - cglobal((:EnzymeGPUCustomCall, MLIR.API.mlir_c))::Ptr{Cvoid}, - "CUDA"::Cstring, - )::Cvoid + @ccall MLIR.API.mlir_c.RegisterEnzymeXLAGPUHandler()::Cvoid # This wasn't properly exported on macos, we'll remove the try once macOS JLL # has the fix. diff --git a/src/utils.jl b/src/utils.jl index 56fa7587b..83c9d51b6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -364,7 +364,10 @@ function call_with_reactant_generator( ir, rt = CC.typeinf_ircode(interp, mi, nothing) end - ir, any_changed = rewrite_insts!(ir, interp) + if !is_reactant_method(mi::Core.MethodInstance) + ir, any_changed = rewrite_insts!(ir, interp) + end + src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ()) src.slotnames = fill(:none, length(ir.argtypes) + 1) src.slotflags = fill(zero(UInt8), length(ir.argtypes)) diff --git a/test/cuda.jl b/test/cuda.jl deleted file mode 100644 index 549002e4f..000000000 --- a/test/cuda.jl +++ /dev/null @@ -1,36 +0,0 @@ -using Reactant -using Test -using CUDA - -using Reactant_jll -@show Reactant_jll.libReactantExtra_path - -function square_kernel!(x) - #i = threadIdx().x - #x[i] *= x[i] - #@cuprintf("overwrote value of %f was thrown during kernel execution on thread (%d, %d, %d) in block (%d, %d, %d).\n", - # 0.0, threadIdx().x, threadIdx().y, threadIdx().z, blockIdx().x, blockIdx().y, blockIdx().z) - #x[i], threadIdx().x, threadIdx().y, threadIdx().z, blockIdx().x, blockIdx().y, blockIdx().z) - - # sync_threads() - return nothing -end - -# basic squaring on GPU -function square!(x) - @cuda blocks = 1 threads = length(x) square_kernel!(x) - return nothing -end - -@testset "Square Kernel" begin - oA = collect(1:1:64) - A = Reactant.to_rarray(oA) - # @show @code_hlo optimize = false square!(A) - # @show @code_hlo optimize = :before_kernel square!(A) - # @show @code_hlo square!(A) - func! = @compile square!(A) - func!(A) - @show A - @show oA - @test all(Array(A) .≈ (oA .* oA)) -end diff --git a/test/integration/cuda.jl b/test/integration/cuda.jl new file mode 100644 index 000000000..47bb8c23a --- /dev/null +++ b/test/integration/cuda.jl @@ -0,0 +1,29 @@ +using Reactant +using Test +using CUDA + +function square_kernel!(x, y) + i = threadIdx().x + x[i] *= y[i] + sync_threads() + return nothing +end + +# basic squaring on GPU +function square!(x, y) + @cuda blocks = 1 threads = length(x) square_kernel!(x, y) + return nothing +end + +@testset "Square Kernel" begin + oA = collect(1:1:64) + A = Reactant.to_rarray(oA) + B = Reactant.to_rarray(100 .* oA) + if CUDA.functional() + @jit square!(A, B) + @test all(Array(A) .≈ (oA .* oA .* 100)) + @test all(Array(B) .≈ (oA .* 100)) + else + @code_hlo optimize = :before_kernel square!(A, B) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 2b3238d10..834d9b504 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -59,6 +59,8 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" + # Temporarily disabled as minutia are debugged + # @safetestset "CUDA" include("integration/cuda.jl") @safetestset "Linear Algebra" include("integration/linear_algebra.jl") @safetestset "AbstractFFTs" include("integration/fft.jl") @safetestset "Random" include("integration/random.jl")