Skip to content

Commit

Permalink
Close requests.Socket in RemoteScheduler before exiting (spotify#3173)
Browse files Browse the repository at this point in the history
  • Loading branch information
starhel committed Jun 10, 2022
1 parent 78145d6 commit 285d257
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
2 changes: 2 additions & 0 deletions luigi/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def _schedule_and_run(tasks, worker_scheduler_factory=None, override_defaults=No
success &= worker.run()
luigi_run_result = LuigiRunResult(worker, success)
logger.info(luigi_run_result.summary_text)
if hasattr(sch, 'close'):
sch.close()
return luigi_run_result


Expand Down
32 changes: 27 additions & 5 deletions luigi/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
rpc.py implements the client side of it, server.py implements the server side.
See :doc:`/central_scheduler` for more info.
"""
import abc
import collections
import os
import json
import logging
import signal
import socket
import base64

Expand Down Expand Up @@ -69,7 +72,17 @@ def __init__(self, message, sub_exception=None):
self.sub_exception = sub_exception


class URLLibFetcher:
class _FetcherInterface(metaclass=abc.ABCMeta):
@abc.abstractmethod
def fetch(self, full_url, body, timeout):
pass

@abc.abstractmethod
def close(self):
pass


class URLLibFetcher(_FetcherInterface):
raises = (URLError, socket.timeout)

def _create_request(self, full_url, body=None):
Expand All @@ -96,12 +109,15 @@ def fetch(self, full_url, body, timeout):
req = self._create_request(full_url, body=body)
return urlopen(req, timeout=timeout).read().decode('utf-8')

def close(self):
pass

class RequestsFetcher:
def __init__(self, session):

class RequestsFetcher(_FetcherInterface):
def __init__(self):
from requests import exceptions as requests_exceptions
self.raises = requests_exceptions.RequestException
self.session = session
self.session = requests.Session()
self.process_id = os.getpid()

def check_pid(self):
Expand All @@ -117,6 +133,9 @@ def fetch(self, full_url, body, timeout):
resp.raise_for_status()
return resp.text

def close(self):
self.session.close()


class RemoteScheduler:
"""
Expand All @@ -140,10 +159,13 @@ def __init__(self, url='http://localhost:8082/', connect_timeout=None):
self._rpc_log_retries = config.getboolean('core', 'rpc-log-retries', True)

if HAS_REQUESTS:
self._fetcher = RequestsFetcher(requests.Session())
self._fetcher = RequestsFetcher()
else:
self._fetcher = URLLibFetcher()

def close(self):
self._fetcher.close()

def _get_retryer(self):
def retry_logging(retry_state):
if self._rpc_log_retries:
Expand Down

0 comments on commit 285d257

Please sign in to comment.