Skip to content

Support MPI #752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
c5f72cd
Register MPI symbols on load
mofeing Feb 15, 2025
d5eaa2d
ops
mofeing Feb 16, 2025
5f800fe
Update ext/ReactantMPIExt/Overrides.jl
mofeing Feb 16, 2025
5e4a8cd
Update ext/ReactantMPIExt/Overrides.jl
mofeing Feb 22, 2025
215cb1d
register MPI constants
mofeing Feb 28, 2025
4b54755
Fix MPI specializations
mofeing Mar 3, 2025
e7fd20e
fix some symbol registration
mofeing Mar 3, 2025
774982a
refactor MPI Ops
mofeing Mar 3, 2025
d1fe99b
Add functionality for parsing single operations (Julia code)
mofeing Mar 4, 2025
81dd8f2
Update `Ops.comm_rank` to use handwritten MLIR injection
mofeing Mar 4, 2025
2ebb52a
comment
mofeing Mar 4, 2025
11ef645
Update `Ops.comm_size`
mofeing Mar 4, 2025
406fe08
fixes
mofeing Mar 4, 2025
108c679
Refactor `Ops.barrier`
mofeing Mar 4, 2025
bfac727
Refactor to `try_inject_to_top_block!`
mofeing Mar 4, 2025
20fadc2
Refactor MLIR injection
mofeing Mar 5, 2025
2babb79
Refactor MPI constante registration
mofeing Mar 5, 2025
2c8e95c
Fix type inference in `Ops.hlo_call` on empty args
mofeing Mar 5, 2025
a6738f5
Fix MLIR of `Ops.comm_rank`
mofeing Mar 5, 2025
9710e59
Fix MLIR injection C-functions
mofeing Mar 6, 2025
f865e3e
Go back to `Cint` for registering symbols
mofeing Mar 6, 2025
2778f1d
Add `tryinjectop!`
mofeing Mar 6, 2025
eb5da6e
Add `tryinject!`, `inject!` methods
mofeing Mar 6, 2025
5995a7b
Update `comm_rank`
mofeing Mar 6, 2025
32b28ef
Update `mlirOperationInject`, `mlirOperationParse`
mofeing Mar 6, 2025
6e9b1c5
Add `verify` flag to `tryinject!`, `parse(::Operation)`
mofeing Mar 6, 2025
f4acb15
Update `Ops.comm_rank`
mofeing Mar 6, 2025
85a79ed
Update `comm_rank`, `comm_size`, `barrier`, `wait`
mofeing Mar 11, 2025
4e3477f
Implement `Ops.allreduce`
mofeing Mar 11, 2025
9fb0163
Implement `Ops.send`
mofeing Mar 11, 2025
f14dd80
Remove comment
mofeing Mar 11, 2025
14d84e2
Update `Ops.wait`
mofeing Mar 11, 2025
beebfe8
Remove `comm` argument from `Comm_size`, `Barrier` overrides
mofeing Mar 11, 2025
708a574
Fix `Ops.comm_size`
mofeing Mar 11, 2025
da59f4a
Fix `Ops.barrier`
mofeing Mar 11, 2025
b6a9cdf
Fixes and renames
mofeing Mar 11, 2025
c207e02
Override `MPI.Allreduce!`
mofeing Mar 11, 2025
287bd32
Fix conversion of MPI constants to word-size type
mofeing Mar 11, 2025
0c32d65
Comment unused MPI datatypes
mofeing Mar 11, 2025
7eb4107
small fixes
mofeing Mar 11, 2025
19c0eca
Implement `MPI.Recv!`
mofeing Mar 11, 2025
946339c
Test MPI
mofeing Mar 16, 2025
8d2c3d1
Merge branch 'main' into ss/mpi
mofeing Mar 16, 2025
42bf6b2
Update src/mlir/IR/Operation.jl
mofeing Mar 16, 2025
fad1845
Fix `mpiexec` symbol import
mofeing Mar 16, 2025
0359559
Fix typo
mofeing Mar 16, 2025
dd327ec
Init and Finalize on MPI tests
mofeing Mar 16, 2025
e9e657c
Merge branch 'main' into ss/mpi
mofeing May 18, 2025
113b30f
Fix changes introduces in "feat: IR inject functions (#1217)"
mofeing May 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 0 additions & 36 deletions ext/ReactantMPIExt.jl

This file was deleted.

586 changes: 586 additions & 0 deletions ext/ReactantMPIExt/Ops.jl

Large diffs are not rendered by default.

134 changes: 134 additions & 0 deletions ext/ReactantMPIExt/Overrides.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
using Reactant: @reactant_overlay, TracedRArray, TracedRNumber

@reactant_overlay @noinline function MPI.Init(; kwargs...)
if !isempty(kwargs)
@warn "Ignoring MPI.Init kwargs when tracing over MPI..." kwargs...
end
return Ops.init()
end

