Skip to content

Commit 307cf7c

Browse files
authored
move the dataset loading from remote/disk to a shared function so we can re-use for RL (#2204)
1 parent 7054114 commit 307cf7c

File tree

2 files changed

+227
-210
lines changed

2 files changed

+227
-210
lines changed

src/axolotl/utils/data/sft.py

+5-210
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import functools
44
import logging
55
from pathlib import Path
6-
from typing import List, Optional, Tuple, Union
6+
from typing import List, Tuple, Union
77

88
from datasets import (
99
Dataset,
@@ -12,8 +12,6 @@
1212
load_dataset,
1313
load_from_disk,
1414
)
15-
from huggingface_hub import hf_hub_download
16-
from huggingface_hub.utils import HFValidationError
1715
from transformers import PreTrainedTokenizerBase
1816

1917
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
@@ -42,6 +40,7 @@
4240
UnsupportedPrompter,
4341
)
4442
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
43+
from axolotl.utils.data.shared import load_dataset_w_config
4544
from axolotl.utils.data.utils import (
4645
deduplicate_and_log_datasets,
4746
md5,
@@ -255,195 +254,9 @@ def for_d_in_datasets(dataset_configs):
255254

256255
# pylint: disable=invalid-name
257256
for config_dataset in for_d_in_datasets(cfg_datasets):
258-
ds: Optional[Union[Dataset, DatasetDict]] = None
259-
ds_from_hub = False
260-
ds_trust_remote_code = config_dataset.trust_remote_code
261-
try:
262-
# this is just a basic check to see if the path is a
263-
# valid HF dataset that's loadable
264-
load_dataset(
265-
config_dataset.path,
266-
name=config_dataset.name,
267-
streaming=True,
268-
token=use_auth_token,
269-
revision=config_dataset.revision,
270-
trust_remote_code=ds_trust_remote_code,
271-
)
272-
ds_from_hub = True
273-
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
274-
pass
275-
276-
ds_from_cloud = False
277-
storage_options = {}
278-
remote_file_system = None
279-
if config_dataset.path.startswith("s3://"):
280-
try:
281-
import aiobotocore.session # type: ignore
282-
import s3fs # type: ignore
283-
except ImportError as exc:
284-
raise ImportError(
285-
"s3:// paths require aiobotocore and s3fs to be installed"
286-
) from exc
287-
288-
# Takes credentials from ~/.aws/credentials for default profile
289-
s3_session = aiobotocore.session.AioSession(profile="default")
290-
storage_options = {"session": s3_session}
291-
remote_file_system = s3fs.S3FileSystem(**storage_options)
292-
elif config_dataset.path.startswith(
293-
"gs://"
294-
) or config_dataset.path.startswith("gcs://"):
295-
try:
296-
import gcsfs # type: ignore
297-
except ImportError as exc:
298-
raise ImportError(
299-
"gs:// or gcs:// paths require gcsfs to be installed"
300-
) from exc
301-
302-
# gcsfs will use default credentials from the environment else anon
303-
# https://gcsfs.readthedocs.io/en/latest/#credentials
304-
storage_options = {"token": None}
305-
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
306-
# TODO: Figure out how to get auth creds passed
307-
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
308-
# try:
309-
# import adlfs
310-
# except ImportError as exc:
311-
# raise ImportError(
312-
# "adl:// or abfs:// paths require adlfs to be installed"
313-
# ) from exc
314-
315-
# # Gen 1
316-
# storage_options = {
317-
# "tenant_id": TENANT_ID,
318-
# "client_id": CLIENT_ID,
319-
# "client_secret": CLIENT_SECRET,
320-
# }
321-
# # Gen 2
322-
# storage_options = {
323-
# "account_name": ACCOUNT_NAME,
324-
# "account_key": ACCOUNT_KEY,
325-
# }
326-
327-
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
328-
try:
329-
if remote_file_system and remote_file_system.exists(
330-
config_dataset.path
331-
):
332-
ds_from_cloud = True
333-
except (FileNotFoundError, ConnectionError):
334-
pass
335-
336-
# prefer local dataset, even if hub exists
337-
local_path = Path(config_dataset.path)
338-
if local_path.exists():
339-
if local_path.is_dir():
340-
if config_dataset.data_files:
341-
ds_type = get_ds_type(config_dataset)
342-
ds = load_dataset(
343-
ds_type,
344-
name=config_dataset.name,
345-
data_files=config_dataset.data_files,
346-
streaming=False,
347-
split=None,
348-
)
349-
else:
350-
try:
351-
ds = load_from_disk(config_dataset.path)
352-
except FileNotFoundError:
353-
ds = load_dataset(
354-
config_dataset.path,
355-
name=config_dataset.name,
356-
streaming=False,
357-
split=None,
358-
)
359-
elif local_path.is_file():
360-
ds_type = get_ds_type(config_dataset)
361-
362-
ds = load_dataset(
363-
ds_type,
364-
name=config_dataset.name,
365-
data_files=config_dataset.path,
366-
streaming=False,
367-
split=None,
368-
)
369-
else:
370-
raise ValueError(
371-
"unhandled dataset load: local path exists, but is neither a directory or a file"
372-
)
373-
elif ds_from_hub:
374-
load_ds_kwargs = {}
375-
if config_dataset.split:
376-
load_ds_kwargs["split"] = config_dataset.split
377-
ds = load_dataset(
378-
config_dataset.path,
379-
name=config_dataset.name,
380-
streaming=False,
381-
data_files=config_dataset.data_files,
382-
token=use_auth_token,
383-
revision=config_dataset.revision,
384-
trust_remote_code=config_dataset.trust_remote_code,
385-
**load_ds_kwargs,
386-
)
387-
elif ds_from_cloud and remote_file_system:
388-
if remote_file_system.isdir(config_dataset.path):
389-
ds = load_from_disk(
390-
config_dataset.path,
391-
storage_options=storage_options,
392-
)
393-
elif remote_file_system.isfile(config_dataset.path):
394-
ds_type = get_ds_type(config_dataset)
395-
ds = load_dataset(
396-
ds_type,
397-
name=config_dataset.name,
398-
data_files=config_dataset.path,
399-
streaming=False,
400-
split=None,
401-
storage_options=storage_options,
402-
trust_remote_code=config_dataset.trust_remote_code,
403-
)
404-
elif config_dataset.path.startswith("https://"):
405-
ds_type = get_ds_type(config_dataset)
406-
ds = load_dataset(
407-
ds_type,
408-
name=config_dataset.name,
409-
data_files=config_dataset.path,
410-
streaming=False,
411-
split=None,
412-
storage_options=storage_options,
413-
trust_remote_code=config_dataset.trust_remote_code,
414-
)
415-
else:
416-
if isinstance(config_dataset.data_files, str):
417-
fp = hf_hub_download(
418-
repo_id=config_dataset.path,
419-
repo_type="dataset",
420-
filename=config_dataset.data_files,
421-
revision=config_dataset.revision,
422-
)
423-
elif isinstance(config_dataset.data_files, list):
424-
fp = []
425-
for file in config_dataset.data_files:
426-
fp.append(
427-
hf_hub_download(
428-
repo_id=config_dataset.path,
429-
repo_type="dataset",
430-
filename=file,
431-
revision=config_dataset.revision,
432-
)
433-
)
434-
else:
435-
raise ValueError(
436-
"data_files must be either a string or list of strings"
437-
)
438-
ds = load_dataset(
439-
"json",
440-
name=config_dataset.name,
441-
data_files=fp,
442-
streaming=False,
443-
split=None,
444-
)
445-
if not ds:
446-
raise ValueError("unhandled dataset load")
257+
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
258+
config_dataset, use_auth_token
259+
)
447260

448261
d_base_type = d_prompt_style = None
449262
d_type = config_dataset.type
@@ -513,24 +326,6 @@ def for_d_in_datasets(dataset_configs):
513326
return dataset, prompters
514327

515328

516-
def get_ds_type(config_dataset: DictDefault):
517-
"""
518-
Get the dataset type from the path if it's not specified
519-
"""
520-
ds_type = "json"
521-
if config_dataset.ds_type:
522-
ds_type = config_dataset.ds_type
523-
elif ".parquet" in config_dataset.path:
524-
ds_type = "parquet"
525-
elif ".arrow" in config_dataset.path:
526-
ds_type = "arrow"
527-
elif ".csv" in config_dataset.path:
528-
ds_type = "csv"
529-
elif ".txt" in config_dataset.path:
530-
ds_type = "text"
531-
return ds_type
532-
533-
534329
def load_prepare_datasets(
535330
tokenizer: PreTrainedTokenizerBase,
536331
cfg,

0 commit comments

Comments
 (0)