From 1a117b3a9a1be96836991c954522f3f6be45bb60 Mon Sep 17 00:00:00 2001 From: Andreas Poehlmann Date: Sun, 18 Feb 2024 18:31:08 +0100 Subject: [PATCH] Implement UPath.joinuri (#189) * tests: add tests for query passthrough and joinuri * upath._flavour: add upath_urijoin * upath: add UPath.joinuri method * upath: UPath().name returns last non-empty part --- upath/_flavour.py | 62 ++++++++++++++++++++++++ upath/core.py | 24 +++++++++ upath/tests/implementations/test_http.py | 42 ++++++++++++++++ 3 files changed, 128 insertions(+) diff --git a/upath/_flavour.py b/upath/_flavour.py index 3b64e0f..aba592e 100644 --- a/upath/_flavour.py +++ b/upath/_flavour.py @@ -27,6 +27,7 @@ __all__ = [ "FSSpecFlavour", + "upath_urijoin", ] @@ -299,3 +300,64 @@ def splitroot(p): return splitroot else: raise NotImplementedError(f"unsupported module: {mod!r}") + + +def upath_urijoin(base: str, uri: str) -> str: + """Join a base URI and a possibly relative URI to form an absolute + interpretation of the latter.""" + # see: + # https://github.com/python/cpython/blob/ae6c01d9d2/Lib/urllib/parse.py#L539-L605 + # modifications: + # - removed allow_fragments parameter + # - all schemes are considered to allow relative paths + # - all schemes are considered to allow netloc (revisit this) + # - no bytes support (removes encoding and decoding) + if not base: + return uri + if not uri: + return base + + bs = urlsplit(base, scheme="") + us = urlsplit(uri, scheme=bs.scheme) + + if us.scheme != bs.scheme: # or us.scheme not in uses_relative: + return uri + # if us.scheme in uses_netloc: + if us.netloc: + return us.geturl() + else: + us = us._replace(netloc=bs.netloc) + # end if + if not us.path and not us.fragment: + us = us._replace(path=bs.path, fragment=bs.fragment) + if not us.query: + us = us._replace(query=bs.query) + return us.geturl() + + base_parts = bs.path.split("/") + if base_parts[-1] != "": + del base_parts[-1] + + if us.path[:1] == "/": + segments = us.path.split("/") + else: + segments = base_parts + us.path.split("/") + segments[1:-1] = filter(None, segments[1:-1]) + + resolved_path = [] + + for seg in segments: + if seg == "..": + try: + resolved_path.pop() + except IndexError: + pass + elif seg == ".": + continue + else: + resolved_path.append(seg) + + if segments[-1] in (".", ".."): + resolved_path.append("") + + return us._replace(path="/".join(resolved_path) or "/").geturl() diff --git a/upath/core.py b/upath/core.py index c541fb2..31343c1 100644 --- a/upath/core.py +++ b/upath/core.py @@ -20,6 +20,7 @@ from upath._compat import str_remove_prefix from upath._compat import str_remove_suffix from upath._flavour import FSSpecFlavour +from upath._flavour import upath_urijoin from upath._protocol import get_upath_protocol from upath._stat import UPathStatResult from upath.registry import get_upath_class @@ -253,6 +254,18 @@ def fs(self) -> AbstractFileSystem: def path(self) -> str: return super().__str__() + def joinuri(self, uri: str | os.PathLike[str]) -> UPath: + """Join with urljoin behavior for UPath instances""" + # short circuit if the new uri uses a different protocol + other_protocol = get_upath_protocol(uri) + if other_protocol and other_protocol != self._protocol: + return UPath(uri) + return UPath( + upath_urijoin(str(self), str(uri)), + protocol=other_protocol or self._protocol, + **self.storage_options, + ) + # === upath.UPath CUSTOMIZABLE API ================================ @classmethod @@ -590,6 +603,17 @@ def is_relative_to(self, other, /, *_deprecated): return False return super().is_relative_to(other, *_deprecated) + @property + def name(self): + tail = self._tail + if not tail: + return "" + name = tail[-1] + if not name and len(tail) >= 2: + return tail[-2] + else: + return name + # === pathlib.Path ================================================ def stat(self, *, follow_symlinks=True) -> UPathStatResult: diff --git a/upath/tests/implementations/test_http.py b/upath/tests/implementations/test_http.py index 7541780..6effd5c 100644 --- a/upath/tests/implementations/test_http.py +++ b/upath/tests/implementations/test_http.py @@ -143,3 +143,45 @@ def test_empty_parts(args, parts): pth = UPath(args) pth_parts = pth.parts assert pth_parts == parts + + +def test_query_parameters_passthrough(): + pth = UPath("http://example.com/?a=1&b=2") + assert pth.parts == ("http://example.com/", "?a=1&b=2") + + +@pytest.mark.parametrize( + "base,rel,expected", + [ + ( + "http://www.example.com/a/b/index.html", + "image.png?version=1", + "http://www.example.com/a/b/image.png?version=1", + ), + ( + "http://www.example.com/a/b/index.html", + "../image.png", + "http://www.example.com/a/image.png", + ), + ( + "http://www.example.com/a/b/index.html", + "/image.png", + "http://www.example.com/image.png", + ), + ( + "http://www.example.com/a/b/index.html", + "ftp://other.com/image.png", + "ftp://other.com/image.png", + ), + ( + "http://www.example.com/a/b/index.html", + "//other.com/image.png", + "http://other.com/image.png", + ), + ], +) +def test_joinuri_behavior(base, rel, expected): + p0 = UPath(base) + pr = p0.joinuri(rel) + pe = UPath(expected) + assert pr == pe