From 88282f91530e2108ce287928acae8bac57bb07bd Mon Sep 17 00:00:00 2001 From: Przemek Denkiewicz Date: Thu, 27 Jun 2024 13:05:36 +0200 Subject: [PATCH] Apply new logic for parsing WWW-Authenticate header --- trino/auth.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/trino/auth.py b/trino/auth.py index 6262f95a..5139e0b1 100644 --- a/trino/auth.py +++ b/trino/auth.py @@ -22,7 +22,6 @@ from requests import PreparedRequest, Request, Response, Session from requests.auth import AuthBase, extract_cookies_to_jar -from requests.utils import parse_dict_header import trino.logging from trino.client import exceptions @@ -421,10 +420,21 @@ def _attempt_oauth(self, response: Response, **kwargs: Any) -> None: if not _OAuth2TokenBearer._BEARER_PREFIX.search(auth_info): raise exceptions.TrinoAuthError(f"Error: header info didn't match {auth_info}") - auth_info_headers = parse_dict_header( - _OAuth2TokenBearer._BEARER_PREFIX.sub("", auth_info, count=1)) # type: ignore + split_challenge = auth_info.split(" ", 1) + self.scheme = split_challenge[0] + trimmed_challenge = split_challenge[1] - auth_server = auth_info_headers.get('x_redirect_server') + auth_info_headers = {} + + for item in trimmed_challenge.split(","): + comps = item.split("=") + if len(comps) == 2: + key = comps[0].strip(' "') + value = comps[1].strip(' "') + if key: + auth_info_headers[key.lower()] = value + + auth_server = auth_info_headers.get('bearer x_redirect_server', auth_info_headers.get('x_redirect_server')) token_server = auth_info_headers.get('x_token_server') if token_server is None: raise exceptions.TrinoAuthError("Error: header info didn't have x_token_server") @@ -443,7 +453,7 @@ def _attempt_oauth(self, response: Response, **kwargs: Any) -> None: request = response.request host = self._determine_host(request.url) user = self._determine_user(request.headers) - key = self._construct_cache_key(host, user) + key = self._construct_cache_key(host, user) # type: ignore self._store_token_to_cache(key, token) def _retry_request(self, response: Response, **kwargs: Any) -> Optional[Response]: