Skip to content

Commit

Permalink
Add Polars loading code (#2997)
Browse files Browse the repository at this point in the history
* c

* update test

* fix tests and lint

* show LOGIN_COMMENT when login required
  • Loading branch information
nameexhaustion authored Jul 26, 2024
1 parent 3493e05 commit 894e1d9
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 1 deletion.
2 changes: 1 addition & 1 deletion services/worker/src/worker/dtos.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ class IsValidResponse(TypedDict):
statistics: bool


DatasetLibrary = Literal["mlcroissant", "webdataset", "datasets", "pandas", "dask"]
DatasetLibrary = Literal["mlcroissant", "webdataset", "datasets", "pandas", "dask", "polars"]
DatasetFormat = Literal["json", "csv", "parquet", "imagefolder", "audiofolder", "webdataset", "text", "arrow"]
ProgrammingLanguage = Literal["python"]

Expand Down
110 changes: 110 additions & 0 deletions services/worker/src/worker/job_runners/dataset/compatible_libraries.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,104 @@ def get_compatible_libraries_for_webdataset(
return {"language": "python", "library": library, "function": function, "loading_codes": loading_codes}


def get_polars_compatible_library(
builder_name: str, dataset: str, hf_token: Optional[str], login_required: bool
) -> Optional[CompatibleLibrary]:
if builder_name in ["parquet", "csv", "json"]:
builder_configs = get_builder_configs_with_simplified_data_files(
dataset, module_name=builder_name, hf_token=hf_token
)

for config in builder_configs:
if any(len(data_files) != 1 for data_files in config.data_files.values()):
raise DatasetWithTooComplexDataFilesPatternsError(
f"Failed to simplify parquet data files pattern: {config.data_files}"
)

loading_codes: list[LoadingCode] = [
{
"config_name": config.name,
"arguments": {
"splits": {str(split): data_files[0] for split, data_files in config.data_files.items()}
},
"code": "",
}
for config in builder_configs
]

compatible_library: CompatibleLibrary = {
"language": "python",
"library": "polars",
"function": "",
"loading_codes": [],
}

def fmt_code(
*,
read_func: str,
splits: dict[str, str],
args: str,
dataset: str = dataset,
login_required: bool = login_required,
) -> str:
if not (args.startswith(", ") or args == ""):
msg = f"incorrect args format: {args = !s}"
raise ValueError(msg)

login_comment = LOGIN_COMMENT if login_required else ""

if len(splits) == 1:
path = next(iter(splits.values()))
return f"""\
import polars as pl
{login_comment}
df = pl.{read_func}('hf://datasets/{dataset}/{path}'{args})
"""
else:
first_split = next(iter(splits))
return f"""\
import polars as pl
{login_comment}
splits = {splits}
df = pl.{read_func}('hf://datasets/{dataset}/' + splits['{first_split}']{args})
"""

args = ""

if builder_name == "parquet":
read_func = "read_parquet"
compatible_library["function"] = f"pl.{read_func}"
elif builder_name == "csv":
read_func = "read_csv"
compatible_library["function"] = f"pl.{read_func}"

first_file = next(iter(loading_codes[0]["arguments"]["splits"].values()))

if ".tsv" in first_file:
args = f"{args}, separator='\\t'"

elif builder_name == "json":
first_file = next(iter(loading_codes[0]["arguments"]["splits"].values()))
is_json_lines = ".jsonl" in first_file or HfFileSystem(token=hf_token).open(first_file, "r").read(1) != "["

if is_json_lines:
read_func = "read_ndjson"
else:
read_func = "read_json"

compatible_library["function"] = f"pl.{read_func}"
else:
return None

for loading_code in loading_codes:
splits = loading_code["arguments"]["splits"]
loading_code["code"] = fmt_code(read_func=read_func, splits=splits, args=args)

compatible_library["loading_codes"] = loading_codes

return compatible_library


get_compatible_library_for_builder: dict[str, Callable[[str, Optional[str], bool], CompatibleLibrary]] = {
"webdataset": get_compatible_libraries_for_webdataset,
"json": get_compatible_libraries_for_json,
Expand Down Expand Up @@ -701,6 +799,18 @@ def compute_compatible_libraries_response(
libraries.append(
get_mlcroissant_compatible_library(dataset, infos, login_required=login_required, partial=partial)
)
# polars
if (
isinstance(builder_name, str)
and (
v := get_polars_compatible_library(
builder_name, dataset, hf_token=hf_token, login_required=login_required
)
)
is not None
):
libraries.append(v)

return DatasetCompatibleLibrariesResponse(libraries=libraries, formats=formats)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,24 @@ def prepare_and_clean_mongo(app_config: AppConfig) -> None:
}
],
},
{
"function": "pl.read_parquet",
"language": "python",
"library": "polars",
"loading_codes": [
{
"arguments": {"splits": {"test": "test.parquet", "train": "train.parquet"}},
"code": "import polars as pl\n"
"\n"
"splits = {'train': 'train.parquet', 'test': "
"'test.parquet'}\n"
"df = "
"pl.read_parquet('hf://datasets/parquet-dataset/' "
"+ splits['train'])\n",
"config_name": "default",
}
],
},
],
},
1.0,
Expand Down Expand Up @@ -193,6 +211,26 @@ def prepare_and_clean_mongo(app_config: AppConfig) -> None:
}
],
},
{
"function": "pl.read_parquet",
"language": "python",
"library": "polars",
"loading_codes": [
{
"arguments": {"splits": {"test": "test.parquet", "train": "train.parquet"}},
"code": "import polars as pl\n"
"\n"
"# Login using e.g. `huggingface-cli login` to "
"access this dataset\n"
"splits = {'train': 'train.parquet', 'test': "
"'test.parquet'}\n"
"df = "
"pl.read_parquet('hf://datasets/parquet-dataset-login_required/' "
"+ splits['train'])\n",
"config_name": "default",
}
],
},
],
},
1.0,
Expand Down

0 comments on commit 894e1d9

Please sign in to comment.