diff --git a/daft/udf.py b/daft/udf.py index 846e49fe14..ab1410ec3a 100644 --- a/daft/udf.py +++ b/daft/udf.py @@ -269,7 +269,10 @@ def __call__(self, *args, **kwargs) -> Expression: # Validate that initialization arguments are provided if the __init__ signature indicates that there are # parameters without defaults init_sig = inspect.signature(self.cls.__init__) # type: ignore - if any(param.default is param.empty for param in init_sig.parameters.values()) and self.init_args is None: + if ( + any(param.default is param.empty for param in init_sig.parameters.values() if param.name != "self") + and self.init_args is None + ): raise ValueError( "Cannot call StatefulUDF without initialization arguments. Please either specify default arguments in your __init__ or provide " "initialization arguments using `.with_init_args(...)`." @@ -290,6 +293,13 @@ def with_init_args(self, *args, **kwargs) -> StatefulUDF: """Replace initialization arguments for the UDF when calling __init__ at runtime on each instance of the UDF. """ + init_sig = inspect.signature(self.cls.__init__) # type: ignore + init_sig.bind( + # Placeholder for `self` + None, + *args, + **kwargs, + ) return dataclasses.replace(self, init_args=(args, kwargs)) def bind_func(self, *args, **kwargs) -> inspect.BoundArguments: diff --git a/tests/expressions/test_udf.py b/tests/expressions/test_udf.py index d14399f93b..2bd1f12d9b 100644 --- a/tests/expressions/test_udf.py +++ b/tests/expressions/test_udf.py @@ -97,6 +97,19 @@ def __call__(self, data): assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]} +def test_class_udf_init_args_bad_args(): + @udf(return_dtype=DataType.string()) + class RepeatN: + def __init__(self, initial_n): + self.n = initial_n + + def __call__(self, data): + return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) + + with pytest.raises(TypeError, match="missing a required argument: 'initial_n'"): + RepeatN.with_init_args(wrong=5) + + def test_udf_kwargs(): table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]})