From 6e15643cba86a85a7d8da46cea41c36833c7234a Mon Sep 17 00:00:00 2001 From: Christoph Clement <36116534+chris-clem@users.noreply.github.com> Date: Mon, 18 Jan 2021 09:55:17 +0100 Subject: [PATCH] Update TensorboardGenerativeModelImageSampler args (#494) * Update args * Add docs * Fix codefactor and update changelog * Update docs * Apply yapf * chlog Co-authored-by: Christoph Clement Co-authored-by: Jirka Borovec --- CHANGELOG.md | 1 + pl_bolts/callbacks/vision/image_generation.py | 50 +++++++++++++++++-- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0bdc457628..64adeb2add 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Refactored `pl_bolts.callbacks` ([#477](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/477)) - Refactored the rest of `pl_bolts.models.self_supervised` ([#481](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/481), [#479](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/479) +- Update [`torchvision.utils.make_grid`(https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid)] kwargs to `TensorboardGenerativeModelImageSampler` ([#494](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/494)) ### Fixed diff --git a/pl_bolts/callbacks/vision/image_generation.py b/pl_bolts/callbacks/vision/image_generation.py index 5572cbb363..83ee748d05 100644 --- a/pl_bolts/callbacks/vision/image_generation.py +++ b/pl_bolts/callbacks/vision/image_generation.py @@ -1,3 +1,5 @@ +from typing import Optional, Tuple + import torch from pytorch_lightning import Callback, LightningModule, Trainer @@ -6,7 +8,7 @@ try: import torchvision except ModuleNotFoundError: - warn_missing_pkg('torchvision') # pragma: no-cover + warn_missing_pkg("torchvision") # pragma: no-cover class TensorboardGenerativeModelImageSampler(Callback): @@ -30,9 +32,39 @@ class TensorboardGenerativeModelImageSampler(Callback): trainer = Trainer(callbacks=[TensorboardGenerativeModelImageSampler()]) """ - def __init__(self, num_samples: int = 3) -> None: + def __init__( + self, + num_samples: int = 3, + nrow: int = 8, + padding: int = 2, + normalize: bool = False, + norm_range: Optional[Tuple[int, int]] = None, + scale_each: bool = False, + pad_value: int = 0, + ) -> None: + """ + Args: + num_samples: Number of images displayed in the grid. Default: ``3``. + nrow: Number of images displayed in each row of the grid. + The final grid size is ``(B / nrow, nrow)``. Default: ``8``. + padding: Amount of padding. Default: ``2``. + normalize: If ``True``, shift the image to the range (0, 1), + by the min and max values specified by :attr:`range`. Default: ``False``. + norm_range: Tuple (min, max) where min and max are numbers, + then these numbers are used to normalize the image. By default, min and max + are computed from the tensor. + scale_each: If ``True``, scale each image in the batch of + images separately rather than the (min, max) over all images. Default: ``False``. + pad_value: Value for the padded pixels. Default: ``0``. + """ super().__init__() - self.num_samples: int = num_samples + self.num_samples = num_samples + self.nrow = nrow + self.padding = padding + self.normalize = normalize + self.norm_range = norm_range + self.scale_each = scale_each + self.pad_value = pad_value def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: dim = (self.num_samples, pl_module.hparams.latent_dim) # type: ignore[union-attr] @@ -48,6 +80,14 @@ def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: img_dim = pl_module.img_dim images = images.view(self.num_samples, *img_dim) - grid = torchvision.utils.make_grid(images) - str_title = f'{pl_module.__class__.__name__}_images' + grid = torchvision.utils.make_grid( + tensor=images, + nrow=self.nrow, + padding=self.padding, + normalize=self.normalize, + range=self.norm_range, + scale_each=self.scale_each, + pad_value=self.pad_value, + ) + str_title = f"{pl_module.__class__.__name__}_images" trainer.logger.experiment.add_image(str_title, grid, global_step=trainer.global_step)