From 026e36f88769581960ed007b8db6b414d71051f8 Mon Sep 17 00:00:00 2001 From: Damian Owsianny Date: Thu, 2 Mar 2023 18:23:35 +0100 Subject: [PATCH] Add role parameter --- README.md | 12 ++++++++++++ tests/integration/test_dbapi_integration.py | 12 ++++++++++++ trino/client.py | 4 +++- 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index f3ea42c3..c418dd23 100644 --- a/README.md +++ b/README.md @@ -359,6 +359,18 @@ conn = trino.dbapi.connect( ) ``` +You could also pass `system` role without explicitly specifing "system" catalog: + +```python +import trino +conn = trino.dbapi.connect( + host='localhost', + port=443, + user='the-user', + roles="role1" # equivalent to {"system": "role1"} +) +``` + ## Timezone The time zone for the session can be explicitly set using the IANA time zone diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index 87d29284..e75d0ce7 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -1144,6 +1144,18 @@ def test_set_role_in_connection(run_trino): assert_role_headers(cur, "system=ALL") +def test_set_system_role_in_connection(run_trino): + _, host, port = run_trino + + trino_connection = trino.dbapi.Connection( + host=host, port=port, user="test", catalog="tpch", roles="ALL" + ) + cur = trino_connection.cursor() + cur.execute('SHOW TABLES FROM information_schema') + cur.fetchall() + assert_role_headers(cur, "system=ALL") + + def assert_role_headers(cursor, expected_header): assert cursor._request.http_headers[constants.HEADER_ROLE] == expected_header diff --git a/trino/client.py b/trino/client.py index 593d579b..30d6ee51 100644 --- a/trino/client.py +++ b/trino/client.py @@ -135,7 +135,7 @@ def __init__( transaction_id: str = None, extra_credential: List[Tuple[str, str]] = None, client_tags: List[str] = None, - roles: Dict[str, str] = None, + roles: Union[Dict[str, str], str] = None, timezone: str = None, ): self._user = user @@ -239,6 +239,8 @@ def timezone(self): return self._timezone def _format_roles(self, roles): + if isinstance(roles, str): + roles = {"system": roles} formatted_roles = {} for catalog, role in roles.items(): is_legacy_role_pattern = ROLE_PATTERN.match(role) is not None