From f40d08679d31ef6e705f1e0e5a66473c817325e1 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Mon, 2 Nov 2020 10:46:02 +0000 Subject: [PATCH 1/4] Add manual logging to training_step manual optimization (#4476) --- docs/source/optimizers.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/optimizers.rst b/docs/source/optimizers.rst index 00e8ef88aa686b..1e7baadb64480c 100644 --- a/docs/source/optimizers.rst +++ b/docs/source/optimizers.rst @@ -48,6 +48,10 @@ to manually manage the optimization process. To do so, do the following: opt_d.step() opt_d.zero_grad() + # log losses + self.log('loss_a', loss_a) + self.log('loss_b', loss_b) + .. note:: This is only recommended for experts who need ultimate flexibility Manual optimization does not yet support accumulated gradients but will be live in 1.1.0 From ef03c39ab75dc246edddbff3af84035113141ec7 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 2 Nov 2020 15:05:58 +0100 Subject: [PATCH 2/4] Add step index in checkpoint name (#3807) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * true final value of global step * ch check * tests * save each validation interval * wip * add test * add test * wip * fix tests, revert old edits, fix merge conflicts, update doctests * test + bugfix * sort files * format test * suggestion by ananth * added changelog * naming * docs * example * suggestion Co-authored-by: Carlos Mocholí * fix test * pep * pep Co-authored-by: Adrian Wälchli Co-authored-by: Rohit Gupta Co-authored-by: Carlos Mocholí --- CHANGELOG.md | 2 + .../callbacks/model_checkpoint.py | 65 ++++++------ pytorch_lightning/trainer/evaluation_loop.py | 7 +- tests/checkpointing/test_model_checkpoint.py | 100 +++++++++++++----- tests/loggers/test_comet.py | 2 +- tests/loggers/test_mlflow.py | 4 +- tests/loggers/test_wandb.py | 2 +- tests/trainer/test_trainer.py | 2 +- 8 files changed, 117 insertions(+), 67 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b2e4107ef5ff8a..cdb9ddc804d72c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added multiclass AUROC metric ([#4236](https://github.com/PyTorchLightning/pytorch-lightning/pull/4236)) +- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807)) + ### Changed - W&B log in sync with Trainer step ([#4405](https://github.com/PyTorchLightning/pytorch-lightning/pull/4405)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 33ed1d71d9eb4d..f3eabf5611cf00 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -101,7 +101,7 @@ class ModelCheckpoint(Callback): ... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}' ... ) - By default, filename is ``None`` and will be set to ``'{epoch}'``. + By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``. Example:: @@ -222,16 +222,16 @@ def save_checkpoint(self, trainer, pl_module): monitor_candidates = self._monitor_candidates(trainer) # ie: path/val_loss=0.5.ckpt - filepath = self._get_metric_interpolated_filepath_name(epoch, monitor_candidates) + filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, epoch, global_step) # callback supports multiple simultaneous modes # here we call each mode sequentially # Mode 1: save all checkpoints OR only the top k if self.save_top_k: - self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, epoch, filepath) + self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, filepath) # Mode 2: save the last checkpoint - self._save_last_checkpoint(trainer, pl_module, epoch, monitor_candidates, filepath) + self._save_last_checkpoint(trainer, pl_module, monitor_candidates, filepath) def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: @@ -360,16 +360,17 @@ def _format_checkpoint_name( cls, filename: Optional[str], epoch: int, + step: int, metrics: Dict[str, Any], prefix: str = "", ) -> str: if not filename: # filename is not set, use default name - filename = "{epoch}" + filename = "{epoch}-{step}" # check and parse user passed keys in the string groups = re.findall(r"(\{.*?)[:\}]", filename) if len(groups) >= 0: - metrics["epoch"] = epoch + metrics.update({"epoch": epoch, 'step': step}) for group in groups: name = group[1:] filename = filename.replace(group, name + "={" + name) @@ -379,7 +380,7 @@ def _format_checkpoint_name( return cls.CHECKPOINT_JOIN_CHAR.join([txt for txt in (prefix, filename) if txt]) def format_checkpoint_name( - self, epoch: int, metrics: Dict[str, Any], ver: Optional[int] = None + self, epoch: int, step: int, metrics: Dict[str, Any], ver: Optional[int] = None ) -> str: """Generate a filename according to the defined template. @@ -387,24 +388,24 @@ def format_checkpoint_name( >>> tmpdir = os.path.dirname(__file__) >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}') - >>> os.path.basename(ckpt.format_checkpoint_name(0, {})) + >>> os.path.basename(ckpt.format_checkpoint_name(0, 1, metrics={})) 'epoch=0.ckpt' >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}') - >>> os.path.basename(ckpt.format_checkpoint_name(5, {})) + >>> os.path.basename(ckpt.format_checkpoint_name(5, 2, metrics={})) 'epoch=005.ckpt' >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}') - >>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456))) + >>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456))) 'epoch=2-val_loss=0.12.ckpt' >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}') - >>> os.path.basename(ckpt.format_checkpoint_name(0, {})) + >>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={})) 'missing=0.ckpt' - >>> ckpt = ModelCheckpoint(filename='{epoch}') - >>> os.path.basename(ckpt.format_checkpoint_name(0, {})) - 'epoch=0.ckpt' + >>> ckpt = ModelCheckpoint(filename='{step}') + >>> os.path.basename(ckpt.format_checkpoint_name(0, 0, {})) + 'step=0.ckpt' """ filename = self._format_checkpoint_name( - self.filename, epoch, metrics, prefix=self.prefix + self.filename, epoch, step, metrics, prefix=self.prefix ) if ver is not None: filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}")) @@ -479,13 +480,11 @@ def _validate_monitor_key(self, trainer): ) raise MisconfigurationException(m) - def _get_metric_interpolated_filepath_name(self, epoch, ckpt_name_metrics): - filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics) + def _get_metric_interpolated_filepath_name(self, ckpt_name_metrics: Dict[str, Any], epoch: int, step: int): + filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics) version_cnt = 0 while self._fs.exists(filepath): - filepath = self.format_checkpoint_name( - epoch, ckpt_name_metrics, ver=version_cnt - ) + filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt) # this epoch called before version_cnt += 1 return filepath @@ -494,9 +493,10 @@ def _monitor_candidates(self, trainer): ckpt_name_metrics = deepcopy(trainer.logger_connector.logged_metrics) ckpt_name_metrics.update(trainer.logger_connector.callback_metrics) ckpt_name_metrics.update(trainer.logger_connector.progress_bar_metrics) + ckpt_name_metrics.update({"step": trainer.global_step, "epoch": trainer.current_epoch}) return ckpt_name_metrics - def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, filepath): + def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath): should_save_last = self.monitor is None or self.save_last if not should_save_last: return @@ -506,7 +506,11 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi # when user ALSO asked for the 'last.ckpt' change the name if self.save_last: last_filepath = self._format_checkpoint_name( - self.CHECKPOINT_NAME_LAST, epoch, ckpt_name_metrics, prefix=self.prefix + self.CHECKPOINT_NAME_LAST, + trainer.current_epoch, + trainer.global_step, + ckpt_name_metrics, + prefix=self.prefix ) last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt") @@ -523,17 +527,19 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi if self.monitor is None: self.best_model_path = self.last_model_path - def _save_top_k_checkpoints(self, metrics, trainer, pl_module, epoch, filepath): + def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath): current = metrics.get(self.monitor) + epoch = metrics.get("epoch") + step = metrics.get("step") if not isinstance(current, torch.Tensor) and current is not None: current = torch.tensor(current, device=pl_module.device) if self.check_monitor_top_k(current): - self._update_best_and_save(filepath, current, epoch, trainer, pl_module) + self._update_best_and_save(filepath, current, epoch, step, trainer, pl_module) elif self.verbose: rank_zero_info( - f"Epoch {epoch:d}: {self.monitor} was not in top {self.save_top_k}" + f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}" ) def _is_valid_monitor_key(self, metrics): @@ -544,11 +550,11 @@ def _update_best_and_save( filepath: str, current: torch.Tensor, epoch: int, + step: int, trainer, pl_module, ): - - k = epoch + 1 if self.save_top_k == -1 else self.save_top_k + k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k del_list = [] if len(self.best_k_models) == k and k > 0: @@ -575,9 +581,8 @@ def _update_best_and_save( if self.verbose: rank_zero_info( - f"Epoch {epoch:d}: {self.monitor} reached" - f" {current:0.5f} (best {self.best_model_score:0.5f})," - f" saving model to {filepath} as top {k}" + f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}" + f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}' ) self._save_model(filepath, trainer, pl_module) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 9dab036583dd85..ffc72f8f0022e8 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -250,9 +250,10 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): # depre warning if eval_results is not None and user_reduced: step = 'testing_epoch_end' if self.testing else 'validation_epoch_end' - m = f'The {step} should not return anything as of 9.1.' \ - f'to log, use self.log(...) or self.write(...) directly in the LightningModule' - self.warning_cache.warn(m) + self.warning_cache.warn( + f'The {step} should not return anything as of 9.1.' + ' To log, use self.log(...) or self.write(...) directly in the LightningModule' + ) if using_eval_result and not user_reduced: eval_results = self.__auto_reduce_result_objs(outputs) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index c2b5b7e9fc8a92..d3d5f67bcfeaa3 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -100,7 +100,7 @@ def test_model_checkpoint_to_yaml(tmpdir, save_top_k): path_yaml = os.path.join(tmpdir, 'best_k_models.yaml') checkpoint.to_yaml(path_yaml) d = yaml.full_load(open(path_yaml, 'r')) - best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} + best_k = {k: v for k, v in checkpoint.best_k_models.items()} assert d == best_k @@ -185,67 +185,72 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir): def test_model_checkpoint_format_checkpoint_name(tmpdir): # empty filename: - ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, {}) - assert ckpt_name == 'epoch=3' + ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, 2, {}) + assert ckpt_name == 'epoch=3-step=2' - ckpt_name = ModelCheckpoint._format_checkpoint_name(None, 3, {}, prefix='test') - assert ckpt_name == 'test-epoch=3' + ckpt_name = ModelCheckpoint._format_checkpoint_name(None, 3, 2, {}, prefix='test') + assert ckpt_name == 'test-epoch=3-step=2' # no groups case: - ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt', 3, {}, prefix='test') + ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt', 3, 2, {}, prefix='test') assert ckpt_name == 'test-ckpt' # no prefix - ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch:03d}-{acc}', 3, {'acc': 0.03}) + ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch:03d}-{acc}', 3, 2, {'acc': 0.03}) assert ckpt_name == 'epoch=003-acc=0.03' # prefix char_org = ModelCheckpoint.CHECKPOINT_JOIN_CHAR ModelCheckpoint.CHECKPOINT_JOIN_CHAR = '@' - ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}', 3, {'acc': 0.03}, prefix='test') + ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}', 3, 2, {'acc': 0.03}, prefix='test') assert ckpt_name == 'test@epoch=3,acc=0.03000' ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org # no dirpath set - ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath=None).format_checkpoint_name(3, {}) - assert ckpt_name == 'epoch=3.ckpt' - ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='').format_checkpoint_name(5, {}) - assert ckpt_name == 'epoch=5.ckpt' + ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath=None).format_checkpoint_name(3, 2, {}) + assert ckpt_name == 'epoch=3-step=2.ckpt' + ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='').format_checkpoint_name(5, 4, {}) + assert ckpt_name == 'epoch=5-step=4.ckpt' # CWD - ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='.').format_checkpoint_name(3, {}) - assert ckpt_name == str(Path('.').resolve() / 'epoch=3.ckpt') + ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='.').format_checkpoint_name(3, 4, {}) + assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=4.ckpt') # with ver ckpt_name = ModelCheckpoint( monitor='early_stop_on', dirpath=tmpdir, filename='name', prefix='test' - ).format_checkpoint_name(3, {}, ver=3) + ).format_checkpoint_name(3, 2, {}, ver=3) assert ckpt_name == tmpdir / 'test-name-v3.ckpt' # using slashes ckpt_name = ModelCheckpoint( monitor='early_stop_on', dirpath=None, filename='{epoch}_{val/loss:.5f}' - ).format_checkpoint_name(4, {'val/loss': 0.03}) + ).format_checkpoint_name(4, 3, {'val/loss': 0.03}) assert ckpt_name == 'epoch=4_val/loss=0.03000.ckpt' # TODO: Checks with filepath. To be removed in v1.2 # CWD - ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='.').format_checkpoint_name(3, {}) - assert ckpt_name == str(Path('.').resolve() / 'epoch=3.ckpt') + ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='.').format_checkpoint_name(3, 2, {}) + assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=2.ckpt') # dir does not exist so it is used as filename filepath = tmpdir / 'dir' - ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, {}) + ckpt_name = ModelCheckpoint( + monitor='early_stop_on', filepath=filepath, prefix='test' + ).format_checkpoint_name(3, 2, {}) assert ckpt_name == tmpdir / 'test-dir.ckpt' # now, dir exists os.mkdir(filepath) - ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, {}) - assert ckpt_name == filepath / 'test-epoch=3.ckpt' + ckpt_name = ModelCheckpoint( + monitor='early_stop_on', filepath=filepath, prefix='test' + ).format_checkpoint_name(3, 2, {}) + assert ckpt_name == filepath / 'test-epoch=3-step=2.ckpt' def test_model_checkpoint_save_last(tmpdir): """Tests that save_last produces only one last checkpoint.""" + seed_everything() model = EvalModelTemplate() epochs = 3 ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}' @@ -257,10 +262,15 @@ def test_model_checkpoint_save_last(tmpdir): logger=False, ) trainer.fit(model) - last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, epochs - 1, {}) + last_filename = model_checkpoint._format_checkpoint_name( + ModelCheckpoint.CHECKPOINT_NAME_LAST, trainer.current_epoch, trainer.global_step, {} + ) last_filename = last_filename + '.ckpt' assert str(tmpdir / last_filename) == model_checkpoint.last_model_path - assert set(os.listdir(tmpdir)) == set([f'epoch={i}.ckpt' for i in range(epochs)] + [last_filename]) + assert set(os.listdir(tmpdir)) == set( + [f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [9, 19, 29])] + [last_filename] + ) + ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last' @@ -295,6 +305,7 @@ def test_none_monitor_save_last(tmpdir): def test_model_checkpoint_none_monitor(tmpdir): """ Test that it is possible to save all checkpoints when monitor=None. """ + seed_everything() model = EvalModelTemplate() model.validation_step = model.validation_step_no_monitor model.validation_epoch_end = model.validation_epoch_end_no_monitor @@ -311,13 +322,13 @@ def test_model_checkpoint_none_monitor(tmpdir): # these should not be set if monitor is None assert checkpoint_callback.monitor is None - assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'epoch=1.ckpt' + assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'epoch=1-step=19.ckpt' assert checkpoint_callback.best_model_score == 0 assert checkpoint_callback.best_k_models == {} assert checkpoint_callback.kth_best_model_path == '' # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs)] + expected = [f'epoch={i}-step={j}.ckpt' for i, j in zip(range(epochs), [9, 19])] assert set(os.listdir(tmpdir)) == set(expected) @@ -325,13 +336,14 @@ def test_model_checkpoint_none_monitor(tmpdir): def test_model_checkpoint_period(tmpdir, period): model = EvalModelTemplate() epochs = 5 - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, period=period) + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}', save_top_k=-1, period=period) trainer = Trainer( default_root_dir=tmpdir, checkpoint_callback=checkpoint_callback, max_epochs=epochs, limit_train_batches=0.1, limit_val_batches=0.1, + val_check_interval=1.0, logger=False, ) trainer.fit(model) @@ -372,12 +384,19 @@ def validation_epoch_end(self, outputs): return {'epoch': self.current_epoch} model = CustomModel() - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="epoch", mode='max', save_top_k=-1) + checkpoint_callback = ModelCheckpoint( + dirpath=tmpdir, + filename="{epoch}", + monitor="epoch", + mode='max', + save_top_k=-1, + ) trainer = Trainer( default_root_dir=tmpdir, checkpoint_callback=checkpoint_callback, max_epochs=epochs, logger=False, + val_check_interval=1.0, ) trainer.fit(model) @@ -439,7 +458,7 @@ def test_default_checkpoint_behavior(tmpdir): # make sure the checkpoint we saved has the metric in the name ckpts = os.listdir(os.path.join(tmpdir, 'lightning_logs', 'version_0', 'checkpoints')) assert len(ckpts) == 1 - assert ckpts[0] == 'epoch=2.ckpt' + assert ckpts[0] == 'epoch=2-step=14.ckpt' def test_ckpt_metric_names_results(tmpdir): @@ -497,7 +516,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): model = EvalModelTemplate() num_epochs = 3 model_checkpoint = ModelCheckpoint( - monitor='early_stop_on', dirpath=tmpdir, save_top_k=num_epochs, save_last=True + monitor='early_stop_on', dirpath=tmpdir, filename="{epoch}", save_top_k=num_epochs, save_last=True ) trainer = Trainer( default_root_dir=tmpdir, @@ -509,6 +528,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): path_last_epoch = str(tmpdir / f"epoch={num_epochs - 1}.ckpt") path_last = str(tmpdir / "last.ckpt") assert path_last == model_checkpoint.last_model_path + assert os.path.isfile(path_last_epoch) ckpt_last_epoch = torch.load(path_last_epoch) ckpt_last = torch.load(path_last) @@ -791,3 +811,25 @@ def test_configure_model_checkpoint(tmpdir): with pytest.raises(MisconfigurationException, match="checkpoint_callback=False but found ModelCheckpoint"): Trainer(checkpoint_callback=False, callbacks=[callback1], **kwargs) + + +def test_val_check_interval_checkpoint_files(tmpdir): + """ Test correct checkpoint naming when validating/checkpointing multiple times per epoch. """ + model = EvalModelTemplate() + model_checkpoint = ModelCheckpoint( + dirpath=tmpdir, + save_top_k=-1, + monitor="val_acc", + mode="max", + verbose=True + ) + trainer = Trainer( + default_root_dir=tmpdir, + val_check_interval=0.2, + max_epochs=1, + limit_train_batches=10, + callbacks=[model_checkpoint] + ) + trainer.fit(model) + files = sorted([p.name for p in Path(tmpdir).glob("*.ckpt")]) + assert files == [f"epoch=0-step={s}.ckpt" for s in [1, 3, 5, 7, 9]] diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index 87af510e492194..fc61829645b6e3 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -159,7 +159,7 @@ def test_comet_logger_dirs_creation(comet, comet_experiment, tmpdir, monkeypatch trainer.fit(model) assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / "1" / 'checkpoints') - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'} @patch('pytorch_lightning.loggers.comet.comet_ml') diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index db2c353dc4e2c6..a200fbf549e6a1 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -115,7 +115,7 @@ def test_mlflow_log_dir(client, mlflow, tmpdir): ) trainer.fit(model) assert trainer.checkpoint_callback.dirpath == (tmpdir / "exp-id" / "run-id" / 'checkpoints') - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=0.ckpt'} def test_mlflow_logger_dirs_creation(tmpdir): @@ -143,7 +143,7 @@ def test_mlflow_logger_dirs_creation(tmpdir): assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics') assert set(os.listdir(tmpdir / exp_id / run_id / 'params')) == model.hparams.keys() assert trainer.checkpoint_callback.dirpath == (tmpdir / exp_id / run_id / 'checkpoints') - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'} @mock.patch('pytorch_lightning.loggers.mlflow.mlflow') diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index a7fd986bad642f..468ca819f91b18 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -116,7 +116,7 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir): trainer.fit(model) assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints') - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'} def test_wandb_sanitize_callable_params(tmpdir): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 35257e28704ba0..6fceae4b5e59d9 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -430,7 +430,7 @@ def mock_save_function(filepath, *args): losses = [10, 9, 2.8, 5, 2.5] checkpoint_callback = ModelCheckpoint( - dirpath=tmpdir, monitor='checkpoint_on', save_top_k=save_top_k, + dirpath=tmpdir, filename='{epoch}', monitor='checkpoint_on', save_top_k=save_top_k, save_last=save_last, prefix=file_prefix, verbose=1 ) checkpoint_callback.save_function = mock_save_function From 9b8102d1a53d3b395ed7e1670b7430bc59461f4e Mon Sep 17 00:00:00 2001 From: chaton Date: Mon, 2 Nov 2020 15:23:24 +0000 Subject: [PATCH 3/4] [DOC] Clarify `tpu_cores` training. (#4475) * better explanation around tpu_cores * more details on tpu training * Apply suggestions from code review Co-authored-by: Jeff Yang --- docs/source/tpu.rst | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/docs/source/tpu.rst b/docs/source/tpu.rst index ffebe41a9c856f..5f4c48076d8136 100644 --- a/docs/source/tpu.rst +++ b/docs/source/tpu.rst @@ -128,13 +128,27 @@ That's it! Your model will train on all 8 TPU cores. ---------------- -Single TPU core training +TPU core training + ------------------------ -Lightning supports training on a single TPU core. Just pass the TPU core ID [1-8] in a list. + +Lightning supports training on a single TPU core or 8 TPU cores. + +The Trainer parameters ``tpu_cores`` defines how many TPU cores to train on (1 or 8) / Single TPU to train on [1]. + +For Single TPU training, Just pass the TPU core ID [1-8] in a list. + +Single TPU core training. Model will train on TPU core ID 5. .. code-block:: python - trainer = pl.Trainer(tpu_cores=[1]) + trainer = pl.Trainer(tpu_cores=[5]) + +8 TPU cores training. Model will train on 8 TPU cores. + +.. code-block:: python + + trainer = pl.Trainer(tpu_cores=8) ---------------- From 102fa9ee7dd087ee167b120ea7812360928408f7 Mon Sep 17 00:00:00 2001 From: chaton Date: Mon, 2 Nov 2020 16:36:48 +0000 Subject: [PATCH 4/4] [BUGFIX] AMP + Precision unscale grad (#4441) * move unscale within Native plugin * remove gradient tracking from lightning backward * forgot trainer.fit * typo * update * cleanup * set to 1.6 * typo * skip if below 1.6 strict * update changelog * remove useless code * Update tests/plugins/test_amp_plugin.py Co-authored-by: Sean Naren * Update tests/plugins/test_amp_plugin.py Co-authored-by: Sean Naren * update changelog * Update CHANGELOG.md Co-authored-by: Sean Naren Co-authored-by: Jeff Yang --- CHANGELOG.md | 10 ++- pytorch_lightning/accelerators/accelerator.py | 5 -- pytorch_lightning/core/lightning.py | 1 - pytorch_lightning/plugins/native_amp.py | 5 ++ pytorch_lightning/trainer/training_loop.py | 25 +++++--- tests/plugins/test_amp_plugin.py | 62 +++++++++++++++++++ 6 files changed, 90 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cdb9ddc804d72c..84d483dd03f2c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,12 +17,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added multiclass AUROC metric ([#4236](https://github.com/PyTorchLightning/pytorch-lightning/pull/4236)) +- Added timeout for `tpu_device_exists` to ensure process does not hang indefinitely ([#4340](https://github.com/PyTorchLightning/pytorch-lightning/pull/4340)) + - Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807)) ### Changed - W&B log in sync with Trainer step ([#4405](https://github.com/PyTorchLightning/pytorch-lightning/pull/4405)) +- Hook `on_after_backward` is called only when `optimizer_step` is being called ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) + +- Moved `track_and_norm_grad` into `training loop` and called only when `optimizer_step` is being called ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) + ### Deprecated - Deprecated passing `ModelCheckpoint` instance to `checkpoint_callback` Trainer argument ([#4336](https://github.com/PyTorchLightning/pytorch-lightning/pull/4336)) @@ -33,6 +39,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed error using `auto_select_gpus=True` with `gpus=-1` ([#4209](https://github.com/PyTorchLightning/pytorch-lightning/pull/4209)) +- Fixed AMP unscale for `on_after_backward` ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) ## [1.0.4] - 2020-10-27 @@ -50,8 +57,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for string values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656)) -- Added timeout for `tpu_device_exists` to ensure process does not hang indefinitely ([#4340](https://github.com/PyTorchLightning/pytorch-lightning/pull/4340)) - ### Changed - Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587)) @@ -78,7 +83,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed WandbLogger not uploading checkpoint artifacts at the end of training ([#4341](https://github.com/PyTorchLightning/pytorch-lightning/pull/4341)) - ## [1.0.3] - 2020-10-20 ### Added diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 8e1969cc9368e2..8ece6c4ec1b101 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -131,11 +131,6 @@ def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): model_ref.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) def clip_gradients(self, optimizer, clip_val=None): - - if self.trainer.amp_backend == AMPType.NATIVE: - self.trainer.scaler.unscale_(optimizer) - - # apply clip gradients # TODO: separate TPU case from here self._clip_gradients(optimizer, clip_val) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 22d63d0a03a74c..d7125eb171a9cc 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1101,7 +1101,6 @@ def backward(self, loss, optimizer, optimizer_idx): """ loss.backward(*args, **kwargs) - self.trainer.train_loop.track_and_norm_grad(optimizer=optimizer) def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): """ diff --git a/pytorch_lightning/plugins/native_amp.py b/pytorch_lightning/plugins/native_amp.py index 6506540bde6e1b..b016b6c5d24fba 100644 --- a/pytorch_lightning/plugins/native_amp.py +++ b/pytorch_lightning/plugins/native_amp.py @@ -38,6 +38,11 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): # once backward has been applied, release graph closure_loss = closure_loss.detach() + + # unscale gradient to allow analyze within `on_after_backward` + if not self.trainer.train_loop.should_accumulate(): + self.trainer.scaler.unscale_(optimizer) + return closure_loss def training_step(self, fx, args): diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d1dfb3eec3733a..0d269c333b2698 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -652,11 +652,6 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) - # checks if backward or backward + optimizer step (via closure) - accumulation_done = self._accumulated_batches_reached() - is_final_batch = self._num_training_batches_reached() - should_accumulate = not (accumulation_done or is_final_batch) - # lightning module hook splits = self.tbptt_split_batch(batch) @@ -676,7 +671,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): model = self.trainer.get_model() model.toggle_optimizer(optimizer, opt_idx) - if should_accumulate: + if self.should_accumulate(): # For gradient accumulation # ------------------- @@ -767,7 +762,7 @@ def train_step_and_backward_closure(): @contextmanager def block_ddp_sync_behaviour(self): if isinstance(self.trainer.model, torch.nn.parallel.DistributedDataParallel): - yield from self.trainer.model.no_sync() + yield self.trainer.model.no_sync() else: yield @@ -817,8 +812,10 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, with self.trainer.profiler.profile("model_backward"): self.backward(result, optimizer, opt_idx) - # hook - self.on_after_backward(result.training_step_output, batch_idx, result.loss) + # hook - call this hook only + # when gradients have finished to accumulate + if not self.should_accumulate(): + self.on_after_backward(result.training_step_output, batch_idx, result.loss) # check if loss or model weights are nan if self.trainer.terminate_on_nan: @@ -837,6 +834,10 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs): result.closure_loss, optimizer, opt_idx, *args, **kwargs ) + if not self.should_accumulate(): + # track gradients + self.track_and_norm_grad(optimizer=optimizer) + def update_train_loop_lr_schedulers(self, monitor_metrics=None): num_accumulated_batches_reached = self._accumulated_batches_reached() num_training_batches_reached = self._num_training_batches_reached() @@ -863,6 +864,12 @@ def _accumulated_batches_reached(self): def _num_training_batches_reached(self): return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches + def should_accumulate(self): + # checks if backward or backward + optimizer step (via closure) + accumulation_done = self._accumulated_batches_reached() + is_final_batch = self._num_training_batches_reached() + return not (accumulation_done or is_final_batch) + def should_check_val_fx(self, batch_idx, is_last_batch): # decide if we should run validation is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 diff --git a/tests/plugins/test_amp_plugin.py b/tests/plugins/test_amp_plugin.py index c0d5747b5fc7e0..6fd000b61d97f2 100644 --- a/tests/plugins/test_amp_plugin.py +++ b/tests/plugins/test_amp_plugin.py @@ -84,3 +84,65 @@ def on_fit_start(self, trainer, pl_module): with pytest.raises(SystemExit): trainer.fit(model) + + +@pytest.mark.skipif( + LooseVersion(torch.__version__) < LooseVersion("1.6.0"), + reason="Minimal PT version is set to 1.6") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_amp_gradient_unscale(tmpdir): + + class ExtendedBoringModel(BoringModel): + + def on_after_backward(self): + norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2) + if not (torch.isinf(norm) or torch.isnan(norm)): + assert norm.item() < 15. + + model = ExtendedBoringModel() + + trainer = Trainer( + max_epochs=2, + default_root_dir=os.getcwd(), + limit_train_batches=2, + limit_test_batches=2, + limit_val_batches=2, + amp_backend='native', + distributed_backend='ddp_spawn', + gpus=2, + precision=16, + track_grad_norm=2, + log_every_n_steps=1 + ) + trainer.fit(model) + + +@pytest.mark.skipif( + LooseVersion(torch.__version__) < LooseVersion("1.6.0"), reason="Minimal PT version is set to 1.6") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_amp_gradient_unscale_accumulate_grad_batches(tmpdir): + + class ExtendedBoringModel(BoringModel): + + def on_after_backward(self): + norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2) + if not (torch.isinf(norm) or torch.isnan(norm)): + assert norm.item() < 15. + + model = ExtendedBoringModel() + + trainer = Trainer( + max_epochs=2, + default_root_dir=os.getcwd(), + limit_train_batches=2, + limit_test_batches=2, + limit_val_batches=2, + amp_backend='native', + distributed_backend='ddp_spawn', + gpus=2, + precision=16, + track_grad_norm=2, + log_every_n_steps=1, + accumulate_grad_batches=2, + ) + trainer.fit(model)