Skip to content

Commit

Permalink
Merge pull request #225 from LabForComputationalVision/ps_refactor
Browse files Browse the repository at this point in the history
Refactor Portilla-Simoncelli model
  • Loading branch information
billbrod authored Feb 29, 2024
2 parents 4ec6d5b + cad4606 commit 136527d
Show file tree
Hide file tree
Showing 22 changed files with 24,837 additions and 26,830 deletions.
38 changes: 36 additions & 2 deletions docs/tips.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ methods.
- For metamers, this means double-checking that the difference between the model
representation of the metamer and the target image is small enough. If your
model's representation is multi-scale, trying coarse-to-fine optimization may
help (see `notebook <tutorials/06_Metamer.html#Coarse-to-fine-optimization>`_
help (see `notebook <tutorials/intro/06_Metamer.html#Coarse-to-fine-optimization>`_
for details).
- For MAD competition, this means double-checking that the reference metric is
constant and that the optimized metric has converged at a lower or higher
Expand All @@ -59,6 +59,40 @@ Additionally, it may be helpful to visualize the progression of synthesis, using
each synthesis method's ``animate`` or ``plot_synthesis_status`` helper
functions (e.g., :func:`plenoptic.synthesize.metamer.plot_synthesis_status`).

Tweaking the model
------------------

You can also improve your changes of finding a good synthesis by tweaking the
model. For example, the loss function used for metamer synthesis by default is
mean-squared error. This implicitly weights all aspects of the model's
representation equally. Thus, if there are portions of the representation whose
magnitudes are significantly smaller than the others, they might not be matched
at the same rate as the others. You can address this using coarse-to-fine
synthesis or picking a more suitable loss function, but it's generally a good
idea for all of a model's representation to have roughly the same magnitude. You
can do this in a principled or empirical manner:

- Principled: compose your representation of statistics that you know lie within
the same range. For example, use correlations instead of covariances (see the
Portilla-Simoncelli model, and in particular `how plenoptic's implementation
differs from matlab
<tutorials/models/Metamer-Portilla-Simoncelli#7.-Notable-differences-between-Matlab-and-Python-Implementations>`_
for an example of this).
- Empirical: measure your model's representation on a dataset of relevant
natural images and then use this output to z-score your model's representation
on each pass (see [Ziemba2021]_ for an example; this is what the Van Hateren
database is used for).
- In the middle: normalize statistics based on their value in the original image
(note: not the image the model is taking as input! this will likely make
optimization very difficult).

If you are computing a multi-channel representation, you may have a similar
problem where one channel is larger or smaller than the others. Here, tweaking
the loss function might be more useful. Using something like `logsumexp` (the
log of the sum of exponentials, a smooth approximation of the maximum function)
to combine across channels after using something like L2-norm to compute the
loss within each channel might help.

None of the existing synthesis methods meet my needs
====================================================

Expand All @@ -79,4 +113,4 @@ methods.

If you extend a method successfully or would like help making it work, please
let us know by posting a `discussion!
<https://github.com/Flatiron-CCN/plenoptic/discussions>`_
<https://github.com/LabForComputationalVision/plenoptic/discussions>`_
15 changes: 2 additions & 13 deletions examples/02_Eigendistortions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -572,21 +572,10 @@
}
],
"source": [
"# a couple helper functions\n",
"\n",
"def center_crop(im, n):\n",
" \"\"\"Crop an nxn image from the center of im\"\"\"\n",
" im_height, im_width = im.shape[-2:]\n",
" assert n<im_height and n<im_width\n",
"\n",
" im_crop = im[..., im_height//2-n//2:im_height//2+n//2,\n",
" im_width//2-n//2:im_width//2+n//2]\n",
" return im_crop\n",
"\n",
"n = 128 # this will be the img_height and width of the input, you can change this to accommodate your machine\n",
"img = po.data.color_wheel()\n",
"# center crop the image to nxn\n",
"img = center_crop(img, n)\n",
"img = po.tools.center_crop(img, n)\n",
"po.imshow(img, as_rgb=True, zoom=3);"
]
},
Expand Down Expand Up @@ -975,7 +964,7 @@
"img = po.data.curie()\n",
"\n",
"# center crop the image to nxn\n",
"img = center_crop(img, n)\n",
"img = po.tools.center_crop(img, n)\n",
"# because this is a grayscale image but ResNet expects a color image, \n",
"# need to duplicate along the color dimension\n",
"img3 = torch.repeat_interleave(img, 3, dim=1)\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/03_Steerable_Pyramid.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
"for k in pyr_coeffs.keys():\n",
" # we ignore the residual_highpass and residual_lowpass, since we're focusing on the filters here\n",
" if isinstance(k, tuple):\n",
" reconList.append(pyr.recon_pyr(pyr_coeffs, k[0], k[1]))\n",
" reconList.append(pyr.recon_pyr(pyr_coeffs, [k[0]], [k[1]]))\n",
" \n",
"po.imshow(reconList, col_wrap=order+1, vrange='indep1', zoom=2);"
]
Expand Down
9 changes: 4 additions & 5 deletions examples/05_Geodesics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
" \" please install it in your plenoptic environment \"\n",
" \"and restart the notebook kernel\")\n",
"import torchvision.transforms as transforms\n",
"from torchvision.transforms.functional import center_crop\n",
"from torchvision import models\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
Expand Down Expand Up @@ -111,7 +110,7 @@
"einstein = po.data.einstein()\n",
"einstein = po.tools.conv.blur_downsample(einstein, n_scales=2)\n",
"vid = po.tools.translation_sequence(einstein, n_steps=20)\n",
"vid = center_crop(vid, image_size // 2)\n",
"vid = po.tools.center_crop(vid, image_size // 2)\n",
"vid = po.tools.rescale(vid, 0, 1)\n",
"\n",
"imgA = vid[0:1]\n",
Expand Down Expand Up @@ -1066,9 +1065,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:plen_3.10]",
"display_name": "plenoptic",
"language": "python",
"name": "conda-env-plen_3.10-py"
"name": "plenoptic"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -1080,7 +1079,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
"version": "3.10.13"
},
"toc-autonumbering": true,
"toc-showtags": true
Expand Down
27,533 changes: 12,942 additions & 14,591 deletions examples/06_Metamer.ipynb

