Skip to content

Commit

Permalink
Add dict[str,Any] as supported input and output (#547)
Browse files Browse the repository at this point in the history
* Add dict[str,Any] as supported input and output

* update

* fix
  • Loading branch information
goodwanghan committed Jun 28, 2024
1 parent 1adc576 commit c6a7f7c
Show file tree
Hide file tree
Showing 9 changed files with 188 additions and 23 deletions.
91 changes: 89 additions & 2 deletions fugue/dataframe/function_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def run( # noqa: C901
p.update(kwargs)
has_kw = False
rargs: Dict[str, Any] = {}
row_param_info: Any = None
for k, v in self._params.items():
if isinstance(v, (PositionalParam, KeywordParam)):
if isinstance(v, KeywordParam):
Expand All @@ -90,7 +91,14 @@ def run( # noqa: C901
isinstance(p[k], DataFrame),
lambda: TypeError(f"{p[k]} is not a DataFrame"),
)
rargs[k] = v.to_input_data(p[k], ctx=ctx)
if v.is_per_row:
assert_or_throw(
row_param_info is None,
lambda: ValueError("only one row parameter is allowed"),
)
row_param_info = (k, v, p[k])
else:
rargs[k] = v.to_input_data(p[k], ctx=ctx)
else:
rargs[k] = p[k] # TODO: should we do auto type conversion?
del p[k]
Expand All @@ -100,12 +108,38 @@ def run( # noqa: C901
rargs.update(p)
elif not ignore_unknown and len(p) > 0:
raise ValueError(f"{p} are not acceptable parameters")
if row_param_info is None:
return self._run_func(rargs, output, output_schema, ctx, raw=False)
else: # input contains row parameter

def _dfs() -> Iterable[Any]:
k, v, df = row_param_info
for row in v.to_input_rows(df, ctx):
rargs[k] = None
_rargs = rargs.copy()
_rargs[k] = row
yield self._run_func(_rargs, output, output_schema, ctx, raw=True)

if not output:
sum(1 for _ in _dfs())
return
else:
return self._rt.iterable_to_output_df(_dfs(), output_schema, ctx)

def _run_func(
self,
rargs: Dict[str, Any],
output: bool,
output_schema: Any,
ctx: Any,
raw: bool,
) -> Any:
rt = self._func(**rargs)
if not output:
if isinstance(self._rt, _DataFrameParamBase):
self._rt.count(rt)
return
if isinstance(self._rt, _DataFrameParamBase):
if not raw and isinstance(self._rt, _DataFrameParamBase):
return self._rt.to_output_df(rt, output_schema, ctx=ctx)
return rt

Expand Down Expand Up @@ -145,14 +179,30 @@ def __init__(self, param: Optional[inspect.Parameter]):
super().__init__(param)
assert_or_throw(self.required, lambda: TypeError(f"{self} must be required"))

@property
def is_per_row(self) -> bool:
return False

def to_input_data(self, df: DataFrame, ctx: Any) -> Any: # pragma: no cover
raise NotImplementedError

def to_input_rows(
self,
df: DataFrame,
ctx: Any,
) -> Iterable[Any]:
raise NotImplementedError # pragma: no cover

def to_output_df(
self, df: Any, schema: Any, ctx: Any
) -> DataFrame: # pragma: no cover
raise NotImplementedError

def iterable_to_output_df(
self, dfs: Iterable[Any], schema: Any, ctx: Any
) -> DataFrame: # pragma: no cover
raise NotImplementedError

def count(self, df: Any) -> int: # pragma: no cover
raise NotImplementedError

Expand Down Expand Up @@ -182,6 +232,34 @@ def count(self, df: Any) -> int:
return sum(1 for _ in df.as_array_iterable())


@fugue_annotated_param(DataFrame, "r", child_can_reuse_code=True)
class RowParam(_DataFrameParamBase):
@property
def is_per_row(self) -> bool:
return True

def count(self, df: Any) -> int:
return 1


@fugue_annotated_param(Dict[str, Any])
class DictParam(RowParam):
def to_input_rows(self, df: DataFrame, ctx: Any) -> Iterable[Any]:
yield from df.as_dict_iterable()

def to_output_df(self, output: Dict[str, Any], schema: Any, ctx: Any) -> DataFrame:
return ArrayDataFrame([list(output.values())], schema)

def iterable_to_output_df(
self, dfs: Iterable[Dict[str, Any]], schema: Any, ctx: Any
) -> DataFrame: # pragma: no cover
params: Dict[str, Any] = {}
if schema is not None:
params["schema"] = Schema(schema).pa_schema
adf = pa.Table.from_pylist(list(dfs), **params)
return ArrowDataFrame(adf)


@fugue_annotated_param(AnyDataFrame)
class _AnyDataFrameParam(DataFrameParam):
def to_output_df(self, output: AnyDataFrame, schema: Any, ctx: Any) -> DataFrame:
Expand All @@ -207,6 +285,15 @@ def to_output_df(self, output: LocalDataFrame, schema: Any, ctx: Any) -> DataFra
)
return output

def iterable_to_output_df(
self, dfs: Iterable[Any], schema: Any, ctx: Any
) -> DataFrame: # pragma: no cover
def _dfs() -> Iterable[DataFrame]:
for df in dfs:
yield self.to_output_df(df, schema, ctx)

return LocalDataFrameIterableDataFrame(_dfs(), schema=schema)

def count(self, df: LocalDataFrame) -> int:
if df.is_bounded:
return df.count()
Expand Down
8 changes: 4 additions & 4 deletions fugue/extensions/transformer/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def from_func(
assert_arg_not_none(schema, "schema")
tr = _FuncAsTransformer()
tr._wrapper = DataFrameFunctionWrapper( # type: ignore
func, "^[lspq][fF]?x*z?$", "^[lspq]$"
func, "^[lspqr][fF]?x*z?$", "^[lspqr]$"
)
tr._output_schema_arg = schema # type: ignore
tr._validation_rules = validation_rules # type: ignore
Expand Down Expand Up @@ -410,7 +410,7 @@ def from_func(
validation_rules.update(parse_validation_rules_from_comment(func))
tr = _FuncAsOutputTransformer()
tr._wrapper = DataFrameFunctionWrapper( # type: ignore
func, "^[lspq][fF]?x*z?$", "^[lspnq]$"
func, "^[lspqr][fF]?x*z?$", "^[lspnqr]$"
)
tr._output_schema_arg = None # type: ignore
tr._validation_rules = validation_rules # type: ignore
Expand Down Expand Up @@ -503,7 +503,7 @@ def from_func(
assert_arg_not_none(schema, "schema")
tr = _FuncAsCoTransformer()
tr._wrapper = DataFrameFunctionWrapper( # type: ignore
func, "^(c|[lspq]+)[fF]?x*z?$", "^[lspq]$"
func, "^(c|[lspq]+)[fF]?x*z?$", "^[lspqr]$"
)
tr._dfs_input = tr._wrapper.input_code[0] == "c" # type: ignore
tr._output_schema_arg = schema # type: ignore
Expand Down Expand Up @@ -562,7 +562,7 @@ def from_func(

tr = _FuncAsOutputCoTransformer()
tr._wrapper = DataFrameFunctionWrapper( # type: ignore
func, "^(c|[lspq]+)[fF]?x*z?$", "^[lspnq]$"
func, "^(c|[lspq]+)[fF]?x*z?$", "^[lspnqr]$"
)
tr._dfs_input = tr._wrapper.input_code[0] == "c" # type: ignore
tr._output_schema_arg = None # type: ignore
Expand Down
37 changes: 22 additions & 15 deletions fugue_ray/_utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from packaging import version
from pyarrow import csv as pacsv
from pyarrow import json as pajson
from ray.data.datasource import FileExtensionFilter

from triad.collections import Schema
from triad.collections.dict import ParamDict
from triad.utils.assertion import assert_or_throw
Expand All @@ -21,6 +21,27 @@

from .._constants import RAY_VERSION

try:
from ray.data.datasource import FileExtensionFilter

class _FileFiler(FileExtensionFilter): # pragma: no cover
def __init__(
self, file_extensions: Union[str, List[str]], exclude: Iterable[str]
):
super().__init__(file_extensions, allow_if_no_extension=True)
self._exclude = set(exclude)

def _is_valid(self, path: str) -> bool:
return pathlib.Path(
path
).name not in self._exclude and self._file_has_extension(path)

def __call__(self, paths: List[str]) -> List[str]:
return [path for path in paths if self._is_valid(path)]

except ImportError: # pragma: no cover
pass # ray >=2.10


class RayIO(object):
def __init__(self, engine: ExecutionEngine):
Expand Down Expand Up @@ -248,17 +269,3 @@ def _read_json() -> RayDataFrame: # pragma: no cover

def _remote_args(self) -> Dict[str, Any]:
return {"num_cpus": 1}


class _FileFiler(FileExtensionFilter): # pragma: no cover
def __init__(self, file_extensions: Union[str, List[str]], exclude: Iterable[str]):
super().__init__(file_extensions, allow_if_no_extension=True)
self._exclude = set(exclude)

def _is_valid(self, path: str) -> bool:
return pathlib.Path(
path
).name not in self._exclude and self._file_has_extension(path)

def __call__(self, paths: List[str]) -> List[str]:
return [path for path in paths if self._is_valid(path)]
37 changes: 36 additions & 1 deletion fugue_test/builtin_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,23 @@ def mt_arrow_2(dfs: Iterable[pa.Table]) -> Iterator[pa.Table]:
dag.df([], "a:int,b:int").assert_eq(b)
dag.run(self.engine)

def test_transform_row_wise(self):
def t1(row: Dict[str, Any]) -> Dict[str, Any]:
row["b"] = 1
return row

def t2(rows: List[Dict[str, Any]]) -> Dict[str, Any]:
return rows[0]

with fa.engine_context(self.engine):
a = pd.DataFrame([[3, 4], [1, 2], [3, 5]], columns=["a", "b"])
b = fa.transform(a, t1, schema="*")
assert sorted(fa.as_array(b)) == [[1, 1], [3, 1], [3, 1]]
b = fa.transform(
a, t2, schema="*", partition={"by": "a", "presort": "b"}
)
assert sorted(fa.as_array(b)) == [[1, 2], [3, 4]]

def test_transform_binary(self):
with FugueWorkflow() as dag:
a = dag.df([[1, pickle.dumps([0, "a"])]], "a:int,b:bytes")
Expand Down Expand Up @@ -548,6 +565,8 @@ def test_cotransform(self):
e = dag.df([[1, 2, 1, 10]], "a:int,ct1:int,ct2:int,x:int")
e.assert_eq(c)

a.zip(b).transform(mock_co_tf1_d, params=dict(p=10)).assert_eq(e)

# interfaceless
c = dag.transform(
a.zip(b),
Expand Down Expand Up @@ -676,6 +695,13 @@ def t10(df: pd.DataFrame) -> Iterable[pa.Table]:
incr()
yield pa.Table.from_pandas(df)

def t11(row: Dict[str, Any]) -> Dict[str, Any]:
incr()
return row

def t12(row: Dict[str, Any]) -> None:
incr()

with FugueWorkflow() as dag:
a = dag.df([[1, 2], [3, 4]], "a:double,b:int")
a.out_transform(t1) # +2
Expand All @@ -688,14 +714,16 @@ def t10(df: pd.DataFrame) -> Iterable[pa.Table]:
a.out_transform(t8, ignore_errors=[NotImplementedError]) # +1
a.out_transform(t9) # +1
a.out_transform(t10) # +1
a.out_transform(t11) # +2
a.out_transform(t12) # +2
raises(FugueWorkflowCompileValidationError, lambda: a.out_transform(t2))
raises(FugueWorkflowCompileValidationError, lambda: a.out_transform(t3))
raises(FugueWorkflowCompileValidationError, lambda: a.out_transform(t4))
raises(FugueWorkflowCompileValidationError, lambda: a.out_transform(t5))
raises(FugueWorkflowCompileValidationError, lambda: a.out_transform(T7))
dag.run(self.engine)

assert 13 <= incr()
assert 17 <= incr()

def test_out_cotransform(self): # noqa: C901
tmpdir = str(self.tmpdir)
Expand Down Expand Up @@ -2001,6 +2029,13 @@ def mock_co_tf1(
return [[df1[0]["a"], len(df1), len(df2), p]]


@cotransformer(lambda dfs, **kwargs: "a:int,ct1:int,ct2:int,x:int")
def mock_co_tf1_d(
df1: List[Dict[str, Any]], df2: List[List[Any]], p=1
) -> Dict[str, Any]:
return dict(a=df1[0]["a"], ct1=len(df1), ct2=len(df2), x=p)


def mock_co_tf2(dfs: DataFrames, p=1) -> List[List[Any]]:
return [[dfs[0].peek_dict()["a"], dfs[0].count(), dfs[1].count(), p]]

Expand Down
2 changes: 1 addition & 1 deletion fugue_version/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.9.1"
__version__ = "0.9.2"
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def test__to_transformer():
assert isinstance(g, CoTransformer)
i = _to_transformer("t7", "a:int,b:int")
assert isinstance(i, CoTransformer)
j = _to_transformer("t8", "a:int,b:int")
assert isinstance(j, CoTransformer)


def test__register():
Expand Down Expand Up @@ -135,6 +137,12 @@ def t7(
yield df


def t8(
df1: pd.DataFrame, df2: pd.DataFrame, c: callable, **kwargs
) -> Dict[str, Any]:
return {}


class MockTransformer(CoTransformer):
def get_output_schema(self, dfs):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def test__to_output_transformer():
assert isinstance(g, CoTransformer)
i = _to_output_transformer("t7")
assert isinstance(i, CoTransformer)
j = _to_output_transformer("t8")
assert isinstance(j, CoTransformer)


def test__register():
Expand Down Expand Up @@ -106,6 +108,12 @@ def t7(
pass


def t8(
df1: Iterable[List[Any]], df2: pd.DataFrame, c: Callable
) -> Dict[str, Any]:
pass


class MockTransformer(CoTransformer):
def get_output_schema(self, dfs):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def test__to_output_transformer():
assert isinstance(h, Transformer)
i = _to_output_transformer("t8")
assert isinstance(i, Transformer)
j = _to_output_transformer("t9")
assert isinstance(j, Transformer)


def test__register():
Expand Down Expand Up @@ -156,6 +158,10 @@ def t8(df: pd.DataFrame, c: Callable[[str], str]) -> Iterable[pd.DataFrame]:
pass


def t9(df: Dict[str, Any], c: Callable[[str], str]) -> Dict[str, Any]:
pass


class MockTransformer(Transformer):
def __init__(self, x=""):
self._x = x
Expand Down
Loading

0 comments on commit c6a7f7c

Please sign in to comment.