Skip to content

Commit

Permalink
Catch bad arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Aug 9, 2024
1 parent 14db577 commit 0a9fd9b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
12 changes: 11 additions & 1 deletion daft/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,10 @@ def __call__(self, *args, **kwargs) -> Expression:
# Validate that initialization arguments are provided if the __init__ signature indicates that there are
# parameters without defaults
init_sig = inspect.signature(self.cls.__init__) # type: ignore
if any(param.default is param.empty for param in init_sig.parameters.values()) and self.init_args is None:
if (
any(param.default is param.empty for param in init_sig.parameters.values() if param.name != "self")
and self.init_args is None
):
raise ValueError(
"Cannot call StatefulUDF without initialization arguments. Please either specify default arguments in your __init__ or provide "
"initialization arguments using `.with_init_args(...)`."
Expand All @@ -290,6 +293,13 @@ def with_init_args(self, *args, **kwargs) -> StatefulUDF:
"""Replace initialization arguments for the UDF when calling __init__ at runtime
on each instance of the UDF.
"""
init_sig = inspect.signature(self.cls.__init__) # type: ignore
init_sig.bind(
# Placeholder for `self`
None,
*args,
**kwargs,
)
return dataclasses.replace(self, init_args=(args, kwargs))

def bind_func(self, *args, **kwargs) -> inspect.BoundArguments:
Expand Down
13 changes: 13 additions & 0 deletions tests/expressions/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,19 @@ def __call__(self, data):
assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]}


def test_class_udf_init_args_bad_args():
@udf(return_dtype=DataType.string())
class RepeatN:
def __init__(self, initial_n):
self.n = initial_n

def __call__(self, data):
return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()])

with pytest.raises(TypeError, match="missing a required argument: 'initial_n'"):
RepeatN.with_init_args(wrong=5)


def test_udf_kwargs():
table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]})

Expand Down

0 comments on commit 0a9fd9b

Please sign in to comment.