forked from AI-Hypercomputer/maxtext
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgpu_multi_process_run.sh
152 lines (136 loc) · 4.87 KB
/
gpu_multi_process_run.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#! /bin/bash
set -e
set -u
set -o pipefail
: "${NNODES:?Must set NNODES}"
: "${NODE_RANK:?Must set NODE_RANK}"
: "${JAX_COORDINATOR_PORT:?Must set JAX_COORDINATOR_PORT}"
: "${JAX_COORDINATOR_ADDRESS:?Must set JAX_COORDINATOR_ADDRESS}"
: "${GPUS_PER_NODE:?Must set GPUS_PER_NODE}"
: "${COMMAND:?Must set COMMAND}"
export GPUS_PER_NODE=$GPUS_PER_NODE
export JAX_COORDINATOR_PORT=$JAX_COORDINATOR_PORT
export JAX_COORDINATOR_ADDRESS=$JAX_COORDINATOR_ADDRESS
set_nccl_gpudirect_tcpx_specific_configuration() {
if [[ "$USE_GPUDIRECT" == "tcpx" ]] || [[ "$USE_GPUDIRECT" == "fastrak" ]]; then
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NCCL_CROSS_NIC=0
export NCCL_DEBUG=INFO
export NCCL_DYNAMIC_CHUNK_SIZE=524288
export NCCL_NET_GDR_LEVEL=PIX
export NCCL_NVLS_ENABLE=0
export NCCL_P2P_NET_CHUNKSIZE=524288
export NCCL_P2P_NVL_CHUNKSIZE=1048576
export NCCL_P2P_PCI_CHUNKSIZE=524288
export NCCL_PROTO=Simple
export NCCL_SOCKET_IFNAME=eth0
export NVTE_FUSED_ATTN=1
export TF_CPP_MAX_LOG_LEVEL=100
export TF_CPP_VMODULE=profile_guided_latency_estimator=10
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
if [[ "$USE_GPUDIRECT" == "tcpx" ]]; then
echo "Using GPUDirect-TCPX"
export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/tcpx/lib64"
export NCCL_ALGO=Ring
export NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV,TUNING,NET,VERSION
export NCCL_GPUDIRECTTCPX_CTRL_DEV=eth0
export NCCL_GPUDIRECTTCPX_FORCE_ACK=0
export NCCL_GPUDIRECTTCPX_PROGRAM_FLOW_STEERING_WAIT_MICROS=1000000
export NCCL_GPUDIRECTTCPX_RX_BINDINGS="eth1:22-35,124-139;eth2:22-35,124-139;eth3:74-87,178-191;eth4:74-87,178-191"
export NCCL_GPUDIRECTTCPX_SOCKET_IFNAME=eth1,eth2,eth3,eth4
export NCCL_GPUDIRECTTCPX_TX_BINDINGS="eth1:8-21,112-125;eth2:8-21,112-125;eth3:60-73,164-177;eth4:60-73,164-177"
export NCCL_GPUDIRECTTCPX_TX_COMPLETION_NANOSLEEP=1000
export NCCL_MAX_NCHANNELS=12
export NCCL_MIN_NCHANNELS=12
export NCCL_NSOCKS_PERTHREAD=4
export NCCL_P2P_PXN_LEVEL=0
export NCCL_SOCKET_NTHREADS=1
elif [[ "$USE_GPUDIRECT" == "fastrak" ]]; then
echo "Using GPUDirect-TCPFasTrak"
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/tcpxo/lib64"
export NCCL_ALGO=Ring,Tree
export NCCL_BUFFSIZE=8388608
export NCCL_FASTRAK_CTRL_DEV=eth0
export NCCL_FASTRAK_ENABLE_CONTROL_CHANNEL=0
export NCCL_FASTRAK_ENABLE_HOTPATH_LOGGING=0
export NCCL_FASTRAK_IFNAME=eth1,eth2,eth3,eth4,eth5,eth6,eth7,eth8
export NCCL_FASTRAK_NUM_FLOWS=2
export NCCL_FASTRAK_USE_LLCM=1
export NCCL_FASTRAK_USE_SNAP=1
export NCCL_MIN_NCHANNELS=4
export NCCL_TUNER_CONFIG_PATH=/usr/local/nvidia/lib64/a3plus_tuner_config.textproto
export NCCL_TUNER_PLUGIN=libnccl-tuner.so
fi
else
echo "NOT using GPUDirect"
fi
}
echo "LD_LIBRARY_PATH ${LD_LIBRARY_PATH}"
set_nccl_gpudirect_tcpx_specific_configuration
wait_all_success_or_exit() {
# https://www.baeldung.com/linux/background-process-get-exit-code
local pids=("$@")
while [[ ${#pids[@]} -ne 0 ]]; do
all_success="true"
for pid in "${pids[@]}"; do
code=$(non_blocking_wait "$pid")
if [[ $code -ne 127 ]]; then
if [[ $code -ne 0 ]]; then
echo "PID $pid failed with exit code $code"
exit "$code"
fi
else
all_success="false"
fi
done
if [[ $all_success == "true" ]]; then
echo "All pids succeeded"
break
fi
sleep 5
done
}
non_blocking_wait() {
# https://www.baeldung.com/linux/background-process-get-exit-code
local pid=$1
local code=127 # special code to indicate not-finished
if [[ ! -d "/proc/$pid" ]]; then
wait "$pid"
code=$?
fi
echo $code
}
resolve_coordinator_ip() {
local lookup_attempt=1
local max_coordinator_lookups=500
local coordinator_found=false
local coordinator_ip_address=""
echo "Coordinator Address $JAX_COORDINATOR_ADDRESS"
while [[ "$coordinator_found" = false && $lookup_attempt -le $max_coordinator_lookups ]]; do
coordinator_ip_address=$(nslookup "$JAX_COORDINATOR_ADDRESS" 2>/dev/null | awk '/^Address: / { print $2 }' | head -n 1)
if [[ -n "$coordinator_ip_address" ]]; then
coordinator_found=true
echo "Coordinator IP address: $coordinator_ip_address"
export JAX_COORDINATOR_IP=$coordinator_ip_address
return 0
else
echo "Failed to recognize coordinator address $JAX_COORDINATOR_ADDRESS on attempt $lookup_attempt, retrying..."
((lookup_attempt++))
sleep 1
fi
done
if [[ "$coordinator_found" = false ]]; then
echo "Failed to resolve coordinator address after $max_coordinator_lookups attempts."
return 1
fi
}
# Resolving coordinator IP
set +e
resolve_coordinator_ip
set -e
PIDS=()
eval ${COMMAND} &
PID=$!
PIDS+=($PID)
wait_all_success_or_exit "${PIDS[@]}"