Large diffs are not rendered by default.

19,129 changes: 9,125 additions & 10,004 deletions examples/Display.ipynb

Large diffs are not rendered by default.

1,246 changes: 409 additions & 837 deletions examples/Metamer-Portilla-Simoncelli.ipynb

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,15 @@ dev = [
"pytest>=5.1.2",
'pytest-cov',
'pytest-xdist',
"torchvision>=0.3",
"requests>=2.21",
"pooch>=1.2.0",
]

nb = [
'jupyter',
'ipywidgets',
"torchvision>=0.3",
'nbclient>=0.5.5',
"torchvision>=0.3",
"pooch>=1.2.0",
]

Expand Down
10 changes: 9 additions & 1 deletion src/plenoptic/data/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
'sample_images.tar.gz': '0ba6fe668a61e9f3cb52032da740fbcf32399ffcc142ddb14380a8e404409bf5',
'test_images.tar.gz': 'eaf35f5f6136e2d51e513f00202a11188a85cae8c6f44141fb9666de25ae9554',
'tid2013.tar.gz': 'bc486ac749b6cfca8dc5f5340b04b9bb01ab24149a5f3a712f13e9d0489dcde0',
'portilla_simoncelli_test_vectors_refactor.tar.gz': '2ca60f1a60b192668567eb3d94c0cdc8679b23bf94a98331890c41eb9406503a',
'portilla_simoncelli_synthesize_torch_v1.12.0_ps-refactor.npz': '9525844b71cf81509b86ed9677172745353588c6bb54e4de8000d695598afa47',
'portilla_simoncelli_synthesize_gpu_ps-refactor.npz': '9fbb490f1548133f6aa49c54832130cf70f8dc6546af59688ead17f62ab94e61',
'portilla_simoncelli_scales_ps-refactor.npz': '1053790a37707627810482beb0dd059cbc193efd688b60c441b3548a75626fdf',
}

OSF_TEMPLATE = "https://osf.io/{}/download"
Expand All @@ -40,8 +44,12 @@
'sample_images.tar.gz': OSF_TEMPLATE.format('6drmy'),
'test_images.tar.gz': OSF_TEMPLATE.format('au3b8'),
'tid2013.tar.gz': OSF_TEMPLATE.format('uscgv'),
'portilla_simoncelli_test_vectors_refactor.tar.gz': OSF_TEMPLATE.format('ca7qt'),
'portilla_simoncelli_synthesize_torch_v1.12.0_ps-refactor.npz': OSF_TEMPLATE.format('vmwzd'),
'portilla_simoncelli_synthesize_gpu_ps-refactor.npz': OSF_TEMPLATE.format('mqs6y'),
'portilla_simoncelli_scales_ps-refactor.npz': OSF_TEMPLATE.format('nvpr4'),
}
DOWNLOADABLE_FILES = list(REGISTRY.keys())
DOWNLOADABLE_FILES = list(REGISTRY_URLS.keys())

import pathlib
from typing import List
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def forward(self, x):
Parameters
----------
x: torch.Tensor of shape (B, C, H, W)
x: torch.Tensor of shape (batch, channel, height, width)
Image, or batch of images. If there are multiple channels,
the Laplacian is computed separately for each of them
Expand Down Expand Up @@ -71,7 +71,7 @@ def recon_pyr(self, y):
Returns
-------
x: torch.Tensor of shape (B, C, H, W)
x: torch.Tensor of shape (batch, channel, height, width)
Image, or batch of images
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def local_gain_control(x, epsilon=1e-8):
Parameters
----------
x : torch.Tensor
Tensor of shape (B,C,H,W)
Tensor of shape (batch, channel, height, width)
epsilon: float, optional
Small constant to avoid division by zero.
Expand Down Expand Up @@ -134,7 +134,7 @@ def local_gain_release(norm, direction, epsilon=1e-8):
Returns
-------
x : torch.Tensor
Tensor of shape (B,C,H,W)
Tensor of shape (batch, channel, height, width)
Notes
-----
Expand Down Expand Up @@ -163,7 +163,7 @@ def local_gain_control_dict(coeff_dict, residuals=True):
Parameters
----------
coeff_dict : dict
A dictionary containing tensors of shape (B,C,H,W)
A dictionary containing tensors of shape (batch, channel, height, width)
residuals: bool, optional
An option to carry around residuals in the energy dict.
Note that the transformation is not applied to the residuals,
Expand Down Expand Up @@ -219,7 +219,7 @@ def local_gain_release_dict(energy, state, residuals=True):
Returns
-------
coeff_dict : dict
A dictionary containing tensors of shape (B,C,H,W)
A dictionary containing tensors of shape (batch, channel, height, width)
Notes
-----
Expand Down
Loading

0 comments on commit 136527d

Please sign in to comment.