diff --git a/services/worker/src/worker/dtos.py b/services/worker/src/worker/dtos.py index 90433076c..da04930ec 100644 --- a/services/worker/src/worker/dtos.py +++ b/services/worker/src/worker/dtos.py @@ -312,7 +312,7 @@ class IsValidResponse(TypedDict): statistics: bool -DatasetLibrary = Literal["mlcroissant", "webdataset", "datasets", "pandas", "dask"] +DatasetLibrary = Literal["mlcroissant", "webdataset", "datasets", "pandas", "dask", "polars"] DatasetFormat = Literal["json", "csv", "parquet", "imagefolder", "audiofolder", "webdataset", "text", "arrow"] ProgrammingLanguage = Literal["python"] diff --git a/services/worker/src/worker/job_runners/dataset/compatible_libraries.py b/services/worker/src/worker/job_runners/dataset/compatible_libraries.py index 99a30965e..53ddc7250 100644 --- a/services/worker/src/worker/job_runners/dataset/compatible_libraries.py +++ b/services/worker/src/worker/job_runners/dataset/compatible_libraries.py @@ -619,6 +619,104 @@ def get_compatible_libraries_for_webdataset( return {"language": "python", "library": library, "function": function, "loading_codes": loading_codes} +def get_polars_compatible_library( + builder_name: str, dataset: str, hf_token: Optional[str], login_required: bool +) -> Optional[CompatibleLibrary]: + if builder_name in ["parquet", "csv", "json"]: + builder_configs = get_builder_configs_with_simplified_data_files( + dataset, module_name=builder_name, hf_token=hf_token + ) + + for config in builder_configs: + if any(len(data_files) != 1 for data_files in config.data_files.values()): + raise DatasetWithTooComplexDataFilesPatternsError( + f"Failed to simplify parquet data files pattern: {config.data_files}" + ) + + loading_codes: list[LoadingCode] = [ + { + "config_name": config.name, + "arguments": { + "splits": {str(split): data_files[0] for split, data_files in config.data_files.items()} + }, + "code": "", + } + for config in builder_configs + ] + + compatible_library: CompatibleLibrary = { + "language": "python", + "library": "polars", + "function": "", + "loading_codes": [], + } + + def fmt_code( + *, + read_func: str, + splits: dict[str, str], + args: str, + dataset: str = dataset, + login_required: bool = login_required, + ) -> str: + if not (args.startswith(", ") or args == ""): + msg = f"incorrect args format: {args = !s}" + raise ValueError(msg) + + login_comment = LOGIN_COMMENT if login_required else "" + + if len(splits) == 1: + path = next(iter(splits.values())) + return f"""\ +import polars as pl +{login_comment} +df = pl.{read_func}('hf://datasets/{dataset}/{path}'{args}) +""" + else: + first_split = next(iter(splits)) + return f"""\ +import polars as pl +{login_comment} +splits = {splits} +df = pl.{read_func}('hf://datasets/{dataset}/' + splits['{first_split}']{args}) +""" + + args = "" + + if builder_name == "parquet": + read_func = "read_parquet" + compatible_library["function"] = f"pl.{read_func}" + elif builder_name == "csv": + read_func = "read_csv" + compatible_library["function"] = f"pl.{read_func}" + + first_file = next(iter(loading_codes[0]["arguments"]["splits"].values())) + + if ".tsv" in first_file: + args = f"{args}, separator='\\t'" + + elif builder_name == "json": + first_file = next(iter(loading_codes[0]["arguments"]["splits"].values())) + is_json_lines = ".jsonl" in first_file or HfFileSystem(token=hf_token).open(first_file, "r").read(1) != "[" + + if is_json_lines: + read_func = "read_ndjson" + else: + read_func = "read_json" + + compatible_library["function"] = f"pl.{read_func}" + else: + return None + + for loading_code in loading_codes: + splits = loading_code["arguments"]["splits"] + loading_code["code"] = fmt_code(read_func=read_func, splits=splits, args=args) + + compatible_library["loading_codes"] = loading_codes + + return compatible_library + + get_compatible_library_for_builder: dict[str, Callable[[str, Optional[str], bool], CompatibleLibrary]] = { "webdataset": get_compatible_libraries_for_webdataset, "json": get_compatible_libraries_for_json, @@ -701,6 +799,18 @@ def compute_compatible_libraries_response( libraries.append( get_mlcroissant_compatible_library(dataset, infos, login_required=login_required, partial=partial) ) + # polars + if ( + isinstance(builder_name, str) + and ( + v := get_polars_compatible_library( + builder_name, dataset, hf_token=hf_token, login_required=login_required + ) + ) + is not None + ): + libraries.append(v) + return DatasetCompatibleLibrariesResponse(libraries=libraries, formats=formats) diff --git a/services/worker/tests/job_runners/dataset/test_compatible_libraries.py b/services/worker/tests/job_runners/dataset/test_compatible_libraries.py index b91b43e31..91557fdf0 100644 --- a/services/worker/tests/job_runners/dataset/test_compatible_libraries.py +++ b/services/worker/tests/job_runners/dataset/test_compatible_libraries.py @@ -127,6 +127,24 @@ def prepare_and_clean_mongo(app_config: AppConfig) -> None: } ], }, + { + "function": "pl.read_parquet", + "language": "python", + "library": "polars", + "loading_codes": [ + { + "arguments": {"splits": {"test": "test.parquet", "train": "train.parquet"}}, + "code": "import polars as pl\n" + "\n" + "splits = {'train': 'train.parquet', 'test': " + "'test.parquet'}\n" + "df = " + "pl.read_parquet('hf://datasets/parquet-dataset/' " + "+ splits['train'])\n", + "config_name": "default", + } + ], + }, ], }, 1.0, @@ -193,6 +211,26 @@ def prepare_and_clean_mongo(app_config: AppConfig) -> None: } ], }, + { + "function": "pl.read_parquet", + "language": "python", + "library": "polars", + "loading_codes": [ + { + "arguments": {"splits": {"test": "test.parquet", "train": "train.parquet"}}, + "code": "import polars as pl\n" + "\n" + "# Login using e.g. `huggingface-cli login` to " + "access this dataset\n" + "splits = {'train': 'train.parquet', 'test': " + "'test.parquet'}\n" + "df = " + "pl.read_parquet('hf://datasets/parquet-dataset-login_required/' " + "+ splits['train'])\n", + "config_name": "default", + } + ], + }, ], }, 1.0,