From 105980e11505f20bf84cf0b1d3dc370473bb9c62 Mon Sep 17 00:00:00 2001 From: David Blain Date: Fri, 28 Feb 2025 16:05:21 +0100 Subject: [PATCH 1/5] fix: SFTPHook.get_conn method should return SFTPClient connection instead of context manager for backward compatibility --- .../src/airflow/providers/sftp/hooks/sftp.py | 44 ++++++++++++------- .../sftp/tests/unit/sftp/hooks/test_sftp.py | 21 ++++++--- 2 files changed, 45 insertions(+), 20 deletions(-) diff --git a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py index 367a8f903973b..3e7ee5596488b 100644 --- a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py +++ b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py @@ -87,6 +87,8 @@ def __init__( *args, **kwargs, ) -> None: + self.conn: SFTPClient | None = None + # TODO: remove support for ssh_hook when it is removed from SFTPOperator if kwargs.get("ssh_hook") is not None: warnings.warn( @@ -110,8 +112,20 @@ def __init__( super().__init__(*args, **kwargs) + def get_conn(self) -> SFTPClient: # type: ignore[override] + """Open an SFTP connection to the remote host.""" + if self.conn is None: + self.conn = super().get_conn().open_sftp() + return self.conn + + def close_conn(self) -> None: + """Close the SFTP connection.""" + if self.conn is not None: + self.conn.close() + self.conn = None + @contextmanager - def get_conn(self) -> Generator[SFTPClient, None, None]: + def get_sftp_conn(self) -> Generator[SFTPClient, None, None]: """Context manager that closes the connection after use.""" with closing(super().get_conn()) as conn: with closing(conn.open_sftp()) as sftp: @@ -126,7 +140,7 @@ def describe_directory(self, path: str) -> dict[str, dict[str, str | int | None] :param path: full path to the remote directory """ - with self.get_conn() as conn: # type: SFTPClient + with self.get_sftp_conn() as conn: # type: SFTPClient flist = sorted(conn.listdir_attr(path), key=lambda x: x.filename) files = {} for f in flist: @@ -144,7 +158,7 @@ def list_directory(self, path: str) -> list[str]: :param path: full path to the remote directory to list """ - with self.get_conn() as conn: + with self.get_sftp_conn() as conn: return sorted(conn.listdir(path)) def list_directory_with_attr(self, path: str) -> list[SFTPAttributes]: @@ -153,7 +167,7 @@ def list_directory_with_attr(self, path: str) -> list[SFTPAttributes]: :param path: full path to the remote directory to list """ - with self.get_conn() as conn: + with self.get_sftp_conn() as conn: return [file for file in conn.listdir_attr(path)] def mkdir(self, path: str, mode: int = 0o777) -> None: @@ -166,7 +180,7 @@ def mkdir(self, path: str, mode: int = 0o777) -> None: :param path: full path to the remote directory to create :param mode: int permissions of octal mode for directory """ - with self.get_conn() as conn: + with self.get_sftp_conn() as conn: conn.mkdir(path, mode=mode) def isdir(self, path: str) -> bool: @@ -175,7 +189,7 @@ def isdir(self, path: str) -> bool: :param path: full path to the remote directory to check """ - with self.get_conn() as conn: + with self.get_sftp_conn() as conn: try: return stat.S_ISDIR(conn.stat(path).st_mode) # type: ignore except OSError: @@ -187,7 +201,7 @@ def isfile(self, path: str) -> bool: :param path: full path to the remote file to check """ - with self.get_conn() as conn: + with self.get_sftp_conn() as conn: try: return stat.S_ISREG(conn.stat(path).st_mode) # type: ignore except OSError: @@ -216,7 +230,7 @@ def create_directory(self, path: str, mode: int = 0o777) -> None: self.create_directory(dirname, mode) if basename: self.log.info("Creating %s", path) - with self.get_conn() as conn: + with self.get_sftp_conn() as conn: conn.mkdir(path, mode=mode) def delete_directory(self, path: str, include_files: bool = False) -> None: @@ -232,7 +246,7 @@ def delete_directory(self, path: str, include_files: bool = False) -> None: files, dirs, _ = self.get_tree_map(path) dirs = dirs[::-1] # reverse the order for deleting deepest directories first - with self.get_conn() as conn: + with self.get_sftp_conn() as conn: for file_path in files: conn.remove(file_path) for dir_path in dirs: @@ -250,7 +264,7 @@ def retrieve_file(self, remote_full_path: str, local_full_path: str, prefetch: b :param local_full_path: full path to the local file or a file-like buffer :param prefetch: controls whether prefetch is performed (default: True) """ - with self.get_conn() as conn: + with self.get_sftp_conn() as conn: if isinstance(local_full_path, BytesIO): conn.getfo(remote_full_path, local_full_path, prefetch=prefetch) else: @@ -266,7 +280,7 @@ def store_file(self, remote_full_path: str, local_full_path: str, confirm: bool :param remote_full_path: full path to the remote file :param local_full_path: full path to the local file or a file-like buffer """ - with self.get_conn() as conn: + with self.get_sftp_conn() as conn: if isinstance(local_full_path, BytesIO): conn.putfo(local_full_path, remote_full_path, confirm=confirm) else: @@ -278,7 +292,7 @@ def delete_file(self, path: str) -> None: :param path: full path to the remote file """ - with self.get_conn() as conn: + with self.get_sftp_conn() as conn: conn.remove(path) def retrieve_directory(self, remote_full_path: str, local_full_path: str, prefetch: bool = True) -> None: @@ -332,7 +346,7 @@ def get_mod_time(self, path: str) -> str: :param path: full path to the remote file """ - with self.get_conn() as conn: + with self.get_sftp_conn() as conn: ftp_mdtm = conn.stat(path).st_mtime return datetime.datetime.fromtimestamp(ftp_mdtm).strftime("%Y%m%d%H%M%S") # type: ignore @@ -342,7 +356,7 @@ def path_exists(self, path: str) -> bool: :param path: full path to the remote file or directory """ - with self.get_conn() as conn: + with self.get_sftp_conn() as conn: try: conn.stat(path) except OSError: @@ -441,7 +455,7 @@ def append_matching_path_callback(list_: list[str]) -> Callable: def test_connection(self) -> tuple[bool, str]: """Test the SFTP connection by calling path with directory.""" try: - with self.get_conn() as conn: + with self.get_sftp_conn() as conn: conn.normalize(".") return True, "Connection successfully tested" except Exception as e: diff --git a/providers/sftp/tests/unit/sftp/hooks/test_sftp.py b/providers/sftp/tests/unit/sftp/hooks/test_sftp.py index afed97ba574d5..0e29f2ad157b8 100644 --- a/providers/sftp/tests/unit/sftp/hooks/test_sftp.py +++ b/providers/sftp/tests/unit/sftp/hooks/test_sftp.py @@ -103,7 +103,18 @@ def setup_test_cases(self, tmp_path_factory): self.update_connection(self.old_login) def test_get_conn(self): - with self.hook.get_conn() as conn: + output = self.hook.get_conn() + assert isinstance(output, paramiko.SFTPClient) + assert self.hook.conn is not None + + def test_close_conn(self): + self.hook.conn = self.hook.get_conn() + assert self.hook.conn is not None + self.hook.close_conn() + assert self.hook.conn is None + + def test_get_sftp_conn(self): + with self.hook.get_sftp_conn() as conn: assert isinstance(conn, paramiko.SFTPClient) @patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn") @@ -113,7 +124,7 @@ def test_get_close_conn(self, mock_get_conn): mock_ssh_client.open_sftp.return_value = mock_sftp_client mock_get_conn.return_value = mock_ssh_client - with SFTPHook().get_conn() as conn: + with SFTPHook().get_sftp_conn() as conn: assert conn == mock_sftp_client mock_sftp_client.close.assert_called_once() @@ -140,7 +151,7 @@ def test_mkdir(self): assert new_dir_name in output # test the directory has default permissions to 777 - umask umask = 0o022 - with self.hook.get_conn() as conn: + with self.hook.get_sftp_conn() as conn: output = conn.lstat(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir_name)) assert output.st_mode & 0o777 == 0o777 - umask @@ -151,7 +162,7 @@ def test_create_and_delete_directory(self): assert new_dir_name in output # test the directory has default permissions to 777 umask = 0o022 - with self.hook.get_conn() as conn: + with self.hook.get_sftp_conn() as conn: output = conn.lstat(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir_name)) assert output.st_mode & 0o777 == 0o777 - umask # test directory already exists for code coverage, should not raise an exception @@ -531,7 +542,7 @@ def test_sftp_hook_with_proxy_command(self, mock_proxy_command, mock_ssh_client) host_proxy_cmd=host_proxy_cmd, ) - with hook.get_conn(): + with hook.get_sftp_conn(): mock_proxy_command.assert_called_once_with(host_proxy_cmd) mock_ssh_client.return_value.connect.assert_called_once_with( hostname="example.com", From 22410b555334ecb08ef3d3d1b53084e4f8ca78db Mon Sep 17 00:00:00 2001 From: David Blain Date: Fri, 28 Feb 2025 16:06:07 +0100 Subject: [PATCH 2/5] fix: SFTPSensor doesn't need to call the close_conn method as the delegated method use the context manager which take care of closing the connection --- providers/sftp/src/airflow/providers/sftp/sensors/sftp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/providers/sftp/src/airflow/providers/sftp/sensors/sftp.py b/providers/sftp/src/airflow/providers/sftp/sensors/sftp.py index fa6a5219a7e53..ca305098914ad 100644 --- a/providers/sftp/src/airflow/providers/sftp/sensors/sftp.py +++ b/providers/sftp/src/airflow/providers/sftp/sensors/sftp.py @@ -129,7 +129,6 @@ def poke(self, context: Context) -> PokeReturnValue | bool: else: files_found.append(actual_file_to_check) - self.hook.close_conn() if not len(files_found): return False From 78cb0730757f000b5181f9098cc3af349288aad7 Mon Sep 17 00:00:00 2001 From: David Blain Date: Fri, 28 Feb 2025 16:14:13 +0100 Subject: [PATCH 3/5] refactor: Make sure SFTPSensor doesn't need to call the close_conn method anymore and test this behaviour --- providers/sftp/tests/unit/sftp/sensors/test_sftp.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/providers/sftp/tests/unit/sftp/sensors/test_sftp.py b/providers/sftp/tests/unit/sftp/sensors/test_sftp.py index 4d1be081af16c..8366a8817c6a2 100644 --- a/providers/sftp/tests/unit/sftp/sensors/test_sftp.py +++ b/providers/sftp/tests/unit/sftp/sensors/test_sftp.py @@ -41,6 +41,7 @@ def test_file_present(self, sftp_hook_mock): context = {"ds": "1970-01-01"} output = sftp_sensor.poke(context) sftp_hook_mock.return_value.get_mod_time.assert_called_once_with("/path/to/file/1970-01-01.txt") + sftp_hook_mock.return_value.close_conn.assert_not_called() assert output @patch("airflow.providers.sftp.sensors.sftp.SFTPHook") @@ -50,6 +51,7 @@ def test_file_absent(self, sftp_hook_mock): context = {"ds": "1970-01-01"} output = sftp_sensor.poke(context) sftp_hook_mock.return_value.get_mod_time.assert_called_once_with("/path/to/file/1970-01-01.txt") + sftp_hook_mock.return_value.close_conn.assert_not_called() assert not output @patch("airflow.providers.sftp.sensors.sftp.SFTPHook") @@ -77,6 +79,7 @@ def test_file_new_enough(self, sftp_hook_mock): context = {"ds": "1970-01-00"} output = sftp_sensor.poke(context) sftp_hook_mock.return_value.get_mod_time.assert_called_once_with("/path/to/file/1970-01-01.txt") + sftp_hook_mock.return_value.close_conn.assert_not_called() assert output @patch("airflow.providers.sftp.sensors.sftp.SFTPHook") @@ -91,6 +94,7 @@ def test_file_not_new_enough(self, sftp_hook_mock): context = {"ds": "1970-01-00"} output = sftp_sensor.poke(context) sftp_hook_mock.return_value.get_mod_time.assert_called_once_with("/path/to/file/1970-01-01.txt") + sftp_hook_mock.return_value.close_conn.assert_not_called() assert not output @pytest.mark.parametrize( @@ -116,6 +120,7 @@ def test_multiple_datetime_format_in_newer_than(self, sftp_hook_mock, newer_than context = {"ds": "1970-01-00"} output = sftp_sensor.poke(context) sftp_hook_mock.return_value.get_mod_time.assert_called_once_with("/path/to/file/1970-01-01.txt") + sftp_hook_mock.return_value.close_conn.assert_not_called() assert not output @patch("airflow.providers.sftp.sensors.sftp.SFTPHook") @@ -126,6 +131,7 @@ def test_file_present_with_pattern(self, sftp_hook_mock): context = {"ds": "1970-01-01"} output = sftp_sensor.poke(context) sftp_hook_mock.return_value.get_mod_time.assert_called_once_with("/path/to/file/text_file.txt") + sftp_hook_mock.return_value.close_conn.assert_not_called() assert output @patch("airflow.providers.sftp.sensors.sftp.SFTPHook") @@ -135,6 +141,7 @@ def test_file_not_present_with_pattern(self, sftp_hook_mock): sftp_sensor = SFTPSensor(task_id="unit_test", path="/path/to/file/", file_pattern="*.txt") context = {"ds": "1970-01-01"} output = sftp_sensor.poke(context) + sftp_hook_mock.return_value.close_conn.assert_not_called() assert not output @patch("airflow.providers.sftp.sensors.sftp.SFTPHook") @@ -149,6 +156,7 @@ def test_multiple_files_present_with_pattern(self, sftp_hook_mock): output = sftp_sensor.poke(context) get_mod_time = sftp_hook_mock.return_value.get_mod_time expected_calls = [call("/path/to/file/text_file.txt"), call("/path/to/file/another_text_file.txt")] + sftp_hook_mock.return_value.close_conn.assert_not_called() assert get_mod_time.mock_calls == expected_calls assert output @@ -176,6 +184,7 @@ def test_multiple_files_present_with_pattern_and_newer_than(self, sftp_hook_mock sftp_hook_mock.return_value.get_mod_time.assert_has_calls( [mock.call("/path/to/file/text_file1.txt"), mock.call("/path/to/file/text_file2.txt")] ) + sftp_hook_mock.return_value.close_conn.assert_not_called() assert output @patch("airflow.providers.sftp.sensors.sftp.SFTPHook") @@ -206,6 +215,7 @@ def test_multiple_old_files_present_with_pattern_and_newer_than(self, sftp_hook_ mock.call("/path/to/file/text_file3.txt"), ] ) + sftp_hook_mock.return_value.close_conn.assert_not_called() assert not output @pytest.mark.parametrize( @@ -231,6 +241,7 @@ def test_file_path_present_with_callback(self, sftp_hook_mock, op_args, op_kwarg output = sftp_sensor.poke(context) sftp_hook_mock.return_value.get_mod_time.assert_called_once_with("/path/to/file/1970-01-01.txt") + sftp_hook_mock.return_value.close_conn.assert_not_called() sample_callable.assert_called_once_with(*op_args, **op_kwargs) assert isinstance(output, PokeReturnValue) assert output.is_done @@ -267,6 +278,7 @@ def test_file_pattern_present_with_callback(self, sftp_hook_mock, op_args, op_kw output = sftp_sensor.poke(context) sample_callable.assert_called_once_with(*op_args, **op_kwargs) + sftp_hook_mock.return_value.close_conn.assert_not_called() assert isinstance(output, PokeReturnValue) assert output.is_done assert output.xcom_value == { From b3464732944e3992a2cbfb3258782c9697cb88a5 Mon Sep 17 00:00:00 2001 From: David Blain Date: Fri, 28 Feb 2025 18:08:44 +0100 Subject: [PATCH 4/5] refactor: Renamed get_sftp_conn method of SFTPHook to use_conn --- .../src/airflow/providers/sftp/hooks/sftp.py | 30 +++++++++---------- .../sftp/tests/unit/sftp/hooks/test_sftp.py | 12 ++++---- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py index 3e7ee5596488b..118c5440d7095 100644 --- a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py +++ b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py @@ -125,7 +125,7 @@ def close_conn(self) -> None: self.conn = None @contextmanager - def get_sftp_conn(self) -> Generator[SFTPClient, None, None]: + def use_conn(self) -> Generator[SFTPClient, None, None]: """Context manager that closes the connection after use.""" with closing(super().get_conn()) as conn: with closing(conn.open_sftp()) as sftp: @@ -140,7 +140,7 @@ def describe_directory(self, path: str) -> dict[str, dict[str, str | int | None] :param path: full path to the remote directory """ - with self.get_sftp_conn() as conn: # type: SFTPClient + with self.use_conn() as conn: # type: SFTPClient flist = sorted(conn.listdir_attr(path), key=lambda x: x.filename) files = {} for f in flist: @@ -158,7 +158,7 @@ def list_directory(self, path: str) -> list[str]: :param path: full path to the remote directory to list """ - with self.get_sftp_conn() as conn: + with self.use_conn() as conn: return sorted(conn.listdir(path)) def list_directory_with_attr(self, path: str) -> list[SFTPAttributes]: @@ -167,7 +167,7 @@ def list_directory_with_attr(self, path: str) -> list[SFTPAttributes]: :param path: full path to the remote directory to list """ - with self.get_sftp_conn() as conn: + with self.use_conn() as conn: return [file for file in conn.listdir_attr(path)] def mkdir(self, path: str, mode: int = 0o777) -> None: @@ -180,7 +180,7 @@ def mkdir(self, path: str, mode: int = 0o777) -> None: :param path: full path to the remote directory to create :param mode: int permissions of octal mode for directory """ - with self.get_sftp_conn() as conn: + with self.use_conn() as conn: conn.mkdir(path, mode=mode) def isdir(self, path: str) -> bool: @@ -189,7 +189,7 @@ def isdir(self, path: str) -> bool: :param path: full path to the remote directory to check """ - with self.get_sftp_conn() as conn: + with self.use_conn() as conn: try: return stat.S_ISDIR(conn.stat(path).st_mode) # type: ignore except OSError: @@ -201,7 +201,7 @@ def isfile(self, path: str) -> bool: :param path: full path to the remote file to check """ - with self.get_sftp_conn() as conn: + with self.use_conn() as conn: try: return stat.S_ISREG(conn.stat(path).st_mode) # type: ignore except OSError: @@ -230,7 +230,7 @@ def create_directory(self, path: str, mode: int = 0o777) -> None: self.create_directory(dirname, mode) if basename: self.log.info("Creating %s", path) - with self.get_sftp_conn() as conn: + with self.use_conn() as conn: conn.mkdir(path, mode=mode) def delete_directory(self, path: str, include_files: bool = False) -> None: @@ -246,7 +246,7 @@ def delete_directory(self, path: str, include_files: bool = False) -> None: files, dirs, _ = self.get_tree_map(path) dirs = dirs[::-1] # reverse the order for deleting deepest directories first - with self.get_sftp_conn() as conn: + with self.use_conn() as conn: for file_path in files: conn.remove(file_path) for dir_path in dirs: @@ -264,7 +264,7 @@ def retrieve_file(self, remote_full_path: str, local_full_path: str, prefetch: b :param local_full_path: full path to the local file or a file-like buffer :param prefetch: controls whether prefetch is performed (default: True) """ - with self.get_sftp_conn() as conn: + with self.use_conn() as conn: if isinstance(local_full_path, BytesIO): conn.getfo(remote_full_path, local_full_path, prefetch=prefetch) else: @@ -280,7 +280,7 @@ def store_file(self, remote_full_path: str, local_full_path: str, confirm: bool :param remote_full_path: full path to the remote file :param local_full_path: full path to the local file or a file-like buffer """ - with self.get_sftp_conn() as conn: + with self.use_conn() as conn: if isinstance(local_full_path, BytesIO): conn.putfo(local_full_path, remote_full_path, confirm=confirm) else: @@ -292,7 +292,7 @@ def delete_file(self, path: str) -> None: :param path: full path to the remote file """ - with self.get_sftp_conn() as conn: + with self.use_conn() as conn: conn.remove(path) def retrieve_directory(self, remote_full_path: str, local_full_path: str, prefetch: bool = True) -> None: @@ -346,7 +346,7 @@ def get_mod_time(self, path: str) -> str: :param path: full path to the remote file """ - with self.get_sftp_conn() as conn: + with self.use_conn() as conn: ftp_mdtm = conn.stat(path).st_mtime return datetime.datetime.fromtimestamp(ftp_mdtm).strftime("%Y%m%d%H%M%S") # type: ignore @@ -356,7 +356,7 @@ def path_exists(self, path: str) -> bool: :param path: full path to the remote file or directory """ - with self.get_sftp_conn() as conn: + with self.use_conn() as conn: try: conn.stat(path) except OSError: @@ -455,7 +455,7 @@ def append_matching_path_callback(list_: list[str]) -> Callable: def test_connection(self) -> tuple[bool, str]: """Test the SFTP connection by calling path with directory.""" try: - with self.get_sftp_conn() as conn: + with self.use_conn() as conn: conn.normalize(".") return True, "Connection successfully tested" except Exception as e: diff --git a/providers/sftp/tests/unit/sftp/hooks/test_sftp.py b/providers/sftp/tests/unit/sftp/hooks/test_sftp.py index 0e29f2ad157b8..b2709b6a4feae 100644 --- a/providers/sftp/tests/unit/sftp/hooks/test_sftp.py +++ b/providers/sftp/tests/unit/sftp/hooks/test_sftp.py @@ -113,8 +113,8 @@ def test_close_conn(self): self.hook.close_conn() assert self.hook.conn is None - def test_get_sftp_conn(self): - with self.hook.get_sftp_conn() as conn: + def test_use_conn(self): + with self.hook.use_conn() as conn: assert isinstance(conn, paramiko.SFTPClient) @patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn") @@ -124,7 +124,7 @@ def test_get_close_conn(self, mock_get_conn): mock_ssh_client.open_sftp.return_value = mock_sftp_client mock_get_conn.return_value = mock_ssh_client - with SFTPHook().get_sftp_conn() as conn: + with SFTPHook().use_conn() as conn: assert conn == mock_sftp_client mock_sftp_client.close.assert_called_once() @@ -151,7 +151,7 @@ def test_mkdir(self): assert new_dir_name in output # test the directory has default permissions to 777 - umask umask = 0o022 - with self.hook.get_sftp_conn() as conn: + with self.hook.use_conn() as conn: output = conn.lstat(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir_name)) assert output.st_mode & 0o777 == 0o777 - umask @@ -162,7 +162,7 @@ def test_create_and_delete_directory(self): assert new_dir_name in output # test the directory has default permissions to 777 umask = 0o022 - with self.hook.get_sftp_conn() as conn: + with self.hook.use_conn() as conn: output = conn.lstat(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir_name)) assert output.st_mode & 0o777 == 0o777 - umask # test directory already exists for code coverage, should not raise an exception @@ -542,7 +542,7 @@ def test_sftp_hook_with_proxy_command(self, mock_proxy_command, mock_ssh_client) host_proxy_cmd=host_proxy_cmd, ) - with hook.get_sftp_conn(): + with hook.use_conn(): mock_proxy_command.assert_called_once_with(host_proxy_cmd) mock_ssh_client.return_value.connect.assert_called_once_with( hostname="example.com", From 3babb6027c893d4e004d43caa5b6374dd132e17b Mon Sep 17 00:00:00 2001 From: David Blain Date: Fri, 28 Feb 2025 20:02:06 +0100 Subject: [PATCH 5/5] refactor: Renamed use_conn method of SFTPHook to get_managed_conn --- .../src/airflow/providers/sftp/hooks/sftp.py | 30 +++++++++---------- .../sftp/tests/unit/sftp/hooks/test_sftp.py | 12 ++++---- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py index 118c5440d7095..0b38ffaea9a8d 100644 --- a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py +++ b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py @@ -125,7 +125,7 @@ def close_conn(self) -> None: self.conn = None @contextmanager - def use_conn(self) -> Generator[SFTPClient, None, None]: + def get_managed_conn(self) -> Generator[SFTPClient, None, None]: """Context manager that closes the connection after use.""" with closing(super().get_conn()) as conn: with closing(conn.open_sftp()) as sftp: @@ -140,7 +140,7 @@ def describe_directory(self, path: str) -> dict[str, dict[str, str | int | None] :param path: full path to the remote directory """ - with self.use_conn() as conn: # type: SFTPClient + with self.get_managed_conn() as conn: # type: SFTPClient flist = sorted(conn.listdir_attr(path), key=lambda x: x.filename) files = {} for f in flist: @@ -158,7 +158,7 @@ def list_directory(self, path: str) -> list[str]: :param path: full path to the remote directory to list """ - with self.use_conn() as conn: + with self.get_managed_conn() as conn: return sorted(conn.listdir(path)) def list_directory_with_attr(self, path: str) -> list[SFTPAttributes]: @@ -167,7 +167,7 @@ def list_directory_with_attr(self, path: str) -> list[SFTPAttributes]: :param path: full path to the remote directory to list """ - with self.use_conn() as conn: + with self.get_managed_conn() as conn: return [file for file in conn.listdir_attr(path)] def mkdir(self, path: str, mode: int = 0o777) -> None: @@ -180,7 +180,7 @@ def mkdir(self, path: str, mode: int = 0o777) -> None: :param path: full path to the remote directory to create :param mode: int permissions of octal mode for directory """ - with self.use_conn() as conn: + with self.get_managed_conn() as conn: conn.mkdir(path, mode=mode) def isdir(self, path: str) -> bool: @@ -189,7 +189,7 @@ def isdir(self, path: str) -> bool: :param path: full path to the remote directory to check """ - with self.use_conn() as conn: + with self.get_managed_conn() as conn: try: return stat.S_ISDIR(conn.stat(path).st_mode) # type: ignore except OSError: @@ -201,7 +201,7 @@ def isfile(self, path: str) -> bool: :param path: full path to the remote file to check """ - with self.use_conn() as conn: + with self.get_managed_conn() as conn: try: return stat.S_ISREG(conn.stat(path).st_mode) # type: ignore except OSError: @@ -230,7 +230,7 @@ def create_directory(self, path: str, mode: int = 0o777) -> None: self.create_directory(dirname, mode) if basename: self.log.info("Creating %s", path) - with self.use_conn() as conn: + with self.get_managed_conn() as conn: conn.mkdir(path, mode=mode) def delete_directory(self, path: str, include_files: bool = False) -> None: @@ -246,7 +246,7 @@ def delete_directory(self, path: str, include_files: bool = False) -> None: files, dirs, _ = self.get_tree_map(path) dirs = dirs[::-1] # reverse the order for deleting deepest directories first - with self.use_conn() as conn: + with self.get_managed_conn() as conn: for file_path in files: conn.remove(file_path) for dir_path in dirs: @@ -264,7 +264,7 @@ def retrieve_file(self, remote_full_path: str, local_full_path: str, prefetch: b :param local_full_path: full path to the local file or a file-like buffer :param prefetch: controls whether prefetch is performed (default: True) """ - with self.use_conn() as conn: + with self.get_managed_conn() as conn: if isinstance(local_full_path, BytesIO): conn.getfo(remote_full_path, local_full_path, prefetch=prefetch) else: @@ -280,7 +280,7 @@ def store_file(self, remote_full_path: str, local_full_path: str, confirm: bool :param remote_full_path: full path to the remote file :param local_full_path: full path to the local file or a file-like buffer """ - with self.use_conn() as conn: + with self.get_managed_conn() as conn: if isinstance(local_full_path, BytesIO): conn.putfo(local_full_path, remote_full_path, confirm=confirm) else: @@ -292,7 +292,7 @@ def delete_file(self, path: str) -> None: :param path: full path to the remote file """ - with self.use_conn() as conn: + with self.get_managed_conn() as conn: conn.remove(path) def retrieve_directory(self, remote_full_path: str, local_full_path: str, prefetch: bool = True) -> None: @@ -346,7 +346,7 @@ def get_mod_time(self, path: str) -> str: :param path: full path to the remote file """ - with self.use_conn() as conn: + with self.get_managed_conn() as conn: ftp_mdtm = conn.stat(path).st_mtime return datetime.datetime.fromtimestamp(ftp_mdtm).strftime("%Y%m%d%H%M%S") # type: ignore @@ -356,7 +356,7 @@ def path_exists(self, path: str) -> bool: :param path: full path to the remote file or directory """ - with self.use_conn() as conn: + with self.get_managed_conn() as conn: try: conn.stat(path) except OSError: @@ -455,7 +455,7 @@ def append_matching_path_callback(list_: list[str]) -> Callable: def test_connection(self) -> tuple[bool, str]: """Test the SFTP connection by calling path with directory.""" try: - with self.use_conn() as conn: + with self.get_managed_conn() as conn: conn.normalize(".") return True, "Connection successfully tested" except Exception as e: diff --git a/providers/sftp/tests/unit/sftp/hooks/test_sftp.py b/providers/sftp/tests/unit/sftp/hooks/test_sftp.py index b2709b6a4feae..b1c0d13c3693f 100644 --- a/providers/sftp/tests/unit/sftp/hooks/test_sftp.py +++ b/providers/sftp/tests/unit/sftp/hooks/test_sftp.py @@ -113,8 +113,8 @@ def test_close_conn(self): self.hook.close_conn() assert self.hook.conn is None - def test_use_conn(self): - with self.hook.use_conn() as conn: + def test_get_managed_conn(self): + with self.hook.get_managed_conn() as conn: assert isinstance(conn, paramiko.SFTPClient) @patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn") @@ -124,7 +124,7 @@ def test_get_close_conn(self, mock_get_conn): mock_ssh_client.open_sftp.return_value = mock_sftp_client mock_get_conn.return_value = mock_ssh_client - with SFTPHook().use_conn() as conn: + with SFTPHook().get_managed_conn() as conn: assert conn == mock_sftp_client mock_sftp_client.close.assert_called_once() @@ -151,7 +151,7 @@ def test_mkdir(self): assert new_dir_name in output # test the directory has default permissions to 777 - umask umask = 0o022 - with self.hook.use_conn() as conn: + with self.hook.get_managed_conn() as conn: output = conn.lstat(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir_name)) assert output.st_mode & 0o777 == 0o777 - umask @@ -162,7 +162,7 @@ def test_create_and_delete_directory(self): assert new_dir_name in output # test the directory has default permissions to 777 umask = 0o022 - with self.hook.use_conn() as conn: + with self.hook.get_managed_conn() as conn: output = conn.lstat(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, new_dir_name)) assert output.st_mode & 0o777 == 0o777 - umask # test directory already exists for code coverage, should not raise an exception @@ -542,7 +542,7 @@ def test_sftp_hook_with_proxy_command(self, mock_proxy_command, mock_ssh_client) host_proxy_cmd=host_proxy_cmd, ) - with hook.use_conn(): + with hook.get_managed_conn(): mock_proxy_command.assert_called_once_with(host_proxy_cmd) mock_ssh_client.return_value.connect.assert_called_once_with( hostname="example.com",