Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DRAEM #1431

Merged
merged 12 commits into from
Nov 8, 2023
12 changes: 7 additions & 5 deletions src/anomalib/models/draem/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dataset:
num_workers: 8
image_size: 256 # dimensions to which images are resized (mandatory)
center_crop: null # dimensions to which images are center-cropped after resizing (optional)
normalization: imagenet # data distribution to which the images will be normalized: [none, imagenet]
normalization: none # data distribution to which the images will be normalized: [none, imagenet]
djdameln marked this conversation as resolved.
Show resolved Hide resolved
transform_config:
train: null
eval: null
Expand All @@ -29,12 +29,14 @@ model:
name: draem
anomaly_source_path: null # optional, e.g. ./datasets/dtd
lr: 0.0001
beta: [0.1, 1.0] # generated anomaly opacity parameter, either float or interval
enable_sspcab: false
sspcab_lambda: 0.1
early_stopping:
patience: 20
metric: pixel_AUROC
mode: max
# uncomment the following settings to enable early stopping
#early_stopping:
# patience: 20
# metric: pixel_AUROC
# mode: max
normalization_method: min_max # options: [none, min_max, cdf]

metrics:
Expand Down
37 changes: 27 additions & 10 deletions src/anomalib/models/draem/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,15 @@ class Draem(AnomalyModule):
"""

def __init__(
self, enable_sspcab: bool = False, sspcab_lambda: float = 0.1, anomaly_source_path: str | None = None
self,
enable_sspcab: bool = False,
sspcab_lambda: float = 0.1,
anomaly_source_path: str | None = None,
beta: float | tuple[float, float] = (0.1, 1.0),
) -> None:
super().__init__()

self.augmenter = Augmenter(anomaly_source_path)
self.augmenter = Augmenter(anomaly_source_path, beta=beta)
self.model = DraemModel(sspcab=enable_sspcab)
self.loss = DraemLoss()
self.sspcab = enable_sspcab
Expand Down Expand Up @@ -121,10 +125,17 @@ class DraemLightning(Draem):
"""

def __init__(self, hparams: DictConfig | ListConfig) -> None:
# beta in config can be either float or sequence
beta = hparams.model.beta
# if sequence - change to tuple[float, float]
if isinstance(beta, ListConfig):
beta = tuple(beta)

super().__init__(
enable_sspcab=hparams.model.enable_sspcab,
sspcab_lambda=hparams.model.sspcab_lambda,
anomaly_source_path=hparams.model.anomaly_source_path,
beta=beta,
)
self.hparams: DictConfig | ListConfig # type: ignore
self.save_hyperparameters(hparams)
Expand All @@ -138,13 +149,19 @@ def configure_callbacks(self) -> list[EarlyStopping]:
deprecated, and callbacks will be configured from either
config.yaml file or from CLI.
"""
early_stopping = EarlyStopping(
monitor=self.hparams.model.early_stopping.metric,
patience=self.hparams.model.early_stopping.patience,
mode=self.hparams.model.early_stopping.mode,
)
return [early_stopping]
callbacks = []
if "early_stopping" in self.hparams.model:
early_stopping = EarlyStopping(
monitor=self.hparams.model.early_stopping.metric,
patience=self.hparams.model.early_stopping.patience,
mode=self.hparams.model.early_stopping.mode,
)
callbacks.append(early_stopping)

return callbacks

def configure_optimizers(self) -> torch.optim.Optimizer:
def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[torch.optim.lr_scheduler.LRScheduler]]:
"""Configure the Adam optimizer."""
return torch.optim.Adam(params=self.model.parameters(), lr=self.hparams.model.lr)
optimizer = torch.optim.Adam(params=self.model.parameters(), lr=self.hparams.model.lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[400, 600], gamma=0.1)
return [optimizer], [scheduler]
Comment on lines +165 to +167
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We previously did not add this because early stopping is enabled, and the training rarely reaches 400 epochs. But since we recently decided to always use the configuration recommended in the paper as the default, I agree it would be good to add it now.

In this case we should probably also disable early stopping by default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a few training runs I did, the last few epochs make a little difference, but it is noticeable. With current design, removing early stopping would also require changing of this function:

def configure_callbacks(self) -> list[EarlyStopping]:
"""Configure model-specific callbacks.
Note:
This method is used for the existing CLI.
When PL CLI is introduced, configure callback method will be
deprecated, and callbacks will be configured from either
config.yaml file or from CLI.
"""
early_stopping = EarlyStopping(
monitor=self.hparams.model.early_stopping.metric,
patience=self.hparams.model.early_stopping.patience,
mode=self.hparams.model.early_stopping.mode,
)
return [early_stopping]

Should I then remove it from config and also add an if statement here to check if it's present or should it be removed altogether.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the user should have the possibility to enable early stopping if they want to, so let's add a check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, this is now implemented.

Loading