Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,7 @@ def __init__(
tolerance: Any = None,
conn_id: str | None = None,
database: str | None = None,
parameters: Iterable | Mapping[str, Any] | None = None,
**kwargs,
):
super().__init__(conn_id=conn_id, database=database, **kwargs)
Expand All @@ -871,6 +872,7 @@ def __init__(
tol = _convert_to_float_if_possible(tolerance)
self.tol = tol if isinstance(tol, float) else None
self.has_tolerance = self.tol is not None
self.parameters = parameters

def check_value(self, records):
if not records:
Expand Down Expand Up @@ -903,7 +905,7 @@ def check_value(self, records):

def execute(self, context: Context):
self.log.info("Executing SQL check: %s", self.sql)
records = self.get_db_hook().get_first(self.sql)
records = self.get_db_hook().get_first(self.sql, self.parameters)
self.check_value(records)

def _to_float(self, records):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
)
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.utils import timezone
from airflow.utils import timezone # type: ignore[attr-defined]
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.types import DagRunType
Expand Down Expand Up @@ -844,7 +844,7 @@ def test_execute_pass(self, mock_get_db_hook):

operator.execute(None)

mock_hook.get_first.assert_called_once_with(sql)
mock_hook.get_first.assert_called_once_with(sql, None)

@mock.patch.object(SQLValueCheckOperator, "get_db_hook")
def test_execute_fail(self, mock_get_db_hook):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def _make_api_call_with_retries(
:param url: The URL for the API endpoint.
:param headers: The headers to include in the API call.
:param params: (Optional) The query parameters to include in the API call.
:param data: (Optional) The data to include in the API call.
:param json: (Optional) The data to include in the API call.
:return: The response object from the API call.
"""
with requests.Session() as session:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ class SnowflakeCheckOperator(SQLCheckOperator):
Template references are recognized by str ending in '.sql'
:param snowflake_conn_id: Reference to
:ref:`Snowflake connection id<howto/connection:snowflake>`
:param autocommit: if True, each command is automatically committed.
(default value: True)
:param parameters: (optional) the parameters to render the SQL query with.
:param warehouse: name of warehouse (will overwrite any warehouse
defined in the connection's extra JSON)
Expand Down Expand Up @@ -109,8 +107,6 @@ def __init__(
sql: str,
snowflake_conn_id: str = "snowflake_default",
parameters: Iterable | Mapping[str, Any] | None = None,
autocommit: bool = True,
do_xcom_push: bool = True,
warehouse: str | None = None,
database: str | None = None,
role: str | None = None,
Expand Down Expand Up @@ -179,8 +175,6 @@ def __init__(
tolerance: Any = None,
snowflake_conn_id: str = "snowflake_default",
parameters: Iterable | Mapping[str, Any] | None = None,
autocommit: bool = True,
do_xcom_push: bool = True,
warehouse: str | None = None,
database: str | None = None,
role: str | None = None,
Expand All @@ -202,7 +196,12 @@ def __init__(
**hook_params,
}
super().__init__(
sql=sql, pass_value=pass_value, tolerance=tolerance, conn_id=snowflake_conn_id, **kwargs
sql=sql,
pass_value=pass_value,
tolerance=tolerance,
conn_id=snowflake_conn_id,
parameters=parameters,
**kwargs,
)
self.query_ids: list[str] = []

Expand Down Expand Up @@ -259,9 +258,6 @@ def __init__(
date_filter_column: str = "ds",
days_back: SupportsAbs[int] = -7,
snowflake_conn_id: str = "snowflake_default",
parameters: Iterable | Mapping[str, Any] | None = None,
autocommit: bool = True,
do_xcom_push: bool = True,
warehouse: str | None = None,
database: str | None = None,
role: str | None = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
from __future__ import annotations

from unittest import mock
from unittest.mock import call

import pendulum
import pytest

from airflow.exceptions import AirflowException, TaskDeferred
from airflow.models import Connection
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
Expand All @@ -34,7 +36,7 @@
SnowflakeValueCheckOperator,
)
from airflow.providers.snowflake.triggers.snowflake_trigger import SnowflakeSqlApiTrigger
from airflow.utils import timezone
from airflow.utils import timezone # type: ignore[attr-defined]
from airflow.utils.types import DagRunType

from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
Expand Down Expand Up @@ -107,25 +109,80 @@ def test_overwrite_params(self, mock_base_op):
)


@pytest.mark.parametrize(
"operator_class, kwargs",
[
(SnowflakeCheckOperator, dict(sql="Select * from test_table")),
(SnowflakeValueCheckOperator, dict(sql="Select * from test_table", pass_value=95)),
(SnowflakeIntervalCheckOperator, dict(table="test-table-id", metrics_thresholds={"COUNT(*)": 1.5})),
],
)
class TestSnowflakeCheckOperators:
@mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.get_db_hook")
@pytest.fixture(autouse=True)
def setup_connections(create_connection_without_db):
create_connection_without_db(
Connection(
conn_id="snowflake_default",
conn_type="snowflake",
host="test_host",
port=443,
schema="test_schema",
login="test_user",
password="test_password",
)
)


class TestSnowflakeCheckOperator:
@mock.patch("airflow.providers.common.sql.operators.sql.SQLCheckOperator.get_db_hook")
def test_get_db_hook(
self,
mock_get_db_hook,
operator_class,
kwargs,
):
operator = operator_class(task_id="snowflake_check", snowflake_conn_id="snowflake_default", **kwargs)
operator.get_db_hook()
mock_get_db_hook.assert_called_once()
operator = SnowflakeCheckOperator(
task_id="snowflake_check",
snowflake_conn_id="snowflake_default",
sql="Select * from test_table",
parameters={"param1": "value1"},
)
operator.execute({})
mock_get_db_hook.assert_has_calls(
[call().get_first("Select * from test_table", {"param1": "value1"})]
)


class TestSnowflakeValueCheckOperator:
@mock.patch("airflow.providers.common.sql.operators.sql.SQLValueCheckOperator.get_db_hook")
@mock.patch("airflow.providers.common.sql.operators.sql.SQLValueCheckOperator.check_value")
def test_get_db_hook(
self,
mock_check_value,
mock_get_db_hook,
):
mock_get_db_hook.return_value.get_first.return_value = ["test_value"]

operator = SnowflakeValueCheckOperator(
task_id="snowflake_check",
sql="Select * from test_table",
pass_value=95,
parameters={"param1": "value1"},
)
operator.execute({})
mock_get_db_hook.assert_has_calls(
[call().get_first("Select * from test_table", {"param1": "value1"})]
)
assert mock_check_value.call_args == call(["test_value"])


class TestSnowflakeIntervalCheckOperator:
@mock.patch("airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator.__init__")
def test_get_db_hook(
self,
mock_snowflake_interval_check_operator,
):
SnowflakeIntervalCheckOperator(
task_id="snowflake_check", table="test-table-id", metrics_thresholds={"COUNT(*)": 1.5}
)
assert mock_snowflake_interval_check_operator.call_args == mock.call(
table="test-table-id",
metrics_thresholds={"COUNT(*)": 1.5},
date_filter_column="ds",
days_back=-7,
conn_id="snowflake_default",
task_id="snowflake_check",
default_args={},
)


@pytest.mark.parametrize(
Expand Down