Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
fbebccb
update
tchaton Dec 17, 2020
f84085c
clean test
tchaton Dec 17, 2020
a309878
still in progress
tchaton Dec 17, 2020
ae08761
udpdate test
tchaton Dec 17, 2020
3ef910f
Merge branch 'master' into bugfix/5165_enable_pl_optimizer
tchaton Dec 17, 2020
f5a5d1e
update
tchaton Dec 17, 2020
7edec88
Merge branch 'bugfix/5165_enable_pl_optimizer' of https://github.com/…
tchaton Dec 17, 2020
b4181ea
update
tchaton Dec 17, 2020
be48064
resolve flake
tchaton Dec 17, 2020
379d2be
add test for zero_grad
tchaton Dec 17, 2020
fd51f32
update
tchaton Dec 17, 2020
05a838e
works without accumulated_grad
tchaton Dec 17, 2020
82c2602
update
tchaton Dec 17, 2020
386f6d4
update
tchaton Dec 17, 2020
b007c9d
resolve amp
tchaton Dec 17, 2020
5007e68
Merge branch 'master' into bugfix/5165_enable_pl_optimizer
tchaton Dec 17, 2020
88c5c63
revert back to True
tchaton Dec 17, 2020
3accce3
Merge branch 'bugfix/5165_enable_pl_optimizer' of https://github.com/…
tchaton Dec 17, 2020
7fc56ee
update
tchaton Dec 18, 2020
8d13893
clean tests
tchaton Dec 18, 2020
e7abee6
cleaned out
tchaton Dec 18, 2020
14475e7
typo
tchaton Dec 18, 2020
b47db5e
update test
tchaton Dec 18, 2020
6a79921
git repare bug
tchaton Dec 18, 2020
c106828
remove print
tchaton Dec 18, 2020
85e4e96
udpate
tchaton Dec 18, 2020
40f7c54
Fix formatting/optimizer imports
Dec 18, 2020
e6f9945
Refactor the test for cleanliness
Dec 18, 2020
9d4fd68
Add vanilla model to the test, better var names
Dec 18, 2020
f71ce5d
Fixed var names, let's clean up these mock tests
Dec 18, 2020
5c98b0f
repare test
Dec 19, 2020
cfd63ea
update test
Dec 19, 2020
ca6c184
resolve flake8
tchaton Dec 19, 2020
a8c0c20
Merge branch 'master' into bugfix/5165_enable_pl_optimizer_refactor
tchaton Dec 19, 2020
6b5af8b
add manual_optimization
tchaton Dec 19, 2020
feaa861
Merge branch 'bugfix/5165_enable_pl_optimizer_refactor' of https://gi…
tchaton Dec 19, 2020
c1e9d14
update tests
Dec 19, 2020
1352a49
resolve flake8
tchaton Dec 19, 2020
c0afb3b
add random accumulate_grad_batches
Dec 19, 2020
2d8b9bb
Merge branch 'bugfix/5165_enable_pl_optimizer_refactor' of https://gi…
tchaton Dec 19, 2020
9a43d8e
improve test
tchaton Dec 19, 2020
12b3554
Update tests/trainer/optimization/test_parity_automatic_optimization.py
tchaton Dec 19, 2020
a126e56
Update tests/trainer/optimization/test_parity_automatic_optimization.py
tchaton Dec 19, 2020
9d083d5
update
tchaton Dec 19, 2020
7126b2d
clean tests
tchaton Dec 19, 2020
b6c7ad0
correct bug
Dec 19, 2020
f5ec5f5
Apply suggestions from code review
Borda Dec 19, 2020
a9c1f7e
format
Borda Dec 19, 2020
151790d
Merge branch 'master' into bugfix/5165_enable_pl_optimizer_refactor
tchaton Dec 20, 2020
b33ee49
adress comments
tchaton Dec 20, 2020
196d8b4
Merge branch 'master' into bugfix/5165_enable_pl_optimizer_refactor
tchaton Dec 20, 2020
1677b6c
Merge branch 'bugfix/5165_enable_pl_optimizer_refactor' of https://gi…
tchaton Dec 20, 2020
02ded96
update on comments
tchaton Dec 21, 2020
94d3b4b
Merge branch 'master' into bugfix/5165_enable_pl_optimizer_refactor
tchaton Dec 21, 2020
1e8a11e
Merge branch 'master' into bugfix/5165_enable_pl_optimizer_refactor
tchaton Dec 21, 2020
6e68e31
Merge branch 'master' into bugfix/5165_enable_pl_optimizer_refactor
tchaton Dec 21, 2020
47d047c
Merge branch 'master' into bugfix/5165_enable_pl_optimizer_refactor
tchaton Dec 22, 2020
05678f5
Merge branch 'master' into bugfix/5165_enable_pl_optimizer_refactor
tchaton Dec 23, 2020
68ed65e
Merge branch 'master' into bugfix/5165_enable_pl_optimizer_refactor
tchaton Dec 23, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@

