Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
allisonwang-db committed Nov 26, 2024
1 parent c240a22 commit a4262f2
Show file tree
Hide file tree
Showing 7 changed files with 277 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,4 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
167 changes: 167 additions & 0 deletions demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"id": "125a1871-6cab-4dc4-9fd5-4e5dbd63ada6",
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"warnings.filterwarnings('ignore')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "38dc7e9e-35fd-4604-9be3-1a1a8749fbcb",
"metadata": {},
"outputs": [],
"source": [
"from pyspark_huggingface import HuggingFaceDatasets"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "620d3ecb-b9cb-480c-b300-69198cce7a9c",
"metadata": {},
"outputs": [],
"source": [
"from pyspark.sql import SparkSession"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "9255ffcb-0b61-43dc-b57a-2b8af01a8432",
"metadata": {},
"outputs": [],
"source": [
"spark = SparkSession.builder.getOrCreate()"
]
},
{
"cell_type": "code",
"id": "7c4501a8-26f4-4f52-9dc8-a70393d567b4",
"metadata": {},
"source": [
"spark.dataSource.register(HuggingFaceDatasets)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"execution_count": 14,
"id": "b8580bde-3f64-4c71-a087-8b3f71099aee",
"metadata": {},
"outputs": [],
"source": [
"df = spark.read.format(\"huggingface\").load(\"rotten_tomatoes\")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "8866bdfb-0782-4430-8b1e-09c65e699f41",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 5:> (0 + 1) / 1]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------------------+-----+\n",
"| text|label|\n",
"+--------------------+-----+\n",
"|the rock is desti...| 1|\n",
"|the gorgeously el...| 1|\n",
"|effective but too...| 1|\n",
"|if you sometimes ...| 1|\n",
"|emerges as someth...| 1|\n",
"|the film provides...| 1|\n",
"|offers that rare ...| 1|\n",
"|perhaps no pictur...| 1|\n",
"|steers turns in a...| 1|\n",
"|take care of my c...| 1|\n",
"|this is a film we...| 1|\n",
"|what really surpr...| 1|\n",
"|( wendigo is ) wh...| 1|\n",
"|one of the greate...| 1|\n",
"|ultimately , it p...| 1|\n",
"|an utterly compel...| 1|\n",
"|illuminating if o...| 1|\n",
"|a masterpiece fou...| 1|\n",
"|the movie's ripe ...| 1|\n",
"|offers a breath o...| 1|\n",
"+--------------------+-----+\n",
"only showing top 20 rows\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" "
]
}
],
"source": [
"df.show()"
]
},
{
"cell_type": "code",
"id": "873bb4fc-1424-4816-b835-6c2b839d3de4",
"metadata": {},
"source": [
"df.count()"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a1b895f-fe20-4520-a90d-b17df8e691e4",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pyspark_huggingface",
"language": "python",
"name": "pyspark_huggingface"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions pyspark_huggingface/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pyspark_huggingface.huggingface import HuggingFaceDatasets
90 changes: 90 additions & 0 deletions pyspark_huggingface/huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.types import StructField, StructType, StringType


# TODO: Use `DefaultSource`
class HuggingFaceDatasets(DataSource):
"""
A DataSource for reading and writing HuggingFace Datasets in Spark.
This data source allows reading public datasets from the HuggingFace Hub directly into Spark
DataFrames. The schema is automatically inferred from the dataset features. The split can be
specified using the `split` option. The default split is `train`.
Name: `huggingface`
Notes:
-----
- The HuggingFace `datasets` library is required to use this data source. Make sure it is installed.
- If the schema is automatically inferred, it will use string type for all fields.
- Currently it can only be used with public datasets. Private or gated ones are not supported.
Examples
--------
Load a public dataset from the HuggingFace Hub.
>>> spark.read.format("huggingface").load("imdb").show()
+--------------------+-----+
| text|label|
+--------------------+-----+
|I rented I AM CUR...| 0|
|"I Am Curious: Ye...| 0|
|... | ...|
+--------------------+-----+
Load a specific split from a public dataset from the HuggingFace Hub.
>>> spark.read.format("huggingface").option("split", "test").load("imdb").show()
+--------------------+-----+
| text|label|
+--------------------+-----+
|I love sci-fi and...| 0|
|Worth the enterta...| 0|
|... | ...|
+--------------------+-----+
"""

def __init__(self, options):
super().__init__(options)
if "path" not in options or not options["path"]:
raise Exception("You must specify a dataset name in`.load()`.")

@classmethod
def name(cls):
return "huggingface"

def schema(self):
from datasets import load_dataset_builder
dataset_name = self.options["path"]
ds_builder = load_dataset_builder(dataset_name)
features = ds_builder.info.features
if features is None:
raise Exception(
"Unable to automatically determine the schema using the dataset features. "
"Please specify the schema manually using `.schema()`."
)
schema = StructType()
for key, value in features.items():
# For simplicity, use string for all values.
schema.add(StructField(key, StringType(), True))
return schema

def reader(self, schema: StructType) -> "DataSourceReader":
return HuggingFaceDatasetsReader(schema, self.options)


class HuggingFaceDatasetsReader(DataSourceReader):
def __init__(self, schema: StructType, options: dict):
self.schema = schema
self.dataset_name = options["path"]
# TODO: validate the split value.
self.split = options.get("split", "train") # Default using train split.

def read(self, partition):
from datasets import load_dataset
columns = [field.name for field in self.schema.fields]
iter_dataset = load_dataset(self.dataset_name, split=self.split, streaming=True)
for example in iter_dataset:
# TODO: next spark 4.0.0 dev release will include the feature to yield as an iterator of pa.RecordBatch
yield tuple([example.get(column) for column in columns])
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
datasets==3.1.0
pyspark[connect]==4.0.0.dev2
pytest
Empty file added tests/__init__.py
Empty file.
15 changes: 15 additions & 0 deletions tests/test_huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest
from pyspark.sql import SparkSession
from pyspark_huggingface import HuggingFaceDatasets


@pytest.fixture
def spark():
spark = SparkSession.builder.getOrCreate()
yield spark


def test_basic_load(spark):
spark.dataSource.register(HuggingFaceDatasets)
df = spark.read.format("huggingface").load("rotten_tomatoes")
assert df.count() == 8530 # length of the training dataset

0 comments on commit a4262f2

Please sign in to comment.