-
Notifications
You must be signed in to change notification settings - Fork 130
/
Copy pathdemo-horovod-mpi.py
99 lines (78 loc) · 2.64 KB
/
demo-horovod-mpi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#!/usr/bin/env python3
"""
Small demo for Horovod + MPI.
"""
import os
import sys
def main():
"""
Main entry.
"""
print("pid %i: Hello" % os.getpid())
print("Python version:", sys.version)
print("Env:")
for key, value in sorted(os.environ.items()):
print("%s=%s" % (key, value))
print()
if os.environ.get("PE_HOSTFILE", ""):
try:
print("PE_HOSTFILE, %s:" % os.environ["PE_HOSTFILE"])
with open(os.environ["PE_HOSTFILE"], "r") as f:
print(f.read())
except Exception as exc:
print(exc)
if os.environ.get("SGE_JOB_SPOOL_DIR", ""):
print("SGE_JOB_SPOOL_DIR, %s:" % os.environ["SGE_JOB_SPOOL_DIR"])
try:
for name in os.listdir(os.environ["SGE_JOB_SPOOL_DIR"]):
print(name)
print()
except Exception as exc:
print(exc)
if os.environ.get("OMPI_FILE_LOCATION", ""):
print("OMPI_FILE_LOCATION, %s:" % os.environ["OMPI_FILE_LOCATION"])
d = os.path.dirname(os.path.dirname(os.environ["OMPI_FILE_LOCATION"]))
try:
print("dir:", d)
for name in os.listdir(d):
print(name)
print()
print("contact.txt:")
with open("%s/contact.txt" % d, "r") as f:
print(f.read())
print()
except Exception as exc:
print(exc)
# https://github.com/horovod/horovod/issues/1123
try:
import ctypes
ctypes.CDLL("libhwloc.so", mode=ctypes.RTLD_GLOBAL)
except Exception as exc:
print("Exception while loading libhwloc.so, ignoring...", exc)
print("sys.path:")
for p in list(sys.path):
print(p)
print()
try:
from mpi4py import MPI # noqa
name = MPI.Get_processor_name()
comm = MPI.COMM_WORLD
print("mpi4py:", "name: %s," % name, "rank: %i," % comm.Get_rank(), "size: %i" % comm.Get_size())
hosts = comm.allgather((comm.Get_rank(), name)) # Get the names of all the other hosts
print(" all hosts:", {key: item for (key, item) in hosts})
except ImportError:
print("mpi4py not available")
print("Import TF now...")
import tensorflow as tf
print("TF version:", tf.__version__)
import horovod # noqa
print("Horovod version:", horovod.__version__)
import horovod.tensorflow as hvd # noqa
# Initialize Horovod
hvd.init()
print(
"pid %i: hvd: rank: %i, size: %i, local_rank %i, local_size %i"
% (os.getpid(), hvd.rank(), hvd.size(), hvd.local_rank(), hvd.local_size())
)
if __name__ == "__main__":
main()