diff --git a/greenplumpython/func.py b/greenplumpython/func.py index c92082b7..dff81877 100644 --- a/greenplumpython/func.py +++ b/greenplumpython/func.py @@ -117,10 +117,18 @@ def apply( ): return_annotation = inspect.signature(self._function._wrapped_func).return_annotation # type: ignore reportUnknownArgumentType _serialize_to_type_name(return_annotation, db=db, for_return=True) + input_args = self._args + if len(input_args) == 0: + raise Exception("No input data specified, please specify a DataFrame or Columns") + input_clause = ( + "*" + if (len(input_args) == 1 and isinstance(input_args[0], DataFrame)) + else ",".join([arg._serialize(db=db) for arg in input_args]) + ) return DataFrame( f""" SELECT * FROM plcontainer_apply(TABLE( - SELECT * {from_clause}), '{self._function._qualified_name_str}', 4096) AS + SELECT {input_clause} {from_clause}), '{self._function._qualified_name_str}', 4096) AS {_defined_types[return_annotation.__args__[0]]._serialize(db=db)} """, db=db, @@ -370,6 +378,7 @@ def _serialize(self, db: Database) -> str: f" import sys as {sys_lib_name}\n" f" if {sysconfig_lib_name}.get_python_version() != '{python_version}':\n" f" raise ModuleNotFoundError\n" + f" {sys_lib_name}.modules['plpy']=plpy\n" f" setattr({sys_lib_name}.modules['plpy'], '_SD', SD)\n" f" GD['{func_ast.name}'] = {pickle_lib_name}.loads({func_pickled})\n" f" except ModuleNotFoundError:\n" diff --git a/tests/test_plcontainer.py b/tests/test_plcontainer.py index 6b16f051..1a1b39e4 100644 --- a/tests/test_plcontainer.py +++ b/tests/test_plcontainer.py @@ -1,25 +1,56 @@ from dataclasses import dataclass +import pytest + import greenplumpython as gp from tests import db -def test_simple_func(db: gp.Database): - @dataclass - class Int: - i: int +@dataclass +class Int: + i: int + + +@dataclass +class Pair: + i: int + j: int + + +@pytest.fixture +def t(db: gp.Database): + rows = [(i, i) for i in range(10)] + return db.create_dataframe(rows=rows, column_names=["a", "b"]) + + +@gp.create_function(language_handler="plcontainer", runtime="plc_python_example") +def add_one(x: list[Int]) -> list[Int]: + return [{"i": arg["i"] + 1} for arg in x] - @gp.create_function(language_handler="plcontainer", runtime="plc_python_example") - def add_one(x: list[Int]) -> list[Int]: - return [{"i": arg["i"] + 1} for arg in x] +def test_simple_func(db: gp.Database): assert ( len( list( db.create_dataframe(columns={"i": range(10)}).apply( - lambda _: add_one(), expand=True + lambda t: add_one(t), expand=True ) ) ) == 10 ) + + +def test_func_no_input(db: gp.Database): + + with pytest.raises(Exception) as exc_info: # no input data for func raises Exception + db.create_dataframe(columns={"i": range(10)}).apply(lambda _: add_one(), expand=True) + assert "No input data specified, please specify a DataFrame or Columns" in str(exc_info.value) + + +def test_func_column(db: gp.Database, t: gp.DataFrame): + @gp.create_function(language_handler="plcontainer", runtime="plc_python_example") + def add(x: list[Pair]) -> list[Int]: + return [{"i": arg["i"] + arg["j"]} for arg in x] + + assert len(list(t.apply(lambda t: add(t["a"], t["b"]), expand=True))) == 10