Skip to content

Commit

Permalink
[hotfix] Extend Optimizer + update doc (#5095)
Browse files Browse the repository at this point in the history
* resolve urgent bug

* update pr

* update doc

* update

* remove typo

* add defaults

* Update pytorch_lightning/__init__.py

* Update setup.py

* update doc

* Update docs/source/optimizers.rst

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* update

* resolve doc

* debug test

* update test

* Update docs/source/optimizers.rst

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Update docs/source/optimizers.rst

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Update docs/source/optimizers.rst

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* remove useless import

* Update docs/source/optimizers.rst

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
3 people authored Dec 11, 2020
1 parent 74171ef commit 1a970b2
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 85 deletions.
49 changes: 36 additions & 13 deletions docs/source/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -191,46 +191,69 @@ override the :meth:`optimizer_step` function.
For example, here step optimizer A every 2 batches and optimizer B every 4 batches
.. testcode::
.. note:: When using Trainer(enable_pl_optimizer=True), there is no need to call `.zero_grad()`.
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False):
optimizer.step()
.. testcode::
def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx):
optimizer.zero_grad()
# Alternating schedule for optimizer steps (ie: GANs)
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False):
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
# update generator opt every 2 steps
if optimizer_i == 0:
if batch_nb % 2 == 0 :
optimizer.step()
optimizer.zero_grad()
optimizer.step(closure=closure)
# update discriminator opt every 4 steps
if optimizer_i == 1:
if batch_nb % 4 == 0 :
optimizer.step()
optimizer.zero_grad()
optimizer.step(closure=closure)
.. note:: When using ``Trainer(enable_pl_optimizer=True)``, ``.step`` accepts a boolean ``make_optimizer_step`` which can be used as follow.
.. testcode::
def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx):
optimizer.zero_grad()
# Alternating schedule for optimizer steps (ie: GANs)
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
# update generator opt every 2 steps
if optimizer_i == 0:
optimizer.step(closure=closure, make_optimizer_step=(batch_nb % 2) == 0)
# ...
# add as many optimizers as you want
# update discriminator opt every 4 steps
if optimizer_i == 1:
optimizer.step(closure=closure, make_optimizer_step=(batch_nb % 4) == 0)
Here we add a learning-rate warm up
.. testcode::
# learning rate warm-up
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False):
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
# warm up lr
if self.trainer.global_step < 500:
lr_scale = min(1., float(self.trainer.global_step + 1) / 500.)
for pg in optimizer.param_groups:
pg['lr'] = lr_scale * self.hparams.learning_rate
# update params
optimizer.step()
optimizer.zero_grad()
optimizer.step(closure=closure)
The default ``optimizer_step`` is relying on the internal ``LightningOptimizer`` to properly perform a step.
.. testcode::
from pytorch_lightning.core.optimizer import LightningOptimizer
# function hook in LightningModule
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
if not isinstance(optimizer, LightningOptimizer):
# wraps into LightingOptimizer only for running step
optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer)
optimizer.step(closure=closure)
----------
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,6 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):

