Skip to content

Commit

Permalink
Merge branch 'master' into jit-ignore-metric-forward
Browse files Browse the repository at this point in the history
  • Loading branch information
teddykoker authored Nov 2, 2020
2 parents 1bd64b5 + 102fa9e commit 3f5f1b9
Show file tree
Hide file tree
Showing 15 changed files with 228 additions and 88 deletions.
12 changes: 9 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added multiclass AUROC metric ([#4236](https://github.com/PyTorchLightning/pytorch-lightning/pull/4236))

- Added timeout for `tpu_device_exists` to ensure process does not hang indefinitely ([#4340](https://github.com/PyTorchLightning/pytorch-lightning/pull/4340))

- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807))

### Changed

- W&B log in sync with Trainer step ([#4405](https://github.com/PyTorchLightning/pytorch-lightning/pull/4405))

- Hook `on_after_backward` is called only when `optimizer_step` is being called ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))

- Moved `track_and_norm_grad` into `training loop` and called only when `optimizer_step` is being called ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))

### Deprecated

- Deprecated passing `ModelCheckpoint` instance to `checkpoint_callback` Trainer argument ([#4336](https://github.com/PyTorchLightning/pytorch-lightning/pull/4336))
Expand All @@ -31,6 +39,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed error using `auto_select_gpus=True` with `gpus=-1` ([#4209](https://github.com/PyTorchLightning/pytorch-lightning/pull/4209))

- Fixed AMP unscale for `on_after_backward` ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))

## [1.0.4] - 2020-10-27

Expand All @@ -48,8 +57,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for string values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656))

- Added timeout for `tpu_device_exists` to ensure process does not hang indefinitely ([#4340](https://github.com/PyTorchLightning/pytorch-lightning/pull/4340))

### Changed

- Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587))
Expand All @@ -76,7 +83,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed WandbLogger not uploading checkpoint artifacts at the end of training ([#4341](https://github.com/PyTorchLightning/pytorch-lightning/pull/4341))


## [1.0.3] - 2020-10-20

### Added
Expand Down
4 changes: 4 additions & 0 deletions docs/source/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ to manually manage the optimization process. To do so, do the following:
opt_d.step()
opt_d.zero_grad()
# log losses
self.log('loss_a', loss_a)
self.log('loss_b', loss_b)
.. note:: This is only recommended for experts who need ultimate flexibility

Manual optimization does not yet support accumulated gradients but will be live in 1.1.0
Expand Down
20 changes: 17 additions & 3 deletions docs/source/tpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,27 @@ That's it! Your model will train on all 8 TPU cores.

----------------

Single TPU core training
TPU core training

------------------------
Lightning supports training on a single TPU core. Just pass the TPU core ID [1-8] in a list.

Lightning supports training on a single TPU core or 8 TPU cores.

The Trainer parameters ``tpu_cores`` defines how many TPU cores to train on (1 or 8) / Single TPU to train on [1].

For Single TPU training, Just pass the TPU core ID [1-8] in a list.

Single TPU core training. Model will train on TPU core ID 5.

.. code-block:: python
trainer = pl.Trainer(tpu_cores=[1])
trainer = pl.Trainer(tpu_cores=[5])
8 TPU cores training. Model will train on 8 TPU cores.

.. code-block:: python
trainer = pl.Trainer(tpu_cores=8)
----------------

Expand Down
5 changes: 0 additions & 5 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,6 @@ def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):
model_ref.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)

def clip_gradients(self, optimizer, clip_val=None):

if self.trainer.amp_backend == AMPType.NATIVE:
self.trainer.scaler.unscale_(optimizer)

# apply clip gradients
# TODO: separate TPU case from here
self._clip_gradients(optimizer, clip_val)

Expand Down
65 changes: 35 additions & 30 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class ModelCheckpoint(Callback):
... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
... )
By default, filename is ``None`` and will be set to ``'{epoch}'``.
By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``.
Example::
Expand Down Expand Up @@ -222,16 +222,16 @@ def save_checkpoint(self, trainer, pl_module):
monitor_candidates = self._monitor_candidates(trainer)

# ie: path/val_loss=0.5.ckpt
filepath = self._get_metric_interpolated_filepath_name(epoch, monitor_candidates)
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, epoch, global_step)

# callback supports multiple simultaneous modes
# here we call each mode sequentially
# Mode 1: save all checkpoints OR only the top k
if self.save_top_k:
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, epoch, filepath)
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, filepath)

