Skip to content

Commit

Permalink
test: simplify test a little
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Nov 6, 2024
1 parent 790c158 commit 22ffc84
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 124 deletions.
4 changes: 2 additions & 2 deletions tests/integration/docarray_v2/test_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,10 +1399,10 @@ def search(


@pytest.mark.parametrize(
'protocols', [['grpc'], ['http'], ['websocket'], ['grpc', 'http', 'websocket']]
'protocols', [['grpc'], ['http'], ['websocket']]
)
@pytest.mark.parametrize('reduce', [True, False])
@pytest.mark.parametrize('sleep_time', [0.1, 5])
@pytest.mark.parametrize('sleep_time', [5])
def test_flow_with_shards_all_shards_return(protocols, reduce, sleep_time):
from typing import List

Expand Down
231 changes: 116 additions & 115 deletions tests/integration/network_failures/test_network_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,96 @@ def _test_error(gateway_port, error_ports, protocol):
assert str(port) in err_info.value.args[0]


@pytest.mark.parametrize('protocol', ['grpc', 'http'])
@pytest.mark.parametrize('fail_endpoint_discovery', [True, False])
@pytest.mark.asyncio
async def test_runtimes_reconnect(port_generator, protocol, fail_endpoint_discovery):
# create gateway and workers manually, then terminate worker process to provoke an error
worker_port = port_generator()
gateway_port = port_generator()
graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}'
pod_addresses = f'{{"pod0": ["0.0.0.0:{worker_port}"]}}'

gateway_process = _create_gateway(
gateway_port, graph_description, pod_addresses, protocol
)

BaseServer.wait_for_ready_or_shutdown(
timeout=5.0,
ctrl_address=f'0.0.0.0:{gateway_port}',
ready_or_shutdown_event=multiprocessing.Event(),
)

try:
if fail_endpoint_discovery:
# send request while Executor is not UP, WILL FAIL
p = multiprocessing.Process(
target=_send_request, args=(gateway_port, protocol)
)
p.start()
p.join()
assert p.exitcode != 0, f"The _send_request #0 Process exited with exitcode {p.exitcode}" # The request will fail and raise

worker_process = _create_worker(worker_port)
assert BaseServer.wait_for_ready_or_shutdown(
timeout=5.0,
ctrl_address=f'0.0.0.0:{worker_port}',
ready_or_shutdown_event=multiprocessing.Event(),
), "The BaseServer wait_for_ready_or_shutdown for worker_port failed"

p = multiprocessing.Process(target=_send_request, args=(gateway_port, protocol))
p.start()
p.join()
assert p.exitcode == 0, f"The _send_request #1 Process exited with exitcode {p.exitcode}" # The request will not fail and raise
worker_process.terminate() # kill worker
worker_process.join()
assert not worker_process.is_alive()

# send request while Executor is not UP, WILL FAIL
p = multiprocessing.Process(target=_send_request, args=(gateway_port, protocol))
p.start()
p.join()
assert p.exitcode != 0, f"The _send_request #2 Process exited with exitcode {p.exitcode}" # The request will not fail and rais

worker_process = _create_worker(worker_port)

assert BaseServer.wait_for_ready_or_shutdown(
timeout=5.0,
ctrl_address=f'0.0.0.0:{worker_port}',
ready_or_shutdown_event=multiprocessing.Event(),
), "The BaseServer wait_for_ready_or_shutdown for worker_port failed"
p = multiprocessing.Process(target=_send_request, args=(gateway_port, protocol))
p.start()
p.join()
assert (
p.exitcode == 0
), f"The _send_request #3 Process exited with exitcode {p.exitcode}" # The request will not fail and rais # if exitcode != 0 then test in other process did not pass and this should fail
# ----------- 2. test that gateways remain alive -----------
# just do the same again, expecting the same failure
worker_process.terminate() # kill worker
worker_process.join()
assert not worker_process.is_alive(), "Worker process is still alive"
assert (
worker_process.exitcode == 0
), f"The worker_process Process exited with exitcode {worker_process.exitcode}" # if exitcode != 0 then test in other process did not pass and this should fail

except Exception as exc:
print(f'===> Exception: {exc}')
assert False
finally: # clean up runtimes
gateway_process.terminate()
gateway_process.join()
worker_process.terminate()
worker_process.join()


@pytest.mark.parametrize(
'fail_before_endpoint_discovery', [True, False]
) # if not before, then after
@pytest.mark.parametrize('protocol', ['http', 'websocket', 'grpc'])
@pytest.mark.asyncio
async def test_runtimes_headless_topology(
port_generator, protocol, fail_before_endpoint_discovery
port_generator, protocol, fail_before_endpoint_discovery
):
# create gateway and workers manually, then terminate worker process to provoke an error
worker_port = port_generator()
Expand Down Expand Up @@ -134,7 +217,7 @@ async def test_runtimes_headless_topology(
)

