Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CDK: Emit control message on config mutation #19428

Merged
merged 30 commits into from
Nov 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f3d4b5a
wip
alafanechere Nov 14, 2022
24d24dd
implementation
alafanechere Nov 15, 2022
355b90b
format
alafanechere Nov 15, 2022
61e9d4b
bump version
alafanechere Nov 15, 2022
1101ccc
Merge branch 'master' into augustin/cdk/emit-updated-configs
alafanechere Nov 15, 2022
22f75c7
always update by default, even if same value
alafanechere Nov 16, 2022
d562117
rename split_config to filter_internal_keywords
alafanechere Nov 16, 2022
733b89a
bing ads example
alafanechere Nov 16, 2022
a59a1e6
wrap around AirbyteMessage
alafanechere Nov 16, 2022
616256e
exclude unset
alafanechere Nov 16, 2022
3ef545d
observer does not write config to disk
alafanechere Nov 21, 2022
0ff7318
revert global changes
alafanechere Nov 21, 2022
9e1483c
revert global changes
alafanechere Nov 21, 2022
2bbff33
revert global changes
alafanechere Nov 21, 2022
1f02860
observe from Oauth2Authenticator
alafanechere Nov 21, 2022
f40d99c
Merge branch 'master' into augustin/cdk/emit-updated-configs
alafanechere Nov 21, 2022
04041d8
ref
alafanechere Nov 21, 2022
12db1b7
handle list of dicts
alafanechere Nov 21, 2022
63d129c
implement SingleUseRefreshTokenOauth2Authenticator
alafanechere Nov 23, 2022
62ef553
test SingleUseRefreshTokenOauth2Authenticator
alafanechere Nov 23, 2022
bbe802a
call copy in ObservedDict
alafanechere Nov 23, 2022
8fd52f9
add docstring
alafanechere Nov 23, 2022
c5a0b6d
source harvest example
alafanechere Nov 23, 2022
803b70a
use dpath
alafanechere Nov 25, 2022
37f474f
better doc string
alafanechere Nov 25, 2022
08f64e4
update changelog
alafanechere Nov 25, 2022
518a36f
use sequence instead of string path for dpath declaration
alafanechere Nov 29, 2022
64bbc48
Merge branch 'master' into augustin/cdk/emit-updated-configs
alafanechere Nov 29, 2022
4da6581
revert connector changes
alafanechere Nov 29, 2022
380e175
format
alafanechere Nov 29, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions airbyte-cdk/python/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

## 0.11.0
Declare a new authenticator `SingleUseRefreshTokenOauth2Authenticator` that can perform connector configuration mutation and emit `AirbyteControlMessage.ConnectorConfig`.

## 0.10.0
Low-code: Add `start_from_page` option to a PageIncrement class

Expand Down
76 changes: 76 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/config_observation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from __future__ import ( # Used to evaluate type hints at runtime, a NameError: name 'ConfigObserver' is not defined is thrown otherwise
annotations,
)

import time
from typing import Any, List, MutableMapping

from airbyte_cdk.models import AirbyteControlConnectorConfigMessage, AirbyteControlMessage, AirbyteMessage, OrchestratorType, Type


class ObservedDict(dict):
def __init__(self, non_observed_mapping: MutableMapping, observer: ConfigObserver, update_on_unchanged_value=True) -> None:
non_observed_mapping = non_observed_mapping.copy()
self.observer = observer
self.update_on_unchanged_value = update_on_unchanged_value
for item, value in non_observed_mapping.items():
# Observe nested dicts
if isinstance(value, MutableMapping):
sherifnada marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be DRYd with the logic in set item?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, will do.

non_observed_mapping[item] = ObservedDict(value, observer)

# Observe nested list of dicts
if isinstance(value, List):
for i, sub_value in enumerate(value):
if isinstance(sub_value, MutableMapping):
value[i] = ObservedDict(sub_value, observer)
super().__init__(non_observed_mapping)

def __setitem__(self, item: Any, value: Any):
"""Override dict.__setitem__ by:
1. Observing the new value if it is a dict
2. Call observer update if the new value is different from the previous one
"""
previous_value = self.get(item)
if isinstance(value, MutableMapping):
value = ObservedDict(value, self.observer)
if isinstance(value, List):
for i, sub_value in enumerate(value):
if isinstance(sub_value, MutableMapping):
value[i] = ObservedDict(sub_value, self.observer)
super(ObservedDict, self).__setitem__(item, value)
if self.update_on_unchanged_value or value != previous_value:
self.observer.update()


