diff --git a/providers/trino/docs/connections.rst b/providers/trino/docs/connections.rst index f6bd0b46c1e5f..32e0abe360d61 100644 --- a/providers/trino/docs/connections.rst +++ b/providers/trino/docs/connections.rst @@ -55,5 +55,7 @@ Extra (optional, connection parameters) * ``session_properties`` - JSON dictionary which allows to set session_properties. Example: ``{'session_properties':{'scale_writers':true,'task_writer_count:1'}}`` * ``client_tags`` - List of comma separated tags. Example ``{'client_tags':['sales','cluster1']}``` * ``timezone`` - The time zone for the session can be explicitly set using the IANA time zone name. Example: ``{'timezone':'Asia/Jerusalem'}``. + * ``extra_credential`` - List of key-value string pairs which are passed to the Trino connector. For more information, refer to the Trino client protocol doc page here: https://trino.io/docs/current/develop/client-protocol.html + * ``roles`` - Mapping of catalog names to their corresponding Trino authorization role. For more information, refer to the Trino Python client docs here: https://github.com/trinodb/trino-python-client?tab=readme-ov-file#roles Note: If ``jwt__file`` and ``jwt__token`` are both given, ``jwt__file`` will take precedent. diff --git a/providers/trino/src/airflow/providers/trino/hooks/trino.py b/providers/trino/src/airflow/providers/trino/hooks/trino.py index 1cca01a2770dd..35bb83cce2a8a 100644 --- a/providers/trino/src/airflow/providers/trino/hooks/trino.py +++ b/providers/trino/src/airflow/providers/trino/hooks/trino.py @@ -211,6 +211,8 @@ def get_conn(self) -> Connection: session_properties=extra.get("session_properties") or None, client_tags=extra.get("client_tags") or None, timezone=extra.get("timezone") or None, + extra_credential=extra.get("extra_credential") or None, + roles=extra.get("roles") or None, ) return trino_conn diff --git a/providers/trino/tests/unit/trino/hooks/test_trino.py b/providers/trino/tests/unit/trino/hooks/test_trino.py index 7e485a28138bb..02966d2e919f3 100644 --- a/providers/trino/tests/unit/trino/hooks/test_trino.py +++ b/providers/trino/tests/unit/trino/hooks/test_trino.py @@ -258,6 +258,22 @@ def test_get_conn_timezone(self, mock_connect, mock_get_connection): TrinoHook().get_conn() self.assert_connection_called_with(mock_connect, timezone="Asia/Jerusalem") + @patch(HOOK_GET_CONNECTION) + @patch(TRINO_DBAPI_CONNECT) + def test_get_conn_extra_credential(self, mock_connect, mock_get_connection): + extras = {"extra_credential": [["a.username", "bar"], ["a.password", "foo"]]} + self.set_get_connection_return_value(mock_get_connection, extra=json.dumps(extras)) + TrinoHook().get_conn() + self.assert_connection_called_with(mock_connect, extra_credential=extras["extra_credential"]) + + @patch(HOOK_GET_CONNECTION) + @patch(TRINO_DBAPI_CONNECT) + def test_get_conn_roles(self, mock_connect, mock_get_connection): + extras = {"roles": {"catalog1": "trinoRoleA", "catalog2": "trinoRoleB"}} + self.set_get_connection_return_value(mock_get_connection, extra=json.dumps(extras)) + TrinoHook().get_conn() + self.assert_connection_called_with(mock_connect, roles=extras["roles"]) + @staticmethod def set_get_connection_return_value(mock_get_connection, extra=None, password=None): mocked_connection = Connection( @@ -274,6 +290,8 @@ def assert_connection_called_with( session_properties=None, client_tags=None, timezone=None, + extra_credential=None, + roles=None, ): mock_connect.assert_called_once_with( catalog="hive", @@ -290,6 +308,8 @@ def assert_connection_called_with( session_properties=session_properties, client_tags=client_tags, timezone=timezone, + extra_credential=extra_credential, + roles=roles, )