Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions providers/trino/docs/operators/trino.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,19 @@

.. _howto/operator:TrinoOperator:

Connect to Trino using SQLExecuteQueryOperator
==============================================
TrinoOperator
=============

Use the :class:`SQLExecuteQueryOperator <airflow.providers.common.sql.operators.sql>` to execute
Use the :class:`TrinoOperator <airflow.providers.trino.operators.trino>` to execute
SQL commands in a `Trino <https://trino.io/>`__ query engine.

.. warning::
TrinoOperator is deprecated in favor of SQLExecuteQueryOperator. If you are using TrinoOperator you should migrate as soon as possible.


Using the Operator
^^^^^^^^^^^^^^^^^^

Use the ``trino_conn_id`` argument to connect to your Trino instance

An example usage of the SQLExecuteQueryOperator to connect to Trino is as follows:
An example usage of the TrinoOperator to connect to Trino is as follows:

.. exampleinclude:: /../../providers/trino/tests/system/trino/example_trino.py
:language: python
Expand Down
16 changes: 16 additions & 0 deletions providers/trino/src/airflow/providers/trino/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
91 changes: 91 additions & 0 deletions providers/trino/src/airflow/providers/trino/operators/trino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains the Trino operator."""
from __future__ import annotations

from collections.abc import Sequence
from typing import Any, ClassVar, List

from airflow.providers.common.sql.hooks.sql import return_single_query_results
from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
from airflow.providers.trino.hooks.trino import TrinoHook


class TrinoOperator(SQLExecuteQueryOperator):
"""
General Trino Operator to execute queries using Trino query engine.

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:TrinoOperator`

:param sql: the SQL code to be executed as a single string, or
a list of str (sql statements), or a reference to a template file.
:param trino_conn_id: id of the connection config for the target Trino environment
:param autocommit: What to set the connection's autocommit setting to before executing the query
:param handler: (optional) the function that will be applied to the cursor (default: fetch_all_handler).
:param parameters: (optional) the parameters to render the SQL query with.
:param output_processor: (optional) the function that will be applied to the result
(default: default_output_processor).
:param split_statements: (optional) if split single SQL string into statements. (default: True).
:param show_return_value_in_logs: (optional) if true operator output will be printed to the task log.
Use with caution. It's not recommended to dump large datasets to the log. (default: False).

"""

template_fields: Sequence[str] = tuple({"trino_conn_id"} | set(SQLExecuteQueryOperator.template_fields))
template_fields_renderers: ClassVar[dict] = {"sql": "sql"}
conn_id_field = "trino_conn_id"

def __init__(
self,
*,
sql: str | List[str],
trino_conn_id: str = TrinoHook.default_conn_name,
**kwargs: Any,
) -> None:
super().__init__(sql=sql, conn_id=trino_conn_id, **kwargs)
self.trino_conn_id = trino_conn_id

def get_db_hook(self) -> TrinoHook:
return TrinoHook(self.trino_conn_id)

def execute(self, context):
self.log.info("Executing: %s", self.sql)
hook = self.get_db_hook()

if self.split_statements is not None:
extra_kwargs = {"split_statements": self.split_statements}
else:
extra_kwargs = {}

output = hook.run(
sql=self.sql,
autocommit=self.autocommit,
parameters=self.parameters,
handler=self.handler,
return_last=self.return_last,
**extra_kwargs,
)
if return_single_query_results(self.sql, self.return_last, self.split_statements):
# For simplicity, we pass always list as input to _process_output, regardless if
# single query results are going to be returned, and we return the first element
# of the list in this case from the (always) list returned by _process_output
return self._process_output([output], hook.descriptions)[-1]
result = self._process_output(output, hook.descriptions)
self.log.info("result: %s", result)
return result
58 changes: 28 additions & 30 deletions providers/trino/tests/system/trino/example_trino.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
Expand All @@ -16,15 +15,15 @@
# specific language governing permissions and limitations
# under the License.
"""
Example DAG using SQLExecuteQueryOperator to connect to Trino.
Example DAG using TrinoOperator to query with Trino.
"""

from __future__ import annotations

from datetime import datetime

from airflow import models
from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
from airflow.providers.trino.operators.trino import TrinoOperator

SCHEMA = "hive.cities"
TABLE = "city"
Expand All @@ -36,47 +35,46 @@
with models.DAG(
dag_id="example_trino",
schedule="@once", # Override to match your needs
start_date=datetime(2022, 1, 1),
start_date=datetime(2025, 2, 17),
catchup=False,
tags=["example"],
) as dag:
trino_create_schema = SQLExecuteQueryOperator(

trino_create_schema = TrinoOperator(
task_id="trino_create_schema",
sql=f"CREATE SCHEMA IF NOT EXISTS {SCHEMA} WITH (location = 's3://irisbkt/cities/');",
handler=list,
sql=f" CREATE SCHEMA IF NOT EXISTS {SCHEMA} WITH (location = 's3://example-bucket/cities/') ",
)
trino_create_table = SQLExecuteQueryOperator(

trino_create_table = TrinoOperator(
task_id="trino_create_table",
sql=f"""CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE}(
cityid bigint,
cityname varchar
)""",
handler=list,
sql=f" CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE} ( cityid bigint, cityname varchar) ",
)
trino_insert = SQLExecuteQueryOperator(

trino_insert = TrinoOperator(
task_id="trino_insert",
sql=f"""INSERT INTO {SCHEMA}.{TABLE} VALUES (1, 'San Francisco');""",
handler=list,
sql=f" INSERT INTO {SCHEMA}.{TABLE} VALUES (1, 'San Francisco') "
)
trino_multiple_queries = SQLExecuteQueryOperator(

trino_multiple_queries = TrinoOperator(
task_id="trino_multiple_queries",
sql=f"""CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE1}(cityid bigint,cityname varchar);
INSERT INTO {SCHEMA}.{TABLE1} VALUES (2, 'San Jose');
CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE2}(cityid bigint,cityname varchar);
INSERT INTO {SCHEMA}.{TABLE2} VALUES (3, 'San Diego');""",
handler=list,
sql=[
f" CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE1}(cityid bigint,cityname varchar) ",
f" INSERT INTO {SCHEMA}.{TABLE1} VALUES (2, 'San Jose') ",
f" CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE2}(cityid bigint,cityname varchar) ",
f" INSERT INTO {SCHEMA}.{TABLE2} VALUES (3, 'San Diego') "
]
)
trino_templated_query = SQLExecuteQueryOperator(

trino_templated_query = TrinoOperator(
task_id="trino_templated_query",
sql="SELECT * FROM {{ params.SCHEMA }}.{{ params.TABLE }}",
handler=list,
params={"SCHEMA": SCHEMA, "TABLE": TABLE1},
params={"SCHEMA": SCHEMA, "TABLE": TABLE1}
)
trino_parameterized_query = SQLExecuteQueryOperator(

trino_parameterized_query = TrinoOperator(
task_id="trino_parameterized_query",
sql=f"select * from {SCHEMA}.{TABLE2} where cityname = ?",
parameters=("San Diego",),
handler=list,
sql=f"SELECT * FROM {SCHEMA}.{TABLE2} WHERE cityname = ?",
parameters=("San Diego",)
)

(
Expand All @@ -91,7 +89,7 @@
# [END howto_operator_trino]


from tests_common.test_utils.system_tests import get_test_run # noqa: E402
from tests_common.test_utils.system_tests import get_test_run

# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
test_run = get_test_run(dag)
16 changes: 16 additions & 0 deletions providers/trino/tests/unit/trino/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
98 changes: 98 additions & 0 deletions providers/trino/tests/unit/trino/operators/test_trino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from unittest import mock
from unittest.mock import MagicMock, Mock

from airflow.exceptions import AirflowException
from airflow.providers.common.sql.hooks.sql import fetch_all_handler
from airflow.providers.trino.hooks.trino import TrinoHook
from airflow.providers.trino.operators.trino import TrinoOperator


class TestTrinoOperator:
@mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook")
def test_get_hook_from_conn(self, mock_get_db_hook):
"""
:class:`~.TrinoOperator` should use the hook returned by :meth:`airflow.models.Connection.get_hook`
if one is returned.

Specifically we verify here that :meth:`~.TrinoOperator.get_hook` returns the hook returned from a
call of ``get_hook`` on the object returned from :meth:`~.BaseHook.get_connection`.
"""
mock_hook = MagicMock()
mock_get_db_hook.return_value = mock_hook

operator = TrinoOperator(task_id="test", sql="")
assert operator.get_db_hook() == mock_hook

@mock.patch(
"airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook",
autospec=TrinoHook,
)
def test_get_hook_default(self, mock_get_db_hook):
"""
If :meth:`airflow.models.Connection.get_hook` does not return a hook (e.g. because of an invalid
conn type), then :class:`~.TrinoHook` should be used.
"""
mock_get_db_hook.return_value.side_effect = Mock(side_effect=AirflowException())

operator = TrinoOperator(task_id="test", sql="")
assert operator.get_db_hook().__class__.__name__ == "TrinoHook"

@mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook")
def test_execute(self, mock_get_db_hook):
sql = "SELECT * FROM test_table"
trino_conn_id = "trino_default"
parameters = ["value"]
autocommit = False
context = "test_context"
task_id = "test_task_id"

operator = TrinoOperator(sql=sql, trino_conn_id=trino_conn_id, parameters=parameters, task_id=task_id)
operator.execute(context=context)
mock_get_db_hook.return_value.run.assert_called_once_with(
sql=sql,
autocommit=autocommit,
parameters=parameters,
handler=fetch_all_handler,
return_last=True,
)

@mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook")
def test_trino_operator_test_multi(self, mock_get_db_hook):
sql = [
"CREATE TABLE IF NOT EXISTS test_airflow (dummy varchar)",
"TRUNCATE TABLE test_airflow",
"INSERT INTO test_airflow VALUES ('X')",
]
trino_conn_id = "trino_default"
parameters = ["value"]
autocommit = False
context = "test_context"
task_id = "test_task_id"

operator = TrinoOperator(sql=sql, trino_conn_id=trino_conn_id, parameters=parameters, task_id=task_id)
operator.execute(context=context)
mock_get_db_hook.return_value.run.assert_called_once_with(
sql=sql,
autocommit=autocommit,
parameters=parameters,
handler=fetch_all_handler,
return_last=True,
)
Loading