# Mode 2: save the last checkpoint
self._save_last_checkpoint(trainer, pl_module, epoch, monitor_candidates, filepath)
self._save_last_checkpoint(trainer, pl_module, monitor_candidates, filepath)

def __validate_init_configuration(self):
if self.save_top_k is not None and self.save_top_k < -1:
Expand Down Expand Up @@ -360,16 +360,17 @@ def _format_checkpoint_name(
cls,
filename: Optional[str],
epoch: int,
step: int,
metrics: Dict[str, Any],
prefix: str = "",
) -> str:
if not filename:
# filename is not set, use default name
filename = "{epoch}"
filename = "{epoch}-{step}"
# check and parse user passed keys in the string
groups = re.findall(r"(\{.*?)[:\}]", filename)
if len(groups) >= 0:
metrics["epoch"] = epoch
metrics.update({"epoch": epoch, 'step': step})
for group in groups:
name = group[1:]
filename = filename.replace(group, name + "={" + name)
Expand All @@ -379,32 +380,32 @@ def _format_checkpoint_name(
return cls.CHECKPOINT_JOIN_CHAR.join([txt for txt in (prefix, filename) if txt])

def format_checkpoint_name(
self, epoch: int, metrics: Dict[str, Any], ver: Optional[int] = None
self, epoch: int, step: int, metrics: Dict[str, Any], ver: Optional[int] = None
) -> str:
"""Generate a filename according to the defined template.
Example::
>>> tmpdir = os.path.dirname(__file__)
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
>>> os.path.basename(ckpt.format_checkpoint_name(0, 1, metrics={}))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}')
>>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
>>> os.path.basename(ckpt.format_checkpoint_name(5, 2, metrics={}))
'epoch=005.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
>>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
>>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
>>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={}))
'missing=0.ckpt'
>>> ckpt = ModelCheckpoint(filename='{epoch}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(filename='{step}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, 0, {}))
'step=0.ckpt'
"""
filename = self._format_checkpoint_name(
self.filename, epoch, metrics, prefix=self.prefix
self.filename, epoch, step, metrics, prefix=self.prefix
)
if ver is not None:
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
Expand Down Expand Up @@ -479,13 +480,11 @@ def _validate_monitor_key(self, trainer):
)
raise MisconfigurationException(m)

def _get_metric_interpolated_filepath_name(self, epoch, ckpt_name_metrics):
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
def _get_metric_interpolated_filepath_name(self, ckpt_name_metrics: Dict[str, Any], epoch: int, step: int):
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)
version_cnt = 0
while self._fs.exists(filepath):
filepath = self.format_checkpoint_name(
epoch, ckpt_name_metrics, ver=version_cnt
)
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt)
# this epoch called before
version_cnt += 1
return filepath
Expand All @@ -494,9 +493,10 @@ def _monitor_candidates(self, trainer):
ckpt_name_metrics = deepcopy(trainer.logger_connector.logged_metrics)
ckpt_name_metrics.update(trainer.logger_connector.callback_metrics)
ckpt_name_metrics.update(trainer.logger_connector.progress_bar_metrics)
ckpt_name_metrics.update({"step": trainer.global_step, "epoch": trainer.current_epoch})
return ckpt_name_metrics

def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, filepath):
def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath):
should_save_last = self.monitor is None or self.save_last
if not should_save_last:
return
Expand All @@ -506,7 +506,11 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi
# when user ALSO asked for the 'last.ckpt' change the name
if self.save_last:
last_filepath = self._format_checkpoint_name(
self.CHECKPOINT_NAME_LAST, epoch, ckpt_name_metrics, prefix=self.prefix
self.CHECKPOINT_NAME_LAST,
trainer.current_epoch,
trainer.global_step,
ckpt_name_metrics,
prefix=self.prefix
)
last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt")

Expand All @@ -523,17 +527,19 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi
if self.monitor is None:
self.best_model_path = self.last_model_path

def _save_top_k_checkpoints(self, metrics, trainer, pl_module, epoch, filepath):
def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath):
current = metrics.get(self.monitor)
epoch = metrics.get("epoch")
step = metrics.get("step")

if not isinstance(current, torch.Tensor) and current is not None:
current = torch.tensor(current, device=pl_module.device)

if self.check_monitor_top_k(current):
self._update_best_and_save(filepath, current, epoch, trainer, pl_module)
self._update_best_and_save(filepath, current, epoch, step, trainer, pl_module)
elif self.verbose:
rank_zero_info(
f"Epoch {epoch:d}: {self.monitor} was not in top {self.save_top_k}"
f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}"
)

