diff --git a/setup.cfg b/setup.cfg index bf31095..f3a3cfe 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,6 +50,7 @@ dev = # pyarrow pydantic pydantic-settings + smbprotocol [options.package_data] upath = diff --git a/upath/_flavour.py b/upath/_flavour.py index 6bbabf0..5489960 100644 --- a/upath/_flavour.py +++ b/upath/_flavour.py @@ -108,6 +108,7 @@ class WrappedFileSystemFlavour: # (pathlib_abc.FlavourBase) "https", "s3", "s3a", + "smb", "gs", "gcs", "az", diff --git a/upath/implementations/smb.py b/upath/implementations/smb.py new file mode 100644 index 0000000..c072165 --- /dev/null +++ b/upath/implementations/smb.py @@ -0,0 +1,52 @@ +import warnings + +import smbprotocol.exceptions + +from upath import UPath + + +class SMBPath(UPath): + __slots__ = () + + def mkdir(self, mode=0o777, parents=False, exist_ok=False): + # smbclient does not support setting mode externally + if parents and not exist_ok and self.exists(): + raise FileExistsError(str(self)) + try: + self.fs.mkdir( + self.path, + create_parents=parents, + ) + except smbprotocol.exceptions.SMBOSError: + if not exist_ok: + raise FileExistsError(str(self)) + if not self.is_dir(): + raise FileExistsError(str(self)) + + def iterdir(self): + if not self.is_dir(): + raise NotADirectoryError(str(self)) + else: + return super().iterdir() + + def rename(self, target, **kwargs): + if "recursive" in kwargs: + warnings.warn( + "SMBPath.rename(): recursive is currently ignored.", + UserWarning, + stacklevel=2, + ) + if "maxdepth" in kwargs: + warnings.warn( + "SMBPath.rename(): maxdepth is currently ignored.", + UserWarning, + stacklevel=2, + ) + if not isinstance(target, UPath): + target = self.parent.joinpath(target).resolve() + self.fs.mv( + self.path, + target.path, + **kwargs, + ) + return target diff --git a/upath/registry.py b/upath/registry.py index 7a54b7f..c886e39 100644 --- a/upath/registry.py +++ b/upath/registry.py @@ -80,6 +80,7 @@ class _Registry(MutableMapping[str, "type[upath.UPath]"]): "webdav+http": "upath.implementations.webdav.WebdavPath", "webdav+https": "upath.implementations.webdav.WebdavPath", "github": "upath.implementations.github.GitHubPath", + "smb": "upath.implementations.smb.SMBPath", } if TYPE_CHECKING: diff --git a/upath/tests/conftest.py b/upath/tests/conftest.py index a2f85b0..976623e 100644 --- a/upath/tests/conftest.py +++ b/upath/tests/conftest.py @@ -12,6 +12,7 @@ import pytest from fsspec.implementations.local import LocalFileSystem from fsspec.implementations.local import make_path_posix +from fsspec.implementations.smb import SMBFileSystem from fsspec.registry import _registry from fsspec.registry import register_implementation from fsspec.utils import stringify_path @@ -409,3 +410,57 @@ def azure_fixture(azurite_credentials, azure_container): finally: for blob in client.list_blobs(): client.delete_blob(blob["name"]) + + +@pytest.fixture(scope="module") +def smb_container(): + try: + pchk = ["docker", "run", "--name", "fsspec_test_smb", "hello-world"] + subprocess.check_call(pchk) + stop_docker("fsspec_test_smb") + except (subprocess.CalledProcessError, FileNotFoundError): + pytest.skip("docker run not available") + + # requires docker + container = "fsspec_smb" + stop_docker(container) + cfg = "-p -u 'testuser;testpass' -s 'home;/share;no;no;no;testuser'" + port = 445 + img = f"docker run --name {container} --detach -p 139:139 -p {port}:445 dperson/samba" # noqa: E231 E501 + cmd = f"{img} {cfg}" + try: + subprocess.check_output(shlex.split(cmd)).strip().decode() + time.sleep(2) + yield { + "host": "localhost", + "port": port, + "username": "testuser", + "password": "testpass", + "register_session_retries": 100, # max ~= 10 seconds + } + finally: + import smbclient # pylint: disable=import-outside-toplevel + + smbclient.reset_connection_cache() + stop_docker(container) + + +@pytest.fixture +def smb_url(smb_container): + smb_url = "smb://{username}:{password}@{host}/home/" + smb_url = smb_url.format(**smb_container) + return smb_url + + +@pytest.fixture +def smb_fixture(local_testdir, smb_url, smb_container): + smb = SMBFileSystem( + host=smb_container["host"], + port=smb_container["port"], + username=smb_container["username"], + password=smb_container["password"], + ) + url = smb_url + "testdir/" + smb.put(local_testdir, "/home/testdir", recursive=True) + yield url + smb.delete("/home/testdir", recursive=True) diff --git a/upath/tests/implementations/test_smb.py b/upath/tests/implementations/test_smb.py new file mode 100644 index 0000000..f404613 --- /dev/null +++ b/upath/tests/implementations/test_smb.py @@ -0,0 +1,38 @@ +import pytest +from fsspec import __version__ as fsspec_version +from packaging.version import Version + +from upath import UPath +from upath.tests.cases import BaseTests +from upath.tests.utils import skip_on_windows + + +@skip_on_windows +class TestUPathSMB(BaseTests): + + @pytest.fixture(autouse=True) + def path(self, smb_fixture): + self.path = UPath(smb_fixture) + + @pytest.mark.parametrize( + "pattern", + ( + "*.txt", + pytest.param( + "*", + marks=pytest.mark.xfail( + reason="SMBFileSystem.info appends '/' to dirs" + ), + ), + pytest.param( + "**/*.txt", + marks=( + pytest.mark.xfail(reason="requires fsspec>=2023.9.0") + if Version(fsspec_version) < Version("2023.9.0") + else () + ), + ), + ), + ) + def test_glob(self, pathlib_base, pattern): + super().test_glob(pathlib_base, pattern) diff --git a/upath/tests/test_registry.py b/upath/tests/test_registry.py index 1c54357..e7fa162 100644 --- a/upath/tests/test_registry.py +++ b/upath/tests/test_registry.py @@ -22,6 +22,7 @@ "memory", "s3", "s3a", + "smb", "webdav", "webdav+http", "webdav+https",