Skip to content
Merged
44 changes: 29 additions & 15 deletions providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def __init__(
*args,
**kwargs,
) -> None:
self.conn: SFTPClient | None = None

# TODO: remove support for ssh_hook when it is removed from SFTPOperator
if kwargs.get("ssh_hook") is not None:
warnings.warn(
Expand All @@ -110,8 +112,20 @@ def __init__(

super().__init__(*args, **kwargs)

def get_conn(self) -> SFTPClient: # type: ignore[override]
"""Open an SFTP connection to the remote host."""
if self.conn is None:
self.conn = super().get_conn().open_sftp()
return self.conn

def close_conn(self) -> None:
"""Close the SFTP connection."""
if self.conn is not None:
self.conn.close()
self.conn = None

@contextmanager
def get_conn(self) -> Generator[SFTPClient, None, None]:
def get_managed_conn(self) -> Generator[SFTPClient, None, None]:
"""Context manager that closes the connection after use."""
with closing(super().get_conn()) as conn:
with closing(conn.open_sftp()) as sftp:
Expand All @@ -126,7 +140,7 @@ def describe_directory(self, path: str) -> dict[str, dict[str, str | int | None]

:param path: full path to the remote directory
"""
with self.get_conn() as conn: # type: SFTPClient
with self.get_managed_conn() as conn: # type: SFTPClient
flist = sorted(conn.listdir_attr(path), key=lambda x: x.filename)
files = {}
for f in flist:
Expand All @@ -144,7 +158,7 @@ def list_directory(self, path: str) -> list[str]:

:param path: full path to the remote directory to list
"""
with self.get_conn() as conn:
with self.get_managed_conn() as conn:
return sorted(conn.listdir(path))

def list_directory_with_attr(self, path: str) -> list[SFTPAttributes]:
Expand All @@ -153,7 +167,7 @@ def list_directory_with_attr(self, path: str) -> list[SFTPAttributes]:

:param path: full path to the remote directory to list
"""
with self.get_conn() as conn:
with self.get_managed_conn() as conn:
return [file for file in conn.listdir_attr(path)]

def mkdir(self, path: str, mode: int = 0o777) -> None:
Expand All @@ -166,7 +180,7 @@ def mkdir(self, path: str, mode: int = 0o777) -> None:
:param path: full path to the remote directory to create
:param mode: int permissions of octal mode for directory
"""
with self.get_conn() as conn:
with self.get_managed_conn() as conn:
conn.mkdir(path, mode=mode)

def isdir(self, path: str) -> bool:
Expand All @@ -175,7 +189,7 @@ def isdir(self, path: str) -> bool:

:param path: full path to the remote directory to check
"""
with self.get_conn() as conn:
with self.get_managed_conn() as conn:
try:
return stat.S_ISDIR(conn.stat(path).st_mode) # type: ignore
except OSError:
Expand All @@ -187,7 +201,7 @@ def isfile(self, path: str) -> bool:

:param path: full path to the remote file to check
"""
with self.get_conn() as conn:
with self.get_managed_conn() as conn:
try:
return stat.S_ISREG(conn.stat(path).st_mode) # type: ignore
except OSError:
Expand Down Expand Up @@ -216,7 +230,7 @@ def create_directory(self, path: str, mode: int = 0o777) -> None:
self.create_directory(dirname, mode)
if basename:
self.log.info("Creating %s", path)
with self.get_conn() as conn:
with self.get_managed_conn() as conn:
conn.mkdir(path, mode=mode)

def delete_directory(self, path: str, include_files: bool = False) -> None:
Expand All @@ -232,7 +246,7 @@ def delete_directory(self, path: str, include_files: bool = False) -> None:
files, dirs, _ = self.get_tree_map(path)
dirs = dirs[::-1] # reverse the order for deleting deepest directories first

with self.get_conn() as conn:
with self.get_managed_conn() as conn:
for file_path in files:
conn.remove(file_path)
for dir_path in dirs:
Expand All @@ -250,7 +264,7 @@ def retrieve_file(self, remote_full_path: str, local_full_path: str, prefetch: b
:param local_full_path: full path to the local file or a file-like buffer
:param prefetch: controls whether prefetch is performed (default: True)
"""
with self.get_conn() as conn:
with self.get_managed_conn() as conn:
if isinstance(local_full_path, BytesIO):
conn.getfo(remote_full_path, local_full_path, prefetch=prefetch)
else:
Expand All @@ -266,7 +280,7 @@ def store_file(self, remote_full_path: str, local_full_path: str, confirm: bool
:param remote_full_path: full path to the remote file
:param local_full_path: full path to the local file or a file-like buffer
"""
with self.get_conn() as conn:
with self.get_managed_conn() as conn:
if isinstance(local_full_path, BytesIO):
conn.putfo(local_full_path, remote_full_path, confirm=confirm)
else:
Expand All @@ -278,7 +292,7 @@ def delete_file(self, path: str) -> None:

:param path: full path to the remote file
"""
with self.get_conn() as conn:
with self.get_managed_conn() as conn:
conn.remove(path)

def retrieve_directory(self, remote_full_path: str, local_full_path: str, prefetch: bool = True) -> None:
Expand Down Expand Up @@ -332,7 +346,7 @@ def get_mod_time(self, path: str) -> str:

:param path: full path to the remote file
"""
with self.get_conn() as conn:
with self.get_managed_conn() as conn:
ftp_mdtm = conn.stat(path).st_mtime
return datetime.datetime.fromtimestamp(ftp_mdtm).strftime("%Y%m%d%H%M%S") # type: ignore

Expand All @@ -342,7 +356,7 @@ def path_exists(self, path: str) -> bool:

:param path: full path to the remote file or directory
"""
with self.get_conn() as conn:
with self.get_managed_conn() as conn:
try:
conn.stat(path)
except OSError:
Expand Down Expand Up @@ -441,7 +455,7 @@ def append_matching_path_callback(list_: list[str]) -> Callable:
def test_connection(self) -> tuple[bool, str]:
"""Test the SFTP connection by calling path with directory."""
try:
with self.get_conn() as conn:
with self.get_managed_conn() as conn:
conn.normalize(".")
return True, "Connection successfully tested"
except Exception as e:
Expand Down
1 change: 0 additions & 1 deletion providers/sftp/src/airflow/providers/sftp/sensors/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def poke(self, context: Context) -> PokeReturnValue | bool:
else:
files_found.append(actual_file_to_check)

self.hook.close_conn()
if not len(files_found):
return False

Expand Down
21 changes: 16 additions & 5 deletions providers/sftp/tests/unit/sftp/hooks/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,18 @@ def setup_test_cases(self, tmp_path_factory):
self.update_connection(self.old_login)

def test_get_conn(self):
with self.hook.get_conn() as conn:
output = self.hook.get_conn()
assert isinstance(output, paramiko.SFTPClient)
assert self.hook.conn is not None

def test_close_conn(self):
self.hook.conn = self.hook.get_conn()
assert self.hook.conn is not None
self.hook.close_conn()
assert self.hook.conn is None

def test_get_managed_conn(self):
with self.hook.get_managed_conn() as conn:
assert isinstance(conn, paramiko.SFTPClient)

@patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn")
Expand All @@ -113,7 +124,7 @@ def test_get_close_conn(self, mock_get_conn):
mock_ssh_client.open_sftp.return_value = mock_sftp_client
mock_get_conn.return_value = mock_ssh_client

with SFTPHook().get_conn() as conn:
with SFTPHook().get_managed_conn() as conn:
assert conn == mock_sftp_client

mock_sftp_client.close.assert_called_once()
Expand All @@ -140,7 +151,7 @@ def test_mkdir(self):
assert new_dir_name in output
# test the directory has default permissions to 777 - umask
umask = 0o022
with self.hook.get_conn() as conn:
with self.hook.get_managed_conn() as conn:
output = conn.lstat(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir_name))
assert output.st_mode & 0o777 == 0o777 - umask

Expand All @@ -151,7 +162,7 @@ def test_create_and_delete_directory(self):
assert new_dir_name in output
# test the directory has default permissions to 777
umask = 0o022
with self.hook.get_conn() as conn:
with self.hook.get_managed_conn() as conn:
output = conn.lstat(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir_name))
assert output.st_mode & 0o777 == 0o777 - umask
# test directory already exists for code coverage, should not raise an exception
Expand Down Expand Up @@ -531,7 +542,7 @@ def test_sftp_hook_with_proxy_command(self, mock_proxy_command, mock_ssh_client)
host_proxy_cmd=host_proxy_cmd,
)

with hook.get_conn():
with hook.get_managed_conn():
mock_proxy_command.assert_called_once_with(host_proxy_cmd)
mock_ssh_client.return_value.connect.assert_called_once_with(
hostname="example.com",
Expand Down
12 changes: 12 additions & 0 deletions providers/sftp/tests/unit/sftp/sensors/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_file_present(self, sftp_hook_mock):
context = {"ds": "1970-01-01"}
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with("/path/to/file/1970-01-01.txt")
sftp_hook_mock.return_value.close_conn.assert_not_called()
assert output

@patch("airflow.providers.sftp.sensors.sftp.SFTPHook")
Expand All @@ -50,6 +51,7 @@ def test_file_absent(self, sftp_hook_mock):
context = {"ds": "1970-01-01"}
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with("/path/to/file/1970-01-01.txt")
sftp_hook_mock.return_value.close_conn.assert_not_called()
assert not output

@patch("airflow.providers.sftp.sensors.sftp.SFTPHook")
Expand Down Expand Up @@ -77,6 +79,7 @@ def test_file_new_enough(self, sftp_hook_mock):
context = {"ds": "1970-01-00"}
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with("/path/to/file/1970-01-01.txt")
sftp_hook_mock.return_value.close_conn.assert_not_called()
assert output

@patch("airflow.providers.sftp.sensors.sftp.SFTPHook")
Expand All @@ -91,6 +94,7 @@ def test_file_not_new_enough(self, sftp_hook_mock):
context = {"ds": "1970-01-00"}
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with("/path/to/file/1970-01-01.txt")
sftp_hook_mock.return_value.close_conn.assert_not_called()
assert not output

@pytest.mark.parametrize(
Expand All @@ -116,6 +120,7 @@ def test_multiple_datetime_format_in_newer_than(self, sftp_hook_mock, newer_than
context = {"ds": "1970-01-00"}
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with("/path/to/file/1970-01-01.txt")
sftp_hook_mock.return_value.close_conn.assert_not_called()
assert not output

@patch("airflow.providers.sftp.sensors.sftp.SFTPHook")
Expand All @@ -126,6 +131,7 @@ def test_file_present_with_pattern(self, sftp_hook_mock):
context = {"ds": "1970-01-01"}
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with("/path/to/file/text_file.txt")
sftp_hook_mock.return_value.close_conn.assert_not_called()
assert output

@patch("airflow.providers.sftp.sensors.sftp.SFTPHook")
Expand All @@ -135,6 +141,7 @@ def test_file_not_present_with_pattern(self, sftp_hook_mock):
sftp_sensor = SFTPSensor(task_id="unit_test", path="/path/to/file/", file_pattern="*.txt")
context = {"ds": "1970-01-01"}
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.close_conn.assert_not_called()
assert not output

@patch("airflow.providers.sftp.sensors.sftp.SFTPHook")
Expand All @@ -149,6 +156,7 @@ def test_multiple_files_present_with_pattern(self, sftp_hook_mock):
output = sftp_sensor.poke(context)
get_mod_time = sftp_hook_mock.return_value.get_mod_time
expected_calls = [call("/path/to/file/text_file.txt"), call("/path/to/file/another_text_file.txt")]
sftp_hook_mock.return_value.close_conn.assert_not_called()
assert get_mod_time.mock_calls == expected_calls
assert output

Expand Down Expand Up @@ -176,6 +184,7 @@ def test_multiple_files_present_with_pattern_and_newer_than(self, sftp_hook_mock
sftp_hook_mock.return_value.get_mod_time.assert_has_calls(
[mock.call("/path/to/file/text_file1.txt"), mock.call("/path/to/file/text_file2.txt")]
)
sftp_hook_mock.return_value.close_conn.assert_not_called()
assert output

@patch("airflow.providers.sftp.sensors.sftp.SFTPHook")
Expand Down Expand Up @@ -206,6 +215,7 @@ def test_multiple_old_files_present_with_pattern_and_newer_than(self, sftp_hook_
mock.call("/path/to/file/text_file3.txt"),
]
)
sftp_hook_mock.return_value.close_conn.assert_not_called()
assert not output

@pytest.mark.parametrize(
Expand All @@ -231,6 +241,7 @@ def test_file_path_present_with_callback(self, sftp_hook_mock, op_args, op_kwarg
output = sftp_sensor.poke(context)

sftp_hook_mock.return_value.get_mod_time.assert_called_once_with("/path/to/file/1970-01-01.txt")
sftp_hook_mock.return_value.close_conn.assert_not_called()
sample_callable.assert_called_once_with(*op_args, **op_kwargs)
assert isinstance(output, PokeReturnValue)
assert output.is_done
Expand Down Expand Up @@ -267,6 +278,7 @@ def test_file_pattern_present_with_callback(self, sftp_hook_mock, op_args, op_kw
output = sftp_sensor.poke(context)

sample_callable.assert_called_once_with(*op_args, **op_kwargs)
sftp_hook_mock.return_value.close_conn.assert_not_called()
assert isinstance(output, PokeReturnValue)
assert output.is_done
assert output.xcom_value == {
Expand Down