Skip to content

Commit

Permalink
feat(datasets): add type hints to functions, improve code readability
Browse files Browse the repository at this point in the history
  • Loading branch information
entelecheia committed Jul 31, 2023
1 parent 07d84b9 commit c42f446
Showing 1 changed file with 25 additions and 19 deletions.
44 changes: 25 additions & 19 deletions src/hyfi/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ def concatenate_dataframes(
def load_data(
path: Optional[str] = "pandas",
name: Optional[str] = None,
data_dir: Optional[str] = "",
data_dir: Optional[str] = None,
data_files: Optional[Union[str, Sequence[str]]] = None,
split: Optional[str] = "train",
filetype: Optional[str] = "",
filetype: Optional[str] = None,
concatenate: Optional[bool] = False,
use_cached: bool = False,
verbose: Optional[bool] = False,
Expand Down Expand Up @@ -151,7 +151,7 @@ def get_data_files(
Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
] = None,
data_dir: Optional[str] = None,
split: str = "",
split: Optional[str] = None,
recursive: bool = True,
use_cached: bool = False,
verbose: bool = False,
Expand Down Expand Up @@ -182,9 +182,9 @@ def get_data_files(
@staticmethod
def load_dataframes(
data_files: Union[str, Sequence[str]],
data_dir: str = "",
filetype: str = "",
split: str = "",
data_dir: Optional[str] = None,
filetype: Optional[str] = None,
split: Optional[str] = None,
concatenate: bool = False,
ignore_index: bool = False,
use_cached: bool = False,
Expand Down Expand Up @@ -241,8 +241,8 @@ def load_dataframes(
@staticmethod
def load_dataframe(
data_file: str,
data_dir: str = "",
filetype: str = "parquet",
data_dir: Optional[str] = None,
filetype: Optional[str] = None,
columns: Optional[Sequence[str]] = None,
index_col: Union[str, int, Sequence[str], Sequence[int], None] = None,
verbose: bool = False,
Expand All @@ -266,6 +266,7 @@ def load_dataframe(
if data_file.split(".")[-1] in ["csv", "tsv", "parquet"]
else filetype
)
filetype = filetype or "csv"
filetype = filetype.replace(".", "")
if filetype not in ["csv", "tsv", "parquet"]:
raise ValueError("`file` should be a csv or a parquet file.")
Expand Down Expand Up @@ -299,11 +300,11 @@ def load_dataframe(
def save_dataframes(
data: Union[pd.DataFrame, dict],
data_file: str,
data_dir: str = "",
data_dir: Optional[str] = None,
columns: Optional[Sequence[str]] = None,
index: bool = False,
filetype: str = "parquet",
suffix: str = "",
filetype: Optional[str] = "parquet",
suffix: Optional[str] = None,
verbose: bool = False,
**kwargs,
):
Expand Down Expand Up @@ -358,7 +359,12 @@ def save_dataframes(
raise ValueError(f"Unsupported data type: {type(data)}")

@staticmethod
def to_datetime(data, _format=None, _columns=None, **kwargs):
def to_datetime(
data,
_format: Optional[str] = None,
_columns: Optional[Union[str, Sequence[str]]] = None,
**kwargs,
):
"""Convert a string, int, or datetime object to a datetime object"""
from datetime import datetime

Expand All @@ -385,9 +391,9 @@ def to_datetime(data, _format=None, _columns=None, **kwargs):
@staticmethod
def to_numeric(
data,
_columns=None,
errors="coerce",
downcast=None,
_columns: Optional[Union[str, Sequence[str]]] = None,
errors: Optional[str] = "coerce",
downcast: Optional[str] = None,
**kwargs,
):
"""Convert a string, int, or float object to a float object"""
Expand All @@ -407,11 +413,11 @@ def to_numeric(

@staticmethod
def dict_to_dataframe(
data,
orient="columns",
data: Dict[Any, Any],
orient: str = "columns",
dtype=None,
columns=None,
):
) -> pd.DataFrame:
"""Convert a dictionary to a pandas dataframe"""
import pandas as pd

Expand All @@ -425,7 +431,7 @@ def records_to_dataframe(
columns=None,
coerce_float=False,
nrows=None,
):
) -> pd.DataFrame:
"""Convert a list of records to a pandas dataframe"""
import pandas as pd

Expand Down

0 comments on commit c42f446

Please sign in to comment.