-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🐛 Hook orders are different to what is documented
In the documentation it says the order of called methods in the train loop is like the following:
def train_loop():
on_train_epoch_start()
train_outs = []
for train_batch in train_dataloader():
on_train_batch_start()
out = training_step(batch)
train_outs.append(out)
loss = out.loss
backward()
on_after_backward()
optimizer_step()
on_before_zero_grad()
optimizer_zero_grad()
on_train_batch_end(out)
It furthermore says in the description of on_after_backward:
Called in the training loop after loss.backward() and before optimizers do anything. This is the ideal place to inspect or log gradient information.
For before_zero_grad it says:
Called after optimizer.step() and before optimizer.zero_grad().
which both matches with the above defined training loop.
However, if I use this methods like in the code below on_before_zero_grad is always called before on_after_backward.
Reproduction
Expected behavior
The methods are called in the order as described in the documentation.
Environment
environment.yml file:
name: hook-order
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- python>=3.7.0
- pytorch>=1.8.0
- torchvision>=0.9.0
- cudatoolkit>=11.1
- scipy
- torchcsprng
- pytest
- mypy
- black
- scikit-learn
- pytorch-lightning
- matplotlib
- rope
- pip
- pip:
- testfixtures
- segmentation-models-pytorch
- PyTorch Version (e.g., 1.0): 1.8.0
- OS (e.g., Linux): Ubuntu 20.04
- How you installed PyTorch (
conda,pip, source): conda env create -f environment.yml - Python version: 3.9.2
Additional context
I would just need a method that is called right after optimizer_step() so if there is any alternative please let me know.
Thanks in advance.