Skip to content

Commit

Permalink
Optimize Spark Dask take function (#530)
Browse files Browse the repository at this point in the history
* Optimize Spark Dask take function

* update

* update
  • Loading branch information
goodwanghan authored Jan 7, 2024
1 parent 29f105d commit 8008bfa
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 0 deletions.
11 changes: 11 additions & 0 deletions fugue_dask/execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,17 @@ def _partition_take(partition, n, presort):
).head(n)

else:
if len(_presort.keys()) == 0 and n == 1:
return DaskDataFrame(
d.drop_duplicates(
subset=partition_spec.partition_by,
ignore_index=True,
keep="first",
),
df.schema,
type_safe=False,
)

d = (
d.groupby(partition_spec.partition_by, dropna=False)
.apply(_partition_take, n=n, presort=_presort, meta=meta)
Expand Down
5 changes: 5 additions & 0 deletions fugue_spark/execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,11 @@ def _presort_to_col(_col: str, _asc: bool) -> Any:

# If partition exists
else:
if len(_presort.keys()) == 0 and n == 1:
return self._to_spark_df(
d.dropDuplicates(subset=partition_spec.partition_by), df.schema
)

w = Window.partitionBy([col(x) for x in partition_spec.partition_by])

if len(_presort.keys()) > 0:
Expand Down
40 changes: 40 additions & 0 deletions fugue_test/execution_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,46 @@ def test_take(self):
"a:str,b:int,c:long",
throw=True,
)
a = fa.as_fugue_engine_df(
e,
[
["a", 2, 3],
[None, 4, 2],
[None, 2, 1],
],
"a:str,b:int,c:long",
)
i = fa.take(a, n=1, partition="a", presort=None)
case1 = df_eq(
i,
[
["a", 2, 3],
[None, 4, 2],
],
"a:str,b:int,c:long",
throw=False,
)
case2 = df_eq(
i,
[
["a", 2, 3],
[None, 2, 1],
],
"a:str,b:int,c:long",
throw=False,
)
assert case1 or case2
j = fa.take(a, n=2, partition="a", presort=None)
df_eq(
j,
[
["a", 2, 3],
[None, 4, 2],
[None, 2, 1],
],
"a:str,b:int,c:long",
throw=True,
)
raises(ValueError, lambda: fa.take(a, n=0.5, presort=None))

def test_sample_n(self):
Expand Down

0 comments on commit 8008bfa

Please sign in to comment.