Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,12 @@ def get_hook(cls, conn_id: str, hook_params: dict | None = None) -> BaseHook:
:param hook_params: hook parameters
:return: default hook for this connection
"""
hook_params = hook_params or {}
connection = BaseHook.get_connection(conn_id)
conn_params = connection.extra_dejson
for conn_param in conn_params:
if conn_param not in hook_params:
hook_params[conn_param] = conn_params[conn_param]
return connection.get_hook(hook_params=hook_params)

@cached_property
Expand Down
22 changes: 22 additions & 0 deletions providers/common/sql/tests/unit/common/sql/operators/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ def test_when_provider_min_airflow_version_is_3_0_or_higher_remove_obsolete_get_


class TestSQLExecuteQueryOperator:
def setup_method(self):
self.task_id = "test_task"
self.conn_id = "sql_default"
self._operator = SQLExecuteQueryOperator(task_id=self.task_id, conn_id=self.conn_id, sql="sql")

def _construct_operator(self, sql, **kwargs):
dag = DAG("test_dag", schedule=None, start_date=datetime.datetime(2017, 1, 1))
return SQLExecuteQueryOperator(
Expand Down Expand Up @@ -190,6 +195,23 @@ def test_output_processor(self, mock_get_db_hook):
assert descriptions == ("id", "name")
assert result == [(1, "Alice"), (2, "Bob")]

@skip_if_force_lowest_dependencies_marker
def test_sql_operator_extra_dejson_fields_to_hook_params(self):
with mock.patch(
"airflow.providers.common.sql.operators.sql.BaseHook.get_connection",
return_value=Connection(conn_id="sql_default", conn_type="postgres"),
) as mock_get_conn:
mock_get_conn.return_value = Connection(
conn_id="google_cloud_bigquery_default",
conn_type="gcpbigquery",
extra={"use_legacy_sql": False, "priority": "INTERACTIVE"},
)
self._operator.hook_params = {"use_legacy_sql": True, "location": "us-east1"}
assert self._operator._hook.conn_type == "gcpbigquery"
assert self._operator._hook.use_legacy_sql is True
assert self._operator._hook.location == "us-east1"
assert self._operator._hook.priority == "INTERACTIVE"


class TestColumnCheckOperator:
valid_column_mapping = {
Expand Down
12 changes: 6 additions & 6 deletions providers/redis/src/airflow/providers/redis/hooks/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class RedisHook(BaseHook):
conn_type = "redis"
hook_name = "Redis"

def __init__(self, redis_conn_id: str = default_conn_name) -> None:
def __init__(self, redis_conn_id: str = default_conn_name, **kwargs) -> None:
"""
Prepare hook to connect to a Redis database.

Expand All @@ -53,11 +53,11 @@ def __init__(self, redis_conn_id: str = default_conn_name) -> None:
super().__init__()
self.redis_conn_id = redis_conn_id
self.redis = None
self.host = None
self.port = None
self.username = None
self.password = None
self.db = None
self.host = kwargs.get("host", None)
self.port = kwargs.get("port", None)
self.username = kwargs.get("username", None)
self.password = kwargs.get("password", None)
self.db = kwargs.get("db", None)

def get_conn(self):
"""Return a Redis connection."""
Expand Down