if (
fail_before_endpoint_discovery
fail_before_endpoint_discovery
): # kill worker before having sent the first request, so before endpoint discov.
worker_process.terminate()
worker_process.join()
Expand All @@ -150,7 +233,7 @@ async def test_runtimes_headless_topology(
p.start()
p.join()
assert (
p.exitcode == 0
p.exitcode == 0
) # if exitcode != 0 then test in other process did not pass and this should fail
else:
# just ping the Flow without having killed a worker before. This (also) performs endpoint discovery
Expand All @@ -172,7 +255,7 @@ async def test_runtimes_headless_topology(
p.start()
p.join()
assert (
p.exitcode == 0
p.exitcode == 0
) # if exitcode != 0 then test in other process did not pass and this should fail
except Exception:
assert False
Expand Down Expand Up @@ -236,90 +319,8 @@ async def patch_process_data(self, requests_, context, **kwargs):
p.start()
p.join()
assert (
p.exitcode == 0
) # if exitcode != 0 then test in other process did not pass and this should fail
except Exception:
assert False
finally: # clean up runtimes
gateway_process.terminate()
gateway_process.join()
worker_process.terminate()
worker_process.join()


@pytest.mark.parametrize('protocol', ['grpc', 'http', 'grpc'])
@pytest.mark.parametrize('fail_endpoint_discovery', [True, False])
@pytest.mark.asyncio
async def test_runtimes_reconnect(port_generator, protocol, fail_endpoint_discovery):
# create gateway and workers manually, then terminate worker process to provoke an error
worker_port = port_generator()
gateway_port = port_generator()
graph_description = '{"start-gateway": ["pod0"], "pod0": ["end-gateway"]}'
pod_addresses = f'{{"pod0": ["0.0.0.0:{worker_port}"]}}'

gateway_process = _create_gateway(
gateway_port, graph_description, pod_addresses, protocol
)

BaseServer.wait_for_ready_or_shutdown(
timeout=5.0,
ctrl_address=f'0.0.0.0:{gateway_port}',
ready_or_shutdown_event=multiprocessing.Event(),
)

try:
if fail_endpoint_discovery:
# send request while Executor is not UP, WILL FAIL
p = multiprocessing.Process(
target=_send_request, args=(gateway_port, protocol)
)
p.start()
p.join()
assert p.exitcode != 0 # The request will fail and raise

worker_process = _create_worker(worker_port)
assert BaseServer.wait_for_ready_or_shutdown(
timeout=5.0,
ctrl_address=f'0.0.0.0:{worker_port}',
ready_or_shutdown_event=multiprocessing.Event(),
)

p = multiprocessing.Process(target=_send_request, args=(gateway_port, protocol))
p.start()
p.join()
assert p.exitcode == 0 # The request will not fail and raise
worker_process.terminate() # kill worker
worker_process.join()
assert not worker_process.is_alive()

# send request while Executor is not UP, WILL FAIL
p = multiprocessing.Process(target=_send_request, args=(gateway_port, protocol))
p.start()
p.join()
assert p.exitcode != 0

worker_process = _create_worker(worker_port)

assert BaseServer.wait_for_ready_or_shutdown(
timeout=5.0,
ctrl_address=f'0.0.0.0:{worker_port}',
ready_or_shutdown_event=multiprocessing.Event(),
)
p = multiprocessing.Process(target=_send_request, args=(gateway_port, protocol))
p.start()
p.join()
assert (
p.exitcode == 0
) # if exitcode != 0 then test in other process did not pass and this should fail
# ----------- 2. test that gateways remain alive -----------
# just do the same again, expecting the same failure
worker_process.terminate() # kill worker
worker_process.join()
assert not worker_process.is_alive()
assert (
worker_process.exitcode == 0
p.exitcode == 0
) # if exitcode != 0 then test in other process did not pass and this should fail

except Exception:
assert False
finally: # clean up runtimes
Expand All @@ -329,11 +330,11 @@ async def test_runtimes_reconnect(port_generator, protocol, fail_endpoint_discov
worker_process.join()


@pytest.mark.parametrize('protocol', ['grpc', 'http', 'grpc'])
@pytest.mark.parametrize('protocol', ['grpc', 'http'])
@pytest.mark.parametrize('fail_endpoint_discovery', [True, False])
@pytest.mark.asyncio
async def test_runtimes_reconnect_replicas(
port_generator, protocol, fail_endpoint_discovery
port_generator, protocol, fail_endpoint_discovery
):
# create gateway and workers manually, then terminate worker process to provoke an error
worker_ports = [port_generator() for _ in range(3)]
Expand Down Expand Up @@ -367,7 +368,7 @@ async def test_runtimes_reconnect_replicas(
p_first_check.start()
p_first_check.join()
assert (
p_first_check.exitcode == 0
p_first_check.exitcode == 0
) # all replicas are connected. At the end, the Flow should return to this state.

worker_processes[1].terminate() # kill 'middle' worker
Expand Down Expand Up @@ -424,7 +425,7 @@ async def test_runtimes_reconnect_replicas(
@pytest.mark.parametrize('fail_before_endpoint_discovery', [True, False])
@pytest.mark.asyncio
async def test_runtimes_replicas(
port_generator, protocol, fail_before_endpoint_discovery
port_generator, protocol, fail_before_endpoint_discovery
):
# create gateway and workers manually, then terminate worker process to provoke an error
worker_ports = [port_generator() for _ in range(3)]
Expand Down Expand Up @@ -453,23 +454,23 @@ async def test_runtimes_replicas(
)

if (
not fail_before_endpoint_discovery
not fail_before_endpoint_discovery
): # make successful request and trigger endpoint discovery
# we have to do this in a new process because otherwise grpc will be sad and everything will crash :(
p = multiprocessing.Process(target=_send_request, args=(gateway_port, protocol))
p.start()
p.join()
# different replica should be picked, no error should be raised
assert (
p.exitcode == 0
p.exitcode == 0
) # if exitcode != 0 then test in other process did not pass and this should fail

worker_processes[0].terminate() # kill first worker
worker_processes[0].join()

try:
for _ in range(
len(worker_ports)
len(worker_ports)
): # make sure all workers are targeted by round robin
# ----------- 1. test that useful errors are given -----------
# we have to do this in a new process because otherwise grpc will be sad and everything will crash :(
Expand All @@ -480,7 +481,7 @@ async def test_runtimes_replicas(
p.join()
# different replica should be picked, no error should be raised
assert (
p.exitcode == 0
p.exitcode == 0
) # if exitcode != 0 then test in other process did not pass and this should fail
except Exception:
assert False
Expand Down Expand Up @@ -555,7 +556,7 @@ async def test_runtimes_headful_topology(port_generator, protocol, terminate_hea
p.start()
p.join()
assert (
p.exitcode == 0
p.exitcode == 0
) # if exitcode != 0 then test in other process did not pass and this should fail
# ----------- 2. test that gateways remain alive -----------
# just do the same again, expecting the same outcome
Expand All @@ -565,7 +566,7 @@ async def test_runtimes_headful_topology(port_generator, protocol, terminate_hea
p.start()
p.join()
assert (
p.exitcode == 0
p.exitcode == 0
) # if exitcode != 0 then test in other process did not pass and this should fail
except Exception:
raise
Expand All @@ -581,8 +582,8 @@ async def test_runtimes_headful_topology(port_generator, protocol, terminate_hea
def _send_gql_request(gateway_port):
"""send request to gateway and see what happens"""
mutation = (
f'mutation {{'
+ '''docs(data: {text: "abcd"}) {
f'mutation {{'
+ '''docs(data: {text: "abcd"}) {
id
}
}
Expand All @@ -601,20 +602,20 @@ def _test_gql_error(gateway_port, error_port):

def _create_gqlgateway_runtime(graph_description, pod_addresses, port):
with AsyncNewLoopRuntime(
set_gateway_parser().parse_args(
[
'--graph-description',
graph_description,
'--deployments-addresses',
pod_addresses,
'--port',
str(port),
'--expose-graphql-endpoint',
'--protocol',
'http',
]
),
req_handler_cls=GatewayRequestHandler,
set_gateway_parser().parse_args(
[
'--graph-description',
graph_description,
'--deployments-addresses',
pod_addresses,
'--port',
str(port),
'--expose-graphql-endpoint',
'--protocol',
'http',
]
),
req_handler_cls=GatewayRequestHandler,
) as runtime:
runtime.run_forever()

Expand Down Expand Up @@ -666,7 +667,7 @@ async def test_runtimes_graphql(port_generator):
p.start()
p.join()
assert (
p.exitcode == 0
p.exitcode == 0
) # if exitcode != 0 then test in other process did not pass and this should fail
# ----------- 2. test that gateways remain alive -----------
# just do the same again, expecting the same outcome
Expand All @@ -676,7 +677,7 @@ async def test_runtimes_graphql(port_generator):
p.start()
p.join()
assert (
p.exitcode == 0
p.exitcode == 0
) # if exitcode != 0 then test in other process did not pass and this should fail
except Exception:
raise
Expand Down
Loading

0 comments on commit 22ffc84

Please sign in to comment.