Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pooled texture model #284

Draft
wants to merge 37 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
59b9e9e
initialize masked
billbrod Aug 9, 2024
1528fa5
adds mask to model
billbrod Aug 9, 2024
966d8ad
changes compute_pixel_stats to weighted version
billbrod Aug 9, 2024
09a4e43
create mask per scale, updates pixel stats
billbrod Aug 13, 2024
ded8204
adds masked autocorr computation
billbrod Aug 13, 2024
ad79bcf
bugfix: need extra scale for reconstructed image autocorr
billbrod Aug 13, 2024
0d61ded
updates comments describing shapes
billbrod Aug 13, 2024
e5851a2
makes compute_skew_kurtosis_recon work with masks
billbrod Aug 13, 2024
dff122d
update name: m -> scale_mask to be consistent
billbrod Aug 14, 2024
2c7cfc9
Merge branch 'portilla_simoncelli_vars' of github.com:LabForComputati…
billbrod Aug 14, 2024
885a17c
sets _compute_cross_correlation to new version
billbrod Aug 14, 2024
230ecf1
updates docstring
billbrod Aug 14, 2024
fa9135f
update masked autocorr function for dim rearrangement
billbrod Aug 14, 2024
1f33f73
gets masked version working!
billbrod Aug 14, 2024
153000a
fix portilla_simoncelli.update_plot for GPU
billbrod Aug 15, 2024
69e1796
add masked model back to init
billbrod Aug 15, 2024
c8f6285
changes to ps_masked plotting
billbrod Aug 15, 2024
44592d3
gets working with GPU, correct across-scale issue
billbrod Aug 15, 2024
5793aed
compute_pixel_stats now expects single-scale mask
billbrod Aug 15, 2024
62e5f6d
normalize skew/kurtosis recon with variance, which brings them to sim…
billbrod Aug 15, 2024
9063395
bugfix: get blur_downsample working with other dtypes
billbrod Aug 16, 2024
f6d75ae
improve how masks are handled across scales
billbrod Aug 16, 2024
e8cb3dc
Make division/sqrt more stable, remove redundant autocorrs
billbrod Aug 16, 2024
53cb7f0
Merge branch 'main' of github.com:LabForComputationalVision/plenoptic…
billbrod Aug 16, 2024
888dc8c
makes PS.plot_representation defaults more reasonable
billbrod Aug 16, 2024
6255bb4
simplify how we handle stability
billbrod Aug 19, 2024
dc3a700
adds example notebook
billbrod Aug 19, 2024
46f5a1e
adds nblink file
billbrod Aug 19, 2024
43c4737
Merge branch 'main' of github.com:plenoptic-org/plenoptic into pooled…
billbrod Dec 10, 2024
873e513
Merge branch 'main' of github.com:plenoptic-org/plenoptic into pooled…
billbrod Dec 11, 2024
e2fbd7e
runs ruff, pre commit on files
billbrod Dec 12, 2024
ed9915e
change compute_pixel_stats to compute actual stats
billbrod Dec 12, 2024
172b03e
adds pixel_epsilon, reruns notebook
billbrod Dec 13, 2024
bd83bd3
adds weighted average module
billbrod Dec 16, 2024
b6677b3
some improvements to weighted average
billbrod Dec 16, 2024
c6df6b5
PS now accepts weights argumnet
billbrod Dec 20, 2024
15d0032
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/tutorials/models/pooled_texture_model.nblink
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"path": "../../../examples/pooled_texture_model.ipynb"
}
16,491 changes: 16,491 additions & 0 deletions examples/pooled_texture_model.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .laplacian_pyramid import LaplacianPyramid
from .non_linearities import *
from .steerable_pyramid_freq import SteerablePyramidFreq
from .weighted_average import *
304 changes: 304 additions & 0 deletions src/plenoptic/simulate/canonical_computations/weighted_average.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
#!/usr/bin/env python3

import torch
from torch import Tensor
import einops
import warnings
from ...tools.conv import blur_downsample
from typing import Literal


class SimpleAverage(torch.nn.Module):
"""Module to average over the last two dimensions of input"""

def __init__(self):
super().__init__()

def forward(self, x: Tensor):
"""Average over the last two dimensions of input."""
return torch.mean(x, dim=(-2, -1))


