From 77ef8a3ff9ec38c3621284c9691c7500fe1ffc0b Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 14 Dec 2023 18:52:43 +0900 Subject: [PATCH 001/117] add stimulus_domain module --- bdpy/dl/torch/stimulus_domain/__init__.py | 0 bdpy/dl/torch/stimulus_domain/core.py | 77 +++++++++ bdpy/dl/torch/stimulus_domain/image_domain.py | 158 ++++++++++++++++++ 3 files changed, 235 insertions(+) create mode 100644 bdpy/dl/torch/stimulus_domain/__init__.py create mode 100644 bdpy/dl/torch/stimulus_domain/core.py create mode 100644 bdpy/dl/torch/stimulus_domain/image_domain.py diff --git a/bdpy/dl/torch/stimulus_domain/__init__.py b/bdpy/dl/torch/stimulus_domain/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bdpy/dl/torch/stimulus_domain/core.py b/bdpy/dl/torch/stimulus_domain/core.py new file mode 100644 index 00000000..2fa7853e --- /dev/null +++ b/bdpy/dl/torch/stimulus_domain/core.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + +import torch.nn as nn + + +class Domain(nn.Module, ABC): + """Base class for stimulus domain. + + This class is used to convert stimulus between each domain and library's internal common space. + """ + + @abstractmethod + def send(self, x: torch.Tensor) -> torch.Tensor: + """Send stimulus to the internal common space from each domain. + + Parameters + ---------- + x : torch.Tensor + Stimulus in the original domain. + + Returns + ------- + torch.Tensor + Stimulus in the internal common space. + """ + pass + + @abstractmethod + def receive(self, x: torch.Tensor) -> torch.Tensor: + """Receive stimulus from the internal common space to each domain. + + Parameters + ---------- + x : torch.Tensor + Stimulus in the internal common space. + + Returns + ------- + torch.Tensor + Stimulus in the original domain. + """ + pass + + +class IrreversibleDomain(Domain): + """The domain which cannot be reversed. + + This class is used to convert stimulus between each domain and library's + internal common space. Note that the subclasses of this class do not + guarantee the reversibility of `send` and `receive` methods. + """ + + def send(self, x: torch.Tensor) -> torch.Tensor: + return x + + def receive(self, x: torch.Tensor) -> torch.Tensor: + return x + + +class ComposedDomain(Domain): + """The domain composed of multiple domains.""" + + def __init__(self, domains: list[Domain]) -> None: + super().__init__() + self.domains = nn.ModuleList(domains) + + def send(self, x: torch.Tensor) -> torch.Tensor: + for domain in reversed(self.domains): + x = domain.send(x) + return x + + def receive(self, x: torch.Tensor) -> torch.Tensor: + for domain in self.domains: + x = domain.receive(x) + return x diff --git a/bdpy/dl/torch/stimulus_domain/image_domain.py b/bdpy/dl/torch/stimulus_domain/image_domain.py new file mode 100644 index 00000000..bb0be8f1 --- /dev/null +++ b/bdpy/dl/torch/stimulus_domain/image_domain.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import warnings + +import numpy as np +import torch + +from .core import Domain, IrreversibleDomain, ComposedDomain + + +def _bgr2rgb(images: torch.Tensor) -> torch.Tensor: + """Convert images from BGR to RGB""" + return images[:, [2, 1, 0], ...] + + +def _rgb2bgr(images: torch.Tensor) -> torch.Tensor: + """Convert images from RGB to BGR""" + return images[:, [2, 1, 0], ...] + + +def _to_channel_first(images: torch.Tensor) -> torch.Tensor: + """Convert images from channel last to channel first""" + return images.permute(0, 3, 1, 2) + + +def _to_channel_last(images: torch.Tensor) -> torch.Tensor: + """Convert images from channel first to channel last""" + return images.permute(0, 2, 3, 1) + + + +class Zero2OneImageDomain(Domain): + """Image domain for images in [0, 1]. + + - Channel axis: 1 + - Pixel range: [0, 1] + - Image size: arbitrary + - Color space: RGB + """ + + def send(self, images: torch.Tensor) -> torch.Tensor: + return images + + def receive(self, images: torch.Tensor) -> torch.Tensor: + return images + + +InternalImageDomain = Zero2OneImageDomain + + +class AffineDomain(Domain): + """Image domain shifted by center and scaled by scale. + + This domain is used to convert images in [0, 1] to images in [-center, scale-center]. + In other words, the pixel intensity p in [0, 1] is converted to p * scale - center. + """ + + def __init__( + self, + center: np.ndarray, + scale: float | np.ndarray, + *, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> None: + super().__init__() + + if center.ndim == 1: # 1D vector (C,) + center = center[np.newaxis, :, np.newaxis, np.newaxis] + elif center.ndim == 3: # 3D vector (1, C, W, H) + center = center[np.newaxis] + else: + raise ValueError( + f"center must be 1D or 3D vector, but got {center.ndim}D vector." + ) + if isinstance(scale, (float, int)) or scale.ndim == 0: + scale = np.array([scale])[np.newaxis, np.newaxis, np.newaxis] + elif scale.ndim == 1: # 1D vector (C,) + scale = scale[np.newaxis, :, np.newaxis, np.newaxis] + elif scale.ndim == 3: # 3D vector (1, C, W, H) + scale = scale[np.newaxis] + else: + raise ValueError( + f"scale must be scalar or 1D or 3D vector, but got {scale.ndim}D vector." + ) + + self._center = torch.from_numpy(center).to(device=device, dtype=dtype) + self._scale = torch.from_numpy(scale).to(device=device, dtype=dtype) + + def send(self, images: torch.Tensor) -> torch.Tensor: + return (images + self._center) / self._scale + + def receive(self, images: torch.Tensor) -> torch.Tensor: + return images * self._scale - self._center + + +class BGRDomain(Domain): + """Image domain for BGR images.""" + + def send(self, images: torch.Tensor) -> torch.Tensor: + return _bgr2rgb(images) + + def receive(self, images: torch.Tensor) -> torch.Tensor: + return _rgb2bgr(images) + + +class PILDomainWithExplicitCrop(IrreversibleDomain): + """Image domain for PIL images. + + - Channel axis: 3 + - Pixel range: [0, 255] + - Image size: arbitrary + - Color space: RGB + """ + + def send(self, images: torch.Tensor) -> torch.Tensor: + warnings.warn( + "PILDomainWithExplicitCrop is an irreversible domain. " \ + "It does not guarantee the reversibility of `send` and `receive` " \ + "methods. Please use PILDomainWithExplicitCrop.send() with caution.", + RuntimeWarning, + ) + return _to_channel_first(images) / 255.0 # to [0, 1.0] + + def receive(self, images: torch.Tensor) -> torch.Tensor: + images = _to_channel_last(images) * 255.0 + + # Crop values to [0, 255] + return torch.clamp(images, 0, 255) + + +class BdPyVGGDomain(ComposedDomain): + """Image domain for VGG architecture defined in BdPy. + + - Channel axis: 1 + - Pixel range: + - red: [-123, 132] + - green: [-117, 138] + - blue: [-104, 151] + # These values are calculated from the mean vector of ImageNet ([123, 117, 104]). + - Image size: arbitrary + - Color space: BGR + """ + + def __init__( + self, *, device: torch.device | None = None, dtype: torch.dtype | None = None + ) -> None: + super().__init__( + [ + AffineDomain( + center=np.array([123.0, 117.0, 104.0]), + scale=255.0, + device=device, + dtype=dtype, + ), + BGRDomain(), + ] + ) From bf1d4ecf6dff81d7bb1c1f1d60e2ef44687c0b2a Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 14 Dec 2023 18:56:32 +0900 Subject: [PATCH 002/117] update type annotation --- bdpy/dl/torch/stimulus_domain/core.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/bdpy/dl/torch/stimulus_domain/core.py b/bdpy/dl/torch/stimulus_domain/core.py index 2fa7853e..9acec49a 100644 --- a/bdpy/dl/torch/stimulus_domain/core.py +++ b/bdpy/dl/torch/stimulus_domain/core.py @@ -1,9 +1,13 @@ from __future__ import annotations from abc import ABC, abstractmethod +from typing import Iterable, TYPE_CHECKING import torch.nn as nn +if TYPE_CHECKING: + import torch + class Domain(nn.Module, ABC): """Base class for stimulus domain. @@ -62,7 +66,7 @@ def receive(self, x: torch.Tensor) -> torch.Tensor: class ComposedDomain(Domain): """The domain composed of multiple domains.""" - def __init__(self, domains: list[Domain]) -> None: + def __init__(self, domains: Iterable[Domain]) -> None: super().__init__() self.domains = nn.ModuleList(domains) From 3ba67e62306a4d80ba89cba804229beff82b2190 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 14 Dec 2023 19:14:18 +0900 Subject: [PATCH 003/117] add dataset module --- bdpy/dl/torch/dataset.py | 195 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 195 insertions(+) create mode 100644 bdpy/dl/torch/dataset.py diff --git a/bdpy/dl/torch/dataset.py b/bdpy/dl/torch/dataset.py new file mode 100644 index 00000000..c821592f --- /dev/null +++ b/bdpy/dl/torch/dataset.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +from typing import Iterable, Callable, Dict + +from pathlib import Path + +from PIL import Image +import numpy as np +from torch.utils.data import Dataset + +from bdpy.dataform import DecodedFeatures, Features + + +_FeatureTypeNP = Dict[str, np.ndarray] + + +def _removesuffix(s: str, suffix: str) -> str: + """Remove suffix from string. + + Note + ---- + This function is available from Python 3.9 as `str.removesuffix`. We can + remove this function when we drop support for Python 3.8. + + Parameters + ---------- + s : str + String. + suffix : str + Suffix to remove. + + Returns + ------- + str + String without suffix. + """ + if suffix and s.endswith(suffix): + return s[: -len(suffix)] + return s[:] + + +class FeaturesDataset(Dataset): + """Dataset of features. + + Parameters + ---------- + root_path : str | Path + Path to the root directory of features. + layer_path_names : Iterable[str] + List of layer path names. Each layer path name is used to get features + from the root directory so that the layer path name must be a part of + the path to the layer. + stimulus_names : list[str], optional + List of stimulus names. If None, all stimulus names are used. + transform : callable, optional + Callable object which is used to transform features. The callable object + must take a dict of features and return a dict of features. + """ + + def __init__( + self, + root_path: str | Path, + layer_path_names: Iterable[str], + stimulus_names: list[str] | None = None, + transform: Callable[[_FeatureTypeNP], _FeatureTypeNP] | None = None, + ): + self._features_store = Features(Path(root_path).as_posix()) + self._layer_path_names = layer_path_names + if stimulus_names is None: + stimulus_names = self._features_store.labels + self._stimulus_names = stimulus_names + self._transform = transform + + def __len__(self) -> int: + return len(self._stimulus_names) + + def __getitem__(self, index: int) -> _FeatureTypeNP: + stimulus_name = self._stimulus_names[index] + features = {} + for layer_path_name in self._layer_path_names: + feature = self._features_store.get( + layer=layer_path_name, label=stimulus_name + ) + feature = feature[0] # NOTE: remove batch axis + features[layer_path_name] = feature + if self._transform is not None: + features = self._transform(features) + return features + + +class DecodedFeaturesDataset(Dataset): + """Dataset of decoded features. + + Parameters + ---------- + root_path : str | Path + Path to the root directory of decoded features. + layer_path_names : Iterable[str] + List of layer path names. Each layer path name is used to get features + from the root directory so that the layer path name must be a part of + the path to the layer. + subject_id : str + ID of the subject. + roi : str + ROI name. + stimulus_names : list[str], optional + List of stimulus names. If None, all stimulus names are used. + transform : callable, optional + Callable object which is used to transform features. The callable object + must take a dict of features and return a dict of features. + """ + + def __init__( + self, + root_path: str | Path, + layer_path_names: Iterable[str], + subject_id: str, + roi: str, + stimulus_names: list[str] | None = None, + transform: Callable[[_FeatureTypeNP], _FeatureTypeNP] | None = None, + ): + self._decoded_features_store = DecodedFeatures(Path(root_path).as_posix()) + self._layer_path_names = layer_path_names + self._subject_id = subject_id + self._roi = roi + if stimulus_names is None: + stimulus_names = self._decoded_features_store.labels + assert stimulus_names is not None + self._stimulus_names = stimulus_names + self._transform = transform + + def __len__(self) -> int: + return len(self._stimulus_names) + + def __getitem__(self, index: int) -> _FeatureTypeNP: + stimulus_name = self._stimulus_names[index] + decoded_features = {} + for layer_path_name in self._layer_path_names: + decoded_feature = self._decoded_features_store.get( + layer=layer_path_name, + label=stimulus_name, + subject=self._subject_id, + roi=self._roi, + ) + decoded_feature = decoded_feature[0] # NOTE: remove batch axis + decoded_features[layer_path_name] = decoded_feature + if self._transform is not None: + decoded_features = self._transform(decoded_features) + return decoded_features + + +class ImageDataset(Dataset): + """Dataset of images. + + Parameters + ---------- + root_path : str | Path + Path to the root directory of images. + stimulus_names : list[str], optional + List of stimulus names. If None, all stimulus names are used. + extension : str, optional + Extension of the image files. + """ + + def __init__( + self, + root_path: str | Path, + stimulus_names: list[str] | None = None, + extension: str = "jpg", + ): + self.root_path = root_path + if stimulus_names is None: + stimulus_names = [ + _removesuffix(path.name, "." + extension) + for path in Path(root_path).glob(f"*{extension}") + ] + self._stimulus_names = stimulus_names + self._extension = extension + + def __len__(self): + return len(self._stimulus_names) + + def __getitem__(self, index: int): + stimulus_name = self._stimulus_names[index] + image = Image.open(Path(self.root_path) / f"{stimulus_name}.{self._extension}") + image = image.convert("RGB") + return np.array(image) / 255.0, stimulus_name + + +class RenameFeatureKeys: + def __init__(self, mapping: dict[str, str]): + self._mapping = mapping + + def __call__(self, features: _FeatureTypeNP) -> _FeatureTypeNP: + return {self._mapping.get(key, key): value for key, value in features.items()} From f3a84d584cb0a10233362ace7b3a431774c4f25e Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 14 Dec 2023 19:57:54 +0900 Subject: [PATCH 004/117] add interface for the feature inversion pipeline --- bdpy/dl/torch/stimulus_domain/__init__.py | 1 + bdpy/recon/torch/interface.py | 39 +++++++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 bdpy/recon/torch/interface.py diff --git a/bdpy/dl/torch/stimulus_domain/__init__.py b/bdpy/dl/torch/stimulus_domain/__init__.py index e69de29b..a6743925 100644 --- a/bdpy/dl/torch/stimulus_domain/__init__.py +++ b/bdpy/dl/torch/stimulus_domain/__init__.py @@ -0,0 +1 @@ +from .core import Domain, IrreversibleDomain, ComposedDomain \ No newline at end of file diff --git a/bdpy/recon/torch/interface.py b/bdpy/recon/torch/interface.py new file mode 100644 index 00000000..ea86948f --- /dev/null +++ b/bdpy/recon/torch/interface.py @@ -0,0 +1,39 @@ +from typing import Dict, Protocol, Iterable, TYPE_CHECKING + +if TYPE_CHECKING: + import torch + import torch.nn as nn + + FeatureType = Dict[str, torch.Tensor] + + +class Encoder(Protocol): + def __call__(self, image: torch.Tensor) -> FeatureType: + ... + + +class Generator(Protocol): + def __call__(self, latent: torch.Tensor) -> torch.Tensor: + ... + + def parameters(self) -> Iterable[torch.Tensor]: + ... + + def reset_state(self) -> None: + ... + + +class Latent(Protocol): + def __call__(self) -> torch.Tensor: + ... + + def parameters(self) -> Iterable[torch.Tensor]: + ... + + def reset_state(self) -> None: + ... + + +class Critic(Protocol): + def __call__(self, features: FeatureType, target_features: FeatureType) -> torch.Tensor: + ... From ab354a2d98db59413772caf2fe60142e19fbbd08 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 14 Dec 2023 19:58:07 +0900 Subject: [PATCH 005/117] add encoder module --- bdpy/recon/torch/modules/__init__.py | 0 bdpy/recon/torch/modules/encoder.py | 83 ++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) create mode 100644 bdpy/recon/torch/modules/__init__.py create mode 100644 bdpy/recon/torch/modules/encoder.py diff --git a/bdpy/recon/torch/modules/__init__.py b/bdpy/recon/torch/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bdpy/recon/torch/modules/encoder.py b/bdpy/recon/torch/modules/encoder.py new file mode 100644 index 00000000..accfebc2 --- /dev/null +++ b/bdpy/recon/torch/modules/encoder.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import Iterable + +import torch +import torch.nn as nn +from bdpy.dl.torch import FeatureExtractor +from bdpy.dl.torch.stimulus_domain import Domain, image_domain + + +class EncoderBase(nn.Module): + """Encoder network module. + + Parameters + ---------- + feature_network : nn.Module + Feature network. This network should have a method `forward` that takes + an image tensor and propagates it through the network. + layer_names : list[str] + Layer names to extract features from. + domain : Domain + Domain of the input images to receive. + device : torch.device + Device to use. + """ + + def __init__( + self, + feature_network: nn.Module, + layer_names: Iterable[str], + domain: Domain, + device: str | torch.device, + ) -> None: + super().__init__() + self._feature_extractor = FeatureExtractor( + encoder=feature_network, layers=layer_names, detach=False, device=device + ) + self._domain = domain + self._feature_network = self._feature_extractor._encoder + + def forward(self, images: torch.Tensor) -> dict[str, torch.Tensor]: + """Forward pass through the encoder network. + + Parameters + ---------- + images : torch.Tensor + Images. + + Returns + ------- + dict[str, torch.Tensor] + Features indexed by the layer names. + """ + images = self._domain.receive(images) + return self._feature_extractor(images) + + +def build_encoder( + feature_network: nn.Module, + layer_names: Iterable[str], + domain: Domain = image_domain.Zero2OneImageDomain(), + device: str | torch.device = "cpu", +) -> EncoderBase: + """Build an encoder network. + + Parameters + ---------- + feature_network : nn.Module + Feature network. This network should have a method `forward` that takes + an image tensor and propagates it through the network. + layer_names : list[str] + Layer names to extract features from. + domain : Domain, optional + Domain of the input images to receive (default: Zero2OneImageDomain()). + device : torch.device, optional + Device to use. (default: "cpu"). + + Returns + ------- + EncoderBase + Encoder network. + """ + return EncoderBase(feature_network, layer_names, domain, device) From 083369fda016129911e7abce31b096e1c0b99b1b Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 14 Dec 2023 19:58:15 +0900 Subject: [PATCH 006/117] refactor type annotations --- bdpy/dl/torch/torch.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/bdpy/dl/torch/torch.py b/bdpy/dl/torch/torch.py index 5b927d30..7d325248 100644 --- a/bdpy/dl/torch/torch.py +++ b/bdpy/dl/torch/torch.py @@ -1,5 +1,7 @@ '''PyTorch module.''' +from __future__ import annotations + from typing import Iterable, List, Dict, Union, Tuple, Any, Callable, Optional import os @@ -56,10 +58,10 @@ def __init__( layer_object = models._parse_layer_name(self._encoder, layer) layer_object.register_forward_hook(self._extractor) - def __call__(self, x: _tensor_t) -> Dict[str, _tensor_t]: + def __call__(self, x: _tensor_t) -> Dict[str, np.ndarray] | Dict[str, torch.Tensor]: return self.run(x) - def run(self, x: _tensor_t) -> Dict[str, _tensor_t]: + def run(self, x: _tensor_t) -> Dict[str, np.ndarray] | Dict[str, torch.Tensor]: '''Extract feature activations from the specified layers. Parameters @@ -82,17 +84,17 @@ def run(self, x: _tensor_t) -> Dict[str, _tensor_t]: self._encoder.forward(xt) - features: Dict[str, _tensor_t] = { + features: Dict[str, torch.Tensor] = { layer: self._extractor.outputs[i] for i, layer in enumerate(self.__layers) } - if self.__detach: - features = { - k: v.cpu().detach().numpy() - for k, v in features.items() - } + if not self.__detach: + return features - return features + return { + k: v.cpu().detach().numpy() + for k, v in features.items() + } class FeatureExtractorHandle(object): From b29f79f2d595a0a15c44e0d84f0280c9eade1d76 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 14 Dec 2023 20:34:44 +0900 Subject: [PATCH 007/117] update encoder interface --- bdpy/recon/torch/modules/encoder.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/bdpy/recon/torch/modules/encoder.py b/bdpy/recon/torch/modules/encoder.py index accfebc2..94e9b122 100644 --- a/bdpy/recon/torch/modules/encoder.py +++ b/bdpy/recon/torch/modules/encoder.py @@ -18,18 +18,18 @@ class EncoderBase(nn.Module): an image tensor and propagates it through the network. layer_names : list[str] Layer names to extract features from. - domain : Domain - Domain of the input images to receive. - device : torch.device - Device to use. + domain : Domain, optional + Domain of the input images to receive. (default: Zero2OneImageDomain()) + device : torch.device, optional + Device to use. (default: "cpu"). """ def __init__( self, feature_network: nn.Module, layer_names: Iterable[str], - domain: Domain, - device: str | torch.device, + domain: Domain = image_domain.Zero2OneImageDomain(), + device: str | torch.device = "cpu", ) -> None: super().__init__() self._feature_extractor = FeatureExtractor( From 5d81a82bc95aeaa52772719328449b4923727095 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 14 Dec 2023 20:34:52 +0900 Subject: [PATCH 008/117] add generator module --- bdpy/recon/torch/interface.py | 4 +- bdpy/recon/torch/modules/generator.py | 245 ++++++++++++++++++++++++++ 2 files changed, 247 insertions(+), 2 deletions(-) create mode 100644 bdpy/recon/torch/modules/generator.py diff --git a/bdpy/recon/torch/interface.py b/bdpy/recon/torch/interface.py index ea86948f..58b2c2ef 100644 --- a/bdpy/recon/torch/interface.py +++ b/bdpy/recon/torch/interface.py @@ -16,7 +16,7 @@ class Generator(Protocol): def __call__(self, latent: torch.Tensor) -> torch.Tensor: ... - def parameters(self) -> Iterable[torch.Tensor]: + def parameters(self, recurse: bool = True) -> Iterable[torch.Tensor]: ... def reset_state(self) -> None: @@ -27,7 +27,7 @@ class Latent(Protocol): def __call__(self) -> torch.Tensor: ... - def parameters(self) -> Iterable[torch.Tensor]: + def parameters(self, recurse: bool = True) -> Iterable[torch.Tensor]: ... def reset_state(self) -> None: diff --git a/bdpy/recon/torch/modules/generator.py b/bdpy/recon/torch/modules/generator.py new file mode 100644 index 00000000..54c9d958 --- /dev/null +++ b/bdpy/recon/torch/modules/generator.py @@ -0,0 +1,245 @@ +from abc import ABC, abstractmethod + +from typing import Callable, Iterator + +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter +from bdpy.dl.torch.stimulus_domain import Domain, image_domain + + +@torch.no_grad() +def reset_all_parameters(module: nn.Module) -> None: + """Reset the parameters of the module.""" + reset_parameters = getattr(module, "reset_parameters", None) + if callable(reset_parameters): + module.reset_parameters() + + +class GeneratorBase(nn.Module, ABC): + """Generator module.""" + + @abstractmethod + def reset_states(self) -> None: + """Reset the state of the generator.""" + pass + + @abstractmethod + def forward(self, latent: torch.Tensor) -> torch.Tensor: + """Forward pass through the generator network. + + Parameters + ---------- + latent : torch.Tensor + Latent vector. + + Returns + ------- + torch.Tensor + Generated image. The generated images must be in the range [0, 1]. + """ + pass + + +class BareGenerator(GeneratorBase): + """Bare generator module. + + This module does not have any trainable parameters. + + Parameters + ---------- + activation : Callable[[torch.Tensor], torch.Tensor], optional + Activation function to apply to the output of the generator, by default nn.Identity() + + Examples + -------- + >>> import torch + >>> from bdpy.recon.torch.modules.generator import BareGenerator + >>> generator = BareGenerator(activation=torch.sigmoid) + >>> latent = torch.randn(1, 3, 64, 64) + >>> generated_image = generator(latent) + >>> generated_image.shape + torch.Size([1, 3, 64, 64]) + """ + + def __init__(self, activation: Callable[[torch.Tensor], torch.Tensor] = nn.Identity()) -> None: + """Initialize the generator.""" + super().__init__() + self._activation = activation + self._domain = image_domain.Zero2OneImageDomain() + + def reset_states(self) -> None: + """Reset the state of the generator.""" + pass + + def forward(self, latent: torch.Tensor) -> torch.Tensor: + """Forward pass through the generator network. + + Parameters + ---------- + latent : torch.Tensor + Latent vector. + + Returns + ------- + torch.Tensor + Generated image. The generated images must be in the range [0, 1]. + """ + return self._domain.send(self._activation(latent)) + + +class DNNGenerator(GeneratorBase): + """DNN generator module. + + This module has the generator network as a submodule and its parameters are + trainable. + + Parameters + ---------- + generator_network : nn.Module + Generator network. This network should have a method `forward` that takes + a latent vector and propagates it through the network. + domain : Domain, optional + Domain of the input images to receive. (default: Zero2OneImageDomain()) + reset_fn : Callable[[nn.Module], None], optional + Function to reset the parameters of the generator network, by default + reset_all_parameters. + + Examples + -------- + >>> import torch + >>> from bdpy.recon.torch.modules.generator import DNNGenerator + >>> generator_network = nn.Sequential( + ... nn.ConvTranspose2d(3, 3, 3), + ... nn.ReLU(), + ... ) + >>> generator = DNNGenerator(generator_network) + >>> latent = torch.randn(1, 3, 64, 64) + >>> generated_image = generator(latent) + >>> generated_image.shape + torch.Size([1, 3, 66, 66]) + """ + + def __init__( + self, + generator_network: nn.Module, + domain: Domain = image_domain.Zero2OneImageDomain(), + reset_fn: Callable[[nn.Module], None] = reset_all_parameters, + ) -> None: + """Initialize the generator.""" + super().__init__() + self._generator_network = generator_network + self._domain = domain + self._reset_fn = reset_fn + + def reset_states(self) -> None: + """Reset the state of the generator.""" + self._generator_network.apply(self._reset_fn) + + def forward(self, latent: torch.Tensor) -> torch.Tensor: + """Forward pass through the generator network. + + Parameters + ---------- + latent : torch.Tensor + Latent vector. + + Returns + ------- + torch.Tensor + Generated image. The generated images must be in the range [0, 1]. + """ + return self._domain.send(self._generator_network(latent)) + + +class FrozenGenerator(DNNGenerator): + """Frozen generator module. + + This module has the generator network as a submodule and its parameters are + frozen. + + Parameters + ---------- + generator_network : nn.Module + Generator network. This network should have a method `forward` that takes + a latent vector and propagates it through the network. + domain : Domain, optional + Domain of the input images to receive. (default: Zero2OneImageDomain()) + + Examples + -------- + >>> import torch + >>> from bdpy.recon.torch.modules.generator import FrozenGenerator + >>> generator_network = nn.Sequential( + ... nn.ConvTranspose2d(3, 3, 3), + ... nn.ReLU(), + ... ) + >>> generator = FrozenGenerator(generator_network) + >>> latent = torch.randn(1, 3, 64, 64) + >>> generated_image = generator(latent) + >>> generated_image.shape + torch.Size([1, 3, 66, 66]) + """ + + def __init__( + self, + generator_network: nn.Module, + domain: Domain = image_domain.Zero2OneImageDomain() + ) -> None: + """Initialize the generator.""" + super().__init__(generator_network, domain=domain, reset_fn=lambda _: None) + for param in self._generator_network.parameters(): + param.requires_grad = False + + def reset_states(self) -> None: + """Reset the state of the generator.""" + pass + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + return iter([]) + + +def build_generator( + generator_network: nn.Module, + domain: Domain = image_domain.Zero2OneImageDomain(), + reset_fn: Callable[[nn.Module], None] = reset_all_parameters, + frozen: bool = True, +) -> GeneratorBase: + """Build a generator module. + + Parameters + ---------- + generator_network : nn.Module + Generator network. This network should have a method `forward` that takes + a latent vector and propagates it through the network. + domain : Domain, optional + Domain of the input images to receive. (default: Zero2OneImageDomain()) + reset_fn : Callable[[nn.Module], None], optional + Function to reset the parameters of the generator network, by default + reset_all_parameters. + frozen : bool, optional + Whether to freeze the parameters of the generator network, by default True. + + Returns + ------- + GeneratorBase + Generator module. + + Examples + -------- + >>> import torch + >>> from bdpy.recon.torch.modules.generator import build_generator + >>> generator_network = nn.Sequential( + ... nn.ConvTranspose2d(3, 3, 3), + ... nn.ReLU(), + ... ) + >>> generator = build_generator(generator_network) + >>> latent = torch.randn(1, 3, 64, 64) + >>> generated_image = generator(latent) + >>> generated_image.shape + torch.Size([1, 3, 66, 66]) + """ + if frozen: + return FrozenGenerator(generator_network, domain=domain) + else: + return DNNGenerator(generator_network, domain=domain, reset_fn=reset_fn) From 524367335f7c49deff232b157255ebc33d239ea4 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 14 Dec 2023 20:37:53 +0900 Subject: [PATCH 009/117] update generator module --- bdpy/recon/torch/modules/generator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/bdpy/recon/torch/modules/generator.py b/bdpy/recon/torch/modules/generator.py index 54c9d958..ff30797e 100644 --- a/bdpy/recon/torch/modules/generator.py +++ b/bdpy/recon/torch/modules/generator.py @@ -188,8 +188,7 @@ def __init__( ) -> None: """Initialize the generator.""" super().__init__(generator_network, domain=domain, reset_fn=lambda _: None) - for param in self._generator_network.parameters(): - param.requires_grad = False + self._generator_network.eval() def reset_states(self) -> None: """Reset the state of the generator.""" From 2e0db92f08d4a373e2cba778bf55e1052a9f0d74 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Fri, 15 Dec 2023 11:26:36 +0900 Subject: [PATCH 010/117] update docstring --- bdpy/recon/torch/modules/encoder.py | 30 +++++++++++++++++++++++++++ bdpy/recon/torch/modules/generator.py | 3 +++ 2 files changed, 33 insertions(+) diff --git a/bdpy/recon/torch/modules/encoder.py b/bdpy/recon/torch/modules/encoder.py index 94e9b122..ae84e7c5 100644 --- a/bdpy/recon/torch/modules/encoder.py +++ b/bdpy/recon/torch/modules/encoder.py @@ -22,6 +22,21 @@ class EncoderBase(nn.Module): Domain of the input images to receive. (default: Zero2OneImageDomain()) device : torch.device, optional Device to use. (default: "cpu"). + + Examples + -------- + >>> import torch + >>> import torch.nn as nn + >>> from bdpy.recon.torch.modules.encoder import EncoderBase + >>> feature_network = nn.Sequential( + ... nn.Conv2d(3, 3, 3), + ... nn.ReLU(), + ... ) + >>> encoder = EncoderBase(feature_network, ['0']) + >>> image = torch.randn(1, 3, 64, 64) + >>> features = encoder(image) + >>> features['0'].shape + torch.Size([1, 3, 62, 62]) """ def __init__( @@ -79,5 +94,20 @@ def build_encoder( ------- EncoderBase Encoder network. + + Examples + -------- + >>> import torch + >>> import torch.nn as nn + >>> from bdpy.recon.torch.modules.encoder import build_encoder + >>> feature_network = nn.Sequential( + ... nn.Conv2d(3, 3, 3), + ... nn.ReLU(), + ... ) + >>> encoder = build_encoder(feature_network, ['0']) + >>> image = torch.randn(1, 3, 64, 64) + >>> features = encoder(image) + >>> features['0'].shape + torch.Size([1, 3, 62, 62]) """ return EncoderBase(feature_network, layer_names, domain, device) diff --git a/bdpy/recon/torch/modules/generator.py b/bdpy/recon/torch/modules/generator.py index ff30797e..e4df390a 100644 --- a/bdpy/recon/torch/modules/generator.py +++ b/bdpy/recon/torch/modules/generator.py @@ -108,6 +108,7 @@ class DNNGenerator(GeneratorBase): Examples -------- >>> import torch + >>> import torch.nn as nn >>> from bdpy.recon.torch.modules.generator import DNNGenerator >>> generator_network = nn.Sequential( ... nn.ConvTranspose2d(3, 3, 3), @@ -169,6 +170,7 @@ class FrozenGenerator(DNNGenerator): Examples -------- >>> import torch + >>> import torch.nn as nn >>> from bdpy.recon.torch.modules.generator import FrozenGenerator >>> generator_network = nn.Sequential( ... nn.ConvTranspose2d(3, 3, 3), @@ -227,6 +229,7 @@ def build_generator( Examples -------- >>> import torch + >>> import torch.nn as nn >>> from bdpy.recon.torch.modules.generator import build_generator >>> generator_network = nn.Sequential( ... nn.ConvTranspose2d(3, 3, 3), From 463f7b555a6ce0bb8ecca737c787897d6b058d21 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Fri, 15 Dec 2023 11:37:58 +0900 Subject: [PATCH 011/117] add latent module --- bdpy/recon/torch/modules/latent.py | 69 ++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 bdpy/recon/torch/modules/latent.py diff --git a/bdpy/recon/torch/modules/latent.py b/bdpy/recon/torch/modules/latent.py new file mode 100644 index 00000000..00ac3377 --- /dev/null +++ b/bdpy/recon/torch/modules/latent.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Callable + +import torch +import torch.nn as nn + + +class LatentBase(nn.Module, ABC): + """Latent variable module.""" + + @abstractmethod + def reset_states(self) -> None: + """Reset the state of the latent variable.""" + pass + + @abstractmethod + def forward(self) -> torch.Tensor: + """Generate a latent variable. + + Returns + ------- + torch.Tensor + Latent variable. + """ + pass + + +class ArbitraryLatent(LatentBase): + """Latent variable with arbitrary shape and initialization function. + + Parameters + ---------- + shape : tuple[int, ...] + Shape of the latent variable including the batch dimension. + init_fn : Callable[[torch.Tensor], None] + Function to initialize the latent variable. + + Examples + -------- + >>> from functools import partial + >>> import torch + >>> import torch.nn as nn + >>> from bdpy.recon.torch.modules.latent import ArbitraryLatent + >>> latent = ArbitraryLatent((1, 3, 64, 64), partial(nn.init.normal_, mean=0, std=1)) + >>> latent().shape + torch.Size([1, 3, 64, 64]) + """ + + def __init__(self, shape: tuple[int, ...], init_fn: Callable[[torch.Tensor], None]) -> None: + super().__init__() + self._shape = shape + self._init_fn = init_fn + self._latent = torch.empty(shape) + + def reset_states(self) -> None: + """Reset the state of the latent variable.""" + self._init_fn(self._latent) + + def forward(self) -> torch.Tensor: + """Generate a latent variable. + + Returns + ------- + torch.Tensor + Latent variable. + """ + return self._latent From 3e50a3cb7b50fadc4d1826e984bdcf5d671efdcf Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Fri, 15 Dec 2023 11:41:21 +0900 Subject: [PATCH 012/117] remove unused import --- bdpy/recon/torch/interface.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bdpy/recon/torch/interface.py b/bdpy/recon/torch/interface.py index 58b2c2ef..75d52cc6 100644 --- a/bdpy/recon/torch/interface.py +++ b/bdpy/recon/torch/interface.py @@ -2,7 +2,6 @@ if TYPE_CHECKING: import torch - import torch.nn as nn FeatureType = Dict[str, torch.Tensor] From c8f99c35e2c030adf9f035830243cc1e28f72c37 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Fri, 15 Dec 2023 11:45:23 +0900 Subject: [PATCH 013/117] update encoder interface --- bdpy/recon/torch/modules/encoder.py | 33 +++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/bdpy/recon/torch/modules/encoder.py b/bdpy/recon/torch/modules/encoder.py index ae84e7c5..7da42d74 100644 --- a/bdpy/recon/torch/modules/encoder.py +++ b/bdpy/recon/torch/modules/encoder.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import ABC, abstractmethod from typing import Iterable import torch @@ -8,8 +9,28 @@ from bdpy.dl.torch.stimulus_domain import Domain, image_domain -class EncoderBase(nn.Module): - """Encoder network module. +class EncoderBase(nn.Module, ABC): + """Encoder network module.""" + + @abstractmethod + def forward(self, images: torch.Tensor) -> dict[str, torch.Tensor]: + """Forward pass through the encoder network. + + Parameters + ---------- + images : torch.Tensor + Images. + + Returns + ------- + dict[str, torch.Tensor] + Features indexed by the layer names. + """ + pass + + +class SimpleEncoder(EncoderBase): + """Encoder network module with a naive feature extractor. Parameters ---------- @@ -27,12 +48,12 @@ class EncoderBase(nn.Module): -------- >>> import torch >>> import torch.nn as nn - >>> from bdpy.recon.torch.modules.encoder import EncoderBase + >>> from bdpy.recon.torch.modules.encoder import SimpleEncoder >>> feature_network = nn.Sequential( ... nn.Conv2d(3, 3, 3), ... nn.ReLU(), ... ) - >>> encoder = EncoderBase(feature_network, ['0']) + >>> encoder = SimpleEncoder(feature_network, ['0']) >>> image = torch.randn(1, 3, 64, 64) >>> features = encoder(image) >>> features['0'].shape @@ -76,7 +97,7 @@ def build_encoder( domain: Domain = image_domain.Zero2OneImageDomain(), device: str | torch.device = "cpu", ) -> EncoderBase: - """Build an encoder network. + """Build an encoder network with a naive feature extractor. Parameters ---------- @@ -110,4 +131,4 @@ def build_encoder( >>> features['0'].shape torch.Size([1, 3, 62, 62]) """ - return EncoderBase(feature_network, layer_names, domain, device) + return SimpleEncoder(feature_network, layer_names, domain, device) From 5e9e77c042a9407e1376475fccba37d0e990c2b7 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Fri, 15 Dec 2023 11:46:21 +0900 Subject: [PATCH 014/117] update docstring --- bdpy/recon/torch/modules/encoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bdpy/recon/torch/modules/encoder.py b/bdpy/recon/torch/modules/encoder.py index 7da42d74..2a62c823 100644 --- a/bdpy/recon/torch/modules/encoder.py +++ b/bdpy/recon/torch/modules/encoder.py @@ -125,7 +125,7 @@ def build_encoder( ... nn.Conv2d(3, 3, 3), ... nn.ReLU(), ... ) - >>> encoder = build_encoder(feature_network, ['0']) + >>> encoder = build_encoder(feature_network, layer_names=['0']) >>> image = torch.randn(1, 3, 64, 64) >>> features = encoder(image) >>> features['0'].shape From 1f6032fde17786f6af253c7007af32fc92495c8f Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Fri, 15 Dec 2023 11:57:06 +0900 Subject: [PATCH 015/117] add critic module --- bdpy/recon/torch/modules/critic.py | 90 ++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 bdpy/recon/torch/modules/critic.py diff --git a/bdpy/recon/torch/modules/critic.py b/bdpy/recon/torch/modules/critic.py new file mode 100644 index 00000000..73c058f0 --- /dev/null +++ b/bdpy/recon/torch/modules/critic.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn + + +class CriticBase(nn.Module, ABC): + """Critic network module.""" + + @abstractmethod + def criterion( + self, feature: torch.Tensor, target_feature: torch.Tensor, layer_name: str + ) -> torch.Tensor: + """Loss function per layer. + + Parameters + ---------- + feature : torch.Tensor + Feature tensor of the layer specified by `layer_name`. + target_feature : torch.Tensor + Target feature tensor of the layer specified by `layer_name`. + layer_name : str + Layer name. + + Returns + ------- + torch.Tensor + Loss value of the layer specified by `layer_name`. + """ + pass + + def forward( + self, + features: dict[str, torch.Tensor], + target_features: dict[str, torch.Tensor], + ) -> torch.Tensor: + """Forward pass through the critic network. + + Parameters + ---------- + features : dict[str, torch.Tensor] + Features indexed by the layer names. + target_features : dict[str, torch.Tensor] + Target features indexed by the layer names. + + Returns + ------- + torch.Tensor + Loss value. + """ + loss = 0.0 + counts = 0 + for layer_name, feature in features.items(): + target_feature = target_features[layer_name] + layer_wise_loss = self.criterion( + feature, target_feature, layer_name=layer_name + ) + loss += layer_wise_loss + counts += 1 + return loss / counts + + +class TargetNormalizedMSE(CriticBase): + """MSE loss divided by the squared norm of the target feature.""" + + def criterion( + self, feature: torch.Tensor, target_feature: torch.Tensor, layer_name: str + ) -> torch.Tensor: + """Loss function per layer. + + Parameters + ---------- + feature : torch.Tensor + Feature tensor of the layer specified by `layer_name`. + target_feature : torch.Tensor + Target feature tensor of the layer specified by `layer_name`. + layer_name : str + Layer name. + + Returns + ------- + torch.Tensor + Loss value of the layer specified by `layer_name`. + """ + squared_norm = (target_feature ** 2).sum(dim=tuple(range(1, target_feature.ndim))) + return ((feature - target_feature) ** 2).sum( + dim=tuple(range(1, feature.ndim)) + ) / squared_norm From 07196b39ec19279ef0b099cfd7cfe680e98b21d7 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Fri, 15 Dec 2023 12:15:24 +0900 Subject: [PATCH 016/117] update docstring --- bdpy/recon/torch/modules/generator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bdpy/recon/torch/modules/generator.py b/bdpy/recon/torch/modules/generator.py index e4df390a..be5a070c 100644 --- a/bdpy/recon/torch/modules/generator.py +++ b/bdpy/recon/torch/modules/generator.py @@ -197,6 +197,7 @@ def reset_states(self) -> None: pass def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + """Return an empty iterator.""" return iter([]) From c780262ab14cd210c6b815c4983fa6774764c806 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Fri, 15 Dec 2023 12:15:40 +0900 Subject: [PATCH 017/117] add feature inversion pipeline --- bdpy/recon/torch/modules/__init__.py | 5 ++ bdpy/recon/torch/pipeline/__init__.py | 0 bdpy/recon/torch/pipeline/inversion.py | 120 +++++++++++++++++++++++++ 3 files changed, 125 insertions(+) create mode 100644 bdpy/recon/torch/pipeline/__init__.py create mode 100644 bdpy/recon/torch/pipeline/inversion.py diff --git a/bdpy/recon/torch/modules/__init__.py b/bdpy/recon/torch/modules/__init__.py index e69de29b..da4456a0 100644 --- a/bdpy/recon/torch/modules/__init__.py +++ b/bdpy/recon/torch/modules/__init__.py @@ -0,0 +1,5 @@ +from .encoder import build_encoder +from .generator import build_generator +from .latent import ArbitraryLatent +from .critic import TargetNormalizedMSE + diff --git a/bdpy/recon/torch/pipeline/__init__.py b/bdpy/recon/torch/pipeline/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bdpy/recon/torch/pipeline/inversion.py b/bdpy/recon/torch/pipeline/inversion.py new file mode 100644 index 00000000..e20f0a48 --- /dev/null +++ b/bdpy/recon/torch/pipeline/inversion.py @@ -0,0 +1,120 @@ +from typing import Dict + +from itertools import chain + +import torch + +from ..interface import Encoder, Generator, Latent, Critic + +FeatureType = Dict[str, torch.Tensor] + + +class FeatureInversionPipeline: + """Feature inversion pipeline. + + Parameters + ---------- + encoder : Encoder + Encoder module. + generator : Generator + Generator module. + latent : Latent + Latent variable module. + critic : Critic + Critic module. + optimizer : torch.optim.Optimizer + Optimizer. + scheduler : torch.optim.lr_scheduler.LRScheduler, optional + Learning rate scheduler, by default None. + num_iterations : int, optional + Number of iterations, by default 1. + log_interval : int, optional + Log interval, by default -1. If -1, logging is disabled. + + Examples + -------- + >>> import torch + >>> import torch.nn as nn + >>> from bdpy.recon.torch.pipeline import FeatureInversionPipeline + >>> from bdpy.recon.torch.modules import build_encoder, build_generator, ArbitraryLatent, TargetNormalizedMSE + >>> encoder = build_encoder(...) + >>> generator = build_generator(...) + >>> latent = ArbitraryLatent(...) + >>> critic = TargetNormalizedMSE(...) + >>> optimizer = torch.optim.Adam(latent.parameters()) + >>> pipeline = FeatureInversionPipeline( + ... encoder, generator, latent, critic, optimizer + ... ) + >>> target_features = encoder(target_image) + >>> pipeline.reset_state() + >>> reconstructed_image = pipeline(target_features) + """ + + def __init__( + self, + encoder: Encoder, + generator: Generator, + latent: Latent, + critic: Critic, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler = None, + num_iterations: int = 1, + log_interval: int = -1, + ) -> None: + self._encoder = encoder + self._generator = generator + self._latent = latent + self._critic = critic + self._optimizer = optimizer + self._scheduler = scheduler + + self._num_iterations = num_iterations + self._log_interval = log_interval + + def __call__( + self, + target_features: FeatureType, + ) -> torch.Tensor: + """Run feature inversion given target features. + + Parameters + ---------- + target_features : FeatureType + Target features. + + Returns + ------- + torch.Tensor + Reconstructed images which have the similar features to the target features. + """ + for step in range(self._num_iterations): + self._optimizer.zero_grad() + + latent = self._latent() + generated_image = self._generator(latent) + + features = self._encoder(generated_image) + + loss = self._critic(features, target_features) + loss.backward() + + self._optimizer.step() + if self._scheduler is not None: + self._scheduler.step() + + if self._log_interval > 0 and step % self._log_interval == 0: + print(f"Step: [{step+1}/{self._num_iterations}], Loss: {loss.item():.4f}") + + return self._generator(self._latent()).detach() + + def reset_state(self) -> None: + """Reset the state of the pipeline.""" + self._generator.reset_state() + self._latent.reset_state() + self._optimizer = self._optimizer.__class__( + chain( + self._generator.parameters(), + self._latent.parameters(), + ), + **self._optimizer.defaults + ) From b753140e086b45dfc56ee378143887baaa7f38de Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Fri, 15 Dec 2023 12:18:38 +0900 Subject: [PATCH 018/117] update imports --- bdpy/recon/torch/pipeline/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bdpy/recon/torch/pipeline/__init__.py b/bdpy/recon/torch/pipeline/__init__.py index e69de29b..0569d5ab 100644 --- a/bdpy/recon/torch/pipeline/__init__.py +++ b/bdpy/recon/torch/pipeline/__init__.py @@ -0,0 +1 @@ +from .inversion import FeatureInversionPipeline \ No newline at end of file From 7a361df09e8fef890dc1a8d9868f6352d84bc4da Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Fri, 15 Dec 2023 15:23:56 +0900 Subject: [PATCH 019/117] rename XYZBase -> BaseXYZ --- bdpy/recon/torch/modules/critic.py | 4 ++-- bdpy/recon/torch/modules/encoder.py | 8 ++++---- bdpy/recon/torch/modules/generator.py | 10 +++++----- bdpy/recon/torch/modules/latent.py | 4 ++-- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/bdpy/recon/torch/modules/critic.py b/bdpy/recon/torch/modules/critic.py index 73c058f0..7567a6dd 100644 --- a/bdpy/recon/torch/modules/critic.py +++ b/bdpy/recon/torch/modules/critic.py @@ -6,7 +6,7 @@ import torch.nn as nn -class CriticBase(nn.Module, ABC): +class BaseCritic(nn.Module, ABC): """Critic network module.""" @abstractmethod @@ -62,7 +62,7 @@ def forward( return loss / counts -class TargetNormalizedMSE(CriticBase): +class TargetNormalizedMSE(BaseCritic): """MSE loss divided by the squared norm of the target feature.""" def criterion( diff --git a/bdpy/recon/torch/modules/encoder.py b/bdpy/recon/torch/modules/encoder.py index 2a62c823..fb18f91b 100644 --- a/bdpy/recon/torch/modules/encoder.py +++ b/bdpy/recon/torch/modules/encoder.py @@ -9,7 +9,7 @@ from bdpy.dl.torch.stimulus_domain import Domain, image_domain -class EncoderBase(nn.Module, ABC): +class BaseEncoder(nn.Module, ABC): """Encoder network module.""" @abstractmethod @@ -29,7 +29,7 @@ def forward(self, images: torch.Tensor) -> dict[str, torch.Tensor]: pass -class SimpleEncoder(EncoderBase): +class SimpleEncoder(BaseEncoder): """Encoder network module with a naive feature extractor. Parameters @@ -96,7 +96,7 @@ def build_encoder( layer_names: Iterable[str], domain: Domain = image_domain.Zero2OneImageDomain(), device: str | torch.device = "cpu", -) -> EncoderBase: +) -> BaseEncoder: """Build an encoder network with a naive feature extractor. Parameters @@ -113,7 +113,7 @@ def build_encoder( Returns ------- - EncoderBase + BaseEncoder Encoder network. Examples diff --git a/bdpy/recon/torch/modules/generator.py b/bdpy/recon/torch/modules/generator.py index be5a070c..215c3801 100644 --- a/bdpy/recon/torch/modules/generator.py +++ b/bdpy/recon/torch/modules/generator.py @@ -16,7 +16,7 @@ def reset_all_parameters(module: nn.Module) -> None: module.reset_parameters() -class GeneratorBase(nn.Module, ABC): +class BaseGenerator(nn.Module, ABC): """Generator module.""" @abstractmethod @@ -41,7 +41,7 @@ def forward(self, latent: torch.Tensor) -> torch.Tensor: pass -class BareGenerator(GeneratorBase): +class BareGenerator(BaseGenerator): """Bare generator module. This module does not have any trainable parameters. @@ -88,7 +88,7 @@ def forward(self, latent: torch.Tensor) -> torch.Tensor: return self._domain.send(self._activation(latent)) -class DNNGenerator(GeneratorBase): +class DNNGenerator(BaseGenerator): """DNN generator module. This module has the generator network as a submodule and its parameters are @@ -206,7 +206,7 @@ def build_generator( domain: Domain = image_domain.Zero2OneImageDomain(), reset_fn: Callable[[nn.Module], None] = reset_all_parameters, frozen: bool = True, -) -> GeneratorBase: +) -> BaseGenerator: """Build a generator module. Parameters @@ -224,7 +224,7 @@ def build_generator( Returns ------- - GeneratorBase + BaseGenerator Generator module. Examples diff --git a/bdpy/recon/torch/modules/latent.py b/bdpy/recon/torch/modules/latent.py index 00ac3377..d73bade1 100644 --- a/bdpy/recon/torch/modules/latent.py +++ b/bdpy/recon/torch/modules/latent.py @@ -7,7 +7,7 @@ import torch.nn as nn -class LatentBase(nn.Module, ABC): +class BaseLatent(nn.Module, ABC): """Latent variable module.""" @abstractmethod @@ -27,7 +27,7 @@ def forward(self) -> torch.Tensor: pass -class ArbitraryLatent(LatentBase): +class ArbitraryLatent(BaseLatent): """Latent variable with arbitrary shape and initialization function. Parameters From 8bb09b6c47d14ee0f4d7c78e287041cb52188bf8 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Fri, 15 Dec 2023 15:59:43 +0900 Subject: [PATCH 020/117] Remove dependency on nn.Module from base class && use base class for type annotation --- bdpy/recon/torch/interface.py | 38 ------------------------ bdpy/recon/torch/modules/__init__.py | 9 +++--- bdpy/recon/torch/modules/critic.py | 40 +++++++++++++++++++++++--- bdpy/recon/torch/modules/encoder.py | 6 ++-- bdpy/recon/torch/modules/generator.py | 23 ++++++++++++--- bdpy/recon/torch/modules/latent.py | 23 ++++++++++++--- bdpy/recon/torch/pipeline/inversion.py | 18 ++++++------ 7 files changed, 90 insertions(+), 67 deletions(-) delete mode 100644 bdpy/recon/torch/interface.py diff --git a/bdpy/recon/torch/interface.py b/bdpy/recon/torch/interface.py deleted file mode 100644 index 75d52cc6..00000000 --- a/bdpy/recon/torch/interface.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Dict, Protocol, Iterable, TYPE_CHECKING - -if TYPE_CHECKING: - import torch - - FeatureType = Dict[str, torch.Tensor] - - -class Encoder(Protocol): - def __call__(self, image: torch.Tensor) -> FeatureType: - ... - - -class Generator(Protocol): - def __call__(self, latent: torch.Tensor) -> torch.Tensor: - ... - - def parameters(self, recurse: bool = True) -> Iterable[torch.Tensor]: - ... - - def reset_state(self) -> None: - ... - - -class Latent(Protocol): - def __call__(self) -> torch.Tensor: - ... - - def parameters(self, recurse: bool = True) -> Iterable[torch.Tensor]: - ... - - def reset_state(self) -> None: - ... - - -class Critic(Protocol): - def __call__(self, features: FeatureType, target_features: FeatureType) -> torch.Tensor: - ... diff --git a/bdpy/recon/torch/modules/__init__.py b/bdpy/recon/torch/modules/__init__.py index da4456a0..7b37bc7e 100644 --- a/bdpy/recon/torch/modules/__init__.py +++ b/bdpy/recon/torch/modules/__init__.py @@ -1,5 +1,4 @@ -from .encoder import build_encoder -from .generator import build_generator -from .latent import ArbitraryLatent -from .critic import TargetNormalizedMSE - +from .encoder import build_encoder, BaseEncoder +from .generator import build_generator, BaseGenerator +from .latent import ArbitraryLatent, BaseLatent +from .critic import TargetNormalizedMSE, BaseCritic diff --git a/bdpy/recon/torch/modules/critic.py b/bdpy/recon/torch/modules/critic.py index 7567a6dd..bce0b68b 100644 --- a/bdpy/recon/torch/modules/critic.py +++ b/bdpy/recon/torch/modules/critic.py @@ -1,14 +1,35 @@ from __future__ import annotations from abc import ABC, abstractmethod +from typing import Dict import torch import torch.nn as nn -class BaseCritic(nn.Module, ABC): +_FeatureType = Dict[str, torch.Tensor] + + +class BaseCritic(ABC): """Critic network module.""" + @abstractmethod + def __call__(self, features: _FeatureType, target_features: _FeatureType) -> torch.Tensor: + """Compute the total loss value given the features and the target features. + + Parameters + ---------- + features : dict[str, torch.Tensor] + Features indexed by the layer names. + target_features : dict[str, torch.Tensor] + Target features indexed by the layer names. + + Returns + ------- + torch.Tensor + Loss value. + """ + @abstractmethod def criterion( self, feature: torch.Tensor, target_feature: torch.Tensor, layer_name: str @@ -31,10 +52,21 @@ def criterion( """ pass + +class NNModuleCritic(BaseCritic, nn.Module): + """Critic network module uses __call__ method of nn.Module.""" + + def __call__(self, features: _FeatureType, target_features: _FeatureType) -> torch.Tensor: + return nn.Module.__call__(self, features, target_features) + + +class LayerWiseAverageCritic(NNModuleCritic): + """Compute the average of the layer-wise loss values.""" + def forward( self, - features: dict[str, torch.Tensor], - target_features: dict[str, torch.Tensor], + features: _FeatureType, + target_features: _FeatureType, ) -> torch.Tensor: """Forward pass through the critic network. @@ -62,7 +94,7 @@ def forward( return loss / counts -class TargetNormalizedMSE(BaseCritic): +class TargetNormalizedMSE(LayerWiseAverageCritic): """MSE loss divided by the squared norm of the target feature.""" def criterion( diff --git a/bdpy/recon/torch/modules/encoder.py b/bdpy/recon/torch/modules/encoder.py index fb18f91b..10cf3aac 100644 --- a/bdpy/recon/torch/modules/encoder.py +++ b/bdpy/recon/torch/modules/encoder.py @@ -9,11 +9,11 @@ from bdpy.dl.torch.stimulus_domain import Domain, image_domain -class BaseEncoder(nn.Module, ABC): +class BaseEncoder(ABC): """Encoder network module.""" @abstractmethod - def forward(self, images: torch.Tensor) -> dict[str, torch.Tensor]: + def __call__(self, images: torch.Tensor) -> dict[str, torch.Tensor]: """Forward pass through the encoder network. Parameters @@ -74,7 +74,7 @@ def __init__( self._domain = domain self._feature_network = self._feature_extractor._encoder - def forward(self, images: torch.Tensor) -> dict[str, torch.Tensor]: + def __call__(self, images: torch.Tensor) -> dict[str, torch.Tensor]: """Forward pass through the encoder network. Parameters diff --git a/bdpy/recon/torch/modules/generator.py b/bdpy/recon/torch/modules/generator.py index 215c3801..29f623b4 100644 --- a/bdpy/recon/torch/modules/generator.py +++ b/bdpy/recon/torch/modules/generator.py @@ -16,7 +16,7 @@ def reset_all_parameters(module: nn.Module) -> None: module.reset_parameters() -class BaseGenerator(nn.Module, ABC): +class BaseGenerator(ABC): """Generator module.""" @abstractmethod @@ -25,7 +25,12 @@ def reset_states(self) -> None: pass @abstractmethod - def forward(self, latent: torch.Tensor) -> torch.Tensor: + def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]: + """Return the parameters of the generator.""" + pass + + @abstractmethod + def __call__(self, latent: torch.Tensor) -> torch.Tensor: """Forward pass through the generator network. Parameters @@ -41,7 +46,17 @@ def forward(self, latent: torch.Tensor) -> torch.Tensor: pass -class BareGenerator(BaseGenerator): +class NNModuleGenerator(BaseGenerator, nn.Module): + """Generator module uses __call__ method and parameters method of nn.Module.""" + + def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]: + return nn.Module.parameters(self, recurse=recurse) + + def __call__(self, latent: torch.Tensor) -> torch.Tensor: + return nn.Module.__call__(self, latent) + + +class BareGenerator(NNModuleGenerator): """Bare generator module. This module does not have any trainable parameters. @@ -88,7 +103,7 @@ def forward(self, latent: torch.Tensor) -> torch.Tensor: return self._domain.send(self._activation(latent)) -class DNNGenerator(BaseGenerator): +class DNNGenerator(NNModuleGenerator): """DNN generator module. This module has the generator network as a submodule and its parameters are diff --git a/bdpy/recon/torch/modules/latent.py b/bdpy/recon/torch/modules/latent.py index d73bade1..823a5e5d 100644 --- a/bdpy/recon/torch/modules/latent.py +++ b/bdpy/recon/torch/modules/latent.py @@ -1,13 +1,13 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Callable +from typing import Callable, Iterator import torch import torch.nn as nn -class BaseLatent(nn.Module, ABC): +class BaseLatent(ABC): """Latent variable module.""" @abstractmethod @@ -16,7 +16,12 @@ def reset_states(self) -> None: pass @abstractmethod - def forward(self) -> torch.Tensor: + def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]: + """Return the parameters of the latent variable.""" + pass + + @abstractmethod + def __call__(self) -> torch.Tensor: """Generate a latent variable. Returns @@ -27,7 +32,17 @@ def forward(self) -> torch.Tensor: pass -class ArbitraryLatent(BaseLatent): +class NNModuleLatent(BaseLatent, nn.Module): + """Latent variable module uses __call__ method and parameters method of nn.Module.""" + + def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]: + return nn.Module.parameters(self, recurse=recurse) + + def __call__(self) -> torch.Tensor: + return nn.Module.__call__(self) + + +class ArbitraryLatent(NNModuleLatent): """Latent variable with arbitrary shape and initialization function. Parameters diff --git a/bdpy/recon/torch/pipeline/inversion.py b/bdpy/recon/torch/pipeline/inversion.py index e20f0a48..7d2afde8 100644 --- a/bdpy/recon/torch/pipeline/inversion.py +++ b/bdpy/recon/torch/pipeline/inversion.py @@ -4,7 +4,7 @@ import torch -from ..interface import Encoder, Generator, Latent, Critic +from ..modules import BaseEncoder, BaseGenerator, BaseLatent, BaseCritic FeatureType = Dict[str, torch.Tensor] @@ -14,13 +14,13 @@ class FeatureInversionPipeline: Parameters ---------- - encoder : Encoder + encoder : BaseEncoder Encoder module. - generator : Generator + generator : BaseGenerator Generator module. - latent : Latent + latent : BaseLatent Latent variable module. - critic : Critic + critic : BaseCritic Critic module. optimizer : torch.optim.Optimizer Optimizer. @@ -52,10 +52,10 @@ class FeatureInversionPipeline: def __init__( self, - encoder: Encoder, - generator: Generator, - latent: Latent, - critic: Critic, + encoder: BaseEncoder, + generator: BaseGenerator, + latent: BaseLatent, + critic: BaseCritic, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler = None, num_iterations: int = 1, From d0e4543322febb0c2049700ad3224f3f73f064ba Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Fri, 15 Dec 2023 16:13:16 +0900 Subject: [PATCH 021/117] add naive MSE critic --- bdpy/recon/torch/modules/critic.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/bdpy/recon/torch/modules/critic.py b/bdpy/recon/torch/modules/critic.py index bce0b68b..3237a1a9 100644 --- a/bdpy/recon/torch/modules/critic.py +++ b/bdpy/recon/torch/modules/critic.py @@ -94,6 +94,33 @@ def forward( return loss / counts +class MSE(LayerWiseAverageCritic): + """MSE loss.""" + + def criterion( + self, feature: torch.Tensor, target_feature: torch.Tensor, layer_name: str + ) -> torch.Tensor: + """Loss function per layer. + + Parameters + ---------- + feature : torch.Tensor + Feature tensor of the layer specified by `layer_name`. + target_feature : torch.Tensor + Target feature tensor of the layer specified by `layer_name`. + layer_name : str + Layer name. + + Returns + ------- + torch.Tensor + Loss value of the layer specified by `layer_name`. + """ + return ((feature - target_feature) ** 2).sum( + dim=tuple(range(1, feature.ndim)) + ) + + class TargetNormalizedMSE(LayerWiseAverageCritic): """MSE loss divided by the squared norm of the target feature.""" From 5afb9a667f14e155660fbbce8658bfa6bcab8bdc Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Tue, 19 Dec 2023 17:23:06 +0900 Subject: [PATCH 022/117] tentative implementation of callback --- bdpy/recon/torch/pipeline/inversion.py | 100 ++++++++++++++++++++++--- bdpy/util/callback.py | 74 ++++++++++++++++++ 2 files changed, 163 insertions(+), 11 deletions(-) create mode 100644 bdpy/util/callback.py diff --git a/bdpy/recon/torch/pipeline/inversion.py b/bdpy/recon/torch/pipeline/inversion.py index 7d2afde8..7129c7e7 100644 --- a/bdpy/recon/torch/pipeline/inversion.py +++ b/bdpy/recon/torch/pipeline/inversion.py @@ -1,14 +1,78 @@ -from typing import Dict +from __future__ import annotations + +from typing import Dict, Iterable from itertools import chain import torch from ..modules import BaseEncoder, BaseGenerator, BaseLatent, BaseCritic +from bdpy.util.callback import CallbackHandler, BaseCallback, unused FeatureType = Dict[str, torch.Tensor] +class FeatureInversionCallback(BaseCallback): + @unused + def on_iteration_start(self, *, step: int) -> None: + """Callback on iteration start.""" + pass + + @unused + def on_image_generated(self, *, step: int, image: torch.Tensor) -> None: + """Callback on image generated.""" + pass + + @unused + def on_feature_extracted(self, *, step: int, features: torch.Tensor) -> None: + """Callback on feature extracted.""" + pass + + @unused + def on_layerwise_loss_calculated(self, *, step: int, layer_loss: torch.Tensor, layer_name: str) -> None: + """Callback on layerwise loss calculated.""" + pass + + @unused + def on_loss_calculated(self, *, step: int, loss: torch.Tensor) -> None: + """Callback on loss calculated.""" + pass + + @unused + def on_backward_end(self, *, step: int) -> None: + """Callback on backward end.""" + pass + + @unused + def on_optimizer_step(self, *, step: int) -> None: + """Callback on optimizer step.""" + pass + + @unused + def on_iteration_end(self, step: int) -> None: + """Called at the end of each iteration.""" + pass + + +class CUILoggingCallback(FeatureInversionCallback): + def __init__(self, interval: int = 1, total_steps: int = -1) -> None: + self._interval = interval + self._total_steps = total_steps + self._loss: int | float = -1 + + def _step_str(self, step: int) -> str: + if self._total_steps > 0: + return f"{step+1}/{self._total_steps}" + else: + return f"{step+1}" + + def on_loss_calculated(self, *, step: int, loss: torch.Tensor) -> None: + self._loss = loss.item() + + def on_iteration_end(self, step: int) -> None: + if step % self._interval == 0: + print(f"Step: [{self._step_str(step)}], Loss: {self._loss:.4f}") + class FeatureInversionPipeline: """Feature inversion pipeline. @@ -28,8 +92,8 @@ class FeatureInversionPipeline: Learning rate scheduler, by default None. num_iterations : int, optional Number of iterations, by default 1. - log_interval : int, optional - Log interval, by default -1. If -1, logging is disabled. + callbacks : FeatureInversionCallback | Iterable[FeatureInversionCallback] | None, optional + Callbacks, by default None. Examples -------- @@ -59,7 +123,7 @@ def __init__( optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler = None, num_iterations: int = 1, - log_interval: int = -1, + callbacks: FeatureInversionCallback | Iterable[FeatureInversionCallback] | None = None, ) -> None: self._encoder = encoder self._generator = generator @@ -69,7 +133,8 @@ def __init__( self._scheduler = scheduler self._num_iterations = num_iterations - self._log_interval = log_interval + + self._callback_handler = CallbackHandler(callbacks) def __call__( self, @@ -87,30 +152,39 @@ def __call__( torch.Tensor Reconstructed images which have the similar features to the target features. """ + self._callback_handler.fire("on_pipeline_start") for step in range(self._num_iterations): + self._callback_handler.fire("on_iteration_start", step=step) self._optimizer.zero_grad() latent = self._latent() generated_image = self._generator(latent) + self._callback_handler.fire("on_image_generated", step=step, image=generated_image) features = self._encoder(generated_image) + self._callback_handler.fire("on_feature_extracted", step=step, features=features) loss = self._critic(features, target_features) + self._callback_handler.fire("on_loss_calculated", step=step, loss=loss) loss.backward() + self._callback_handler.fire("on_backward_end", step=step) self._optimizer.step() + self._callback_handler.fire("on_optimizer_step", step=step) if self._scheduler is not None: self._scheduler.step() - if self._log_interval > 0 and step % self._log_interval == 0: - print(f"Step: [{step+1}/{self._num_iterations}], Loss: {loss.item():.4f}") + self._callback_handler.fire("on_iteration_end", step=step) - return self._generator(self._latent()).detach() + generated_image = self._generator(self._latent()).detach() - def reset_state(self) -> None: + self._callback_handler.fire("on_pipeline_end") + return generated_image + + def reset_states(self) -> None: """Reset the state of the pipeline.""" - self._generator.reset_state() - self._latent.reset_state() + self._generator.reset_states() + self._latent.reset_states() self._optimizer = self._optimizer.__class__( chain( self._generator.parameters(), @@ -118,3 +192,7 @@ def reset_state(self) -> None: ), **self._optimizer.defaults ) + + def register_callback(self, callback: FeatureInversionCallback) -> None: + """Register a callback.""" + self._callback_handler.register(callback) diff --git a/bdpy/util/callback.py b/bdpy/util/callback.py new file mode 100644 index 00000000..3812c015 --- /dev/null +++ b/bdpy/util/callback.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from typing import Callable, Type, Any, Iterable +from typing_extensions import Annotated, ParamSpec + +from collections import defaultdict +from functools import wraps + + +_P = ParamSpec("_P") +_Unused = Annotated[None, "unused"] + + +def _is_unused(fn: Callable) -> bool: + return_type: Type | None = fn.__annotations__.get("return", None) + if return_type is None: + return False + return return_type == _Unused + + +def unused(fn: Callable[_P, Any]) -> Callable[_P, _Unused]: + @wraps(fn) # NOTE: preserve name, docstring, etc. of the original function + def _unused(*args: _P.args, **kwargs: _P.kwargs) -> _Unused: + raise RuntimeError(f"Function {fn} is decorated with @unused and must not be called.") + + # NOTE: change the return type to Unused + _unused.__annotations__["return"] = _Unused + + return _unused + + +class BaseCallback: + @unused + def on_pipeline_start(self) -> None: + """Callback on pipeline start.""" + pass + + @unused + def on_pipeline_end(self) -> None: + """Callback on pipeline end.""" + pass + + +class CallbackHandler: + _callbacks: list[BaseCallback] + _registered_functions: defaultdict[str, list[Callable]] + + def __init__(self, callbacks: BaseCallback | Iterable[BaseCallback] | None = None) -> None: + self._callbacks = [] + self._registered_functions = defaultdict(list) + if callbacks is not None: + if isinstance(callbacks, BaseCallback): + callbacks = [callbacks] + for callback in callbacks: + self.register(callback) + + def register(self, callback: BaseCallback) -> None: + self._callbacks.append(callback) + for event_type in dir(callback): + callback_method = getattr(callback, event_type) + if not callable(callback_method): + continue + if _is_unused(callback_method): + continue + if event_type.startswith("_"): + continue + if event_type.startswith("on_"): + self._registered_functions[event_type].append(callback_method) + continue + + def fire(self, event_type: str, **kwargs) -> None: + for callback_method in self._registered_functions[event_type]: + callback_method(**kwargs) + From 9305df9d70d75d056bfb9e95832ea7fa478bc2c1 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Tue, 19 Dec 2023 17:35:10 +0900 Subject: [PATCH 023/117] improve APIs and docstring --- bdpy/util/callback.py | 87 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 86 insertions(+), 1 deletion(-) diff --git a/bdpy/util/callback.py b/bdpy/util/callback.py index 3812c015..d4a7f458 100644 --- a/bdpy/util/callback.py +++ b/bdpy/util/callback.py @@ -19,6 +19,33 @@ def _is_unused(fn: Callable) -> bool: def unused(fn: Callable[_P, Any]) -> Callable[_P, _Unused]: + """Decorate a function to raise an error when called. + + This decorator marks a function as unused and raises an error when called. + The type of the return value is changed to `Annotated[None, "unused"]`. + + Parameters + ---------- + fn : Callable + Function to decorate. + + Returns + ------- + Callable + Decorated function. + + Examples + -------- + >>> @unused + ... def f(a: int, b: int, c: int = 0) -> int: + ... return a + b + c + ... + >>> f(1, 2, 3) + Traceback (most recent call last): + ... + RuntimeError: Function is decorated with @unused and must not be called. + """ + @wraps(fn) # NOTE: preserve name, docstring, etc. of the original function def _unused(*args: _P.args, **kwargs: _P.kwargs) -> _Unused: raise RuntimeError(f"Function {fn} is decorated with @unused and must not be called.") @@ -42,6 +69,35 @@ def on_pipeline_end(self) -> None: class CallbackHandler: + """Callback handler. + + This class manages the callback functions registered to the event types. + The callback functions are registered by calling the `register` method. + The callback functions are executed by calling the `fire` method. + The callback functions must be defined as methods of the class that inherits + `BaseCallback` and starts with "on_". + + Parameters + ---------- + callbacks : BaseCallback | Iterable[BaseCallback] | None, optional + Callbacks to register, by default None + + Examples + -------- + >>> class Callback(BaseCallback): + ... def on_pipeline_start(self): + ... print("Pipeline started.") + ... + ... def on_pipeline_end(self): + ... print("Pipeline ended.") + ... + >>> handler = CallbackHandler(Callback()) + >>> handler.fire("on_pipeline_start") + Pipeline started. + >>> handler.fire("on_pipeline_end") + Pipeline ended. + """ + _callbacks: list[BaseCallback] _registered_functions: defaultdict[str, list[Callable]] @@ -55,6 +111,21 @@ def __init__(self, callbacks: BaseCallback | Iterable[BaseCallback] | None = Non self.register(callback) def register(self, callback: BaseCallback) -> None: + """Register a callback. + + Parameters + ---------- + callback : BaseCallback + Callback to register. + + Raises + ------ + TypeError + If the callback is not an instance of BaseCallback. + """ + if not isinstance(callback, BaseCallback): + raise TypeError(f"Callback must be an instance of BaseCallback, not {type(callback)}.") + self._callbacks.append(callback) for event_type in dir(callback): callback_method = getattr(callback, event_type) @@ -68,7 +139,21 @@ def register(self, callback: BaseCallback) -> None: self._registered_functions[event_type].append(callback_method) continue - def fire(self, event_type: str, **kwargs) -> None: + def fire(self, event_type: str, **kwargs: dict[str, Any]) -> None: + """Execute the callback functions registered to the event type. + + Parameters + ---------- + event_type : str + Event type to fire, which must start with "on_". + kwargs : dict[str, Any] + Keyword arguments to pass to the callback functions. + + Raises + ------ + KeyError + If the event type is not registered. + """ for callback_method in self._registered_functions[event_type]: callback_method(**kwargs) From b51c57336828be295d0c4238bed109ceb7e777b9 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Tue, 19 Dec 2023 17:38:34 +0900 Subject: [PATCH 024/117] update docstring --- bdpy/recon/torch/pipeline/inversion.py | 14 ++++++++++++++ bdpy/util/callback.py | 9 ++++----- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/bdpy/recon/torch/pipeline/inversion.py b/bdpy/recon/torch/pipeline/inversion.py index 7129c7e7..b61cd6e4 100644 --- a/bdpy/recon/torch/pipeline/inversion.py +++ b/bdpy/recon/torch/pipeline/inversion.py @@ -13,6 +13,8 @@ class FeatureInversionCallback(BaseCallback): + """Callback for feature inversion pipeline.""" + @unused def on_iteration_start(self, *, step: int) -> None: """Callback on iteration start.""" @@ -55,6 +57,18 @@ def on_iteration_end(self, step: int) -> None: class CUILoggingCallback(FeatureInversionCallback): + """Callback for logging on CUI. + + Parameters + ---------- + interval : int, optional + Logging interval, by default 1. If `interval` is 1, the callback logs + every iteration. + total_steps : int, optional + Total number of iterations, by default -1. If `total_steps` is -1, + the callback does not show the total number of iterations. + """ + def __init__(self, interval: int = 1, total_steps: int = -1) -> None: self._interval = interval self._total_steps = total_steps diff --git a/bdpy/util/callback.py b/bdpy/util/callback.py index d4a7f458..0253b05e 100644 --- a/bdpy/util/callback.py +++ b/bdpy/util/callback.py @@ -71,11 +71,10 @@ def on_pipeline_end(self) -> None: class CallbackHandler: """Callback handler. - This class manages the callback functions registered to the event types. - The callback functions are registered by calling the `register` method. - The callback functions are executed by calling the `fire` method. - The callback functions must be defined as methods of the class that inherits - `BaseCallback` and starts with "on_". + This class manages the callback objects and fires the callback functions + registered to the event type. The callback functions must be defined as + methods of the callback classes. The callback functions must be named as + "on_". Parameters ---------- From 23234a3a10caf079c1ab0cd9308cbd67ed666ec8 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Tue, 19 Dec 2023 18:07:43 +0900 Subject: [PATCH 025/117] update docstrings --- bdpy/recon/torch/pipeline/inversion.py | 27 +++++++++++++++++--------- bdpy/util/callback.py | 14 +++++++++++-- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/bdpy/recon/torch/pipeline/inversion.py b/bdpy/recon/torch/pipeline/inversion.py index b61cd6e4..bb0a0c03 100644 --- a/bdpy/recon/torch/pipeline/inversion.py +++ b/bdpy/recon/torch/pipeline/inversion.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, Iterable +from typing import Dict, Iterable, Callable from itertools import chain @@ -12,8 +12,17 @@ FeatureType = Dict[str, torch.Tensor] +def _apply_to_features(fn: Callable[[torch.Tensor], torch.Tensor], features: FeatureType) -> FeatureType: + return {k: fn(v) for k, v in features.items()} + + class FeatureInversionCallback(BaseCallback): - """Callback for feature inversion pipeline.""" + """Callback for feature inversion pipeline. + + As a design principle, the callback functions must not have any side effects + on the pipeline results. It should be used only for logging, visualization, + etc. + """ @unused def on_iteration_start(self, *, step: int) -> None: @@ -26,12 +35,12 @@ def on_image_generated(self, *, step: int, image: torch.Tensor) -> None: pass @unused - def on_feature_extracted(self, *, step: int, features: torch.Tensor) -> None: + def on_feature_extracted(self, *, step: int, features: FeatureType) -> None: """Callback on feature extracted.""" pass @unused - def on_layerwise_loss_calculated(self, *, step: int, layer_loss: torch.Tensor, layer_name: str) -> None: + def on_layerwise_loss_calculated(self, *, layer_loss: torch.Tensor, layer_name: str) -> None: """Callback on layerwise loss calculated.""" pass @@ -51,7 +60,7 @@ def on_optimizer_step(self, *, step: int) -> None: pass @unused - def on_iteration_end(self, step: int) -> None: + def on_iteration_end(self, *, step: int) -> None: """Called at the end of each iteration.""" pass @@ -83,7 +92,7 @@ def _step_str(self, step: int) -> str: def on_loss_calculated(self, *, step: int, loss: torch.Tensor) -> None: self._loss = loss.item() - def on_iteration_end(self, step: int) -> None: + def on_iteration_end(self, *, step: int) -> None: if step % self._interval == 0: print(f"Step: [{self._step_str(step)}], Loss: {self._loss:.4f}") @@ -173,13 +182,13 @@ def __call__( latent = self._latent() generated_image = self._generator(latent) - self._callback_handler.fire("on_image_generated", step=step, image=generated_image) + self._callback_handler.fire("on_image_generated", step=step, image=generated_image.detach()) features = self._encoder(generated_image) - self._callback_handler.fire("on_feature_extracted", step=step, features=features) + self._callback_handler.fire("on_feature_extracted", step=step, features=_apply_to_features(lambda x: x.detach(), features)) loss = self._critic(features, target_features) - self._callback_handler.fire("on_loss_calculated", step=step, loss=loss) + self._callback_handler.fire("on_loss_calculated", step=step, loss=loss.detach()) loss.backward() self._callback_handler.fire("on_backward_end", step=step) diff --git a/bdpy/util/callback.py b/bdpy/util/callback.py index 0253b05e..3ccb6415 100644 --- a/bdpy/util/callback.py +++ b/bdpy/util/callback.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Callable, Type, Any, Iterable -from typing_extensions import Annotated, ParamSpec +from typing_extensions import Annotated, ParamSpec, Unpack from collections import defaultdict from functools import wraps @@ -57,6 +57,16 @@ def _unused(*args: _P.args, **kwargs: _P.kwargs) -> _Unused: class BaseCallback: + """Base class for callbacks. + + Callbacks are used to hook into the pipeline and execute custom functions + at specific events. Callback functions must be defined as methods of the + callback classes. The callback functions must be named as "on_". + As a design principle, the callback functions must not have any side effects + on the pipeline results. It should be used only for logging, visualization, + etc. + """ + @unused def on_pipeline_start(self) -> None: """Callback on pipeline start.""" @@ -138,7 +148,7 @@ def register(self, callback: BaseCallback) -> None: self._registered_functions[event_type].append(callback_method) continue - def fire(self, event_type: str, **kwargs: dict[str, Any]) -> None: + def fire(self, event_type: str, **kwargs: Any) -> None: """Execute the callback functions registered to the event type. Parameters From 6b3c51f2c92aac37fb872fdb00143c9acdfb97e6 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Tue, 19 Dec 2023 19:26:16 +0900 Subject: [PATCH 026/117] tentative implementation of the callback for W&B --- bdpy/recon/torch/modules/critic.py | 12 ++++++- bdpy/recon/torch/pipeline/inversion.py | 45 ++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/bdpy/recon/torch/modules/critic.py b/bdpy/recon/torch/modules/critic.py index 3237a1a9..178b7629 100644 --- a/bdpy/recon/torch/modules/critic.py +++ b/bdpy/recon/torch/modules/critic.py @@ -1,11 +1,13 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Dict +from typing import Dict, Iterable import torch import torch.nn as nn +from bdpy.util.callback import CallbackHandler, BaseCallback + _FeatureType = Dict[str, torch.Tensor] @@ -13,6 +15,9 @@ class BaseCritic(ABC): """Critic network module.""" + def __init__(self, callbacks: BaseCallback | Iterable[BaseCallback] | None = None) -> None: + self._callback_handler = CallbackHandler(callbacks) + @abstractmethod def __call__(self, features: _FeatureType, target_features: _FeatureType) -> torch.Tensor: """Compute the total loss value given the features and the target features. @@ -89,6 +94,11 @@ def forward( layer_wise_loss = self.criterion( feature, target_feature, layer_name=layer_name ) + self._callback_handler.fire( + "on_layerwise_loss_calculated", + layer_name=layer_name, + layer_loss=layer_wise_loss, + ) loss += layer_wise_loss counts += 1 return loss / counts diff --git a/bdpy/recon/torch/pipeline/inversion.py b/bdpy/recon/torch/pipeline/inversion.py index bb0a0c03..7e380636 100644 --- a/bdpy/recon/torch/pipeline/inversion.py +++ b/bdpy/recon/torch/pipeline/inversion.py @@ -96,6 +96,51 @@ def on_iteration_end(self, *, step: int) -> None: if step % self._interval == 0: print(f"Step: [{self._step_str(step)}], Loss: {self._loss:.4f}") + +class WandBLoggingCallback(FeatureInversionCallback): + """Callback for logging on Weights & Biases. + + Parameters + ---------- + run : wandb.sdk.wandb_run.Run + Run object of Weights & Biases. + interval : int, optional + Logging interval, by default 1. If `interval` is 1, the callback logs + every iteration. + media_interval : int, optional + Logging interval for media, by default 1. If `media_interval` is 1, + the callback logs every iteration. + + Notes + ----- + TODO: Currently it does not work because the dependency (wandb) is not installed. + """ + + def __init__(self, run: wandb.sdk.wandb_run.Run, interval: int = 1, media_interval: int = 1) -> None: + self._run = run + self._interval = interval + self._media_interval = media_interval + self._step = 0 + + def on_iteration_start(self, *, step: int) -> None: + # NOTE: We need to store the global step because we cannot access it + # in `on_layerwise_loss_calculated` by design. + self._step = step + + def on_image_generated(self, *, step: int, image: torch.Tensor) -> None: + if self._step % self._media_interval == 0: + image = wandb.Image(image) + self._run.log({"generated_image": image}, step=self._step) + + def on_layerwise_loss_calculated(self, *, layer_loss: torch.Tensor, layer_name: str) -> None: + if self._step % self._interval == 0: + self._run.log({f"critic/{layer_name}": layer_loss.item()}, step=self._step) + + def on_loss_calculated(self, *, step: int, loss: torch.Tensor) -> None: + if self._step % self._interval == 0: + self._run.log({"loss": loss.item()}, step=self._step) + + class FeatureInversionPipeline: """Feature inversion pipeline. From 34c665ce5a9873ed51ef5c3b81641adead636750 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 11:32:17 +0900 Subject: [PATCH 027/117] change generator API --- bdpy/recon/torch/modules/generator.py | 30 +++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/bdpy/recon/torch/modules/generator.py b/bdpy/recon/torch/modules/generator.py index 29f623b4..f74625c9 100644 --- a/bdpy/recon/torch/modules/generator.py +++ b/bdpy/recon/torch/modules/generator.py @@ -30,8 +30,23 @@ def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]: pass @abstractmethod + def generate(self, latent: torch.Tensor) -> torch.Tensor: + """Generate image given latent variable. + + Parameters + ---------- + latent : torch.Tensor + Latent variable. + + Returns + ------- + torch.Tensor + Generated image. The generated images must be in the range [0, 1]. + """ + pass + def __call__(self, latent: torch.Tensor) -> torch.Tensor: - """Forward pass through the generator network. + """Call self.generate. Parameters ---------- @@ -43,7 +58,7 @@ def __call__(self, latent: torch.Tensor) -> torch.Tensor: torch.Tensor Generated image. The generated images must be in the range [0, 1]. """ - pass + return self.generate(latent) class NNModuleGenerator(BaseGenerator, nn.Module): @@ -52,6 +67,9 @@ class NNModuleGenerator(BaseGenerator, nn.Module): def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]: return nn.Module.parameters(self, recurse=recurse) + def forward(self, latent: torch.Tensor) -> torch.Tensor: + return self.generate(latent) + def __call__(self, latent: torch.Tensor) -> torch.Tensor: return nn.Module.__call__(self, latent) @@ -87,8 +105,8 @@ def reset_states(self) -> None: """Reset the state of the generator.""" pass - def forward(self, latent: torch.Tensor) -> torch.Tensor: - """Forward pass through the generator network. + def generate(self, latent: torch.Tensor) -> torch.Tensor: + """Naively pass the latent vector to the activation function. Parameters ---------- @@ -152,8 +170,8 @@ def reset_states(self) -> None: """Reset the state of the generator.""" self._generator_network.apply(self._reset_fn) - def forward(self, latent: torch.Tensor) -> torch.Tensor: - """Forward pass through the generator network. + def generate(self, latent: torch.Tensor) -> torch.Tensor: + """Generate image using the generator network. Parameters ---------- From 7dc43e458f7b4847a42f16e3331e31630aa3e58c Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 11:43:42 +0900 Subject: [PATCH 028/117] change critic API --- bdpy/recon/torch/modules/critic.py | 37 +++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/bdpy/recon/torch/modules/critic.py b/bdpy/recon/torch/modules/critic.py index 178b7629..450d7391 100644 --- a/bdpy/recon/torch/modules/critic.py +++ b/bdpy/recon/torch/modules/critic.py @@ -18,8 +18,29 @@ class BaseCritic(ABC): def __init__(self, callbacks: BaseCallback | Iterable[BaseCallback] | None = None) -> None: self._callback_handler = CallbackHandler(callbacks) - @abstractmethod def __call__(self, features: _FeatureType, target_features: _FeatureType) -> torch.Tensor: + """Call self.compare. + + Parameters + ---------- + features : dict[str, torch.Tensor] + Features indexed by the layer names. + target_features : dict[str, torch.Tensor] + Target features indexed by the layer names. + + Returns + ------- + torch.Tensor + Loss value. + """ + return self.compare(features, target_features) + + @abstractmethod + def compare( + self, + features: _FeatureType, + target_features: _FeatureType, + ) -> torch.Tensor: """Compute the total loss value given the features and the target features. Parameters @@ -34,9 +55,10 @@ def __call__(self, features: _FeatureType, target_features: _FeatureType) -> tor torch.Tensor Loss value. """ + pass @abstractmethod - def criterion( + def compare_layer( self, feature: torch.Tensor, target_feature: torch.Tensor, layer_name: str ) -> torch.Tensor: """Loss function per layer. @@ -64,16 +86,19 @@ class NNModuleCritic(BaseCritic, nn.Module): def __call__(self, features: _FeatureType, target_features: _FeatureType) -> torch.Tensor: return nn.Module.__call__(self, features, target_features) + def forward(self, features: _FeatureType, target_features: _FeatureType) -> torch.Tensor: + return self.compare(features, target_features) + class LayerWiseAverageCritic(NNModuleCritic): """Compute the average of the layer-wise loss values.""" - def forward( + def compare( self, features: _FeatureType, target_features: _FeatureType, ) -> torch.Tensor: - """Forward pass through the critic network. + """Compute the total loss value given the features and the target features. Parameters ---------- @@ -107,7 +132,7 @@ def forward( class MSE(LayerWiseAverageCritic): """MSE loss.""" - def criterion( + def compare_layer( self, feature: torch.Tensor, target_feature: torch.Tensor, layer_name: str ) -> torch.Tensor: """Loss function per layer. @@ -134,7 +159,7 @@ def criterion( class TargetNormalizedMSE(LayerWiseAverageCritic): """MSE loss divided by the squared norm of the target feature.""" - def criterion( + def compare_layer( self, feature: torch.Tensor, target_feature: torch.Tensor, layer_name: str ) -> torch.Tensor: """Loss function per layer. From bd202146f1153c4aefe90264cdbca4938da4cf66 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 11:45:18 +0900 Subject: [PATCH 029/117] change encoder API --- bdpy/recon/torch/modules/encoder.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/bdpy/recon/torch/modules/encoder.py b/bdpy/recon/torch/modules/encoder.py index 10cf3aac..30ae2d34 100644 --- a/bdpy/recon/torch/modules/encoder.py +++ b/bdpy/recon/torch/modules/encoder.py @@ -13,8 +13,8 @@ class BaseEncoder(ABC): """Encoder network module.""" @abstractmethod - def __call__(self, images: torch.Tensor) -> dict[str, torch.Tensor]: - """Forward pass through the encoder network. + def encode(self, images: torch.Tensor) -> dict[str, torch.Tensor]: + """Encode images as a hierarchical feature representation. Parameters ---------- @@ -28,6 +28,21 @@ def __call__(self, images: torch.Tensor) -> dict[str, torch.Tensor]: """ pass + def __call__(self, images: torch.Tensor) -> dict[str, torch.Tensor]: + """Call self.encode. + + Parameters + ---------- + images : torch.Tensor + Images. + + Returns + ------- + dict[str, torch.Tensor] + Features indexed by the layer names. + """ + return self.encode(images) + class SimpleEncoder(BaseEncoder): """Encoder network module with a naive feature extractor. @@ -74,8 +89,8 @@ def __init__( self._domain = domain self._feature_network = self._feature_extractor._encoder - def __call__(self, images: torch.Tensor) -> dict[str, torch.Tensor]: - """Forward pass through the encoder network. + def encode(self, images: torch.Tensor) -> dict[str, torch.Tensor]: + """Encode images as a hierarchical feature representation. Parameters ---------- From e4e6e13bb8940b963870707c88bb529edd7a4a3f Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 11:46:27 +0900 Subject: [PATCH 030/117] change latent API --- bdpy/recon/torch/modules/latent.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/bdpy/recon/torch/modules/latent.py b/bdpy/recon/torch/modules/latent.py index 823a5e5d..f1a9e7db 100644 --- a/bdpy/recon/torch/modules/latent.py +++ b/bdpy/recon/torch/modules/latent.py @@ -21,7 +21,7 @@ def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]: pass @abstractmethod - def __call__(self) -> torch.Tensor: + def generate(self) -> torch.Tensor: """Generate a latent variable. Returns @@ -31,6 +31,16 @@ def __call__(self) -> torch.Tensor: """ pass + def __call__(self) -> torch.Tensor: + """Call self.generate. + + Returns + ------- + torch.Tensor + Latent variable. + """ + return self.generate() + class NNModuleLatent(BaseLatent, nn.Module): """Latent variable module uses __call__ method and parameters method of nn.Module.""" @@ -41,6 +51,9 @@ def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]: def __call__(self) -> torch.Tensor: return nn.Module.__call__(self) + def forward(self) -> torch.Tensor: + return self.generate() + class ArbitraryLatent(NNModuleLatent): """Latent variable with arbitrary shape and initialization function. @@ -73,7 +86,7 @@ def reset_states(self) -> None: """Reset the state of the latent variable.""" self._init_fn(self._latent) - def forward(self) -> torch.Tensor: + def generate(self) -> torch.Tensor: """Generate a latent variable. Returns From ad7e53f44790d242ec6599d8ff07a564e3fbdf57 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 11:48:45 +0900 Subject: [PATCH 031/117] remove unused import --- bdpy/util/callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bdpy/util/callback.py b/bdpy/util/callback.py index 3ccb6415..b417d048 100644 --- a/bdpy/util/callback.py +++ b/bdpy/util/callback.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Callable, Type, Any, Iterable -from typing_extensions import Annotated, ParamSpec, Unpack +from typing_extensions import Annotated, ParamSpec from collections import defaultdict from functools import wraps From 43dc7dbec2a3e6c0cb35c03d5d414806e80f4a0c Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 12:25:30 +0900 Subject: [PATCH 032/117] add test cases for critic module --- tests/recon/__init__.py | 0 tests/recon/torch/__init__.py | 0 tests/recon/torch/modules/__init__.py | 0 tests/recon/torch/modules/test_critic.py | 103 +++++++++++++++++++++++ 4 files changed, 103 insertions(+) create mode 100644 tests/recon/__init__.py create mode 100644 tests/recon/torch/__init__.py create mode 100644 tests/recon/torch/modules/__init__.py create mode 100644 tests/recon/torch/modules/test_critic.py diff --git a/tests/recon/__init__.py b/tests/recon/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/recon/torch/__init__.py b/tests/recon/torch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/recon/torch/modules/__init__.py b/tests/recon/torch/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/recon/torch/modules/test_critic.py b/tests/recon/torch/modules/test_critic.py new file mode 100644 index 00000000..72ca14d3 --- /dev/null +++ b/tests/recon/torch/modules/test_critic.py @@ -0,0 +1,103 @@ +"""Tests for bdpy.recon.torch.modules.critic.""" + +import unittest + +import torch + +from bdpy.recon.torch.modules import critic as critic_module + + +class TestBaseCritic(unittest.TestCase): + """Tests for bdpy.recon.torch.modules.critic.BaseCritic.""" + def setUp(self): + self.features = { + "conv1": torch.tensor([1.0], requires_grad=True), + "conv2": torch.tensor([2.0], requires_grad=True), + "conv3": torch.tensor([3.0], requires_grad=True), + } + self.target_features = { + "conv1": torch.tensor([0.0]), + "conv2": torch.tensor([1.0]), + "conv3": torch.tensor([2.0]), + } + + def test_instantiation(self): + """Test instantiation.""" + self.assertRaises(TypeError, critic_module.BaseCritic) + + def test_call(self): + """Test __call__.""" + class ReturnZeroCritic(critic_module.BaseCritic): + def compare(self, features, target_features): + return 0.0 + + critic = ReturnZeroCritic() + self.assertEqual(critic(self.features, self.target_features), 0.0) + + def test_loss_computation(self): + """Test loss computation.""" + class SumCritic(critic_module.BaseCritic): + def compare(self, features, target_features): + loss = 0.0 + for feature, target_feature in zip(features.values(), target_features.values()): + loss += torch.sum(torch.abs(feature - target_feature)) + return loss + + critic = SumCritic() + self.assertEqual(critic(self.features, self.target_features), 3.0) + + for feature in self.features.values(): + feature.grad = None + loss = critic(self.features, self.target_features) + loss.backward() + for feature in self.features.values(): + self.assertIsNotNone(feature.grad) + self.assertEqual(feature.grad, torch.ones_like(feature)) + + +class TestNNModuleCritic(unittest.TestCase): + """Tests for bdpy.recon.torch.modules.critic.NNModuleCritic.""" + def setUp(self): + self.features = { + "conv1": torch.tensor([1.0], requires_grad=True), + "conv2": torch.tensor([2.0], requires_grad=True), + "conv3": torch.tensor([3.0], requires_grad=True), + } + self.target_features = { + "conv1": torch.tensor([0.0]), + "conv2": torch.tensor([1.0]), + "conv3": torch.tensor([2.0]), + } + + def test_instantiation(self): + """Test instantiation.""" + self.assertRaises(TypeError, critic_module.NNModuleCritic) + + def test_call(self): + """Test __call__.""" + class ReturnZeroCritic(critic_module.NNModuleCritic): + def compare(self, features, target_features): + return 0.0 + + critic = ReturnZeroCritic() + self.assertEqual(critic(self.features, self.target_features), 0.0) + + def test_loss_computation(self): + """Test loss computation.""" + class SumCritic(critic_module.NNModuleCritic): + def compare(self, features, target_features): + loss = 0.0 + for feature, target_feature in zip(features.values(), target_features.values()): + loss += torch.sum(torch.abs(feature - target_feature)) + return loss + + critic = SumCritic() + self.assertEqual(critic(self.features, self.target_features), 3.0) + + for feature in self.features.values(): + feature.grad = None + loss = critic(self.features, self.target_features) + loss.backward() + for feature in self.features.values(): + self.assertIsNotNone(feature.grad) + self.assertEqual(feature.grad, torch.ones_like(feature)) From 5af9091238992d2b63a6dce558582b877db7e4bf Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 12:25:45 +0900 Subject: [PATCH 033/117] update ciritic API so that it pass test cases --- bdpy/recon/torch/modules/critic.py | 49 ++++++++++++++++-------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/bdpy/recon/torch/modules/critic.py b/bdpy/recon/torch/modules/critic.py index 450d7391..18009eb0 100644 --- a/bdpy/recon/torch/modules/critic.py +++ b/bdpy/recon/torch/modules/critic.py @@ -57,31 +57,12 @@ def compare( """ pass - @abstractmethod - def compare_layer( - self, feature: torch.Tensor, target_feature: torch.Tensor, layer_name: str - ) -> torch.Tensor: - """Loss function per layer. - - Parameters - ---------- - feature : torch.Tensor - Feature tensor of the layer specified by `layer_name`. - target_feature : torch.Tensor - Target feature tensor of the layer specified by `layer_name`. - layer_name : str - Layer name. - - Returns - ------- - torch.Tensor - Loss value of the layer specified by `layer_name`. - """ - pass - class NNModuleCritic(BaseCritic, nn.Module): """Critic network module uses __call__ method of nn.Module.""" + def __init__(self, callbacks: BaseCallback | Iterable[BaseCallback] | None = None) -> None: + BaseCritic.__init__(self, callbacks) + nn.Module.__init__(self) def __call__(self, features: _FeatureType, target_features: _FeatureType) -> torch.Tensor: return nn.Module.__call__(self, features, target_features) @@ -116,7 +97,7 @@ def compare( counts = 0 for layer_name, feature in features.items(): target_feature = target_features[layer_name] - layer_wise_loss = self.criterion( + layer_wise_loss = self.compare_layer( feature, target_feature, layer_name=layer_name ) self._callback_handler.fire( @@ -128,6 +109,28 @@ def compare( counts += 1 return loss / counts + @abstractmethod + def compare_layer( + self, feature: torch.Tensor, target_feature: torch.Tensor, layer_name: str + ) -> torch.Tensor: + """Loss function per layer. + + Parameters + ---------- + feature : torch.Tensor + Feature tensor of the layer specified by `layer_name`. + target_feature : torch.Tensor + Target feature tensor of the layer specified by `layer_name`. + layer_name : str + Layer name. + + Returns + ------- + torch.Tensor + Loss value of the layer specified by `layer_name`. + """ + pass + class MSE(LayerWiseAverageCritic): """MSE loss.""" From 6f300246c781440cc7972c6c206f8b0bb54dc308 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 14:24:57 +0900 Subject: [PATCH 034/117] update test cases for critic module --- tests/recon/torch/modules/test_critic.py | 52 ++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tests/recon/torch/modules/test_critic.py b/tests/recon/torch/modules/test_critic.py index 72ca14d3..7ea1e5fa 100644 --- a/tests/recon/torch/modules/test_critic.py +++ b/tests/recon/torch/modules/test_critic.py @@ -101,3 +101,55 @@ def compare(self, features, target_features): for feature in self.features.values(): self.assertIsNotNone(feature.grad) self.assertEqual(feature.grad, torch.ones_like(feature)) + + +class TestLayerWiseAverageCritic(unittest.TestCase): + """Tests for bdpy.recon.torch.modules.critic.LayerWiseAverageCritic.""" + def setUp(self): + self.features = { + "conv1": torch.tensor([1.0], requires_grad=True), + "conv2": torch.tensor([2.0], requires_grad=True), + "conv3": torch.tensor([3.0], requires_grad=True), + } + self.target_features = { + "conv1": torch.tensor([0.0]), + "conv2": torch.tensor([1.0]), + "conv3": torch.tensor([2.0]), + } + + def test_call(self): + """Test __call__.""" + class ReturnZeroCritic(critic_module.LayerWiseAverageCritic): + def compare_layer(self, feature, target_feature, layer_name): + return 0.0 + + critic = ReturnZeroCritic() + self.assertEqual(critic(self.features, self.target_features), 0.0) + + def test_loss_computation(self): + """Test loss computation.""" + class AbsCritic(critic_module.LayerWiseAverageCritic): + def compare_layer(self, feature, target_feature, layer_name): + return torch.abs(feature - target_feature) + + critic = AbsCritic() + self.assertEqual(critic(self.features, self.target_features), 1) + + for feature in self.features.values(): + feature.grad = None + loss = critic(self.features, self.target_features) + loss.backward() + for feature in self.features.values(): + self.assertIsNotNone(feature.grad) + self.assertEqual(feature.grad, torch.ones_like(feature)/3) + + +class TestMSE(unittest.TestCase): + """Tests for bdpy.recon.torch.modules.critic.MSE.""" + def test_compare_layer(self): + """Test compare_layer.""" + critic = critic_module.MSE() + feature = torch.randn(13, 7) + target_feature = torch.randn_like(feature) + loss = critic.compare_layer(feature, target_feature, "conv1") + self.assertTrue(torch.allclose(loss, torch.sum((feature - target_feature)**2, dim=1))) From ac87094b13aac32871f662c3766b40505e6c2670 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 14:34:16 +0900 Subject: [PATCH 035/117] update test cases --- tests/recon/torch/modules/test_critic.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/recon/torch/modules/test_critic.py b/tests/recon/torch/modules/test_critic.py index 7ea1e5fa..4c474c14 100644 --- a/tests/recon/torch/modules/test_critic.py +++ b/tests/recon/torch/modules/test_critic.py @@ -153,3 +153,17 @@ def test_compare_layer(self): target_feature = torch.randn_like(feature) loss = critic.compare_layer(feature, target_feature, "conv1") self.assertTrue(torch.allclose(loss, torch.sum((feature - target_feature)**2, dim=1))) + + +class TestTargetNormalizedMSE(unittest.TestCase): + """Tests for bdpy.recon.torch.modules.critic.TargetNormalizedMSE.""" + def test_compare_layer(self): + """Test compare_layer.""" + critic = critic_module.TargetNormalizedMSE() + feature = torch.randn(13, 7) + target_feature = torch.randn_like(feature) + loss = critic.compare_layer(feature, target_feature, "conv1") + self.assertTrue(torch.allclose( + loss, + torch.sum((feature - target_feature)**2, dim=1)/torch.sum(target_feature**2, dim=1) + )) From 19aeefc93f2f0707c73b5c90d9b9d220645f381a Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 16:17:06 +0900 Subject: [PATCH 036/117] update docstring --- bdpy/dl/torch/stimulus_domain/core.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/bdpy/dl/torch/stimulus_domain/core.py b/bdpy/dl/torch/stimulus_domain/core.py index 9acec49a..73529bc2 100644 --- a/bdpy/dl/torch/stimulus_domain/core.py +++ b/bdpy/dl/torch/stimulus_domain/core.py @@ -13,6 +13,21 @@ class Domain(nn.Module, ABC): """Base class for stimulus domain. This class is used to convert stimulus between each domain and library's internal common space. + Suppose that we have two functions `f: X -> Y_1` and `g: Y_2 -> Z` and want to compose them. + Here, `X`, `Y_1`, `Y_2`, and `Z` are different domains and assume that `Y_1` and `Y_2` are + the similar domain that can be converted to each other. + Then, we can compose `f` and `g` as `g . t . f(x)`, where `t: Y_1 -> Y_2` is the domain + conversion function. This class is used to implement `t`. + + The subclasses of this class should implement `send` and `receive` methods. The `send` method + converts stimulus from the original domain (`Y_1` or `Y_2`) to the internal common space (`Y_0`), + and the `receive` method converts stimulus from the internal common space to the original domain. + By implementing domain class for `Y_1` and `Y_2`, we can construct the domain conversion function + `t` as `t = Y_2.receive . Y_1.send`. + + Note that the subclasses of this class do not necessarily guarantee the reversibility of `send` + and `receive` methods. If the domain conversion is irreversible, the subclasses should inherit + `IrreversibleDomain` class instead of this class. """ @abstractmethod @@ -64,7 +79,7 @@ def receive(self, x: torch.Tensor) -> torch.Tensor: class ComposedDomain(Domain): - """The domain composed of multiple domains.""" + """The domain composed of multiple sub-domains.""" def __init__(self, domains: Iterable[Domain]) -> None: super().__init__() From 39253030ba227b9231df2dad21f6db7ec6c176e6 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 16:36:46 +0900 Subject: [PATCH 037/117] update domain API --- .../{stimulus_domain => domain}/__init__.py | 0 .../torch/{stimulus_domain => domain}/core.py | 18 +++++++++--------- .../image_domain.py | 0 bdpy/recon/torch/modules/encoder.py | 2 +- bdpy/recon/torch/modules/generator.py | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) rename bdpy/dl/torch/{stimulus_domain => domain}/__init__.py (100%) rename bdpy/dl/torch/{stimulus_domain => domain}/core.py (79%) rename bdpy/dl/torch/{stimulus_domain => domain}/image_domain.py (100%) diff --git a/bdpy/dl/torch/stimulus_domain/__init__.py b/bdpy/dl/torch/domain/__init__.py similarity index 100% rename from bdpy/dl/torch/stimulus_domain/__init__.py rename to bdpy/dl/torch/domain/__init__.py diff --git a/bdpy/dl/torch/stimulus_domain/core.py b/bdpy/dl/torch/domain/core.py similarity index 79% rename from bdpy/dl/torch/stimulus_domain/core.py rename to bdpy/dl/torch/domain/core.py index 73529bc2..097bb0e8 100644 --- a/bdpy/dl/torch/stimulus_domain/core.py +++ b/bdpy/dl/torch/domain/core.py @@ -12,7 +12,7 @@ class Domain(nn.Module, ABC): """Base class for stimulus domain. - This class is used to convert stimulus between each domain and library's internal common space. + This class is used to convert data between each domain and library's internal common space. Suppose that we have two functions `f: X -> Y_1` and `g: Y_2 -> Z` and want to compose them. Here, `X`, `Y_1`, `Y_2`, and `Z` are different domains and assume that `Y_1` and `Y_2` are the similar domain that can be converted to each other. @@ -20,8 +20,8 @@ class Domain(nn.Module, ABC): conversion function. This class is used to implement `t`. The subclasses of this class should implement `send` and `receive` methods. The `send` method - converts stimulus from the original domain (`Y_1` or `Y_2`) to the internal common space (`Y_0`), - and the `receive` method converts stimulus from the internal common space to the original domain. + converts data from the original domain (`Y_1` or `Y_2`) to the internal common space (`Y_0`), + and the `receive` method converts data from the internal common space to the original domain. By implementing domain class for `Y_1` and `Y_2`, we can construct the domain conversion function `t` as `t = Y_2.receive . Y_1.send`. @@ -37,28 +37,28 @@ def send(self, x: torch.Tensor) -> torch.Tensor: Parameters ---------- x : torch.Tensor - Stimulus in the original domain. + Data in the original domain. Returns ------- torch.Tensor - Stimulus in the internal common space. + Data in the internal common space. """ pass @abstractmethod def receive(self, x: torch.Tensor) -> torch.Tensor: - """Receive stimulus from the internal common space to each domain. + """Receive data from the internal common space to each domain. Parameters ---------- x : torch.Tensor - Stimulus in the internal common space. + Data in the internal common space. Returns ------- torch.Tensor - Stimulus in the original domain. + Data in the original domain. """ pass @@ -66,7 +66,7 @@ def receive(self, x: torch.Tensor) -> torch.Tensor: class IrreversibleDomain(Domain): """The domain which cannot be reversed. - This class is used to convert stimulus between each domain and library's + This class is used to convert data between each domain and library's internal common space. Note that the subclasses of this class do not guarantee the reversibility of `send` and `receive` methods. """ diff --git a/bdpy/dl/torch/stimulus_domain/image_domain.py b/bdpy/dl/torch/domain/image_domain.py similarity index 100% rename from bdpy/dl/torch/stimulus_domain/image_domain.py rename to bdpy/dl/torch/domain/image_domain.py diff --git a/bdpy/recon/torch/modules/encoder.py b/bdpy/recon/torch/modules/encoder.py index 30ae2d34..14b2e5a6 100644 --- a/bdpy/recon/torch/modules/encoder.py +++ b/bdpy/recon/torch/modules/encoder.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn from bdpy.dl.torch import FeatureExtractor -from bdpy.dl.torch.stimulus_domain import Domain, image_domain +from bdpy.dl.torch.domain import Domain, image_domain class BaseEncoder(ABC): diff --git a/bdpy/recon/torch/modules/generator.py b/bdpy/recon/torch/modules/generator.py index f74625c9..ed70b118 100644 --- a/bdpy/recon/torch/modules/generator.py +++ b/bdpy/recon/torch/modules/generator.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn from torch.nn.parameter import Parameter -from bdpy.dl.torch.stimulus_domain import Domain, image_domain +from bdpy.dl.torch.domain import Domain, image_domain @torch.no_grad() From fdb957f909a719cffe8d4c9293ed30c0e51e0b87 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 17:49:18 +0900 Subject: [PATCH 038/117] WIP update of domain API --- bdpy/dl/torch/domain/__init__.py | 2 +- bdpy/dl/torch/domain/core.py | 87 +++++++++++++++++++++----- bdpy/dl/torch/domain/feature_domain.py | 44 +++++++++++++ bdpy/dl/torch/domain/image_domain.py | 12 ++++ 4 files changed, 129 insertions(+), 16 deletions(-) create mode 100644 bdpy/dl/torch/domain/feature_domain.py diff --git a/bdpy/dl/torch/domain/__init__.py b/bdpy/dl/torch/domain/__init__.py index a6743925..1e24530e 100644 --- a/bdpy/dl/torch/domain/__init__.py +++ b/bdpy/dl/torch/domain/__init__.py @@ -1 +1 @@ -from .core import Domain, IrreversibleDomain, ComposedDomain \ No newline at end of file +from .core import Domain, IrreversibleDomain, ComposedDomain, KeyValueDomain \ No newline at end of file diff --git a/bdpy/dl/torch/domain/core.py b/bdpy/dl/torch/domain/core.py index 097bb0e8..a1453018 100644 --- a/bdpy/dl/torch/domain/core.py +++ b/bdpy/dl/torch/domain/core.py @@ -1,15 +1,17 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Iterable, TYPE_CHECKING +from typing import Iterable, TYPE_CHECKING, TypeVar, Generic import torch.nn as nn if TYPE_CHECKING: import torch +_T = TypeVar("_T") -class Domain(nn.Module, ABC): + +class Domain(nn.Module, ABC, Generic[_T]): """Base class for stimulus domain. This class is used to convert data between each domain and library's internal common space. @@ -31,39 +33,39 @@ class Domain(nn.Module, ABC): """ @abstractmethod - def send(self, x: torch.Tensor) -> torch.Tensor: + def send(self, x: _T) -> _T: """Send stimulus to the internal common space from each domain. Parameters ---------- - x : torch.Tensor + x : _T Data in the original domain. Returns ------- - torch.Tensor + _T Data in the internal common space. """ pass @abstractmethod - def receive(self, x: torch.Tensor) -> torch.Tensor: + def receive(self, x: _T) -> _T: """Receive data from the internal common space to each domain. Parameters ---------- - x : torch.Tensor + x : _T Data in the internal common space. Returns ------- - torch.Tensor + _T Data in the original domain. """ pass -class IrreversibleDomain(Domain): +class IrreversibleDomain(Domain, Generic[_T]): """The domain which cannot be reversed. This class is used to convert data between each domain and library's @@ -71,26 +73,81 @@ class IrreversibleDomain(Domain): guarantee the reversibility of `send` and `receive` methods. """ - def send(self, x: torch.Tensor) -> torch.Tensor: + def send(self, x: _T) -> _T: return x - def receive(self, x: torch.Tensor) -> torch.Tensor: + def receive(self, x: _T) -> _T: return x -class ComposedDomain(Domain): - """The domain composed of multiple sub-domains.""" +class ComposedDomain(Domain, Generic[_T]): + """The domain composed of multiple sub-domains. + + Suppose we have list of domain objects `domains = [d_0, d_1, ..., d_n]`. + Then, the data in the original domain `D` can be accessed as + `d_0[0].receive . d_1[1].receive . ... . d_n[n].receive(x)`. + + Parameters + ---------- + domains : Iterable[Domain] + Sub-domains to compose. + + Examples + -------- + >>> import numpy as np + >>> import torch + >>> from bdpy.dl.torch.domain import ComposedDomain + >>> from bdpy.dl.torch.domain.image_domain import AffineDomain, BGRDomain + >>> composed_domain = ComposedDomain([ + ... AffineDomain(np.array([0.5]), 1), + ... BGRDomain(), + ... ]) + >>> image = torch.randn(1, 3, 64, 64).clamp(-0.5, 0.5) + >>> image.shape + torch.Size([1, 3, 64, 64]) + >>> composed_domain.send(image).shape + torch.Size([1, 3, 64, 64]) + >>> print(composed_domain.send(image).min().item(), composed_domain.send(image).max().item()) + 0.0 1.0 + """ def __init__(self, domains: Iterable[Domain]) -> None: super().__init__() self.domains = nn.ModuleList(domains) - def send(self, x: torch.Tensor) -> torch.Tensor: + def send(self, x: _T) -> _T: for domain in reversed(self.domains): x = domain.send(x) return x - def receive(self, x: torch.Tensor) -> torch.Tensor: + def receive(self, x: _T) -> _T: for domain in self.domains: x = domain.receive(x) return x + + +class KeyValueDomain(Domain, Generic[_T]): + """The domain which converts key-value pairs. + + This class is used to convert key-value pairs between each domain and library's + internal common space. + + Parameters + ---------- + domain_mapper : dict[str, Domain] + Dictionary that maps keys to domains. + """ + + def __init__(self, domain_mapper: dict[str, Domain]) -> None: + super().__init__() + self.domain_mapper = domain_mapper + + def send(self, x: dict[str, _T]) -> dict[str, _T]: + return { + key: self.domain_mapper[key].send(value) for key, value in x.items() + } + + def receive(self, x: dict[str, _T]) -> dict[str, _T]: + return { + key: self.domain_mapper[key].receive(value) for key, value in x.items() + } diff --git a/bdpy/dl/torch/domain/feature_domain.py b/bdpy/dl/torch/domain/feature_domain.py new file mode 100644 index 00000000..ad658384 --- /dev/null +++ b/bdpy/dl/torch/domain/feature_domain.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from typing import Dict + +import torch + +from .core import Domain + +_FeatureType = Dict[str, torch.Tensor] + + +def _lnd2nld(feature: torch.Tensor) -> torch.Tensor: + """Convert features having the shape of (L, N, D) to (N, L, D).""" + return feature.permute(1, 0, 2) + +def _nld2lnd(feature: torch.Tensor) -> torch.Tensor: + """Convert features having the shape of (N, L, D) to (L, N, D).""" + return feature.permute(1, 0, 2) + + +class ArbitraryFeatureKeyDomain(Domain): + def __init__( + self, + to_internal: dict[str, str] | None = None, + to_self: dict[str, str] | None = None, + ): + super().__init__() + + if to_internal is None and to_self is None: + raise ValueError("Either to_internal or to_self must be specified.") + + if to_internal is None: + to_internal = {value: key for key, value in to_self.items()} + elif to_self is None: + to_self = {value: key for key, value in to_internal.items()} + + self._to_internal = to_internal + self._to_self = to_self + + def send(self, features: _FeatureType) -> _FeatureType: + return {self._to_internal.get(key, key): value for key, value in features.items()} + + def receive(self, features: _FeatureType) -> _FeatureType: + return {self._to_self.get(key, key): value for key, value in features.items()} diff --git a/bdpy/dl/torch/domain/image_domain.py b/bdpy/dl/torch/domain/image_domain.py index bb0be8f1..f875bfd6 100644 --- a/bdpy/dl/torch/domain/image_domain.py +++ b/bdpy/dl/torch/domain/image_domain.py @@ -1,3 +1,15 @@ +"""Image domains for PyTorch. + +This module provides image domains for PyTorch. The image domains are used to +convert images between each domain and library's internal common space. +The internal common space is defined as follows: + +- Channel axis: 1 +- Pixel range: [0, 1] +- Image size: arbitrary +- Color space: RGB +""" + from __future__ import annotations import warnings From 393887b189b325178a73c4ee26bf4bca4b4d03a2 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 17:53:27 +0900 Subject: [PATCH 039/117] update domain API --- bdpy/dl/torch/domain/image_domain.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/bdpy/dl/torch/domain/image_domain.py b/bdpy/dl/torch/domain/image_domain.py index f875bfd6..13c75e9d 100644 --- a/bdpy/dl/torch/domain/image_domain.py +++ b/bdpy/dl/torch/domain/image_domain.py @@ -65,11 +65,28 @@ class AffineDomain(Domain): This domain is used to convert images in [0, 1] to images in [-center, scale-center]. In other words, the pixel intensity p in [0, 1] is converted to p * scale - center. + + Parameters + ---------- + center : float | np.ndarray + Center of the affine transformation. + If center.ndim == 0, it must be scalar. + If center.ndim == 1, it must be 1D vector (C,). + If center.ndim == 3, it must be 3D vector (1, C, W, H). + scale : float | np.ndarray + Scale of the affine transformation. + If scale.ndim == 0, it must be scalar. + If scale.ndim == 1, it must be 1D vector (C,). + If scale.ndim == 3, it must be 3D vector (1, C, W, H). + device : torch.device | None + Device to send/receive images. + dtype : torch.dtype | None + Data type to send/receive images. """ def __init__( self, - center: np.ndarray, + center: float | np.ndarray, scale: float | np.ndarray, *, device: torch.device | None = None, @@ -77,6 +94,8 @@ def __init__( ) -> None: super().__init__() + if isinstance(center, (float, int)) or center.ndim == 0: + center = np.array([center])[np.newaxis, np.newaxis, np.newaxis] if center.ndim == 1: # 1D vector (C,) center = center[np.newaxis, :, np.newaxis, np.newaxis] elif center.ndim == 3: # 3D vector (1, C, W, H) @@ -152,6 +171,13 @@ class BdPyVGGDomain(ComposedDomain): # These values are calculated from the mean vector of ImageNet ([123, 117, 104]). - Image size: arbitrary - Color space: BGR + + Parameters + ---------- + device : torch.device | None + Device to send/receive images. + dtype : torch.dtype | None + Data type to send/receive images. """ def __init__( From aeb1679ca7a521100ee03ed335bc420d14aebede Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 18:01:29 +0900 Subject: [PATCH 040/117] update docstring and warnings --- bdpy/dl/torch/domain/core.py | 16 ++++++++++++---- bdpy/dl/torch/domain/image_domain.py | 11 ++++++----- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/bdpy/dl/torch/domain/core.py b/bdpy/dl/torch/domain/core.py index a1453018..63778a45 100644 --- a/bdpy/dl/torch/domain/core.py +++ b/bdpy/dl/torch/domain/core.py @@ -1,13 +1,11 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Iterable, TYPE_CHECKING, TypeVar, Generic +from typing import Iterable, TypeVar, Generic +import warnings import torch.nn as nn -if TYPE_CHECKING: - import torch - _T = TypeVar("_T") @@ -73,6 +71,16 @@ class IrreversibleDomain(Domain, Generic[_T]): guarantee the reversibility of `send` and `receive` methods. """ + def __init__(self) -> None: + super().__init__() + warnings.warn( + f"{self.__class__.__name__} is an irreversible domain. " \ + "It does not guarantee the reversibility of `send` and `receive` " \ + "methods. Please use the combination of `send` and `receive` methods " \ + "with caution.", + RuntimeWarning, + ) + def send(self, x: _T) -> _T: return x diff --git a/bdpy/dl/torch/domain/image_domain.py b/bdpy/dl/torch/domain/image_domain.py index 13c75e9d..ea58cd78 100644 --- a/bdpy/dl/torch/domain/image_domain.py +++ b/bdpy/dl/torch/domain/image_domain.py @@ -145,15 +145,16 @@ class PILDomainWithExplicitCrop(IrreversibleDomain): """ def send(self, images: torch.Tensor) -> torch.Tensor: + return _to_channel_first(images) / 255.0 # to [0, 1.0] + + def receive(self, images: torch.Tensor) -> torch.Tensor: warnings.warn( - "PILDomainWithExplicitCrop is an irreversible domain. " \ - "It does not guarantee the reversibility of `send` and `receive` " \ - "methods. Please use PILDomainWithExplicitCrop.send() with caution.", + "`PILDominWithExplicitCrop.receive` performs explicit cropping. " \ + "It could be affected to the gradient computation. " \ + "Please do not use this domain inside the optimization pipeline.", RuntimeWarning, ) - return _to_channel_first(images) / 255.0 # to [0, 1.0] - def receive(self, images: torch.Tensor) -> torch.Tensor: images = _to_channel_last(images) * 255.0 # Crop values to [0, 255] From 4147d7fd89de45841a13bb8a2a34a05d0e1061f2 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 18:04:21 +0900 Subject: [PATCH 041/117] update docstring --- bdpy/dl/torch/domain/image_domain.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bdpy/dl/torch/domain/image_domain.py b/bdpy/dl/torch/domain/image_domain.py index ea58cd78..4873d12a 100644 --- a/bdpy/dl/torch/domain/image_domain.py +++ b/bdpy/dl/torch/domain/image_domain.py @@ -169,7 +169,6 @@ class BdPyVGGDomain(ComposedDomain): - red: [-123, 132] - green: [-117, 138] - blue: [-104, 151] - # These values are calculated from the mean vector of ImageNet ([123, 117, 104]). - Image size: arbitrary - Color space: BGR @@ -179,6 +178,10 @@ class BdPyVGGDomain(ComposedDomain): Device to send/receive images. dtype : torch.dtype | None Data type to send/receive images. + + Notes + ----- + The pixel ranges of this domain are derived from the mean vector of ImageNet ([123, 117, 104]). """ def __init__( From f26b97132e99fec5323bd63e144158ffcb0ea212 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 18:09:32 +0900 Subject: [PATCH 042/117] update docstring --- bdpy/dl/torch/domain/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bdpy/dl/torch/domain/core.py b/bdpy/dl/torch/domain/core.py index 63778a45..7acc016b 100644 --- a/bdpy/dl/torch/domain/core.py +++ b/bdpy/dl/torch/domain/core.py @@ -92,8 +92,8 @@ class ComposedDomain(Domain, Generic[_T]): """The domain composed of multiple sub-domains. Suppose we have list of domain objects `domains = [d_0, d_1, ..., d_n]`. - Then, the data in the original domain `D` can be accessed as - `d_0[0].receive . d_1[1].receive . ... . d_n[n].receive(x)`. + Then, `ComposedDomain(domains)` accesses the data in the original domain `D` + as `d_n.receive . ... d_1.receive . d_0.receive(x)` from the internal common space `D_0`. Parameters ---------- @@ -107,7 +107,7 @@ class ComposedDomain(Domain, Generic[_T]): >>> from bdpy.dl.torch.domain import ComposedDomain >>> from bdpy.dl.torch.domain.image_domain import AffineDomain, BGRDomain >>> composed_domain = ComposedDomain([ - ... AffineDomain(np.array([0.5]), 1), + ... AffineDomain(0.5, 1), ... BGRDomain(), ... ]) >>> image = torch.randn(1, 3, 64, 64).clamp(-0.5, 0.5) From 678b32387d03e89814d35eea4212657fba1122d3 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 18:24:00 +0900 Subject: [PATCH 043/117] test cases for encoder --- tests/recon/torch/modules/test_encoder.py | 25 +++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 tests/recon/torch/modules/test_encoder.py diff --git a/tests/recon/torch/modules/test_encoder.py b/tests/recon/torch/modules/test_encoder.py new file mode 100644 index 00000000..058b5719 --- /dev/null +++ b/tests/recon/torch/modules/test_encoder.py @@ -0,0 +1,25 @@ +"""Tests for bdpy.recon.torch.modules.encoder.""" + +import unittest + +import torch + +from bdpy.recon.torch.modules import encoder as encoder_module + + +class TestBaseEncoder(unittest.TestCase): + """Tests for bdpy.recon.torch.modules.encoder.BaseEncoder.""" + def test_instantiation(self): + """Test instantiation.""" + self.assertRaises(TypeError, encoder_module.BaseEncoder) + + def test_call(self): + """Test __call__.""" + class ReturnAsIsEncoder(encoder_module.BaseEncoder): + def encode(self, images): + return {"image": images} + + encoder = ReturnAsIsEncoder() + images = torch.randn(1, 3, 64, 64) + features = encoder(images) + self.assertDictEqual(features, {"image": images}) From 7f15a52a99e45b669df6dc6f4179d393264b5e43 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 18:38:29 +0900 Subject: [PATCH 044/117] update docstring based on the comments --- bdpy/recon/torch/pipeline/inversion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bdpy/recon/torch/pipeline/inversion.py b/bdpy/recon/torch/pipeline/inversion.py index 7e380636..804db472 100644 --- a/bdpy/recon/torch/pipeline/inversion.py +++ b/bdpy/recon/torch/pipeline/inversion.py @@ -175,10 +175,10 @@ class FeatureInversionPipeline: >>> critic = TargetNormalizedMSE(...) >>> optimizer = torch.optim.Adam(latent.parameters()) >>> pipeline = FeatureInversionPipeline( - ... encoder, generator, latent, critic, optimizer + ... encoder, generator, latent, critic, optimizer, num_iterations=200, ... ) >>> target_features = encoder(target_image) - >>> pipeline.reset_state() + >>> pipeline.reset_states() >>> reconstructed_image = pipeline(target_features) """ From 18218f9eb70d0f97fe287c52068ac14fbaf3e7b3 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 18:57:38 +0900 Subject: [PATCH 045/117] fixed docstring --- bdpy/recon/torch/modules/encoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bdpy/recon/torch/modules/encoder.py b/bdpy/recon/torch/modules/encoder.py index 14b2e5a6..6574bf64 100644 --- a/bdpy/recon/torch/modules/encoder.py +++ b/bdpy/recon/torch/modules/encoder.py @@ -68,10 +68,10 @@ class SimpleEncoder(BaseEncoder): ... nn.Conv2d(3, 3, 3), ... nn.ReLU(), ... ) - >>> encoder = SimpleEncoder(feature_network, ['0']) + >>> encoder = SimpleEncoder(feature_network, ['[0]']) >>> image = torch.randn(1, 3, 64, 64) >>> features = encoder(image) - >>> features['0'].shape + >>> features['[0]'].shape torch.Size([1, 3, 62, 62]) """ From f838bf08635ddf99b538d8f16d2358665f364865 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 18:57:55 +0900 Subject: [PATCH 046/117] fixed bug --- bdpy/util/callback.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bdpy/util/callback.py b/bdpy/util/callback.py index b417d048..3ae0ba49 100644 --- a/bdpy/util/callback.py +++ b/bdpy/util/callback.py @@ -140,10 +140,10 @@ def register(self, callback: BaseCallback) -> None: callback_method = getattr(callback, event_type) if not callable(callback_method): continue - if _is_unused(callback_method): - continue if event_type.startswith("_"): continue + if _is_unused(callback_method): + continue if event_type.startswith("on_"): self._registered_functions[event_type].append(callback_method) continue From cdc1baf49638501edaa57aead11d772f6ee888a6 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 19:04:35 +0900 Subject: [PATCH 047/117] update control flow --- bdpy/util/callback.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/bdpy/util/callback.py b/bdpy/util/callback.py index 3ae0ba49..40120a17 100644 --- a/bdpy/util/callback.py +++ b/bdpy/util/callback.py @@ -140,13 +140,11 @@ def register(self, callback: BaseCallback) -> None: callback_method = getattr(callback, event_type) if not callable(callback_method): continue - if event_type.startswith("_"): + if not event_type.startswith("on_"): continue if _is_unused(callback_method): continue - if event_type.startswith("on_"): - self._registered_functions[event_type].append(callback_method) - continue + self._registered_functions[event_type].append(callback_method) def fire(self, event_type: str, **kwargs: Any) -> None: """Execute the callback functions registered to the event type. From 8ea441d67f0383e0234658543594212ec62381a6 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 19:16:56 +0900 Subject: [PATCH 048/117] update docstring for callback --- bdpy/recon/torch/pipeline/inversion.py | 6 +++-- bdpy/util/callback.py | 37 ++++++++++++++++++++++---- 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/bdpy/recon/torch/pipeline/inversion.py b/bdpy/recon/torch/pipeline/inversion.py index 804db472..47e4aed3 100644 --- a/bdpy/recon/torch/pipeline/inversion.py +++ b/bdpy/recon/torch/pipeline/inversion.py @@ -21,7 +21,8 @@ class FeatureInversionCallback(BaseCallback): As a design principle, the callback functions must not have any side effects on the pipeline results. It should be used only for logging, visualization, - etc. + etc. Please refer to `bdpy.util.callback.BaseCallback` for details of the + usage of callbacks. """ @unused @@ -161,7 +162,8 @@ class FeatureInversionPipeline: num_iterations : int, optional Number of iterations, by default 1. callbacks : FeatureInversionCallback | Iterable[FeatureInversionCallback] | None, optional - Callbacks, by default None. + Callbacks, by default None. Please refer to `bdpy.util.callback.Callback` + and `bdpy.recon.torch.pipeline.FeatureInversionCallback` for details. Examples -------- diff --git a/bdpy/util/callback.py b/bdpy/util/callback.py index 40120a17..b679e4bc 100644 --- a/bdpy/util/callback.py +++ b/bdpy/util/callback.py @@ -65,6 +65,28 @@ class BaseCallback: As a design principle, the callback functions must not have any side effects on the pipeline results. It should be used only for logging, visualization, etc. + + For example, the following callback class logs the start and end of the + pipeline. + + >>> class Callback(BaseCallback): + ... def on_pipeline_start(self): + ... print("Pipeline started.") + ... + ... def on_pipeline_end(self): + ... print("Pipeline ended.") + ... + >>> callback = Callback() + >>> some_pipeline = SomePipeline() # Initialize a pipeline object + >>> some_pipeline.register_callback(callback) + >>> outputs = some_pipeline(inputs) # Run the pipeline + Pipeline started. + Pipeline ended. + + The set of available events that can be hooked into depends on the pipeline. + See the base class of the corresponding pipeline for the list of all events. + `@unused` decorator can be used to mark a callback function as unused, so + that the callback handler does not fire the function. """ @unused @@ -94,17 +116,22 @@ class CallbackHandler: Examples -------- >>> class Callback(BaseCallback): + ... def __init__(self, name): + ... self._name = name + ... ... def on_pipeline_start(self): - ... print("Pipeline started.") + ... print(f"Pipeline started (name={self._name}).") ... ... def on_pipeline_end(self): - ... print("Pipeline ended.") + ... print(f"Pipeline ended (name={self._name}).") ... - >>> handler = CallbackHandler(Callback()) + >>> handler = CallbackHandler([Callback("A"), Callback("B")]) >>> handler.fire("on_pipeline_start") - Pipeline started. + Pipeline started (name=A). + Pipeline started (name=B). >>> handler.fire("on_pipeline_end") - Pipeline ended. + Pipeline ended (name=A). + Pipeline ended (name=B). """ _callbacks: list[BaseCallback] From 71ce6128b888dc7b1bd4ac3eebed8841135955e1 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 19:27:17 +0900 Subject: [PATCH 049/117] explicitly define all the available event types in the base callback class --- bdpy/recon/torch/pipeline/inversion.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/bdpy/recon/torch/pipeline/inversion.py b/bdpy/recon/torch/pipeline/inversion.py index 47e4aed3..0ae7a88f 100644 --- a/bdpy/recon/torch/pipeline/inversion.py +++ b/bdpy/recon/torch/pipeline/inversion.py @@ -25,6 +25,11 @@ class FeatureInversionCallback(BaseCallback): usage of callbacks. """ + @unused + def on_pipeline_start(self) -> None: + """Callback on pipeline start.""" + pass + @unused def on_iteration_start(self, *, step: int) -> None: """Callback on iteration start.""" @@ -65,6 +70,11 @@ def on_iteration_end(self, *, step: int) -> None: """Called at the end of each iteration.""" pass + @unused + def on_pipeline_end(self) -> None: + """Callback on pipeline end.""" + pass + class CUILoggingCallback(FeatureInversionCallback): """Callback for logging on CUI. From 34ecc3ac9a6cc60e11ef5394d4b254a564d221db Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 19:32:34 +0900 Subject: [PATCH 050/117] change the default media_interval to be -1 --- bdpy/recon/torch/pipeline/inversion.py | 44 +++++++++++++++++++------- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/bdpy/recon/torch/pipeline/inversion.py b/bdpy/recon/torch/pipeline/inversion.py index 0ae7a88f..614e89a9 100644 --- a/bdpy/recon/torch/pipeline/inversion.py +++ b/bdpy/recon/torch/pipeline/inversion.py @@ -12,7 +12,9 @@ FeatureType = Dict[str, torch.Tensor] -def _apply_to_features(fn: Callable[[torch.Tensor], torch.Tensor], features: FeatureType) -> FeatureType: +def _apply_to_features( + fn: Callable[[torch.Tensor], torch.Tensor], features: FeatureType +) -> FeatureType: return {k: fn(v) for k, v in features.items()} @@ -46,7 +48,9 @@ def on_feature_extracted(self, *, step: int, features: FeatureType) -> None: pass @unused - def on_layerwise_loss_calculated(self, *, layer_loss: torch.Tensor, layer_name: str) -> None: + def on_layerwise_loss_calculated( + self, *, layer_loss: torch.Tensor, layer_name: str + ) -> None: """Callback on layerwise loss calculated.""" pass @@ -119,20 +123,26 @@ class WandBLoggingCallback(FeatureInversionCallback): Logging interval, by default 1. If `interval` is 1, the callback logs every iteration. media_interval : int, optional - Logging interval for media, by default 1. If `media_interval` is 1, - the callback logs every iteration. + Logging interval for media, by default -1. If `media_interval` is -1, + the callback does not log media. Notes ----- TODO: Currently it does not work because the dependency (wandb) is not installed. """ - def __init__(self, run: wandb.sdk.wandb_run.Run, interval: int = 1, media_interval: int = 1) -> None: + def __init__( + self, run: wandb.sdk.wandb_run.Run, interval: int = 1, media_interval: int = -1 + ) -> None: self._run = run self._interval = interval self._media_interval = media_interval self._step = 0 + if media_interval < 0: + # NOTE: Decorate `on_image_generated` to do nothing. + self.on_image_generated = unused(self.on_image_generated) + def on_iteration_start(self, *, step: int) -> None: # NOTE: We need to store the global step because we cannot access it # in `on_layerwise_loss_calculated` by design. @@ -143,7 +153,9 @@ def on_image_generated(self, *, step: int, image: torch.Tensor) -> None: image = wandb.Image(image) self._run.log({"generated_image": image}, step=self._step) - def on_layerwise_loss_calculated(self, *, layer_loss: torch.Tensor, layer_name: str) -> None: + def on_layerwise_loss_calculated( + self, *, layer_loss: torch.Tensor, layer_name: str + ) -> None: if self._step % self._interval == 0: self._run.log({f"critic/{layer_name}": layer_loss.item()}, step=self._step) @@ -203,7 +215,9 @@ def __init__( optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler = None, num_iterations: int = 1, - callbacks: FeatureInversionCallback | Iterable[FeatureInversionCallback] | None = None, + callbacks: FeatureInversionCallback + | Iterable[FeatureInversionCallback] + | None = None, ) -> None: self._encoder = encoder self._generator = generator @@ -239,13 +253,21 @@ def __call__( latent = self._latent() generated_image = self._generator(latent) - self._callback_handler.fire("on_image_generated", step=step, image=generated_image.detach()) + self._callback_handler.fire( + "on_image_generated", step=step, image=generated_image.detach() + ) features = self._encoder(generated_image) - self._callback_handler.fire("on_feature_extracted", step=step, features=_apply_to_features(lambda x: x.detach(), features)) + self._callback_handler.fire( + "on_feature_extracted", + step=step, + features=_apply_to_features(lambda x: x.detach(), features), + ) loss = self._critic(features, target_features) - self._callback_handler.fire("on_loss_calculated", step=step, loss=loss.detach()) + self._callback_handler.fire( + "on_loss_calculated", step=step, loss=loss.detach() + ) loss.backward() self._callback_handler.fire("on_backward_end", step=step) @@ -270,7 +292,7 @@ def reset_states(self) -> None: self._generator.parameters(), self._latent.parameters(), ), - **self._optimizer.defaults + **self._optimizer.defaults, ) def register_callback(self, callback: FeatureInversionCallback) -> None: From 1c7129b5d1da5b200474a236d4d3cbe53a593483 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 19:41:11 +0900 Subject: [PATCH 051/117] bugfix in docstring --- bdpy/recon/torch/pipeline/inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bdpy/recon/torch/pipeline/inversion.py b/bdpy/recon/torch/pipeline/inversion.py index 614e89a9..f65a1107 100644 --- a/bdpy/recon/torch/pipeline/inversion.py +++ b/bdpy/recon/torch/pipeline/inversion.py @@ -184,7 +184,7 @@ class FeatureInversionPipeline: num_iterations : int, optional Number of iterations, by default 1. callbacks : FeatureInversionCallback | Iterable[FeatureInversionCallback] | None, optional - Callbacks, by default None. Please refer to `bdpy.util.callback.Callback` + Callbacks, by default None. Please refer to `bdpy.util.callback.BaseCallback` and `bdpy.recon.torch.pipeline.FeatureInversionCallback` for details. Examples From 45995b46bb596113d9b9baaaa4a465239d2158a2 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 20:05:40 +0900 Subject: [PATCH 052/117] update docstring --- bdpy/recon/torch/pipeline/inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bdpy/recon/torch/pipeline/inversion.py b/bdpy/recon/torch/pipeline/inversion.py index f65a1107..db573bd8 100644 --- a/bdpy/recon/torch/pipeline/inversion.py +++ b/bdpy/recon/torch/pipeline/inversion.py @@ -244,7 +244,7 @@ def __call__( Returns ------- torch.Tensor - Reconstructed images which have the similar features to the target features. + Reconstructed images on the libraries internal domain. """ self._callback_handler.fire("on_pipeline_start") for step in range(self._num_iterations): From e9912bfc3fc97d1b10691b8b326cdbb781cb0b6c Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 20 Dec 2023 20:09:41 +0900 Subject: [PATCH 053/117] update docstring so that it explictly states which domain to use --- bdpy/recon/torch/modules/encoder.py | 4 ++-- bdpy/recon/torch/modules/generator.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bdpy/recon/torch/modules/encoder.py b/bdpy/recon/torch/modules/encoder.py index 6574bf64..5bc133b5 100644 --- a/bdpy/recon/torch/modules/encoder.py +++ b/bdpy/recon/torch/modules/encoder.py @@ -34,7 +34,7 @@ def __call__(self, images: torch.Tensor) -> dict[str, torch.Tensor]: Parameters ---------- images : torch.Tensor - Images. + Images on the libraries internal domain. Returns ------- @@ -95,7 +95,7 @@ def encode(self, images: torch.Tensor) -> dict[str, torch.Tensor]: Parameters ---------- images : torch.Tensor - Images. + Images on the libraries internal domain. Returns ------- diff --git a/bdpy/recon/torch/modules/generator.py b/bdpy/recon/torch/modules/generator.py index ed70b118..76a5cbcd 100644 --- a/bdpy/recon/torch/modules/generator.py +++ b/bdpy/recon/torch/modules/generator.py @@ -41,7 +41,7 @@ def generate(self, latent: torch.Tensor) -> torch.Tensor: Returns ------- torch.Tensor - Generated image. The generated images must be in the range [0, 1]. + Generated image on the libraries internal domain. """ pass @@ -116,7 +116,7 @@ def generate(self, latent: torch.Tensor) -> torch.Tensor: Returns ------- torch.Tensor - Generated image. The generated images must be in the range [0, 1]. + Generated image on the libraries internal domain. """ return self._domain.send(self._activation(latent)) @@ -181,7 +181,7 @@ def generate(self, latent: torch.Tensor) -> torch.Tensor: Returns ------- torch.Tensor - Generated image. The generated images must be in the range [0, 1]. + Generated image on the libraries internal domain. """ return self._domain.send(self._generator_network(latent)) From bb0ff97750a50b4fcb6e59174f7d15b0fd08f7f2 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 21 Dec 2023 11:08:04 +0900 Subject: [PATCH 054/117] update test case --- tests/recon/torch/modules/test_encoder.py | 60 +++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/tests/recon/torch/modules/test_encoder.py b/tests/recon/torch/modules/test_encoder.py index 058b5719..f8bec2e2 100644 --- a/tests/recon/torch/modules/test_encoder.py +++ b/tests/recon/torch/modules/test_encoder.py @@ -3,18 +3,38 @@ import unittest import torch +import torch.nn as nn +from bdpy.dl.torch.domain.image_domain import Zero2OneImageDomain from bdpy.recon.torch.modules import encoder as encoder_module +class MLP(nn.Module): + """A simple MLP.""" + + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(64 * 64 * 3, 256) + self.fc2 = nn.Linear(256, 128) + + def forward(self, x): + x = x.view(x.size(0), -1) + x = self.fc1(x) + x = torch.relu(x) + x = self.fc2(x) + return x + + class TestBaseEncoder(unittest.TestCase): """Tests for bdpy.recon.torch.modules.encoder.BaseEncoder.""" + def test_instantiation(self): """Test instantiation.""" self.assertRaises(TypeError, encoder_module.BaseEncoder) def test_call(self): """Test __call__.""" + class ReturnAsIsEncoder(encoder_module.BaseEncoder): def encode(self, images): return {"image": images} @@ -23,3 +43,43 @@ def encode(self, images): images = torch.randn(1, 3, 64, 64) features = encoder(images) self.assertDictEqual(features, {"image": images}) + + +class TestSimpleEncoder(unittest.TestCase): + """Tests for bdpy.recon.torch.modules.encoder.SimpleEncoder.""" + + def test_call(self): + """Test __call__.""" + encoder = encoder_module.SimpleEncoder( + MLP(), ["fc1", "fc2"], domain=Zero2OneImageDomain() + ) + images = torch.randn(1, 3, 64, 64).clamp(0, 1) + features = encoder(images) + self.assertIsInstance(features, dict) + self.assertEqual(len(features), 2) + self.assertEqual(features["fc1"].shape, (1, 256)) + self.assertEqual(features["fc2"].shape, (1, 128)) + + +class TestBuildEncoder(unittest.TestCase): + """Tests for bdpy.recon.torch.modules.encoder.build_encoder.""" + + def test_build_encoder(self): + """Test build_encoder.""" + mlp = MLP() + encoder_from_builder = encoder_module.build_encoder( + feature_network=mlp, + layer_names=["fc1", "fc2"], + domain=Zero2OneImageDomain(), + ) + encoder = encoder_module.SimpleEncoder( + mlp, ["fc1", "fc2"], domain=Zero2OneImageDomain() + ) + + images = torch.randn(1, 3, 64, 64).clamp(0, 1) + features_from_builder = encoder_from_builder(images) + features = encoder(images) + self.assertEqual(type(encoder_from_builder), type(encoder)) + self.assertEqual(features_from_builder.keys(), features.keys()) + for key in features_from_builder.keys(): + self.assertTrue(torch.allclose(features_from_builder[key], features[key])) From 1b8be764a78267aa6c18d12c214244e64e9967f2 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 21 Dec 2023 11:08:11 +0900 Subject: [PATCH 055/117] fix docstring --- bdpy/recon/torch/modules/encoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bdpy/recon/torch/modules/encoder.py b/bdpy/recon/torch/modules/encoder.py index 5bc133b5..6c25eed5 100644 --- a/bdpy/recon/torch/modules/encoder.py +++ b/bdpy/recon/torch/modules/encoder.py @@ -140,10 +140,10 @@ def build_encoder( ... nn.Conv2d(3, 3, 3), ... nn.ReLU(), ... ) - >>> encoder = build_encoder(feature_network, layer_names=['0']) + >>> encoder = build_encoder(feature_network, layer_names=['[0]']) >>> image = torch.randn(1, 3, 64, 64) >>> features = encoder(image) - >>> features['0'].shape + >>> features['[0]'].shape torch.Size([1, 3, 62, 62]) """ return SimpleEncoder(feature_network, layer_names, domain, device) From f4a5f36f1d9e9aac3cee17c628f7e625bf8ba090 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 21 Dec 2023 11:13:19 +0900 Subject: [PATCH 056/117] remove device option from encoder API --- bdpy/recon/torch/modules/encoder.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/bdpy/recon/torch/modules/encoder.py b/bdpy/recon/torch/modules/encoder.py index 6c25eed5..a89745fc 100644 --- a/bdpy/recon/torch/modules/encoder.py +++ b/bdpy/recon/torch/modules/encoder.py @@ -56,8 +56,6 @@ class SimpleEncoder(BaseEncoder): Layer names to extract features from. domain : Domain, optional Domain of the input images to receive. (default: Zero2OneImageDomain()) - device : torch.device, optional - Device to use. (default: "cpu"). Examples -------- @@ -80,11 +78,10 @@ def __init__( feature_network: nn.Module, layer_names: Iterable[str], domain: Domain = image_domain.Zero2OneImageDomain(), - device: str | torch.device = "cpu", ) -> None: super().__init__() self._feature_extractor = FeatureExtractor( - encoder=feature_network, layers=layer_names, detach=False, device=device + encoder=feature_network, layers=layer_names, detach=False, device=None ) self._domain = domain self._feature_network = self._feature_extractor._encoder @@ -110,7 +107,6 @@ def build_encoder( feature_network: nn.Module, layer_names: Iterable[str], domain: Domain = image_domain.Zero2OneImageDomain(), - device: str | torch.device = "cpu", ) -> BaseEncoder: """Build an encoder network with a naive feature extractor. @@ -123,8 +119,6 @@ def build_encoder( Layer names to extract features from. domain : Domain, optional Domain of the input images to receive (default: Zero2OneImageDomain()). - device : torch.device, optional - Device to use. (default: "cpu"). Returns ------- @@ -146,4 +140,4 @@ def build_encoder( >>> features['[0]'].shape torch.Size([1, 3, 62, 62]) """ - return SimpleEncoder(feature_network, layer_names, domain, device) + return SimpleEncoder(feature_network, layer_names, domain) From 8b320b615bb8f74930d66fde99963cc22c31b6f4 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 21 Dec 2023 11:22:11 +0900 Subject: [PATCH 057/117] main statement in test cases --- tests/recon/torch/modules/test_critic.py | 4 ++++ tests/recon/torch/modules/test_encoder.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/tests/recon/torch/modules/test_critic.py b/tests/recon/torch/modules/test_critic.py index 4c474c14..c4600f56 100644 --- a/tests/recon/torch/modules/test_critic.py +++ b/tests/recon/torch/modules/test_critic.py @@ -167,3 +167,7 @@ def test_compare_layer(self): loss, torch.sum((feature - target_feature)**2, dim=1)/torch.sum(target_feature**2, dim=1) )) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/recon/torch/modules/test_encoder.py b/tests/recon/torch/modules/test_encoder.py index f8bec2e2..31caa746 100644 --- a/tests/recon/torch/modules/test_encoder.py +++ b/tests/recon/torch/modules/test_encoder.py @@ -83,3 +83,7 @@ def test_build_encoder(self): self.assertEqual(features_from_builder.keys(), features.keys()) for key in features_from_builder.keys(): self.assertTrue(torch.allclose(features_from_builder[key], features[key])) + + +if __name__ == "__main__": + unittest.main() From f67228e64c7b4fa8877234d2b9c4822fcb6ed069 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 21 Dec 2023 11:51:04 +0900 Subject: [PATCH 058/117] update reset function --- bdpy/recon/torch/modules/generator.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/bdpy/recon/torch/modules/generator.py b/bdpy/recon/torch/modules/generator.py index 76a5cbcd..5fc2b36b 100644 --- a/bdpy/recon/torch/modules/generator.py +++ b/bdpy/recon/torch/modules/generator.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from typing import Callable, Iterator @@ -8,12 +10,24 @@ from bdpy.dl.torch.domain import Domain, image_domain +def _get_reset_module_fn(module: nn.Module) -> Callable[[], None] | None: + """Get the function to reset the parameters of the module.""" + reset_parameters = getattr(module, "reset_parameters", None) + if callable(reset_parameters): + return reset_parameters + # NOTE: This is needed for nn.MultiheadAttention + reset_parameters = getattr(module, "_reset_parameters", None) + if callable(reset_parameters): + return reset_parameters + return None + + @torch.no_grad() def reset_all_parameters(module: nn.Module) -> None: """Reset the parameters of the module.""" - reset_parameters = getattr(module, "reset_parameters", None) - if callable(reset_parameters): - module.reset_parameters() + reset_parameters = _get_reset_module_fn(module) + if reset_parameters is not None: + reset_parameters() class BaseGenerator(ABC): From c9aa64d56036d06411bf331e5c5801e02681a271 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 21 Dec 2023 13:30:45 +0900 Subject: [PATCH 059/117] update docstring --- bdpy/recon/torch/modules/encoder.py | 11 ++++++++++- bdpy/recon/torch/modules/generator.py | 6 +++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/bdpy/recon/torch/modules/encoder.py b/bdpy/recon/torch/modules/encoder.py index a89745fc..0d5e15f8 100644 --- a/bdpy/recon/torch/modules/encoder.py +++ b/bdpy/recon/torch/modules/encoder.py @@ -110,15 +110,24 @@ def build_encoder( ) -> BaseEncoder: """Build an encoder network with a naive feature extractor. + The function builds an encoder module from a feature network that takes + images on its own domain as input and processes them. The encoder module + receives images on the library's internal domain and returns features on the + library's internal domain indexed by `layer_names`. `domain` is used to + convert the input images to the feature network's domain from the library's + internal domain. + Parameters ---------- feature_network : nn.Module Feature network. This network should have a method `forward` that takes - an image tensor and propagates it through the network. + an image tensor and propagates it through the network. The images should + be on the network's own domain. layer_names : list[str] Layer names to extract features from. domain : Domain, optional Domain of the input images to receive (default: Zero2OneImageDomain()). + One needs to specify the equivalent domain of the feature network. Returns ------- diff --git a/bdpy/recon/torch/modules/generator.py b/bdpy/recon/torch/modules/generator.py index 5fc2b36b..a5de2718 100644 --- a/bdpy/recon/torch/modules/generator.py +++ b/bdpy/recon/torch/modules/generator.py @@ -256,11 +256,15 @@ def build_generator( ) -> BaseGenerator: """Build a generator module. + This function builds a generator module from a generator network that takes + a latent vector as an input and returns an image on its own domain. One + needs to specify the domain of the generator network. + Parameters ---------- generator_network : nn.Module Generator network. This network should have a method `forward` that takes - a latent vector and propagates it through the network. + a latent vector and returns an image on its own domain. domain : Domain, optional Domain of the input images to receive. (default: Zero2OneImageDomain()) reset_fn : Callable[[nn.Module], None], optional From 87f92c386e4867b1c7e597566dec46df1c27eb81 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 21 Dec 2023 13:32:18 +0900 Subject: [PATCH 060/117] update docstring --- bdpy/recon/torch/modules/encoder.py | 3 ++- bdpy/recon/torch/modules/generator.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/bdpy/recon/torch/modules/encoder.py b/bdpy/recon/torch/modules/encoder.py index 0d5e15f8..662038db 100644 --- a/bdpy/recon/torch/modules/encoder.py +++ b/bdpy/recon/torch/modules/encoder.py @@ -127,7 +127,8 @@ def build_encoder( Layer names to extract features from. domain : Domain, optional Domain of the input images to receive (default: Zero2OneImageDomain()). - One needs to specify the equivalent domain of the feature network. + One needs to specify the domain that corresponds to the feature network's + input domain. Returns ------- diff --git a/bdpy/recon/torch/modules/generator.py b/bdpy/recon/torch/modules/generator.py index a5de2718..15a1b95b 100644 --- a/bdpy/recon/torch/modules/generator.py +++ b/bdpy/recon/torch/modules/generator.py @@ -266,7 +266,9 @@ def build_generator( Generator network. This network should have a method `forward` that takes a latent vector and returns an image on its own domain. domain : Domain, optional - Domain of the input images to receive. (default: Zero2OneImageDomain()) + Domain of the input images to receive. (default: Zero2OneImageDomain()). + One needs to specify the domain that corresponds to the generator + network's output domain. reset_fn : Callable[[nn.Module], None], optional Function to reset the parameters of the generator network, by default reset_all_parameters. From 8c09756b9da8011142fad3392eeaee6f93b27bc5 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 21 Dec 2023 14:57:50 +0900 Subject: [PATCH 061/117] test case for generator module --- tests/recon/torch/modules/test_generator.py | 87 +++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 tests/recon/torch/modules/test_generator.py diff --git a/tests/recon/torch/modules/test_generator.py b/tests/recon/torch/modules/test_generator.py new file mode 100644 index 00000000..44cfe06d --- /dev/null +++ b/tests/recon/torch/modules/test_generator.py @@ -0,0 +1,87 @@ +"""Tests for bdpy.recon.torch.modules.generator.""" + +import unittest + +import copy + +import torch +import torch.nn as nn + +from bdpy.recon.torch.modules import generator as generator_module + + +class TestResetAllParameters(unittest.TestCase): + """Tests for bdpy.recon.torch.modules.generator.reset_all_parameters.""" + + def test_reset_all_parameters(self): + """Test reset_all_parameters.""" + pass + + +class TestBaseGenerator(unittest.TestCase): + """Tests for bdpy.recon.torch.modules.generator.BaseGenerator.""" + + def test_instantiation(self): + """Test instantiation.""" + self.assertRaises(TypeError, generator_module.BaseGenerator) + + def test_call(self): + """Test __call__.""" + + class ReturnAsIsGenerator(generator_module.BaseGenerator): + def generate(self, latent): + return latent + + def reset_states(self) -> None: + pass + + def parameters(self, recurse=True): + return iter([]) + + generator = ReturnAsIsGenerator() + latent = torch.randn(1, 3, 64, 64) + generated_image = generator(latent) + self.assertEqual(generated_image.shape, (1, 3, 64, 64)) + + +class TestNNModuleGenerator(unittest.TestCase): + """Tests for bdpy.recon.torch.modules.generator.NNModuleGenerator.""" + + def setUp(self): + """Set up.""" + + class LinearGenerator(generator_module.NNModuleGenerator): + def __init__(self): + super().__init__() + self.fc = nn.Linear(64, 64) + + def generate(self, latent): + return self.fc(latent) + + def reset_states(self) -> None: + self.fc.apply(generator_module.reset_all_parameters) + + self.generator = LinearGenerator() + + def test_instantiation(self): + """Test instantiation.""" + self.assertRaises(TypeError, generator_module.NNModuleGenerator) + + def test_call(self): + """Test __call__.""" + latent = torch.randn(1, 64) + generated_image = self.generator(latent) + self.assertEqual(generated_image.shape, (1, 64)) + + def test_reset_states(self): + """Test reset_states.""" + generator_copy = copy.deepcopy(self.generator) + for p1, p2 in zip(self.generator.parameters(), generator_copy.parameters()): + self.assertTrue(torch.equal(p1, p2)) + self.generator.reset_states() + for p1, p2 in zip(self.generator.parameters(), generator_copy.parameters()): + self.assertFalse(torch.equal(p1, p2)) + + +if __name__ == "__main__": + unittest.main() From e8d88af0e34c8be6e15eff2af85a39d799818635 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 21 Dec 2023 15:02:34 +0900 Subject: [PATCH 062/117] update test case --- tests/recon/torch/modules/test_generator.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/recon/torch/modules/test_generator.py b/tests/recon/torch/modules/test_generator.py index 44cfe06d..08187efd 100644 --- a/tests/recon/torch/modules/test_generator.py +++ b/tests/recon/torch/modules/test_generator.py @@ -53,7 +53,7 @@ def setUp(self): class LinearGenerator(generator_module.NNModuleGenerator): def __init__(self): super().__init__() - self.fc = nn.Linear(64, 64) + self.fc = nn.Linear(64, 10) def generate(self, latent): return self.fc(latent) @@ -71,7 +71,9 @@ def test_call(self): """Test __call__.""" latent = torch.randn(1, 64) generated_image = self.generator(latent) - self.assertEqual(generated_image.shape, (1, 64)) + self.assertEqual(generated_image.shape, (1, 10)) + generated_image.sum().backward() + self.assertIsNotNone(self.generator.fc.weight.grad) def test_reset_states(self): """Test reset_states.""" From ceda3370dfe0b852f0d4fdd49245845970873d85 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 21 Dec 2023 17:18:13 +0900 Subject: [PATCH 063/117] bugfix --- bdpy/recon/torch/modules/latent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bdpy/recon/torch/modules/latent.py b/bdpy/recon/torch/modules/latent.py index f1a9e7db..28195c70 100644 --- a/bdpy/recon/torch/modules/latent.py +++ b/bdpy/recon/torch/modules/latent.py @@ -80,7 +80,7 @@ def __init__(self, shape: tuple[int, ...], init_fn: Callable[[torch.Tensor], Non super().__init__() self._shape = shape self._init_fn = init_fn - self._latent = torch.empty(shape) + self._latent = nn.Parameter(torch.empty(shape)) def reset_states(self) -> None: """Reset the state of the latent variable.""" From 8f13ebb34b99db4826cfdb9dc9dbd364f81ce8b7 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 21 Dec 2023 17:24:31 +0900 Subject: [PATCH 064/117] make the list of event types in the callback object as minimum as possible --- bdpy/recon/torch/pipeline/inversion.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/bdpy/recon/torch/pipeline/inversion.py b/bdpy/recon/torch/pipeline/inversion.py index db573bd8..f8ba25e6 100644 --- a/bdpy/recon/torch/pipeline/inversion.py +++ b/bdpy/recon/torch/pipeline/inversion.py @@ -42,11 +42,6 @@ def on_image_generated(self, *, step: int, image: torch.Tensor) -> None: """Callback on image generated.""" pass - @unused - def on_feature_extracted(self, *, step: int, features: FeatureType) -> None: - """Callback on feature extracted.""" - pass - @unused def on_layerwise_loss_calculated( self, *, layer_loss: torch.Tensor, layer_name: str @@ -59,16 +54,6 @@ def on_loss_calculated(self, *, step: int, loss: torch.Tensor) -> None: """Callback on loss calculated.""" pass - @unused - def on_backward_end(self, *, step: int) -> None: - """Callback on backward end.""" - pass - - @unused - def on_optimizer_step(self, *, step: int) -> None: - """Callback on optimizer step.""" - pass - @unused def on_iteration_end(self, *, step: int) -> None: """Called at the end of each iteration.""" From 72021f8036596ed1160ef1885d704c9ce5100292 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 21 Dec 2023 17:32:33 +0900 Subject: [PATCH 065/117] deprecation warning --- bdpy/dl/torch/torch.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/bdpy/dl/torch/torch.py b/bdpy/dl/torch/torch.py index 7d325248..de2a8c78 100644 --- a/bdpy/dl/torch/torch.py +++ b/bdpy/dl/torch/torch.py @@ -5,6 +5,7 @@ from typing import Iterable, List, Dict, Union, Tuple, Any, Callable, Optional import os +import warnings import numpy as np from PIL import Image @@ -171,6 +172,13 @@ def __init__( - Images are converted to RGB. Alpha channels in RGBA images are ignored. ''' + warnings.warn( + "dl.torch.torch.ImageDataset is deprecated. Please consider using " \ + "bdpy.dl.torch.dataset.ImageDataset instead.", + DeprecationWarning, + stacklevel=2 + ) + self.transform = transform # Custom transforms self.__shape = shape From 54add00aaa3271770551fe01f50cd21194fbed9d Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 21 Dec 2023 17:34:38 +0900 Subject: [PATCH 066/117] move callback.py to bdpy.pipeline --- bdpy/{util => pipeline}/callback.py | 0 bdpy/recon/torch/modules/critic.py | 2 +- bdpy/recon/torch/pipeline/inversion.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename bdpy/{util => pipeline}/callback.py (100%) diff --git a/bdpy/util/callback.py b/bdpy/pipeline/callback.py similarity index 100% rename from bdpy/util/callback.py rename to bdpy/pipeline/callback.py diff --git a/bdpy/recon/torch/modules/critic.py b/bdpy/recon/torch/modules/critic.py index 18009eb0..af537e46 100644 --- a/bdpy/recon/torch/modules/critic.py +++ b/bdpy/recon/torch/modules/critic.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn -from bdpy.util.callback import CallbackHandler, BaseCallback +from bdpy.pipeline.callback import CallbackHandler, BaseCallback _FeatureType = Dict[str, torch.Tensor] diff --git a/bdpy/recon/torch/pipeline/inversion.py b/bdpy/recon/torch/pipeline/inversion.py index f8ba25e6..8174b01a 100644 --- a/bdpy/recon/torch/pipeline/inversion.py +++ b/bdpy/recon/torch/pipeline/inversion.py @@ -7,7 +7,7 @@ import torch from ..modules import BaseEncoder, BaseGenerator, BaseLatent, BaseCritic -from bdpy.util.callback import CallbackHandler, BaseCallback, unused +from bdpy.pipeline.callback import CallbackHandler, BaseCallback, unused FeatureType = Dict[str, torch.Tensor] From 75ddefcacba5c4d0508e8fd535cd98ff18fcd498 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 21 Dec 2023 17:53:22 +0900 Subject: [PATCH 067/117] add callback validation --- bdpy/pipeline/callback.py | 35 +++++++++++++++++++++++--- bdpy/recon/torch/pipeline/inversion.py | 8 +++++- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/bdpy/pipeline/callback.py b/bdpy/pipeline/callback.py index b679e4bc..fe409ba2 100644 --- a/bdpy/pipeline/callback.py +++ b/bdpy/pipeline/callback.py @@ -48,7 +48,9 @@ def unused(fn: Callable[_P, Any]) -> Callable[_P, _Unused]: @wraps(fn) # NOTE: preserve name, docstring, etc. of the original function def _unused(*args: _P.args, **kwargs: _P.kwargs) -> _Unused: - raise RuntimeError(f"Function {fn} is decorated with @unused and must not be called.") + raise RuntimeError( + f"Function {fn} is decorated with @unused and must not be called." + ) # NOTE: change the return type to Unused _unused.__annotations__["return"] = _Unused @@ -56,6 +58,28 @@ def _unused(*args: _P.args, **kwargs: _P.kwargs) -> _Unused: return _unused +def _validate_callback(callback: BaseCallback, base_class: Type[BaseCallback]) -> None: + if not isinstance(callback, base_class): + raise TypeError( + f"Callback must be an instance of {base_class}, not {type(callback)}." + ) + acceptable_events = [] + for event_type in dir(base_class): + if event_type.startswith("on_") and callable(getattr(base_class, event_type)): + acceptable_events.append(event_type) + for event_type in dir(callback): + if not ( + event_type.startswith("on_") and callable(getattr(callback, event_type)) + ): + continue + if event_type not in acceptable_events: + raise ValueError( + f"{event_type} is not an acceptable event type. " + f"Acceptable event types are {acceptable_events}. " + f"Please refer to the documentation of {base_class.__name__} for the list of acceptable event types." + ) + + class BaseCallback: """Base class for callbacks. @@ -137,7 +161,9 @@ class CallbackHandler: _callbacks: list[BaseCallback] _registered_functions: defaultdict[str, list[Callable]] - def __init__(self, callbacks: BaseCallback | Iterable[BaseCallback] | None = None) -> None: + def __init__( + self, callbacks: BaseCallback | Iterable[BaseCallback] | None = None + ) -> None: self._callbacks = [] self._registered_functions = defaultdict(list) if callbacks is not None: @@ -160,7 +186,9 @@ def register(self, callback: BaseCallback) -> None: If the callback is not an instance of BaseCallback. """ if not isinstance(callback, BaseCallback): - raise TypeError(f"Callback must be an instance of BaseCallback, not {type(callback)}.") + raise TypeError( + f"Callback must be an instance of BaseCallback, not {type(callback)}." + ) self._callbacks.append(callback) for event_type in dir(callback): @@ -190,4 +218,3 @@ def fire(self, event_type: str, **kwargs: Any) -> None: """ for callback_method in self._registered_functions[event_type]: callback_method(**kwargs) - diff --git a/bdpy/recon/torch/pipeline/inversion.py b/bdpy/recon/torch/pipeline/inversion.py index 8174b01a..cc220b9a 100644 --- a/bdpy/recon/torch/pipeline/inversion.py +++ b/bdpy/recon/torch/pipeline/inversion.py @@ -7,7 +7,7 @@ import torch from ..modules import BaseEncoder, BaseGenerator, BaseLatent, BaseCritic -from bdpy.pipeline.callback import CallbackHandler, BaseCallback, unused +from bdpy.pipeline.callback import CallbackHandler, BaseCallback, unused, _validate_callback FeatureType = Dict[str, torch.Tensor] @@ -27,6 +27,10 @@ class FeatureInversionCallback(BaseCallback): usage of callbacks. """ + def __init__(self) -> None: + super().__init__() + _validate_callback(self, FeatureInversionCallback) + @unused def on_pipeline_start(self) -> None: """Callback on pipeline start.""" @@ -79,6 +83,7 @@ class CUILoggingCallback(FeatureInversionCallback): """ def __init__(self, interval: int = 1, total_steps: int = -1) -> None: + super().__init__() self._interval = interval self._total_steps = total_steps self._loss: int | float = -1 @@ -119,6 +124,7 @@ class WandBLoggingCallback(FeatureInversionCallback): def __init__( self, run: wandb.sdk.wandb_run.Run, interval: int = 1, media_interval: int = -1 ) -> None: + super().__init__() self._run = run self._interval = interval self._media_interval = media_interval From 6c46bbf93a9c108c78562660883d3c15e25b9321 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 21 Dec 2023 17:58:32 +0900 Subject: [PATCH 068/117] bugfix --- bdpy/recon/torch/pipeline/inversion.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/bdpy/recon/torch/pipeline/inversion.py b/bdpy/recon/torch/pipeline/inversion.py index cc220b9a..660f5259 100644 --- a/bdpy/recon/torch/pipeline/inversion.py +++ b/bdpy/recon/torch/pipeline/inversion.py @@ -249,21 +249,14 @@ def __call__( ) features = self._encoder(generated_image) - self._callback_handler.fire( - "on_feature_extracted", - step=step, - features=_apply_to_features(lambda x: x.detach(), features), - ) loss = self._critic(features, target_features) self._callback_handler.fire( "on_loss_calculated", step=step, loss=loss.detach() ) loss.backward() - self._callback_handler.fire("on_backward_end", step=step) self._optimizer.step() - self._callback_handler.fire("on_optimizer_step", step=step) if self._scheduler is not None: self._scheduler.step() From 2d5bd86c9c159f83d2726c820b6f2fec870753ac Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 21 Dec 2023 18:01:37 +0900 Subject: [PATCH 069/117] use .clone().detach() when pipeline passes tensor to callback --- bdpy/recon/torch/pipeline/inversion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bdpy/recon/torch/pipeline/inversion.py b/bdpy/recon/torch/pipeline/inversion.py index 660f5259..8f9001f7 100644 --- a/bdpy/recon/torch/pipeline/inversion.py +++ b/bdpy/recon/torch/pipeline/inversion.py @@ -245,14 +245,14 @@ def __call__( latent = self._latent() generated_image = self._generator(latent) self._callback_handler.fire( - "on_image_generated", step=step, image=generated_image.detach() + "on_image_generated", step=step, image=generated_image.clone().detach() ) features = self._encoder(generated_image) loss = self._critic(features, target_features) self._callback_handler.fire( - "on_loss_calculated", step=step, loss=loss.detach() + "on_loss_calculated", step=step, loss=loss.clone().detach() ) loss.backward() From eeb64c6be65bb8afcc278246c4017ba0651f5e0b Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 21 Dec 2023 18:40:38 +0900 Subject: [PATCH 070/117] add dependency on torchvision --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index d1c27baa..5af3637f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ caffe = [ ] torch = [ "torch", + "torchvision", "Pillow" ] fig = [ From c3babf4e0cadf862f063a89cc144a56a9c7cec00 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 21 Dec 2023 19:02:03 +0900 Subject: [PATCH 071/117] add FixedResolutionDomain --- bdpy/dl/torch/domain/image_domain.py | 41 ++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/bdpy/dl/torch/domain/image_domain.py b/bdpy/dl/torch/domain/image_domain.py index 4873d12a..b53b7fc8 100644 --- a/bdpy/dl/torch/domain/image_domain.py +++ b/bdpy/dl/torch/domain/image_domain.py @@ -16,6 +16,7 @@ import numpy as np import torch +from torchvision.transforms import InterpolationMode, Resize from .core import Domain, IrreversibleDomain, ComposedDomain @@ -198,3 +199,43 @@ def __init__( BGRDomain(), ] ) + + +class FixedResolutionDomain(IrreversibleDomain): + """Image domain for images with fixed resolution. + + Parameters + ---------- + image_shape : tuple[int, int] + Spatial resolution of the images. + interpolation : InterpolationMode, optional + Interpolation mode for resizing. (default: InterpolationMode.BILINEAR) + antialias : bool, optional + Whether to use antialiasing. (default: True) + """ + + def __init__( + self, + image_shape: tuple[int, int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + antialias: bool = True, + ) -> None: + super().__init__() + self._image_shape = image_shape + self._interpolation = interpolation + self._antialias = antialias + + self._resizer = Resize( + size=self._image_shape, + interpolation=self._interpolation, + antialias=self._antialias + ) + + def send(self, images: torch.Tensor) -> torch.Tensor: + raise RuntimeError( + "FixedResolutionDomain is not supposed to be used for sending images " \ + "because the internal image resolution could not be determined." + ) + + def receive(self, images: torch.Tensor) -> torch.Tensor: + return self._resizer(images) From 9458842992e7c3e135db35fc238d958fc3e02db0 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 21 Dec 2023 20:02:06 +0900 Subject: [PATCH 072/117] update test case for reset_all_parameters --- tests/recon/torch/modules/test_generator.py | 42 ++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/tests/recon/torch/modules/test_generator.py b/tests/recon/torch/modules/test_generator.py index 08187efd..d4f735d5 100644 --- a/tests/recon/torch/modules/test_generator.py +++ b/tests/recon/torch/modules/test_generator.py @@ -6,16 +6,56 @@ import torch import torch.nn as nn +from torchvision.models import get_model from bdpy.recon.torch.modules import generator as generator_module class TestResetAllParameters(unittest.TestCase): """Tests for bdpy.recon.torch.modules.generator.reset_all_parameters.""" + def setUp(self): + self.model_ids = [ + "alexnet", + "efficientnet_b0", + "fasterrcnn_resnet50_fpn", + "inception_v3", + "resnet18", + "vgg11", + "vit_b_16", + ] + # NOTE: The following modules are excluded from validation because they + # initialize their parameters as constants every time. + self.excluded_modules = [ + nn.modules.batchnorm._BatchNorm, + nn.LayerNorm, + ] + + def _validate_module(self, module: nn.Module, module_copy: nn.Module, parent_name: str = ""): + if isinstance(module, tuple(self.excluded_modules)): + return + for (name_p1, p1), (_, p2) in zip(module.named_parameters(recurse=False), module_copy.named_parameters(recurse=False)): + # NOTE: skip parameters that are prbably not randomly initialized + if "weight" not in name_p1: + continue + self.assertFalse( + torch.equal(p1, p2), + msg=f"Parameter {parent_name}.{name_p1} does not change after reset_all_parameters." + ) + for (name_m1, m1), (_, m2) in zip(module.named_children(), module_copy.named_children()): + self._validate_module(m1, m2, f"{parent_name}.{name_m1}") def test_reset_all_parameters(self): """Test reset_all_parameters.""" - pass + for model_id in self.model_ids: + model = get_model(model_id) + model_copy = copy.deepcopy(model) + for (name_p1, p1), (_, p2) in zip(model.named_parameters(), model_copy.named_parameters()): + self.assertTrue( + torch.equal(p1, p2), + msg=f"Parameter {name_p1} of {model_id} has been changed by deepcopy." + ) + model.apply(generator_module.reset_all_parameters) + self._validate_module(model, model_copy, model_id) class TestBaseGenerator(unittest.TestCase): From 718702af19fb1b4ca27de86eb60581c8e65caae2 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Fri, 22 Dec 2023 13:42:47 +0900 Subject: [PATCH 073/117] update test case --- tests/recon/torch/modules/test_generator.py | 59 ++++++++++++++++----- 1 file changed, 47 insertions(+), 12 deletions(-) diff --git a/tests/recon/torch/modules/test_generator.py b/tests/recon/torch/modules/test_generator.py index d4f735d5..63eb9090 100644 --- a/tests/recon/torch/modules/test_generator.py +++ b/tests/recon/torch/modules/test_generator.py @@ -11,6 +11,18 @@ from bdpy.recon.torch.modules import generator as generator_module +class LinearGenerator(generator_module.NNModuleGenerator): + def __init__(self): + super().__init__() + self.fc = nn.Linear(64, 10) + + def generate(self, latent): + return self.fc(latent) + + def reset_states(self) -> None: + self.fc.apply(generator_module.reset_all_parameters) + + class TestResetAllParameters(unittest.TestCase): """Tests for bdpy.recon.torch.modules.generator.reset_all_parameters.""" def setUp(self): @@ -89,18 +101,6 @@ class TestNNModuleGenerator(unittest.TestCase): def setUp(self): """Set up.""" - - class LinearGenerator(generator_module.NNModuleGenerator): - def __init__(self): - super().__init__() - self.fc = nn.Linear(64, 10) - - def generate(self, latent): - return self.fc(latent) - - def reset_states(self) -> None: - self.fc.apply(generator_module.reset_all_parameters) - self.generator = LinearGenerator() def test_instantiation(self): @@ -125,5 +125,40 @@ def test_reset_states(self): self.assertFalse(torch.equal(p1, p2)) +class TestBareGenerator(unittest.TestCase): + """Tests for bdpy.recon.torch.modules.generator.BareGenerator.""" + + def test_call(self): + """Test __call__.""" + generator = generator_module.BareGenerator(activation=torch.sigmoid) + latent = torch.randn(1, 3, 64, 64) + generated_image = generator(latent) + self.assertEqual(generated_image.shape, (1, 3, 64, 64)) + torch.testing.assert_close(generated_image, torch.sigmoid(latent)) + + +class TestDNNGenerator(unittest.TestCase): + """Tests for bdpy.recon.torch.modules.generator.DNNGenerator.""" + def test_call(self): + """Test __call__.""" + generator_network = LinearGenerator() + generator = generator_module.DNNGenerator(generator_network) + latent = torch.randn(1, 64) + generated_image = generator(latent) + self.assertEqual(generated_image.shape, (1, 10)) + generated_image.sum().backward() + self.assertIsNotNone(generator_network.fc.weight.grad) + + def test_reset_states(self): + """Test reset_states.""" + generator = generator_module.DNNGenerator(LinearGenerator()) + generator_copy = copy.deepcopy(generator) + for p1, p2 in zip(generator.parameters(), generator_copy.parameters()): + self.assertTrue(torch.equal(p1, p2)) + generator.reset_states() + for p1, p2 in zip(generator.parameters(), generator_copy.parameters()): + self.assertFalse(torch.equal(p1, p2)) + + if __name__ == "__main__": unittest.main() From cd6529ae6cee6b9be8ef8f8c03412f5ffc787468 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Fri, 22 Dec 2023 14:09:02 +0900 Subject: [PATCH 074/117] NNModuleEncoder --- bdpy/recon/torch/modules/encoder.py | 24 +++++++++++++++++- tests/recon/torch/modules/test_encoder.py | 31 +++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/bdpy/recon/torch/modules/encoder.py b/bdpy/recon/torch/modules/encoder.py index 662038db..027d2d83 100644 --- a/bdpy/recon/torch/modules/encoder.py +++ b/bdpy/recon/torch/modules/encoder.py @@ -44,7 +44,29 @@ def __call__(self, images: torch.Tensor) -> dict[str, torch.Tensor]: return self.encode(images) -class SimpleEncoder(BaseEncoder): +class NNModuleEncoder(BaseEncoder, nn.Module): + """Encoder network module subclassed from nn.Module.""" + + def forward(self, images: torch.Tensor) -> dict[str, torch.Tensor]: + """Call self.encode. + + Parameters + ---------- + images : torch.Tensor + Images on the library's internal domain. + + Returns + ------- + dict[str, torch.Tensor] + Features indexed by the layer names. + """ + return self.encode(images) + + def __call__(self, images: torch.Tensor) -> dict[str, torch.Tensor]: + return nn.Module.__call__(self, images) + + +class SimpleEncoder(NNModuleEncoder): """Encoder network module with a naive feature extractor. Parameters diff --git a/tests/recon/torch/modules/test_encoder.py b/tests/recon/torch/modules/test_encoder.py index 31caa746..7d36ce8c 100644 --- a/tests/recon/torch/modules/test_encoder.py +++ b/tests/recon/torch/modules/test_encoder.py @@ -45,6 +45,34 @@ def encode(self, images): self.assertDictEqual(features, {"image": images}) +class TestNNModuleEncoder(unittest.TestCase): + """Tests for bdpy.recon.torch.modules.encoder.NNModuleEncoder.""" + + def test_instantiation(self): + """Test instantiation.""" + self.assertRaises(TypeError, encoder_module.NNModuleEncoder) + + def test_call(self): + """Test __call__.""" + + class ReturnAsIsEncoder(encoder_module.NNModuleEncoder): + def __init__(self) -> None: + super().__init__() + def encode(self, images): + return {"image": images} + + encoder = ReturnAsIsEncoder() + + images = torch.randn(1, 3, 64, 64) + images.requires_grad = True + features = encoder(images) + self.assertIsInstance(features, dict) + self.assertEqual(len(features), 1) + self.assertEqual(features["image"].shape, (1, 3, 64, 64)) + features["image"].sum().backward() + self.assertIsNotNone(images.grad) + + class TestSimpleEncoder(unittest.TestCase): """Tests for bdpy.recon.torch.modules.encoder.SimpleEncoder.""" @@ -54,11 +82,14 @@ def test_call(self): MLP(), ["fc1", "fc2"], domain=Zero2OneImageDomain() ) images = torch.randn(1, 3, 64, 64).clamp(0, 1) + images.requires_grad = True features = encoder(images) self.assertIsInstance(features, dict) self.assertEqual(len(features), 2) self.assertEqual(features["fc1"].shape, (1, 256)) self.assertEqual(features["fc2"].shape, (1, 128)) + features["fc2"].sum().backward() + self.assertIsNotNone(images.grad) class TestBuildEncoder(unittest.TestCase): From 7977cd440d2cdf61232b661b78c0ed229136fc05 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Mon, 25 Dec 2023 18:01:34 +0900 Subject: [PATCH 075/117] update test case --- tests/recon/torch/modules/test_generator.py | 34 +++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/recon/torch/modules/test_generator.py b/tests/recon/torch/modules/test_generator.py index 63eb9090..f3a270d4 100644 --- a/tests/recon/torch/modules/test_generator.py +++ b/tests/recon/torch/modules/test_generator.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn +import torch.optim as optim from torchvision.models import get_model from bdpy.recon.torch.modules import generator as generator_module @@ -160,5 +161,38 @@ def test_reset_states(self): self.assertFalse(torch.equal(p1, p2)) +class TestFrozenGenerator(unittest.TestCase): + """Tests for bdpy.recon.torch.modules.generator.FrozenGenerator.""" + def test_call(self): + """Test __call__.""" + generator_network = LinearGenerator() + generator = generator_module.FrozenGenerator(generator_network) + latent = torch.randn(1, 64) + generated_image = generator(latent) + self.assertEqual(generated_image.shape, (1, 10)) + self.assertRaises(ValueError, optim.SGD, generator.parameters()) + + def test_reset_states(self): + """Test reset_states.""" + generator = generator_module.FrozenGenerator(LinearGenerator()) + generator_copy = copy.deepcopy(generator) + for p1, p2 in zip(generator.parameters(), generator_copy.parameters()): + self.assertTrue(torch.equal(p1, p2)) + generator.reset_states() + for p1, p2 in zip(generator.parameters(), generator_copy.parameters()): + self.assertTrue(torch.equal(p1, p2)) + + +class TestBuildGenerator(unittest.TestCase): + """Tests for bdpy.recon.torch.modules.generator.build_generator.""" + def test_build_generator(self): + """Test build_generator.""" + generator_network = LinearGenerator() + generator = generator_module.build_generator(generator_network) + self.assertIsInstance(generator, generator_module.DNNGenerator) + generator = generator_module.build_generator(generator_network, frozen=True) + self.assertIsInstance(generator, generator_module.FrozenGenerator) + + if __name__ == "__main__": unittest.main() From 5dee9d980b253a240ccd5cc76cdf55ac1d1804ff Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 27 Dec 2023 11:36:12 +0900 Subject: [PATCH 076/117] rename pipeline -> task --- bdpy/dl/torch/domain/image_domain.py | 2 +- bdpy/recon/torch/modules/critic.py | 2 +- bdpy/recon/torch/pipeline/__init__.py | 1 - bdpy/recon/torch/task/__init__.py | 1 + .../torch/{pipeline => task}/inversion.py | 34 +++++------ bdpy/task/__init__.py | 0 bdpy/{pipeline => task}/callback.py | 60 +++++++++---------- 7 files changed, 49 insertions(+), 51 deletions(-) delete mode 100644 bdpy/recon/torch/pipeline/__init__.py create mode 100644 bdpy/recon/torch/task/__init__.py rename bdpy/recon/torch/{pipeline => task}/inversion.py (90%) create mode 100644 bdpy/task/__init__.py rename bdpy/{pipeline => task}/callback.py (82%) diff --git a/bdpy/dl/torch/domain/image_domain.py b/bdpy/dl/torch/domain/image_domain.py index b53b7fc8..34810761 100644 --- a/bdpy/dl/torch/domain/image_domain.py +++ b/bdpy/dl/torch/domain/image_domain.py @@ -152,7 +152,7 @@ def receive(self, images: torch.Tensor) -> torch.Tensor: warnings.warn( "`PILDominWithExplicitCrop.receive` performs explicit cropping. " \ "It could be affected to the gradient computation. " \ - "Please do not use this domain inside the optimization pipeline.", + "Please do not use this domain inside the optimization.", RuntimeWarning, ) diff --git a/bdpy/recon/torch/modules/critic.py b/bdpy/recon/torch/modules/critic.py index af537e46..bbf31157 100644 --- a/bdpy/recon/torch/modules/critic.py +++ b/bdpy/recon/torch/modules/critic.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn -from bdpy.pipeline.callback import CallbackHandler, BaseCallback +from bdpy.task.callback import CallbackHandler, BaseCallback _FeatureType = Dict[str, torch.Tensor] diff --git a/bdpy/recon/torch/pipeline/__init__.py b/bdpy/recon/torch/pipeline/__init__.py deleted file mode 100644 index 0569d5ab..00000000 --- a/bdpy/recon/torch/pipeline/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .inversion import FeatureInversionPipeline \ No newline at end of file diff --git a/bdpy/recon/torch/task/__init__.py b/bdpy/recon/torch/task/__init__.py new file mode 100644 index 00000000..a5292096 --- /dev/null +++ b/bdpy/recon/torch/task/__init__.py @@ -0,0 +1 @@ +from .inversion import FeatureInversionTask \ No newline at end of file diff --git a/bdpy/recon/torch/pipeline/inversion.py b/bdpy/recon/torch/task/inversion.py similarity index 90% rename from bdpy/recon/torch/pipeline/inversion.py rename to bdpy/recon/torch/task/inversion.py index 8f9001f7..2a62426e 100644 --- a/bdpy/recon/torch/pipeline/inversion.py +++ b/bdpy/recon/torch/task/inversion.py @@ -7,7 +7,7 @@ import torch from ..modules import BaseEncoder, BaseGenerator, BaseLatent, BaseCritic -from bdpy.pipeline.callback import CallbackHandler, BaseCallback, unused, _validate_callback +from bdpy.task.callback import CallbackHandler, BaseCallback, unused, _validate_callback FeatureType = Dict[str, torch.Tensor] @@ -19,10 +19,10 @@ def _apply_to_features( class FeatureInversionCallback(BaseCallback): - """Callback for feature inversion pipeline. + """Callback for feature inversion task. As a design principle, the callback functions must not have any side effects - on the pipeline results. It should be used only for logging, visualization, + on the task results. It should be used only for logging, visualization, etc. Please refer to `bdpy.util.callback.BaseCallback` for details of the usage of callbacks. """ @@ -32,8 +32,8 @@ def __init__(self) -> None: _validate_callback(self, FeatureInversionCallback) @unused - def on_pipeline_start(self) -> None: - """Callback on pipeline start.""" + def on_task_start(self) -> None: + """Callback on task start.""" pass @unused @@ -64,8 +64,8 @@ def on_iteration_end(self, *, step: int) -> None: pass @unused - def on_pipeline_end(self) -> None: - """Callback on pipeline end.""" + def on_task_end(self) -> None: + """Callback on task end.""" pass @@ -155,8 +155,8 @@ def on_loss_calculated(self, *, step: int, loss: torch.Tensor) -> None: self._run.log({"loss": loss.item()}, step=self._step) -class FeatureInversionPipeline: - """Feature inversion pipeline. +class FeatureInversionTask: + """Feature inversion Task. Parameters ---------- @@ -176,25 +176,25 @@ class FeatureInversionPipeline: Number of iterations, by default 1. callbacks : FeatureInversionCallback | Iterable[FeatureInversionCallback] | None, optional Callbacks, by default None. Please refer to `bdpy.util.callback.BaseCallback` - and `bdpy.recon.torch.pipeline.FeatureInversionCallback` for details. + and `bdpy.recon.torch.task.FeatureInversionCallback` for details. Examples -------- >>> import torch >>> import torch.nn as nn - >>> from bdpy.recon.torch.pipeline import FeatureInversionPipeline + >>> from bdpy.recon.torch.task import FeatureInversionTask >>> from bdpy.recon.torch.modules import build_encoder, build_generator, ArbitraryLatent, TargetNormalizedMSE >>> encoder = build_encoder(...) >>> generator = build_generator(...) >>> latent = ArbitraryLatent(...) >>> critic = TargetNormalizedMSE(...) >>> optimizer = torch.optim.Adam(latent.parameters()) - >>> pipeline = FeatureInversionPipeline( + >>> task = FeatureInversionTask( ... encoder, generator, latent, critic, optimizer, num_iterations=200, ... ) >>> target_features = encoder(target_image) - >>> pipeline.reset_states() - >>> reconstructed_image = pipeline(target_features) + >>> task.reset_states() + >>> reconstructed_image = task(target_features) """ def __init__( @@ -237,7 +237,7 @@ def __call__( torch.Tensor Reconstructed images on the libraries internal domain. """ - self._callback_handler.fire("on_pipeline_start") + self._callback_handler.fire("on_task_start") for step in range(self._num_iterations): self._callback_handler.fire("on_iteration_start", step=step) self._optimizer.zero_grad() @@ -264,11 +264,11 @@ def __call__( generated_image = self._generator(self._latent()).detach() - self._callback_handler.fire("on_pipeline_end") + self._callback_handler.fire("on_task_end") return generated_image def reset_states(self) -> None: - """Reset the state of the pipeline.""" + """Reset the state of the task.""" self._generator.reset_states() self._latent.reset_states() self._optimizer = self._optimizer.__class__( diff --git a/bdpy/task/__init__.py b/bdpy/task/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bdpy/pipeline/callback.py b/bdpy/task/callback.py similarity index 82% rename from bdpy/pipeline/callback.py rename to bdpy/task/callback.py index fe409ba2..6668aff1 100644 --- a/bdpy/pipeline/callback.py +++ b/bdpy/task/callback.py @@ -83,44 +83,42 @@ def _validate_callback(callback: BaseCallback, base_class: Type[BaseCallback]) - class BaseCallback: """Base class for callbacks. - Callbacks are used to hook into the pipeline and execute custom functions + Callbacks are used to hook into the task and execute custom functions at specific events. Callback functions must be defined as methods of the callback classes. The callback functions must be named as "on_". As a design principle, the callback functions must not have any side effects - on the pipeline results. It should be used only for logging, visualization, - etc. + on the task results. It should be used only for logging, visualization, etc. - For example, the following callback class logs the start and end of the - pipeline. + For example, the following callback class logs the start and end of the task. >>> class Callback(BaseCallback): - ... def on_pipeline_start(self): - ... print("Pipeline started.") + ... def on_task_start(self): + ... print("Task started.") ... - ... def on_pipeline_end(self): - ... print("Pipeline ended.") + ... def on_task_end(self): + ... print("Task ended.") ... >>> callback = Callback() - >>> some_pipeline = SomePipeline() # Initialize a pipeline object - >>> some_pipeline.register_callback(callback) - >>> outputs = some_pipeline(inputs) # Run the pipeline - Pipeline started. - Pipeline ended. - - The set of available events that can be hooked into depends on the pipeline. - See the base class of the corresponding pipeline for the list of all events. + >>> some_task = SomeTask() # Initialize a task object + >>> some_task.register_callback(callback) + >>> outputs = some_task(inputs) # Run the task + Task started. + Task ended. + + The set of available events that can be hooked into depends on the task. + See the base class of the corresponding task for the list of all events. `@unused` decorator can be used to mark a callback function as unused, so that the callback handler does not fire the function. """ @unused - def on_pipeline_start(self) -> None: - """Callback on pipeline start.""" + def on_task_start(self) -> None: + """Callback on task start.""" pass @unused - def on_pipeline_end(self) -> None: - """Callback on pipeline end.""" + def on_task_end(self) -> None: + """Callback on task end.""" pass @@ -143,19 +141,19 @@ class CallbackHandler: ... def __init__(self, name): ... self._name = name ... - ... def on_pipeline_start(self): - ... print(f"Pipeline started (name={self._name}).") + ... def on_task_start(self): + ... print(f"Task started (name={self._name}).") ... - ... def on_pipeline_end(self): - ... print(f"Pipeline ended (name={self._name}).") + ... def on_task_end(self): + ... print(f"Task ended (name={self._name}).") ... >>> handler = CallbackHandler([Callback("A"), Callback("B")]) - >>> handler.fire("on_pipeline_start") - Pipeline started (name=A). - Pipeline started (name=B). - >>> handler.fire("on_pipeline_end") - Pipeline ended (name=A). - Pipeline ended (name=B). + >>> handler.fire("on_task_start") + Task started (name=A). + Task started (name=B). + >>> handler.fire("on_task_end") + Task ended (name=A). + Task ended (name=B). """ _callbacks: list[BaseCallback] From 4828f45deb1b9a2b80c64c549551d6f668132b87 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 27 Dec 2023 11:53:59 +0900 Subject: [PATCH 077/117] base class for the task object --- bdpy/recon/torch/task/inversion.py | 10 +++------ bdpy/task/__init__.py | 1 + bdpy/task/callback.py | 15 +++++++------ bdpy/task/core.py | 34 ++++++++++++++++++++++++++++++ 4 files changed, 47 insertions(+), 13 deletions(-) create mode 100644 bdpy/task/core.py diff --git a/bdpy/recon/torch/task/inversion.py b/bdpy/recon/torch/task/inversion.py index 2a62426e..5236de97 100644 --- a/bdpy/recon/torch/task/inversion.py +++ b/bdpy/recon/torch/task/inversion.py @@ -7,6 +7,7 @@ import torch from ..modules import BaseEncoder, BaseGenerator, BaseLatent, BaseCritic +from bdpy.task import BaseTask from bdpy.task.callback import CallbackHandler, BaseCallback, unused, _validate_callback FeatureType = Dict[str, torch.Tensor] @@ -155,7 +156,7 @@ def on_loss_calculated(self, *, step: int, loss: torch.Tensor) -> None: self._run.log({"loss": loss.item()}, step=self._step) -class FeatureInversionTask: +class FeatureInversionTask(BaseTask): """Feature inversion Task. Parameters @@ -210,6 +211,7 @@ def __init__( | Iterable[FeatureInversionCallback] | None = None, ) -> None: + super().__init__(callbacks) self._encoder = encoder self._generator = generator self._latent = latent @@ -219,8 +221,6 @@ def __init__( self._num_iterations = num_iterations - self._callback_handler = CallbackHandler(callbacks) - def __call__( self, target_features: FeatureType, @@ -278,7 +278,3 @@ def reset_states(self) -> None: ), **self._optimizer.defaults, ) - - def register_callback(self, callback: FeatureInversionCallback) -> None: - """Register a callback.""" - self._callback_handler.register(callback) diff --git a/bdpy/task/__init__.py b/bdpy/task/__init__.py index e69de29b..b1ee439e 100644 --- a/bdpy/task/__init__.py +++ b/bdpy/task/__init__.py @@ -0,0 +1 @@ +from .core import BaseTask diff --git a/bdpy/task/callback.py b/bdpy/task/callback.py index 6668aff1..8bd91bbf 100644 --- a/bdpy/task/callback.py +++ b/bdpy/task/callback.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable, Type, Any, Iterable +from typing import Callable, Type, Any, Iterable, TypeVar, Generic from typing_extensions import Annotated, ParamSpec from collections import defaultdict @@ -122,7 +122,10 @@ def on_task_end(self) -> None: pass -class CallbackHandler: +_CallbackType = TypeVar("_CallbackType", bound=BaseCallback) + + +class CallbackHandler(Generic[_CallbackType]): """Callback handler. This class manages the callback objects and fires the callback functions @@ -156,21 +159,21 @@ class CallbackHandler: Task ended (name=B). """ - _callbacks: list[BaseCallback] + _callbacks: list[_CallbackType] _registered_functions: defaultdict[str, list[Callable]] def __init__( - self, callbacks: BaseCallback | Iterable[BaseCallback] | None = None + self, callbacks: _CallbackType | Iterable[_CallbackType] | None = None ) -> None: self._callbacks = [] self._registered_functions = defaultdict(list) if callbacks is not None: - if isinstance(callbacks, BaseCallback): + if not isinstance(callbacks, Iterable): callbacks = [callbacks] for callback in callbacks: self.register(callback) - def register(self, callback: BaseCallback) -> None: + def register(self, callback: _CallbackType) -> None: """Register a callback. Parameters diff --git a/bdpy/task/core.py b/bdpy/task/core.py new file mode 100644 index 00000000..98a95d9f --- /dev/null +++ b/bdpy/task/core.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + +from typing import Iterable, Any, TypeVar, Generic + +from bdpy.task.callback import CallbackHandler, BaseCallback + + +_CallbackType = TypeVar("_CallbackType", bound=BaseCallback) + + +class BaseTask(ABC, Generic[_CallbackType]): + """Base class for tasks.""" + + def __init__( + self, callbacks: _CallbackType | Iterable[_CallbackType] | None = None + ) -> None: + self._callback_handler = CallbackHandler(callbacks) + + @abstractmethod + def __call__(self, *inputs, **parameters) -> Any: + """Run the task.""" + pass + + def register_callback(self, callback: _CallbackType) -> None: + """Register a callback. + + Parameters + ---------- + callback : BaseCallback + Callback to register. + """ + self._callback_handler.register(callback) From b01c9557e12ed076ccb408af3655defb10da196c Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 27 Dec 2023 14:15:06 +0900 Subject: [PATCH 078/117] organize APIs --- bdpy/recon/torch/task/inversion.py | 2 +- bdpy/task/core.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/bdpy/recon/torch/task/inversion.py b/bdpy/recon/torch/task/inversion.py index 5236de97..818729a7 100644 --- a/bdpy/recon/torch/task/inversion.py +++ b/bdpy/recon/torch/task/inversion.py @@ -8,7 +8,7 @@ from ..modules import BaseEncoder, BaseGenerator, BaseLatent, BaseCritic from bdpy.task import BaseTask -from bdpy.task.callback import CallbackHandler, BaseCallback, unused, _validate_callback +from bdpy.task.callback import BaseCallback, unused, _validate_callback FeatureType = Dict[str, torch.Tensor] diff --git a/bdpy/task/core.py b/bdpy/task/core.py index 98a95d9f..8f5c0e1c 100644 --- a/bdpy/task/core.py +++ b/bdpy/task/core.py @@ -13,6 +13,8 @@ class BaseTask(ABC, Generic[_CallbackType]): """Base class for tasks.""" + _callback_handler: CallbackHandler[_CallbackType] + def __init__( self, callbacks: _CallbackType | Iterable[_CallbackType] | None = None ) -> None: From 00fb71a481e5767eb0190083ce50dbc66e7b306b Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Wed, 27 Dec 2023 19:09:08 +0900 Subject: [PATCH 079/117] docstring --- bdpy/task/core.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/bdpy/task/core.py b/bdpy/task/core.py index 8f5c0e1c..ad6b5b9d 100644 --- a/bdpy/task/core.py +++ b/bdpy/task/core.py @@ -11,7 +11,27 @@ class BaseTask(ABC, Generic[_CallbackType]): - """Base class for tasks.""" + """Base class for tasks. + + Parameters + ---------- + callbacks : BaseCallback | Iterable[BaseCallback] | None + Callbacks to register. If `None`, no callbacks are registered. + + Attributes + ---------- + _callback_handler : CallbackHandler + Callback handler. + + Notes + ----- + This class is designed to be used as a base class for tasks. The task + implementation should override the `__call__` method. The actual interface + of `__call__` depends on the task. For example, the task may take a single + input and return a single output, or it may take multiple inputs and return + multiple outputs. The task may also take keyword arguments. Please refer to + the documentation of the specific task for details. + """ _callback_handler: CallbackHandler[_CallbackType] From 5a521b4c1a0ed0234be5d9c4333d9c6643e707ea Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 28 Dec 2023 15:12:43 +0900 Subject: [PATCH 080/117] rename `reset_all_parameters` -> `call_reset_parameters` and add warning --- bdpy/recon/torch/modules/generator.py | 20 ++++++++++++++------ tests/recon/torch/modules/test_generator.py | 14 +++++++------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/bdpy/recon/torch/modules/generator.py b/bdpy/recon/torch/modules/generator.py index 15a1b95b..52e3960d 100644 --- a/bdpy/recon/torch/modules/generator.py +++ b/bdpy/recon/torch/modules/generator.py @@ -1,8 +1,8 @@ from __future__ import annotations from abc import ABC, abstractmethod - from typing import Callable, Iterator +import warnings import torch import torch.nn as nn @@ -23,8 +23,16 @@ def _get_reset_module_fn(module: nn.Module) -> Callable[[], None] | None: @torch.no_grad() -def reset_all_parameters(module: nn.Module) -> None: +def call_reset_parameters(module: nn.Module) -> None: """Reset the parameters of the module.""" + warnings.warn( + "`call_reset_parameters` calls the instance method named `reset_parameters` " \ + "or `_reset_parameters` of the module. This method does not guarantee that " \ + "all the parameters of the module are reset. Please use this method with " \ + "caution.", + UserWarning, + stacklevel=2, + ) reset_parameters = _get_reset_module_fn(module) if reset_parameters is not None: reset_parameters() @@ -150,7 +158,7 @@ class DNNGenerator(NNModuleGenerator): Domain of the input images to receive. (default: Zero2OneImageDomain()) reset_fn : Callable[[nn.Module], None], optional Function to reset the parameters of the generator network, by default - reset_all_parameters. + call_reset_parameters. Examples -------- @@ -172,7 +180,7 @@ def __init__( self, generator_network: nn.Module, domain: Domain = image_domain.Zero2OneImageDomain(), - reset_fn: Callable[[nn.Module], None] = reset_all_parameters, + reset_fn: Callable[[nn.Module], None] = call_reset_parameters, ) -> None: """Initialize the generator.""" super().__init__() @@ -251,7 +259,7 @@ def parameters(self, recurse: bool = True) -> Iterator[Parameter]: def build_generator( generator_network: nn.Module, domain: Domain = image_domain.Zero2OneImageDomain(), - reset_fn: Callable[[nn.Module], None] = reset_all_parameters, + reset_fn: Callable[[nn.Module], None] = call_reset_parameters, frozen: bool = True, ) -> BaseGenerator: """Build a generator module. @@ -271,7 +279,7 @@ def build_generator( network's output domain. reset_fn : Callable[[nn.Module], None], optional Function to reset the parameters of the generator network, by default - reset_all_parameters. + call_reset_parameters. frozen : bool, optional Whether to freeze the parameters of the generator network, by default True. diff --git a/tests/recon/torch/modules/test_generator.py b/tests/recon/torch/modules/test_generator.py index f3a270d4..965e391e 100644 --- a/tests/recon/torch/modules/test_generator.py +++ b/tests/recon/torch/modules/test_generator.py @@ -21,11 +21,11 @@ def generate(self, latent): return self.fc(latent) def reset_states(self) -> None: - self.fc.apply(generator_module.reset_all_parameters) + self.fc.apply(generator_module.call_reset_parameters) -class TestResetAllParameters(unittest.TestCase): - """Tests for bdpy.recon.torch.modules.generator.reset_all_parameters.""" +class TestCallResetParameters(unittest.TestCase): + """Tests for bdpy.recon.torch.modules.generator.call_reset_parameters.""" def setUp(self): self.model_ids = [ "alexnet", @@ -52,13 +52,13 @@ def _validate_module(self, module: nn.Module, module_copy: nn.Module, parent_nam continue self.assertFalse( torch.equal(p1, p2), - msg=f"Parameter {parent_name}.{name_p1} does not change after reset_all_parameters." + msg=f"Parameter {parent_name}.{name_p1} does not change after calling reset_parameters." ) for (name_m1, m1), (_, m2) in zip(module.named_children(), module_copy.named_children()): self._validate_module(m1, m2, f"{parent_name}.{name_m1}") - def test_reset_all_parameters(self): - """Test reset_all_parameters.""" + def test_call_reset_parameters(self): + """Test call_reset_parameters.""" for model_id in self.model_ids: model = get_model(model_id) model_copy = copy.deepcopy(model) @@ -67,7 +67,7 @@ def test_reset_all_parameters(self): torch.equal(p1, p2), msg=f"Parameter {name_p1} of {model_id} has been changed by deepcopy." ) - model.apply(generator_module.reset_all_parameters) + model.apply(generator_module.call_reset_parameters) self._validate_module(model, model_copy, model_id) From 0404f717d3f106c8e236518c0f20ed731f11123d Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 4 Jan 2024 13:46:22 +0900 Subject: [PATCH 081/117] move definition of internal domain to core module --- bdpy/dl/torch/domain/__init__.py | 2 +- bdpy/dl/torch/domain/core.py | 14 ++++++++++++++ bdpy/dl/torch/domain/image_domain.py | 26 +++++++------------------- bdpy/recon/torch/modules/encoder.py | 10 +++++----- bdpy/recon/torch/modules/generator.py | 16 ++++++++-------- 5 files changed, 35 insertions(+), 33 deletions(-) diff --git a/bdpy/dl/torch/domain/__init__.py b/bdpy/dl/torch/domain/__init__.py index 1e24530e..05a08fa1 100644 --- a/bdpy/dl/torch/domain/__init__.py +++ b/bdpy/dl/torch/domain/__init__.py @@ -1 +1 @@ -from .core import Domain, IrreversibleDomain, ComposedDomain, KeyValueDomain \ No newline at end of file +from .core import Domain, InternalDomain, IrreversibleDomain, ComposedDomain, KeyValueDomain \ No newline at end of file diff --git a/bdpy/dl/torch/domain/core.py b/bdpy/dl/torch/domain/core.py index 7acc016b..6792cb95 100644 --- a/bdpy/dl/torch/domain/core.py +++ b/bdpy/dl/torch/domain/core.py @@ -63,6 +63,20 @@ def receive(self, x: _T) -> _T: pass +class InternalDomain(Domain, Generic[_T]): + """The internal common space. + + The domain class which defines the internal common space. This class + receives and sends data as it is. + """ + + def send(self, x: _T) -> _T: + return x + + def receive(self, x: _T) -> _T: + return x + + class IrreversibleDomain(Domain, Generic[_T]): """The domain which cannot be reversed. diff --git a/bdpy/dl/torch/domain/image_domain.py b/bdpy/dl/torch/domain/image_domain.py index 34810761..dcf9440d 100644 --- a/bdpy/dl/torch/domain/image_domain.py +++ b/bdpy/dl/torch/domain/image_domain.py @@ -18,7 +18,7 @@ import torch from torchvision.transforms import InterpolationMode, Resize -from .core import Domain, IrreversibleDomain, ComposedDomain +from .core import Domain, InternalDomain, IrreversibleDomain, ComposedDomain def _bgr2rgb(images: torch.Tensor) -> torch.Tensor: @@ -41,24 +41,12 @@ def _to_channel_last(images: torch.Tensor) -> torch.Tensor: return images.permute(0, 2, 3, 1) - -class Zero2OneImageDomain(Domain): - """Image domain for images in [0, 1]. - - - Channel axis: 1 - - Pixel range: [0, 1] - - Image size: arbitrary - - Color space: RGB - """ - - def send(self, images: torch.Tensor) -> torch.Tensor: - return images - - def receive(self, images: torch.Tensor) -> torch.Tensor: - return images - - -InternalImageDomain = Zero2OneImageDomain +# NOTE: The internal common space for images is defined as follows: +# - Channel axis: 1 +# - Pixel range: [0, 1] +# - Image size: arbitrary +# - Color space: RGB +Zero2OneImageDomain = InternalDomain[torch.Tensor] class AffineDomain(Domain): diff --git a/bdpy/recon/torch/modules/encoder.py b/bdpy/recon/torch/modules/encoder.py index 027d2d83..ec53bb12 100644 --- a/bdpy/recon/torch/modules/encoder.py +++ b/bdpy/recon/torch/modules/encoder.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn from bdpy.dl.torch import FeatureExtractor -from bdpy.dl.torch.domain import Domain, image_domain +from bdpy.dl.torch.domain import Domain, InternalDomain class BaseEncoder(ABC): @@ -77,7 +77,7 @@ class SimpleEncoder(NNModuleEncoder): layer_names : list[str] Layer names to extract features from. domain : Domain, optional - Domain of the input images to receive. (default: Zero2OneImageDomain()) + Domain of the input stimuli to receive. (default: InternalDomain()) Examples -------- @@ -99,7 +99,7 @@ def __init__( self, feature_network: nn.Module, layer_names: Iterable[str], - domain: Domain = image_domain.Zero2OneImageDomain(), + domain: Domain = InternalDomain(), ) -> None: super().__init__() self._feature_extractor = FeatureExtractor( @@ -128,7 +128,7 @@ def encode(self, images: torch.Tensor) -> dict[str, torch.Tensor]: def build_encoder( feature_network: nn.Module, layer_names: Iterable[str], - domain: Domain = image_domain.Zero2OneImageDomain(), + domain: Domain = InternalDomain(), ) -> BaseEncoder: """Build an encoder network with a naive feature extractor. @@ -148,7 +148,7 @@ def build_encoder( layer_names : list[str] Layer names to extract features from. domain : Domain, optional - Domain of the input images to receive (default: Zero2OneImageDomain()). + Domain of the input stimuli to receive (default: InternalDomain()). One needs to specify the domain that corresponds to the feature network's input domain. diff --git a/bdpy/recon/torch/modules/generator.py b/bdpy/recon/torch/modules/generator.py index 52e3960d..98f92109 100644 --- a/bdpy/recon/torch/modules/generator.py +++ b/bdpy/recon/torch/modules/generator.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn from torch.nn.parameter import Parameter -from bdpy.dl.torch.domain import Domain, image_domain +from bdpy.dl.torch.domain import Domain, InternalDomain def _get_reset_module_fn(module: nn.Module) -> Callable[[], None] | None: @@ -121,7 +121,7 @@ def __init__(self, activation: Callable[[torch.Tensor], torch.Tensor] = nn.Ident """Initialize the generator.""" super().__init__() self._activation = activation - self._domain = image_domain.Zero2OneImageDomain() + self._domain = InternalDomain() def reset_states(self) -> None: """Reset the state of the generator.""" @@ -155,7 +155,7 @@ class DNNGenerator(NNModuleGenerator): Generator network. This network should have a method `forward` that takes a latent vector and propagates it through the network. domain : Domain, optional - Domain of the input images to receive. (default: Zero2OneImageDomain()) + Domain of the input stimuli to receive. (default: InternalDomain()) reset_fn : Callable[[nn.Module], None], optional Function to reset the parameters of the generator network, by default call_reset_parameters. @@ -179,7 +179,7 @@ class DNNGenerator(NNModuleGenerator): def __init__( self, generator_network: nn.Module, - domain: Domain = image_domain.Zero2OneImageDomain(), + domain: Domain = InternalDomain(), reset_fn: Callable[[nn.Module], None] = call_reset_parameters, ) -> None: """Initialize the generator.""" @@ -220,7 +220,7 @@ class FrozenGenerator(DNNGenerator): Generator network. This network should have a method `forward` that takes a latent vector and propagates it through the network. domain : Domain, optional - Domain of the input images to receive. (default: Zero2OneImageDomain()) + Domain of the input stimuli to receive. (default: InternalDomain()) Examples -------- @@ -241,7 +241,7 @@ class FrozenGenerator(DNNGenerator): def __init__( self, generator_network: nn.Module, - domain: Domain = image_domain.Zero2OneImageDomain() + domain: Domain = InternalDomain(), ) -> None: """Initialize the generator.""" super().__init__(generator_network, domain=domain, reset_fn=lambda _: None) @@ -258,7 +258,7 @@ def parameters(self, recurse: bool = True) -> Iterator[Parameter]: def build_generator( generator_network: nn.Module, - domain: Domain = image_domain.Zero2OneImageDomain(), + domain: Domain = InternalDomain(), reset_fn: Callable[[nn.Module], None] = call_reset_parameters, frozen: bool = True, ) -> BaseGenerator: @@ -274,7 +274,7 @@ def build_generator( Generator network. This network should have a method `forward` that takes a latent vector and returns an image on its own domain. domain : Domain, optional - Domain of the input images to receive. (default: Zero2OneImageDomain()). + Domain of the input images to receive. (default: InternalDomain()). One needs to specify the domain that corresponds to the generator network's output domain. reset_fn : Callable[[nn.Module], None], optional From e0bae890e5c219995ff9f42a497a9e7df200720b Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Thu, 4 Jan 2024 18:32:32 +0900 Subject: [PATCH 082/117] test cases for callback module --- bdpy/task/callback.py | 42 +++++++++++++++++++++++- tests/task/__init__.py | 0 tests/task/test_callback.py | 64 +++++++++++++++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 1 deletion(-) create mode 100644 tests/task/__init__.py create mode 100644 tests/task/test_callback.py diff --git a/bdpy/task/callback.py b/bdpy/task/callback.py index 8bd91bbf..8ef95eba 100644 --- a/bdpy/task/callback.py +++ b/bdpy/task/callback.py @@ -12,6 +12,8 @@ def _is_unused(fn: Callable) -> bool: + if not hasattr(fn, "__annotations__"): + return False return_type: Type | None = fn.__annotations__.get("return", None) if return_type is None: return False @@ -43,7 +45,7 @@ def unused(fn: Callable[_P, Any]) -> Callable[_P, _Unused]: >>> f(1, 2, 3) Traceback (most recent call last): ... - RuntimeError: Function is decorated with @unused and must not be called. + RuntimeError: Function is decorated with @unused and must not be called. """ @wraps(fn) # NOTE: preserve name, docstring, etc. of the original function @@ -59,6 +61,44 @@ def _unused(*args: _P.args, **kwargs: _P.kwargs) -> _Unused: def _validate_callback(callback: BaseCallback, base_class: Type[BaseCallback]) -> None: + """Validate a callback. + + Parameters + ---------- + callback : BaseCallback + Callback to validate. + base_class : Type[BaseCallback] + Base class of the callback. + + Raises + ------ + TypeError + If the callback is not an instance of the base class. + ValueError + If the callback has an event type that is not acceptable. + + Examples + -------- + >>> class TaskBaseCallback(BaseCallback): + ... @unused + ... def on_task_start(self): + ... pass + ... + ... @unused + ... def on_task_end(self): + ... pass + ... + >>> class SomeTaskCallback(TaskBaseCallback): + ... def on_unacceptable_event(self): + ... # do something + ... + >>> callback = SomeTaskCallback() + >>> _validate_callback(callback, TaskBaseCallback) + Traceback (most recent call last): + ... + ValueError: on_unacceptable_event is not an acceptable event type. ... + """ + if not isinstance(callback, base_class): raise TypeError( f"Callback must be an instance of {base_class}, not {type(callback)}." diff --git a/tests/task/__init__.py b/tests/task/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/task/test_callback.py b/tests/task/test_callback.py new file mode 100644 index 00000000..0ad3b8ce --- /dev/null +++ b/tests/task/test_callback.py @@ -0,0 +1,64 @@ +"""Tests for bdpy.task.callback.""" + +from __future__ import annotations + +import unittest +from typing import Any, Callable + +from bdpy.task import callback + + +# NOTE: setup functions +def setup_fns() -> list[tuple[Callable, tuple[Any], Any]]: + + def f1(input_: Any) -> None: + pass + + def f2(input_): + pass + + def f3(a: int, b: int) -> int: + return a + b + + class F4: + def __call__(self, input_: Any) -> None: + pass + + return [ + (f1, (None,), None), + (f2, (None,), None), + (f3, (1, 2), 3), + (F4(), (None,), None), + ] + + +class TestUnused(unittest.TestCase): + """Tests for unused decorator.""" + + def test_unused(self): + """Test unused decorator. + + Unused decorator should change the return type of the decorated function + to Annotated[None, "unused"]. The decorated function should raise a + RuntimeError when called. + + Examples + -------- + >>> @unused + ... def f(a: int, b: int, c: int = 0) -> int: + ... return a + b + c + ... + >>> f(1, 2, 3) + Traceback (most recent call last): + ... + RuntimeError: Function is decorated with @unused and must not be called. + """ + + params = setup_fns() + for fn, inputs_, output in params: + self.assertFalse(callback._is_unused(fn)) + self.assertEqual(fn(*inputs_), output) + self.assertTrue(callback._is_unused(callback.unused(fn))) + with self.assertRaises(RuntimeError): + callback.unused(fn)(*inputs_) + From 2e637f7945267e6b8e1550b1b5204e2d42d8a295 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Fri, 5 Jan 2024 10:57:25 +0900 Subject: [PATCH 083/117] update test case for callback module --- tests/task/test_callback.py | 43 +++++++++++++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/tests/task/test_callback.py b/tests/task/test_callback.py index 0ad3b8ce..057a62a9 100644 --- a/tests/task/test_callback.py +++ b/tests/task/test_callback.py @@ -10,7 +10,6 @@ # NOTE: setup functions def setup_fns() -> list[tuple[Callable, tuple[Any], Any]]: - def f1(input_: Any) -> None: pass @@ -56,9 +55,45 @@ def test_unused(self): params = setup_fns() for fn, inputs_, output in params: - self.assertFalse(callback._is_unused(fn)) + self.assertTrue( + not hasattr(fn, "__annotations__") + or fn.__annotations__.get("return", None) != callback._Unused + ) self.assertEqual(fn(*inputs_), output) - self.assertTrue(callback._is_unused(callback.unused(fn))) + unused_fn = callback.unused(fn) + self.assertTrue( + hasattr(unused_fn, "__annotations__") + and unused_fn.__annotations__.get("return", None) == callback._Unused + ) with self.assertRaises(RuntimeError): - callback.unused(fn)(*inputs_) + unused_fn(*inputs_) + + def test_is_unused(self): + params = setup_fns() + for fn, _, _ in params: + self.assertFalse(callback._is_unused(fn)) + self.assertTrue(callback._is_unused(callback.unused(fn))) + +class TestBaseCallback(unittest.TestCase): + def setUp(self): + self.callback = callback.BaseCallback() + self.expected_method_names = { + "on_task_start", + "on_task_end", + } + + def test_instance_methods(self): + method_names = { + event_type + for event_type in dir(self.callback) + if event_type.startswith("on_") and callable(getattr(self.callback, event_type)) + } + self.assertEqual(method_names, self.expected_method_names) + for event_type in method_names: + fn = getattr(self.callback, event_type) + self.assertRaises(RuntimeError, fn) + + +if __name__ == "__main__": + unittest.main() From 772c0493830bc9aa6ec9583ee0ea388b8b7c65e6 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Fri, 5 Jan 2024 11:36:58 +0900 Subject: [PATCH 084/117] update test cases for BaseCallback --- tests/task/test_callback.py | 63 ++++++++++++++++++++++++++++++++----- 1 file changed, 55 insertions(+), 8 deletions(-) diff --git a/tests/task/test_callback.py b/tests/task/test_callback.py index 057a62a9..64cea877 100644 --- a/tests/task/test_callback.py +++ b/tests/task/test_callback.py @@ -5,7 +5,7 @@ import unittest from typing import Any, Callable -from bdpy.task import callback +from bdpy.task import callback as callback_module # NOTE: setup functions @@ -31,6 +31,22 @@ def __call__(self, input_: Any) -> None: ] +def setup_callback_classes(): + class TaskBaseCallback(callback_module.BaseCallback): + @callback_module.unused + def on_some_event(self, input_): + pass + + class AppendCallback(TaskBaseCallback): + def __init__(self): + self._storage = [] + + def on_some_event(self, input_): + self._storage.append(input_) + + return TaskBaseCallback, AppendCallback + + class TestUnused(unittest.TestCase): """Tests for unused decorator.""" @@ -57,13 +73,14 @@ def test_unused(self): for fn, inputs_, output in params: self.assertTrue( not hasattr(fn, "__annotations__") - or fn.__annotations__.get("return", None) != callback._Unused + or fn.__annotations__.get("return", None) != callback_module._Unused ) self.assertEqual(fn(*inputs_), output) - unused_fn = callback.unused(fn) + unused_fn = callback_module.unused(fn) self.assertTrue( hasattr(unused_fn, "__annotations__") - and unused_fn.__annotations__.get("return", None) == callback._Unused + and unused_fn.__annotations__.get("return", None) + == callback_module._Unused ) with self.assertRaises(RuntimeError): unused_fn(*inputs_) @@ -71,13 +88,13 @@ def test_unused(self): def test_is_unused(self): params = setup_fns() for fn, _, _ in params: - self.assertFalse(callback._is_unused(fn)) - self.assertTrue(callback._is_unused(callback.unused(fn))) + self.assertFalse(callback_module._is_unused(fn)) + self.assertTrue(callback_module._is_unused(callback_module.unused(fn))) class TestBaseCallback(unittest.TestCase): def setUp(self): - self.callback = callback.BaseCallback() + self.callback = callback_module.BaseCallback() self.expected_method_names = { "on_task_start", "on_task_end", @@ -87,13 +104,43 @@ def test_instance_methods(self): method_names = { event_type for event_type in dir(self.callback) - if event_type.startswith("on_") and callable(getattr(self.callback, event_type)) + if event_type.startswith("on_") + and callable(getattr(self.callback, event_type)) } self.assertEqual(method_names, self.expected_method_names) for event_type in method_names: fn = getattr(self.callback, event_type) self.assertRaises(RuntimeError, fn) + def test_subclass_definition(self): + TaskBaseCallback, _ = setup_callback_classes() + callback = TaskBaseCallback() + expected_method_names = {"on_task_start", "on_some_event", "on_task_end"} + method_names = { + event_type + for event_type in dir(callback) + if event_type.startswith("on_") + and callable(getattr(callback, event_type)) + } + + self.assertEqual(method_names, expected_method_names) + for event_type in method_names: + fn = getattr(callback, event_type) + self.assertRaises(RuntimeError, fn) + + def test_validate_callback(self): + TaskBaseCallback, AppendCallback = setup_callback_classes() + class Unrelated(callback_module.BaseCallback): + pass + + class HasUnkownEvent(TaskBaseCallback): + def on_unknown_event(self): + pass + + self.assertIsNone(callback_module._validate_callback(AppendCallback(), TaskBaseCallback)) + self.assertRaises(TypeError, callback_module._validate_callback, Unrelated(), TaskBaseCallback) + self.assertRaises(ValueError, callback_module._validate_callback, HasUnkownEvent(), TaskBaseCallback) + if __name__ == "__main__": unittest.main() From 60ff8091c61977d26b82c1953d75e71f4b33c3d6 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Fri, 5 Jan 2024 13:32:29 +0900 Subject: [PATCH 085/117] test case for callback handler --- tests/task/test_callback.py | 82 ++++++++++++++++++++++++++++++++++--- 1 file changed, 76 insertions(+), 6 deletions(-) diff --git a/tests/task/test_callback.py b/tests/task/test_callback.py index 64cea877..9ad615a8 100644 --- a/tests/task/test_callback.py +++ b/tests/task/test_callback.py @@ -119,8 +119,7 @@ def test_subclass_definition(self): method_names = { event_type for event_type in dir(callback) - if event_type.startswith("on_") - and callable(getattr(callback, event_type)) + if event_type.startswith("on_") and callable(getattr(callback, event_type)) } self.assertEqual(method_names, expected_method_names) @@ -130,16 +129,87 @@ def test_subclass_definition(self): def test_validate_callback(self): TaskBaseCallback, AppendCallback = setup_callback_classes() + class Unrelated(callback_module.BaseCallback): pass - class HasUnkownEvent(TaskBaseCallback): + class HasUnknownEvent(TaskBaseCallback): def on_unknown_event(self): pass - self.assertIsNone(callback_module._validate_callback(AppendCallback(), TaskBaseCallback)) - self.assertRaises(TypeError, callback_module._validate_callback, Unrelated(), TaskBaseCallback) - self.assertRaises(ValueError, callback_module._validate_callback, HasUnkownEvent(), TaskBaseCallback) + self.assertIsNone( + callback_module._validate_callback(AppendCallback(), TaskBaseCallback) + ) + self.assertRaises( + TypeError, callback_module._validate_callback, Unrelated(), TaskBaseCallback + ) + self.assertRaises( + ValueError, + callback_module._validate_callback, + HasUnknownEvent(), + TaskBaseCallback, + ) + + +class TestCallbackHandler(unittest.TestCase): + def test_initialization(self): + _, AppendCallback = setup_callback_classes() + c1, c2 = AppendCallback(), AppendCallback() + + handler = callback_module.CallbackHandler() + self.assertListEqual(handler._callbacks, []) + self.assertDictEqual(handler._registered_functions, {}) + + handler = callback_module.CallbackHandler(c1) + self.assertListEqual(handler._callbacks, [c1]) + self.assertDictEqual( + handler._registered_functions, + {"on_some_event": [c1.on_some_event]}, + ) + + handler = callback_module.CallbackHandler([c1, c2]) + self.assertListEqual(handler._callbacks, [c1, c2]) + self.assertDictEqual( + handler._registered_functions, + {"on_some_event": [c1.on_some_event, c2.on_some_event]}, + ) + + def test_register(self): + handler = callback_module.CallbackHandler() + _, AppendCallback = setup_callback_classes() + cb = AppendCallback() + + self.assertListEqual(handler._callbacks, []) + self.assertDictEqual(handler._registered_functions, {}) + handler.register(cb) + self.assertListEqual(handler._callbacks, [cb]) + self.assertDictEqual( + handler._registered_functions, + {"on_some_event": [cb.on_some_event]}, + ) + + def test_fire(self): + handler = callback_module.CallbackHandler() + _, AppendCallback = setup_callback_classes() + cb = AppendCallback() + handler.register(cb) + + self.assertListEqual(cb._storage, []) + + handler.fire("on_task_start") + self.assertListEqual(cb._storage, []) + + handler.fire("on_some_event", input_=1) + self.assertListEqual(cb._storage, [1]) + + handler.fire("on_some_event", input_=2) + self.assertListEqual(cb._storage, [1, 2]) + + # NOTE: fire() should only accept keyword arguments + self.assertRaises(TypeError, handler.fire, "on_some_event", 1, 2) + + handler.fire("on_task_end") + self.assertListEqual(cb._storage, [1, 2]) if __name__ == "__main__": From ed03f17e2de0a8415dad7644ec91dfd1308dd49e Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Fri, 5 Jan 2024 14:34:17 +0900 Subject: [PATCH 086/117] remove unsupported codes (WandB logging) --- bdpy/recon/torch/task/inversion.py | 53 ------------------------------ 1 file changed, 53 deletions(-) diff --git a/bdpy/recon/torch/task/inversion.py b/bdpy/recon/torch/task/inversion.py index 818729a7..dddcc74b 100644 --- a/bdpy/recon/torch/task/inversion.py +++ b/bdpy/recon/torch/task/inversion.py @@ -103,59 +103,6 @@ def on_iteration_end(self, *, step: int) -> None: print(f"Step: [{self._step_str(step)}], Loss: {self._loss:.4f}") -class WandBLoggingCallback(FeatureInversionCallback): - """Callback for logging on Weights & Biases. - - Parameters - ---------- - run : wandb.sdk.wandb_run.Run - Run object of Weights & Biases. - interval : int, optional - Logging interval, by default 1. If `interval` is 1, the callback logs - every iteration. - media_interval : int, optional - Logging interval for media, by default -1. If `media_interval` is -1, - the callback does not log media. - - Notes - ----- - TODO: Currently it does not work because the dependency (wandb) is not installed. - """ - - def __init__( - self, run: wandb.sdk.wandb_run.Run, interval: int = 1, media_interval: int = -1 - ) -> None: - super().__init__() - self._run = run - self._interval = interval - self._media_interval = media_interval - self._step = 0 - - if media_interval < 0: - # NOTE: Decorate `on_image_generated` to do nothing. - self.on_image_generated = unused(self.on_image_generated) - - def on_iteration_start(self, *, step: int) -> None: - # NOTE: We need to store the global step because we cannot access it - # in `on_layerwise_loss_calculated` by design. - self._step = step - - def on_image_generated(self, *, step: int, image: torch.Tensor) -> None: - if self._step % self._media_interval == 0: - image = wandb.Image(image) - self._run.log({"generated_image": image}, step=self._step) - - def on_layerwise_loss_calculated( - self, *, layer_loss: torch.Tensor, layer_name: str - ) -> None: - if self._step % self._interval == 0: - self._run.log({f"critic/{layer_name}": layer_loss.item()}, step=self._step) - - def on_loss_calculated(self, *, step: int, loss: torch.Tensor) -> None: - if self._step % self._interval == 0: - self._run.log({"loss": loss.item()}, step=self._step) - - class FeatureInversionTask(BaseTask): """Feature inversion Task. From 4d3c8afac048587766017b3e9c47fbccfeb8be2b Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Fri, 5 Jan 2024 14:54:47 +0900 Subject: [PATCH 087/117] changed interface around base callback --- bdpy/recon/torch/task/inversion.py | 3 +-- bdpy/task/callback.py | 4 ++++ tests/task/test_callback.py | 14 ++++++++------ 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/bdpy/recon/torch/task/inversion.py b/bdpy/recon/torch/task/inversion.py index dddcc74b..3a873290 100644 --- a/bdpy/recon/torch/task/inversion.py +++ b/bdpy/recon/torch/task/inversion.py @@ -29,8 +29,7 @@ class FeatureInversionCallback(BaseCallback): """ def __init__(self) -> None: - super().__init__() - _validate_callback(self, FeatureInversionCallback) + super().__init__(base_class=FeatureInversionCallback) @unused def on_task_start(self) -> None: diff --git a/bdpy/task/callback.py b/bdpy/task/callback.py index 8ef95eba..93f097f4 100644 --- a/bdpy/task/callback.py +++ b/bdpy/task/callback.py @@ -150,6 +150,10 @@ class BaseCallback: `@unused` decorator can be used to mark a callback function as unused, so that the callback handler does not fire the function. """ + def __init__(self, base_class: Type[BaseCallback] | None = None) -> None: + if base_class is None: + base_class = BaseCallback + _validate_callback(self, base_class) @unused def on_task_start(self) -> None: diff --git a/tests/task/test_callback.py b/tests/task/test_callback.py index 9ad615a8..caae1cd2 100644 --- a/tests/task/test_callback.py +++ b/tests/task/test_callback.py @@ -33,6 +33,9 @@ def __call__(self, input_: Any) -> None: def setup_callback_classes(): class TaskBaseCallback(callback_module.BaseCallback): + def __init__(self): + super().__init__(base_class=TaskBaseCallback) + @callback_module.unused def on_some_event(self, input_): pass @@ -131,9 +134,13 @@ def test_validate_callback(self): TaskBaseCallback, AppendCallback = setup_callback_classes() class Unrelated(callback_module.BaseCallback): + """Valid callback object but is not a subclass of TaskBaseCallback""" + pass class HasUnknownEvent(TaskBaseCallback): + """Having invalid instance method `on_unknown_event` as a subclass of TaskBaseCallback""" + def on_unknown_event(self): pass @@ -143,12 +150,7 @@ def on_unknown_event(self): self.assertRaises( TypeError, callback_module._validate_callback, Unrelated(), TaskBaseCallback ) - self.assertRaises( - ValueError, - callback_module._validate_callback, - HasUnknownEvent(), - TaskBaseCallback, - ) + self.assertRaises(ValueError, HasUnknownEvent) class TestCallbackHandler(unittest.TestCase): From 49b35422ad0cf53657849e525a3f9bcab4bcf8bc Mon Sep 17 00:00:00 2001 From: yu1120 <65484599+myaokai@users.noreply.github.com> Date: Tue, 23 Jan 2024 16:38:46 +0900 Subject: [PATCH 088/117] first commit From fb203f14e291c84fb1fe6219ba30c0827e2ba190 Mon Sep 17 00:00:00 2001 From: yu1120 <65484599+myaokai@users.noreply.github.com> Date: Tue, 23 Jan 2024 16:47:23 +0900 Subject: [PATCH 089/117] create test files --- tests/dl/torch/test_dataset.py | 0 tests/recon/torch/modules/test_latent.py | 0 tests/recon/torch/task/__init__.py | 0 tests/recon/torch/task/test_inversion.py | 0 tests/task/test_core.py | 0 5 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/dl/torch/test_dataset.py create mode 100644 tests/recon/torch/modules/test_latent.py create mode 100644 tests/recon/torch/task/__init__.py create mode 100644 tests/recon/torch/task/test_inversion.py create mode 100644 tests/task/test_core.py diff --git a/tests/dl/torch/test_dataset.py b/tests/dl/torch/test_dataset.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/recon/torch/modules/test_latent.py b/tests/recon/torch/modules/test_latent.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/recon/torch/task/__init__.py b/tests/recon/torch/task/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/recon/torch/task/test_inversion.py b/tests/recon/torch/task/test_inversion.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/task/test_core.py b/tests/task/test_core.py new file mode 100644 index 00000000..e69de29b From 3210670cf16a9e8da463cf525ce8f264b2491187 Mon Sep 17 00:00:00 2001 From: yu1120 <65484599+myaokai@users.noreply.github.com> Date: Wed, 24 Jan 2024 10:51:58 +0900 Subject: [PATCH 090/117] add TestBaseLatent --- tests/dl/torch/test_dataset.py | 27 +++++++++++ tests/recon/torch/modules/test_latent.py | 60 ++++++++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/tests/dl/torch/test_dataset.py b/tests/dl/torch/test_dataset.py index e69de29b..56656a4c 100644 --- a/tests/dl/torch/test_dataset.py +++ b/tests/dl/torch/test_dataset.py @@ -0,0 +1,27 @@ +import unittest + +import torch +import torch.nn as nn + +from bdpy.dl.torch import models + + +class TestFeatureDataset(unittest.TestCase): + def setUp(self): + #self.dataset = + pass + pass + +class TestDecodedFeatureDataset(unittest.TestCase): + pass + +class TestImageDataset(unittest.TestCase): + pass + +class TestRenameFeatureKeys(unittest.TestCase): + pass + + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/recon/torch/modules/test_latent.py b/tests/recon/torch/modules/test_latent.py index e69de29b..a3d0525a 100644 --- a/tests/recon/torch/modules/test_latent.py +++ b/tests/recon/torch/modules/test_latent.py @@ -0,0 +1,60 @@ +import torch +import unittest +from abc import ABC, abstractmethod +from typing import Iterator +import torch.nn as nn +from bdpy.recon.torch.modules import latent as latent_module + + +class DummyLatent(latent_module.BaseLatent): + def __init__(self): + self.latent = torch.tensor([1.0]) + + def reset_states(self): + self.latent = torch.zeros_like(self.latent) + + def parameters(self, recurse): + return iter([nn.Parameter(self.latent)]) + + def generate(self): + return self.latent + +class TestBaseLatent(unittest.TestCase): + """Tests for bdpy.recon.torch.modules.latent.BaseLatent.""" + def setUp(self): + self.latent = torch.tensor([1.0]) + + def test_instantiation(self): + """Test instantiation.""" + self.assertRaises(TypeError, latent_module.BaseLatent) + + def test_call(self): + """Test __call__.""" + + latent = DummyLatent() + + self.assertEqual(latent(), self.latent) + + def test_parameters(self): + """test parameters""" + latent = DummyLatent() + params = latent.parameters(recurse=True) + + self.assertIsInstance(params, Iterator) + self.assertEqual(next(params).item(), 1.0) + + def test_reset_states(self): + """test reset_states""" + latent = DummyLatent() + latent.reset_states() + params = latent.parameters(recurse=True) + + self.assertEqual(next(params).item(), 0.0) + + + + + + +if __name__ == '__main__': + unittest.main() From 31d3ba50645b9370e6f3a18bd95433ec1e0531fd Mon Sep 17 00:00:00 2001 From: myaokai <65484599+myaokai@users.noreply.github.com> Date: Wed, 24 Jan 2024 13:46:27 +0900 Subject: [PATCH 091/117] Update tests/recon/torch/modules/test_latent.py Co-authored-by: Yoshihiro Nagano --- tests/recon/torch/modules/test_latent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/recon/torch/modules/test_latent.py b/tests/recon/torch/modules/test_latent.py index a3d0525a..a99560c6 100644 --- a/tests/recon/torch/modules/test_latent.py +++ b/tests/recon/torch/modules/test_latent.py @@ -22,7 +22,7 @@ def generate(self): class TestBaseLatent(unittest.TestCase): """Tests for bdpy.recon.torch.modules.latent.BaseLatent.""" def setUp(self): - self.latent = torch.tensor([1.0]) + self.latent_value_expected = torch.tensor([1.0]) def test_instantiation(self): """Test instantiation.""" From 4f367b5b01ccd4712a5c38524b7f69bc108efbdd Mon Sep 17 00:00:00 2001 From: myaokai <65484599+myaokai@users.noreply.github.com> Date: Wed, 24 Jan 2024 13:46:39 +0900 Subject: [PATCH 092/117] Update tests/recon/torch/modules/test_latent.py Co-authored-by: Yoshihiro Nagano --- tests/recon/torch/modules/test_latent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/recon/torch/modules/test_latent.py b/tests/recon/torch/modules/test_latent.py index a99560c6..af0af0d2 100644 --- a/tests/recon/torch/modules/test_latent.py +++ b/tests/recon/torch/modules/test_latent.py @@ -8,7 +8,7 @@ class DummyLatent(latent_module.BaseLatent): def __init__(self): - self.latent = torch.tensor([1.0]) + self.latent = nn.Parameter(torch.tensor([1.0])) def reset_states(self): self.latent = torch.zeros_like(self.latent) From 3660f4043092d83bbaee15340cb50fb3d7745673 Mon Sep 17 00:00:00 2001 From: myaokai <65484599+myaokai@users.noreply.github.com> Date: Wed, 24 Jan 2024 13:46:48 +0900 Subject: [PATCH 093/117] Update tests/recon/torch/modules/test_latent.py Co-authored-by: Yoshihiro Nagano --- tests/recon/torch/modules/test_latent.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/recon/torch/modules/test_latent.py b/tests/recon/torch/modules/test_latent.py index af0af0d2..131589a8 100644 --- a/tests/recon/torch/modules/test_latent.py +++ b/tests/recon/torch/modules/test_latent.py @@ -11,7 +11,8 @@ def __init__(self): self.latent = nn.Parameter(torch.tensor([1.0])) def reset_states(self): - self.latent = torch.zeros_like(self.latent) + with torch.no_grad(): + self.latent.fill_(0.0) def parameters(self, recurse): return iter([nn.Parameter(self.latent)]) From 33f2c9fbb2c19373b3394716299108779aaaef5a Mon Sep 17 00:00:00 2001 From: yu1120 <65484599+myaokai@users.noreply.github.com> Date: Thu, 25 Jan 2024 15:15:51 +0900 Subject: [PATCH 094/117] update test latent --- tests/recon/torch/modules/test_latent.py | 84 +++++++++++++++++++++--- 1 file changed, 74 insertions(+), 10 deletions(-) diff --git a/tests/recon/torch/modules/test_latent.py b/tests/recon/torch/modules/test_latent.py index 131589a8..76b5ca32 100644 --- a/tests/recon/torch/modules/test_latent.py +++ b/tests/recon/torch/modules/test_latent.py @@ -3,19 +3,21 @@ from abc import ABC, abstractmethod from typing import Iterator import torch.nn as nn +from functools import partial from bdpy.recon.torch.modules import latent as latent_module +from IPython import embed class DummyLatent(latent_module.BaseLatent): def __init__(self): - self.latent = nn.Parameter(torch.tensor([1.0])) + self.latent = nn.Parameter(torch.tensor([0.0, 1.0, 2.0])) def reset_states(self): with torch.no_grad(): self.latent.fill_(0.0) def parameters(self, recurse): - return iter([nn.Parameter(self.latent)]) + return iter(self.latent) def generate(self): return self.latent @@ -23,7 +25,8 @@ def generate(self): class TestBaseLatent(unittest.TestCase): """Tests for bdpy.recon.torch.modules.latent.BaseLatent.""" def setUp(self): - self.latent_value_expected = torch.tensor([1.0]) + self.latent_value_expected = nn.Parameter(torch.tensor([0.0, 1.0, 2.0])) + self.latent_reset_value_expected = nn.Parameter(torch.tensor([0.0, 0.0, 0.0])) def test_instantiation(self): """Test instantiation.""" @@ -31,31 +34,92 @@ def test_instantiation(self): def test_call(self): """Test __call__.""" - latent = DummyLatent() - - self.assertEqual(latent(), self.latent) + self.assertTrue(torch.equal(latent(), self.latent_value_expected)) def test_parameters(self): """test parameters""" latent = DummyLatent() params = latent.parameters(recurse=True) - self.assertIsInstance(params, Iterator) - self.assertEqual(next(params).item(), 1.0) def test_reset_states(self): """test reset_states""" latent = DummyLatent() latent.reset_states() - params = latent.parameters(recurse=True) + self.assertTrue(torch.equal(latent(), self.latent_reset_value_expected)) + +class DummyNNModuleLatent(latent_module.NNModuleLatent): + def __init__(self): + super().__init__() + self.latent = nn.Parameter(torch.tensor([0.0, 1.0, 2.0])) + + def reset_states(self): + with torch.no_grad(): + self.latent.fill_(0.0) + + def generate(self): + return self.latent + +class TestNNModuleLatent(unittest.TestCase): + """Tests for bdpy.recon.torch.modules.latent.NNModuleLatent.""" + def setUp(self): + self.latent_value_expected = nn.Parameter(torch.tensor([0.0, 1.0, 2.0])) + self.latent_reset_value_expected = nn.Parameter(torch.tensor([0.0, 0.0, 0.0])) + + def test_instantiation(self): + """Test instantiation.""" + self.assertRaises(TypeError, latent_module.NNModuleLatent) + + def test_call(self): + """Test __call__.""" + latent = DummyNNModuleLatent() + self.assertTrue(torch.equal(latent(), self.latent_value_expected)) - self.assertEqual(next(params).item(), 0.0) + def test_parameters(self): + """test parameters""" + latent = DummyNNModuleLatent() + params = latent.parameters(recurse=True) + self.assertIsInstance(params, Iterator) + + def test_reset_states(self): + """test reset_states""" + latent = DummyNNModuleLatent() + latent.reset_states() + self.assertTrue(torch.equal(latent(), self.latent_reset_value_expected)) +class DummyArbitraryLatent(latent_module.ArbitraryLatent): + def parameters(self, recurse): + return iter(self._latent) +class TestArbitraryLatent(unittest.TestCase): + """Tests for bdpy.recon.torch.modules.latent.ArbitraryLatent.""" + def setUp(self): + self.latent = DummyArbitraryLatent((1, 3, 64, 64), partial(nn.init.normal_, mean=0, std=1)) + self.latent_shape_expected = (1, 3, 64, 64) + self.latent_value_expected = nn.Parameter(torch.tensor([0.0, 1.0, 2.0])) + self.latent_reset_value_expected = nn.Parameter(torch.tensor([0.0, 0.0, 0.0])) + def test_instantiation(self): + """Test instantiation.""" + self.assertRaises(TypeError, latent_module.ArbitraryLatent) + def test_call(self): + """Test __call__.""" + self.assertEqual(self.latent().size(), self.latent_shape_expected) + def test_parameters(self): + """test parameters""" + params = self.latent.parameters(recurse=True) + self.assertIsInstance(params, Iterator) + + def test_reset_states(self): + """test reset_states""" + self.latent.reset_states() + mean = self.latent().mean().item() + std = self.latent().std().item() + self.assertAlmostEqual(mean, 0, places=1) + self.assertAlmostEqual(std, 1, places=1) if __name__ == '__main__': unittest.main() From b03b68b64477ebad9d6bf8cbec559b7618c655d1 Mon Sep 17 00:00:00 2001 From: yu1120 <65484599+myaokai@users.noreply.github.com> Date: Thu, 25 Jan 2024 16:37:05 +0900 Subject: [PATCH 095/117] create test_core --- tests/dl/torch/domain/__init__.py | 0 tests/dl/torch/domain/test_core.py | 83 ++++++++++++++++++++ tests/dl/torch/domain/test_feature_domain.py | 0 tests/dl/torch/domain/test_image_domain.py | 0 4 files changed, 83 insertions(+) create mode 100644 tests/dl/torch/domain/__init__.py create mode 100644 tests/dl/torch/domain/test_core.py create mode 100644 tests/dl/torch/domain/test_feature_domain.py create mode 100644 tests/dl/torch/domain/test_image_domain.py diff --git a/tests/dl/torch/domain/__init__.py b/tests/dl/torch/domain/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/dl/torch/domain/test_core.py b/tests/dl/torch/domain/test_core.py new file mode 100644 index 00000000..e774f38c --- /dev/null +++ b/tests/dl/torch/domain/test_core.py @@ -0,0 +1,83 @@ +"""Tests for bdpy.dl.torch.domain.core.""" + +import unittest +import torch +from bdpy.dl.torch.domain import core as core_module + + +class DummyAddDomain(core_module.Domain): + def send(self, num): + return num + 1 + + def receive(self, num): + return num - 1 + +class DummyDoubleDomain(core_module.Domain): + def send(self, num): + return num * 2 + + def receive(self, num): + return num // 2 + +class TestDomain(unittest.TestCase): + """Tests for bdpy.dl.torch.domain.core.Domain.""" + def setUp(self): + self.domian = DummyAddDomain() + self.original_space_num = 0 + self.internal_space_num = 1 + + def test_send(self): + """test send""" + self.assertEqual(self.domian.send(self.original_space_num), self.internal_space_num) + + def test_receive(self): + """test receive""" + self.assertEqual(self.domian.receive(self.internal_space_num), self.original_space_num) + +class TestInternalDomain(unittest.TestCase): + """Tests for bdpy.dl.torch.domain.core.InternalDomain.""" + def setUp(self): + self.domian = core_module.InternalDomain() + self.num = 1 + + def test_send(self): + """test send""" + self.assertEqual(self.domian.send(self.num), self.num) + + def test_receive(self): + """test receive""" + self.assertEqual(self.domian.receive(self.num), self.num) + +class TestIrreversibleDomain(unittest.TestCase): + """Tests for bdpy.dl.torch.domain.core.IrreversibleDomain.""" + def setUp(self): + self.domian = core_module.IrreversibleDomain() + self.num = 1 + + def test_send(self): + """test send""" + self.assertEqual(self.domian.send(self.num), self.num) + + def test_receive(self): + """test receive""" + self.assertEqual(self.domian.receive(self.num), self.num) + +class TestComposedDomain(unittest.TestCase): + """Tests for bdpy.dl.torch.domain.core.ComposedDomain.""" + def setUp(self): + self.composed_domian = core_module.ComposedDomain([ + DummyDoubleDomain(), + DummyAddDomain(), + ]) + self.original_space_num = 0 + self.internal_space_num = 2 + + def test_send(self): + """test send""" + self.assertEqual(self.composed_domian.send(self.original_space_num), self.internal_space_num) + + def test_receive(self): + """test receive""" + self.assertEqual(self.composed_domian.receive(self.internal_space_num), self.original_space_num) +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/dl/torch/domain/test_feature_domain.py b/tests/dl/torch/domain/test_feature_domain.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/dl/torch/domain/test_image_domain.py b/tests/dl/torch/domain/test_image_domain.py new file mode 100644 index 00000000..e69de29b From 6cef2f2a1dbe7331bf2c38babead84de41b2d8ec Mon Sep 17 00:00:00 2001 From: yu1120 <65484599+myaokai@users.noreply.github.com> Date: Thu, 25 Jan 2024 16:55:38 +0900 Subject: [PATCH 096/117] update test_core --- tests/dl/torch/domain/test_core.py | 28 ++++++++++++++++++++ tests/dl/torch/domain/test_feature_domain.py | 5 ++++ 2 files changed, 33 insertions(+) diff --git a/tests/dl/torch/domain/test_core.py b/tests/dl/torch/domain/test_core.py index e774f38c..f03d1ecb 100644 --- a/tests/dl/torch/domain/test_core.py +++ b/tests/dl/torch/domain/test_core.py @@ -19,6 +19,13 @@ def send(self, num): def receive(self, num): return num // 2 +class DummyUpperCaseDomain(core_module.Domain): + def send(self, text): + return text.upper() + + def receive(self, value): + return value.lower() + class TestDomain(unittest.TestCase): """Tests for bdpy.dl.torch.domain.core.Domain.""" def setUp(self): @@ -79,5 +86,26 @@ def test_send(self): def test_receive(self): """test receive""" self.assertEqual(self.composed_domian.receive(self.internal_space_num), self.original_space_num) + +class TestKeyValueDomain(unittest.TestCase): + """Tests for bdpy.dl.torch.domain.core.KeyValueDomain.""" + def setUp(self): + self.key_value_domian = core_module.KeyValueDomain({ + "name": DummyUpperCaseDomain(), + "age": DummyDoubleDomain() + }) + self.original_space_data = {"name": "alice", "age": 30} + self.internal_space_data = {"name": "ALICE", "age": 60} + + def test_send(self): + """test send""" + self.assertEqual(self.key_value_domian.send(self.original_space_data), self.internal_space_data) + + def test_receive(self): + """test receive""" + self.assertEqual(self.key_value_domian.receive(self.internal_space_data), self.original_space_data) + + + if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/tests/dl/torch/domain/test_feature_domain.py b/tests/dl/torch/domain/test_feature_domain.py index e69de29b..3fd137a9 100644 --- a/tests/dl/torch/domain/test_feature_domain.py +++ b/tests/dl/torch/domain/test_feature_domain.py @@ -0,0 +1,5 @@ +"""Tests for bdpy.dl.torch.domain.feature_domain.""" + +import unittest +import torch +from bdpy.dl.torch.domain import feature_domain as feature_domain_module \ No newline at end of file From 547d337028a8c0a97f8e18ed98527e71a3d11b97 Mon Sep 17 00:00:00 2001 From: yu1120 <65484599+myaokai@users.noreply.github.com> Date: Thu, 25 Jan 2024 17:35:21 +0900 Subject: [PATCH 097/117] create test_feature_domain --- tests/dl/torch/domain/test_feature_domain.py | 80 +++++++++++++++++++- 1 file changed, 79 insertions(+), 1 deletion(-) diff --git a/tests/dl/torch/domain/test_feature_domain.py b/tests/dl/torch/domain/test_feature_domain.py index 3fd137a9..72d7ba92 100644 --- a/tests/dl/torch/domain/test_feature_domain.py +++ b/tests/dl/torch/domain/test_feature_domain.py @@ -2,4 +2,82 @@ import unittest import torch -from bdpy.dl.torch.domain import feature_domain as feature_domain_module \ No newline at end of file +from bdpy.dl.torch.domain import feature_domain as feature_domain_module + +class TestMethods(unittest.TestCase): + def setUp(self): + self.lnd_tensor = torch.empty((12, 196, 768)) + self.nld_tensor = torch.empty((196, 12, 768)) + + def test_lnd2nld(self): + """test _lnd2nld""" + self.assertEqual(feature_domain_module._lnd2nld(self.lnd_tensor).shape, self.nld_tensor.shape) + + def test_nld2lnd(self): + """test _nld2lnd""" + self.assertEqual(feature_domain_module._nld2lnd(self.nld_tensor).shape, self.lnd_tensor.shape) + +class TestArbitraryFeatureKeyDomain(unittest.TestCase): + """Tests for bdpy.dl.torch.domain.feature_domain.ArbitraryFeatureKeyDomain.""" + def setUp(self): + self.to_internal_mapping = { + "self_key1": "internal_key1", + "self_key2": "internal_key2" + } + self.to_self_mapping = { + "internal_key1": "self_key1", + "internal_key2": "self_key2" + } + self.features = { + "self_key1": 123, + "self_key2": 456 + } + self.internal_features = { + "internal_key1": 123, + "internal_key2": 456 + } + + def test_send(self): + """test send""" + # when both are specified + domain = feature_domain_module.ArbitraryFeatureKeyDomain( + to_internal=self.to_internal_mapping, + to_self=self.to_self_mapping + ) + self.assertEqual(domain.send(self.features), self.internal_features) + + # when only to_self is specified + domain = feature_domain_module.ArbitraryFeatureKeyDomain( + to_self=self.to_self_mapping + ) + self.assertEqual(domain.send(self.features), self.internal_features) + + # when only to_internal is specified + domain = feature_domain_module.ArbitraryFeatureKeyDomain( + to_internal=self.to_internal_mapping + ) + self.assertEqual(domain.send(self.features), self.internal_features) + + def test_receive(self): + """test receive""" + # when both are specified + domain = feature_domain_module.ArbitraryFeatureKeyDomain( + to_internal=self.to_internal_mapping, + to_self=self.to_self_mapping + ) + self.assertEqual(domain.receive(self.internal_features), self.features) + + # when only to_self is specified + domain = feature_domain_module.ArbitraryFeatureKeyDomain( + to_self=self.to_self_mapping + ) + self.assertEqual(domain.receive(self.internal_features), self.features) + + # when only to_internal is specified + domain = feature_domain_module.ArbitraryFeatureKeyDomain( + to_internal=self.to_internal_mapping + ) + self.assertEqual(domain.receive(self.internal_features), self.features) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From b038a757e8f049ac1e3d28a711055dc7b3e9e3d2 Mon Sep 17 00:00:00 2001 From: yu1120 <65484599+myaokai@users.noreply.github.com> Date: Sat, 27 Jan 2024 02:45:41 +0900 Subject: [PATCH 098/117] create test_image_domain --- bdpy/dl/torch/domain/image_domain.py | 2 +- tests/dl/torch/domain/test_core.py | 4 + tests/dl/torch/domain/test_image_domain.py | 120 +++++++++++++++++++++ 3 files changed, 125 insertions(+), 1 deletion(-) diff --git a/bdpy/dl/torch/domain/image_domain.py b/bdpy/dl/torch/domain/image_domain.py index dcf9440d..74b49b49 100644 --- a/bdpy/dl/torch/domain/image_domain.py +++ b/bdpy/dl/torch/domain/image_domain.py @@ -85,7 +85,7 @@ def __init__( if isinstance(center, (float, int)) or center.ndim == 0: center = np.array([center])[np.newaxis, np.newaxis, np.newaxis] - if center.ndim == 1: # 1D vector (C,) + elif center.ndim == 1: # 1D vector (C,) center = center[np.newaxis, :, np.newaxis, np.newaxis] elif center.ndim == 3: # 3D vector (1, C, W, H) center = center[np.newaxis] diff --git a/tests/dl/torch/domain/test_core.py b/tests/dl/torch/domain/test_core.py index f03d1ecb..bb00e2ea 100644 --- a/tests/dl/torch/domain/test_core.py +++ b/tests/dl/torch/domain/test_core.py @@ -33,6 +33,10 @@ def setUp(self): self.original_space_num = 0 self.internal_space_num = 1 + def test_instantiation(self): + """Test instantiation.""" + self.assertRaises(TypeError, core_module.Domain) + def test_send(self): """test send""" self.assertEqual(self.domian.send(self.original_space_num), self.internal_space_num) diff --git a/tests/dl/torch/domain/test_image_domain.py b/tests/dl/torch/domain/test_image_domain.py index e69de29b..5c1d8635 100644 --- a/tests/dl/torch/domain/test_image_domain.py +++ b/tests/dl/torch/domain/test_image_domain.py @@ -0,0 +1,120 @@ +"""Tests for bdpy.dl.torch.domain.image_domain.""" + +import unittest +import torch +import numpy as np +import warnings +from bdpy.dl.torch.domain import image_domain as iamge_domain_module + +class TestAffineDomain(unittest.TestCase): + """Tests for bdpy.dl.torch.domain.image_domain.AffineDomain""" + def setUp(self): + self.center0d = 0.0 + self.center1d = np.random.randn(3) + self.center2d = np.random.randn(32, 32) + self.center3d = np.random.randn(3, 32, 32) + self.scale0d = 1 + self.scale1d = np.random.randn(3) + self.scale2d = np.random.randn(32, 32) + self.scale3d = np.random.randn(3, 32, 32) + self.image = torch.rand((1, 3, 32, 32)) + + def test_instantiation(self): + """Test instantiation.""" + # Succeeds when center and scale are 0-dimensional + affine_domain = iamge_domain_module.AffineDomain(self.center0d, self.scale0d) + self.assertIsInstance(affine_domain, iamge_domain_module.AffineDomain) + + # Succeeds when center and scale are 1-dimensional + affine_domain = iamge_domain_module.AffineDomain(self.center1d, self.scale1d) + self.assertIsInstance(affine_domain, iamge_domain_module.AffineDomain) + + # Succeeds when center and scale are 3-dimensional + affine_domain = iamge_domain_module.AffineDomain(self.center3d, self.scale3d) + self.assertIsInstance(affine_domain, iamge_domain_module.AffineDomain) + + # Failss when the center is neither 1-dimensional nor 3-dimensional + with self.assertRaises(ValueError): + iamge_domain_module.AffineDomain(self.center2d, self.scale0d) + + # Failss when the scale is neither 1-dimensional nor 3-dimensional + with self.assertRaises(ValueError): + iamge_domain_module.AffineDomain(self.center0d, self.scale2d) + + def test_send_and_receive(self): + """Test send and receive""" + # when 0d + affine_domain = iamge_domain_module.AffineDomain(self.center0d, self.scale0d) + transformed_image = affine_domain.send(self.image) + center0d = torch.from_numpy(np.array([self.center0d])[np.newaxis, np.newaxis, np.newaxis]) + scale0d = torch.from_numpy(np.array([self.scale0d])[np.newaxis, np.newaxis, np.newaxis]) + expected_transformed_image = (self.image + center0d) / self.scale0d + torch.testing.assert_close(transformed_image, expected_transformed_image) + received_image = affine_domain.receive(transformed_image) + expected_received_image = expected_transformed_image * scale0d - center0d + torch.testing.assert_close(received_image, expected_received_image) + + # when 1d + affine_domain = iamge_domain_module.AffineDomain(self.center1d, self.scale1d) + transformed_image = affine_domain.send(self.image) + center1d = self.center1d[np.newaxis, :, np.newaxis, np.newaxis] + scale1d = self.scale1d[np.newaxis, :, np.newaxis, np.newaxis] + expected_transformed_image = (self.image + center1d) / scale1d + torch.testing.assert_close(transformed_image, expected_transformed_image) + received_image = affine_domain.receive(transformed_image) + expected_received_image = expected_transformed_image * scale1d - center1d + torch.testing.assert_close(received_image, expected_received_image) + + # when 3d + affine_domain = iamge_domain_module.AffineDomain(self.center3d, self.scale3d) + transformed_image = affine_domain.send(self.image) + center3d = self.center3d[np.newaxis] + scale3d = self.scale3d[np.newaxis] + expected_transformed_image = (self.image + center3d) / scale3d + torch.testing.assert_close(transformed_image, expected_transformed_image) + received_image = affine_domain.receive(transformed_image) + expected_received_image = expected_transformed_image * scale3d - center3d + torch.testing.assert_close(received_image, expected_received_image) + +class TestRGBDomain(unittest.TestCase): + """Tests fot bdpy.dl.torch.domain.image_domain.BGRDomain""" + + def setUp(self): + self.bgr_image = torch.rand((1, 3, 32, 32)) + self.rgb_image = self.bgr_image[:, [2, 1, 0], ...] + + def test_send(self): + """Test send""" + bgr_domain = iamge_domain_module.BGRDomain() + transformed_image = bgr_domain.send(self.bgr_image) + torch.testing.assert_close(transformed_image, self.rgb_image) + + def test_receive(self): + """Tests receive""" + bgr_domain = iamge_domain_module.BGRDomain() + received_image = bgr_domain.receive(self.rgb_image) + torch.testing.assert_close(received_image, self.bgr_image) + +class TestPILDomainWithExplicitCrop(unittest.TestCase): + """Tests fot bdpy.dl.torch.domain.image_domain.PILDomainWithExplicitCrop""" + def setUp(self): + self.expected_transformed_image = torch.rand((1, 3, 32, 32)) + self.image = self.expected_transformed_image.permute(0, 2, 3, 1) * 255 + + def test_send(self): + """Test send""" + pdwe_domain = iamge_domain_module.PILDomainWithExplicitCrop() + transformed_image = pdwe_domain.send(self.image) + torch.testing.assert_close(transformed_image, self.expected_transformed_image) + + def test_receive(self): + """Tests receive""" + pdwe_domain = iamge_domain_module.PILDomainWithExplicitCrop() + with warnings.catch_warnings(record=True) as w: + received_image = pdwe_domain.receive(self.expected_transformed_image) + self.assertTrue(any(isinstance(warn.message, RuntimeWarning) for warn in w)) + torch.testing.assert_close(received_image, self.image) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 6605f280e5ebf4cd5933e479c1c50ac5638f8612 Mon Sep 17 00:00:00 2001 From: yu1120 <65484599+myaokai@users.noreply.github.com> Date: Tue, 30 Jan 2024 15:58:51 +0900 Subject: [PATCH 099/117] Update test_image_domain.py --- tests/dl/torch/domain/test_image_domain.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/tests/dl/torch/domain/test_image_domain.py b/tests/dl/torch/domain/test_image_domain.py index 5c1d8635..9f658fee 100644 --- a/tests/dl/torch/domain/test_image_domain.py +++ b/tests/dl/torch/domain/test_image_domain.py @@ -5,6 +5,7 @@ import numpy as np import warnings from bdpy.dl.torch.domain import image_domain as iamge_domain_module +from IPython import embed class TestAffineDomain(unittest.TestCase): """Tests for bdpy.dl.torch.domain.image_domain.AffineDomain""" @@ -114,7 +115,25 @@ def test_receive(self): received_image = pdwe_domain.receive(self.expected_transformed_image) self.assertTrue(any(isinstance(warn.message, RuntimeWarning) for warn in w)) torch.testing.assert_close(received_image, self.image) - + +class TestFixedResolutionDomain(unittest.TestCase): + """Tests fot bdpy.dl.torch.domain.image_domain.FixedResolutionDomain""" + def setUp(self): + self.expected_received_image_size = (1, 3, 16, 16) + self.image =torch.rand((1, 3, 32, 32)) + + def test_send(self): + """Test send""" + fr_domain = iamge_domain_module.FixedResolutionDomain((16, 16)) + with self.assertRaises(RuntimeError): + fr_domain.send(self.image) + + def test_receive(self): + """Tests receive""" + fr_domain = iamge_domain_module.FixedResolutionDomain((16, 16)) + + received_image = fr_domain.receive(self.image) + self.assertEqual(received_image.size(), self.expected_received_image_size) if __name__ == "__main__": unittest.main() \ No newline at end of file From 7cbfc0ecf23311bb3c9024342950f2b811e166d2 Mon Sep 17 00:00:00 2001 From: yu1120 <65484599+myaokai@users.noreply.github.com> Date: Tue, 30 Jan 2024 16:41:23 +0900 Subject: [PATCH 100/117] create task/test_core --- tests/dl/torch/domain/test_image_domain.py | 1 - tests/recon/torch/modules/test_latent.py | 1 - tests/task/test_core.py | 57 ++++++++++++++++++++++ 3 files changed, 57 insertions(+), 2 deletions(-) diff --git a/tests/dl/torch/domain/test_image_domain.py b/tests/dl/torch/domain/test_image_domain.py index 9f658fee..c52bfa4f 100644 --- a/tests/dl/torch/domain/test_image_domain.py +++ b/tests/dl/torch/domain/test_image_domain.py @@ -5,7 +5,6 @@ import numpy as np import warnings from bdpy.dl.torch.domain import image_domain as iamge_domain_module -from IPython import embed class TestAffineDomain(unittest.TestCase): """Tests for bdpy.dl.torch.domain.image_domain.AffineDomain""" diff --git a/tests/recon/torch/modules/test_latent.py b/tests/recon/torch/modules/test_latent.py index 76b5ca32..5910ad03 100644 --- a/tests/recon/torch/modules/test_latent.py +++ b/tests/recon/torch/modules/test_latent.py @@ -5,7 +5,6 @@ import torch.nn as nn from functools import partial from bdpy.recon.torch.modules import latent as latent_module -from IPython import embed class DummyLatent(latent_module.BaseLatent): diff --git a/tests/task/test_core.py b/tests/task/test_core.py index e69de29b..3fdd2a5e 100644 --- a/tests/task/test_core.py +++ b/tests/task/test_core.py @@ -0,0 +1,57 @@ +"""Tests for bdpy.task.core.""" + +from __future__ import annotations + +import unittest + +from bdpy.task import core as core_module + +class MockCallback(core_module.BaseCallback): + """Mock callback for testing.""" + def __init__(self): + self._storage = [] + + def on_some_event(self, input_): + self._storage.append(input_) + +class MockTask(core_module.BaseTask[MockCallback]): + """Mock task for testing BaseTask.""" + def __call__(self, *inputs, **parameters): + self._callback_handler.fire("on_some_event", input_=1) + return inputs, parameters + +class TestBaseTask(unittest.TestCase): + """Tests forbdpy.task.core.BaseTask """ + def setUp(self): + self.input1 = 1.0 + self.input2 = 2.0 + self.task_name = "reconstruction" + + def test_initialization_without_callbacks(self): + """Test initialization without callbacks.""" + task = MockTask() + self.assertIsInstance(task._callback_handler, core_module.CallbackHandler) + self.assertEqual(len(task._callback_handler._callbacks), 0) + + def test_initialization_with_callbacks(self): + """Test initialization with callbacks.""" + mock_callback = MockCallback() + task = MockTask(callbacks=mock_callback) + self.assertEqual(len(task._callback_handler._callbacks), 1) + self.assertIn(mock_callback, task._callback_handler._callbacks) + + def test_register_callback(self): + """Test register_callback method.""" + task = MockTask() + mock_callback = MockCallback() + task.register_callback(mock_callback) + self.assertIn(mock_callback, task._callback_handler._callbacks) + + def test_call(self): + """Test __call__""" + mock_callback = MockCallback() + task = MockTask(callbacks=mock_callback) + task_inputs, task_parameters = task(self.input1, self.input2, name=self.task_name) + self.assertEqual(task_inputs, (self.input1, self.input2)) + self.assertEqual(task_parameters["name"], self.task_name) + self.assertEqual(mock_callback._storage, [1]) From ca3b1629a033457542e3803a08f2dfb926f9bf64 Mon Sep 17 00:00:00 2001 From: yu1120 <65484599+myaokai@users.noreply.github.com> Date: Tue, 6 Feb 2024 22:34:13 +0900 Subject: [PATCH 101/117] create inversion test --- tests/recon/torch/task/test_inversion.py | 97 ++++++++++++++++++++++++ tests/task/test_core.py | 3 + 2 files changed, 100 insertions(+) diff --git a/tests/recon/torch/task/test_inversion.py b/tests/recon/torch/task/test_inversion.py index e69de29b..81f412c8 100644 --- a/tests/recon/torch/task/test_inversion.py +++ b/tests/recon/torch/task/test_inversion.py @@ -0,0 +1,97 @@ +"""Tests for bdpy.recon.torch.task.inversion""" + +from __future__ import annotations + +import unittest +from unittest.mock import patch +import torch + +from bdpy.recon.torch.task import inversion as inversion_module +from bdpy.task import callback as callback_module + + +class TaskFeatureInversionCallback(inversion_module.FeatureInversionCallback): + def __init__(self): + super().__init__() + + def on_task_start(self): + print('task start') + +class TestFeatureInversionCallback(unittest.TestCase): + """Tests for bdpy.recon.torch.task.inversion.FeatureInversionCallback""" + def setUp(self): + self.callback = inversion_module.FeatureInversionCallback() + self.expected_method_names = { + "on_task_start", + "on_iteration_start", + "on_image_generated", + "on_layerwise_loss_calculated", + "on_loss_calculated", + "on_iteration_end", + "on_task_end", + } + + def test_instance_methods(self): + method_names = { + event_type + for event_type in dir(self.callback) + if event_type.startswith("on_") + and callable(getattr(self.callback, event_type)) + } + self.assertEqual(method_names, self.expected_method_names) + for event_type in method_names: + fn = getattr(self.callback, event_type) + self.assertRaises(RuntimeError, fn) + + + def test_validate_callback(self): + + class Unrelated(callback_module.BaseCallback): + """Valid callback object but is not a subclass of TaskFeatureInversionCallback""" + + pass + + class HasUnknownEvent(TaskFeatureInversionCallback): + """Having invalid instance method `on_unknown_event` as a subclass of TaskFeatureInversionCallback""" + + def on_unknown_event(self): + pass + + self.assertIsNone( + callback_module._validate_callback(TaskFeatureInversionCallback(), inversion_module.FeatureInversionCallback) + ) + self.assertRaises( + TypeError, callback_module._validate_callback, Unrelated(), inversion_module.FeatureInversionCallback + ) + self.assertRaises(ValueError, HasUnknownEvent) + + +class TestCUILoggingCallback(unittest.TestCase): + """Tests for bdpy.recon.torch.task.inversion.CUILoggingCallback""" + def setUp(self): + self.callback = inversion_module.CUILoggingCallback() + self.expected_loss = torch.tensor([1.0]) + + def test_on_loss_culculated(self): + self.callback.on_loss_calculated(step=0, loss=self.expected_loss) + self.assertEqual(self.callback._loss, self.expected_loss.item()) + + @patch('builtins.print') + def test_on_iteration_end(self, mock_print): + self.callback.on_iteration_end(step=0) + mock_print.assert_called_once_with("Step: [1], Loss: -1.0000") + +class TestFeatureInversionTask(unittest.TestCase): + """Tests for bdpy.recon.torch.task.inversion.FeatureInversionTask""" + def setUp(self): + pass + + + + + + + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/task/test_core.py b/tests/task/test_core.py index 3fdd2a5e..98b92527 100644 --- a/tests/task/test_core.py +++ b/tests/task/test_core.py @@ -55,3 +55,6 @@ def test_call(self): self.assertEqual(task_inputs, (self.input1, self.input2)) self.assertEqual(task_parameters["name"], self.task_name) self.assertEqual(mock_callback._storage, [1]) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 948d5a59e02ab8e3dcfe73af2d5dc6e9ae4586d2 Mon Sep 17 00:00:00 2001 From: yu1120 <65484599+myaokai@users.noreply.github.com> Date: Thu, 8 Feb 2024 16:09:33 +0900 Subject: [PATCH 102/117] Update test_inversion.py --- tests/recon/torch/task/test_inversion.py | 143 +++++++++++++++++++++-- 1 file changed, 132 insertions(+), 11 deletions(-) diff --git a/tests/recon/torch/task/test_inversion.py b/tests/recon/torch/task/test_inversion.py index 81f412c8..9a532cbc 100644 --- a/tests/recon/torch/task/test_inversion.py +++ b/tests/recon/torch/task/test_inversion.py @@ -3,19 +3,53 @@ from __future__ import annotations import unittest -from unittest.mock import patch +from unittest.mock import patch, call +import copy import torch +import torch.nn as nn +import torch.optim as optim from bdpy.recon.torch.task import inversion as inversion_module from bdpy.task import callback as callback_module +from bdpy.dl.torch.domain.image_domain import Zero2OneImageDomain +from bdpy.recon.torch.modules import encoder as encoder_module +from bdpy.recon.torch.modules import generator as generator_module +from bdpy.recon.torch.modules import latent as latent_module +from bdpy.recon.torch.modules import critic as critic_module +from IPython import embed -class TaskFeatureInversionCallback(inversion_module.FeatureInversionCallback): - def __init__(self): +class DummyFeatureInversionCallback(inversion_module.FeatureInversionCallback): + def __init__(self, total_steps = 1): super().__init__() + self._total_steps = total_steps + self._loss = 0 + + def _step_str(self, step: int) -> str: + if self._total_steps > 0: + return f"{step+1}/{self._total_steps}" + else: + return f"{step+1}" def on_task_start(self): print('task start') + + def on_iteration_start(self, step): + print(f"Step [{self._step_str(step)}] start") + + def on_image_generated(self, step, image): + print(f"Step [{self._step_str(step)}], {image.shape}") + + def on_loss_calculated(self, step, loss): + self._loss = loss.item() + + def on_iteration_end(self, step): + print(f"Step [{self._step_str(step)}] end") + + def on_task_end(self): + print('task end') + + class TestFeatureInversionCallback(unittest.TestCase): """Tests for bdpy.recon.torch.task.inversion.FeatureInversionCallback""" @@ -51,14 +85,14 @@ class Unrelated(callback_module.BaseCallback): pass - class HasUnknownEvent(TaskFeatureInversionCallback): + class HasUnknownEvent(DummyFeatureInversionCallback): """Having invalid instance method `on_unknown_event` as a subclass of TaskFeatureInversionCallback""" def on_unknown_event(self): pass self.assertIsNone( - callback_module._validate_callback(TaskFeatureInversionCallback(), inversion_module.FeatureInversionCallback) + callback_module._validate_callback(DummyFeatureInversionCallback(), inversion_module.FeatureInversionCallback) ) self.assertRaises( TypeError, callback_module._validate_callback, Unrelated(), inversion_module.FeatureInversionCallback @@ -81,17 +115,104 @@ def test_on_iteration_end(self, mock_print): self.callback.on_iteration_end(step=0) mock_print.assert_called_once_with("Step: [1], Loss: -1.0000") -class TestFeatureInversionTask(unittest.TestCase): - """Tests for bdpy.recon.torch.task.inversion.FeatureInversionTask""" - def setUp(self): - pass +class MLP(nn.Module): + """A simple MLP.""" - + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(7 * 7 * 3, 32) + self.fc2 = nn.Linear(32, 10) + def forward(self, x): + x = x.view(x.size(0), -1) + x = self.fc1(x) + x = torch.relu(x) + x = self.fc2(x) + return x +class LinearGenerator(generator_module.NNModuleGenerator): + def __init__(self): + super().__init__() + self.fc = nn.Linear(10, 7 * 7 * 3) - + def generate(self, latent): + return self.fc(latent) + + def reset_states(self) -> None: + self.fc.apply(generator_module.call_reset_parameters) + +class DummyNNModuleLatent(latent_module.NNModuleLatent): + def __init__(self, base_latent): + super().__init__() + self.latent = nn.Parameter(base_latent) + def reset_states(self): + with torch.no_grad(): + self.latent.fill_(0.0) + + def generate(self): + return self.latent + +class TestFeatureInversionTask(unittest.TestCase): + """Tests for bdpy.recon.torch.task.inversion.FeatureInversionTask""" + def setUp(self): + self.init_latent = torch.randn(1, 10) + self.target_feature = { + 'fc1': torch.randn(1, 32), + 'fc2': torch.randn(1, 10) + } + self.encoder = encoder_module.SimpleEncoder( + MLP(), ["fc1", "fc2"], domain=Zero2OneImageDomain() + ) + self.generator = generator_module.DNNGenerator(LinearGenerator()) + self.latent = DummyNNModuleLatent(self.init_latent.clone()) + self.critic = critic_module.MSE() + self.optimizer = optim.SGD([self.latent.latent], lr=0.1) + self.callbacks = DummyFeatureInversionCallback() + + self.inversion_task = inversion_module.FeatureInversionTask( + encoder=self.encoder, + generator=copy.deepcopy(self.generator), + latent=copy.deepcopy(self.latent), + critic=self.critic, + optimizer=self.optimizer, + callbacks=self.callbacks + ) + + @patch('builtins.print') + def test_call(self, mock_print): + """Test __call__.""" + generated_image = self.inversion_task(self.target_feature) + self.assertTrue(len(self.inversion_task._callback_handler._callbacks) > 0) + + # test for process + self.assertEqual(generated_image.shape, (1, 7 * 7 * 3)) + self.assertIsNotNone(self.inversion_task._generator._generator_network.fc.weight.grad) + self.assertFalse(torch.equal(self.inversion_task._latent.latent, self.init_latent)) + + + # test for callbacks + self.assertTrue(self.inversion_task._callback_handler._callbacks[0]._loss > 0 ) + mock_print.assert_has_calls([ + call('task start'), + call('Step [1/1] start'), + call('Step [1/1], torch.Size([1, 147])'), + call('Step [1/1] end'), + call('task end'), + ]) + + def test_reset_state(self): + """Test reset_states.""" + generator_copy = copy.deepcopy(self.inversion_task._generator) + latent_copy = copy.deepcopy(self.inversion_task._latent) + for p1, p2 in zip(self.inversion_task._generator.parameters(), generator_copy.parameters()): + self.assertTrue(torch.equal(p1, p2)) + torch.testing.assert_close(self.inversion_task._latent.latent, latent_copy.latent) + self.inversion_task.reset_states() + + for p1, p2 in zip(self.inversion_task._generator.parameters(), generator_copy.parameters()): + self.assertFalse(torch.equal(p1, p2)) + self.assertFalse(torch.equal(self.inversion_task._latent.latent, latent_copy.latent)) if __name__ == "__main__": unittest.main() \ No newline at end of file From bd2e180816d41156517ee7df1cc6d81a462123f9 Mon Sep 17 00:00:00 2001 From: yu1120 <65484599+myaokai@users.noreply.github.com> Date: Thu, 8 Feb 2024 16:18:19 +0900 Subject: [PATCH 103/117] Update test_inversion.py --- tests/recon/torch/task/test_inversion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/recon/torch/task/test_inversion.py b/tests/recon/torch/task/test_inversion.py index 9a532cbc..05503b0e 100644 --- a/tests/recon/torch/task/test_inversion.py +++ b/tests/recon/torch/task/test_inversion.py @@ -172,8 +172,8 @@ def setUp(self): self.inversion_task = inversion_module.FeatureInversionTask( encoder=self.encoder, - generator=copy.deepcopy(self.generator), - latent=copy.deepcopy(self.latent), + generator=self.generator, + latent=self.latent, critic=self.critic, optimizer=self.optimizer, callbacks=self.callbacks From 3f693810cae7466e6c791f666672b4c19e6fc83f Mon Sep 17 00:00:00 2001 From: myaokai <65484599+myaokai@users.noreply.github.com> Date: Thu, 21 Mar 2024 15:23:24 +0900 Subject: [PATCH 104/117] Update tests/dl/torch/domain/test_core.py Co-authored-by: Yoshihiro Nagano --- tests/dl/torch/domain/test_core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/dl/torch/domain/test_core.py b/tests/dl/torch/domain/test_core.py index bb00e2ea..82b8d1ec 100644 --- a/tests/dl/torch/domain/test_core.py +++ b/tests/dl/torch/domain/test_core.py @@ -1,7 +1,6 @@ """Tests for bdpy.dl.torch.domain.core.""" import unittest -import torch from bdpy.dl.torch.domain import core as core_module From 291807fde6d887cdec8afd6ffc20732748b6aa43 Mon Sep 17 00:00:00 2001 From: myaokai <65484599+myaokai@users.noreply.github.com> Date: Thu, 21 Mar 2024 15:23:36 +0900 Subject: [PATCH 105/117] Update tests/dl/torch/domain/test_core.py Co-authored-by: Yoshihiro Nagano --- tests/dl/torch/domain/test_core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/dl/torch/domain/test_core.py b/tests/dl/torch/domain/test_core.py index 82b8d1ec..ca8687d4 100644 --- a/tests/dl/torch/domain/test_core.py +++ b/tests/dl/torch/domain/test_core.py @@ -11,6 +11,7 @@ def send(self, num): def receive(self, num): return num - 1 + class DummyDoubleDomain(core_module.Domain): def send(self, num): return num * 2 From c87d8bcacd9f641b87827573d71acc2fff6bcce5 Mon Sep 17 00:00:00 2001 From: myaokai <65484599+myaokai@users.noreply.github.com> Date: Thu, 21 Mar 2024 15:23:43 +0900 Subject: [PATCH 106/117] Update tests/dl/torch/domain/test_image_domain.py Co-authored-by: Yoshihiro Nagano --- tests/dl/torch/domain/test_image_domain.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dl/torch/domain/test_image_domain.py b/tests/dl/torch/domain/test_image_domain.py index c52bfa4f..81682040 100644 --- a/tests/dl/torch/domain/test_image_domain.py +++ b/tests/dl/torch/domain/test_image_domain.py @@ -33,7 +33,7 @@ def test_instantiation(self): affine_domain = iamge_domain_module.AffineDomain(self.center3d, self.scale3d) self.assertIsInstance(affine_domain, iamge_domain_module.AffineDomain) - # Failss when the center is neither 1-dimensional nor 3-dimensional + # Fails when the center is neither 1-dimensional nor 3-dimensional with self.assertRaises(ValueError): iamge_domain_module.AffineDomain(self.center2d, self.scale0d) From 7f0c3235ab6fac5517cfcca29020933edac672c0 Mon Sep 17 00:00:00 2001 From: myaokai <65484599+myaokai@users.noreply.github.com> Date: Thu, 21 Mar 2024 15:23:49 +0900 Subject: [PATCH 107/117] Update tests/dl/torch/domain/test_image_domain.py Co-authored-by: Yoshihiro Nagano --- tests/dl/torch/domain/test_image_domain.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dl/torch/domain/test_image_domain.py b/tests/dl/torch/domain/test_image_domain.py index 81682040..e8d79cc5 100644 --- a/tests/dl/torch/domain/test_image_domain.py +++ b/tests/dl/torch/domain/test_image_domain.py @@ -37,7 +37,7 @@ def test_instantiation(self): with self.assertRaises(ValueError): iamge_domain_module.AffineDomain(self.center2d, self.scale0d) - # Failss when the scale is neither 1-dimensional nor 3-dimensional + # Fails when the scale is neither 1-dimensional nor 3-dimensional with self.assertRaises(ValueError): iamge_domain_module.AffineDomain(self.center0d, self.scale2d) From 13da19fab63e15ddc550ef46a76daca8f0a4b5dd Mon Sep 17 00:00:00 2001 From: myaokai <65484599+myaokai@users.noreply.github.com> Date: Thu, 21 Mar 2024 15:24:21 +0900 Subject: [PATCH 108/117] Update tests/dl/torch/domain/test_core.py Co-authored-by: Yoshihiro Nagano --- tests/dl/torch/domain/test_core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/dl/torch/domain/test_core.py b/tests/dl/torch/domain/test_core.py index ca8687d4..92442ffe 100644 --- a/tests/dl/torch/domain/test_core.py +++ b/tests/dl/torch/domain/test_core.py @@ -18,7 +18,8 @@ def send(self, num): def receive(self, num): return num // 2 - + + class DummyUpperCaseDomain(core_module.Domain): def send(self, text): return text.upper() From 46d4d10fa58aa60e1a9b31c902cefc3061c83715 Mon Sep 17 00:00:00 2001 From: myaokai <65484599+myaokai@users.noreply.github.com> Date: Thu, 21 Mar 2024 15:24:27 +0900 Subject: [PATCH 109/117] Update tests/recon/torch/modules/test_latent.py Co-authored-by: Yoshihiro Nagano --- tests/recon/torch/modules/test_latent.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/recon/torch/modules/test_latent.py b/tests/recon/torch/modules/test_latent.py index 5910ad03..a668114e 100644 --- a/tests/recon/torch/modules/test_latent.py +++ b/tests/recon/torch/modules/test_latent.py @@ -1,6 +1,5 @@ import torch import unittest -from abc import ABC, abstractmethod from typing import Iterator import torch.nn as nn from functools import partial From 75a45ab795a6e6bde9f38e0cef760cd91588d4ff Mon Sep 17 00:00:00 2001 From: myaokai <65484599+myaokai@users.noreply.github.com> Date: Thu, 21 Mar 2024 15:24:35 +0900 Subject: [PATCH 110/117] Update tests/recon/torch/task/test_inversion.py Co-authored-by: Yoshihiro Nagano --- tests/recon/torch/task/test_inversion.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/recon/torch/task/test_inversion.py b/tests/recon/torch/task/test_inversion.py index 05503b0e..cf04b56f 100644 --- a/tests/recon/torch/task/test_inversion.py +++ b/tests/recon/torch/task/test_inversion.py @@ -16,7 +16,6 @@ from bdpy.recon.torch.modules import generator as generator_module from bdpy.recon.torch.modules import latent as latent_module from bdpy.recon.torch.modules import critic as critic_module -from IPython import embed class DummyFeatureInversionCallback(inversion_module.FeatureInversionCallback): From bbae3d7b835ae12534852fb18b79977203c5325f Mon Sep 17 00:00:00 2001 From: myaokai <65484599+myaokai@users.noreply.github.com> Date: Thu, 21 Mar 2024 15:53:51 +0900 Subject: [PATCH 111/117] Update test_core.py --- tests/dl/torch/domain/test_core.py | 58 ++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 18 deletions(-) diff --git a/tests/dl/torch/domain/test_core.py b/tests/dl/torch/domain/test_core.py index 92442ffe..092c3487 100644 --- a/tests/dl/torch/domain/test_core.py +++ b/tests/dl/torch/domain/test_core.py @@ -26,11 +26,12 @@ def send(self, text): def receive(self, value): return value.lower() - + + class TestDomain(unittest.TestCase): """Tests for bdpy.dl.torch.domain.core.Domain.""" def setUp(self): - self.domian = DummyAddDomain() + self.domain = DummyAddDomain() self.original_space_num = 0 self.internal_space_num = 1 @@ -40,44 +41,59 @@ def test_instantiation(self): def test_send(self): """test send""" - self.assertEqual(self.domian.send(self.original_space_num), self.internal_space_num) + self.assertEqual(self.domain.send(self.original_space_num), self.internal_space_num) def test_receive(self): """test receive""" - self.assertEqual(self.domian.receive(self.internal_space_num), self.original_space_num) + self.assertEqual(self.domain.receive(self.internal_space_num), self.original_space_num) + + def test_invertibility(self): + input_candidates = [-1, 0, 1, 0.5] + for x in input_candidates: + assert x == self.domain.send(self.domain.receive(x)) + assert x == self.domain.receive(self.domain.send(x)) + class TestInternalDomain(unittest.TestCase): """Tests for bdpy.dl.torch.domain.core.InternalDomain.""" def setUp(self): - self.domian = core_module.InternalDomain() + self.domain = core_module.InternalDomain() self.num = 1 def test_send(self): """test send""" - self.assertEqual(self.domian.send(self.num), self.num) + self.assertEqual(self.domain.send(self.num), self.num) def test_receive(self): """test receive""" - self.assertEqual(self.domian.receive(self.num), self.num) + self.assertEqual(self.domain.receive(self.num), self.num) + + def test_invertibility(self): + input_candidates = [-1, 0, 1, 0.5] + for x in input_candidates: + assert x == self.domain.send(self.domain.receive(x)) + assert x == self.domain.receive(self.domain.send(x)) + class TestIrreversibleDomain(unittest.TestCase): """Tests for bdpy.dl.torch.domain.core.IrreversibleDomain.""" def setUp(self): - self.domian = core_module.IrreversibleDomain() + self.domain = core_module.IrreversibleDomain() self.num = 1 def test_send(self): """test send""" - self.assertEqual(self.domian.send(self.num), self.num) + self.assertEqual(self.domain.send(self.num), self.num) def test_receive(self): """test receive""" - self.assertEqual(self.domian.receive(self.num), self.num) + self.assertEqual(self.domain.receive(self.num), self.num) + class TestComposedDomain(unittest.TestCase): """Tests for bdpy.dl.torch.domain.core.ComposedDomain.""" def setUp(self): - self.composed_domian = core_module.ComposedDomain([ + self.composed_domain = core_module.ComposedDomain([ DummyDoubleDomain(), DummyAddDomain(), ]) @@ -86,16 +102,17 @@ def setUp(self): def test_send(self): """test send""" - self.assertEqual(self.composed_domian.send(self.original_space_num), self.internal_space_num) + self.assertEqual(self.composed_domain.send(self.original_space_num), self.internal_space_num) def test_receive(self): """test receive""" - self.assertEqual(self.composed_domian.receive(self.internal_space_num), self.original_space_num) + self.assertEqual(self.composed_domain.receive(self.internal_space_num), self.original_space_num) + class TestKeyValueDomain(unittest.TestCase): """Tests for bdpy.dl.torch.domain.core.KeyValueDomain.""" def setUp(self): - self.key_value_domian = core_module.KeyValueDomain({ + self.key_value_domain = core_module.KeyValueDomain({ "name": DummyUpperCaseDomain(), "age": DummyDoubleDomain() }) @@ -104,13 +121,18 @@ def setUp(self): def test_send(self): """test send""" - self.assertEqual(self.key_value_domian.send(self.original_space_data), self.internal_space_data) + self.assertEqual(self.key_value_domain.send(self.original_space_data), self.internal_space_data) def test_receive(self): """test receive""" - self.assertEqual(self.key_value_domian.receive(self.internal_space_data), self.original_space_data) - + self.assertEqual(self.key_value_domain.receive(self.internal_space_data), self.original_space_data) if __name__ == "__main__": - unittest.main() \ No newline at end of file + #unittest.main() + composed_domain = core_module.ComposedDomain([ + DummyDoubleDomain(), + DummyAddDomain(), + ]) + print(composed_domain.receive(-1)) + print(composed_domain.send(-2)) From 556e6f0af8c803297e7fa2d2fafc962f6dee9fec Mon Sep 17 00:00:00 2001 From: myaokai <65484599+myaokai@users.noreply.github.com> Date: Thu, 21 Mar 2024 16:03:33 +0900 Subject: [PATCH 112/117] fix --- tests/dl/torch/test_dataset.py | 27 ------------------------ tests/recon/torch/modules/test_latent.py | 11 +++++----- tests/recon/torch/task/test_inversion.py | 6 +++++- 3 files changed, 11 insertions(+), 33 deletions(-) delete mode 100644 tests/dl/torch/test_dataset.py diff --git a/tests/dl/torch/test_dataset.py b/tests/dl/torch/test_dataset.py deleted file mode 100644 index 56656a4c..00000000 --- a/tests/dl/torch/test_dataset.py +++ /dev/null @@ -1,27 +0,0 @@ -import unittest - -import torch -import torch.nn as nn - -from bdpy.dl.torch import models - - -class TestFeatureDataset(unittest.TestCase): - def setUp(self): - #self.dataset = - pass - pass - -class TestDecodedFeatureDataset(unittest.TestCase): - pass - -class TestImageDataset(unittest.TestCase): - pass - -class TestRenameFeatureKeys(unittest.TestCase): - pass - - - -if __name__ == '__main__': - unittest.main() \ No newline at end of file diff --git a/tests/recon/torch/modules/test_latent.py b/tests/recon/torch/modules/test_latent.py index a668114e..6e24d50a 100644 --- a/tests/recon/torch/modules/test_latent.py +++ b/tests/recon/torch/modules/test_latent.py @@ -19,7 +19,8 @@ def parameters(self, recurse): def generate(self): return self.latent - + + class TestBaseLatent(unittest.TestCase): """Tests for bdpy.recon.torch.modules.latent.BaseLatent.""" def setUp(self): @@ -47,6 +48,7 @@ def test_reset_states(self): latent.reset_states() self.assertTrue(torch.equal(latent(), self.latent_reset_value_expected)) + class DummyNNModuleLatent(latent_module.NNModuleLatent): def __init__(self): super().__init__() @@ -59,6 +61,7 @@ def reset_states(self): def generate(self): return self.latent + class TestNNModuleLatent(unittest.TestCase): """Tests for bdpy.recon.torch.modules.latent.NNModuleLatent.""" def setUp(self): @@ -86,14 +89,11 @@ def test_reset_states(self): latent.reset_states() self.assertTrue(torch.equal(latent(), self.latent_reset_value_expected)) -class DummyArbitraryLatent(latent_module.ArbitraryLatent): - def parameters(self, recurse): - return iter(self._latent) class TestArbitraryLatent(unittest.TestCase): """Tests for bdpy.recon.torch.modules.latent.ArbitraryLatent.""" def setUp(self): - self.latent = DummyArbitraryLatent((1, 3, 64, 64), partial(nn.init.normal_, mean=0, std=1)) + self.latent = latent_module.ArbitraryLatent((1, 3, 64, 64), partial(nn.init.normal_, mean=0, std=1)) self.latent_shape_expected = (1, 3, 64, 64) self.latent_value_expected = nn.Parameter(torch.tensor([0.0, 1.0, 2.0])) self.latent_reset_value_expected = nn.Parameter(torch.tensor([0.0, 0.0, 0.0])) @@ -119,5 +119,6 @@ def test_reset_states(self): self.assertAlmostEqual(mean, 0, places=1) self.assertAlmostEqual(std, 1, places=1) + if __name__ == '__main__': unittest.main() diff --git a/tests/recon/torch/task/test_inversion.py b/tests/recon/torch/task/test_inversion.py index cf04b56f..dc548cb9 100644 --- a/tests/recon/torch/task/test_inversion.py +++ b/tests/recon/torch/task/test_inversion.py @@ -140,6 +140,7 @@ def generate(self, latent): def reset_states(self) -> None: self.fc.apply(generator_module.call_reset_parameters) + class DummyNNModuleLatent(latent_module.NNModuleLatent): def __init__(self, base_latent): super().__init__() @@ -151,7 +152,8 @@ def reset_states(self): def generate(self): return self.latent - + + class TestFeatureInversionTask(unittest.TestCase): """Tests for bdpy.recon.torch.task.inversion.FeatureInversionTask""" def setUp(self): @@ -185,6 +187,7 @@ def test_call(self, mock_print): self.assertTrue(len(self.inversion_task._callback_handler._callbacks) > 0) # test for process + assert isinstance(generated_image, torch.Tensor) self.assertEqual(generated_image.shape, (1, 7 * 7 * 3)) self.assertIsNotNone(self.inversion_task._generator._generator_network.fc.weight.grad) self.assertFalse(torch.equal(self.inversion_task._latent.latent, self.init_latent)) @@ -213,5 +216,6 @@ def test_reset_state(self): self.assertFalse(torch.equal(p1, p2)) self.assertFalse(torch.equal(self.inversion_task._latent.latent, latent_copy.latent)) + if __name__ == "__main__": unittest.main() \ No newline at end of file From 7ddd4fbcb4a37e650f3e96a8406d76ab94a67bc7 Mon Sep 17 00:00:00 2001 From: myaokai <65484599+myaokai@users.noreply.github.com> Date: Thu, 21 Mar 2024 16:05:24 +0900 Subject: [PATCH 113/117] fix --- tests/dl/torch/domain/test_feature_domain.py | 3 +++ tests/dl/torch/domain/test_image_domain.py | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/tests/dl/torch/domain/test_feature_domain.py b/tests/dl/torch/domain/test_feature_domain.py index 72d7ba92..970950b3 100644 --- a/tests/dl/torch/domain/test_feature_domain.py +++ b/tests/dl/torch/domain/test_feature_domain.py @@ -4,6 +4,7 @@ import torch from bdpy.dl.torch.domain import feature_domain as feature_domain_module + class TestMethods(unittest.TestCase): def setUp(self): self.lnd_tensor = torch.empty((12, 196, 768)) @@ -17,6 +18,7 @@ def test_nld2lnd(self): """test _nld2lnd""" self.assertEqual(feature_domain_module._nld2lnd(self.nld_tensor).shape, self.lnd_tensor.shape) + class TestArbitraryFeatureKeyDomain(unittest.TestCase): """Tests for bdpy.dl.torch.domain.feature_domain.ArbitraryFeatureKeyDomain.""" def setUp(self): @@ -79,5 +81,6 @@ def test_receive(self): ) self.assertEqual(domain.receive(self.internal_features), self.features) + if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/tests/dl/torch/domain/test_image_domain.py b/tests/dl/torch/domain/test_image_domain.py index e8d79cc5..6140ab26 100644 --- a/tests/dl/torch/domain/test_image_domain.py +++ b/tests/dl/torch/domain/test_image_domain.py @@ -6,6 +6,7 @@ import warnings from bdpy.dl.torch.domain import image_domain as iamge_domain_module + class TestAffineDomain(unittest.TestCase): """Tests for bdpy.dl.torch.domain.image_domain.AffineDomain""" def setUp(self): @@ -76,6 +77,7 @@ def test_send_and_receive(self): expected_received_image = expected_transformed_image * scale3d - center3d torch.testing.assert_close(received_image, expected_received_image) + class TestRGBDomain(unittest.TestCase): """Tests fot bdpy.dl.torch.domain.image_domain.BGRDomain""" @@ -95,6 +97,7 @@ def test_receive(self): received_image = bgr_domain.receive(self.rgb_image) torch.testing.assert_close(received_image, self.bgr_image) + class TestPILDomainWithExplicitCrop(unittest.TestCase): """Tests fot bdpy.dl.torch.domain.image_domain.PILDomainWithExplicitCrop""" def setUp(self): @@ -115,6 +118,7 @@ def test_receive(self): self.assertTrue(any(isinstance(warn.message, RuntimeWarning) for warn in w)) torch.testing.assert_close(received_image, self.image) + class TestFixedResolutionDomain(unittest.TestCase): """Tests fot bdpy.dl.torch.domain.image_domain.FixedResolutionDomain""" def setUp(self): @@ -134,5 +138,6 @@ def test_receive(self): received_image = fr_domain.receive(self.image) self.assertEqual(received_image.size(), self.expected_received_image_size) + if __name__ == "__main__": unittest.main() \ No newline at end of file From d86e526b20b11218a93daf56023cb5cc1ff04c2e Mon Sep 17 00:00:00 2001 From: myaokai <65484599+myaokai@users.noreply.github.com> Date: Thu, 21 Mar 2024 16:07:47 +0900 Subject: [PATCH 114/117] fix --- tests/recon/torch/task/test_inversion.py | 2 ++ tests/task/test_core.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/recon/torch/task/test_inversion.py b/tests/recon/torch/task/test_inversion.py index dc548cb9..080fe930 100644 --- a/tests/recon/torch/task/test_inversion.py +++ b/tests/recon/torch/task/test_inversion.py @@ -114,6 +114,7 @@ def test_on_iteration_end(self, mock_print): self.callback.on_iteration_end(step=0) mock_print.assert_called_once_with("Step: [1], Loss: -1.0000") + class MLP(nn.Module): """A simple MLP.""" @@ -129,6 +130,7 @@ def forward(self, x): x = self.fc2(x) return x + class LinearGenerator(generator_module.NNModuleGenerator): def __init__(self): super().__init__() diff --git a/tests/task/test_core.py b/tests/task/test_core.py index 98b92527..0ba74d4e 100644 --- a/tests/task/test_core.py +++ b/tests/task/test_core.py @@ -3,9 +3,9 @@ from __future__ import annotations import unittest - from bdpy.task import core as core_module + class MockCallback(core_module.BaseCallback): """Mock callback for testing.""" def __init__(self): @@ -14,12 +14,14 @@ def __init__(self): def on_some_event(self, input_): self._storage.append(input_) + class MockTask(core_module.BaseTask[MockCallback]): """Mock task for testing BaseTask.""" def __call__(self, *inputs, **parameters): self._callback_handler.fire("on_some_event", input_=1) return inputs, parameters + class TestBaseTask(unittest.TestCase): """Tests forbdpy.task.core.BaseTask """ def setUp(self): @@ -56,5 +58,6 @@ def test_call(self): self.assertEqual(task_parameters["name"], self.task_name) self.assertEqual(mock_callback._storage, [1]) + if __name__ == "__main__": unittest.main() \ No newline at end of file From f87016c3a8366100a2dc3fdbbd5f9066969c80df Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Fri, 26 Jul 2024 17:56:15 +0900 Subject: [PATCH 115/117] change the name of the abstractmethod of a task to run() --- bdpy/task/core.py | 6 +++++- tests/task/test_core.py | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/bdpy/task/core.py b/bdpy/task/core.py index ad6b5b9d..d345a43e 100644 --- a/bdpy/task/core.py +++ b/bdpy/task/core.py @@ -40,8 +40,12 @@ def __init__( ) -> None: self._callback_handler = CallbackHandler(callbacks) - @abstractmethod def __call__(self, *inputs, **parameters) -> Any: + """Run the task.""" + return self.run(*inputs, **parameters) + + @abstractmethod + def run(self, *inputs, **parameters) -> Any: """Run the task.""" pass diff --git a/tests/task/test_core.py b/tests/task/test_core.py index 0ba74d4e..32fba222 100644 --- a/tests/task/test_core.py +++ b/tests/task/test_core.py @@ -10,14 +10,14 @@ class MockCallback(core_module.BaseCallback): """Mock callback for testing.""" def __init__(self): self._storage = [] - + def on_some_event(self, input_): self._storage.append(input_) class MockTask(core_module.BaseTask[MockCallback]): """Mock task for testing BaseTask.""" - def __call__(self, *inputs, **parameters): + def run(self, *inputs, **parameters): self._callback_handler.fire("on_some_event", input_=1) return inputs, parameters From 614601233e4f5f3df5646d4f4af9105bc8fbd61d Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Fri, 26 Jul 2024 18:00:05 +0900 Subject: [PATCH 116/117] automatically run reset_states() when the task is fired --- bdpy/recon/torch/task/inversion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bdpy/recon/torch/task/inversion.py b/bdpy/recon/torch/task/inversion.py index 3a873290..8a3b40b8 100644 --- a/bdpy/recon/torch/task/inversion.py +++ b/bdpy/recon/torch/task/inversion.py @@ -140,7 +140,6 @@ class FeatureInversionTask(BaseTask): ... encoder, generator, latent, critic, optimizer, num_iterations=200, ... ) >>> target_features = encoder(target_image) - >>> task.reset_states() >>> reconstructed_image = task(target_features) """ @@ -167,7 +166,7 @@ def __init__( self._num_iterations = num_iterations - def __call__( + def run( self, target_features: FeatureType, ) -> torch.Tensor: @@ -184,6 +183,7 @@ def __call__( Reconstructed images on the libraries internal domain. """ self._callback_handler.fire("on_task_start") + self.reset_states() for step in range(self._num_iterations): self._callback_handler.fire("on_iteration_start", step=step) self._optimizer.zero_grad() From 53e1403cb269ead216ae9c2521f018ca596765c2 Mon Sep 17 00:00:00 2001 From: Yoshihiro Nagano Date: Fri, 26 Jul 2024 18:10:20 +0900 Subject: [PATCH 117/117] critic.compare() -> critic.evaluate() --- bdpy/recon/torch/modules/critic.py | 8 ++++---- tests/recon/torch/modules/test_critic.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/bdpy/recon/torch/modules/critic.py b/bdpy/recon/torch/modules/critic.py index bbf31157..b43d5ba4 100644 --- a/bdpy/recon/torch/modules/critic.py +++ b/bdpy/recon/torch/modules/critic.py @@ -33,10 +33,10 @@ def __call__(self, features: _FeatureType, target_features: _FeatureType) -> tor torch.Tensor Loss value. """ - return self.compare(features, target_features) + return self.evaluate(features, target_features) @abstractmethod - def compare( + def evaluate( self, features: _FeatureType, target_features: _FeatureType, @@ -68,13 +68,13 @@ def __call__(self, features: _FeatureType, target_features: _FeatureType) -> tor return nn.Module.__call__(self, features, target_features) def forward(self, features: _FeatureType, target_features: _FeatureType) -> torch.Tensor: - return self.compare(features, target_features) + return self.evaluate(features, target_features) class LayerWiseAverageCritic(NNModuleCritic): """Compute the average of the layer-wise loss values.""" - def compare( + def evaluate( self, features: _FeatureType, target_features: _FeatureType, diff --git a/tests/recon/torch/modules/test_critic.py b/tests/recon/torch/modules/test_critic.py index c4600f56..191b27f8 100644 --- a/tests/recon/torch/modules/test_critic.py +++ b/tests/recon/torch/modules/test_critic.py @@ -28,7 +28,7 @@ def test_instantiation(self): def test_call(self): """Test __call__.""" class ReturnZeroCritic(critic_module.BaseCritic): - def compare(self, features, target_features): + def evaluate(self, features, target_features): return 0.0 critic = ReturnZeroCritic() @@ -37,7 +37,7 @@ def compare(self, features, target_features): def test_loss_computation(self): """Test loss computation.""" class SumCritic(critic_module.BaseCritic): - def compare(self, features, target_features): + def evaluate(self, features, target_features): loss = 0.0 for feature, target_feature in zip(features.values(), target_features.values()): loss += torch.sum(torch.abs(feature - target_feature)) @@ -76,7 +76,7 @@ def test_instantiation(self): def test_call(self): """Test __call__.""" class ReturnZeroCritic(critic_module.NNModuleCritic): - def compare(self, features, target_features): + def evaluate(self, features, target_features): return 0.0 critic = ReturnZeroCritic() @@ -85,7 +85,7 @@ def compare(self, features, target_features): def test_loss_computation(self): """Test loss computation.""" class SumCritic(critic_module.NNModuleCritic): - def compare(self, features, target_features): + def evaluate(self, features, target_features): loss = 0.0 for feature, target_feature in zip(features.values(), target_features.values()): loss += torch.sum(torch.abs(feature - target_feature))