-
Notifications
You must be signed in to change notification settings - Fork 22
Support MPI #752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
mofeing
wants to merge
49
commits into
main
Choose a base branch
from
ss/mpi
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+979
−36
Open
Support MPI #752
Changes from all commits
Commits
Show all changes
49 commits
Select commit
Hold shift + click to select a range
c5f72cd
Register MPI symbols on load
mofeing d5eaa2d
ops
mofeing 5f800fe
Update ext/ReactantMPIExt/Overrides.jl
mofeing 5e4a8cd
Update ext/ReactantMPIExt/Overrides.jl
mofeing 215cb1d
register MPI constants
mofeing 4b54755
Fix MPI specializations
mofeing e7fd20e
fix some symbol registration
mofeing 774982a
refactor MPI Ops
mofeing d1fe99b
Add functionality for parsing single operations (Julia code)
mofeing 81dd8f2
Update `Ops.comm_rank` to use handwritten MLIR injection
mofeing 2ebb52a
comment
mofeing 11ef645
Update `Ops.comm_size`
mofeing 406fe08
fixes
mofeing 108c679
Refactor `Ops.barrier`
mofeing bfac727
Refactor to `try_inject_to_top_block!`
mofeing 20fadc2
Refactor MLIR injection
mofeing 2babb79
Refactor MPI constante registration
mofeing 2c8e95c
Fix type inference in `Ops.hlo_call` on empty args
mofeing a6738f5
Fix MLIR of `Ops.comm_rank`
mofeing 9710e59
Fix MLIR injection C-functions
mofeing f865e3e
Go back to `Cint` for registering symbols
mofeing 2778f1d
Add `tryinjectop!`
mofeing eb5da6e
Add `tryinject!`, `inject!` methods
mofeing 5995a7b
Update `comm_rank`
mofeing 32b28ef
Update `mlirOperationInject`, `mlirOperationParse`
mofeing 6e9b1c5
Add `verify` flag to `tryinject!`, `parse(::Operation)`
mofeing f4acb15
Update `Ops.comm_rank`
mofeing 85a79ed
Update `comm_rank`, `comm_size`, `barrier`, `wait`
mofeing 4e3477f
Implement `Ops.allreduce`
mofeing 9fb0163
Implement `Ops.send`
mofeing f14dd80
Remove comment
mofeing 14d84e2
Update `Ops.wait`
mofeing beebfe8
Remove `comm` argument from `Comm_size`, `Barrier` overrides
mofeing 708a574
Fix `Ops.comm_size`
mofeing da59f4a
Fix `Ops.barrier`
mofeing b6a9cdf
Fixes and renames
mofeing c207e02
Override `MPI.Allreduce!`
mofeing 287bd32
Fix conversion of MPI constants to word-size type
mofeing 0c32d65
Comment unused MPI datatypes
mofeing 7eb4107
small fixes
mofeing 19c0eca
Implement `MPI.Recv!`
mofeing 946339c
Test MPI
mofeing 8d2c3d1
Merge branch 'main' into ss/mpi
mofeing 42bf6b2
Update src/mlir/IR/Operation.jl
mofeing fad1845
Fix `mpiexec` symbol import
mofeing 0359559
Fix typo
mofeing dd327ec
Init and Finalize on MPI tests
mofeing e9e657c
Merge branch 'main' into ss/mpi
mofeing 113b30f
Fix changes introduces in "feat: IR inject functions (#1217)"
mofeing File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
using Reactant: @reactant_overlay, TracedRArray, TracedRNumber | ||
|
||
@reactant_overlay @noinline function MPI.Init(; kwargs...) | ||
if !isempty(kwargs) | ||
@warn "Ignoring MPI.Init kwargs when tracing over MPI..." kwargs... | ||
end | ||
return Ops.init() | ||
end | ||
|
||
@reactant_overlay @noinline function MPI.Finalize(; kwargs...) | ||
return Ops.finalize() | ||
end | ||
|
||
@reactant_overlay @noinline function MPI.Comm_rank(comm::MPI.Comm) | ||
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" | ||
return Ops.comm_rank() | ||
end | ||
|
||
@reactant_overlay @noinline function MPI.Comm_size(comm::MPI.Comm) | ||
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" | ||
return Ops.comm_size() | ||
end | ||
|
||
@reactant_overlay @noinline function MPI.Barrier(comm::MPI.Comm) | ||
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" | ||
return Ops.barrier() | ||
end | ||
|
||
# TODO status not supported yet | ||
function MPI.Wait(req::TracedRequest) | ||
return Ops.wait(req) | ||
end | ||
|
||
# TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer` | ||
function MPI.Send(buf::TracedRArray, dest::Integer, tag::Integer, comm::MPI.Comm) | ||
tag = Reactant.Ops.constant(tag) | ||
dest = Reactant.Ops.constant(dest) | ||
return MPI.Send(buf, dest, tag, comm) | ||
end | ||
|
||
# TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer` | ||
function MPI.Send( | ||
buf::TracedRArray, dest::TracedRNumber, tag::TracedRNumber, comm::MPI.Comm | ||
) | ||
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" | ||
return Ops.send(buf, tag, dest) | ||
end | ||
|
||
# TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer` | ||
function MPI.Isend( | ||
buf::TracedRArray, | ||
dest::Union{T,TracedRNumber{T}}, | ||
tag::Union{T,TracedRNumber{T}}, | ||
comm::MPI.Comm, | ||
) where {T<:Integer} | ||
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" | ||
|
||
tag = if !(tag isa TracedRNumber) | ||
Reactant.Ops.constant(tag) | ||
end | ||
|
||
dest = if !(dest isa TracedRNumber) | ||
Reactant.Ops.constant(dest) | ||
end | ||
|
||
return Ops.isend(buf, tag, dest) | ||
end | ||
|
||
# TODO should we error if other `AbstractRequest` types are passed in? | ||
function MPI.Isend( | ||
buf::TracedRArray, dest::Number, tag::Number, comm::MPI.Comm, req::TracedRequest | ||
) | ||
gen_req = MPI.Isend(buf, dest, tag, comm) | ||
req.mlir_data = gen_req.mlir_data | ||
return req | ||
end | ||
|
||
function MPI.Recv!(buf::TracedRArray, source::Integer, tag::Integer, comm::MPI.Comm) | ||
tag = Reactant.Ops.constant(tag) | ||
source = Reactant.Ops.constant(source) | ||
return MPI.Recv!(buf, source, tag, comm) | ||
end | ||
|
||
function MPI.Recv!( | ||
recvbuf::TracedRArray, | ||
source::Integer, | ||
tag::Integer, | ||
comm::MPI.Comm, | ||
::Type{MPI.API.MPI_Status}, | ||
) | ||
return MPI.Recv!(recvbuf, source, tag, comm) | ||
end | ||
|
||
function MPI.Recv!( | ||
recvbuf::TracedRArray, source::Integer, tag::Integer, comm::MPI.Comm, ::Nothing | ||
) | ||
return MPI.Recv!(recvbuf, source, tag, comm) | ||
end | ||
|
||
# TODO use `make_tracer` to delinearize arbitrary types? check out `MPI.Buffer` | ||
function MPI.Recv!( | ||
recvbuf::TracedRArray, source::TracedRNumber, tag::TracedRNumber, comm::MPI.Comm | ||
) | ||
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" | ||
return Ops.recv!(recvbuf, tag, source) | ||
end | ||
|
||
# TODO use `make_tracer` to delinearize arbitrary types? check out `MPI.Buffer` | ||
function MPI.Irecv!(recvbuf::TracedRArray, source::Number, tag::Number, comm::MPI.Comm) | ||
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" | ||
|
||
tag = if !(tag isa TracedRNumber) | ||
Reactant.Ops.constant(tag) | ||
end | ||
|
||
source = if !(source isa TracedRNumber) | ||
Reactant.Ops.constant(source) | ||
end | ||
|
||
return Ops.irecv!(recvbuf, tag, source) | ||
end | ||
|
||
function MPI.Irecv!( | ||
recvbuf::TracedRArray, source::Number, tag::Number, comm::MPI.Comm, req::TracedRequest | ||
) | ||
gen_req = MPI.Irecv!(recvbuf, source, tag, comm) | ||
req.mlir_data = gen_req.mlir_data | ||
return req | ||
end | ||
|
||
function MPI.Allreduce!(sendbuf::TracedRArray, recvbuf::TracedRArray, op, comm::MPI.Comm) | ||
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" | ||
return Ops.allreduce!(op, sendbuf, recvbuf) | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
module ReactantMPIExt | ||
|
||
using Reactant | ||
using Reactant: Reactant, Distributed, MLIR | ||
using MPI: MPI | ||
using Libdl | ||
|
||
# https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/mpi4py_cluster.py | ||
Distributed.is_env_present(::Distributed.MPIEnvDetector) = MPI.Initialized() | ||
|
||
function Distributed.get_coordinator_address( | ||
::Distributed.MPIEnvDetector, timeout_in_seconds::Integer | ||
) | ||
if MPI.Comm_rank(MPI.COMM_WORLD) == 0 | ||
hostname = gethostname() | ||
port_id = hash(hostname) % 2^12 + (65535 - 2^12 + 1) | ||
hostname = "$(hostname):$(port_id)" | ||
else | ||
hostname = nothing | ||
end | ||
|
||
return MPI.bcast(hostname, MPI.COMM_WORLD; root=0) | ||
end | ||
|
||
function Distributed.get_process_count(::Distributed.MPIEnvDetector) | ||
return Int(MPI.Comm_size(MPI.COMM_WORLD)) | ||
end | ||
|
||
function Distributed.get_process_id(::Distributed.MPIEnvDetector) | ||
return Int(MPI.Comm_rank(MPI.COMM_WORLD)) | ||
end | ||
|
||
function Distributed.get_local_process_id(::Distributed.MPIEnvDetector) | ||
new_comm = MPI.Comm_split_type(MPI.COMM_WORLD, MPI.COMM_TYPE_SHARED, 0) | ||
return Int(MPI.Comm_rank(new_comm)) | ||
end | ||
|
||
function __init__() | ||
# TODO maybe it's more efficient if we use `RTLD_NOW` instead of `RTLD_LAZY`? | ||
libmpi_handle = Libdl.dlopen(MPI.API.libmpi, RTLD_LAZY | RTLD_GLOBAL) | ||
|
||
# register MPI routines | ||
for name in [ | ||
:MPI_Init, | ||
:MPI_Finalize, | ||
:MPI_Comm_rank, | ||
:MPI_Comm_size, | ||
:MPI_Send, | ||
:MPI_Recv, | ||
:MPI_Isend, | ||
:MPI_Irecv, | ||
:MPI_Barrier, | ||
:MPI_Wait, | ||
:MPI_Request_free, | ||
] | ||
sym = Libdl.dlsym(libmpi_handle, name) | ||
@ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(name::Cstring, sym::Ptr{Cvoid})::Cvoid | ||
end | ||
|
||
# register MPI constants | ||
# NOTE these symbols are not ABI-stable until MPI 5.0, but in practice, they are represented as word-size values (i.e. `int` or ptr) | ||
for name in [ | ||
# communicators | ||
:MPI_COMM_WORLD, | ||
:MPI_COMM_SELF, | ||
:MPI_COMM_NULL, | ||
:MPI_COMM_TYPE_SHARED, | ||
# datatypes | ||
:MPI_DATATYPE_NULL, | ||
:MPI_BYTE, | ||
:MPI_PACKED, | ||
:MPI_CHAR, | ||
:MPI_SHORT, | ||
:MPI_INT, | ||
:MPI_LONG, | ||
:MPI_FLOAT, | ||
:MPI_DOUBLE, | ||
:MPI_UNSIGNED_CHAR, | ||
:MPI_SIGNED_CHAR, | ||
:MPI_UNSIGNED_SHORT, | ||
:MPI_UNSIGNED_LONG, | ||
:MPI_UNSIGNED, | ||
:MPI_FLOAT_INT, | ||
:MPI_DOUBLE_INT, | ||
:MPI_LONG_DOUBLE_INT, | ||
:MPI_LONG_INT, | ||
:MPI_SHORT_INT, | ||
# :MPI_2INT, | ||
:MPI_UB, | ||
:MPI_LB, | ||
:MPI_WCHAR, | ||
:MPI_LONG_LONG_INT, | ||
:MPI_UNSIGNED_LONG_LONG, | ||
# :MPI_2COMPLEX, | ||
# :MPI_2DOUBLE_COMPLEX, | ||
:MPI_INT8_T, | ||
:MPI_UINT8_T, | ||
:MPI_INT16_T, | ||
:MPI_UINT16_T, | ||
:MPI_INT32_T, | ||
:MPI_UINT32_T, | ||
:MPI_INT64_T, | ||
:MPI_UINT64_T, | ||
:MPI_AINT, | ||
:MPI_OFFSET, | ||
:MPI_C_BOOL, | ||
:MPI_C_FLOAT_COMPLEX, | ||
:MPI_C_DOUBLE_COMPLEX, | ||
# :MPI_C_LONG_DOUBLE_COMPLEX, | ||
:MPI_COUNT, | ||
# ops | ||
:MPI_OP_NULL, | ||
:MPI_MAX, | ||
:MPI_MIN, | ||
:MPI_SUM, | ||
:MPI_PROD, | ||
:MPI_LAND, | ||
:MPI_BAND, | ||
:MPI_LOR, | ||
:MPI_BOR, | ||
:MPI_LXOR, | ||
:MPI_BXOR, | ||
:MPI_MINLOC, | ||
:MPI_MAXLOC, | ||
:MPI_REPLACE, | ||
:MPI_NO_OP, | ||
# request | ||
:MPI_REQUEST_NULL, | ||
# status | ||
:MPI_STATUS_IGNORE, | ||
:MPI_STATUSES_IGNORE, | ||
# error | ||
:MPI_SUCCESS, | ||
:MPI_ERR_BUFFER, | ||
:MPI_ERR_COUNT, | ||
:MPI_ERR_TYPE, | ||
:MPI_ERR_TAG, | ||
:MPI_ERR_COMM, | ||
:MPI_ERR_RANK, | ||
:MPI_ERR_REQUEST, | ||
:MPI_ERR_ROOT, | ||
:MPI_ERR_GROUP, | ||
:MPI_ERR_OP, | ||
:MPI_ERR_TOPOLOGY, | ||
:MPI_ERR_DIMS, | ||
:MPI_ERR_ARG, | ||
:MPI_ERR_UNKNOWN, | ||
:MPI_ERR_TRUNCATE, | ||
:MPI_ERR_OTHER, | ||
:MPI_ERR_INTERN, | ||
:MPI_ERR_IN_STATUS, | ||
:MPI_ERR_PENDING, | ||
:MPI_ERR_ACCESS, | ||
:MPI_ERR_AMODE, | ||
:MPI_ERR_ASSERT, | ||
:MPI_ERR_BAD_FILE, | ||
:MPI_ERR_BASE, | ||
:MPI_ERR_CONVERSION, | ||
:MPI_ERR_DISP, | ||
:MPI_ERR_DUP_DATAREP, | ||
:MPI_ERR_FILE_EXISTS, | ||
:MPI_ERR_FILE_IN_USE, | ||
:MPI_ERR_FILE, | ||
:MPI_ERR_INFO_KEY, | ||
:MPI_ERR_INFO_NOKEY, | ||
:MPI_ERR_INFO_VALUE, | ||
:MPI_ERR_INFO, | ||
:MPI_ERR_IO, | ||
:MPI_ERR_KEYVAL, | ||
:MPI_ERR_LOCKTYPE, | ||
:MPI_ERR_NAME, | ||
:MPI_ERR_NO_MEM, | ||
:MPI_ERR_NOT_SAME, | ||
:MPI_ERR_NO_SPACE, | ||
:MPI_ERR_NO_SUCH_FILE, | ||
:MPI_ERR_PORT, | ||
:MPI_ERR_QUOTA, | ||
:MPI_ERR_READ_ONLY, | ||
:MPI_ERR_RMA_CONFLICT, | ||
:MPI_ERR_RMA_SYNC, | ||
:MPI_ERR_SERVICE, | ||
:MPI_ERR_SIZE, | ||
:MPI_ERR_SPAWN, | ||
:MPI_ERR_UNSUPPORTED_DATAREP, | ||
:MPI_ERR_UNSUPPORTED_OPERATION, | ||
:MPI_ERR_WIN, | ||
# :MPI_T_ERR_MEMORY, | ||
# :MPI_T_ERR_NOT_INITIALIZED, | ||
# :MPI_T_ERR_CANNOT_INIT, | ||
# :MPI_T_ERR_INVALID_INDEX, | ||
# :MPI_T_ERR_INVALID_ITEM, | ||
# :MPI_T_ERR_INVALID_HANDLE, | ||
# :MPI_T_ERR_OUT_OF_HANDLES, | ||
# :MPI_T_ERR_OUT_OF_SESSIONS, | ||
# :MPI_T_ERR_INVALID_SESSION, | ||
# :MPI_T_ERR_CVAR_SET_NOT_NOW, | ||
# :MPI_T_ERR_CVAR_SET_NEVER, | ||
# :MPI_T_ERR_PVAR_NO_STARTSTOP, | ||
# :MPI_T_ERR_PVAR_NO_WRITE, | ||
# :MPI_T_ERR_PVAR_NO_ATOMIC, | ||
:MPI_ERR_RMA_RANGE, | ||
:MPI_ERR_RMA_ATTACH, | ||
:MPI_ERR_RMA_FLAVOR, | ||
:MPI_ERR_RMA_SHARED, | ||
# :MPI_T_ERR_INVALID, | ||
# :MPI_T_ERR_INVALID_NAME, | ||
# :MPI_ERR_PROC_ABORTED, | ||
# :MPI_ERR_PROC_FAILED, | ||
# :MPI_ERR_PROC_FAILED_PENDING, | ||
# :MPI_ERR_REVOKED, | ||
] | ||
value = getproperty(MPI.API, name) | ||
if value isa Base.RefValue | ||
value = value[] | ||
end | ||
value = convert(Int, value) | ||
@ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(name::Cstring, value::Int)::Cvoid | ||
end | ||
end | ||
|
||
struct TracedRequest <: MPI.AbstractRequest | ||
mlir_data::Union{Nothing,Reactant.MLIR.IR.Value} | ||
end | ||
|
||
include("Ops.jl") | ||
include("Overrides.jl") | ||
|
||
end # module |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
using Test, MPI, Reactant | ||
|
||
MPI.Init() | ||
|
||
@testset "Comm_rank" begin | ||
comm = MPI.COMM_WORLD | ||
rank = MPI.Comm_rank(comm) | ||
@test rank == @jit MPI.Comm_rank(comm) | ||
end | ||
|
||
@testset "Comm_size" begin | ||
comm = MPI.COMM_WORLD | ||
nranks = MPI.Comm_size(comm) | ||
@test nranks == @jit MPI.Comm_size(comm) | ||
end | ||
|
||
@testset "Allreduce" begin | ||
comm = MPI.COMM_WORLD | ||
x = ConcreteRArray(fill(1)) | ||
nranks = MPI.Comm_size(comm) | ||
@test nranks == @jit MPI.Allreduce(x, MPI.SUM, MPI.COMM_WORLD) | ||
end | ||
|
||
MPI.Finalize() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.