Skip to content

Commit

Permalink
feat(datasets): add aggregate, basic and reshape classes, rename tran…
Browse files Browse the repository at this point in the history
…sform to combine, process to plot, filter to slice
  • Loading branch information
entelecheia committed Aug 8, 2023
1 parent 177fa5f commit 142c74a
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 8 deletions.
11 changes: 11 additions & 0 deletions src/hyfi/utils/datasets/aggregate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""
This module contains the class for aggregating datasets.
"""
from hyfi.utils.logging import LOGGING

logger = LOGGING.getLogger(__name__)


class DSAggregate:
def __init__(self):
pass
11 changes: 11 additions & 0 deletions src/hyfi/utils/datasets/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""
This file contains the basic dataset functions.
"""
from hyfi.utils.logging import LOGGING

logger = LOGGING.getLogger(__name__)


class DSBasic:
def __init__(self):
pass
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
Dataset transformation functions. Concatenate, merge, join, etc.
"""
from typing import Dict, List, Optional, Sequence, Union

import datasets as hfds
Expand All @@ -13,7 +16,7 @@
logger = LOGGING.getLogger(__name__)


class DSTransform:
class DSCombine:
@staticmethod
def concatenate_data(
data: Union[Dict[str, pd.DataFrame], Sequence[pd.DataFrame], List[DatasetType]],
Expand All @@ -28,14 +31,14 @@ def concatenate_data(
) -> Union[pd.DataFrame, DatasetType]:
# if data is a list of datasets, concatenate them
if isinstance(data, list) and isinstance(data[0], Dataset):
return DSTransform.concatenate_datasets(
return DSCombine.concatenate_datasets(
data,
axis=axis,
split=split,
**kwargs,
)
else:
return DSTransform.concatenate_dataframes(
return DSCombine.concatenate_dataframes(
data,
columns=columns,
add_split_key_column=add_split_key_column,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""
Plotting functions for datasets.
"""
from hyfi.utils.logging import LOGGING

logger = LOGGING.getLogger(__name__)


class DSProcess:
class DSPlot:
def __init__(self):
pass
11 changes: 11 additions & 0 deletions src/hyfi/utils/datasets/reshape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""
This module contains the class for reshaping the dataset. Pivot, melt, etc.
"""
from hyfi.utils.logging import LOGGING

logger = LOGGING.getLogger(__name__)


class DSReshape:
def __init__(self):
pass
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""
Filter datasets. Slice, sample, and filter datasets.
"""
import random
from typing import List, Optional, Union
from typing import List, Optional, Sequence, Union

import numpy as np
import pandas as pd
from datasets.arrow_dataset import Dataset
from datasets.dataset_dict import DatasetDict
Expand All @@ -14,7 +18,7 @@
logger = LOGGING.getLogger(__name__)


class DSFilter:
class DSSlice:
@staticmethod
def sample_dataset(
data: DatasetLikeType,
Expand Down Expand Up @@ -71,15 +75,15 @@ def filter_and_sample_data(
"""
# Filter by queries
if queries:
data_ = DSFilter.filter_data_by_queries(data, queries, verbose=verbose)
data_ = DSSlice.filter_data_by_queries(data, queries, verbose=verbose)
else:
logger.warning("No query specified")
data_ = data.copy()

# Create a sample for analysis
sample = None
if sample_size:
sample = DSFilter.sample_data(
sample = DSSlice.sample_data(
data_,
sample_size_per_group=sample_size,
sample_seed=sample_seed,
Expand Down Expand Up @@ -237,3 +241,24 @@ def sample_data(
print(grp_dists)

return _sample

@staticmethod
def split_dataframe(
data,
indices_or_sections: Union[int, Sequence[int]],
verbose: bool = False,
) -> List[pd.DataFrame]:
"""
Split a dataframe into multiple dataframes
Args:
data (pd.DataFrame): dataframe to split
indices_or_sections (int or sequence of ints): if int, number of chunks to split the dataframe into
Returns:
List[pd.DataFrame]: list of dataframes
"""

if verbose:
logger.info("Splitting dataframe into %s", indices_or_sections)
return np.array_split(data, indices_or_sections)

0 comments on commit 142c74a

Please sign in to comment.