Skip to content

Commit

Permalink
fix(utils/datasets): rename save_to_disk to save_dataset_to_disk and …
Browse files Browse the repository at this point in the history
…improve functionality
  • Loading branch information
entelecheia committed Jul 31, 2023
1 parent 3adc15e commit d3ea336
Showing 1 changed file with 97 additions and 16 deletions.
113 changes: 97 additions & 16 deletions src/hyfi/utils/datasets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import random
from os import PathLike
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
Expand All @@ -22,7 +23,13 @@
logger = LOGGING.getLogger(__name__)

DatasetType = Union[Dataset, IterableDataset]
DatasetLikeType = Union[Dataset, IterableDataset, DatasetDict, IterableDatasetDict]
DatasetDictType = Union[DatasetDict, IterableDatasetDict]
DatasetLikeType = Union[
Dataset,
IterableDataset,
DatasetDict,
IterableDatasetDict,
]


class DATASETs:
Expand Down Expand Up @@ -533,30 +540,104 @@ def load_dataset(
)

@staticmethod
def save_to_disk(
dset,
def load_dataset_from_disk(
dataset_path: str,
keep_in_memory: Optional[bool] = None,
storage_options: Optional[dict] = None,
num_heads: Optional[int] = 1,
num_tails: Optional[int] = 1,
verbose: bool = False,
) -> Union[Dataset, DatasetDict]:
"""Load a dataset from the filesystem."""
data = hfds.load_from_disk(
dataset_path=dataset_path,
keep_in_memory=keep_in_memory,
storage_options=storage_options,
)
logger.info("Dataset loaded from %s.", dataset_path)
if verbose:
if isinstance(data, DatasetDict or IterableDatasetDict):
for split in data:
logger.info("Split: %s", split)
logger.info("Dataset features: %s", data[split].features)
logger.info("Number of records: %s", len(data[split]))
else:
if num_heads:
num_heads = min(num_heads, len(data))
print(data[:num_heads])
if num_tails:
num_tails = min(num_tails, len(data))
print(data[-num_tails:])
logger.info("Dataset features: %s", data.features)
logger.info("Number of records: %s", len(data))

return data

@staticmethod
def sample_dataset(
data: DatasetLikeType,
split: Optional[Union[str, Split]] = None,
num_samples: int = 100,
randomize: bool = True,
random_seed: int = 42,
num_heads: Optional[int] = 1,
num_tails: Optional[int] = 1,
verbose: bool = False,
) -> Dataset:
"""Sample a dataset."""
if not isinstance(data, Dataset or DatasetDict):
if split is None:
raise ValueError(
"Please provide a split name when sampling a DatasetDict."
)
data = data[split]
if random_seed > 0:
random.seed(random_seed)
if randomize:
idx = random.sample(range(len(data)), num_samples)
else:
idx = range(num_samples)

data = data.select(idx)
logger.info("Sampling done.")
if verbose:
if num_heads:
num_heads = min(num_heads, len(data))
print(data[:num_heads])
if num_tails:
num_tails = min(num_tails, len(data))
print(data[-num_tails:])

return data

@staticmethod
def save_dataset_to_disk(
dset: DatasetLikeType,
dataset_path: PathLike,
max_shard_size: Optional[Union[str, int]] = None,
num_shards: Optional[int] = None,
num_proc: Optional[int] = None,
storage_options: Optional[dict] = None,
):
verbose: bool = False,
**kwargs,
) -> DatasetLikeType:
"""Save a dataset or a dataset dict to the filesystem."""
dset.save_to_disk(
dataset_path=dataset_path,
dataset_path,
max_shard_size=max_shard_size,
num_shards=num_shards,
num_proc=num_proc,
storage_options=storage_options,
)
logger.info("Dataset saved to %s.", dataset_path)
if verbose:
if isinstance(dset, DatasetDict or IterableDatasetDict):
for split in dset:
logger.info("Split: %s", split)
logger.info("Dataset features: %s", dset[split].features)
logger.info("Number of records: %s", len(dset[split]))
else:
logger.info("Dataset features: %s", dset.features)
logger.info("Number of records: %s", len(dset))

@staticmethod
def load_from_disk(
dataset_path: str,
keep_in_memory: Optional[bool] = None,
storage_options: Optional[dict] = None,
) -> Union[Dataset, DatasetDict]:
return hfds.load_from_disk(
dataset_path=dataset_path,
keep_in_memory=keep_in_memory,
storage_options=storage_options,
)
return dset

0 comments on commit d3ea336

Please sign in to comment.