Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 119 additions & 90 deletions providers/common/sql/src/airflow/providers/common/sql/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,86 @@ def _raise_exception(self, exception_string: str) -> NoReturn:
raise AirflowException(exception_string)
raise AirflowFailException(exception_string)

def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
"""Generate OpenLineage facets on start for SQL operators."""
try:
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.openlineage.sqlparser import SQLParser
except ImportError:
self.log.debug("OpenLineage could not import required classes. Skipping.")
return None

sql = getattr(self, "sql", None)
if not sql:
self.log.debug("OpenLineage could not find 'sql' attribute on `%s`.", type(self).__name__)
return OperatorLineage()

hook = self.get_db_hook()
try:
from airflow.providers.openlineage.utils.utils import should_use_external_connection

use_external_connection = should_use_external_connection(hook)
except ImportError:
# OpenLineage provider release < 1.8.0 - we always use connection
use_external_connection = True

connection = hook.get_connection(getattr(hook, hook.conn_name_attr))
try:
database_info = hook.get_openlineage_database_info(connection)
except AttributeError:
self.log.debug("%s has no database info provided", hook)
database_info = None

if database_info is None:
self.log.debug("OpenLineage could not retrieve database information. Skipping.")
return OperatorLineage()

try:
sql_parser = SQLParser(
dialect=hook.get_openlineage_database_dialect(connection),
default_schema=hook.get_openlineage_default_schema(),
)
except AttributeError:
self.log.debug("%s failed to get database dialect", hook)
return None

operator_lineage = sql_parser.generate_openlineage_metadata_from_sql(
sql=sql,
hook=hook,
database_info=database_info,
database=self.database,
sqlalchemy_engine=hook.get_sqlalchemy_engine(),
use_connection=use_external_connection,
)

return operator_lineage

def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage | None:
"""Generate OpenLineage facets when task completes."""
try:
from airflow.providers.openlineage.extractors import OperatorLineage
except ImportError:
self.log.debug("OpenLineage could not import required classes. Skipping.")
return None

operator_lineage = self.get_openlineage_facets_on_start() or OperatorLineage()
hook = self.get_db_hook()
try:
database_specific_lineage = hook.get_openlineage_database_specific_lineage(task_instance)
except AttributeError:
self.log.debug("%s has no database specific lineage provided", hook)
database_specific_lineage = None

if database_specific_lineage is None:
return operator_lineage

return OperatorLineage(
inputs=operator_lineage.inputs + database_specific_lineage.inputs,
outputs=operator_lineage.outputs + database_specific_lineage.outputs,
run_facets=merge_dicts(operator_lineage.run_facets, database_specific_lineage.run_facets),
job_facets=merge_dicts(operator_lineage.job_facets, database_specific_lineage.job_facets),
)


