Skip to content

Commit

Permalink
Merge pull request #353 from donglihe-hub/data_utils
Browse files Browse the repository at this point in the history
Enhance data_utils
  • Loading branch information
Eleven1Liu authored Jan 24, 2024
2 parents 290734f + f00da0a commit a7ec069
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions libmultilabel/nn/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _load_raw_data(data, is_test=False, tokenize_text=True, remove_no_label_data
This is effective only when is_test=False. Defaults to False.
Returns:
pandas.DataFrame: Data composed of index, label, and tokenized text.
dict: [{(optional: "index": ..., )"label": ..., "text": ...}, ...]
"""
assert isinstance(data, str) or isinstance(data, pd.DataFrame), "Data must be from a file or pandas dataframe."
if isinstance(data, str):
Expand Down Expand Up @@ -222,15 +222,12 @@ def load_datasets(
Returns:
dict: A dictionary of datasets.
"""
if isinstance(training_data, str) or isinstance(test_data, str):
assert training_data or test_data, "At least one of `training_data` and `test_data` must be specified."
elif isinstance(training_data, pd.DataFrame) or isinstance(test_data, pd.DataFrame):
assert (
not training_data.empty or not test_data.empty
), "At least one of `training_data` and `test_data` must be specified."
if training_data is None and test_data is None:
raise ValueError("At least one of `training_data` and `test_data` must be specified.")

datasets = {}
if training_data is not None:
logging.info(f"Loading training data")
datasets["train"] = _load_raw_data(
training_data, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data
)
Expand All @@ -243,11 +240,12 @@ def load_datasets(
datasets["train"], datasets["val"] = train_test_split(datasets["train"], test_size=val_size, random_state=42)

if test_data is not None:
logging.info(f"Loading test data")
datasets["test"] = _load_raw_data(
test_data, is_test=True, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data
)

if merge_train_val:
if merge_train_val and "val" in datasets:
datasets["train"] = datasets["train"] + datasets["val"]
for i in range(len(datasets["train"])):
datasets["train"][i]["index"] = i
Expand Down

0 comments on commit a7ec069

Please sign in to comment.