diff --git a/src/anomalib/data/base/datamodule.py b/src/anomalib/data/base/datamodule.py index 5f4b49cf9e..3a6a603e0e 100644 --- a/src/anomalib/data/base/datamodule.py +++ b/src/anomalib/data/base/datamodule.py @@ -185,7 +185,15 @@ def _create_test_split(self) -> None: def _create_val_split(self) -> None: """Obtain the validation set based on the settings in the config.""" - if self.val_split_mode == ValSplitMode.FROM_TEST: + if self.val_split_mode == ValSplitMode.FROM_TRAIN: + # randomly sampled from train set + self.train_data, self.val_data = random_split( + self.train_data, + self.val_split_ratio, + label_aware=True, + seed=self.seed, + ) + elif self.val_split_mode == ValSplitMode.FROM_TEST: # randomly sampled from test set self.test_data, self.val_data = random_split( self.test_data, diff --git a/src/anomalib/data/utils/split.py b/src/anomalib/data/utils/split.py index 27d1b4d770..536ccc72c0 100644 --- a/src/anomalib/data/utils/split.py +++ b/src/anomalib/data/utils/split.py @@ -47,6 +47,7 @@ class ValSplitMode(str, Enum): NONE = "none" SAME_AS_TEST = "same_as_test" + FROM_TRAIN = "from_train" FROM_TEST = "from_test" SYNTHETIC = "synthetic"