@reactant_overlay @noinline function MPI.Finalize(; kwargs...)
return Ops.finalize()
end

@reactant_overlay @noinline function MPI.Comm_rank(comm::MPI.Comm)
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently"
return Ops.comm_rank()
end

@reactant_overlay @noinline function MPI.Comm_size(comm::MPI.Comm)
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently"
return Ops.comm_size()
end

@reactant_overlay @noinline function MPI.Barrier(comm::MPI.Comm)
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently"
return Ops.barrier()
end

# TODO status not supported yet
function MPI.Wait(req::TracedRequest)
return Ops.wait(req)
end

# TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer`
function MPI.Send(buf::TracedRArray, dest::Integer, tag::Integer, comm::MPI.Comm)
tag = Reactant.Ops.constant(tag)
dest = Reactant.Ops.constant(dest)
return MPI.Send(buf, dest, tag, comm)
end

# TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer`
function MPI.Send(
buf::TracedRArray, dest::TracedRNumber, tag::TracedRNumber, comm::MPI.Comm
)
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently"
return Ops.send(buf, tag, dest)
end

# TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer`
function MPI.Isend(
buf::TracedRArray,
dest::Union{T,TracedRNumber{T}},
tag::Union{T,TracedRNumber{T}},
comm::MPI.Comm,
) where {T<:Integer}
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently"

tag = if !(tag isa TracedRNumber)
Reactant.Ops.constant(tag)
end

dest = if !(dest isa TracedRNumber)
Reactant.Ops.constant(dest)
end

return Ops.isend(buf, tag, dest)
end

# TODO should we error if other `AbstractRequest` types are passed in?
function MPI.Isend(
buf::TracedRArray, dest::Number, tag::Number, comm::MPI.Comm, req::TracedRequest
)
gen_req = MPI.Isend(buf, dest, tag, comm)
req.mlir_data = gen_req.mlir_data
return req
end

function MPI.Recv!(buf::TracedRArray, source::Integer, tag::Integer, comm::MPI.Comm)
tag = Reactant.Ops.constant(tag)
source = Reactant.Ops.constant(source)
return MPI.Recv!(buf, source, tag, comm)
end

function MPI.Recv!(
recvbuf::TracedRArray,
source::Integer,
tag::Integer,
comm::MPI.Comm,
::Type{MPI.API.MPI_Status},
)
return MPI.Recv!(recvbuf, source, tag, comm)
end

function MPI.Recv!(
recvbuf::TracedRArray, source::Integer, tag::Integer, comm::MPI.Comm, ::Nothing
)
return MPI.Recv!(recvbuf, source, tag, comm)
end

# TODO use `make_tracer` to delinearize arbitrary types? check out `MPI.Buffer`
function MPI.Recv!(
recvbuf::TracedRArray, source::TracedRNumber, tag::TracedRNumber, comm::MPI.Comm
)
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently"
return Ops.recv!(recvbuf, tag, source)
end

# TODO use `make_tracer` to delinearize arbitrary types? check out `MPI.Buffer`
function MPI.Irecv!(recvbuf::TracedRArray, source::Number, tag::Number, comm::MPI.Comm)
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently"

tag = if !(tag isa TracedRNumber)
Reactant.Ops.constant(tag)
end

source = if !(source isa TracedRNumber)
Reactant.Ops.constant(source)
end

return Ops.irecv!(recvbuf, tag, source)
end

function MPI.Irecv!(
recvbuf::TracedRArray, source::Number, tag::Number, comm::MPI.Comm, req::TracedRequest
)
gen_req = MPI.Irecv!(recvbuf, source, tag, comm)
req.mlir_data = gen_req.mlir_data
return req
end

function MPI.Allreduce!(sendbuf::TracedRArray, recvbuf::TracedRArray, op, comm::MPI.Comm)
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently"
return Ops.allreduce!(op, sendbuf, recvbuf)
end
228 changes: 228 additions & 0 deletions ext/ReactantMPIExt/ReactantMPIExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
module ReactantMPIExt

using Reactant
using Reactant: Reactant, Distributed, MLIR
using MPI: MPI
using Libdl

# https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/mpi4py_cluster.py
Distributed.is_env_present(::Distributed.MPIEnvDetector) = MPI.Initialized()

function Distributed.get_coordinator_address(
::Distributed.MPIEnvDetector, timeout_in_seconds::Integer
)
if MPI.Comm_rank(MPI.COMM_WORLD) == 0
hostname = gethostname()
port_id = hash(hostname) % 2^12 + (65535 - 2^12 + 1)
hostname = "$(hostname):$(port_id)"
else
hostname = nothing
end

return MPI.bcast(hostname, MPI.COMM_WORLD; root=0)
end

