Skip to content

Commit

Permalink
fix and add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
jaychia committed Oct 7, 2022
1 parent 4ef5fd1 commit 24124ed
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 12 deletions.
2 changes: 1 addition & 1 deletion daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ def explode(self, *columns: ColumnInputType) -> DataFrame:
exprs_to_explode = self.__column_input_to_expression(columns)
explode_op = logical_plan.Explode(
self._plan,
exprs_to_explode,
ExpressionList([e._explode() for e in exprs_to_explode]),
)
return DataFrame(explode_op)

Expand Down
11 changes: 7 additions & 4 deletions daft/logical/logical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from typing import Any, Generic, List, Optional, Tuple, TypeVar, Union

from daft.datasources import SourceInfo, StorageType
from daft.expressions import ColumnExpression, Expression
from daft.execution.operators import OperatorEnum
from daft.expressions import CallExpression, ColumnExpression, Expression
from daft.internal.treenode import TreeNode
from daft.logical.map_partition_ops import ExplodeOp, MapPartitionOp
from daft.logical.schema import ExpressionList
Expand Down Expand Up @@ -360,9 +361,11 @@ def eval_partition(self, partition: vPartition) -> vPartition:

class Explode(MapPartition[ExplodeOp]):
def __init__(self, input: LogicalPlan, explode_expressions: ExpressionList):
explode_expressions_resolved = ExpressionList([e._explode() for e in explode_expressions]).resolve(
input.schema()
)
assert [
isinstance(e, CallExpression) and e._operator == OperatorEnum.EXPLODE for e in explode_expressions
], "Expressions supplied to Explode LogicalPlan must be a CallExpression with OperatorEnum.EXPLODE"

explode_expressions_resolved = explode_expressions.resolve(input.schema())
map_partition_op = ExplodeOp(explode_columns=explode_expressions_resolved)
super().__init__(
input,
Expand Down
20 changes: 13 additions & 7 deletions tests/dataframe_cookbook/test_explodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,23 @@ def add_one(o: MyObj):
return None
return o._x + 1

data = {"explode": [[MyObj(1), MyObj(2), MyObj(3)], [MyObj(4), MyObj(5)], [], None], "repeat": ["a", "b", "c", "d"]}
data = {
"explode": [[MyObj(1), MyObj(2), MyObj(3)], [MyObj(4), MyObj(5)], [], None],
"repeat": ["a", "b", "c", "d"],
"repeat2": ["a", "b", "c", "d"],
}
df = DataFrame.from_pydict(data).repartition(nrepartitions)
df = df.explode(col("explode"))
df = df.with_column("explode_plus1", col("explode").apply(add_one))

assert df.schema()["explode"].daft_type == ExpressionType.python_object()

df = df.select(col("explode_plus1"))

df = df.with_column("explode_plus1", col("explode").apply(add_one))
df = df.select(col("explode_plus1"), col("repeat"))
assert df.schema()["explode_plus1"].daft_type == ExpressionType.python_object()

pd_df = pd.DataFrame(data)
pd_df = pd_df.explode("explode")
pd_df["explode_plus1"] = pd_df["explode"].apply(add_one)
pd_df = pd_df[["explode_plus1"]]
pd_df = pd_df[["explode_plus1", "repeat"]]

df.collect()
daft_pd_df = pd.DataFrame(df._result.to_pydict())
Expand All @@ -51,7 +53,11 @@ def add_one(o: MyObj):

@pytest.mark.parametrize("nrepartitions", [1, 5])
def test_explode_single_col_arrow(nrepartitions):
data = {"explode": pa.array([[1, 2, 3], [4, 5], [], None]), "repeat": ["a", "b", "c", "d"]}
data = {
"explode": pa.array([[1, 2, 3], [4, 5], [], None]),
"repeat": ["a", "b", "c", "d"],
"repeat2": ["a", "b", "c", "d"],
}
df = DataFrame.from_pydict(data).repartition(nrepartitions)
df = df.explode(col("explode"))

Expand Down

0 comments on commit 24124ed

Please sign in to comment.