From cce3d687532d90e0902ae464252c6a0cea1a0c23 Mon Sep 17 00:00:00 2001 From: Przemek Denkiewicz Date: Wed, 14 Aug 2024 16:42:00 +0200 Subject: [PATCH] Shard long JWT token on Windows --- setup.py | 1 + tests/unit/test_client.py | 83 ++++++++++++++++++++++++++++++++++++++- tests/unit/test_dbapi.py | 2 +- trino/auth.py | 40 +++++++++++++++++-- trino/constants.py | 1 + 5 files changed, 120 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index 2047b415..fcd7cf5c 100755 --- a/setup.py +++ b/setup.py @@ -46,6 +46,7 @@ "pre-commit", "black", "isort", + "keyring" ] setup( diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 9b5613ff..d5d70ec9 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -16,11 +16,12 @@ import uuid from contextlib import nullcontext as does_not_raise from typing import Any, Dict, Optional -from unittest import mock +from unittest import TestCase, mock from urllib.parse import urlparse import gssapi import httpretty +import keyring import pytest import requests from httpretty import httprettified @@ -42,7 +43,12 @@ _post_statement_requests, ) from trino import __version__, constants -from trino.auth import GSSAPIAuthentication, KerberosAuthentication, _OAuth2TokenBearer +from trino.auth import ( + GSSAPIAuthentication, + KerberosAuthentication, + _OAuth2KeyRingTokenCache, + _OAuth2TokenBearer, +) from trino.client import ( ClientSession, TrinoQuery, @@ -1343,3 +1349,76 @@ def test_request_with_invalid_timezone(mock_get_and_post): ), ) assert str(zinfo_error.value).startswith("'No time zone found with key") + + +class TestShardedPassword(TestCase): + def test_store_short_password(self): + # set the keyring to mock class + keyring.set_keyring(MockKeyring()) + + host = "trino.com" + short_password = "x" * 10 + + cache = _OAuth2KeyRingTokenCache() + cache.store_token_to_cache(host, short_password) + + retrieved_password = cache.get_token_from_cache(host) + self.assertEqual(short_password, retrieved_password) + + def test_store_long_password(self): + # set the keyring to mock class + keyring.set_keyring(MockKeyring()) + + host = "trino.com" + long_password = "x" * 3000 + + cache = _OAuth2KeyRingTokenCache() + cache.store_token_to_cache(host, long_password) + + retrieved_password = cache.get_token_from_cache(host) + self.assertEqual(long_password, retrieved_password) + + +class MockKeyring(keyring.backend.KeyringBackend): + def __init__(self): + self.file_location = self._generate_test_root_dir() + + @staticmethod + def _generate_test_root_dir(): + import tempfile + + return tempfile.mkdtemp(prefix="trino-python-client-unit-test-") + + def file_path(self, servicename, username): + from os.path import join + + file_location = self.file_location + file_name = f"{servicename}_{username}.txt" + return join(file_location, file_name) + + def set_password(self, servicename, username, password): + file_path = self.file_path(servicename, username) + + with open(file_path, "w") as file: + file.write(password) + + def get_password(self, servicename, username): + import os + + file_path = self.file_path(servicename, username) + if not os.path.exists(file_path): + return None + + with open(file_path, "r") as file: + password = file.read() + + return password + + def delete_password(self, servicename, username): + import os + + file_path = self.file_path(servicename, username) + if not os.path.exists(file_path): + return None + + os.remove(file_path) diff --git a/tests/unit/test_dbapi.py b/tests/unit/test_dbapi.py index e2857235..17462cb7 100644 --- a/tests/unit/test_dbapi.py +++ b/tests/unit/test_dbapi.py @@ -119,7 +119,7 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sampl conn2.cursor().execute("SELECT 2") conn2.cursor().execute("SELECT 3") - assert len(_get_token_requests(challenge_id)) == 2 + assert len(_get_token_requests(challenge_id)) == 1 @httprettified diff --git a/trino/auth.py b/trino/auth.py index bf9b4b3b..e5939403 100644 --- a/trino/auth.py +++ b/trino/auth.py @@ -25,7 +25,7 @@ import trino.logging from trino.client import exceptions -from trino.constants import HEADER_USER +from trino.constants import HEADER_USER, MAX_NT_PASSWORD_SIZE logger = trino.logging.get_logger(__name__) @@ -347,17 +347,49 @@ def is_keyring_available(self) -> bool: and not isinstance(self._keyring.get_keyring(), self._keyring.backends.fail.Keyring) def get_token_from_cache(self, key: Optional[str]) -> Optional[str]: + password = self._keyring.get_password(key, "token") + try: - return self._keyring.get_password(key, "token") + password_as_dict = json.loads(str(password)) + if password_as_dict.get("sharded_password"): + # if password was stored shared, reconstruct it + shard_count = int(password_as_dict.get("shard_count")) + + password = "" + for i in range(shard_count): + password += str(self._keyring.get_password(key, f"token__{i}")) + except self._keyring.errors.NoKeyringError as e: raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been " "detected, check https://pypi.org/project/keyring/ for more " "information.") from e + except ValueError: + pass + + return password def store_token_to_cache(self, key: Optional[str], token: str) -> None: + # keyring is installed, so we can store the token for reuse within multiple threads try: - # keyring is installed, so we can store the token for reuse within multiple threads - self._keyring.set_password(key, "token", token) + # if not Windows or "small" password, stick to the default + if os.name != "nt" or len(token) < MAX_NT_PASSWORD_SIZE: + self._keyring.set_password(key, "token", token) + else: + logger.debug(f"password is {len(token)} characters, sharding it.") + + password_shards = [ + token[i: i + MAX_NT_PASSWORD_SIZE] for i in range(0, len(token), MAX_NT_PASSWORD_SIZE) + ] + shard_info = { + "sharded_password": True, + "shard_count": len(password_shards), + } + + # store the "shard info" as the "base" password + self._keyring.set_password(key, "token", json.dumps(shard_info)) + # then store all shards with the shard number as postfix + for i, s in enumerate(password_shards): + self._keyring.set_password(key, f"token__{i}", s) except self._keyring.errors.NoKeyringError as e: raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been " "detected, check https://pypi.org/project/keyring/ for more " diff --git a/trino/constants.py b/trino/constants.py index d4ba904d..1dd0df94 100644 --- a/trino/constants.py +++ b/trino/constants.py @@ -20,6 +20,7 @@ DEFAULT_AUTH: Optional[Any] = None DEFAULT_MAX_ATTEMPTS = 3 DEFAULT_REQUEST_TIMEOUT: float = 30.0 +MAX_NT_PASSWORD_SIZE: int = 1280 HTTP = "http" HTTPS = "https"