class WeightedAverage(torch.nn.Module):
"""Module to take a weighted average over last two dimensions of tensors.

- Weights are set at initialization, must be non-negative, and 3d Tensors (different
weighting regions indexed on first dimension, height and width on last two).

- If two weights are set, they are multiplied together when taking the average
(e.g., separable polar angle and eccentricity weights).

- Weights are normalized at initialization so that they sum to 1 (as best as
possible). If any weighting region sums to near-zero, an exception is raised. If
there's variation across weighting region sums, a warning is raised.

Parameters
----------
weights_1, weights_2 :
3d Tensors defining the weights for the average.
image_shape :
Last two dimensions of weights tensors

Attributes
----------
weights :
List of one or two 3d Tensors (depending on whether weights_2 set at
initialization) containing the normalized weights.

"""

def __init__(self, weights_1: Tensor, weights_2: Tensor | None = None):
super().__init__()
self._validate_weights(weights_1, "_1")
self.register_buffer("_weights_1", weights_1)
self._n_weights = 1
input_einsum = "w1 h w"
output_einsum = "w1"
if weights_2 is not None:
self._validate_weights(weights_2, "_2")
if weights_1.shape[-2:] != weights_2.shape[-2:]:
raise ValueError(
"weights_1 and weights_2 must have same height and width!"
)
self._n_weights += 1
input_einsum += ", w2 h w"
output_einsum += " w2"
self.register_buffer("_weights_2", weights_2)
self.image_shape = weights_1.shape[-2:]
self._input_einsum = input_einsum
self._output_einsum = output_einsum
self._weight_einsum = f"{input_einsum} -> {output_einsum}"
self._forward_einsum = f"{input_einsum}, b c {{extra_dims}} h w -> b c {output_einsum} {{extra_dims}}"
self._extra_dims = ["", "i1", "i1 i2"]
self._normalize_weights()

@property
def weights(self):
weights = [self._weights_1]
if self._n_weights > 1:
weights.append(self._weights_2)
return weights

def _normalize_weights(self):
"""Normalize weights.

Call sum_weights() to multiply and sum all weights, then:

- Check whether any weighting region sum is near-zero. If so, raise ValueError

- Check variance of weighting region sums and raise warning if that variance is
not near-zero. (Ideally, all weighting region sums would be the same value.)

- Take the modal weighting region sum value and divide all weighting regions by
that value to normalize them. If we have two weight tensors, divide each by
the sqrt of the mode instead.

"""
weight_sums = self.sum_weights()
if torch.isclose(weight_sums, torch.zeros_like(weight_sums)).any():
raise ValueError(
"Some of the weights sum to zero! This will not work out well."
)
var = weight_sums.var()
if not torch.isclose(var, torch.zeros_like(var)):
warnings.warn(
"Looks like there's some variation in the sums across your weights."
" That might be fine, but just wanted to make sure you knew..."
)
mode = torch.mode(weight_sums.flatten()).values
if not torch.isclose(mode, torch.ones_like(mode)):
warnings.warn("Weights don't sum to 1, normalizing...")
if self._n_weights == 1:
self._weights_1 = self._weights_1 / mode
else:
self._weights_1 = self._weights_1 / mode.sqrt()
self._weights_2 = self._weights_2 / mode.sqrt()

@staticmethod
def _validate_weights(weights: Tensor, idx: "str" = "_1"):
if weights.ndim != 3:
raise ValueError(f"weights{idx} must be 3d!")
if weights.min() < 0:
raise ValueError(f"weights{idx} must be non-negative!")

def forward(self, image: Tensor) -> Tensor:
"""Take the weighted average over last two dimensions of input.

All other dimensions are preserved.

Parameters
----------
image :
4d to 6d Tensor.

Returns
-------
weighted_avg :
Weighted average. Dimensionality depends on both the input's dimensionality
and ``len(weights)``

"""
if image.ndim < 4:
raise ValueError("image must be a tensor of 4 to 6 dimensions!")
try:
extra_dims = self._extra_dims[image.ndim - 4]
except IndexError:
raise ValueError("image must be a tensor of 4 to 6 dimensions!")
einsum_str = self._forward_einsum.format(extra_dims=extra_dims)
return einops.einsum(*self.weights, image, einsum_str).flatten(2, 3)