"""nn.Module with additional great features."""

from abc import ABC
from argparse import Namespace
import collections
import copy
import inspect
import os
from pathlib import Path
import re
import tempfile
from abc import ABC
from argparse import Namespace
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import torch
Expand All @@ -35,9 +35,9 @@
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_warn, TPU_AVAILABLE
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args
Expand Down Expand Up @@ -1252,9 +1252,6 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
optimizer.zero_grad()

"""
if not isinstance(optimizer, LightningOptimizer):
# wraps into LightingOptimizer only for running step
optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer)
optimizer.step(closure=optimizer_closure)

def optimizer_zero_grad(
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def _on_trainer_init(self, trainer):

@classmethod
def to_lightning_optimizer(cls, optimizer, trainer):
if isinstance(optimizer, LightningOptimizer):
return optimizer
optimizer = cls(optimizer)
optimizer._on_trainer_init(trainer)
return optimizer
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/plugins/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
from torch.optim import Optimizer

from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.plugins.precision_plugin import PrecisionPlugin


Expand Down Expand Up @@ -52,7 +53,10 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):

# unscale gradient to allow analyze within `on_after_backward`
if not self.trainer.train_loop.should_accumulate() and automatic_optimization:
self.trainer.scaler.unscale_(optimizer)
if isinstance(optimizer, LightningOptimizer):
self.trainer.scaler.unscale_(optimizer._optimizer)
else:
self.trainer.scaler.unscale_(optimizer)

return closure_loss

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@
"""Trainer to automate the training."""

import os
import warnings
from typing import Dict, Iterable, List, Optional, Union
import warnings

import torch
from torch.utils.data import DataLoader

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
Expand All @@ -47,6 +46,7 @@
from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector
from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes
from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
Expand All @@ -56,7 +56,7 @@
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import rank_zero_warn, DeviceType
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down
13 changes: 12 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum
from pytorch_lightning.utilities import TPU_AVAILABLE, AMPType, parsing
from pytorch_lightning.utilities import AMPType, parsing, TPU_AVAILABLE
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
Expand Down Expand Up @@ -489,6 +489,9 @@ def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_
'native PyTorch amp and lbfgs are not compatible.'
' To request, please file a Github issue in PyTorch and tag @mcarilli')

# wraps into LightingOptimizer only for running step
optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer)

# model hook
model_ref.optimizer_step(
self.trainer.current_epoch,
Expand Down Expand Up @@ -831,6 +834,8 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs):

# backward can be called manually in the training loop
if isinstance(result, torch.Tensor):
# scale loss under accumulate_grad_batches > 1 and manual_backward
result = self.scale_closure_loss(result)
self.trainer.accelerator_backend.backward(result, optimizer, opt_idx, *args, **kwargs)
else:
result.closure_loss = self.trainer.accelerator_backend.backward(
Expand Down Expand Up @@ -975,3 +980,9 @@ def update_running_loss(self):

# reset for next set of accumulated grads
self.accumulated_loss.reset()

def scale_closure_loss(self, loss: torch.Tensor) -> torch.Tensor:
model_ref = self.trainer.get_model()
if model_ref._running_manual_backward:
loss /= self.trainer.accumulate_grad_batches
return loss
3 changes: 2 additions & 1 deletion tests/base/boring_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from pytorch_lightning import LightningModule
from torch.utils.data import Dataset

from pytorch_lightning import LightningModule


class RandomDictDataset(Dataset):
def __init__(self, size, length):
Expand Down
10 changes: 3 additions & 7 deletions tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
from argparse import ArgumentParser
import pickle
from typing import Optional
from unittest.mock import MagicMock, patch

import pytest
import torch
from torch.optim import SGD, Adam
from torch.optim import Adam, SGD
from torch.utils.data import DataLoader, random_split

from pytorch_lightning import LightningDataModule, Trainer, seed_everything
from pytorch_lightning import LightningDataModule, seed_everything, Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import BoringModel

Expand Down Expand Up @@ -75,16 +75,12 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
if batch_idx % 2 == 0:
assert isinstance(optimizer, SGD)
optimizer.step(closure=optimizer_closure)
if not enable_pl_optimizer:
optimizer.zero_grad()

# update discriminator opt every 4 steps
if optimizer_idx == 1:
if batch_idx % 4 == 0:
assert isinstance(optimizer, Adam)
optimizer.step(closure=optimizer_closure)
if not enable_pl_optimizer:
optimizer.zero_grad()

model = TestModel()
model.training_epoch_end = None
Expand Down
4 changes: 3 additions & 1 deletion tests/core/test_lightning_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@
import os
from unittest.mock import patch

import numpy as np
import pytest
import torch
import torch.nn as nn
from torch.optim import Adam, Optimizer

import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning import LightningModule, seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_utils import is_overridden
from tests.base.boring_model import BoringModel, RandomDataset, RandomDictDataset, RandomDictStringDataset


Expand Down
Loading