diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e1671cd1..eb7bcdcc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,6 +28,7 @@ repos: rev: v2.4.1 hooks: - id: codespell + args: ['-L', 'fo'] additional_dependencies: ["tomli"] - repo: https://github.com/asottile/pyupgrade rev: v3.19.1 diff --git a/pyproject.toml b/pyproject.toml index ac0b69b0..671af64c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -141,6 +141,7 @@ exclude_lines = [ "raise NotImplementedError", "raise AssertionError", "@overload", + "except ImportError", ] [tool.mypy] @@ -176,10 +177,12 @@ ignore_missing_imports = true [[tool.mypy.overrides]] module = "pydantic.*" +ignore_missing_imports = true ignore_errors = true [[tool.mypy.overrides]] module = "pydantic_core.*" +ignore_missing_imports = true ignore_errors = true [[tool.mypy.overrides]] diff --git a/upath/__init__.py b/upath/__init__.py index d08612a9..9db64eee 100644 --- a/upath/__init__.py +++ b/upath/__init__.py @@ -1,10 +1,25 @@ """Pathlib API extended to use fsspec backends.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + try: from upath._version import __version__ except ImportError: __version__ = "not-installed" -from upath.core import UPath +if TYPE_CHECKING: + from upath.core import UPath __all__ = ["UPath"] + + +def __getattr__(name): + if name == "UPath": + from upath.core import UPath + + globals()["UPath"] = UPath + return UPath + else: + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/upath/_chain.py b/upath/_chain.py new file mode 100644 index 00000000..9b779196 --- /dev/null +++ b/upath/_chain.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +import sys +import warnings +from collections import defaultdict +from collections.abc import MutableMapping +from collections.abc import Sequence +from collections.abc import Set +from typing import TYPE_CHECKING +from typing import Any +from typing import NamedTuple + +if TYPE_CHECKING: + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self + +from upath._flavour import WrappedFileSystemFlavour +from upath._protocol import get_upath_protocol +from upath.registry import available_implementations + +__all__ = [ + "ChainSegment", + "Chain", + "FSSpecChainParser", + "DEFAULT_CHAIN_PARSER", +] + + +class ChainSegment(NamedTuple): + path: str | None # support for path passthrough (i.e. simplecache) + protocol: str + storage_options: dict[str, Any] + + +class Chain: + """holds current chain segments""" + + __slots__ = ( + "_segments", + "_index", + ) + + def __init__( + self, + *segments: ChainSegment, + index: int = 0, + ) -> None: + if not (0 <= index < len(segments)): + raise ValueError("index must be between 0 and len(segments)") + self._segments = segments + self._index = index + + def __repr__(self) -> str: + args = ", ".join(map(repr, self._segments)) + if self._index != 0: + args += f", index={self._index}" + return f"{type(self).__name__}({args})" + + @property + def current(self) -> ChainSegment: + return self._segments[self._index] + + @property + def _path_index(self) -> int: + for idx, segment in enumerate(self._segments[self._index :], start=self._index): + if segment.path is not None: + return idx + raise IndexError("No target path found") + + @property + def active_path(self) -> str: + path = self._segments[self._path_index].path + if path is None: + raise RuntimeError + return path + + @property + def active_path_protocol(self) -> str: + return self._segments[self._path_index].protocol + + def replace( + self, + *, + path: str | None = None, + protocol: str | None = None, + storage_options: dict[str, Any] | None = None, + ) -> Self: + """replace the current chain segment keeping remaining chain segments""" + segments = self.to_list() + index = self._index + + replacements: MutableMapping[int, dict[str, Any]] = defaultdict(dict) + if protocol is not None: + replacements[index]["protocol"] = protocol + if storage_options is not None: + replacements[index]["storage_options"] = storage_options + if path is not None: + replacements[self._path_index]["path"] = path + + for idx, items in replacements.items(): + segments[idx] = segments[idx]._replace(**items) + + return type(self)(*segments, index=index) + + def to_list(self) -> list[ChainSegment]: + return list(self._segments) + + @classmethod + def from_list(cls, segments: list[ChainSegment], index: int = 0) -> Self: + return cls(*segments, index=index) + + def nest(self) -> ChainSegment: + """return a nested target_* structure""" + # see: fsspec.core.url_to_fs + inkwargs: dict[str, Any] = {} + # Reverse iterate the chain, creating a nested target_* structure + chain = self._segments + _prev = chain[-1].path + for i, ch in enumerate(reversed(chain)): + urls, protocol, kw = ch + if urls is None: + urls = _prev + _prev = urls + if i == len(chain) - 1: + inkwargs = dict(**kw, **inkwargs) + continue + inkwargs["target_options"] = dict(**kw, **inkwargs) + inkwargs["target_protocol"] = protocol + inkwargs["fo"] = urls # codespell:ignore fo + urlpath, protocol, _ = chain[0] + return ChainSegment(urlpath, protocol, inkwargs) + + +class FSSpecChainParser: + """parse an fsspec chained urlpath""" + + def __init__(self) -> None: + self.link: str = "::" + self.known_protocols: Set[str] = set() + + def unchain(self, path: str, kwargs: dict[str, Any]) -> list[ChainSegment]: + """implements same behavior as fsspec.core._un_chain + + two differences: + 1. it sets the urlpath to None for upstream filesystems that passthrough + 2. it checks against the known protocols for exact matches + + """ + # TODO: upstream to fsspec + first_bit_protocol: str | None = kwargs.pop("protocol", None) + it_bits = iter(path.split(self.link)) + bits: list[str] + if first_bit_protocol is not None: + bits = [next(it_bits)] + else: + bits = [] + for p in it_bits: + if "://" in p: # uri-like, fast-path + bits.append(p) + elif "/" in p: # path-like, fast-path + bits.append(p) + elif p in self.known_protocols: # exact match a fsspec protocol + bits.append(f"{p}://") + elif p in (m := set(available_implementations(fallback=True))): + self.known_protocols = m + bits.append(f"{p}://") + else: + bits.append(p) + + # [[url, protocol, kwargs], ...] + out: list[ChainSegment] = [] + previous_bit: str | None = None + kwargs = kwargs.copy() + first_bit_idx = len(bits) - 1 + for idx, bit in enumerate(reversed(bits)): + if idx == first_bit_idx: + protocol = first_bit_protocol or get_upath_protocol(bit) or "" + else: + protocol = get_upath_protocol(bit) or "" + flavour = WrappedFileSystemFlavour.from_protocol(protocol) + extra_kwargs = flavour.get_kwargs_from_url(bit) + kws = kwargs.pop(protocol, {}) + if bit is bits[0]: + kws.update(kwargs) + kw = dict(**extra_kwargs) + kw.update(kws) + bit = flavour.strip_protocol(bit) or flavour.root_marker + if ( + protocol in {"blockcache", "filecache", "simplecache"} + and "target_protocol" not in kw + ): + out.append(ChainSegment(None, protocol, kw)) + if previous_bit is not None: + bit = previous_bit + else: + out.append(ChainSegment(bit, protocol, kw)) + previous_bit = bit + out.reverse() + return out + + def chain(self, segments: Sequence[ChainSegment]) -> tuple[str, dict[str, Any]]: + """returns a chained urlpath from the segments""" + urlpaths = [] + kwargs = {} + for segment in segments: + if segment.protocol and segment.path is not None: + # FIXME: currently unstrip_protocol is only implemented by + # AbstractFileSystem, LocalFileSystem, and OSSFileSystem + # so to make this work we just implement it ourselves here. + # To do this properly we would need to instantiate the + # filesystem with its storage options and call + # fs.unstrip_protocol(segment.path) + if segment.path.startswith(f"{segment.protocol}:/"): + urlpath = segment.path + else: + urlpath = f"{segment.protocol}://{segment.path}" + elif segment.protocol: + urlpath = segment.protocol + elif segment.path is not None: + urlpath = segment.path + else: + warnings.warn( + f"skipping invalid segment {segment}", + RuntimeWarning, + stacklevel=2, + ) + continue + urlpaths.append(urlpath) + # TODO: ensure roundtrip with unchain behavior + kwargs[segment.protocol] = segment.storage_options + return self.link.join(urlpaths), kwargs + + +DEFAULT_CHAIN_PARSER = FSSpecChainParser() + + +if __name__ == "__main__": + from pprint import pp + + from fsspec.core import _un_chain + + chained_path = "simplecache::zip://haha.csv::gcs://bucket/file.zip" + chained_kw = {"zip": {"allowZip64": False}} + print(chained_path, chained_kw) + out0 = _un_chain(chained_path, chained_kw) + out1 = FSSpecChainParser().unchain(chained_path, chained_kw) + + pp(out0) + pp(out1) + + rechained_path, rechained_kw = FSSpecChainParser().chain(out1) + print(rechained_path, rechained_kw) + + # UPath should store segments and access the path to operate on + # through segments.current.path + segments0 = Chain.from_list(segments=out1, index=1) + assert segments0.current.protocol == "zip" + + # try to switch out zip path + segments1 = segments0.replace(path="/newfile.csv") + new_path, new_kw = FSSpecChainParser().chain(segments1.to_list()) + print(new_path, new_kw) diff --git a/upath/_flavour.py b/upath/_flavour.py index b0fa0366..25d35c3a 100644 --- a/upath/_flavour.py +++ b/upath/_flavour.py @@ -403,7 +403,9 @@ def __get__( self, obj: UPath | None, objtype: type[UPath] | None = None ) -> WrappedFileSystemFlavour: if obj is not None: - return WrappedFileSystemFlavour.from_protocol(obj.protocol) + return WrappedFileSystemFlavour.from_protocol( + obj._chain.active_path_protocol + ) elif self._default_protocol: # type: ignore return WrappedFileSystemFlavour.from_protocol(self._default_protocol) else: diff --git a/upath/_protocol.py b/upath/_protocol.py index a8897ba7..db1fead1 100644 --- a/upath/_protocol.py +++ b/upath/_protocol.py @@ -2,10 +2,14 @@ import os import re +from collections import ChainMap from pathlib import PurePath from typing import TYPE_CHECKING from typing import Any +from fsspec.registry import known_implementations as _known_implementations +from fsspec.registry import registry as _registry + if TYPE_CHECKING: from upath.types import JoinablePath @@ -18,7 +22,7 @@ # Regular expression to match fsspec style protocols. # Matches single slash usage too for compatibility. _PROTOCOL_RE = re.compile( - r"^(?P[A-Za-z][A-Za-z0-9+]+):(?P//?)(?P.*)" + r"^(?P[A-Za-z][A-Za-z0-9+]+):(?:(?P//?)|:)(?P.*)" ) # Matches data URIs @@ -33,6 +37,28 @@ def _match_protocol(pth: str) -> str: return "" +_fsspec_registry_map = ChainMap(_registry, _known_implementations) + + +def _fsspec_protocol_equals(p0: str, p1: str) -> bool: + """check if two fsspec protocols are equivalent""" + p0 = p0 or "file" + p1 = p1 or "file" + if p0 == p1: + return True + + try: + o0 = _fsspec_registry_map[p0] + except KeyError: + raise ValueError(f"Protocol not known: {p0}") + try: + o1 = _fsspec_registry_map[p1] + except KeyError: + raise ValueError(f"Protocol not known: {p1}") + + return o0 == o1 + + def get_upath_protocol( pth: str | os.PathLike[str] | PurePath | JoinablePath, *, @@ -54,7 +80,11 @@ def get_upath_protocol( pth_protocol = _match_protocol(str(pth)) # if storage_options and not protocol and not pth_protocol: # protocol = "file" - if protocol and pth_protocol and not pth_protocol.startswith(protocol): + if ( + protocol + and pth_protocol + and not _fsspec_protocol_equals(pth_protocol, protocol) + ): raise ValueError( f"requested protocol {protocol!r} incompatible with {pth_protocol!r}" ) @@ -63,7 +93,7 @@ def get_upath_protocol( def normalize_empty_netloc(pth: str) -> str: if m := _PROTOCOL_RE.match(pth): - if len(m.group("slashes")) == 1: + if m.group("slashes") == "/": protocol = m.group("protocol") path = m.group("path") pth = f"{protocol}:///{path}" @@ -80,6 +110,6 @@ def compatible_protocol( # consider protocols equivalent if they match up to the first "+" other_protocol = other_protocol.partition("+")[0] # protocols: only identical (or empty "") protocols can combine - if other_protocol and other_protocol != protocol: + if other_protocol and not _fsspec_protocol_equals(other_protocol, protocol): return False return True diff --git a/upath/core.py b/upath/core.py index 4b120614..5c889ca9 100644 --- a/upath/core.py +++ b/upath/core.py @@ -23,7 +23,11 @@ from fsspec.registry import get_filesystem_class from fsspec.spec import AbstractFileSystem +from upath._chain import DEFAULT_CHAIN_PARSER +from upath._chain import Chain +from upath._chain import FSSpecChainParser from upath._flavour import LazyFlavourDescriptor +from upath._flavour import WrappedFileSystemFlavour from upath._flavour import upath_get_kwargs_from_url from upath._flavour import upath_urijoin from upath._protocol import compatible_protocol @@ -70,18 +74,6 @@ def _make_instance(cls, args, kwargs): return cls(*args, **kwargs) -def _explode_path(path, parser): - split = parser.split - path = parser.strip_protocol(path) - parent, name = parser.split(path) - names = [] - while path != parent: - names.append(name) - path = parent - parent, name = split(path) - return path, names - - def _buffering2blocksize(mode: str, buffering: int) -> int | None: if not isinstance(buffering, int): raise TypeError("buffering must be an integer") @@ -95,17 +87,26 @@ def _buffering2blocksize(mode: str, buffering: int) -> int | None: return buffering -if sys.version_info >= (3, 11): - _UPathMeta = ABCMeta - -else: - - class _UPathMeta(ABCMeta): +class _UPathMeta(ABCMeta): + if sys.version_info < (3, 11): # pathlib 3.9 and 3.10 supported `Path[str]` but # did not return a GenericAlias but the class itself? def __getitem__(cls, key): return cls + def __call__(cls, *args, **kwargs): + # create a copy if UPath class + try: + (arg0,) = args + except ValueError: + pass + else: + if isinstance(arg0, UPath) and not kwargs: + return copy(arg0) + inst = cls.__new__(cls, *args, **kwargs) + inst.__init__(*args, **kwargs) + return inst + class _UPathMixin(metaclass=_UPathMeta): __slots__ = () @@ -116,21 +117,39 @@ def parser(self) -> UPathParser: raise NotImplementedError @property - @abstractmethod def _protocol(self) -> str: - raise NotImplementedError + return self._chain.nest().protocol @_protocol.setter def _protocol(self, value: str) -> None: - raise NotImplementedError + self._chain = self._chain.replace(protocol=value) @property - @abstractmethod def _storage_options(self) -> dict[str, Any]: - raise NotImplementedError + return self._chain.nest().storage_options @_storage_options.setter def _storage_options(self, value: dict[str, Any]) -> None: + self._chain = self._chain.replace(storage_options=value) + + @property + @abstractmethod + def _chain(self) -> Chain: + raise NotImplementedError + + @_chain.setter + @abstractmethod + def _chain(self, value: Chain) -> None: + raise NotImplementedError + + @property + @abstractmethod + def _chain_parser(self) -> FSSpecChainParser: + raise NotImplementedError + + @_chain_parser.setter + @abstractmethod + def _chain_parser(self, value: FSSpecChainParser) -> None: raise NotImplementedError @property @@ -235,19 +254,13 @@ def __new__( cls, *args: JoinablePathLike, protocol: str | None = None, + chain_parser: FSSpecChainParser = DEFAULT_CHAIN_PARSER, **storage_options: Any, ) -> UPath: # narrow type - assert issubclass(cls, UPath), "_UPathMixin should never be instantiated" - - # fill empty arguments - if not args: - args = (".",) - - # create a copy if UPath class - part0, *parts = args - if not parts and not storage_options and isinstance(part0, cls): - return copy(part0) + assert issubclass( + cls, UPath + ), "UPath.__new__ can't instantiate non-UPath classes" # deprecate 'scheme' if "scheme" in storage_options: @@ -260,7 +273,9 @@ def __new__( # determine the protocol pth_protocol = get_upath_protocol( - part0, protocol=protocol, storage_options=storage_options + args[0] if args else "", + protocol=protocol, + storage_options=storage_options, ) # determine which UPath subclass to dispatch to if cls._protocol_dispatch or cls._protocol_dispatch is None: @@ -277,30 +292,13 @@ def __new__( # for all supported user protocols. upath_cls = cls - # create a new instance - if cls is UPath: - # we called UPath() directly, and want an instance based on the - # provided or detected protocol (i.e. upath_cls) - obj: UPath = object.__new__(upath_cls) - obj._protocol = pth_protocol - - if cls not in upath_cls.mro(): - # we are not in the upath_cls mro, so we need to - # call __init__ of the upath_cls - upath_cls.__init__(obj, *args, protocol=pth_protocol, **storage_options) - - elif issubclass(cls, upath_cls): - # we called a sub- or sub-sub-class of UPath, i.e. S3Path() and the - # corresponding upath_cls based on protocol is equal-to or a - # parent-of the cls. - obj = object.__new__(cls) - obj._protocol = pth_protocol - - elif issubclass(cls, UPath): - # we called a subclass of UPath directly, i.e. S3Path() but the - # detected protocol would return a non-related UPath subclass, i.e. - # S3Path("file:///abc"). This behavior is going to raise an error - # in future versions + if issubclass(upath_cls, cls): + pass + + elif not issubclass(upath_cls, UPath): + raise RuntimeError("UPath.__new__ expected cls to be subclass of UPath") + + else: msg_protocol = repr(pth_protocol) if not pth_protocol: msg_protocol += " (empty string)" @@ -315,65 +313,74 @@ def __new__( f" registering the {cls.__name__} implementation with protocol" f" {msg_protocol!s} replacing the default implementation." ) - warnings.warn(msg, DeprecationWarning, stacklevel=2) - - obj = object.__new__(upath_cls) - obj._protocol = pth_protocol - - upath_cls.__init__( - obj, *args, protocol=pth_protocol, **storage_options - ) # type: ignore - - else: - raise RuntimeError("UPath.__new__ expected cls to be subclass of UPath") + warnings.warn( + msg, + DeprecationWarning, + stacklevel=2, + ) + upath_cls = cls - return obj + return object.__new__(upath_cls) def __init__( self, *args: JoinablePathLike, protocol: str | None = None, + chain_parser: FSSpecChainParser = DEFAULT_CHAIN_PARSER, **storage_options: Any, ) -> None: - # allow subclasses to customize __init__ arg parsing - base_options = getattr(self, "_storage_options", {}) + + # todo: avoid duplicating this call from __new__ + protocol = get_upath_protocol( + args[0] if args else "", + protocol=protocol, + storage_options=storage_options, + ) args, protocol, storage_options = type(self)._transform_init_args( - args, protocol or self._protocol, {**base_options, **storage_options} + args, protocol, storage_options ) - if self._protocol != protocol and protocol: - self._protocol = protocol - # retrieve storage_options + # check that UPath subclasses in args are compatible + # TODO: + # Future versions of UPath could verify that storage_options + # can be combined between UPath instances. Not sure if this + # is really necessary though. A warning might be enough... + if not compatible_protocol(protocol, *args): + raise ValueError("can't combine incompatible UPath protocols") + if args: args0 = args[0] if isinstance(args0, UPath): - self._storage_options = {**args0.storage_options, **storage_options} + storage_options = { + **args0._chain.nest().storage_options, + **storage_options, + } + str_args0 = str(args0) + else: - if hasattr(args0, "__fspath__"): - _args0 = args0.__fspath__() + if hasattr(args0, "__fspath__") and args0.__fspath__ is not None: + str_args0 = args0.__fspath__() else: - _args0 = str(args0) - self._storage_options = type(self)._parse_storage_options( - _args0, protocol, storage_options + str_args0 = str(args0) + storage_options = type(self)._parse_storage_options( + str_args0, protocol, storage_options + ) + if len(args) > 1: + str_args0 = WrappedFileSystemFlavour.from_protocol(protocol).join( + str_args0, *args[1:] ) else: - self._storage_options = storage_options.copy() + str_args0 = "." - # check that UPath subclasses in args are compatible - # TODO: - # Future versions of UPath could verify that storage_options - # can be combined between UPath instances. Not sure if this - # is really necessary though. A warning might be enough... - if not compatible_protocol(self._protocol, *args): - raise ValueError("can't combine incompatible UPath protocols") - - if hasattr(self, "_raw_urlpaths"): - return + segments = chain_parser.unchain( + str_args0, {"protocol": protocol, **storage_options} + ) + self._chain = Chain.from_list(segments) + self._chain_parser = chain_parser self._raw_urlpaths = args # --- deprecated attributes --------------------------------------- - # deprecation @property def _url(self) -> SplitResult: # TODO: @@ -384,15 +391,15 @@ def _url(self) -> SplitResult: class UPath(_UPathMixin, OpenablePath): __slots__ = ( - "_protocol", - "_storage_options", + "_chain", + "_chain_parser", "_fs_cached", "_raw_urlpaths", ) if TYPE_CHECKING: - _protocol: str - _storage_options: dict[str, Any] + _chain: Chain + _chain_parser: FSSpecChainParser _fs_cached: bool _raw_urlpaths: Sequence[JoinablePathLike] @@ -408,15 +415,7 @@ def with_segments(self, *pathsegments: JoinablePathLike) -> Self: ) def __str__(self) -> str: - path = self.parser.join(*self._raw_urlpaths) - if self._protocol: - if path.startswith(f"{self._protocol}://"): - return path - elif path.startswith(f"{self._protocol}:/"): - return path.replace(":/", "://", 1) - else: - return f"{self._protocol}://{path}" - return path + return self._chain_parser.chain(self._chain.to_list())[0] def __repr__(self) -> str: return f"{type(self).__name__}({self.path!r}, protocol={self._protocol!r})" @@ -425,9 +424,29 @@ def __repr__(self) -> str: @property def parts(self) -> Sequence[str]: - anchor, parts = _explode_path(str(self), self.parser) - if anchor: - parts.append(anchor) + split = self.parser.split + sep = self.parser.sep + + path = self._chain.active_path + drive = self.parser.splitdrive(self._chain.active_path)[0] + stripped_path = self.parser.strip_protocol(path) + if stripped_path: + _, _, tail = path.partition(stripped_path) + path = stripped_path + tail + + parent, name = split(path) + names = [] + while path != parent: + names.append(name) + path = parent + parent, name = split(path) + + if names and names[-1] == drive: + names = names[:-1] + if names and names[-1].startswith(sep): + parts = [*names[:-1], names[-1].removeprefix(sep), drive + sep] + else: + parts = [*names, drive + sep] return tuple(reversed(parts)) def with_name(self, name) -> Self: @@ -439,6 +458,10 @@ def with_name(self, name) -> Self: path = path.removesuffix(split(path)[1]) + name return self.with_segments(path) + @property + def anchor(self) -> str: + return self.drive + self.root + # === ReadablePath attributes ===================================== @property @@ -865,7 +888,7 @@ def replace(self, target: WritablePathLike) -> Self: @property def drive(self) -> str: - return self.parser.splitdrive(str(self))[0] + return self.parser.splitroot(str(self))[0] @property def root(self) -> str: diff --git a/upath/extensions.py b/upath/extensions.py index bdd8c627..3c9cec06 100644 --- a/upath/extensions.py +++ b/upath/extensions.py @@ -16,6 +16,8 @@ from fsspec import AbstractFileSystem +from upath._chain import Chain +from upath._chain import ChainSegment from upath._stat import UPathStatResult from upath.core import UPath from upath.types import UNSET_DEFAULT @@ -54,13 +56,7 @@ class ProxyUPath: # _fs_factory # _protocol_dispatch - def __init__( - self, - *args: JoinablePathLike, - protocol: str | None = None, - **storage_options: Any, - ) -> None: - self.__wrapped__ = UPath(*args, protocol=protocol, **storage_options) + # === non-public methods / attributes ============================= @classmethod def _from_upath(cls, upath: UPath, /) -> Self: @@ -71,6 +67,31 @@ def _from_upath(cls, upath: UPath, /) -> Self: obj.__wrapped__ = upath return obj + @property + def _chain(self): + try: + return self.__wrapped__._chain + except AttributeError: + return Chain( + ChainSegment( + path=self.__wrapped__.path, + protocol=self.__wrapped__.protocol, + storage_options=dict(self.__wrapped__.storage_options), + ), + [], + [], + ) + + # === wrapped interface =========================================== + + def __init__( + self, + *args: JoinablePathLike, + protocol: str | None = None, + **storage_options: Any, + ) -> None: + self.__wrapped__ = UPath(*args, protocol=protocol, **storage_options) + @property def parser(self) -> UPathParser: return self.__wrapped__.parser diff --git a/upath/implementations/cached.py b/upath/implementations/cached.py new file mode 100644 index 00000000..2f4071f3 --- /dev/null +++ b/upath/implementations/cached.py @@ -0,0 +1,5 @@ +from upath.core import UPath + + +class SimpleCachePath(UPath): + pass diff --git a/upath/implementations/cloud.py b/upath/implementations/cloud.py index 865e3e95..fabb7383 100644 --- a/upath/implementations/cloud.py +++ b/upath/implementations/cloud.py @@ -44,6 +44,10 @@ def _transform_init_args( break return super()._transform_init_args(args, protocol, storage_options) + @property + def root(self) -> str: + return self.parser.sep + def mkdir( self, mode: int = 0o777, parents: bool = False, exist_ok: bool = False ) -> None: diff --git a/upath/implementations/github.py b/upath/implementations/github.py index 741dfa12..1d37724d 100644 --- a/upath/implementations/github.py +++ b/upath/implementations/github.py @@ -2,6 +2,8 @@ GitHub file system implementation """ +from collections.abc import Sequence + import upath.core @@ -21,3 +23,11 @@ def iterdir(self): if self.is_file(): raise NotADirectoryError(str(self)) yield from super().iterdir() + + @property + def parts(self) -> Sequence[str]: + parts = super().parts + if parts and parts[0] == "/": + return parts[1:] + else: + return parts diff --git a/upath/implementations/http.py b/upath/implementations/http.py index 9b49cd12..4e4f4dbd 100644 --- a/upath/implementations/http.py +++ b/upath/implementations/http.py @@ -3,7 +3,6 @@ import sys import warnings from collections.abc import Iterator -from collections.abc import Sequence from itertools import chain from typing import TYPE_CHECKING from typing import Any @@ -38,11 +37,6 @@ def _transform_init_args( args = (f"{protocol}://{str(args[0]).lstrip('/')}", *args[1:]) return args, protocol, storage_options - @property - def parts(self) -> Sequence[str]: - _parts = super().parts - return f"{_parts[0]}/", *_parts[1:] - def __str__(self) -> str: sr = urlsplit(super().__str__()) return sr._replace(path=sr.path or "/").geturl() diff --git a/upath/implementations/local.py b/upath/implementations/local.py index c2065228..8d71c7be 100644 --- a/upath/implementations/local.py +++ b/upath/implementations/local.py @@ -12,6 +12,10 @@ from fsspec import AbstractFileSystem +from upath._chain import DEFAULT_CHAIN_PARSER +from upath._chain import Chain +from upath._chain import ChainSegment +from upath._chain import FSSpecChainParser from upath._protocol import compatible_protocol from upath.core import UPath from upath.core import _UPathMixin @@ -64,13 +68,13 @@ def _warn_protocol_storage_options( class LocalPath(_UPathMixin, pathlib.Path): __slots__ = ( - "_protocol", - "_storage_options", + "_chain", + "_chain_parser", "_fs_cached", ) if TYPE_CHECKING: - _protocol: str - _storage_options: dict[str, Any] + _chain: Chain + _chain_parser: FSSpecChainParser _fs_cached: AbstractFileSystem parser = os.path # type: ignore[misc,assignment] @@ -86,49 +90,60 @@ def _raw_urlpaths(self, value: Sequence[JoinablePathLike]) -> None: if sys.version_info >= (3, 12): def __init__( - self, *args, protocol: str | None = None, **storage_options: Any + self, + *args, + protocol: str | None = None, + chain_parser: FSSpecChainParser = DEFAULT_CHAIN_PARSER, + **storage_options: Any, ) -> None: super(_UPathMixin, self).__init__(*args) - self._protocol = protocol or "" - self._storage_options = storage_options + self._chain = Chain(ChainSegment(str(self), "", storage_options)) + self._chain_parser = chain_parser elif sys.version_info >= (3, 10): def __init__( - self, *args, protocol: str | None = None, **storage_options: Any + self, + *args, + protocol: str | None = None, + chain_parser: FSSpecChainParser = DEFAULT_CHAIN_PARSER, + **storage_options: Any, ) -> None: + # super(_UPathMixin, self).__init__(*args) _warn_protocol_storage_options(type(self), protocol, storage_options) self._drv, self._root, self._parts = self._parse_args(args) # type: ignore[attr-defined] # noqa: E501 - self._protocol = "" - self._storage_options = {} + self._chain = Chain(ChainSegment(str(self), "", {})) + self._chain_parser = chain_parser @classmethod def _from_parts(cls, args): obj = super()._from_parts(args) - obj._protocol = "" - obj._storage_options = {} + obj._chain = Chain(ChainSegment(str(obj), "", {})) return obj @classmethod def _from_parsed_parts(cls, drv, root, parts): obj = super()._from_parsed_parts(drv, root, parts) - obj._protocol = "" - obj._storage_options = {} + obj._chain = Chain(ChainSegment(str(obj), "", {})) return obj else: def __init__( - self, *args, protocol: str | None = None, **storage_options: Any + self, + *args, + protocol: str | None = None, + chain_parser: FSSpecChainParser = DEFAULT_CHAIN_PARSER, + **storage_options: Any, ) -> None: _warn_protocol_storage_options(type(self), protocol, storage_options) self._drv, self._root, self._parts = self._parse_args(args) # type: ignore[attr-defined] # noqa: E501 self._init() + self._chain_parser = chain_parser def _init(self, **kwargs: Any) -> None: super()._init(**kwargs) # type: ignore[misc] - self._protocol = "" - self._storage_options = {} + self._chain = Chain(ChainSegment(str(self), "", {})) def with_segments(self, *pathsegments: str | os.PathLike[str]) -> Self: return type(self)( @@ -170,7 +185,11 @@ class WindowsUPath(LocalPath, pathlib.WindowsPath): if os.name != "nt": def __new__( - cls, *args, protocol: str | None = None, **storage_options: Any + cls, + *args, + protocol: str | None = None, + chain_parser: FSSpecChainParser = DEFAULT_CHAIN_PARSER, + **storage_options: Any, ) -> WindowsUPath: raise NotImplementedError( f"cannot instantiate {cls.__name__} on your system" @@ -183,7 +202,11 @@ class PosixUPath(LocalPath, pathlib.PosixPath): if os.name == "nt": def __new__( - cls, *args, protocol: str | None = None, **storage_options: Any + cls, + *args, + protocol: str | None = None, + chain_parser: FSSpecChainParser = DEFAULT_CHAIN_PARSER, + **storage_options: Any, ) -> PosixUPath: raise NotImplementedError( f"cannot instantiate {cls.__name__} on your system" diff --git a/upath/implementations/webdav.py b/upath/implementations/webdav.py index 48552651..9ec47b8f 100644 --- a/upath/implementations/webdav.py +++ b/upath/implementations/webdav.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Mapping +from collections.abc import Sequence from typing import Any from urllib.parse import urlsplit @@ -60,3 +61,11 @@ def _parse_storage_options( urlpath = url._replace(scheme="", netloc="").geturl() or "/" so.setdefault("base_url", base) return super()._parse_storage_options(urlpath, "webdav", so) + + @property + def parts(self) -> Sequence[str]: + parts = super().parts + if parts and parts[0] == "/": + return parts[1:] + else: + return parts diff --git a/upath/registry.py b/upath/registry.py index fa86d4ab..4ef5a7f3 100644 --- a/upath/registry.py +++ b/upath/registry.py @@ -77,6 +77,7 @@ class _Registry(MutableMapping[str, "type[upath.UPath]"]): "memory": "upath.implementations.memory.MemoryPath", "s3": "upath.implementations.cloud.S3Path", "s3a": "upath.implementations.cloud.S3Path", + "simplecache": "upath.implementations.cached.SimpleCachePath", "sftp": "upath.implementations.sftp.SFTPPath", "ssh": "upath.implementations.sftp.SFTPPath", "webdav": "upath.implementations.webdav.WebdavPath", diff --git a/upath/tests/conftest.py b/upath/tests/conftest.py index 98549961..67d307f3 100644 --- a/upath/tests/conftest.py +++ b/upath/tests/conftest.py @@ -34,6 +34,9 @@ def _strip_protocol(cls, path): path = path[5:] return make_path_posix(path).rstrip("/") or cls.root_marker + def unstrip_protocol(self, path): + return f"mock://{self._strip_protocol(path)}" + @pytest.fixture(scope="session") def clear_registry(): diff --git a/upath/tests/implementations/test_s3.py b/upath/tests/implementations/test_s3.py index 4ca44aa6..c18f089a 100644 --- a/upath/tests/implementations/test_s3.py +++ b/upath/tests/implementations/test_s3.py @@ -78,7 +78,7 @@ def test_no_bucket_joinpath(self, joiner): def test_creating_s3path_with_bucket(self): path = UPath("s3://", bucket="bucket", anon=self.anon, **self.s3so) - assert str(path) == "s3://bucket/" + assert str(path) == "s3://bucket" def test_iterdir_with_plus_in_name(self, s3_with_plus_chr_name): bucket, anon, s3so = s3_with_plus_chr_name diff --git a/upath/tests/test_chain.py b/upath/tests/test_chain.py new file mode 100644 index 00000000..dff872f8 --- /dev/null +++ b/upath/tests/test_chain.py @@ -0,0 +1,89 @@ +from pathlib import Path + +import pytest +from fsspec.implementations.memory import MemoryFileSystem + +from upath import UPath + + +@pytest.mark.parametrize( + "urlpath,expected", + [ + ("simplecache::file://tmp", "simplecache"), + ("zip://file.txt::file://tmp.zip", "zip"), + ], +) +def test_chaining_upath_protocol(urlpath, expected): + pth = UPath(urlpath) + assert pth.protocol == expected + + +@pytest.mark.parametrize( + "urlpath,expected", + [ + ( + "simplecache::file:///tmp", + { + "target_protocol": "file", + "fo": Path("/tmp").absolute().as_posix(), + "target_options": {}, + }, + ), + ], +) +def test_chaining_upath_storage_options(urlpath, expected): + pth = UPath(urlpath) + assert dict(pth.storage_options) == expected + + +@pytest.mark.parametrize( + "urlpath,expected", + [ + ("simplecache::memory://tmp", ("/", "tmp")), + ], +) +def test_chaining_upath_parts(urlpath, expected): + pth = UPath(urlpath) + assert pth.parts == expected + + +@pytest.mark.parametrize( + "urlpath,expected", + [ + ("simplecache::memory:///tmp", "simplecache::memory:///tmp"), + ], +) +def test_chaining_upath_str(urlpath, expected): + pth = UPath(urlpath) + assert str(pth) == expected + + +@pytest.fixture +def clear_memory_fs(): + fs = MemoryFileSystem() + store = fs.store + pseudo_dirs = fs.pseudo_dirs + try: + yield fs + finally: + fs.store.clear() + fs.store.update(store) + fs.pseudo_dirs[:] = pseudo_dirs + + +@pytest.fixture +def memory_file_urlpath(clear_memory_fs): + fs = clear_memory_fs + fs.pipe_file("/abc/file.txt", b"hello world") + yield fs.unstrip_protocol("/abc/file.txt") + + +def test_read_file(memory_file_urlpath): + pth = UPath(f"simplecache::{memory_file_urlpath}") + assert pth.read_bytes() == b"hello world" + + +def test_write_file(clear_memory_fs): + pth = UPath("simplecache::memory://abc.txt") + pth.write_bytes(b"hello world") + assert clear_memory_fs.cat_file("abc.txt") == b"hello world" diff --git a/upath/tests/test_core.py b/upath/tests/test_core.py index 8bc11dd5..02183480 100644 --- a/upath/tests/test_core.py +++ b/upath/tests/test_core.py @@ -5,6 +5,7 @@ import warnings from urllib.parse import SplitResult +import pathlib_abc import pytest from upath import UPath @@ -108,9 +109,11 @@ class MyPath(UPath): DeprecationWarning, match=r"MyPath\(...\) detected protocol '' .*" ): path = MyPath(local_testdir) - assert str(path) == str(pathlib.Path(local_testdir)) + assert str(path) == pathlib.Path(local_testdir).as_posix() assert issubclass(MyPath, UPath) - assert isinstance(path, pathlib.Path) + assert isinstance(path, pathlib_abc.ReadablePath) + assert isinstance(path, pathlib_abc.WritablePath) + assert not isinstance(path, pathlib.Path) def test_subclass_with_gcs(): diff --git a/upath/tests/test_drive_root_anchor_parts.py b/upath/tests/test_drive_root_anchor_parts.py new file mode 100644 index 00000000..71419636 --- /dev/null +++ b/upath/tests/test_drive_root_anchor_parts.py @@ -0,0 +1,70 @@ +from pathlib import Path + +import pytest + +from upath import UPath + +DRIVE_ROOT_ANCHOR_TESTS = [ + # cloud + ("s3://bucket", "bucket", "/", "bucket/", ("bucket/",)), + ("s3://bucket/a", "bucket", "/", "bucket/", ("bucket/", "a")), + ("gs://bucket", "bucket", "/", "bucket/", ("bucket/",)), + ("gs://bucket/a", "bucket", "/", "bucket/", ("bucket/", "a")), + ("az://bucket", "bucket", "/", "bucket/", ("bucket/",)), + ("az://bucket/a", "bucket", "/", "bucket/", ("bucket/", "a")), + # data + ( + "data:text/plain,A%20brief%20note", + "", + "", + "", + ("data:text/plain,A%20brief%20note",), + ), + # github + ("github://user:token@repo/abc", "", "", "", ("abc",)), + # hdfs + ("hdfs://a/b/c", "", "/", "/", ("/", "b", "c")), + ("hdfs:///a/b/c", "", "/", "/", ("/", "a", "b", "c")), + # http + ("http://a/", "http://a", "/", "http://a/", ("http://a/", "")), + ("http://a/b/c", "http://a", "/", "http://a/", ("http://a/", "b", "c")), + ("https://a/b/c", "https://a", "/", "https://a/", ("https://a/", "b", "c")), + # memory + ("memory://a/b/c", "", "/", "/", ("/", "a", "b", "c")), + ("memory:///a/b/c", "", "/", "/", ("/", "a", "b", "c")), + # sftp + ("sftp://a/b/c", "", "/", "/", ("/", "b", "c")), + ("sftp:///a/b/c", "", "/", "/", ("/", "a", "b", "c")), + # smb + ("smb://a/b/c", "", "/", "/", ("/", "b", "c")), + ("smb:///a/b/c", "", "/", "/", ("/", "a", "b", "c")), + # webdav + ("webdav+http://host.com/a/b/c", "", "", "", ("a", "b", "c")), + ("webdav+http://host.com/a/b/c", "", "", "", ("a", "b", "c")), + # local + ( + "file:///a/b/c", + Path("/a/b/c").absolute().drive, + Path("/").absolute().root.replace("\\", "/"), + Path("/").absolute().anchor.replace("\\", "/"), + tuple(x.replace("\\", "/") for x in Path("/a/b/c").absolute().parts), + ), +] + + +@pytest.mark.parametrize( + "path,drive,root,anchor", + [x[0:4] for x in DRIVE_ROOT_ANCHOR_TESTS], +) +def test_drive_root_anchor(path, drive, root, anchor): + p = UPath(path) + assert (p.drive, p.root, p.anchor) == (drive, root, anchor) + + +@pytest.mark.parametrize( + "path,parts", + [(x[0], x[4]) for x in DRIVE_ROOT_ANCHOR_TESTS], +) +def test_parts(path, parts): + p = UPath(path) + assert p.parts == parts diff --git a/upath/tests/test_registry.py b/upath/tests/test_registry.py index 294b1323..3d6e1a38 100644 --- a/upath/tests/test_registry.py +++ b/upath/tests/test_registry.py @@ -29,6 +29,7 @@ "memory", "s3", "s3a", + "simplecache", "sftp", "smb", "ssh",