def einsum(self, einsum_str: str, *tensors: Tensor) -> Tensor:
"""More general version of forward.

This takes the input einsum_str and prepends self.weights to it and inserts the
weight dimensions into the output after "b c" (for batch, channel). Thus this
will be weird if there's no "b c" dimensions.

Parameters
----------
einsum_str :
String of einsum notation, which must contain "b c" in the output. Intended
use is that this string produces a single output tensor.
tensors :
Any number of tensors

Returns
-------
output :
The result of this einsum

"""
einsum_str = f"{self._input_einsum}, {einsum_str.split('->')[0]} -> b c {self._output_einsum} {einsum_str.split('->')[1].replace('b c', '')}"
return einops.einsum(*self.weights, *tensors, einsum_str).flatten(2, 3)

def sum_weights(self) -> Tensor:
"""Sum weights, largely used for diagnostic purposes.

Returns
-------
sum :
1d or 2d tensor (depending on ``len(weights)``) containing the sum of all
weights.

"""
return einops.einsum(*self.weights, self._weight_einsum)


class WeightedAveragePyramid(torch.nn.Module):
"""Module to take weighted average across scales.

This initializes a ``WeightedAverage`` per scale, down-sampling by a factor of 2
using the ``blur_downsample`` method (and normalizing independently).

As with ``WeightedAverage``:

- Weights are set at initialization, must be non-negative, and 3d Tensors (different
weighting regions indexed on first dimension, height and width on last two).

- If two weights are set, they are multiplied together when taking the average
(e.g., separable polar angle and eccentricity weights).

- Weights are normalized at initialization so that they sum to 1 (as best as
possible). If any weighting region sums to near-zero, an exception is raised. If
there's variation across weighting region sums, a warning is raised.

Parameters
----------
weights_1, weights_2 :
3d Tensors defining the weights for the average.
n_scales :
Number of scales.

Attributes
----------
weights :
ModuleList of ``WeightedAverage`` at each scale

"""

def __init__(
self, weights_1: Tensor, weights_2: Tensor | None = None, n_scales: int = 4
):
super().__init__()
self._n_weights = 1 if weights_2 is None else 2
self.n_scales = n_scales
weights = []
for i in range(n_scales):
if i != 0:
weights_1 = blur_downsample(
weights_1.unsqueeze(0), 1, scale_filter=True
)
weights_1 = weights_1.squeeze(0).clip(min=0)
if weights_2 is not None:
weights_2 = blur_downsample(
weights_2.unsqueeze(0), 1, scale_filter=True
)
weights_2 = weights_2.squeeze(0).clip(min=0)
# it's possible negative values will get introduced by the downsampling
# above, in which case we remove them
weights.append(WeightedAverage(weights_1, weights_2))
self.weights = torch.nn.ModuleList(weights)

def __getitem__(self, idx: int):
return self.weights[idx]

def forward(self, image: list[Tensor]) -> list[Tensor]:
"""Take the weighted average over last two dimensions of each input in list.

All other dimensions are preserved.

Parameters
----------
image :
List of 4d to 6d Tensor, each of which has been downsampled by a factor of 2.

Returns
-------
weighted_avg :
Weighted average. Dimensionality depends on both the input's dimensionality
and whether ``weights_2`` was set at initialization. Scales are stacked
along last dimension.

"""
return torch.stack([w(x) for x, w in zip(image, self.weights)], dim=-1)

def einsum(self, einsum_str: str, *tensors: list[Tensor]) -> list[Tensor]:
"""More general version of forward, operates on each

This takes the input einsum_str and prepends self.weights to it and inserts the
weight dimensions into the output after "b c" (for batch, channel). Thus this
will be weird if there's no "b c" dimensions.

Parameters
----------
einsum_str :
String of einsum notation, which must contain "b c" in the output. Intended
use is that this string produces a single output tensor.
tensors :
Any number of lists of tensors (should all have same number of elements,
each corresponding to a different scale and thus downsampled by factor of 2).

Returns
-------
output :
The result of this einsum. Scales are stacked along last dimension.

"""
return torch.stack(
[w.einsum(einsum_str, *x) for *x, w in zip(*tensors, self.weights)], dim=-1
)

def sum_weights(self) -> Tensor:
"""Sum weights, largely used for diagnostic purposes.

Returns
-------
sum :
2d or 3d tensor (depending on whether ``weights_2`` was set at
initialization) containing the sum of all weights on each scale.

"""
sums = []
for w in self.weights:
sums.append(w.sum_weights())
return einops.pack(sums, f"* {self.weights[0]._output_einsum}")[0]
1 change: 1 addition & 0 deletions src/plenoptic/simulate/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .frontend import *
from .naive import *
from .portilla_simoncelli import PortillaSimoncelli
from .portilla_simoncelli_masked import PortillaSimoncelliMasked
Loading
Loading