diff --git a/airflow-core/tests/unit/always/test_project_structure.py b/airflow-core/tests/unit/always/test_project_structure.py index 95d87e2f2b494..83b0ec7b27243 100644 --- a/airflow-core/tests/unit/always/test_project_structure.py +++ b/airflow-core/tests/unit/always/test_project_structure.py @@ -199,6 +199,7 @@ def test_providers_modules_should_have_tests(self): "providers/opensearch/tests/unit/opensearch/test_version_compat.py", "providers/presto/tests/unit/presto/test_version_compat.py", "providers/redis/tests/unit/redis/test_version_compat.py", + "providers/snowflake/tests/unit/snowflake/test_version_compat.py", "providers/snowflake/tests/unit/snowflake/triggers/test_snowflake_trigger.py", "providers/standard/tests/unit/standard/operators/test_branch.py", "providers/standard/tests/unit/standard/operators/test_empty.py", diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 2c5fc991291c8..bd7f7a559adc7 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -1297,7 +1297,7 @@ }, "snowflake": { "deps": [ - "apache-airflow-providers-common-compat>=1.1.0", + "apache-airflow-providers-common-compat>=1.6.0", "apache-airflow-providers-common-sql>=1.20.0", "apache-airflow>=2.9.0", "pandas>=2.1.2,<2.2", diff --git a/providers/snowflake/README.rst b/providers/snowflake/README.rst index 88df99d5f5e84..ce5655db4730a 100644 --- a/providers/snowflake/README.rst +++ b/providers/snowflake/README.rst @@ -54,7 +54,7 @@ Requirements PIP package Version required ========================================== ===================================== ``apache-airflow`` ``>=2.9.0`` -``apache-airflow-providers-common-compat`` ``>=1.1.0`` +``apache-airflow-providers-common-compat`` ``>=1.6.0`` ``apache-airflow-providers-common-sql`` ``>=1.20.0`` ``pandas`` ``>=2.1.2,<2.2`` ``pyarrow`` ``>=14.0.1`` diff --git a/providers/snowflake/pyproject.toml b/providers/snowflake/pyproject.toml index c939e2c5f279b..93e4f2f502b6c 100644 --- a/providers/snowflake/pyproject.toml +++ b/providers/snowflake/pyproject.toml @@ -58,7 +58,7 @@ requires-python = "~=3.9" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=2.9.0", - "apache-airflow-providers-common-compat>=1.1.0", + "apache-airflow-providers-common-compat>=1.6.0", "apache-airflow-providers-common-sql>=1.20.0", # In pandas 2.2 minimal version of the sqlalchemy is 2.0 # https://pandas.pydata.org/docs/whatsnew/v2.2.0.html#increased-minimum-versions-for-dependencies diff --git a/providers/snowflake/src/airflow/providers/snowflake/get_provider_info.py b/providers/snowflake/src/airflow/providers/snowflake/get_provider_info.py index 9ef7da6016dc4..d73fb58ea2674 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/get_provider_info.py +++ b/providers/snowflake/src/airflow/providers/snowflake/get_provider_info.py @@ -154,7 +154,7 @@ def get_provider_info(): ], "dependencies": [ "apache-airflow>=2.9.0", - "apache-airflow-providers-common-compat>=1.1.0", + "apache-airflow-providers-common-compat>=1.6.0", "apache-airflow-providers-common-sql>=1.20.0", "pandas>=2.1.2,<2.2", "pyarrow>=14.0.1", diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py index f997d20aecdef..092c8be9069a8 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py @@ -544,15 +544,41 @@ def _get_openlineage_authority(self, _) -> str | None: uri = fix_snowflake_sqlalchemy_uri(self.get_uri()) return urlparse(uri).hostname - def get_openlineage_database_specific_lineage(self, _) -> OperatorLineage | None: + def get_openlineage_database_specific_lineage(self, task_instance) -> OperatorLineage | None: + """ + Generate OpenLineage metadata for a Snowflake task instance based on executed query IDs. + + If a single query ID is present, attach an `ExternalQueryRunFacet` to the lineage metadata. + If multiple query IDs are present, emits separate OpenLineage events for each query. + + Note that `get_openlineage_database_specific_lineage` is usually called after task's execution, + so if multiple query IDs are present, both START and COMPLETE event for each query will be emitted + after task's execution. If we are able to query Snowflake for query execution metadata, + query event times will correspond to actual query's start and finish times. + + Args: + task_instance: The Airflow TaskInstance object for which lineage is being collected. + + Returns: + An `OperatorLineage` object if a single query ID is found; otherwise `None`. + """ from airflow.providers.common.compat.openlineage.facet import ExternalQueryRunFacet from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.sqlparser import SQLParser + from airflow.providers.snowflake.utils.openlineage import ( + emit_openlineage_events_for_snowflake_queries, + ) - if self.query_ids: - self.log.debug("openlineage: getting connection to get database info") - connection = self.get_connection(self.get_conn_id()) - namespace = SQLParser.create_namespace(self.get_openlineage_database_info(connection)) + if not self.query_ids: + self.log.debug("openlineage: no snowflake query ids found.") + return None + + self.log.debug("openlineage: getting connection to get database info") + connection = self.get_connection(self.get_conn_id()) + namespace = SQLParser.create_namespace(self.get_openlineage_database_info(connection)) + + if len(self.query_ids) == 1: + self.log.debug("Attaching ExternalQueryRunFacet with single query_id to OpenLineage event.") return OperatorLineage( run_facets={ "externalQuery": ExternalQueryRunFacet( @@ -560,4 +586,21 @@ def get_openlineage_database_specific_lineage(self, _) -> OperatorLineage | None ) } ) + + self.log.info("Multiple query_ids found. Separate OpenLineage event will be emitted for each query.") + try: + from airflow.providers.openlineage.utils.utils import should_use_external_connection + + use_external_connection = should_use_external_connection(self) + except ImportError: + # OpenLineage provider release < 1.8.0 - we always use connection + use_external_connection = True + + emit_openlineage_events_for_snowflake_queries( + query_ids=self.query_ids, + query_source_namespace=namespace, + task_instance=task_instance, + hook=self if use_external_connection else None, + ) + return None diff --git a/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py b/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py index 12821d49f35c3..c5483be651877 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py +++ b/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py @@ -16,8 +16,26 @@ # under the License. from __future__ import annotations +import datetime +import logging +from contextlib import closing +from typing import TYPE_CHECKING from urllib.parse import quote, urlparse, urlunparse +from airflow.providers.common.compat.openlineage.check import require_openlineage_version +from airflow.providers.snowflake.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS +from airflow.utils import timezone +from airflow.utils.state import TaskInstanceState + +if TYPE_CHECKING: + from openlineage.client.event_v2 import RunEvent + from openlineage.client.facet_v2 import JobFacet + + from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook + + +log = logging.getLogger(__name__) + def fix_account_name(name: str) -> str: """Fix account name to have the following format: ...""" @@ -78,3 +96,246 @@ def fix_snowflake_sqlalchemy_uri(uri: str) -> str: hostname = fix_account_name(hostname) # else - its new hostname, just return it return urlunparse((parts.scheme, hostname, parts.path, parts.params, parts.query, parts.fragment)) + + +# todo: move this run_id logic into OpenLineage's listener to avoid differences +def _get_ol_run_id(task_instance) -> str: + """ + Get OpenLineage run_id from TaskInstance. + + It's crucial that the task_instance's run_id creation logic matches OpenLineage's listener implementation. + Only then can we ensure that the generated run_id aligns with the Airflow task, + enabling a proper connection between events. + """ + from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter + + def _get_logical_date(): + # todo: remove when min airflow version >= 3.0 + if AIRFLOW_V_3_0_PLUS: + dagrun = task_instance.get_template_context()["dag_run"] + return dagrun.logical_date or dagrun.run_after + + if hasattr(task_instance, "logical_date"): + date = task_instance.logical_date + else: + date = task_instance.execution_date + + return date + + def _get_try_number_success(): + """We are running this in the _on_complete, so need to adjust for try_num changes.""" + # todo: remove when min airflow version >= 2.10.0 + if AIRFLOW_V_2_10_PLUS: + return task_instance.try_number + if task_instance.state == TaskInstanceState.SUCCESS: + return task_instance.try_number - 1 + return task_instance.try_number + + # Generate same OL run id as is generated for current task instance + return OpenLineageAdapter.build_task_instance_run_id( + dag_id=task_instance.dag_id, + task_id=task_instance.task_id, + logical_date=_get_logical_date(), + try_number=_get_try_number_success(), + map_index=task_instance.map_index, + ) + + +def _get_parent_run_facet(task_instance): + """ + Retrieve the ParentRunFacet associated with a specific Airflow task instance. + + This facet helps link OpenLineage events of child jobs - such as queries executed within + external systems (e.g., Snowflake) by the Airflow task - to the original Airflow task execution. + Establishing this connection enables better lineage tracking and observability. + """ + from openlineage.client.facet_v2 import parent_run + + from airflow.providers.openlineage.conf import namespace + + parent_run_id = _get_ol_run_id(task_instance) + + return parent_run.ParentRunFacet( + run=parent_run.Run(runId=parent_run_id), + job=parent_run.Job( + namespace=namespace(), + name=f"{task_instance.dag_id}.{task_instance.task_id}", + ), + ) + + +def _run_single_query_with_hook(hook: SnowflakeHook, sql: str) -> list[dict]: + """Execute a query against Snowflake without adding extra logging or instrumentation.""" + with closing(hook.get_conn()) as conn: + hook.set_autocommit(conn, False) + with hook._get_cursor(conn, return_dictionaries=True) as cur: + cur.execute(sql) + result = cur.fetchall() + conn.commit() + return result + + +def _get_queries_details_from_snowflake( + hook: SnowflakeHook, query_ids: list[str] +) -> dict[str, dict[str, str]]: + """Retrieve execution details for specific queries from Snowflake's query history.""" + if not query_ids: + return {} + query_condition = f"IN {tuple(query_ids)}" if len(query_ids) > 1 else f"= '{query_ids[0]}'" + query = ( + "SELECT " + "QUERY_ID, EXECUTION_STATUS, START_TIME, END_TIME, QUERY_TEXT, ERROR_CODE, ERROR_MESSAGE " + "FROM " + "table(information_schema.query_history()) " + f"WHERE " + f"QUERY_ID {query_condition}" + f";" + ) + + result = _run_single_query_with_hook(hook=hook, sql=query) + + return {row["QUERY_ID"]: row for row in result} if result else {} + + +def _create_snowflake_event_pair( + job_namespace: str, + job_name: str, + start_time: datetime.datetime, + end_time: datetime.datetime, + is_successful: bool, + run_facets: dict | None = None, + job_facets: dict | None = None, +) -> tuple[RunEvent, RunEvent]: + """Create a pair of OpenLineage RunEvents representing the start and end of a Snowflake job execution.""" + from openlineage.client.event_v2 import Job, Run, RunEvent, RunState + from openlineage.client.uuid import generate_new_uuid + + run = Run(runId=str(generate_new_uuid()), facets=run_facets or {}) + job = Job(namespace=job_namespace, name=job_name, facets=job_facets or {}) + + start = RunEvent( + eventType=RunState.START, + eventTime=start_time.isoformat(), + run=run, + job=job, + ) + end = RunEvent( + eventType=RunState.COMPLETE if is_successful else RunState.FAIL, + eventTime=end_time.isoformat(), + run=run, + job=job, + ) + return start, end + + +@require_openlineage_version(provider_min_version="2.0.0") +def emit_openlineage_events_for_snowflake_queries( + query_ids: list[str], + query_source_namespace: str, + task_instance, + hook: SnowflakeHook | None = None, + additional_run_facets: dict | None = None, + additional_job_facets: dict | None = None, +) -> None: + """ + Emit OpenLineage events for executed Snowflake queries. + + Metadata retrieval from Snowflake is attempted only if a `SnowflakeHook` is provided. + If metadata is available, execution details such as start time, end time, execution status, + error messages, and SQL text are included in the events. If no metadata is found, the function + defaults to using the Airflow task instance's state and the current timestamp. + + Note that both START and COMPLETE event for each query will be emitted at the same time. + If we are able to query Snowflake for query execution metadata, event times + will correspond to actual query execution times. + + Args: + query_ids: A list of Snowflake query IDs to emit events for. + query_source_namespace: The namespace to be included in ExternalQueryRunFacet. + task_instance: The Airflow task instance that run these queries. + hook: A SnowflakeHook instance used to retrieve query metadata if available. + additional_run_facets: Additional run facets to include in OpenLineage events. + additional_job_facets: Additional job facets to include in OpenLineage events. + """ + from openlineage.client.facet_v2 import job_type_job + + from airflow.providers.common.compat.openlineage.facet import ( + ErrorMessageRunFacet, + ExternalQueryRunFacet, + SQLJobFacet, + ) + from airflow.providers.openlineage.conf import namespace + from airflow.providers.openlineage.plugins.listener import get_openlineage_listener + + if not query_ids: + log.debug("No Snowflake query IDs provided; skipping OpenLineage event emission.") + return + + query_ids = [q for q in query_ids] # Make a copy to make sure it does not change + + if hook: + log.debug("Retrieving metadata for %s queries from Snowflake.", len(query_ids)) + snowflake_metadata = _get_queries_details_from_snowflake(hook, query_ids) + else: + log.debug("SnowflakeHook not provided. No extra metadata fill be fetched from Snowflake.") + snowflake_metadata = {} + + # If real metadata is unavailable, we send events with eventTime=now + default_event_time = timezone.utcnow() + # If no query metadata is provided, we use task_instance's state when checking for success + default_state = str(task_instance.state) if hasattr(task_instance, "state") else "" + + common_run_facets = {"parent": _get_parent_run_facet(task_instance)} + common_job_facets: dict[str, JobFacet] = { + "jobType": job_type_job.JobTypeJobFacet( + jobType="QUERY", + integration="SNOWFLAKE", + processingType="BATCH", + ) + } + additional_run_facets = additional_run_facets or {} + additional_job_facets = additional_job_facets or {} + + events: list[RunEvent] = [] + for counter, query_id in enumerate(query_ids, 1): + query_metadata = snowflake_metadata.get(query_id, {}) + log.debug( + "Metadata for query no. %s, (ID `%s`): `%s`", + counter, + query_id, + query_metadata if query_metadata else "not found", + ) + + query_specific_run_facets = { + "externalQuery": ExternalQueryRunFacet(externalQueryId=query_id, source=query_source_namespace) + } + if query_metadata.get("ERROR_MESSAGE"): + query_specific_run_facets["error"] = ErrorMessageRunFacet( + message=f"{query_metadata.get('ERROR_CODE')} : {query_metadata['ERROR_MESSAGE']}", + programmingLanguage="SQL", + ) + + query_specific_job_facets = {} + if query_metadata.get("QUERY_TEXT"): + query_specific_job_facets["sql"] = SQLJobFacet(query=query_metadata["QUERY_TEXT"]) + + log.debug("Creating OpenLineage event pair for query ID: %s", query_id) + event_batch = _create_snowflake_event_pair( + job_namespace=namespace(), + job_name=f"{task_instance.dag_id}.{task_instance.task_id}.query.{counter}", + start_time=query_metadata.get("START_TIME", default_event_time), # type: ignore[arg-type] + end_time=query_metadata.get("END_TIME", default_event_time), # type: ignore[arg-type] + # `EXECUTION_STATUS` can be `success`, `fail` or `incident` (Snowflake outage, so still failure) + is_successful=query_metadata.get("EXECUTION_STATUS", default_state).lower() == "success", + run_facets={**query_specific_run_facets, **common_run_facets, **additional_run_facets}, + job_facets={**query_specific_job_facets, **common_job_facets, **additional_job_facets}, + ) + events.extend(event_batch) + + log.debug("Generated %s OpenLineage events; emitting now.", len(events)) + client = get_openlineage_listener().adapter.get_or_create_openlineage_client() + for event in events: + client.emit(event) + + log.info("OpenLineage has successfully finished processing information about Snowflake queries.") + return diff --git a/providers/snowflake/src/airflow/providers/snowflake/version_compat.py b/providers/snowflake/src/airflow/providers/snowflake/version_compat.py new file mode 100644 index 0000000000000..21e7170194e36 --- /dev/null +++ b/providers/snowflake/src/airflow/providers/snowflake/version_compat.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# NOTE! THIS FILE IS COPIED MANUALLY IN OTHER PROVIDERS DELIBERATELY TO AVOID ADDING UNNECESSARY +# DEPENDENCIES BETWEEN PROVIDERS. IF YOU WANT TO ADD CONDITIONAL CODE IN YOUR PROVIDER THAT DEPENDS +# ON AIRFLOW VERSION, PLEASE COPY THIS FILE TO THE ROOT PACKAGE OF YOUR PROVIDER AND IMPORT +# THOSE CONSTANTS FROM IT RATHER THAN IMPORTING THEM FROM ANOTHER PROVIDER OR TEST CODE +# +from __future__ import annotations + + +def get_base_airflow_version_tuple() -> tuple[int, int, int]: + from packaging.version import Version + + from airflow import __version__ + + airflow_version = Version(__version__) + return airflow_version.major, airflow_version.minor, airflow_version.micro + + +AIRFLOW_V_2_10_PLUS = get_base_airflow_version_tuple() >= (2, 10, 0) +AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py index b1a65b4293b66..2897921dc0cf7 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py @@ -28,6 +28,7 @@ from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa +from airflow.exceptions import AirflowOptionalProviderFeatureException from airflow.models import Connection from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook @@ -721,6 +722,75 @@ def test_get_openlineage_default_schema_with_schema_set(self, mock_get_first): assert hook_with_schema_param.get_openlineage_default_schema() == "my_schema" mock_get_first.assert_not_called() + def test_get_openlineage_database_specific_lineage_with_no_query_ids(self): + hook = SnowflakeHook(snowflake_conn_id="test_conn") + assert hook.query_ids == [] + + result = hook.get_openlineage_database_specific_lineage(None) + assert result is None + + def test_get_openlineage_database_specific_lineage_with_single_query_id(self): + from airflow.providers.common.compat.openlineage.facet import ExternalQueryRunFacet + from airflow.providers.openlineage.extractors import OperatorLineage + + hook = SnowflakeHook(snowflake_conn_id="test_conn") + hook.query_ids = ["query1"] + hook.get_connection = mock.MagicMock() + hook.get_openlineage_database_info = lambda x: mock.MagicMock(authority="auth", scheme="scheme") + + result = hook.get_openlineage_database_specific_lineage(None) + assert result == OperatorLineage( + run_facets={ + "externalQuery": ExternalQueryRunFacet(externalQueryId="query1", source="scheme://auth") + } + ) + + @pytest.mark.parametrize("use_external_connection", [True, False]) + @mock.patch("airflow.providers.openlineage.utils.utils.should_use_external_connection") + @mock.patch("airflow.providers.snowflake.utils.openlineage.emit_openlineage_events_for_snowflake_queries") + def test_get_openlineage_database_specific_lineage_with_multiple_query_ids( + self, mock_emit, mock_use_conn, use_external_connection + ): + mock_use_conn.return_value = use_external_connection + + hook = SnowflakeHook(snowflake_conn_id="test_conn") + hook.query_ids = ["query1", "query2"] + hook.get_connection = mock.MagicMock() + hook.get_openlineage_database_info = lambda x: mock.MagicMock(authority="auth", scheme="scheme") + + ti = mock.MagicMock() + + result = hook.get_openlineage_database_specific_lineage(ti) + mock_use_conn.assert_called_once() + mock_emit.assert_called_once_with( + **{ + "hook": hook if use_external_connection else None, + "query_ids": ["query1", "query2"], + "query_source_namespace": "scheme://auth", + "task_instance": ti, + } + ) + assert result is None + + # emit_openlineage_events_for_snowflake_queries requires OL provider 2.0.0 + @mock.patch("importlib.metadata.version", return_value="1.99.0") + @mock.patch("airflow.providers.openlineage.utils.utils.should_use_external_connection") + def test_get_openlineage_database_specific_lineage_with_old_openlineage_provider( + self, mock_use_conn, mock_version + ): + hook = SnowflakeHook(snowflake_conn_id="test_conn") + hook.query_ids = ["query1", "query2"] + hook.get_connection = mock.MagicMock() + hook.get_openlineage_database_info = lambda x: mock.MagicMock(authority="auth", scheme="scheme") + + expected_err = ( + "OpenLineage provider version `1.99.0` is lower than required `2.0.0`, " + "skipping function `emit_openlineage_events_for_snowflake_queries` execution" + ) + with pytest.raises(AirflowOptionalProviderFeatureException, match=expected_err): + hook.get_openlineage_database_specific_lineage(mock.MagicMock()) + mock_use_conn.assert_called_once() + @pytest.mark.skipif(sys.version_info >= (3, 12), reason="Snowpark Python doesn't support Python 3.12 yet") @mock.patch("snowflake.snowpark.Session.builder") def test_get_snowpark_session(self, mock_session_builder): diff --git a/providers/snowflake/tests/unit/snowflake/utils/test_openlineage.py b/providers/snowflake/tests/unit/snowflake/utils/test_openlineage.py index 393341c5ff4bd..1ecaf75af1804 100644 --- a/providers/snowflake/tests/unit/snowflake/utils/test_openlineage.py +++ b/providers/snowflake/tests/unit/snowflake/utils/test_openlineage.py @@ -16,9 +16,35 @@ # under the License. from __future__ import annotations +import copy +from unittest import mock + import pytest +from openlineage.client.event_v2 import Job, Run, RunEvent, RunState +from openlineage.client.facet_v2 import job_type_job, parent_run + +from airflow.exceptions import AirflowOptionalProviderFeatureException +from airflow.providers.common.compat.openlineage.facet import ( + ErrorMessageRunFacet, + ExternalQueryRunFacet, + SQLJobFacet, +) +from airflow.providers.openlineage.conf import namespace +from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook +from airflow.providers.snowflake.utils.openlineage import ( + _create_snowflake_event_pair, + _get_ol_run_id, + _get_parent_run_facet, + _get_queries_details_from_snowflake, + _run_single_query_with_hook, + emit_openlineage_events_for_snowflake_queries, + fix_account_name, + fix_snowflake_sqlalchemy_uri, +) +from airflow.utils import timezone +from airflow.utils.state import TaskInstanceState -from airflow.providers.snowflake.utils.openlineage import fix_account_name, fix_snowflake_sqlalchemy_uri +from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS @pytest.mark.parametrize( @@ -84,3 +110,589 @@ def test_fix_account_name(name, expected): fix_snowflake_sqlalchemy_uri(f"snowflake://{name}/database/schema") == f"snowflake://{expected}/database/schema" ) + + +def test_get_ol_run_id_ti_success(): + logical_date = timezone.datetime(2025, 1, 1) + mock_ti = mock.MagicMock( + dag_id="dag_id", + task_id="task_id", + map_index=1, + try_number=1 if AIRFLOW_V_2_10_PLUS else 2, + logical_date=logical_date, + state=TaskInstanceState.SUCCESS, + ) + mock_ti.get_template_context.return_value = {"dag_run": mock.MagicMock(logical_date=logical_date)} + + result = _get_ol_run_id(mock_ti) + assert result == "01941f29-7c00-7087-8906-40e512c257bd" + + +def test_get_ol_run_id_ti_failed(): + logical_date = timezone.datetime(2025, 1, 1) + mock_ti = mock.MagicMock( + dag_id="dag_id", + task_id="task_id", + map_index=1, + try_number=1, + logical_date=logical_date, + state=TaskInstanceState.FAILED, + ) + mock_ti.get_template_context.return_value = {"dag_run": mock.MagicMock(logical_date=logical_date)} + + result = _get_ol_run_id(mock_ti) + assert result == "01941f29-7c00-7087-8906-40e512c257bd" + + +def test_get_parent_run_facet(): + logical_date = timezone.datetime(2025, 1, 1) + mock_ti = mock.MagicMock( + dag_id="dag_id", + task_id="task_id", + map_index=1, + try_number=1 if AIRFLOW_V_2_10_PLUS else 2, + logical_date=logical_date, + state=TaskInstanceState.SUCCESS, + ) + mock_ti.get_template_context.return_value = {"dag_run": mock.MagicMock(logical_date=logical_date)} + + result = _get_parent_run_facet(mock_ti) + + assert result.run.runId == "01941f29-7c00-7087-8906-40e512c257bd" + assert result.job.namespace == namespace() + assert result.job.name == "dag_id.task_id" + + +@mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn") +@mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.set_autocommit") +@mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_cursor") +def test_run_single_query_with_hook(mock_get_cursor, mock_set_autocommit, mock_get_conn): + mock_cursor = mock.MagicMock() + mock_cursor.fetchall.return_value = [{"col1": "value1"}, {"col2": "value2"}] + mock_get_cursor.return_value.__enter__.return_value = mock_cursor + hook = SnowflakeHook(snowflake_conn_id="test_conn") + + sql_query = "SELECT * FROM test_table;" + result = _run_single_query_with_hook(hook, sql_query) + + mock_cursor.execute.assert_called_once_with(sql_query) + assert result == [{"col1": "value1"}, {"col2": "value2"}] + + +def test_get_queries_details_from_snowflake_empty_query_ids(): + details = _get_queries_details_from_snowflake(None, []) + assert details == {} + + +@mock.patch("airflow.providers.snowflake.utils.openlineage._run_single_query_with_hook") +def test_get_queries_details_from_snowflake_single_query(mock_run_single_query): + hook = mock.MagicMock() + query_ids = ["ABC"] + fake_result = [ + { + "QUERY_ID": "ABC", + "EXECUTION_STATUS": "SUCCESS", + "START_TIME": timezone.datetime(2025, 1, 1), + "END_TIME": timezone.datetime(2025, 1, 1), + "QUERY_TEXT": "SELECT * FROM test_table;", + "ERROR_CODE": None, + "ERROR_MESSAGE": None, + } + ] + mock_run_single_query.return_value = fake_result + + details = _get_queries_details_from_snowflake(hook, query_ids) + expected_query = ( + "SELECT QUERY_ID, EXECUTION_STATUS, START_TIME, END_TIME, QUERY_TEXT, ERROR_CODE, ERROR_MESSAGE " + "FROM table(information_schema.query_history()) " + "WHERE QUERY_ID = 'ABC';" + ) + mock_run_single_query.assert_called_once_with(hook=hook, sql=expected_query) + assert details == {"ABC": fake_result[0]} + + +@mock.patch("airflow.providers.snowflake.utils.openlineage._run_single_query_with_hook") +def test_get_queries_details_from_snowflake_multiple_queries(mock_run_single_query): + hook = mock.MagicMock() + query_ids = ["ABC", "DEF"] + fake_result = [ + { + "QUERY_ID": "ABC", + "EXECUTION_STATUS": "SUCCESS", + "START_TIME": timezone.datetime(2025, 1, 1), + "END_TIME": timezone.datetime(2025, 1, 1), + "QUERY_TEXT": "SELECT * FROM table1;", + "ERROR_CODE": None, + "ERROR_MESSAGE": None, + }, + { + "QUERY_ID": "DEF", + "EXECUTION_STATUS": "FAILED", + "START_TIME": timezone.datetime(2025, 1, 1), + "END_TIME": timezone.datetime(2025, 1, 1), + "QUERY_TEXT": "SELECT * FROM table2;", + "ERROR_CODE": "123", + "ERROR_MESSAGE": "Some error", + }, + ] + mock_run_single_query.return_value = fake_result + + details = _get_queries_details_from_snowflake(hook, query_ids) + + expected_query_condition = f"IN {tuple(query_ids)}" + expected_query = ( + "SELECT QUERY_ID, EXECUTION_STATUS, START_TIME, END_TIME, QUERY_TEXT, ERROR_CODE, ERROR_MESSAGE " + "FROM table(information_schema.query_history()) " + f"WHERE QUERY_ID {expected_query_condition};" + ) + mock_run_single_query.assert_called_once_with(hook=hook, sql=expected_query) + assert details == {row["QUERY_ID"]: row for row in fake_result} + + +@mock.patch("airflow.providers.snowflake.utils.openlineage._run_single_query_with_hook") +def test_get_queries_details_from_snowflake_no_data_found(mock_run_single_query): + hook = mock.MagicMock() + query_ids = ["ABC", "DEF"] + mock_run_single_query.return_value = [] + + details = _get_queries_details_from_snowflake(hook, query_ids) + + expected_query_condition = f"IN {tuple(query_ids)}" + expected_query = ( + "SELECT QUERY_ID, EXECUTION_STATUS, START_TIME, END_TIME, QUERY_TEXT, ERROR_CODE, ERROR_MESSAGE " + "FROM table(information_schema.query_history()) " + f"WHERE QUERY_ID {expected_query_condition};" + ) + mock_run_single_query.assert_called_once_with(hook=hook, sql=expected_query) + assert details == {} + + +@pytest.mark.parametrize("is_successful", [True, False]) +@mock.patch("openlineage.client.uuid.generate_new_uuid") +def test_create_snowflake_event_pair_success(mock_generate_uuid, is_successful): + fake_uuid = "01941f29-7c00-7087-8906-40e512c257bd" + mock_generate_uuid.return_value = fake_uuid + + job_namespace = "test_namespace" + job_name = "test_job" + start_time = timezone.datetime(2021, 1, 1, 10, 0, 0) + end_time = timezone.datetime(2021, 1, 1, 10, 30, 0) + run_facets = {"run_key": "run_value"} + job_facets = {"job_key": "job_value"} + + start_event, end_event = _create_snowflake_event_pair( + job_namespace, + job_name, + start_time, + end_time, + is_successful=is_successful, + run_facets=run_facets, + job_facets=job_facets, + ) + + assert start_event.eventType == RunState.START + assert start_event.eventTime == start_time.isoformat() + assert end_event.eventType == RunState.COMPLETE if is_successful else RunState.FAIL + assert end_event.eventTime == end_time.isoformat() + + assert start_event.run.runId == fake_uuid + assert start_event.run.facets == run_facets + + assert start_event.job.namespace == job_namespace + assert start_event.job.name == job_name + assert start_event.job.facets == job_facets + + assert start_event.run is end_event.run + assert start_event.job == end_event.job + + +@mock.patch("openlineage.client.uuid.generate_new_uuid") +@mock.patch("airflow.utils.timezone.utcnow") +def test_emit_openlineage_events_for_snowflake_queries_with_hook(mock_now, mock_generate_uuid): + fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f" + mock_generate_uuid.return_value = fake_uuid + + default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0) + mock_now.return_value = default_event_time + + query_ids = ["query1", "query2", "query3"] + original_query_ids = copy.deepcopy(query_ids) + logical_date = timezone.datetime(2025, 1, 1) + mock_ti = mock.MagicMock( + dag_id="dag_id", + task_id="task_id", + map_index=1, + try_number=1, + logical_date=logical_date, + state=TaskInstanceState.FAILED, # This will be query default state if no metadata found + ) + mock_ti.get_template_context.return_value = {"dag_run": mock.MagicMock(logical_date=logical_date)} + + fake_metadata = { + "query1": { + "START_TIME": timezone.datetime(2025, 1, 1, 0, 0, 0), + "END_TIME": timezone.datetime(2025, 1, 2, 0, 0, 0), + "EXECUTION_STATUS": "SUCCESS", + "QUERY_TEXT": "SELECT * FROM table1", + # No error for query1 + }, + "query2": { + "START_TIME": timezone.datetime(2025, 1, 3, 0, 0, 0), + "END_TIME": timezone.datetime(2025, 1, 4, 0, 0, 0), + "EXECUTION_STATUS": "FAIL", + "QUERY_TEXT": "SELECT * FROM table2", + "ERROR_MESSAGE": "Error occurred", + "ERROR_CODE": "ERR001", + }, + # No metadata for query3 + } + + additional_run_facets = {"custom_run": "value_run"} + additional_job_facets = {"custom_job": "value_job"} + + fake_client = mock.MagicMock() + fake_client.emit = mock.MagicMock() + fake_listener = mock.MagicMock() + fake_listener.adapter.get_or_create_openlineage_client.return_value = fake_client + + with ( + mock.patch( + "airflow.providers.snowflake.utils.openlineage._get_queries_details_from_snowflake", + return_value=fake_metadata, + ), + mock.patch( + "airflow.providers.openlineage.plugins.listener.get_openlineage_listener", + return_value=fake_listener, + ), + ): + emit_openlineage_events_for_snowflake_queries( + query_ids=query_ids, + query_source_namespace="snowflake_ns", + task_instance=mock_ti, + hook=mock.MagicMock(), # any non-None hook to trigger metadata retrieval + additional_run_facets=additional_run_facets, + additional_job_facets=additional_job_facets, + ) + + assert query_ids == original_query_ids # Verify that the input query_ids list is unchanged. + assert fake_client.emit.call_count == 6 # Expect two events per query. + + expected_common_job_facets = { + "jobType": job_type_job.JobTypeJobFacet( + jobType="QUERY", + processingType="BATCH", + integration="SNOWFLAKE", + ), + "custom_job": "value_job", + } + expected_common_run_facets = { + "parent": parent_run.ParentRunFacet( + run=parent_run.Run(runId="01941f29-7c00-7087-8906-40e512c257bd"), + job=parent_run.Job(namespace=namespace(), name="dag_id.task_id"), + ), + "custom_run": "value_run", + } + + expected_calls = [ + mock.call( # Query1: START event + RunEvent( + eventTime=fake_metadata["query1"]["START_TIME"].isoformat(), + eventType=RunState.START, + run=Run( + runId=fake_uuid, + facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query1", source="snowflake_ns" + ), + **expected_common_run_facets, + }, + ), + job=Job( + namespace=namespace(), + name="dag_id.task_id.query.1", + facets={ + "sql": SQLJobFacet(query="SELECT * FROM table1"), + **expected_common_job_facets, + }, + ), + ) + ), + mock.call( # Query1: COMPLETE event + RunEvent( + eventTime=fake_metadata["query1"]["END_TIME"].isoformat(), + eventType=RunState.COMPLETE, + run=Run( + runId=fake_uuid, + facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query1", source="snowflake_ns" + ), + **expected_common_run_facets, + }, + ), + job=Job( + namespace=namespace(), + name="dag_id.task_id.query.1", + facets={ + "sql": SQLJobFacet(query="SELECT * FROM table1"), + **expected_common_job_facets, + }, + ), + ) + ), + mock.call( # Query2: START event + RunEvent( + eventTime=fake_metadata["query2"]["START_TIME"].isoformat(), + eventType=RunState.START, + run=Run( + runId=fake_uuid, + facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query2", source="snowflake_ns" + ), + "error": ErrorMessageRunFacet( + message="ERR001 : Error occurred", programmingLanguage="SQL" + ), + **expected_common_run_facets, + }, + ), + job=Job( + namespace=namespace(), + name="dag_id.task_id.query.2", + facets={ + "sql": SQLJobFacet(query="SELECT * FROM table2"), + **expected_common_job_facets, + }, + ), + ) + ), + mock.call( # Query2: FAIL event + RunEvent( + eventTime=fake_metadata["query2"]["END_TIME"].isoformat(), + eventType=RunState.FAIL, + run=Run( + runId=fake_uuid, + facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query2", source="snowflake_ns" + ), + "error": ErrorMessageRunFacet( + message="ERR001 : Error occurred", programmingLanguage="SQL" + ), + **expected_common_run_facets, + }, + ), + job=Job( + namespace=namespace(), + name="dag_id.task_id.query.2", + facets={ + "sql": SQLJobFacet(query="SELECT * FROM table2"), + **expected_common_job_facets, + }, + ), + ) + ), + mock.call( # Query3: START event (no metadata) + RunEvent( + eventTime=default_event_time.isoformat(), # no metadata for query3 + eventType=RunState.START, + run=Run( + runId=fake_uuid, + facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query3", source="snowflake_ns" + ), + **expected_common_run_facets, + }, + ), + job=Job( + namespace=namespace(), + name="dag_id.task_id.query.3", + facets=expected_common_job_facets, + ), + ) + ), + mock.call( # Query3: FAIL event (no metadata) + RunEvent( + eventTime=default_event_time.isoformat(), # no metadata for query3 + eventType=RunState.FAIL, + run=Run( + runId=fake_uuid, + facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query3", source="snowflake_ns" + ), + **expected_common_run_facets, + }, + ), + job=Job( + namespace=namespace(), + name="dag_id.task_id.query.3", + facets=expected_common_job_facets, + ), + ) + ), + ] + + assert fake_client.emit.call_args_list == expected_calls + + +@mock.patch("openlineage.client.uuid.generate_new_uuid") +@mock.patch("airflow.utils.timezone.utcnow") +def test_emit_openlineage_events_for_snowflake_queries_without_hook(mock_now, mock_generate_uuid): + fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f" + mock_generate_uuid.return_value = fake_uuid + + default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0) + mock_now.return_value = default_event_time + + query_ids = ["query1"] + original_query_ids = copy.deepcopy(query_ids) + logical_date = timezone.datetime(2025, 1, 1) + mock_ti = mock.MagicMock( + dag_id="dag_id", + task_id="task_id", + map_index=1, + try_number=1 if AIRFLOW_V_2_10_PLUS else 2, + logical_date=logical_date, + state=TaskInstanceState.SUCCESS, # This will be query default state if no metadata found + ) + mock_ti.get_template_context.return_value = {"dag_run": mock.MagicMock(logical_date=logical_date)} + + additional_run_facets = {"custom_run": "value_run"} + additional_job_facets = {"custom_job": "value_job"} + + fake_client = mock.MagicMock() + fake_client.emit = mock.MagicMock() + fake_listener = mock.MagicMock() + fake_listener.adapter.get_or_create_openlineage_client.return_value = fake_client + + with mock.patch( + "airflow.providers.openlineage.plugins.listener.get_openlineage_listener", + return_value=fake_listener, + ): + emit_openlineage_events_for_snowflake_queries( + query_ids=query_ids, + query_source_namespace="snowflake_ns", + task_instance=mock_ti, + hook=None, # None so metadata retrieval is not triggered + additional_run_facets=additional_run_facets, + additional_job_facets=additional_job_facets, + ) + + assert query_ids == original_query_ids # Verify that the input query_ids list is unchanged. + assert fake_client.emit.call_count == 2 # Expect two events per query. + + expected_common_job_facets = { + "jobType": job_type_job.JobTypeJobFacet( + jobType="QUERY", + processingType="BATCH", + integration="SNOWFLAKE", + ), + "custom_job": "value_job", + } + expected_common_run_facets = { + "parent": parent_run.ParentRunFacet( + run=parent_run.Run(runId="01941f29-7c00-7087-8906-40e512c257bd"), + job=parent_run.Job(namespace=namespace(), name="dag_id.task_id"), + ), + "custom_run": "value_run", + } + + expected_calls = [ + mock.call( # Query1: START event (no metadata) + RunEvent( + eventTime=default_event_time.isoformat(), + eventType=RunState.START, + run=Run( + runId=fake_uuid, + facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query1", source="snowflake_ns" + ), + **expected_common_run_facets, + }, + ), + job=Job( + namespace=namespace(), + name="dag_id.task_id.query.1", + facets=expected_common_job_facets, + ), + ) + ), + mock.call( # Query1: COMPLETE event (no metadata) + RunEvent( + eventTime=default_event_time.isoformat(), + eventType=RunState.COMPLETE, + run=Run( + runId=fake_uuid, + facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query1", source="snowflake_ns" + ), + **expected_common_run_facets, + }, + ), + job=Job( + namespace=namespace(), + name="dag_id.task_id.query.1", + facets=expected_common_job_facets, + ), + ) + ), + ] + + assert fake_client.emit.call_args_list == expected_calls + + +def test_emit_openlineage_events_for_snowflake_queries_without_query_ids(): + query_ids = [] + original_query_ids = copy.deepcopy(query_ids) + + fake_client = mock.MagicMock() + fake_client.emit = mock.MagicMock() + fake_listener = mock.MagicMock() + fake_listener.adapter.get_or_create_openlineage_client.return_value = fake_client + + with mock.patch( + "airflow.providers.openlineage.plugins.listener.get_openlineage_listener", + return_value=fake_listener, + ): + emit_openlineage_events_for_snowflake_queries( + query_ids=query_ids, + query_source_namespace="snowflake_ns", + task_instance=None, + ) + + assert query_ids == original_query_ids # Verify that the input query_ids list is unchanged. + fake_client.emit.assert_not_called() # No events should be emitted + + +# emit_openlineage_events_for_snowflake_queries requires OL provider 2.0.0 +@mock.patch("importlib.metadata.version", return_value="1.99.0") +def test_emit_openlineage_events_with_old_openlineage_provider(mock_version): + query_ids = ["q1", "q2"] + original_query_ids = copy.deepcopy(query_ids) + + fake_client = mock.MagicMock() + fake_client.emit = mock.MagicMock() + fake_listener = mock.MagicMock() + fake_listener.adapter.get_or_create_openlineage_client.return_value = fake_client + + with mock.patch( + "airflow.providers.openlineage.plugins.listener.get_openlineage_listener", + return_value=fake_listener, + ): + expected_err = ( + "OpenLineage provider version `1.99.0` is lower than required `2.0.0`, " + "skipping function `emit_openlineage_events_for_snowflake_queries` execution" + ) + + with pytest.raises(AirflowOptionalProviderFeatureException, match=expected_err): + emit_openlineage_events_for_snowflake_queries( + query_ids=query_ids, + query_source_namespace="snowflake_ns", + task_instance=None, + ) + assert query_ids == original_query_ids # Verify that the input query_ids list is unchanged. + fake_client.emit.assert_not_called() # No events should be emitted