diff --git a/jina/serve/runtimes/worker/__init__.py b/jina/serve/runtimes/worker/__init__.py index 4e7764935a792..29f79fd0195aa 100644 --- a/jina/serve/runtimes/worker/__init__.py +++ b/jina/serve/runtimes/worker/__init__.py @@ -138,7 +138,7 @@ async def _async_setup_grpc_server(self): reflection.SERVICE_NAME, ) # Mark all services as healthy. - health_pb2_grpc.add_HealthServicer_to_server(self, self._grpc_server) + health_pb2_grpc.add_HealthServicer_to_server(self._health_servicer, self._grpc_server) for service in service_names: self._health_servicer.set(service, health_pb2.HealthCheckResponse.SERVING) @@ -275,22 +275,23 @@ async def _status(self, empty, context) -> jina_pb2.JinaInfoProto: infoProto.envs[k] = str(v) return infoProto - async def Check( - self, request: health_pb2.HealthCheckRequest, context - ) -> health_pb2.HealthCheckResponse: - '''Calls the underlying HealthServicer.Check method with the same arguments - :param request: grpc request - :param context: grpc request context - :returns: the grpc HealthCheckResponse - ''' - return self._health_servicer.Check(request, context) - - async def Watch( - self, request: health_pb2.HealthCheckRequest, context - ) -> health_pb2.HealthCheckResponse: - '''Calls the underlying HealthServicer.Watch method with the same arguments - :param request: grpc request - :param context: grpc request context - :returns: the grpc HealthCheckResponse - ''' - return self._health_servicer.Watch(request, context) + # async def Check( + # self, request: health_pb2.HealthCheckRequest, context + # ) -> health_pb2.HealthCheckResponse: + # '''Calls the underlying HealthServicer.Check method with the same arguments + # :param request: grpc request + # :param context: grpc request context + # :returns: the grpc HealthCheckResponse + # ''' + # print(f' CHECK REQUEST') + # return self._health_servicer.Check(request, context) + # + # async def Watch( + # self, request: health_pb2.HealthCheckRequest, context + # ) -> health_pb2.HealthCheckResponse: + # '''Calls the underlying HealthServicer.Watch method with the same arguments + # :param request: grpc request + # :param context: grpc request context + # :returns: the grpc HealthCheckResponse + # ''' + # return self._health_servicer.Watch(request, context) diff --git a/tests/integration/runtimes/test_runtimes.py b/tests/integration/runtimes/test_runtimes.py index 6337a5df7f4ee..d2b28055dd954 100644 --- a/tests/integration/runtimes/test_runtimes.py +++ b/tests/integration/runtimes/test_runtimes.py @@ -115,7 +115,7 @@ def complete_graph_dict(): @pytest.mark.parametrize('uses_after', [True, False]) # test gateway, head and worker runtime by creating them manually in a more Flow like topology with branching/merging async def test_runtimes_flow_topology( - complete_graph_dict, uses_before, uses_after, port_generator + complete_graph_dict, uses_before, uses_after, port_generator ): pods = [ pod_name for pod_name in complete_graph_dict.keys() if 'gateway' not in pod_name @@ -443,7 +443,7 @@ async def test_runtimes_with_executor(port_generator): assert len(response_list) == 20 assert ( - len(response_list[0]) == (1 + 1 + 1) * 10 + 1 + len(response_list[0]) == (1 + 1 + 1) * 10 + 1 ) # 1 starting doc + 1 uses_before + every exec adds 1 * 10 shards + 1 doc uses_after doc_texts = [doc.text for doc in response_list[0]] @@ -693,13 +693,13 @@ def _create_worker_runtime(port, name='', executor=None): def _create_head_runtime( - port, - connection_list_dict, - name='', - polling='ANY', - uses_before=None, - uses_after=None, - retries=-1, + port, + connection_list_dict, + name='', + polling='ANY', + uses_before=None, + uses_after=None, + retries=-1, ): args = set_pod_parser().parse_args([]) args.port = port @@ -717,23 +717,23 @@ def _create_head_runtime( def _create_gateway_runtime( - graph_description, pod_addresses, port, protocol='grpc', retries=-1 + graph_description, pod_addresses, port, protocol='grpc', retries=-1 ): with GatewayRuntime( - set_gateway_parser().parse_args( - [ - '--graph-description', - graph_description, - '--deployments-addresses', - pod_addresses, - '--port', - str(port), - '--retries', - str(retries), - '--protocol', - protocol, - ] - ) + set_gateway_parser().parse_args( + [ + '--graph-description', + graph_description, + '--deployments-addresses', + pod_addresses, + '--port', + str(port), + '--retries', + str(retries), + '--protocol', + protocol, + ] + ) ) as runtime: runtime.run_forever() @@ -784,8 +784,8 @@ async def test_head_runtime_with_offline_shards(port_generator): ) with grpc.insecure_channel( - f'0.0.0.0:{head_port}', - options=GrpcConnectionPool.get_default_grpc_options(), + f'0.0.0.0:{head_port}', + options=GrpcConnectionPool.get_default_grpc_options(), ) as channel: stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) _, call = stub.process_single_data.with_call( @@ -804,3 +804,53 @@ async def test_head_runtime_with_offline_shards(port_generator): head_process.join() for shard_process in shard_processes: shard_process.join() + + +def test_runtime_slow_processing_readiness(port_generator): + class SlowProcessingExecutor(Executor): + @requests + def foo(self, **kwargs): + time.sleep(10) + + worker_port = port_generator() + # create a single worker runtime + worker_process = multiprocessing.Process( + target=_create_worker_runtime, args=(worker_port, f'pod0', 'SlowProcessingExecutor') + ) + try: + worker_process.start() + AsyncNewLoopRuntime.wait_for_ready_or_shutdown( + timeout=5.0, + ctrl_address=f'0.0.0.0:{worker_port}', + ready_or_shutdown_event=multiprocessing.Event(), + ) + + def _send_messages(): + with grpc.insecure_channel( + f'0.0.0.0:{worker_port}', + options=GrpcConnectionPool.get_default_grpc_options(), + ) as channel: + stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(channel) + resp, _ = stub.process_single_data.with_call( + list(request_generator('/', DocumentArray([Document(text='abc')])))[0] + ) + assert resp.docs[0].text == 'abc' + + send_message_process = multiprocessing.Process( + target=_send_messages + ) + send_message_process.start() + + for _ in range(50): + is_ready = WorkerRuntime.is_ready(f'0.0.0.0:{worker_port}') + assert is_ready + time.sleep(0.5) + except Exception: + raise + finally: + worker_process.terminate() + send_message_process.terminate() + worker_process.join() + send_message_process.join() + assert worker_process.exitcode == 0 + assert send_message_process.exitcode == 0