From e72d6421b14ce02ab2610838f2e34bd708ec7804 Mon Sep 17 00:00:00 2001 From: Ron Frederick Date: Mon, 2 Dec 2024 18:12:38 -0800 Subject: [PATCH] Revert part of remote_copy change to allow pathnames This commit backs out some of the prior change to _SFTPFileCopier so that it still handles file opens prior to deciding whether to use remote_copy() or not. This seems to make a difference on Windows in particular to properly handle errors in a copy. Passing in either an open file or pathname is still supported. This change only changes where the file open happens in the case where _SFTPFileCopier is used. --- asyncssh/sftp.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/asyncssh/sftp.py b/asyncssh/sftp.py index 6d29126..11b39e8 100644 --- a/asyncssh/sftp.py +++ b/asyncssh/sftp.py @@ -801,22 +801,6 @@ async def run_task(self, offset: int, size: int) -> Tuple[int, int]: async def run(self) -> None: """Perform parallel file copy""" - if self._srcfs == self._dstfs and \ - isinstance(self._srcfs, SFTPClient): - try: - await self._srcfs.remote_copy(self._srcpath, self._dstpath) - except SFTPOpUnsupported: - pass - else: - self._bytes_copied = self._total_bytes - - if self._progress_handler: - self._progress_handler(self._srcpath, self._dstpath, - self._bytes_copied, - self._total_bytes) - - return - try: self._src = await self._srcfs.open(self._srcpath, 'rb', block_size=0) @@ -826,6 +810,24 @@ async def run(self) -> None: if self._progress_handler and self._total_bytes == 0: self._progress_handler(self._srcpath, self._dstpath, 0, 0) + if self._srcfs == self._dstfs and \ + isinstance(self._srcfs, SFTPClient): + try: + await self._srcfs.remote_copy( + cast(SFTPClientFile, self._src), + cast(SFTPClientFile, self._dst)) + except SFTPOpUnsupported: + pass + else: + self._bytes_copied = self._total_bytes + + if self._progress_handler: + self._progress_handler(self._srcpath, self._dstpath, + self._bytes_copied, + self._total_bytes) + + return + async for _, datalen in self.iter(): if datalen: self._bytes_copied += datalen