function Distributed.get_process_count(::Distributed.MPIEnvDetector)
return Int(MPI.Comm_size(MPI.COMM_WORLD))
end

function Distributed.get_process_id(::Distributed.MPIEnvDetector)
return Int(MPI.Comm_rank(MPI.COMM_WORLD))
end

function Distributed.get_local_process_id(::Distributed.MPIEnvDetector)
new_comm = MPI.Comm_split_type(MPI.COMM_WORLD, MPI.COMM_TYPE_SHARED, 0)
return Int(MPI.Comm_rank(new_comm))
end

function __init__()
# TODO maybe it's more efficient if we use `RTLD_NOW` instead of `RTLD_LAZY`?
libmpi_handle = Libdl.dlopen(MPI.API.libmpi, RTLD_LAZY | RTLD_GLOBAL)

# register MPI routines
for name in [
:MPI_Init,
:MPI_Finalize,
:MPI_Comm_rank,
:MPI_Comm_size,
:MPI_Send,
:MPI_Recv,
:MPI_Isend,
:MPI_Irecv,
:MPI_Barrier,
:MPI_Wait,
:MPI_Request_free,
]
sym = Libdl.dlsym(libmpi_handle, name)
@ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(name::Cstring, sym::Ptr{Cvoid})::Cvoid
end

# register MPI constants
# NOTE these symbols are not ABI-stable until MPI 5.0, but in practice, they are represented as word-size values (i.e. `int` or ptr)
for name in [
# communicators
:MPI_COMM_WORLD,
:MPI_COMM_SELF,
:MPI_COMM_NULL,
:MPI_COMM_TYPE_SHARED,
# datatypes
:MPI_DATATYPE_NULL,
:MPI_BYTE,
:MPI_PACKED,
:MPI_CHAR,
:MPI_SHORT,
:MPI_INT,
:MPI_LONG,
:MPI_FLOAT,
:MPI_DOUBLE,
:MPI_UNSIGNED_CHAR,
:MPI_SIGNED_CHAR,
:MPI_UNSIGNED_SHORT,
:MPI_UNSIGNED_LONG,
:MPI_UNSIGNED,
:MPI_FLOAT_INT,
:MPI_DOUBLE_INT,
:MPI_LONG_DOUBLE_INT,
:MPI_LONG_INT,
:MPI_SHORT_INT,
# :MPI_2INT,
:MPI_UB,
:MPI_LB,
:MPI_WCHAR,
:MPI_LONG_LONG_INT,
:MPI_UNSIGNED_LONG_LONG,
# :MPI_2COMPLEX,
# :MPI_2DOUBLE_COMPLEX,
:MPI_INT8_T,
:MPI_UINT8_T,
:MPI_INT16_T,
:MPI_UINT16_T,
:MPI_INT32_T,
:MPI_UINT32_T,
:MPI_INT64_T,
:MPI_UINT64_T,
:MPI_AINT,
:MPI_OFFSET,
:MPI_C_BOOL,
:MPI_C_FLOAT_COMPLEX,
:MPI_C_DOUBLE_COMPLEX,
# :MPI_C_LONG_DOUBLE_COMPLEX,
:MPI_COUNT,
# ops
:MPI_OP_NULL,
:MPI_MAX,
:MPI_MIN,
:MPI_SUM,
:MPI_PROD,
:MPI_LAND,
:MPI_BAND,
:MPI_LOR,
:MPI_BOR,
:MPI_LXOR,
:MPI_BXOR,
:MPI_MINLOC,
:MPI_MAXLOC,
:MPI_REPLACE,
:MPI_NO_OP,
# request
:MPI_REQUEST_NULL,
# status
:MPI_STATUS_IGNORE,
:MPI_STATUSES_IGNORE,
# error
:MPI_SUCCESS,
:MPI_ERR_BUFFER,
:MPI_ERR_COUNT,
:MPI_ERR_TYPE,
:MPI_ERR_TAG,
:MPI_ERR_COMM,
:MPI_ERR_RANK,
:MPI_ERR_REQUEST,
:MPI_ERR_ROOT,
:MPI_ERR_GROUP,
:MPI_ERR_OP,
:MPI_ERR_TOPOLOGY,
:MPI_ERR_DIMS,
:MPI_ERR_ARG,
:MPI_ERR_UNKNOWN,
:MPI_ERR_TRUNCATE,
:MPI_ERR_OTHER,
:MPI_ERR_INTERN,
:MPI_ERR_IN_STATUS,
:MPI_ERR_PENDING,
:MPI_ERR_ACCESS,
:MPI_ERR_AMODE,
:MPI_ERR_ASSERT,
:MPI_ERR_BAD_FILE,
:MPI_ERR_BASE,
:MPI_ERR_CONVERSION,
:MPI_ERR_DISP,
:MPI_ERR_DUP_DATAREP,
:MPI_ERR_FILE_EXISTS,
:MPI_ERR_FILE_IN_USE,
:MPI_ERR_FILE,
:MPI_ERR_INFO_KEY,
:MPI_ERR_INFO_NOKEY,
:MPI_ERR_INFO_VALUE,
:MPI_ERR_INFO,
:MPI_ERR_IO,
:MPI_ERR_KEYVAL,
:MPI_ERR_LOCKTYPE,
:MPI_ERR_NAME,
:MPI_ERR_NO_MEM,
:MPI_ERR_NOT_SAME,
:MPI_ERR_NO_SPACE,
:MPI_ERR_NO_SUCH_FILE,
:MPI_ERR_PORT,
:MPI_ERR_QUOTA,
:MPI_ERR_READ_ONLY,
:MPI_ERR_RMA_CONFLICT,
:MPI_ERR_RMA_SYNC,
:MPI_ERR_SERVICE,
:MPI_ERR_SIZE,
:MPI_ERR_SPAWN,
:MPI_ERR_UNSUPPORTED_DATAREP,
:MPI_ERR_UNSUPPORTED_OPERATION,
:MPI_ERR_WIN,
# :MPI_T_ERR_MEMORY,
# :MPI_T_ERR_NOT_INITIALIZED,
# :MPI_T_ERR_CANNOT_INIT,
# :MPI_T_ERR_INVALID_INDEX,
# :MPI_T_ERR_INVALID_ITEM,
# :MPI_T_ERR_INVALID_HANDLE,
# :MPI_T_ERR_OUT_OF_HANDLES,
# :MPI_T_ERR_OUT_OF_SESSIONS,
# :MPI_T_ERR_INVALID_SESSION,
# :MPI_T_ERR_CVAR_SET_NOT_NOW,
# :MPI_T_ERR_CVAR_SET_NEVER,
# :MPI_T_ERR_PVAR_NO_STARTSTOP,
# :MPI_T_ERR_PVAR_NO_WRITE,
# :MPI_T_ERR_PVAR_NO_ATOMIC,
:MPI_ERR_RMA_RANGE,
:MPI_ERR_RMA_ATTACH,
:MPI_ERR_RMA_FLAVOR,
:MPI_ERR_RMA_SHARED,
# :MPI_T_ERR_INVALID,
# :MPI_T_ERR_INVALID_NAME,
# :MPI_ERR_PROC_ABORTED,
# :MPI_ERR_PROC_FAILED,
# :MPI_ERR_PROC_FAILED_PENDING,
# :MPI_ERR_REVOKED,
]
value = getproperty(MPI.API, name)
if value isa Base.RefValue
value = value[]
end
value = convert(Int, value)
@ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(name::Cstring, value::Int)::Cvoid
end
end

