|
3 | 3 | import functools
|
4 | 4 | import logging
|
5 | 5 | from pathlib import Path
|
6 |
| -from typing import List, Optional, Tuple, Union |
| 6 | +from typing import List, Tuple, Union |
7 | 7 |
|
8 | 8 | from datasets import (
|
9 | 9 | Dataset,
|
|
12 | 12 | load_dataset,
|
13 | 13 | load_from_disk,
|
14 | 14 | )
|
15 |
| -from huggingface_hub import hf_hub_download |
16 |
| -from huggingface_hub.utils import HFValidationError |
17 | 15 | from transformers import PreTrainedTokenizerBase
|
18 | 16 |
|
19 | 17 | from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
|
42 | 40 | UnsupportedPrompter,
|
43 | 41 | )
|
44 | 42 | from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
| 43 | +from axolotl.utils.data.shared import load_dataset_w_config |
45 | 44 | from axolotl.utils.data.utils import (
|
46 | 45 | deduplicate_and_log_datasets,
|
47 | 46 | md5,
|
@@ -255,195 +254,9 @@ def for_d_in_datasets(dataset_configs):
|
255 | 254 |
|
256 | 255 | # pylint: disable=invalid-name
|
257 | 256 | for config_dataset in for_d_in_datasets(cfg_datasets):
|
258 |
| - ds: Optional[Union[Dataset, DatasetDict]] = None |
259 |
| - ds_from_hub = False |
260 |
| - ds_trust_remote_code = config_dataset.trust_remote_code |
261 |
| - try: |
262 |
| - # this is just a basic check to see if the path is a |
263 |
| - # valid HF dataset that's loadable |
264 |
| - load_dataset( |
265 |
| - config_dataset.path, |
266 |
| - name=config_dataset.name, |
267 |
| - streaming=True, |
268 |
| - token=use_auth_token, |
269 |
| - revision=config_dataset.revision, |
270 |
| - trust_remote_code=ds_trust_remote_code, |
271 |
| - ) |
272 |
| - ds_from_hub = True |
273 |
| - except (FileNotFoundError, ConnectionError, HFValidationError, ValueError): |
274 |
| - pass |
275 |
| - |
276 |
| - ds_from_cloud = False |
277 |
| - storage_options = {} |
278 |
| - remote_file_system = None |
279 |
| - if config_dataset.path.startswith("s3://"): |
280 |
| - try: |
281 |
| - import aiobotocore.session # type: ignore |
282 |
| - import s3fs # type: ignore |
283 |
| - except ImportError as exc: |
284 |
| - raise ImportError( |
285 |
| - "s3:// paths require aiobotocore and s3fs to be installed" |
286 |
| - ) from exc |
287 |
| - |
288 |
| - # Takes credentials from ~/.aws/credentials for default profile |
289 |
| - s3_session = aiobotocore.session.AioSession(profile="default") |
290 |
| - storage_options = {"session": s3_session} |
291 |
| - remote_file_system = s3fs.S3FileSystem(**storage_options) |
292 |
| - elif config_dataset.path.startswith( |
293 |
| - "gs://" |
294 |
| - ) or config_dataset.path.startswith("gcs://"): |
295 |
| - try: |
296 |
| - import gcsfs # type: ignore |
297 |
| - except ImportError as exc: |
298 |
| - raise ImportError( |
299 |
| - "gs:// or gcs:// paths require gcsfs to be installed" |
300 |
| - ) from exc |
301 |
| - |
302 |
| - # gcsfs will use default credentials from the environment else anon |
303 |
| - # https://gcsfs.readthedocs.io/en/latest/#credentials |
304 |
| - storage_options = {"token": None} |
305 |
| - remote_file_system = gcsfs.GCSFileSystem(**storage_options) |
306 |
| - # TODO: Figure out how to get auth creds passed |
307 |
| - # elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"): |
308 |
| - # try: |
309 |
| - # import adlfs |
310 |
| - # except ImportError as exc: |
311 |
| - # raise ImportError( |
312 |
| - # "adl:// or abfs:// paths require adlfs to be installed" |
313 |
| - # ) from exc |
314 |
| - |
315 |
| - # # Gen 1 |
316 |
| - # storage_options = { |
317 |
| - # "tenant_id": TENANT_ID, |
318 |
| - # "client_id": CLIENT_ID, |
319 |
| - # "client_secret": CLIENT_SECRET, |
320 |
| - # } |
321 |
| - # # Gen 2 |
322 |
| - # storage_options = { |
323 |
| - # "account_name": ACCOUNT_NAME, |
324 |
| - # "account_key": ACCOUNT_KEY, |
325 |
| - # } |
326 |
| - |
327 |
| - # remote_file_system = adlfs.AzureBlobFileSystem(**storage_options) |
328 |
| - try: |
329 |
| - if remote_file_system and remote_file_system.exists( |
330 |
| - config_dataset.path |
331 |
| - ): |
332 |
| - ds_from_cloud = True |
333 |
| - except (FileNotFoundError, ConnectionError): |
334 |
| - pass |
335 |
| - |
336 |
| - # prefer local dataset, even if hub exists |
337 |
| - local_path = Path(config_dataset.path) |
338 |
| - if local_path.exists(): |
339 |
| - if local_path.is_dir(): |
340 |
| - if config_dataset.data_files: |
341 |
| - ds_type = get_ds_type(config_dataset) |
342 |
| - ds = load_dataset( |
343 |
| - ds_type, |
344 |
| - name=config_dataset.name, |
345 |
| - data_files=config_dataset.data_files, |
346 |
| - streaming=False, |
347 |
| - split=None, |
348 |
| - ) |
349 |
| - else: |
350 |
| - try: |
351 |
| - ds = load_from_disk(config_dataset.path) |
352 |
| - except FileNotFoundError: |
353 |
| - ds = load_dataset( |
354 |
| - config_dataset.path, |
355 |
| - name=config_dataset.name, |
356 |
| - streaming=False, |
357 |
| - split=None, |
358 |
| - ) |
359 |
| - elif local_path.is_file(): |
360 |
| - ds_type = get_ds_type(config_dataset) |
361 |
| - |
362 |
| - ds = load_dataset( |
363 |
| - ds_type, |
364 |
| - name=config_dataset.name, |
365 |
| - data_files=config_dataset.path, |
366 |
| - streaming=False, |
367 |
| - split=None, |
368 |
| - ) |
369 |
| - else: |
370 |
| - raise ValueError( |
371 |
| - "unhandled dataset load: local path exists, but is neither a directory or a file" |
372 |
| - ) |
373 |
| - elif ds_from_hub: |
374 |
| - load_ds_kwargs = {} |
375 |
| - if config_dataset.split: |
376 |
| - load_ds_kwargs["split"] = config_dataset.split |
377 |
| - ds = load_dataset( |
378 |
| - config_dataset.path, |
379 |
| - name=config_dataset.name, |
380 |
| - streaming=False, |
381 |
| - data_files=config_dataset.data_files, |
382 |
| - token=use_auth_token, |
383 |
| - revision=config_dataset.revision, |
384 |
| - trust_remote_code=config_dataset.trust_remote_code, |
385 |
| - **load_ds_kwargs, |
386 |
| - ) |
387 |
| - elif ds_from_cloud and remote_file_system: |
388 |
| - if remote_file_system.isdir(config_dataset.path): |
389 |
| - ds = load_from_disk( |
390 |
| - config_dataset.path, |
391 |
| - storage_options=storage_options, |
392 |
| - ) |
393 |
| - elif remote_file_system.isfile(config_dataset.path): |
394 |
| - ds_type = get_ds_type(config_dataset) |
395 |
| - ds = load_dataset( |
396 |
| - ds_type, |
397 |
| - name=config_dataset.name, |
398 |
| - data_files=config_dataset.path, |
399 |
| - streaming=False, |
400 |
| - split=None, |
401 |
| - storage_options=storage_options, |
402 |
| - trust_remote_code=config_dataset.trust_remote_code, |
403 |
| - ) |
404 |
| - elif config_dataset.path.startswith("https://"): |
405 |
| - ds_type = get_ds_type(config_dataset) |
406 |
| - ds = load_dataset( |
407 |
| - ds_type, |
408 |
| - name=config_dataset.name, |
409 |
| - data_files=config_dataset.path, |
410 |
| - streaming=False, |
411 |
| - split=None, |
412 |
| - storage_options=storage_options, |
413 |
| - trust_remote_code=config_dataset.trust_remote_code, |
414 |
| - ) |
415 |
| - else: |
416 |
| - if isinstance(config_dataset.data_files, str): |
417 |
| - fp = hf_hub_download( |
418 |
| - repo_id=config_dataset.path, |
419 |
| - repo_type="dataset", |
420 |
| - filename=config_dataset.data_files, |
421 |
| - revision=config_dataset.revision, |
422 |
| - ) |
423 |
| - elif isinstance(config_dataset.data_files, list): |
424 |
| - fp = [] |
425 |
| - for file in config_dataset.data_files: |
426 |
| - fp.append( |
427 |
| - hf_hub_download( |
428 |
| - repo_id=config_dataset.path, |
429 |
| - repo_type="dataset", |
430 |
| - filename=file, |
431 |
| - revision=config_dataset.revision, |
432 |
| - ) |
433 |
| - ) |
434 |
| - else: |
435 |
| - raise ValueError( |
436 |
| - "data_files must be either a string or list of strings" |
437 |
| - ) |
438 |
| - ds = load_dataset( |
439 |
| - "json", |
440 |
| - name=config_dataset.name, |
441 |
| - data_files=fp, |
442 |
| - streaming=False, |
443 |
| - split=None, |
444 |
| - ) |
445 |
| - if not ds: |
446 |
| - raise ValueError("unhandled dataset load") |
| 257 | + ds: Union[Dataset, DatasetDict] = load_dataset_w_config( |
| 258 | + config_dataset, use_auth_token |
| 259 | + ) |
447 | 260 |
|
448 | 261 | d_base_type = d_prompt_style = None
|
449 | 262 | d_type = config_dataset.type
|
@@ -513,24 +326,6 @@ def for_d_in_datasets(dataset_configs):
|
513 | 326 | return dataset, prompters
|
514 | 327 |
|
515 | 328 |
|
516 |
| -def get_ds_type(config_dataset: DictDefault): |
517 |
| - """ |
518 |
| - Get the dataset type from the path if it's not specified |
519 |
| - """ |
520 |
| - ds_type = "json" |
521 |
| - if config_dataset.ds_type: |
522 |
| - ds_type = config_dataset.ds_type |
523 |
| - elif ".parquet" in config_dataset.path: |
524 |
| - ds_type = "parquet" |
525 |
| - elif ".arrow" in config_dataset.path: |
526 |
| - ds_type = "arrow" |
527 |
| - elif ".csv" in config_dataset.path: |
528 |
| - ds_type = "csv" |
529 |
| - elif ".txt" in config_dataset.path: |
530 |
| - ds_type = "text" |
531 |
| - return ds_type |
532 |
| - |
533 |
| - |
534 | 329 | def load_prepare_datasets(
|
535 | 330 | tokenizer: PreTrainedTokenizerBase,
|
536 | 331 | cfg,
|
|
0 commit comments