Skip to content

Commit

Permalink
Apply new logic for parsing WWW-Authenticate header
Browse files Browse the repository at this point in the history
  • Loading branch information
hovaesco committed Jun 27, 2024
1 parent 169226e commit ca7ad24
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
2 changes: 2 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,8 @@ def test_oauth2_authentication_missing_headers(header, error):
'Bearer x_token_server="{token_server}", x_redirect_server="{redirect_server}"',
'Basic realm="Trino", Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}"',
'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", Basic realm="Trino"',
'realm="Trino", Bearer realm="Trino", token_type="JWT", Bearer x_redirect_server="{redirect_server}", '
'x_token_server="{token_server}"'
])
@httprettified
def test_oauth2_header_parsing(header, sample_post_response_data):
Expand Down
25 changes: 21 additions & 4 deletions trino/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from requests import PreparedRequest, Request, Response, Session
from requests.auth import AuthBase, extract_cookies_to_jar
from requests.utils import parse_dict_header

import trino.logging
from trino.client import exceptions
Expand Down Expand Up @@ -421,10 +420,13 @@ def _attempt_oauth(self, response: Response, **kwargs: Any) -> None:
if not _OAuth2TokenBearer._BEARER_PREFIX.search(auth_info):
raise exceptions.TrinoAuthError(f"Error: header info didn't match {auth_info}")

auth_info_headers = parse_dict_header(
_OAuth2TokenBearer._BEARER_PREFIX.sub("", auth_info, count=1)) # type: ignore
# Example www-authenticate header value:
# 'Basic realm="Trino", Bearer realm="Trino", token_type="JWT",
# Bearer x_redirect_server="https://trino.com/oauth2/token/uuid4",
# x_token_server="https://trino.com/oauth2/token/uuid4"'
auth_info_headers = self._parse_authenticate_header(auth_info)

auth_server = auth_info_headers.get('x_redirect_server')
auth_server = auth_info_headers.get('bearer x_redirect_server', auth_info_headers.get('x_redirect_server'))
token_server = auth_info_headers.get('x_token_server')
if token_server is None:
raise exceptions.TrinoAuthError("Error: header info didn't have x_token_server")
Expand Down Expand Up @@ -510,6 +512,21 @@ def _construct_cache_key(host: Optional[str], user: Optional[str]) -> Optional[s
else:
return f"{host}@{user}"

@staticmethod
def _parse_authenticate_header(header: str) -> Dict[str, str]:
split_challenge = header.split(" ", 1)
trimmed_challenge = split_challenge[1] if len(split_challenge) > 1 else ""
auth_info_headers = {}

for item in trimmed_challenge.split(","):
comps = item.split("=")
if len(comps) == 2:
key = comps[0].strip(' "')
value = comps[1].strip(' "')
if key:
auth_info_headers[key.lower()] = value
return auth_info_headers


class OAuth2Authentication(Authentication):
def __init__(self, redirect_auth_url_handler: CompositeRedirectHandler = CompositeRedirectHandler([
Expand Down

0 comments on commit ca7ad24

Please sign in to comment.