Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfixes for Datamodules feature branch #800

Merged
merged 2 commits into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion anomalib/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def get_configurable_parameters(
config.dataset.mask_dir = config.dataset.mask
if "path" in config.dataset:
warn(DeprecationWarning("path will be deprecated in favor of root in config.dataset in a future release."))
config.dataset.mask_dir = config.dataset.mask
config.dataset.root = config.dataset.path

# add category subfolder if needed
if config.dataset.format.lower() in ("btech", "mvtec"):
Expand Down
3 changes: 1 addition & 2 deletions anomalib/data/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def make_folder_dataset(
filenames += filename
labels += label

samples = DataFrame({"image_path": filenames, "label": labels})
samples = DataFrame({"image_path": filenames, "label": labels, "mask_path": ""})

# Create label index for normal (0) and abnormal (1) images.
samples.loc[(samples.label == "normal") | (samples.label == "normal_test"), "label_index"] = 0
Expand All @@ -146,7 +146,6 @@ def make_folder_dataset(
# If a path to mask is provided, add it to the sample dataframe.
if mask_dir is not None:
mask_dir = _check_and_convert_path(mask_dir)
samples["mask_path"] = ""
for index, row in samples.iterrows():
if row.label_index == 1:
rel_image_path = row.image_path.relative_to(abnormal_dir)
Expand Down
28 changes: 26 additions & 2 deletions tests/pre_merge/datasets/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def make_folder_data_module(
normal_dir="good",
abnormal_dir="broken_large",
normal_test_dir="good_test",
mask_dir="ground_truth/broken_large",
):
"""Create Folder Data Module."""
root = get_dataset_path(dataset="bottle")
Expand All @@ -88,7 +89,7 @@ def make_folder_data_module(
normal_dir=normal_dir,
abnormal_dir=abnormal_dir,
normal_test_dir=normal_test_dir,
mask_dir="ground_truth/broken_large",
mask_dir=mask_dir,
normal_split_ratio=0.2,
image_size=(256, 256),
train_batch_size=batch_size,
Expand Down Expand Up @@ -313,16 +314,39 @@ def test_equal_splits(self, make_data_module, dataset, test_split_mode):

@pytest.mark.parametrize("test_split_mode", ("from_dir", "synthetic"))
def test_normal_test_dir_omitted(self, make_data_module, test_split_mode):
"""The test set should always contain normal samples even when no normal_test_dir ir provided."""
"""Tests if the data module functions properly when no normal_test_dir is provided."""
data_module = make_data_module(dataset="folder", test_split_mode=test_split_mode, normal_test_dir=None)
# check if we can retrieve a sample from every subset
next(iter(data_module.train_dataloader()))
next(iter(data_module.test_dataloader()))
next(iter(data_module.val_dataloader()))
# the test set should contain normal samples which are sampled from the train set
assert data_module.test_data.has_normal

def test_abnormal_dir_omitted_from_dir(self, make_data_module):
"""The test set should not contain anomalous samples if no abnormal_dir provided and split mode is from_dir."""
data_module = make_data_module(dataset="folder", test_split_mode="from_dir", abnormal_dir=None)
# check if we can retrieve a sample from every subset
next(iter(data_module.train_dataloader()))
next(iter(data_module.test_dataloader()))
next(iter(data_module.val_dataloader()))
# the test set should not contain anomalous samples, because there aren't any available
assert not data_module.test_data.has_anomalous

def test_abnormal_dir_omitted_synthetic(self, make_data_module):
"""The test set should contain anomalous samples if no abnormal_dir provided and split mode is synthetic."""
data_module = make_data_module(dataset="folder", test_split_mode="synthetic", abnormal_dir=None)
# check if we can retrieve a sample from every subset
next(iter(data_module.train_dataloader()))
next(iter(data_module.test_dataloader()))
next(iter(data_module.val_dataloader()))
# the test set should contain anomalous samples, which have been converted from normals
assert data_module.test_data.has_anomalous

def test_masks_dir_omitted(self, make_data_module):
"""Tests if the data module can be set up in classification mode when no masks are passed."""
data_module = make_data_module(dataset="folder", task="classification", mask_dir=None)
# check if we can retrieve a sample from every subset
next(iter(data_module.train_dataloader()))
next(iter(data_module.test_dataloader()))
next(iter(data_module.val_dataloader()))