diff --git a/src/smbclient/shutil.py b/src/smbclient/shutil.py index 13a89e48..edd18659 100644 --- a/src/smbclient/shutil.py +++ b/src/smbclient/shutil.py @@ -281,7 +281,7 @@ def copytree( source path and the destination path as arguments. By default copy() is used, but any function that supports the same signature (like copy()) can be used. - In this current form, copytree() only supports remote to remote copies over SMB. + In this current form, copytree() only supports remote to remote copies over SMB, or remote to local copies. :param src: The source directory to copy. :param dst: The destination directory to copy to. @@ -296,7 +296,11 @@ def copytree( :return: The dst path. """ dir_entries = list(scandir(src, **kwargs)) - makedirs(dst, exist_ok=dirs_exist_ok, **kwargs) + + if is_remote_path(dst): + makedirs(dst, exist_ok=dirs_exist_ok, **kwargs) + else: + os.makedirs(dst, exist_ok=dirs_exist_ok) ignored = [] if ignore is not None: diff --git a/tests/test_smbclient_shutil.py b/tests/test_smbclient_shutil.py index 4f48ea42..be23d8ad 100644 --- a/tests/test_smbclient_shutil.py +++ b/tests/test_smbclient_shutil.py @@ -1149,6 +1149,33 @@ def ignore(name, children): assert fd.read() == "file3.txt" +def test_copytree_with_local_dst(smb_share, tmp_path): + src_dirname = "%s\\source" % smb_share + dst_dirname = str(tmp_path / "target") + + makedirs("%s\\dir1\\subdir1" % src_dirname) + with open_file("%s\\file1.txt" % src_dirname, mode="w") as fd: + fd.write("file1.txt") + with open_file("%s\\dir1\\file2.txt" % src_dirname, mode="w") as fd: + fd.write("file2.txt") + with open_file("%s\\dir1\\subdir1\\file3.txt" % src_dirname, mode="w") as fd: + fd.write("file3.txt") + + actual = copytree(src_dirname, dst_dirname) + assert actual == dst_dirname + + assert sorted(list(os.listdir(dst_dirname))) == ["dir1", "file1.txt"] + assert sorted(list(os.listdir(os.path.join(dst_dirname, "dir1")))) == ["file2.txt", "subdir1"] + assert sorted(list(os.listdir(os.path.join(dst_dirname, "dir1", "subdir1")))) == ["file3.txt"] + + with open(os.path.join(dst_dirname, "file1.txt")) as fd: + assert fd.read() == "file1.txt" + with open(os.path.join(dst_dirname, "dir1", "file2.txt")) as fd: + assert fd.read() == "file2.txt" + with open(os.path.join(dst_dirname, "dir1", "subdir1", "file3.txt")) as fd: + assert fd.read() == "file3.txt" + + @pytest.mark.skipif( os.name != "nt" and not os.environ.get("SMB_FORCE", False), reason="Samba does not update timestamps" )