Skip to content

Commit 3ff6ed3

Browse files
Sparks0219dayshah
andauthored
[core] Add release test to simulate network transient error via ip tables (#58241)
Signed-off-by: joshlee <joshlee@anyscale.com> Co-authored-by: Dhyey Shah <dhyey2019@gmail.com>
1 parent dad8002 commit 3ff6ed3

File tree

2 files changed

+273
-2
lines changed

2 files changed

+273
-2
lines changed
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
import argparse
2+
import subprocess
3+
import sys
4+
import threading
5+
import time
6+
from concurrent.futures import ThreadPoolExecutor, as_completed
7+
8+
import ray
9+
import random
10+
import ray.util
11+
12+
# The goal of the this script is to simulate cross AZ transient network failures periodically on a Ray job.
13+
# We do this by modifying the iptables to drop all inbound and outbound traffic for a given duration
14+
# except for intra-node and SSH traffic. After the duration, the iptables rules are restored.
15+
# The failure script is run in a background thread while the main command is run in the foreground.
16+
# NOTE: The script itself does not spin up a Ray cluster, it operates on the assumption that an existing
17+
# Ray cluster is running and we are able to SSH into the nodes (like on Anyscale).
18+
19+
PARALLEL = 500 # concurrent SSH sessions
20+
SSH_USER = "ubuntu" # Anyscale default
21+
AFFECT_WORKER_RATIO = 0.50 # failure affects 50% of worker nodes
22+
EXTRA_SSH = [
23+
"-o",
24+
"BatchMode=yes",
25+
"-o",
26+
"StrictHostKeyChecking=accept-new",
27+
"-o",
28+
"ConnectTimeout=10",
29+
]
30+
31+
32+
def iptables_cmd(self_ip: str, seconds: int) -> str:
33+
return f"""\
34+
nohup setsid bash -lc '
35+
sudo iptables -w -A INPUT -p tcp --dport 22 -j ACCEPT
36+
sudo iptables -w -A OUTPUT -p tcp --sport 22 -j ACCEPT
37+
sudo iptables -w -A INPUT -s 127.0.0.0/8 -d 127.0.0.0/8 -j ACCEPT
38+
sudo iptables -w -A OUTPUT -s 127.0.0.0/8 -d 127.0.0.0/8 -j ACCEPT
39+
sudo iptables -w -A INPUT -s {self_ip} -d {self_ip} -j ACCEPT
40+
sudo iptables -w -A OUTPUT -s {self_ip} -d {self_ip} -j ACCEPT
41+
sudo iptables -w -A INPUT -j DROP
42+
sudo iptables -w -A OUTPUT -j DROP
43+
sleep {seconds}
44+
sudo iptables -w -D OUTPUT -j DROP
45+
sudo iptables -w -D INPUT -j DROP
46+
sudo iptables -w -D OUTPUT -s {self_ip} -d {self_ip} -j ACCEPT
47+
sudo iptables -w -D INPUT -s {self_ip} -d {self_ip} -j ACCEPT
48+
sudo iptables -w -D OUTPUT -s 127.0.0.0/8 -d 127.0.0.0/8 -j ACCEPT
49+
sudo iptables -w -D INPUT -s 127.0.0.0/8 -d 127.0.0.0/8 -j ACCEPT
50+
sudo iptables -w -D OUTPUT -p tcp --sport 22 -j ACCEPT
51+
sudo iptables -w -D INPUT -p tcp --dport 22 -j ACCEPT
52+
' &>/dev/null &
53+
"""
54+
55+
56+
def ssh_run(ip: str, cmd: str) -> tuple[bool, str]:
57+
"""Run SSH command on remote host."""
58+
target = f"{SSH_USER}@{ip}"
59+
res = subprocess.run(
60+
["ssh", *EXTRA_SSH, target, cmd], capture_output=True, text=True
61+
)
62+
ok = res.returncode == 0
63+
msg = res.stdout.strip() if ok else (res.stderr.strip() or res.stdout.strip())
64+
return ok, msg
65+
66+
67+
def simulate_cross_az_network_failure(seconds: int):
68+
if not ray.is_initialized():
69+
ray.init(address="auto")
70+
71+
nodes = ray.nodes()
72+
all_ips = [n["NodeManagerAddress"] for n in nodes if n.get("Alive", False)]
73+
# Always inject failures on the head node
74+
head_ip = next(
75+
(
76+
n["NodeManagerAddress"]
77+
for n in nodes
78+
if n.get("NodeManagerAddress") == ray.util.get_node_ip_address()
79+
),
80+
None,
81+
)
82+
83+
print(f"Discovered {len(all_ips)} alive nodes")
84+
print(f"Head node: {head_ip}")
85+
86+
worker_ips = [ip for ip in all_ips if ip != head_ip]
87+
print(f"Eligible worker nodes: {len(worker_ips)}")
88+
if not worker_ips:
89+
print("ERROR: No worker nodes found")
90+
return
91+
92+
k = max(1, int(len(worker_ips) * AFFECT_WORKER_RATIO))
93+
affected = random.sample(worker_ips, k)
94+
# NOTE: When running this script on Anyscale with longer failure durations the blacked out head node could
95+
# cause your workspace to lag and die. To avoid this, comment out the below line.
96+
affected.append(head_ip)
97+
print(
98+
f"Affecting {len(affected)} nodes (~{AFFECT_WORKER_RATIO*100:.0f}% of workers + head node):"
99+
)
100+
print(", ".join(affected[:10]) + (" ..." if len(affected) > 10 else ""))
101+
102+
cmds = {ip: iptables_cmd(ip, seconds) for ip in affected}
103+
104+
print(f"\nTriggering {seconds}s of transient network failure...")
105+
successes, failures = [], {}
106+
107+
with ThreadPoolExecutor(max_workers=PARALLEL) as ex:
108+
futs = {ex.submit(ssh_run, ip, cmds[ip]): ip for ip in affected}
109+
for fut in as_completed(futs):
110+
ip = futs[fut]
111+
try:
112+
ok, msg = fut.result()
113+
if ok:
114+
successes.append(ip)
115+
else:
116+
failures[ip] = msg
117+
except Exception as e:
118+
failures[ip] = str(e)
119+
120+
print("\n=== Summary ===")
121+
print(f"Succeeded: {len(successes)} nodes")
122+
print(f"Failed : {len(failures)} nodes")
123+
if failures:
124+
for ip, msg in list(failures.items()):
125+
print(f" {ip}: {msg}")
126+
127+
128+
def network_failure_loop(interval, network_failure_duration):
129+
"""
130+
Run the network failure loop in a background thread at regular intervals.
131+
132+
Args:
133+
interval: Interval in seconds between network failure events
134+
network_failure_duration: Duration in seconds of each network failure
135+
"""
136+
print(
137+
f"[NETWORK FAILURE {time.strftime('%H:%M:%S')}] Starting network failure thread with interval: {interval} seconds"
138+
)
139+
140+
while True:
141+
# Sleep for the interval duration
142+
time.sleep(interval)
143+
144+
# Simulate a network failure
145+
print(
146+
f"[NETWORK FAILURE {time.strftime('%H:%M:%S')}] Triggering network failure simulation..."
147+
)
148+
try:
149+
simulate_cross_az_network_failure(network_failure_duration)
150+
except Exception as e:
151+
print(
152+
f"[NETWORK FAILURE {time.strftime('%H:%M:%S')}] ERROR: Network failure simulation failed: {e}"
153+
)
154+
155+
156+
def parse_args():
157+
parser = argparse.ArgumentParser(
158+
description="Run benchmark with network failure injection at regular intervals",
159+
formatter_class=argparse.RawDescriptionHelpFormatter,
160+
epilog="""
161+
Examples:
162+
# Run map_benchmark with network failures injected every 300 seconds, each lasting 5 seconds
163+
python simulate_cross_az_network_failure.py --network-failure-interval 300 --network-failure-duration 5 --command python map_benchmark.py --api map_batches --sf 1000
164+
""",
165+
)
166+
parser.add_argument(
167+
"--network-failure-interval",
168+
type=int,
169+
required=True,
170+
help="Interval in seconds between network failure events",
171+
)
172+
parser.add_argument(
173+
"--network-failure-duration",
174+
type=int,
175+
required=True,
176+
help="Duration in seconds of each network failure",
177+
)
178+
parser.add_argument(
179+
"--command",
180+
nargs=argparse.REMAINDER,
181+
required=True,
182+
help="The main command to run (e.g., 'python map_benchmark.py --api map_batches ...')",
183+
)
184+
return parser.parse_args()
185+
186+
187+
def main():
188+
args = parse_args()
189+
190+
# Validate command (argparse catches missing --command, but not empty --command)
191+
if not args.command:
192+
print("ERROR: --command requires at least one argument")
193+
print(
194+
"Usage: python simulate_cross_az_network_failure.py --network-failure-interval <seconds> --network-failure-duration <seconds> --command <command>"
195+
)
196+
sys.exit(1)
197+
198+
print("=" * 80)
199+
print("Running with Network Failure Injection")
200+
print("=" * 80)
201+
print(f"Network failure interval: {args.network_failure_interval} seconds")
202+
print(f"Network failure duration: {args.network_failure_duration} seconds")
203+
print(f"Command: {' '.join(args.command)}")
204+
print("=" * 80)
205+
print()
206+
207+
# Start network failure thread as daemon - it will die with the process
208+
network_failure_thread = threading.Thread(
209+
target=network_failure_loop,
210+
args=(args.network_failure_interval, args.network_failure_duration),
211+
daemon=True,
212+
)
213+
network_failure_thread.start()
214+
215+
try:
216+
# Run the main command in the foreground
217+
print(
218+
f"[MAIN {time.strftime('%H:%M:%S')}] Starting command: {' '.join(args.command)}"
219+
)
220+
main_result = subprocess.run(args.command)
221+
print(
222+
f"\n[MAIN {time.strftime('%H:%M:%S')}] Command completed with exit code: {main_result.returncode}"
223+
)
224+
exit_code = main_result.returncode
225+
226+
except KeyboardInterrupt:
227+
print("\n[MAIN] Interrupted by user")
228+
exit_code = 130
229+
230+
except Exception as e:
231+
print(f"[MAIN] ERROR: {e}")
232+
exit_code = 1
233+
234+
print("\n" + "=" * 80)
235+
print(f"Execution completed with exit code: {exit_code}")
236+
print("=" * 80)
237+
238+
sys.exit(exit_code)
239+
240+
241+
if __name__ == "__main__":
242+
main()

release/release_data_tests.yaml

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -653,9 +653,8 @@
653653
#################################################
654654

655655
- name: "cross_az_map_batches_autoscaling"
656-
frequency: nightly
656+
frequency: manual
657657
env: gce
658-
659658
cluster:
660659
cluster_compute: cross_az_250_350_compute_gce.yaml
661660

@@ -690,3 +689,33 @@
690689
# - RAY_testing_rpc_failure="*=-1:10:10:1:1"
691690
# - RAY_testing_rpc_failure_avoid_intra_node_failures=1
692691
# cluster_compute: cross_az_250_350_compute_aws.yaml
692+
693+
- name: "cross_az_map_batches_autoscaling_iptable_failure_injection"
694+
frequency: nightly
695+
env: gce
696+
working_dir: nightly_tests
697+
698+
cluster:
699+
byod:
700+
runtime_env:
701+
- RAY_health_check_period_ms=10000
702+
- RAY_health_check_timeout_ms=100000
703+
- RAY_health_check_failure_threshold=10
704+
- RAY_gcs_rpc_server_connect_timeout_s=60
705+
cluster_compute: dataset/cross_az_250_350_compute_gce.yaml
706+
707+
run:
708+
timeout: 10800
709+
# The network failure interval is set to 210 seconds since the test as is takes around double that to run without failures.
710+
# If the runtime of the test is dramatically reduced in the future, the interval will have to be retuned.
711+
script: >
712+
python simulate_cross_az_network_failure.py --network-failure-interval 210 --network-failure-duration 5 --command python dataset/map_benchmark.py
713+
--api map_batches --batch-format numpy --compute actors --sf 1000
714+
--repeat-inputs 1 --concurrency 1024 2048
715+
716+
variations:
717+
- __suffix__: gce
718+
- __suffix__: aws
719+
env: aws
720+
cluster:
721+
cluster_compute: dataset/cross_az_250_350_compute_aws.yaml

0 commit comments

Comments
 (0)