class ConfigObserver:
"""This class is made to track mutations on ObservedDict config.
When update is called a CONNECTOR_CONFIG control message is emitted on stdout.
"""

def set_config(self, config: ObservedDict) -> None:
self.config = config

def update(self) -> None:
self._emit_airbyte_control_message()

def _emit_airbyte_control_message(self) -> None:
control_message = AirbyteControlMessage(
type=OrchestratorType.CONNECTOR_CONFIG,
emitted_at=time.time() * 1000,
connectorConfig=AirbyteControlConnectorConfigMessage(config=self.config),
)
airbyte_message = AirbyteMessage(type=Type.CONTROL, control=control_message)
print(airbyte_message.json(exclude_unset=True))


def observe_connector_config(non_observed_connector_config: MutableMapping[str, Any]):
if isinstance(non_observed_connector_config, ObservedDict):
raise ValueError("This connector configuration is already observed")
sherifnada marked this conversation as resolved.
Show resolved Hide resolved
connector_config_observer = ConfigObserver()
observed_connector_config = ObservedDict(non_observed_connector_config, connector_config_observer)
connector_config_observer.set_config(observed_connector_config)
return observed_connector_config
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from .oauth import Oauth2Authenticator
from .oauth import Oauth2Authenticator, SingleUseRefreshTokenOauth2Authenticator
from .token import BasicHttpAuthenticator, MultipleTokenAuthenticator, TokenAuthenticator

__all__ = ["Oauth2Authenticator", "TokenAuthenticator", "MultipleTokenAuthenticator", "BasicHttpAuthenticator"]
__all__ = [
"Oauth2Authenticator",
"SingleUseRefreshTokenOauth2Authenticator",
"TokenAuthenticator",
"MultipleTokenAuthenticator",
"BasicHttpAuthenticator",
]
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,19 @@ def build_refresh_request_body(self) -> Mapping[str, Any]:

return payload

def _get_refresh_access_token_response(self):
response = requests.request(method="POST", url=self.get_token_refresh_endpoint(), data=self.build_refresh_request_body())
response.raise_for_status()
return response.json()

def refresh_access_token(self) -> Tuple[str, int]:
"""
Returns the refresh token and its lifespan in seconds

:return: a tuple of (access_token, token_lifespan_in_seconds)
"""
try:
response = requests.request(method="POST", url=self.get_token_refresh_endpoint(), data=self.build_refresh_request_body())
response.raise_for_status()
response_json = response.json()
response_json = self._get_refresh_access_token_response()
return response_json[self.get_access_token_name()], response_json[self.get_expires_in_name()]
except Exception as e:
raise Exception(f"Error while refreshing access token: {e}") from e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from typing import Any, List, Mapping
from typing import Any, List, Mapping, Sequence, Tuple

import dpath
import pendulum
from airbyte_cdk.config_observation import observe_connector_config
from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth import AbstractOauth2Authenticator


class Oauth2Authenticator(AbstractOauth2Authenticator):
"""
Generates OAuth2.0 access tokens from an OAuth2.0 refresh token and client credentials.
The generated access token is attached to each request via the Authorization header.
If a connector_config is provided any mutation of it's value in the scope of this class will emit AirbyteControlConnectorConfigMessage.
"""

