Skip to content

Commit ee414d2

Browse files
Jeff YangSeanNarenSeanNarentchaton
authored
Switch to PyTorch 1.6 in Drone CI (#4393)
* switch to 1.6 * readme * 1.7 * back to normal [ci skip] * horovodrun --verbose * try with apex * add apex test * change base * description * test with 1.7 * back to 1.6 * no gradient_clip_val * re-add gradient_clip_val * no amp * temp skip torch.cuda.amp + horovod test * Apply suggestion from code review Co-authored-by: Sean Naren <sean.narenthiran@gmail.com> * Fix formatting * ddp * Moved extended model outside of function to prevent pickling issue for drone * typo * resolve bug * extract automatic_automization Co-authored-by: Sean Naren <sean.narenthiran@gmail.com> Co-authored-by: SeanNaren <sean@grid.ai> Co-authored-by: chaton <thomas@grid.ai>
1 parent a32bffc commit ee414d2

File tree

5 files changed

+51
-23
lines changed

5 files changed

+51
-23
lines changed

.drone.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ name: torch-GPU
2020

2121
steps:
2222
- name: testing
23-
image: pytorchlightning/pytorch_lightning:base-cuda-py3.7-torch1.5
23+
image: pytorchlightning/pytorch_lightning:base-cuda-py3.7-torch1.6
2424

2525
environment:
2626
CODECOV_TOKEN:

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ Lightning can automatically export to ONNX or TorchScript for those cases.
9292
| System / PyTorch ver. | 1.3 (min. req.)* | 1.4 | 1.5 | 1.6 | 1.7 (latest) |
9393
| :---: | :---: | :---: | :---: | :---: | :---: |
9494
| Conda py3.7 [linux] | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) | [![PyTorch & Conda](https://github.com/PyTorchLightning/pytorch-lightning/workflows/PyTorch%20&%20Conda/badge.svg)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22PyTorch+%26+Conda%22+branch%3Amaster) |
95-
| Linux py3.7 [GPUs**] | - | - | [![Build Status](http://104.154.220.231/api/badges/PyTorchLightning/pytorch-lightning/status.svg)](http://104.154.220.231/PyTorchLightning/pytorch-lightning) | - | - |
95+
| Linux py3.7 [GPUs**] | - | - | - | [![Build Status](http://104.154.220.231/api/badges/PyTorchLightning/pytorch-lightning/status.svg)](http://104.154.220.231/PyTorchLightning/pytorch-lightning) | - |
9696
| Linux py3.7 [TPUs***] | - | - | - | [![TPU tests](https://github.com/PyTorchLightning/pytorch-lightning/workflows/TPU%20tests/badge.svg)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22TPU+tests%22+branch%3Amaster) | - |
9797
| Linux py3.6 / py3.7 / py3.8 | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) |
9898
| OSX py3.6 / py3.7 | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) | - | - | [![CI complete testing](https://github.com/PyTorchLightning/pytorch-lightning/workflows/CI%20complete%20testing/badge.svg?event=push)](https://github.com/PyTorchLightning/pytorch-lightning/actions?query=workflow%3A%22CI+testing%22) |

pytorch_lightning/plugins/native_amp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@ def connect(self, model, optimizers):
2929
def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
3030
closure_loss = self.trainer.scaler.scale(closure_loss)
3131

32+
automatic_optimization = self.trainer.train_loop.automatic_optimization
33+
3234
# do backward pass
33-
if self.trainer.train_loop.automatic_optimization:
35+
if automatic_optimization:
3436
model = self.trainer.get_model()
3537
model.backward(closure_loss, optimizer, opt_idx)
3638
else:
@@ -40,7 +42,7 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
4042
closure_loss = closure_loss.detach()
4143

4244
# unscale gradient to allow analyze within `on_after_backward`
43-
if not self.trainer.train_loop.should_accumulate():
45+
if not self.trainer.train_loop.should_accumulate() and automatic_optimization:
4446
self.trainer.scaler.unscale_(optimizer)
4547

4648
return closure_loss

tests/models/test_horovod.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import tests.base.develop_pipelines as tpipes
2626
import tests.base.develop_utils as tutils
2727
from pytorch_lightning import Trainer
28+
from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE
2829
from tests.base import EvalModelTemplate
2930
from tests.base.models import BasicGAN
3031

@@ -126,8 +127,33 @@ def test_horovod_multi_gpu(tmpdir):
126127
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
127128
@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support")
128129
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
130+
@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex")
131+
def test_horovod_apex(tmpdir):
132+
"""Test Horovod with multi-GPU support using apex amp."""
133+
trainer_options = dict(
134+
default_root_dir=str(tmpdir),
135+
weights_save_path=str(tmpdir),
136+
gradient_clip_val=1.0,
137+
progress_bar_refresh_rate=0,
138+
max_epochs=1,
139+
limit_train_batches=0.4,
140+
limit_val_batches=0.2,
141+
gpus=2,
142+
deterministic=True,
143+
distributed_backend='horovod',
144+
amp_backend='apex',
145+
precision=16,
146+
)
147+
_run_horovod(trainer_options, on_gpu=True)
148+
149+
150+
@pytest.mark.skip(reason="Skip till Horovod fixes integration with Native torch.cuda.amp")
151+
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
152+
@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support")
153+
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
154+
@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="test requires torch.cuda.amp")
129155
def test_horovod_amp(tmpdir):
130-
"""Test Horovod with multi-GPU support."""
156+
"""Test Horovod with multi-GPU support using native amp."""
131157
trainer_options = dict(
132158
default_root_dir=str(tmpdir),
133159
weights_save_path=str(tmpdir),
@@ -139,6 +165,7 @@ def test_horovod_amp(tmpdir):
139165
gpus=2,
140166
deterministic=True,
141167
distributed_backend='horovod',
168+
amp_backend='native',
142169
precision=16,
143170
)
144171
_run_horovod(trainer_options, on_gpu=True)

tests/plugins/test_amp_plugin.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -86,20 +86,19 @@ def on_fit_start(self, trainer, pl_module):
8686
trainer.fit(model)
8787

8888

89+
class GradientUnscaleBoringModel(BoringModel):
90+
def on_after_backward(self):
91+
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
92+
if not (torch.isinf(norm) or torch.isnan(norm)):
93+
assert norm.item() < 15.
94+
95+
8996
@pytest.mark.skipif(
9097
LooseVersion(torch.__version__) < LooseVersion("1.6.0"),
9198
reason="Minimal PT version is set to 1.6")
9299
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
93100
def test_amp_gradient_unscale(tmpdir):
94-
95-
class ExtendedBoringModel(BoringModel):
96-
97-
def on_after_backward(self):
98-
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
99-
if not (torch.isinf(norm) or torch.isnan(norm)):
100-
assert norm.item() < 15.
101-
102-
model = ExtendedBoringModel()
101+
model = GradientUnscaleBoringModel()
103102

104103
trainer = Trainer(
105104
max_epochs=2,
@@ -117,19 +116,19 @@ def on_after_backward(self):
117116
trainer.fit(model)
118117

119118

119+
class UnscaleAccumulateGradBatchesBoringModel(BoringModel):
120+
121+
def on_after_backward(self):
122+
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
123+
if not (torch.isinf(norm) or torch.isnan(norm)):
124+
assert norm.item() < 15.
125+
126+
120127
@pytest.mark.skipif(
121128
LooseVersion(torch.__version__) < LooseVersion("1.6.0"), reason="Minimal PT version is set to 1.6")
122129
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
123130
def test_amp_gradient_unscale_accumulate_grad_batches(tmpdir):
124-
125-
class ExtendedBoringModel(BoringModel):
126-
127-
def on_after_backward(self):
128-
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
129-
if not (torch.isinf(norm) or torch.isnan(norm)):
130-
assert norm.item() < 15.
131-
132-
model = ExtendedBoringModel()
131+
model = UnscaleAccumulateGradBatchesBoringModel()
133132

134133
trainer = Trainer(
135134
max_epochs=2,

0 commit comments

Comments
 (0)