diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py index e52fa4df0fc77..0c4719bef93e2 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py @@ -617,10 +617,9 @@ def _get_openlineage_authority(self, _) -> str | None: def get_openlineage_database_specific_lineage(self, task_instance) -> OperatorLineage | None: """ - Generate OpenLineage metadata for a Snowflake task instance based on executed query IDs. + Emit separate OpenLineage events for each Snowflake query, based on executed query IDs. - If a single query ID is present, attach an `ExternalQueryRunFacet` to the lineage metadata. - If multiple query IDs are present, emits separate OpenLineage events for each query. + If a single query ID is present, also add an `ExternalQueryRunFacet` to the returned lineage metadata. Note that `get_openlineage_database_specific_lineage` is usually called after task's execution, so if multiple query IDs are present, both START and COMPLETE event for each query will be emitted @@ -641,13 +640,22 @@ def get_openlineage_database_specific_lineage(self, task_instance) -> OperatorLi ) if not self.query_ids: - self.log.debug("openlineage: no snowflake query ids found.") + self.log.info("OpenLineage could not find snowflake query ids.") return None self.log.debug("openlineage: getting connection to get database info") connection = self.get_connection(self.get_conn_id()) namespace = SQLParser.create_namespace(self.get_openlineage_database_info(connection)) + self.log.info("Separate OpenLineage events will be emitted for each query_id.") + emit_openlineage_events_for_snowflake_queries( + task_instance=task_instance, + hook=self, + query_ids=self.query_ids, + query_for_extra_metadata=True, + query_source_namespace=namespace, + ) + if len(self.query_ids) == 1: self.log.debug("Attaching ExternalQueryRunFacet with single query_id to OpenLineage event.") return OperatorLineage( @@ -658,20 +666,4 @@ def get_openlineage_database_specific_lineage(self, task_instance) -> OperatorLi } ) - self.log.info("Multiple query_ids found. Separate OpenLineage event will be emitted for each query.") - try: - from airflow.providers.openlineage.utils.utils import should_use_external_connection - - use_external_connection = should_use_external_connection(self) - except ImportError: - # OpenLineage provider release < 1.8.0 - we always use connection - use_external_connection = True - - emit_openlineage_events_for_snowflake_queries( - query_ids=self.query_ids, - query_source_namespace=namespace, - task_instance=task_instance, - hook=self if use_external_connection else None, - ) - return None diff --git a/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py b/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py index 0e3acf5d2542a..aa954c0444eac 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py +++ b/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py @@ -19,7 +19,7 @@ import datetime import logging from contextlib import closing -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from urllib.parse import quote, urlparse, urlunparse from airflow.providers.common.compat.openlineage.check import require_openlineage_version @@ -31,6 +31,7 @@ from openlineage.client.facet_v2 import JobFacet from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook + from airflow.providers.snowflake.hooks.snowflake_sql_api import SnowflakeSqlApiHook log = logging.getLogger(__name__) @@ -204,9 +205,29 @@ def _run_single_query_with_hook(hook: SnowflakeHook, sql: str) -> list[dict]: return result +def _run_single_query_with_api_hook(hook: SnowflakeSqlApiHook, sql: str) -> list[dict[str, Any]]: + """Execute a query against Snowflake API without adding extra logging or instrumentation.""" + # `hook.execute_query` resets the query_ids, so we need to save them and re-assign after we're done + query_ids_before_execution = list(hook.query_ids) + try: + _query_ids = hook.execute_query(sql=sql, statement_count=0) + hook.wait_for_query(query_id=_query_ids[0], raise_error=True, poll_interval=1, timeout=3) + return hook.get_result_from_successful_sql_api_query(query_id=_query_ids[0]) + finally: + hook.query_ids = query_ids_before_execution + + +def _process_data_from_api(data: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Convert 'START_TIME' and 'END_TIME' fields to UTC datetime objects.""" + for row in data: + for key in ("START_TIME", "END_TIME"): + row[key] = datetime.datetime.fromtimestamp(float(row[key]), timezone.utc) + return data + + def _get_queries_details_from_snowflake( - hook: SnowflakeHook, query_ids: list[str] -) -> dict[str, dict[str, str]]: + hook: SnowflakeHook | SnowflakeSqlApiHook, query_ids: list[str] +) -> dict[str, dict[str, Any]]: """Retrieve execution details for specific queries from Snowflake's query history.""" if not query_ids: return {} @@ -221,7 +242,16 @@ def _get_queries_details_from_snowflake( f";" ) - result = _run_single_query_with_hook(hook=hook, sql=query) + try: + # Can't import the SnowflakeSqlApiHook class and do proper isinstance check - circular imports + if hook.__class__.__name__ == "SnowflakeSqlApiHook": + result = _run_single_query_with_api_hook(hook=hook, sql=query) # type: ignore[arg-type] + result = _process_data_from_api(data=result) + else: + result = _run_single_query_with_hook(hook=hook, sql=query) + except Exception as e: + log.warning("OpenLineage could not retrieve extra metadata from Snowflake. Error encountered: %s", e) + result = [] return {row["QUERY_ID"]: row for row in result} if result else {} @@ -259,17 +289,18 @@ def _create_snowflake_event_pair( @require_openlineage_version(provider_min_version="2.3.0") def emit_openlineage_events_for_snowflake_queries( - query_ids: list[str], - query_source_namespace: str, task_instance, - hook: SnowflakeHook | None = None, + hook: SnowflakeHook | SnowflakeSqlApiHook | None = None, + query_ids: list[str] | None = None, + query_source_namespace: str | None = None, + query_for_extra_metadata: bool = False, additional_run_facets: dict | None = None, additional_job_facets: dict | None = None, ) -> None: """ Emit OpenLineage events for executed Snowflake queries. - Metadata retrieval from Snowflake is attempted only if a `SnowflakeHook` is provided. + Metadata retrieval from Snowflake is attempted only if `get_extra_metadata` is True and hook is provided. If metadata is available, execution details such as start time, end time, execution status, error messages, and SQL text are included in the events. If no metadata is found, the function defaults to using the Airflow task instance's state and the current timestamp. @@ -279,10 +310,16 @@ def emit_openlineage_events_for_snowflake_queries( will correspond to actual query execution times. Args: - query_ids: A list of Snowflake query IDs to emit events for. - query_source_namespace: The namespace to be included in ExternalQueryRunFacet. task_instance: The Airflow task instance that run these queries. - hook: A SnowflakeHook instance used to retrieve query metadata if available. + hook: A supported Snowflake hook instance used to retrieve query metadata if available. + If omitted, `query_ids` and `query_source_namespace` must be provided explicitly and + `query_for_extra_metadata` must be `False`. + query_ids: A list of Snowflake query IDs to emit events for, can only be None if `hook` is provided + and `hook.query_ids` are present. + query_source_namespace: The namespace to be included in ExternalQueryRunFacet, + can be `None` only if hook is provided. + query_for_extra_metadata: Whether to query Snowflake for additional metadata about queries. + Must be `False` if `hook` is not provided. additional_run_facets: Additional run facets to include in OpenLineage events. additional_job_facets: Additional job facets to include in OpenLineage events. """ @@ -297,23 +334,49 @@ def emit_openlineage_events_for_snowflake_queries( from airflow.providers.openlineage.conf import namespace from airflow.providers.openlineage.plugins.listener import get_openlineage_listener - if not query_ids: - log.debug("No Snowflake query IDs provided; skipping OpenLineage event emission.") - return - - query_ids = [q for q in query_ids] # Make a copy to make sure it does not change + log.info("OpenLineage will emit events for Snowflake queries.") if hook: + if not query_ids: + log.debug("No Snowflake query IDs provided; Checking `hook.query_ids` property.") + query_ids = getattr(hook, "query_ids", []) + if not query_ids: + raise ValueError("No Snowflake query IDs provided and `hook.query_ids` are not present.") + + if not query_source_namespace: + log.debug("No Snowflake query namespace provided; Creating one from scratch.") + from airflow.providers.openlineage.sqlparser import SQLParser + + connection = hook.get_connection(hook.get_conn_id()) + query_source_namespace = SQLParser.create_namespace( + hook.get_openlineage_database_info(connection) + ) + else: + if not query_ids: + raise ValueError("If 'hook' is not provided, 'query_ids' must be set.") + if not query_source_namespace: + raise ValueError("If 'hook' is not provided, 'query_source_namespace' must be set.") + if query_for_extra_metadata: + raise ValueError("If 'hook' is not provided, 'query_for_extra_metadata' must be False.") + + query_ids = [q for q in query_ids] # Make a copy to make sure we do not change hook's attribute + + if query_for_extra_metadata and hook: log.debug("Retrieving metadata for %s queries from Snowflake.", len(query_ids)) snowflake_metadata = _get_queries_details_from_snowflake(hook, query_ids) else: - log.debug("SnowflakeHook not provided. No extra metadata fill be fetched from Snowflake.") + log.debug("`query_for_extra_metadata` is False. No extra metadata fill be fetched from Snowflake.") snowflake_metadata = {} # If real metadata is unavailable, we send events with eventTime=now default_event_time = timezone.utcnow() # If no query metadata is provided, we use task_instance's state when checking for success - default_state = task_instance.state.value if hasattr(task_instance, "state") else "" + # ti.state has no `value` attr (AF2) when task it's still running, in AF3 we get 'running', in that case + # assuming it's user call and query succeeded, so we replace it with success. + default_state = ( + getattr(task_instance.state, "value", "running") if hasattr(task_instance, "state") else "" + ) + default_state = "success" if default_state == "running" else default_state common_run_facets = {"parent": _get_parent_run_facet(task_instance)} common_job_facets: dict[str, JobFacet] = { diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py index e28d0bca28e7d..fe21b3cad0871 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py @@ -905,14 +905,17 @@ def test_get_openlineage_default_schema_with_schema_set(self, mock_get_first): assert hook_with_schema_param.get_openlineage_default_schema() == "my_schema" mock_get_first.assert_not_called() - def test_get_openlineage_database_specific_lineage_with_no_query_ids(self): + @mock.patch("airflow.providers.snowflake.utils.openlineage.emit_openlineage_events_for_snowflake_queries") + def test_get_openlineage_database_specific_lineage_with_no_query_ids(self, mock_emit): hook = SnowflakeHook(snowflake_conn_id="test_conn") assert hook.query_ids == [] result = hook.get_openlineage_database_specific_lineage(None) + mock_emit.assert_not_called() assert result is None - def test_get_openlineage_database_specific_lineage_with_single_query_id(self): + @mock.patch("airflow.providers.snowflake.utils.openlineage.emit_openlineage_events_for_snowflake_queries") + def test_get_openlineage_database_specific_lineage_with_single_query_id(self, mock_emit): from airflow.providers.common.compat.openlineage.facet import ExternalQueryRunFacet from airflow.providers.openlineage.extractors import OperatorLineage @@ -921,21 +924,26 @@ def test_get_openlineage_database_specific_lineage_with_single_query_id(self): hook.get_connection = mock.MagicMock() hook.get_openlineage_database_info = lambda x: mock.MagicMock(authority="auth", scheme="scheme") - result = hook.get_openlineage_database_specific_lineage(None) + ti = mock.MagicMock() + + result = hook.get_openlineage_database_specific_lineage(ti) + mock_emit.assert_called_once_with( + **{ + "hook": hook, + "query_ids": ["query1"], + "query_source_namespace": "scheme://auth", + "task_instance": ti, + "query_for_extra_metadata": True, + } + ) assert result == OperatorLineage( run_facets={ "externalQuery": ExternalQueryRunFacet(externalQueryId="query1", source="scheme://auth") } ) - @pytest.mark.parametrize("use_external_connection", [True, False]) - @mock.patch("airflow.providers.openlineage.utils.utils.should_use_external_connection") @mock.patch("airflow.providers.snowflake.utils.openlineage.emit_openlineage_events_for_snowflake_queries") - def test_get_openlineage_database_specific_lineage_with_multiple_query_ids( - self, mock_emit, mock_use_conn, use_external_connection - ): - mock_use_conn.return_value = use_external_connection - + def test_get_openlineage_database_specific_lineage_with_multiple_query_ids(self, mock_emit): hook = SnowflakeHook(snowflake_conn_id="test_conn") hook.query_ids = ["query1", "query2"] hook.get_connection = mock.MagicMock() @@ -944,23 +952,19 @@ def test_get_openlineage_database_specific_lineage_with_multiple_query_ids( ti = mock.MagicMock() result = hook.get_openlineage_database_specific_lineage(ti) - mock_use_conn.assert_called_once() mock_emit.assert_called_once_with( **{ - "hook": hook if use_external_connection else None, + "hook": hook, "query_ids": ["query1", "query2"], "query_source_namespace": "scheme://auth", "task_instance": ti, + "query_for_extra_metadata": True, } ) assert result is None - # emit_openlineage_events_for_snowflake_queries requires OL provider 2.0.0 @mock.patch("importlib.metadata.version", return_value="1.99.0") - @mock.patch("airflow.providers.openlineage.utils.utils.should_use_external_connection") - def test_get_openlineage_database_specific_lineage_with_old_openlineage_provider( - self, mock_use_conn, mock_version - ): + def test_get_openlineage_database_specific_lineage_with_old_openlineage_provider(self, mock_version): hook = SnowflakeHook(snowflake_conn_id="test_conn") hook.query_ids = ["query1", "query2"] hook.get_connection = mock.MagicMock() @@ -972,7 +976,6 @@ def test_get_openlineage_database_specific_lineage_with_old_openlineage_provider ) with pytest.raises(AirflowOptionalProviderFeatureException, match=expected_err): hook.get_openlineage_database_specific_lineage(mock.MagicMock()) - mock_use_conn.assert_called_once() @pytest.mark.skipif(sys.version_info >= (3, 12), reason="Snowpark Python doesn't support Python 3.12 yet") @mock.patch("snowflake.snowpark.Session.builder") diff --git a/providers/snowflake/tests/unit/snowflake/utils/test_openlineage.py b/providers/snowflake/tests/unit/snowflake/utils/test_openlineage.py index bba8d317032c2..ec83cabc4fbb5 100644 --- a/providers/snowflake/tests/unit/snowflake/utils/test_openlineage.py +++ b/providers/snowflake/tests/unit/snowflake/utils/test_openlineage.py @@ -17,6 +17,7 @@ from __future__ import annotations import copy +import datetime from unittest import mock import pytest @@ -31,11 +32,14 @@ ) from airflow.providers.openlineage.conf import namespace from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook +from airflow.providers.snowflake.hooks.snowflake_sql_api import SnowflakeSqlApiHook from airflow.providers.snowflake.utils.openlineage import ( _create_snowflake_event_pair, _get_ol_run_id, _get_parent_run_facet, _get_queries_details_from_snowflake, + _process_data_from_api, + _run_single_query_with_api_hook, _run_single_query_with_hook, emit_openlineage_events_for_snowflake_queries, fix_account_name, @@ -164,6 +168,47 @@ def test_get_parent_run_facet(): assert result.job.name == "dag_id.task_id" +def test_process_data_from_api(): + data = [ + { + "QUERY_ID": "ABC", + "EXECUTION_STATUS": "SUCCESS", + "START_TIME": "1750245171.326000", + "END_TIME": 1750245171.387000, + "QUERY_TEXT": "SELECT * FROM test_table;", + "ERROR_CODE": None, + "ERROR_MESSAGE": None, + }, + { + "START_TIME": 1750245171.326000, + "END_TIME": "1750245171.387000", + }, + ] + expected_details = [ + { + "QUERY_ID": "ABC", + "EXECUTION_STATUS": "SUCCESS", + "START_TIME": datetime.datetime(2025, 6, 18, 11, 12, 51, 326000, tzinfo=datetime.timezone.utc), + "END_TIME": datetime.datetime(2025, 6, 18, 11, 12, 51, 387000, tzinfo=datetime.timezone.utc), + "QUERY_TEXT": "SELECT * FROM test_table;", + "ERROR_CODE": None, + "ERROR_MESSAGE": None, + }, + { + "START_TIME": datetime.datetime(2025, 6, 18, 11, 12, 51, 326000, tzinfo=datetime.timezone.utc), + "END_TIME": datetime.datetime(2025, 6, 18, 11, 12, 51, 387000, tzinfo=datetime.timezone.utc), + }, + ] + result = _process_data_from_api(data=data) + assert len(result) == 2 + assert result == expected_details + + +def test_process_data_from_api_error(): + with pytest.raises(KeyError): + _process_data_from_api(data=[{"START_TIME": "1750245171.326000"}]) + + @mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn") @mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.set_autocommit") @mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_cursor") @@ -180,6 +225,58 @@ def test_run_single_query_with_hook(mock_get_cursor, mock_set_autocommit, mock_g assert result == [{"col1": "value1"}, {"col2": "value2"}] +@mock.patch( + "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_result_from_successful_sql_api_query" +) +@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.wait_for_query") +@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.execute_query") +def test_run_single_query_with_api_hook_success(mock_execute, mock_wait, mock_get_result): + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + hook.query_ids = ["old-id"] + + # Simulate that execute_query overwrites hook.query_ids + def execute_query_side_effect(*args, **kwargs): + hook.query_ids = ["overwritten-id"] + return ["new-id"] + + mock_execute.side_effect = execute_query_side_effect + mock_get_result.return_value = [{"col": "value"}] + + result = _run_single_query_with_api_hook(hook, "SELECT 1") + + assert result == [{"col": "value"}] + mock_execute.assert_called_once_with(sql="SELECT 1", statement_count=0) + mock_wait.assert_called_once_with(query_id="new-id", raise_error=True, poll_interval=1, timeout=3) + mock_get_result.assert_called_once_with(query_id="new-id") + assert hook.query_ids == ["old-id"] + + +@mock.patch( + "airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_result_from_successful_sql_api_query" +) +@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.wait_for_query") +@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.execute_query") +def test_run_single_query_exception_restores_query_ids(mock_execute, mock_wait, mock_get_result): + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + hook.query_ids = ["persistent-id"] + + # Simulate that execute_query overwrites hook.query_ids + def execute_query_side_effect(*args, **kwargs): + hook.query_ids = [] + return ["new-id"] + + mock_execute.side_effect = execute_query_side_effect + mock_wait.side_effect = RuntimeError("execution failed") + + with pytest.raises(RuntimeError, match="execution failed"): + _run_single_query_with_api_hook(hook, "SELECT 1") + + assert hook.query_ids == ["persistent-id"] + mock_execute.assert_called_once_with(sql="SELECT 1", statement_count=0) + mock_wait.assert_called_once_with(query_id="new-id", raise_error=True, poll_interval=1, timeout=3) + mock_get_result.assert_not_called() + + def test_get_queries_details_from_snowflake_empty_query_ids(): details = _get_queries_details_from_snowflake(None, []) assert details == {} @@ -187,7 +284,7 @@ def test_get_queries_details_from_snowflake_empty_query_ids(): @mock.patch("airflow.providers.snowflake.utils.openlineage._run_single_query_with_hook") def test_get_queries_details_from_snowflake_single_query(mock_run_single_query): - hook = mock.MagicMock() + hook = SnowflakeHook(snowflake_conn_id="test_conn") query_ids = ["ABC"] fake_result = [ { @@ -212,9 +309,46 @@ def test_get_queries_details_from_snowflake_single_query(mock_run_single_query): assert details == {"ABC": fake_result[0]} +@mock.patch("airflow.providers.snowflake.utils.openlineage._run_single_query_with_api_hook") +def test_get_queries_details_from_snowflake_single_query_api_hook(mock_run_single_query_api): + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + query_ids = ["ABC"] + fake_result = [ + { + "QUERY_ID": "ABC", + "EXECUTION_STATUS": "SUCCESS", + "START_TIME": "1750245171.326000", + "END_TIME": "1750245171.387000", + "QUERY_TEXT": "SELECT * FROM test_table;", + "ERROR_CODE": None, + "ERROR_MESSAGE": None, + } + ] + mock_run_single_query_api.return_value = fake_result + + details = _get_queries_details_from_snowflake(hook, query_ids) + + expected_query = ( + "SELECT QUERY_ID, EXECUTION_STATUS, START_TIME, END_TIME, QUERY_TEXT, ERROR_CODE, ERROR_MESSAGE " + "FROM table(information_schema.query_history()) " + "WHERE QUERY_ID = 'ABC';" + ) + expected_details = { + "QUERY_ID": "ABC", + "EXECUTION_STATUS": "SUCCESS", + "START_TIME": datetime.datetime(2025, 6, 18, 11, 12, 51, 326000, tzinfo=datetime.timezone.utc), + "END_TIME": datetime.datetime(2025, 6, 18, 11, 12, 51, 387000, tzinfo=datetime.timezone.utc), + "QUERY_TEXT": "SELECT * FROM test_table;", + "ERROR_CODE": None, + "ERROR_MESSAGE": None, + } + mock_run_single_query_api.assert_called_once_with(hook=hook, sql=expected_query) + assert details == {"ABC": expected_details} + + @mock.patch("airflow.providers.snowflake.utils.openlineage._run_single_query_with_hook") def test_get_queries_details_from_snowflake_multiple_queries(mock_run_single_query): - hook = mock.MagicMock() + hook = SnowflakeHook(snowflake_conn_id="test_conn") query_ids = ["ABC", "DEF"] fake_result = [ { @@ -250,9 +384,67 @@ def test_get_queries_details_from_snowflake_multiple_queries(mock_run_single_que assert details == {row["QUERY_ID"]: row for row in fake_result} +@mock.patch("airflow.providers.snowflake.utils.openlineage._run_single_query_with_api_hook") +def test_get_queries_details_from_snowflake_multiple_queries_api_hook(mock_run_single_query_api): + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + query_ids = ["ABC", "DEF"] + fake_result = [ + { + "QUERY_ID": "ABC", + "EXECUTION_STATUS": "SUCCESS", + "START_TIME": "1750245171.326000", + "END_TIME": "1750245171.387000", + "QUERY_TEXT": "SELECT * FROM table1;", + "ERROR_CODE": None, + "ERROR_MESSAGE": None, + }, + { + "QUERY_ID": "DEF", + "EXECUTION_STATUS": "FAILED", + "START_TIME": "1750245171.326000", + "END_TIME": "1750245171.387000", + "QUERY_TEXT": "SELECT * FROM table2;", + "ERROR_CODE": "123", + "ERROR_MESSAGE": "Some error", + }, + ] + mock_run_single_query_api.return_value = fake_result + + details = _get_queries_details_from_snowflake(hook, query_ids) + + expected_query_condition = f"IN {tuple(query_ids)}" + expected_query = ( + "SELECT QUERY_ID, EXECUTION_STATUS, START_TIME, END_TIME, QUERY_TEXT, ERROR_CODE, ERROR_MESSAGE " + "FROM table(information_schema.query_history()) " + f"WHERE QUERY_ID {expected_query_condition};" + ) + expected_details = [ + { + "QUERY_ID": "ABC", + "EXECUTION_STATUS": "SUCCESS", + "START_TIME": datetime.datetime(2025, 6, 18, 11, 12, 51, 326000, tzinfo=datetime.timezone.utc), + "END_TIME": datetime.datetime(2025, 6, 18, 11, 12, 51, 387000, tzinfo=datetime.timezone.utc), + "QUERY_TEXT": "SELECT * FROM table1;", + "ERROR_CODE": None, + "ERROR_MESSAGE": None, + }, + { + "QUERY_ID": "DEF", + "EXECUTION_STATUS": "FAILED", + "START_TIME": datetime.datetime(2025, 6, 18, 11, 12, 51, 326000, tzinfo=datetime.timezone.utc), + "END_TIME": datetime.datetime(2025, 6, 18, 11, 12, 51, 387000, tzinfo=datetime.timezone.utc), + "QUERY_TEXT": "SELECT * FROM table2;", + "ERROR_CODE": "123", + "ERROR_MESSAGE": "Some error", + }, + ] + mock_run_single_query_api.assert_called_once_with(hook=hook, sql=expected_query) + assert details == {row["QUERY_ID"]: row for row in expected_details} + + @mock.patch("airflow.providers.snowflake.utils.openlineage._run_single_query_with_hook") def test_get_queries_details_from_snowflake_no_data_found(mock_run_single_query): - hook = mock.MagicMock() + hook = SnowflakeHook(snowflake_conn_id="test_conn") query_ids = ["ABC", "DEF"] mock_run_single_query.return_value = [] @@ -268,6 +460,83 @@ def test_get_queries_details_from_snowflake_no_data_found(mock_run_single_query) assert details == {} +@mock.patch("airflow.providers.snowflake.utils.openlineage._run_single_query_with_api_hook") +def test_get_queries_details_from_snowflake_no_data_found_api_hook(mock_run_single_query_api): + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + query_ids = ["ABC", "DEF"] + mock_run_single_query_api.return_value = [] + + details = _get_queries_details_from_snowflake(hook, query_ids) + + expected_query_condition = f"IN {tuple(query_ids)}" + expected_query = ( + "SELECT QUERY_ID, EXECUTION_STATUS, START_TIME, END_TIME, QUERY_TEXT, ERROR_CODE, ERROR_MESSAGE " + "FROM table(information_schema.query_history()) " + f"WHERE QUERY_ID {expected_query_condition};" + ) + mock_run_single_query_api.assert_called_once_with(hook=hook, sql=expected_query) + assert details == {} + + +@mock.patch("airflow.providers.snowflake.utils.openlineage._run_single_query_with_hook") +def test_get_queries_details_from_snowflake_error(mock_run_single_query): + hook = SnowflakeHook(snowflake_conn_id="test_conn") + query_ids = ["ABC", "DEF"] + mock_run_single_query.side_effect = ValueError("Query failure") + + details = _get_queries_details_from_snowflake(hook, query_ids) + + expected_query_condition = f"IN {tuple(query_ids)}" + expected_query = ( + "SELECT QUERY_ID, EXECUTION_STATUS, START_TIME, END_TIME, QUERY_TEXT, ERROR_CODE, ERROR_MESSAGE " + "FROM table(information_schema.query_history()) " + f"WHERE QUERY_ID {expected_query_condition};" + ) + mock_run_single_query.assert_called_once_with(hook=hook, sql=expected_query) + assert details == {} + + +@mock.patch("airflow.providers.snowflake.utils.openlineage._run_single_query_with_api_hook") +def test_get_queries_details_from_snowflake_error_api_hook(mock_run_single_query_api): + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + query_ids = ["ABC", "DEF"] + mock_run_single_query_api.side_effect = ValueError("Query failure") + + details = _get_queries_details_from_snowflake(hook, query_ids) + + expected_query_condition = f"IN {tuple(query_ids)}" + expected_query = ( + "SELECT QUERY_ID, EXECUTION_STATUS, START_TIME, END_TIME, QUERY_TEXT, ERROR_CODE, ERROR_MESSAGE " + "FROM table(information_schema.query_history()) " + f"WHERE QUERY_ID {expected_query_condition};" + ) + mock_run_single_query_api.assert_called_once_with(hook=hook, sql=expected_query) + assert details == {} + + +@mock.patch("airflow.providers.snowflake.utils.openlineage._process_data_from_api") +@mock.patch("airflow.providers.snowflake.utils.openlineage._run_single_query_with_api_hook") +def test_get_queries_details_from_snowflake_error_api_hook_process_data( + mock_run_single_query_api, mock_process_data +): + hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn") + query_ids = ["ABC", "DEF"] + mock_run_single_query_api.return_value = ["some_data"] + mock_process_data.side_effect = ValueError("Processing failure") + + details = _get_queries_details_from_snowflake(hook, query_ids) + + expected_query_condition = f"IN {tuple(query_ids)}" + expected_query = ( + "SELECT QUERY_ID, EXECUTION_STATUS, START_TIME, END_TIME, QUERY_TEXT, ERROR_CODE, ERROR_MESSAGE " + "FROM table(information_schema.query_history()) " + f"WHERE QUERY_ID {expected_query_condition};" + ) + mock_run_single_query_api.assert_called_once_with(hook=hook, sql=expected_query) + mock_process_data.assert_called_once_with(data=["some_data"]) + assert details == {} + + @pytest.mark.parametrize("is_successful", [True, False]) @mock.patch("openlineage.client.uuid.generate_new_uuid") def test_create_snowflake_event_pair_success(mock_generate_uuid, is_successful): @@ -310,7 +579,9 @@ def test_create_snowflake_event_pair_success(mock_generate_uuid, is_successful): @mock.patch("importlib.metadata.version", return_value="2.3.0") @mock.patch("openlineage.client.uuid.generate_new_uuid") @mock.patch("airflow.utils.timezone.utcnow") -def test_emit_openlineage_events_for_snowflake_queries_with_hook(mock_now, mock_generate_uuid, mock_version): +def test_emit_openlineage_events_for_snowflake_queries_with_extra_metadata( + mock_now, mock_generate_uuid, mock_version +): fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f" mock_generate_uuid.return_value = fake_uuid @@ -373,7 +644,8 @@ def test_emit_openlineage_events_for_snowflake_queries_with_hook(mock_now, mock_ query_ids=query_ids, query_source_namespace="snowflake_ns", task_instance=mock_ti, - hook=mock.MagicMock(), # any non-None hook to trigger metadata retrieval + hook=mock.MagicMock(), + query_for_extra_metadata=True, additional_run_facets=additional_run_facets, additional_job_facets=additional_job_facets, ) @@ -548,7 +820,7 @@ def test_emit_openlineage_events_for_snowflake_queries_with_hook(mock_now, mock_ @mock.patch("importlib.metadata.version", return_value="2.3.0") @mock.patch("openlineage.client.uuid.generate_new_uuid") @mock.patch("airflow.utils.timezone.utcnow") -def test_emit_openlineage_events_for_snowflake_queries_without_hook( +def test_emit_openlineage_events_for_snowflake_queries_without_extra_metadata( mock_now, mock_generate_uuid, mock_version ): fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f" @@ -589,7 +861,8 @@ def test_emit_openlineage_events_for_snowflake_queries_without_hook( query_ids=query_ids, query_source_namespace="snowflake_ns", task_instance=mock_ti, - hook=None, # None so metadata retrieval is not triggered + hook=mock.MagicMock(), + # query_for_extra_metadata=False, # False by default additional_run_facets=additional_run_facets, additional_job_facets=additional_job_facets, ) @@ -664,9 +937,36 @@ def test_emit_openlineage_events_for_snowflake_queries_without_hook( @mock.patch("importlib.metadata.version", return_value="2.3.0") -def test_emit_openlineage_events_for_snowflake_queries_without_query_ids(mock_version): - query_ids = [] - original_query_ids = copy.deepcopy(query_ids) +@mock.patch("openlineage.client.uuid.generate_new_uuid") +@mock.patch("airflow.utils.timezone.utcnow") +def test_emit_openlineage_events_for_snowflake_queries_without_query_ids( + mock_now, mock_generate_uuid, mock_version +): + fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f" + mock_generate_uuid.return_value = fake_uuid + + default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0) + mock_now.return_value = default_event_time + + hook = mock.MagicMock() + hook.query_ids = ["query1"] + original_query_ids = copy.deepcopy(hook.query_ids) + logical_date = timezone.datetime(2025, 1, 1) + mock_ti = mock.MagicMock( + dag_id="dag_id", + task_id="task_id", + map_index=1, + try_number=1, + logical_date=logical_date, + state=TaskInstanceState.RUNNING, # Success will be query default state if no metadata found + dag_run=mock.MagicMock(logical_date=logical_date, clear_number=0), + ) + mock_ti.get_template_context.return_value = { + "dag_run": mock.MagicMock(logical_date=logical_date, clear_number=0) + } + + additional_run_facets = {"custom_run": "value_run"} + additional_job_facets = {"custom_job": "value_job"} fake_adapter = mock.MagicMock() fake_adapter.emit = mock.MagicMock() @@ -678,11 +978,395 @@ def test_emit_openlineage_events_for_snowflake_queries_without_query_ids(mock_ve return_value=fake_listener, ): emit_openlineage_events_for_snowflake_queries( - query_ids=query_ids, + query_ids=[], query_source_namespace="snowflake_ns", - task_instance=None, + task_instance=mock_ti, + hook=hook, + # query_for_extra_metadata=False, # False by default + additional_run_facets=additional_run_facets, + additional_job_facets=additional_job_facets, ) + assert hook.query_ids == original_query_ids # Verify that the input query_ids list is unchanged. + assert fake_adapter.emit.call_count == 2 # Expect two events per query. + + expected_common_job_facets = { + "jobType": job_type_job.JobTypeJobFacet( + jobType="QUERY", + processingType="BATCH", + integration="SNOWFLAKE", + ), + "custom_job": "value_job", + } + expected_common_run_facets = { + "parent": parent_run.ParentRunFacet( + run=parent_run.Run(runId="01941f29-7c00-7087-8906-40e512c257bd"), + job=parent_run.Job(namespace=namespace(), name="dag_id.task_id"), + root=parent_run.Root( + run=parent_run.RootRun(runId="01941f29-7c00-743e-b109-28b18d0a19c5"), + job=parent_run.RootJob(namespace=namespace(), name="dag_id"), + ), + ), + "custom_run": "value_run", + } + + expected_calls = [ + mock.call( # Query1: START event (no metadata) + RunEvent( + eventTime=default_event_time.isoformat(), + eventType=RunState.START, + run=Run( + runId=fake_uuid, + facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query1", source="snowflake_ns" + ), + **expected_common_run_facets, + }, + ), + job=Job( + namespace=namespace(), + name="dag_id.task_id.query.1", + facets=expected_common_job_facets, + ), + ) + ), + mock.call( # Query1: COMPLETE event (no metadata) + RunEvent( + eventTime=default_event_time.isoformat(), + eventType=RunState.COMPLETE, + run=Run( + runId=fake_uuid, + facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query1", source="snowflake_ns" + ), + **expected_common_run_facets, + }, + ), + job=Job( + namespace=namespace(), + name="dag_id.task_id.query.1", + facets=expected_common_job_facets, + ), + ) + ), + ] + + assert fake_adapter.emit.call_args_list == expected_calls + + +@mock.patch("airflow.providers.openlineage.sqlparser.SQLParser.create_namespace", return_value="snowflake_ns") +@mock.patch("importlib.metadata.version", return_value="2.3.0") +@mock.patch("openlineage.client.uuid.generate_new_uuid") +@mock.patch("airflow.utils.timezone.utcnow") +def test_emit_openlineage_events_for_snowflake_queries_without_query_ids_and_namespace( + mock_now, mock_generate_uuid, mock_version, mock_parser +): + fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f" + mock_generate_uuid.return_value = fake_uuid + + default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0) + mock_now.return_value = default_event_time + + hook = mock.MagicMock() + hook.query_ids = ["query1"] + original_query_ids = copy.deepcopy(hook.query_ids) + logical_date = timezone.datetime(2025, 1, 1) + mock_ti = mock.MagicMock( + dag_id="dag_id", + task_id="task_id", + map_index=1, + try_number=1, + logical_date=logical_date, + state="running", # Success will be query default state if no metadata found + dag_run=mock.MagicMock(logical_date=logical_date, clear_number=0), + ) + mock_ti.get_template_context.return_value = { + "dag_run": mock.MagicMock(logical_date=logical_date, clear_number=0) + } + + additional_run_facets = {"custom_run": "value_run"} + additional_job_facets = {"custom_job": "value_job"} + + fake_adapter = mock.MagicMock() + fake_adapter.emit = mock.MagicMock() + fake_listener = mock.MagicMock() + fake_listener.adapter = fake_adapter + + with mock.patch( + "airflow.providers.openlineage.plugins.listener.get_openlineage_listener", + return_value=fake_listener, + ): + emit_openlineage_events_for_snowflake_queries( + query_ids=[], + query_source_namespace=None, + task_instance=mock_ti, + hook=hook, + # query_for_extra_metadata=False, # False by default + additional_run_facets=additional_run_facets, + additional_job_facets=additional_job_facets, + ) + + assert hook.query_ids == original_query_ids # Verify that the input query_ids list is unchanged. + assert fake_adapter.emit.call_count == 2 # Expect two events per query. + + expected_common_job_facets = { + "jobType": job_type_job.JobTypeJobFacet( + jobType="QUERY", + processingType="BATCH", + integration="SNOWFLAKE", + ), + "custom_job": "value_job", + } + expected_common_run_facets = { + "parent": parent_run.ParentRunFacet( + run=parent_run.Run(runId="01941f29-7c00-7087-8906-40e512c257bd"), + job=parent_run.Job(namespace=namespace(), name="dag_id.task_id"), + root=parent_run.Root( + run=parent_run.RootRun(runId="01941f29-7c00-743e-b109-28b18d0a19c5"), + job=parent_run.RootJob(namespace=namespace(), name="dag_id"), + ), + ), + "custom_run": "value_run", + } + + expected_calls = [ + mock.call( # Query1: START event (no metadata) + RunEvent( + eventTime=default_event_time.isoformat(), + eventType=RunState.START, + run=Run( + runId=fake_uuid, + facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query1", source="snowflake_ns" + ), + **expected_common_run_facets, + }, + ), + job=Job( + namespace=namespace(), + name="dag_id.task_id.query.1", + facets=expected_common_job_facets, + ), + ) + ), + mock.call( # Query1: COMPLETE event (no metadata) + RunEvent( + eventTime=default_event_time.isoformat(), + eventType=RunState.COMPLETE, + run=Run( + runId=fake_uuid, + facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query1", source="snowflake_ns" + ), + **expected_common_run_facets, + }, + ), + job=Job( + namespace=namespace(), + name="dag_id.task_id.query.1", + facets=expected_common_job_facets, + ), + ) + ), + ] + + assert fake_adapter.emit.call_args_list == expected_calls + + +@mock.patch("importlib.metadata.version", return_value="2.3.0") +@mock.patch("openlineage.client.uuid.generate_new_uuid") +@mock.patch("airflow.utils.timezone.utcnow") +def test_emit_openlineage_events_for_snowflake_queries_with_query_ids_and_hook_query_ids( + mock_now, mock_generate_uuid, mock_version +): + fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f" + mock_generate_uuid.return_value = fake_uuid + + default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0) + mock_now.return_value = default_event_time + + hook = mock.MagicMock() + hook.query_ids = ["query1"] + original_query_ids = copy.deepcopy(hook.query_ids) + logical_date = timezone.datetime(2025, 1, 1) + mock_ti = mock.MagicMock( + dag_id="dag_id", + task_id="task_id", + map_index=1, + try_number=1, + logical_date=logical_date, + state="running", # Success will be query default state if no metadata found + dag_run=mock.MagicMock(logical_date=logical_date, clear_number=0), + ) + mock_ti.get_template_context.return_value = { + "dag_run": mock.MagicMock(logical_date=logical_date, clear_number=0) + } + + additional_run_facets = {"custom_run": "value_run"} + additional_job_facets = {"custom_job": "value_job"} + + fake_adapter = mock.MagicMock() + fake_adapter.emit = mock.MagicMock() + fake_listener = mock.MagicMock() + fake_listener.adapter = fake_adapter + + with mock.patch( + "airflow.providers.openlineage.plugins.listener.get_openlineage_listener", + return_value=fake_listener, + ): + emit_openlineage_events_for_snowflake_queries( + query_ids=["query2"], + query_source_namespace="snowflake_ns", + task_instance=mock_ti, + hook=hook, + # query_for_extra_metadata=False, # False by default + additional_run_facets=additional_run_facets, + additional_job_facets=additional_job_facets, + ) + + assert hook.query_ids == original_query_ids # Verify that the input query_ids list is unchanged. + assert fake_adapter.emit.call_count == 2 # Expect two events per query. + + expected_common_job_facets = { + "jobType": job_type_job.JobTypeJobFacet( + jobType="QUERY", + processingType="BATCH", + integration="SNOWFLAKE", + ), + "custom_job": "value_job", + } + expected_common_run_facets = { + "parent": parent_run.ParentRunFacet( + run=parent_run.Run(runId="01941f29-7c00-7087-8906-40e512c257bd"), + job=parent_run.Job(namespace=namespace(), name="dag_id.task_id"), + root=parent_run.Root( + run=parent_run.RootRun(runId="01941f29-7c00-743e-b109-28b18d0a19c5"), + job=parent_run.RootJob(namespace=namespace(), name="dag_id"), + ), + ), + "custom_run": "value_run", + } + + expected_calls = [ + mock.call( # Query1: START event (no metadata) + RunEvent( + eventTime=default_event_time.isoformat(), + eventType=RunState.START, + run=Run( + runId=fake_uuid, + facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query2", source="snowflake_ns" + ), + **expected_common_run_facets, + }, + ), + job=Job( + namespace=namespace(), + name="dag_id.task_id.query.1", + facets=expected_common_job_facets, + ), + ) + ), + mock.call( # Query1: COMPLETE event (no metadata) + RunEvent( + eventTime=default_event_time.isoformat(), + eventType=RunState.COMPLETE, + run=Run( + runId=fake_uuid, + facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query2", source="snowflake_ns" + ), + **expected_common_run_facets, + }, + ), + job=Job( + namespace=namespace(), + name="dag_id.task_id.query.1", + facets=expected_common_job_facets, + ), + ) + ), + ] + + assert fake_adapter.emit.call_args_list == expected_calls + + +@mock.patch("importlib.metadata.version", return_value="2.3.0") +def test_emit_openlineage_events_for_snowflake_queries_missing_query_ids_and_hook(mock_version): + fake_adapter = mock.MagicMock() + fake_adapter.emit = mock.MagicMock() + fake_listener = mock.MagicMock() + fake_listener.adapter = fake_adapter + + with mock.patch( + "airflow.providers.openlineage.plugins.listener.get_openlineage_listener", + return_value=fake_listener, + ): + with pytest.raises(ValueError, match="If 'hook' is not provided, 'query_ids' must be set."): + emit_openlineage_events_for_snowflake_queries( + task_instance=None, query_source_namespace="snowflake_ns", query_for_extra_metadata=False + ) + + fake_adapter.emit.assert_not_called() # No events should be emitted + + +@mock.patch("importlib.metadata.version", return_value="2.3.0") +def test_emit_openlineage_events_for_snowflake_queries_missing_query_namespace_and_hook(mock_version): + query_ids = ["1", "2"] + original_query_ids = copy.deepcopy(query_ids) + + fake_adapter = mock.MagicMock() + fake_adapter.emit = mock.MagicMock() + fake_listener = mock.MagicMock() + fake_listener.adapter = fake_adapter + + with mock.patch( + "airflow.providers.openlineage.plugins.listener.get_openlineage_listener", + return_value=fake_listener, + ): + with pytest.raises( + ValueError, match="If 'hook' is not provided, 'query_source_namespace' must be set." + ): + emit_openlineage_events_for_snowflake_queries( + task_instance=None, query_ids=query_ids, query_for_extra_metadata=False + ) + + assert query_ids == original_query_ids # Verify that the input query_ids list is unchanged. + fake_adapter.emit.assert_not_called() # No events should be emitted + + +@mock.patch("importlib.metadata.version", return_value="2.3.0") +def test_emit_openlineage_events_for_snowflake_queries_missing_hook_and_query_for_extra_metadata_true( + mock_version, +): + query_ids = ["1", "2"] + original_query_ids = copy.deepcopy(query_ids) + + fake_adapter = mock.MagicMock() + fake_adapter.emit = mock.MagicMock() + fake_listener = mock.MagicMock() + fake_listener.adapter = fake_adapter + + with mock.patch( + "airflow.providers.openlineage.plugins.listener.get_openlineage_listener", + return_value=fake_listener, + ): + with pytest.raises( + ValueError, match="If 'hook' is not provided, 'query_for_extra_metadata' must be False." + ): + emit_openlineage_events_for_snowflake_queries( + task_instance=None, + query_source_namespace="snowflake_ns", + query_ids=query_ids, + query_for_extra_metadata=True, + ) + assert query_ids == original_query_ids # Verify that the input query_ids list is unchanged. fake_adapter.emit.assert_not_called() # No events should be emitted