diff --git a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py index 95d602bc9b46c..0b884d97f0c26 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py +++ b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py @@ -863,6 +863,7 @@ def __init__( tolerance: Any = None, conn_id: str | None = None, database: str | None = None, + parameters: Iterable | Mapping[str, Any] | None = None, **kwargs, ): super().__init__(conn_id=conn_id, database=database, **kwargs) @@ -871,6 +872,7 @@ def __init__( tol = _convert_to_float_if_possible(tolerance) self.tol = tol if isinstance(tol, float) else None self.has_tolerance = self.tol is not None + self.parameters = parameters def check_value(self, records): if not records: @@ -903,7 +905,7 @@ def check_value(self, records): def execute(self, context: Context): self.log.info("Executing SQL check: %s", self.sql) - records = self.get_db_hook().get_first(self.sql) + records = self.get_db_hook().get_first(self.sql, self.parameters) self.check_value(records) def _to_float(self, records): diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py index 23065cd3a717b..dd404d5f59e39 100644 --- a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py +++ b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py @@ -51,7 +51,7 @@ ) from airflow.providers.postgres.hooks.postgres import PostgresHook from airflow.providers.standard.operators.empty import EmptyOperator -from airflow.utils import timezone +from airflow.utils import timezone # type: ignore[attr-defined] from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.types import DagRunType @@ -844,7 +844,7 @@ def test_execute_pass(self, mock_get_db_hook): operator.execute(None) - mock_hook.get_first.assert_called_once_with(sql) + mock_hook.get_first.assert_called_once_with(sql, None) @mock.patch.object(SQLValueCheckOperator, "get_db_hook") def test_execute_fail(self, mock_get_db_hook): diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py index 48747381602eb..9834912720447 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -435,7 +435,7 @@ def _make_api_call_with_retries( :param url: The URL for the API endpoint. :param headers: The headers to include in the API call. :param params: (Optional) The query parameters to include in the API call. - :param data: (Optional) The data to include in the API call. + :param json: (Optional) The data to include in the API call. :return: The response object from the API call. """ with requests.Session() as session: diff --git a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py index 4f214c681fb30..84f6773b2bfef 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py @@ -76,8 +76,6 @@ class SnowflakeCheckOperator(SQLCheckOperator): Template references are recognized by str ending in '.sql' :param snowflake_conn_id: Reference to :ref:`Snowflake connection id` - :param autocommit: if True, each command is automatically committed. - (default value: True) :param parameters: (optional) the parameters to render the SQL query with. :param warehouse: name of warehouse (will overwrite any warehouse defined in the connection's extra JSON) @@ -109,8 +107,6 @@ def __init__( sql: str, snowflake_conn_id: str = "snowflake_default", parameters: Iterable | Mapping[str, Any] | None = None, - autocommit: bool = True, - do_xcom_push: bool = True, warehouse: str | None = None, database: str | None = None, role: str | None = None, @@ -179,8 +175,6 @@ def __init__( tolerance: Any = None, snowflake_conn_id: str = "snowflake_default", parameters: Iterable | Mapping[str, Any] | None = None, - autocommit: bool = True, - do_xcom_push: bool = True, warehouse: str | None = None, database: str | None = None, role: str | None = None, @@ -202,7 +196,12 @@ def __init__( **hook_params, } super().__init__( - sql=sql, pass_value=pass_value, tolerance=tolerance, conn_id=snowflake_conn_id, **kwargs + sql=sql, + pass_value=pass_value, + tolerance=tolerance, + conn_id=snowflake_conn_id, + parameters=parameters, + **kwargs, ) self.query_ids: list[str] = [] @@ -259,9 +258,6 @@ def __init__( date_filter_column: str = "ds", days_back: SupportsAbs[int] = -7, snowflake_conn_id: str = "snowflake_default", - parameters: Iterable | Mapping[str, Any] | None = None, - autocommit: bool = True, - do_xcom_push: bool = True, warehouse: str | None = None, database: str | None = None, role: str | None = None, diff --git a/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py b/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py index 721b72e57811e..c4711ccc26d41 100644 --- a/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py +++ b/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py @@ -18,11 +18,13 @@ from __future__ import annotations from unittest import mock +from unittest.mock import call import pendulum import pytest from airflow.exceptions import AirflowException, TaskDeferred +from airflow.models import Connection from airflow.models.dag import DAG from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance @@ -34,7 +36,7 @@ SnowflakeValueCheckOperator, ) from airflow.providers.snowflake.triggers.snowflake_trigger import SnowflakeSqlApiTrigger -from airflow.utils import timezone +from airflow.utils import timezone # type: ignore[attr-defined] from airflow.utils.types import DagRunType from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS @@ -107,25 +109,80 @@ def test_overwrite_params(self, mock_base_op): ) -@pytest.mark.parametrize( - "operator_class, kwargs", - [ - (SnowflakeCheckOperator, dict(sql="Select * from test_table")), - (SnowflakeValueCheckOperator, dict(sql="Select * from test_table", pass_value=95)), - (SnowflakeIntervalCheckOperator, dict(table="test-table-id", metrics_thresholds={"COUNT(*)": 1.5})), - ], -) -class TestSnowflakeCheckOperators: - @mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.get_db_hook") +@pytest.fixture(autouse=True) +def setup_connections(create_connection_without_db): + create_connection_without_db( + Connection( + conn_id="snowflake_default", + conn_type="snowflake", + host="test_host", + port=443, + schema="test_schema", + login="test_user", + password="test_password", + ) + ) + + +class TestSnowflakeCheckOperator: + @mock.patch("airflow.providers.common.sql.operators.sql.SQLCheckOperator.get_db_hook") def test_get_db_hook( self, mock_get_db_hook, - operator_class, - kwargs, ): - operator = operator_class(task_id="snowflake_check", snowflake_conn_id="snowflake_default", **kwargs) - operator.get_db_hook() - mock_get_db_hook.assert_called_once() + operator = SnowflakeCheckOperator( + task_id="snowflake_check", + snowflake_conn_id="snowflake_default", + sql="Select * from test_table", + parameters={"param1": "value1"}, + ) + operator.execute({}) + mock_get_db_hook.assert_has_calls( + [call().get_first("Select * from test_table", {"param1": "value1"})] + ) + + +class TestSnowflakeValueCheckOperator: + @mock.patch("airflow.providers.common.sql.operators.sql.SQLValueCheckOperator.get_db_hook") + @mock.patch("airflow.providers.common.sql.operators.sql.SQLValueCheckOperator.check_value") + def test_get_db_hook( + self, + mock_check_value, + mock_get_db_hook, + ): + mock_get_db_hook.return_value.get_first.return_value = ["test_value"] + + operator = SnowflakeValueCheckOperator( + task_id="snowflake_check", + sql="Select * from test_table", + pass_value=95, + parameters={"param1": "value1"}, + ) + operator.execute({}) + mock_get_db_hook.assert_has_calls( + [call().get_first("Select * from test_table", {"param1": "value1"})] + ) + assert mock_check_value.call_args == call(["test_value"]) + + +class TestSnowflakeIntervalCheckOperator: + @mock.patch("airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator.__init__") + def test_get_db_hook( + self, + mock_snowflake_interval_check_operator, + ): + SnowflakeIntervalCheckOperator( + task_id="snowflake_check", table="test-table-id", metrics_thresholds={"COUNT(*)": 1.5} + ) + assert mock_snowflake_interval_check_operator.call_args == mock.call( + table="test-table-id", + metrics_thresholds={"COUNT(*)": 1.5}, + date_filter_column="ds", + days_back=-7, + conn_id="snowflake_default", + task_id="snowflake_check", + default_args={}, + ) @pytest.mark.parametrize(