-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #81 from KamitaniLab/feature-inversion-pipeline
Feature inversion pipeline for modular iCNN construction
- Loading branch information
Showing
33 changed files
with
3,475 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Iterable, Callable, Dict | ||
|
||
from pathlib import Path | ||
|
||
from PIL import Image | ||
import numpy as np | ||
from torch.utils.data import Dataset | ||
|
||
from bdpy.dataform import DecodedFeatures, Features | ||
|
||
|
||
_FeatureTypeNP = Dict[str, np.ndarray] | ||
|
||
|
||
def _removesuffix(s: str, suffix: str) -> str: | ||
"""Remove suffix from string. | ||
Note | ||
---- | ||
This function is available from Python 3.9 as `str.removesuffix`. We can | ||
remove this function when we drop support for Python 3.8. | ||
Parameters | ||
---------- | ||
s : str | ||
String. | ||
suffix : str | ||
Suffix to remove. | ||
Returns | ||
------- | ||
str | ||
String without suffix. | ||
""" | ||
if suffix and s.endswith(suffix): | ||
return s[: -len(suffix)] | ||
return s[:] | ||
|
||
|
||
class FeaturesDataset(Dataset): | ||
"""Dataset of features. | ||
Parameters | ||
---------- | ||
root_path : str | Path | ||
Path to the root directory of features. | ||
layer_path_names : Iterable[str] | ||
List of layer path names. Each layer path name is used to get features | ||
from the root directory so that the layer path name must be a part of | ||
the path to the layer. | ||
stimulus_names : list[str], optional | ||
List of stimulus names. If None, all stimulus names are used. | ||
transform : callable, optional | ||
Callable object which is used to transform features. The callable object | ||
must take a dict of features and return a dict of features. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
root_path: str | Path, | ||
layer_path_names: Iterable[str], | ||
stimulus_names: list[str] | None = None, | ||
transform: Callable[[_FeatureTypeNP], _FeatureTypeNP] | None = None, | ||
): | ||
self._features_store = Features(Path(root_path).as_posix()) | ||
self._layer_path_names = layer_path_names | ||
if stimulus_names is None: | ||
stimulus_names = self._features_store.labels | ||
self._stimulus_names = stimulus_names | ||
self._transform = transform | ||
|
||
def __len__(self) -> int: | ||
return len(self._stimulus_names) | ||
|
||
def __getitem__(self, index: int) -> _FeatureTypeNP: | ||
stimulus_name = self._stimulus_names[index] | ||
features = {} | ||
for layer_path_name in self._layer_path_names: | ||
feature = self._features_store.get( | ||
layer=layer_path_name, label=stimulus_name | ||
) | ||
feature = feature[0] # NOTE: remove batch axis | ||
features[layer_path_name] = feature | ||
if self._transform is not None: | ||
features = self._transform(features) | ||
return features | ||
|
||
|
||
class DecodedFeaturesDataset(Dataset): | ||
"""Dataset of decoded features. | ||
Parameters | ||
---------- | ||
root_path : str | Path | ||
Path to the root directory of decoded features. | ||
layer_path_names : Iterable[str] | ||
List of layer path names. Each layer path name is used to get features | ||
from the root directory so that the layer path name must be a part of | ||
the path to the layer. | ||
subject_id : str | ||
ID of the subject. | ||
roi : str | ||
ROI name. | ||
stimulus_names : list[str], optional | ||
List of stimulus names. If None, all stimulus names are used. | ||
transform : callable, optional | ||
Callable object which is used to transform features. The callable object | ||
must take a dict of features and return a dict of features. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
root_path: str | Path, | ||
layer_path_names: Iterable[str], | ||
subject_id: str, | ||
roi: str, | ||
stimulus_names: list[str] | None = None, | ||
transform: Callable[[_FeatureTypeNP], _FeatureTypeNP] | None = None, | ||
): | ||
self._decoded_features_store = DecodedFeatures(Path(root_path).as_posix()) | ||
self._layer_path_names = layer_path_names | ||
self._subject_id = subject_id | ||
self._roi = roi | ||
if stimulus_names is None: | ||
stimulus_names = self._decoded_features_store.labels | ||
assert stimulus_names is not None | ||
self._stimulus_names = stimulus_names | ||
self._transform = transform | ||
|
||
def __len__(self) -> int: | ||
return len(self._stimulus_names) | ||
|
||
def __getitem__(self, index: int) -> _FeatureTypeNP: | ||
stimulus_name = self._stimulus_names[index] | ||
decoded_features = {} | ||
for layer_path_name in self._layer_path_names: | ||
decoded_feature = self._decoded_features_store.get( | ||
layer=layer_path_name, | ||
label=stimulus_name, | ||
subject=self._subject_id, | ||
roi=self._roi, | ||
) | ||
decoded_feature = decoded_feature[0] # NOTE: remove batch axis | ||
decoded_features[layer_path_name] = decoded_feature | ||
if self._transform is not None: | ||
decoded_features = self._transform(decoded_features) | ||
return decoded_features | ||
|
||
|
||
class ImageDataset(Dataset): | ||
"""Dataset of images. | ||
Parameters | ||
---------- | ||
root_path : str | Path | ||
Path to the root directory of images. | ||
stimulus_names : list[str], optional | ||
List of stimulus names. If None, all stimulus names are used. | ||
extension : str, optional | ||
Extension of the image files. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
root_path: str | Path, | ||
stimulus_names: list[str] | None = None, | ||
extension: str = "jpg", | ||
): | ||
self.root_path = root_path | ||
if stimulus_names is None: | ||
stimulus_names = [ | ||
_removesuffix(path.name, "." + extension) | ||
for path in Path(root_path).glob(f"*{extension}") | ||
] | ||
self._stimulus_names = stimulus_names | ||
self._extension = extension | ||
|
||
def __len__(self): | ||
return len(self._stimulus_names) | ||
|
||
def __getitem__(self, index: int): | ||
stimulus_name = self._stimulus_names[index] | ||
image = Image.open(Path(self.root_path) / f"{stimulus_name}.{self._extension}") | ||
image = image.convert("RGB") | ||
return np.array(image) / 255.0, stimulus_name | ||
|
||
|
||
class RenameFeatureKeys: | ||
def __init__(self, mapping: dict[str, str]): | ||
self._mapping = mapping | ||
|
||
def __call__(self, features: _FeatureTypeNP) -> _FeatureTypeNP: | ||
return {self._mapping.get(key, key): value for key, value in features.items()} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .core import Domain, InternalDomain, IrreversibleDomain, ComposedDomain, KeyValueDomain |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Iterable, TypeVar, Generic | ||
import warnings | ||
|
||
import torch.nn as nn | ||
|
||
_T = TypeVar("_T") | ||
|
||
|
||
class Domain(nn.Module, ABC, Generic[_T]): | ||
"""Base class for stimulus domain. | ||
This class is used to convert data between each domain and library's internal common space. | ||
Suppose that we have two functions `f: X -> Y_1` and `g: Y_2 -> Z` and want to compose them. | ||
Here, `X`, `Y_1`, `Y_2`, and `Z` are different domains and assume that `Y_1` and `Y_2` are | ||
the similar domain that can be converted to each other. | ||
Then, we can compose `f` and `g` as `g . t . f(x)`, where `t: Y_1 -> Y_2` is the domain | ||
conversion function. This class is used to implement `t`. | ||
The subclasses of this class should implement `send` and `receive` methods. The `send` method | ||
converts data from the original domain (`Y_1` or `Y_2`) to the internal common space (`Y_0`), | ||
and the `receive` method converts data from the internal common space to the original domain. | ||
By implementing domain class for `Y_1` and `Y_2`, we can construct the domain conversion function | ||
`t` as `t = Y_2.receive . Y_1.send`. | ||
Note that the subclasses of this class do not necessarily guarantee the reversibility of `send` | ||
and `receive` methods. If the domain conversion is irreversible, the subclasses should inherit | ||
`IrreversibleDomain` class instead of this class. | ||
""" | ||
|
||
@abstractmethod | ||
def send(self, x: _T) -> _T: | ||
"""Send stimulus to the internal common space from each domain. | ||
Parameters | ||
---------- | ||
x : _T | ||
Data in the original domain. | ||
Returns | ||
------- | ||
_T | ||
Data in the internal common space. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def receive(self, x: _T) -> _T: | ||
"""Receive data from the internal common space to each domain. | ||
Parameters | ||
---------- | ||
x : _T | ||
Data in the internal common space. | ||
Returns | ||
------- | ||
_T | ||
Data in the original domain. | ||
""" | ||
pass | ||
|
||
|
||
class InternalDomain(Domain, Generic[_T]): | ||
"""The internal common space. | ||
The domain class which defines the internal common space. This class | ||
receives and sends data as it is. | ||
""" | ||
|
||
def send(self, x: _T) -> _T: | ||
return x | ||
|
||
def receive(self, x: _T) -> _T: | ||
return x | ||
|
||
|
||
class IrreversibleDomain(Domain, Generic[_T]): | ||
"""The domain which cannot be reversed. | ||
This class is used to convert data between each domain and library's | ||
internal common space. Note that the subclasses of this class do not | ||
guarantee the reversibility of `send` and `receive` methods. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
warnings.warn( | ||
f"{self.__class__.__name__} is an irreversible domain. " \ | ||
"It does not guarantee the reversibility of `send` and `receive` " \ | ||
"methods. Please use the combination of `send` and `receive` methods " \ | ||
"with caution.", | ||
RuntimeWarning, | ||
) | ||
|
||
def send(self, x: _T) -> _T: | ||
return x | ||
|
||
def receive(self, x: _T) -> _T: | ||
return x | ||
|
||
|
||
class ComposedDomain(Domain, Generic[_T]): | ||
"""The domain composed of multiple sub-domains. | ||
Suppose we have list of domain objects `domains = [d_0, d_1, ..., d_n]`. | ||
Then, `ComposedDomain(domains)` accesses the data in the original domain `D` | ||
as `d_n.receive . ... d_1.receive . d_0.receive(x)` from the internal common space `D_0`. | ||
Parameters | ||
---------- | ||
domains : Iterable[Domain] | ||
Sub-domains to compose. | ||
Examples | ||
-------- | ||
>>> import numpy as np | ||
>>> import torch | ||
>>> from bdpy.dl.torch.domain import ComposedDomain | ||
>>> from bdpy.dl.torch.domain.image_domain import AffineDomain, BGRDomain | ||
>>> composed_domain = ComposedDomain([ | ||
... AffineDomain(0.5, 1), | ||
... BGRDomain(), | ||
... ]) | ||
>>> image = torch.randn(1, 3, 64, 64).clamp(-0.5, 0.5) | ||
>>> image.shape | ||
torch.Size([1, 3, 64, 64]) | ||
>>> composed_domain.send(image).shape | ||
torch.Size([1, 3, 64, 64]) | ||
>>> print(composed_domain.send(image).min().item(), composed_domain.send(image).max().item()) | ||
0.0 1.0 | ||
""" | ||
|
||
def __init__(self, domains: Iterable[Domain]) -> None: | ||
super().__init__() | ||
self.domains = nn.ModuleList(domains) | ||
|
||
def send(self, x: _T) -> _T: | ||
for domain in reversed(self.domains): | ||
x = domain.send(x) | ||
return x | ||
|
||
def receive(self, x: _T) -> _T: | ||
for domain in self.domains: | ||
x = domain.receive(x) | ||
return x | ||
|
||
|
||
class KeyValueDomain(Domain, Generic[_T]): | ||
"""The domain which converts key-value pairs. | ||
This class is used to convert key-value pairs between each domain and library's | ||
internal common space. | ||
Parameters | ||
---------- | ||
domain_mapper : dict[str, Domain] | ||
Dictionary that maps keys to domains. | ||
""" | ||
|
||
def __init__(self, domain_mapper: dict[str, Domain]) -> None: | ||
super().__init__() | ||
self.domain_mapper = domain_mapper | ||
|
||
def send(self, x: dict[str, _T]) -> dict[str, _T]: | ||
return { | ||
key: self.domain_mapper[key].send(value) for key, value in x.items() | ||
} | ||
|
||
def receive(self, x: dict[str, _T]) -> dict[str, _T]: | ||
return { | ||
key: self.domain_mapper[key].receive(value) for key, value in x.items() | ||
} |
Oops, something went wrong.
3cd3c8d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Coverage Report
3cd3c8d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Coverage Report
3cd3c8d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Coverage Report
3cd3c8d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Coverage Report