Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CDK: retry token refresh requests #23815

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/sources/streams/http/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@
#


import logging
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple

import backoff
import pendulum
import requests
from deprecated import deprecated

from ..exceptions import DefaultBackoffException
from .core import HttpAuthenticator

logger = logging.getLogger("airbyte")


@deprecated(version="0.1.20", reason="Use airbyte_cdk.sources.streams.http.requests_native_auth.Oauth2Authenticator instead")
class Oauth2Authenticator(HttpAuthenticator):
Expand Down Expand Up @@ -69,6 +74,14 @@ def get_refresh_request_body(self) -> Mapping[str, Any]:

return payload

@backoff.on_exception(
backoff.expo,
DefaultBackoffException,
on_backoff=lambda details: logger.info(
f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
),
max_time=300,
)
def refresh_access_token(self) -> Tuple[str, int]:
"""
returns a tuple of (access_token, token_lifespan_in_seconds)
Expand All @@ -83,6 +96,10 @@ def refresh_access_token(self) -> Tuple[str, int]:
response.raise_for_status()
response_json = response.json()
return response_json["access_token"], response_json["expires_in"]
except requests.exceptions.RequestException as e:
if e.response.status_code == 429 or e.response.status_code >= 500:
raise DefaultBackoffException(request=e.response.request, response=e.response)
raise
except Exception as e:
raise Exception(f"Error while refreshing access token: {e}") from e

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import logging
from abc import abstractmethod
from typing import Any, List, Mapping, MutableMapping, Tuple, Union

import backoff
import pendulum
import requests
from requests.auth import AuthBase

from ..exceptions import DefaultBackoffException

logger = logging.getLogger("airbyte")


class AbstractOauth2Authenticator(AuthBase):
"""
Expand Down Expand Up @@ -64,22 +70,34 @@ def build_refresh_request_body(self) -> Mapping[str, Any]:

return payload

@backoff.on_exception(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just out of curiosity. Is it possible to disable or configure this behavior if needed? Not urgent/blocking.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well disabling may be a bit more more challenging than configuring it but still I think both are possible. Anyway, I think it's worth doing in a dedicated PR if we do need these options

backoff.expo,
DefaultBackoffException,
on_backoff=lambda details: logger.info(
f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
),
max_time=300,
)
def _get_refresh_access_token_response(self):
response = requests.request(method="POST", url=self.get_token_refresh_endpoint(), data=self.build_refresh_request_body())
response.raise_for_status()
return response.json()
try:
response = requests.request(method="POST", url=self.get_token_refresh_endpoint(), data=self.build_refresh_request_body())
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
if e.response.status_code == 429 or e.response.status_code >= 500:
raise DefaultBackoffException(request=e.response.request, response=e.response)
raise
except Exception as e:
raise Exception(f"Error while refreshing access token: {e}") from e

def refresh_access_token(self) -> Tuple[str, int]:
"""
Returns the refresh token and its lifespan in seconds

:return: a tuple of (access_token, token_lifespan_in_seconds)
"""
try:
response_json = self._get_refresh_access_token_response()
return response_json[self.get_access_token_name()], response_json[self.get_expires_in_name()]
except Exception as e:
raise Exception(f"Error while refreshing access token: {e}") from e
response_json = self._get_refresh_access_token_response()
return response_json[self.get_access_token_name()], response_json[self.get_expires_in_name()]

@abstractmethod
def get_token_refresh_endpoint(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,9 @@ def get_access_token(self) -> str:
return self.access_token

def refresh_access_token(self) -> Tuple[str, int, str]:
try:
response_json = self._get_refresh_access_token_response()
return (
response_json[self.get_access_token_name()],
response_json[self.get_expires_in_name()],
response_json[self.get_refresh_token_name()],
)
except Exception as e:
raise Exception(f"Error while refreshing access token and refresh token: {e}") from e
response_json = self._get_refresh_access_token_response()
return (
response_json[self.get_access_token_name()],
response_json[self.get_expires_in_name()],
response_json[self.get_refresh_token_name()],
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import logging

import pytest
from airbyte_cdk.sources.streams.http.auth import (
BasicHttpAuthenticator,
MultipleTokenAuthenticator,
Expand Down Expand Up @@ -143,6 +144,24 @@ def test_refresh_access_token(self, requests_mock):
assert self.refresh_access_token_headers[header] == mock_refresh_token_call.last_request.headers[header]
assert mock_refresh_token_call.called

@pytest.mark.parametrize("error_code", (429, 500, 502, 504))
def test_refresh_access_token_retry(self, error_code, requests_mock):
oauth = Oauth2Authenticator(
TestOauth2Authenticator.refresh_endpoint,
TestOauth2Authenticator.client_id,
TestOauth2Authenticator.client_secret,
TestOauth2Authenticator.refresh_token
)
requests_mock.post(
TestOauth2Authenticator.refresh_endpoint,
[
{"status_code": error_code}, {"status_code": error_code}, {"json": {"access_token": "token", "expires_in": 10}}
]
)
token, expires_in = oauth.refresh_access_token()
assert (token, expires_in) == ("token", 10)
assert requests_mock.call_count == 3

def test_refresh_access_authenticator(self):
oauth = Oauth2Authenticator(
TestOauth2Authenticator.refresh_endpoint,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,24 @@ def test_refresh_access_token(self, mocker):

assert ("access_token", 1000) == token

@pytest.mark.parametrize("error_code", (429, 500, 502, 504))
def test_refresh_access_token_retry(self, error_code, requests_mock):
oauth = Oauth2Authenticator(
f"https://{TestOauth2Authenticator.refresh_endpoint}",
TestOauth2Authenticator.client_id,
TestOauth2Authenticator.client_secret,
TestOauth2Authenticator.refresh_token
)
requests_mock.post(
f"https://{TestOauth2Authenticator.refresh_endpoint}",
[
{"status_code": error_code}, {"status_code": error_code}, {"json": {"access_token": "token", "expires_in": 10}}
]
)
token, expires_in = oauth.refresh_access_token()
assert (token, expires_in) == ("token", 10)
assert requests_mock.call_count == 3

def test_auth_call_method(self, mocker):
oauth = Oauth2Authenticator(
token_refresh_endpoint=TestOauth2Authenticator.refresh_endpoint,
Expand Down