Skip to content
2 changes: 1 addition & 1 deletion airflow-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ dependencies = [
"termcolor>=3.0.0",
"typing-extensions>=4.14.1",
# https://github.com/apache/airflow/issues/56369 , rework universal-pathlib usage
"universal-pathlib>=0.2.6,<0.3.0",
"universal-pathlib>=0.3.8",
"uuid6>=2024.7.10",
"apache-airflow-task-sdk<1.3.0,>=1.2.0",
# pre-installed providers
Expand Down
2 changes: 2 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
aarch
abc
AbstractFileSystem
accessor
AccessSecretVersionResponse
accountmaking
Expand Down Expand Up @@ -733,6 +734,7 @@ fqdn
frontend
fs
fsGroup
fsspec
fullname
func
Fundera
Expand Down
123 changes: 76 additions & 47 deletions task-sdk/src/airflow/sdk/io/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,21 @@

from __future__ import annotations

import contextlib
import os
import shutil
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, ClassVar
from urllib.parse import urlsplit

from fsspec.utils import stringify_path
from upath.implementations.cloud import CloudPath
from upath.registry import get_upath_class
from upath import UPath
from upath.extensions import ProxyUPath

from airflow.sdk.io.stat import stat_result
from airflow.sdk.io.store import attach

if TYPE_CHECKING:
from fsspec import AbstractFileSystem
from typing_extensions import Self
from upath.types import JoinablePathLike


class _TrackingFileWrapper:
Expand Down Expand Up @@ -77,42 +76,48 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self._obj.__exit__(exc_type, exc_val, exc_tb)


class ObjectStoragePath(CloudPath):
class ObjectStoragePath(ProxyUPath):
"""A path-like object for object storage."""

__version__: ClassVar[int] = 1

_protocol_dispatch = False

sep: ClassVar[str] = "/"
root_marker: ClassVar[str] = "/"

__slots__ = ("_hash_cached",)

@classmethod
def _transform_init_args(
cls,
args: tuple[str | os.PathLike, ...],
protocol: str,
storage_options: dict[str, Any],
) -> tuple[tuple[str | os.PathLike, ...], str, dict[str, Any]]:
"""Extract conn_id from the URL and set it as a storage option."""
def __init__(
self,
*args: JoinablePathLike,
protocol: str | None = None,
conn_id: str | None = None,
**storage_options: Any,
) -> None:
# ensure conn_id is always set in storage_options
storage_options.setdefault("conn_id", None)
# parse conn_id from args if provided
if args:
arg0 = args[0]
parsed_url = urlsplit(stringify_path(arg0))
userinfo, have_info, hostinfo = parsed_url.netloc.rpartition("@")
if have_info:
storage_options.setdefault("conn_id", userinfo or None)
parsed_url = parsed_url._replace(netloc=hostinfo)
args = (parsed_url.geturl(),) + args[1:]
protocol = protocol or parsed_url.scheme
return args, protocol, storage_options
if isinstance(arg0, type(self)):
storage_options["conn_id"] = arg0.storage_options.get("conn_id")
else:
parsed_url = urlsplit(stringify_path(arg0))
userinfo, have_info, hostinfo = parsed_url.netloc.rpartition("@")
if have_info:
conn_id = storage_options["conn_id"] = userinfo or None
parsed_url = parsed_url._replace(netloc=hostinfo)
args = (parsed_url.geturl(),) + args[1:]
protocol = protocol or parsed_url.scheme
# override conn_id if explicitly provided
if conn_id is not None:
storage_options["conn_id"] = conn_id
super().__init__(*args, protocol=protocol, **storage_options)

@classmethod
def _fs_factory(
cls, urlpath: str, protocol: str, storage_options: Mapping[str, Any]
) -> AbstractFileSystem:
return attach(protocol or "file", storage_options.get("conn_id")).fs
@property
def fs(self) -> AbstractFileSystem:
"""Return the filesystem for this path, using airflow's attach mechanism."""
conn_id = self.storage_options.get("conn_id")
return attach(self.protocol or "file", conn_id).fs

def __hash__(self) -> int:
self._hash_cached: int
Expand Down Expand Up @@ -181,12 +186,7 @@ def samefile(self, other_path: Any) -> bool:
and st["ino"] == other_st["ino"]
)

def _scandir(self):
# Emulate os.scandir(), which returns an object that can be used as a
# context manager.
return contextlib.nullcontext(self.iterdir())

def replace(self, target) -> ObjectStoragePath:
def replace(self, target) -> Self:
"""
Rename this path to the target path, overwriting if that path exists.

Expand All @@ -199,16 +199,12 @@ def replace(self, target) -> ObjectStoragePath:
return self.rename(target)

@classmethod
def cwd(cls):
if cls is ObjectStoragePath:
return get_upath_class("").cwd()
raise NotImplementedError
def cwd(cls) -> Self:
return cls._from_upath(UPath.cwd())

@classmethod
def home(cls):
if cls is ObjectStoragePath:
return get_upath_class("").home()
raise NotImplementedError
def home(cls) -> Self:
return cls._from_upath(UPath.home())

# EXTENDED OPERATIONS

Expand Down Expand Up @@ -299,7 +295,7 @@ def _cp_file(self, dst: ObjectStoragePath, **kwargs):
# make use of system dependent buffer size
shutil.copyfileobj(f1, f2, **kwargs)

def copy(self, dst: str | ObjectStoragePath, recursive: bool = False, **kwargs) -> None:
def copy(self, dst: str | ObjectStoragePath, recursive: bool = False, **kwargs) -> None: # type: ignore[override]
"""
Copy file(s) from this path to another location.

