Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Reactant"
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>"]
version = "0.2.20"
version = "0.2.21"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -70,7 +70,7 @@ PythonCall = "0.9"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.4"
Reactant_jll = "0.0.45"
Reactant_jll = "0.0.46"
Scratch = "1.2"
Sockets = "1.10"
SpecialFunctions = "2.4"
Expand Down
6 changes: 6 additions & 0 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,12 @@ std::vector<int64_t> col_major(int64_t dim) {
return minor_to_major;
}

extern "C" void ReactantLLVMParseCommandLineOptions(int argc, const char *const *argv,
const char *Overview) {
llvm::cl::ParseCommandLineOptions(argc, argv, StringRef(Overview),
&llvm::nulls());
}

std::vector<int64_t> row_major(int64_t dim) {
std::vector<int64_t> minor_to_major;
for (int i = 0; i < dim; i++) {
Expand Down
3 changes: 2 additions & 1 deletion deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,8 @@ cc_library(
"-Wl,-exported_symbol,_ProfilerActivityStart",
"-Wl,-exported_symbol,_ProfilerActivityEnd",
"-Wl,-exported_symbol,_ReactantFuncSetArgAttr",
"-Wl,-exported_symbol,_ReactantCudaDriverGetVersion"
"-Wl,-exported_symbol,_ReactantCudaDriverGetVersion",
"-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions"
]}),
deps = [
"@enzyme//:EnzymeMLIR",
Expand Down
2 changes: 1 addition & 1 deletion deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ http_archive(
)

# load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")
XLA_COMMIT = "281c11225c4a0bb7b710a290610a06d71194febd"
XLA_COMMIT = "e0c92850a41cf5208744d8a919b969fa3506863c"
XLA_SHA256 = ""

http_archive(
Expand Down
3 changes: 1 addition & 2 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,10 @@ function compile(job)
if !isempty(errors)
throw(GPUCompiler.InvalidIRError(job, errors))
end
LLVM.strip_debuginfo!(mod)
# LLVM.strip_debuginfo!(mod)
modstr = string(mod)
# This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version
# it is probably safer to reparse a string using the right llvm module api, so we will do that.

mmod = MLIR.IR.Module(
@ccall MLIR.API.mlir_c.ConvertLLVMStrToMLIR(
modstr::Cstring, MLIR.IR.context()::MLIR.API.MlirContext
Expand Down
8 changes: 8 additions & 0 deletions src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ module XLA

import ...MLIR

function LLVMclopts(opts...)
args = ["", opts...]
@ccall MLIR.API.mlir_c.ReactantLLVMParseCommandLineOptions(length(args)::Cint, args::Ptr{Cstring}, C_NULL::Ptr{Cvoid})::Cvoid
end

mutable struct Client
client::Ptr{Cvoid}

Expand Down Expand Up @@ -50,6 +55,7 @@ function CPUClient(asynchronous=false, node_id=0, num_nodes=1; checkcount=true)
end
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeCPUClient")
client = ccall(f, Ptr{Cvoid}, (UInt, Cint, Cint), asynchronous, node_id, num_nodes)
LLVMclopts("-nvptx-fma-level=1")
#client = @ccall MLIR.API.mlir_c.MakeCPUClient(asynchronous::UInt8, node_id::Cint, num_nodes::Cint)::Ptr{Cvoid}
return Client(client)
end
Expand All @@ -73,6 +79,7 @@ function GPUClient(node_id=0, num_nodes=1, platform="gpu")
if client == C_NULL
throw(AssertionError(unsafe_string(refstr[])))
end
LLVMclopts("-nvptx-fma-level=1")
return Client(client)
end

Expand All @@ -83,6 +90,7 @@ function TPUClient(tpu_path::String)
if client == C_NULL
throw(AssertionError(unsafe_string(refstr[])))
end
LLVMclopts("-nvptx-fma-level=1")
return Client(client)
end

Expand Down
Loading