Skip to content

Commit

Permalink
update indicator for wrapped lr scheduler, in align with torch 2.4
Browse files Browse the repository at this point in the history
  • Loading branch information
wenh06 committed Aug 1, 2024
1 parent 658d8e5 commit 9cc0e80
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 3 deletions.
20 changes: 17 additions & 3 deletions fl_sim/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@
from pathlib import Path
from typing import Any, Iterable, Union

import packaging.version
import torch.optim as opt
import torch_optimizer as topt
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from torch.torch_version import __version__ as torch_version
from torch_ecg.cfg import CFG
from torch_ecg.utils import add_docstring

Expand Down Expand Up @@ -140,7 +142,13 @@ def get_optimizer(
# ``Detected call of `lr_scheduler.step()` before `optimizer.step()`.``.
# The risk is one has to check that scheduler.step() is called after
# optimizer.step() in the training loop by himself.
optimizer.step._with_counter = True
if packaging.version.parse(torch_version) < packaging.version.parse("2.4.0"):
optimizer.step._with_counter = True
else:
# NOTE: new in torch 2.4.0,
# the check by `optimizer.step._with_counter` is replaced by
# `optimizer.step._wrapped_by_lr_sched`
optimizer.step._wrapped_by_lr_sched = True
return optimizer

try:
Expand All @@ -153,7 +161,10 @@ def get_optimizer(
optimizer.step,
**{k: v for k, v in _extra_kwargs.items() if k not in step_args},
)
optimizer.step._with_counter = True
if packaging.version.parse(torch_version) < packaging.version.parse("2.4.0"):
optimizer.step._with_counter = True
else:
optimizer.step._wrapped_by_lr_sched = True
# print(f"optimizer_name: {optimizer_name}")
return optimizer
except Exception:
Expand All @@ -170,7 +181,10 @@ def get_optimizer(
optimizer.step,
**{k: v for k, v in _extra_kwargs.items() if k not in step_args},
)
optimizer.step._with_counter = True
if packaging.version.parse(torch_version) < packaging.version.parse("2.4.0"):
optimizer.step._with_counter = True
else:
optimizer.step._wrapped_by_lr_sched = True
return optimizer
except Exception:
pass
Expand Down
1 change: 1 addition & 0 deletions requirements-no-torch.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ bib-lookup>=0.0.19
deprecate-kwargs
torch-ecg>=0.0.26
PyYAML
packaging
1 change: 1 addition & 0 deletions requirements-viz.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ seaborn
ipywidgets
ipython
termcolor
packaging
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ bib-lookup>=0.0.19
deprecate-kwargs
torch-ecg>=0.0.26
PyYAML
packaging

0 comments on commit 9cc0e80

Please sign in to comment.