Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Reactant"
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>"]
version = "0.2.20"
version = "0.2.21"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -70,7 +70,7 @@ PythonCall = "0.9"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.4"
Reactant_jll = "0.0.45"
Reactant_jll = "0.0.46"
Scratch = "1.2"
Sockets = "1.10"
SpecialFunctions = "2.4"
Expand Down
6 changes: 6 additions & 0 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,12 @@ std::vector<int64_t> col_major(int64_t dim) {
return minor_to_major;
}

extern "C" void ReactantLLVMParseCommandLineOptions(int argc, const char *const *argv,
const char *Overview) {
llvm::cl::ParseCommandLineOptions(argc, argv, StringRef(Overview),
&llvm::nulls());
}

std::vector<int64_t> row_major(int64_t dim) {
std::vector<int64_t> minor_to_major;
for (int i = 0; i < dim; i++) {
Expand Down
3 changes: 2 additions & 1 deletion deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,8 @@ cc_library(
"-Wl,-exported_symbol,_ProfilerActivityStart",
"-Wl,-exported_symbol,_ProfilerActivityEnd",
"-Wl,-exported_symbol,_ReactantFuncSetArgAttr",
"-Wl,-exported_symbol,_ReactantCudaDriverGetVersion"
"-Wl,-exported_symbol,_ReactantCudaDriverGetVersion",
"-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions"
]}),
deps = [
"@enzyme//:EnzymeMLIR",
Expand Down
2 changes: 1 addition & 1 deletion deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ http_archive(
)

# load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")
XLA_COMMIT = "281c11225c4a0bb7b710a290610a06d71194febd"
XLA_COMMIT = "e0c92850a41cf5208744d8a919b969fa3506863c"
XLA_SHA256 = ""

http_archive(
Expand Down
3 changes: 1 addition & 2 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,10 @@ function compile(job)
if !isempty(errors)
throw(GPUCompiler.InvalidIRError(job, errors))
end
LLVM.strip_debuginfo!(mod)
# LLVM.strip_debuginfo!(mod)
modstr = string(mod)
# This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version
# it is probably safer to reparse a string using the right llvm module api, so we will do that.

mmod = MLIR.IR.Module(
@ccall MLIR.API.mlir_c.ConvertLLVMStrToMLIR(
modstr::Cstring, MLIR.IR.context()::MLIR.API.MlirContext
Expand Down
10 changes: 10 additions & 0 deletions src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@ module XLA

import ...MLIR

function LLVMclopts(opts...)
args = ["", opts...]
@ccall MLIR.API.mlir_c.ReactantLLVMParseCommandLineOptions(
length(args)::Cint, args::Ptr{Cstring}, C_NULL::Ptr{Cvoid}
)::Cvoid
end

mutable struct Client
client::Ptr{Cvoid}

Expand Down Expand Up @@ -50,6 +57,7 @@ function CPUClient(asynchronous=false, node_id=0, num_nodes=1; checkcount=true)
end
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeCPUClient")
client = ccall(f, Ptr{Cvoid}, (UInt, Cint, Cint), asynchronous, node_id, num_nodes)
LLVMclopts("-nvptx-fma-level=1")
#client = @ccall MLIR.API.mlir_c.MakeCPUClient(asynchronous::UInt8, node_id::Cint, num_nodes::Cint)::Ptr{Cvoid}
return Client(client)
end
Expand All @@ -73,6 +81,7 @@ function GPUClient(node_id=0, num_nodes=1, platform="gpu")
if client == C_NULL
throw(AssertionError(unsafe_string(refstr[])))
end
LLVMclopts("-nvptx-fma-level=1")
return Client(client)
end

Expand All @@ -83,6 +92,7 @@ function TPUClient(tpu_path::String)
if client == C_NULL
throw(AssertionError(unsafe_string(refstr[])))
end
LLVMclopts("-nvptx-fma-level=1")
return Client(client)
end

Expand Down
Loading