Skip to content

Commit 3b20f2a

Browse files
authored
Merge branch 'master' into model-checkpoint-save-last-fix
2 parents 3841079 + c7f30a2 commit 3b20f2a

File tree

126 files changed

+1919
-1133
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

126 files changed

+1919
-1133
lines changed

.github/workflows/ci_test-conda.yml

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,20 @@ jobs:
2727
run: |
2828
conda info
2929
conda list
30+
# adjust versions according installed Torch version
31+
python ./requirements/adjust_versions.py requirements/extra.txt
32+
python ./requirements/adjust_versions.py requirements/examples.txt
3033
pip install --requirement requirements/devel.txt --upgrade-strategy only-if-needed
3134
pip list
3235
3336
- name: Pull checkpoints from S3
34-
# todo: consider adding coma caching, but ATM all models have less then 100KB
3537
run: |
3638
# enter legacy and update checkpoints from S3
3739
cd legacy
3840
curl https://pl-public-data.s3.amazonaws.com/legacy/checkpoints.zip --output checkpoints.zip
3941
unzip -o checkpoints.zip
4042
ls -l checkpoints/
4143
42-
# todo: require proper fix in docker image
43-
- name: Hotfix dependency
44-
run: |
45-
pip install torchtext==0.6.0 -U
46-
shell: bash
47-
4844
- name: Tests
4945
run: |
5046
# NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003

.github/workflows/ci_test-full.yml

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -104,20 +104,17 @@ jobs:
104104
HOROVOD_WITHOUT_MXNET: 1
105105
HOROVOD_WITHOUT_TENSORFLOW: 1
106106
run: |
107-
# python -m pip install --upgrade --user pip
108-
pip install --requirement requirements.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet --upgrade
109-
pip install --requirement ./requirements/devel.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet --upgrade
110107
python --version
111108
pip --version
109+
# python -m pip install --upgrade --user pip
110+
pip install --requirement requirements.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade
111+
# adjust versions according installed Torch version
112+
python ./requirements/adjust_versions.py requirements/extra.txt
113+
python ./requirements/adjust_versions.py requirements/examples.txt
114+
pip install --requirement ./requirements/devel.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade
112115
pip list
113116
shell: bash
114117

115-
# todo: require proper fix in docker image
116-
- name: Hotfix dependency
117-
run: |
118-
pip install torchtext==0.6.0 -U
119-
shell: bash
120-
121118
- name: Reinstall Horovod if necessary
122119
if: runner.os != 'windows'
123120
env:
@@ -143,10 +140,9 @@ jobs:
143140
# NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003
144141
coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml
145142
146-
# todo: put this back just when TorchVision can download datasets
147-
#- name: Examples
148-
# run: |
149-
# python -m pytest pl_examples -v --durations=10
143+
- name: Examples
144+
run: |
145+
python -m pytest pl_examples -v --durations=10
150146
151147
- name: Upload pytest test results
152148
uses: actions/upload-artifact@v2

.github/workflows/docs-checks.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@ jobs:
4141
4242
- name: Install dependencies
4343
run: |
44+
python --version
45+
pip --version
4446
# remove Horovod from requirements
4547
python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if not line.startswith('horovod')] ; open(fname, 'w').writelines(lines)"
4648
# python -m pip install --upgrade --user pip
4749
pip install --requirement requirements.txt --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet
4850
pip install --requirement requirements/extra.txt
4951
pip install --requirement requirements/loggers.txt
5052
pip install --requirement requirements/docs.txt
51-
python --version
52-
pip --version
5353
pip list
5454
shell: bash
5555

@@ -84,12 +84,12 @@ jobs:
8484
8585
- name: Install dependencies
8686
run: |
87-
pip install --requirement requirements.txt --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet
87+
python --version
88+
pip --version
89+
# pip install --requirement requirements.txt --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet
8890
pip install --requirement requirements/docs.txt
8991
# install Texlive, see https://linuxconfig.org/how-to-install-latex-on-ubuntu-20-04-focal-fossa-linux
9092
sudo apt-get update && sudo apt-get install -y texlive-latex-extra dvipng texlive-pictures
91-
python --version
92-
pip --version
9393
pip list
9494
shell: bash
9595

.github/workflows/events-nightly.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,6 @@ jobs:
102102
id: extend
103103
104104
- name: Publish CUDA to Docker Hub
105-
# ToDo: extend also building for Nightly from pip
106-
if: matrix.pytorch_version < 1.8
107105
# publish master/release
108106
uses: docker/build-push-action@v2
109107
with:

CHANGELOG.md

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,46 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515
- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))
1616

1717

18+
- Added `RunningStage.SANITY_CHECKING` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
19+
20+
21+
- Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
22+
23+
24+
- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915))
25+
26+
1827
- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274))
1928

2029

30+
- Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139))
31+
32+
2133
### Changed
2234

