Skip to content

Commit

Permalink
New datamodules design (#572)
Browse files Browse the repository at this point in the history
* move sample generation to datamodule instead of dataset

* move sample generation from init to setup

* remove inference stage and add base classes

* replace dataset classes with AnomalibDataset

* move setup to base class, create samples as class method

* update docstrings

* refactor btech to new format

* allow training with no anomalous data

* remove MVTec name from comment

* raise NotImplementedError in base class

* allow both png and bmp images for btech

* use label_index to check if dataset contains anomalous images

* refactor getitem in dataset class

* use iloc for indexing

* move dataloader getters to base class

* refactor to add validate stage in setup

* implement alternative datamodules solution

* small improvements

* improve design

* remove unused constructor arguments

* adapt btech to new design

* add prepare_data method for mvtec

* implement more generic random splitting function

* update docstrings for folder module

* ensure type consistency when performing operations on dataset

* change imports

* change variable names

* replace pass with NotImplementedError

* allow training on folder without test images

* use relative path for normal_test_dir

* fix dataset tests

* update validation set parameter in configs

* change default argument

* use setter for samples

* hint options for val_split_mode

* update assert message and docstring

* revert name change dataset vs datamodule

* typing and docstrings

* remove samples argument from dataset constructor

* val/test -> eval

* remove Split.Full from enum

* sort samples when setting

* update warn message

* formatting

* use setter when creating samples in dataset classes

* add tests for new dataset class

* add test case for label aware random split

* update parameter name in inferencers

* move _setup implementation to base class

* address codacy issues

* fix pylint issues

* codacy

* update example dataset config in docs

* fix test

* move base classes to separate files (avoid circular import)

* add base classes

* update docstring

* fix imports

* validation_split_mode -> val_split_mode

* update docs

* Update anomalib/data/base/dataset.py

Co-authored-by: Joao P C Bertoldo <24547377+jpcbertoldo@users.noreply.github.com>

* get length from self.samples

* assert unique indices

* check is_setup for individual datasets

Co-authored-by: Joao P C Bertoldo <24547377+jpcbertoldo@users.noreply.github.com>

* remove assert in __getitem_\

Co-authored-by: Joao P C Bertoldo <24547377+jpcbertoldo@users.noreply.github.com>

* Update anomalib/data/btech.py

Co-authored-by: Joao P C Bertoldo <24547377+jpcbertoldo@users.noreply.github.com>

* clearer assert message

* clarify list inversion in comment

* comments and typing

* validate contents of samples dataframe before setting

* add file paths check

* add seed to random_split function

* fix expected columns

* fix typo

* add seed parameter to datamodules

* set global seed in test entrypoint

* add NONE option to valsplitmode

* clarify setup behaviour in docstring

* fix typo

Co-authored-by: Joao P C Bertoldo <24547377+jpcbertoldo@users.noreply.github.com>

Co-authored-by: Joao P C Bertoldo <24547377+jpcbertoldo@users.noreply.github.com>
  • Loading branch information
djdameln and jpcbertoldo authored Oct 31, 2022
1 parent d78f995 commit b21045b
Show file tree
Hide file tree
Showing 29 changed files with 1,018 additions and 1,234 deletions.
21 changes: 21 additions & 0 deletions anomalib/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,27 @@ def get_configurable_parameters(
if "format" not in config.dataset.keys():
config.dataset.format = "mvtec"

if "create_validation_set" in config.dataset.keys():
warn(
"The 'create_validation_set' parameter is deprecated and will be removed in v0.4.0. Please use "
"'validation_split_mode' instead."
)
config.dataset.validation_split_mode = "from_test" if config.dataset.create_validation_set else "same_as_test"

if "test_batch_size" in config.dataset.keys():
warn(
"The 'test_batch_size' parameter is deprecated and will be removed in v0.4.0. Please use "
"'eval_batch_size' instead."
)
config.dataset.eval_batch_size = config.dataset.test_batch_size

if "transform_config" in config.dataset.keys() and "val" in config.dataset.transform_config.keys():
warn(
"The 'transform_config.val' parameter is deprecated and will be removed in v0.4.0. Please use "
"'transform_config.eval' instead."
)
config.dataset.transform_config.eval = config.dataset.transform_config.val

config = update_input_size_config(config)

# Project Configs
Expand Down
31 changes: 14 additions & 17 deletions anomalib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from typing import Union

from omegaconf import DictConfig, ListConfig
from pytorch_lightning import LightningDataModule

from .base import AnomalibDataModule, AnomalibDataset
from .btech import BTech
from .folder import Folder
from .inference import InferenceDataset
Expand All @@ -17,7 +17,7 @@
logger = logging.getLogger(__name__)


def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule:
def get_datamodule(config: Union[DictConfig, ListConfig]) -> AnomalibDataModule:
"""Get Anomaly Datamodule.
Args:
Expand All @@ -28,37 +28,33 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule
"""
logger.info("Loading the datamodule")

datamodule: LightningDataModule
datamodule: AnomalibDataModule

if config.dataset.format.lower() == "mvtec":
datamodule = MVTec(
# TODO: Remove config values. IAAALD-211
root=config.dataset.path,
category=config.dataset.category,
image_size=(config.dataset.image_size[0], config.dataset.image_size[1]),
train_batch_size=config.dataset.train_batch_size,
test_batch_size=config.dataset.test_batch_size,
eval_batch_size=config.dataset.eval_batch_size,
num_workers=config.dataset.num_workers,
seed=config.project.seed,
task=config.dataset.task,
transform_config_train=config.dataset.transform_config.train,
transform_config_val=config.dataset.transform_config.val,
create_validation_set=config.dataset.create_validation_set,
transform_config_eval=config.dataset.transform_config.eval,
val_split_mode=config.dataset.val_split_mode,
)
elif config.dataset.format.lower() == "btech":
datamodule = BTech(
# TODO: Remove config values. IAAALD-211
root=config.dataset.path,
category=config.dataset.category,
image_size=(config.dataset.image_size[0], config.dataset.image_size[1]),
train_batch_size=config.dataset.train_batch_size,
test_batch_size=config.dataset.test_batch_size,
eval_batch_size=config.dataset.eval_batch_size,
num_workers=config.dataset.num_workers,
seed=config.project.seed,
task=config.dataset.task,
transform_config_train=config.dataset.transform_config.train,
transform_config_val=config.dataset.transform_config.val,
create_validation_set=config.dataset.create_validation_set,
transform_config_eval=config.dataset.transform_config.eval,
val_split_mode=config.dataset.val_split_mode,
)
elif config.dataset.format.lower() == "folder":
datamodule = Folder(
Expand All @@ -70,14 +66,13 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule
mask_dir=config.dataset.mask,
extensions=config.dataset.extensions,
split_ratio=config.dataset.split_ratio,
seed=config.project.seed,
image_size=(config.dataset.image_size[0], config.dataset.image_size[1]),
train_batch_size=config.dataset.train_batch_size,
test_batch_size=config.dataset.test_batch_size,
eval_batch_size=config.dataset.eval_batch_size,
num_workers=config.dataset.num_workers,
transform_config_train=config.dataset.transform_config.train,
transform_config_val=config.dataset.transform_config.val,
create_validation_set=config.dataset.create_validation_set,
transform_config_eval=config.dataset.transform_config.eval,
val_split_mode=config.dataset.val_split_mode,
)
else:
raise ValueError(
Expand All @@ -90,6 +85,8 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule


__all__ = [
"AnomalibDataset",
"AnomalibDataModule",
"get_datamodule",
"BTech",
"Folder",
Expand Down
10 changes: 10 additions & 0 deletions anomalib/data/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Base classes for custom dataset and datamodules."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


from .datamodule import AnomalibDataModule
from .dataset import AnomalibDataset

__all__ = ["AnomalibDataset", "AnomalibDataModule"]
108 changes: 108 additions & 0 deletions anomalib/data/base/datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Anomalib datamodule base class."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import logging
from abc import ABC
from typing import Optional

from pandas import DataFrame
from pytorch_lightning import LightningDataModule
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils.data import DataLoader

from anomalib.data.base.dataset import AnomalibDataset
from anomalib.data.utils import ValSplitMode, random_split

logger = logging.getLogger(__name__)


class AnomalibDataModule(LightningDataModule, ABC):
"""Base Anomalib data module.
Args:
train_batch_size (int): Batch size used by the train dataloader.
test_batch_size (int): Batch size used by the val and test dataloaders.
num_workers (int): Number of workers used by the train, val and test dataloaders.
seed (Optional[int], optional): Seed used during random subset splitting.
"""

def __init__(
self,
train_batch_size: int,
eval_batch_size: int,
num_workers: int,
val_split_mode: ValSplitMode,
seed: Optional[int] = None,
):
super().__init__()
self.train_batch_size = train_batch_size
self.eval_batch_size = eval_batch_size
self.num_workers = num_workers
self.val_split_mode = val_split_mode
self.seed = seed

self.train_data: Optional[AnomalibDataset] = None
self.val_data: Optional[AnomalibDataset] = None
self.test_data: Optional[AnomalibDataset] = None

self._samples: Optional[DataFrame] = None

def setup(self, stage: Optional[str] = None):
"""Setup train, validation and test data.
Args:
stage: Optional[str]: Train/Val/Test stages. (Default value = None)
"""
if not self.is_setup:
self._setup(stage)
assert self.is_setup

def _setup(self, _stage: Optional[str] = None) -> None:
"""Set up the datasets and perform dynamic subset splitting.
This method may be overridden in subclass for custom splitting behaviour.
Note: The stage argument is not used here. This is because, for a given instance of an AnomalibDataModule
subclass, all three subsets are created at the first call of setup(). This is to accommodate the subset
splitting behaviour of anomaly tasks, where the validation set is usually extracted from the test set, and
the test set must therefore be created as early as the `fit` stage.
"""
assert self.train_data is not None
assert self.test_data is not None

self.train_data.setup()
self.test_data.setup()
if self.val_split_mode == ValSplitMode.FROM_TEST:
self.val_data, self.test_data = random_split(self.test_data, [0.5, 0.5], label_aware=True, seed=self.seed)
elif self.val_split_mode == ValSplitMode.SAME_AS_TEST:
self.val_data = self.test_data
elif self.val_split_mode != ValSplitMode.NONE:
raise ValueError(f"Unknown validation split mode: {self.val_split_mode}")

@property
def is_setup(self):
"""Checks if setup() has been called."""
# at least one of [train_data, val_data, test_data] should be setup
if self.train_data is not None and self.train_data.is_setup:
return True
if self.val_data is not None and self.val_data.is_setup:
return True
if self.test_data is not None and self.test_data.is_setup:
return True
return False

def train_dataloader(self) -> TRAIN_DATALOADERS:
"""Get train dataloader."""
return DataLoader(self.train_data, shuffle=True, batch_size=self.train_batch_size, num_workers=self.num_workers)

def val_dataloader(self) -> EVAL_DATALOADERS:
"""Get validation dataloader."""
return DataLoader(self.val_data, shuffle=False, batch_size=self.eval_batch_size, num_workers=self.num_workers)

def test_dataloader(self) -> EVAL_DATALOADERS:
"""Get test dataloader."""
return DataLoader(self.test_data, shuffle=False, batch_size=self.eval_batch_size, num_workers=self.num_workers)
Loading

0 comments on commit b21045b

Please sign in to comment.