Skip to content

Commit

Permalink
fix(hyfi/pipe): add verbose logging options, import Dataset from arro…
Browse files Browse the repository at this point in the history
…w_dataset instead of datasets, add num_heads and num_tails parameters
  • Loading branch information
entelecheia committed Jul 26, 2023
1 parent 9d8f663 commit 6f8f8f8
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions src/hyfi/pipe/datasets.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import random
from pathlib import Path
from typing import Union
from typing import Optional, Union

from datasets import Dataset # type: ignore
from datasets.arrow_dataset import Dataset

from hyfi.main import HyFI

Expand All @@ -12,18 +12,22 @@
def save_dataset_to_disk(
data: Dataset,
dataset_path: Union[str, Path],
verbose: bool = False,
) -> Dataset:
"""
Save a dataset.
"""
data.save_to_disk(str(dataset_path))
logger.info("Dataset saved to %s.", dataset_path)
if verbose:
logger.info("Dataset saved to %s.", dataset_path)

return data


def load_dataset_from_disk(
dataset_path: str,
num_heads: Optional[int] = 1,
num_tails: Optional[int] = 1,
verbose: bool = False,
) -> Dataset:
"""
Expand All @@ -32,8 +36,12 @@ def load_dataset_from_disk(
data = Dataset.load_from_disk(dataset_path)
logger.info("Dataset loaded from %s.", dataset_path)
if verbose:
print(data[0])
print(data[-1])
if num_heads:
num_heads = min(num_heads, len(data))
print(data[:num_heads])
if num_tails:
num_tails = min(num_tails, len(data))
print(data[-num_tails:])
logger.info("Dataset features: %s", data.features)
logger.info("Number of samples: %s", len(data))

Expand All @@ -45,6 +53,8 @@ def sample_dataset(
num_samples: int = 100,
randomize: bool = True,
random_seed: int = 42,
num_heads: Optional[int] = 1,
num_tails: Optional[int] = 1,
verbose: bool = False,
) -> Dataset:
"""
Expand All @@ -60,6 +70,11 @@ def sample_dataset(
data = data.select(idx)
logger.info("Sampling done.")
if verbose:
print(data)
if num_heads:
num_heads = min(num_heads, len(data))
print(data[:num_heads])
if num_tails:
num_tails = min(num_tails, len(data))
print(data[-num_tails:])

return data

0 comments on commit 6f8f8f8

Please sign in to comment.