From 59084fd1f4c200986433f9ff60b28cd6f8a0bcc1 Mon Sep 17 00:00:00 2001 From: Sebastian Daum Date: Sun, 21 Apr 2024 08:19:24 +0200 Subject: [PATCH] fix: add retry logic in case of google auth refresh credential error (#38961) --- .../providers/google/cloud/hooks/bigquery.py | 1 + .../google/common/hooks/base_google.py | 48 +++++++++++++++++-- .../google/cloud/hooks/test_bigquery.py | 32 +++++++++++++ .../google/common/hooks/test_base_google.py | 44 ++++++++++++++++- 4 files changed, 118 insertions(+), 7 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index 7482be89fb4193..f8c0fd1a5d0d49 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -1580,6 +1580,7 @@ def cancel_job( time.sleep(5) @GoogleBaseHook.fallback_to_default_project_id + @GoogleBaseHook.refresh_credentials_retry() def get_job( self, job_id: str, diff --git a/airflow/providers/google/common/hooks/base_google.py b/airflow/providers/google/common/hooks/base_google.py index 0fe5d16aae1708..5800f8e44ccb6e 100644 --- a/airflow/providers/google/common/hooks/base_google.py +++ b/airflow/providers/google/common/hooks/base_google.py @@ -114,6 +114,19 @@ def is_operation_in_progress_exception(exception: Exception) -> bool: return False +def is_refresh_credentials_exception(exception: Exception) -> bool: + """ + Handle refresh credentials exceptions. + + Some calls return 502 (server error) in case a new token cannot be obtained. + + * Google BigQuery + """ + if isinstance(exception, RefreshError): + return "Unable to acquire impersonated credentials" in str(exception) + return False + + class retry_if_temporary_quota(tenacity.retry_if_exception): """Retries if there was an exception for exceeding the temporary quote limit.""" @@ -122,12 +135,19 @@ def __init__(self): class retry_if_operation_in_progress(tenacity.retry_if_exception): - """Retries if there was an exception for exceeding the temporary quote limit.""" + """Retries if there was an exception in case of operation in progress.""" def __init__(self): super().__init__(is_operation_in_progress_exception) +class retry_if_temporary_refresh_credentials(tenacity.retry_if_exception): + """Retries if there was an exception for refreshing credentials.""" + + def __init__(self): + super().__init__(is_refresh_credentials_exception) + + # A fake project_id to use in functions decorated by fallback_to_default_project_id # This allows the 'project_id' argument to be of type str instead of str | None, # making it easier to type hint the function body without dealing with the None @@ -426,7 +446,7 @@ def scopes(self) -> Sequence[str]: def quota_retry(*args, **kwargs) -> Callable: """Provide a mechanism to repeat requests in response to exceeding a temporary quota limit.""" - def decorator(fun: Callable): + def decorator(func: Callable): default_kwargs = { "wait": tenacity.wait_exponential(multiplier=1, max=100), "retry": retry_if_temporary_quota(), @@ -434,7 +454,7 @@ def decorator(fun: Callable): "after": tenacity.after_log(log, logging.DEBUG), } default_kwargs.update(**kwargs) - return tenacity.retry(*args, **default_kwargs)(fun) + return tenacity.retry(*args, **default_kwargs)(func) return decorator @@ -442,7 +462,7 @@ def decorator(fun: Callable): def operation_in_progress_retry(*args, **kwargs) -> Callable[[T], T]: """Provide a mechanism to repeat requests in response to operation in progress (HTTP 409) limit.""" - def decorator(fun: T): + def decorator(func: T): default_kwargs = { "wait": tenacity.wait_exponential(multiplier=1, max=300), "retry": retry_if_operation_in_progress(), @@ -450,7 +470,25 @@ def decorator(fun: T): "after": tenacity.after_log(log, logging.DEBUG), } default_kwargs.update(**kwargs) - return cast(T, tenacity.retry(*args, **default_kwargs)(fun)) + return cast(T, tenacity.retry(*args, **default_kwargs)(func)) + + return decorator + + @staticmethod + def refresh_credentials_retry(*args, **kwargs) -> Callable[[T], T]: + """Provide a mechanism to repeat requests in response to a temporary refresh credential issue.""" + + def decorator(func: T): + default_kwargs = { + "wait": tenacity.wait_exponential(multiplier=1, max=5), + "stop": tenacity.stop_after_attempt(3), + "retry": retry_if_temporary_refresh_credentials(), + "reraise": True, + "before": tenacity.before_log(log, logging.DEBUG), + "after": tenacity.after_log(log, logging.DEBUG), + } + default_kwargs.update(**kwargs) + return cast(T, tenacity.retry(*args, **default_kwargs)(func)) return decorator diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py b/tests/providers/google/cloud/hooks/test_bigquery.py index 37096b0ff3bde9..c63dc581a99fd5 100644 --- a/tests/providers/google/cloud/hooks/test_bigquery.py +++ b/tests/providers/google/cloud/hooks/test_bigquery.py @@ -26,6 +26,7 @@ import pytest from gcloud.aio.bigquery import Job, Table as Table_async from google.api_core import page_iterator +from google.auth.exceptions import RefreshError from google.cloud.bigquery import DEFAULT_RETRY, DatasetReference, Table, TableReference from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem from google.cloud.bigquery.table import _EmptyRowIterator @@ -598,6 +599,37 @@ def test_poll_job_complete(self, mock_client): mock_client.return_value.get_job.assert_called_once_with(job_id=JOB_ID) mock_client.return_value.get_job.return_value.done.assert_called_once_with(retry=DEFAULT_RETRY) + @mock.patch("tenacity.nap.time.sleep", mock.MagicMock()) + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client") + def test_get_job_credentials_refresh_error(self, mock_client): + error = "Unable to acquire impersonated credentials" + response_body = "\n\n \n" + mock_job = mock.MagicMock( + job_id="123456_hash", + error_result=False, + state="PENDING", + done=lambda: False, + ) + mock_client.return_value.get_job.side_effect = [RefreshError(error, response_body), mock_job] + + job = self.hook.get_job(job_id=JOB_ID, location=LOCATION, project_id=PROJECT_ID) + mock_client.assert_any_call(location=LOCATION, project_id=PROJECT_ID) + assert mock_client.call_count == 2 + assert job == mock_job + + @pytest.mark.parametrize( + "error", + [ + RefreshError("Other error", "test body"), + ValueError(), + ], + ) + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client") + def test_get_job_credentials_error(self, mock_client, error): + mock_client.return_value.get_job.side_effect = error + with pytest.raises(type(error)): + self.hook.get_job(job_id=JOB_ID, location=LOCATION, project_id=PROJECT_ID) + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.poll_job_complete") @mock.patch("logging.Logger.info") def test_cancel_query_jobs_to_cancel( diff --git a/tests/providers/google/common/hooks/test_base_google.py b/tests/providers/google/common/hooks/test_base_google.py index ab2e26f59bda76..1a91f0742d5106 100644 --- a/tests/providers/google/common/hooks/test_base_google.py +++ b/tests/providers/google/common/hooks/test_base_google.py @@ -30,14 +30,14 @@ import pytest import tenacity from google.auth.environment_vars import CREDENTIALS -from google.auth.exceptions import GoogleAuthError +from google.auth.exceptions import GoogleAuthError, RefreshError from google.cloud.exceptions import Forbidden from airflow import version from airflow.exceptions import AirflowException from airflow.providers.google.cloud.utils.credentials_provider import _DEFAULT_SCOPES from airflow.providers.google.common.hooks import base_google as hook -from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook, is_refresh_credentials_exception from tests.providers.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id default_creds_available = True @@ -98,6 +98,46 @@ def test_raise_exception_on_non_quota_exception(self): ) +class TestRefreshCredentialsRetry: + @pytest.mark.parametrize( + "exc, retryable", + [ + (RefreshError("Other error", "test body"), False), + (RefreshError("Unable to acquire impersonated credentials", "test body"), True), + (ValueError(), False), + ], + ) + def test_is_refresh_credentials_exception(self, exc, retryable): + assert is_refresh_credentials_exception(exc) is retryable + + def test_do_nothing_on_non_error(self): + @hook.GoogleBaseHook.refresh_credentials_retry() + def func(): + return 42 + + assert func() == 42 + + def test_raise_non_refresh_error(self): + @hook.GoogleBaseHook.refresh_credentials_retry() + def func(): + raise ValueError() + + with pytest.raises(ValueError): + func() + + @mock.patch("tenacity.nap.time.sleep", mock.MagicMock()) + def test_retry_on_refresh_error(self): + func_return = mock.Mock( + side_effect=[RefreshError("Unable to acquire impersonated credentials", "test body"), 42] + ) + + @hook.GoogleBaseHook.refresh_credentials_retry() + def func(): + return func_return() + + assert func() == 42 + + class FallbackToDefaultProjectIdFixtureClass: def __init__(self, project_id): self.mock = mock.Mock()