struct TracedRequest <: MPI.AbstractRequest
mlir_data::Union{Nothing,Reactant.MLIR.IR.Value}
end

include("Ops.jl")
include("Overrides.jl")

end # module
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
@@ -47,6 +48,7 @@ LinearAlgebra = "1.10"
Lux = "1.4.1"
LuxLib = "1.3"
MLUtils = "0.4.4"
MPI = "0.20"
NNlib = "0.9.26"
OffsetArrays = "1"
OneHotArrays = "0.2.6"
24 changes: 24 additions & 0 deletions test/integration/mpi.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using Test, MPI, Reactant

MPI.Init()

@testset "Comm_rank" begin
comm = MPI.COMM_WORLD
rank = MPI.Comm_rank(comm)
@test rank == @jit MPI.Comm_rank(comm)
end

@testset "Comm_size" begin
comm = MPI.COMM_WORLD
nranks = MPI.Comm_size(comm)
@test nranks == @jit MPI.Comm_size(comm)
end

@testset "Allreduce" begin
comm = MPI.COMM_WORLD
x = ConcreteRArray(fill(1))
nranks = MPI.Comm_size(comm)
@test nranks == @jit MPI.Allreduce(x, MPI.SUM, MPI.COMM_WORLD)
end

MPI.Finalize()
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -48,6 +48,11 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
@safetestset "Random" include("integration/random.jl")
@safetestset "Python" include("integration/python.jl")
@safetestset "Optimisers" include("integration/optimisers.jl")
@safetestset "MPI" begin
using MPI
nranks = 2
run(`$(mpiexec()) -n $nranks $(Base.julia_cmd()) integration/mpi.jl`)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
run(`$(mpiexec()) -n $nranks $(Base.julia_cmd()) integration/mpi.jl`)
run(`$(mpiexec()) -n $nranks $(Base.julia_cmd()) --startup-file=no $(joinpath(@__DIR__, "integration", "mpi.jl")`)

end
end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"