23-
- Changed the order of `backward`, `step`, `zero_grad` to `zero_grad`, `backward`, `step` ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147))
35+
- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))
2436

2537

26-
- Changed default for DeepSpeed CPU Offload to False, due to prohibitively slow speeds at smaller scale ([#6262](https://github.com/PyTorchLightning/pytorch-lightning/pull/6262))
38+
- Refactor `RunningStage` and `TrainerState` usage ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
2739

2840

29-
- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))
41+
- Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
3042

3143

3244
### Deprecated
3345

3446

47+
- Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
48+
49+
3550
### Removed
3651

3752
- Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164))
3853

3954

55+
- Removed no return warning from val/test step ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139))
56+
57+
4058
- Removed passing a `ModelCheckpoint` instance to `Trainer(checkpoint_callback)` ([#6166](https://github.com/PyTorchLightning/pytorch-lightning/pull/6166))
4159

4260

@@ -57,6 +75,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5775
- Removed deprecated `LightningModule` `hparams` setter ([#6207](https://github.com/PyTorchLightning/pytorch-lightning/pull/6207))
5876

5977

78+
- Removed `optimizer_idx` argument from `training_step` in manual optimization ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093))
79+
80+
6081
### Fixed
6182

6283
- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))
@@ -77,33 +98,45 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7798
- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115))
7899

79100

80-
- Fixed epoch level schedulers not being called when `val_check_interval < 1.0` ([#6075](https://github.com/PyTorchLightning/pytorch-lightning/pull/6075))
101+
- Fixed `ModelPruning(make_pruning_permanent=True)` pruning buffers getting removed when saved during training ([#6073](https://github.com/PyTorchLightning/pytorch-lightning/pull/6073))
81102

82103

83-
- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197))
104+
- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272))
84105

85106

86-
- Fixed `ModelPruning(make_pruning_permanent=True)` pruning buffers getting removed when saved during training ([#6073](https://github.com/PyTorchLightning/pytorch-lightning/pull/6073))
107+
- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275))
87108

88109

