Skip to content

Commit 8111623

Browse files
Avik Palavik-pal
authored andcommitted
feat: add OMPI cluster detection
1 parent dbea6f7 commit 8111623

File tree

2 files changed

+83
-3
lines changed

2 files changed

+83
-3
lines changed

ext/ReactantMPIExt.jl

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,82 @@ module ReactantMPIExt
33
using Reactant: Reactant, Distributed
44
using MPI: MPI
55

6-
# Code taken from https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/mpi4py_cluster.py
6+
# Code taken from:
7+
# 1. https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/mpi4py_cluster.py
8+
# 2. https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/ompi_cluster.py
79

8-
# XXX: Is this a good check??
10+
# Based on ompi_cluster
11+
const _ORTE_URI = "OMPI_MCA_orte_hnp_uri"
12+
const _PMIX_SERVER_URI = (
13+
"PMIX_SERVER_URI2",
14+
"PMIX_SERVER_URI3",
15+
"PMIX_SERVER_URI4",
16+
"PMIX_SERVER_URI41",
17+
"PMIX_SERVER_URI21",
18+
)
19+
const _OMPI_PROCESS_COUNT = "OMPI_COMM_WORLD_SIZE"
20+
const _OMPI_PROCESS_ID = "OMPI_COMM_WORLD_RANK"
21+
const _OMPI_LOCAL_PROCESS_ID = "OMPI_COMM_WORLD_LOCAL_RANK"
22+
23+
Distributed.is_env_present(::Distributed.OpenMPIORTEEnvDetector) = haskey(ENV, _ORTE_URI)
24+
25+
function Distributed.is_env_present(::Distributed.OpenMPIPMIXEnvDetector)
26+
return any(Base.Fix1(haskey, ENV), _PMIX_SERVER_URI)
27+
end
28+
29+
function Distributed.get_coordinator_address(
30+
::Distributed.OpenMPIORTEEnvDetector, timeout_in_seconds::Integer
31+
)
32+
orte_uri = ENV[_ORTE_URI]
33+
34+
job_id = parse(Int, split(orte_uri, '.'; limit=2)[1])
35+
port = job_id % 2^12 + (65535 - 2^12 + 1)
36+
37+
launcher_ip_match = match(r"tcp://(.+?)[,:]|tcp6://\[(.+?)[,\]]", orte_uri)
38+
39+
@assert launcher_ip_match !== nothing "Could not parse coordinator IP address from \
40+
Open MPI environment."
41+
42+
launcher_ip = launcher_ip_match.captures[findfirst(
43+
!isnothing, launcher_ip_match.captures
44+
)]
45+
return "$(launcher_ip):$(port)"
46+
end
47+
48+
function Distributed.get_coordinator_address(
49+
::Distributed.OpenMPIPMIXEnvDetector, timeout_in_seconds::Integer
50+
)
51+
varname = findfirst(Base.Fix1(haskey, ENV), _PMIX_SERVER_URI)
52+
pmix_uri = ENV[_PMIX_SERVER_URI[varname]]
53+
54+
job_id = parse(Int, split(split(pmix_uri, '-'; limit=3)[3], "@"; limit=2)[1])
55+
port = job_id % 2^12 + (65535 - 2^12 + 1)
56+
57+
launcher_ip_match = match(r"tcp4://(.+?):|tcp6://\[(.+?)\]", pmix_uri)
58+
59+
@assert launcher_ip_match !== nothing "Could not parse coordinator IP address from \
60+
Open MPI environment."
61+
62+
launcher_ip = launcher_ip_match.captures[findfirst(
63+
!isnothing, launcher_ip_match.captures
64+
)]
65+
66+
return "$(launcher_ip):$(port)"
67+
end
68+
69+
function Distributed.get_process_count(::Distributed.AbstractOMPIClusterEnvDetector)
70+
return parse(Int, ENV[_OMPI_PROCESS_COUNT])
71+
end
72+
73+
function Distributed.get_process_id(::Distributed.AbstractOMPIClusterEnvDetector)
74+
return parse(Int, ENV[_OMPI_PROCESS_ID])
75+
end
76+
77+
function Distributed.get_local_process_id(::Distributed.AbstractOMPIClusterEnvDetector)
78+
return parse(Int, ENV[_OMPI_LOCAL_PROCESS_ID])
79+
end
80+
81+
# Based on mpi4py
982
Distributed.is_env_present(::Distributed.MPIEnvDetector) = MPI.Initialized()
1083

1184
function Distributed.get_coordinator_address(

src/Distributed.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ end
3636

3737
abstract type AbstractClusterEnvDetector end
3838

39+
abstract type AbstractOMPIClusterEnvDetector <: AbstractClusterEnvDetector end
40+
41+
struct OpenMPIORTEEnvDetector <: AbstractOMPIClusterEnvDetector end
42+
struct OpenMPIPMIXEnvDetector <: AbstractOMPIClusterEnvDetector end
43+
3944
struct MPIEnvDetector <: AbstractClusterEnvDetector end
4045

4146
# Based on https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/cluster.py
@@ -48,7 +53,7 @@ function get_process_id end
4853
function get_local_process_id end
4954

5055
function auto_detect_unset_distributed_params(;
51-
detector_list=[MPIEnvDetector()],
56+
detector_list=[OpenMPIORTEEnvDetector(), OpenMPIPMIXEnvDetector(), MPIEnvDetector()],
5257
coordinator_address::Union{Nothing,String}=nothing,
5358
num_processes::Union{Nothing,Integer}=nothing,
5459
process_id::Union{Nothing,Integer}=nothing,
@@ -70,6 +75,8 @@ function auto_detect_unset_distributed_params(;
7075

7176
detector = detector_list[idx]
7277

78+
@debug "Detected cluster environment" detector
79+
7380
if coordinator_address === nothing
7481
coordinator_address = get_coordinator_address(
7582
detector, initialization_timeout_in_seconds

0 commit comments

Comments
 (0)