diff --git a/providers/amazon/pyproject.toml b/providers/amazon/pyproject.toml index aa2f29ca5584b..dcae5f2bdfdd2 100644 --- a/providers/amazon/pyproject.toml +++ b/providers/amazon/pyproject.toml @@ -60,7 +60,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.13.0", - "apache-airflow-providers-common-sql>=1.27.0", + "apache-airflow-providers-common-sql>=1.27.0", # use next version "apache-airflow-providers-http", # We should update minimum version of boto3 and here regularly to avoid `pip` backtracking with the number # of candidates to consider. Make sure to configure boto3 version here as well as in all the tools below diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena.py index e6e3066540ab7..5c9621cf69848 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena.py @@ -31,6 +31,7 @@ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.utils.waiter_with_logging import wait from airflow.providers.common.compat.sdk import AirflowException +from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage if TYPE_CHECKING: from botocore.paginate import PageIterator @@ -126,6 +127,11 @@ def run_query( response = self.get_conn().start_query_execution(**params) query_execution_id = response["QueryExecutionId"] self.log.info("Query execution id: %s", query_execution_id) + send_sql_hook_lineage( + context=self, + sql=query, + job_id=query_execution_id, + ) return query_execution_id def get_query_info(self, query_execution_id: str, use_cache: bool = False) -> dict: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_data.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_data.py index c57d1d2a26763..c6c50e6a2820a 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_data.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_data.py @@ -28,6 +28,7 @@ from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook from airflow.providers.amazon.aws.utils import trim_none_values +from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage if TYPE_CHECKING: from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient # noqa: F401 @@ -154,6 +155,19 @@ def execute_query( statement_id = resp["Id"] + send_sql_hook_lineage( + context=self, + sql="; ".join(sql) if isinstance(sql, list) else sql, + sql_parameters=parameters or None, + job_id=statement_id, + default_db=database, + extra={ + "cluster_identifier": cluster_identifier, + "workgroup_name": workgroup_name, + "session_id": session_id or resp.get("SessionId"), + }, + ) + if wait_for_completion: self.wait_for_results(statement_id, poll_interval=poll_interval) diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_athena.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_athena.py index ebc141c026d05..e743e831873e5 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_athena.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_athena.py @@ -124,6 +124,21 @@ def test_hook_run_query_no_log_query(self, mock_conn, log): ) assert athena_hook_no_log_query.log.info.call_count == 1 + @mock.patch("airflow.providers.amazon.aws.hooks.athena.send_sql_hook_lineage") + @mock.patch.object(AthenaHook, "get_conn") + def test_run_query_hook_lineage(self, mock_conn, mock_send_lineage): + mock_conn.return_value.start_query_execution.return_value = MOCK_QUERY_EXECUTION + self.athena.run_query( + query=MOCK_DATA["query"], + query_context=mock_query_context, + result_configuration=mock_result_configuration, + ) + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.athena + assert call_kw["sql"] == MOCK_DATA["query"] + assert call_kw["job_id"] == MOCK_DATA["query_execution_id"] + @mock.patch.object(AthenaHook, "get_conn") def test_hook_get_query_results_with_non_succeeded_query(self, mock_conn): mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_redshift_data.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_redshift_data.py index 1c30a2ecfc1fb..5940cf2009eb9 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_redshift_data.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_redshift_data.py @@ -276,6 +276,30 @@ def test_execute_reuse_session(self, mock_conn): Id=STATEMENT_ID, ) + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.send_sql_hook_lineage") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") + def test_execute_query_hook_lineage(self, mock_conn, mock_send_lineage): + mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID} + hook = RedshiftDataHook() + hook.execute_query( + database=DATABASE, + cluster_identifier="cluster_identifier", + sql=SQL, + wait_for_completion=False, + ) + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == SQL + assert call_kw["sql_parameters"] is None + assert call_kw["job_id"] == STATEMENT_ID + assert call_kw["default_db"] == DATABASE + assert call_kw["extra"] == { + "cluster_identifier": "cluster_identifier", + "workgroup_name": None, + "session_id": None, + } + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") def test_batch_execute(self, mock_conn): mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID} diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_redshift_sql.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_redshift_sql.py index 4c78ac8828d55..f1cf985517058 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_redshift_sql.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_redshift_sql.py @@ -284,3 +284,81 @@ def test_get_openlineage_redshift_authority_part( assert f"{expected_identity}:{LOGIN_PORT}" == self.db_hook._get_openlineage_redshift_authority_part( self.connection ) + + +class TestRedshiftSQLHookLineage: + def setup_method(self): + self.cur = mock.MagicMock(rowcount=0) + self.conn = mock.MagicMock() + self.conn.cursor.return_value = self.cur + conn = self.conn + + class UnitTestRedshiftSQLHook(RedshiftSQLHook): + conn_name_attr = "test_conn_id" + + def get_conn(self): + return conn + + self.db_hook = UnitTestRedshiftSQLHook() + self.db_hook.get_connection = mock.Mock( + return_value=Connection( + conn_type="redshift", + login=LOGIN_USER, + password=LOGIN_PASSWORD, + host=LOGIN_HOST, + port=LOGIN_PORT, + schema=LOGIN_SCHEMA, + ) + ) + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_run_hook_lineage(self, mock_send_lineage): + statement = "SELECT 1" + + self.db_hook.run(statement) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == statement + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is self.cur + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_insert_rows_hook_lineage(self, mock_send_lineage): + table = "table" + rows = [("hello",), ("world",)] + + self.db_hook.insert_rows(table, rows) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == "INSERT INTO table VALUES (%s)" + assert call_kw["row_count"] == 2 + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df") + def test_get_df_hook_lineage(self, mock_get_pandas_df, mock_send_lineage): + sql = "SELECT 1" + parameters = ("x",) + self.db_hook.get_df(sql, parameters=parameters) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") + def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_send_lineage): + sql = "SELECT 1" + parameters = ("x",) + self.db_hook.get_df_by_chunks(sql, parameters=parameters, chunksize=1) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters diff --git a/providers/apache/drill/pyproject.toml b/providers/apache/drill/pyproject.toml index 2f017350eb90c..808319a4b151b 100644 --- a/providers/apache/drill/pyproject.toml +++ b/providers/apache/drill/pyproject.toml @@ -59,7 +59,7 @@ requires-python = ">=3.10" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=2.11.0", - "apache-airflow-providers-common-sql>=1.26.0", + "apache-airflow-providers-common-sql>=1.26.0", # use next version "apache-airflow-providers-common-compat>=1.8.0", # Workaround until we get https://github.com/JohnOmernik/sqlalchemy-drill/issues/94 fixed. "sqlalchemy-drill>=1.1.0,!=1.1.6,!=1.1.7", diff --git a/providers/apache/drill/tests/unit/apache/drill/hooks/test_drill.py b/providers/apache/drill/tests/unit/apache/drill/hooks/test_drill.py index a1d23ed8597fe..8af4f67dec510 100644 --- a/providers/apache/drill/tests/unit/apache/drill/hooks/test_drill.py +++ b/providers/apache/drill/tests/unit/apache/drill/hooks/test_drill.py @@ -191,3 +191,40 @@ def test_insert_rows_raises_not_implemented(self): db_hook = self.db_hook() with pytest.raises(NotImplementedError, match=r"There is no INSERT statement in Drill."): db_hook.insert_rows(table="my_table", rows=[("a",)]) + + @patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_run_hook_lineage(self, mock_send_lineage): + sql = "SELECT 1" + self.db_hook().run(sql) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is not None + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is self.cur + + @patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df") + def test_get_df_hook_lineage(self, mock_get_pandas_df, mock_send_lineage): + sql = "SELECT 1" + self.db_hook().get_df(sql, df_type="pandas") + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is not None + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] is None + + @patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") + def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_send_lineage): + sql = "SELECT 1" + parameters = ("x",) + self.db_hook().get_df_by_chunks(sql, parameters=parameters, chunksize=1) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is not None + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters diff --git a/providers/apache/druid/pyproject.toml b/providers/apache/druid/pyproject.toml index c1368ef735c32..3bdcfb7364fde 100644 --- a/providers/apache/druid/pyproject.toml +++ b/providers/apache/druid/pyproject.toml @@ -59,7 +59,7 @@ requires-python = ">=3.10" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=2.11.0", - "apache-airflow-providers-common-sql>=1.26.0", + "apache-airflow-providers-common-sql>=1.26.0", # use next version "apache-airflow-providers-common-compat>=1.10.1", "pydruid>=0.6.6", ] diff --git a/providers/apache/druid/tests/unit/apache/druid/hooks/test_druid.py b/providers/apache/druid/tests/unit/apache/druid/hooks/test_druid.py index a152be0ffc082..c2e82b84e5570 100644 --- a/providers/apache/druid/tests/unit/apache/druid/hooks/test_druid.py +++ b/providers/apache/druid/tests/unit/apache/druid/hooks/test_druid.py @@ -509,3 +509,40 @@ def test_get_df_polars(self): assert column == df.columns[0] assert result_sets[0][0] == df.row(0)[0] assert result_sets[1][0] == df.row(1)[0] + + @patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_run_hook_lineage(self, mock_send_lineage): + sql = "SELECT 1" + self.db_hook().run(sql) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is not None + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is self.cur + + @patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df") + def test_get_df_hook_lineage(self, mock_get_pandas_df, mock_send_lineage): + sql = "SELECT 1" + self.db_hook().get_df(sql, df_type="pandas") + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is not None + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] is None + + @patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") + def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_send_lineage): + sql = "SELECT 1" + parameters = ("x",) + self.db_hook().get_df_by_chunks(sql, parameters=parameters, chunksize=1) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is not None + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters diff --git a/providers/apache/hive/pyproject.toml b/providers/apache/hive/pyproject.toml index 8b37a262a9030..8d3103b926886 100644 --- a/providers/apache/hive/pyproject.toml +++ b/providers/apache/hive/pyproject.toml @@ -60,7 +60,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.12.0", - "apache-airflow-providers-common-sql>=1.26.0", + "apache-airflow-providers-common-sql>=1.26.0", # use next version "hmsclient>=0.1.0", 'pandas>=2.1.2; python_version <"3.13"', 'pandas>=2.2.3; python_version >="3.13"', diff --git a/providers/apache/hive/src/airflow/providers/apache/hive/hooks/hive.py b/providers/apache/hive/src/airflow/providers/apache/hive/hooks/hive.py index 7b01f79b057fd..a9758f61bf7e7 100644 --- a/providers/apache/hive/src/airflow/providers/apache/hive/hooks/hive.py +++ b/providers/apache/hive/src/airflow/providers/apache/hive/hooks/hive.py @@ -39,6 +39,7 @@ BaseHook, conf, ) +from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.security import utils from airflow.utils.helpers import as_flattened_list @@ -332,6 +333,8 @@ def run_cli( if sub_process.returncode: raise AirflowException(stdout) + send_sql_hook_lineage(context=self, sql=hql) + return stdout def test_hql(self, hql: str) -> None: @@ -916,6 +919,7 @@ def _get_results( for statement in sql: cur.execute(statement) + send_sql_hook_lineage(context=self, sql=statement, cur=cur, default_schema=schema) # we only get results of statements that returns lowered_statement = statement.lower().strip() if lowered_statement.startswith(("select", "with", "show")) or ( diff --git a/providers/apache/hive/tests/unit/apache/hive/hooks/test_hive.py b/providers/apache/hive/tests/unit/apache/hive/hooks/test_hive.py index a6ef4caa45648..f94c05c3e27a3 100644 --- a/providers/apache/hive/tests/unit/apache/hive/hooks/test_hive.py +++ b/providers/apache/hive/tests/unit/apache/hive/hooks/test_hive.py @@ -117,6 +117,33 @@ def test_run_cli(self, mock_popen, mock_temp_dir): close_fds=True, ) + @mock.patch("airflow.providers.apache.hive.hooks.hive.send_sql_hook_lineage") + @mock.patch("tempfile.tempdir", "/tmp/") + @mock.patch("tempfile._RandomNameSequence.__next__") + @mock.patch("subprocess.Popen") + def test_run_cli_hook_lineage(self, mock_popen, mock_temp_dir, mock_send_lineage): + mock_subprocess = MockSubProcess() + mock_popen.return_value = mock_subprocess + mock_temp_dir.return_value = "test_run_cli" + hql = "SHOW DATABASES" + envron_name = "AIRFLOW_CTX_LOGICAL_DATE" if AIRFLOW_V_3_0_PLUS else "AIRFLOW_CTX_EXECUTION_DATE" + with mock.patch.dict( + "os.environ", + { + "AIRFLOW_CTX_DAG_ID": "test_dag_id", + "AIRFLOW_CTX_TASK_ID": "test_task_id", + envron_name: "2015-01-01T00:00:00+00:00", + "AIRFLOW_CTX_TRY_NUMBER": "1", + "AIRFLOW_CTX_DAG_RUN_ID": "55", + "AIRFLOW_CTX_DAG_OWNER": "airflow", + "AIRFLOW_CTX_DAG_EMAIL": "test@airflow.com", + }, + ): + hook = MockHiveCliHook() + hook.run_cli(hql, schema="some_schema") + + mock_send_lineage.assert_called_once_with(context=hook, sql=f"USE some_schema;\n{hql}\n") + def test_hive_cli_hook_invalid_schema(self): hook = InvalidHiveCliHook() with pytest.raises(RuntimeError) as error: @@ -767,6 +794,20 @@ def test_get_results_data(self): assert results["data"] == [(1, 1), (2, 2)] + @mock.patch("airflow.providers.apache.hive.hooks.hive.send_sql_hook_lineage") + def test_get_results_sends_hook_lineage(self, mock_send_lineage): + hook = MockHiveServer2Hook() + + query = f"SELECT * FROM {self.table}" + hook.get_results(query, schema=self.database) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == query + assert call_kw["cur"] is hook.mock_cursor + assert call_kw["default_schema"] == self.database + def test_to_csv(self): hook = MockHiveServer2Hook() hook._get_results = mock.MagicMock( diff --git a/providers/apache/impala/pyproject.toml b/providers/apache/impala/pyproject.toml index 80a5dc94a9b8a..e88a8902f3d02 100644 --- a/providers/apache/impala/pyproject.toml +++ b/providers/apache/impala/pyproject.toml @@ -60,7 +60,7 @@ requires-python = ">=3.10" dependencies = [ "impyla>=0.22.0,<1.0", "apache-airflow-providers-common-compat>=1.12.0", - "apache-airflow-providers-common-sql>=1.26.0", + "apache-airflow-providers-common-sql>=1.26.0", # use next version "apache-airflow>=2.11.0", ] diff --git a/providers/apache/impala/tests/unit/apache/impala/hooks/test_impala.py b/providers/apache/impala/tests/unit/apache/impala/hooks/test_impala.py index 3283e3092349b..8feb5171454b3 100644 --- a/providers/apache/impala/tests/unit/apache/impala/hooks/test_impala.py +++ b/providers/apache/impala/tests/unit/apache/impala/hooks/test_impala.py @@ -31,6 +31,7 @@ def impala_hook_fixture() -> ImpalaHook: mock_get_conn.return_value.cursor = MagicMock() mock_get_conn.return_value.cursor.return_value.rowcount = 2 hook.get_conn = mock_get_conn # type:ignore[method-assign] + hook.get_connection = MagicMock(return_value=Connection(conn_type="impala")) # type:ignore[method-assign] return hook @@ -136,3 +137,56 @@ def test_get_df_polars(impala_hook_fixture): assert column == df.columns[0] assert result_sets[0][0] == df.row(0)[0] assert result_sets[1][0] == df.row(1)[0] + + +@patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") +def test_run_hook_lineage(mock_send_lineage, impala_hook_fixture): + sql = "SELECT 1" + impala_hook_fixture.run(sql) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is impala_hook_fixture + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is impala_hook_fixture.get_conn.return_value.cursor.return_value + + +@patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") +def test_insert_rows_hook_lineage(mock_send_lineage, impala_hook_fixture): + table = "table" + rows = [("hello",), ("world",)] + impala_hook_fixture.insert_rows(table, rows) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is impala_hook_fixture + assert call_kw["sql"] == "INSERT INTO table VALUES (%s)" + assert call_kw["row_count"] == 2 + + +@patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") +@patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df") +def test_get_df_hook_lineage(mock_get_pandas_df, mock_send_lineage, impala_hook_fixture): + sql = "SELECT 1" + impala_hook_fixture.get_df(sql, df_type="pandas") + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is impala_hook_fixture + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] is None + + +@patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") +@patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") +def test_get_df_by_chunks_hook_lineage(mock_get_pandas_df_by_chunks, mock_send_lineage, impala_hook_fixture): + sql = "SELECT 1" + parameters = ("x",) + impala_hook_fixture.get_df_by_chunks(sql, parameters=parameters, chunksize=1) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is impala_hook_fixture + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters diff --git a/providers/apache/pinot/pyproject.toml b/providers/apache/pinot/pyproject.toml index 7ee153fe22f15..7f7326157f5b6 100644 --- a/providers/apache/pinot/pyproject.toml +++ b/providers/apache/pinot/pyproject.toml @@ -60,7 +60,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.10.1", - "apache-airflow-providers-common-sql>=1.26.0", + "apache-airflow-providers-common-sql>=1.26.0", # use next version "pinotdb>=5.1.0", ] diff --git a/providers/apache/pinot/src/airflow/providers/apache/pinot/hooks/pinot.py b/providers/apache/pinot/src/airflow/providers/apache/pinot/hooks/pinot.py index 3093844fb2767..11eaf189a8cb4 100644 --- a/providers/apache/pinot/src/airflow/providers/apache/pinot/hooks/pinot.py +++ b/providers/apache/pinot/src/airflow/providers/apache/pinot/hooks/pinot.py @@ -26,6 +26,7 @@ from pinotdb import connect from airflow.providers.common.compat.sdk import AirflowException, BaseHook +from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage from airflow.providers.common.sql.hooks.sql import DbApiHook if TYPE_CHECKING: @@ -335,6 +336,7 @@ def get_records( """ with self.get_conn() as cur: cur.execute(sql) + send_sql_hook_lineage(context=self, sql=sql, sql_parameters=parameters, cur=cur) return cur.fetchall() def get_first(self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] | None = None) -> Any: @@ -347,6 +349,7 @@ def get_first(self, sql: str | list[str], parameters: Iterable | Mapping[str, An """ with self.get_conn() as cur: cur.execute(sql) + send_sql_hook_lineage(context=self, sql=sql, sql_parameters=parameters, cur=cur) return cur.fetchone() def set_autocommit(self, conn: Connection, autocommit: Any) -> Any: diff --git a/providers/apache/pinot/tests/unit/apache/pinot/hooks/test_pinot.py b/providers/apache/pinot/tests/unit/apache/pinot/hooks/test_pinot.py index 2af53cc8aafd2..14614f1f84c87 100644 --- a/providers/apache/pinot/tests/unit/apache/pinot/hooks/test_pinot.py +++ b/providers/apache/pinot/tests/unit/apache/pinot/hooks/test_pinot.py @@ -260,12 +260,81 @@ def test_get_records(self): self.cur.fetchall.return_value = result_sets assert result_sets == self.db_hook().get_records(statement) + @mock.patch("airflow.providers.apache.pinot.hooks.pinot.send_sql_hook_lineage") + def test_get_records_hook_lineage(self, mock_send_lineage): + statement = "SQL" + hook = self.db_hook() + hook.get_records(statement) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == statement + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is self.cur + def test_get_first(self): statement = "SQL" result_sets = [("row1",), ("row2",)] self.cur.fetchone.return_value = result_sets[0] assert result_sets[0] == self.db_hook().get_first(statement) + @mock.patch("airflow.providers.apache.pinot.hooks.pinot.send_sql_hook_lineage") + def test_get_first_hook_lineage(self, mock_send_lineage): + statement = "SQL" + hook = self.db_hook() + hook.get_first(statement) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == statement + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is self.cur + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_run_hook_lineage(self, mock_send_lineage): + statement = "SELECT 1" + hook = self.db_hook() + self.cur.fetchall.return_value = [] + + hook.run(statement) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == statement + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is self.cur + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df") + def test_get_df_hook_lineage(self, mock_get_pandas_df, mock_send_lineage): + statement = "SELECT 1" + parameters = ("x",) + hook = self.db_hook() + hook.get_df(statement, parameters=parameters) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == statement + assert call_kw["sql_parameters"] == parameters + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") + def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_send_lineage): + statement = "SELECT 1" + parameters = ("x",) + hook = self.db_hook() + list(hook.get_df_by_chunks(statement, parameters=parameters, chunksize=1)) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == statement + assert call_kw["sql_parameters"] == parameters + def test_get_df_pandas(self): statement = "SQL" column = "col" diff --git a/providers/common/sql/provider.yaml b/providers/common/sql/provider.yaml index b491a539d1945..a56f441b2debb 100644 --- a/providers/common/sql/provider.yaml +++ b/providers/common/sql/provider.yaml @@ -110,6 +110,7 @@ hooks: - integration-name: Common SQL python-modules: - airflow.providers.common.sql.hooks.handlers + - airflow.providers.common.sql.hooks.lineage - airflow.providers.common.sql.hooks.sql triggers: diff --git a/providers/common/sql/src/airflow/providers/common/sql/get_provider_info.py b/providers/common/sql/src/airflow/providers/common/sql/get_provider_info.py index 29ac848d4fc90..e2f8332bd95af 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/get_provider_info.py +++ b/providers/common/sql/src/airflow/providers/common/sql/get_provider_info.py @@ -55,6 +55,7 @@ def get_provider_info(): "integration-name": "Common SQL", "python-modules": [ "airflow.providers.common.sql.hooks.handlers", + "airflow.providers.common.sql.hooks.lineage", "airflow.providers.common.sql.hooks.sql", ], } diff --git a/providers/common/sql/src/airflow/providers/common/sql/hooks/handlers.py b/providers/common/sql/src/airflow/providers/common/sql/hooks/handlers.py index aabb2571e5d1a..e67ddc8fb0f26 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/hooks/handlers.py +++ b/providers/common/sql/src/airflow/providers/common/sql/hooks/handlers.py @@ -51,6 +51,15 @@ def return_single_query_results( return isinstance(sql, str) and return_last +def get_row_count(cursor) -> int | None: + # According to PEP 249, this is -1 or None when query result is not applicable. + # We standardize so it's either None (when not applicable) or positive integer / 0 (when applicable) + row_count = getattr(cursor, "rowcount", None) + if isinstance(row_count, int) and row_count >= 0: + return row_count + return None + + def fetch_all_handler(cursor) -> list[tuple] | None: """Return results for DbApiHook.run().""" if not hasattr(cursor, "description"): diff --git a/providers/common/sql/src/airflow/providers/common/sql/hooks/handlers.pyi b/providers/common/sql/src/airflow/providers/common/sql/hooks/handlers.pyi index daa577822b60c..b969f64c68e10 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/hooks/handlers.pyi +++ b/providers/common/sql/src/airflow/providers/common/sql/hooks/handlers.pyi @@ -37,5 +37,6 @@ from collections.abc import Iterable def return_single_query_results( sql: str | Iterable[str], return_last: bool, split_statements: bool | None ): ... +def get_row_count(cursor) -> int | None: ... def fetch_all_handler(cursor) -> list[tuple] | None: ... def fetch_one_handler(cursor) -> tuple | None: ... diff --git a/providers/common/sql/src/airflow/providers/common/sql/hooks/lineage.py b/providers/common/sql/src/airflow/providers/common/sql/hooks/lineage.py new file mode 100644 index 0000000000000..b8662486070c9 --- /dev/null +++ b/providers/common/sql/src/airflow/providers/common/sql/hooks/lineage.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from enum import Enum +from typing import TYPE_CHECKING, Any + +from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector +from airflow.providers.common.sql.hooks.handlers import get_row_count + +if TYPE_CHECKING: + from airflow.providers.common.compat.lineage.hook import LineageContext + + +log = logging.getLogger(__name__) + + +class SqlJobHookLineageExtra(str, Enum): + """ + Keys for the SQL job hook-level lineage extra entry. + + Reported via ``get_hook_lineage_collector().add_extra()``. ``KEY`` is the + extra entry key; ``VALUE__*`` are the keys inside the value dict (one entry + per SQL statement so job_id, SQL text, row count, default_db, etc. stay stitched). + """ + + KEY = "sql_job" + VALUE__SQL_STATEMENT = "sql" + VALUE__SQL_STATEMENT_PARAMETERS = "sql_parameters" + VALUE__JOB_ID = "job_id" + VALUE__ROW_COUNT = "row_count" + VALUE__DEFAULT_DB = "default_db" + VALUE__DEFAULT_SCHEMA = "default_schema" + VALUE__EXTRA = "extra" + + @classmethod + def value_keys(cls) -> tuple[SqlJobHookLineageExtra, ...]: + """Value-dict keys only (KEY excluded). Use when iterating or validating the value dict.""" + return ( + cls.VALUE__SQL_STATEMENT, + cls.VALUE__SQL_STATEMENT_PARAMETERS, + cls.VALUE__JOB_ID, + cls.VALUE__ROW_COUNT, + cls.VALUE__DEFAULT_DB, + cls.VALUE__DEFAULT_SCHEMA, + cls.VALUE__EXTRA, + ) + + +def send_sql_hook_lineage( + *, + context: LineageContext, + sql: str | list[str], + sql_parameters: Any = None, + cur: Any = None, + job_id: str | None = None, + row_count: int | None = None, + default_db: str | None = None, + default_schema: str | None = None, + extra: dict[str, Any] | None = None, +) -> None: + """ + Report a single SQL execution to the hook lineage collector. + + Call this after running a SQL statement so that hook lineage collectors can associate the execution + with the task. Each call produces one extra entry in the collector; when executing multiple statements + in one hook run, one should call this function separately for each sql job, so that job_id, SQL text, + row count, and other fields stay tied together per statement. + + Usable from any hook: pass the hook instance as ``context``. Not limited to + ``DbApiHook`` subclasses. + + :param context: Lineage context, typically the hook instance. Must be valid for + ``get_hook_lineage_collector().add_extra(context=..., ...)``. + :param sql: The SQL statement that was executed (or a representative string). + :param sql_parameters: Optional parameters bound to the statement. + :param cur: Optional DB-API cursor after execution. If given, job_id is taken + from ``query_id`` or ``sfqid`` when not provided explicitly, and row_count + from ``cur.rowcount`` when applicable (PEP 249). + :param job_id: Explicit job ID; used instead of cursor-derived value when set. + :param row_count: Explicit row count; used instead of cursor-derived value when set. + :param default_db: Default database/catalog name for this execution context. + :param default_schema: Default schema name for this execution context. + :param extra: Optional additional key-value data to attach to this lineage entry. + """ + try: + sql = "; ".join(sql) if isinstance(sql, list) else sql + value: dict[str, Any] = {SqlJobHookLineageExtra.VALUE__SQL_STATEMENT.value: sql} + if sql_parameters: + value[SqlJobHookLineageExtra.VALUE__SQL_STATEMENT_PARAMETERS.value] = sql_parameters + + # Get SQL job_id: either explicitly or from cursor + if job_id is not None: + value[SqlJobHookLineageExtra.VALUE__JOB_ID.value] = job_id + elif cur is not None: + for attr in ("query_id", "sfqid"): + if (cursor_job_id := getattr(cur, attr, None)) is not None: + value[SqlJobHookLineageExtra.VALUE__JOB_ID.value] = cursor_job_id + break + + # Get row count: either explicitly or from cursor + if row_count is None and cur is not None: + row_count = get_row_count(cur) + if row_count is not None and row_count >= 0: + value[SqlJobHookLineageExtra.VALUE__ROW_COUNT.value] = row_count + + if default_db is not None: + value[SqlJobHookLineageExtra.VALUE__DEFAULT_DB.value] = default_db + if default_schema is not None: + value[SqlJobHookLineageExtra.VALUE__DEFAULT_SCHEMA.value] = default_schema + if extra: + value[SqlJobHookLineageExtra.VALUE__EXTRA.value] = extra + + get_hook_lineage_collector().add_extra( + context=context, + key=SqlJobHookLineageExtra.KEY.value, + value=value, + ) + except Exception as e: + log.warning("Sending SQL hook level lineage failed: %s", f"{e.__class__.__name__}: {str(e)}") + log.debug("Exception details:", exc_info=True) diff --git a/providers/common/sql/src/airflow/providers/common/sql/hooks/lineage.pyi b/providers/common/sql/src/airflow/providers/common/sql/hooks/lineage.pyi new file mode 100644 index 0000000000000..876364231b74b --- /dev/null +++ b/providers/common/sql/src/airflow/providers/common/sql/hooks/lineage.pyi @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This is automatically generated stub for the `common.sql` provider +# +# This file is generated automatically by the `update-common-sql-api stubs` prek hook +# and the .pyi file represents part of the "public" API that the +# `common.sql` provider exposes to other providers. +# +# Any, potentially breaking change in the stubs will require deliberate manual action from the contributor +# making a change to the `common.sql` provider. Those stubs are also used by MyPy automatically when checking +# if only public API of the common.sql provider is used by all the other providers. +# +# You can read more in the README_API.md file +# +""" +Definition of the public interface for +airflow.providers.common.sql.src.airflow.providers.common.sql.hooks.lineage. +""" + +from enum import Enum +from typing import Any + +from airflow.providers.common.compat.lineage.hook import LineageContext + +class SqlJobHookLineageExtra(str, Enum): + KEY = "sql_job" + VALUE__SQL_STATEMENT = "sql" + VALUE__SQL_STATEMENT_PARAMETERS = "sql_parameters" + VALUE__JOB_ID = "job_id" + VALUE__ROW_COUNT = "row_count" + VALUE__DEFAULT_DB = "default_db" + VALUE__DEFAULT_SCHEMA = "default_schema" + VALUE__EXTRA = "extra" + @classmethod + def value_keys(cls) -> tuple[SqlJobHookLineageExtra, ...]: ... + +def send_sql_hook_lineage( + *, + context: LineageContext, + sql: str | list[str], + sql_parameters: Any = None, + cur: Any = None, + job_id: str | None = None, + row_count: int | None = None, + default_db: str | None = None, + default_schema: str | None = None, + extra: dict[str, Any] | None = None, +) -> None: ... diff --git a/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py b/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py index 3b439947596c0..db62736c78a8e 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py +++ b/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py @@ -52,6 +52,7 @@ ) from airflow.providers.common.sql.dialects.dialect import Dialect from airflow.providers.common.sql.hooks import handlers +from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage if TYPE_CHECKING: from pandas import DataFrame as PandasDataFrame @@ -469,9 +470,18 @@ def get_df( :param kwargs: (optional) passed into `pandas.io.sql.read_sql` or `polars.read_database` method """ if df_type == "pandas": - return self._get_pandas_df(sql, parameters, **kwargs) - if df_type == "polars": - return self._get_polars_df(sql, parameters, **kwargs) + result: PandasDataFrame | PolarsDataFrame = self._get_pandas_df(sql, parameters, **kwargs) + elif df_type == "polars": + result = self._get_polars_df(sql, parameters, **kwargs) + else: + raise ValueError(f"Unsupported df_type: {df_type}") + + send_sql_hook_lineage( + context=self, + sql=sql, + sql_parameters=parameters, + ) + return result def _get_pandas_df( self, @@ -568,9 +578,20 @@ def get_df_by_chunks( :param kwargs: (optional) passed into `pandas.io.sql.read_sql` or `polars.read_database` method """ if df_type == "pandas": - return self._get_pandas_df_by_chunks(sql, parameters, chunksize=chunksize, **kwargs) - if df_type == "polars": - return self._get_polars_df_by_chunks(sql, parameters, chunksize=chunksize, **kwargs) + result: Generator[PandasDataFrame | PolarsDataFrame, None, None] = self._get_pandas_df_by_chunks( + sql, parameters, chunksize=chunksize, **kwargs + ) + elif df_type == "polars": + result = self._get_polars_df_by_chunks(sql, parameters, chunksize=chunksize, **kwargs) + else: + raise ValueError(f"Unsupported df_type: {df_type}") + + send_sql_hook_lineage( + context=self, + sql=sql, + sql_parameters=parameters, + ) + return result def _get_pandas_df_by_chunks( self, @@ -836,9 +857,15 @@ def _run_command(self, cur, sql_statement, parameters): else: cur.execute(sql_statement) - # According to PEP 249, this is -1 when query result is not applicable. - if cur.rowcount >= 0: - self.log.info("Rows affected: %s", cur.rowcount) + send_sql_hook_lineage( + context=self, + sql=sql_statement, + sql_parameters=parameters, + cur=cur, + ) + + if (row_count := handlers.get_row_count(cur)) is not None: + self.log.info("Rows affected: %s", row_count) def set_autocommit(self, conn, autocommit): """Set the autocommit flag on the connection.""" @@ -928,6 +955,7 @@ def insert_rows( before executing the query. """ nb_rows = 0 + sql = None # not generated unless we actually process at least one chunk with self._create_autocommit_connection(autocommit) as conn: conn.commit() with closing(conn.cursor()) as cur: @@ -979,6 +1007,11 @@ def insert_rows( self.log.info("Loaded %s rows into %s so far", i, table) nb_rows += 1 conn.commit() + + if sql: + # We only send lineage once, not for each value collection, to save memory. + send_sql_hook_lineage(context=self, sql=sql, row_count=nb_rows) + self.log.info("Done loading. Loaded a total of %s rows into %s", nb_rows, table) @classmethod diff --git a/providers/common/sql/tests/unit/common/sql/hooks/test_dbapi.py b/providers/common/sql/tests/unit/common/sql/hooks/test_dbapi.py index 8127740b17025..8c187c43028d2 100644 --- a/providers/common/sql/tests/unit/common/sql/hooks/test_dbapi.py +++ b/providers/common/sql/tests/unit/common/sql/hooks/test_dbapi.py @@ -637,3 +637,58 @@ def test_insert_rows_with_executemany_correctly_logs_amount_of_commited_rows(sel assert any(f"Loaded 3 rows into {table} so far" in message for message in caplog.messages) assert any(f"Loaded 6 rows into {table} so far" in message for message in caplog.messages) assert any(f"Loaded 9 rows into {table} so far" in message for message in caplog.messages) + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_run_calls_send_sql_hook_lineage(self, mock_send_lineage): + statement = "SQL" + self.cur.fetchall.return_value = [] + + self.db_hook.run(statement) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == statement + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is self.cur + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_insert_rows_calls_send_sql_hook_lineage(self, mock_send_lineage): + table = "table" + rows = [("hello",), ("world",)] + + self.db_hook.insert_rows(table, rows) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == "INSERT INTO table VALUES (%s)" + assert call_kw["row_count"] == 2 + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df") + def test_get_df_calls_send_sql_hook_lineage(self, mock_get_pandas_df, mock_send_lineage): + sql = "SELECT 1" + params = ("x",) + + self.db_hook.get_df(sql, parameters=params) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == params + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") + def test_get_df_by_chunks_calls_send_sql_hook_lineage(self, mock_get_pandas_df_chunks, mock_send_lineage): + sql = "SELECT 1" + params = ("x",) + + self.db_hook.get_df_by_chunks(sql, parameters=params, chunksize=1) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == params diff --git a/providers/common/sql/tests/unit/common/sql/hooks/test_lineage.py b/providers/common/sql/tests/unit/common/sql/hooks/test_lineage.py new file mode 100644 index 0000000000000..a2fb0ba184926 --- /dev/null +++ b/providers/common/sql/tests/unit/common/sql/hooks/test_lineage.py @@ -0,0 +1,201 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from unittest import mock + +from airflow.providers.common.sql.hooks.lineage import ( + SqlJobHookLineageExtra, + send_sql_hook_lineage, +) + + +class TestSqlJobHookLineageExtra: + def test_key_value(self): + assert SqlJobHookLineageExtra.KEY.value == "sql_job" + + def test_value_keys_includes_all_value_members(self): + keys = SqlJobHookLineageExtra.value_keys() + assert len(keys) == 7 + assert keys == ( + SqlJobHookLineageExtra.VALUE__SQL_STATEMENT, + SqlJobHookLineageExtra.VALUE__SQL_STATEMENT_PARAMETERS, + SqlJobHookLineageExtra.VALUE__JOB_ID, + SqlJobHookLineageExtra.VALUE__ROW_COUNT, + SqlJobHookLineageExtra.VALUE__DEFAULT_DB, + SqlJobHookLineageExtra.VALUE__DEFAULT_SCHEMA, + SqlJobHookLineageExtra.VALUE__EXTRA, + ) + + +class TestSendSqlHookLineage: + """Test send_sql_hook_lineage calls get_hook_lineage_collector().add_extra with correct params.""" + + @mock.patch("airflow.providers.common.sql.hooks.lineage.get_hook_lineage_collector") + def test_add_extra_called_with_minimal_args(self, mock_get_collector): + mock_collector = mock.MagicMock() + mock_get_collector.return_value = mock_collector + mock_context = mock.MagicMock() + + send_sql_hook_lineage(context=mock_context, sql="SELECT 1") + + mock_collector.add_extra.assert_called_once() + call_kw = mock_collector.add_extra.call_args.kwargs + assert len(call_kw) == 3 + assert call_kw["context"] is mock_context + assert call_kw["key"] == "sql_job" + assert len(call_kw["value"]) == 1 + assert call_kw["value"] == {"sql": "SELECT 1"} + + @mock.patch("airflow.providers.common.sql.hooks.lineage.get_hook_lineage_collector") + def test_add_extra_called_with_sql_list_joined(self, mock_get_collector): + mock_collector = mock.MagicMock() + mock_get_collector.return_value = mock_collector + mock_context = mock.MagicMock() + + send_sql_hook_lineage(context=mock_context, sql=["SELECT 1", "SELECT 2"]) + + mock_collector.add_extra.assert_called_once() + call_kw = mock_collector.add_extra.call_args.kwargs + assert len(call_kw) == 3 + assert call_kw["context"] is mock_context + assert call_kw["key"] == "sql_job" + assert len(call_kw["value"]) == 1 + assert call_kw["value"] == {"sql": "SELECT 1; SELECT 2"} + + @mock.patch("airflow.providers.common.sql.hooks.lineage.get_hook_lineage_collector") + def test_add_extra_called_with_all_args_no_cursor(self, mock_get_collector): + mock_collector = mock.MagicMock() + mock_get_collector.return_value = mock_collector + mock_context = mock.MagicMock() + + send_sql_hook_lineage( + context=mock_context, + sql="INSERT INTO t VALUES (%s)", + sql_parameters=("x",), + job_id="job-123", + row_count=42, + default_db="mydb", + default_schema="myschema", + extra={"custom": "data"}, + ) + + mock_collector.add_extra.assert_called_once() + call_kw = mock_collector.add_extra.call_args.kwargs + assert len(call_kw) == 3 + assert call_kw["context"] is mock_context + assert call_kw["key"] == "sql_job" + value = call_kw["value"] + assert len(value) == 7 + assert value == { + "sql": "INSERT INTO t VALUES (%s)", + "sql_parameters": ("x",), + "job_id": "job-123", + "row_count": 42, + "default_db": "mydb", + "default_schema": "myschema", + "extra": {"custom": "data"}, + } + + @mock.patch("airflow.providers.common.sql.hooks.lineage.get_hook_lineage_collector") + def test_add_extra_job_id_from_cursor(self, mock_get_collector): + mock_collector = mock.MagicMock() + mock_get_collector.return_value = mock_collector + mock_context = mock.MagicMock() + mock_cur = mock.MagicMock() + mock_cur.query_id = "cursor-query-id" + + send_sql_hook_lineage(context=mock_context, sql="SELECT 1", cur=mock_cur) + + call_kw = mock_collector.add_extra.call_args.kwargs + assert len(call_kw) == 3 + assert call_kw["context"] is mock_context + assert call_kw["key"] == "sql_job" + value = call_kw["value"] + assert len(value) == 2 + assert value == {"sql": "SELECT 1", "job_id": "cursor-query-id"} + + @mock.patch("airflow.providers.common.sql.hooks.lineage.get_hook_lineage_collector") + def test_add_extra_row_count_from_cursor(self, mock_get_collector): + mock_collector = mock.MagicMock() + mock_get_collector.return_value = mock_collector + mock_context = mock.MagicMock() + mock_cur = mock.MagicMock() + mock_cur.rowcount = 10 + mock_cur.query_id = "123" + + send_sql_hook_lineage(context=mock_context, sql="SELECT 1", cur=mock_cur) + + call_kw = mock_collector.add_extra.call_args.kwargs + assert len(call_kw) == 3 + assert call_kw["context"] is mock_context + assert call_kw["key"] == "sql_job" + value = call_kw["value"] + assert len(value) == 3 + assert value == {"sql": "SELECT 1", "row_count": 10, "job_id": "123"} + + @mock.patch("airflow.providers.common.sql.hooks.lineage.get_hook_lineage_collector") + def test_explicit_job_id_overrides_cursor(self, mock_get_collector): + mock_collector = mock.MagicMock() + mock_get_collector.return_value = mock_collector + mock_context = mock.MagicMock() + mock_cur = mock.MagicMock() + mock_cur.query_id = "cursor-id" + + send_sql_hook_lineage(context=mock_context, sql="SELECT 1", cur=mock_cur, job_id="explicit-id") + + call_kw = mock_collector.add_extra.call_args.kwargs + assert len(call_kw) == 3 + assert call_kw["context"] is mock_context + assert call_kw["key"] == "sql_job" + value = call_kw["value"] + assert len(value) == 2 + assert value == {"sql": "SELECT 1", "job_id": "explicit-id"} + + @mock.patch("airflow.providers.common.sql.hooks.lineage.get_hook_lineage_collector") + def test_explicit_row_count_overrides_cursor(self, mock_get_collector): + mock_collector = mock.MagicMock() + mock_get_collector.return_value = mock_collector + mock_context = mock.MagicMock() + mock_cur = mock.MagicMock() + mock_cur.rowcount = 99 + del mock_cur.query_id + del mock_cur.sfqid + + send_sql_hook_lineage(context=mock_context, sql="SELECT 1", cur=mock_cur, row_count=1) + + call_kw = mock_collector.add_extra.call_args.kwargs + assert len(call_kw) == 3 + assert call_kw["context"] is mock_context + assert call_kw["key"] == "sql_job" + value = call_kw["value"] + assert len(value) == 2 + assert value == {"sql": "SELECT 1", "row_count": 1} + + @mock.patch("airflow.providers.common.sql.hooks.lineage.get_hook_lineage_collector") + def test_exception_is_swallowed_and_logged(self, mock_get_collector, caplog): + mock_collector = mock.MagicMock() + mock_collector.add_extra.side_effect = RuntimeError("collector broke") + mock_get_collector.return_value = mock_collector + + with caplog.at_level(logging.WARNING, logger="airflow.providers.common.sql.hooks.lineage"): + send_sql_hook_lineage(context=mock.MagicMock(), sql="SELECT 1") + + assert "Sending SQL hook level lineage failed" in caplog.text + assert "RuntimeError: collector broke" in caplog.text diff --git a/providers/databricks/pyproject.toml b/providers/databricks/pyproject.toml index 241ce0ba14658..5243119ef7c14 100644 --- a/providers/databricks/pyproject.toml +++ b/providers/databricks/pyproject.toml @@ -60,7 +60,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.13.0", - "apache-airflow-providers-common-sql>=1.27.0", + "apache-airflow-providers-common-sql>=1.27.0", # use next version "requests>=2.32.0,<3", "databricks-sql-connector>=4.0.0", "aiohttp>=3.9.2, <4", diff --git a/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py b/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py index 84c28c38a4377..e7b273973f369 100644 --- a/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py @@ -35,6 +35,7 @@ from requests import exceptions as requests_exceptions from airflow.providers.common.compat.sdk import AirflowException +from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook GET_CLUSTER_ENDPOINT = ("GET", "2.1/clusters/get") @@ -800,7 +801,17 @@ def post_sql_statement(self, json: dict[str, Any]) -> str: :return: The statement_id as a string. """ response = self._do_api_call(("POST", f"{SQL_STATEMENTS_ENDPOINT}"), json) - return response["statement_id"] + statement_id = response["statement_id"] + if (sql_statement := json.get("statement")) is not None: + send_sql_hook_lineage( + context=self, + sql=sql_statement, + sql_parameters=json.get("parameters"), + job_id=statement_id, + default_db=json.get("catalog"), + default_schema=json.get("schema"), + ) + return statement_id def get_sql_statement_state(self, statement_id: str) -> SQLStatementState: """ diff --git a/providers/databricks/tests/unit/databricks/hooks/test_databricks.py b/providers/databricks/tests/unit/databricks/hooks/test_databricks.py index 3971a096fab39..431dea7b95619 100644 --- a/providers/databricks/tests/unit/databricks/hooks/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks.py @@ -1226,6 +1226,31 @@ def test_post_sql_statement(self, mock_requests): timeout=self.hook.timeout_seconds, ) + @mock.patch("airflow.providers.databricks.hooks.databricks.send_sql_hook_lineage") + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") + def test_post_sql_statement_hook_lineage(self, mock_requests, mock_send_lineage): + mock_requests.post.return_value.json.return_value = { + "statement_id": "01f00ed2-04e2-15bd-a944-a8ae011dac69" + } + json_payload = { + "statement": "select * from test.test;", + "warehouse_id": WAREHOUSE_ID, + "catalog": "some_catalog", + "schema": "some_schema", + "parameters": {"a": 1}, + "wait_timeout": "0s", + } + self.hook.post_sql_statement(json_payload) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.hook + assert call_kw["sql"] == "select * from test.test;" + assert call_kw["job_id"] == "01f00ed2-04e2-15bd-a944-a8ae011dac69" + assert call_kw["sql_parameters"] == {"a": 1} + assert call_kw["default_db"] == "some_catalog" + assert call_kw["default_schema"] == "some_schema" + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") def test_get_sql_statement_state(self, mock_requests): mock_requests.codes.ok = 200 diff --git a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py index a25fc835d9f4f..98ea7e1d34779 100644 --- a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py +++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py @@ -386,6 +386,79 @@ def test_no_query(databricks_hook, empty_statement): databricks_hook.run(sql=empty_statement) +@mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") +def test_run_hook_lineage(mock_send_lineage, mock_get_conn, mock_get_requests): + """Ensure run() triggers send_sql_hook_lineage via base DbApiHook._run_command.""" + conn = mock.MagicMock() + cur = mock.MagicMock( + rowcount=1, + description=[("id",), ("value",)], + ) + cur.fetchall.return_value = [Row(id=1, value=2)] + conn.cursor.return_value = cur + mock_get_conn.return_value = conn + + sql = "SELECT 1" + hook = DatabricksSqlHook(sql_endpoint_name="Test") + hook.run(sql=sql, handler=fetch_all_handler) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is cur + + +@mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") +def test_insert_rows_hook_lineage(mock_send_lineage, mock_get_conn): + conn = mock.MagicMock() + cur = mock.MagicMock(rowcount=0) + conn.cursor.return_value = cur + mock_get_conn.return_value = conn + + table = "table" + rows = [("hello",), ("world",)] + hook = DatabricksSqlHook(sql_endpoint_name="Test") + hook.insert_rows(table, rows) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == "INSERT INTO table VALUES (%s)" + assert call_kw["row_count"] == 2 + + +@mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") +@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df") +def test_get_df_hook_lineage(mock_get_pandas_df, mock_send_lineage, mock_get_conn): + sql = "SELECT 1" + parameters = ("x",) + hook = DatabricksSqlHook(sql_endpoint_name="Test") + hook.get_df(sql, parameters=parameters) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + + +@mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") +@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") +def test_get_df_by_chunks_hook_lineage(mock_get_pandas_df_by_chunks, mock_send_lineage, mock_get_conn): + sql = "SELECT 1" + parameters = ("x",) + hook = DatabricksSqlHook(sql_endpoint_name="Test") + hook.get_df_by_chunks(sql, parameters=parameters, chunksize=1) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + + @pytest.mark.parametrize( ("row_objects", "fields_names"), [ diff --git a/providers/elasticsearch/pyproject.toml b/providers/elasticsearch/pyproject.toml index 7404f93ee4770..9894c9096a7d1 100644 --- a/providers/elasticsearch/pyproject.toml +++ b/providers/elasticsearch/pyproject.toml @@ -60,7 +60,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.12.0", - "apache-airflow-providers-common-sql>=1.27.0", + "apache-airflow-providers-common-sql>=1.27.0", # use next version "elasticsearch>=8.10,<9", ] diff --git a/providers/elasticsearch/tests/unit/elasticsearch/hooks/test_elasticsearch.py b/providers/elasticsearch/tests/unit/elasticsearch/hooks/test_elasticsearch.py index 095ed14200f8f..78c141a9baa05 100644 --- a/providers/elasticsearch/tests/unit/elasticsearch/hooks/test_elasticsearch.py +++ b/providers/elasticsearch/tests/unit/elasticsearch/hooks/test_elasticsearch.py @@ -190,6 +190,43 @@ def test_run(self): self.spy_agency.assert_spy_called(self.cur.close) self.spy_agency.assert_spy_called(self.cur.execute) + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_run_hook_lineage(self, mock_send_lineage): + statement = "SELECT * FROM hollywood.actors" + self.db_hook.run(statement, handler=fetch_all_handler) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == statement + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is self.cur + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df") + def test_get_df_hook_lineage(self, mock_get_pandas_df, mock_send_lineage): + statement = "SELECT 1" + self.db_hook.get_df(statement, df_type="pandas") + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == statement + assert call_kw["sql_parameters"] is None + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") + def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_send_lineage): + sql = "SELECT 1" + parameters = ("x",) + self.db_hook.get_df_by_chunks(sql, parameters=parameters, chunksize=1) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + @mock.patch("airflow.providers.elasticsearch.hooks.elasticsearch.Elasticsearch") def test_execute_sql_query(self, mock_es): mock_es_sql_client = MagicMock() diff --git a/providers/exasol/pyproject.toml b/providers/exasol/pyproject.toml index a8bd509c2c094..6a6337ec4b099 100644 --- a/providers/exasol/pyproject.toml +++ b/providers/exasol/pyproject.toml @@ -60,7 +60,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.12.0", - "apache-airflow-providers-common-sql>=1.26.0", + "apache-airflow-providers-common-sql>=1.26.0", # use next version "pyexasol>=0.26.0", 'pandas>=2.1.2; python_version <"3.13"', 'pandas>=2.2.3; python_version >="3.13"', diff --git a/providers/exasol/src/airflow/providers/exasol/hooks/exasol.py b/providers/exasol/src/airflow/providers/exasol/hooks/exasol.py index 9cda81368939d..0d418bc798ea9 100644 --- a/providers/exasol/src/airflow/providers/exasol/hooks/exasol.py +++ b/providers/exasol/src/airflow/providers/exasol/hooks/exasol.py @@ -33,6 +33,7 @@ from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException from airflow.providers.common.sql.hooks.handlers import return_single_query_results +from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage from airflow.providers.common.sql.hooks.sql import DbApiHook if TYPE_CHECKING: @@ -188,6 +189,12 @@ def get_records( :param parameters: The parameters to render the SQL query with. """ with closing(self.get_conn()) as conn, closing(conn.execute(sql, parameters)) as cur: + send_sql_hook_lineage( + context=self, + sql=sql, + sql_parameters=parameters, + cur=cur, + ) return cur.fetchall() def get_first(self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] | None = None) -> Any: @@ -199,6 +206,12 @@ def get_first(self, sql: str | list[str], parameters: Iterable | Mapping[str, An :param parameters: The parameters to render the SQL query with. """ with closing(self.get_conn()) as conn, closing(conn.execute(sql, parameters)) as cur: + send_sql_hook_lineage( + context=self, + sql=sql, + sql_parameters=parameters, + cur=cur, + ) return cur.fetchone() def export_to_file( @@ -332,6 +345,13 @@ def run( results.append(result) self.descriptions.append(self.get_description(exa_statement)) self.log.info("Rows affected: %s", exa_statement.rowcount()) + rc = exa_statement.rowcount() + send_sql_hook_lineage( + context=self, + sql=sql_statement, + sql_parameters=parameters, + row_count=rc if rc is not None and rc >= 0 else None, + ) # If autocommit was set to False or db does not support autocommit, we do a manual commit. if not self.get_autocommit(conn): diff --git a/providers/exasol/tests/unit/exasol/hooks/test_exasol.py b/providers/exasol/tests/unit/exasol/hooks/test_exasol.py index e94fe4181f05d..098a0ff3292fe 100644 --- a/providers/exasol/tests/unit/exasol/hooks/test_exasol.py +++ b/providers/exasol/tests/unit/exasol/hooks/test_exasol.py @@ -216,6 +216,65 @@ def test_run_no_queries(self): with pytest.raises(ValueError, match="List of SQL statements is empty"): self.db_hook.run(sql=[]) + @mock.patch("airflow.providers.exasol.hooks.exasol.send_sql_hook_lineage") + def test_run_calls_hook_lineage(self, mock_send_lineage): + sql = "SELECT 1" + self.db_hook.run(sql, autocommit=True) + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] is None + assert call_kw["row_count"] == 0 + + @mock.patch("airflow.providers.exasol.hooks.exasol.send_sql_hook_lineage") + def test_get_records_hook_lineage(self, mock_send_lineage): + sql = "SELECT 1" + self.db_hook.get_records(sql) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is self.cur + + @mock.patch("airflow.providers.exasol.hooks.exasol.send_sql_hook_lineage") + def test_get_first_hook_lineage(self, mock_send_lineage): + sql = "SELECT 1" + self.db_hook.get_first(sql) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is self.cur + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_get_df_hook_lineage(self, mock_send_lineage): + statement = "SQL" + self.db_hook.get_df(statement, df_type="pandas") + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == statement + assert call_kw["sql_parameters"] is None + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") + def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_send_lineage): + statement = "SELECT 1" + parameters = ("x",) + self.db_hook.get_df_by_chunks(statement, parameters=parameters, chunksize=1) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == statement + assert call_kw["sql_parameters"] == parameters + def test_no_result_set(self): """Queries like DROP and SELECT are of type rowCount (not resultSet), which raises an error in pyexasol if trying to iterate over them""" diff --git a/providers/google/pyproject.toml b/providers/google/pyproject.toml index f274638b5660f..64c97f4c4be47 100644 --- a/providers/google/pyproject.toml +++ b/providers/google/pyproject.toml @@ -60,7 +60,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.13.0", - "apache-airflow-providers-common-sql>=1.27.0", + "apache-airflow-providers-common-sql>=1.27.0", # use next version "asgiref>=3.5.2", "dill>=0.2.3", "gcloud-aio-auth>=5.2.0", diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py index 3c9c6c07077c8..1963ff93a0b75 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py @@ -62,6 +62,7 @@ from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector from airflow.providers.common.compat.sdk import AirflowException, AirflowOptionalProviderFeatureException +from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.google.cloud.utils.bigquery import bq_cast from airflow.providers.google.cloud.utils.credentials_provider import _get_scopes @@ -375,11 +376,19 @@ def get_df( defaults to use `self.use_legacy_sql` if not specified :param kwargs: (optional) passed into pandas_gbq.read_gbq method """ - if df_type == "polars": - return self._get_polars_df(sql, parameters, dialect, **kwargs) - if df_type == "pandas": - return self._get_pandas_df(sql, parameters, dialect, **kwargs) + result: pd.DataFrame | pl.DataFrame = self._get_pandas_df(sql, parameters, dialect, **kwargs) + elif df_type == "polars": + result = self._get_polars_df(sql, parameters, dialect, **kwargs) + else: + raise ValueError(f"Unsupported df_type: {df_type}") + + send_sql_hook_lineage( + context=self, + sql=sql, + sql_parameters=parameters, + ) + return result @deprecated( planned_removal_date="November 30, 2025", @@ -713,6 +722,15 @@ def insert_all( ignore_unknown_values=ignore_unknown_values, skip_invalid_rows=skip_invalid_rows, ) + get_hook_lineage_collector().add_output_asset( + context=self, + scheme="bigquery", + asset_kwargs={ + "project_id": table.project, + "dataset_id": table.dataset_id, + "table_id": table.table_id, + }, + ) if errors: error_msg = f"{len(errors)} insert error(s) occurred. Details: {errors}" self.log.error(error_msg) @@ -1015,14 +1033,24 @@ def list_rows( table_id=table_id, ) + table_object = Table.from_api_repr(table) iterator = self.get_client(project_id=project_id, location=location).list_rows( - table=Table.from_api_repr(table), + table=table_object, selected_fields=selected_fields_sequence, max_results=max_results, page_token=page_token, start_index=start_index, retry=retry, ) + get_hook_lineage_collector().add_input_asset( + context=self, + scheme="bigquery", + asset_kwargs={ + "project_id": table_object.project, + "dataset_id": table_object.dataset_id, + "table_id": table_object.table_id, + }, + ) if return_iterator: return iterator return list(iterator) @@ -1301,8 +1329,20 @@ def insert_job( else: # Start the job and wait for it to complete and get the result. job_api_repr.result(timeout=timeout, retry=retry) + + self._send_hook_level_lineage_for_bq_job(job=job_api_repr) + return job_api_repr + def _send_hook_level_lineage_for_bq_job(self, job): + # TODO(kacpermuda) Add support for other job types and more params to sql job + if job.job_type == QueryJob.job_type: + send_sql_hook_lineage( + context=self, + sql=job.query, + job_id=job.job_id, + ) + def generate_job_id( self, job_id: str | None, diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/spanner.py b/providers/google/src/airflow/providers/google/cloud/hooks/spanner.py index e52f22de9af38..b8be5e1245e4f 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/spanner.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/spanner.py @@ -28,6 +28,7 @@ from sqlalchemy import create_engine from airflow.providers.common.compat.sdk import AirflowException +from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.google.common.consts import CLIENT_INFO from airflow.providers.google.common.hooks.base_google import GoogleBaseHook, get_field @@ -421,6 +422,11 @@ def _tx_runner(tx: Transaction) -> dict[str, int]: preview = sql if len(sql) <= 300 else sql[:300] + "…" self.log.info("[DML %d/%d] affected rows=%d | %s", i, len(result), rc, preview) result_rows_count_per_query.append(rc) + send_sql_hook_lineage( + context=self, + sql=sql, + row_count=rc, + ) return result_rows_count_per_query @staticmethod diff --git a/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py b/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py index 6d55aabd34a99..a2cd8894e4f23 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py @@ -1951,3 +1951,152 @@ def test_delete_table_collects_assets(self, mock_bq_client, table_id, project_id assert hook_lineage_collector.collected_assets.inputs[0].asset == Asset( uri=f"bigquery://{PROJECT_ID}/{DATASET_ID}/{TABLE_ID}" ) + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") + def test_create_table_collects_assets(self, mock_bq_client, hook_lineage_collector): + mock_bq_client.return_value.create_table.return_value = Table(TABLE_REFERENCE) + + self.hook.create_table( + dataset_id=DATASET_ID, + table_id=TABLE_ID, + table_resource={"tableReference": TABLE_REFERENCE_REPR}, + ) + + assert len(hook_lineage_collector.collected_assets.inputs) == 0 + assert len(hook_lineage_collector.collected_assets.outputs) == 1 + assert hook_lineage_collector.collected_assets.outputs[0].asset == Asset( + uri=f"bigquery://{PROJECT_ID}/{DATASET_ID}/{TABLE_ID}" + ) + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") + def test_insert_all_collects_assets(self, mock_bq_client, hook_lineage_collector): + mock_bq_client.return_value.get_table.return_value = Table(TABLE_REFERENCE) + mock_bq_client.return_value.insert_rows.return_value = [] + + self.hook.insert_all( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + rows=[{"json": {"a_key": "a_value"}}], + ) + + assert len(hook_lineage_collector.collected_assets.inputs) == 0 + assert len(hook_lineage_collector.collected_assets.outputs) == 1 + assert hook_lineage_collector.collected_assets.outputs[0].asset == Asset( + uri=f"bigquery://{PROJECT_ID}/{DATASET_ID}/{TABLE_ID}" + ) + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") + def test_list_rows_collects_assets(self, mock_bq_client, hook_lineage_collector): + mock_bq_client.return_value.list_rows.return_value = _EmptyRowIterator() + + self.hook.list_rows( + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + + assert len(hook_lineage_collector.collected_assets.inputs) == 1 + assert len(hook_lineage_collector.collected_assets.outputs) == 0 + assert hook_lineage_collector.collected_assets.inputs[0].asset == Asset( + uri=f"bigquery://{PROJECT_ID}/{DATASET_ID}/{TABLE_ID}" + ) + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_conn") + def test_run_hook_lineage(self, mock_get_conn, mock_send_lineage): + mock_cur = mock.MagicMock() + mock_cur.rowcount = 0 + mock_conn = mock.MagicMock() + mock_conn.cursor.return_value = mock_cur + mock_conn.autocommit = True + mock_get_conn.return_value = mock_conn + + sql = "SELECT 1" + self.hook.run(sql, autocommit=True) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is mock_cur + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_conn") + def test_run_hook_lineage_with_parameters(self, mock_get_conn, mock_send_lineage): + mock_cur = mock.MagicMock() + mock_cur.rowcount = 0 + mock_conn = mock.MagicMock() + mock_conn.cursor.return_value = mock_cur + mock_conn.autocommit = True + mock_get_conn.return_value = mock_conn + + sql = "SELECT 1" + parameters = ("x",) + self.hook.run(sql, parameters=parameters, autocommit=True) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + assert call_kw["cur"] is mock_cur + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.send_sql_hook_lineage") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.read_gbq") + def test_get_df_hook_lineage(self, mock_read_gbq, mock_send_lineage): + mock_read_gbq.return_value = mock.MagicMock() + sql = "select 1" + parameters = {"x": 1} + self.hook.get_df(sql, parameters=parameters, df_type="pandas") + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") + def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_send_lineage): + sql = "SELECT 1" + parameters = ("x",) + self.hook.get_df_by_chunks(sql, parameters=parameters, chunksize=1) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.send_sql_hook_lineage") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.QueryJob") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client") + def test_insert_job_hook_lineage(self, mock_client, mock_query_job, mock_send_lineage): + query_job_type = "query" + job_conf = { + query_job_type: { + query_job_type: "SELECT * FROM test", + "useLegacySql": "False", + } + } + mock_query_job._JOB_TYPE = query_job_type + mock_query_job.job_type = query_job_type + mock_job_instance = mock.MagicMock() + mock_job_instance.job_id = JOB_ID + mock_job_instance.query = "SELECT * FROM test" + mock_job_instance.job_type = query_job_type + mock_query_job.from_api_repr.return_value = mock_job_instance + + self.hook.insert_job( + configuration=job_conf, + job_id=JOB_ID, + project_id=PROJECT_ID, + location=LOCATION, + nowait=True, + ) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.hook + assert call_kw["sql"] == "SELECT * FROM test" + assert call_kw["job_id"] == JOB_ID diff --git a/providers/google/tests/unit/google/cloud/hooks/test_spanner.py b/providers/google/tests/unit/google/cloud/hooks/test_spanner.py index 070db8fe9f2fa..68d108d9d13b8 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_spanner.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_spanner.py @@ -436,6 +436,32 @@ def test_execute_dml_overridden_project_id(self, get_client): def test_execute_dml_oqueries_row_count(self, get_client): pass + @mock.patch("airflow.providers.google.cloud.hooks.spanner.send_sql_hook_lineage") + @mock.patch( + "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id", + new_callable=PropertyMock, + return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST, + ) + @mock.patch("airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client") + def test_execute_dml_hook_lineage(self, get_client, mock_project_id, mock_send_lineage): + instance_method = get_client.return_value.instance + database_method = instance_method.return_value.database + run_in_tx = database_method.return_value.run_in_transaction + run_in_tx.return_value = OrderedDict([("INSERT INTO T VALUES (1)", 1)]) + + self.spanner_hook_default_project_id.execute_dml( + instance_id=SPANNER_INSTANCE, + database_id=SPANNER_DATABASE, + queries=["INSERT INTO T VALUES (1)"], + project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, + ) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.spanner_hook_default_project_id + assert call_kw["sql"] == "INSERT INTO T VALUES (1)" + assert call_kw["row_count"] == 1 + @pytest.mark.parametrize( ("returned_items", "expected_counts"), [ diff --git a/providers/jdbc/pyproject.toml b/providers/jdbc/pyproject.toml index 323f2158ac7e6..d05e9997c09cc 100644 --- a/providers/jdbc/pyproject.toml +++ b/providers/jdbc/pyproject.toml @@ -59,8 +59,8 @@ requires-python = ">=3.10" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=2.11.0", - "apache-airflow-providers-common-compat>=1.10.1", #use next version - "apache-airflow-providers-common-sql>=1.20.0", + "apache-airflow-providers-common-compat>=1.10.1", # use next version + "apache-airflow-providers-common-sql>=1.20.0", # use next version "jaydebeapi>=1.1.1", ] diff --git a/providers/jdbc/tests/unit/jdbc/hooks/test_jdbc.py b/providers/jdbc/tests/unit/jdbc/hooks/test_jdbc.py index 30df0fdfb6d01..ff0f77b07bcd7 100644 --- a/providers/jdbc/tests/unit/jdbc/hooks/test_jdbc.py +++ b/providers/jdbc/tests/unit/jdbc/hooks/test_jdbc.py @@ -328,6 +328,61 @@ def call_get_conn(): assert mock_connect.call_count == 10 + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @patch("airflow.providers.jdbc.hooks.jdbc.jaydebeapi.connect", autospec=True, return_value=jdbc_conn_mock) + def test_run_hook_lineage(self, jdbc_mock, mock_send_lineage): + hook = get_hook() + jdbc_conn_mock.cursor.return_value.rowcount = 0 + sql = "SELECT 1" + hook.run(sql) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] is None + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @patch("airflow.providers.jdbc.hooks.jdbc.jaydebeapi.connect", autospec=True, return_value=jdbc_conn_mock) + def test_insert_rows_hook_lineage(self, jdbc_mock, mock_send_lineage): + hook = get_hook() + table = "table" + rows = [("hello",), ("world",)] + hook.insert_rows(table, rows) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == "INSERT INTO table VALUES (%s)" + assert call_kw["row_count"] == 2 + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df") + def test_get_df_hook_lineage(self, mock_get_pandas_df, mock_send_lineage): + hook = get_hook() + sql = "SELECT 1" + hook.get_df(sql, df_type="pandas") + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] is None + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") + def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_send_lineage): + hook = get_hook() + sql = "SELECT 1" + parameters = ("x",) + hook.get_df_by_chunks(sql, parameters=parameters, chunksize=1) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + @pytest.mark.parametrize( ("params", "expected_uri"), [ diff --git a/providers/microsoft/mssql/pyproject.toml b/providers/microsoft/mssql/pyproject.toml index bba94c828e52c..9a35ed62ace75 100644 --- a/providers/microsoft/mssql/pyproject.toml +++ b/providers/microsoft/mssql/pyproject.toml @@ -59,7 +59,7 @@ requires-python = ">=3.10" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=2.11.0", - "apache-airflow-providers-common-sql>=1.23.0", + "apache-airflow-providers-common-sql>=1.23.0", # use next version "pymssql>=2.3.5", "methodtools>=0.4.7", ] diff --git a/providers/microsoft/mssql/tests/unit/microsoft/mssql/hooks/test_mssql.py b/providers/microsoft/mssql/tests/unit/microsoft/mssql/hooks/test_mssql.py index 24972cd60db07..e547b54d22487 100644 --- a/providers/microsoft/mssql/tests/unit/microsoft/mssql/hooks/test_mssql.py +++ b/providers/microsoft/mssql/tests/unit/microsoft/mssql/hooks/test_mssql.py @@ -191,6 +191,89 @@ def test_sqlalchemy_scheme_is_default(self, get_connection, mssql_connections): hook = MsSqlHook() assert hook.sqlalchemy_scheme == hook.DEFAULT_SQLALCHEMY_SCHEME + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_conn") + @mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection") + def test_run_hook_lineage(self, get_connection, mssql_get_conn, mock_send_lineage, mssql_connections): + get_connection.return_value = mssql_connections["default"] + cur = mock.MagicMock(rowcount=0) + cur.fetchall.return_value = [] + conn = mock.MagicMock() + conn.cursor.return_value = cur + mssql_get_conn.return_value = conn + + hook = MsSqlHook() + statement = "SELECT 1" + hook.run(statement) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == statement + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is cur + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_conn") + @mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection") + def test_insert_rows_hook_lineage( + self, get_connection, mssql_get_conn, mock_send_lineage, mssql_connections + ): + get_connection.return_value = mssql_connections["default"] + cur = mock.MagicMock(rowcount=0) + conn = mock.MagicMock() + conn.cursor.return_value = cur + mssql_get_conn.return_value = conn + + hook = MsSqlHook() + table = "table" + rows = [("hello",), ("world",)] + hook.insert_rows(table, rows) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == "INSERT INTO table VALUES (%s)" + assert call_kw["row_count"] == 2 + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df") + @mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection") + def test_get_df_hook_lineage( + self, get_connection, mock_get_pandas_df, mock_send_lineage, mssql_connections + ): + get_connection.return_value = mssql_connections["default"] + + hook = MsSqlHook() + sql = "SELECT 1" + parameters = ("x",) + hook.get_df(sql, parameters=parameters) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") + @mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection") + def test_get_df_by_chunks_hook_lineage( + self, get_connection, mock_get_pandas_df_by_chunks, mock_send_lineage, mssql_connections + ): + get_connection.return_value = mssql_connections["default"] + + hook = MsSqlHook() + sql = "SELECT 1" + parameters = ("x",) + hook.get_df_by_chunks(sql, parameters=parameters, chunksize=1) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + def test_sqlalchemy_scheme_is_from_hook(self): hook = MsSqlHook(sqlalchemy_scheme="mssql+mytestdriver") assert hook.sqlalchemy_scheme == "mssql+mytestdriver" diff --git a/providers/mysql/pyproject.toml b/providers/mysql/pyproject.toml index 45e9495a7d1ca..9ff8113b27623 100644 --- a/providers/mysql/pyproject.toml +++ b/providers/mysql/pyproject.toml @@ -60,7 +60,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.12.0", - "apache-airflow-providers-common-sql>=1.20.0", + "apache-airflow-providers-common-sql>=1.20.0", # use next version # The mysqlclient package creates friction when installing on MacOS as it needs pkg-config to # Install and compile, and it's really only used by MySQL provider, so we can skip it on MacOS # Instead, if someone attempts to use it on MacOS, they will get explanatory error on how to install it diff --git a/providers/mysql/src/airflow/providers/mysql/hooks/mysql.py b/providers/mysql/src/airflow/providers/mysql/hooks/mysql.py index 99061d1521e2e..40d597325302c 100644 --- a/providers/mysql/src/airflow/providers/mysql/hooks/mysql.py +++ b/providers/mysql/src/airflow/providers/mysql/hooks/mysql.py @@ -25,6 +25,7 @@ from urllib.parse import quote_plus, urlencode from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException +from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage from airflow.providers.common.sql.hooks.sql import DbApiHook logger = logging.getLogger(__name__) @@ -248,11 +249,19 @@ def bulk_load(self, table: str, tmp_file: str) -> None: if not re.fullmatch(r"^[a-zA-Z0-9_.]+$", table): raise ValueError(f"Invalid table name: {table}") + sql_statement = f"LOAD DATA LOCAL INFILE %s INTO TABLE `{table}`" + parameters = (tmp_file,) cur.execute( - f"LOAD DATA LOCAL INFILE %s INTO TABLE `{table}`", - (tmp_file,), + sql_statement, + parameters, ) conn.commit() + send_sql_hook_lineage( + context=self, + sql=sql_statement, + sql_parameters=parameters, + cur=cur, + ) conn.close() def bulk_dump(self, table: str, tmp_file: str) -> None: @@ -265,11 +274,14 @@ def bulk_dump(self, table: str, tmp_file: str) -> None: if not re.fullmatch(r"^[a-zA-Z0-9_.]+$", table): raise ValueError(f"Invalid table name: {table}") + sql_statement = f"SELECT * INTO OUTFILE %s FROM `{table}`" + parameters = (tmp_file,) cur.execute( - f"SELECT * INTO OUTFILE %s FROM `{table}`", - (tmp_file,), + sql_statement, + parameters, ) conn.commit() + send_sql_hook_lineage(context=self, sql=sql_statement, sql_parameters=parameters, cur=cur) conn.close() @staticmethod @@ -330,11 +342,14 @@ def bulk_load_custom( conn = self.get_conn() cursor = conn.cursor() + sql_statement = f"LOAD DATA LOCAL INFILE %s %s INTO TABLE `{table}` %s" + parameters = (tmp_file, duplicate_key_handling, extra_options) cursor.execute( - f"LOAD DATA LOCAL INFILE %s %s INTO TABLE `{table}` %s", - (tmp_file, duplicate_key_handling, extra_options), + sql_statement, + parameters, ) + send_sql_hook_lineage(context=self, sql=sql_statement, sql_parameters=parameters, cur=cursor) cursor.close() conn.commit() conn.close() diff --git a/providers/mysql/tests/unit/mysql/hooks/test_mysql.py b/providers/mysql/tests/unit/mysql/hooks/test_mysql.py index 3e5a731d79d5b..35b3df6976360 100644 --- a/providers/mysql/tests/unit/mysql/hooks/test_mysql.py +++ b/providers/mysql/tests/unit/mysql/hooks/test_mysql.py @@ -408,16 +408,90 @@ def test_run_multi_queries(self): self.cur.execute.assert_has_calls(calls, any_order=True) self.conn.commit.assert_not_called() + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_run_hook_lineage(self, mock_send_lineage): + statement = "SELECT 1" + self.cur.fetchall.return_value = [] + + self.db_hook.run(statement) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == statement + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is self.cur + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_insert_rows_hook_lineage(self, mock_send_lineage): + table = "table" + rows = [("hello",), ("world",)] + + self.db_hook.insert_rows(table, rows) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == "INSERT INTO table VALUES (%s)" + assert call_kw["row_count"] == 2 + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df") + def test_get_df_hook_lineage(self, mock_get_pandas_df, mock_send_lineage): + sql = "SELECT 1" + parameters = ("x",) + self.db_hook.get_df(sql, parameters=parameters) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") + def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_send_lineage): + sql = "SELECT 1" + parameters = ("x",) + self.db_hook.get_df_by_chunks(sql, parameters=parameters, chunksize=1) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + def test_bulk_load(self): self.db_hook.bulk_load("table", "/tmp/file") self.cur.execute.assert_called_once_with( "LOAD DATA LOCAL INFILE %s INTO TABLE `table`", ("/tmp/file",) ) + @mock.patch("airflow.providers.mysql.hooks.mysql.send_sql_hook_lineage") + def test_bulk_load_hook_lineage(self, mock_send_lineage): + self.db_hook.bulk_load("table", "/tmp/file") + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == "LOAD DATA LOCAL INFILE %s INTO TABLE `table`" + assert call_kw["sql_parameters"] == ("/tmp/file",) + assert call_kw["cur"] is self.cur + def test_bulk_dump(self): self.db_hook.bulk_dump("table", "/tmp/file") self.cur.execute.assert_called_once_with("SELECT * INTO OUTFILE %s FROM `table`", ("/tmp/file",)) + @mock.patch("airflow.providers.mysql.hooks.mysql.send_sql_hook_lineage") + def test_bulk_dump_hook_lineage(self, mock_send_lineage): + self.db_hook.bulk_dump("table", "/tmp/file") + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == "SELECT * INTO OUTFILE %s FROM `table`" + assert call_kw["sql_parameters"] == ("/tmp/file",) + assert call_kw["cur"] is self.cur + def test_serialize_cell(self): assert self.db_hook._serialize_cell("foo", None) == "foo" @@ -442,6 +516,21 @@ def test_bulk_load_custom(self, table): ), ) + @mock.patch("airflow.providers.mysql.hooks.mysql.send_sql_hook_lineage") + def test_bulk_load_custom_hook_lineage(self, mock_send_lineage): + self.db_hook.bulk_load_custom( + "table", + "/tmp/file", + "IGNORE", + "FIELDS TERMINATED BY ';'", + ) + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == "LOAD DATA LOCAL INFILE %s %s INTO TABLE `table` %s" + assert call_kw["sql_parameters"] == ("/tmp/file", "IGNORE", "FIELDS TERMINATED BY ';'") + assert call_kw["cur"] is self.cur + def test_reserved_words(self): hook = MySqlHook() assert hook.reserved_words == sqlalchemy.dialects.mysql.reserved_words.RESERVED_WORDS_MYSQL diff --git a/providers/odbc/pyproject.toml b/providers/odbc/pyproject.toml index 72e9c090c11c4..0d0908f0f7d9f 100644 --- a/providers/odbc/pyproject.toml +++ b/providers/odbc/pyproject.toml @@ -60,7 +60,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.12.0", - "apache-airflow-providers-common-sql>=1.20.0", + "apache-airflow-providers-common-sql>=1.20.0", # use next version "pyodbc>=5.0.0; python_version < '3.13'", "pyodbc>=5.2.0; python_version >= '3.13'", ] diff --git a/providers/odbc/tests/unit/odbc/hooks/test_odbc.py b/providers/odbc/tests/unit/odbc/hooks/test_odbc.py index 8c7df6aae60fa..ed55eaf8e94cd 100644 --- a/providers/odbc/tests/unit/odbc/hooks/test_odbc.py +++ b/providers/odbc/tests/unit/odbc/hooks/test_odbc.py @@ -367,6 +367,45 @@ def test_dialect_name_when_resolved_from_dialect_in_extra(self): hook.get_uri = raise_argument_error assert hook.dialect_name == "oracle" + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_run_hook_lineage(self, mock_send_lineage): + hook = mock_db_hook(OdbcHook) + sql = "SELECT 1" + hook.run(sql) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] is None + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df") + def test_get_df_hook_lineage(self, mock_get_pandas_df, mock_send_lineage): + hook = mock_db_hook(OdbcHook) + sql = "SELECT 1" + hook.get_df(sql, df_type="pandas") + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] is None + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") + def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_send_lineage): + hook = mock_db_hook(OdbcHook) + sql = "SELECT 1" + parameters = ("x",) + hook.get_df_by_chunks(sql, parameters=parameters, chunksize=1) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + def test_get_sqlalchemy_engine_verify_creator_is_being_used(self): hook = mock_db_hook(OdbcHook, conn_params={"extra": {"sqlalchemy_scheme": "sqlite"}}) diff --git a/providers/oracle/pyproject.toml b/providers/oracle/pyproject.toml index 516f0025306e0..74eda14aceb2b 100644 --- a/providers/oracle/pyproject.toml +++ b/providers/oracle/pyproject.toml @@ -60,7 +60,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.8.0", - "apache-airflow-providers-common-sql>=1.20.0", + "apache-airflow-providers-common-sql>=1.20.0", # use next version "oracledb>=2.0.0", ] diff --git a/providers/oracle/src/airflow/providers/oracle/hooks/oracle.py b/providers/oracle/src/airflow/providers/oracle/hooks/oracle.py index 186ed44a2b2a6..6bd444f460aa9 100644 --- a/providers/oracle/src/airflow/providers/oracle/hooks/oracle.py +++ b/providers/oracle/src/airflow/providers/oracle/hooks/oracle.py @@ -29,6 +29,7 @@ from airflow.models.connection import Connection from airflow.providers.openlineage.sqlparser import DatabaseInfo +from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.oracle.hooks import handlers @@ -362,6 +363,7 @@ def insert_rows( self.set_autocommit(conn, False) cur = conn.cursor() i = 0 + sql = None # not generated unless we actually process at least one chunk for row in rows: i += 1 lst = [] @@ -383,6 +385,11 @@ def insert_rows( conn.commit() self.log.info("Loaded %s into %s rows so far", i, table) conn.commit() + + if sql: + # We only send lineage once, not for each value collection, to save memory. + send_sql_hook_lineage(context=self, sql=sql, row_count=i) + cur.close() conn.close() self.log.info("Done loading. Loaded a total of %s rows", i) @@ -458,6 +465,8 @@ def bulk_insert_rows( cursor.executemany(None, row_chunk) conn.commit() self.log.info("[%s] inserted %s rows", table, row_count) + # We only send lineage once, not for each value collection, to save memory. + send_sql_hook_lineage(context=self, sql=prepared_stm, row_count=row_count) cursor.close() conn.close() diff --git a/providers/oracle/tests/unit/oracle/hooks/test_oracle.py b/providers/oracle/tests/unit/oracle/hooks/test_oracle.py index 42ef28a9936de..2810b698acce5 100644 --- a/providers/oracle/tests/unit/oracle/hooks/test_oracle.py +++ b/providers/oracle/tests/unit/oracle/hooks/test_oracle.py @@ -361,6 +361,46 @@ def test_run_with_parameters(self): self.cur.execute.assert_called_once_with(sql, param) assert self.conn.commit.called + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_run_hook_lineage(self, mock_send_lineage): + statement = "SELECT 1" + self.cur.fetchall.return_value = [] + + self.db_hook.run(statement) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == statement + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is self.cur + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df") + def test_get_df_hook_lineage(self, mock_get_pandas_df, mock_send_lineage): + sql = "SELECT 1" + parameters = ("x",) + self.db_hook.get_df(sql, parameters=parameters) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") + def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_send_lineage): + sql = "SELECT 1" + parameters = ("x",) + self.db_hook.get_df_by_chunks(sql, parameters=parameters, chunksize=1) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + def test_insert_rows_with_fields(self): rows = [ ( @@ -412,6 +452,18 @@ def test_insert_rows_without_fields(self): "to_date('2019-01-24 00:00:00','YYYY-MM-DD HH24:MI:SS'),1,10.24,'str')" ) + @mock.patch("airflow.providers.oracle.hooks.oracle.send_sql_hook_lineage") + def test_insert_rows_hook_lineage(self, mock_send_lineage): + rows = [("a", "b", "c")] + target_fields = ["col1", "col2", "col3"] + self.db_hook.insert_rows("table", rows, target_fields) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == "INSERT /*+ APPEND */ INTO table (col1, col2, col3) VALUES ('a','b','c')" + assert call_kw["row_count"] == 1 + def test_bulk_insert_rows_with_fields(self): rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)] target_fields = ["col1", "col2", "col3"] @@ -440,6 +492,18 @@ def test_bulk_insert_rows_without_fields(self): self.cur.prepare.assert_called_once_with("insert into table values (:1, :2, :3)") self.cur.executemany.assert_called_once_with(None, rows) + @mock.patch("airflow.providers.oracle.hooks.oracle.send_sql_hook_lineage") + def test_bulk_insert_rows_hook_lineage(self, mock_send_lineage): + rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)] + target_fields = ["col1", "col2", "col3"] + self.db_hook.bulk_insert_rows("table", rows, target_fields) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == "insert into table (col1, col2, col3) values (:1, :2, :3)" + assert call_kw["row_count"] == 3 + def test_bulk_insert_rows_no_rows(self): rows = [] with pytest.raises(ValueError, match="parameter rows could not be None or empty iterable"): diff --git a/providers/pgvector/pyproject.toml b/providers/pgvector/pyproject.toml index 63c63a3b8d3ac..3f7720776e949 100644 --- a/providers/pgvector/pyproject.toml +++ b/providers/pgvector/pyproject.toml @@ -60,7 +60,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.8.0", - "apache-airflow-providers-postgres>=5.7.1", + "apache-airflow-providers-postgres>=5.7.1", # use next version "pgvector>=0.3.1", ] diff --git a/providers/pgvector/tests/unit/pgvector/hooks/test_pgvector.py b/providers/pgvector/tests/unit/pgvector/hooks/test_pgvector.py index 4053a1d5f7959..30183e3c6eb9d 100644 --- a/providers/pgvector/tests/unit/pgvector/hooks/test_pgvector.py +++ b/providers/pgvector/tests/unit/pgvector/hooks/test_pgvector.py @@ -16,13 +16,12 @@ # under the License. from __future__ import annotations -from unittest.mock import Mock +from unittest.mock import MagicMock, Mock, patch import pytest -from airflow.providers.pgvector.hooks.pgvector import ( - PgVectorHook, -) +from airflow.models import Connection +from airflow.providers.pgvector.hooks.pgvector import PgVectorHook @pytest.fixture @@ -30,6 +29,24 @@ def pg_vector_hook(): return PgVectorHook(postgres_conn_id="your_postgres_conn_id") +@pytest.fixture +def pgvector_hook_setup(): + """Set up mock PgVectorHook for testing (follows the postgres test pattern).""" + cur = MagicMock(rowcount=0) + conn = MagicMock() + conn.cursor.return_value = cur + + class UnitTestPgVectorHook(PgVectorHook): + conn_name_attr = "test_conn_id" + + def get_conn(self): + return conn + + db_hook = UnitTestPgVectorHook() + db_hook.get_connection = MagicMock(return_value=Connection(conn_type="postgres")) + return MagicMock(cur=cur, conn=conn, db_hook=db_hook) + + def test_create_table(pg_vector_hook): pg_vector_hook.run = Mock() table_name = "my_table" @@ -59,3 +76,62 @@ def test_truncate_table(pg_vector_hook): table_name = "my_table" pg_vector_hook.truncate_table(table_name, restart_identity=True) pg_vector_hook.run.assert_called_with("TRUNCATE TABLE my_table RESTART IDENTITY") + + +@patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") +def test_run_hook_lineage(mock_send_lineage, pgvector_hook_setup): + setup = pgvector_hook_setup + sql = "SELECT 1" + setup.db_hook.run(sql) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is setup.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is setup.cur + + +@patch("airflow.providers.postgres.hooks.postgres.send_sql_hook_lineage") +@patch("airflow.providers.postgres.hooks.postgres.PostgresHook._get_polars_df") +def test_get_df_hook_lineage(mock_get_polars_df, mock_send_lineage, pgvector_hook_setup): + setup = pgvector_hook_setup + sql = "SELECT 1" + parameters = ("x",) + setup.db_hook.get_df(sql, parameters=parameters, df_type="polars") + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is setup.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + + +@patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") +@patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") +def test_get_df_by_chunks_hook_lineage(mock_get_pandas_df_by_chunks, mock_send_lineage, pgvector_hook_setup): + setup = pgvector_hook_setup + sql = "SELECT 1" + parameters = ("x",) + setup.db_hook.get_df_by_chunks(sql, parameters=parameters, chunksize=1) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is setup.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + + +@patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") +def test_insert_rows_hook_lineage(mock_send_lineage, pgvector_hook_setup): + setup = pgvector_hook_setup + table = "table" + rows = [("hello",), ("world",)] + + setup.db_hook.insert_rows(table, rows) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is setup.db_hook + assert call_kw["sql"] == f"INSERT INTO {table} VALUES (%s)" + assert call_kw["row_count"] == 2 diff --git a/providers/postgres/pyproject.toml b/providers/postgres/pyproject.toml index f0e687f42be6f..833a6590b02b6 100644 --- a/providers/postgres/pyproject.toml +++ b/providers/postgres/pyproject.toml @@ -60,7 +60,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.12.0", - "apache-airflow-providers-common-sql>=1.23.0", + "apache-airflow-providers-common-sql>=1.23.0", # use next version "psycopg2-binary>=2.9.9; python_version < '3.13'", "psycopg2-binary>=2.9.10; python_version >= '3.13'", "asyncpg>=0.30.0", diff --git a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py index 52a0d40a772ed..8e3edabb94a68 100644 --- a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py +++ b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py @@ -33,6 +33,7 @@ Connection, conf, ) +from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.postgres.dialects.postgres import PostgresDialect @@ -338,14 +339,17 @@ def get_df( with engine.connect() as conn: if isinstance(sql, list): sql = "; ".join(sql) # Or handle multiple queries differently - return cast("PandasDataFrame", psql.read_sql(sql, con=conn, params=parameters, **kwargs)) - + result: PandasDataFrame | PolarsDataFrame = cast( + "PandasDataFrame", psql.read_sql(sql, con=conn, params=parameters, **kwargs) + ) elif df_type == "polars": - return self._get_polars_df(sql, parameters, **kwargs) - + result = self._get_polars_df(sql, parameters, **kwargs) else: raise ValueError(f"Unsupported df_type: {df_type}") + send_sql_hook_lineage(context=self, sql=sql, sql_parameters=parameters) + return result + def copy_expert(self, sql: str, filename: str) -> None: """ Execute SQL using psycopg's ``copy_expert`` method. @@ -371,6 +375,7 @@ def copy_expert(self, sql: str, filename: str) -> None: while data := file.read(8192): copy.write(data) conn.commit() + send_sql_hook_lineage(context=self, sql=sql, sql_parameters=(filename,), cur=cur) else: # Handle COPY TO STDOUT: read from the database and write to the file. with open(filename, "wb") as file, self.get_conn() as conn, conn.cursor() as cur: @@ -378,6 +383,7 @@ def copy_expert(self, sql: str, filename: str) -> None: for data in copy: file.write(data) conn.commit() + send_sql_hook_lineage(context=self, sql=sql, sql_parameters=(filename,), cur=cur) else: if not os.path.isfile(filename): with open(filename, "w"): @@ -391,6 +397,7 @@ def copy_expert(self, sql: str, filename: str) -> None: cur.copy_expert(sql, file) file.truncate(file.tell()) conn.commit() + send_sql_hook_lineage(context=self, sql=sql, sql_parameters=(filename,), cur=cur) def get_uri(self) -> str: """ @@ -672,6 +679,7 @@ def insert_rows( # if fast_executemany is enabled, use optimized execute_batch from psycopg nb_rows = 0 + sql = None # not generated unless we actually process at least one chunk with self._create_autocommit_connection(autocommit) as conn: conn.commit() with closing(conn.cursor()) as cur: @@ -695,4 +703,10 @@ def insert_rows( conn.commit() nb_rows += len(chunked_rows) self.log.info("Loaded %s rows into %s so far", nb_rows, table) + + if sql: + # We only send lineage once, not for each value collection, to save memory. + send_sql_hook_lineage(context=self, sql=sql, row_count=nb_rows) + self.log.info("Done loading. Loaded a total of %s rows into %s", nb_rows, table) + return None diff --git a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py index 788e26095d869..f4570f6563b8f 100644 --- a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py +++ b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py @@ -842,6 +842,63 @@ def test_copy_expert(self, mocker): self.cur.copy_expert.assert_called_once_with(statement, open_mock.return_value) assert open_mock.call_args.args == (filename, "r+") + @mock.patch("airflow.providers.postgres.hooks.postgres.send_sql_hook_lineage") + def test_copy_expert_hook_lineage(self, mock_send_lineage, mocker): + open_mock = mocker.mock_open(read_data='{"some": "json"}') + mocker.patch("airflow.providers.postgres.hooks.postgres.open", open_mock) + statement = "COPY t FROM STDIN" + filename = "file" + + self.db_hook.copy_expert(statement, filename) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == statement + assert call_kw["sql_parameters"] == (filename,) + assert call_kw["cur"] is self.cur + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_run_hook_lineage(self, mock_send_lineage): + statement = "SELECT 1" + self.cur.fetchall.return_value = [] + + self.db_hook.run(statement) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == statement + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is self.cur + + @mock.patch("airflow.providers.postgres.hooks.postgres.send_sql_hook_lineage") + @mock.patch("pandas.io.sql.read_sql", return_value=pd.DataFrame({"a": [1]})) + @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.get_sqlalchemy_engine") + def test_get_df_hook_lineage(self, mock_engine, mock_read_sql, mock_send_lineage): + sql = "SELECT 1" + parameters = ("x",) + self.db_hook.get_df(sql, parameters=parameters) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") + def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_send_lineage): + sql = "SELECT 1" + parameters = ("x",) + self.db_hook.get_df_by_chunks(sql, parameters=parameters, chunksize=1) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + def test_insert_rows(self, postgres_hook_setup): setup = postgres_hook_setup table = "table" @@ -858,6 +915,20 @@ def test_insert_rows(self, postgres_hook_setup): sql = f"INSERT INTO {table} VALUES (%s)" setup.cur.executemany.assert_any_call(sql, rows) + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_insert_rows_hook_lineage(self, mock_send_lineage, postgres_hook_setup): + setup = postgres_hook_setup + table = "table" + rows = [("hello",), ("world",)] + + setup.db_hook.insert_rows(table, rows) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is setup.db_hook + assert call_kw["sql"] == f"INSERT INTO {table} VALUES (%s)" + assert call_kw["row_count"] == 2 + @mock.patch("airflow.providers.postgres.hooks.postgres.execute_batch") def test_insert_rows_fast_executemany(self, mock_execute_batch, postgres_hook_setup): setup = postgres_hook_setup @@ -882,6 +953,23 @@ def test_insert_rows_fast_executemany(self, mock_execute_batch, postgres_hook_se # executemany should NOT be called in this mode setup.cur.executemany.assert_not_called() + @mock.patch("airflow.providers.postgres.hooks.postgres.send_sql_hook_lineage") + @mock.patch("airflow.providers.postgres.hooks.postgres.execute_batch") + def test_insert_rows_fast_executemany_hook_lineage( + self, mock_execute_batch, mock_send_lineage, postgres_hook_setup + ): + setup = postgres_hook_setup + table = "table" + rows = [("hello",), ("world",)] + + setup.db_hook.insert_rows(table, rows, fast_executemany=True) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is setup.db_hook + assert call_kw["sql"] == f"INSERT INTO {table} VALUES (%s)" + assert call_kw["row_count"] == 2 + def test_insert_rows_replace(self, postgres_hook_setup): setup = postgres_hook_setup table = "table" diff --git a/providers/presto/pyproject.toml b/providers/presto/pyproject.toml index fa5c0d80c2c90..9317e3df9c5dd 100644 --- a/providers/presto/pyproject.toml +++ b/providers/presto/pyproject.toml @@ -60,7 +60,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.12.0", - "apache-airflow-providers-common-sql>=1.26.0", + "apache-airflow-providers-common-sql>=1.26.0", # use next version "presto-python-client>=0.8.4", 'pandas>=2.1.2; python_version <"3.13"', 'pandas>=2.2.3; python_version >="3.13"', diff --git a/providers/presto/tests/unit/presto/hooks/test_presto.py b/providers/presto/tests/unit/presto/hooks/test_presto.py index 9e0e28e10a484..b9777c9885b01 100644 --- a/providers/presto/tests/unit/presto/hooks/test_presto.py +++ b/providers/presto/tests/unit/presto/hooks/test_presto.py @@ -236,6 +236,7 @@ def get_isolation_level(self): return IsolationLevel.READ_COMMITTED self.db_hook = UnitTestPrestoHook() + self.db_hook.get_connection = mock.Mock(return_value=Connection(conn_type="presto")) @patch("airflow.providers.common.sql.hooks.sql.DbApiHook.insert_rows") def test_insert_rows(self, mock_insert_rows): @@ -297,3 +298,56 @@ def test_split_sql_string(self): def test_serialize_cell(self): assert self.db_hook._serialize_cell("foo", None) == "foo" assert self.db_hook._serialize_cell(1, None) == 1 + + @patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_run_hook_lineage(self, mock_send_lineage): + statement = "SELECT 1" + self.cur.fetchall.return_value = [] + + self.db_hook.run(statement) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == statement + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is self.cur + + @patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_insert_rows_hook_lineage(self, mock_send_lineage): + table = "table" + rows = [("hello",), ("world",)] + + self.db_hook.insert_rows(table, rows) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == "INSERT INTO table VALUES (?)" + assert call_kw["row_count"] == 2 + + @patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df") + def test_get_df_hook_lineage(self, mock_get_pandas_df, mock_send_lineage): + sql = "SELECT 1" + parameters = ("x",) + self.db_hook.get_df(sql, parameters=parameters) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + + @patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") + def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_send_lineage): + sql = "SELECT 1" + parameters = ("x",) + self.db_hook.get_df_by_chunks(sql, parameters=parameters, chunksize=1) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters diff --git a/providers/snowflake/pyproject.toml b/providers/snowflake/pyproject.toml index c7ee1f7a67c4e..2ee111815542c 100644 --- a/providers/snowflake/pyproject.toml +++ b/providers/snowflake/pyproject.toml @@ -60,7 +60,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.12.0", - "apache-airflow-providers-common-sql>=1.27.5", + "apache-airflow-providers-common-sql>=1.27.5", # use next version 'pandas>=2.1.2; python_version <"3.13"', 'pandas>=2.2.3; python_version >="3.13"', "pyarrow>=16.1.0; python_version < '3.13'", diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py index d077b07233016..a71ac283646db 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py @@ -790,7 +790,7 @@ def _get_cursor(self, conn: Any, return_dictionaries: bool): def get_openlineage_database_info(self, connection) -> DatabaseInfo: from airflow.providers.openlineage.sqlparser import DatabaseInfo - database = self.database or self._get_field(connection.extra_dejson, "database") + database = self._get_conn_params()["database"] return DatabaseInfo( scheme=self.get_openlineage_database_dialect(connection), @@ -803,7 +803,7 @@ def get_openlineage_database_info(self, connection) -> DatabaseInfo: "data_type", "table_catalog", ], - database=database, + database=database or None, is_information_schema_cross_db=True, is_uppercase_names=True, ) @@ -812,7 +812,7 @@ def get_openlineage_database_dialect(self, _) -> str: return "snowflake" def get_openlineage_default_schema(self) -> str | None: - return self._get_conn_params()["schema"] + return self._get_conn_params()["schema"] or None def _get_openlineage_authority(self, _) -> str | None: uri = fix_snowflake_sqlalchemy_uri(self.get_uri()) diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py index f5f34b34153e7..efe45ed9d13bf 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -42,6 +42,7 @@ from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.common.compat.sdk import AirflowException +from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook from airflow.providers.snowflake.utils.sql_api_generate_jwt import JWTGenerator @@ -226,6 +227,29 @@ def execute_query( self.query_ids.append(json_response["statementHandle"]) else: raise AirflowException("No statementHandle/statementHandles present in response") + + # Send Hook Level Lineage + len_query_ids = len(self.query_ids) + if len_query_ids == 1: + send_sql_hook_lineage( + context=self, + sql=sql, + job_id=self.query_ids[0], + ) + else: + sql_statements = sql.split(";") + if len(sql_statements) == len_query_ids: + for single_sql, single_query_id in zip(sql_statements, self.query_ids): + send_sql_hook_lineage( + context=self, + sql=single_sql, + job_id=single_query_id, + ) + else: # SQL/query ID count mismatch; can't correlate sql with id - send SQL only. + send_sql_hook_lineage( + context=self, + sql=sql, + ) return self.query_ids def get_headers(self) -> dict[str, Any]: diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py index 4e89a8900a4c7..a5e332df41c17 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_sql_api.py @@ -280,6 +280,40 @@ def test_execute_query( query_ids = hook.execute_query(sql, statement_count) assert query_ids == expected_query_ids + @pytest.mark.parametrize( + ("sql", "statement_count", "expected_response", "expected_query_ids"), + [(SINGLE_STMT, 1, {"statementHandle": "uuid"}, ["uuid"])], + ) + @mock.patch(f"{MODULE_PATH}.send_sql_hook_lineage") + @mock.patch(f"{HOOK_PATH}._get_conn_params") + @mock.patch(f"{HOOK_PATH}.get_headers") + def test_execute_query_hook_lineage( + self, + mock_get_header, + mock_conn_param, + mock_send_lineage, + sql, + statement_count, + expected_response, + expected_query_ids, + mock_requests, + ): + mock_requests.codes.ok = 200 + mock_requests.request.side_effect = [ + create_successful_response_mock(expected_response), + ] + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.request.return_value).status_code = status_code_mock + + hook = SnowflakeSqlApiHook("mock_conn_id") + hook.execute_query(sql, statement_count) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == sql + assert call_kw["job_id"] == expected_query_ids[0] + @mock.patch(f"{HOOK_PATH}._get_conn_params") @mock.patch(f"{HOOK_PATH}.get_headers") def test_execute_query_multiple_times_give_fresh_query_ids_each_time( diff --git a/providers/sqlite/pyproject.toml b/providers/sqlite/pyproject.toml index 5dece8c2fe095..a0cfd62bbf166 100644 --- a/providers/sqlite/pyproject.toml +++ b/providers/sqlite/pyproject.toml @@ -59,7 +59,7 @@ requires-python = ">=3.10" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=2.11.0", - "apache-airflow-providers-common-sql>=1.26.0", + "apache-airflow-providers-common-sql>=1.26.0", # use next version ] [dependency-groups] diff --git a/providers/sqlite/tests/unit/sqlite/hooks/test_sqlite.py b/providers/sqlite/tests/unit/sqlite/hooks/test_sqlite.py index 69068c2df0d28..32928ec350e31 100644 --- a/providers/sqlite/tests/unit/sqlite/hooks/test_sqlite.py +++ b/providers/sqlite/tests/unit/sqlite/hooks/test_sqlite.py @@ -96,6 +96,7 @@ def get_conn(self): return conn self.db_hook = UnitTestSqliteHook() + self.db_hook.get_connection = mock.Mock(return_value=Connection(conn_type="sqlite")) def test_get_first_record(self): statement = "SQL" @@ -175,6 +176,55 @@ def test_generate_insert_sql_replace_true(self): assert sql == expected_sql + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_run_hook_lineage(self, mock_send_lineage): + sql = "SELECT 1" + self.db_hook.run(sql) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is self.cur + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_insert_rows_hook_lineage(self, mock_send_lineage): + table = "table" + rows = [("hello",), ("world",)] + self.db_hook.insert_rows(table, rows) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == "INSERT INTO table VALUES (?)" + assert call_kw["row_count"] == 2 + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df") + def test_get_df_hook_lineage(self, mock_get_pandas_df, mock_send_lineage): + sql = "SELECT 1" + self.db_hook.get_df(sql, df_type="pandas") + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] is None + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") + def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_send_lineage): + sql = "SELECT 1" + parameters = ("x",) + self.db_hook.get_df_by_chunks(sql, parameters=parameters, chunksize=1) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + @pytest.mark.db_test def test_sqlalchemy_engine(self): """Test that the sqlalchemy engine is initialized""" diff --git a/providers/teradata/pyproject.toml b/providers/teradata/pyproject.toml index eafa49162fef1..059e3dbe4cfff 100644 --- a/providers/teradata/pyproject.toml +++ b/providers/teradata/pyproject.toml @@ -60,7 +60,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.12.0", - "apache-airflow-providers-common-sql>=1.20.0", + "apache-airflow-providers-common-sql>=1.20.0", # use next version "teradatasqlalchemy>=17.20.0.0", "teradatasql>=17.20.0.28", ] diff --git a/providers/teradata/tests/unit/teradata/hooks/test_teradata.py b/providers/teradata/tests/unit/teradata/hooks/test_teradata.py index f10c1e629d211..95ff5f62782f4 100644 --- a/providers/teradata/tests/unit/teradata/hooks/test_teradata.py +++ b/providers/teradata/tests/unit/teradata/hooks/test_teradata.py @@ -250,6 +250,55 @@ def getvalue(self): result = self.test_db_hook.callproc("proc", True, parameters) assert result == parameters + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_run_hook_lineage(self, mock_send_lineage): + sql = "SELECT 1" + self.test_db_hook.run(sql) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.test_db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is self.cur + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_insert_rows_hook_lineage(self, mock_send_lineage): + table = "table" + rows = [("hello",), ("world",)] + self.test_db_hook.insert_rows(table, rows) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.test_db_hook + assert call_kw["sql"] == "INSERT INTO table VALUES (?)" + assert call_kw["row_count"] == 2 + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df") + def test_get_df_hook_lineage(self, mock_get_pandas_df, mock_send_lineage): + sql = "SELECT 1" + self.test_db_hook.get_df(sql, df_type="pandas") + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.test_db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] is None + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") + def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_send_lineage): + sql = "SELECT 1" + parameters = ("x",) + self.test_db_hook.get_df_by_chunks(sql, parameters=parameters, chunksize=1) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.test_db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + def test_set_query_band(self): query_band_text = "example_query_band_text" _handle_user_query_band_text(query_band_text) diff --git a/providers/trino/pyproject.toml b/providers/trino/pyproject.toml index e6de866c9d33a..301fe39787566 100644 --- a/providers/trino/pyproject.toml +++ b/providers/trino/pyproject.toml @@ -60,7 +60,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.12.0", - "apache-airflow-providers-common-sql>=1.20.0", + "apache-airflow-providers-common-sql>=1.20.0", # use next version 'pandas>=2.1.2; python_version <"3.13"', 'pandas>=2.2.3; python_version >="3.13"', "trino>=0.319.0", diff --git a/providers/trino/tests/unit/trino/hooks/test_trino.py b/providers/trino/tests/unit/trino/hooks/test_trino.py index 001d045ed6247..8acb5b9d03a0f 100644 --- a/providers/trino/tests/unit/trino/hooks/test_trino.py +++ b/providers/trino/tests/unit/trino/hooks/test_trino.py @@ -333,6 +333,7 @@ def get_isolation_level(self): return IsolationLevel.READ_COMMITTED self.db_hook = UnitTestTrinoHook() + self.db_hook.get_connection = mock.Mock(return_value=Connection(conn_type="trino")) @patch("airflow.providers.common.sql.hooks.sql.DbApiHook.insert_rows") def test_insert_rows(self, mock_insert_rows): @@ -433,6 +434,59 @@ def test_run_multistatement_defaults_to_split(self, super_run): self.db_hook.run(sql) super_run.assert_called_once_with(sql, False, None, None, True, True) + @patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_run_hook_lineage(self, mock_send_lineage): + statement = "SELECT 1" + self.cur.fetchall.return_value = [] + + self.db_hook.run(statement) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == statement + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is self.cur + + @patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_insert_rows_hook_lineage(self, mock_send_lineage): + table = "table" + rows = [("hello",), ("world",)] + + self.db_hook.insert_rows(table, rows) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == "INSERT INTO table VALUES (?)" + assert call_kw["row_count"] == 2 + + @patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df") + def test_get_df_hook_lineage(self, mock_get_pandas_df, mock_send_lineage): + sql = "SELECT 1" + parameters = ("x",) + self.db_hook.get_df(sql, parameters=parameters) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + + @patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") + def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_send_lineage): + sql = "SELECT 1" + parameters = ("x",) + self.db_hook.get_df_by_chunks(sql, parameters=parameters, chunksize=1) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + def test_connection_success(self): status, msg = self.db_hook.test_connection() assert status is True diff --git a/providers/vertica/pyproject.toml b/providers/vertica/pyproject.toml index 110f03ad0d732..ab53a90779853 100644 --- a/providers/vertica/pyproject.toml +++ b/providers/vertica/pyproject.toml @@ -60,7 +60,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.12.0", - "apache-airflow-providers-common-sql>=1.26.0", + "apache-airflow-providers-common-sql>=1.26.0", # use next version "vertica-python>=1.3.0", ] diff --git a/providers/vertica/tests/unit/vertica/hooks/test_vertica.py b/providers/vertica/tests/unit/vertica/hooks/test_vertica.py index 0063bf30534d2..8fc260f1d8512 100644 --- a/providers/vertica/tests/unit/vertica/hooks/test_vertica.py +++ b/providers/vertica/tests/unit/vertica/hooks/test_vertica.py @@ -197,3 +197,80 @@ def test_get_df_polars(self): assert column == df.columns[0] assert result_sets[0][0] == df.row(0)[0] assert result_sets[1][0] == df.row(1)[0] + + +class TestVerticaHookLineage: + def setup_method(self): + self.cur = mock.MagicMock(rowcount=0) + self.conn = mock.MagicMock() + self.conn.cursor.return_value = self.cur + conn = self.conn + + class UnitTestVerticaHook(VerticaHook): + conn_name_attr = "vertica_conn_id" + + def get_conn(self): + return conn + + self.db_hook = UnitTestVerticaHook() + self.db_hook.get_connection = mock.Mock( + return_value=Connection( + login="login", + password="password", + host="host", + schema="vertica", + ) + ) + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_run_hook_lineage(self, mock_send_lineage): + statement = "SELECT 1" + self.cur.fetchall.return_value = [] + + self.db_hook.run(statement) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == statement + assert call_kw["sql_parameters"] is None + assert call_kw["cur"] is self.cur + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + def test_insert_rows_hook_lineage(self, mock_send_lineage): + table = "table" + rows = [("hello",), ("world",)] + + self.db_hook.insert_rows(table, rows) + + mock_send_lineage.assert_called() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == "INSERT INTO table VALUES (%s)" + assert call_kw["row_count"] == 2 + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df") + def test_get_df_hook_lineage(self, mock_get_pandas_df, mock_send_lineage): + sql = "SELECT 1" + parameters = ("x",) + self.db_hook.get_df(sql, parameters=parameters) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters + + @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") + @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") + def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_send_lineage): + sql = "SELECT 1" + parameters = ("x",) + self.db_hook.get_df_by_chunks(sql, parameters=parameters, chunksize=1) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is self.db_hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters diff --git a/providers/ydb/pyproject.toml b/providers/ydb/pyproject.toml index 899b6f21cde98..4db6667a061f2 100644 --- a/providers/ydb/pyproject.toml +++ b/providers/ydb/pyproject.toml @@ -60,7 +60,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.10.1", - "apache-airflow-providers-common-sql>=1.20.0", + "apache-airflow-providers-common-sql>=1.20.0", # Use next version "ydb>=3.18.8", "ydb-dbapi>=0.1.0", ] diff --git a/providers/ydb/tests/unit/ydb/hooks/test_ydb.py b/providers/ydb/tests/unit/ydb/hooks/test_ydb.py index 7930180194698..587ec68e6033a 100644 --- a/providers/ydb/tests/unit/ydb/hooks/test_ydb.py +++ b/providers/ydb/tests/unit/ydb/hooks/test_ydb.py @@ -99,3 +99,110 @@ def test_execute(cursor_class, mock_session_pool, mock_driver, mock_get_connecti assert cur.fetchone() == (1, 2) assert cur.fetchmany(2) == [(1, 2), (2, 3)] assert cur.fetchall() == [(1, 2), (2, 3), (3, 4)] + + +@patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") +@patch(f"{BASEHOOK_PATCH_PATH}.get_connection") +@patch("ydb.Driver") +@patch("ydb.QuerySessionPool") +@patch("ydb_dbapi.Connection._cursor_cls", new_callable=PropertyMock) +def test_run_hook_lineage( + cursor_class, mock_session_pool, mock_driver, mock_get_connection, mock_send_lineage +): + mock_get_connection.return_value = Connection( + conn_type="ydb", + host="grpc://localhost", + port=2135, + login="my_user", + password="my_pwd", + extra={"database": "/my_db1"}, + ) + driver_instance = FakeDriver() + + cursor_class.return_value = FakeYDBCursor + mock_driver.return_value = driver_instance + mock_session_pool.return_value = FakeSessionPool(driver_instance) + + hook = YDBHook() + sql = "SELECT 1" + hook.run(sql) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] is None + + +@patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") +@patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df") +@patch(f"{BASEHOOK_PATCH_PATH}.get_connection") +@patch("ydb.Driver") +@patch("ydb.QuerySessionPool") +@patch("ydb_dbapi.Connection._cursor_cls", new_callable=PropertyMock) +def test_get_df_hook_lineage( + cursor_class, mock_session_pool, mock_driver, mock_get_connection, mock_get_pandas_df, mock_send_lineage +): + mock_get_connection.return_value = Connection( + conn_type="ydb", + host="grpc://localhost", + port=2135, + login="my_user", + password="my_pwd", + extra={"database": "/my_db1"}, + ) + driver_instance = FakeDriver() + + cursor_class.return_value = FakeYDBCursor + mock_driver.return_value = driver_instance + mock_session_pool.return_value = FakeSessionPool(driver_instance) + + hook = YDBHook() + sql = "SELECT 1" + hook.get_df(sql, df_type="pandas") + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] is None + + +@patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage") +@patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks") +@patch(f"{BASEHOOK_PATCH_PATH}.get_connection") +@patch("ydb.Driver") +@patch("ydb.QuerySessionPool") +@patch("ydb_dbapi.Connection._cursor_cls", new_callable=PropertyMock) +def test_get_df_by_chunks_hook_lineage( + cursor_class, + mock_session_pool, + mock_driver, + mock_get_connection, + mock_get_pandas_df_by_chunks, + mock_send_lineage, +): + mock_get_connection.return_value = Connection( + conn_type="ydb", + host="grpc://localhost", + port=2135, + login="my_user", + password="my_pwd", + extra={"database": "/my_db1"}, + ) + driver_instance = FakeDriver() + + cursor_class.return_value = FakeYDBCursor + mock_driver.return_value = driver_instance + mock_session_pool.return_value = FakeSessionPool(driver_instance) + + hook = YDBHook() + sql = "SELECT 1" + parameters = ("x",) + hook.get_df_by_chunks(sql, parameters=parameters, chunksize=1) + + mock_send_lineage.assert_called_once() + call_kw = mock_send_lineage.call_args.kwargs + assert call_kw["context"] is hook + assert call_kw["sql"] == sql + assert call_kw["sql_parameters"] == parameters