89-
- Fixed incorrect usage of `detach()`, `cpu()`, `to()` ([#6216](https://github.com/PyTorchLightning/pytorch-lightning/pull/6216))
110+
- Fixed `SingleTPU` calling `all_gather` ([#6296](https://github.com/PyTorchLightning/pytorch-lightning/pull/6296))
90111

91112

92-
- Fixed LBFGS optimizer support which didn't converge in automatic optimization ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147))
113+
- Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324))
93114

94115

95-
- Prevent `WandbLogger` from dropping values ([#5931](https://github.com/PyTorchLightning/pytorch-lightning/pull/5931))
116+
- Fixed PyTorch Profiler with `emit_nvtx` ([#6260](https://github.com/PyTorchLightning/pytorch-lightning/pull/6260))
96117

97118

98119
- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272))
99120

100121

101-
- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275))
122+
## [1.2.2] - 2021-03-02
102123

124+
### Added
103125

104-
- Fixed `SingleTPU` calling `all_gather` ([#6296](https://github.com/PyTorchLightning/pytorch-lightning/pull/6296))
126+
- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))
105127

128+
### Changed
106129

130+
- Changed the order of `backward`, `step`, `zero_grad` to `zero_grad`, `backward`, `step` ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147))
131+
- Changed default for DeepSpeed CPU Offload to False, due to prohibitively slow speeds at smaller scale ([#6262](https://github.com/PyTorchLightning/pytorch-lightning/pull/6262))
132+
133+
### Fixed
134+
135+
- Fixed epoch level schedulers not being called when `val_check_interval < 1.0` ([#6075](https://github.com/PyTorchLightning/pytorch-lightning/pull/6075))
136+
- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197))
137+
- Fixed incorrect usage of `detach()`, `cpu()`, `to()` ([#6216](https://github.com/PyTorchLightning/pytorch-lightning/pull/6216))
138+
- Fixed LBFGS optimizer support which didn't converge in automatic optimization ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147))
139+
- Prevent `WandbLogger` from dropping values ([#5931](https://github.com/PyTorchLightning/pytorch-lightning/pull/5931))
107140
- Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297)
108141

109142

MANIFEST.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ recursive-include docs/source/_static/images/general/ pl_overview* tf_* tutorial
4646

4747
# Include the Requirements
4848
recursive-include requirements *.txt
49-
recursive-exclude requirements *.sh
49+
recursive-exclude requirements *.sh *.py
5050
include requirements.txt
5151
include pyproject.toml
5252

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,9 @@ class LitAutoEncoder(pl.LightningModule):
318318
super().__init__()
319319
self.automatic_optimization = False
320320

321-
def training_step(self, batch, batch_idx, optimizer_idx):
321+
def training_step(self, batch, batch_idx):
322322
# access your optimizers with use_pl_optimizer=False. Default is True
323-
(opt_a, opt_b) = self.optimizers(use_pl_optimizer=True)
323+
opt_a, opt_b = self.optimizers(use_pl_optimizer=True)
324324

325325
loss_a = ...
326326
self.manual_backward(loss_a, opt_a)

azure-pipelines.yml

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,6 @@ jobs:
7171
python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu >= 2, f'GPU: {mgpu}'"
7272
displayName: 'Env details'
7373
74-
# todo: require proper fix in docker image
75-
- bash: |
76-
pip install torchtext==0.7 -U
77-
displayName: 'HotFix'
78-
7974
- bash: |
8075
wget https://pl-public-data.s3.amazonaws.com/legacy/checkpoints.zip -P legacy/
8176
unzip -o legacy/checkpoints.zip -d legacy/
@@ -100,10 +95,12 @@ jobs:
10095
python -m pytest benchmarks -v --maxfail=2 --durations=0
10196
displayName: 'Testing: benchmarks'
10297
103-
# todo: put this back just when TorchVision can download datasets
104-
#- bash: |
105-
# python -m pytest pl_examples -v --maxfail=2 --durations=0
106-
# python setup.py install --user --quiet
107-
# bash pl_examples/run_ddp-example.sh
108-
# pip uninstall -y pytorch-lightning
109-
# displayName: 'Examples'
98+
- bash: |
99+
python -m pytest pl_examples -v --maxfail=2 --durations=0
100+
python setup.py install --user --quiet
101+
bash pl_examples/run_ddp-example.sh
102+
cd pl_examples/basic_examples
103+
bash submit_ddp_job.sh
104+
bash submit_ddp2_job.sh
105+
pip uninstall -y pytorch-lightning
106+
displayName: 'Examples'

benchmarks/test_sharded_parity.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import os
16-
import platform
1716
import time
1817
from typing import Type
1918

@@ -22,25 +21,20 @@
2221

2322
from pytorch_lightning import seed_everything, Trainer
2423
from pytorch_lightning.plugins import DDPSpawnShardedPlugin
25-
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE
2624
from tests.accelerators import DDPLauncher
2725
from tests.helpers.boring_model import BoringModel, RandomDataset
26+
from tests.helpers.runif import RunIf
2827

2928

30-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
31-
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
32-
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
29+
@RunIf(min_gpus=1, skip_windows=True, fairscale=True)
3330
def test_ddp_sharded_plugin_correctness_one_gpu():
3431
plugin_parity_test(
3532
gpus=1,
3633
model_cls=SeedTrainLoaderModel,
3734
)
3835

3936

40-
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
41-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
42-
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
43-
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
37+
@RunIf(min_gpus=1, skip_windows=True, fairscale=True, amp_native=True)
4438
def test_ddp_sharded_plugin_correctness_amp_one_gpu():
4539
plugin_parity_test(
4640
gpus=1,
@@ -50,9 +44,7 @@ def test_ddp_sharded_plugin_correctness_amp_one_gpu():
5044

5145

5246
@pytest.mark.skip(reason="Not a critical test, skip till drone CI performance improves.")
53-
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
54-
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
55-
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
47+
@RunIf(min_gpus=2, skip_windows=True, fairscale=True)
5648
def test_ddp_sharded_plugin_correctness_multi_gpu():
5749
plugin_parity_test(
5850
gpus=2,
@@ -61,10 +53,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu():
6153
)
6254

6355

64-
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
65-
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
66-
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
67-
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
56+
@RunIf(min_gpus=2, skip_windows=True, fairscale=True, amp_native=True)
6857
def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
6958
plugin_parity_test(
7059
gpus=2,
@@ -74,10 +63,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
7463
)
7564

7665

77-
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
78-
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
79-
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
80-
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
66+
@RunIf(min_gpus=2, skip_windows=True, fairscale=True, amp_native=True)
8167
def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
8268
plugin_parity_test(
8369
gpus=2,
@@ -87,8 +73,7 @@ def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
8773
)
8874

8975

90-
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
91-
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
76+
@RunIf(min_gpus=2, fairscale=True)
9277
@pytest.mark.skipif(
9378
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
9479
)
@@ -101,8 +86,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None):
10186
)
10287

10388

104-
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
105-
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
89+
@RunIf(min_gpus=2, fairscale=True)
10690
@pytest.mark.skipif(
10791
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
10892
)
@@ -116,9 +100,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
116100

117101

118102
@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
119-
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
120-
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
121-
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
103+
@RunIf(min_gpus=2, skip_windows=True, fairscale=True)
122104
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
123105
"""
124106
Ensures same results using multiple optimizers across multiple GPUs
@@ -131,9 +113,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
131113

132114

133115
@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
134-
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
135-
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
136-
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
116+
@RunIf(min_gpus=2, skip_windows=True, fairscale=True)
137117
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
138118
"""
139119
Ensures using multiple optimizers across multiple GPUs with manual optimization

0 commit comments

Comments
 (0)