From 70604dea48bdc64983a186aa7329fac1e9b85218 Mon Sep 17 00:00:00 2001 From: EdenKik Date: Mon, 17 Feb 2025 21:28:04 +0200 Subject: [PATCH 1/3] Add TrinoOperator for handling Trino queries --- .../providers/trino/operators/__init__.py | 16 ++++ .../providers/trino/operators/trino.py | 94 +++++++++++++++++++ 2 files changed, 110 insertions(+) create mode 100644 providers/trino/src/airflow/providers/trino/operators/__init__.py create mode 100644 providers/trino/src/airflow/providers/trino/operators/trino.py diff --git a/providers/trino/src/airflow/providers/trino/operators/__init__.py b/providers/trino/src/airflow/providers/trino/operators/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/trino/src/airflow/providers/trino/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/trino/src/airflow/providers/trino/operators/trino.py b/providers/trino/src/airflow/providers/trino/operators/trino.py new file mode 100644 index 0000000000000..1818aef004d94 --- /dev/null +++ b/providers/trino/src/airflow/providers/trino/operators/trino.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains the Trino operator.""" +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any, ClassVar, List + +from airflow.providers.common.sql.hooks.sql import return_single_query_results +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator, default_output_processor +from airflow.providers.trino.hooks.trino import TrinoHook + + +class TrinoOperator(SQLExecuteQueryOperator): + """ + General Trino Operator to execute queries using Trino query engine. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TrinoOperator` + + :param sql: the SQL code to be executed as a single string, or + a list of str (sql statements), or a reference to a template file. + :param trino_conn_id: id of the connection config for the target Trino environment + :param autocommit: What to set the connection's autocommit setting to before executing the query + :param handler: (optional) the function that will be applied to the cursor (default: fetch_all_handler). + :param parameters: (optional) the parameters to render the SQL query with. + :param output_processor: (optional) the function that will be applied to the result + (default: default_output_processor). + :param split_statements: (optional) if split single SQL string into statements. (default: True). + :param show_return_value_in_logs: (optional) if true operator output will be printed to the task log. + Use with caution. It's not recommended to dump large datasets to the log. (default: False). + + """ + + template_fields: Sequence[str] = tuple( + {"trino_conn_id"} + | set(SQLExecuteQueryOperator.template_fields) + ) + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} + conn_id_field = "trino_conn_id" + + def __init__( + self, + *, + sql: str | List[str], + trino_conn_id: str = TrinoHook.default_conn_name, + **kwargs: Any, + ) -> None: + super().__init__(sql=sql, conn_id=trino_conn_id, **kwargs) + self.trino_conn_id = trino_conn_id + + def get_db_hook(self) -> TrinoHook: + return TrinoHook(self.trino_conn_id) + + def execute(self, context): + self.log.info("Executing: %s", self.sql) + hook = self.get_db_hook() + + if self.split_statements is not None: + extra_kwargs = {"split_statements": self.split_statements} + else: + extra_kwargs = {} + + output = hook.run( + sql=self.sql, + autocommit=self.autocommit, + parameters=self.parameters, + handler=self.handler, + return_last=self.return_last, + **extra_kwargs, + ) + if return_single_query_results(self.sql, self.return_last, self.split_statements): + # For simplicity, we pass always list as input to _process_output, regardless if + # single query results are going to be returned, and we return the first element + # of the list in this case from the (always) list returned by _process_output + return self._process_output([output], hook.descriptions)[-1] + result = self._process_output(output, hook.descriptions) + self.log.info("result: %s", result) + return result From b4efd34626f3527f48564d20183c18e38154df58 Mon Sep 17 00:00:00 2001 From: EdenKik Date: Mon, 17 Feb 2025 21:29:09 +0200 Subject: [PATCH 2/3] Add tests for TrinoOperator --- .../providers/trino/operators/trino.py | 7 +- .../tests/unit/trino/operators/__init__.py | 16 +++ .../tests/unit/trino/operators/test_trino.py | 98 +++++++++++++++++++ 3 files changed, 116 insertions(+), 5 deletions(-) create mode 100644 providers/trino/tests/unit/trino/operators/__init__.py create mode 100644 providers/trino/tests/unit/trino/operators/test_trino.py diff --git a/providers/trino/src/airflow/providers/trino/operators/trino.py b/providers/trino/src/airflow/providers/trino/operators/trino.py index 1818aef004d94..b2a08d05bb3c4 100644 --- a/providers/trino/src/airflow/providers/trino/operators/trino.py +++ b/providers/trino/src/airflow/providers/trino/operators/trino.py @@ -21,7 +21,7 @@ from typing import Any, ClassVar, List from airflow.providers.common.sql.hooks.sql import return_single_query_results -from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator, default_output_processor +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.providers.trino.hooks.trino import TrinoHook @@ -47,10 +47,7 @@ class TrinoOperator(SQLExecuteQueryOperator): """ - template_fields: Sequence[str] = tuple( - {"trino_conn_id"} - | set(SQLExecuteQueryOperator.template_fields) - ) + template_fields: Sequence[str] = tuple({"trino_conn_id"} | set(SQLExecuteQueryOperator.template_fields)) template_fields_renderers: ClassVar[dict] = {"sql": "sql"} conn_id_field = "trino_conn_id" diff --git a/providers/trino/tests/unit/trino/operators/__init__.py b/providers/trino/tests/unit/trino/operators/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/trino/tests/unit/trino/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/trino/tests/unit/trino/operators/test_trino.py b/providers/trino/tests/unit/trino/operators/test_trino.py new file mode 100644 index 0000000000000..c54f89405c15c --- /dev/null +++ b/providers/trino/tests/unit/trino/operators/test_trino.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import MagicMock, Mock + +from airflow.exceptions import AirflowException +from airflow.providers.common.sql.hooks.sql import fetch_all_handler +from airflow.providers.trino.hooks.trino import TrinoHook +from airflow.providers.trino.operators.trino import TrinoOperator + + +class TestTrinoOperator: + @mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook") + def test_get_hook_from_conn(self, mock_get_db_hook): + """ + :class:`~.TrinoOperator` should use the hook returned by :meth:`airflow.models.Connection.get_hook` + if one is returned. + + Specifically we verify here that :meth:`~.TrinoOperator.get_hook` returns the hook returned from a + call of ``get_hook`` on the object returned from :meth:`~.BaseHook.get_connection`. + """ + mock_hook = MagicMock() + mock_get_db_hook.return_value = mock_hook + + operator = TrinoOperator(task_id="test", sql="") + assert operator.get_db_hook() == mock_hook + + @mock.patch( + "airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook", + autospec=TrinoHook, + ) + def test_get_hook_default(self, mock_get_db_hook): + """ + If :meth:`airflow.models.Connection.get_hook` does not return a hook (e.g. because of an invalid + conn type), then :class:`~.TrinoHook` should be used. + """ + mock_get_db_hook.return_value.side_effect = Mock(side_effect=AirflowException()) + + operator = TrinoOperator(task_id="test", sql="") + assert operator.get_db_hook().__class__.__name__ == "TrinoHook" + + @mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook") + def test_execute(self, mock_get_db_hook): + sql = "SELECT * FROM test_table" + trino_conn_id = "trino_default" + parameters = ["value"] + autocommit = False + context = "test_context" + task_id = "test_task_id" + + operator = TrinoOperator(sql=sql, trino_conn_id=trino_conn_id, parameters=parameters, task_id=task_id) + operator.execute(context=context) + mock_get_db_hook.return_value.run.assert_called_once_with( + sql=sql, + autocommit=autocommit, + parameters=parameters, + handler=fetch_all_handler, + return_last=True, + ) + + @mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook") + def test_trino_operator_test_multi(self, mock_get_db_hook): + sql = [ + "CREATE TABLE IF NOT EXISTS test_airflow (dummy varchar)", + "TRUNCATE TABLE test_airflow", + "INSERT INTO test_airflow VALUES ('X')", + ] + trino_conn_id = "trino_default" + parameters = ["value"] + autocommit = False + context = "test_context" + task_id = "test_task_id" + + operator = TrinoOperator(sql=sql, trino_conn_id=trino_conn_id, parameters=parameters, task_id=task_id) + operator.execute(context=context) + mock_get_db_hook.return_value.run.assert_called_once_with( + sql=sql, + autocommit=autocommit, + parameters=parameters, + handler=fetch_all_handler, + return_last=True, + ) From 21206a69a22a3ac034596bb445cb84613234cf80 Mon Sep 17 00:00:00 2001 From: EdenKik Date: Mon, 17 Feb 2025 23:51:55 +0200 Subject: [PATCH 3/3] Add TrinoOperator documentation and example DAGs --- providers/trino/docs/operators/trino.rst | 11 ++-- .../trino/tests/system/trino/example_trino.py | 58 +++++++++---------- 2 files changed, 32 insertions(+), 37 deletions(-) diff --git a/providers/trino/docs/operators/trino.rst b/providers/trino/docs/operators/trino.rst index f31bf4c038a14..927a8ecde0ead 100644 --- a/providers/trino/docs/operators/trino.rst +++ b/providers/trino/docs/operators/trino.rst @@ -17,22 +17,19 @@ .. _howto/operator:TrinoOperator: -Connect to Trino using SQLExecuteQueryOperator -============================================== +TrinoOperator +============= -Use the :class:`SQLExecuteQueryOperator ` to execute +Use the :class:`TrinoOperator ` to execute SQL commands in a `Trino `__ query engine. -.. warning:: - TrinoOperator is deprecated in favor of SQLExecuteQueryOperator. If you are using TrinoOperator you should migrate as soon as possible. - Using the Operator ^^^^^^^^^^^^^^^^^^ Use the ``trino_conn_id`` argument to connect to your Trino instance -An example usage of the SQLExecuteQueryOperator to connect to Trino is as follows: +An example usage of the TrinoOperator to connect to Trino is as follows: .. exampleinclude:: /../../providers/trino/tests/system/trino/example_trino.py :language: python diff --git a/providers/trino/tests/system/trino/example_trino.py b/providers/trino/tests/system/trino/example_trino.py index db9fef4128b93..625501f3bca83 100644 --- a/providers/trino/tests/system/trino/example_trino.py +++ b/providers/trino/tests/system/trino/example_trino.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -16,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """ -Example DAG using SQLExecuteQueryOperator to connect to Trino. +Example DAG using TrinoOperator to query with Trino. """ from __future__ import annotations @@ -24,7 +23,7 @@ from datetime import datetime from airflow import models -from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator +from airflow.providers.trino.operators.trino import TrinoOperator SCHEMA = "hive.cities" TABLE = "city" @@ -36,47 +35,46 @@ with models.DAG( dag_id="example_trino", schedule="@once", # Override to match your needs - start_date=datetime(2022, 1, 1), + start_date=datetime(2025, 2, 17), catchup=False, tags=["example"], ) as dag: - trino_create_schema = SQLExecuteQueryOperator( + + trino_create_schema = TrinoOperator( task_id="trino_create_schema", - sql=f"CREATE SCHEMA IF NOT EXISTS {SCHEMA} WITH (location = 's3://irisbkt/cities/');", - handler=list, + sql=f" CREATE SCHEMA IF NOT EXISTS {SCHEMA} WITH (location = 's3://example-bucket/cities/') ", ) - trino_create_table = SQLExecuteQueryOperator( + + trino_create_table = TrinoOperator( task_id="trino_create_table", - sql=f"""CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE}( - cityid bigint, - cityname varchar - )""", - handler=list, + sql=f" CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE} ( cityid bigint, cityname varchar) ", ) - trino_insert = SQLExecuteQueryOperator( + + trino_insert = TrinoOperator( task_id="trino_insert", - sql=f"""INSERT INTO {SCHEMA}.{TABLE} VALUES (1, 'San Francisco');""", - handler=list, + sql=f" INSERT INTO {SCHEMA}.{TABLE} VALUES (1, 'San Francisco') " ) - trino_multiple_queries = SQLExecuteQueryOperator( + + trino_multiple_queries = TrinoOperator( task_id="trino_multiple_queries", - sql=f"""CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE1}(cityid bigint,cityname varchar); - INSERT INTO {SCHEMA}.{TABLE1} VALUES (2, 'San Jose'); - CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE2}(cityid bigint,cityname varchar); - INSERT INTO {SCHEMA}.{TABLE2} VALUES (3, 'San Diego');""", - handler=list, + sql=[ + f" CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE1}(cityid bigint,cityname varchar) ", + f" INSERT INTO {SCHEMA}.{TABLE1} VALUES (2, 'San Jose') ", + f" CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE2}(cityid bigint,cityname varchar) ", + f" INSERT INTO {SCHEMA}.{TABLE2} VALUES (3, 'San Diego') " + ] ) - trino_templated_query = SQLExecuteQueryOperator( + + trino_templated_query = TrinoOperator( task_id="trino_templated_query", sql="SELECT * FROM {{ params.SCHEMA }}.{{ params.TABLE }}", - handler=list, - params={"SCHEMA": SCHEMA, "TABLE": TABLE1}, + params={"SCHEMA": SCHEMA, "TABLE": TABLE1} ) - trino_parameterized_query = SQLExecuteQueryOperator( + + trino_parameterized_query = TrinoOperator( task_id="trino_parameterized_query", - sql=f"select * from {SCHEMA}.{TABLE2} where cityname = ?", - parameters=("San Diego",), - handler=list, + sql=f"SELECT * FROM {SCHEMA}.{TABLE2} WHERE cityname = ?", + parameters=("San Diego",) ) ( @@ -91,7 +89,7 @@ # [END howto_operator_trino] -from tests_common.test_utils.system_tests import get_test_run # noqa: E402 +from tests_common.test_utils.system_tests import get_test_run # Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) test_run = get_test_run(dag)