Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
maddiedawson committed Apr 27, 2023
1 parent e182947 commit 54f73a6
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 68 deletions.
18 changes: 7 additions & 11 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,13 +1231,12 @@ def from_spark(
df: "pyspark.sql.DataFrame",
split: Optional[NamedSplit] = None,
features: Optional[Features] = None,
streaming: bool = True,
keep_in_memory: bool = False,
cache_dir: str = None,
load_from_cache_file: bool = True,
**kwargs,
):
"""Create Dataset from Spark DataFrame. Dataset downloading is distributed over Spark workers.
"""Create a Dataset from Spark DataFrame. Dataset downloading is distributed over Spark workers.
Args:
df (`pyspark.sql.DataFrame`):
Expand All @@ -1246,16 +1245,13 @@ def from_spark(
Split name to be assigned to the dataset.
features (`Features`, *optional*):
Dataset features.
streaming (`bool`):
Whether to stream the dataset from the dataframe by returning an IterableDataset. Otherwise, the
dataframe will be materialized to `cache_dir`, and a Dataset will be returned.
cache_dir (`str`, *optional*, defaults to `"~/.cache/huggingface/datasets"`):
Directory to cache data (if not streaming). When using a multi-node Spark cluster, the cache_dir must be
accessible to both workers and the driver.
Directory to cache data. When using a multi-node Spark cluster, the cache_dir must be accessible to both
workers and the driver.
keep_in_memory (`bool`):
When not streaming, whether to copy the data in-memory.
Whether to copy the data in-memory.
load_from_cache_file (`bool`):
When not streaming, whether to load the dataset from the cache if possible.
Whether to load the dataset from the cache if possible.
Returns:
[`Dataset`]
Expand All @@ -1274,13 +1270,13 @@ def from_spark(
from .io.spark import SparkDatasetReader

if sys.platform == "win32":
raise EnvironmentError("Datasets.from_spark is not currently supported on Windows")
raise EnvironmentError("Dataset.from_spark is not currently supported on Windows")

return SparkDatasetReader(
df,
split=split,
features=features,
streaming=streaming,
streaming=False,
cache_dir=cache_dir,
keep_in_memory=keep_in_memory,
load_from_cache_file=load_from_cache_file,
Expand Down
43 changes: 43 additions & 0 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,49 @@ def from_generator(
streaming=True,
).read()

@staticmethod
def from_spark(
df: "pyspark.sql.DataFrame",
split: Optional[NamedSplit] = None,
features: Optional[Features] = None,
**kwargs,
) -> "IterableDataset":
"""Create an IterableDataset from Spark DataFrame. The dataset is streamed to the driver in batches.
Args:
df (`pyspark.sql.DataFrame`):
The DataFrame containing the desired data.
split (`NamedSplit`, *optional*):
Split name to be assigned to the dataset.
features (`Features`, *optional*):
Dataset features.
Returns:
[`IterableDataset`]
Example:
```py
>>> df = spark.createDataFrame(
>>> data=[[1, "Elia"], [2, "Teo"], [3, "Fang"]],
>>> columns=["id", "name"],
>>> )
>>> ds = IterableDataset.from_spark(df)
```
"""
from .io.spark import SparkDatasetReader

if sys.platform == "win32":
raise EnvironmentError("IterableDataset.from_spark is not currently supported on Windows")

return SparkDatasetReader(
df,
split=split,
features=features,
streaming=True,
**kwargs,
).read()

def with_format(
self,
type: Optional[str] = None,
Expand Down
58 changes: 3 additions & 55 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3635,40 +3635,13 @@ def test_from_spark():
("3", 3, 3.0),
]
df = spark.createDataFrame(data, "col_1: string, col_2: int, col_3: float")
dataset = Dataset.from_spark(df, streaming=False)
dataset = Dataset.from_spark(df)
assert isinstance(dataset, Dataset)
assert dataset.num_rows == 4
assert dataset.num_columns == 3
assert dataset.column_names == ["col_1", "col_2", "col_3"]


@require_not_windows
@require_dill_gt_0_3_2
@require_pyspark
def test_from_spark_streaming():
import pyspark

spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
data = [
("0", 0, 0.0),
("1", 1, 1.0),
("2", 2, 2.0),
("3", 3, 3.0),
]
df = spark.createDataFrame(data, "col_1: string, col_2: int, col_3: float")
dataset = Dataset.from_spark(df, streaming=True)
assert isinstance(dataset, IterableDataset)
results = []
for ex in dataset:
results.append(ex)
assert results == [
{"col_1": "0", "col_2": 0, "col_3": 0.0},
{"col_1": "1", "col_2": 1, "col_3": 1.0},
{"col_1": "2", "col_2": 2, "col_3": 2.0},
{"col_1": "3", "col_2": 3, "col_3": 3.0},
]


@require_not_windows
@require_dill_gt_0_3_2
@require_pyspark
Expand All @@ -3683,7 +3656,6 @@ def test_from_spark_features():
dataset = Dataset.from_spark(
df,
features=features,
streaming=False,
)
assert isinstance(dataset, Dataset)
assert dataset.num_rows == 1
Expand All @@ -3694,30 +3666,6 @@ def test_from_spark_features():
assert_arrow_metadata_are_synced_with_dataset_features(dataset)


@require_not_windows
@require_dill_gt_0_3_2
@require_pyspark
def test_from_spark_streaming_features():
import PIL.Image
import pyspark

spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
data = [(0, np.arange(4 * 4 * 3).reshape(4, 4, 3).tolist())]
df = spark.createDataFrame(data, "idx: int, image: array<array<array<int>>>")
features = Features({"idx": Value("int64"), "image": Image()})
dataset = Dataset.from_spark(
df,
features=features,
streaming=True,
)
assert isinstance(dataset, IterableDataset)
results = []
for ex in dataset:
results.append(ex)
assert len(results) == 1
isinstance(results[0]["image"], PIL.Image.Image)


@require_not_windows
@require_dill_gt_0_3_2
@require_pyspark
Expand All @@ -3726,10 +3674,10 @@ def test_from_spark_different_cache():

spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
df = spark.createDataFrame([("0", 0)], "col_1: string, col_2: int")
dataset = Dataset.from_spark(df, streaming=False)
dataset = Dataset.from_spark(df)
assert isinstance(dataset, Dataset)
different_df = spark.createDataFrame([("1", 1)], "col_1: string, col_2: int")
different_dataset = Dataset.from_spark(different_df, streaming=False)
different_dataset = Dataset.from_spark(different_df)
assert isinstance(different_dataset, Dataset)
assert dataset[0]["col_1"] == "0"
# Check to make sure that the second dataset wasn't read from the cache.
Expand Down
65 changes: 63 additions & 2 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@

from datasets import load_dataset
from datasets.combine import concatenate_datasets, interleave_datasets
from datasets.features import ClassLabel, Features, Value
from datasets.features import (
ClassLabel,
Features,
Image,
Value,
)
from datasets.formatting import get_format_type_from_alias
from datasets.info import DatasetInfo
from datasets.iterable_dataset import (
Expand All @@ -27,7 +32,13 @@
_examples_to_batch,
)

from .utils import is_rng_equal, require_torch
from .utils import (
is_rng_equal,
require_dill_gt_0_3_2,
require_not_windows,
require_pyspark,
require_torch,
)


DEFAULT_N_EXAMPLES = 20
Expand Down Expand Up @@ -645,6 +656,56 @@ def gen(shard_names):
assert dataset.n_shards == len(shard_names)


@require_not_windows
@require_dill_gt_0_3_2
@require_pyspark
def test_from_spark_streaming():
import pyspark

spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
data = [
("0", 0, 0.0),
("1", 1, 1.0),
("2", 2, 2.0),
("3", 3, 3.0),
]
df = spark.createDataFrame(data, "col_1: string, col_2: int, col_3: float")
dataset = IterableDataset.from_spark(df)
assert isinstance(dataset, IterableDataset)
results = []
for ex in dataset:
results.append(ex)
assert results == [
{"col_1": "0", "col_2": 0, "col_3": 0.0},
{"col_1": "1", "col_2": 1, "col_3": 1.0},
{"col_1": "2", "col_2": 2, "col_3": 2.0},
{"col_1": "3", "col_2": 3, "col_3": 3.0},
]


@require_not_windows
@require_dill_gt_0_3_2
@require_pyspark
def test_from_spark_streaming_features():
import PIL.Image
import pyspark

spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
data = [(0, np.arange(4 * 4 * 3).reshape(4, 4, 3).tolist())]
df = spark.createDataFrame(data, "idx: int, image: array<array<array<int>>>")
features = Features({"idx": Value("int64"), "image": Image()})
dataset = IterableDataset.from_spark(
df,
features=features,
)
assert isinstance(dataset, IterableDataset)
results = []
for ex in dataset:
results.append(ex)
assert len(results) == 1
isinstance(results[0]["image"], PIL.Image.Image)


@require_torch
def test_iterable_dataset_torch_integration():
ex_iterable = ExamplesIterable(generate_examples_fn, {})
Expand Down

0 comments on commit 54f73a6

Please sign in to comment.