diff --git a/python/seldon_core/app.py b/python/seldon_core/app.py index dce598f830..6be784786e 100644 --- a/python/seldon_core/app.py +++ b/python/seldon_core/app.py @@ -1,12 +1,20 @@ import os import logging +import atexit +from multiprocessing.util import _exit_function from typing import Dict, Union from gunicorn.app.base import BaseApplication logger = logging.getLogger(__name__) +def post_worker_init(worker): + # Remove the atexit handler set up by the parent process + # https://github.com/benoitc/gunicorn/issues/1391#issuecomment-467010209 + atexit.unregister(_exit_function) + + def accesslog(log_level: str) -> Union[str, None]: """ Enable / disable access log in Gunicorn depending on the log level. diff --git a/python/seldon_core/microservice.py b/python/seldon_core/microservice.py index 5696f03cd1..4149b95c44 100644 --- a/python/seldon_core/microservice.py +++ b/python/seldon_core/microservice.py @@ -20,6 +20,7 @@ UserModelApplication, accesslog, threads, + post_worker_init, ) logger = logging.getLogger(__name__) @@ -53,18 +54,23 @@ def start_servers( Auxilary flask process """ - p2 = mp.Process(target=target2) - p2.daemon = True - p2.start() + p2 = None + if target2: + p2 = mp.Process(target=target2, daemon=True) + p2.start() - p3 = mp.Process(target=metrics_target) - p3.daemon = True - p3.start() + p3 = None + if metrics_target: + p3 = mp.Process(target=metrics_target, daemon=True) + p3.start() target1() - p2.join() - p3.join() + if p2: + p2.join() + + if p3: + p3.join() def parse_parameters(parameters: Dict) -> Dict: @@ -391,6 +397,7 @@ def rest_prediction_server(): "workers": args.workers, "max_requests": args.max_requests, "max_requests_jitter": args.max_requests_jitter, + "post_worker_init": post_worker_init, } app = seldon_microservice.get_rest_microservice( user_object, seldon_metrics @@ -452,6 +459,7 @@ def rest_metrics_server(): "timeout": 5000, "max_requests": args.max_requests, "max_requests_jitter": args.max_requests_jitter, + "post_worker_init": post_worker_init, } StandaloneApplication(app, options=options).run() @@ -465,7 +473,7 @@ def rest_metrics_server(): else: server2_func = None - logger.info("Starting servers") + logger.info("Starting servers") start_servers(server1_func, server2_func, metrics_server_func)