Skip to content

Commit 01b937c

Browse files
committed
Apply new logic for parsing WWW-Authenticate header
1 parent 169226e commit 01b937c

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

trino/auth.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
from requests import PreparedRequest, Request, Response, Session
2424
from requests.auth import AuthBase, extract_cookies_to_jar
25-
from requests.utils import parse_dict_header
2625

2726
import trino.logging
2827
from trino.client import exceptions
@@ -421,10 +420,21 @@ def _attempt_oauth(self, response: Response, **kwargs: Any) -> None:
421420
if not _OAuth2TokenBearer._BEARER_PREFIX.search(auth_info):
422421
raise exceptions.TrinoAuthError(f"Error: header info didn't match {auth_info}")
423422

424-
auth_info_headers = parse_dict_header(
425-
_OAuth2TokenBearer._BEARER_PREFIX.sub("", auth_info, count=1)) # type: ignore
423+
split_challenge = auth_info.split(" ", 1)
424+
self.scheme = split_challenge[0]
425+
trimmed_challenge = split_challenge[1] if len(split_challenge) > 1 else ""
426426

427-
auth_server = auth_info_headers.get('x_redirect_server')
427+
auth_info_headers = {}
428+
429+
for item in trimmed_challenge.split(","):
430+
comps = item.split("=")
431+
if len(comps) == 2:
432+
key = comps[0].strip(' "')
433+
value = comps[1].strip(' "')
434+
if key:
435+
auth_info_headers[key.lower()] = value
436+
437+
auth_server = auth_info_headers.get('bearer x_redirect_server', auth_info_headers.get('x_redirect_server'))
428438
token_server = auth_info_headers.get('x_token_server')
429439
if token_server is None:
430440
raise exceptions.TrinoAuthError("Error: header info didn't have x_token_server")
@@ -443,7 +453,7 @@ def _attempt_oauth(self, response: Response, **kwargs: Any) -> None:
443453
request = response.request
444454
host = self._determine_host(request.url)
445455
user = self._determine_user(request.headers)
446-
key = self._construct_cache_key(host, user)
456+
key = self._construct_cache_key(host, user) # type: ignore
447457
self._store_token_to_cache(key, token)
448458

449459
def _retry_request(self, response: Response, **kwargs: Any) -> Optional[Response]:

0 commit comments

Comments
 (0)