Skip to content

Commit

Permalink
Remove legacy Result parameters (#6016)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Mar 28, 2021
1 parent 0e45220 commit f0c5479
Show file tree
Hide file tree
Showing 13 changed files with 35 additions and 184 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `mode='auto'` from `EarlyStopping` ([#6167](https://github.com/PyTorchLightning/pytorch-lightning/pull/6167))


- Removed legacy references for magic keys in the `Result` object ([#6016](https://github.com/PyTorchLightning/pytorch-lightning/pull/6016))


- Removed deprecated `LightningModule` `hparams` setter ([#6207](https://github.com/PyTorchLightning/pytorch-lightning/pull/6207))


Expand Down
1 change: 0 additions & 1 deletion docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,6 @@ For cases like production, you might want to iterate different models inside a L
loss = F.cross_entropy(y_hat, y)
acc = FM.accuracy(y_hat, y)
# loss is tensor. The Checkpoint Callback is monitoring 'checkpoint_on'
metrics = {'val_acc': acc, 'val_loss': loss}
self.log_dict(metrics)
return metrics
Expand Down
8 changes: 1 addition & 7 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1478,15 +1478,9 @@ with the hidden
def training_step(self, batch, batch_idx, hiddens):
# hiddens are the hiddens from the previous truncated backprop step
out, hiddens = self.lstm(data, hiddens)
# remember to detach() hiddens.
# If you don't, you will get a RuntimeError: Trying to backward through
# the graph a second time...
# Using hiddens.detach() allows each split to be disconnected.
return {
"loss": ...,
"hiddens": hiddens # remember to detach() this
"hiddens": hiddens
}
To modify how the batch is split,
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def __init__(
self.wait_count = 0
self.stopped_epoch = 0
self.mode = mode
self.warned_result_obj = False

if self.mode not in self.mode_dict:
raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def __init__(
self.best_model_path = ""
self.last_model_path = ""
self.save_function = None
self.warned_result_obj = False

self.__init_monitor_mode(monitor, mode)
self.__init_ckpt_dir(dirpath, filename, save_top_k)
Expand Down
79 changes: 15 additions & 64 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""[Train, Eval]Result for easier logging, checkpointing, early stopping, epoch-wise reduction."""
"""Result class for easier logging and epoch-wise reduction."""

import numbers
import os
from copy import copy
from typing import Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Sequence, Tuple, Union

Expand All @@ -27,33 +26,14 @@

class Result(Dict):

def __init__(
self,
minimize: Optional[Tensor] = None,
early_stop_on: Optional[Tensor] = None,
checkpoint_on: Optional[Union[Tensor, bool]] = None,
hiddens: Optional[Tensor] = None,
):

def __init__(self, minimize: Optional[Tensor] = None):
super().__init__()

# temporary until dict results are deprecated
os.environ['PL_USING_RESULT_OBJ'] = '1'

if early_stop_on is not None:
self.early_stop_on = early_stop_on
if checkpoint_on is not None and checkpoint_on:
self.checkpoint_on = checkpoint_on
if hiddens is not None:
self.hiddens = hiddens.detach()
if minimize is not None:
err = 'Minimize can only be used in training_step, training_step_end, training_epoch_end'
self._assert_grad_tensor_metric('minimize', minimize, err)
self.minimize = minimize

if minimize is not None and checkpoint_on is None:
self.checkpoint_on = minimize.detach()

self['meta'] = {'_internal': {'_reduce_on_epoch': False, 'batch_sizes': []}}

def __getitem__(self, key: Union[str, Any]) -> Any:
Expand All @@ -64,9 +44,7 @@ def __getitem__(self, key: Union[str, Any]) -> Any:

def __getattr__(self, key: str) -> Any:
try:
if key == 'callback_metrics':
return self.get_callback_metrics()
elif key == 'batch_log_metrics':
if key == 'batch_log_metrics':
return self.get_batch_log_metrics()
elif key == 'batch_pbar_metrics':
return self.get_batch_pbar_metrics()
Expand All @@ -80,16 +58,9 @@ def __getattr__(self, key: str) -> Any:
return None

def __setattr__(self, key: str, val: Union[Tensor, Any]):
# ensure reserve keys are tensors and detached
if key in {'checkpoint_on', 'early_stop_on'}:
self._assert_tensor_metric(key, val)
if val is not None and isinstance(val, torch.Tensor):
val = val.detach()

# ensure anything else that is a tensor is detached
elif isinstance(val, torch.Tensor) and key != 'minimize':
# ensure tensors are detached
if isinstance(val, torch.Tensor) and key != 'minimize':
val = val.detach()

self[key] = val

def __getstate__(self):
Expand All @@ -98,11 +69,6 @@ def __getstate__(self):
def __setstate__(self, d):
self.update(d)

def _assert_tensor_metric(self, name: str, potential_metric: Union[bool, Tensor, None, Any]):
if potential_metric is not None and not isinstance(potential_metric, bool):
if not isinstance(potential_metric, Tensor):
raise TypeError(f'{name} must be a torch.Tensor')

def _assert_grad_tensor_metric(self, name: str, x: Union[torch.Tensor, Any], additional_err: str = ''):
if x is not None:
if not isinstance(x, Tensor):
Expand Down Expand Up @@ -272,11 +238,6 @@ def get_batch_sizes(self):
meta = self['meta']
return torch.tensor(meta['_internal']['batch_sizes'])

def get_callback_metrics(self) -> dict:
result = {'early_stop_on': self.early_stop_on, 'checkpoint_on': self.checkpoint_on}

return result

def _add_dataloader_idx(self, k: str, dataloader_idx: Union[int, None], add_dataloader_idx: bool) -> str:
if dataloader_idx is not None and add_dataloader_idx:
return f"{k}/dataloader_idx_{dataloader_idx}"
Expand Down Expand Up @@ -495,25 +456,22 @@ def padded_gather(cls, outputs):
# find the padding used for other values
default_padding_idx = 0
for name, value in result.items():
if isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor):
if name not in {'checkpoint_on', 'early_stop_on', 'minimize'}:
default_padding_idx = meta[name]['tbptt_pad_token']
break
if (
name != 'minimize' and isinstance(value, list) and len(value) > 0
and isinstance(value[0], torch.Tensor)
):
default_padding_idx = meta[name]['tbptt_pad_token']
break

# pad across each key individually
for name, value in result.items():
is_reserved = name in {'checkpoint_on', 'early_stop_on', 'minimize'}
if isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor):

if is_reserved:
padding_key = default_padding_idx
else:
padding_key = meta[name]['tbptt_pad_token']
if (isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor)):
padding_key = default_padding_idx if name == 'minimize' else meta[name]['tbptt_pad_token']
padded = torch.nn.utils.rnn.pad_sequence(value, batch_first=True, padding_value=padding_key)
result[name] = padded

# also update the result
if meta and not is_reserved:
if meta and name != "minimize":
meta[name]['value'] = padded
if meta:
result['meta'] = meta
Expand Down Expand Up @@ -581,10 +539,7 @@ def reduce_across_time(cls, time_outputs):
continue

# pick the reduce fx
if k in ['checkpoint_on', 'early_stop_on', 'minimize']:
tbptt_reduce_fx = torch.mean
else:
tbptt_reduce_fx = meta[k]['tbptt_reduce_fx']
tbptt_reduce_fx = torch.mean if k == "minimize" else meta[k]['tbptt_reduce_fx']

if isinstance(value, list):
value = torch.tensor(value)
Expand Down Expand Up @@ -612,10 +567,6 @@ def dp_reduce(self):
def should_reduce_on_epoch_end(self) -> bool:
return self['meta']['_internal']['_reduce_on_epoch']

def drop_hiddens(self):
if 'hiddens' in self:
del self['hiddens']

def rename_keys(self, map_dict: dict):
"""
Maps key values to the target values. Useful when renaming variables in mass.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,6 @@ def _track_callback_metrics(self, eval_results):
elif isinstance(eval_result, dict):
flat = flatten_dict(eval_result)

# removing val_loss magic word to map to checkpoint + ES callback
if 'val_loss' in flat:
flat['checkpoint_on'] = flat['val_loss']
flat['early_stop_on'] = flat['val_loss']
self.trainer.logger_connector.callback_metrics.update(flat)
if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING):
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
Expand All @@ -331,11 +327,6 @@ def _track_callback_metrics(self, eval_results):
else:
flat = flatten_dict(eval_results)

# removing val_loss magic word to map to checkpoint + ES callback
if 'val_loss' in flat:
flat['checkpoint_on'] = flat['val_loss']
flat['early_stop_on'] = flat['val_loss']

self.trainer.logger_connector.callback_metrics.update(flat)
if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING):
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
Expand Down Expand Up @@ -370,26 +361,13 @@ def on_train_epoch_end(self):
# inform cached logger connector epoch finished
self.cached_results.has_batch_loop_finished = True

def log_train_epoch_end_metrics(
self, epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers
):
def log_train_epoch_end_metrics(self, epoch_output, num_optimizers):
# epoch output is a list. Each item in that list has all the outputs per optimizer
# epoch_output[optimizer_idx][training_step_idx][tbptt_index]
# remember that not using truncated backprop is equivalent with truncated back prop of len(1)

model = self.trainer.lightning_module

epoch_callback_metrics = {}

# -----------------------
# Calculate epoch callback values if given
# -----------------------
if checkpoint_accumulator.num_values > 0:
epoch_callback_metrics['checkpoint_on'] = checkpoint_accumulator.mean()

if early_stopping_accumulator.num_values > 0:
epoch_callback_metrics['early_stop_on'] = early_stopping_accumulator.mean()

# ------------------------
# determine if using a result obj
# ------------------------
Expand Down Expand Up @@ -437,9 +415,6 @@ def log_train_epoch_end_metrics(
self.log_metrics(epoch_log_metrics, {})
self._callback_metrics.update(epoch_log_metrics)

# add metrics to callbacks
self._callback_metrics.update(epoch_callback_metrics)

# add metrics to progress_bar and callbacks
if len(epoch_progress_bar_metrics) > 0:
self.add_progress_bar_metrics(epoch_progress_bar_metrics)
Expand Down
9 changes: 0 additions & 9 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,6 @@ def __gather_epoch_end_eval_results(self, outputs):
eval_results = []
for epoch_output in outputs:
result = epoch_output[0].__class__.gather(epoch_output)
if 'checkpoint_on' in result:
result.checkpoint_on = result.checkpoint_on.mean()
if 'early_stop_on' in result:
result.early_stop_on = result.early_stop_on.mean()

eval_results.append(result)

# with 1 dataloader don't pass in a list
Expand All @@ -269,10 +264,6 @@ def __auto_reduce_result_objs(self, outputs):
for dl_output in outputs:
result = dl_output[0]
result = result.__class__.reduce_on_epoch_end(dl_output)
if 'checkpoint_on' in result:
result.checkpoint_on = result.checkpoint_on.mean()
if 'early_stop_on' in result:
result.early_stop_on = result.early_stop_on.mean()
eval_results.append(result)

return eval_results
Expand Down
9 changes: 4 additions & 5 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import inspect
from abc import ABC
from typing import Mapping
from collections import Mapping

import torch

Expand Down Expand Up @@ -76,10 +76,7 @@ def process_dict_result(self, output, train=False):
# --------------------------
# single scalar returned from a xx_step
if isinstance(output, torch.Tensor):
progress_bar_metrics = {}
log_metrics = {}
hiddens = None
return output, progress_bar_metrics, log_metrics, hiddens
return output, {}, {}, None

# ---------------
# EXTRACT PROGRESS BAR KEYS
Expand Down Expand Up @@ -140,6 +137,8 @@ def process_dict_result(self, output, train=False):
# EXTRACT HIDDEN
# ---------------
hiddens = output.get('hiddens', None) if isinstance(output, Mapping) else None
if hiddens is not None:
hiddens = hiddens.detach()

# detach all metrics for callbacks to prevent memory leaks
# no .item() because it will slow things down
Expand Down
15 changes: 0 additions & 15 deletions pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,21 +104,6 @@ def _agg_memory(self, how: str):
return getattr(self.memory[:self.current_idx], how)()


class Accumulator(object):

def __init__(self):
self.num_values = 0
self.total = 0

def accumulate(self, x):
with torch.no_grad():
self.total += x
self.num_values += 1

def mean(self):
return self.total / self.num_values


class PredictionCollection(object):

def __init__(self, global_rank: int, world_size: int):
Expand Down
Loading

0 comments on commit f0c5479

Please sign in to comment.