Skip to content

Commit

Permalink
fix: add retry logic in case of google auth refresh credential error
Browse files Browse the repository at this point in the history
  • Loading branch information
dondaum committed Apr 15, 2024
1 parent 6520653 commit 0f506dd
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 3 deletions.
1 change: 1 addition & 0 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,7 @@ def cancel_job(
time.sleep(5)

@GoogleBaseHook.fallback_to_default_project_id
@GoogleBaseHook.refresh_credentials_retry()
def get_job(
self,
job_id: str,
Expand Down
40 changes: 39 additions & 1 deletion airflow/providers/google/common/hooks/base_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,19 @@ def is_operation_in_progress_exception(exception: Exception) -> bool:
return False


def is_refresh_credentials_exception(exception: Exception) -> bool:
"""
Handle refresh credentials exceptions.
Some calls return 502 (server error) in case a new token cannot be obtained.
* Google BigQuery
"""
if isinstance(exception, RefreshError):
return "Unable to acquire impersonated credentials" in str(exception)
return False


class retry_if_temporary_quota(tenacity.retry_if_exception):
"""Retries if there was an exception for exceeding the temporary quote limit."""

Expand All @@ -122,12 +135,19 @@ def __init__(self):


class retry_if_operation_in_progress(tenacity.retry_if_exception):
"""Retries if there was an exception for exceeding the temporary quote limit."""
"""Retries if there was an exception in case of operation in progress."""

def __init__(self):
super().__init__(is_operation_in_progress_exception)


class retry_if_temporary_refresh_credentials(tenacity.retry_if_exception):
"""Retries if there was an exception for refreshing credentials."""

def __init__(self):
super().__init__(is_refresh_credentials_exception)


# A fake project_id to use in functions decorated by fallback_to_default_project_id
# This allows the 'project_id' argument to be of type str instead of str | None,
# making it easier to type hint the function body without dealing with the None
Expand Down Expand Up @@ -454,6 +474,24 @@ def decorator(fun: T):

return decorator

@staticmethod
def refresh_credentials_retry(*args, **kwargs) -> Callable[[T], T]:
"""Provide a mechanism to repeat requests in response to a temporary refresh credential issue."""

def decorator(fun: T):
default_kwargs = {
"wait": tenacity.wait_exponential(multiplier=1, max=5),
"stop": tenacity.stop_after_attempt(3),
"retry": retry_if_temporary_refresh_credentials(),
"reraise": True,
"before": tenacity.before_log(log, logging.DEBUG),
"after": tenacity.after_log(log, logging.DEBUG),
}
default_kwargs.update(**kwargs)
return cast(T, tenacity.retry(*args, **default_kwargs)(fun))

return decorator

@staticmethod
def fallback_to_default_project_id(func: Callable[..., RT]) -> Callable[..., RT]:
"""
Expand Down
32 changes: 32 additions & 0 deletions tests/providers/google/cloud/hooks/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import pytest
from gcloud.aio.bigquery import Job, Table as Table_async
from google.api_core import page_iterator
from google.auth.exceptions import RefreshError
from google.cloud.bigquery import DEFAULT_RETRY, DatasetReference, Table, TableReference
from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem
from google.cloud.bigquery.table import _EmptyRowIterator
Expand Down Expand Up @@ -598,6 +599,37 @@ def test_poll_job_complete(self, mock_client):
mock_client.return_value.get_job.assert_called_once_with(job_id=JOB_ID)
mock_client.return_value.get_job.return_value.done.assert_called_once_with(retry=DEFAULT_RETRY)

@mock.patch("tenacity.nap.time.sleep", mock.MagicMock())
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client")
def test_get_job_credentials_refresh_error(self, mock_client):
error = "Unable to acquire impersonated credentials"
response_body = "<!DOCTYPE html>\n<html lang=en>\n <meta charset=utf-8>\n"
mock_job = mock.MagicMock(
job_id="123456_hash",
error_result=False,
state="PENDING",
done=lambda: False,
)
mock_client.return_value.get_job.side_effect = [RefreshError(error, response_body), mock_job]

job = self.hook.get_job(job_id=JOB_ID, location=LOCATION, project_id=PROJECT_ID)
mock_client.assert_any_call(location=LOCATION, project_id=PROJECT_ID)
assert mock_client.call_count == 2
assert job == mock_job

@pytest.mark.parametrize(
"error",
[
RefreshError("Other error", "test body"),
ValueError(),
],
)
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client")
def test_get_job_credentials_error(self, mock_client, error):
mock_client.return_value.get_job.side_effect = error
with pytest.raises(type(error)):
self.hook.get_job(job_id=JOB_ID, location=LOCATION, project_id=PROJECT_ID)

@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.poll_job_complete")
@mock.patch("logging.Logger.info")
def test_cancel_query_jobs_to_cancel(
Expand Down
44 changes: 42 additions & 2 deletions tests/providers/google/common/hooks/test_base_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@
import pytest
import tenacity
from google.auth.environment_vars import CREDENTIALS
from google.auth.exceptions import GoogleAuthError
from google.auth.exceptions import GoogleAuthError, RefreshError
from google.cloud.exceptions import Forbidden

from airflow import version
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.utils.credentials_provider import _DEFAULT_SCOPES
from airflow.providers.google.common.hooks import base_google as hook
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook, is_refresh_credentials_exception
from tests.providers.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id

default_creds_available = True
Expand Down Expand Up @@ -98,6 +98,46 @@ def test_raise_exception_on_non_quota_exception(self):
)


class TestRefreshCredentialsRetry:
@pytest.mark.parametrize(
"exc, retryable",
[
(RefreshError("Other error", "test body"), False),
(RefreshError("Unable to acquire impersonated credentials", "test body"), True),
(ValueError(), False),
],
)
def test_is_refresh_credentials_exception(self, exc, retryable):
assert is_refresh_credentials_exception(exc) is retryable

def test_do_nothing_on_non_error(self):
@hook.GoogleBaseHook.refresh_credentials_retry()
def func():
return 42

assert func(), 42

def test_raise_non_refresh_error(self):
@hook.GoogleBaseHook.refresh_credentials_retry()
def func():
raise ValueError()

with pytest.raises(ValueError):
func()

@mock.patch("tenacity.nap.time.sleep", mock.MagicMock())
def test_retry_on_refresh_error(self):
func_return = mock.Mock(
side_effect=[RefreshError("Unable to acquire impersonated credentials", "test body"), 42]
)

@hook.GoogleBaseHook.refresh_credentials_retry()
def func():
return func_return()

assert func() == 42


class FallbackToDefaultProjectIdFixtureClass:
def __init__(self, project_id):
self.mock = mock.Mock()
Expand Down

0 comments on commit 0f506dd

Please sign in to comment.