@@ -3,9 +3,82 @@ module ReactantMPIExt
33using Reactant: Reactant, Distributed
44using MPI: MPI
55
6- # Code taken from https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/mpi4py_cluster.py
6+ # Code taken from:
7+ # 1. https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/mpi4py_cluster.py
8+ # 2. https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/ompi_cluster.py
79
8- # XXX : Is this a good check??
10+ # Based on ompi_cluster
11+ const _ORTE_URI = " OMPI_MCA_orte_hnp_uri"
12+ const _PMIX_SERVER_URI = (
13+ " PMIX_SERVER_URI2" ,
14+ " PMIX_SERVER_URI3" ,
15+ " PMIX_SERVER_URI4" ,
16+ " PMIX_SERVER_URI41" ,
17+ " PMIX_SERVER_URI21" ,
18+ )
19+ const _OMPI_PROCESS_COUNT = " OMPI_COMM_WORLD_SIZE"
20+ const _OMPI_PROCESS_ID = " OMPI_COMM_WORLD_RANK"
21+ const _OMPI_LOCAL_PROCESS_ID = " OMPI_COMM_WORLD_LOCAL_RANK"
22+
23+ Distributed. is_env_present (:: Distributed.OpenMPIORTEEnvDetector ) = haskey (ENV , _ORTE_URI)
24+
25+ function Distributed. is_env_present (:: Distributed.OpenMPIPMIXEnvDetector )
26+ return any (Base. Fix1 (haskey, ENV ), _PMIX_SERVER_URI)
27+ end
28+
29+ function Distributed. get_coordinator_address (
30+ :: Distributed.OpenMPIORTEEnvDetector , timeout_in_seconds:: Integer
31+ )
32+ orte_uri = ENV [_ORTE_URI]
33+
34+ job_id = parse (Int, split (orte_uri, ' .' ; limit= 2 )[1 ])
35+ port = job_id % 2 ^ 12 + (65535 - 2 ^ 12 + 1 )
36+
37+ launcher_ip_match = match (r" tcp://(.+?)[,:]|tcp6://\[ (.+?)[,\] ]" , orte_uri)
38+
39+ @assert launcher_ip_match != = nothing " Could not parse coordinator IP address from \
40+ Open MPI environment."
41+
42+ launcher_ip = launcher_ip_match. captures[findfirst (
43+ ! isnothing, launcher_ip_match. captures
44+ )]
45+ return " $(launcher_ip) :$(port) "
46+ end
47+
48+ function Distributed. get_coordinator_address (
49+ :: Distributed.OpenMPIPMIXEnvDetector , timeout_in_seconds:: Integer
50+ )
51+ varname = findfirst (Base. Fix1 (haskey, ENV ), _PMIX_SERVER_URI)
52+ pmix_uri = ENV [_PMIX_SERVER_URI[varname]]
53+
54+ job_id = parse (Int, split (split (pmix_uri, ' -' ; limit= 3 )[3 ], " @" ; limit= 2 )[1 ])
55+ port = job_id % 2 ^ 12 + (65535 - 2 ^ 12 + 1 )
56+
57+ launcher_ip_match = match (r" tcp4://(.+?):|tcp6://\[ (.+?)\] " , pmix_uri)
58+
59+ @assert launcher_ip_match != = nothing " Could not parse coordinator IP address from \
60+ Open MPI environment."
61+
62+ launcher_ip = launcher_ip_match. captures[findfirst (
63+ ! isnothing, launcher_ip_match. captures
64+ )]
65+
66+ return " $(launcher_ip) :$(port) "
67+ end
68+
69+ function Distributed. get_process_count (:: Distributed.AbstractOMPIClusterEnvDetector )
70+ return parse (Int, ENV [_OMPI_PROCESS_COUNT])
71+ end
72+
73+ function Distributed. get_process_id (:: Distributed.AbstractOMPIClusterEnvDetector )
74+ return parse (Int, ENV [_OMPI_PROCESS_ID])
75+ end
76+
77+ function Distributed. get_local_process_id (:: Distributed.AbstractOMPIClusterEnvDetector )
78+ return parse (Int, ENV [_OMPI_LOCAL_PROCESS_ID])
79+ end
80+
81+ # Based on mpi4py
982Distributed. is_env_present (:: Distributed.MPIEnvDetector ) = MPI. Initialized ()
1083
1184function Distributed. get_coordinator_address (
0 commit comments