Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -617,10 +617,9 @@ def _get_openlineage_authority(self, _) -> str | None:

def get_openlineage_database_specific_lineage(self, task_instance) -> OperatorLineage | None:
"""
Generate OpenLineage metadata for a Snowflake task instance based on executed query IDs.
Emit separate OpenLineage events for each Snowflake query, based on executed query IDs.

If a single query ID is present, attach an `ExternalQueryRunFacet` to the lineage metadata.
If multiple query IDs are present, emits separate OpenLineage events for each query.
If a single query ID is present, also add an `ExternalQueryRunFacet` to the returned lineage metadata.

Note that `get_openlineage_database_specific_lineage` is usually called after task's execution,
so if multiple query IDs are present, both START and COMPLETE event for each query will be emitted
Expand All @@ -641,13 +640,22 @@ def get_openlineage_database_specific_lineage(self, task_instance) -> OperatorLi
)

if not self.query_ids:
self.log.debug("openlineage: no snowflake query ids found.")
self.log.info("OpenLineage could not find snowflake query ids.")
return None

self.log.debug("openlineage: getting connection to get database info")
connection = self.get_connection(self.get_conn_id())
namespace = SQLParser.create_namespace(self.get_openlineage_database_info(connection))

self.log.info("Separate OpenLineage events will be emitted for each query_id.")
emit_openlineage_events_for_snowflake_queries(
task_instance=task_instance,
hook=self,
query_ids=self.query_ids,
query_for_extra_metadata=True,
query_source_namespace=namespace,
)

if len(self.query_ids) == 1:
self.log.debug("Attaching ExternalQueryRunFacet with single query_id to OpenLineage event.")
return OperatorLineage(
Expand All @@ -658,20 +666,4 @@ def get_openlineage_database_specific_lineage(self, task_instance) -> OperatorLi
}
)

self.log.info("Multiple query_ids found. Separate OpenLineage event will be emitted for each query.")
try:
from airflow.providers.openlineage.utils.utils import should_use_external_connection

use_external_connection = should_use_external_connection(self)
except ImportError:
# OpenLineage provider release < 1.8.0 - we always use connection
use_external_connection = True

emit_openlineage_events_for_snowflake_queries(
query_ids=self.query_ids,
query_source_namespace=namespace,
task_instance=task_instance,
hook=self if use_external_connection else None,
)

return None
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import datetime
import logging
from contextlib import closing
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from urllib.parse import quote, urlparse, urlunparse

from airflow.providers.common.compat.openlineage.check import require_openlineage_version
Expand All @@ -31,6 +31,7 @@
from openlineage.client.facet_v2 import JobFacet

from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
from airflow.providers.snowflake.hooks.snowflake_sql_api import SnowflakeSqlApiHook


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -204,9 +205,29 @@ def _run_single_query_with_hook(hook: SnowflakeHook, sql: str) -> list[dict]:
return result


def _run_single_query_with_api_hook(hook: SnowflakeSqlApiHook, sql: str) -> list[dict[str, Any]]:
"""Execute a query against Snowflake API without adding extra logging or instrumentation."""
# `hook.execute_query` resets the query_ids, so we need to save them and re-assign after we're done
query_ids_before_execution = list(hook.query_ids)
try:
_query_ids = hook.execute_query(sql=sql, statement_count=0)
hook.wait_for_query(query_id=_query_ids[0], raise_error=True, poll_interval=1, timeout=3)
return hook.get_result_from_successful_sql_api_query(query_id=_query_ids[0])
finally:
hook.query_ids = query_ids_before_execution


def _process_data_from_api(data: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert 'START_TIME' and 'END_TIME' fields to UTC datetime objects."""
for row in data:
for key in ("START_TIME", "END_TIME"):
row[key] = datetime.datetime.fromtimestamp(float(row[key]), timezone.utc)
return data


def _get_queries_details_from_snowflake(
hook: SnowflakeHook, query_ids: list[str]
) -> dict[str, dict[str, str]]:
hook: SnowflakeHook | SnowflakeSqlApiHook, query_ids: list[str]
) -> dict[str, dict[str, Any]]:
"""Retrieve execution details for specific queries from Snowflake's query history."""
if not query_ids:
return {}
Expand All @@ -221,7 +242,16 @@ def _get_queries_details_from_snowflake(
f";"
)

result = _run_single_query_with_hook(hook=hook, sql=query)
try:
# Can't import the SnowflakeSqlApiHook class and do proper isinstance check - circular imports
if hook.__class__.__name__ == "SnowflakeSqlApiHook":
result = _run_single_query_with_api_hook(hook=hook, sql=query) # type: ignore[arg-type]
result = _process_data_from_api(data=result)
else:
result = _run_single_query_with_hook(hook=hook, sql=query)
except Exception as e:
log.warning("OpenLineage could not retrieve extra metadata from Snowflake. Error encountered: %s", e)
result = []