def _is_valid_monitor_key(self, metrics):
Expand All @@ -544,11 +550,11 @@ def _update_best_and_save(
filepath: str,
current: torch.Tensor,
epoch: int,
step: int,
trainer,
pl_module,
):

k = epoch + 1 if self.save_top_k == -1 else self.save_top_k
k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k

del_list = []
if len(self.best_k_models) == k and k > 0:
Expand All @@ -575,9 +581,8 @@ def _update_best_and_save(

if self.verbose:
rank_zero_info(
f"Epoch {epoch:d}: {self.monitor} reached"
f" {current:0.5f} (best {self.best_model_score:0.5f}),"
f" saving model to {filepath} as top {k}"
f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}"
f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}'
)
self._save_model(filepath, trainer, pl_module)

Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,6 @@ def backward(self, loss, optimizer, optimizer_idx):
"""
loss.backward(*args, **kwargs)
self.trainer.train_loop.track_and_norm_grad(optimizer=optimizer)

def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
"""
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/plugins/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):

# once backward has been applied, release graph
closure_loss = closure_loss.detach()

# unscale gradient to allow analyze within `on_after_backward`
if not self.trainer.train_loop.should_accumulate():
self.trainer.scaler.unscale_(optimizer)

return closure_loss

def training_step(self, fx, args):
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,10 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
# depre warning
if eval_results is not None and user_reduced:
step = 'testing_epoch_end' if self.testing else 'validation_epoch_end'
m = f'The {step} should not return anything as of 9.1.' \
f'to log, use self.log(...) or self.write(...) directly in the LightningModule'
self.warning_cache.warn(m)
self.warning_cache.warn(
f'The {step} should not return anything as of 9.1.'
' To log, use self.log(...) or self.write(...) directly in the LightningModule'
)

if using_eval_result and not user_reduced:
eval_results = self.__auto_reduce_result_objs(outputs)
Expand Down
25 changes: 16 additions & 9 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,11 +652,6 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
if response == -1:
return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic)

# checks if backward or backward + optimizer step (via closure)
accumulation_done = self._accumulated_batches_reached()
is_final_batch = self._num_training_batches_reached()
should_accumulate = not (accumulation_done or is_final_batch)

# lightning module hook
splits = self.tbptt_split_batch(batch)

Expand All @@ -676,7 +671,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
model = self.trainer.get_model()
model.toggle_optimizer(optimizer, opt_idx)

if should_accumulate:
if self.should_accumulate():
# For gradient accumulation

# -------------------
Expand Down Expand Up @@ -767,7 +762,7 @@ def train_step_and_backward_closure():
@contextmanager
def block_ddp_sync_behaviour(self):
if isinstance(self.trainer.model, torch.nn.parallel.DistributedDataParallel):
yield from self.trainer.model.no_sync()
yield self.trainer.model.no_sync()
else:
yield

Expand Down Expand Up @@ -817,8 +812,10 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,
with self.trainer.profiler.profile("model_backward"):
self.backward(result, optimizer, opt_idx)

# hook
self.on_after_backward(result.training_step_output, batch_idx, result.loss)
# hook - call this hook only
# when gradients have finished to accumulate
if not self.should_accumulate():
self.on_after_backward(result.training_step_output, batch_idx, result.loss)

# check if loss or model weights are nan
if self.trainer.terminate_on_nan:
Expand All @@ -837,6 +834,10 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs):
result.closure_loss, optimizer, opt_idx, *args, **kwargs
)

if not self.should_accumulate():
# track gradients
self.track_and_norm_grad(optimizer=optimizer)

def update_train_loop_lr_schedulers(self, monitor_metrics=None):
num_accumulated_batches_reached = self._accumulated_batches_reached()
num_training_batches_reached = self._num_training_batches_reached()
Expand All @@ -863,6 +864,12 @@ def _accumulated_batches_reached(self):
def _num_training_batches_reached(self):
return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches

def should_accumulate(self):
# checks if backward or backward + optimizer step (via closure)
accumulation_done = self._accumulated_batches_reached()
is_final_batch = self._num_training_batches_reached()
return not (accumulation_done or is_final_batch)

def should_check_val_fx(self, batch_idx, is_last_batch):
# decide if we should run validation
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
Expand Down
Loading

0 comments on commit 3f5f1b9

Please sign in to comment.