diff --git a/luigi/contrib/bigquery.py b/luigi/contrib/bigquery.py index 9da6d883ed..5d93efdaed 100644 --- a/luigi/contrib/bigquery.py +++ b/luigi/contrib/bigquery.py @@ -51,7 +51,7 @@ def is_error_5xx(err): wait=wait_exponential(multiplier=1, min=1, max=10), stop=stop_after_attempt(3), reraise=True, - after=lambda x: x.args[0].__initialise_client() + after=lambda x: x.args[0]._initialise_client() ) @@ -152,9 +152,9 @@ def __init__(self, oauth_credentials=None, descriptor='', http_=None): self.descriptor = descriptor self.http_ = http_ - self.__initialise_client() + self._initialise_client() - def __initialise_client(self): + def _initialise_client(self): authenticate_kwargs = gcp.get_authenticate_kwargs(self.oauth_credentials, self.http_) if self.descriptor: diff --git a/test/contrib/bigquery_test.py b/test/contrib/bigquery_test.py index 61ee2d7cda..25ad19fef9 100644 --- a/test/contrib/bigquery_test.py +++ b/test/contrib/bigquery_test.py @@ -23,9 +23,11 @@ import mock import pytest +from mock.mock import MagicMock +from luigi.contrib import bigquery from luigi.contrib.bigquery import BigQueryLoadTask, BigQueryTarget, BQDataset, \ - BigQueryRunQueryTask, BigQueryExtractTask + BigQueryRunQueryTask, BigQueryExtractTask, BigQueryClient from luigi.contrib.gcs import GCSTarget @@ -147,3 +149,31 @@ def output(self): } } run_job.assert_called_with('proj', expected_body, dataset=BQDataset('proj', 'ds', None)) + + +class BigQueryClientTest(unittest.TestCase): + + def test_retry_succeeds_on_second_attempt(self): + try: + from googleapiclient import errors + except ImportError: + raise unittest.SkipTest('Unable to load googleapiclient module') + client = MagicMock(spec=BigQueryClient) + attempts = 0 + + @bigquery.bq_retry + def fail_once(bq_client): + nonlocal attempts + attempts += 1 + if attempts == 1: + raise errors.HttpError( + resp=MagicMock(status=500), + content=b'{"error": {"message": "stub"}', + ) + else: + return MagicMock(status=200) + + response = fail_once(client) + client._initialise_client.assert_called_once() + self.assertEqual(attempts, 2) + self.assertEqual(response.status, 200)