Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
:param return_last: (optional) return the result of only last statement (default: True).
:param show_return_value_in_logs: (optional) if true operator output will be printed to the task log.
Use with caution. It's not recommended to dump large datasets to the log. (default: False).
:param requires_result_fetch: (optional) if True, ensures that query results are fetched before
completing execution. If `do_xcom_push` is True, results are fetched automatically,
making this parameter redundant. (default: False).

.. seealso::
For more information on how to use this operator, take a look at the guide:
Expand Down Expand Up @@ -254,6 +257,7 @@ def __init__(
split_statements: bool | None = None,
return_last: bool = True,
show_return_value_in_logs: bool = False,
requires_result_fetch: bool = False,
**kwargs,
) -> None:
super().__init__(conn_id=conn_id, database=database, **kwargs)
Expand All @@ -265,6 +269,7 @@ def __init__(
self.split_statements = split_statements
self.return_last = return_last
self.show_return_value_in_logs = show_return_value_in_logs
self.requires_result_fetch = requires_result_fetch

def _process_output(
self, results: list[Any], descriptions: list[Sequence[Sequence] | None]
Expand Down Expand Up @@ -303,7 +308,9 @@ def execute(self, context):
sql=self.sql,
autocommit=self.autocommit,
parameters=self.parameters,
handler=self.handler if self._should_run_output_processing() else None,
handler=self.handler
if self._should_run_output_processing() or self.requires_result_fetch
else None,
return_last=self.return_last,
**extra_kwargs,
)
Expand Down
22 changes: 20 additions & 2 deletions providers/common/sql/tests/unit/common/sql/operators/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,11 @@ def _construct_operator(self, sql, **kwargs):

@mock.patch.object(SQLExecuteQueryOperator, "_process_output")
@mock.patch.object(SQLExecuteQueryOperator, "get_db_hook")
def test_do_xcom_push(self, mock_get_db_hook, mock_process_output):
operator = self._construct_operator("SELECT 1;", do_xcom_push=True)
@pytest.mark.parametrize("requires_result_fetch", [True, False])
def test_do_xcom_push(self, mock_get_db_hook, mock_process_output, requires_result_fetch):
operator = self._construct_operator(
"SELECT 1;", do_xcom_push=True, requires_result_fetch=requires_result_fetch
)
operator.execute(context=MagicMock())

mock_get_db_hook.return_value.run.assert_called_once_with(
Expand All @@ -152,6 +155,21 @@ def test_dont_xcom_push(self, mock_get_db_hook, mock_process_output):
)
mock_process_output.assert_not_called()

@mock.patch.object(SQLExecuteQueryOperator, "_process_output")
@mock.patch.object(SQLExecuteQueryOperator, "get_db_hook")
def test_requires_result_fetch_dont_xcom_push(self, mock_get_db_hook, mock_process_output):
operator = self._construct_operator("SELECT 1;", requires_result_fetch=True, do_xcom_push=False)
operator.execute(context=MagicMock())

mock_get_db_hook.return_value.run.assert_called_once_with(
sql="SELECT 1;",
autocommit=False,
handler=fetch_all_handler,
parameters=None,
return_last=True,
)
mock_process_output.assert_not_called()

@mock.patch.object(SQLExecuteQueryOperator, "get_db_hook")
def test_output_processor(self, mock_get_db_hook):
data = [(1, "Alice"), (2, "Bob")]
Expand Down
25 changes: 13 additions & 12 deletions providers/trino/tests/system/trino/example_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,35 +36,36 @@
with models.DAG(
dag_id="example_trino",
schedule="@once", # Override to match your needs
start_date=datetime(2022, 1, 1),
start_date=datetime(2025, 2, 24),
catchup=False,
tags=["example"],
) as dag:
trino_create_schema = SQLExecuteQueryOperator(
task_id="trino_create_schema",
sql=f"CREATE SCHEMA IF NOT EXISTS {SCHEMA} WITH (location = 's3://irisbkt/cities/');",
sql=f" CREATE SCHEMA IF NOT EXISTS {SCHEMA} WITH (location = 's3://irisbkt/cities/') ",
handler=list,
)
trino_create_table = SQLExecuteQueryOperator(
task_id="trino_create_table",
sql=f"""CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE}(
cityid bigint,
cityname varchar
)""",
sql=f" CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE}( cityid bigint, cityname varchar) ",
handler=list,
)
trino_insert = SQLExecuteQueryOperator(
task_id="trino_insert",
sql=f"""INSERT INTO {SCHEMA}.{TABLE} VALUES (1, 'San Francisco');""",
sql=f" INSERT INTO {SCHEMA}.{TABLE} VALUES (1, 'San Francisco') ",
handler=list,
requires_result_fetch=True,
)
trino_multiple_queries = SQLExecuteQueryOperator(
task_id="trino_multiple_queries",
sql=f"""CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE1}(cityid bigint,cityname varchar);
INSERT INTO {SCHEMA}.{TABLE1} VALUES (2, 'San Jose');
CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE2}(cityid bigint,cityname varchar);
INSERT INTO {SCHEMA}.{TABLE2} VALUES (3, 'San Diego');""",
sql=[
f" CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE1}(cityid bigint,cityname varchar) ",
f" INSERT INTO {SCHEMA}.{TABLE1} VALUES (2, 'San Jose') ",
f" CREATE TABLE IF NOT EXISTS {SCHEMA}.{TABLE2}(cityid bigint,cityname varchar) ",
f" INSERT INTO {SCHEMA}.{TABLE2} VALUES (3, 'San Diego') ",
],
handler=list,
requires_result_fetch=True,
)
trino_templated_query = SQLExecuteQueryOperator(
task_id="trino_templated_query",
Expand All @@ -74,7 +75,7 @@
)
trino_parameterized_query = SQLExecuteQueryOperator(
task_id="trino_parameterized_query",
sql=f"select * from {SCHEMA}.{TABLE2} where cityname = ?",
sql=f" SELECT * FROM {SCHEMA}.{TABLE2} WHERE cityname = ?",
parameters=("San Diego",),
handler=list,
)
Expand Down