Skip to content

Commit

Permalink
Fix refresh tokens flow
Browse files Browse the repository at this point in the history
Signed-off-by: sinkuladis <sink.vlad@gmail.com>
  • Loading branch information
sinkuladis committed Nov 17, 2022
1 parent bcd4039 commit 3cf5b80
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 7 deletions.
11 changes: 11 additions & 0 deletions tests/unit/oauth_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def __call__(self, url):
self.redirect_server += url


class RedirectHandlerWithException:
def __init__(self, exception):
self.exception = exception

def __call__(self, url):
raise self.exception


class PostStatementCallback:
def __init__(self, redirect_server, token_server, tokens, sample_post_response_data):
self.redirect_server = redirect_server
Expand All @@ -45,6 +53,9 @@ def __call__(self, request, uri, response_headers):
authorization = request.headers.get("Authorization")
if authorization and authorization.replace("Bearer ", "") in self.tokens:
return [200, response_headers, json.dumps(self.sample_post_response_data)]
elif self.redirect_server is None and self.token_server is not None:
return [401, {'Www-Authenticate': f'Bearer x_token_server="{self.token_server}"',
'Basic realm': '"Trino"'}, ""]
return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{self.redirect_server}", '
f'x_token_server="{self.token_server}"',
'Basic realm': '"Trino"'}, ""]
Expand Down
46 changes: 44 additions & 2 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
MultithreadedTokenServer,
PostStatementCallback,
RedirectHandler,
RedirectHandlerWithException,
_get_token_requests,
_post_statement_requests,
)
Expand Down Expand Up @@ -384,6 +385,48 @@ def test_oauth2_authentication_flow(attempts, sample_post_response_data):
assert len(_get_token_requests(challenge_id)) == attempts


@httprettified
def test_oauth2_refresh_token_flow(sample_post_response_data):
token = str(uuid.uuid4())
challenge_id = str(uuid.uuid4())

token_server = f"{TOKEN_RESOURCE}/{challenge_id}"

post_statement_callback = PostStatementCallback(None, token_server, [token], sample_post_response_data)

# bind post statement
httpretty.register_uri(
method=httpretty.POST,
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
body=post_statement_callback)

# bind get token
get_token_callback = GetTokenCallback(token_server, token)
httpretty.register_uri(
method=httpretty.GET,
uri=token_server,
body=get_token_callback)

redirect_handler = RedirectHandlerWithException(
trino.exceptions.TrinoAuthError(
"Do not use redirect handler when there is no redirect_uri in the response"))

request = TrinoRequest(
host="coordinator",
port=constants.DEFAULT_TLS_PORT,
client_session=ClientSession(
user="test",
),
http_scheme=constants.HTTPS,
auth=trino.auth.OAuth2Authentication(redirect_auth_url_handler=redirect_handler))

response = request.post("select 1")

assert response.request.headers['Authorization'] == f"Bearer {token}"
assert get_token_callback.attempts == 0
assert len(_post_statement_requests()) == 2


@pytest.mark.parametrize("attempts", [6, 10])
@httprettified
def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data):
Expand Down Expand Up @@ -430,10 +473,9 @@ def test_oauth2_exceed_max_attempts(attempts, sample_post_response_data):

@pytest.mark.parametrize("header,error", [
("", "Error: header WWW-Authenticate not available in the response."),
('Bearer"', 'Error: header info didn\'t have x_redirect_server'),
('Bearer"', 'Error: header info didn\'t have x_token_server'),
('x_redirect_server="redirect_server", x_token_server="token_server"', 'Error: header info didn\'t match x_redirect_server="redirect_server", x_token_server="token_server"'), # noqa: E501
('Bearer x_redirect_server="redirect_server"', 'Error: header info didn\'t have x_token_server'),
('Bearer x_token_server="token_server"', 'Error: header info didn\'t have x_redirect_server'),
])
@httprettified
def test_oauth2_authentication_missing_headers(header, error):
Expand Down
8 changes: 3 additions & 5 deletions trino/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,15 +322,13 @@ def _attempt_oauth(self, response, **kwargs):
auth_info_headers = parse_dict_header(_OAuth2TokenBearer._BEARER_PREFIX.sub("", auth_info, count=1))

auth_server = auth_info_headers.get('x_redirect_server')
if auth_server is None:
raise exceptions.TrinoAuthError("Error: header info didn't have x_redirect_server")

token_server = auth_info_headers.get('x_token_server')
if token_server is None:
raise exceptions.TrinoAuthError("Error: header info didn't have x_token_server")

# tell app that use this url to proceed with the authentication
self._redirect_auth_url(auth_server)
if auth_server is not None:
# tell app that use this url to proceed with the authentication
self._redirect_auth_url(auth_server)

# Consume content and release the original connection
# to allow our new request to reuse the same one.
Expand Down

0 comments on commit 3cf5b80

Please sign in to comment.