diff --git a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py index 3714abc82503f..ddf1a8b5ce6d3 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py +++ b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py @@ -161,7 +161,12 @@ def get_hook(cls, conn_id: str, hook_params: dict | None = None) -> BaseHook: :param hook_params: hook parameters :return: default hook for this connection """ + hook_params = hook_params or {} connection = BaseHook.get_connection(conn_id) + conn_params = connection.extra_dejson + for conn_param in conn_params: + if conn_param not in hook_params: + hook_params[conn_param] = conn_params[conn_param] return connection.get_hook(hook_params=hook_params) @cached_property diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py index c27e94773673f..12166d88f9f99 100644 --- a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py +++ b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py @@ -113,6 +113,11 @@ def test_when_provider_min_airflow_version_is_3_0_or_higher_remove_obsolete_get_ class TestSQLExecuteQueryOperator: + def setup_method(self): + self.task_id = "test_task" + self.conn_id = "sql_default" + self._operator = SQLExecuteQueryOperator(task_id=self.task_id, conn_id=self.conn_id, sql="sql") + def _construct_operator(self, sql, **kwargs): dag = DAG("test_dag", schedule=None, start_date=datetime.datetime(2017, 1, 1)) return SQLExecuteQueryOperator( @@ -190,6 +195,23 @@ def test_output_processor(self, mock_get_db_hook): assert descriptions == ("id", "name") assert result == [(1, "Alice"), (2, "Bob")] + @skip_if_force_lowest_dependencies_marker + def test_sql_operator_extra_dejson_fields_to_hook_params(self): + with mock.patch( + "airflow.providers.common.sql.operators.sql.BaseHook.get_connection", + return_value=Connection(conn_id="sql_default", conn_type="postgres"), + ) as mock_get_conn: + mock_get_conn.return_value = Connection( + conn_id="google_cloud_bigquery_default", + conn_type="gcpbigquery", + extra={"use_legacy_sql": False, "priority": "INTERACTIVE"}, + ) + self._operator.hook_params = {"use_legacy_sql": True, "location": "us-east1"} + assert self._operator._hook.conn_type == "gcpbigquery" + assert self._operator._hook.use_legacy_sql is True + assert self._operator._hook.location == "us-east1" + assert self._operator._hook.priority == "INTERACTIVE" + class TestColumnCheckOperator: valid_column_mapping = { diff --git a/providers/redis/src/airflow/providers/redis/hooks/redis.py b/providers/redis/src/airflow/providers/redis/hooks/redis.py index 2f956a5cb137e..97ca40ea9bb5c 100644 --- a/providers/redis/src/airflow/providers/redis/hooks/redis.py +++ b/providers/redis/src/airflow/providers/redis/hooks/redis.py @@ -43,7 +43,7 @@ class RedisHook(BaseHook): conn_type = "redis" hook_name = "Redis" - def __init__(self, redis_conn_id: str = default_conn_name) -> None: + def __init__(self, redis_conn_id: str = default_conn_name, **kwargs) -> None: """ Prepare hook to connect to a Redis database. @@ -53,11 +53,11 @@ def __init__(self, redis_conn_id: str = default_conn_name) -> None: super().__init__() self.redis_conn_id = redis_conn_id self.redis = None - self.host = None - self.port = None - self.username = None - self.password = None - self.db = None + self.host = kwargs.get("host", None) + self.port = kwargs.get("port", None) + self.username = kwargs.get("username", None) + self.password = kwargs.get("password", None) + self.db = kwargs.get("db", None) def get_conn(self): """Return a Redis connection."""