diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/spanner.py b/providers/google/src/airflow/providers/google/cloud/hooks/spanner.py index 93ace3ff19606..d364dd17673a0 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/spanner.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/spanner.py @@ -19,6 +19,7 @@ from __future__ import annotations +from collections import OrderedDict from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, NamedTuple @@ -388,7 +389,7 @@ def execute_dml( database_id: str, queries: list[str], project_id: str, - ) -> None: + ) -> list[int]: """ Execute an arbitrary DML query (INSERT, UPDATE, DELETE). @@ -398,12 +399,31 @@ def execute_dml( :param project_id: Optional, the ID of the Google Cloud project that owns the Cloud Spanner database. If set to None or missing, the default project_id from the Google Cloud connection is used. + :return: list of numbers of affected rows by DML query """ - self._get_client(project_id=project_id).instance(instance_id=instance_id).database( - database_id=database_id - ).run_in_transaction(lambda transaction: self._execute_sql_in_transaction(transaction, queries)) + db = ( + self._get_client(project_id=project_id) + .instance(instance_id=instance_id) + .database(database_id=database_id) + ) + + def _tx_runner(tx: Transaction) -> dict[str, int]: + return self._execute_sql_in_transaction(tx, queries) + + result = db.run_in_transaction(_tx_runner) + + result_rows_count_per_query = [] + for i, (sql, rc) in enumerate(result.items(), start=1): + if not sql.startswith("SELECT"): + preview = sql if len(sql) <= 300 else sql[:300] + "…" + self.log.info("[DML %d/%d] affected rows=%d | %s", i, len(result), rc, preview) + result_rows_count_per_query.append(rc) + return result_rows_count_per_query @staticmethod - def _execute_sql_in_transaction(transaction: Transaction, queries: list[str]): + def _execute_sql_in_transaction(transaction: Transaction, queries: list[str]) -> dict[str, int]: + counts: OrderedDict[str, int] = OrderedDict() for sql in queries: - transaction.execute_update(sql) + rc = transaction.execute_update(sql) + counts[sql] = rc + return counts diff --git a/providers/google/src/airflow/providers/google/cloud/operators/spanner.py b/providers/google/src/airflow/providers/google/cloud/operators/spanner.py index 51c4f61f20841..732b2e19b7c1b 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/spanner.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/spanner.py @@ -280,8 +280,8 @@ def execute(self, context: Context): self.instance_id, self.database_id, ) - self.log.info(queries) - hook.execute_dml( + self.log.info("Executing queries: %s", queries) + result_rows_count_per_query = hook.execute_dml( project_id=self.project_id, instance_id=self.instance_id, database_id=self.database_id, @@ -293,6 +293,7 @@ def execute(self, context: Context): database_id=self.database_id, project_id=self.project_id or hook.project_id, ) + return result_rows_count_per_query @staticmethod def sanitize_queries(queries: list[str]) -> None: diff --git a/providers/google/tests/unit/google/cloud/hooks/test_spanner.py b/providers/google/tests/unit/google/cloud/hooks/test_spanner.py index 527a0cb0cce79..ad1f1906795b1 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_spanner.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_spanner.py @@ -17,9 +17,11 @@ # under the License. from __future__ import annotations +from collections import OrderedDict from unittest import mock from unittest.mock import MagicMock, PropertyMock +import pytest import sqlalchemy from airflow.providers.google.cloud.hooks.spanner import SpannerHook @@ -405,14 +407,14 @@ def test_execute_dml(self, get_client, mock_project_id): res = self.spanner_hook_default_project_id.execute_dml( instance_id=SPANNER_INSTANCE, database_id=SPANNER_DATABASE, - queries="", + queries=[""], project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, ) get_client.assert_called_once_with(project_id="example-project") instance_method.assert_called_once_with(instance_id="instance") database_method.assert_called_once_with(database_id="database-name") run_in_transaction_method.assert_called_once_with(mock.ANY) - assert res is None + assert res == [] @mock.patch("airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client") def test_execute_dml_overridden_project_id(self, get_client): @@ -422,13 +424,75 @@ def test_execute_dml_overridden_project_id(self, get_client): database_method = instance_method.return_value.database run_in_transaction_method = database_method.return_value.run_in_transaction res = self.spanner_hook_default_project_id.execute_dml( - project_id="new-project", instance_id=SPANNER_INSTANCE, database_id=SPANNER_DATABASE, queries="" + project_id="new-project", instance_id=SPANNER_INSTANCE, database_id=SPANNER_DATABASE, queries=[""] ) get_client.assert_called_once_with(project_id="new-project") instance_method.assert_called_once_with(instance_id="instance") database_method.assert_called_once_with(database_id="database-name") run_in_transaction_method.assert_called_once_with(mock.ANY) - assert res is None + assert res == [] + + @mock.patch("airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client") + def test_execute_dml_oqueries_row_count(self, get_client): + pass + + @pytest.mark.parametrize( + "returned_items, expected_counts", + [ + pytest.param( + [ + ("DELETE FROM T WHERE archived = TRUE", 5), + ("SELECT * FROM T", 42), + ("UPDATE U SET flag = FALSE WHERE x = 1", 3), + ], + [5, 3], + ), + pytest.param( + [ + ("DELETE FROM Logs WHERE created_at < '2024-01-01'", 7), + ], + [7], + ), + pytest.param( + [ + ( + "UPDATE Accounts SET active=false WHERE last_login < DATE_SUB(CURRENT_DATE(), INTERVAL 365 DAY)", + 11, + ), + ("DELETE FROM Sessions WHERE expires_at < CURRENT_TIMESTAMP()", 23), + ], + [11, 23], + ), + pytest.param( + [ + ("SELECT COUNT(*) FROM Users", 50000), + ("SELECT * FROM BigTable", 123456), + ], + [], + ), + pytest.param( + [], + [], + ), + ], + ) + @mock.patch("airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client") + def test_execute_dml_parametrized(self, get_client, returned_items, expected_counts): + instance_method = get_client.return_value.instance + database_method = instance_method.return_value.database + run_in_tx = database_method.return_value.run_in_transaction + + returned_mapping = OrderedDict(returned_items) + run_in_tx.return_value = returned_mapping + + res = self.spanner_hook_default_project_id.execute_dml( + instance_id=SPANNER_INSTANCE, + database_id=SPANNER_DATABASE, + queries=[sql for sql, _ in returned_items], + project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, + ) + + assert res == expected_counts def test_get_uri(self): self.spanner_hook_default_project_id._get_conn_params = MagicMock(return_value=SPANNER_CONN_PARAMS) @@ -682,13 +746,13 @@ def test_execute_dml_overridden_project_id(self, get_client): project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, instance_id=SPANNER_INSTANCE, database_id=SPANNER_DATABASE, - queries="", + queries=[""], ) get_client.assert_called_once_with(project_id="example-project") instance_method.assert_called_once_with(instance_id="instance") database_method.assert_called_once_with(database_id="database-name") run_in_transaction_method.assert_called_once_with(mock.ANY) - assert res is None + assert res == [] def test_get_uri(self): self.spanner_hook_no_default_project_id._get_conn_params = MagicMock(return_value=SPANNER_CONN_PARAMS) diff --git a/providers/google/tests/unit/google/cloud/operators/test_spanner.py b/providers/google/tests/unit/google/cloud/operators/test_spanner.py index e9d800665bf2c..1784a0499aab0 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_spanner.py +++ b/providers/google/tests/unit/google/cloud/operators/test_spanner.py @@ -250,7 +250,7 @@ def test_instance_delete_ex_if_param_missing(self, mock_hook, project_id, instan @mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook") def test_instance_query(self, mock_hook): - mock_hook.return_value.execute_sql.return_value = None + mock_hook.return_value.execute_dml.return_value = [3] op = SpannerQueryDatabaseInstanceOperator( project_id=PROJECT_ID, instance_id=INSTANCE_ID, @@ -258,8 +258,7 @@ def test_instance_query(self, mock_hook): query=INSERT_QUERY, task_id="id", ) - context = mock.MagicMock() - result = op.execute(context=context) + result = op.execute(context=mock.MagicMock()) mock_hook.assert_called_once_with( gcp_conn_id="google_cloud_default", impersonation_chain=None, @@ -267,11 +266,11 @@ def test_instance_query(self, mock_hook): mock_hook.return_value.execute_dml.assert_called_once_with( project_id=PROJECT_ID, instance_id=INSTANCE_ID, database_id=DB_ID, queries=[INSERT_QUERY] ) - assert result is None + assert result == [3] @mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook") def test_instance_query_missing_project_id(self, mock_hook): - mock_hook.return_value.execute_sql.return_value = None + mock_hook.return_value.execute_dml.return_value = [3] op = SpannerQueryDatabaseInstanceOperator( instance_id=INSTANCE_ID, database_id=DB_ID, query=INSERT_QUERY, task_id="id" ) @@ -284,7 +283,7 @@ def test_instance_query_missing_project_id(self, mock_hook): mock_hook.return_value.execute_dml.assert_called_once_with( project_id=None, instance_id=INSTANCE_ID, database_id=DB_ID, queries=[INSERT_QUERY] ) - assert result is None + assert result == [3] @pytest.mark.parametrize( "project_id, instance_id, database_id, query, exp_msg",