-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c240a22
commit a4262f2
Showing
7 changed files
with
277 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"id": "125a1871-6cab-4dc4-9fd5-4e5dbd63ada6", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import warnings\n", | ||
"warnings.filterwarnings('ignore')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "38dc7e9e-35fd-4604-9be3-1a1a8749fbcb", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from pyspark_huggingface import HuggingFaceDatasets" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "620d3ecb-b9cb-480c-b300-69198cce7a9c", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from pyspark.sql import SparkSession" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"id": "9255ffcb-0b61-43dc-b57a-2b8af01a8432", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"spark = SparkSession.builder.getOrCreate()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"id": "7c4501a8-26f4-4f52-9dc8-a70393d567b4", | ||
"metadata": {}, | ||
"source": [ | ||
"spark.dataSource.register(HuggingFaceDatasets)" | ||
], | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"id": "b8580bde-3f64-4c71-a087-8b3f71099aee", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"df = spark.read.format(\"huggingface\").load(\"rotten_tomatoes\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 15, | ||
"id": "8866bdfb-0782-4430-8b1e-09c65e699f41", | ||
"metadata": { | ||
"editable": true, | ||
"slideshow": { | ||
"slide_type": "" | ||
}, | ||
"tags": [] | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"[Stage 5:> (0 + 1) / 1]" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"+--------------------+-----+\n", | ||
"| text|label|\n", | ||
"+--------------------+-----+\n", | ||
"|the rock is desti...| 1|\n", | ||
"|the gorgeously el...| 1|\n", | ||
"|effective but too...| 1|\n", | ||
"|if you sometimes ...| 1|\n", | ||
"|emerges as someth...| 1|\n", | ||
"|the film provides...| 1|\n", | ||
"|offers that rare ...| 1|\n", | ||
"|perhaps no pictur...| 1|\n", | ||
"|steers turns in a...| 1|\n", | ||
"|take care of my c...| 1|\n", | ||
"|this is a film we...| 1|\n", | ||
"|what really surpr...| 1|\n", | ||
"|( wendigo is ) wh...| 1|\n", | ||
"|one of the greate...| 1|\n", | ||
"|ultimately , it p...| 1|\n", | ||
"|an utterly compel...| 1|\n", | ||
"|illuminating if o...| 1|\n", | ||
"|a masterpiece fou...| 1|\n", | ||
"|the movie's ripe ...| 1|\n", | ||
"|offers a breath o...| 1|\n", | ||
"+--------------------+-----+\n", | ||
"only showing top 20 rows\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
" " | ||
] | ||
} | ||
], | ||
"source": [ | ||
"df.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"id": "873bb4fc-1424-4816-b835-6c2b839d3de4", | ||
"metadata": {}, | ||
"source": [ | ||
"df.count()" | ||
], | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "4a1b895f-fe20-4520-a90d-b17df8e691e4", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "pyspark_huggingface", | ||
"language": "python", | ||
"name": "pyspark_huggingface" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.10" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from pyspark_huggingface.huggingface import HuggingFaceDatasets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from pyspark.sql.datasource import DataSource, DataSourceReader | ||
from pyspark.sql.types import StructField, StructType, StringType | ||
|
||
|
||
# TODO: Use `DefaultSource` | ||
class HuggingFaceDatasets(DataSource): | ||
""" | ||
A DataSource for reading and writing HuggingFace Datasets in Spark. | ||
This data source allows reading public datasets from the HuggingFace Hub directly into Spark | ||
DataFrames. The schema is automatically inferred from the dataset features. The split can be | ||
specified using the `split` option. The default split is `train`. | ||
Name: `huggingface` | ||
Notes: | ||
----- | ||
- The HuggingFace `datasets` library is required to use this data source. Make sure it is installed. | ||
- If the schema is automatically inferred, it will use string type for all fields. | ||
- Currently it can only be used with public datasets. Private or gated ones are not supported. | ||
Examples | ||
-------- | ||
Load a public dataset from the HuggingFace Hub. | ||
>>> spark.read.format("huggingface").load("imdb").show() | ||
+--------------------+-----+ | ||
| text|label| | ||
+--------------------+-----+ | ||
|I rented I AM CUR...| 0| | ||
|"I Am Curious: Ye...| 0| | ||
|... | ...| | ||
+--------------------+-----+ | ||
Load a specific split from a public dataset from the HuggingFace Hub. | ||
>>> spark.read.format("huggingface").option("split", "test").load("imdb").show() | ||
+--------------------+-----+ | ||
| text|label| | ||
+--------------------+-----+ | ||
|I love sci-fi and...| 0| | ||
|Worth the enterta...| 0| | ||
|... | ...| | ||
+--------------------+-----+ | ||
""" | ||
|
||
def __init__(self, options): | ||
super().__init__(options) | ||
if "path" not in options or not options["path"]: | ||
raise Exception("You must specify a dataset name in`.load()`.") | ||
|
||
@classmethod | ||
def name(cls): | ||
return "huggingface" | ||
|
||
def schema(self): | ||
from datasets import load_dataset_builder | ||
dataset_name = self.options["path"] | ||
ds_builder = load_dataset_builder(dataset_name) | ||
features = ds_builder.info.features | ||
if features is None: | ||
raise Exception( | ||
"Unable to automatically determine the schema using the dataset features. " | ||
"Please specify the schema manually using `.schema()`." | ||
) | ||
schema = StructType() | ||
for key, value in features.items(): | ||
# For simplicity, use string for all values. | ||
schema.add(StructField(key, StringType(), True)) | ||
return schema | ||
|
||
def reader(self, schema: StructType) -> "DataSourceReader": | ||
return HuggingFaceDatasetsReader(schema, self.options) | ||
|
||
|
||
class HuggingFaceDatasetsReader(DataSourceReader): | ||
def __init__(self, schema: StructType, options: dict): | ||
self.schema = schema | ||
self.dataset_name = options["path"] | ||
# TODO: validate the split value. | ||
self.split = options.get("split", "train") # Default using train split. | ||
|
||
def read(self, partition): | ||
from datasets import load_dataset | ||
columns = [field.name for field in self.schema.fields] | ||
iter_dataset = load_dataset(self.dataset_name, split=self.split, streaming=True) | ||
for example in iter_dataset: | ||
# TODO: next spark 4.0.0 dev release will include the feature to yield as an iterator of pa.RecordBatch | ||
yield tuple([example.get(column) for column in columns]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
datasets==3.1.0 | ||
pyspark[connect]==4.0.0.dev2 | ||
pytest |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import pytest | ||
from pyspark.sql import SparkSession | ||
from pyspark_huggingface import HuggingFaceDatasets | ||
|
||
|
||
@pytest.fixture | ||
def spark(): | ||
spark = SparkSession.builder.getOrCreate() | ||
yield spark | ||
|
||
|
||
def test_basic_load(spark): | ||
spark.dataSource.register(HuggingFaceDatasets) | ||
df = spark.read.format("huggingface").load("rotten_tomatoes") | ||
assert df.count() == 8530 # length of the training dataset |