Skip to content

Commit

Permalink
Support refresh tokens flow
Browse files Browse the repository at this point in the history
  • Loading branch information
sinkuladis committed Nov 9, 2022
1 parent 2e272a0 commit ba4537a
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 4 deletions.
3 changes: 3 additions & 0 deletions tests/unit/oauth_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def __call__(self, request, uri, response_headers):
authorization = request.headers.get("Authorization")
if authorization and authorization.replace("Bearer ", "") in self.tokens:
return [200, response_headers, json.dumps(self.sample_post_response_data)]
elif self.redirect_server is None and self.token_server is not None:
return [401, {'Www-Authenticate': f'Bearer x_token_server="{self.token_server}"',
'Basic realm': '"Trino"'}, ""]
return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{self.redirect_server}", '
f'x_token_server="{self.token_server}"',
'Basic realm': '"Trino"'}, ""]
Expand Down
53 changes: 53 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,59 @@ def test_oauth2_authentication_flow(attempts, sample_post_response_data):
assert len(_get_token_requests(challenge_id)) == attempts


@httprettified
def test_oauth2_refresh_token_flow(sample_post_response_data):
token = str(uuid.uuid4())
challenge_id = str(uuid.uuid4())

redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
token_server = f"{TOKEN_RESOURCE}/{challenge_id}"

post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data)

# bind post statement
httpretty.register_uri(
method=httpretty.POST,
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
body=post_statement_callback)

# bind get token
get_token_callback = GetTokenCallback(token_server, token)
httpretty.register_uri(
method=httpretty.GET,
uri=token_server,
body=get_token_callback)

redirect_handler = RedirectHandler()

request = TrinoRequest(
host="coordinator",
port=constants.DEFAULT_TLS_PORT,
client_session=ClientSession(
user="test",
),
http_scheme=constants.HTTPS,
auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler))

request.post("select 1")

# post response without x_redirect_server
post_statement_callback = PostStatementCallback(None, token_server, [token], sample_post_response_data)

# rebind post statement
httpretty.register_uri(
method=httpretty.POST,
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
body=post_statement_callback)

response = request.post("select 1")

assert response.request.headers['Authorization'] == f"Bearer {token}"
assert redirect_handler.redirect_server == redirect_server
assert get_token_callback.attempts == 0
assert len(_post_statement_requests()) == 3


@pytest.mark.parametrize("attempts", [6, 10])
@httprettified
def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data):
Expand Down
12 changes: 8 additions & 4 deletions trino/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ class _OAuth2TokenBearer(AuthBase):
_BEARER_PREFIX = re.compile(r"bearer", flags=re.IGNORECASE)

def __init__(self, redirect_auth_url_handler: Callable[[str], None]):
self._redirect_auth_url = redirect_auth_url_handler
self._redirect_auth_url = None
self._redirect_auth_url_handler = redirect_auth_url_handler
keyring_cache = _OAuth2KeyRingTokenCache()
self._token_cache = keyring_cache if keyring_cache.is_keyring_available() else _OAuth2TokenInMemoryCache()
self._token_lock = threading.Lock()
Expand Down Expand Up @@ -322,15 +323,18 @@ def _attempt_oauth(self, response, **kwargs):
auth_info_headers = parse_dict_header(_OAuth2TokenBearer._BEARER_PREFIX.sub("", auth_info, count=1))

auth_server = auth_info_headers.get('x_redirect_server')
if auth_server is None:
if auth_server is None and self._redirect_auth_url is None:
# app didn't receive redirect url neither now nor in previous responses
raise exceptions.TrinoAuthError("Error: header info didn't have x_redirect_server")

token_server = auth_info_headers.get('x_token_server')
if token_server is None:
raise exceptions.TrinoAuthError("Error: header info didn't have x_token_server")

# tell app that use this url to proceed with the authentication
self._redirect_auth_url(auth_server)
if auth_server is not None:
# tell app that use this url to proceed with the authentication
self._redirect_auth_url = auth_server
self._redirect_auth_url_handler(self._redirect_auth_url)

# Consume content and release the original connection
# to allow our new request to reuse the same one.
Expand Down

0 comments on commit ba4537a

Please sign in to comment.