Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

polar #129

Closed
wants to merge 32 commits into from
Closed

polar #129

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
d0330d3
base complexpyr working with torch1.8- still needs more test checks a…
Mar 18, 2021
7b8e72c
merge master
nikparth Mar 18, 2021
b851eb8
require torch1.8 for ffts
nikparth Mar 18, 2021
92e3f11
fix steerable tests
nikparth Mar 22, 2021
e0e4af7
merge master (new changes to remove pooling)
nikparth Mar 22, 2021
5ec176c
change animshow and imshow for new complex tensors
billbrod Mar 23, 2021
38bc78b
removes convert_pyrshow, adds pyrshow
billbrod Mar 23, 2021
4a698e4
adds pyrshow tests
billbrod Mar 23, 2021
fc6aba1
replaces pt.pyrshow with po.pyrshow
billbrod Mar 23, 2021
4452215
merge master (test speedups)
nikparth Mar 25, 2021
88ab60f
Merge branch 'complex_pyramid' of https://github.com/LabForComputatio…
nikparth Mar 25, 2021
1e240af
fix recon tests, fix some of steerpyr notebook
nikparth Mar 26, 2021
7a3b8d3
updates Display notebook
billbrod Mar 30, 2021
4be2fd0
change convert_pyr_to_tensor function, minor changes to notebook
nikparth Mar 30, 2021
7422ac3
fixes failing mad and metamer tests
billbrod Apr 1, 2021
5bd79c6
Merge branch 'master' of https://github.com/LabForComputationalVision…
pehf Apr 23, 2021
1e2e95b
Revert "Merge branch 'master' of https://github.com/LabForComputation…
pehf Apr 23, 2021
2219c24
Merge branch 'complex_pyramid' of https://github.com/LabForComputatio…
pehf May 25, 2021
5021b78
support complex representations in non-linearities, and clean-up of s…
pehf May 26, 2021
c7b4004
first pass implementation of factorized pyramid
pehf May 26, 2021
e12ec92
tests nonlinearities and factorized pyramid
pehf May 26, 2021
cfae313
remove some duplicates
pehf May 26, 2021
ae0689c
generalize and test factorized_pyr
pehf May 26, 2021
ed3a767
complex pyr update comments and misc fixes
nikparth May 27, 2021
aaab1a8
Merge branch 'master' of https://github.com/LabForComputationalVision…
pehf May 27, 2021
ec769e3
Fix merge conflicts
nikparth May 27, 2021
1a559dd
cosmetics
pehf May 27, 2021
9c2839b
Merge branch 'complex_pyramid' of https://github.com/LabForComputatio…
pehf May 27, 2021
ba6a399
Merge branch 'docs_fix' of https://github.com/LabForComputationalVisi…
pehf May 27, 2021
e6b9ac9
Merge branch 'main' of https://github.com/LabForComputationalVision/p…
pehf May 28, 2021
42f6262
Merge branch 'main' of https://github.com/LabForComputationalVision/p…
pehf Aug 25, 2021
07b2751
Merge branch 'main' into polar, modify factorized pyr
pehf Aug 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 45 additions & 36 deletions examples/01_Linear_approximation_of_nonlinear_model.ipynb

Large diffs are not rendered by default.

272 changes: 136 additions & 136 deletions examples/09_Original_MAD.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion plenoptic/metric/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .perceptual_distance import ssim, ms_ssim, nlpd, nspd, ssim_map
from .perceptual_distance import ssim, nlpd, nspd, ssim_map
from .model_metric import model_metric
from .naive import mse
from .classes import NLP
131 changes: 21 additions & 110 deletions plenoptic/metric/perceptual_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,21 @@ def _ssim_parts(img1, img2, dynamic_range):
C1 = (0.01 * dynamic_range) ** 2
C2 = (0.03 * dynamic_range) ** 2

