From 54f73a65bce99f6a0fa480f81f6eaace70165d54 Mon Sep 17 00:00:00 2001 From: Maddie Dawson Date: Thu, 27 Apr 2023 10:09:38 -0700 Subject: [PATCH] Address comments --- src/datasets/arrow_dataset.py | 18 ++++----- src/datasets/iterable_dataset.py | 43 +++++++++++++++++++++ tests/test_arrow_dataset.py | 58 ++-------------------------- tests/test_iterable_dataset.py | 65 +++++++++++++++++++++++++++++++- 4 files changed, 116 insertions(+), 68 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 1ba71aa2aa4b..7fa5e62254db 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1231,13 +1231,12 @@ def from_spark( df: "pyspark.sql.DataFrame", split: Optional[NamedSplit] = None, features: Optional[Features] = None, - streaming: bool = True, keep_in_memory: bool = False, cache_dir: str = None, load_from_cache_file: bool = True, **kwargs, ): - """Create Dataset from Spark DataFrame. Dataset downloading is distributed over Spark workers. + """Create a Dataset from Spark DataFrame. Dataset downloading is distributed over Spark workers. Args: df (`pyspark.sql.DataFrame`): @@ -1246,16 +1245,13 @@ def from_spark( Split name to be assigned to the dataset. features (`Features`, *optional*): Dataset features. - streaming (`bool`): - Whether to stream the dataset from the dataframe by returning an IterableDataset. Otherwise, the - dataframe will be materialized to `cache_dir`, and a Dataset will be returned. cache_dir (`str`, *optional*, defaults to `"~/.cache/huggingface/datasets"`): - Directory to cache data (if not streaming). When using a multi-node Spark cluster, the cache_dir must be - accessible to both workers and the driver. + Directory to cache data. When using a multi-node Spark cluster, the cache_dir must be accessible to both + workers and the driver. keep_in_memory (`bool`): - When not streaming, whether to copy the data in-memory. + Whether to copy the data in-memory. load_from_cache_file (`bool`): - When not streaming, whether to load the dataset from the cache if possible. + Whether to load the dataset from the cache if possible. Returns: [`Dataset`] @@ -1274,13 +1270,13 @@ def from_spark( from .io.spark import SparkDatasetReader if sys.platform == "win32": - raise EnvironmentError("Datasets.from_spark is not currently supported on Windows") + raise EnvironmentError("Dataset.from_spark is not currently supported on Windows") return SparkDatasetReader( df, split=split, features=features, - streaming=streaming, + streaming=False, cache_dir=cache_dir, keep_in_memory=keep_in_memory, load_from_cache_file=load_from_cache_file, diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 83465dd85f50..da1ede5dcaea 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1027,6 +1027,49 @@ def from_generator( streaming=True, ).read() + @staticmethod + def from_spark( + df: "pyspark.sql.DataFrame", + split: Optional[NamedSplit] = None, + features: Optional[Features] = None, + **kwargs, + ) -> "IterableDataset": + """Create an IterableDataset from Spark DataFrame. The dataset is streamed to the driver in batches. + + Args: + df (`pyspark.sql.DataFrame`): + The DataFrame containing the desired data. + split (`NamedSplit`, *optional*): + Split name to be assigned to the dataset. + features (`Features`, *optional*): + Dataset features. + + Returns: + [`IterableDataset`] + + Example: + + ```py + >>> df = spark.createDataFrame( + >>> data=[[1, "Elia"], [2, "Teo"], [3, "Fang"]], + >>> columns=["id", "name"], + >>> ) + >>> ds = IterableDataset.from_spark(df) + ``` + """ + from .io.spark import SparkDatasetReader + + if sys.platform == "win32": + raise EnvironmentError("IterableDataset.from_spark is not currently supported on Windows") + + return SparkDatasetReader( + df, + split=split, + features=features, + streaming=True, + **kwargs, + ).read() + def with_format( self, type: Optional[str] = None, diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index a18a2ab3cb05..06a75717b3dc 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -3635,40 +3635,13 @@ def test_from_spark(): ("3", 3, 3.0), ] df = spark.createDataFrame(data, "col_1: string, col_2: int, col_3: float") - dataset = Dataset.from_spark(df, streaming=False) + dataset = Dataset.from_spark(df) assert isinstance(dataset, Dataset) assert dataset.num_rows == 4 assert dataset.num_columns == 3 assert dataset.column_names == ["col_1", "col_2", "col_3"] -@require_not_windows -@require_dill_gt_0_3_2 -@require_pyspark -def test_from_spark_streaming(): - import pyspark - - spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate() - data = [ - ("0", 0, 0.0), - ("1", 1, 1.0), - ("2", 2, 2.0), - ("3", 3, 3.0), - ] - df = spark.createDataFrame(data, "col_1: string, col_2: int, col_3: float") - dataset = Dataset.from_spark(df, streaming=True) - assert isinstance(dataset, IterableDataset) - results = [] - for ex in dataset: - results.append(ex) - assert results == [ - {"col_1": "0", "col_2": 0, "col_3": 0.0}, - {"col_1": "1", "col_2": 1, "col_3": 1.0}, - {"col_1": "2", "col_2": 2, "col_3": 2.0}, - {"col_1": "3", "col_2": 3, "col_3": 3.0}, - ] - - @require_not_windows @require_dill_gt_0_3_2 @require_pyspark @@ -3683,7 +3656,6 @@ def test_from_spark_features(): dataset = Dataset.from_spark( df, features=features, - streaming=False, ) assert isinstance(dataset, Dataset) assert dataset.num_rows == 1 @@ -3694,30 +3666,6 @@ def test_from_spark_features(): assert_arrow_metadata_are_synced_with_dataset_features(dataset) -@require_not_windows -@require_dill_gt_0_3_2 -@require_pyspark -def test_from_spark_streaming_features(): - import PIL.Image - import pyspark - - spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate() - data = [(0, np.arange(4 * 4 * 3).reshape(4, 4, 3).tolist())] - df = spark.createDataFrame(data, "idx: int, image: array>>") - features = Features({"idx": Value("int64"), "image": Image()}) - dataset = Dataset.from_spark( - df, - features=features, - streaming=True, - ) - assert isinstance(dataset, IterableDataset) - results = [] - for ex in dataset: - results.append(ex) - assert len(results) == 1 - isinstance(results[0]["image"], PIL.Image.Image) - - @require_not_windows @require_dill_gt_0_3_2 @require_pyspark @@ -3726,10 +3674,10 @@ def test_from_spark_different_cache(): spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate() df = spark.createDataFrame([("0", 0)], "col_1: string, col_2: int") - dataset = Dataset.from_spark(df, streaming=False) + dataset = Dataset.from_spark(df) assert isinstance(dataset, Dataset) different_df = spark.createDataFrame([("1", 1)], "col_1: string, col_2: int") - different_dataset = Dataset.from_spark(different_df, streaming=False) + different_dataset = Dataset.from_spark(different_df) assert isinstance(different_dataset, Dataset) assert dataset[0]["col_1"] == "0" # Check to make sure that the second dataset wasn't read from the cache. diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 50d1e35de622..370b3d168ea9 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -7,7 +7,12 @@ from datasets import load_dataset from datasets.combine import concatenate_datasets, interleave_datasets -from datasets.features import ClassLabel, Features, Value +from datasets.features import ( + ClassLabel, + Features, + Image, + Value, +) from datasets.formatting import get_format_type_from_alias from datasets.info import DatasetInfo from datasets.iterable_dataset import ( @@ -27,7 +32,13 @@ _examples_to_batch, ) -from .utils import is_rng_equal, require_torch +from .utils import ( + is_rng_equal, + require_dill_gt_0_3_2, + require_not_windows, + require_pyspark, + require_torch, +) DEFAULT_N_EXAMPLES = 20 @@ -645,6 +656,56 @@ def gen(shard_names): assert dataset.n_shards == len(shard_names) +@require_not_windows +@require_dill_gt_0_3_2 +@require_pyspark +def test_from_spark_streaming(): + import pyspark + + spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate() + data = [ + ("0", 0, 0.0), + ("1", 1, 1.0), + ("2", 2, 2.0), + ("3", 3, 3.0), + ] + df = spark.createDataFrame(data, "col_1: string, col_2: int, col_3: float") + dataset = IterableDataset.from_spark(df) + assert isinstance(dataset, IterableDataset) + results = [] + for ex in dataset: + results.append(ex) + assert results == [ + {"col_1": "0", "col_2": 0, "col_3": 0.0}, + {"col_1": "1", "col_2": 1, "col_3": 1.0}, + {"col_1": "2", "col_2": 2, "col_3": 2.0}, + {"col_1": "3", "col_2": 3, "col_3": 3.0}, + ] + + +@require_not_windows +@require_dill_gt_0_3_2 +@require_pyspark +def test_from_spark_streaming_features(): + import PIL.Image + import pyspark + + spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate() + data = [(0, np.arange(4 * 4 * 3).reshape(4, 4, 3).tolist())] + df = spark.createDataFrame(data, "idx: int, image: array>>") + features = Features({"idx": Value("int64"), "image": Image()}) + dataset = IterableDataset.from_spark( + df, + features=features, + ) + assert isinstance(dataset, IterableDataset) + results = [] + for ex in dataset: + results.append(ex) + assert len(results) == 1 + isinstance(results[0]["image"], PIL.Image.Image) + + @require_torch def test_iterable_dataset_torch_integration(): ex_iterable = ExamplesIterable(generate_examples_fn, {})