Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace naive retry with tenacity #3026

Merged
merged 13 commits into from
Feb 23, 2021
42 changes: 18 additions & 24 deletions luigi/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
import json
import logging
import socket
import time
import base64

from urllib.parse import urljoin, urlencode, urlparse
from urllib.request import urlopen, Request
from urllib.error import URLError

from tenacity import Retrying, wait_fixed, stop_after_attempt
from luigi import configuration
from luigi.scheduler import RPC_METHODS

Expand Down Expand Up @@ -144,35 +144,29 @@ def __init__(self, url='http://localhost:8082/', connect_timeout=None):
else:
self._fetcher = URLLibFetcher()

def _wait(self):
if self._rpc_log_retries:
logger.info("Wait for %d seconds" % self._rpc_retry_wait)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to keep this log. Never know what users would do with it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added wait log and fixed test in 8f36797

time.sleep(self._rpc_retry_wait)
def _get_retryer(self):
def retry_logging(retry_state):
if self._rpc_log_retries:
logger.warning("Failed connecting to remote scheduler %r", self._url, exc_info=True)
logger.info("Retrying attempt %r of %r (max)" % (retry_state.attempt_number + 1, self._rpc_retry_attempts))
logger.info("Wait for %d seconds" % self._rpc_retry_wait)

return Retrying(wait=wait_fixed(self._rpc_retry_wait),
stop=stop_after_attempt(self._rpc_retry_attempts),
reraise=True,
after=retry_logging)

def _fetch(self, url_suffix, body):
full_url = _urljoin(self._url, url_suffix)
last_exception = None
attempt = 0
while attempt < self._rpc_retry_attempts:
attempt += 1
if last_exception:
if self._rpc_log_retries:
logger.info("Retrying attempt %r of %r (max)" % (attempt, self._rpc_retry_attempts))
self._wait() # wait for a bit and retry
try:
response = self._fetcher.fetch(full_url, body, self._connect_timeout)
break
except self._fetcher.raises as e:
last_exception = e
if self._rpc_log_retries:
logger.warning("Failed connecting to remote scheduler %r", self._url,
exc_info=True)
continue
else:
scheduler_retry = self._get_retryer()

try:
response = scheduler_retry(self._fetcher.fetch, full_url, body, self._connect_timeout)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for some reason i'm struggling to see where it's defined that full_uri, body, and self._connect_timeout are being passed to self._fetcher.fetch() here.

Copy link
Contributor Author

@hirosassa hirosassa Feb 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your comment!
Here is a definition of Retrying, first argument of this class' __call__ is function, and followings are its arguments.
https://github.com/jd/tenacity/blob/3e2244535ccfbb6a4b7fdd77bfc9aa61a1302302/tenacity/__init__.py#L422-L442

So, the code above calls self._fetcher.fetch(full_url, body, self._connect_timeout) with retrying feature.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, i see. A bit obfuscated, but it is being used as tenacity expects

except self._fetcher.raises as e:
raise RPCError(
"Errors (%d attempts) when connecting to remote scheduler %r" %
(self._rpc_retry_attempts, self._url),
last_exception
e
)
return response

Expand Down
14 changes: 4 additions & 10 deletions test/rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import scheduler_api_test
import luigi.server
from server_test import ServerTestBase
import time
import socket
from multiprocessing import Process, Queue
import requests
Expand All @@ -41,15 +40,8 @@ def testUrlArgumentVariations(self):
fetcher.fetch.assert_called_once_with('http://zorg.com/api/123', '{}', 42)

def get_work(self, fetcher_side_effect):
class ShorterWaitRemoteScheduler(luigi.rpc.RemoteScheduler):
"""
A RemoteScheduler which waits shorter than usual before retrying (to speed up tests).
"""

def _wait(self):
time.sleep(1)

scheduler = ShorterWaitRemoteScheduler('http://zorg.com', 42)
scheduler = luigi.rpc.RemoteScheduler('http://zorg.com', 42)
scheduler._rpc_retry_wait = 1 # shorten wait time to speed up tests

with mock.patch.object(scheduler, '_fetcher') as fetcher:
fetcher.raises = socket.timeout, socket.gaierror
Expand Down Expand Up @@ -83,8 +75,10 @@ def test_log_rpc_retries_enabled(self, mock_logger):
self.assertEqual([
mock.call.warning('Failed connecting to remote scheduler %r', 'http://zorg.com', exc_info=True),
mock.call.info('Retrying attempt 2 of 3 (max)'),
mock.call.info('Wait for 1 seconds'),
mock.call.warning('Failed connecting to remote scheduler %r', 'http://zorg.com', exc_info=True),
mock.call.info('Retrying attempt 3 of 3 (max)'),
mock.call.info('Wait for 1 seconds'),
], mock_logger.mock_calls)

@with_config({'core': {'rpc-log-retries': 'false'}})
Expand Down
12 changes: 3 additions & 9 deletions test/worker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,13 +1151,7 @@ def run(self, result=None):
@email_patch
def test_connection_error(self, emails):
sch = RemoteScheduler('http://tld.invalid:1337', connect_timeout=1)

self.waits = 0

def dummy_wait():
self.waits += 1

sch._wait = dummy_wait
sch._rpc_retry_wait = 1 # shorten wait time to speed up tests

class A(DummyTask):
pass
Expand All @@ -1167,8 +1161,8 @@ class A(DummyTask):
with Worker(scheduler=sch) as worker:
try:
worker.add(a)
except RPCError:
self.assertEqual(self.waits, 2) # should attempt to add it 3 times
except RPCError as e:
self.assertTrue(str(e).find("Errors (3 attempts)") != -1)
self.assertNotEqual(emails, [])
self.assertTrue(emails[0].find("Luigi: Framework error while scheduling %s" % (a,)) != -1)
else:
Expand Down