diff --git a/Project.toml b/Project.toml index 1b8c235f3b..46e072d40d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Reactant" uuid = "3c362404-f566-11ee-1572-e11a4b42c853" authors = ["William Moses ", "Valentin Churavy ", "Sergio Sánchez Ramírez ", "Paul Berg ", "Avik Pal "] -version = "0.2.20" +version = "0.2.21" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -70,7 +70,7 @@ PythonCall = "0.9" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.4" -Reactant_jll = "0.0.45" +Reactant_jll = "0.0.46" Scratch = "1.2" Sockets = "1.10" SpecialFunctions = "2.4" diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index b95fc8e4f4..2c5819451f 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -446,6 +446,12 @@ std::vector col_major(int64_t dim) { return minor_to_major; } +extern "C" void ReactantLLVMParseCommandLineOptions(int argc, const char *const *argv, + const char *Overview) { + llvm::cl::ParseCommandLineOptions(argc, argv, StringRef(Overview), + &llvm::nulls()); +} + std::vector row_major(int64_t dim) { std::vector minor_to_major; for (int i = 0; i < dim; i++) { diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index f737f39517..7825fb1d10 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -450,7 +450,8 @@ cc_library( "-Wl,-exported_symbol,_ProfilerActivityStart", "-Wl,-exported_symbol,_ProfilerActivityEnd", "-Wl,-exported_symbol,_ReactantFuncSetArgAttr", -"-Wl,-exported_symbol,_ReactantCudaDriverGetVersion" +"-Wl,-exported_symbol,_ReactantCudaDriverGetVersion", +"-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions" ]}), deps = [ "@enzyme//:EnzymeMLIR", diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index f42737260b..5dd18ef305 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -136,7 +136,7 @@ http_archive( ) # load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256") -XLA_COMMIT = "281c11225c4a0bb7b710a290610a06d71194febd" +XLA_COMMIT = "e0c92850a41cf5208744d8a919b969fa3506863c" XLA_SHA256 = "" http_archive( diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 29e7749e82..e5f6e4fc3b 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -400,11 +400,10 @@ function compile(job) if !isempty(errors) throw(GPUCompiler.InvalidIRError(job, errors)) end - LLVM.strip_debuginfo!(mod) + # LLVM.strip_debuginfo!(mod) modstr = string(mod) # This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version # it is probably safer to reparse a string using the right llvm module api, so we will do that. - mmod = MLIR.IR.Module( @ccall MLIR.API.mlir_c.ConvertLLVMStrToMLIR( modstr::Cstring, MLIR.IR.context()::MLIR.API.MlirContext diff --git a/src/XLA.jl b/src/XLA.jl index 3676bb8d1e..3f06c8d026 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -2,6 +2,13 @@ module XLA import ...MLIR +function LLVMclopts(opts...) + args = ["", opts...] + @ccall MLIR.API.mlir_c.ReactantLLVMParseCommandLineOptions( + length(args)::Cint, args::Ptr{Cstring}, C_NULL::Ptr{Cvoid} + )::Cvoid +end + mutable struct Client client::Ptr{Cvoid} @@ -50,6 +57,7 @@ function CPUClient(asynchronous=false, node_id=0, num_nodes=1; checkcount=true) end f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeCPUClient") client = ccall(f, Ptr{Cvoid}, (UInt, Cint, Cint), asynchronous, node_id, num_nodes) + LLVMclopts("-nvptx-fma-level=1") #client = @ccall MLIR.API.mlir_c.MakeCPUClient(asynchronous::UInt8, node_id::Cint, num_nodes::Cint)::Ptr{Cvoid} return Client(client) end @@ -73,6 +81,7 @@ function GPUClient(node_id=0, num_nodes=1, platform="gpu") if client == C_NULL throw(AssertionError(unsafe_string(refstr[]))) end + LLVMclopts("-nvptx-fma-level=1") return Client(client) end @@ -83,6 +92,7 @@ function TPUClient(tpu_path::String) if client == C_NULL throw(AssertionError(unsafe_string(refstr[]))) end + LLVMclopts("-nvptx-fma-level=1") return Client(client) end