diff --git a/airflow-core/docs/core-concepts/objectstorage.rst b/airflow-core/docs/core-concepts/objectstorage.rst index 7811b761b4d0c..8c79dab2b22ad 100644 --- a/airflow-core/docs/core-concepts/objectstorage.rst +++ b/airflow-core/docs/core-concepts/objectstorage.rst @@ -71,7 +71,7 @@ object you want to interact with. For example, to point to a bucket in s3, you w .. code-block:: python - from airflow.io.path import ObjectStoragePath + from airflow.sdk import ObjectStoragePath base = ObjectStoragePath("s3://aws_default@my-bucket/") @@ -153,8 +153,8 @@ would do the following: .. code-block:: python - from airflow.io.path import ObjectStoragePath - from airflow.io.store import attach + from airflow.sdk import ObjectStoragePath + from airflow.sdk.io import attach from fsspec.implementations.dbfs import DBFSFileSystem @@ -163,7 +163,7 @@ would do the following: .. note:: - To reuse the registration across tasks make sure to attach the backend at the top-level of your DAG. + To reuse the registration across tasks, make sure to attach the backend at the top-level of your DAG. Otherwise, the backend will not be available across multiple tasks. @@ -177,7 +177,7 @@ and builds upon `Universal Pathlib the same API to interact with object storage as you would with a local filesystem. In this section we only list the differences between the two APIs. Extended operations beyond the standard Path API, like copying and moving, are listed in the next section. For details about each operation, like what arguments they take, see the documentation of -the :class:`~airflow.io.path.ObjectStoragePath` class. +the :class:`~airflow.sdk.definitions.path.ObjectStoragePath` class. mkdir @@ -321,7 +321,7 @@ are used to connect to s3 and a parquet file, indicated by a ``ObjectStoragePath .. code-block:: python import duckdb - from airflow.io.path import ObjectStoragePath + from airflow.sdk import ObjectStoragePath path = ObjectStoragePath("s3://my-bucket/my-table.parquet", conn_id="aws_default") conn = duckdb.connect(database=":memory:") diff --git a/airflow-core/newsfragments/45425.significant.rst b/airflow-core/newsfragments/45425.significant.rst new file mode 100644 index 0000000000000..8fb3c3d477083 --- /dev/null +++ b/airflow-core/newsfragments/45425.significant.rst @@ -0,0 +1,21 @@ +The ``airflow.io`` class ``ObjectStoragePath`` and function ``attach`` are moved to ``airflow.sdk``. + +* Types of change + + * [x] Dag changes + * [ ] Config changes + * [ ] API changes + * [ ] CLI changes + * [ ] Behaviour changes + * [ ] Plugin changes + * [ ] Dependency changes + * [x] Code interface changes + +* Migration rules needed + + * ruff + + * AIR302 + + * [ ] ``airflow.io.path.ObjectStoragePath`` → ``airflow.sdk.ObjectStoragePath`` + * [ ] ``airflow.io.attach`` → ``airflow.sdk.io.attach`` diff --git a/airflow-core/pyproject.toml b/airflow-core/pyproject.toml index 61b67d19c4504..527ddec8d854a 100644 --- a/airflow-core/pyproject.toml +++ b/airflow-core/pyproject.toml @@ -84,7 +84,6 @@ dependencies = [ # 0.115.10 fastapi was a bad release that broke our API's and static checks. # Related fastapi issue here: https://github.com/fastapi/fastapi/discussions/13431 "fastapi[standard]>=0.115.0,!=0.115.10", - "fsspec>=2023.10.0", "gitpython>=3.1.40", "gunicorn>=20.1.0", "httpx>=0.25.0", diff --git a/airflow-core/src/airflow/example_dags/tutorial_objectstorage.py b/airflow-core/src/airflow/example_dags/tutorial_objectstorage.py index 4c2efca631947..8d5d57626ff38 100644 --- a/airflow-core/src/airflow/example_dags/tutorial_objectstorage.py +++ b/airflow-core/src/airflow/example_dags/tutorial_objectstorage.py @@ -22,8 +22,7 @@ import pendulum import requests -from airflow.io.path import ObjectStoragePath -from airflow.sdk import dag, task +from airflow.sdk import ObjectStoragePath, dag, task # [END import_module] diff --git a/airflow-core/src/airflow/io/path.py b/airflow-core/src/airflow/io/path.py index 5a68789f8ebac..bc323d0030bc5 100644 --- a/airflow-core/src/airflow/io/path.py +++ b/airflow-core/src/airflow/io/path.py @@ -14,405 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import annotations - -import contextlib -import os -import shutil -import typing -from collections.abc import Mapping -from typing import Any -from urllib.parse import urlsplit - -from fsspec.utils import stringify_path -from upath.implementations.cloud import CloudPath -from upath.registry import get_upath_class - -from airflow.io.store import attach -from airflow.io.utils.stat import stat_result -from airflow.lineage.hook import get_hook_lineage_collector -from airflow.utils.log.logging_mixin import LoggingMixin - -if typing.TYPE_CHECKING: - from fsspec import AbstractFileSystem - - -PT = typing.TypeVar("PT", bound="ObjectStoragePath") - -default = "file" - - -class TrackingFileWrapper(LoggingMixin): - """Wrapper that tracks file operations to intercept lineage.""" - - def __init__(self, path: ObjectStoragePath, obj): - super().__init__() - self._path: ObjectStoragePath = path - self._obj = obj - - def __getattr__(self, name): - attr = getattr(self._obj, name) - if callable(attr): - # If the attribute is a method, wrap it in another method to intercept the call - def wrapper(*args, **kwargs): - self.log.debug("Calling method: %s", name) - if name == "read": - get_hook_lineage_collector().add_input_asset(context=self._path, uri=str(self._path)) - elif name == "write": - get_hook_lineage_collector().add_output_asset(context=self._path, uri=str(self._path)) - result = attr(*args, **kwargs) - return result - - return wrapper - return attr - - def __getitem__(self, key): - # Intercept item access - return self._obj[key] - - def __enter__(self): - self._obj.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self._obj.__exit__(exc_type, exc_val, exc_tb) - - -class ObjectStoragePath(CloudPath): - """A path-like object for object storage.""" - - __version__: typing.ClassVar[int] = 1 - - _protocol_dispatch = False - - sep: typing.ClassVar[str] = "/" - root_marker: typing.ClassVar[str] = "/" - - __slots__ = ("_hash_cached",) - - @classmethod - def _transform_init_args( - cls, - args: tuple[str | os.PathLike, ...], - protocol: str, - storage_options: dict[str, Any], - ) -> tuple[tuple[str | os.PathLike, ...], str, dict[str, Any]]: - """Extract conn_id from the URL and set it as a storage option.""" - if args: - arg0 = args[0] - parsed_url = urlsplit(stringify_path(arg0)) - userinfo, have_info, hostinfo = parsed_url.netloc.rpartition("@") - if have_info: - storage_options.setdefault("conn_id", userinfo or None) - parsed_url = parsed_url._replace(netloc=hostinfo) - args = (parsed_url.geturl(),) + args[1:] - protocol = protocol or parsed_url.scheme - return args, protocol, storage_options - - @classmethod - def _fs_factory( - cls, urlpath: str, protocol: str, storage_options: Mapping[str, Any] - ) -> AbstractFileSystem: - return attach(protocol or "file", storage_options.get("conn_id")).fs - - def __hash__(self) -> int: - self._hash_cached: int - try: - return self._hash_cached - except AttributeError: - self._hash_cached = hash(str(self)) - return self._hash_cached - - def __eq__(self, other: typing.Any) -> bool: - return self.samestore(other) and str(self) == str(other) - - def samestore(self, other: typing.Any) -> bool: - return ( - isinstance(other, ObjectStoragePath) - and self.protocol == other.protocol - and self.storage_options.get("conn_id") == other.storage_options.get("conn_id") - ) - - @property - def container(self) -> str: - return self.bucket - - @property - def bucket(self) -> str: - if self._url: - return self._url.netloc - else: - return "" - - @property - def key(self) -> str: - if self._url: - # per convention, we strip the leading slashes to ensure a relative key is returned - # we keep the trailing slash to allow for directory-like semantics - return self._url.path.lstrip(self.sep) - else: - return "" - - @property - def namespace(self) -> str: - return f"{self.protocol}://{self.bucket}" if self.bucket else self.protocol - - def open(self, mode="r", **kwargs): - """Open the file pointed to by this path.""" - kwargs.setdefault("block_size", kwargs.pop("buffering", None)) - return TrackingFileWrapper(self, self.fs.open(self.path, mode=mode, **kwargs)) - - def stat(self) -> stat_result: # type: ignore[override] - """Call ``stat`` and return the result.""" - return stat_result( - self.fs.stat(self.path), - protocol=self.protocol, - conn_id=self.storage_options.get("conn_id"), - ) - - def samefile(self, other_path: typing.Any) -> bool: - """Return whether other_path is the same or not as this file.""" - if not isinstance(other_path, ObjectStoragePath): - return False - - st = self.stat() - other_st = other_path.stat() - - return ( - st["protocol"] == other_st["protocol"] - and st["conn_id"] == other_st["conn_id"] - and st["ino"] == other_st["ino"] - ) - - def _scandir(self): - # Emulate os.scandir(), which returns an object that can be used as a - # context manager. - return contextlib.nullcontext(self.iterdir()) - - def replace(self, target) -> ObjectStoragePath: - """ - Rename this path to the target path, overwriting if that path exists. - - The target path may be absolute or relative. Relative paths are - interpreted relative to the current working directory, *not* the - directory of the Path object. - - Returns the new Path instance pointing to the target path. - """ - return self.rename(target) - - @classmethod - def cwd(cls): - if cls is ObjectStoragePath: - return get_upath_class("").cwd() - else: - raise NotImplementedError - - @classmethod - def home(cls): - if cls is ObjectStoragePath: - return get_upath_class("").home() - else: - raise NotImplementedError - - # EXTENDED OPERATIONS - def ukey(self) -> str: - """Hash of file properties, to tell if it has changed.""" - return self.fs.ukey(self.path) - - def checksum(self) -> int: - """Return the checksum of the file at this path.""" - # we directly access the fs here to avoid changing the abstract interface - return self.fs.checksum(self.path) - - def read_block(self, offset: int, length: int, delimiter=None): - r""" - Read a block of bytes. - - Starting at ``offset`` of the file, read ``length`` bytes. If - ``delimiter`` is set then we ensure that the read starts and stops at - delimiter boundaries that follow the locations ``offset`` and ``offset - + length``. If ``offset`` is zero then we start at zero. The - bytestring returned WILL include the end delimiter string. - - If offset+length is beyond the eof, reads to eof. - - :param offset: int - Byte offset to start read - :param length: int - Number of bytes to read. If None, read to the end. - :param delimiter: bytes (optional) - Ensure reading starts and stops at delimiter bytestring - - Examples - -------- - >>> read_block(0, 13) - b'Alice, 100\\nBo' - >>> read_block(0, 13, delimiter=b"\\n") - b'Alice, 100\\nBob, 200\\n' - - Use ``length=None`` to read to the end of the file. - >>> read_block(0, None, delimiter=b"\\n") - b'Alice, 100\\nBob, 200\\nCharlie, 300' - - See Also - -------- - :func:`fsspec.utils.read_block` - """ - return self.fs.read_block(self.path, offset=offset, length=length, delimiter=delimiter) - - def sign(self, expiration: int = 100, **kwargs): - """ - Create a signed URL representing the given path. - - Some implementations allow temporary URLs to be generated, as a - way of delegating credentials. - - :param path: str - The path on the filesystem - :param expiration: int - Number of seconds to enable the URL for (if supported) - - :returns URL: str - The signed URL - - :raises NotImplementedError: if the method is not implemented for a store - """ - return self.fs.sign(self.path, expiration=expiration, **kwargs) - - def size(self) -> int: - """Size in bytes of the file at this path.""" - return self.fs.size(self.path) - - def _cp_file(self, dst: ObjectStoragePath, **kwargs): - """Copy a single file from this path to another location by streaming the data.""" - # create the directory or bucket if required - if dst.key.endswith(self.sep) or not dst.key: - dst.mkdir(exist_ok=True, parents=True) - dst = dst / self.key - elif dst.is_dir(): - dst = dst / self.key - - # streaming copy - with self.open("rb") as f1, dst.open("wb") as f2: - # make use of system dependent buffer size - shutil.copyfileobj(f1, f2, **kwargs) - - def copy(self, dst: str | ObjectStoragePath, recursive: bool = False, **kwargs) -> None: - """ - Copy file(s) from this path to another location. - - For remote to remote copies, the key used for the destination will be the same as the source. - So that s3://src_bucket/foo/bar will be copied to gcs://dst_bucket/foo/bar and not - gcs://dst_bucket/bar. - - :param dst: Destination path - :param recursive: If True, copy directories recursively. - - kwargs: Additional keyword arguments to be passed to the underlying implementation. - """ - if isinstance(dst, str): - dst = ObjectStoragePath(dst) - - if self.samestore(dst) or self.protocol == "file" or dst.protocol == "file": - # only emit this in "optimized" variants - else lineage will be captured by file writes/reads - get_hook_lineage_collector().add_input_asset(context=self, uri=str(self)) - get_hook_lineage_collector().add_output_asset(context=dst, uri=str(dst)) - - # same -> same - if self.samestore(dst): - self.fs.copy(self.path, dst.path, recursive=recursive, **kwargs) - return - - # use optimized path for local -> remote or remote -> local - if self.protocol == "file": - dst.fs.put(self.path, dst.path, recursive=recursive, **kwargs) - return - - if dst.protocol == "file": - self.fs.get(self.path, dst.path, recursive=recursive, **kwargs) - return - - if not self.exists(): - raise FileNotFoundError(f"{self} does not exist") - - # remote dir -> remote dir - if self.is_dir(): - if dst.is_file(): - raise ValueError("Cannot copy directory to a file.") - - dst.mkdir(exist_ok=True, parents=True) - - out = self.fs.expand_path(self.path, recursive=True, **kwargs) - - for path in out: - # this check prevents one extra call to is_dir() as - # glob returns self as well - if path == self.path: - continue - - src_obj = ObjectStoragePath( - path, - protocol=self.protocol, - conn_id=self.storage_options.get("conn_id"), - ) - - # skip directories, empty directories will not be created - if src_obj.is_dir(): - continue - - src_obj._cp_file(dst) - return - - # remote file -> remote dir - self._cp_file(dst, **kwargs) - - def move(self, path: str | ObjectStoragePath, recursive: bool = False, **kwargs) -> None: - """ - Move file(s) from this path to another location. - - :param path: Destination path - :param recursive: bool - If True, move directories recursively. - - kwargs: Additional keyword arguments to be passed to the underlying implementation. - """ - if isinstance(path, str): - path = ObjectStoragePath(path) - - if self.samestore(path): - get_hook_lineage_collector().add_input_asset(context=self, uri=str(self)) - get_hook_lineage_collector().add_output_asset(context=path, uri=str(path)) - return self.fs.move(self.path, path.path, recursive=recursive, **kwargs) - - # non-local copy - self.copy(path, recursive=recursive, **kwargs) - self.unlink() - - def serialize(self) -> dict[str, typing.Any]: - _kwargs = {**self.storage_options} - conn_id = _kwargs.pop("conn_id", None) - - return { - "path": str(self), - "conn_id": conn_id, - "kwargs": _kwargs, - } - - @classmethod - def deserialize(cls, data: dict, version: int) -> ObjectStoragePath: - if version > cls.__version__: - raise ValueError(f"Cannot deserialize version {version} with version {cls.__version__}.") - - _kwargs = data.pop("kwargs") - path = data.pop("path") - conn_id = data.pop("conn_id", None) +from __future__ import annotations - return ObjectStoragePath(path, conn_id=conn_id, **_kwargs) +from airflow.sdk import ObjectStoragePath - def __str__(self): - conn_id = self.storage_options.get("conn_id") - if self._protocol and conn_id: - return f"{self._protocol}://{conn_id}@{self.path}" - return super().__str__() +__all__ = ["ObjectStoragePath"] diff --git a/airflow-core/src/airflow/io/storage.py b/airflow-core/src/airflow/io/storage.py new file mode 100644 index 0000000000000..4723e8a15f65a --- /dev/null +++ b/airflow-core/src/airflow/io/storage.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.sdk.io import attach + +__all__ = ["attach"] diff --git a/airflow-core/src/airflow/lineage/hook.py b/airflow-core/src/airflow/lineage/hook.py index 796deb8466b91..9aeb65c277147 100644 --- a/airflow-core/src/airflow/lineage/hook.py +++ b/airflow-core/src/airflow/lineage/hook.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: from airflow.hooks.base import BaseHook - from airflow.io.path import ObjectStoragePath + from airflow.sdk import ObjectStoragePath # Store context what sent lineage. LineageContext = Union[BaseHook, ObjectStoragePath] diff --git a/airflow-core/tests/unit/io/test_path.py b/airflow-core/tests/unit/io/test_path.py deleted file mode 100644 index f079357c398ff..0000000000000 --- a/airflow-core/tests/unit/io/test_path.py +++ /dev/null @@ -1,440 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import uuid -from stat import S_ISDIR, S_ISREG -from tempfile import NamedTemporaryFile -from typing import Any, ClassVar -from unittest import mock - -import pytest -from fsspec.implementations.local import LocalFileSystem -from fsspec.implementations.memory import MemoryFileSystem -from fsspec.registry import _registry as _fsspec_registry, register_implementation - -from airflow.io import _register_filesystems, get_fs -from airflow.io.path import ObjectStoragePath -from airflow.io.store import _STORE_CACHE, ObjectStore, attach -from airflow.sdk.definitions.asset import Asset -from airflow.utils.module_loading import qualname - -FAKE = "file:///fake" -MNT = "file:///mnt/warehouse" -FOO = "file:///mnt/warehouse/foo" -BAR = FOO - - -class FakeLocalFileSystem(MemoryFileSystem): - protocol = ("file", "local") - root_marker = "/" - store: ClassVar[dict[str, Any]] = {} - pseudo_dirs = [""] - - def __init__(self, *args, **kwargs): - self.conn_id = kwargs.pop("conn_id", None) - super().__init__(*args, **kwargs) - - @classmethod - def _strip_protocol(cls, path): - for protocol in cls.protocol: - if path.startswith(f"{protocol}://"): - return path[len(f"{protocol}://") :] - if "::" in path or "://" in path: - return path.rstrip("/") - path = path.lstrip("/").rstrip("/") - return path - - -class FakeRemoteFileSystem(MemoryFileSystem): - protocol = ("s3", "fakefs", "ffs", "ffs2") - root_marker = "" - store: ClassVar[dict[str, Any]] = {} - pseudo_dirs = [""] - - def __init__(self, *args, **kwargs): - self.conn_id = kwargs.pop("conn_id", None) - super().__init__(*args, **kwargs) - - @classmethod - def _strip_protocol(cls, path): - for protocol in cls.protocol: - if path.startswith(f"{protocol}://"): - return path[len(f"{protocol}://") :] - if "::" in path or "://" in path: - return path.rstrip("/") - path = path.lstrip("/").rstrip("/") - return path - - -def get_fs_no_storage_options(_: str): - return LocalFileSystem() - - -class TestFs: - def setup_class(self): - self._store_cache = _STORE_CACHE.copy() - self._fsspec_registry = _fsspec_registry.copy() - for protocol in FakeRemoteFileSystem.protocol: - register_implementation(protocol, FakeRemoteFileSystem, clobber=True) - - def teardown(self): - _STORE_CACHE.clear() - _STORE_CACHE.update(self._store_cache) - _fsspec_registry.clear() - _fsspec_registry.update(self._fsspec_registry) - - def test_alias(self): - store = attach("file", alias="local") - assert isinstance(store.fs, LocalFileSystem) - assert "local" in _STORE_CACHE - - def test_init_objectstoragepath(self): - attach("s3", fs=FakeRemoteFileSystem()) - - path = ObjectStoragePath("s3://bucket/key/part1/part2") - assert path.bucket == "bucket" - assert path.key == "key/part1/part2" - assert path.protocol == "s3" - assert path.path == "bucket/key/part1/part2" - - path2 = ObjectStoragePath(path / "part3") - assert path2.bucket == "bucket" - assert path2.key == "key/part1/part2/part3" - assert path2.protocol == "s3" - assert path2.path == "bucket/key/part1/part2/part3" - - path3 = ObjectStoragePath(path2 / "2023") - assert path3.bucket == "bucket" - assert path3.key == "key/part1/part2/part3/2023" - assert path3.protocol == "s3" - assert path3.path == "bucket/key/part1/part2/part3/2023" - - def test_read_write(self): - o = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}") - - with o.open("wb") as f: - f.write(b"foo") - - assert o.open("rb").read() == b"foo" - - o.unlink() - - @pytest.mark.parametrize( - "scheme", - [ - pytest.param("local", id="local_scheme"), - pytest.param("file", id="file_scheme"), - ], - ) - def test_ls(self, scheme): - dirname = str(uuid.uuid4()) - filename = str(uuid.uuid4()) - - d = ObjectStoragePath(f"{scheme}:///tmp/{dirname}") - d.mkdir(parents=True) - o = d / filename - o.touch() - - data = list(d.iterdir()) - assert len(data) == 1 - assert data[0] == o - - d.rmdir(recursive=True) - - assert not o.exists() - - def test_objectstoragepath_init_conn_id_in_uri(self): - attach(protocol="fake", conn_id="fake", fs=FakeRemoteFileSystem(conn_id="fake")) - p = ObjectStoragePath("fake://fake@bucket/path") - p.touch() - fsspec_info = p.fs.info(p.path) - assert p.stat() == {**fsspec_info, "conn_id": "fake", "protocol": "fake"} - - @pytest.fixture - def fake_local_files(self): - obj = FakeLocalFileSystem() - obj.touch(FOO) - try: - yield - finally: - FakeLocalFileSystem.store.clear() - FakeLocalFileSystem.pseudo_dirs[:] = [""] - - @pytest.mark.parametrize( - "fn, args, fn2, path, expected_args, expected_kwargs", - [ - ("checksum", {}, "checksum", FOO, FakeLocalFileSystem._strip_protocol(BAR), {}), - ("size", {}, "size", FOO, FakeLocalFileSystem._strip_protocol(BAR), {}), - ( - "sign", - {"expiration": 200, "extra": "xtra"}, - "sign", - FOO, - FakeLocalFileSystem._strip_protocol(BAR), - {"expiration": 200, "extra": "xtra"}, - ), - ("ukey", {}, "ukey", FOO, FakeLocalFileSystem._strip_protocol(BAR), {}), - ( - "read_block", - {"offset": 0, "length": 1}, - "read_block", - FOO, - FakeLocalFileSystem._strip_protocol(BAR), - {"delimiter": None, "length": 1, "offset": 0}, - ), - ], - ) - def test_standard_extended_api( - self, fake_local_files, fn, args, fn2, path, expected_args, expected_kwargs - ): - fs = FakeLocalFileSystem() - with mock.patch.object(fs, fn2) as method: - attach(protocol="file", conn_id="fake", fs=fs) - o = ObjectStoragePath(path, conn_id="fake") - - getattr(o, fn)(**args) - method.assert_called_once_with(expected_args, **expected_kwargs) - - def test_stat(self): - with NamedTemporaryFile() as f: - o = ObjectStoragePath(f"file://{f.name}") - assert o.stat().st_size == 0 - assert S_ISREG(o.stat().st_mode) - assert S_ISDIR(o.parent.stat().st_mode) - - def test_bucket_key_protocol(self): - attach(protocol="s3", fs=FakeRemoteFileSystem()) - - bucket = "bkt" - key = "yek" - protocol = "s3" - - o = ObjectStoragePath(f"{protocol}://{bucket}/{key}") - assert o.bucket == bucket - assert o.container == bucket - assert o.key == f"{key}" - assert o.protocol == protocol - - def test_cwd_home(self): - assert ObjectStoragePath.cwd() - assert ObjectStoragePath.home() - - def test_replace(self): - o = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}") - i = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}") - - o.touch() - i.touch() - - assert i.size() == 0 - - txt = "foo" - o.write_text(txt) - e = o.replace(i) - assert o.exists() is False - assert i == e - assert e.size() == len(txt) - - e.unlink() - - def test_is_relative_to(self): - uuid_dir = f"/tmp/{str(uuid.uuid4())}" - o1 = ObjectStoragePath(f"file://{uuid_dir}/aaa") - o2 = ObjectStoragePath(f"file://{uuid_dir}") - o3 = ObjectStoragePath(f"file://{str(uuid.uuid4())}") - assert o1.is_relative_to(o2) - assert not o1.is_relative_to(o3) - - def test_relative_to(self): - uuid_dir = f"/tmp/{str(uuid.uuid4())}" - o1 = ObjectStoragePath(f"file://{uuid_dir}/aaa") - o2 = ObjectStoragePath(f"file://{uuid_dir}") - o3 = ObjectStoragePath(f"file://{str(uuid.uuid4())}") - - _ = o1.relative_to(o2) # Should not raise any error - - with pytest.raises(ValueError): - o1.relative_to(o3) - - def test_move_local(self, hook_lineage_collector): - _from_path = f"file:///tmp/{str(uuid.uuid4())}" - _to_path = f"file:///tmp/{str(uuid.uuid4())}" - _from = ObjectStoragePath(_from_path) - _to = ObjectStoragePath(_to_path) - - _from.touch() - _from.move(_to) - assert _to.exists() - assert not _from.exists() - - _to.unlink() - - collected_assets = hook_lineage_collector.collected_assets - - assert len(collected_assets.inputs) == 1 - assert len(collected_assets.outputs) == 1 - assert collected_assets.inputs[0].asset == Asset(uri=_from_path) - assert collected_assets.outputs[0].asset == Asset(uri=_to_path) - - def test_move_remote(self, hook_lineage_collector): - attach("fakefs", fs=FakeRemoteFileSystem()) - - _from_path = f"file:///tmp/{str(uuid.uuid4())}" - _to_path = f"fakefs:///tmp/{str(uuid.uuid4())}" - - _from = ObjectStoragePath(_from_path) - _to = ObjectStoragePath(_to_path) - - _from.touch() - _from.move(_to) - assert not _from.exists() - assert _to.exists() - - _to.unlink() - - collected_assets = hook_lineage_collector.collected_assets - - assert len(collected_assets.inputs) == 1 - assert len(collected_assets.outputs) == 1 - assert collected_assets.inputs[0].asset == Asset(uri=str(_from)) - assert collected_assets.outputs[0].asset == Asset(uri=str(_to)) - - def test_copy_remote_remote(self, hook_lineage_collector): - attach("ffs", fs=FakeRemoteFileSystem(skip_instance_cache=True)) - attach("ffs2", fs=FakeRemoteFileSystem(skip_instance_cache=True)) - - dir_src = f"bucket1/{str(uuid.uuid4())}" - dir_dst = f"bucket2/{str(uuid.uuid4())}" - key = "foo/bar/baz.txt" - - _from_path = f"ffs://{dir_src}" - _from = ObjectStoragePath(_from_path) - _from_file = _from / key - _from_file.touch() - assert _from.bucket == "bucket1" - assert _from_file.exists() - - _to_path = f"ffs2://{dir_dst}" - _to = ObjectStoragePath(_to_path) - _from.copy(_to) - - assert _to.bucket == "bucket2" - assert _to.exists() - assert _to.is_dir() - assert (_to / _from.key / key).exists() - assert (_to / _from.key / key).is_file() - - _from.rmdir(recursive=True) - _to.rmdir(recursive=True) - - assert len(hook_lineage_collector.collected_assets.inputs) == 1 - assert hook_lineage_collector.collected_assets.inputs[0].asset == Asset(uri=str(_from_file)) - - # Empty file - shutil.copyfileobj does nothing - assert len(hook_lineage_collector.collected_assets.outputs) == 0 - - def test_serde_objectstoragepath(self): - path = "file:///bucket/key/part1/part2" - o = ObjectStoragePath(path) - - s = o.serialize() - assert s["path"] == path - d = ObjectStoragePath.deserialize(s, 1) - assert o == d - - o = ObjectStoragePath(path, my_setting="foo") - s = o.serialize() - assert "my_setting" in s["kwargs"] - d = ObjectStoragePath.deserialize(s, 1) - assert o == d - - store = attach("filex", conn_id="mock") - o = ObjectStoragePath(path, store=store) - s = o.serialize() - assert s["kwargs"]["store"] == store - - d = ObjectStoragePath.deserialize(s, 1) - assert o == d - - def test_serde_store(self): - store = attach("file", conn_id="mock") - s = store.serialize() - d = ObjectStore.deserialize(s, 1) - - assert s["protocol"] == "file" - assert s["conn_id"] == "mock" - assert s["filesystem"] == qualname(LocalFileSystem) - assert store == d - - store = attach("localfs", fs=LocalFileSystem()) - s = store.serialize() - d = ObjectStore.deserialize(s, 1) - - assert s["protocol"] == "localfs" - assert s["conn_id"] is None - assert s["filesystem"] == qualname(LocalFileSystem) - assert store == d - - def test_backwards_compat(self): - _register_filesystems.cache_clear() - from airflow.io import _BUILTIN_SCHEME_TO_FS as SCHEMES - - try: - SCHEMES["file"] = get_fs_no_storage_options # type: ignore[call-arg] - - assert get_fs("file") - - with pytest.raises(AttributeError): - get_fs("file", storage_options={"foo": "bar"}) - - finally: - # Reset the cache to avoid side effects - _register_filesystems.cache_clear() - - def test_asset(self): - attach("s3", fs=FakeRemoteFileSystem()) - - p = "s3" - f = "bucket/object" - i = Asset(uri=f"{p}://{f}", name="test-asset", extra={"foo": "bar"}) - o = ObjectStoragePath(i) - assert o.protocol == p - assert o.path == f - - def test_hash(self): - file_uri_1 = f"file:///tmp/{str(uuid.uuid4())}" - file_uri_2 = f"file:///tmp/{str(uuid.uuid4())}" - s = set() - for _ in range(10): - s.add(ObjectStoragePath(file_uri_1)) - s.add(ObjectStoragePath(file_uri_2)) - assert len(s) == 2 - - def test_lazy_load(self): - o = ObjectStoragePath("file:///tmp/foo") - with pytest.raises(AttributeError): - assert o._fs_cached - - assert o.fs is not None - assert o._fs_cached - - @pytest.mark.parametrize("input_str", ("file:///tmp/foo", "s3://conn_id@bucket/test.txt")) - def test_str(self, input_str): - o = ObjectStoragePath(input_str) - assert str(o) == input_str diff --git a/airflow-core/tests/unit/io/test_wrapper.py b/airflow-core/tests/unit/io/test_wrapper.py index 35469326794ea..1fbb5bc50fba3 100644 --- a/airflow-core/tests/unit/io/test_wrapper.py +++ b/airflow-core/tests/unit/io/test_wrapper.py @@ -19,8 +19,7 @@ import uuid from unittest.mock import patch -from airflow.io.path import ObjectStoragePath -from airflow.sdk.definitions.asset import Asset +from airflow.sdk import Asset, ObjectStoragePath @patch("airflow.providers_manager.ProvidersManager") diff --git a/providers/common/io/src/airflow/providers/common/io/operators/file_transfer.py b/providers/common/io/src/airflow/providers/common/io/operators/file_transfer.py index 840ef796e598d..0faec858d1c75 100644 --- a/providers/common/io/src/airflow/providers/common/io/operators/file_transfer.py +++ b/providers/common/io/src/airflow/providers/common/io/operators/file_transfer.py @@ -20,12 +20,18 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from airflow.io.path import ObjectStoragePath -from airflow.models import BaseOperator +from airflow.providers.common.io.version_compat import AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: from airflow.providers.openlineage.extractors import OperatorLineage - from airflow.utils.context import Context + from airflow.sdk import Context + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import ObjectStoragePath + from airflow.sdk.bases.operator import BaseOperator +else: + from airflow.io.path import ObjectStoragePath # type: ignore[no-redef] + from airflow.models import BaseOperator # type: ignore[no-redef] class FileTransferOperator(BaseOperator): diff --git a/providers/common/io/src/airflow/providers/common/io/xcom/backend.py b/providers/common/io/src/airflow/providers/common/io/xcom/backend.py index 79bedc3977eb5..5fc86b8f182ee 100644 --- a/providers/common/io/src/airflow/providers/common/io/xcom/backend.py +++ b/providers/common/io/src/airflow/providers/common/io/xcom/backend.py @@ -27,7 +27,6 @@ import fsspec.utils from airflow.configuration import conf -from airflow.io.path import ObjectStoragePath from airflow.providers.common.io.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils.json import XComDecoder, XComEncoder @@ -37,8 +36,10 @@ from airflow.sdk.execution_time.comms import XComResult if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import ObjectStoragePath from airflow.sdk.bases.xcom import BaseXCom else: + from airflow.io.path import ObjectStoragePath # type: ignore[no-redef] from airflow.models.xcom import BaseXCom # type: ignore[no-redef] T = TypeVar("T") diff --git a/providers/common/io/tests/system/common/io/example_file_transfer_local_to_s3.py b/providers/common/io/tests/system/common/io/example_file_transfer_local_to_s3.py index 04204f3c38be7..d5a522ac89cbc 100644 --- a/providers/common/io/tests/system/common/io/example_file_transfer_local_to_s3.py +++ b/providers/common/io/tests/system/common/io/example_file_transfer_local_to_s3.py @@ -23,10 +23,15 @@ from airflow import DAG from airflow.decorators import task -from airflow.io.path import ObjectStoragePath from airflow.providers.common.io.operators.file_transfer import FileTransferOperator +from airflow.providers.common.io.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils.trigger_rule import TriggerRule +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import ObjectStoragePath +else: + from airflow.io.path import ObjectStoragePath # type: ignore[no-redef] + ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "example_file_transfer_local_to_s3" diff --git a/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py b/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py index 9f80f2d8a962c..ab3cdd56008cd 100644 --- a/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py +++ b/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py @@ -33,7 +33,6 @@ ) from uuid6 import uuid7 -from airflow.io.path import ObjectStoragePath from airflow.models.taskinstance import TaskInstance from airflow.providers.common.compat.lineage.entities import Column, File, Table, User from airflow.providers.openlineage.extractors import OperatorLineage @@ -78,13 +77,15 @@ def hook_lineage_collector(): if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import ObjectStoragePath from airflow.sdk.api.datamodels._generated import BundleInfo, TaskInstance as SDKTaskInstance from airflow.sdk.bases.operator import BaseOperator from airflow.sdk.execution_time import task_runner from airflow.sdk.execution_time.comms import StartupDetails from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, parse else: - from airflow.models.baseoperator import BaseOperator + from airflow.io.path import ObjectStoragePath # type: ignore[no-redef] + from airflow.models import BaseOperator SDKTaskInstance = ... # type: ignore task_runner = ... # type: ignore diff --git a/task-sdk/pyproject.toml b/task-sdk/pyproject.toml index 13138406ba505..f3f4e39c413e5 100644 --- a/task-sdk/pyproject.toml +++ b/task-sdk/pyproject.toml @@ -24,6 +24,7 @@ requires-python = ">=3.9, <3.13" dependencies = [ "aiologic>=0.14.0", "attrs>=24.2.0, !=25.2.0", + "fsspec>=2023.10.0", "httpx>=0.27.0", "jinja2>=3.1.5", "methodtools>=0.4.7", diff --git a/task-sdk/src/airflow/sdk/__init__.py b/task-sdk/src/airflow/sdk/__init__.py index a7cad123c6dbd..8b933942df504 100644 --- a/task-sdk/src/airflow/sdk/__init__.py +++ b/task-sdk/src/airflow/sdk/__init__.py @@ -35,6 +35,7 @@ "EdgeModifier", "Label", "Metadata", + "ObjectStoragePath", "Param", "PokeReturnValue", "TaskGroup", @@ -75,6 +76,7 @@ from airflow.sdk.definitions.template import literal from airflow.sdk.definitions.variable import Variable from airflow.sdk.definitions.xcom_arg import XComArg + from airflow.sdk.io.path import ObjectStoragePath __lazy_imports: dict[str, str] = { "Asset": ".definitions.asset", @@ -92,6 +94,7 @@ "EdgeModifier": ".definitions.edges", "Label": ".definitions.edges", "Metadata": ".definitions.asset.metadata", + "ObjectStoragePath": ".io.path", "Param": ".definitions.param", "PokeReturnValue": ".bases.sensor", "TaskGroup": ".definitions.taskgroup", diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/templater.py b/task-sdk/src/airflow/sdk/definitions/_internal/templater.py index a091411349657..d141b0328b015 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/templater.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/templater.py @@ -27,11 +27,11 @@ import jinja2.nativetypes import jinja2.sandbox +from airflow.sdk import ObjectStoragePath from airflow.sdk.definitions._internal.mixins import ResolveMixin from airflow.utils.helpers import render_template_as_native, render_template_to_string if TYPE_CHECKING: - from airflow.io.path import ObjectStoragePath from airflow.models.operator import Operator from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.dag import DAG @@ -154,13 +154,6 @@ def render_template( *RecursionError* on circular dependencies) :return: Templated content """ - try: - # Delay this to runtime, it invokes provider manager otherwise - from airflow.io.path import ObjectStoragePath - except ImportError: - # A placeholder class so isinstance checks work - class ObjectStoragePath: ... # type: ignore[no-redef] - # "content" is a bad name, but we're stuck to it being public API. value = content del content diff --git a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py index 43f0d7c2c1359..f156df5895365 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -35,6 +35,7 @@ from urllib.parse import SplitResult from airflow.models.asset import AssetModel + from airflow.sdk.io.path import ObjectStoragePath from airflow.serialization.serialized_objects import SerializedAssetWatcher from airflow.triggers.base import BaseEventTrigger @@ -129,13 +130,14 @@ def _get_normalized_scheme(uri: str) -> str: return parsed.scheme.lower() -def _sanitize_uri(uri: str) -> str: +def _sanitize_uri(inp: str | ObjectStoragePath) -> str: """ Sanitize an asset URI. This checks for URI validity, and normalizes the URI if needed. A fully normalized URI is returned. """ + uri = str(inp) parsed = urllib.parse.urlsplit(uri) if not parsed.scheme and not parsed.netloc: # Does not look like a URI. return uri @@ -313,7 +315,7 @@ class Asset(os.PathLike, BaseAsset): def __init__( self, name: str, - uri: str, + uri: str | ObjectStoragePath, *, group: str = ..., extra: dict | None = None, @@ -336,7 +338,7 @@ def __init__( def __init__( self, *, - uri: str, + uri: str | ObjectStoragePath, group: str = ..., extra: dict | None = None, watchers: list[AssetWatcher | SerializedAssetWatcher] = ..., @@ -346,7 +348,7 @@ def __init__( def __init__( self, name: str | None = None, - uri: str | None = None, + uri: str | ObjectStoragePath | None = None, *, group: str | None = None, extra: dict | None = None, @@ -355,7 +357,7 @@ def __init__( if name is None and uri is None: raise TypeError("Asset() requires either 'name' or 'uri'") elif name is None: - name = uri + name = str(uri) elif uri is None: uri = name diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py index 8ad4dfb2a1288..d899dd3718336 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py @@ -29,9 +29,9 @@ if TYPE_CHECKING: from collections.abc import Callable, Collection, Iterator, Mapping - from airflow.io.path import ObjectStoragePath - from airflow.sdk.definitions.asset import AssetAlias, AssetUniqueKey - from airflow.sdk.definitions.dag import DAG, DagStateChangeCallback, ScheduleArg + from airflow.sdk import DAG, AssetAlias, ObjectStoragePath + from airflow.sdk.definitions.asset import AssetUniqueKey + from airflow.sdk.definitions.dag import DagStateChangeCallback, ScheduleArg from airflow.sdk.definitions.param import ParamsDict from airflow.serialization.dag_dependency import DagDependency from airflow.triggers.base import BaseTrigger diff --git a/task-sdk/src/airflow/sdk/io/__init__.py b/task-sdk/src/airflow/sdk/io/__init__.py new file mode 100644 index 0000000000000..4247c5e4acf23 --- /dev/null +++ b/task-sdk/src/airflow/sdk/io/__init__.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.sdk.io.path import ObjectStoragePath +from airflow.sdk.io.store import attach + +__all__ = ["ObjectStoragePath", "attach"] diff --git a/task-sdk/src/airflow/sdk/io/path.py b/task-sdk/src/airflow/sdk/io/path.py new file mode 100644 index 0000000000000..8c519d0ee27df --- /dev/null +++ b/task-sdk/src/airflow/sdk/io/path.py @@ -0,0 +1,416 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import contextlib +import os +import shutil +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, ClassVar +from urllib.parse import urlsplit + +from fsspec.utils import stringify_path +from upath.implementations.cloud import CloudPath +from upath.registry import get_upath_class + +from airflow.sdk.io.stat import stat_result +from airflow.sdk.io.store import attach + +if TYPE_CHECKING: + from fsspec import AbstractFileSystem + + +class _TrackingFileWrapper: + """Wrapper that tracks file operations to intercept lineage.""" + + def __init__(self, path: ObjectStoragePath, obj): + super().__init__() + self._path = path + self._obj = obj + + def __getattr__(self, name): + from airflow.lineage.hook import get_hook_lineage_collector + + if not callable(attr := getattr(self._obj, name)): + return attr + + # If the attribute is a method, wrap it in another method to intercept the call + def wrapper(*args, **kwargs): + if name == "read": + get_hook_lineage_collector().add_input_asset(context=self._path, uri=str(self._path)) + elif name == "write": + get_hook_lineage_collector().add_output_asset(context=self._path, uri=str(self._path)) + result = attr(*args, **kwargs) + return result + + return wrapper + + def __getitem__(self, key): + # Intercept item access + return self._obj[key] + + def __enter__(self): + self._obj.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._obj.__exit__(exc_type, exc_val, exc_tb) + + +class ObjectStoragePath(CloudPath): + """A path-like object for object storage.""" + + __version__: ClassVar[int] = 1 + + _protocol_dispatch = False + + sep: ClassVar[str] = "/" + root_marker: ClassVar[str] = "/" + + __slots__ = ("_hash_cached",) + + @classmethod + def _transform_init_args( + cls, + args: tuple[str | os.PathLike, ...], + protocol: str, + storage_options: dict[str, Any], + ) -> tuple[tuple[str | os.PathLike, ...], str, dict[str, Any]]: + """Extract conn_id from the URL and set it as a storage option.""" + if args: + arg0 = args[0] + parsed_url = urlsplit(stringify_path(arg0)) + userinfo, have_info, hostinfo = parsed_url.netloc.rpartition("@") + if have_info: + storage_options.setdefault("conn_id", userinfo or None) + parsed_url = parsed_url._replace(netloc=hostinfo) + args = (parsed_url.geturl(),) + args[1:] + protocol = protocol or parsed_url.scheme + return args, protocol, storage_options + + @classmethod + def _fs_factory( + cls, urlpath: str, protocol: str, storage_options: Mapping[str, Any] + ) -> AbstractFileSystem: + return attach(protocol or "file", storage_options.get("conn_id")).fs + + def __hash__(self) -> int: + self._hash_cached: int + try: + return self._hash_cached + except AttributeError: + self._hash_cached = hash(str(self)) + return self._hash_cached + + def __eq__(self, other: Any) -> bool: + return self.samestore(other) and str(self) == str(other) + + def samestore(self, other: Any) -> bool: + return ( + isinstance(other, ObjectStoragePath) + and self.protocol == other.protocol + and self.storage_options.get("conn_id") == other.storage_options.get("conn_id") + ) + + @property + def container(self) -> str: + return self.bucket + + @property + def bucket(self) -> str: + if self._url: + return self._url.netloc + else: + return "" + + @property + def key(self) -> str: + if self._url: + # per convention, we strip the leading slashes to ensure a relative key is returned + # we keep the trailing slash to allow for directory-like semantics + return self._url.path.lstrip(self.sep) + else: + return "" + + @property + def namespace(self) -> str: + return f"{self.protocol}://{self.bucket}" if self.bucket else self.protocol + + def open(self, mode="r", **kwargs): + """Open the file pointed to by this path.""" + kwargs.setdefault("block_size", kwargs.pop("buffering", None)) + return _TrackingFileWrapper(self, self.fs.open(self.path, mode=mode, **kwargs)) + + def stat(self) -> stat_result: # type: ignore[override] + """Call ``stat`` and return the result.""" + return stat_result( + self.fs.stat(self.path), + protocol=self.protocol, + conn_id=self.storage_options.get("conn_id"), + ) + + def samefile(self, other_path: Any) -> bool: + """Return whether other_path is the same or not as this file.""" + if not isinstance(other_path, ObjectStoragePath): + return False + + st = self.stat() + other_st = other_path.stat() + + return ( + st["protocol"] == other_st["protocol"] + and st["conn_id"] == other_st["conn_id"] + and st["ino"] == other_st["ino"] + ) + + def _scandir(self): + # Emulate os.scandir(), which returns an object that can be used as a + # context manager. + return contextlib.nullcontext(self.iterdir()) + + def replace(self, target) -> ObjectStoragePath: + """ + Rename this path to the target path, overwriting if that path exists. + + The target path may be absolute or relative. Relative paths are + interpreted relative to the current working directory, *not* the + directory of the Path object. + + Returns the new Path instance pointing to the target path. + """ + return self.rename(target) + + @classmethod + def cwd(cls): + if cls is ObjectStoragePath: + return get_upath_class("").cwd() + else: + raise NotImplementedError + + @classmethod + def home(cls): + if cls is ObjectStoragePath: + return get_upath_class("").home() + else: + raise NotImplementedError + + # EXTENDED OPERATIONS + + def ukey(self) -> str: + """Hash of file properties, to tell if it has changed.""" + return self.fs.ukey(self.path) + + def checksum(self) -> int: + """Return the checksum of the file at this path.""" + # we directly access the fs here to avoid changing the abstract interface + return self.fs.checksum(self.path) + + def read_block(self, offset: int, length: int, delimiter=None): + r""" + Read a block of bytes. + + Starting at ``offset`` of the file, read ``length`` bytes. If + ``delimiter`` is set then we ensure that the read starts and stops at + delimiter boundaries that follow the locations ``offset`` and ``offset + + length``. If ``offset`` is zero then we start at zero. The + bytestring returned WILL include the end delimiter string. + + If offset+length is beyond the eof, reads to eof. + + :param offset: int + Byte offset to start read + :param length: int + Number of bytes to read. If None, read to the end. + :param delimiter: bytes (optional) + Ensure reading starts and stops at delimiter bytestring + + Examples + -------- + >>> read_block(0, 13) + b'Alice, 100\\nBo' + >>> read_block(0, 13, delimiter=b"\\n") + b'Alice, 100\\nBob, 200\\n' + + Use ``length=None`` to read to the end of the file. + >>> read_block(0, None, delimiter=b"\\n") + b'Alice, 100\\nBob, 200\\nCharlie, 300' + + See Also + -------- + :func:`fsspec.utils.read_block` + """ + return self.fs.read_block(self.path, offset=offset, length=length, delimiter=delimiter) + + def sign(self, expiration: int = 100, **kwargs): + """ + Create a signed URL representing the given path. + + Some implementations allow temporary URLs to be generated, as a + way of delegating credentials. + + :param path: str + The path on the filesystem + :param expiration: int + Number of seconds to enable the URL for (if supported) + + :returns URL: str + The signed URL + + :raises NotImplementedError: if the method is not implemented for a store + """ + return self.fs.sign(self.path, expiration=expiration, **kwargs) + + def size(self) -> int: + """Size in bytes of the file at this path.""" + return self.fs.size(self.path) + + def _cp_file(self, dst: ObjectStoragePath, **kwargs): + """Copy a single file from this path to another location by streaming the data.""" + # create the directory or bucket if required + if dst.key.endswith(self.sep) or not dst.key: + dst.mkdir(exist_ok=True, parents=True) + dst = dst / self.key + elif dst.is_dir(): + dst = dst / self.key + + # streaming copy + with self.open("rb") as f1, dst.open("wb") as f2: + # make use of system dependent buffer size + shutil.copyfileobj(f1, f2, **kwargs) + + def copy(self, dst: str | ObjectStoragePath, recursive: bool = False, **kwargs) -> None: + """ + Copy file(s) from this path to another location. + + For remote to remote copies, the key used for the destination will be the same as the source. + So that s3://src_bucket/foo/bar will be copied to gcs://dst_bucket/foo/bar and not + gcs://dst_bucket/bar. + + :param dst: Destination path + :param recursive: If True, copy directories recursively. + + kwargs: Additional keyword arguments to be passed to the underlying implementation. + """ + from airflow.lineage.hook import get_hook_lineage_collector + + if isinstance(dst, str): + dst = ObjectStoragePath(dst) + + if self.samestore(dst) or self.protocol == "file" or dst.protocol == "file": + # only emit this in "optimized" variants - else lineage will be captured by file writes/reads + get_hook_lineage_collector().add_input_asset(context=self, uri=str(self)) + get_hook_lineage_collector().add_output_asset(context=dst, uri=str(dst)) + + # same -> same + if self.samestore(dst): + self.fs.copy(self.path, dst.path, recursive=recursive, **kwargs) + return + + # use optimized path for local -> remote or remote -> local + if self.protocol == "file": + dst.fs.put(self.path, dst.path, recursive=recursive, **kwargs) + return + + if dst.protocol == "file": + self.fs.get(self.path, dst.path, recursive=recursive, **kwargs) + return + + if not self.exists(): + raise FileNotFoundError(f"{self} does not exist") + + # remote dir -> remote dir + if self.is_dir(): + if dst.is_file(): + raise ValueError("Cannot copy directory to a file.") + + dst.mkdir(exist_ok=True, parents=True) + + out = self.fs.expand_path(self.path, recursive=True, **kwargs) + + for path in out: + # this check prevents one extra call to is_dir() as + # glob returns self as well + if path == self.path: + continue + + src_obj = ObjectStoragePath( + path, + protocol=self.protocol, + conn_id=self.storage_options.get("conn_id"), + ) + + # skip directories, empty directories will not be created + if src_obj.is_dir(): + continue + + src_obj._cp_file(dst) + return + + # remote file -> remote dir + self._cp_file(dst, **kwargs) + + def move(self, path: str | ObjectStoragePath, recursive: bool = False, **kwargs) -> None: + """ + Move file(s) from this path to another location. + + :param path: Destination path + :param recursive: bool + If True, move directories recursively. + + kwargs: Additional keyword arguments to be passed to the underlying implementation. + """ + from airflow.lineage.hook import get_hook_lineage_collector + + if isinstance(path, str): + path = ObjectStoragePath(path) + + if self.samestore(path): + get_hook_lineage_collector().add_input_asset(context=self, uri=str(self)) + get_hook_lineage_collector().add_output_asset(context=path, uri=str(path)) + return self.fs.move(self.path, path.path, recursive=recursive, **kwargs) + + # non-local copy + self.copy(path, recursive=recursive, **kwargs) + self.unlink() + + def serialize(self) -> dict[str, Any]: + _kwargs = {**self.storage_options} + conn_id = _kwargs.pop("conn_id", None) + + return { + "path": str(self), + "conn_id": conn_id, + "kwargs": _kwargs, + } + + @classmethod + def deserialize(cls, data: dict, version: int) -> ObjectStoragePath: + if version > cls.__version__: + raise ValueError(f"Cannot deserialize version {version} with version {cls.__version__}.") + + _kwargs = data.pop("kwargs") + path = data.pop("path") + conn_id = data.pop("conn_id", None) + + return ObjectStoragePath(path, conn_id=conn_id, **_kwargs) + + def __str__(self): + conn_id = self.storage_options.get("conn_id") + if self._protocol and conn_id: + return f"{self._protocol}://{conn_id}@{self.path}" + return super().__str__() diff --git a/airflow-core/src/airflow/io/utils/stat.py b/task-sdk/src/airflow/sdk/io/stat.py similarity index 100% rename from airflow-core/src/airflow/io/utils/stat.py rename to task-sdk/src/airflow/sdk/io/stat.py diff --git a/airflow-core/src/airflow/io/store/__init__.py b/task-sdk/src/airflow/sdk/io/store.py similarity index 95% rename from airflow-core/src/airflow/io/store/__init__.py rename to task-sdk/src/airflow/sdk/io/store.py index d1e897c5de01c..2a3d9732bdb9c 100644 --- a/airflow-core/src/airflow/io/store/__init__.py +++ b/task-sdk/src/airflow/sdk/io/store.py @@ -20,9 +20,6 @@ from functools import cached_property from typing import TYPE_CHECKING, ClassVar -from airflow.io import get_fs, has_fs -from airflow.utils.module_loading import qualname - if TYPE_CHECKING: from fsspec import AbstractFileSystem @@ -30,7 +27,11 @@ class ObjectStore: - """Manages a filesystem or object storage.""" + """ + Manages a filesystem or object storage. + + To use this class, call :meth:`.attach` instead. + """ __version__: ClassVar[int] = 1 @@ -57,6 +58,8 @@ def __str__(self): @cached_property def fs(self) -> AbstractFileSystem: + from airflow.io import get_fs + # if the fs is provided in init, the next statement will be ignored return get_fs(self.protocol, self.conn_id) @@ -76,6 +79,8 @@ def fsid(self) -> str: return f"{self.fs.protocol}-{self.conn_id or 'env'}" def serialize(self): + from airflow.utils.module_loading import qualname + return { "protocol": self.protocol, "conn_id": self.conn_id, @@ -85,6 +90,8 @@ def serialize(self): @classmethod def deserialize(cls, data: dict[str, str], version: int): + from airflow.io import has_fs + if version > cls.__version__: raise ValueError(f"Cannot deserialize version {version} for {cls.__name__}") diff --git a/task-sdk/tests/task_sdk/definitions/_internal/test_templater.py b/task-sdk/tests/task_sdk/definitions/_internal/test_templater.py index 7e33c0b409291..b396b596df0ef 100644 --- a/task-sdk/tests/task_sdk/definitions/_internal/test_templater.py +++ b/task-sdk/tests/task_sdk/definitions/_internal/test_templater.py @@ -22,8 +22,8 @@ import jinja2 import pytest +from airflow.sdk import DAG, ObjectStoragePath from airflow.sdk.definitions._internal.templater import LiteralValue, SandboxedEnvironment, Templater -from airflow.sdk.definitions.dag import DAG class TestTemplater: @@ -55,9 +55,6 @@ def test_resolve_template_files_logs_exception(self, caplog): assert "Failed to resolve template field 'message'" in caplog.text def test_render_object_storage_path(self): - # TODO: Move this import to top-level after https://github.com/apache/airflow/issues/45425 - from airflow.io.path import ObjectStoragePath - templater = Templater() path = ObjectStoragePath("s3://bucket/key/{{ ds }}/part") context = {"ds": "2006-02-01"} diff --git a/task-sdk/tests/task_sdk/definitions/test_asset.py b/task-sdk/tests/task_sdk/definitions/test_asset.py index b084946fe7374..5b7c922cde1e2 100644 --- a/task-sdk/tests/task_sdk/definitions/test_asset.py +++ b/task-sdk/tests/task_sdk/definitions/test_asset.py @@ -37,6 +37,7 @@ _sanitize_uri, ) from airflow.sdk.definitions.dag import DAG +from airflow.sdk.io import ObjectStoragePath from airflow.serialization.serialized_objects import SerializedDAG ASSET_MODULE_PATH = "airflow.sdk.definitions.asset" @@ -131,6 +132,12 @@ def test_uri_without_scheme(): EmptyOperator(task_id="task1", outlets=[asset]) +def test_objectstoragepath(): + o = ObjectStoragePath("file:///123/456") + a = Asset(name="o", uri=o) + assert a.uri == "file:///123/456" + + def test_fspath(): uri = "s3://example/asset" asset = Asset(uri=uri) diff --git a/airflow-core/src/airflow/io/utils/__init__.py b/task-sdk/tests/task_sdk/io/__init__.py similarity index 100% rename from airflow-core/src/airflow/io/utils/__init__.py rename to task-sdk/tests/task_sdk/io/__init__.py diff --git a/task-sdk/tests/task_sdk/io/test_path.py b/task-sdk/tests/task_sdk/io/test_path.py new file mode 100644 index 0000000000000..18b805dc4e0a5 --- /dev/null +++ b/task-sdk/tests/task_sdk/io/test_path.py @@ -0,0 +1,343 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import uuid +from stat import S_ISDIR, S_ISREG +from typing import Any, ClassVar +from unittest import mock + +import pytest +from fsspec.implementations.local import LocalFileSystem +from fsspec.implementations.memory import MemoryFileSystem + +from airflow.sdk import Asset, ObjectStoragePath +from airflow.sdk.io import attach +from airflow.sdk.io.store import _STORE_CACHE, ObjectStore +from airflow.utils.module_loading import qualname + + +def test_init(): + path = ObjectStoragePath("s3://bucket/key/part1/part2") + assert path.bucket == "bucket" + assert path.key == "key/part1/part2" + assert path.protocol == "s3" + assert path.path == "bucket/key/part1/part2" + + path2 = ObjectStoragePath(path / "part3") + assert path2.bucket == "bucket" + assert path2.key == "key/part1/part2/part3" + assert path2.protocol == "s3" + assert path2.path == "bucket/key/part1/part2/part3" + + path3 = ObjectStoragePath(path2 / "2023") + assert path3.bucket == "bucket" + assert path3.key == "key/part1/part2/part3/2023" + assert path3.protocol == "s3" + assert path3.path == "bucket/key/part1/part2/part3/2023" + + +@pytest.mark.parametrize("input_str", ("file:///tmp/foo", "s3://conn_id@bucket/test.txt")) +def test_str(input_str): + o = ObjectStoragePath(input_str) + assert str(o) == input_str + + +def test_cwd(): + assert ObjectStoragePath.cwd() + + +def test_home(): + assert ObjectStoragePath.home() + + +def test_lazy_load(): + o = ObjectStoragePath("file:///tmp/foo") + with pytest.raises(AttributeError): + assert o._fs_cached + + assert o.fs is not None + assert o._fs_cached + + +class _FakeRemoteFileSystem(MemoryFileSystem): + protocol = ("s3", "fakefs", "ffs", "ffs2") + root_marker = "" + store: ClassVar[dict[str, Any]] = {} + pseudo_dirs = [""] + + def __init__(self, *args, **kwargs): + self.conn_id = kwargs.pop("conn_id", None) + super().__init__(*args, **kwargs) + + @classmethod + def _strip_protocol(cls, path): + for protocol in cls.protocol: + if path.startswith(f"{protocol}://"): + return path[len(f"{protocol}://") :] + if "::" in path or "://" in path: + return path.rstrip("/") + path = path.lstrip("/").rstrip("/") + return path + + +class TestAttach: + FAKE = "ffs:///fake" + MNT = "ffs:///mnt/warehouse" + FOO = "ffs:///mnt/warehouse/foo" + BAR = FOO + + @pytest.fixture(autouse=True) + def restore_cache(self): + cache = _STORE_CACHE.copy() + yield + _STORE_CACHE.clear() + _STORE_CACHE.update(cache) + + @pytest.fixture + def fake_files(self): + obj = _FakeRemoteFileSystem() + obj.touch(self.FOO) + try: + yield + finally: + _FakeRemoteFileSystem.store.clear() + _FakeRemoteFileSystem.pseudo_dirs[:] = [""] + + def test_alias(self): + store = attach("file", alias="local") + assert isinstance(store.fs, LocalFileSystem) + assert {"local": store, "file": store} == _STORE_CACHE + + def test_objectstoragepath_init_conn_id_in_uri(self): + attach(protocol="fake", conn_id="fake", fs=_FakeRemoteFileSystem(conn_id="fake")) + p = ObjectStoragePath("fake://fake@bucket/path") + p.touch() + fsspec_info = p.fs.info(p.path) + assert p.stat() == {**fsspec_info, "conn_id": "fake", "protocol": "fake"} + + @pytest.mark.parametrize( + "fn, args, fn2, path, expected_args, expected_kwargs", + [ + ("checksum", {}, "checksum", FOO, _FakeRemoteFileSystem._strip_protocol(BAR), {}), + ("size", {}, "size", FOO, _FakeRemoteFileSystem._strip_protocol(BAR), {}), + ( + "sign", + {"expiration": 200, "extra": "xtra"}, + "sign", + FOO, + _FakeRemoteFileSystem._strip_protocol(BAR), + {"expiration": 200, "extra": "xtra"}, + ), + ("ukey", {}, "ukey", FOO, _FakeRemoteFileSystem._strip_protocol(BAR), {}), + ( + "read_block", + {"offset": 0, "length": 1}, + "read_block", + FOO, + _FakeRemoteFileSystem._strip_protocol(BAR), + {"delimiter": None, "length": 1, "offset": 0}, + ), + ], + ) + def test_standard_extended_api(self, fake_files, fn, args, fn2, path, expected_args, expected_kwargs): + fs = _FakeRemoteFileSystem() + attach(protocol="ffs", conn_id="fake", fs=fs) + with mock.patch.object(fs, fn2) as method: + o = ObjectStoragePath(path, conn_id="fake") + getattr(o, fn)(**args) + method.assert_called_once_with(expected_args, **expected_kwargs) + + +class TestRemotePath: + @pytest.fixture(autouse=True) + def fake_fs(self, monkeypatch): + monkeypatch.setattr(ObjectStoragePath, "_fs_factory", lambda *a, **k: _FakeRemoteFileSystem()) + + def test_bucket_key_protocol(self): + bucket = "bkt" + key = "yek" + protocol = "s3" + + o = ObjectStoragePath(f"{protocol}://{bucket}/{key}") + assert o.bucket == bucket + assert o.container == bucket + assert o.key == key + assert o.protocol == protocol + + +class TestLocalPath: + @pytest.fixture + def target(self, tmp_path): + tmp = tmp_path.joinpath(str(uuid.uuid4())) + tmp.touch() + return tmp.as_posix() + + @pytest.fixture + def another(self, tmp_path): + tmp = tmp_path.joinpath(str(uuid.uuid4())) + tmp.touch() + return tmp.as_posix() + + def test_ls(self, tmp_path, target): + d = ObjectStoragePath(f"file://{tmp_path.as_posix()}") + o = ObjectStoragePath(f"file://{target}") + + data = list(d.iterdir()) + assert len(data) == 1 + assert data[0] == o + + d.rmdir(recursive=True) + assert not o.exists() + + def test_read_write(self, target): + o = ObjectStoragePath(f"file://{target}") + with o.open("wb") as f: + f.write(b"foo") + assert o.open("rb").read() == b"foo" + o.unlink() + + def test_stat(self, target): + o = ObjectStoragePath(f"file://{target}") + assert o.stat().st_size == 0 + assert S_ISREG(o.stat().st_mode) + assert S_ISDIR(o.parent.stat().st_mode) + + def test_replace(self, target, another): + o = ObjectStoragePath(f"file://{target}") + i = ObjectStoragePath(f"file://{another}") + assert i.size() == 0 + + txt = "foo" + o.write_text(txt) + e = o.replace(i) + assert o.exists() is False + assert i == e + assert e.size() == len(txt) + + def test_hash(self, target, another): + file_uri_1 = f"file://{target}" + file_uri_2 = f"file://{another}" + s = set() + for _ in range(10): + s.add(ObjectStoragePath(file_uri_1)) + s.add(ObjectStoragePath(file_uri_2)) + assert len(s) == 2 + + def test_is_relative_to(self, tmp_path, target): + o1 = ObjectStoragePath(f"file://{target}") + o2 = ObjectStoragePath(f"file://{tmp_path.as_posix()}") + o3 = ObjectStoragePath(f"file:///{uuid.uuid4()}") + assert o1.is_relative_to(o2) + assert not o1.is_relative_to(o3) + + def test_relative_to(self, tmp_path, target): + o1 = ObjectStoragePath(f"file://{target}") + o2 = ObjectStoragePath(f"file://{tmp_path.as_posix()}") + o3 = ObjectStoragePath(f"file:///{uuid.uuid4()}") + assert o1.relative_to(o2) == o1 + with pytest.raises(ValueError): + o1.relative_to(o3) + + def test_asset(self): + p = "s3" + f = "bucket/object" + i = Asset(uri=f"{p}://{f}", name="test-asset", extra={"foo": "bar"}) + o = ObjectStoragePath(i) + assert o.protocol == p + assert o.path == f + + def test_move_local(self, hook_lineage_collector, tmp_path, target): + o1 = ObjectStoragePath(f"file://{target}") + o2 = ObjectStoragePath(f"file://{tmp_path}/{uuid.uuid4()}") + assert o1.exists() + assert not o2.exists() + + o1.move(o2) + assert o2.exists() + assert not o1.exists() + + collected_assets = hook_lineage_collector.collected_assets + assert len(collected_assets.inputs) == 1 + assert len(collected_assets.outputs) == 1 + assert collected_assets.inputs[0].asset.uri == str(o1) + assert collected_assets.outputs[0].asset.uri == str(o2) + + def test_serde_objectstoragepath(self): + path = "file:///bucket/key/part1/part2" + o = ObjectStoragePath(path) + + s = o.serialize() + assert s["path"] == path + d = ObjectStoragePath.deserialize(s, 1) + assert o == d + + o = ObjectStoragePath(path, my_setting="foo") + s = o.serialize() + assert "my_setting" in s["kwargs"] + d = ObjectStoragePath.deserialize(s, 1) + assert o == d + + store = attach("filex", conn_id="mock") + o = ObjectStoragePath(path, store=store) + s = o.serialize() + assert s["kwargs"]["store"] == store + + d = ObjectStoragePath.deserialize(s, 1) + assert o == d + + def test_serde_store(self): + store = attach("file", conn_id="mock") + s = store.serialize() + d = ObjectStore.deserialize(s, 1) + + assert s["protocol"] == "file" + assert s["conn_id"] == "mock" + assert s["filesystem"] == qualname(LocalFileSystem) + assert store == d + + store = attach("localfs", fs=LocalFileSystem()) + s = store.serialize() + d = ObjectStore.deserialize(s, 1) + + assert s["protocol"] == "localfs" + assert s["conn_id"] is None + assert s["filesystem"] == qualname(LocalFileSystem) + assert store == d + + +class TestBackwardsCompatibility: + @pytest.fixture(autouse=True) + def reset(self): + from airflow.io import _register_filesystems + + _register_filesystems.cache_clear() + yield + _register_filesystems.cache_clear() + + def test_backwards_compat(self): + from airflow.io import _BUILTIN_SCHEME_TO_FS, get_fs + + def get_fs_no_storage_options(_: str): + return LocalFileSystem() + + _BUILTIN_SCHEME_TO_FS["file"] = get_fs_no_storage_options # type: ignore[call-arg] + assert get_fs("file") + with pytest.raises(AttributeError): + get_fs("file", storage_options={"foo": "bar"})