class SQLExecuteQueryOperator(BaseSQLOperator):
"""
Expand Down Expand Up @@ -343,76 +423,6 @@ def prepare_template(self) -> None:
if isinstance(self.parameters, str):
self.parameters = ast.literal_eval(self.parameters)

def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
try:
from airflow.providers.openlineage.sqlparser import SQLParser
except ImportError:
return None

hook = self.get_db_hook()

try:
from airflow.providers.openlineage.utils.utils import should_use_external_connection

use_external_connection = should_use_external_connection(hook)
except ImportError:
# OpenLineage provider release < 1.8.0 - we always use connection
use_external_connection = True

connection = hook.get_connection(getattr(hook, hook.conn_name_attr))
try:
database_info = hook.get_openlineage_database_info(connection)
except AttributeError:
self.log.debug("%s has no database info provided", hook)
database_info = None

if database_info is None:
return None

try:
sql_parser = SQLParser(
dialect=hook.get_openlineage_database_dialect(connection),
default_schema=hook.get_openlineage_default_schema(),
)
except AttributeError:
self.log.debug("%s failed to get database dialect", hook)
return None

operator_lineage = sql_parser.generate_openlineage_metadata_from_sql(
sql=self.sql,
hook=hook,
database_info=database_info,
database=self.database,
sqlalchemy_engine=hook.get_sqlalchemy_engine(),
use_connection=use_external_connection,
)

return operator_lineage

def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage | None:
try:
from airflow.providers.openlineage.extractors import OperatorLineage
except ImportError:
return None

operator_lineage = self.get_openlineage_facets_on_start() or OperatorLineage()

hook = self.get_db_hook()
try:
database_specific_lineage = hook.get_openlineage_database_specific_lineage(task_instance)
except AttributeError:
database_specific_lineage = None

if database_specific_lineage is None:
return operator_lineage

return OperatorLineage(
inputs=operator_lineage.inputs + database_specific_lineage.inputs,
outputs=operator_lineage.outputs + database_specific_lineage.outputs,
run_facets=merge_dicts(operator_lineage.run_facets, database_specific_lineage.run_facets),
job_facets=merge_dicts(operator_lineage.job_facets, database_specific_lineage.job_facets),
)


class SQLColumnCheckOperator(BaseSQLOperator):
"""
Expand Down Expand Up @@ -999,8 +1009,13 @@ def __init__(

self.sql1 = f"{sqlt}'{{{{ ds }}}}'"
self.sql2 = f"{sqlt}'{{{{ macros.ds_add(ds, {self.days_back}) }}}}'"
# Save all queries as `sql` attr - similar to other sql operators (to be used by listeners).
self.sql: list[str] = [self.sql1, self.sql2]

def execute(self, context: Context):
# Re-set with templated queries
self.sql = [self.sql1, self.sql2]

hook = self.get_db_hook()
self.log.info("Using ratio formula: %s", self.ratio_formula)
self.log.info("Executing SQL check: %s", self.sql2)
Expand All @@ -1017,25 +1032,36 @@ def execute(self, context: Context):
reference = dict(zip(self.metrics_sorted, row2))

ratios: dict[str, int | None] = {}
test_results = {}
# Save all details about all tests to be used in error message if needed
all_tests_results: dict[str, dict[str, Any]] = {}

for metric in self.metrics_sorted:
cur = current[metric]
ref = reference[metric]
threshold = self.metrics_thresholds[metric]
single_metric_results = {
"metric": metric,
"current_metric": cur,
"past_metric": ref,
"threshold": threshold,
"ignore_zero": self.ignore_zero,
}
if cur == 0 or ref == 0:
ratios[metric] = None
test_results[metric] = self.ignore_zero
single_metric_results["ratio"] = None
single_metric_results["success"] = self.ignore_zero
else:
ratio_metric = self.ratio_formulas[self.ratio_formula](current[metric], reference[metric])
ratios[metric] = ratio_metric
single_metric_results["ratio"] = ratio_metric
if ratio_metric is not None:
test_results[metric] = ratio_metric < threshold
single_metric_results["success"] = ratio_metric < threshold
else:
test_results[metric] = self.ignore_zero
single_metric_results["success"] = self.ignore_zero

all_tests_results[metric] = single_metric_results
self.log.info(
("Current metric for %s: %s\nPast metric for %s: %s\nRatio for %s: %s\nThreshold: %s\n"),
"Current metric for %s: %s\nPast metric for %s: %s\nRatio for %s: %s\nThreshold: %s\n",
metric,
cur,
metric,
Expand All @@ -1045,21 +1071,24 @@ def execute(self, context: Context):
threshold,
)

if not all(test_results.values()):
failed_tests = [it[0] for it in test_results.items() if not it[1]]
failed_tests = [single for single in all_tests_results.values() if not single["success"]]
if failed_tests:
self.log.warning(
"The following %s tests out of %s failed:",
len(failed_tests),
len(self.metrics_sorted),
)
for k in failed_tests:
for single_filed_test in failed_tests:
self.log.warning(
"'%s' check failed. %s is above %s",
k,
ratios[k],
self.metrics_thresholds[k],
single_filed_test["metric"],
single_filed_test["ratio"],
single_filed_test["threshold"],
)
self._raise_exception(f"The following tests have failed:\n {', '.join(sorted(failed_tests))}")
failed_test_details = "; ".join(
f"{t['metric']}: {t}" for t in sorted(failed_tests, key=lambda x: x["metric"])
)
self._raise_exception(f"The following tests have failed:\n {failed_test_details}")

self.log.info("All tests have passed")

Expand Down Expand Up @@ -1206,6 +1235,8 @@ def __init__(
self.parameters = parameters
self.follow_task_ids_if_true = follow_task_ids_if_true
self.follow_task_ids_if_false = follow_task_ids_if_false
# Chosen branch, after evaluating condition, set during execution, to be used by listeners
self.follow_branch: list[str] | None = None

def execute(self, context: Context):
self.log.info(
Expand All @@ -1232,32 +1263,30 @@ def execute(self, context: Context):

self.log.info("Query returns %s, type '%s'", query_result, type(query_result))

follow_branch = None
try:
if isinstance(query_result, bool):
if query_result:
follow_branch = self.follow_task_ids_if_true
self.follow_branch = self.follow_task_ids_if_true
elif isinstance(query_result, str):
# return result is not Boolean, try to convert from String to Boolean
if _parse_boolean(query_result):
follow_branch = self.follow_task_ids_if_true
self.follow_branch = self.follow_task_ids_if_true
elif isinstance(query_result, int):
if bool(query_result):
follow_branch = self.follow_task_ids_if_true
self.follow_branch = self.follow_task_ids_if_true
else:
raise AirflowException(
f"Unexpected query return result '{query_result}' type '{type(query_result)}'"
)

if follow_branch is None:
follow_branch = self.follow_task_ids_if_false
if self.follow_branch is None:
self.follow_branch = self.follow_task_ids_if_false
except ValueError:
raise AirflowException(
f"Unexpected query return result '{query_result}' type '{type(query_result)}'"
)

# TODO(potiuk) remove the type ignore once we solve provider <-> Task SDK relationship
self.skip_all_except(context["ti"], follow_branch)
self.skip_all_except(context["ti"], self.follow_branch)


class SQLInsertRowsOperator(BaseSQLOperator):
Expand Down
20 changes: 18 additions & 2 deletions providers/common/sql/tests/unit/common/sql/operators/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,16 @@ def returned_row():
ignore_zero=True,
)

with pytest.raises(AirflowException, match="f0, f1, f2"):
expected_err_message = (
"The following tests have failed:\n "
"f0: {'metric': 'f0', 'current_metric': 1, 'past_metric': 2, 'threshold': 1.0,"
" 'ignore_zero': True, 'ratio': 2.0, 'success': False}; "
"f1: {'metric': 'f1', 'current_metric': 1, 'past_metric': 2, 'threshold': 1.5,"
" 'ignore_zero': True, 'ratio': 2.0, 'success': False}; "
"f2: {'metric': 'f2', 'current_metric': 1, 'past_metric': 2, 'threshold': 2.0,"
" 'ignore_zero': True, 'ratio': 2.0, 'success': False}"
)
with pytest.raises(AirflowException, match=expected_err_message):
operator.execute(context=MagicMock())

@mock.patch.object(SQLIntervalCheckOperator, "get_db_hook")
Expand Down Expand Up @@ -969,7 +978,14 @@ def returned_row():
ignore_zero=True,
)

with pytest.raises(AirflowException, match="f0, f1"):
expected_err_message = (
"The following tests have failed:\n "
"f0: {'metric': 'f0', 'current_metric': 1, 'past_metric': 3, 'threshold': 0.5, "
"'ignore_zero': True, 'ratio': 0.6666666666666666, 'success': False}; "
"f1: {'metric': 'f1', 'current_metric': 1, 'past_metric': 3, 'threshold': 0.6, "
"'ignore_zero': True, 'ratio': 0.6666666666666666, 'success': False}"
)
with pytest.raises(AirflowException, match=expected_err_message):
operator.execute(context=MagicMock())


Expand Down