diff --git a/docs/source/use_with_spark.mdx b/docs/source/use_with_spark.mdx index 071d4f4af85..07767ca447f 100644 --- a/docs/source/use_with_spark.mdx +++ b/docs/source/use_with_spark.mdx @@ -31,6 +31,8 @@ Alternatively, you can skip materialization by using [`IterableDataset.from_spar ... columns=["id", "name"], ... ) >>> ds = IterableDataset.from_spark(df) + >>> print(next(iter(ds))) + {"id": 1, "name": "Elia"} ``` ### Caching diff --git a/src/datasets/packaged_modules/spark/spark.py b/src/datasets/packaged_modules/spark/spark.py index bd69a330732..f3e5821b59c 100644 --- a/src/datasets/packaged_modules/spark/spark.py +++ b/src/datasets/packaged_modules/spark/spark.py @@ -33,13 +33,19 @@ class SparkConfig(datasets.BuilderConfig): def _generate_iterable_examples( df: "pyspark.sql.DataFrame", - partition_order: List[int] = None, + partition_order: List[int], ): + import pyspark + def generate_fn(): - row_id = 0 - for row in df.rdd.toLocalIterator(True): - yield row_id, row.asDict() - row_id += 1 + df_with_partition_id = df.select("*", pyspark.sql.functions.spark_partition_id().alias("part_id")) + for partition_id in partition_order: + partition_df = df_with_partition_id.select("*").where(f"part_id = {partition_id}").drop("part_id") + rows = partition_df.collect() + row_id = 0 + for row in rows: + yield f"{partition_id}_{row_id}", row.asDict() + row_id += 1 return generate_fn @@ -51,25 +57,24 @@ def __init__( partition_order=None, ): self.df = df - self.generate_examples_fn = _generate_iterable_examples( - self.df, - partition_order or range(self.df.rdd.getNumPartitions()), - ) + self.partition_order = partition_order or range(self.df.rdd.getNumPartitions()) + self.generate_examples_fn = _generate_iterable_examples(self.df, self.partition_order) def __iter__(self): yield from self.generate_examples_fn() def shuffle_data_sources(self, generator: np.random.Generator) -> "SparkExamplesIterable": - partition_ids = range(self.df.rdd.getNumPartitions()) - generator.shuffle(partition_ids) - return SparkExamplesIterable(self.df, partition_order=partition_ids) + partition_order = list(range(self.df.rdd.getNumPartitions())) + generator.shuffle(partition_order) + return SparkExamplesIterable(self.df, partition_order=partition_order) - def shard_data_sources(self, shard_indices: List[int]) -> "SparkExamplesIterable": - return SparkExamplesIterable(self.df, partition_order=shard_indices) + def shard_data_sources(self, worker_id: int, num_workers: int) -> "SparkExamplesIterable": + partition_order = self.split_shard_indices_by_worker(worker_id, num_workers) + return SparkExamplesIterable(self.df, partition_order=partition_order) @property def n_shards(self) -> int: - return self.df.rdd.getNumPartitions() + return len(self.partition_order) class Spark(datasets.DatasetBuilder): diff --git a/tests/packaged_modules/test_spark.py b/tests/packaged_modules/test_spark.py new file mode 100644 index 00000000000..1c073dbb662 --- /dev/null +++ b/tests/packaged_modules/test_spark.py @@ -0,0 +1,92 @@ +from unittest.mock import patch + +import pyspark + +from datasets.packaged_modules.spark.spark import ( + SparkExamplesIterable, + _generate_iterable_examples, +) + +from ..utils import ( + require_dill_gt_0_3_2, + require_not_windows, +) + + +def _get_expected_row_ids_and_row_dicts_for_partition_order(df, partition_order): + expected_row_ids_and_row_dicts = [] + for part_id in partition_order: + partition = df.where(f"SPARK_PARTITION_ID() = {part_id}").collect() + for row_idx, row in enumerate(partition): + expected_row_ids_and_row_dicts.append((f"{part_id}_{row_idx}", row.asDict())) + return expected_row_ids_and_row_dicts + + +@require_not_windows +@require_dill_gt_0_3_2 +def test_generate_iterable_examples(): + spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate() + df = spark.range(10).repartition(2) + partition_order = [1, 0] + generate_fn = _generate_iterable_examples(df, partition_order) # Reverse the partitions. + expected_row_ids_and_row_dicts = _get_expected_row_ids_and_row_dicts_for_partition_order(df, partition_order) + + for i, (row_id, row_dict) in enumerate(generate_fn()): + expected_row_id, expected_row_dict = expected_row_ids_and_row_dicts[i] + assert row_id == expected_row_id + assert row_dict == expected_row_dict + + +@require_not_windows +@require_dill_gt_0_3_2 +def test_spark_examples_iterable(): + spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate() + df = spark.range(10).repartition(1) + it = SparkExamplesIterable(df) + assert it.n_shards == 1 + for i, (row_id, row_dict) in enumerate(it): + assert row_id == f"0_{i}" + assert row_dict == {"id": i} + + +@require_not_windows +@require_dill_gt_0_3_2 +def test_spark_examples_iterable_shuffle(): + spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate() + df = spark.range(30).repartition(3) + # Mock the generator so that shuffle reverses the partition indices. + with patch("numpy.random.Generator") as generator_mock: + generator_mock.shuffle.side_effect = lambda x: x.reverse() + expected_row_ids_and_row_dicts = _get_expected_row_ids_and_row_dicts_for_partition_order(df, [2, 1, 0]) + + shuffled_it = SparkExamplesIterable(df).shuffle_data_sources(generator_mock) + assert shuffled_it.n_shards == 3 + for i, (row_id, row_dict) in enumerate(shuffled_it): + expected_row_id, expected_row_dict = expected_row_ids_and_row_dicts[i] + assert row_id == expected_row_id + assert row_dict == expected_row_dict + + +@require_not_windows +@require_dill_gt_0_3_2 +def test_spark_examples_iterable_shard(): + spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate() + df = spark.range(20).repartition(4) + + # Partitions 0 and 2 + shard_it_1 = SparkExamplesIterable(df).shard_data_sources(worker_id=0, num_workers=2) + assert shard_it_1.n_shards == 2 + expected_row_ids_and_row_dicts_1 = _get_expected_row_ids_and_row_dicts_for_partition_order(df, [0, 2]) + for i, (row_id, row_dict) in enumerate(shard_it_1): + expected_row_id, expected_row_dict = expected_row_ids_and_row_dicts_1[i] + assert row_id == expected_row_id + assert row_dict == expected_row_dict + + # Partitions 1 and 3 + shard_it_2 = SparkExamplesIterable(df).shard_data_sources(worker_id=1, num_workers=2) + assert shard_it_2.n_shards == 2 + expected_row_ids_and_row_dicts_2 = _get_expected_row_ids_and_row_dicts_for_partition_order(df, [1, 3]) + for i, (row_id, row_dict) in enumerate(shard_it_2): + expected_row_id, expected_row_dict = expected_row_ids_and_row_dicts_2[i] + assert row_id == expected_row_id + assert row_dict == expected_row_dict