def __init__(
Expand Down Expand Up @@ -80,3 +83,126 @@ def access_token(self) -> str:
@access_token.setter
def access_token(self, value: str):
self._access_token = value


class SingleUseRefreshTokenOauth2Authenticator(Oauth2Authenticator):
"""
Authenticator that should be used for API implementing single use refresh tokens:
when refreshing access token some API returns a new refresh token that needs to used in the next refresh flow.
This authenticator updates the configuration with new refresh token by emitting Airbyte control message from an observed mutation.
By default this authenticator expects a connector config with a"credentials" field with the following nested fields: client_id, client_secret, refresh_token.
This behavior can be changed by defining custom config path (using dpath paths) in client_id_config_path, client_secret_config_path, refresh_token_config_path constructor arguments.
"""

def __init__(
self,
connector_config: Mapping[str, Any],
alafanechere marked this conversation as resolved.
Show resolved Hide resolved
token_refresh_endpoint: str,
scopes: List[str] = None,
token_expiry_date: pendulum.DateTime = None,
access_token_name: str = "access_token",
expires_in_name: str = "expires_in",
refresh_token_name: str = "refresh_token",
refresh_request_body: Mapping[str, Any] = None,
grant_type: str = "refresh_token",
client_id_config_path: Sequence[str] = ("credentials", "client_id"),
client_secret_config_path: Sequence[str] = ("credentials", "client_secret"),
refresh_token_config_path: Sequence[str] = ("credentials", "refresh_token"),
):
"""

Args:
connector_config (Mapping[str, Any]): The full connector configuration
token_refresh_endpoint (str): Full URL to the token refresh endpoint
scopes (List[str], optional): List of OAuth scopes to pass in the refresh token request body. Defaults to None.
token_expiry_date (pendulum.DateTime, optional): Datetime at which the current token will expire. Defaults to None.
access_token_name (str, optional): Name of the access token field, used to parse the refresh token response. Defaults to "access_token".
expires_in_name (str, optional): Name of the name of the field that characterizes when the current access token will expire, used to parse the refresh token response. Defaults to "expires_in".
refresh_token_name (str, optional): Name of the name of the refresh token field, used to parse the refresh token response. Defaults to "refresh_token".
refresh_request_body (Mapping[str, Any], optional): Custom key value pair that will be added to the refresh token request body. Defaults to None.
grant_type (str, optional): OAuth grant type. Defaults to "refresh_token".
client_id_config_path (Sequence[str]): Dpath to the client_id field in the connector configuration. Defaults to ("credentials", "client_id").
client_secret_config_path (Sequence[str]): Dpath to the client_secret field in the connector configuration. Defaults to ("credentials", "client_secret").
refresh_token_config_path (Sequence[str]): Dpath to the refresh_token field in the connector configuration. Defaults to ("credentials", "refresh_token").
"""
self._client_id_config_path = client_id_config_path
self._client_secret_config_path = client_secret_config_path
self._refresh_token_config_path = refresh_token_config_path
self._refresh_token_name = refresh_token_name
self._connector_config = observe_connector_config(connector_config)
self._validate_connector_config()
super().__init__(
token_refresh_endpoint,
self.get_client_id(),
self.get_client_secret(),
self.get_refresh_token(),
scopes,
token_expiry_date,
access_token_name,
expires_in_name,
refresh_request_body,
grant_type,
)

def _validate_connector_config(self):
"""Validates the defined getters for configuration values are returning values.

Raises:
ValueError: Raised if the defined getters are not returning a value.
"""
for field_path, getter, parameter_name in [
(self._client_id_config_path, self.get_client_id, "client_id_config_path"),
(self._client_secret_config_path, self.get_client_secret, "client_secret_config_path"),
(self._refresh_token_config_path, self.get_refresh_token, "refresh_token_config_path"),
]:
try:
assert getter()
except KeyError:
raise ValueError(
f"This authenticator expects a value under the {field_path} field path. Please check your configuration structure or change the {parameter_name} value at initialization of this authenticator."
)

def get_refresh_token_name(self) -> str:
return self._refresh_token_name

def get_client_id(self) -> str:
return dpath.util.get(self._connector_config, self._client_id_config_path)

def get_client_secret(self) -> str:
return dpath.util.get(self._connector_config, self._client_secret_config_path)

def get_refresh_token(self) -> str:
return dpath.util.get(self._connector_config, self._refresh_token_config_path)

def set_refresh_token(self, new_refresh_token: str):
"""Set the new refresh token value. The mutation of the connector_config object will emit an Airbyte control message.

Args:
new_refresh_token (str): The new refresh token value.
"""
dpath.util.set(self._connector_config, self._refresh_token_config_path, new_refresh_token)

def get_access_token(self) -> str:
"""Retrieve new access and refresh token if the access token has expired.
The new refresh token is persisted with the set_refresh_token function
Returns:
str: The current access_token, updated if it was previously expired.
"""
if self.token_has_expired():
t0 = pendulum.now()
new_access_token, access_token_expires_in, new_refresh_token = self.refresh_access_token()
self.access_token = new_access_token
self.set_token_expiry_date(t0.add(seconds=access_token_expires_in))
self.set_refresh_token(new_refresh_token)
return self.access_token

def refresh_access_token(self) -> Tuple[str, int, str]:
try:
response_json = self._get_refresh_access_token_response()
return (
response_json[self.get_access_token_name()],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would recommend using dpath

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed but for implementation consistency with parent class and abstract class, I'd prefer to do this in a separate PR in which we can replace the access_token_name, expires_in_name, and refresh_token_name by dpaths that can be used when parsing these responses. Wdyt?

Copy link
Contributor

@sherifnada sherifnada Nov 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense!

response_json[self.get_expires_in_name()],
response_json[self.get_refresh_token_name()],
)
except Exception as e:
raise Exception(f"Error while refreshing access token and refresh token: {e}") from e
2 changes: 1 addition & 1 deletion airbyte-cdk/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

setup(
name="airbyte-cdk",
version="0.10.0",
version="0.11.0",
description="A framework for writing Airbyte Connectors.",
long_description=README,
long_description_content_type="text/markdown",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

import json
import logging

import pendulum
import pytest
import requests
from airbyte_cdk.config_observation import ObservedDict
from airbyte_cdk.sources.streams.http.requests_native_auth import (
BasicHttpAuthenticator,
MultipleTokenAuthenticator,
Oauth2Authenticator,
SingleUseRefreshTokenOauth2Authenticator,
TokenAuthenticator,
)
from requests import Response
Expand Down Expand Up @@ -175,6 +179,73 @@ def test_auth_call_method(self, mocker):
assert {"Authorization": "Bearer access_token"} == prepared_request.headers


class TestSingleUseRefreshTokenOauth2Authenticator:
@pytest.fixture
def connector_config(self):
return {
"credentials": {
"access_token": "my_access_token",
"refresh_token": "my_refresh_token",
"client_id": "my_client_id",
"client_secret": "my_client_secret",
}
}

@pytest.fixture
def invalid_connector_config(self):
return {"no_credentials_key": "foo"}

def test_init(self, connector_config):
authenticator = SingleUseRefreshTokenOauth2Authenticator(
connector_config,
token_refresh_endpoint="foobar",
)
assert isinstance(authenticator._connector_config, ObservedDict)

def test_init_with_invalid_config(self, invalid_connector_config):
with pytest.raises(ValueError):
SingleUseRefreshTokenOauth2Authenticator(
invalid_connector_config,
token_refresh_endpoint="foobar",
)

def test_get_access_token(self, capsys, mocker, connector_config):
authenticator = SingleUseRefreshTokenOauth2Authenticator(
connector_config,
token_refresh_endpoint="foobar",
)
authenticator.refresh_access_token = mocker.Mock(return_value=("new_access_token", 42, "new_refresh_token"))
authenticator.token_has_expired = mocker.Mock(return_value=True)
access_token = authenticator.get_access_token()
captured = capsys.readouterr()
airbyte_message = json.loads(captured.out)
expected_new_config = connector_config.copy()
expected_new_config["credentials"]["refresh_token"] = "new_refresh_token"
assert airbyte_message["control"]["connectorConfig"]["config"] == expected_new_config
assert authenticator.access_token == access_token == "new_access_token"
assert authenticator.get_refresh_token() == "new_refresh_token"
assert authenticator.get_token_expiry_date() > pendulum.now()
authenticator.token_has_expired = mocker.Mock(return_value=False)
access_token = authenticator.get_access_token()
captured = capsys.readouterr()
assert not captured.out
assert authenticator.access_token == access_token == "new_access_token"

def test_refresh_access_token(self, mocker, connector_config):
authenticator = SingleUseRefreshTokenOauth2Authenticator(
connector_config,
token_refresh_endpoint="foobar",
)
authenticator._get_refresh_access_token_response = mocker.Mock(
return_value={
authenticator.get_access_token_name(): "new_access_token",
authenticator.get_expires_in_name(): 42,
authenticator.get_refresh_token_name(): "new_refresh_token",
}
)
assert authenticator.refresh_access_token() == ("new_access_token", 42, "new_refresh_token")


def mock_request(method, url, data):
if url == "refresh_end":
return resp
Expand Down
Loading