From c540f501756e3d10b9e962bd6b49ea3f6bb26232 Mon Sep 17 00:00:00 2001 From: Huw Date: Tue, 7 May 2024 07:59:13 +0000 Subject: [PATCH] Support object as value in extra_credential --- tests/unit/test_client.py | 34 ++++++++++++++++++++++++++++++++++ trino/client.py | 3 ++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 48ac3948..1d178406 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -866,6 +866,40 @@ def test_extra_credential_value_encoding(mock_get_and_post): assert constants.HEADER_EXTRA_CREDENTIAL in headers assert headers[constants.HEADER_EXTRA_CREDENTIAL] == "foo=bar+%E7%9A%84" +def test_extra_credential_value_object(mock_get_and_post): + _, post = mock_get_and_post + + class TestCredential(object): + value = "initial" + + def __str__(self): + return self.value + + credential = TestCredential() + + req = TrinoRequest( + host="coordinator", + port=constants.DEFAULT_TLS_PORT, + client_session=ClientSession( + user="test", + extra_credential=[("foo", credential)] + ) + ) + + req.post("SELECT 1") + _, post_kwargs = post.call_args + headers = post_kwargs["headers"] + assert constants.HEADER_EXTRA_CREDENTIAL in headers + assert headers[constants.HEADER_EXTRA_CREDENTIAL] == "foo=initial" + + # Make a second request, assert that credential has changed + credential.value = "changed" + req.post("SELECT 1") + _, post_kwargs = post.call_args + headers = post_kwargs["headers"] + assert constants.HEADER_EXTRA_CREDENTIAL in headers + assert headers[constants.HEADER_EXTRA_CREDENTIAL] == "foo=changed" + class MockGssapiCredentials: def __init__(self, name: gssapi.Name, usage: str): diff --git a/trino/client.py b/trino/client.py index 763e0ebc..cbcd15af 100644 --- a/trino/client.py +++ b/trino/client.py @@ -486,7 +486,8 @@ def http_headers(self) -> Dict[str, str]: # extra credential value is encoded per spec (application/x-www-form-urlencoded MIME format) headers[constants.HEADER_EXTRA_CREDENTIAL] = \ ", ".join( - [f"{tup[0]}={urllib.parse.quote_plus(tup[1])}" for tup in self._client_session.extra_credential]) + [f"{tup[0]}={urllib.parse.quote_plus(str(tup[1]))}" + for tup in self._client_session.extra_credential]) return headers