-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Extend support for logging a collection #7771
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1aefd92
b9982bc
064b27f
9c268e6
1da392d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -46,7 +46,7 @@ | |
| from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
| from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters | ||
| from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature | ||
| from pytorch_lightning.utilities.types import _METRIC, EPOCH_OUTPUT, STEP_OUTPUT | ||
| from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, STEP_OUTPUT | ||
| from pytorch_lightning.utilities.warnings import WarningCache | ||
|
|
||
| if TYPE_CHECKING: | ||
|
|
@@ -261,7 +261,7 @@ def forward(self, x): | |
| def log( | ||
| self, | ||
| name: str, | ||
| value: Any, | ||
| value: _METRIC_COLLECTION, | ||
| prog_bar: bool = False, | ||
| logger: bool = True, | ||
| on_step: Optional[bool] = None, | ||
|
|
@@ -324,6 +324,9 @@ def log( | |
| ' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`' | ||
| ) | ||
|
|
||
| # check for none values | ||
| apply_to_collection(value, type(None), partial(self.__check_none, name, value)) | ||
|
|
||
| # set the default depending on the fx_name | ||
| on_step = self.__auto_choose_log_on_step(on_step) | ||
| on_epoch = self.__auto_choose_log_on_epoch(on_epoch) | ||
|
|
@@ -335,14 +338,15 @@ def log( | |
| if "/dataloader_idx_" in name: | ||
| raise MisconfigurationException(f"Logged key: {name} should not contain information about dataloader_idx.") | ||
|
|
||
| value = self.__sync( | ||
| value, | ||
| sync_fn = partial( | ||
| self.__sync, | ||
| sync_fn=self.trainer.training_type_plugin.reduce, | ||
| sync_dist=sync_dist, | ||
| sync_dist_op=sync_dist_op, | ||
| sync_dist_group=sync_dist_group, | ||
| device=self.device, | ||
| ) | ||
| value = apply_to_collection(value, (torch.Tensor, numbers.Number), sync_fn) | ||
|
|
||
| assert self._results is not None | ||
| self._results.log( | ||
|
|
@@ -359,7 +363,7 @@ def log( | |
|
|
||
| def log_dict( | ||
| self, | ||
| dictionary: dict, | ||
| dictionary: Dict[str, _METRIC_COLLECTION], | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The type here is interesting. If this is read-only, let's use Since Dict implies mutability (eg, the function might mutate the dictionary), the type checker defines it as invariant. As such, Dict[str, torch.Tensor] is NOT a valid type to pass to to a function accepting Dict[str, Union[torch.Tensor, Metric]]). See here for discussion: python/mypy#2300 (comment)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting! I did not know that. Feel free to open a patch with the change. Thanks!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is a minor type-only change, sent it out with #7851. |
||
| prog_bar: bool = False, | ||
| logger: bool = True, | ||
| on_step: Optional[bool] = None, | ||
|
|
@@ -416,29 +420,26 @@ def log_dict( | |
|
|
||
| @staticmethod | ||
| def __sync( | ||
| value: _METRIC, | ||
| value: Union[torch.Tensor, numbers.Number], | ||
| sync_fn: Optional[Callable] = None, | ||
| sync_dist: bool = False, | ||
| sync_dist_op: Union[Any, str] = 'mean', | ||
| sync_dist_group: Optional[Any] = None, | ||
| device: torch.device = None, | ||
| ) -> _METRIC: | ||
| ) -> torch.Tensor: | ||
| """Sync across workers when using distributed training""" | ||
| if not isinstance(value, (torch.Tensor, numbers.Number)): | ||
| return value | ||
|
|
||
| if isinstance(value, numbers.Number): | ||
| value = torch.tensor(value, device=device, dtype=torch.float) | ||
| sync_fn = sync_fn or sync_ddp_if_available | ||
| dist_available = torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed() | ||
| if not sync_dist or not dist_available: | ||
| return value | ||
|
|
||
| # TODO: Find a way to make the reduction only once, so we don't need to clone. | ||
| if isinstance(value, torch.Tensor): | ||
| value = value.clone() | ||
| else: | ||
| value = torch.tensor(value, device=device, dtype=torch.float) | ||
| return sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op) | ||
|
|
||
| @staticmethod | ||
| def __check_none(name: str, value: Any, _) -> Any: | ||
| raise ValueError(f'`self.log({name}, {value})` was called, but `None` values cannot be logged') | ||
|
|
||
| def write_prediction( | ||
| self, name: str, value: Union[torch.Tensor, List[torch.Tensor]], filename: str = 'predictions.pt' | ||
| ): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.