diff --git a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py index 1301fd9102cfe..367a8f903973b 100644 --- a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py +++ b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py @@ -113,8 +113,9 @@ def __init__( @contextmanager def get_conn(self) -> Generator[SFTPClient, None, None]: """Context manager that closes the connection after use.""" - with closing(super().get_conn().open_sftp()) as conn: - yield conn + with closing(super().get_conn()) as conn: + with closing(conn.open_sftp()) as sftp: + yield sftp def describe_directory(self, path: str) -> dict[str, dict[str, str | int | None]]: """ @@ -204,18 +205,18 @@ def create_directory(self, path: str, mode: int = 0o777) -> None: :param path: full path to the remote directory to create :param mode: int permissions of octal mode for directory """ - with self.get_conn() as conn: - if self.isdir(path): - self.log.info("%s already exists", path) - return - elif self.isfile(path): - raise AirflowException(f"{path} already exists and is a file") - else: - dirname, basename = os.path.split(path) - if dirname and not self.isdir(dirname): - self.create_directory(dirname, mode) - if basename: - self.log.info("Creating %s", path) + if self.isdir(path): + self.log.info("%s already exists", path) + return + elif self.isfile(path): + raise AirflowException(f"{path} already exists and is a file") + else: + dirname, basename = os.path.split(path) + if dirname and not self.isdir(dirname): + self.create_directory(dirname, mode) + if basename: + self.log.info("Creating %s", path) + with self.get_conn() as conn: conn.mkdir(path, mode=mode) def delete_directory(self, path: str, include_files: bool = False) -> None: @@ -224,14 +225,18 @@ def delete_directory(self, path: str, include_files: bool = False) -> None: :param path: full path to the remote directory to delete """ + files: list[str] = [] + dirs: list[str] = [] + + if include_files is True: + files, dirs, _ = self.get_tree_map(path) + dirs = dirs[::-1] # reverse the order for deleting deepest directories first + with self.get_conn() as conn: - if include_files is True: - files, dirs, _ = self.get_tree_map(path) - dirs = dirs[::-1] # reverse the order for deleting deepest directories first - for file_path in files: - conn.remove(file_path) - for dir_path in dirs: - conn.rmdir(dir_path) + for file_path in files: + conn.remove(file_path) + for dir_path in dirs: + conn.rmdir(dir_path) conn.rmdir(path) def retrieve_file(self, remote_full_path: str, local_full_path: str, prefetch: bool = True) -> None: diff --git a/providers/sftp/tests/provider_tests/sftp/hooks/test_sftp.py b/providers/sftp/tests/provider_tests/sftp/hooks/test_sftp.py index a35ee4010712d..afed97ba574d5 100644 --- a/providers/sftp/tests/provider_tests/sftp/hooks/test_sftp.py +++ b/providers/sftp/tests/provider_tests/sftp/hooks/test_sftp.py @@ -117,6 +117,7 @@ def test_get_close_conn(self, mock_get_conn): assert conn == mock_sftp_client mock_sftp_client.close.assert_called_once() + mock_ssh_client.close.assert_called_once() def test_describe_directory(self): output = self.hook.describe_directory(self.temp_dir)