Skip to content

Commit

Permalink
🐛 fix declarative oauth initialization (#32967)
Browse files Browse the repository at this point in the history
Co-authored-by: girarda <girarda@users.noreply.github.com>
  • Loading branch information
girarda and girarda authored Jan 9, 2024
1 parent 1737ab1 commit c8ca4b1
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 59 deletions.
72 changes: 43 additions & 29 deletions airbyte-cdk/python/airbyte_cdk/sources/declarative/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,74 +46,88 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
refresh_token: Optional[Union[InterpolatedString, str]] = None
scopes: Optional[List[str]] = None
token_expiry_date: Optional[Union[InterpolatedString, str]] = None
_token_expiry_date: pendulum.DateTime = field(init=False, repr=False, default=None)
token_expiry_date_format: str = None
_token_expiry_date: Optional[pendulum.DateTime] = field(init=False, repr=False, default=None)
token_expiry_date_format: Optional[str] = None
token_expiry_is_time_of_expiration: bool = False
access_token_name: Union[InterpolatedString, str] = "access_token"
expires_in_name: Union[InterpolatedString, str] = "expires_in"
refresh_request_body: Optional[Mapping[str, Any]] = None
grant_type: Union[InterpolatedString, str] = "refresh_token"
message_repository: MessageRepository = NoopMessageRepository()

def __post_init__(self, parameters: Mapping[str, Any]):
self.token_refresh_endpoint = InterpolatedString.create(self.token_refresh_endpoint, parameters=parameters)
self.client_id = InterpolatedString.create(self.client_id, parameters=parameters)
self.client_secret = InterpolatedString.create(self.client_secret, parameters=parameters)
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
super().__init__()
self._token_refresh_endpoint = InterpolatedString.create(self.token_refresh_endpoint, parameters=parameters)
self._client_id = InterpolatedString.create(self.client_id, parameters=parameters)
self._client_secret = InterpolatedString.create(self.client_secret, parameters=parameters)
if self.refresh_token is not None:
self.refresh_token = InterpolatedString.create(self.refresh_token, parameters=parameters)
self._refresh_token = InterpolatedString.create(self.refresh_token, parameters=parameters)
else:
self._refresh_token = None
self.access_token_name = InterpolatedString.create(self.access_token_name, parameters=parameters)
self.expires_in_name = InterpolatedString.create(self.expires_in_name, parameters=parameters)
self.grant_type = InterpolatedString.create(self.grant_type, parameters=parameters)
self._refresh_request_body = InterpolatedMapping(self.refresh_request_body or {}, parameters=parameters)
self._token_expiry_date = (
pendulum.parse(InterpolatedString.create(self.token_expiry_date, parameters=parameters).eval(self.config))
self._token_expiry_date: pendulum.DateTime = (
pendulum.parse(InterpolatedString.create(self.token_expiry_date, parameters=parameters).eval(self.config)) # type: ignore # pendulum.parse returns a datetime in this context
if self.token_expiry_date
else pendulum.now().subtract(days=1)
else pendulum.now().subtract(days=1) # type: ignore # substract does not have type hints
)
self._access_token = None
self._access_token: Optional[str] = None # access_token is initialized by a setter

if self.get_grant_type() == "refresh_token" and self.refresh_token is None:
if self.get_grant_type() == "refresh_token" and self._refresh_token is None:
raise ValueError("OAuthAuthenticator needs a refresh_token parameter if grant_type is set to `refresh_token`")

def get_token_refresh_endpoint(self) -> str:
return self.token_refresh_endpoint.eval(self.config)
refresh_token: str = self._token_refresh_endpoint.eval(self.config)
if not refresh_token:
raise ValueError("OAuthAuthenticator was unable to evaluate token_refresh_endpoint parameter")
return refresh_token

def get_client_id(self) -> str:
return self.client_id.eval(self.config)
client_id: str = self._client_id.eval(self.config)
if not client_id:
raise ValueError("OAuthAuthenticator was unable to evaluate client_id parameter")
return client_id

def get_client_secret(self) -> str:
return self.client_secret.eval(self.config)
client_secret: str = self._client_secret.eval(self.config)
if not client_secret:
raise ValueError("OAuthAuthenticator was unable to evaluate client_secret parameter")
return client_secret

def get_refresh_token(self) -> Optional[str]:
return None if self.refresh_token is None else self.refresh_token.eval(self.config)
return None if self._refresh_token is None else self._refresh_token.eval(self.config)

def get_scopes(self) -> [str]:
return self.scopes
def get_scopes(self) -> List[str]:
return self.scopes or []

def get_access_token_name(self) -> InterpolatedString:
return self.access_token_name.eval(self.config)
def get_access_token_name(self) -> str:
return self.access_token_name.eval(self.config) # type: ignore # eval returns a string in this context

def get_expires_in_name(self) -> InterpolatedString:
return self.expires_in_name.eval(self.config)
def get_expires_in_name(self) -> str:
return self.expires_in_name.eval(self.config) # type: ignore # eval returns a string in this context

def get_grant_type(self) -> InterpolatedString:
return self.grant_type.eval(self.config)
def get_grant_type(self) -> str:
return self.grant_type.eval(self.config) # type: ignore # eval returns a string in this context

def get_refresh_request_body(self) -> Mapping[str, Any]:
return self._refresh_request_body.eval(self.config)
return self._refresh_request_body.eval(self.config) # type: ignore # eval should return a Mapping in this context

def get_token_expiry_date(self) -> pendulum.DateTime:
return self._token_expiry_date
return self._token_expiry_date # type: ignore # _token_expiry_date is a pendulum.DateTime. It is never None despite what mypy thinks

def set_token_expiry_date(self, value: Union[str, int]):
def set_token_expiry_date(self, value: Union[str, int]) -> None:
self._token_expiry_date = self._parse_token_expiration_date(value)

@property
def access_token(self) -> str:
if self._access_token is None:
raise ValueError("access_token is not set")
return self._access_token

@access_token.setter
def access_token(self, value: str):
def access_token(self, value: str) -> None:
self._access_token = value

@property
Expand All @@ -130,5 +144,5 @@ class DeclarativeSingleUseRefreshTokenOauth2Authenticator(SingleUseRefreshTokenO
Declarative version of SingleUseRefreshTokenOauth2Authenticator which can be used in declarative connectors.
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
self._refresh_token_error_key = refresh_token_error_key
self._refresh_token_error_values = refresh_token_error_values

def __call__(self, request: requests.Request) -> requests.Request:
def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
"""Attach the HTTP headers required to authenticate on the HTTP request"""
request.headers.update(self.get_auth_header())
return request
Expand All @@ -65,7 +65,7 @@ def get_access_token(self) -> str:

def token_has_expired(self) -> bool:
"""Returns True if the token is expired"""
return pendulum.now() > self.get_token_expiry_date()
return pendulum.now() > self.get_token_expiry_date() # type: ignore # this is always a bool despite what mypy thinks

def build_refresh_request_body(self) -> Mapping[str, Any]:
"""
Expand All @@ -80,7 +80,7 @@ def build_refresh_request_body(self) -> Mapping[str, Any]:
"refresh_token": self.get_refresh_token(),
}

if self.get_scopes:
if self.get_scopes():
payload["scopes"] = self.get_scopes()

if self.get_refresh_request_body():
Expand All @@ -93,7 +93,10 @@ def build_refresh_request_body(self) -> Mapping[str, Any]:

def _wrap_refresh_token_exception(self, exception: requests.exceptions.RequestException) -> bool:
try:
exception_content = exception.response.json()
if exception.response is not None:
exception_content = exception.response.json()
else:
return False
except JSONDecodeError:
return False
return (
Expand All @@ -109,15 +112,16 @@ def _wrap_refresh_token_exception(self, exception: requests.exceptions.RequestEx
),
max_time=300,
)
def _get_refresh_access_token_response(self):
def _get_refresh_access_token_response(self) -> Any:
try:
response = requests.request(method="POST", url=self.get_token_refresh_endpoint(), data=self.build_refresh_request_body())
self._log_response(response)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
if e.response.status_code == 429 or e.response.status_code >= 500:
raise DefaultBackoffException(request=e.response.request, response=e.response)
if e.response is not None:
if e.response.status_code == 429 or e.response.status_code >= 500:
raise DefaultBackoffException(request=e.response.request, response=e.response)
if self._wrap_refresh_token_exception(e):
message = "Refresh token is invalid or expired. Please re-authenticate from Sources/<your source>/Settings."
raise AirbyteTracedException(internal_message=message, message=message, failure_type=FailureType.config_error)
Expand Down Expand Up @@ -147,7 +151,7 @@ def _parse_token_expiration_date(self, value: Union[str, int]) -> pendulum.DateT
raise ValueError(
f"Invalid token expiry date format {self.token_expiry_date_format}; a string representing the format is required."
)
return pendulum.from_format(value, self.token_expiry_date_format)
return pendulum.from_format(str(value), self.token_expiry_date_format)
else:
return pendulum.now().add(seconds=int(float(value)))

Expand Down Expand Up @@ -192,7 +196,7 @@ def get_token_expiry_date(self) -> pendulum.DateTime:
"""Expiration date of the access token"""

@abstractmethod
def set_token_expiry_date(self, value: Union[str, int]):
def set_token_expiry_date(self, value: Union[str, int]) -> None:
"""Setter for access token expiration date"""

@abstractmethod
Expand Down Expand Up @@ -228,14 +232,15 @@ def _message_repository(self) -> Optional[MessageRepository]:
"""
return _NOOP_MESSAGE_REPOSITORY

def _log_response(self, response: requests.Response):
self._message_repository.log_message(
Level.DEBUG,
lambda: format_http_message(
response,
"Refresh token",
"Obtains access token",
self._NO_STREAM_NAME,
is_auxiliary=True,
),
)
def _log_response(self, response: requests.Response) -> None:
if self._message_repository:
self._message_repository.log_message(
Level.DEBUG,
lambda: format_http_message(
response,
"Refresh token",
"Obtains access token",
self._NO_STREAM_NAME,
is_auxiliary=True,
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def test_refresh_with_encode_config_params(self):
"client_id": base64.b64encode(config["client_id"].encode("utf-8")).decode(),
"client_secret": base64.b64encode(config["client_secret"].encode("utf-8")).decode(),
"refresh_token": None,
"scopes": None,
}
assert body == expected

Expand All @@ -104,7 +103,6 @@ def test_refresh_with_decode_config_params(self):
"client_id": "some_client_id",
"client_secret": "some_client_secret",
"refresh_token": None,
"scopes": None,
}
assert body == expected

Expand All @@ -126,7 +124,6 @@ def test_refresh_without_refresh_token(self):
"client_id": "some_client_id",
"client_secret": "some_client_secret",
"refresh_token": None,
"scopes": None,
}
assert body == expected

Expand Down Expand Up @@ -278,6 +275,28 @@ def test_set_token_expiry_date_no_format(self, mocker, expires_in_response, next
assert "access_token" == token
assert oauth.get_token_expiry_date() == pendulum.parse(next_day)

def test_error_handling(self, mocker):
oauth = DeclarativeOauth2Authenticator(
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
client_id="{{ config['client_id'] }}",
client_secret="{{ config['client_secret'] }}",
refresh_token="{{ config['refresh_token'] }}",
config=config,
scopes=["scope1", "scope2"],
refresh_request_body={
"custom_field": "{{ config['custom_field'] }}",
"another_field": "{{ config['another_field'] }}",
"scopes": ["no_override"],
},
parameters={},
)
resp.status_code = 400
mocker.patch.object(resp, "json", return_value={"access_token": "access_token", "expires_in": 123})
mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True)
with pytest.raises(requests.exceptions.HTTPError) as e:
oauth.refresh_access_token()
assert e.value.errno == 400


def mock_request(method, url, data):
if url == "refresh_end":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,10 @@ def test_interpolate_config():
)

assert isinstance(authenticator, DeclarativeOauth2Authenticator)
assert authenticator.client_id.eval(input_config) == "some_client_id"
assert authenticator.client_secret.string == "some_client_secret"
assert authenticator.token_refresh_endpoint.eval(input_config) == "https://api.sendgrid.com/v3/auth"
assert authenticator.refresh_token.eval(input_config) == "verysecrettoken"
assert authenticator._client_id.eval(input_config) == "some_client_id"
assert authenticator._client_secret.string == "some_client_secret"
assert authenticator._token_refresh_endpoint.eval(input_config) == "https://api.sendgrid.com/v3/auth"
assert authenticator._refresh_token.eval(input_config) == "verysecrettoken"
assert authenticator._refresh_request_body.mapping == {"body_field": "yoyoyo", "interpolated_body_field": "{{ config['apikey'] }}"}
assert authenticator.get_refresh_request_body() == {"body_field": "yoyoyo", "interpolated_body_field": "verysecrettoken"}

Expand All @@ -332,9 +332,9 @@ def test_interpolate_config_with_token_expiry_date_format():
assert isinstance(authenticator, DeclarativeOauth2Authenticator)
assert authenticator.token_expiry_date_format == "%Y-%m-%d %H:%M:%S.%f+00:00"
assert authenticator.token_expiry_is_time_of_expiration
assert authenticator.client_id.eval(input_config) == "some_client_id"
assert authenticator.client_secret.string == "some_client_secret"
assert authenticator.token_refresh_endpoint.eval(input_config) == "https://api.sendgrid.com/v3/auth"
assert authenticator._client_id.eval(input_config) == "some_client_id"
assert authenticator._client_secret.string == "some_client_secret"
assert authenticator._token_refresh_endpoint.eval(input_config) == "https://api.sendgrid.com/v3/auth"


def test_single_use_oauth_branch():
Expand Down

0 comments on commit c8ca4b1

Please sign in to comment.