def optimizer_step(
self,
*args,
epoch: int = None,
batch_idx: int = None,
optimizer: Optimizer = None,
Expand All @@ -1179,7 +1178,6 @@ def optimizer_step(
on_tpu: bool = None,
using_native_amp: bool = None,
using_lbfgs: bool = None,
**kwargs,
) -> None:
r"""
Override this method to adjust the default way the
Expand Down Expand Up @@ -1254,7 +1252,7 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
if not isinstance(optimizer, LightningOptimizer):
# wraps into LightingOptimizer only for running step
optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer)
optimizer.step(closure=optimizer_closure, *args, **kwargs)
optimizer.step(closure=optimizer_closure)

def optimizer_zero_grad(
self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int
Expand Down
34 changes: 26 additions & 8 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import types
from typing import Any, Callable, Optional
from weakref import proxy
Expand Down Expand Up @@ -58,12 +57,35 @@ def __init__(self,
else:
self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})

self._trainer = None
self._optimizer = optimizer
self._trainer = None
self._accumulate_grad_batches = accumulate_grad_batches
self._support_closure = 'closure' in inspect.signature(optimizer.step).parameters
self._optimizer_idx = None

@property
def defaults(self):
return self._optimizer.defaults

@defaults.setter
def defaults(self, defaults):
self._optimizer.defaults = defaults

@property
def state(self):
return self._optimizer.state

@state.setter
def state(self, state):
self._optimizer.state = state

@property
def param_groups(self):
return self._optimizer.param_groups

@param_groups.setter
def param_groups(self, param_groups):
self._optimizer.param_groups = param_groups

@property
def accumulate_grad_batches(self):
return self._accumulate_grad_batches
Expand Down Expand Up @@ -111,11 +133,7 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n

else:
with trainer.profiler.profile(profiler_name):
if self._support_closure:
optimizer.step(closure=closure, *args, **kwargs)
else:
closure()
optimizer.step(*args, **kwargs)
optimizer.step(closure=closure, *args, **kwargs)

accelerator_backend = trainer.accelerator_backend
if accelerator_backend is not None and accelerator_backend.rpc_enabled:
Expand Down
14 changes: 6 additions & 8 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def _process_result(self, training_step_output, split_batch):

return training_step_output_for_epoch_end

def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure, *args, **kwargs):
def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure):
model_ref = self.trainer.get_model()

is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
Expand All @@ -491,16 +491,14 @@ def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_

# model hook
model_ref.optimizer_step(
epoch=self.trainer.current_epoch,
batch_idx=batch_idx,
optimizer=optimizer,
optimizer_idx=opt_idx,
optimizer_closure=train_step_and_backward_closure,
self.trainer.current_epoch,
batch_idx,
optimizer,
opt_idx,
train_step_and_backward_closure,
on_tpu=self.trainer.use_tpu and TPU_AVAILABLE,
using_native_amp=using_native_amp,
using_lbfgs=is_lbfgs,
*args,
**kwargs,
)

def on_before_zero_grad(self, optimizer):
Expand Down
44 changes: 44 additions & 0 deletions tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,47 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
assert sgd_zero_grad.call_count == 4
assert adam_step.call_count == 2
assert adam_zero_grad.call_count == 2


@pytest.mark.parametrize("enable_pl_optimizer", [False, True])
def test_params_groups_and_state_are_accessible(enable_pl_optimizer, tmpdir):

with patch("torch.optim.SGD.step") as sgd_step, \
patch("torch.optim.SGD.zero_grad") as sgd_zero_grad, \
patch("torch.optim.Adam.step") as adam_step, \
patch("torch.optim.Adam.zero_grad") as adam_zero_grad:

class TestModel(BoringModel):

def training_step(self, batch, batch_idx, optimizer_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"loss": loss}

def configure_optimizers(self):
optimizer = SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = Adam(self.layer.parameters(), lr=0.1)
return [optimizer, optimizer_2]

def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure,
on_tpu=False, using_native_amp=False, using_lbfgs=False):
# warm up lr
if self.trainer.global_step < 500:
lr_scale = min(1., float(self.trainer.global_step + 1) / 500.)
for pg in optimizer.param_groups:
pg['lr'] = lr_scale * 0.01

optimizer.step(closure=closure)

model = TestModel()
model.training_epoch_end = None

trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
limit_train_batches=8,
accumulate_grad_batches=1,
enable_pl_optimizer=enable_pl_optimizer
)

trainer.fit(model)
68 changes: 18 additions & 50 deletions tests/core/test_lightning_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,29 @@ def test_state(tmpdir):
model = torch.nn.Linear(3, 4)
optimizer = torch.optim.Adam(model.parameters())
lightning_optimizer = LightningOptimizer(optimizer)

# test state
assert optimizer.state == lightning_optimizer.state
lightning_optimizer.state = optimizer.state
assert optimizer.state == lightning_optimizer.state

# test param_groups
assert optimizer.param_groups == lightning_optimizer.param_groups
lightning_optimizer.param_groups = optimizer.param_groups
assert optimizer.param_groups == lightning_optimizer.param_groups

# test defaults
assert optimizer.defaults == lightning_optimizer.defaults
lightning_optimizer.defaults = optimizer.defaults
assert optimizer.defaults == lightning_optimizer.defaults

assert isinstance(lightning_optimizer, LightningOptimizer)
assert isinstance(lightning_optimizer, Adam)
assert isinstance(lightning_optimizer, Optimizer)
lightning_dict = {}
special_attrs = ["_accumulate_grad_batches", "_optimizer", "_optimizer_idx", "_support_closure",
"_trainer"]
"_trainer", "__getstate__", "__setstate__", "state_dict", "load_state_dict",
"zero_grad", "__setstate__", "add_param_group"]
for k, v in lightning_optimizer.__dict__.items():
if k not in special_attrs:
lightning_dict[k] = v
Expand All @@ -207,55 +224,6 @@ def test_state(tmpdir):
assert optimizer.state == lightning_optimizer.state


def test_lightning_optimizer_with_wrong_optimizer_interface(tmpdir):
class OptimizerWrapper(object):
def __init__(self, optimizer):
self.optim = optimizer
self.state_dict = self.optim.state_dict
self.load_state_dict = self.optim.load_state_dict
self.zero_grad = self.optim.zero_grad
self.add_param_group = self.optim.add_param_group
self.__setstate__ = self.optim.__setstate__
self.__getstate__ = self.optim.__getstate__
self.__repr__ = self.optim.__repr__

@property
def __class__(self):
return Optimizer

@property
def state(self):
return self.optim.state

@property
def param_groups(self):
return self.optim.param_groups

@param_groups.setter
def param_groups(self, value):
self.optim.param_groups = value

def step(self):
# wrongly defined step. Should contain closure
self.optim.step(closure=None)

class TestLightningOptimizerModel(BoringModel):

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.1)
optimizer = OptimizerWrapper(optimizer)
return [optimizer]

model = TestLightningOptimizerModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
weights_summary=None,
log_every_n_steps=1,
)
trainer.fit(model)


def test_lightning_optimizer_automatic_optimization(tmpdir):
"""
Test lightning optimize works with make_optimizer_step in automatic_optimization
Expand Down
6 changes: 3 additions & 3 deletions tests/trainer/optimization/test_manual_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ def automatic_optimization(self) -> bool:
)

trainer.fit(model)
expected_calls = [call() for s in range(2)]
expected_calls = [call(closure=ANY) for s in range(2)]
step_mock.assert_has_calls(expected_calls)


Expand Down Expand Up @@ -933,9 +933,9 @@ def automatic_optimization(self) -> bool:
)

trainer.fit(model)
expected_calls = [call(optim='sgd') for s in range(4)]
expected_calls = [call(closure=ANY, optim='sgd') for s in range(4)]
mock_sgd_step.assert_has_calls(expected_calls)
expected_calls = [call() for s in range(2)]
expected_calls = [call(closure=ANY) for s in range(2)]
mock_adam_step.assert_has_calls(expected_calls)


Expand Down

0 comments on commit 1a970b2

Please sign in to comment.