Skip to content

Commit

Permalink
remove legacy udf decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
mattseddon committed Sep 13, 2024
1 parent 944defc commit 1135c73
Show file tree
Hide file tree
Showing 13 changed files with 757 additions and 2,829 deletions.
3 changes: 0 additions & 3 deletions src/datachain/data_storage/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,6 @@ def sys_columns():
),
]

def dir_expansion(self):
return self.dataset_dir_expansion(self)


PARTITION_COLUMN_ID = "partition_id"

Expand Down
2 changes: 1 addition & 1 deletion src/datachain/lib/webdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def add(self, file: tarfile.TarInfo):
f"file with extension '.{ext}' already exists in the archive",
)
else:
type_ = self._get_type(ext)
type_ = self._get_type(ext or "txt")
if type_ is None:
raise UnknownFileExtensionError(self._tar_stream, fstream.name, ext)

Expand Down
96 changes: 0 additions & 96 deletions src/datachain/query/builtins.py

This file was deleted.

40 changes: 1 addition & 39 deletions src/datachain/query/udf.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import typing
from collections.abc import Iterable, Iterator, Mapping, Sequence
from collections.abc import Iterable, Iterator, Sequence
from dataclasses import dataclass
from functools import WRAPPER_ASSIGNMENTS
from inspect import isclass
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -25,8 +24,6 @@
)
from .schema import (
UDFParameter,
UDFParamSpec,
normalize_param,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -66,41 +63,6 @@ def signal_names(self) -> Iterable[str]:
return self.output.keys()


def udf(
params: Sequence[UDFParamSpec],
output: UDFOutputSpec,
*,
method: Optional[str] = None, # only used for class-based UDFs
batch: int = 1,
):
"""
Decorate a function or a class to be used as a UDF.
The decorator expects both the outputs and inputs of the UDF to be specified.
The outputs are defined as a collection of tuples containing the signal name
and type.
Parameters are defined as a list of column objects (e.g. C.name).
Optionally, UDFs can be run on batches of rows to improve performance, this
is determined by the 'batch' parameter. When operating on batches of inputs,
the UDF function will be called with a single argument - a list
of tuples containing inputs (e.g. ((input1_a, input1_b), (input2_a, input2b))).
"""
if isinstance(params, str):
params = (params,)
if not isinstance(output, Mapping):
raise TypeError(f"'output' must be a mapping, got {type(output).__name__}")

properties = UDFProperties([normalize_param(p) for p in params], output, batch)

def decorator(udf_base: Union[Callable, type]):
if isclass(udf_base):
return UDFClassWrapper(udf_base, properties, method=method)
if callable(udf_base):
return UDFWrapper(udf_base, properties)

return decorator


class UDFBase:
"""A base class for implementing stateful UDFs."""

Expand Down
116 changes: 0 additions & 116 deletions tests/func/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
from tests.data import ENTRIES
from tests.utils import (
DEFAULT_TREE,
TARRED_TREE,
assert_row_names,
create_tar_dataset,
make_index,
skip_if_not_sqlite,
tree_from_path,
Expand Down Expand Up @@ -273,55 +271,6 @@ def test_cp_local_dataset(cloud_test_catalog, dogs_dataset):
}


@pytest.mark.parametrize("tree", [TARRED_TREE], indirect=True)
@pytest.mark.parametrize("suffix", ["/", "/*"])
@pytest.mark.parametrize("recursive", [False, True])
@pytest.mark.parametrize("dir_exists", [False, True])
@pytest.mark.xfail(reason="Missing support for v-objects in cp")
def test_cp_tar_root(cloud_test_catalog, suffix, recursive, dir_exists):
ctc = cloud_test_catalog
catalog = ctc.catalog
create_tar_dataset(catalog, ctc.src_uri, "tarred")
dest = ctc.working_dir / "data"
if dir_exists:
dest.mkdir()
src = f"ds://tarred/animals.tar{suffix}"
dest_path = str(dest) + "/"

if not dir_exists and suffix == "/*":
with pytest.raises(FileNotFoundError):
catalog.cp([src], dest_path, recursive=recursive, no_edatachain_file=True)
return

catalog.cp([src], dest_path, recursive=recursive, no_edatachain_file=True)

expected = DEFAULT_TREE.copy()
if not recursive:
# Directories are not copied
if suffix == "/":
expected = {}
else:
for key in list(expected):
if isinstance(expected[key], dict):
del expected[key]

assert tree_from_path(dest) == expected


@pytest.mark.parametrize("tree", [TARRED_TREE], indirect=True)
@pytest.mark.xfail(reason="Missing support for v-objects in cp")
def test_cp_full_tar(cloud_test_catalog):
ctc = cloud_test_catalog
catalog = ctc.catalog
create_tar_dataset(catalog, ctc.src_uri, "tarred")
dest = ctc.working_dir / "data"
dest.mkdir()
src = "ds://tarred/"
catalog.cp([src], str(dest), recursive=True, no_edatachain_file=True)

assert tree_from_path(dest, binary=True) == TARRED_TREE


@pytest.mark.parametrize(
"recursive,star,slash,dir_exists",
(
Expand Down Expand Up @@ -403,43 +352,6 @@ def test_cp_subdir(cloud_test_catalog, recursive, star, slash, dir_exists):
assert files_by_name["dogs/others/dog4"]["size"] == 4


@pytest.mark.parametrize("tree", [TARRED_TREE], indirect=True)
@pytest.mark.parametrize("path", ["*/dogs", "animals.tar/dogs"])
@pytest.mark.parametrize("suffix", ["", "/", "/*"])
@pytest.mark.parametrize("recursive", [False, True])
@pytest.mark.parametrize("dir_exists", [False, True])
@pytest.mark.xfail(reason="Missing support for v-objects in cp")
def test_cp_tar_subdir(cloud_test_catalog, path, suffix, recursive, dir_exists):
ctc = cloud_test_catalog
catalog = ctc.catalog
create_tar_dataset(catalog, ctc.src_uri, "tarred")
dest = ctc.working_dir / "data"
if dir_exists:
dest.mkdir()
src = f"ds://tarred/{path}{suffix}"

if not dir_exists and suffix == "/*":
with pytest.raises(FileNotFoundError):
catalog.cp([src], str(dest), recursive=recursive)
return

catalog.cp([src], str(dest), recursive=recursive)

expected = DEFAULT_TREE["dogs"].copy()
if suffix in ("",) and dir_exists:
expected = {"dogs": expected}
if not recursive:
# Directories are not copied
if not dir_exists or suffix == "/":
expected = {}
else:
for key in list(expected):
if isinstance(expected[key], dict):
del expected[key]

assert tree_from_path(dest) == expected


@pytest.mark.parametrize(
"recursive,star,slash",
(
Expand Down Expand Up @@ -867,34 +779,6 @@ def clear_storages(catalog):
ds.db.execute(ds._storages.delete())


@pytest.mark.parametrize("tree", [TARRED_TREE], indirect=True)
@pytest.mark.xfail(reason="Missing support for datasets in ls")
def test_ls_subobjects(cloud_test_catalog):
ctc = cloud_test_catalog
catalog = ctc.catalog
create_tar_dataset(catalog, ctc.src_uri, "tarred")

def do_ls(target):
((_, results),) = list(catalog.ls([target], fields=["name"]))
results = list(results)
result_set = {x[0] for x in results}
assert len(result_set) == len(results)
return result_set

ds = "ds://tarred"
assert do_ls(ds) == {"animals.tar"}
assert do_ls(f"{ds}/animals.tar") == {"animals.tar"}
assert do_ls(f"{ds}/animals.tar/dogs") == {
"dog1",
"dog2",
"dog3",
"others",
}
assert do_ls(f"{ds}/animals.tar/") == {"description", "cats", "dogs"}
assert do_ls(f"{ds}/*.tar/") == {"description", "cats", "dogs"}
assert do_ls(f"{ds}/*.tar/desc*") == {"description"}


def test_index_error(cloud_test_catalog):
protocol = cloud_test_catalog.src_uri.split("://", 1)[0]
# XXX: different clients raise inconsistent exceptions
Expand Down
Loading

0 comments on commit 1135c73

Please sign in to comment.