-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
531 additions
and
607 deletions.
There are no files selected for viewing
81 changes: 45 additions & 36 deletions
81
examples/01_Linear_approximation_of_nonlinear_model.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from .perceptual_distance import ssim, ms_ssim, nlpd, nspd, ssim_map | ||
from .perceptual_distance import ssim, nlpd, nspd, ssim_map | ||
from .model_metric import model_metric | ||
from .naive import mse | ||
from .classes import NLP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .frontend import * | ||
from .naive import * | ||
from .factorized_pyramid import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import torch.nn as nn | ||
from plenoptic.simulate.canonical_computations.non_linearities import ( | ||
local_gain_control, local_gain_control_dict, local_gain_release, | ||
local_gain_release_dict, polar_to_rectangular_dict, | ||
rectangular_to_polar_dict) | ||
from plenoptic.simulate.canonical_computations.steerable_pyramid_freq import \ | ||
Steerable_Pyramid_Freq | ||
from plenoptic.tools.signal import polar_to_rectangular, rectangular_to_polar | ||
|
||
|
||
class Factorized_Pyramid(nn.Module): | ||
""" | ||
An non-linear transform which factorizes signal and is exactely invertible. | ||
Loosely partitions things and stuff. | ||
Analogous to Fourier amplitude and phase for a localized multiscale | ||
and oriented transform. | ||
Notes | ||
----- | ||
residuals are stored in amplitude | ||
by default the not downsampled version also returns a tensor, | ||
which allows easy further processing | ||
eg. recursive Factorized Pyr | ||
(analogous to the scattering transform) | ||
TODO | ||
---- | ||
flesh out the relationship btw real and complex cases | ||
handle multi channel input | ||
eg. from front end, or from recursive calls | ||
hack: fold channels into batch dim and then back out | ||
cross channel processing - thats next level | ||
""" | ||
def __init__(self, image_size, n_ori=4, n_scale='auto', | ||
downsample_dict=True, is_complex=True): | ||
super().__init__() | ||
|
||
self.downsample_dict = downsample_dict | ||
self.is_complex = is_complex | ||
|
||
pyr = Steerable_Pyramid_Freq(image_size, | ||
order=n_ori-1, | ||
height=n_scale, | ||
is_complex=is_complex, | ||
downsample=downsample_dict) | ||
self.n_ori = pyr.num_orientations | ||
self.n_scale = pyr.num_scales | ||
|
||
if downsample_dict: | ||
self.pyramid_analysis = lambda x: pyr.forward(x) | ||
self.pyramid_synthesis = lambda y: pyr.recon_pyr(y) | ||
if is_complex: | ||
self.decomposition = rectangular_to_polar_dict | ||
self.recomposition = polar_to_rectangular_dict | ||
else: | ||
self.decomposition = local_gain_control_dict | ||
self.recomposition = local_gain_release_dict | ||
else: | ||
def stash(y, info): | ||
self.pyr_info = info | ||
return y | ||
self.pyramid_analysis = lambda x: stash(*pyr.convert_pyr_to_tensor( | ||
pyr.forward(x))) | ||
self.pyramid_synthesis = lambda y: pyr.recon_pyr( | ||
pyr.convert_tensor_to_pyr(y, *self.pyr_info)) | ||
if is_complex: | ||
self.decomposition = rectangular_to_polar | ||
self.recomposition = polar_to_rectangular | ||
else: | ||
self.decomposition = local_gain_control | ||
self.recomposition = local_gain_release | ||
|
||
def analysis(self, x): | ||
y = self.pyramid_analysis(x) | ||
energy, state = self.decomposition(y) | ||
return energy, state | ||
|
||
def synthesis(self, energy, state): | ||
y = self.recomposition(energy, state) | ||
x = self.pyramid_synthesis(y) | ||
return x | ||
|
||
def forward(self, x): | ||
return self.analysis(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.