Skip to content

Commit

Permalink
CDK: retry token refresh requests (#23815)
Browse files Browse the repository at this point in the history
* #23767 CDK: retry token refresh requests

* Automated Commit - Formatting Changes

---------

Co-authored-by: davydov-d <davydov-d@users.noreply.github.com>
  • Loading branch information
davydov-d and davydov-d authored Mar 7, 2023
1 parent b6b4203 commit f9f1402
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 17 deletions.
17 changes: 17 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/sources/streams/http/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@
#


import logging
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple

import backoff
import pendulum
import requests
from deprecated import deprecated

from ..exceptions import DefaultBackoffException
from .core import HttpAuthenticator

logger = logging.getLogger("airbyte")


@deprecated(version="0.1.20", reason="Use airbyte_cdk.sources.streams.http.requests_native_auth.Oauth2Authenticator instead")
class Oauth2Authenticator(HttpAuthenticator):
Expand Down Expand Up @@ -69,6 +74,14 @@ def get_refresh_request_body(self) -> Mapping[str, Any]:

return payload

@backoff.on_exception(
backoff.expo,
DefaultBackoffException,
on_backoff=lambda details: logger.info(
f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
),
max_time=300,
)
def refresh_access_token(self) -> Tuple[str, int]:
"""
returns a tuple of (access_token, token_lifespan_in_seconds)
Expand All @@ -83,6 +96,10 @@ def refresh_access_token(self) -> Tuple[str, int]:
response.raise_for_status()
response_json = response.json()
return response_json["access_token"], response_json["expires_in"]
except requests.exceptions.RequestException as e:
if e.response.status_code == 429 or e.response.status_code >= 500:
raise DefaultBackoffException(request=e.response.request, response=e.response)
raise
except Exception as e:
raise Exception(f"Error while refreshing access token: {e}") from e

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import logging
from abc import abstractmethod
from typing import Any, List, Mapping, MutableMapping, Tuple, Union

import backoff
import pendulum
import requests
from requests.auth import AuthBase

from ..exceptions import DefaultBackoffException

logger = logging.getLogger("airbyte")


class AbstractOauth2Authenticator(AuthBase):
"""
Expand Down Expand Up @@ -64,22 +70,34 @@ def build_refresh_request_body(self) -> Mapping[str, Any]:

return payload

@backoff.on_exception(
backoff.expo,
DefaultBackoffException,
on_backoff=lambda details: logger.info(
f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
),
max_time=300,
)
def _get_refresh_access_token_response(self):
response = requests.request(method="POST", url=self.get_token_refresh_endpoint(), data=self.build_refresh_request_body())
response.raise_for_status()
return response.json()
try:
response = requests.request(method="POST", url=self.get_token_refresh_endpoint(), data=self.build_refresh_request_body())
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
if e.response.status_code == 429 or e.response.status_code >= 500:
raise DefaultBackoffException(request=e.response.request, response=e.response)
raise
except Exception as e:
raise Exception(f"Error while refreshing access token: {e}") from e

def refresh_access_token(self) -> Tuple[str, int]:
"""
Returns the refresh token and its lifespan in seconds
:return: a tuple of (access_token, token_lifespan_in_seconds)
"""
try:
response_json = self._get_refresh_access_token_response()
return response_json[self.get_access_token_name()], response_json[self.get_expires_in_name()]
except Exception as e:
raise Exception(f"Error while refreshing access token: {e}") from e
response_json = self._get_refresh_access_token_response()
return response_json[self.get_access_token_name()], response_json[self.get_expires_in_name()]

@abstractmethod
def get_token_refresh_endpoint(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,9 @@ def get_access_token(self) -> str:
return self.access_token

def refresh_access_token(self) -> Tuple[str, int, str]:
try:
response_json = self._get_refresh_access_token_response()
return (
response_json[self.get_access_token_name()],
response_json[self.get_expires_in_name()],
response_json[self.get_refresh_token_name()],
)
except Exception as e:
raise Exception(f"Error while refreshing access token and refresh token: {e}") from e
response_json = self._get_refresh_access_token_response()
return (
response_json[self.get_access_token_name()],
response_json[self.get_expires_in_name()],
response_json[self.get_refresh_token_name()],
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import logging

import pytest
from airbyte_cdk.sources.streams.http.auth import (
BasicHttpAuthenticator,
MultipleTokenAuthenticator,
Expand Down Expand Up @@ -143,6 +144,24 @@ def test_refresh_access_token(self, requests_mock):
assert self.refresh_access_token_headers[header] == mock_refresh_token_call.last_request.headers[header]
assert mock_refresh_token_call.called

@pytest.mark.parametrize("error_code", (429, 500, 502, 504))
def test_refresh_access_token_retry(self, error_code, requests_mock):
oauth = Oauth2Authenticator(
TestOauth2Authenticator.refresh_endpoint,
TestOauth2Authenticator.client_id,
TestOauth2Authenticator.client_secret,
TestOauth2Authenticator.refresh_token
)
requests_mock.post(
TestOauth2Authenticator.refresh_endpoint,
[
{"status_code": error_code}, {"status_code": error_code}, {"json": {"access_token": "token", "expires_in": 10}}
]
)
token, expires_in = oauth.refresh_access_token()
assert (token, expires_in) == ("token", 10)
assert requests_mock.call_count == 3

def test_refresh_access_authenticator(self):
oauth = Oauth2Authenticator(
TestOauth2Authenticator.refresh_endpoint,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,24 @@ def test_refresh_access_token(self, mocker):

assert ("access_token", 1000) == token

@pytest.mark.parametrize("error_code", (429, 500, 502, 504))
def test_refresh_access_token_retry(self, error_code, requests_mock):
oauth = Oauth2Authenticator(
f"https://{TestOauth2Authenticator.refresh_endpoint}",
TestOauth2Authenticator.client_id,
TestOauth2Authenticator.client_secret,
TestOauth2Authenticator.refresh_token
)
requests_mock.post(
f"https://{TestOauth2Authenticator.refresh_endpoint}",
[
{"status_code": error_code}, {"status_code": error_code}, {"json": {"access_token": "token", "expires_in": 10}}
]
)
token, expires_in = oauth.refresh_access_token()
assert (token, expires_in) == ("token", 10)
assert requests_mock.call_count == 3

def test_auth_call_method(self, mocker):
oauth = Oauth2Authenticator(
token_refresh_endpoint=TestOauth2Authenticator.refresh_endpoint,
Expand Down

0 comments on commit f9f1402

Please sign in to comment.