Skip to content

Commit

Permalink
🐛 Handle empty preprocessor in datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
arxyzan committed Jun 14, 2024
1 parent 7c65ce1 commit b3e27b7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
29 changes: 18 additions & 11 deletions hezar/data/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
SplitType,
)
from ...preprocessors import Preprocessor, PreprocessorsContainer
from ...utils import get_module_config_class, list_repo_files, verify_dependencies
from ...utils import Logger, get_module_config_class, list_repo_files, verify_dependencies

logger = Logger(__name__)


class Dataset(TorchDataset):
Expand All @@ -26,7 +28,7 @@ class Dataset(TorchDataset):
Args:
config: The configuration object for the dataset.
split: Dataset split name e.g, train, test, validation, etc.
preprocessor: Optional preprocessor object (note that most datasets require this argument)
preprocessor: Preprocessor object or path (note that Hezar datasets classes require this argument).
**kwargs: Additional keyword arguments.
Attributes:
Expand Down Expand Up @@ -69,7 +71,11 @@ def create_preprocessor(preprocessor: str | Preprocessor | PreprocessorsContaine
preprocessor (str | Preprocessor | PreprocessorsContainer): Preprocessor for the dataset
"""
if preprocessor is None:
return preprocessor
logger.warning(
"Since v0.39.0, `Dataset` classes require the `preprocessor` parameter and cannot be None or it will "
"lead to errors later on! (This warning will change to an error in the future)"
)
return PreprocessorsContainer()

if isinstance(preprocessor, str):
preprocessor = Preprocessor.load(preprocessor)
Expand All @@ -91,12 +97,13 @@ def __str__(self):

def __len__(self):
"""
Returns the length of the dataset. The `max_size` parameter in the config can overwrite this value.
Returns the length of the dataset. The `max_size` parameter in the config can overwrite this value. Override
with caution!
"""
if isinstance(self.config.max_size, float) and 0 < self.config.max_size <= 1:
return math.ceil(self.config.max_size * len(self.data))

elif (isinstance(self.config.max_size, int) and 0 < self.config.max_size < len(self.data)):
elif isinstance(self.config.max_size, int) and 0 < self.config.max_size < len(self.data):
return self.config.max_size

return len(self.data)
Expand All @@ -118,10 +125,10 @@ def __getitem__(self, index):
def load(
cls,
hub_path: str | os.PathLike,
config: DatasetConfig = None,
config_filename: str = None,
split: str | SplitType = None,
preprocessor: str | Preprocessor | PreprocessorsContainer = None,
config: DatasetConfig = None,
config_filename: str = None,
cache_dir: str = None,
**kwargs,
) -> "Dataset":
Expand All @@ -131,14 +138,14 @@ def load(
Args:
hub_path (str | os.PathLike):
Path to dataset from hub or locally.
config: (DatasetConfig):
A config object to ignore the config in the repo or in case the repo has no `dataset_config.yaml` file
config_filename (Optional[str]):
Dataset config file name. Falls back to `dataset_config.yaml` if not given.
split (Optional[str | SplitType]):
Dataset split, defaults to "train".
preprocessor (str | Preprocessor | PreprocessorsContainer):
Preprocessor object for the dataset
config: (DatasetConfig):
A config object to ignore the config in the repo or in case the repo has no `dataset_config.yaml` file
config_filename (Optional[str]):
Dataset config file name. Falls back to `dataset_config.yaml` if not given.
cache_dir (str):
Path to cache directory, defaults to Hezar's cache directory
**kwargs:
Expand Down
2 changes: 1 addition & 1 deletion hezar/data/datasets/text_classification_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, config: TextClassificationDatasetConfig, split=None, preproce
self.data_collator = TextPaddingDataCollator(
tokenizer=self.tokenizer,
max_length=self.config.max_length,
)
) if self.tokenizer else None

def _load(self, split):
"""
Expand Down

0 comments on commit b3e27b7

Please sign in to comment.