# SSIM is the product of a luminance component, a contrast component, and a
# structure component. The contrast-structure component has to be separated
# when computing MS-SSIM.
luminance_map = (2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)
contrast_structure_map = (2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
map_ssim = luminance_map * contrast_structure_map
v1 = 2.0 * sigma12 + C2
v2 = sigma1_sq + sigma2_sq + C2

# SSIM consists of a luminance component, a contrast component, and a
# structure component. This is the contrast component, which is used to
# compute MS-SSIM This is the contrast component, which is used to compute
# MS-SSIM.
contrast_map = v1 / v2

# the weight used for stability
weight = torch.log((1 + sigma1_sq/C2) * (1 + sigma2_sq/C2))
return map_ssim, contrast_structure_map, weight
weight = torch.log(torch.matmul((1+(sigma1_sq/C2)), (1+(sigma2_sq/C2))))

ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * contrast_map
return ssim_map, contrast_map, weight


def ssim(img1, img2, weighted=False, dynamic_range=1):
Expand All @@ -108,7 +113,7 @@ def ssim(img1, img2, weighted=False, dynamic_range=1):
As described in [1]_, the structural similarity index (SSIM) is a
perceptual distance metric, giving the distance between two images. SSIM is
based on three comparison measurements between the two images: luminance,
contrast, and structure. All of these are computed convolutionally across the
contrast, and structure. All of these are computed in windows across the
images. See the references for more information.

This implementation follows the original implementation, as found at [2]_,
Expand Down Expand Up @@ -146,8 +151,8 @@ def ssim(img1, img2, weighted=False, dynamic_range=1):
Returns
------
mssim : torch.Tensor
2d tensor of shape (batch, channel) containing the mean SSIM for each
image, averaged over the whole image
2d tensor containing the mean SSIM for each image, averaged over the
whole image

Notes
-----
Expand All @@ -164,7 +169,7 @@ def ssim(img1, img2, weighted=False, dynamic_range=1):
----------
.. [1] Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image
quality assessment: From error measurement to structural similarity"
IEEE Transactions on Image Processing, vol. 13, no. 1, Jan. 2004.
IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004.
.. [2] [matlab code](https://www.cns.nyu.edu/~lcv/ssim/ssim_index.m)
.. [3] [project page](https://www.cns.nyu.edu/~lcv/ssim/)
.. [4] Wang, Z., & Simoncelli, E. P. (2008). Maximum differentiation (MAD)
Expand All @@ -181,10 +186,6 @@ def ssim(img1, img2, weighted=False, dynamic_range=1):
else:
mssim = (map_ssim*weight).sum((-1, -2)) / weight.sum((-1, -2))

if min(img1.shape[2], img1.shape[3]) < 11:
warnings.warn("SSIM uses 11x11 convolutional kernel, but the height and/or "
"the width of the input image is smaller than 11, so the "
"kernel size is set to be the minimum of these two numbers.")
return mssim


Expand All @@ -194,7 +195,7 @@ def ssim_map(img1, img2, dynamic_range=1):
As described in [1]_, the structural similarity index (SSIM) is a
perceptual distance metric, giving the distance between two images. SSIM is
based on three comparison measurements between the two images: luminance,
contrast, and structure. All of these are computed convolutionally across the
contrast, and structure. All of these are computed in windows across the
images. See the references for more information.

This implementation follows the original implementation, as found at [2]_,
Expand Down Expand Up @@ -237,7 +238,7 @@ def ssim_map(img1, img2, dynamic_range=1):
----------
.. [1] Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image
quality assessment: From error measurement to structural similarity"
IEEE Transactions on Image Processing, vol. 13, no. 1, Jan. 2004.
IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004.
.. [2] [matlab code](https://www.cns.nyu.edu/~lcv/ssim/ssim_index.m)
.. [3] [project page](https://www.cns.nyu.edu/~lcv/ssim/)
.. [4] Wang, Z., & Simoncelli, E. P. (2008). Maximum differentiation (MAD)
Expand All @@ -246,97 +247,9 @@ def ssim_map(img1, img2, dynamic_range=1):
http://dx.doi.org/10.1167/8.12.8

"""
if min(img1.shape[2], img1.shape[3]) < 11:
warnings.warn("SSIM uses 11x11 convolutional kernel, but the height and/or "
"the width of the input image is smaller than 11, so the "
"kernel size is set to be the minimum of these two numbers.")
return _ssim_parts(img1, img2, dynamic_range)[0]


def ms_ssim(img1, img2, dynamic_range=1, power_factors=None):
r"""Multiscale structural similarity index (MS-SSIM)

As described in [1]_, multiscale structural similarity index (MS-SSIM) is
an improvement upon structural similarity index (SSIM) that takes into
account the perceptual distance between two images on different scales.

SSIM is based on three comparison measurements between the two images:
luminance, contrast, and structure. All of these are computed convolutionally
across the images, producing three maps instead of scalars. The SSIM map is
the elementwise product of these three maps. See `metric.ssim` and
`metric.ssim_map` for a full description of SSIM.

To get images of different scales, average pooling operations with kernel
size 2 are performed recursively on the input images. The product of
contrast map and structure map (the "contrast-structure map") is computed
for all but the coarsest scales, and the overall SSIM map is only computed
for the coarsest scale. Their mean values are raised to exponents and
multiplied to produce MS-SSIM:
.. math::
MSSSIM = {SSIM}_M^{a_M} \prod_{i=1}^{M-1} ({CS}_i)^{a_i}
Here :math: `M` is the number of scales, :math: `{CS}_i` is the mean value
of the contrast-structure map for the i'th finest scale, and :math: `{SSIM}_M`
is the mean value of the SSIM map for the coarsest scale. If at least one
of these terms are negative, the value of MS-SSIM is zero. The values of
:math: `a_i, i=1,...,M` are taken from the argument `power_factors`.

Parameters
----------
img1 : torch.Tensor
4d tensor with first image to compare
img2 : torch.Tensor
4d tensor with second image to compare. Must have the same height and
width (last two dimensions) as `img1`
dynamic_range : int, optional.
dynamic range of the images. Note we assume that both images have the
same dynamic range. 1, the default, is appropriate for float images
between 0 and 1, as is common in synthesis. 2 is appropriate for float
images between -1 and 1, and 255 is appropriate for standard 8-bit
integer images. We'll raise a warning if it looks like your value is
not appropriate for `img1` or `img2`, but will calculate it anyway.
power_factors : 1D array, optional.
power exponents for the mean values of maps, for different scales (from
fine to coarse). The length of this array determines the number of scales.
By default, this is set to [0.0448, 0.2856, 0.3001, 0.2363, 0.1333],
which is what psychophysical experiments in [1]_ found.

Returns
------
msssim : torch.Tensor
2d tensor of shape (batch, channel) containing the MS-SSIM for each image

References
----------
.. [1] Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. "Multiscale
structural similarity for image quality assessment." The Thrity-Seventh
Asilomar Conference on Signals, Systems & Computers, 2003. Vol. 2. IEEE, 2003.

"""
if power_factors is None:
power_factors = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]

def downsample(img):
img = F.pad(img, (0, img.shape[3] % 2, 0, img.shape[2] % 2), mode="replicate")
img = F.avg_pool2d(img, kernel_size=2)
return img

msssim = 1
for i in range(len(power_factors) - 1):
_, contrast_structure_map, _ = _ssim_parts(img1, img2, dynamic_range)
msssim *= F.relu(contrast_structure_map.mean((-1, -2))).pow(power_factors[i])
img1 = downsample(img1)
img2 = downsample(img2)
map_ssim, _, _ = _ssim_parts(img1, img2, dynamic_range)
msssim *= F.relu(map_ssim.mean((-1, -2))).pow(power_factors[-1])

if min(img1.shape[2], img1.shape[3]) < 11:
warnings.warn("SSIM uses 11x11 convolutional kernel, but for some scales "
"of the input image, the height and/or the width is smaller "
"than 11, so the kernel size in SSIM is set to be the "
"minimum of these two numbers for these scales.")
return msssim


def normalized_laplacian_pyramid(im):
"""computes the normalized Laplacian Pyramid using pre-optimized parameters

Expand All @@ -360,8 +273,7 @@ def normalized_laplacian_pyramid(im):
padd = 2
normalized_laplacian_activations = []
for N_b in range(0, N_scales):
filt = torch.tensor(spatialpooling_filters[N_b], dtype=torch.float32,
device=im.device).unsqueeze(0).unsqueeze(0)
filt = torch.tensor(spatialpooling_filters[N_b], dtype=torch.float32, device=im.device).unsqueeze(0).unsqueeze(0)
filtered_activations = F.conv2d(torch.abs(laplacian_activations[N_b]), filt, padding=padd, groups=channel)
normalized_laplacian_activations.append(laplacian_activations[N_b] / (sigmas[N_b] + filtered_activations))

Expand Down Expand Up @@ -402,8 +314,7 @@ def nlpd(IM_1, IM_2):

References
----------
.. [1] Laparra, V., Ballé, J., Berardino, A. and Simoncelli, E.P., 2016. Perceptual image quality
assessment using a normalized Laplacian pyramid. Electronic Imaging, 2016(16), pp.1-6.
.. [1] Laparra, V., Ballé, J., Berardino, A. and Simoncelli, E.P., 2016. Perceptual image quality assessment using a normalized Laplacian pyramid. Electronic Imaging, 2016(16), pp.1-6.
"""

y = normalized_laplacian_pyramid(torch.cat((IM_1, IM_2), 0))
Expand Down
33 changes: 10 additions & 23 deletions plenoptic/simulate/canonical_computations/non_linearities.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
from ...tools.signal import rectangular_to_polar, polar_to_rectangular


def rectangular_to_polar_dict(coeff_dict, dim=-1, residuals=False):
"""Return the complex modulus and the phase of each complex tensor
in a dictionary.
def rectangular_to_polar_dict(coeff_dict, residuals=True):
"""Wraps the rectangular to polar transform to act on all
the values in a dictionary.

Parameters
----------
coeff_dict : dict
A dictionary containing complex tensors.
dim : int, optional
The dimension that contains the real and imaginary components.
residuals: bool, optional
An option to carry around residuals in the energy branch.
Note that the transformation is not applied to the residuals,
that is dictionary elements with a key starting in "residual".

Returns
-------
Expand All @@ -27,24 +27,12 @@ def rectangular_to_polar_dict(coeff_dict, dim=-1, residuals=False):

Note
----
Since complex numbers are not supported by pytorch, we represent
complex tensors as having an extra dimension with two slices, where
one contains the real and the other contains the imaginary
components. E.g., ``1+2j`` would be represented as
``torch.tensor([1, 2])`` and ``[1+2j, 4+5j]`` would be
``torch.tensor([[1, 2], [4, 5]])``. In the cases represented here,
this "complex dimension" is the last one, and so the default
argument ``dim=-1`` would work.

Note that energy and state is not computed on the residuals.

Computing the state is local gain control in disguise, see
``rectangular_to_polar_real`` and ``local_gain_control``.
``local_gain_control`` and ``local_gain_control_dict``.
"""
energy = {}
state = {}
for key in coeff_dict.keys():
# ignore residuals
if isinstance(key, tuple) or not key.startswith('residual'):
energy[key], state[key] = rectangular_to_polar(coeff_dict[key])

Expand All @@ -56,7 +44,8 @@ def rectangular_to_polar_dict(coeff_dict, dim=-1, residuals=False):


def polar_to_rectangular_dict(energy, state, residuals=True):
"""Return the real and imaginary part tensor in a dictionary.
"""Wraps the polar to rectangular transform to act on all
the values in a matching pair of dictionaries.

Parameters
----------
Expand All @@ -65,10 +54,10 @@ def polar_to_rectangular_dict(energy, state, residuals=True):
modulus.
state : dict
The dictionary of torch.Tensors containing the local phase.
dim : int, optional
The dimension that contains the real and imaginary components.
residuals: bool, optional
An option to carry around residuals in the energy branch.
Note that the transformation is not applied to the residuals,
that is dictionary elements with a key starting in "residual".

Returns
-------
Expand All @@ -78,8 +67,6 @@ def polar_to_rectangular_dict(energy, state, residuals=True):

coeff_dict = {}
for key in energy.keys():
# ignore residuals

if isinstance(key, tuple) or not key.startswith('residual'):
coeff_dict[key] = polar_to_rectangular(energy[key], state[key])

Expand Down
1 change: 1 addition & 0 deletions plenoptic/simulate/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .frontend import *
from .naive import *
from .factorized_pyramid import *
89 changes: 89 additions & 0 deletions plenoptic/simulate/models/factorized_pyramid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import torch.nn as nn
from plenoptic.simulate.canonical_computations.non_linearities import (
local_gain_control, local_gain_control_dict, local_gain_release,
local_gain_release_dict, polar_to_rectangular_dict,
rectangular_to_polar_dict)
from plenoptic.simulate.canonical_computations.steerable_pyramid_freq import \
Steerable_Pyramid_Freq
from plenoptic.tools.signal import polar_to_rectangular, rectangular_to_polar


class Factorized_Pyramid(nn.Module):
"""
An non-linear transform which factorizes signal and is exactely invertible.

Loosely partitions things and stuff.

Analogous to Fourier amplitude and phase for a localized multiscale
and oriented transform.

Notes
-----
residuals are stored in amplitude

by default the not downsampled version also returns a tensor,
which allows easy further processing
eg. recursive Factorized Pyr
(analogous to the scattering transform)

TODO
----
flesh out the relationship btw real and complex cases

handle multi channel input
eg. from front end, or from recursive calls
hack: fold channels into batch dim and then back out

cross channel processing - thats next level
"""
def __init__(self, image_size, n_ori=4, n_scale='auto',
downsample_dict=True, is_complex=True):
super().__init__()

self.downsample_dict = downsample_dict
self.is_complex = is_complex

pyr = Steerable_Pyramid_Freq(image_size,
order=n_ori-1,
height=n_scale,
is_complex=is_complex,
downsample=downsample_dict)
self.n_ori = pyr.num_orientations
self.n_scale = pyr.num_scales

if downsample_dict:
self.pyramid_analysis = lambda x: pyr.forward(x)
self.pyramid_synthesis = lambda y: pyr.recon_pyr(y)
if is_complex:
self.decomposition = rectangular_to_polar_dict
self.recomposition = polar_to_rectangular_dict
else:
self.decomposition = local_gain_control_dict
self.recomposition = local_gain_release_dict
else:
def stash(y, info):
self.pyr_info = info
return y
self.pyramid_analysis = lambda x: stash(*pyr.convert_pyr_to_tensor(
pyr.forward(x)))
self.pyramid_synthesis = lambda y: pyr.recon_pyr(
pyr.convert_tensor_to_pyr(y, *self.pyr_info))
if is_complex:
self.decomposition = rectangular_to_polar
self.recomposition = polar_to_rectangular
else:
self.decomposition = local_gain_control
self.recomposition = local_gain_release

def analysis(self, x):
y = self.pyramid_analysis(x)
energy, state = self.decomposition(y)
return energy, state

def synthesis(self, energy, state):
y = self.recomposition(energy, state)
x = self.pyramid_synthesis(y)
return x

def forward(self, x):
return self.analysis(x)
1 change: 0 additions & 1 deletion plenoptic/tools/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def convert_float_to_int(im, dtype=np.uint8):
return (im * np.iinfo(dtype).max).astype(dtype)



def make_synthetic_stimuli(size=256, requires_grad=True):
r""" Make a set of basic stimuli, useful for developping and debugging models

Expand Down
Loading