Skip to content

Commit

Permalink
test: add test proving readinessProbe can pass while processing
Browse files Browse the repository at this point in the history
Signed-off-by: Joan Fontanals Martinez <joan.martinez@jina.ai>
  • Loading branch information
JoanFM committed Nov 21, 2022
1 parent bbc25ae commit 308a3fa
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 46 deletions.
41 changes: 21 additions & 20 deletions jina/serve/runtimes/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ async def _async_setup_grpc_server(self):
reflection.SERVICE_NAME,
)
# Mark all services as healthy.
health_pb2_grpc.add_HealthServicer_to_server(self, self._grpc_server)
health_pb2_grpc.add_HealthServicer_to_server(self._health_servicer, self._grpc_server)

for service in service_names:
self._health_servicer.set(service, health_pb2.HealthCheckResponse.SERVING)
Expand Down Expand Up @@ -275,22 +275,23 @@ async def _status(self, empty, context) -> jina_pb2.JinaInfoProto:
infoProto.envs[k] = str(v)
return infoProto

async def Check(
self, request: health_pb2.HealthCheckRequest, context
) -> health_pb2.HealthCheckResponse:
'''Calls the underlying HealthServicer.Check method with the same arguments
:param request: grpc request
:param context: grpc request context
:returns: the grpc HealthCheckResponse
'''
return self._health_servicer.Check(request, context)

async def Watch(
self, request: health_pb2.HealthCheckRequest, context
) -> health_pb2.HealthCheckResponse:
'''Calls the underlying HealthServicer.Watch method with the same arguments
:param request: grpc request
:param context: grpc request context
:returns: the grpc HealthCheckResponse
'''
return self._health_servicer.Watch(request, context)
# async def Check(
# self, request: health_pb2.HealthCheckRequest, context
# ) -> health_pb2.HealthCheckResponse:
# '''Calls the underlying HealthServicer.Check method with the same arguments
# :param request: grpc request
# :param context: grpc request context
# :returns: the grpc HealthCheckResponse
# '''
# print(f' CHECK REQUEST')
# return self._health_servicer.Check(request, context)
#
# async def Watch(
# self, request: health_pb2.HealthCheckRequest, context
# ) -> health_pb2.HealthCheckResponse:
# '''Calls the underlying HealthServicer.Watch method with the same arguments
# :param request: grpc request
# :param context: grpc request context
# :returns: the grpc HealthCheckResponse
# '''
# return self._health_servicer.Watch(request, context)
102 changes: 76 additions & 26 deletions tests/integration/runtimes/test_runtimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def complete_graph_dict():
@pytest.mark.parametrize('uses_after', [True, False])
# test gateway, head and worker runtime by creating them manually in a more Flow like topology with branching/merging
async def test_runtimes_flow_topology(
complete_graph_dict, uses_before, uses_after, port_generator
complete_graph_dict, uses_before, uses_after, port_generator
):
pods = [
pod_name for pod_name in complete_graph_dict.keys() if 'gateway' not in pod_name
Expand Down Expand Up @@ -443,7 +443,7 @@ async def test_runtimes_with_executor(port_generator):

assert len(response_list) == 20
assert (
len(response_list[0]) == (1 + 1 + 1) * 10 + 1
len(response_list[0]) == (1 + 1 + 1) * 10 + 1
) # 1 starting doc + 1 uses_before + every exec adds 1 * 10 shards + 1 doc uses_after

doc_texts = [doc.text for doc in response_list[0]]
Expand Down Expand Up @@ -693,13 +693,13 @@ def _create_worker_runtime(port, name='', executor=None):


def _create_head_runtime(
port,
connection_list_dict,
name='',
polling='ANY',
uses_before=None,
uses_after=None,
retries=-1,
port,
connection_list_dict,
name='',
polling='ANY',
uses_before=None,
uses_after=None,
retries=-1,
):
args = set_pod_parser().parse_args([])
args.port = port
Expand All @@ -717,23 +717,23 @@ def _create_head_runtime(


def _create_gateway_runtime(
graph_description, pod_addresses, port, protocol='grpc', retries=-1
graph_description, pod_addresses, port, protocol='grpc', retries=-1
):
with GatewayRuntime(
set_gateway_parser().parse_args(
[
'--graph-description',
graph_description,
'--deployments-addresses',
pod_addresses,
'--port',
str(port),
'--retries',
str(retries),
'--protocol',
protocol,
]
)
set_gateway_parser().parse_args(
[
'--graph-description',
graph_description,
'--deployments-addresses',
pod_addresses,
'--port',
str(port),
'--retries',
str(retries),
'--protocol',
protocol,
]
)
) as runtime:
runtime.run_forever()

Expand Down Expand Up @@ -784,8 +784,8 @@ async def test_head_runtime_with_offline_shards(port_generator):
)

with grpc.insecure_channel(
f'0.0.0.0:{head_port}',
options=GrpcConnectionPool.get_default_grpc_options(),
f'0.0.0.0:{head_port}',
options=GrpcConnectionPool.get_default_grpc_options(),
) as channel:
stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel)
_, call = stub.process_single_data.with_call(
Expand All @@ -804,3 +804,53 @@ async def test_head_runtime_with_offline_shards(port_generator):
head_process.join()
for shard_process in shard_processes:
shard_process.join()


def test_runtime_slow_processing_readiness(port_generator):
class SlowProcessingExecutor(Executor):
@requests
def foo(self, **kwargs):
time.sleep(10)

worker_port = port_generator()
# create a single worker runtime
worker_process = multiprocessing.Process(
target=_create_worker_runtime, args=(worker_port, f'pod0', 'SlowProcessingExecutor')
)
try:
worker_process.start()
AsyncNewLoopRuntime.wait_for_ready_or_shutdown(
timeout=5.0,
ctrl_address=f'0.0.0.0:{worker_port}',
ready_or_shutdown_event=multiprocessing.Event(),
)

def _send_messages():
with grpc.insecure_channel(
f'0.0.0.0:{worker_port}',
options=GrpcConnectionPool.get_default_grpc_options(),
) as channel:
stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel)
resp, _ = stub.process_single_data.with_call(
list(request_generator('/', DocumentArray([Document(text='abc')])))[0]
)
assert resp.docs[0].text == 'abc'

send_message_process = multiprocessing.Process(
target=_send_messages
)
send_message_process.start()

for _ in range(50):
is_ready = WorkerRuntime.is_ready(f'0.0.0.0:{worker_port}')
assert is_ready
time.sleep(0.5)
except Exception:
raise
finally:
worker_process.terminate()
send_message_process.terminate()
worker_process.join()
send_message_process.join()
assert worker_process.exitcode == 0
assert send_message_process.exitcode == 0

0 comments on commit 308a3fa

Please sign in to comment.