Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import asyncssh
from asgiref.sync import sync_to_async
from paramiko.config import SSH_PORT

from airflow.exceptions import (
AirflowException,
Expand Down Expand Up @@ -703,10 +704,10 @@ class SFTPHookAsync(BaseHook):
def __init__( # nosec: B107
self,
sftp_conn_id: str = default_conn_name,
host: str = "",
port: int = 22,
username: str = "",
password: str = "",
host: str | None = None,
port: int | None = None,
username: str | None = None,
password: str | None = None,
known_hosts: str = default_known_hosts,
key_file: str = "",
passphrase: str = "",
Expand Down Expand Up @@ -762,11 +763,19 @@ async def _get_conn(self) -> asyncssh.SSHClientConnection:
if conn.extra is not None:
self._parse_extras(conn) # type: ignore[arg-type]

conn_config: dict[str, Any] = {
"host": conn.host,
"port": conn.port,
"username": conn.login,
"password": conn.password,
def _get_value(self_val, conn_val, default=None):
"""Return the first non-None value among self, conn, default."""
if self_val is not None:
return self_val
if conn_val is not None:
return conn_val
return default

conn_config = {
"host": _get_value(self.host, conn.host),
"port": _get_value(self.port, conn.port, SSH_PORT),
"username": _get_value(self.username, conn.login),
"password": _get_value(self.password, conn.password),
}
if self.key_file:
conn_config.update(client_keys=self.key_file)
Expand Down
61 changes: 61 additions & 0 deletions providers/sftp/tests/unit/sftp/hooks/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,67 @@ async def test_connection_private(self, mock_get_connection, mock_import_private

mock_connect.assert_called_with(**expected_connection_details)

@pytest.mark.asyncio
@patch("asyncssh.connect", new_callable=AsyncMock)
@patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
async def test_connection_port_default_to_22(self, mock_get_connection, mock_connect):
from unittest.mock import Mock, call

mock_get_connection.return_value = Mock(
host="localhost",
port=None,
login="username",
password="password",
extra="{}",
extra_dejson={},
)

hook = SFTPHookAsync()
await hook._get_conn()
assert mock_connect.mock_calls == [
call(
host="localhost",
# Even if the port is not specified in conn_config, it should still default to 22.
# This behavior is consistent with STPHook.
port=22,
username="username",
password="password",
known_hosts=None,
),
]

@pytest.mark.asyncio
@patch("asyncssh.connect", new_callable=AsyncMock)
@patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection")
async def test_init_argument_not_ignored(self, mock_get_connection, mock_connect):
from unittest.mock import Mock, call

mock_get_connection.return_value = Mock(
host="localhost",
port=None,
login="username",
password="password",
extra="{}",
extra_dejson={},
)

hook = SFTPHookAsync(
host="localhost-from-init",
port=25,
username="username-from-init",
password="password-from-init",
)
await hook._get_conn()
assert mock_connect.mock_calls == [
call(
host="localhost-from-init",
port=25,
username="username-from-init",
password="password-from-init",
known_hosts=None,
),
]

@patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn")
@pytest.mark.asyncio
async def test_list_directory_path_does_not_exist(self, mock_hook_get_conn):
Expand Down