Expand Down Expand Up @@ -370,7 +366,23 @@ def copy(self, dst: str | ObjectStoragePath, recursive: bool = False, **kwargs)
# remote file -> remote dir
self._cp_file(dst, **kwargs)

def move(self, path: str | ObjectStoragePath, recursive: bool = False, **kwargs) -> None:
def copy_into(self, target_dir: str | ObjectStoragePath, recursive: bool = False, **kwargs) -> None: # type: ignore[override]
"""
Copy file(s) from this path into another directory.

:param target_dir: Destination directory
:param recursive: If True, copy directories recursively.

kwargs: Additional keyword arguments to be passed to the underlying implementation.
"""
if isinstance(target_dir, str):
target_dir = ObjectStoragePath(target_dir)
if not target_dir.is_dir():
raise NotADirectoryError(f"Destination {target_dir} is not a directory.")
dst_path = target_dir / self.name
self.copy(dst_path, recursive=recursive, **kwargs)

def move(self, path: str | ObjectStoragePath, recursive: bool = False, **kwargs) -> None: # type: ignore[override]
"""
Move file(s) from this path to another location.

Expand All @@ -394,6 +406,23 @@ def move(self, path: str | ObjectStoragePath, recursive: bool = False, **kwargs)
self.copy(path, recursive=recursive, **kwargs)
self.unlink()

def move_into(self, target_dir: str | ObjectStoragePath, recursive: bool = False, **kwargs) -> None: # type: ignore[override]
"""
Move file(s) from this path into another directory.

:param target_dir: Destination directory
:param recursive: bool
If True, move directories recursively.

kwargs: Additional keyword arguments to be passed to the underlying implementation.
"""
if isinstance(target_dir, str):
target_dir = ObjectStoragePath(target_dir)
if not target_dir.is_dir():
raise NotADirectoryError(f"Destination {target_dir} is not a directory.")
dst_path = target_dir / self.name
self.move(dst_path, recursive=recursive, **kwargs)

def serialize(self) -> dict[str, Any]:
_kwargs = {**self.storage_options}
conn_id = _kwargs.pop("conn_id", None)
Expand All @@ -417,6 +446,6 @@ def deserialize(cls, data: dict, version: int) -> ObjectStoragePath:

def __str__(self):
conn_id = self.storage_options.get("conn_id")
if self._protocol and conn_id:
return f"{self._protocol}://{conn_id}@{self.path}"
if self.protocol and conn_id:
return f"{self.protocol}://{conn_id}@{self.path}"
return super().__str__()
34 changes: 26 additions & 8 deletions task-sdk/tests/task_sdk/io/test_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,19 @@ def test_home():
def test_lazy_load():
o = ObjectStoragePath("file:///tmp/foo")
with pytest.raises(AttributeError):
assert o._fs_cached
assert o.__wrapped__._fs_cached

# ObjectStoragePath overrides .fs and provides cached filesystems via the STORE_CACHE
assert o.fs is not None
assert o._fs_cached

with pytest.raises(AttributeError):
assert o.__wrapped__._fs_cached
# Clear the cache to avoid side effects in other tests below
_STORE_CACHE.clear()


class _FakeRemoteFileSystem(MemoryFileSystem):
protocol = ("s3", "fakefs", "ffs", "ffs2")
protocol = ("s3", "fake", "fakefs", "ffs", "ffs2")
root_marker = ""
store: ClassVar[dict[str, Any]] = {}
pseudo_dirs = [""]
Expand All @@ -99,6 +102,21 @@ def _strip_protocol(cls, path):
return path


@pytest.fixture(scope="module", autouse=True)
def register_fake_remote_filesystem():
# Register the fake filesystem with fsspec so UPath can discover it
from fsspec.registry import _registry as fsspec_implementation_registry, register_implementation

old_registry = fsspec_implementation_registry.copy()
try:
for proto in _FakeRemoteFileSystem.protocol:
register_implementation(proto, _FakeRemoteFileSystem, clobber=True)
yield
finally:
fsspec_implementation_registry.clear()
fsspec_implementation_registry.update(old_registry)


class TestAttach:
FAKE = "ffs:///fake"
MNT = "ffs:///mnt/warehouse"
Expand Down Expand Up @@ -168,10 +186,6 @@ def test_standard_extended_api(self, fake_files, fn, args, fn2, path, expected_a


class TestRemotePath:
@pytest.fixture(autouse=True)
def fake_fs(self, monkeypatch):
monkeypatch.setattr(ObjectStoragePath, "_fs_factory", lambda *a, **k: _FakeRemoteFileSystem())

def test_bucket_key_protocol(self):
bucket = "bkt"
key = "yek"
Expand Down Expand Up @@ -262,7 +276,11 @@ def test_relative_to(self, tmp_path, target):
o1 = ObjectStoragePath(f"file://{target}")
o2 = ObjectStoragePath(f"file://{tmp_path.as_posix()}")
o3 = ObjectStoragePath(f"file:///{uuid.uuid4()}")
assert o1.relative_to(o2) == o1
# relative_to returns the relative path from o2 to o1
relative = o1.relative_to(o2)
# The relative path should be the basename (uuid) of the target
expected_relative = target.split("/")[-1]
assert str(relative) == expected_relative
with pytest.raises(ValueError, match="is not in the subpath of"):
o1.relative_to(o3)

Expand Down