Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Training EfficientAd failed: ZeroDivisionError: integer division or modulo by zero #1679

Closed
1 task done
glucasol opened this issue Jan 26, 2024 · 16 comments · Fixed by #1705
Closed
1 task done

Comments

@glucasol
Copy link

glucasol commented Jan 26, 2024

Describe the bug

Hi, I have tried to run the following example in last anomalib version:

from anomalib.data import MVTec
from anomalib.models import EfficientAd
from anomalib.engine import Engine

from anomalib.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import EarlyStopping

datamodule = MVTec()
model = EfficientAd()
engine = Engine()

engine.fit(datamodule=datamodule, model=model)

But It gives the following error:

ZeroDivisionError                         Traceback (most recent call last)

[<ipython-input-7-4052201ff8ac>](https://localhost:8080/#) in <cell line: 3>()
      1 # start training
      2 engine = Engine(task=TaskType.SEGMENTATION)
----> 3 engine.fit(model=model, datamodule=datamodule)

15 frames

[/content/anomalib/src/anomalib/engine/engine.py](https://localhost:8080/#) in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    354             self.trainer.validate(model, val_dataloaders, datamodule=datamodule, ckpt_path=ckpt_path)
    355         else:
--> 356             self.trainer.fit(model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    357 
    358     def validate(

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py](https://localhost:8080/#) in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    542         self.state.status = TrainerStatus.RUNNING
    543         self.training = True
--> 544         call._call_and_handle_interrupt(
    545             self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    546         )

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py](https://localhost:8080/#) in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     42         if trainer.strategy.launcher is not None:
     43             return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 44         return trainer_fn(*args, **kwargs)
     45 
     46     except _TunerExitException:

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py](https://localhost:8080/#) in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    578             model_connected=self.lightning_module is not None,
    579         )
--> 580         self._run(model, ckpt_path=ckpt_path)
    581 
    582         assert self.state.stopped

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py](https://localhost:8080/#) in _run(self, model, ckpt_path)
    987         # RUN THE TRAINER
    988         # ----------------------------
--> 989         results = self._run_stage()
    990 
    991         # ----------------------------

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py](https://localhost:8080/#) in _run_stage(self)
   1033                 self._run_sanity_check()
   1034             with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1035                 self.fit_loop.run()
   1036             return None
   1037         raise RuntimeError(f"Unexpected state {self.state}")

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/fit_loop.py](https://localhost:8080/#) in run(self)
    200             try:
    201                 self.on_advance_start()
--> 202                 self.advance()
    203                 self.on_advance_end()
    204                 self._restarting = False

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/fit_loop.py](https://localhost:8080/#) in advance(self)
    357         with self.trainer.profiler.profile("run_training_epoch"):
    358             assert self._data_fetcher is not None
--> 359             self.epoch_loop.run(self._data_fetcher)
    360 
    361     def on_advance_end(self) -> None:

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/training_epoch_loop.py](https://localhost:8080/#) in run(self, data_fetcher)
    134         while not self.done:
    135             try:
--> 136                 self.advance(data_fetcher)
    137                 self.on_advance_end(data_fetcher)
    138                 self._restarting = False

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/training_epoch_loop.py](https://localhost:8080/#) in advance(self, data_fetcher)
    248         self.update_lr_schedulers("step", update_plateau_schedulers=False)
    249         if self._num_ready_batches_reached():
--> 250             self.update_lr_schedulers("epoch", update_plateau_schedulers=False)
    251 
    252         if using_dataloader_iter:

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/training_epoch_loop.py](https://localhost:8080/#) in update_lr_schedulers(self, interval, update_plateau_schedulers)
    335         if interval == "step" and self._should_accumulate():
    336             return
--> 337         self._update_learning_rates(interval=interval, update_plateau_schedulers=update_plateau_schedulers)
    338 
    339     def _update_learning_rates(self, interval: str, update_plateau_schedulers: bool) -> None:

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/training_epoch_loop.py](https://localhost:8080/#) in _update_learning_rates(self, interval, update_plateau_schedulers)
    386 
    387                 # update LR
--> 388                 call._call_lightning_module_hook(
    389                     trainer,
    390                     "lr_scheduler_step",

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py](https://localhost:8080/#) in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, **kwargs)
    155 
    156     with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
--> 157         output = fn(*args, **kwargs)
    158 
    159     # restore current_fx when nested context

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/core/module.py](https://localhost:8080/#) in lr_scheduler_step(self, scheduler, metric)
   1246         """
   1247         if metric is None:
-> 1248             scheduler.step()  # type: ignore[call-arg]
   1249         else:
   1250             scheduler.step(metric)

[/usr/local/lib/python3.10/dist-packages/torch/optim/lr_scheduler.py](https://localhost:8080/#) in step(self, epoch)
    145             if epoch is None:
    146                 self.last_epoch += 1
--> 147                 values = self.get_lr()
    148             else:
    149                 warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)

[/usr/local/lib/python3.10/dist-packages/torch/optim/lr_scheduler.py](https://localhost:8080/#) in get_lr(self)
    385                           "please use `get_last_lr()`.", UserWarning)
    386 
--> 387         if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
    388             return [group['lr'] for group in self.optimizer.param_groups]
    389         return [group['lr'] * self.gamma

ZeroDivisionError: integer division or modulo by zero

Dataset

MVTec

Model

Other (please specify in the field below)

Steps to reproduce the behavior

Model: EfficientAd

OS information

OS information:

  • Google Colab

Expected behavior

Train the model successfully.

Screenshots

No response

Pip/GitHub

pip

What version/branch did you use?

Branch: main
anomalib version: v1

Configuration YAML

Default

Logs

Default

Code of Conduct

  • I agree to follow this project's Code of Conduct
@isaacdominguez
Copy link

isaacdominguez commented Jan 29, 2024

Hi @glucasol, it might not be the fix, but u have declared Padim instead of EfficientAD. I have also found this bug when I add auto_find_lr in trainer parameters for >0.7.0 anomalib config file. Still testing before I open a new issue.

@glucasol
Copy link
Author

Hi, @isaacdominguez thanks for the reply! My mistake, I have updated the wrong import in the issue, but now it is correct! The error it’s not due to the import. I will check the auto_find_lr parameter, thanks!

@samet-akcay any idea on what could it be?

@isaacdominguez
Copy link

isaacdominguez commented Jan 29, 2024

Ok so I found it was not the auto_find_lr, doing some test today found the same error:

  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py", line 267, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 221, in advance
    self.update_lr_schedulers("epoch", update_plateau_schedulers=False)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 395, in update_lr_schedulers
    self._update_learning_rates(
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 456, in _update_learning_rates
    self.trainer._call_lightning_module_hook(
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 1356, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/core/module.py", line 1671, in lr_scheduler_step
    scheduler.step()  # type: ignore[call-arg]
  File "/usr/local/lib/python3.10/dist-packages/torch/optim/lr_scheduler.py", line 147, in step
    values = self.get_lr()
  File "/usr/local/lib/python3.10/dist-packages/torch/optim/lr_scheduler.py", line 387, in get_lr
    if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
ZeroDivisionError: integer division or modulo by zero
Epoch 0:  79%|███████▉  | 2756/3478 [2:00:49<31:39,  2.63s/it, loss=12, v_num=e833, train_st_step=26.90, train_ae_step=1.910, train_stae_step=0.253, train_loss_step=29.10] 

Still don't know why, but this only happened to me with EfficientAD model.
BTW @glucasol, try to fill the branch/release in the template for the issue, so the developers have more info where this is happening.

@samet-akcay
Copy link
Contributor

samet-akcay commented Jan 29, 2024

I'm a bit occupied with some other tasks, will try to have a look asap

@blaz-r
Copy link
Contributor

blaz-r commented Jan 29, 2024

Juding from issue, the problem is in step_size which is equal 0. I think I encountered this once when I wanted to train EfficientAD for only 2 epochs. So I think something might be wrong with epoch setup.

@blaz-r
Copy link
Contributor

blaz-r commented Jan 29, 2024

Something probably goes wrong here:

def configure_optimizers(self) -> torch.optim.Optimizer:
"""Configure optimizers."""
optimizer = torch.optim.Adam(
list(self.model.student.parameters()) + list(self.model.ae.parameters()),
lr=self.lr,
weight_decay=self.weight_decay,
)
num_steps = min(
self.trainer.max_steps,
self.trainer.max_epochs * len(self.trainer.datamodule.train_dataloader()),
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.95 * num_steps), gamma=0.1)
return {"optimizer": optimizer, "lr_scheduler": scheduler}

so step_size is set to 0, probably indicating that num_steps is also 0, which in turn means something is not correct with max_steps and max_epochs of trainer.

@glucasol
Copy link
Author

@blaz-r Thanks for the reply!
You're right, step_size is set to 0 because num_steps is the defaut value (-1) and then step_size=int(0.95*(-1))=0.
Any reason for num_steps be the mininum between trainer.max_steps and trainer.max_epochs * len(self.trainer.datamodule.train_dataloader()) ?

@blaz-r
Copy link
Contributor

blaz-r commented Jan 30, 2024

Due to the way efficientad training is specified, so it's set like this in config:
https://github.com/openvinotoolkit/anomalib/blob/main/src%2Fconfigs%2Fmodel%2Fefficient_ad.yaml#L16-L18
Default config has this, but this causes issues if you use the model directly.

@NilsB98
Copy link

NilsB98 commented Jan 30, 2024

Just ran into the same issue and the solution that @blaz-r mentioned solves it.
Using the API instead of the CLI you can then adjust your Engine like this:
engine = Engine(max_epochs=200,max_steps=70_000)

@glucasol
Copy link
Author

For me the solution @blaz-r mentioned worked too. Thank you guys!

@blaz-r
Copy link
Contributor

blaz-r commented Jan 31, 2024

Yeah, maybe we should add a check for that somewhere in code. @alexriedel1 what do you think would be the best option here?

@alexriedel1
Copy link
Contributor

@blaz-r Thanks for the reply! You're right, step_size is set to 0 because num_steps is the defaut value (-1) and then step_size=int(0.95*(-1))=0. Any reason for num_steps be the mininum between trainer.max_steps and trainer.max_epochs * len(self.trainer.datamodule.train_dataloader()) ?

Without setting any steps or epochs, the trainer will default max_epochs = 1000
(https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api)

The easiest way (that would also make the model a bit more epoch-agnostic) would be to use the max instead of the min here:

However, the original papers say they are using 70k steps for training. From my point of view this was an arbitrary choice for training on typical anomaly datasets. In real world datasets there might the completely different numbers of steps and epochs necessary, so I have no bad feeling of using the number of epochs instead of steps.

@blaz-r
Copy link
Contributor

blaz-r commented Jan 31, 2024

Thanks for the answer.
I think that using max would break the mechanism though. If I understand this correctly, the point is to set a scheduler step to 90% of steps.
Due to the mechanism of Trainer, it stops at the earliest that is reached (either max_steps is reached or max_epochs) meaning that for scheduler (which tracks steps) to work as intended, we need to find the actual number of steps that the model will run.
This translates into minimum in code between max_steps and epochs * "steps in each epoch".

If we had max here, and let's say max_epochs = 100, max_steps = 10, and "steps in each epoch" = 4, then maximum would eval to 100 * 4, but the training would stop at max_steps, which in turn means that scheduler didn't make the reduction.

So, with all this information, I would actually say that we add an if statement like this:

max_steps = self.trainer.max_steps
max_steps_from_epochs = self.trainer.max_epochs * len(self.trainer.datamodule.train_dataloader())

if max_steps == -1 or max_steps_from_epochs == -1:
        num_steps = max(
            max_steps ,
            max_steps_from_epochs 
        )
else:
        num_steps = min(
            max_steps ,
            max_steps_from_epochs 
        )

which would set to the minimum of each if they are both defined. If only one is defined, then the max would take the other. If by any chance infinite training is specified, this would again fail, so it'd need another guard, but I'm not sure if it's realistic to expect infinite training here?

@alexriedel1
Copy link
Contributor

Ah yes you're right!

We want to recreate this behaviour of pytorch lightning:

max_epochs ([Optional](https://docs.python.org/3/library/typing.html#typing.Optional)[[int](https://docs.python.org/3/library/functions.html#int)]) – Stop training once this number of epochs is reached. Disabled by default (None). If both max_epochs and max_steps are not specified, defaults to max_epochs = 1000. To enable infinite training, set max_epochs = -1.

max_steps ([int](https://docs.python.org/3/library/functions.html#int)) – Stop training after this number of steps. Disabled by default (-1). If max_steps = -1 and max_epochs = None, will default to max_epochs = 1000. To enable infinite training, set max_epochs to -1.

if self.trainer.max_epochs <  0:
    raise ValueError("A finite number of steps or epochs must be defined")

num_steps = min(
            self.trainer.max_steps,
            self.trainer.max_epochs * len(self.trainer.datamodule.train_dataloader()),
        )

if self.trainer.max_steps == -1:
   num_steps = self.trainer.max_epochs * len(self.trainer.datamodule.train_dataloader())

@blaz-r
Copy link
Contributor

blaz-r commented Jan 31, 2024

Yeah. This would probably be the best solution. Do you want to make a PR or should I?

@alexriedel1
Copy link
Contributor

Yeah. This would probably be the best solution. Do you want to make a PR or should I?

you can go for it! thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants