Skip to content

Commit

Permalink
Address comments again
Browse files Browse the repository at this point in the history
  • Loading branch information
maddiedawson committed May 10, 2023
1 parent 006a3d1 commit a6e11ed
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 15 deletions.
2 changes: 2 additions & 0 deletions docs/source/use_with_spark.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ Alternatively, you can skip materialization by using [`IterableDataset.from_spar
... columns=["id", "name"],
... )
>>> ds = IterableDataset.from_spark(df)
>>> print(next(iter(ds)))
{"id": 1, "name": "Elia"}
```

### Caching
Expand Down
35 changes: 20 additions & 15 deletions src/datasets/packaged_modules/spark/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,19 @@ class SparkConfig(datasets.BuilderConfig):

def _generate_iterable_examples(
df: "pyspark.sql.DataFrame",
partition_order: List[int] = None,
partition_order: List[int],
):
import pyspark

def generate_fn():
row_id = 0
for row in df.rdd.toLocalIterator(True):
yield row_id, row.asDict()
row_id += 1
df_with_partition_id = df.select("*", pyspark.sql.functions.spark_partition_id().alias("part_id"))
for partition_id in partition_order:
partition_df = df_with_partition_id.select("*").where(f"part_id = {partition_id}").drop("part_id")
rows = partition_df.collect()
row_id = 0
for row in rows:
yield f"{partition_id}_{row_id}", row.asDict()
row_id += 1

return generate_fn

Expand All @@ -51,25 +57,24 @@ def __init__(
partition_order=None,
):
self.df = df
self.generate_examples_fn = _generate_iterable_examples(
self.df,
partition_order or range(self.df.rdd.getNumPartitions()),
)
self.partition_order = partition_order or range(self.df.rdd.getNumPartitions())
self.generate_examples_fn = _generate_iterable_examples(self.df, self.partition_order)

def __iter__(self):
yield from self.generate_examples_fn()

def shuffle_data_sources(self, generator: np.random.Generator) -> "SparkExamplesIterable":
partition_ids = range(self.df.rdd.getNumPartitions())
generator.shuffle(partition_ids)
return SparkExamplesIterable(self.df, partition_order=partition_ids)
partition_order = list(range(self.df.rdd.getNumPartitions()))
generator.shuffle(partition_order)
return SparkExamplesIterable(self.df, partition_order=partition_order)

def shard_data_sources(self, shard_indices: List[int]) -> "SparkExamplesIterable":
return SparkExamplesIterable(self.df, partition_order=shard_indices)
def shard_data_sources(self, worker_id: int, num_workers: int) -> "SparkExamplesIterable":
partition_order = self.split_shard_indices_by_worker(worker_id, num_workers)
return SparkExamplesIterable(self.df, partition_order=partition_order)

@property
def n_shards(self) -> int:
return self.df.rdd.getNumPartitions()
return len(self.partition_order)


class Spark(datasets.DatasetBuilder):
Expand Down
92 changes: 92 additions & 0 deletions tests/packaged_modules/test_spark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from unittest.mock import patch

import pyspark

from datasets.packaged_modules.spark.spark import (
SparkExamplesIterable,
_generate_iterable_examples,
)

from ..utils import (
require_dill_gt_0_3_2,
require_not_windows,
)


def _get_expected_row_ids_and_row_dicts_for_partition_order(df, partition_order):
expected_row_ids_and_row_dicts = []
for part_id in partition_order:
partition = df.where(f"SPARK_PARTITION_ID() = {part_id}").collect()
for row_idx, row in enumerate(partition):
expected_row_ids_and_row_dicts.append((f"{part_id}_{row_idx}", row.asDict()))
return expected_row_ids_and_row_dicts


@require_not_windows
@require_dill_gt_0_3_2
def test_generate_iterable_examples():
spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
df = spark.range(10).repartition(2)
partition_order = [1, 0]
generate_fn = _generate_iterable_examples(df, partition_order) # Reverse the partitions.
expected_row_ids_and_row_dicts = _get_expected_row_ids_and_row_dicts_for_partition_order(df, partition_order)

for i, (row_id, row_dict) in enumerate(generate_fn()):
expected_row_id, expected_row_dict = expected_row_ids_and_row_dicts[i]
assert row_id == expected_row_id
assert row_dict == expected_row_dict


@require_not_windows
@require_dill_gt_0_3_2
def test_spark_examples_iterable():
spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
df = spark.range(10).repartition(1)
it = SparkExamplesIterable(df)
assert it.n_shards == 1
for i, (row_id, row_dict) in enumerate(it):
assert row_id == f"0_{i}"
assert row_dict == {"id": i}


@require_not_windows
@require_dill_gt_0_3_2
def test_spark_examples_iterable_shuffle():
spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
df = spark.range(30).repartition(3)
# Mock the generator so that shuffle reverses the partition indices.
with patch("numpy.random.Generator") as generator_mock:
generator_mock.shuffle.side_effect = lambda x: x.reverse()
expected_row_ids_and_row_dicts = _get_expected_row_ids_and_row_dicts_for_partition_order(df, [2, 1, 0])

shuffled_it = SparkExamplesIterable(df).shuffle_data_sources(generator_mock)
assert shuffled_it.n_shards == 3
for i, (row_id, row_dict) in enumerate(shuffled_it):
expected_row_id, expected_row_dict = expected_row_ids_and_row_dicts[i]
assert row_id == expected_row_id
assert row_dict == expected_row_dict


@require_not_windows
@require_dill_gt_0_3_2
def test_spark_examples_iterable_shard():
spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
df = spark.range(20).repartition(4)

# Partitions 0 and 2
shard_it_1 = SparkExamplesIterable(df).shard_data_sources(worker_id=0, num_workers=2)
assert shard_it_1.n_shards == 2
expected_row_ids_and_row_dicts_1 = _get_expected_row_ids_and_row_dicts_for_partition_order(df, [0, 2])
for i, (row_id, row_dict) in enumerate(shard_it_1):
expected_row_id, expected_row_dict = expected_row_ids_and_row_dicts_1[i]
assert row_id == expected_row_id
assert row_dict == expected_row_dict

# Partitions 1 and 3
shard_it_2 = SparkExamplesIterable(df).shard_data_sources(worker_id=1, num_workers=2)
assert shard_it_2.n_shards == 2
expected_row_ids_and_row_dicts_2 = _get_expected_row_ids_and_row_dicts_for_partition_order(df, [1, 3])
for i, (row_id, row_dict) in enumerate(shard_it_2):
expected_row_id, expected_row_dict = expected_row_ids_and_row_dicts_2[i]
assert row_id == expected_row_id
assert row_dict == expected_row_dict

0 comments on commit a6e11ed

Please sign in to comment.