Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the behaviour when logging evaluation step metrics to no longer append `/epoch_*` to the metric name ([#7351](https://github.com/PyTorchLightning/pytorch-lightning/pull/7351))


- Raise `ValueError` when a `None` value is `self.log`-ed ([#7771](https://github.com/PyTorchLightning/pytorch-lightning/pull/7771))


- Changed `resolve_training_type_plugins` to allow setting `num_nodes` and `sync_batchnorm` from `Trainer` setting ([#7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))


Expand Down
33 changes: 17 additions & 16 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import _METRIC, EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

if TYPE_CHECKING:
Expand Down Expand Up @@ -261,7 +261,7 @@ def forward(self, x):
def log(
self,
name: str,
value: Any,
value: _METRIC_COLLECTION,
prog_bar: bool = False,
logger: bool = True,
on_step: Optional[bool] = None,
Expand Down Expand Up @@ -324,6 +324,9 @@ def log(
' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`'
)

# check for none values
apply_to_collection(value, type(None), partial(self.__check_none, name, value))

# set the default depending on the fx_name
on_step = self.__auto_choose_log_on_step(on_step)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)
Expand All @@ -335,14 +338,15 @@ def log(
if "/dataloader_idx_" in name:
raise MisconfigurationException(f"Logged key: {name} should not contain information about dataloader_idx.")

value = self.__sync(
value,
sync_fn = partial(
self.__sync,
sync_fn=self.trainer.training_type_plugin.reduce,
sync_dist=sync_dist,
sync_dist_op=sync_dist_op,
sync_dist_group=sync_dist_group,
device=self.device,
)
value = apply_to_collection(value, (torch.Tensor, numbers.Number), sync_fn)

assert self._results is not None
self._results.log(
Expand All @@ -359,7 +363,7 @@ def log(

def log_dict(
self,
dictionary: dict,
dictionary: Dict[str, _METRIC_COLLECTION],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type here is interesting.

If this is read-only, let's use Mapping[str, _METRIC_COLLECTION] which is covarient with the value type of the key (meaning that Mapping[str, torch.Tensor] and Mapping[str, Metric]`) are both subtypes, and will validate type checking.

Since Dict implies mutability (eg, the function might mutate the dictionary), the type checker defines it as invariant. As such, Dict[str, torch.Tensor] is NOT a valid type to pass to to a function accepting Dict[str, Union[torch.Tensor, Metric]]). See here for discussion: python/mypy#2300 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! I did not know that.

Feel free to open a patch with the change. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a minor type-only change, sent it out with #7851.

prog_bar: bool = False,
logger: bool = True,
on_step: Optional[bool] = None,
Expand Down Expand Up @@ -416,29 +420,26 @@ def log_dict(

@staticmethod
def __sync(
value: _METRIC,
value: Union[torch.Tensor, numbers.Number],
sync_fn: Optional[Callable] = None,
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
sync_dist_group: Optional[Any] = None,
device: torch.device = None,
) -> _METRIC:
) -> torch.Tensor:
"""Sync across workers when using distributed training"""
if not isinstance(value, (torch.Tensor, numbers.Number)):
return value

if isinstance(value, numbers.Number):
value = torch.tensor(value, device=device, dtype=torch.float)
sync_fn = sync_fn or sync_ddp_if_available
dist_available = torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed()
if not sync_dist or not dist_available:
return value

# TODO: Find a way to make the reduction only once, so we don't need to clone.
if isinstance(value, torch.Tensor):
value = value.clone()
else:
value = torch.tensor(value, device=device, dtype=torch.float)
return sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op)

@staticmethod
def __check_none(name: str, value: Any, _) -> Any:
raise ValueError(f'`self.log({name}, {value})` was called, but `None` values cannot be logged')

def write_prediction(
self, name: str, value: Union[torch.Tensor, List[torch.Tensor]], filename: str = 'predictions.pt'
):
Expand Down
20 changes: 10 additions & 10 deletions pytorch_lightning/plugins/training_type/ddp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import torch

from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.connectors.logger_connector.result import Result
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.types import _METRIC_COLLECTION


class DDP2Plugin(DDPPlugin):
Expand All @@ -34,26 +35,25 @@ def setup(self, model):
self.task_idx = self.cluster_environment.local_rank()
# the difference to DDP is that we don't call children processes here

def reduce(self, tensor, *args, **kwargs):
def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION:
"""
Reduces a tensor from all processes to one aggregated tensor.
Reduces a collection of tensors from all processes. It can be applied to just a single tensor.
In DDP2, the reduction here is only across local devices within the node.

Args:
tensor: the tensor to sync and reduce
collection: The collection of tensors to sync and reduce.
*args: ignored for DDP2
**kwargs: ignored for DDP2

Return:
reduced value, except when the input was not a tensor the output remains is unchanged
Reduced tensor values or the same value if it was not or did not contain a tensor.
"""
if isinstance(tensor, Result):
tensor.dp_reduce()

elif isinstance(tensor, torch.Tensor):
tensor = tensor.mean()
def mean(t: torch.Tensor) -> torch.Tensor:
original_dtype = t.dtype
return t.float().mean().to(original_dtype)

return tensor
return apply_to_collection(collection, torch.Tensor, mean)

@property
def root_device(self):
Expand Down
24 changes: 9 additions & 15 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

from pytorch_lightning.overrides.data_parallel import LightningParallelModule
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.connectors.logger_connector.result import Result
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.types import _METRIC_COLLECTION


class DataParallelPlugin(ParallelPlugin):
Expand Down Expand Up @@ -52,30 +52,24 @@ def setup(self, model):
model.to(self.root_device)
self._model = DataParallel(LightningParallelModule(model), self.parallel_devices)

def reduce(self, tensor, *args, **kwargs):
def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION:
"""
Reduces a tensor from all parallel processes to one aggregated tensor.
Reduces a collection of tensors from all processes. It can be applied to just a single tensor.

Args:
tensor: the tensor to sync and reduce
collection: The collection of tensors to sync and reduce.
*args: ignored for DP
**kwargs: ignored for DP

Return:
reduced value, except when the input was not a tensor the output remains is unchanged
Reduced tensor values or the same value if it was not or did not contain a tensor.
"""
if isinstance(tensor, Result):
tensor.dp_reduce()

else:
def mean(t: torch.Tensor) -> torch.Tensor:
original_dtype = t.dtype
return t.float().mean().to(original_dtype)

def _reduce(t: torch.Tensor):
dtype_tensor = t.dtype
return t.float().mean().type(dtype_tensor)

tensor = apply_to_collection(tensor, torch.Tensor, _reduce)

return tensor
return apply_to_collection(collection, torch.Tensor, mean)

@property
def root_device(self):
Expand Down
10 changes: 0 additions & 10 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,16 +461,6 @@ def reduce_across_time(cls, time_outputs):
result['meta'] = meta
return result

def dp_reduce(self):
for k, value in self.items():
if k == 'meta' or isinstance(value, Metric):
continue

if isinstance(value, list):
value = torch.tensor(value)

self[k] = value.mean(dim=-1)

@property
def should_reduce_on_epoch_end(self) -> bool:
return self['meta']['_internal']['_reduce_on_epoch']
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from torchmetrics import Metric

_METRIC = Union[Metric, torch.Tensor, Number]
# real type is `Union[_METRIC, Dict[str, '_METRIC_COLLECTION']]` but Sphinx fails with `RecursionError`
_METRIC_COLLECTION = Union[_METRIC, Dict[str, _METRIC]]
STEP_OUTPUT = Union[torch.Tensor, Dict[str, Any]]
EPOCH_OUTPUT = List[STEP_OUTPUT]
_EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader
Expand Down
14 changes: 14 additions & 0 deletions tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,3 +769,17 @@ def validation_step(self, batch, batch_idx):

assert trainer.callback_metrics["val_acc"] == 8 / 32.
assert "train_loss" in trainer.callback_metrics


@pytest.mark.parametrize('value', [None, {'a': {'b': None}}])
def test_log_none_raises(tmpdir, value):

class TestModel(BoringModel):

def training_step(self, *args):
self.log("foo", value)

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
model = TestModel()
with pytest.raises(ValueError, match=rf"self.log\(foo, {value}\)` was called"):
trainer.fit(model)