return {row["QUERY_ID"]: row for row in result} if result else {}

Expand Down Expand Up @@ -259,17 +289,18 @@ def _create_snowflake_event_pair(

@require_openlineage_version(provider_min_version="2.3.0")
def emit_openlineage_events_for_snowflake_queries(
query_ids: list[str],
query_source_namespace: str,
task_instance,
hook: SnowflakeHook | None = None,
hook: SnowflakeHook | SnowflakeSqlApiHook | None = None,
query_ids: list[str] | None = None,
query_source_namespace: str | None = None,
query_for_extra_metadata: bool = False,
additional_run_facets: dict | None = None,
additional_job_facets: dict | None = None,
) -> None:
"""
Emit OpenLineage events for executed Snowflake queries.

Metadata retrieval from Snowflake is attempted only if a `SnowflakeHook` is provided.
Metadata retrieval from Snowflake is attempted only if `get_extra_metadata` is True and hook is provided.
If metadata is available, execution details such as start time, end time, execution status,
error messages, and SQL text are included in the events. If no metadata is found, the function
defaults to using the Airflow task instance's state and the current timestamp.
Expand All @@ -279,10 +310,16 @@ def emit_openlineage_events_for_snowflake_queries(
will correspond to actual query execution times.

Args:
query_ids: A list of Snowflake query IDs to emit events for.
query_source_namespace: The namespace to be included in ExternalQueryRunFacet.
task_instance: The Airflow task instance that run these queries.
hook: A SnowflakeHook instance used to retrieve query metadata if available.
hook: A supported Snowflake hook instance used to retrieve query metadata if available.
If omitted, `query_ids` and `query_source_namespace` must be provided explicitly and
`query_for_extra_metadata` must be `False`.
query_ids: A list of Snowflake query IDs to emit events for, can only be None if `hook` is provided
and `hook.query_ids` are present.
query_source_namespace: The namespace to be included in ExternalQueryRunFacet,
can be `None` only if hook is provided.
query_for_extra_metadata: Whether to query Snowflake for additional metadata about queries.
Must be `False` if `hook` is not provided.
additional_run_facets: Additional run facets to include in OpenLineage events.
additional_job_facets: Additional job facets to include in OpenLineage events.
"""
Expand All @@ -297,23 +334,49 @@ def emit_openlineage_events_for_snowflake_queries(
from airflow.providers.openlineage.conf import namespace
from airflow.providers.openlineage.plugins.listener import get_openlineage_listener

if not query_ids:
log.debug("No Snowflake query IDs provided; skipping OpenLineage event emission.")
return

query_ids = [q for q in query_ids] # Make a copy to make sure it does not change
log.info("OpenLineage will emit events for Snowflake queries.")

if hook:
if not query_ids:
log.debug("No Snowflake query IDs provided; Checking `hook.query_ids` property.")
query_ids = getattr(hook, "query_ids", [])
if not query_ids:
raise ValueError("No Snowflake query IDs provided and `hook.query_ids` are not present.")

if not query_source_namespace:
log.debug("No Snowflake query namespace provided; Creating one from scratch.")
from airflow.providers.openlineage.sqlparser import SQLParser

connection = hook.get_connection(hook.get_conn_id())
query_source_namespace = SQLParser.create_namespace(
hook.get_openlineage_database_info(connection)
)
else:
if not query_ids:
raise ValueError("If 'hook' is not provided, 'query_ids' must be set.")
if not query_source_namespace:
raise ValueError("If 'hook' is not provided, 'query_source_namespace' must be set.")
if query_for_extra_metadata:
raise ValueError("If 'hook' is not provided, 'query_for_extra_metadata' must be False.")

query_ids = [q for q in query_ids] # Make a copy to make sure we do not change hook's attribute

if query_for_extra_metadata and hook:
log.debug("Retrieving metadata for %s queries from Snowflake.", len(query_ids))
snowflake_metadata = _get_queries_details_from_snowflake(hook, query_ids)
else:
log.debug("SnowflakeHook not provided. No extra metadata fill be fetched from Snowflake.")
log.debug("`query_for_extra_metadata` is False. No extra metadata fill be fetched from Snowflake.")
snowflake_metadata = {}

# If real metadata is unavailable, we send events with eventTime=now
default_event_time = timezone.utcnow()
# If no query metadata is provided, we use task_instance's state when checking for success
default_state = task_instance.state.value if hasattr(task_instance, "state") else ""
# ti.state has no `value` attr (AF2) when task it's still running, in AF3 we get 'running', in that case
# assuming it's user call and query succeeded, so we replace it with success.
default_state = (
getattr(task_instance.state, "value", "running") if hasattr(task_instance, "state") else ""
)
default_state = "success" if default_state == "running" else default_state

common_run_facets = {"parent": _get_parent_run_facet(task_instance)}
common_job_facets: dict[str, JobFacet] = {
Expand Down
39 changes: 21 additions & 18 deletions providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,14 +905,17 @@ def test_get_openlineage_default_schema_with_schema_set(self, mock_get_first):
assert hook_with_schema_param.get_openlineage_default_schema() == "my_schema"
mock_get_first.assert_not_called()

def test_get_openlineage_database_specific_lineage_with_no_query_ids(self):
@mock.patch("airflow.providers.snowflake.utils.openlineage.emit_openlineage_events_for_snowflake_queries")
def test_get_openlineage_database_specific_lineage_with_no_query_ids(self, mock_emit):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
assert hook.query_ids == []

result = hook.get_openlineage_database_specific_lineage(None)
mock_emit.assert_not_called()
assert result is None

def test_get_openlineage_database_specific_lineage_with_single_query_id(self):
@mock.patch("airflow.providers.snowflake.utils.openlineage.emit_openlineage_events_for_snowflake_queries")
def test_get_openlineage_database_specific_lineage_with_single_query_id(self, mock_emit):
from airflow.providers.common.compat.openlineage.facet import ExternalQueryRunFacet
from airflow.providers.openlineage.extractors import OperatorLineage

Expand All @@ -921,21 +924,26 @@ def test_get_openlineage_database_specific_lineage_with_single_query_id(self):
hook.get_connection = mock.MagicMock()
hook.get_openlineage_database_info = lambda x: mock.MagicMock(authority="auth", scheme="scheme")

result = hook.get_openlineage_database_specific_lineage(None)
ti = mock.MagicMock()

result = hook.get_openlineage_database_specific_lineage(ti)
mock_emit.assert_called_once_with(
**{
"hook": hook,
"query_ids": ["query1"],
"query_source_namespace": "scheme://auth",
"task_instance": ti,
"query_for_extra_metadata": True,
}
)
assert result == OperatorLineage(
run_facets={
"externalQuery": ExternalQueryRunFacet(externalQueryId="query1", source="scheme://auth")
}
)

@pytest.mark.parametrize("use_external_connection", [True, False])
@mock.patch("airflow.providers.openlineage.utils.utils.should_use_external_connection")
@mock.patch("airflow.providers.snowflake.utils.openlineage.emit_openlineage_events_for_snowflake_queries")
def test_get_openlineage_database_specific_lineage_with_multiple_query_ids(
self, mock_emit, mock_use_conn, use_external_connection
):
mock_use_conn.return_value = use_external_connection

def test_get_openlineage_database_specific_lineage_with_multiple_query_ids(self, mock_emit):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
hook.query_ids = ["query1", "query2"]
hook.get_connection = mock.MagicMock()
Expand All @@ -944,23 +952,19 @@ def test_get_openlineage_database_specific_lineage_with_multiple_query_ids(
ti = mock.MagicMock()

result = hook.get_openlineage_database_specific_lineage(ti)
mock_use_conn.assert_called_once()
mock_emit.assert_called_once_with(
**{
"hook": hook if use_external_connection else None,
"hook": hook,
"query_ids": ["query1", "query2"],
"query_source_namespace": "scheme://auth",
"task_instance": ti,
"query_for_extra_metadata": True,
}
)
assert result is None

# emit_openlineage_events_for_snowflake_queries requires OL provider 2.0.0
@mock.patch("importlib.metadata.version", return_value="1.99.0")
@mock.patch("airflow.providers.openlineage.utils.utils.should_use_external_connection")
def test_get_openlineage_database_specific_lineage_with_old_openlineage_provider(
self, mock_use_conn, mock_version
):
def test_get_openlineage_database_specific_lineage_with_old_openlineage_provider(self, mock_version):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
hook.query_ids = ["query1", "query2"]
hook.get_connection = mock.MagicMock()
Expand All @@ -972,7 +976,6 @@ def test_get_openlineage_database_specific_lineage_with_old_openlineage_provider
)
with pytest.raises(AirflowOptionalProviderFeatureException, match=expected_err):
hook.get_openlineage_database_specific_lineage(mock.MagicMock())
mock_use_conn.assert_called_once()

@pytest.mark.skipif(sys.version_info >= (3, 12), reason="Snowpark Python doesn't support Python 3.12 yet")
@mock.patch("snowflake.snowpark.Session.builder")
Expand Down
Loading