From 79e93d3f02d4b121173f5d1836875fb4fbc13e09 Mon Sep 17 00:00:00 2001 From: Hanna Dettki Date: Thu, 8 Aug 2024 11:17:48 -0400 Subject: [PATCH] Revert "updating some deprecated imports, isinstance for union of types, unsorted imports, f-strings, replaced single quote with double quotes and deleted trailing whitespace" This reverts commit c1fd8bcb131387240356fe85392b66b5ae2bb6b4. --- examples/00_quickstart.ipynb | 12 +- examples/02_Eigendistortions.ipynb | 14 +- examples/03_Steerable_Pyramid.ipynb | 18 +- examples/04_Perceptual_distance.ipynb | 15 +- examples/05_Geodesics.ipynb | 49 +- examples/06_Metamer.ipynb | 8 +- examples/07_Simple_MAD.ipynb | 17 +- examples/08_MAD_Competition.ipynb | 8 +- examples/09_Original_MAD.ipynb | 9 +- examples/Demo_Eigendistortion.ipynb | 4 +- examples/Display.ipynb | 6 +- examples/Metamer-Portilla-Simoncelli.ipynb | 32 +- examples/Synthesis_extensions.ipynb | 22 +- noxfile.py | 2 - src/plenoptic/__init__.py | 10 +- src/plenoptic/data/__init__.py | 28 +- src/plenoptic/data/data_utils.py | 14 +- src/plenoptic/data/fetch.py | 110 +-- src/plenoptic/metric/__init__.py | 4 +- src/plenoptic/metric/classes.py | 12 +- src/plenoptic/metric/perceptual_distance.py | 165 ++-- src/plenoptic/simulate/__init__.py | 2 +- .../canonical_computations/__init__.py | 4 +- .../canonical_computations/filters.py | 27 +- .../laplacian_pyramid.py | 3 +- .../canonical_computations/non_linearities.py | 29 +- .../steerable_pyramid_freq.py | 221 ++--- src/plenoptic/simulate/models/frontend.py | 109 +-- src/plenoptic/simulate/models/naive.py | 80 +- .../simulate/models/portilla_simoncelli.py | 171 ++-- src/plenoptic/synthesize/__init__.py | 2 +- src/plenoptic/synthesize/autodiff.py | 7 +- src/plenoptic/synthesize/eigendistortion.py | 129 +-- src/plenoptic/synthesize/geodesic.py | 281 ++---- src/plenoptic/synthesize/mad_competition.py | 763 ++++++--------- src/plenoptic/synthesize/metamer.py | 873 +++++++----------- src/plenoptic/synthesize/simple_metamer.py | 50 +- src/plenoptic/synthesize/synthesis.py | 179 ++-- src/plenoptic/tools/__init__.py | 12 +- src/plenoptic/tools/conv.py | 75 +- src/plenoptic/tools/convergence.py | 37 +- src/plenoptic/tools/data.py | 42 +- src/plenoptic/tools/display.py | 342 +++---- src/plenoptic/tools/external.py | 128 +-- src/plenoptic/tools/optim.py | 15 +- src/plenoptic/tools/signal.py | 90 +- src/plenoptic/tools/stats.py | 26 +- src/plenoptic/tools/straightness.py | 48 +- src/plenoptic/tools/validate.py | 81 +- 49 files changed, 1636 insertions(+), 2749 deletions(-) diff --git a/examples/00_quickstart.ipynb b/examples/00_quickstart.ipynb index 0526e39a..faf80c8b 100644 --- a/examples/00_quickstart.ipynb +++ b/examples/00_quickstart.ipynb @@ -15,11 +15,10 @@ "metadata": {}, "outputs": [], "source": [ - "import matplotlib.pyplot as plt\n", - "import torch\n", - "\n", "import plenoptic as po\n", - "\n", + "import torch\n", + "import pyrtools as pt\n", + "import matplotlib.pyplot as plt\n", "# so that relative sizes of axes created by po.imshow and others look right\n", "plt.rcParams['figure.dpi'] = 72\n", "\n", @@ -84,10 +83,7 @@ ], "source": [ "# this is a convenience function for creating a simple Gaussian kernel\n", - "from plenoptic.simulate.canonical_computations.filters import (\n", - " circular_gaussian2d,\n", - ")\n", - "\n", + "from plenoptic.simulate.canonical_computations.filters import circular_gaussian2d\n", "\n", "# Simple rectified Gaussian convolutional model\n", "class SimpleModel(torch.nn.Module):\n", diff --git a/examples/02_Eigendistortions.ipynb b/examples/02_Eigendistortions.ipynb index f75c9602..8b85fc29 100644 --- a/examples/02_Eigendistortions.ipynb +++ b/examples/02_Eigendistortions.ipynb @@ -45,14 +45,11 @@ "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", - "\n", "# so that relative sizes of axes created by po.imshow and others look right\n", "plt.rcParams['figure.dpi'] = 72\n", "import torch\n", - "from torch import nn\n", - "\n", "from plenoptic.synthesize.eigendistortion import Eigendistortion\n", - "\n", + "from torch import nn\n", "# this notebook uses torchvision, which is an optional dependency.\n", "# if this fails, install torchvision in your plenoptic environment \n", "# and restart the notebook kernel.\n", @@ -62,6 +59,7 @@ " raise ModuleNotFoundError(\"optional dependency torchvision not found!\"\n", " \" please install it in your plenoptic environment \"\n", " \"and restart the notebook kernel\")\n", + "import os.path as op\n", "import plenoptic as po" ] }, @@ -824,7 +822,7 @@ } ], "source": [ - "po.synth.eigendistortion.display_eigendistortion(ed_resneta, 0, as_rgb=True, zoom=3)\n", + "po.synth.eigendistortion.display_eigendistortion(ed_resneta, 0, as_rgb=True, zoom=3);\n", "po.synth.eigendistortion.display_eigendistortion(ed_resneta, -1, as_rgb=True, zoom=3);" ] }, @@ -1027,10 +1025,10 @@ } ], "source": [ - "po.synth.eigendistortion.display_eigendistortion(ed_resneta, 0, as_rgb=True, zoom=2, title=\"top eigendist\")\n", - "po.synth.eigendistortion.display_eigendistortion(ed_resneta, -1, as_rgb=True, zoom=2, title=\"bottom eigendist\")\n", + "po.synth.eigendistortion.display_eigendistortion(ed_resneta, 0, as_rgb=True, zoom=2, title=\"top eigendist\");\n", + "po.synth.eigendistortion.display_eigendistortion(ed_resneta, -1, as_rgb=True, zoom=2, title=\"bottom eigendist\");\n", "\n", - "po.synth.eigendistortion.display_eigendistortion(ed_resnetb, 0, as_rgb=True, zoom=2, title=\"top eigendist\")\n", + "po.synth.eigendistortion.display_eigendistortion(ed_resnetb, 0, as_rgb=True, zoom=2, title=\"top eigendist\");\n", "po.synth.eigendistortion.display_eigendistortion(ed_resnetb, -1, as_rgb=True, zoom=2, title=\"bottom eigendist\");" ] }, diff --git a/examples/03_Steerable_Pyramid.ipynb b/examples/03_Steerable_Pyramid.ipynb index 2b82cddf..a1030fba 100644 --- a/examples/03_Steerable_Pyramid.ipynb +++ b/examples/03_Steerable_Pyramid.ipynb @@ -21,7 +21,6 @@ "source": [ "import numpy as np\n", "import torch\n", - "\n", "# this notebook uses torchvision, which is an optional dependency.\n", "# if this fails, install torchvision in your plenoptic environment \n", "# and restart the notebook kernel.\n", @@ -31,19 +30,20 @@ " raise ModuleNotFoundError(\"optional dependency torchvision not found!\"\n", " \" please install it in your plenoptic environment \"\n", " \"and restart the notebook kernel\")\n", - "import matplotlib.pyplot as plt\n", - "import torch.nn.functional as F\n", "import torchvision.transforms as transforms\n", + "import torch.nn.functional as F\n", "from torch import nn\n", + "import matplotlib.pyplot as plt\n", "\n", + "import pyrtools as pt\n", "import plenoptic as po\n", "from plenoptic.simulate import SteerablePyramidFreq\n", + "from plenoptic.synthesize import Eigendistortion\n", "from plenoptic.tools.data import to_numpy\n", - "\n", "dtype = torch.float32\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "import os\n", "from tqdm.auto import tqdm\n", - "\n", "%load_ext autoreload\n", "\n", "%autoreload 2\n", @@ -218,7 +218,7 @@ ], "source": [ "print(pyr_coeffs.keys())\n", - "po.pyrshow(pyr_coeffs, zoom=0.5, batch_idx=0)\n", + "po.pyrshow(pyr_coeffs, zoom=0.5, batch_idx=0);\n", "po.pyrshow(pyr_coeffs, zoom=0.5, batch_idx=1);" ] }, @@ -267,7 +267,7 @@ "#get the 3rd scale\n", "print(pyr.scales)\n", "pyr_coeffs_scale0 = pyr(im_batch, scales=[2])\n", - "po.pyrshow(pyr_coeffs_scale0, zoom=2, batch_idx=0)\n", + "po.pyrshow(pyr_coeffs_scale0, zoom=2, batch_idx=0);\n", "po.pyrshow(pyr_coeffs_scale0, zoom=2, batch_idx=1);" ] }, @@ -323,7 +323,7 @@ ], "source": [ "# the same visualization machinery works for complex pyramids; what is shown is the magnitude of the coefficients\n", - "po.pyrshow(pyr_coeffs_complex, zoom=0.5, batch_idx=0)\n", + "po.pyrshow(pyr_coeffs_complex, zoom=0.5, batch_idx=0);\n", "po.pyrshow(pyr_coeffs_complex, zoom=0.5, batch_idx=1);" ] }, @@ -2310,7 +2310,7 @@ } ], "source": [ - "po.pyrshow(pyr_coeffs_complex, zoom=0.5)\n", + "po.pyrshow(pyr_coeffs_complex, zoom=0.5);\n", "po.pyrshow(pyr_coeffs_fixed_1, zoom=0.5);" ] }, diff --git a/examples/04_Perceptual_distance.ipynb b/examples/04_Perceptual_distance.ipynb index 93a1c869..46bd12f0 100644 --- a/examples/04_Perceptual_distance.ipynb +++ b/examples/04_Perceptual_distance.ipynb @@ -28,15 +28,14 @@ "outputs": [], "source": [ "import os\n", - "\n", + "import io\n", "import imageio\n", - "import matplotlib.pyplot as plt\n", + "import plenoptic as po\n", "import numpy as np\n", - "import torch\n", - "from PIL import Image\n", "from scipy.stats import pearsonr, spearmanr\n", - "\n", - "import plenoptic as po" + "import matplotlib.pyplot as plt\n", + "import torch\n", + "from PIL import Image" ] }, { @@ -81,8 +80,6 @@ "outputs": [], "source": [ "import tempfile\n", - "\n", - "\n", "def add_jpeg_artifact(img, quality):\n", " # need to convert this back to 2d 8-bit int for writing out as jpg\n", " img = po.to_numpy(img.squeeze() * 255).astype(np.uint8)\n", @@ -396,7 +393,7 @@ " folder / \"distorted_images\" / distorted_filename).convert(\"L\"))) / 255\n", " distorted_images = distorted_images[:, [0] + list(range(2, 17)) + list(range(18, 24))] # Remove color distortions\n", "\n", - " with open(folder/ \"mos.txt\", encoding=\"utf-8\") as g:\n", + " with open(folder/ \"mos.txt\", \"r\", encoding=\"utf-8\") as g:\n", " mos_values = list(map(float, g.readlines()))\n", " mos_values = np.array(mos_values).reshape([25, 24, 5])\n", " mos_values = mos_values[:, [0] + list(range(2, 17)) + list(range(18, 24))] # Remove color distortions\n", diff --git a/examples/05_Geodesics.ipynb b/examples/05_Geodesics.ipynb index 73f32e30..a6fc4a13 100644 --- a/examples/05_Geodesics.ipynb +++ b/examples/05_Geodesics.ipynb @@ -36,24 +36,20 @@ } ], "source": [ - "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "\n", + "import matplotlib.pyplot as plt\n", "# so that relative sizes of axes created by po.imshow and others look right\n", "plt.rcParams['figure.dpi'] = 72\n", "%matplotlib inline\n", "\n", "import pyrtools as pt\n", - "\n", "import plenoptic as po\n", "from plenoptic.tools import to_numpy\n", - "\n", "%load_ext autoreload\n", "%autoreload 2\n", "\n", "import torch\n", "import torch.nn as nn\n", - "\n", "# this notebook uses torchvision, which is an optional dependency.\n", "# if this fails, install torchvision in your plenoptic environment \n", "# and restart the notebook kernel.\n", @@ -146,8 +142,6 @@ "outputs": [], "source": [ "import torch.fft\n", - "\n", - "\n", "class Fourier(nn.Module):\n", " def __init__(self, representation = 'amp'):\n", " super().__init__()\n", @@ -228,7 +222,7 @@ ], "source": [ "fig, axes = plt.subplots(2, 1, figsize=(5, 8))\n", - "po.synth.geodesic.plot_loss(moog, ax=axes[0])\n", + "po.synth.geodesic.plot_loss(moog, ax=axes[0]);\n", "po.synth.geodesic.plot_deviation_from_line(moog, vid, ax=axes[1]);" ] }, @@ -249,7 +243,7 @@ } ], "source": [ - "plt.plot(po.to_numpy(moog.step_energy), alpha=.2)\n", + "plt.plot(po.to_numpy(moog.step_energy), alpha=.2);\n", "plt.plot(moog.step_energy.mean(1), 'r-', label='path energy')\n", "plt.axhline(torch.linalg.vector_norm(moog.model(moog.image_a) - moog.model(moog.image_b), ord=2) ** 2 / moog.n_steps ** 2)\n", "plt.legend()\n", @@ -308,7 +302,7 @@ } ], "source": [ - "plt.plot(po.to_numpy(moog.dev_from_line[..., 1]))\n", + "plt.plot(po.to_numpy(moog.dev_from_line[..., 1]));\n", "\n", "plt.title('evolution of distance from representation line')\n", "plt.ylabel('distance from representation line')\n", @@ -367,7 +361,7 @@ "geodesic = to_numpy(moog.geodesic.squeeze())\n", "fig = pt.imshow([video[5], pixelfade[5], geodesic[5]],\n", " title=['video', 'pixelfade', 'geodesic'],\n", - " col_wrap=3, zoom=4)\n", + " col_wrap=3, zoom=4);\n", "\n", "size = geodesic.shape[-1]\n", "h, m , l = (size//2 + size//4, size//2, size//2 - size//4)\n", @@ -378,9 +372,9 @@ " a.axhline(line, lw=2)\n", "\n", "pt.imshow([video[:,l], pixelfade[:,l], geodesic[:,l]],\n", - " title=None, col_wrap=3, zoom=4)\n", + " title=None, col_wrap=3, zoom=4);\n", "pt.imshow([video[:,m], pixelfade[:,m], geodesic[:,m]],\n", - " title=None, col_wrap=3, zoom=4)\n", + " title=None, col_wrap=3, zoom=4);\n", "pt.imshow([video[:,h], pixelfade[:,h], geodesic[:,h]],\n", " title=None, col_wrap=3, zoom=4);" ] @@ -477,7 +471,7 @@ ], "source": [ "fig, axes = plt.subplots(2, 1, figsize=(5, 8))\n", - "po.synth.geodesic.plot_loss(moog, ax=axes[0])\n", + "po.synth.geodesic.plot_loss(moog, ax=axes[0]);\n", "po.synth.geodesic.plot_deviation_from_line(moog, ax=axes[1]);" ] }, @@ -524,7 +518,7 @@ } ], "source": [ - "plt.plot(po.to_numpy(moog.step_energy), alpha=.2)\n", + "plt.plot(po.to_numpy(moog.step_energy), alpha=.2);\n", "plt.plot(moog.step_energy.mean(1), 'r-', label='path energy')\n", "plt.axhline(torch.linalg.vector_norm(moog.model(moog.image_a) - moog.model(moog.image_b), ord=2) ** 2 / moog.n_steps ** 2)\n", "plt.legend()\n", @@ -636,9 +630,9 @@ ], "source": [ "print('geodesic')\n", - "pt.imshow(list(geodesic), vrange='auto1', title=None, zoom=4)\n", + "pt.imshow(list(geodesic), vrange='auto1', title=None, zoom=4);\n", "print('diff')\n", - "pt.imshow(list(geodesic - pixelfade), vrange='auto1', title=None, zoom=4)\n", + "pt.imshow(list(geodesic - pixelfade), vrange='auto1', title=None, zoom=4);\n", "print('pixelfade')\n", "pt.imshow(list(pixelfade), vrange='auto1', title=None, zoom=4);" ] @@ -663,7 +657,7 @@ "# checking that the range constraint is met\n", "plt.hist(video.flatten(), histtype='step', density=True, label='video')\n", "plt.hist(pixelfade.flatten(), histtype='step', density=True, label='pixelfade')\n", - "plt.hist(geodesic.flatten(), histtype='step', density=True, label='geodesic')\n", + "plt.hist(geodesic.flatten(), histtype='step', density=True, label='geodesic');\n", "plt.title('signal value histogram')\n", "plt.legend(loc=1)\n", "plt.show()" @@ -722,9 +716,9 @@ "l = 90\n", "imgA = imgA[..., u:u+224, l:l+224]\n", "imgB = imgB[..., u:u+224, l:l+224]\n", - "po.imshow([imgA, imgB], as_rgb=True)\n", + "po.imshow([imgA, imgB], as_rgb=True);\n", "diff = imgA - imgB\n", - "po.imshow(diff)\n", + "po.imshow(diff);\n", "pt.image_compare(po.to_numpy(imgA, True), po.to_numpy(imgB, True));" ] }, @@ -745,6 +739,7 @@ } ], "source": [ + "from torchvision import models\n", "# Create a class that takes the nth layer output of a given model\n", "class NthLayer(torch.nn.Module):\n", " \"\"\"Wrap any model to get the response of an intermediate layer\n", @@ -825,7 +820,7 @@ "predA = po.to_numpy(models.vgg16(pretrained=True)(imgA))[0]\n", "predB = po.to_numpy(models.vgg16(pretrained=True)(imgB))[0]\n", "\n", - "plt.plot(predA)\n", + "plt.plot(predA);\n", "plt.plot(predB);" ] }, @@ -940,7 +935,7 @@ ], "source": [ "fig, axes = plt.subplots(2, 1, figsize=(5, 8))\n", - "po.synth.geodesic.plot_loss(moog, ax=axes[0])\n", + "po.synth.geodesic.plot_loss(moog, ax=axes[0]);\n", "po.synth.geodesic.plot_deviation_from_line(moog, ax=axes[1]);" ] }, @@ -1057,12 +1052,12 @@ } ], "source": [ - "po.imshow(moog.geodesic, as_rgb=True, zoom=2, title=None, vrange='auto0')\n", - "po.imshow(moog.pixelfade, as_rgb=True, zoom=2, title=None, vrange='auto0')\n", + "po.imshow(moog.geodesic, as_rgb=True, zoom=2, title=None, vrange='auto0');\n", + "po.imshow(moog.pixelfade, as_rgb=True, zoom=2, title=None, vrange='auto0');\n", "# per channel difference\n", - "po.imshow([(moog.geodesic - moog.pixelfade)[1:-1, 0:1]], zoom=2, title=None, vrange='auto1')\n", - "po.imshow([(moog.geodesic - moog.pixelfade)[1:-1, 1:2]], zoom=2, title=None, vrange='auto1')\n", - "po.imshow([(moog.geodesic - moog.pixelfade)[1:-1, 2:]], zoom=2, title=None, vrange='auto1')\n", + "po.imshow([(moog.geodesic - moog.pixelfade)[1:-1, 0:1]], zoom=2, title=None, vrange='auto1');\n", + "po.imshow([(moog.geodesic - moog.pixelfade)[1:-1, 1:2]], zoom=2, title=None, vrange='auto1');\n", + "po.imshow([(moog.geodesic - moog.pixelfade)[1:-1, 2:]], zoom=2, title=None, vrange='auto1');\n", "# exaggerated color difference\n", "po.imshow([po.tools.rescale((moog.geodesic - moog.pixelfade)[1:-1])], as_rgb=True, zoom=2, title=None);" ] diff --git a/examples/06_Metamer.ipynb b/examples/06_Metamer.ipynb index a35c4644..16f5cc68 100644 --- a/examples/06_Metamer.ipynb +++ b/examples/06_Metamer.ipynb @@ -21,12 +21,12 @@ "metadata": {}, "outputs": [], "source": [ + "import plenoptic as po\n", + "from plenoptic.tools import to_numpy\n", "import imageio\n", - "import matplotlib.pyplot as plt\n", "import torch\n", - "\n", - "import plenoptic as po\n", - "\n", + "import pyrtools as pt\n", + "import matplotlib.pyplot as plt\n", "# so that relative sizes of axes created by po.imshow and others look right\n", "plt.rcParams['figure.dpi'] = 72\n", "# Animation-related settings\n", diff --git a/examples/07_Simple_MAD.ipynb b/examples/07_Simple_MAD.ipynb index 52b177b9..964594a6 100644 --- a/examples/07_Simple_MAD.ipynb +++ b/examples/07_Simple_MAD.ipynb @@ -24,18 +24,15 @@ } ], "source": [ - "import matplotlib.pyplot as plt\n", - "import pyrtools as pt\n", - "import torch\n", - "\n", "import plenoptic as po\n", "from plenoptic.tools import to_numpy\n", - "\n", + "import torch\n", + "import pyrtools as pt\n", + "import matplotlib.pyplot as plt\n", "# so that relative sizes of axes created by po.imshow and others look right\n", "plt.rcParams['figure.dpi'] = 72\n", - "import itertools\n", - "\n", "import numpy as np\n", + "import itertools\n", "\n", "%load_ext autoreload\n", "%autoreload 2" @@ -120,7 +117,7 @@ "all_mad = {}\n", "\n", "# this gets us all four possibilities\n", - "for t, (m1, m2) in itertools.product(['min', 'max'], zip(metrics, metrics[::-1], strict=False)):\n", + "for t, (m1, m2) in itertools.product(['min', 'max'], zip(metrics, metrics[::-1])):\n", " name = f'{m1.__name__}_{t}'\n", " # we set the seed like this to ensure that all four MADCompetition instances have the same initial_signal. Try different seed values! \n", " po.tools.set_seed(10)\n", @@ -171,7 +168,7 @@ "source": [ "fig, axes = plt.subplots(2, 2, figsize=(10, 10))\n", "pal = {'l1_norm': 'C0', 'l2_norm': 'C1'}\n", - "for ax, (k, mad) in zip(axes.flatten(), all_mad.items(), strict=False):\n", + "for ax, (k, mad) in zip(axes.flatten(), all_mad.items()):\n", " ax.plot(mad.optimized_metric_loss, pal[mad.optimized_metric.__name__], label=mad.optimized_metric.__name__)\n", " ax.plot(mad.reference_metric_loss, pal[mad.reference_metric.__name__], label=mad.reference_metric.__name__)\n", " ax.set(title=k.capitalize().replace('_', ' '), xlabel='Iteration', ylabel='Loss')\n", @@ -409,7 +406,7 @@ "all_mad = {}\n", "\n", "# this gets us all four possibilities\n", - "for t, (m1, m2) in itertools.product(['min', 'max'], zip(metrics, metrics[::-1], strict=False)):\n", + "for t, (m1, m2) in itertools.product(['min', 'max'], zip(metrics, metrics[::-1])):\n", " name = f'{m1.__name__}_{t}'\n", " # we set the seed like this to ensure that all four MADCompetition instances have the same initial_signal. Try different seed values! \n", " po.tools.set_seed(0)\n", diff --git a/examples/08_MAD_Competition.ipynb b/examples/08_MAD_Competition.ipynb index 9b16f3df..5688609c 100644 --- a/examples/08_MAD_Competition.ipynb +++ b/examples/08_MAD_Competition.ipynb @@ -35,12 +35,14 @@ } ], "source": [ - "import matplotlib.pyplot as plt\n", - "\n", "import plenoptic as po\n", - "\n", + "import imageio\n", + "import torch\n", + "import pyrtools as pt\n", + "import matplotlib.pyplot as plt\n", "# so that relative sizes of axes created by po.imshow and others look right\n", "plt.rcParams['figure.dpi'] = 72\n", + "import numpy as np\n", "import warnings\n", "\n", "%load_ext autoreload\n", diff --git a/examples/09_Original_MAD.ipynb b/examples/09_Original_MAD.ipynb index d731dc7e..7c02a123 100644 --- a/examples/09_Original_MAD.ipynb +++ b/examples/09_Original_MAD.ipynb @@ -17,8 +17,15 @@ "metadata": {}, "outputs": [], "source": [ + "import imageio\n", + "import torch\n", + "import scipy.io as sio\n", + "import pyrtools as pt\n", + "from scipy.io import loadmat\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", "import plenoptic as po\n", - "\n", + "import os.path as op\n", "%matplotlib inline\n", "\n", "%load_ext autoreload\n", diff --git a/examples/Demo_Eigendistortion.ipynb b/examples/Demo_Eigendistortion.ipynb index c811a5dc..558c0ad6 100644 --- a/examples/Demo_Eigendistortion.ipynb +++ b/examples/Demo_Eigendistortion.ipynb @@ -44,9 +44,8 @@ } ], "source": [ - "from plenoptic.simulate.models import OnOff\n", "from plenoptic.synthesize import Eigendistortion\n", - "\n", + "from plenoptic.simulate.models import OnOff\n", "# this notebook uses torchvision, which is an optional dependency.\n", "# if this fails, install torchvision in your plenoptic environment \n", "# and restart the notebook kernel.\n", @@ -58,7 +57,6 @@ " \"and restart the notebook kernel\")\n", "import torch\n", "from torch import nn\n", - "\n", "import plenoptic as po\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", diff --git a/examples/Display.ipynb b/examples/Display.ipynb index f3dbf6c8..a62db0da 100644 --- a/examples/Display.ipynb +++ b/examples/Display.ipynb @@ -18,10 +18,8 @@ "metadata": {}, "outputs": [], "source": [ - "import matplotlib.pyplot as plt\n", - "\n", "import plenoptic as po\n", - "\n", + "import matplotlib.pyplot as plt\n", "# so that relativfe sizes of axes created by po.imshow and others look right\n", "plt.rcParams['figure.dpi'] = 72\n", "# Animation-related settings\n", @@ -30,8 +28,8 @@ "plt.rcParams['animation.writer'] = 'ffmpeg'\n", "plt.rcParams['animation.ffmpeg_args'] = ['-threads', '1']\n", "\n", - "import numpy as np\n", "import torch\n", + "import numpy as np\n", "\n", "%load_ext autoreload\n", "%autoreload 2\n", diff --git a/examples/Metamer-Portilla-Simoncelli.ipynb b/examples/Metamer-Portilla-Simoncelli.ipynb index 4772e233..8e0e1816 100644 --- a/examples/Metamer-Portilla-Simoncelli.ipynb +++ b/examples/Metamer-Portilla-Simoncelli.ipynb @@ -15,13 +15,20 @@ } ], "source": [ - "\n", - "import einops\n", + "import numpy as np\n", + "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import torch\n", - "\n", "import plenoptic as po\n", - "\n", + "import scipy.io as sio\n", + "import os\n", + "import os.path as op\n", + "import einops\n", + "import glob\n", + "import math\n", + "import pyrtools as pt\n", + "from tqdm import tqdm\n", + "from PIL import Image\n", "%load_ext autoreload\n", "%autoreload \n", "\n", @@ -368,7 +375,7 @@ "# send image and PS model to GPU, if available. then im_init and Metamer will also use GPU\n", "img = img.to(DEVICE)\n", "model = po.simul.PortillaSimoncelli(img.shape[-2:]).to(DEVICE)\n", - "im_init = (torch.rand_like(img)-.5) * .1 + img.mean()\n", + "im_init = (torch.rand_like(img)-.5) * .1 + img.mean();\n", "\n", "met = po.synth.MetamerCTF(img, model, loss_function=po.tools.optim.l2_norm, initial_image=im_init,\n", " coarse_to_fine='together')\n", @@ -519,8 +526,6 @@ "# Be sure to run this cell.\n", "\n", "from collections import OrderedDict\n", - "\n", - "\n", "class PortillaSimoncelliRemove(po.simul.PortillaSimoncelli):\n", " r\"\"\"Model for measuring a subset of texture statistics reported by PortillaSimoncelli\n", "\n", @@ -665,7 +670,7 @@ "source": [ "# visualize results\n", "fig = po.imshow([metamer.image, metamer.metamer, metamer_remove.metamer], \n", - " title=['Target image', 'Full Statistics', 'Without Marginal Statistics'], vrange='auto1')\n", + " title=['Target image', 'Full Statistics', 'Without Marginal Statistics'], vrange='auto1');\n", "# add plots showing the different pixel intensity histograms\n", "fig.add_axes([.33, -1, .33, .9])\n", "fig.add_axes([.67, -1, .33, .9])\n", @@ -1372,8 +1377,8 @@ " target=None\n", " ):\n", " super().__init__(im_shape, n_scales=4, n_orientations=4, spatial_corr_width=9)\n", - " self.mask = mask\n", - " self.target = target\n", + " self.mask = mask;\n", + " self.target = target;\n", " \n", " def forward(self, image, scales=None):\n", " r\"\"\"Generate Texture Statistics representation of an image using the target for the masked portion\n", @@ -1434,7 +1439,7 @@ "source": [ "img_file = DATA_PATH / 'fig14b.jpg'\n", "img = po.tools.load_images(img_file).to(DEVICE)\n", - "im_init = (torch.rand_like(img)-.5) * .1 + img.mean()\n", + "im_init = (torch.rand_like(img)-.5) * .1 + img.mean();\n", "\n", "mask = torch.zeros(1,1,256,256).bool().to(DEVICE)\n", "ctr_dim = (img.shape[-2]//4, img.shape[-1]//4)\n", @@ -1990,6 +1995,7 @@ "metadata": {}, "outputs": [], "source": [ + "from collections import OrderedDict\n", "\n", "class PortillaSimoncelliMagMeans(po.simul.PortillaSimoncelli):\n", " r\"\"\"Include the magnitude means in the PS texture representation.\n", @@ -2137,11 +2143,11 @@ ], "source": [ "fig, axes = plt.subplots(2, 2, figsize=(21, 11), gridspec_kw={'width_ratios': [1, 3.1]})\n", - "for ax, im, info in zip(axes[:, 0], [met.metamer, met_mag_means.metamer], ['with', 'without'], strict=False):\n", + "for ax, im, info in zip(axes[:, 0], [met.metamer, met_mag_means.metamer], ['with', 'without']):\n", " po.imshow(im, ax=ax, title=f\"Metamer {info} magnitude means\")\n", " ax.xaxis.set_visible(False)\n", " ax.yaxis.set_visible(False)\n", - "model_mag_means.plot_representation(model_mag_means(met.metamer)-model_mag_means(img), ylim=(-.06, .06), ax=axes[0,1])\n", + "model_mag_means.plot_representation(model_mag_means(met.metamer)-model_mag_means(img), ylim=(-.06, .06), ax=axes[0,1]);\n", "model_mag_means.plot_representation(model_mag_means(met_mag_means.metamer)-model_mag_means(img), ylim=(-.06, .06), ax=axes[1,1]);" ] }, diff --git a/examples/Synthesis_extensions.ipynb b/examples/Synthesis_extensions.ipynb index 0e49b31c..d0d1efe1 100644 --- a/examples/Synthesis_extensions.ipynb +++ b/examples/Synthesis_extensions.ipynb @@ -21,15 +21,13 @@ }, "outputs": [], "source": [ - "import warnings\n", - "from collections.abc import Callable\n", - "from typing import Literal\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import torch\n", - "from torch import Tensor\n", - "\n", "import plenoptic as po\n", + "from torch import Tensor\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "import warnings\n", + "from typing import Union, Callable, Tuple, Optional\n", + "from typing_extensions import Literal\n", "\n", "# so that relative sizes of axes created by po.imshow and others look right\n", "plt.rcParams['figure.dpi'] = 72\n", @@ -48,13 +46,13 @@ "class MADCompetitionVariant(po.synth.MADCompetition):\n", " \"\"\"Initialize MADCompetition with an image instead!\"\"\"\n", " def __init__(self, image: Tensor,\n", - " optimized_metric: torch.nn.Module | Callable[[Tensor, Tensor], Tensor],\n", - " reference_metric: torch.nn.Module | Callable[[Tensor, Tensor], Tensor],\n", + " optimized_metric: Union[torch.nn.Module, Callable[[Tensor, Tensor], Tensor]],\n", + " reference_metric: Union[torch.nn.Module, Callable[[Tensor, Tensor], Tensor]],\n", " minmax: Literal['min', 'max'],\n", " initial_image: Tensor = None,\n", - " metric_tradeoff_lambda: float | None = None,\n", + " metric_tradeoff_lambda: Optional[float] = None,\n", " range_penalty_lambda: float = .1,\n", - " allowed_range: tuple[float, float] = (0, 1)):\n", + " allowed_range: Tuple[float, float] = (0, 1)):\n", " if initial_image is None:\n", " initial_image = torch.rand_like(image)\n", " super().__init__(image, optimized_metric, reference_metric,\n", diff --git a/noxfile.py b/noxfile.py index 111564db..58bc0d91 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,13 +1,11 @@ import nox - @nox.session(name="lint") def lint(session): # run linters session.install("ruff") session.run("ruff", "check", "--ignore", "D") - @nox.session(name="tests", python=["3.10", "3.11", "3.12"]) def tests(session): # run tests diff --git a/src/plenoptic/__init__.py b/src/plenoptic/__init__.py index 1b7f4621..a62bb3da 100644 --- a/src/plenoptic/__init__.py +++ b/src/plenoptic/__init__.py @@ -1,6 +1,10 @@ -from . import data, metric, tools from . import simulate as simul from . import synthesize as synth -from .tools.data import load_images, to_numpy -from .tools.display import animshow, imshow, pyrshow +from . import metric +from . import tools +from . import data + +from .tools.display import imshow, animshow, pyrshow +from .tools.data import to_numpy, load_images + from .version import version as __version__ diff --git a/src/plenoptic/data/__init__.py b/src/plenoptic/data/__init__.py index fd974a06..b6527ec8 100644 --- a/src/plenoptic/data/__init__.py +++ b/src/plenoptic/data/__init__.py @@ -1,38 +1,28 @@ -import torch - from . import data_utils -from .fetch import DOWNLOADABLE_FILES, fetch_data - -__all__ = [ - "einstein", - "curie", - "parrot", - "reptile_skin", - "color_wheel", - "fetch_data", - "DOWNLOADABLE_FILES", -] - +from .fetch import fetch_data, DOWNLOADABLE_FILES +import torch +__all__ = ['einstein', 'curie', 'parrot', 'reptile_skin', + 'color_wheel', 'fetch_data', 'DOWNLOADABLE_FILES'] def __dir__(): return __all__ def einstein() -> torch.Tensor: - return data_utils.get("einstein") + return data_utils.get('einstein') def curie() -> torch.Tensor: - return data_utils.get("curie") + return data_utils.get('curie') def parrot(as_gray: bool = False) -> torch.Tensor: - return data_utils.get("parrot", as_gray=as_gray) + return data_utils.get('parrot', as_gray=as_gray) def reptile_skin() -> torch.Tensor: - return data_utils.get("reptile_skin") + return data_utils.get('reptile_skin') def color_wheel(as_gray: bool = False) -> torch.Tensor: - return data_utils.get("color_wheel", as_gray=as_gray) + return data_utils.get('color_wheel', as_gray=as_gray) diff --git a/src/plenoptic/data/data_utils.py b/src/plenoptic/data/data_utils.py index cfce7003..037baffa 100644 --- a/src/plenoptic/data/data_utils.py +++ b/src/plenoptic/data/data_utils.py @@ -1,5 +1,7 @@ from importlib import resources from importlib.abc import Traversable +from typing import Union + from ..tools.data import load_images @@ -28,18 +30,12 @@ def get_path(item_name: str) -> Traversable: This function uses glob to search for files in the current directory matching the `item_name`. It is assumed that there is only one file matching the name regardless of its extension. """ - fhs = [ - file - for file in resources.files("plenoptic.data").iterdir() - if file.stem == item_name - ] - assert ( - len(fhs) == 1 - ), f"Expected exactly one file for {item_name}, but found {len(fhs)}." + fhs = [file for file in resources.files("plenoptic.data").iterdir() if file.stem == item_name] + assert len(fhs) == 1, f"Expected exactly one file for {item_name}, but found {len(fhs)}." return fhs[0] -def get(*item_names: str, as_gray: None | bool = None): +def get(*item_names: str, as_gray: Union[None, bool] = None): """Load an image based on the item name from the package's data resources. Parameters diff --git a/src/plenoptic/data/fetch.py b/src/plenoptic/data/fetch.py index 905f99a6..3606f644 100644 --- a/src/plenoptic/data/fetch.py +++ b/src/plenoptic/data/fetch.py @@ -5,64 +5,54 @@ """ REGISTRY = { - "plenoptic-test-files.tar.gz": "a6b8e03ecc8d7e40c505c88e6c767af5da670478d3bebb4e13a9d08ee4f39ae8", - "ssim_images.tar.gz": "19c1955921a3c37d30c88724fd5a13bdbc9620c9e7dfaeaa3ff835283d2bb42e", - "ssim_analysis.mat": "921d324783f06d1a2e6f1ce154c7ba9204f91c569772936991311ff299597f24", - "msssim_images.tar.gz": "a01273c95c231ba9e860dfc48f2ac8044ac3db13ad7061739c29ea5f9f20382c", - "MAD_results.tar.gz": "29794ed7dc14626f115b9e4173bff88884cb356378a1d4f1f6cd940dd5b31dbe", - "portilla_simoncelli_matlab_test_vectors.tar.gz": "83087d4d9808a3935b8eb4197624bbae19007189cd0d786527084c98b0b0ab81", - "portilla_simoncelli_test_vectors.tar.gz": "d67787620a0cf13addbe4588ec05b276105ff1fad46e72f8c58d99f184386dfb", - "portilla_simoncelli_images.tar.gz": "4d3228fbb51de45b4fc81eba590d20f5861a54d9e46766c8431ab08326e80827", - "portilla_simoncelli_synthesize.npz": "9c304580cd60e0275a2ef663187eccb71f2e6a31883c88acf4c6a699f4854c80", - "portilla_simoncelli_synthesize_torch_v1.12.0.npz": "5a76ef223bac641c9d48a0b7f49b3ce0a05c12a48e96cd309866b1e7d5e4473f", - "portilla_simoncelli_synthesize_gpu.npz": "324efc2a6c54382aae414d361c099394227b56cd24460eebab2532f70728c3ee", - "portilla_simoncelli_scales.npz": "eae2db6bd5db7d37c28d8f8320c4dd4fa5ab38294f5be22f8cf69e5cd5e4936a", - "sample_images.tar.gz": "0ba6fe668a61e9f3cb52032da740fbcf32399ffcc142ddb14380a8e404409bf5", - "test_images.tar.gz": "eaf35f5f6136e2d51e513f00202a11188a85cae8c6f44141fb9666de25ae9554", - "tid2013.tar.gz": "bc486ac749b6cfca8dc5f5340b04b9bb01ab24149a5f3a712f13e9d0489dcde0", - "portilla_simoncelli_test_vectors_refactor.tar.gz": "2ca60f1a60b192668567eb3d94c0cdc8679b23bf94a98331890c41eb9406503a", - "portilla_simoncelli_synthesize_torch_v1.12.0_ps-refactor.npz": "9525844b71cf81509b86ed9677172745353588c6bb54e4de8000d695598afa47", - "portilla_simoncelli_synthesize_gpu_ps-refactor.npz": "9fbb490f1548133f6aa49c54832130cf70f8dc6546af59688ead17f62ab94e61", - "portilla_simoncelli_scales_ps-refactor.npz": "1053790a37707627810482beb0dd059cbc193efd688b60c441b3548a75626fdf", + 'plenoptic-test-files.tar.gz': 'a6b8e03ecc8d7e40c505c88e6c767af5da670478d3bebb4e13a9d08ee4f39ae8', + 'ssim_images.tar.gz': '19c1955921a3c37d30c88724fd5a13bdbc9620c9e7dfaeaa3ff835283d2bb42e', + 'ssim_analysis.mat': '921d324783f06d1a2e6f1ce154c7ba9204f91c569772936991311ff299597f24', + 'msssim_images.tar.gz': 'a01273c95c231ba9e860dfc48f2ac8044ac3db13ad7061739c29ea5f9f20382c', + 'MAD_results.tar.gz': '29794ed7dc14626f115b9e4173bff88884cb356378a1d4f1f6cd940dd5b31dbe', + 'portilla_simoncelli_matlab_test_vectors.tar.gz': '83087d4d9808a3935b8eb4197624bbae19007189cd0d786527084c98b0b0ab81', + 'portilla_simoncelli_test_vectors.tar.gz': 'd67787620a0cf13addbe4588ec05b276105ff1fad46e72f8c58d99f184386dfb', + 'portilla_simoncelli_images.tar.gz': '4d3228fbb51de45b4fc81eba590d20f5861a54d9e46766c8431ab08326e80827', + 'portilla_simoncelli_synthesize.npz': '9c304580cd60e0275a2ef663187eccb71f2e6a31883c88acf4c6a699f4854c80', + 'portilla_simoncelli_synthesize_torch_v1.12.0.npz': '5a76ef223bac641c9d48a0b7f49b3ce0a05c12a48e96cd309866b1e7d5e4473f', + 'portilla_simoncelli_synthesize_gpu.npz': '324efc2a6c54382aae414d361c099394227b56cd24460eebab2532f70728c3ee', + 'portilla_simoncelli_scales.npz': 'eae2db6bd5db7d37c28d8f8320c4dd4fa5ab38294f5be22f8cf69e5cd5e4936a', + 'sample_images.tar.gz': '0ba6fe668a61e9f3cb52032da740fbcf32399ffcc142ddb14380a8e404409bf5', + 'test_images.tar.gz': 'eaf35f5f6136e2d51e513f00202a11188a85cae8c6f44141fb9666de25ae9554', + 'tid2013.tar.gz': 'bc486ac749b6cfca8dc5f5340b04b9bb01ab24149a5f3a712f13e9d0489dcde0', + 'portilla_simoncelli_test_vectors_refactor.tar.gz': '2ca60f1a60b192668567eb3d94c0cdc8679b23bf94a98331890c41eb9406503a', + 'portilla_simoncelli_synthesize_torch_v1.12.0_ps-refactor.npz': '9525844b71cf81509b86ed9677172745353588c6bb54e4de8000d695598afa47', + 'portilla_simoncelli_synthesize_gpu_ps-refactor.npz': '9fbb490f1548133f6aa49c54832130cf70f8dc6546af59688ead17f62ab94e61', + 'portilla_simoncelli_scales_ps-refactor.npz': '1053790a37707627810482beb0dd059cbc193efd688b60c441b3548a75626fdf', } OSF_TEMPLATE = "https://osf.io/{}/download" # these are all from the OSF project at https://osf.io/ts37w/. REGISTRY_URLS = { - "plenoptic-test-files.tar.gz": OSF_TEMPLATE.format("q9kn8"), - "ssim_images.tar.gz": OSF_TEMPLATE.format("j65tw"), - "ssim_analysis.mat": OSF_TEMPLATE.format("ndtc7"), - "msssim_images.tar.gz": OSF_TEMPLATE.format("5fuba"), - "MAD_results.tar.gz": OSF_TEMPLATE.format("jwcsr"), - "portilla_simoncelli_matlab_test_vectors.tar.gz": OSF_TEMPLATE.format( - "qtn5y" - ), - "portilla_simoncelli_test_vectors.tar.gz": OSF_TEMPLATE.format("8r2gq"), - "portilla_simoncelli_images.tar.gz": OSF_TEMPLATE.format("eqr3t"), - "portilla_simoncelli_synthesize.npz": OSF_TEMPLATE.format("a7p9r"), - "portilla_simoncelli_synthesize_torch_v1.12.0.npz": OSF_TEMPLATE.format( - "gbv8e" - ), - "portilla_simoncelli_synthesize_gpu.npz": OSF_TEMPLATE.format("tn4y8"), - "portilla_simoncelli_scales.npz": OSF_TEMPLATE.format("xhwv3"), - "sample_images.tar.gz": OSF_TEMPLATE.format("6drmy"), - "test_images.tar.gz": OSF_TEMPLATE.format("au3b8"), - "tid2013.tar.gz": OSF_TEMPLATE.format("uscgv"), - "portilla_simoncelli_test_vectors_refactor.tar.gz": OSF_TEMPLATE.format( - "ca7qt" - ), - "portilla_simoncelli_synthesize_torch_v1.12.0_ps-refactor.npz": OSF_TEMPLATE.format( - "vmwzd" - ), - "portilla_simoncelli_synthesize_gpu_ps-refactor.npz": OSF_TEMPLATE.format( - "mqs6y" - ), - "portilla_simoncelli_scales_ps-refactor.npz": OSF_TEMPLATE.format("nvpr4"), + 'plenoptic-test-files.tar.gz': OSF_TEMPLATE.format('q9kn8'), + 'ssim_images.tar.gz': OSF_TEMPLATE.format('j65tw'), + 'ssim_analysis.mat': OSF_TEMPLATE.format('ndtc7'), + 'msssim_images.tar.gz': OSF_TEMPLATE.format('5fuba'), + 'MAD_results.tar.gz': OSF_TEMPLATE.format('jwcsr'), + 'portilla_simoncelli_matlab_test_vectors.tar.gz': OSF_TEMPLATE.format('qtn5y'), + 'portilla_simoncelli_test_vectors.tar.gz': OSF_TEMPLATE.format('8r2gq'), + 'portilla_simoncelli_images.tar.gz': OSF_TEMPLATE.format('eqr3t'), + 'portilla_simoncelli_synthesize.npz': OSF_TEMPLATE.format('a7p9r'), + 'portilla_simoncelli_synthesize_torch_v1.12.0.npz': OSF_TEMPLATE.format('gbv8e'), + 'portilla_simoncelli_synthesize_gpu.npz': OSF_TEMPLATE.format('tn4y8'), + 'portilla_simoncelli_scales.npz': OSF_TEMPLATE.format('xhwv3'), + 'sample_images.tar.gz': OSF_TEMPLATE.format('6drmy'), + 'test_images.tar.gz': OSF_TEMPLATE.format('au3b8'), + 'tid2013.tar.gz': OSF_TEMPLATE.format('uscgv'), + 'portilla_simoncelli_test_vectors_refactor.tar.gz': OSF_TEMPLATE.format('ca7qt'), + 'portilla_simoncelli_synthesize_torch_v1.12.0_ps-refactor.npz': OSF_TEMPLATE.format('vmwzd'), + 'portilla_simoncelli_synthesize_gpu_ps-refactor.npz': OSF_TEMPLATE.format('mqs6y'), + 'portilla_simoncelli_scales_ps-refactor.npz': OSF_TEMPLATE.format('nvpr4'), } DOWNLOADABLE_FILES = list(REGISTRY_URLS.keys()) import pathlib - +from typing import List try: import pooch except ImportError: @@ -73,7 +63,7 @@ # Use the default cache folder for the operating system # Pooch uses appdirs (https://github.com/ActiveState/appdirs) to # select an appropriate directory for the cache on each platform. - path=pooch.os_cache("plenoptic"), + path=pooch.os_cache('plenoptic'), base_url="", urls=REGISTRY_URLS, registry=REGISTRY, @@ -82,7 +72,7 @@ ) -def find_shared_directory(paths: list[pathlib.Path]) -> pathlib.Path: +def find_shared_directory(paths: List[pathlib.Path]) -> pathlib.Path: """Find directory shared by all paths.""" for dir in paths[0].parents: if all([dir in p.parents for p in paths]): @@ -102,19 +92,17 @@ def fetch_data(dataset_name: str) -> pathlib.Path: """ if retriever is None: - raise ImportError( - "Missing optional dependency 'pooch'." - " Please use pip or " - "conda to install 'pooch'." - ) - if dataset_name.endswith(".tar.gz"): + raise ImportError("Missing optional dependency 'pooch'." + " Please use pip or " + "conda to install 'pooch'.") + if dataset_name.endswith('.tar.gz'): processor = pooch.Untar() else: processor = None - fname = retriever.fetch( - dataset_name, progressbar=True, processor=processor - ) - if dataset_name.endswith(".tar.gz"): + fname = retriever.fetch(dataset_name, + progressbar=True, + processor=processor) + if dataset_name.endswith('.tar.gz'): fname = find_shared_directory([pathlib.Path(f) for f in fname]) else: fname = pathlib.Path(fname) diff --git a/src/plenoptic/metric/__init__.py b/src/plenoptic/metric/__init__.py index 5e4c47e4..6f4e6f5e 100644 --- a/src/plenoptic/metric/__init__.py +++ b/src/plenoptic/metric/__init__.py @@ -1,4 +1,4 @@ -from .classes import NLP +from .perceptual_distance import ssim, ms_ssim, nlpd, ssim_map from .model_metric import model_metric from .naive import mse -from .perceptual_distance import ms_ssim, nlpd, ssim, ssim_map +from .classes import NLP diff --git a/src/plenoptic/metric/classes.py b/src/plenoptic/metric/classes.py index 52206cde..6bc83860 100644 --- a/src/plenoptic/metric/classes.py +++ b/src/plenoptic/metric/classes.py @@ -1,5 +1,4 @@ import torch - from .perceptual_distance import normalized_laplacian_pyramid @@ -16,7 +15,6 @@ class NLP(torch.nn.Module): ``torch.sqrt(torch.mean(x-y)**2))`` as the distance metric between representations. """ - def __init__(self): super().__init__() @@ -38,16 +36,10 @@ def forward(self, image): """ if image.shape[0] > 1 or image.shape[1] > 1: - raise Exception( - "For now, this only supports batch and channel size 1" - ) + raise Exception("For now, this only supports batch and channel size 1") activations = normalized_laplacian_pyramid(image) # activations is a list of tensors, each at a different scale # (down-sampled by factors of 2). To combine these into one # vector, we need to flatten each of them and then unsqueeze so # it is 3d - return ( - torch.cat([i.flatten() for i in activations]) - .unsqueeze(0) - .unsqueeze(0) - ) + return torch.cat([i.flatten() for i in activations]).unsqueeze(0).unsqueeze(0) diff --git a/src/plenoptic/metric/perceptual_distance.py b/src/plenoptic/metric/perceptual_distance.py index efeb9515..f70fd003 100644 --- a/src/plenoptic/metric/perceptual_distance.py +++ b/src/plenoptic/metric/perceptual_distance.py @@ -1,14 +1,15 @@ -import os -import warnings - import numpy as np import torch import torch.nn.functional as F +import warnings from ..simulate.canonical_computations import LaplacianPyramid from ..simulate.canonical_computations.filters import circular_gaussian2d from ..tools.conv import same_padding +import os +import pickle + DIRNAME = os.path.dirname(__file__) @@ -36,39 +37,25 @@ def _ssim_parts(img1, img2, pad=False): these work. """ - img_ranges = torch.as_tensor( - [[img1.min(), img1.max()], [img2.min(), img2.max()]] - ) + img_ranges = torch.as_tensor([[img1.min(), img1.max()], [img2.min(), img2.max()]]) if (img_ranges > 1).any() or (img_ranges < 0).any(): - warnings.warn( - "Image range falls outside [0, 1]." - f" img1: {img_ranges[0]}, img2: {img_ranges[1]}. " - "Continuing anyway..." - ) + warnings.warn("Image range falls outside [0, 1]." + f" img1: {img_ranges[0]}, img2: {img_ranges[1]}. " + "Continuing anyway...") if not img1.ndim == img2.ndim == 4: - raise Exception( - "Input images should have four dimensions: (batch, channel, height, width)" - ) + raise Exception("Input images should have four dimensions: (batch, channel, height, width)") if img1.shape[-2:] != img2.shape[-2:]: raise Exception("img1 and img2 must have the same height and width!") for i in range(2): - if ( - img1.shape[i] != img2.shape[i] - and img1.shape[i] != 1 - and img2.shape[i] != 1 - ): - raise Exception( - "Either img1 and img2 should have the same number of " - "elements in each dimension, or one of " - "them should be 1! But got shapes " - f"{img1.shape}, {img2.shape} instead" - ) + if img1.shape[i] != img2.shape[i] and img1.shape[i] != 1 and img2.shape[i] != 1: + raise Exception("Either img1 and img2 should have the same number of " + "elements in each dimension, or one of " + "them should be 1! But got shapes " + f"{img1.shape}, {img2.shape} instead") if img1.shape[1] > 1 or img2.shape[1] > 1: - warnings.warn( - "SSIM was designed for grayscale images and here it will be computed separately for each " - "channel (so channels are treated in the same way as batches)." - ) + warnings.warn("SSIM was designed for grayscale images and here it will be computed separately for each " + "channel (so channels are treated in the same way as batches).") if img1.dtype != img2.dtype: raise ValueError("Input images must have same dtype!") @@ -92,13 +79,9 @@ def _ssim_parts(img1, img2, pad=False): def windowed_average(img): padd = 0 (n_batches, n_channels, _, _) = img.shape - img = img.reshape( - n_batches * n_channels, 1, img.shape[2], img.shape[3] - ) + img = img.reshape(n_batches * n_channels, 1, img.shape[2], img.shape[3]) img_average = F.conv2d(img, window, padding=padd) - img_average = img_average.reshape( - n_batches, n_channels, img_average.shape[2], img_average.shape[3] - ) + img_average = img_average.reshape(n_batches, n_channels, img_average.shape[2], img_average.shape[3]) return img_average mu1 = windowed_average(img1) @@ -112,20 +95,18 @@ def windowed_average(img): sigma2_sq = windowed_average(img2 * img2) - mu2_sq sigma12 = windowed_average(img1 * img2) - mu1_mu2 - C1 = 0.01**2 - C2 = 0.03**2 + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 # SSIM is the product of a luminance component, a contrast component, and a # structure component. The contrast-structure component has to be separated # when computing MS-SSIM. luminance_map = (2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1) - contrast_structure_map = (2.0 * sigma12 + C2) / ( - sigma1_sq + sigma2_sq + C2 - ) + contrast_structure_map = (2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) map_ssim = luminance_map * contrast_structure_map # the weight used for stability - weight = torch.log((1 + sigma1_sq / C2) * (1 + sigma2_sq / C2)) + weight = torch.log((1 + sigma1_sq/C2) * (1 + sigma2_sq/C2)) return map_ssim, contrast_structure_map, weight @@ -209,14 +190,12 @@ def ssim(img1, img2, weighted=False, pad=False): if not weighted: mssim = map_ssim.mean((-1, -2)) else: - mssim = (map_ssim * weight).sum((-1, -2)) / weight.sum((-1, -2)) + mssim = (map_ssim*weight).sum((-1, -2)) / weight.sum((-1, -2)) if min(img1.shape[2], img1.shape[3]) < 11: - warnings.warn( - "SSIM uses 11x11 convolutional kernel, but the height and/or " - "the width of the input image is smaller than 11, so the " - "kernel size is set to be the minimum of these two numbers." - ) + warnings.warn("SSIM uses 11x11 convolutional kernel, but the height and/or " + "the width of the input image is smaller than 11, so the " + "kernel size is set to be the minimum of these two numbers.") return mssim @@ -278,11 +257,9 @@ def ssim_map(img1, img2): """ if min(img1.shape[2], img1.shape[3]) < 11: - warnings.warn( - "SSIM uses 11x11 convolutional kernel, but the height and/or " - "the width of the input image is smaller than 11, so the " - "kernel size is set to be the minimum of these two numbers." - ) + warnings.warn("SSIM uses 11x11 convolutional kernel, but the height and/or " + "the width of the input image is smaller than 11, so the " + "kernel size is set to be the minimum of these two numbers.") return _ssim_parts(img1, img2)[0] @@ -349,30 +326,24 @@ def ms_ssim(img1, img2, power_factors=None): power_factors = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] def downsample(img): - img = F.pad( - img, (0, img.shape[3] % 2, 0, img.shape[2] % 2), mode="replicate" - ) + img = F.pad(img, (0, img.shape[3] % 2, 0, img.shape[2] % 2), mode="replicate") img = F.avg_pool2d(img, kernel_size=2) return img msssim = 1 for i in range(len(power_factors) - 1): _, contrast_structure_map, _ = _ssim_parts(img1, img2) - msssim *= F.relu(contrast_structure_map.mean((-1, -2))).pow( - power_factors[i] - ) + msssim *= F.relu(contrast_structure_map.mean((-1, -2))).pow(power_factors[i]) img1 = downsample(img1) img2 = downsample(img2) map_ssim, _, _ = _ssim_parts(img1, img2) msssim *= F.relu(map_ssim.mean((-1, -2))).pow(power_factors[-1]) if min(img1.shape[2], img1.shape[3]) < 11: - warnings.warn( - "SSIM uses 11x11 convolutional kernel, but for some scales " - "of the input image, the height and/or the width is smaller " - "than 11, so the kernel size in SSIM is set to be the " - "minimum of these two numbers for these scales." - ) + warnings.warn("SSIM uses 11x11 convolutional kernel, but for some scales " + "of the input image, the height and/or the width is smaller " + "than 11, so the kernel size in SSIM is set to be the " + "minimum of these two numbers for these scales.") return msssim @@ -395,8 +366,8 @@ def normalized_laplacian_pyramid(img): (_, channel, height, width) = img.size() N_scales = 6 - spatialpooling_filters = np.load(os.path.join(DIRNAME, "DN_filts.npy")) - sigmas = np.load(os.path.join(DIRNAME, "DN_sigmas.npy")) + spatialpooling_filters = np.load(os.path.join(DIRNAME, 'DN_filts.npy')) + sigmas = np.load(os.path.join(DIRNAME, 'DN_sigmas.npy')) L = LaplacianPyramid(n_scales=N_scales, scale_filter=True) laplacian_activations = L.forward(img) @@ -404,18 +375,10 @@ def normalized_laplacian_pyramid(img): padd = 2 normalized_laplacian_activations = [] for N_b in range(0, N_scales): - filt = torch.as_tensor( - spatialpooling_filters[N_b], dtype=torch.float32, device=img.device - ).repeat(channel, 1, 1, 1) - filtered_activations = F.conv2d( - torch.abs(laplacian_activations[N_b]), - filt, - padding=padd, - groups=channel, - ) - normalized_laplacian_activations.append( - laplacian_activations[N_b] / (sigmas[N_b] + filtered_activations) - ) + filt = torch.as_tensor(spatialpooling_filters[N_b], dtype=torch.float32, + device=img.device).repeat(channel, 1, 1, 1) + filtered_activations = F.conv2d(torch.abs(laplacian_activations[N_b]), filt, padding=padd, groups=channel) + normalized_laplacian_activations.append(laplacian_activations[N_b] / (sigmas[N_b] + filtered_activations)) return normalized_laplacian_activations @@ -462,47 +425,31 @@ def nlpd(img1, img2): """ if not img1.ndim == img2.ndim == 4: - raise Exception( - "Input images should have four dimensions: (batch, channel, height, width)" - ) + raise Exception("Input images should have four dimensions: (batch, channel, height, width)") if img1.shape[-2:] != img2.shape[-2:]: raise Exception("img1 and img2 must have the same height and width!") for i in range(2): - if ( - img1.shape[i] != img2.shape[i] - and img1.shape[i] != 1 - and img2.shape[i] != 1 - ): - raise Exception( - "Either img1 and img2 should have the same number of " - "elements in each dimension, or one of " - "them should be 1! But got shapes " - f"{img1.shape}, {img2.shape} instead" - ) + if img1.shape[i] != img2.shape[i] and img1.shape[i] != 1 and img2.shape[i] != 1: + raise Exception("Either img1 and img2 should have the same number of " + "elements in each dimension, or one of " + "them should be 1! But got shapes " + f"{img1.shape}, {img2.shape} instead") if img1.shape[1] > 1 or img2.shape[1] > 1: - warnings.warn( - "NLPD was designed for grayscale images and here it will be computed separately for each " - "channel (so channels are treated in the same way as batches)." - ) - - img_ranges = torch.as_tensor( - [[img1.min(), img1.max()], [img2.min(), img2.max()]] - ) + warnings.warn("NLPD was designed for grayscale images and here it will be computed separately for each " + "channel (so channels are treated in the same way as batches).") + + img_ranges = torch.as_tensor([[img1.min(), img1.max()], [img2.min(), img2.max()]]) if (img_ranges > 1).any() or (img_ranges < 0).any(): - warnings.warn( - "Image range falls outside [0, 1]." - f" img1: {img_ranges[0]}, img2: {img_ranges[1]}. " - "Continuing anyway..." - ) - + warnings.warn("Image range falls outside [0, 1]." + f" img1: {img_ranges[0]}, img2: {img_ranges[1]}. " + "Continuing anyway...") + y1 = normalized_laplacian_pyramid(img1) y2 = normalized_laplacian_pyramid(img2) epsilon = 1e-10 # for optimization purpose (stabilizing the gradient around zero) dist = [] for i in range(6): - dist.append( - torch.sqrt(torch.mean((y1[i] - y2[i]) ** 2, dim=(2, 3)) + epsilon) - ) + dist.append(torch.sqrt(torch.mean((y1[i] - y2[i]) ** 2, dim=(2, 3)) + epsilon)) return torch.stack(dist).mean(dim=0) diff --git a/src/plenoptic/simulate/__init__.py b/src/plenoptic/simulate/__init__.py index c82eb526..9659b0ce 100644 --- a/src/plenoptic/simulate/__init__.py +++ b/src/plenoptic/simulate/__init__.py @@ -1,2 +1,2 @@ -from .canonical_computations import * from .models import * +from .canonical_computations import * diff --git a/src/plenoptic/simulate/canonical_computations/__init__.py b/src/plenoptic/simulate/canonical_computations/__init__.py index 49d69cc4..b51ca84b 100644 --- a/src/plenoptic/simulate/canonical_computations/__init__.py +++ b/src/plenoptic/simulate/canonical_computations/__init__.py @@ -1,4 +1,4 @@ -from .filters import * from .laplacian_pyramid import LaplacianPyramid -from .non_linearities import * from .steerable_pyramid_freq import SteerablePyramidFreq +from .non_linearities import * +from .filters import * diff --git a/src/plenoptic/simulate/canonical_computations/filters.py b/src/plenoptic/simulate/canonical_computations/filters.py index d45c4568..098d7a79 100644 --- a/src/plenoptic/simulate/canonical_computations/filters.py +++ b/src/plenoptic/simulate/canonical_computations/filters.py @@ -1,10 +1,13 @@ +from typing import Union, Tuple + import torch from torch import Tensor +from warnings import warn __all__ = ["gaussian1d", "circular_gaussian2d"] -def gaussian1d(kernel_size: int = 11, std: float | Tensor = 1.5) -> Tensor: +def gaussian1d(kernel_size: int = 11, std: Union[float, Tensor] = 1.5) -> Tensor: """Normalized 1D Gaussian. 1d Gaussian of size `kernel_size`, centered half-way, with variable std @@ -32,14 +35,14 @@ def gaussian1d(kernel_size: int = 11, std: float | Tensor = 1.5) -> Tensor: x = torch.arange(kernel_size).to(device) mu = kernel_size // 2 - gauss = torch.exp(-((x - mu) ** 2) / (2 * std**2)) + gauss = torch.exp(-((x - mu) ** 2) / (2 * std ** 2)) filt = gauss / gauss.sum() # normalize return filt def circular_gaussian2d( - kernel_size: int | tuple[int, int], - std: float | Tensor, + kernel_size: Union[int, Tuple[int, int]], + std: Union[float, Tensor], out_channels: int = 1, ) -> Tensor: """Creates normalized, centered circular 2D gaussian tensor with which to convolve. @@ -72,23 +75,17 @@ def circular_gaussian2d( assert out_channels >= 1, "number of filters must be positive integer" assert torch.all(std > 0.0), "stdev must be positive" assert len(std) == out_channels, "Number of stds must equal out_channels" - origin = torch.as_tensor( - ((kernel_size[0] + 1) / 2.0, (kernel_size[1] + 1) / 2.0) - ) + origin = torch.as_tensor(((kernel_size[0] + 1) / 2.0, (kernel_size[1] + 1) / 2.0)) origin = origin.to(device) - shift_y = ( - torch.arange(1, kernel_size[0] + 1, device=device) - origin[0] - ) # height - shift_x = ( - torch.arange(1, kernel_size[1] + 1, device=device) - origin[1] - ) # width + shift_y = torch.arange(1, kernel_size[0] + 1, device=device) - origin[0] # height + shift_x = torch.arange(1, kernel_size[1] + 1, device=device) - origin[1] # width (xramp, yramp) = torch.meshgrid(shift_y, shift_x) - log_filt = (xramp**2) + (yramp**2) + log_filt = ((xramp ** 2) + (yramp ** 2)) log_filt = log_filt.repeat(out_channels, 1, 1, 1) # 4D - log_filt = log_filt / (-2.0 * std**2).view(out_channels, 1, 1, 1) + log_filt = log_filt / (-2. * std ** 2).view(out_channels, 1, 1, 1) filt = torch.exp(log_filt) filt = filt / torch.sum(filt, dim=[1, 2, 3], keepdim=True) # normalize diff --git a/src/plenoptic/simulate/canonical_computations/laplacian_pyramid.py b/src/plenoptic/simulate/canonical_computations/laplacian_pyramid.py index 53fac227..d51e3955 100644 --- a/src/plenoptic/simulate/canonical_computations/laplacian_pyramid.py +++ b/src/plenoptic/simulate/canonical_computations/laplacian_pyramid.py @@ -1,12 +1,11 @@ import torch import torch.nn as nn - from ...tools.conv import blur_downsample, upsample_blur class LaplacianPyramid(nn.Module): """Laplacian Pyramid in Torch. - + The Laplacian pyramid [1]_ is a multiscale image representation. It decomposes the image by computing the local mean using Gaussian blurring filters and substracting it from the image and repeating this operation on diff --git a/src/plenoptic/simulate/canonical_computations/non_linearities.py b/src/plenoptic/simulate/canonical_computations/non_linearities.py index 839918c7..fec6a59c 100644 --- a/src/plenoptic/simulate/canonical_computations/non_linearities.py +++ b/src/plenoptic/simulate/canonical_computations/non_linearities.py @@ -1,7 +1,6 @@ import torch - from ...tools.conv import blur_downsample, upsample_blur -from ...tools.signal import polar_to_rectangular, rectangular_to_polar +from ...tools.signal import rectangular_to_polar, polar_to_rectangular def rectangular_to_polar_dict(coeff_dict, residuals=False): @@ -29,12 +28,12 @@ def rectangular_to_polar_dict(coeff_dict, residuals=False): state = {} for key in coeff_dict.keys(): # ignore residuals - if isinstance(key, tuple) or not key.startswith("residual"): + if isinstance(key, tuple) or not key.startswith('residual'): energy[key], state[key] = rectangular_to_polar(coeff_dict[key]) if residuals: - energy["residual_lowpass"] = coeff_dict["residual_lowpass"] - energy["residual_highpass"] = coeff_dict["residual_highpass"] + energy['residual_lowpass'] = coeff_dict['residual_lowpass'] + energy['residual_highpass'] = coeff_dict['residual_highpass'] return energy, state @@ -64,12 +63,12 @@ def polar_to_rectangular_dict(energy, state, residuals=True): for key in energy.keys(): # ignore residuals - if isinstance(key, tuple) or not key.startswith("residual"): + if isinstance(key, tuple) or not key.startswith('residual'): coeff_dict[key] = polar_to_rectangular(energy[key], state[key]) if residuals: - coeff_dict["residual_lowpass"] = energy["residual_lowpass"] - coeff_dict["residual_highpass"] = energy["residual_highpass"] + coeff_dict['residual_lowpass'] = energy['residual_lowpass'] + coeff_dict['residual_highpass'] = energy['residual_highpass'] return coeff_dict @@ -112,7 +111,7 @@ def local_gain_control(x, epsilon=1e-8): # these could be parameters, but no use case so far p = 2.0 - norm = blur_downsample(torch.abs(x**p)).pow(1 / p) + norm = blur_downsample(torch.abs(x ** p)).pow(1 / p) odd = torch.as_tensor(x.shape)[2:4] % 2 direction = x / (upsample_blur(norm, odd) + epsilon) @@ -191,12 +190,12 @@ def local_gain_control_dict(coeff_dict, residuals=True): state = {} for key in coeff_dict.keys(): - if isinstance(key, tuple) or not key.startswith("residual"): + if isinstance(key, tuple) or not key.startswith('residual'): energy[key], state[key] = local_gain_control(coeff_dict[key]) if residuals: - energy["residual_lowpass"] = coeff_dict["residual_lowpass"] - energy["residual_highpass"] = coeff_dict["residual_highpass"] + energy['residual_lowpass'] = coeff_dict['residual_lowpass'] + energy['residual_highpass'] = coeff_dict['residual_highpass'] return energy, state @@ -231,11 +230,11 @@ def local_gain_release_dict(energy, state, residuals=True): coeff_dict = {} for key in energy.keys(): - if isinstance(key, tuple) or not key.startswith("residual"): + if isinstance(key, tuple) or not key.startswith('residual'): coeff_dict[key] = local_gain_release(energy[key], state[key]) if residuals: - coeff_dict["residual_lowpass"] = energy["residual_lowpass"] - coeff_dict["residual_highpass"] = energy["residual_highpass"] + coeff_dict['residual_lowpass'] = energy['residual_lowpass'] + coeff_dict['residual_highpass'] = energy['residual_highpass'] return coeff_dict diff --git a/src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py b/src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py index 4b8fc189..5a6cf090 100644 --- a/src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py +++ b/src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py @@ -5,24 +5,23 @@ """ import warnings from collections import OrderedDict -from typing import Literal, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch import torch.fft as fft import torch.nn as nn from einops import rearrange -from numpy.typing import NDArray from scipy.special import factorial from torch import Tensor +from typing_extensions import Literal +from numpy.typing import NDArray from ...tools.signal import interpolate1d, raised_cosine, steer complex_types = [torch.cdouble, torch.cfloat] SCALES_TYPE = Union[int, Literal["residual_lowpass", "residual_highpass"]] -KEYS_TYPE = Union[ - tuple[int, int], Literal["residual_lowpass", "residual_highpass"] -] +KEYS_TYPE = Union[Tuple[int, int], Literal["residual_lowpass", "residual_highpass"]] class SteerablePyramidFreq(nn.Module): @@ -96,14 +95,15 @@ class SteerablePyramidFreq(nn.Module): def __init__( self, - image_shape: tuple[int, int], - height: Literal["auto"] | int = "auto", + image_shape: Tuple[int, int], + height: Union[Literal["auto"], int] = "auto", order: int = 3, twidth: int = 1, is_complex: bool = False, downsample: bool = True, tight_frame: bool = False, ): + super().__init__() self.pyr_size = OrderedDict() @@ -111,9 +111,7 @@ def __init__( self.image_shape = image_shape if (self.image_shape[0] % 2 != 0) or (self.image_shape[1] % 2 != 0): - warnings.warn( - "Reconstruction will not be perfect with odd-sized images" - ) + warnings.warn("Reconstruction will not be perfect with odd-sized images") self.is_complex = is_complex self.downsample = downsample @@ -131,16 +129,11 @@ def __init__( ) self.alpha = (self.Xcosn + np.pi) % (2 * np.pi) - np.pi - max_ht = ( - np.floor(np.log2(min(self.image_shape[0], self.image_shape[1]))) - - 2 - ) + max_ht = np.floor(np.log2(min(self.image_shape[0], self.image_shape[1]))) - 2 if height == "auto": self.num_scales = int(max_ht) elif height > max_ht: - raise ValueError( - "Cannot build pyramid higher than %d levels." % (max_ht) - ) + raise ValueError("Cannot build pyramid higher than %d levels." % (max_ht)) else: self.num_scales = int(height) @@ -158,8 +151,7 @@ def __init__( ctr = np.ceil((np.array(dims) + 0.5) / 2).astype(int) (xramp, yramp) = np.meshgrid( - np.linspace(-1, 1, dims[1] + 1)[:-1], - np.linspace(-1, 1, dims[0] + 1)[:-1], + np.linspace(-1, 1, dims[1] + 1)[:-1], np.linspace(-1, 1, dims[0] + 1)[:-1] ) self.angle = np.arctan2(yramp, xramp) @@ -168,9 +160,7 @@ def __init__( self.log_rad = np.log2(log_rad) # radial transition function (a raised cosine in log-frequency): - self.Xrcos, Yrcos = raised_cosine( - twidth, (-twidth / 2.0), np.array([0, 1]) - ) + self.Xrcos, Yrcos = raised_cosine(twidth, (-twidth / 2.0), np.array([0, 1])) self.Yrcos = np.sqrt(Yrcos) self.YIrcos = np.sqrt(1.0 - self.Yrcos**2) @@ -178,8 +168,9 @@ def __init__( # create low and high masks lo0mask = interpolate1d(self.log_rad, self.YIrcos, self.Xrcos) hi0mask = interpolate1d(self.log_rad, self.Yrcos, self.Xrcos) - self.register_buffer("lo0mask", torch.as_tensor(lo0mask).unsqueeze(0)) - self.register_buffer("hi0mask", torch.as_tensor(hi0mask).unsqueeze(0)) + self.register_buffer('lo0mask', torch.as_tensor(lo0mask).unsqueeze(0)) + self.register_buffer('hi0mask', torch.as_tensor(hi0mask).unsqueeze(0)) + # need a mock image to down-sample so that we correctly # construct the differently-sized masks @@ -208,10 +199,7 @@ def __init__( const = ( (2 ** (2 * self.order)) * (factorial(self.order, exact=True) ** 2) - / float( - self.num_orientations - * factorial(2 * self.order, exact=True) - ) + / float(self.num_orientations * factorial(2 * self.order, exact=True)) ) if self.is_complex: @@ -221,50 +209,32 @@ def __init__( * (np.cos(self.Xcosn) ** self.order) * (np.abs(self.alpha) < np.pi / 2.0).astype(int) ) - Ycosn_recon = ( - np.sqrt(const) * (np.cos(self.Xcosn)) ** self.order - ) + Ycosn_recon = np.sqrt(const) * (np.cos(self.Xcosn)) ** self.order else: - Ycosn_forward = ( - np.sqrt(const) * (np.cos(self.Xcosn)) ** self.order - ) + Ycosn_forward = np.sqrt(const) * (np.cos(self.Xcosn)) ** self.order Ycosn_recon = Ycosn_forward himask = interpolate1d(log_rad, self.Yrcos, Xrcos) - self.register_buffer( - f"_himasks_scale_{i}", torch.as_tensor(himask).unsqueeze(0) - ) + self.register_buffer(f'_himasks_scale_{i}', torch.as_tensor(himask).unsqueeze(0)) anglemasks = [] anglemasks_recon = [] for b in range(self.num_orientations): anglemask = interpolate1d( - angle, - Ycosn_forward, - self.Xcosn + np.pi * b / self.num_orientations, + angle, Ycosn_forward, self.Xcosn + np.pi * b / self.num_orientations ) anglemask_recon = interpolate1d( - angle, - Ycosn_recon, - self.Xcosn + np.pi * b / self.num_orientations, + angle, Ycosn_recon, self.Xcosn + np.pi * b / self.num_orientations ) anglemasks.append(torch.as_tensor(anglemask).unsqueeze(0)) - anglemasks_recon.append( - torch.as_tensor(anglemask_recon).unsqueeze(0) - ) + anglemasks_recon.append(torch.as_tensor(anglemask_recon).unsqueeze(0)) - self.register_buffer( - f"_anglemasks_scale_{i}", torch.cat(anglemasks) - ) - self.register_buffer( - f"_anglemasks_recon_scale_{i}", torch.cat(anglemasks_recon) - ) + self.register_buffer(f'_anglemasks_scale_{i}', torch.cat(anglemasks)) + self.register_buffer(f'_anglemasks_recon_scale_{i}', torch.cat(anglemasks_recon)) if not self.downsample: lomask = interpolate1d(log_rad, self.YIrcos, Xrcos) - self.register_buffer( - f"_lomasks_scale_{i}", torch.as_tensor(lomask).unsqueeze(0) - ) + self.register_buffer(f'_lomasks_scale_{i}', torch.as_tensor(lomask).unsqueeze(0)) self._loindices.append([np.array([0, 0]), dims]) lodft = lodft * lomask @@ -283,9 +253,7 @@ def __init__( angle = angle[lostart[0] : loend[0], lostart[1] : loend[1]] lomask = interpolate1d(log_rad, self.YIrcos, Xrcos) - self.register_buffer( - f"_lomasks_scale_{i}", torch.as_tensor(lomask).unsqueeze(0) - ) + self.register_buffer(f'_lomasks_scale_{i}', torch.as_tensor(lomask).unsqueeze(0)) # subsampling lodft = lodft[lostart[0] : loend[0], lostart[1] : loend[1]] # convolution in spatial domain @@ -297,7 +265,7 @@ def __init__( def forward( self, x: Tensor, - scales: list[SCALES_TYPE] | None = None, + scales: Optional[List[SCALES_TYPE]] = None, ) -> OrderedDict: r"""Generate the steerable pyramid coefficients for an image @@ -337,9 +305,7 @@ def forward( # x is a torch tensor batch of images of size (batch, channel, height, # width) - assert ( - len(x.shape) == 4 - ), "Input must be batch of images of shape BxCxHxW" + assert len(x.shape) == 4, "Input must be batch of images of shape BxCxHxW" imdft = fft.fft2(x, dim=(-2, -1), norm=self.fft_norm) imdft = fft.fftshift(imdft) @@ -356,18 +322,20 @@ def forward( lodft = imdft * lo0mask for i in range(self.num_scales): + if i in scales: # high-pass mask is selected based on the current scale - himask = getattr(self, f"_himasks_scale_{i}") + himask = getattr(self, f'_himasks_scale_{i}') # compute filter output at each orientation for b in range(self.num_orientations): + # band pass filtering is done in the fourier space as multiplying by the fft of a gaussian derivative. # The oriented dft is computed as a product of the fft of the low-passed component, # the precomputed anglemask (specifies orientation), and the precomputed hipass mask (creating a bandpass filter) # the complex_const variable comes from the Fourier transform of a gaussian derivative. # Based on the order of the gaussian, this constant changes. - anglemask = getattr(self, f"_anglemasks_scale_{i}")[b] + anglemask = getattr(self, f'_anglemasks_scale_{i}')[b] complex_const = np.power(complex(0, -1), self.order) banddft = complex_const * lodft * anglemask * himask @@ -380,6 +348,7 @@ def forward( if not self.is_complex: pyr_coeffs[(i, b)] = band.real else: + # Because the input signal is real, to maintain a tight frame # if the complex pyramid is used, magnitudes need to be divided by sqrt(2) # because energy is doubled. @@ -392,7 +361,7 @@ def forward( if not self.downsample: # no subsampling of angle and rad # just use lo0mask - lomask = getattr(self, f"_lomasks_scale_{i}") + lomask = getattr(self, f'_lomasks_scale_{i}') lodft = lodft * lomask # because we don't subsample here, if we are not using orthonormalization that @@ -409,11 +378,9 @@ def forward( angle = angle[lostart[0] : loend[0], lostart[1] : loend[1]] # subsampling of the dft for next scale - lodft = lodft[ - :, :, lostart[0] : loend[0], lostart[1] : loend[1] - ] + lodft = lodft[:, :, lostart[0] : loend[0], lostart[1] : loend[1]] # low-pass filter mask is selected - lomask = getattr(self, f"_lomasks_scale_{i}") + lomask = getattr(self, f'_lomasks_scale_{i}') # again multiply dft by subsampled mask (convolution in spatial domain) lodft = lodft * lomask @@ -430,7 +397,7 @@ def forward( @staticmethod def convert_pyr_to_tensor( pyr_coeffs: OrderedDict, split_complex: bool = False - ) -> tuple[Tensor, tuple[int, bool, list[KEYS_TYPE]]]: + ) -> Tuple[Tensor, Tuple[int, bool, List[KEYS_TYPE]]]: r"""Convert coefficient dictionary to a tensor. The output tensor has shape (batch, channel, height, width) and is @@ -506,10 +473,10 @@ def convert_pyr_to_tensor( try: pyr_tensor = torch.cat(coeff_list, dim=1) pyr_info = tuple([num_channels, split_complex, pyr_keys]) - except RuntimeError: + except RuntimeError as e: raise Exception( - """feature maps could not be concatenated into tensor. - Check that you are using coefficients that are not downsampled across scales. + """feature maps could not be concatenated into tensor. + Check that you are using coefficients that are not downsampled across scales. This is done with the 'downsample=False' argument for the pyramid""" ) @@ -520,7 +487,7 @@ def convert_tensor_to_pyr( pyr_tensor: Tensor, num_channels: int, split_complex: bool, - pyr_keys: list[KEYS_TYPE], + pyr_keys: List[KEYS_TYPE], ) -> OrderedDict: r"""Convert pyramid coefficient tensor to dictionary format. @@ -571,8 +538,7 @@ def convert_tensor_to_pyr( if split_complex: band = torch.view_as_complex( rearrange( - pyr_tensor[:, i : i + 2, ...], - "b c h w -> b h w c", + pyr_tensor[:, i : i + 2, ...], "b c h w -> b h w c" ) .unsqueeze(1) .contiguous() @@ -589,8 +555,8 @@ def convert_tensor_to_pyr( return pyr_coeffs def _recon_levels_check( - self, levels: Literal["all"] | list[SCALES_TYPE] - ) -> list[SCALES_TYPE]: + self, levels: Union[Literal["all"], List[SCALES_TYPE]] + ) -> List[SCALES_TYPE]: r"""Check whether levels arg is valid for reconstruction and return valid version When reconstructing the input image (i.e., when calling `recon_pyr()`), @@ -615,9 +581,7 @@ def _recon_levels_check( """ if isinstance(levels, str): if levels != "all": - raise TypeError( - f"levels must be a list of levels or the string 'all' but got {levels}" - ) + raise TypeError(f"levels must be a list of levels or the string 'all' but got {levels}") levels = ( ["residual_highpass"] + list(range(self.num_scales)) @@ -625,18 +589,15 @@ def _recon_levels_check( ) else: if not hasattr(levels, "__iter__"): - raise TypeError( - f"levels must be a list of levels or the string 'all' but got {levels}" - ) + raise TypeError(f"levels must be a list of levels or the string 'all' but got {levels}") levs_nums = np.array( [int(i) for i in levels if isinstance(i, int)] ) + assert (levs_nums >= 0).all(), "Level numbers must be non-negative." assert ( - levs_nums >= 0 - ).all(), "Level numbers must be non-negative." - assert (levs_nums < self.num_scales).all(), ( - "Level numbers must be in the range [0, %d]" - % (self.num_scales - 1) + levs_nums < self.num_scales + ).all(), "Level numbers must be in the range [0, %d]" % ( + self.num_scales - 1 ) levs_tmp = list(np.sort(levs_nums)) # we want smallest first if "residual_highpass" in levels: @@ -659,8 +620,8 @@ def _recon_levels_check( return levels def _recon_bands_check( - self, bands: Literal["all"] | list[int] - ) -> list[int]: + self, bands: Union[Literal["all"], List[int]] + ) -> List[int]: """Check whether bands arg is valid for reconstruction and return valid version When reconstructing the input image (i.e., when calling `recon_pyr()`), the user specifies @@ -683,31 +644,26 @@ def _recon_bands_check( """ if isinstance(bands, str): if bands != "all": - raise TypeError( - f"bands must be a list of ints or the string 'all' but got {bands}" - ) + raise TypeError(f"bands must be a list of ints or the string 'all' but got {bands}") bands = np.arange(self.num_orientations) else: if not hasattr(bands, "__iter__"): - raise TypeError( - f"bands must be a list of ints or the string 'all' but got {bands}" - ) + raise TypeError(f"bands must be a list of ints or the string 'all' but got {bands}") bands: NDArray = np.array(bands, ndmin=1) + assert (bands >= 0).all(), "Error: band numbers must be larger than 0." assert ( - bands >= 0 - ).all(), "Error: band numbers must be larger than 0." - assert (bands < self.num_orientations).all(), ( - "Error: band numbers must be in the range [0, %d]" - % (self.num_orientations - 1) + bands < self.num_orientations + ).all(), "Error: band numbers must be in the range [0, %d]" % ( + self.num_orientations - 1 ) return list(bands) def _recon_keys( self, - levels: Literal["all"] | list[SCALES_TYPE], - bands: Literal["all"] | list[int], - max_orientations: int | None = None, - ) -> list[KEYS_TYPE]: + levels: Union[Literal["all"], List[SCALES_TYPE]], + bands: Union[Literal["all"], List[int]], + max_orientations: Optional[int] = None, + ) -> List[KEYS_TYPE]: """Make a list of all the relevant keys from `pyr_coeffs` to use in pyramid reconstruction When reconstructing the input image (i.e., when calling `recon_pyr()`), @@ -745,9 +701,11 @@ def _recon_keys( for i in bands: if i >= max_orientations: warnings.warn( - "You wanted band %d in the reconstruction but max_orientation" - " is %d, so we're ignoring that band" - % (i, max_orientations) + ( + "You wanted band %d in the reconstruction but max_orientation" + " is %d, so we're ignoring that band" + % (i, max_orientations) + ) ) bands = [i for i in bands if i < max_orientations] recon_keys = [] @@ -764,8 +722,8 @@ def _recon_keys( def recon_pyr( self, pyr_coeffs: OrderedDict, - levels: Literal["all"] | list[SCALES_TYPE] = "all", - bands: Literal["all"] | list[int] = "all", + levels: Union[Literal["all"], List[SCALES_TYPE]] = "all", + bands: Union[Literal["all"], List[int]] = "all", ) -> Tensor: """Reconstruct the image or batch of images, optionally using subset of pyramid coefficients. @@ -830,9 +788,7 @@ def recon_pyr( # generate highpass residual Reconstruction if "residual_highpass" in recon_keys: hidft = fft.fft2( - pyr_coeffs["residual_highpass"], - dim=(-2, -1), - norm=self.fft_norm, + pyr_coeffs["residual_highpass"], dim=(-2, -1), norm=self.fft_norm ) hidft = fft.fftshift(hidft) @@ -845,9 +801,7 @@ def recon_pyr( # get output reconstruction by inverting the fft reconstruction = fft.ifftshift(outdft) - reconstruction = fft.ifft2( - reconstruction, dim=(-2, -1), norm=self.fft_norm - ) + reconstruction = fft.ifft2(reconstruction, dim=(-2, -1), norm=self.fft_norm) # get real part of reconstruction (if complex) reconstruction = reconstruction.real @@ -855,7 +809,7 @@ def recon_pyr( return reconstruction def _recon_levels( - self, pyr_coeffs: OrderedDict, recon_keys: list[KEYS_TYPE], scale: int + self, pyr_coeffs: OrderedDict, recon_keys: List[KEYS_TYPE], scale: int ) -> Tensor: """Recursive function used to build the reconstruction. Called by recon_pyr @@ -884,14 +838,14 @@ def _recon_levels( if scale == self.num_scales: if "residual_lowpass" in recon_keys: lodft = fft.fft2( - pyr_coeffs["residual_lowpass"], - dim=(-2, -1), - norm=self.fft_norm, + pyr_coeffs["residual_lowpass"], dim=(-2, -1), norm=self.fft_norm ) lodft = fft.fftshift(lodft) else: lodft = fft.fft2( - torch.zeros_like(pyr_coeffs["residual_lowpass"]), + torch.zeros_like( + pyr_coeffs["residual_lowpass"] + ), dim=(-2, -1), norm=self.fft_norm, ) @@ -900,14 +854,12 @@ def _recon_levels( # Reconstruct from orientation bands # update himask - himask = getattr(self, f"_himasks_scale_{scale}") + himask = getattr(self, f'_himasks_scale_{scale}') orientdft = torch.zeros_like(pyr_coeffs[(scale, 0)]) for b in range(self.num_orientations): if (scale, b) in recon_keys: - anglemask = getattr(self, f"_anglemasks_recon_scale_{scale}")[ - b - ] + anglemask = getattr(self, f'_anglemasks_recon_scale_{scale}')[b] coeffs = pyr_coeffs[(scale, b)] if self.tight_frame and self.is_complex: coeffs = coeffs * np.sqrt(2) @@ -923,7 +875,7 @@ def _recon_levels( lostart, loend = self._loindices[scale] # create lowpass mask - lomask = getattr(self, f"_lomasks_scale_{scale}") + lomask = getattr(self, f'_lomasks_scale_{scale}') # Recursively reconstruct by going to the next scale reslevdft = self._recon_levels(pyr_coeffs, recon_keys, scale + 1) @@ -931,24 +883,17 @@ def _recon_levels( if (not self.tight_frame) and (not self.downsample): reslevdft = reslevdft / 2 # create output for reconstruction result - resdft = torch.zeros_like( - pyr_coeffs[(scale, 0)], dtype=torch.complex64 - ) + resdft = torch.zeros_like(pyr_coeffs[(scale, 0)], dtype=torch.complex64) # place upsample and convolve lowpass component - resdft[:, :, lostart[0] : loend[0], lostart[1] : loend[1]] = ( - reslevdft * lomask - ) + resdft[:, :, lostart[0] : loend[0], lostart[1] : loend[1]] = reslevdft * lomask recondft = resdft + orientdft # add orientation interpolated and added images to the lowpass image return recondft def steer_coeffs( - self, - pyr_coeffs: OrderedDict, - angles: list[float], - even_phase: bool = True, - ) -> tuple[dict, dict]: + self, pyr_coeffs: OrderedDict, angles: List[float], even_phase: bool = True + ) -> Tuple[dict, dict]: """Steer pyramid coefficients to the specified angles This allows you to have filters that have the Gaussian derivative order specified in diff --git a/src/plenoptic/simulate/models/frontend.py b/src/plenoptic/simulate/models/frontend.py index 802de615..7d1050dc 100644 --- a/src/plenoptic/simulate/models/frontend.py +++ b/src/plenoptic/simulate/models/frontend.py @@ -10,25 +10,22 @@ .. [2] http://www.cns.nyu.edu/~lcv/eigendistortions/ModelsIQA.html """ -from collections import OrderedDict -from collections.abc import Callable -from warnings import warn +from typing import Tuple, Union, Callable import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor +from .naive import Gaussian, CenterSurround from ...tools.display import imshow from ...tools.signal import make_disk -from .naive import CenterSurround, Gaussian +from collections import OrderedDict +from warnings import warn + -__all__ = [ - "LinearNonlinear", - "LuminanceGainControl", - "LuminanceContrastGainControl", - "OnOff", -] +__all__ = ["LinearNonlinear", "LuminanceGainControl", + "LuminanceContrastGainControl", "OnOff"] class LinearNonlinear(nn.Module): @@ -69,11 +66,12 @@ class LinearNonlinear(nn.Module): def __init__( self, - kernel_size: int | tuple[int, int], + kernel_size: Union[int, Tuple[int, int]], on_center: bool = True, width_ratio_limit: float = 4.0, amplitude_ratio: float = 1.25, pad_mode: str = "reflect", + activation: Callable[[Tensor], Tensor] = F.softplus, ): super().__init__() @@ -114,7 +112,7 @@ def display_filters(self, zoom=5.0, **kwargs): class LuminanceGainControl(nn.Module): - """Linear center-surround followed by luminance gain control and activation. + """ Linear center-surround followed by luminance gain control and activation. Model is described in [1]_ and [2]_. Parameters @@ -152,14 +150,14 @@ class LuminanceGainControl(nn.Module): representations, NeurIPS 2017; https://arxiv.org/abs/1710.02266 .. [2] http://www.cns.nyu.edu/~lcv/eigendistortions/ModelsIQA.html """ - def __init__( self, - kernel_size: int | tuple[int, int], + kernel_size: Union[int, Tuple[int, int]], on_center: bool = True, width_ratio_limit: float = 4.0, amplitude_ratio: float = 1.25, pad_mode: str = "reflect", + activation: Callable[[Tensor], Tensor] = F.softplus, ): super().__init__() @@ -203,25 +201,17 @@ def display_filters(self, zoom=5.0, **kwargs): dim=0, ).detach() - title = [ - "linear filt", - "luminance filt", - ] + title = ["linear filt", "luminance filt",] fig = imshow( - weights, - title=title, - col_wrap=2, - zoom=zoom, - vrange="indep0", - **kwargs, + weights, title=title, col_wrap=2, zoom=zoom, vrange="indep0", **kwargs ) return fig class LuminanceContrastGainControl(nn.Module): - """Linear center-surround followed by luminance and contrast gain control, + """ Linear center-surround followed by luminance and contrast gain control, and activation function. Model is described in [1]_ and [2]_. Parameters @@ -265,11 +255,12 @@ class LuminanceContrastGainControl(nn.Module): def __init__( self, - kernel_size: int | tuple[int, int], + kernel_size: Union[int, Tuple[int, int]], on_center: bool = True, width_ratio_limit: float = 4.0, amplitude_ratio: float = 1.25, pad_mode: str = "reflect", + activation: Callable[[Tensor], Tensor] = F.softplus, ): super().__init__() @@ -294,9 +285,7 @@ def forward(self, x: Tensor) -> Tensor: lum = self.luminance(x) lum_normed = linear / (1 + self.luminance_scalar * lum) - con = ( - self.contrast(lum_normed.pow(2)).sqrt() + 1e-6 - ) # avoid div by zero + con = self.contrast(lum_normed.pow(2)).sqrt() + 1E-6 # avoid div by zero con_normed = lum_normed / (1 + self.contrast_scalar * con) y = self.activation(con_normed) return y @@ -327,12 +316,7 @@ def display_filters(self, zoom=5.0, **kwargs): title = ["linear filt", "luminance filt", "contrast filt"] fig = imshow( - weights, - title=title, - col_wrap=3, - zoom=zoom, - vrange="indep0", - **kwargs, + weights, title=title, col_wrap=3, zoom=zoom, vrange="indep0", **kwargs ) return fig @@ -385,7 +369,7 @@ class OnOff(nn.Module): def __init__( self, - kernel_size: int | tuple[int, int], + kernel_size: Union[int, Tuple[int, int]], width_ratio_limit: float = 4.0, amplitude_ratio: float = 1.25, pad_mode: str = "reflect", @@ -393,20 +377,16 @@ def __init__( activation: Callable[[Tensor], Tensor] = F.softplus, apply_mask: bool = False, cache_filt: bool = False, + ): super().__init__() if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) if pretrained: - assert kernel_size == ( - 31, - 31, - ), "pretrained model has kernel_size (31, 31)" + assert kernel_size == (31, 31), "pretrained model has kernel_size (31, 31)" if cache_filt is False: - warn( - "pretrained is True but cache_filt is False. Set cache_filt to " - "True for efficiency unless you are fine-tuning." - ) + warn("pretrained is True but cache_filt is False. Set cache_filt to " + "True for efficiency unless you are fine-tuning.") self.center_surround = CenterSurround( kernel_size=kernel_size, @@ -419,17 +399,17 @@ def __init__( ) self.luminance = Gaussian( - kernel_size=kernel_size, - out_channels=2, - pad_mode=pad_mode, - cache_filt=cache_filt, + kernel_size=kernel_size, + out_channels=2, + pad_mode=pad_mode, + cache_filt=cache_filt, ) self.contrast = Gaussian( - kernel_size=kernel_size, - out_channels=2, - pad_mode=pad_mode, - cache_filt=cache_filt, + kernel_size=kernel_size, + out_channels=2, + pad_mode=pad_mode, + cache_filt=cache_filt, ) # init scalar values around fitted parameters found in Berardino et al 2017 @@ -446,23 +426,15 @@ def __init__( def forward(self, x: Tensor) -> Tensor: linear = self.center_surround(x) lum = self.luminance(x) - lum_normed = linear / ( - 1 + self.luminance_scalar.view(1, 2, 1, 1) * lum - ) + lum_normed = linear / (1 + self.luminance_scalar.view(1, 2, 1, 1) * lum) - con = ( - self.contrast(lum_normed.pow(2), groups=2).sqrt() + 1e-6 - ) # avoid div by 0 - con_normed = lum_normed / ( - 1 + self.contrast_scalar.view(1, 2, 1, 1) * con - ) + con = self.contrast(lum_normed.pow(2), groups=2).sqrt() + 1E-6 # avoid div by 0 + con_normed = lum_normed / (1 + self.contrast_scalar.view(1, 2, 1, 1) * con) y = self.activation(con_normed) if self.apply_mask: im_shape = x.shape[-2:] - if ( - self._disk is None or self._disk.shape != im_shape - ): # cache new mask + if self._disk is None or self._disk.shape != im_shape: # cache new mask self._disk = make_disk(im_shape).to(x.device) if self._disk.device != x.device: self._disk = self._disk.to(x.device) @@ -471,6 +443,7 @@ def forward(self, x: Tensor) -> Tensor: return y + def display_filters(self, zoom=5.0, **kwargs): """Displays convolutional filters of model @@ -504,12 +477,7 @@ def display_filters(self, zoom=5.0, **kwargs): ] fig = imshow( - weights, - title=title, - col_wrap=2, - zoom=zoom, - vrange="indep0", - **kwargs, + weights, title=title, col_wrap=2, zoom=zoom, vrange="indep0", **kwargs ) return fig @@ -526,6 +494,7 @@ def _pretrained_state_dict() -> OrderedDict: ("center_surround.amplitude_ratio", torch.as_tensor([1.25])), ("luminance.std", torch.as_tensor([8.7366, 1.4751])), ("contrast.std", torch.as_tensor([2.7353, 1.5583])), + ] ) return state_dict diff --git a/src/plenoptic/simulate/models/naive.py b/src/plenoptic/simulate/models/naive.py index 9b8a7035..16263abe 100644 --- a/src/plenoptic/simulate/models/naive.py +++ b/src/plenoptic/simulate/models/naive.py @@ -1,5 +1,8 @@ +from typing import Union, Tuple, List import torch -from torch import Tensor, nn +from torch import nn, nn as nn, Tensor +from torch import Tensor +import numpy as np from torch.nn import functional as F from ...tools.conv import same_padding @@ -55,7 +58,7 @@ class Linear(nn.Module): def __init__( self, - kernel_size: int | tuple[int, int] = (3, 3), + kernel_size: Union[int, Tuple[int, int]] = (3, 3), pad_mode: str = "circular", default_filters: bool = True, ): @@ -70,10 +73,10 @@ def __init__( self.conv = nn.Conv2d(1, 2, kernel_size, bias=False) if default_filters: - var = torch.as_tensor(3.0) + var = torch.as_tensor(3.) f1 = circular_gaussian2d(kernel_size, std=torch.sqrt(var)) - f2 = circular_gaussian2d(kernel_size, std=torch.sqrt(var / 3)) + f2 = circular_gaussian2d(kernel_size, std=torch.sqrt(var/3)) f2 = f2 - f1 f2 = f2 / f2.sum() @@ -107,8 +110,8 @@ class Gaussian(nn.Module): def __init__( self, - kernel_size: int | tuple[int, int], - std: float | Tensor = 3.0, + kernel_size: Union[int, Tuple[int, int]], + std: Union[float, Tensor] = 3.0, pad_mode: str = "reflect", out_channels: int = 1, cache_filt: bool = False, @@ -126,19 +129,17 @@ def __init__( self.out_channels = out_channels self.cache_filt = cache_filt - self.register_buffer("_filt", None) + self.register_buffer('_filt', None) @property def filt(self): if self._filt is not None: # use old filter return self._filt else: # create new filter, optionally cache it - filt = circular_gaussian2d( - self.kernel_size, self.std, self.out_channels - ) + filt = circular_gaussian2d(self.kernel_size, self.std, self.out_channels) if self.cache_filt: - self.register_buffer("_filt", filt) + self.register_buffer('_filt', filt) return filt def forward(self, x: Tensor, **conv2d_kwargs) -> Tensor: @@ -195,12 +196,12 @@ class CenterSurround(nn.Module): def __init__( self, - kernel_size: int | tuple[int, int], - on_center: bool | list[bool,] = True, + kernel_size: Union[int, Tuple[int, int]], + on_center: Union[bool, List[bool, ]] = True, width_ratio_limit: float = 2.0, amplitude_ratio: float = 1.25, - center_std: float | Tensor = 1.0, - surround_std: float | Tensor = 4.0, + center_std: Union[float, Tensor] = 1.0, + surround_std: Union[float, Tensor] = 4.0, out_channels: int = 1, pad_mode: str = "reflect", cache_filt: bool = False, @@ -210,46 +211,31 @@ def __init__( # make sure each channel is on-off or off-on if isinstance(on_center, bool): on_center = [on_center] * out_channels - assert ( - len(on_center) == out_channels - ), "len(on_center) must match out_channels" + assert len(on_center) == out_channels, "len(on_center) must match out_channels" # make sure each channel has a center and surround std if isinstance(center_std, float) or center_std.shape == torch.Size([]): center_std = torch.ones(out_channels) * center_std - if isinstance(surround_std, float) or surround_std.shape == torch.Size( - [] - ): + if isinstance(surround_std, float) or surround_std.shape == torch.Size([]): surround_std = torch.ones(out_channels) * surround_std - assert ( - len(center_std) == out_channels - and len(surround_std) == out_channels - ), "stds must correspond to each out_channel" - assert ( - width_ratio_limit > 1.0 - ), "stdev of surround must be greater than center" - assert ( - amplitude_ratio >= 1.0 - ), "ratio of amplitudes must at least be 1." + assert len(center_std) == out_channels and len(surround_std) == out_channels, "stds must correspond to each out_channel" + assert width_ratio_limit > 1.0, "stdev of surround must be greater than center" + assert amplitude_ratio >= 1.0, "ratio of amplitudes must at least be 1." self.on_center = on_center self.kernel_size = kernel_size self.width_ratio_limit = width_ratio_limit - self.register_buffer( - "amplitude_ratio", torch.as_tensor(amplitude_ratio) - ) + self.register_buffer("amplitude_ratio", torch.as_tensor(amplitude_ratio)) self.center_std = nn.Parameter(torch.ones(out_channels) * center_std) - self.surround_std = nn.Parameter( - torch.ones(out_channels) * surround_std - ) + self.surround_std = nn.Parameter(torch.ones(out_channels) * surround_std) self.out_channels = out_channels self.pad_mode = pad_mode self.cache_filt = cache_filt - self.register_buffer("_filt", None) + self.register_buffer('_filt', None) @property def filt(self) -> Tensor: @@ -260,32 +246,24 @@ def filt(self) -> Tensor: on_amp = self.amplitude_ratio device = on_amp.device - filt_center = circular_gaussian2d( - self.kernel_size, self.center_std, self.out_channels - ) - filt_surround = circular_gaussian2d( - self.kernel_size, self.surround_std, self.out_channels - ) + filt_center = circular_gaussian2d(self.kernel_size, self.center_std, self.out_channels) + filt_surround = circular_gaussian2d(self.kernel_size, self.surround_std, self.out_channels) # sign is + or - depending on center is on or off - sign = torch.as_tensor( - [1.0 if x else -1.0 for x in self.on_center] - ).to(device) + sign = torch.as_tensor([1. if x else -1. for x in self.on_center]).to(device) sign = sign.view(self.out_channels, 1, 1, 1) filt = on_amp * (sign * (filt_center - filt_surround)) if self.cache_filt: - self.register_buffer("_filt", filt) + self.register_buffer('_filt', filt) return filt def _clamp_surround_std(self): """Clamps surround standard deviation to ratio_limit times center_std""" lower_bound = self.width_ratio_limit * self.center_std for i, lb in enumerate(lower_bound): - self.surround_std[i].data = self.surround_std[i].data.clamp( - min=float(lb) - ) + self.surround_std[i].data = self.surround_std[i].data.clamp(min=float(lb)) def forward(self, x: Tensor) -> Tensor: x = same_padding(x, self.kernel_size, pad_mode=self.pad_mode) diff --git a/src/plenoptic/simulate/models/portilla_simoncelli.py b/src/plenoptic/simulate/models/portilla_simoncelli.py index edc7d3d0..81545620 100644 --- a/src/plenoptic/simulate/models/portilla_simoncelli.py +++ b/src/plenoptic/simulate/models/portilla_simoncelli.py @@ -7,7 +7,7 @@ consider them as members of the same family of textures. """ from collections import OrderedDict -from typing import Literal, Union +from typing import List, Optional, Tuple, Union import einops import matplotlib as mpl @@ -17,17 +17,16 @@ import torch.fft import torch.nn as nn from torch import Tensor +from typing_extensions import Literal from ...tools import signal, stats from ...tools.data import to_numpy from ...tools.display import clean_stem_plot, clean_up_axes, update_stem from ...tools.validate import validate_input +from ..canonical_computations.steerable_pyramid_freq import SteerablePyramidFreq from ..canonical_computations.steerable_pyramid_freq import ( SCALES_TYPE as PYR_SCALES_TYPE, ) -from ..canonical_computations.steerable_pyramid_freq import ( - SteerablePyramidFreq, -) SCALES_TYPE = Union[Literal["pixel_statistics"], PYR_SCALES_TYPE] @@ -81,7 +80,7 @@ class PortillaSimoncelli(nn.Module): def __init__( self, - image_shape: tuple[int, int], + image_shape: Tuple[int, int], n_scales: int = 4, n_orientations: int = 4, spatial_corr_width: int = 9, @@ -147,6 +146,8 @@ def __init__( ] def _create_scales_shape_dict(self) -> OrderedDict: + + """Create dictionary defining scales and shape of each stat. This dictionary functions as metadata which is used for two main @@ -220,11 +221,7 @@ def _create_scales_shape_dict(self) -> OrderedDict: shape_dict["kurtosis_reconstructed"] = scales_with_lowpass auto_corr = np.ones( - ( - self.spatial_corr_width, - self.spatial_corr_width, - self.n_scales + 1, - ), + (self.spatial_corr_width, self.spatial_corr_width, self.n_scales + 1), dtype=object, ) auto_corr *= einops.rearrange(scales_with_lowpass, "s -> 1 1 s") @@ -233,8 +230,7 @@ def _create_scales_shape_dict(self) -> OrderedDict: shape_dict["std_reconstructed"] = scales_with_lowpass cross_orientation_corr_mag = np.ones( - (self.n_orientations, self.n_orientations, self.n_scales), - dtype=int, + (self.n_orientations, self.n_orientations, self.n_scales), dtype=int ) cross_orientation_corr_mag *= einops.rearrange(scales, "s -> 1 1 s") shape_dict[ @@ -246,21 +242,15 @@ def _create_scales_shape_dict(self) -> OrderedDict: shape_dict["magnitude_std"] = mags_std cross_scale_corr_mag = np.ones( - (self.n_orientations, self.n_orientations, self.n_scales - 1), - dtype=int, - ) - cross_scale_corr_mag *= einops.rearrange( - scales_without_coarsest, "s -> 1 1 s" + (self.n_orientations, self.n_orientations, self.n_scales - 1), dtype=int ) + cross_scale_corr_mag *= einops.rearrange(scales_without_coarsest, "s -> 1 1 s") shape_dict["cross_scale_correlation_magnitude"] = cross_scale_corr_mag cross_scale_corr_real = np.ones( - (self.n_orientations, 2 * self.n_orientations, self.n_scales - 1), - dtype=int, - ) - cross_scale_corr_real *= einops.rearrange( - scales_without_coarsest, "s -> 1 1 s" + (self.n_orientations, 2 * self.n_orientations, self.n_scales - 1), dtype=int ) + cross_scale_corr_real *= einops.rearrange(scales_without_coarsest, "s -> 1 1 s") shape_dict["cross_scale_correlation_real"] = cross_scale_corr_real shape_dict["var_highpass_residual"] = np.array(["residual_highpass"]) @@ -297,9 +287,7 @@ def _create_necessary_stats_dict( mask_dict = scales_shape_dict.copy() # Pre-compute some necessary indices. # Lower triangular indices (including diagonal), for auto correlations - tril_inds = torch.tril_indices( - self.spatial_corr_width, self.spatial_corr_width - ) + tril_inds = torch.tril_indices(self.spatial_corr_width, self.spatial_corr_width) # Get the second half of the diagonal, i.e., everything from the center # element on. These are all repeated for the auto correlations. (As # these are autocorrelations (rather than auto-covariance) matrices, @@ -312,14 +300,9 @@ def _create_necessary_stats_dict( # for cross_orientation_correlation_magnitude (because we've normalized # this matrix to be true cross-correlations, the diagonals are all 1, # like for the auto-correlations) - triu_inds = torch.triu_indices( - self.n_orientations, self.n_orientations - ) + triu_inds = torch.triu_indices(self.n_orientations, self.n_orientations) for k, v in mask_dict.items(): - if k in [ - "auto_correlation_magnitude", - "auto_correlation_reconstructed", - ]: + if k in ["auto_correlation_magnitude", "auto_correlation_reconstructed"]: # Symmetry M_{i,j} = M_{n-i+1, n-j+1} # Start with all False, then place True in necessary stats. mask = torch.zeros(v.shape, dtype=torch.bool) @@ -341,7 +324,7 @@ def _create_necessary_stats_dict( return mask_dict def forward( - self, image: Tensor, scales: list[SCALES_TYPE] | None = None + self, image: Tensor, scales: Optional[List[SCALES_TYPE]] = None ) -> Tensor: r"""Generate Texture Statistics representation of an image. @@ -389,17 +372,14 @@ def forward( # real_pyr_coeffs, which contain the demeaned magnitude of the pyramid # coefficients and the real part of the pyramid coefficients # respectively. - ( - mag_pyr_coeffs, - real_pyr_coeffs, - ) = self._compute_intermediate_representations(pyr_coeffs) + mag_pyr_coeffs, real_pyr_coeffs = self._compute_intermediate_representations( + pyr_coeffs + ) # Then, the reconstructed lowpass image at each scale. (this is a list # of length n_scales+1 containing tensors of shape (batch, channel, # height, width)) - reconstructed_images = self._reconstruct_lowpass_at_each_scale( - pyr_dict - ) + reconstructed_images = self._reconstruct_lowpass_at_each_scale(pyr_dict) # the reconstructed_images list goes from coarse-to-fine, but we want # each of the stats computed from it to go from fine-to-coarse, so we # reverse its direction. @@ -421,9 +401,7 @@ def forward( # tensor of shape (batch, channel, spatial_corr_width, # spatial_corr_width, n_scales+1), and var_recon is a tensor of shape # (batch, channel, n_scales+1) - autocorr_recon, var_recon = self._compute_autocorr( - reconstructed_images - ) + autocorr_recon, var_recon = self._compute_autocorr(reconstructed_images) # Compute the standard deviation, skew, and kurtosis of each # reconstructed lowpass image. std_recon, skew_recon, and # kurtosis_recon will all end up as tensors of shape (batch, channel, @@ -449,28 +427,23 @@ def forward( if self.n_scales != 1: # First, double the phase the coefficients, so we can correctly # compute correlations across scales. - ( - phase_doubled_mags, - phase_doubled_sep, - ) = self._double_phase_pyr_coeffs(pyr_coeffs) + phase_doubled_mags, phase_doubled_sep = self._double_phase_pyr_coeffs( + pyr_coeffs + ) # Compute the cross-scale correlations between the magnitude # coefficients. For each coefficient, we're correlating it with the # coefficients at the next-coarsest scale. this will be a tensor of # shape (batch, channel, n_orientations, n_orientations, # n_scales-1) cross_scale_corr_mags, _ = self._compute_cross_correlation( - mag_pyr_coeffs[:-1], - phase_doubled_mags, - tensors_are_identical=False, + mag_pyr_coeffs[:-1], phase_doubled_mags, tensors_are_identical=False ) # Compute the cross-scale correlations between the real # coefficients and the real and imaginary coefficients at the next # coarsest scale. this will be a tensor of shape (batch, channel, # n_orientations, 2*n_orientations, n_scales-1) cross_scale_corr_real, _ = self._compute_cross_correlation( - real_pyr_coeffs[:-1], - phase_doubled_sep, - tensors_are_identical=False, + real_pyr_coeffs[:-1], phase_doubled_sep, tensors_are_identical=False ) # Compute the variance of the highpass residual @@ -507,14 +480,12 @@ def forward( # Return the subset of stats corresponding to the specified scale. if scales is not None: - representation_tensor = self.remove_scales( - representation_tensor, scales - ) + representation_tensor = self.remove_scales(representation_tensor, scales) return representation_tensor def remove_scales( - self, representation_tensor: Tensor, scales_to_keep: list[SCALES_TYPE] + self, representation_tensor: Tensor, scales_to_keep: List[SCALES_TYPE] ) -> Tensor: """Remove statistics not associated with scales. @@ -619,9 +590,7 @@ def convert_to_dict(self, representation_tensor: Tensor) -> OrderedDict: device=representation_tensor.device, ) # v.sum() gives the number of necessary elements from this stat - this_stat_vec = representation_tensor[ - ..., n_filled : n_filled + v.sum() - ] + this_stat_vec = representation_tensor[..., n_filled : n_filled + v.sum()] # use boolean indexing to put the values from new_stat_vec in the # appropriate place new_v[..., v] = this_stat_vec @@ -631,7 +600,7 @@ def convert_to_dict(self, representation_tensor: Tensor) -> OrderedDict: def _compute_pyr_coeffs( self, image: Tensor - ) -> tuple[OrderedDict, list[Tensor], Tensor, Tensor]: + ) -> Tuple[OrderedDict, List[Tensor], Tensor, Tensor]: """Compute pyramid coefficients of image. Note that the residual lowpass has been demeaned independently for each @@ -673,9 +642,7 @@ def _compute_pyr_coeffs( # of shape (batch, channel, n_orientations, height, width) (note that # height and width halves on each scale) coeffs_list = [ - torch.stack( - [pyr_coeffs[(i, j)] for j in range(self.n_orientations)], 2 - ) + torch.stack([pyr_coeffs[(i, j)] for j in range(self.n_orientations)], 2) for i in range(self.n_scales) ] return pyr_coeffs, coeffs_list, highpass, lowpass @@ -712,14 +679,12 @@ def _compute_pixel_stats(image: Tensor) -> Tensor: # mean needed to be unflattened to be used by skew and kurtosis # correctly, but we'll want it to be flattened like this in the final # representation tensor - return einops.pack( - [mean, var, skew, kurtosis, img_min, img_max], "b c *" - )[0] + return einops.pack([mean, var, skew, kurtosis, img_min, img_max], "b c *")[0] @staticmethod def _compute_intermediate_representations( pyr_coeffs: Tensor - ) -> tuple[list[Tensor], list[Tensor]]: + ) -> Tuple[List[Tensor], List[Tensor]]: """Compute useful intermediate representations. These representations are: @@ -754,17 +719,14 @@ def _compute_intermediate_representations( mag.mean((-2, -1), keepdim=True) for mag in magnitude_pyr_coeffs ] magnitude_pyr_coeffs = [ - mag - mn - for mag, mn in zip( - magnitude_pyr_coeffs, magnitude_means, strict=False - ) + mag - mn for mag, mn in zip(magnitude_pyr_coeffs, magnitude_means) ] real_pyr_coeffs = [coeff.real for coeff in pyr_coeffs] return magnitude_pyr_coeffs, real_pyr_coeffs def _reconstruct_lowpass_at_each_scale( self, pyr_coeffs_dict: OrderedDict - ) -> list[Tensor]: + ) -> List[Tensor]: """Reconstruct the lowpass unoriented image at each scale. The autocorrelation, standard deviation, skew, and kurtosis of each of @@ -799,15 +761,12 @@ def _reconstruct_lowpass_at_each_scale( # values across scales. This could also be handled by making the # pyramid tight frame reconstructed_images[:-1] = [ - signal.shrink(r, 2 ** (self.n_scales - i)) - * 4 ** (self.n_scales - i) + signal.shrink(r, 2 ** (self.n_scales - i)) * 4 ** (self.n_scales - i) for i, r in enumerate(reconstructed_images[:-1]) ] return reconstructed_images - def _compute_autocorr( - self, coeffs_list: list[Tensor] - ) -> tuple[Tensor, Tensor]: + def _compute_autocorr(self, coeffs_list: List[Tensor]) -> Tuple[Tensor, Tensor]: """Compute the autocorrelation of some statistics. Parameters @@ -843,18 +802,16 @@ def _compute_autocorr( ) acs = [signal.autocorrelation(coeff) for coeff in coeffs_list] var = [signal.center_crop(ac, 1) for ac in acs] - acs = [ac / v for ac, v in zip(acs, var, strict=False)] + acs = [ac / v for ac, v in zip(acs, var)] var = einops.pack(var, "b c *")[0] acs = [signal.center_crop(ac, self.spatial_corr_width) for ac in acs] acs = torch.stack(acs, 2) - return einops.rearrange( - acs, f"b c {dims} a1 a2 -> b c a1 a2 {dims}" - ), var + return einops.rearrange(acs, f"b c {dims} a1 a2 -> b c a1 a2 {dims}"), var @staticmethod def _compute_skew_kurtosis_recon( - reconstructed_images: list[Tensor], var_recon: Tensor, img_var: Tensor - ) -> tuple[Tensor, Tensor]: + reconstructed_images: List[Tensor], var_recon: Tensor, img_var: Tensor + ) -> Tuple[Tensor, Tensor]: """Compute the skew and kurtosis of each lowpass reconstructed image. For each scale, if the ratio of its variance to the original image's @@ -902,17 +859,15 @@ def _compute_skew_kurtosis_recon( res = torch.finfo(img_var.dtype).resolution unstable_locs = var_recon / img_var.unsqueeze(-1) < res skew_recon = torch.where(unstable_locs, skew_default, skew_recon) - kurtosis_recon = torch.where( - unstable_locs, kurtosis_default, kurtosis_recon - ) + kurtosis_recon = torch.where(unstable_locs, kurtosis_default, kurtosis_recon) return skew_recon, kurtosis_recon def _compute_cross_correlation( self, - coeffs_tensor: list[Tensor], - coeffs_tensor_other: list[Tensor], + coeffs_tensor: List[Tensor], + coeffs_tensor_other: List[Tensor], tensors_are_identical: bool = False, - ) -> tuple[Tensor, Tensor]: + ) -> Tuple[Tensor, Tensor]: """Compute cross-correlations. Parameters @@ -939,9 +894,7 @@ def _compute_cross_correlation( """ covars = [] coeffs_var = [] - for coeff, coeff_other in zip( - coeffs_tensor, coeffs_tensor_other, strict=False - ): + for coeff, coeff_other in zip(coeffs_tensor, coeffs_tensor_other): # precompute this, which we'll use for normalization numel = torch.mul(*coeff.shape[-2:]) # compute the covariance @@ -955,18 +908,14 @@ def _compute_cross_correlation( # First, compute the variances of each coeff (if coeff and # coeff_other are identical, this is equivalent to the diagonal of # the above covar matrix, but re-computing it is actually faster) - coeff_var = einops.einsum( - coeff, coeff, "b c o1 h w, b c o1 h w -> b c o1" - ) + coeff_var = einops.einsum(coeff, coeff, "b c o1 h w, b c o1 h w -> b c o1") coeff_var = coeff_var / numel coeffs_var.append(coeff_var) if tensors_are_identical: coeff_other_var = coeff_var else: coeff_other_var = einops.einsum( - coeff_other, - coeff_other, - "b c o2 h w, b c o2 h w -> b c o2", + coeff_other, coeff_other, "b c o2 h w, b c o2 h w -> b c o2" ) coeff_other_var = coeff_other_var / numel # Then compute the outer product of those variances. @@ -980,8 +929,8 @@ def _compute_cross_correlation( @staticmethod def _double_phase_pyr_coeffs( - pyr_coeffs: list[Tensor] - ) -> tuple[list[Tensor], list[Tensor]]: + pyr_coeffs: List[Tensor] + ) -> Tuple[List[Tensor], List[Tensor]]: """Upsample and double the phase of pyramid coefficients. Parameters @@ -1022,21 +971,19 @@ def _double_phase_pyr_coeffs( ) doubled_phase_mags.append(doubled_phase_mag) doubled_phase_sep.append( - einops.pack( - [doubled_phase.real, doubled_phase.imag], "b c * h w" - )[0] + einops.pack([doubled_phase.real, doubled_phase.imag], "b c * h w")[0] ) return doubled_phase_mags, doubled_phase_sep def plot_representation( self, data: Tensor, - ax: plt.Axes | None = None, - figsize: tuple[float, float] = (15, 15), - ylim: tuple[float, float] | Literal[False] | None = None, + ax: Optional[plt.Axes] = None, + figsize: Tuple[float, float] = (15, 15), + ylim: Optional[Union[Tuple[float, float], Literal[False]]] = None, batch_idx: int = 0, - title: str | None = None, - ) -> tuple[plt.Figure, list[plt.Axes]]: + title: Optional[str] = None, + ) -> Tuple[plt.Figure, List[plt.Axes]]: r"""Plot the representation in a human viewable format -- stem plots with data separated out by statistic type. @@ -1199,10 +1146,10 @@ def _representation_for_plotting(self, rep: OrderedDict) -> OrderedDict: def update_plot( self, - axes: list[plt.Axes], + axes: List[plt.Axes], data: Tensor, batch_idx: int = 0, - ) -> list[plt.Artist]: + ) -> List[plt.Artist]: r"""Update the information in our representation plot. This is used for creating an animation of the representation @@ -1255,7 +1202,7 @@ def update_plot( # of the first two dims rep = {k: v[0, 0] for k, v in self.convert_to_dict(data).items()} rep = self._representation_for_plotting(rep) - for ax, d in zip(axes, rep.values(), strict=False): + for ax, d in zip(axes, rep.values()): if isinstance(d, dict): vals = np.array([dd.detach() for dd in d.values()]) else: diff --git a/src/plenoptic/synthesize/__init__.py b/src/plenoptic/synthesize/__init__.py index 7eb36795..f9d7e0f3 100644 --- a/src/plenoptic/synthesize/__init__.py +++ b/src/plenoptic/synthesize/__init__.py @@ -1,5 +1,5 @@ from .eigendistortion import Eigendistortion +from .metamer import Metamer, MetamerCTF from .geodesic import Geodesic from .mad_competition import MADCompetition -from .metamer import Metamer, MetamerCTF from .simple_metamer import SimpleMetamer diff --git a/src/plenoptic/synthesize/autodiff.py b/src/plenoptic/synthesize/autodiff.py index 84c7724f..8be6e00c 100755 --- a/src/plenoptic/synthesize/autodiff.py +++ b/src/plenoptic/synthesize/autodiff.py @@ -1,7 +1,6 @@ -import warnings - import torch from torch import Tensor +import warnings def jacobian(y: Tensor, x: Tensor) -> Tensor: @@ -41,9 +40,7 @@ def jacobian(y: Tensor, x: Tensor) -> Tensor: .t() ) - if ( - y.shape[0] == 1 - ): # need to return a 2D tensor even if y dimensionality is 1 + if y.shape[0] == 1: # need to return a 2D tensor even if y dimensionality is 1 J = J.unsqueeze(0) return J.detach() diff --git a/src/plenoptic/synthesize/eigendistortion.py b/src/plenoptic/synthesize/eigendistortion.py index 2dd67037..3f4061c4 100755 --- a/src/plenoptic/synthesize/eigendistortion.py +++ b/src/plenoptic/synthesize/eigendistortion.py @@ -1,22 +1,18 @@ +from typing import Tuple, List, Callable, Union, Optional import warnings -from collections.abc import Callable -from typing import Literal +from typing_extensions import Literal import matplotlib.pyplot +from matplotlib.figure import Figure import numpy as np import torch -from matplotlib.figure import Figure from torch import Tensor from tqdm.auto import tqdm +from .synthesis import Synthesis +from .autodiff import jacobian, vector_jacobian_product, jacobian_vector_product from ..tools.display import imshow from ..tools.validate import validate_input, validate_model -from .autodiff import ( - jacobian, - jacobian_vector_product, - vector_jacobian_product, -) -from .synthesis import Synthesis def fisher_info_matrix_vector_product( @@ -53,7 +49,7 @@ def fisher_info_matrix_vector_product( def fisher_info_matrix_eigenvalue( - y: Tensor, x: Tensor, v: Tensor, dummy_vec: Tensor | None = None + y: Tensor, x: Tensor, v: Tensor, dummy_vec: Optional[Tensor] = None ) -> Tensor: r"""Compute the eigenvalues of the Fisher Information Matrix corresponding to eigenvectors in v :math:`\lambda= v^T F v` @@ -64,7 +60,7 @@ def fisher_info_matrix_eigenvalue( Fv = fisher_info_matrix_vector_product(y, x, v, dummy_vec) # compute eigenvalues for all vectors in v - lmbda = torch.stack([a.dot(b) for a, b in zip(v.T, Fv.T, strict=False)]) + lmbda = torch.stack([a.dot(b) for a, b in zip(v.T, Fv.T)]) return lmbda @@ -121,12 +117,8 @@ class Eigendistortion(Synthesis): def __init__(self, image: Tensor, model: torch.nn.Module): validate_input(image, no_batch=True) - validate_model( - model, - image_shape=image.shape, - image_dtype=image.dtype, - device=image.device, - ) + validate_model(model, image_shape=image.shape, + image_dtype=image.dtype, device=image.device) ( self.batch_size, @@ -151,7 +143,7 @@ def __init__(self, image: Tensor, model: torch.nn.Module): self._eigenindex = None def _init_representation(self, image): - """Set self._representation_flat, based on model and image""" + """Set self._representation_flat, based on model and image """ self._image = self._image_flat.view(*image.shape) image_representation = self.model(self.image) @@ -201,29 +193,24 @@ def synthesize( """ allowed_methods = ["power", "exact", "randomized_svd"] - assert ( - method in allowed_methods - ), f"method must be in {allowed_methods}" + assert method in allowed_methods, f"method must be in {allowed_methods}" if ( method == "exact" - and self._representation_flat.size(0) * self._image_flat.size(0) - > 1e6 + and self._representation_flat.size(0) * self._image_flat.size(0) > 1e6 ): warnings.warn( "Jacobian > 1e6 elements and may cause out-of-memory. Use method = {'power', 'randomized_svd'}." ) if method == "exact": # compute exact Jacobian - print("Computing all eigendistortions") + print(f"Computing all eigendistortions") eig_vals, eig_vecs = self._synthesize_exact() eig_vecs = self._vector_to_image(eig_vecs.detach()) eig_vecs_ind = torch.arange(len(eig_vecs)) elif method == "randomized_svd": - print( - f"Estimating top k={k} eigendistortions using randomized SVD" - ) + print(f"Estimating top k={k} eigendistortions using randomized SVD") lmbda_new, v_new, error_approx = self._synthesize_randomized_svd( k=k, p=p, q=q ) @@ -237,6 +224,7 @@ def synthesize( ) else: # method == 'power' + assert max_iter > 0, "max_iter must be greater than zero" lmbda_max, v_max = self._synthesize_power( @@ -247,20 +235,16 @@ def synthesize( ) n = v_max.shape[0] - eig_vecs = self._vector_to_image( - torch.cat((v_max, v_min), dim=1).detach() - ) + eig_vecs = self._vector_to_image(torch.cat((v_max, v_min), dim=1).detach()) eig_vals = torch.cat([lmbda_max, lmbda_min]).squeeze() eig_vecs_ind = torch.cat((torch.arange(k), torch.arange(n - k, n))) # reshape to (n x num_chans x h x w) - self._eigendistortions = ( - torch.stack(eig_vecs, 0) if len(eig_vecs) != 0 else [] - ) + self._eigendistortions = torch.stack(eig_vecs, 0) if len(eig_vecs) != 0 else [] self._eigenvalues = torch.abs(eig_vals.detach()) self._eigenindex = eig_vecs_ind - def _synthesize_exact(self) -> tuple[Tensor, Tensor]: + def _synthesize_exact(self) -> Tuple[Tensor, Tensor]: r"""Eigendecomposition of explicitly computed Fisher Information Matrix. To be used when the input is small (e.g. less than 70x70 image on cluster or 30x30 on your own machine). This @@ -300,8 +284,8 @@ def compute_jacobian(self) -> Tensor: return J def _synthesize_power( - self, k: int, shift: Tensor | float, tol: float, max_iter: int - ) -> tuple[Tensor, Tensor]: + self, k: int, shift: Union[Tensor, float], tol: float, max_iter: int + ) -> Tuple[Tensor, Tensor]: r"""Use power method (or orthogonal iteration when k>1) to obtain largest (smallest) eigenvalue/vector pairs. Apply the algorithm to approximate the extremal eigenvalues and eigenvectors of the Fisher @@ -342,9 +326,7 @@ def _synthesize_power( v = torch.randn(len(x), k, device=x.device, dtype=x.dtype) v = v / torch.linalg.vector_norm(v, dim=0, keepdim=True, ord=2) - _dummy_vec = torch.ones_like( - y, requires_grad=True - ) # cache a dummy vec for jvp + _dummy_vec = torch.ones_like(y, requires_grad=True) # cache a dummy vec for jvp Fv = fisher_info_matrix_vector_product(y, x, v, _dummy_vec) v = Fv / torch.linalg.vector_norm(Fv, dim=0, keepdim=True, ord=2) lmbda = fisher_info_matrix_eigenvalue(y, x, v, _dummy_vec) @@ -366,15 +348,11 @@ def _synthesize_power( Fv = fisher_info_matrix_vector_product(y, x, v, _dummy_vec) Fv = Fv - shift * v # optionally shift: (F - shift*I)v - v_new, _ = torch.linalg.qr( - Fv, "reduced" - ) # (ortho)normalize vector(s) + v_new, _ = torch.linalg.qr(Fv, "reduced") # (ortho)normalize vector(s) lmbda_new = fisher_info_matrix_eigenvalue(y, x, v_new, _dummy_vec) - d_lambda = torch.linalg.vector_norm( - lmbda - lmbda_new, ord=2 - ) # stability of eigenspace + d_lambda = torch.linalg.vector_norm(lmbda - lmbda_new, ord=2) # stability of eigenspace v = v_new lmbda = lmbda_new @@ -384,7 +362,7 @@ def _synthesize_power( def _synthesize_randomized_svd( self, k: int, p: int, q: int - ) -> tuple[Tensor, Tensor, Tensor]: + ) -> Tuple[Tensor, Tensor, Tensor]: r"""Synthesize eigendistortions using randomized truncated SVD. This method approximates the column space of the Fisher Info Matrix, projects the FIM into that column space, @@ -443,13 +421,11 @@ def _synthesize_randomized_svd( y, x, torch.randn(n, 20).to(x.device), _dummy_vec ) error_approx = omega - (Q @ Q.T @ omega) - error_approx = torch.linalg.vector_norm( - error_approx, dim=0, ord=2 - ).mean() + error_approx = torch.linalg.vector_norm(error_approx, dim=0, ord=2).mean() return S[:k].clone(), V[:, :k].clone(), error_approx # truncate - def _vector_to_image(self, vecs: Tensor) -> list[Tensor]: + def _vector_to_image(self, vecs: Tensor) -> List[Tensor]: r"""Reshapes eigenvectors back into correct image dimensions. Parameters @@ -465,9 +441,7 @@ def _vector_to_image(self, vecs: Tensor) -> list[Tensor]: """ imgs = [ - vecs[:, i].reshape( - (self.n_channels, self.im_height, self.im_width) - ) + vecs[:, i].reshape((self.n_channels, self.im_height, self.im_width)) for i in range(vecs.shape[1]) ] return imgs @@ -479,9 +453,7 @@ def _indexer(self, idx: int) -> int: i = idx_range[idx] all_idx = self.eigenindex - assert ( - i in all_idx - ), "eigenindex must be the index of one of the vectors" + assert i in all_idx, "eigenindex must be the index of one of the vectors" assert ( all_idx is not None and len(all_idx) != 0 ), "No eigendistortions synthesized" @@ -534,24 +506,14 @@ def to(self, *args, **kwargs): dtype and device for all parameters and buffers in this module """ - attrs = [ - "_jacobian", - "_eigendistortions", - "_eigenvalues", - "_eigenindex", - "_model", - "_image", - "_image_flat", - "_representation_flat", - ] + attrs = ["_jacobian", "_eigendistortions", "_eigenvalues", + "_eigenindex", "_model", "_image", "_image_flat", + "_representation_flat"] super().to(*args, attrs=attrs, **kwargs) - def load( - self, - file_path: str, - map_location: str | None = None, - **pickle_load_args, - ): + def load(self, file_path: str, + map_location: Union[str, None] = None, + **pickle_load_args): r"""Load all relevant stuff from a .pt file. This should be called by an initialized ``Eigendistortion`` object -- @@ -585,15 +547,12 @@ def load( *then* load. """ - check_attributes = ["_image", "_representation_flat"] + check_attributes = ['_image', '_representation_flat'] check_loss_functions = [] - super().load( - file_path, - map_location=map_location, - check_attributes=check_attributes, - check_loss_functions=check_loss_functions, - **pickle_load_args, - ) + super().load(file_path, map_location=map_location, + check_attributes=check_attributes, + check_loss_functions=check_loss_functions, + **pickle_load_args) # make these require a grad again self._image_flat.requires_grad_() # we need _representation_flat and _image_flat to be connected in the @@ -611,22 +570,22 @@ def image(self): @property def jacobian(self): - """Is only set when :func:`synthesize` is run with ``method='exact'``. Default to ``None``.""" + """Is only set when :func:`synthesize` is run with ``method='exact'``. Default to ``None``. """ return self._jacobian @property def eigendistortions(self): - """Tensor of eigendistortions (eigenvectors of Fisher matrix), ordered by eigenvalue.""" + """Tensor of eigendistortions (eigenvectors of Fisher matrix), ordered by eigenvalue. """ return self._eigendistortions @property def eigenvalues(self): - """Tensor of eigenvalues corresponding to each eigendistortion, listed in decreasing order.""" + """Tensor of eigenvalues corresponding to each eigendistortion, listed in decreasing order. """ return self._eigenvalues @property def eigenindex(self): - """Index of each eigenvector/eigenvalue.""" + """Index of each eigenvector/eigenvalue. """ return self._eigenindex @@ -635,7 +594,7 @@ def display_eigendistortion( eigenindex: int = 0, alpha: float = 5.0, process_image: Callable[[Tensor], Tensor] = lambda x: x, - ax: matplotlib.pyplot.axis | None = None, + ax: Optional[matplotlib.pyplot.axis] = None, plot_complex: str = "rectangular", **kwargs, ) -> Figure: diff --git a/src/plenoptic/synthesize/geodesic.py b/src/plenoptic/synthesize/geodesic.py index 56fd81b8..9e4f6a14 100644 --- a/src/plenoptic/synthesize/geodesic.py +++ b/src/plenoptic/synthesize/geodesic.py @@ -1,24 +1,21 @@ -import warnings from collections import OrderedDict -from typing import Literal - -import matplotlib as mpl +import warnings import matplotlib.pyplot as plt +import matplotlib as mpl import torch import torch.autograd as autograd from torch import Tensor from tqdm.auto import tqdm +from typing import Union, Tuple, Optional +from typing_extensions import Literal -from ..tools.convergence import pixel_change_convergence +from .synthesis import OptimizedSynthesis from ..tools.data import to_numpy from ..tools.optim import penalize_range -from ..tools.straightness import ( - deviation_from_line, - make_straight_line, - sample_brownian_bridge, -) from ..tools.validate import validate_input, validate_model -from .synthesis import OptimizedSynthesis +from ..tools.convergence import pixel_change_convergence +from ..tools.straightness import (deviation_from_line, make_straight_line, + sample_brownian_bridge) class Geodesic(OptimizedSynthesis): @@ -99,26 +96,16 @@ class Geodesic(OptimizedSynthesis): http://www.cns.nyu.edu/~lcv/pubs/makeAbs.php?loc=Henaff16b """ - - def __init__( - self, - image_a: Tensor, - image_b: Tensor, - model: torch.nn.Module, - n_steps: int = 10, - initial_sequence: Literal["straight", "bridge"] = "straight", - range_penalty_lambda: float = 0.1, - allowed_range: tuple[float, float] = (0, 1), - ): + def __init__(self, image_a: Tensor, image_b: Tensor, + model: torch.nn.Module, n_steps: int = 10, + initial_sequence: Literal['straight', 'bridge'] = 'straight', + range_penalty_lambda: float = .1, + allowed_range: Tuple[float, float] = (0, 1)): super().__init__(range_penalty_lambda, allowed_range) validate_input(image_a, no_batch=True, allowed_range=allowed_range) validate_input(image_b, no_batch=True, allowed_range=allowed_range) - validate_model( - model, - image_shape=image_a.shape, - image_dtype=image_a.dtype, - device=image_a.device, - ) + validate_model(model, image_shape=image_a.shape, image_dtype=image_a.dtype, + device=image_a.device) self.n_steps = n_steps self._model = model @@ -139,27 +126,22 @@ def _initialize(self, initial_sequence, start, stop, n_steps): (``'straight'``), or with a brownian bridge between the two anchors (``'bridge'``). """ - if initial_sequence == "bridge": + if initial_sequence == 'bridge': geodesic = sample_brownian_bridge(start, stop, n_steps) - elif initial_sequence == "straight": + elif initial_sequence == 'straight': geodesic = make_straight_line(start, stop, n_steps) else: - raise ValueError( - f"Don't know how to handle initial_sequence={initial_sequence}" - ) - _, geodesic, _ = torch.split(geodesic, [1, n_steps - 1, 1]) + raise ValueError(f"Don't know how to handle initial_sequence={initial_sequence}") + _, geodesic, _ = torch.split(geodesic, [1, n_steps-1, 1]) self._initial_sequence = initial_sequence geodesic.requires_grad_() self._geodesic = geodesic - def synthesize( - self, - max_iter: int = 1000, - optimizer: torch.optim.Optimizer | None = None, - store_progress: bool | int = False, - stop_criterion: float | None = None, - stop_iters_to_check: int = 50, - ): + def synthesize(self, max_iter: int = 1000, + optimizer: Optional[torch.optim.Optimizer] = None, + store_progress: Union[bool, int] = False, + stop_criterion: Optional[float] = None, + stop_iters_to_check: int = 50): """Synthesize a geodesic via optimization. Parameters @@ -191,17 +173,10 @@ def synthesize( """ if stop_criterion is None: # semi arbitrary default choice of tolerance - stop_criterion = ( - torch.linalg.vector_norm(self.pixelfade, ord=2) - / 1e4 - * (1 + 5**0.5) - / 2 - ) - print( - f"\n Stop criterion for pixel_change_norm = {stop_criterion:.5e}" - ) - - self._initialize_optimizer(optimizer, "_geodesic", 0.001) + stop_criterion = torch.linalg.vector_norm(self.pixelfade, ord=2) / 1e4 * (1 + 5 ** .5) / 2 + print(f"\n Stop criterion for pixel_change_norm = {stop_criterion:.5e}") + + self._initialize_optimizer(optimizer, '_geodesic', .001) # get ready to store progress self.store_progress = store_progress @@ -216,14 +191,12 @@ def synthesize( raise ValueError("Found a NaN in loss during optimization.") if self._check_convergence(stop_criterion, stop_iters_to_check): - warnings.warn( - "Pixel change norm has converged, stopping synthesis" - ) + warnings.warn("Pixel change norm has converged, stopping synthesis") break pbar.close() - def objective_function(self, geodesic: Tensor | None = None) -> Tensor: + def objective_function(self, geodesic: Optional[Tensor] = None) -> Tensor: """Compute geodesic synthesis loss. This is the path energy (i.e., squared L2 norm of each step) of the @@ -251,19 +224,16 @@ def objective_function(self, geodesic: Tensor | None = None) -> Tensor: if geodesic is None: geodesic = self.geodesic self._geodesic_representation = self.model(geodesic) - self._most_recent_step_energy = self._calculate_step_energy( - self._geodesic_representation - ) + self._most_recent_step_energy = self._calculate_step_energy(self._geodesic_representation) loss = self._most_recent_step_energy.mean() range_penalty = penalize_range(self.geodesic, self.allowed_range) return loss + self.range_penalty_lambda * range_penalty def _calculate_step_energy(self, z): - """calculate the energy (i.e. squared l2 norm) of each step in `z`.""" + """calculate the energy (i.e. squared l2 norm) of each step in `z`. + """ velocity = torch.diff(z, dim=0) - step_energy = ( - torch.linalg.vector_norm(velocity, ord=2, dim=[1, 2, 3]) ** 2 - ) + step_energy = torch.linalg.vector_norm(velocity, ord=2, dim=[1, 2, 3]) ** 2 return step_energy def _optimizer_step(self, pbar): @@ -284,30 +254,21 @@ def _optimizer_step(self, pbar): loss = self.optimizer.step(self._closure) self._losses.append(loss.item()) - grad_norm = torch.linalg.vector_norm( - self._geodesic.grad.data, ord=2, dim=None - ) + grad_norm = torch.linalg.vector_norm(self._geodesic.grad.data, + ord=2, dim=None) self._gradient_norm.append(grad_norm) - pixel_change_norm = torch.linalg.vector_norm( - self._geodesic - last_iter_geodesic, ord=2, dim=None - ) + pixel_change_norm = torch.linalg.vector_norm(self._geodesic - last_iter_geodesic, + ord=2, dim=None) self._pixel_change_norm.append(pixel_change_norm) # displaying some information - pbar.set_postfix( - OrderedDict( - [ - ("loss", f"{loss.item():.4e}"), - ("gradient norm", f"{grad_norm.item():.4e}"), - ("pixel change norm", f"{pixel_change_norm.item():.5e}"), - ] - ) - ) + pbar.set_postfix(OrderedDict([('loss', f'{loss.item():.4e}'), + ('gradient norm', f'{grad_norm.item():.4e}'), + ('pixel change norm', f"{pixel_change_norm.item():.5e}")])) return loss - def _check_convergence( - self, stop_criterion: float, stop_iters_to_check: int - ) -> bool: + def _check_convergence(self, stop_criterion: float, + stop_iters_to_check: int) -> bool: """Check whether the pixel change norm has stabilized and, if so, return True. Have we been synthesizing for ``stop_iters_to_check`` iterations? @@ -336,11 +297,9 @@ def _check_convergence( Whether the pixel change norm has stabilized or not. """ - return pixel_change_convergence( - self, stop_criterion, stop_iters_to_check - ) + return pixel_change_convergence(self, stop_criterion, stop_iters_to_check) - def calculate_jerkiness(self, geodesic: Tensor | None = None) -> Tensor: + def calculate_jerkiness(self, geodesic: Optional[Tensor] = None) -> Tensor: """Compute the alignment of representation's acceleration to model local curvature. This is the first order optimality condition for a geodesic, and can be @@ -362,19 +321,15 @@ def calculate_jerkiness(self, geodesic: Tensor | None = None) -> Tensor: geodesic_representation = self.model(geodesic) velocity = torch.diff(geodesic_representation, dim=0) acceleration = torch.diff(velocity, dim=0) - acc_magnitude = torch.linalg.vector_norm( - acceleration, ord=2, dim=[1, 2, 3], keepdim=True - ) + acc_magnitude = torch.linalg.vector_norm(acceleration, ord=2, dim=[1,2,3], + keepdim=True) acc_direction = torch.div(acceleration, acc_magnitude) # we slice the output of the VJP, rather than slicing geodesic, because # slicing interferes with the gradient computation: # https://stackoverflow.com/a/54767100 - accJac = self._vector_jacobian_product( - geodesic_representation[1:-1], geodesic, acc_direction - )[1:-1] - step_jerkiness = ( - torch.linalg.vector_norm(accJac, dim=[1, 2, 3], ord=2) ** 2 - ) + accJac = self._vector_jacobian_product(geodesic_representation[1:-1], + geodesic, acc_direction)[1:-1] + step_jerkiness = torch.linalg.vector_norm(accJac, dim=[1,2,3], ord=2) ** 2 return step_jerkiness def _vector_jacobian_product(self, y, x, a): @@ -382,9 +337,9 @@ def _vector_jacobian_product(self, y, x, a): and allow for further gradient computations by retaining, and creating the graph. """ - accJac = autograd.grad(y, x, a, retain_graph=True, create_graph=True)[ - 0 - ] + accJac = autograd.grad(y, x, a, + retain_graph=True, + create_graph=True)[0] return accJac def _store(self, i: int) -> bool: @@ -407,29 +362,15 @@ def _store(self, i: int) -> bool: if self.store_progress and (i % self.store_progress == 0): # want these to always be on cpu, to reduce memory use for GPUs try: - self._step_energy.append( - self._most_recent_step_energy.detach().to("cpu") - ) - self._dev_from_line.append( - torch.stack( - deviation_from_line( - self._geodesic_representation.detach().to("cpu") - ) - ).T - ) + self._step_energy.append(self._most_recent_step_energy.detach().to('cpu')) + self._dev_from_line.append(torch.stack(deviation_from_line(self._geodesic_representation.detach().to('cpu'))).T) except AttributeError: # the first time _store is called (i.e., before optimizer is # stepped for first time) those attributes won't be # initialized geod_rep = self.model(self.geodesic) - self._step_energy.append( - self._calculate_step_energy(geod_rep).detach().to("cpu") - ) - self._dev_from_line.append( - torch.stack( - deviation_from_line(geod_rep.detach().to("cpu")) - ).T - ) + self._step_energy.append(self._calculate_step_energy(geod_rep).detach().to('cpu')) + self._dev_from_line.append(torch.stack(deviation_from_line(geod_rep.detach().to('cpu'))).T) stored = True else: stored = False @@ -486,23 +427,13 @@ def to(self, *args, **kwargs): dtype and device for all parameters and buffers in this module """ - attrs = [ - "_image_a", - "_image_b", - "_geodesic", - "_model", - "_step_energy", - "_dev_from_line", - "pixelfade", - ] + attrs = ['_image_a', '_image_b', '_geodesic', '_model', + '_step_energy', '_dev_from_line', 'pixelfade'] super().to(*args, attrs=attrs, **kwargs) - def load( - self, - file_path: str, - map_location: str | None = None, - **pickle_load_args, - ): + def load(self, file_path: str, + map_location: Union[str, None] = None, + **pickle_load_args): r"""Load all relevant stuff from a .pt file. This should be called by an initialized ``Geodesic`` object -- we will @@ -538,47 +469,28 @@ def load( *then* load. """ - check_attributes = [ - "_image_a", - "_image_b", - "n_steps", - "_initial_sequence", - "_range_penalty_lambda", - "_allowed_range", - "pixelfade", - ] + check_attributes = ['_image_a', '_image_b', 'n_steps', + '_initial_sequence', '_range_penalty_lambda', + '_allowed_range', 'pixelfade'] check_loss_functions = [] new_loss = self.objective_function(self.pixelfade) - super().load( - file_path, - map_location=map_location, - check_attributes=check_attributes, - check_loss_functions=check_loss_functions, - **pickle_load_args, - ) - old_loss = self.__dict__.pop("_save_check") + super().load(file_path, map_location=map_location, + check_attributes=check_attributes, + check_loss_functions=check_loss_functions, + **pickle_load_args) + old_loss = self.__dict__.pop('_save_check') if not torch.allclose(new_loss, old_loss, rtol=1e-2): - raise ValueError( - "objective_function on pixelfade of saved and initialized Geodesic object are different! Do they use the same model?" - f" Self: {new_loss}, Saved: {old_loss}" - ) + raise ValueError("objective_function on pixelfade of saved and initialized Geodesic object are different! Do they use the same model?" + f" Self: {new_loss}, Saved: {old_loss}") # make this require a grad again self._geodesic.requires_grad_() # these are always supposed to be on cpu, but may get copied over to # gpu on load (which can cause problems when resuming synthesis), so # fix that. - if ( - len(self._dev_from_line) - and self._dev_from_line[0].device.type != "cpu" - ): - self._dev_from_line = [ - dev.to("cpu") for dev in self._dev_from_line - ] - if ( - len(self._step_energy) - and self._step_energy[0].device.type != "cpu" - ): - self._step_energy = [step.to("cpu") for step in self._step_energy] + if len(self._dev_from_line) and self._dev_from_line[0].device.type != 'cpu': + self._dev_from_line = [dev.to('cpu') for dev in self._dev_from_line] + if len(self._step_energy) and self._step_energy[0].device.type != 'cpu': + self._step_energy = [step.to('cpu') for step in self._step_energy] @property def model(self): @@ -623,9 +535,9 @@ def dev_from_line(self): return torch.stack(self._dev_from_line) -def plot_loss( - geodesic: Geodesic, ax: mpl.axes.Axes | None = None, **kwargs -) -> mpl.axes.Axes: +def plot_loss(geodesic: Geodesic, + ax: Union[mpl.axes.Axes, None] = None, + **kwargs) -> mpl.axes.Axes: """Plot synthesis loss. Parameters @@ -647,15 +559,14 @@ def plot_loss( if ax is None: ax = plt.gca() ax.semilogy(geodesic.losses, **kwargs) - ax.set(xlabel="Synthesis iteration", ylabel="Loss") + ax.set(xlabel='Synthesis iteration', + ylabel='Loss') return ax - -def plot_deviation_from_line( - geodesic: Geodesic, - natural_video: Tensor | None = None, - ax: mpl.axes.Axes | None = None, -) -> mpl.axes.Axes: +def plot_deviation_from_line(geodesic: Geodesic, + natural_video: Union[Tensor, None] = None, + ax: Union[mpl.axes.Axes, None] = None + ) -> mpl.axes.Axes: """Visual diagnostic of geodesic linearity in representation space. This plot illustrates the deviation from the straight line connecting @@ -698,24 +609,18 @@ def plot_deviation_from_line( ax = plt.gca() pixelfade_dev = deviation_from_line(geodesic.model(geodesic.pixelfade)) - ax.plot(*[to_numpy(d) for d in pixelfade_dev], "g-o", label="pixelfade") + ax.plot(*[to_numpy(d) for d in pixelfade_dev], 'g-o', label='pixelfade') - geodesic_dev = deviation_from_line( - geodesic.model(geodesic.geodesic).detach() - ) - ax.plot(*[to_numpy(d) for d in geodesic_dev], "r-o", label="geodesic") + geodesic_dev = deviation_from_line(geodesic.model(geodesic.geodesic).detach()) + ax.plot(*[to_numpy(d) for d in geodesic_dev], 'r-o', label='geodesic') if natural_video is not None: video_dev = deviation_from_line(geodesic.model(natural_video)) - ax.plot( - *[to_numpy(d) for d in video_dev], "b-o", label="natural video" - ) - - ax.set( - xlabel="Distance along representation line", - ylabel="Distance from representation line", - title="Deviation from the straight line", - ) + ax.plot(*[to_numpy(d) for d in video_dev], 'b-o', label='natural video') + + ax.set(xlabel='Distance along representation line', + ylabel='Distance from representation line', + title='Deviation from the straight line') ax.legend(loc=1) return ax diff --git a/src/plenoptic/synthesize/mad_competition.py b/src/plenoptic/synthesize/mad_competition.py index 4baf6dd0..b3e61330 100644 --- a/src/plenoptic/synthesize/mad_competition.py +++ b/src/plenoptic/synthesize/mad_competition.py @@ -1,21 +1,19 @@ """Run MAD Competition.""" +import torch +import numpy as np +from torch import Tensor +from tqdm.auto import tqdm +from ..tools import optim, display, data +from typing import Union, Tuple, Callable, List, Dict, Optional +from typing_extensions import Literal +from .synthesis import OptimizedSynthesis import warnings -from collections import OrderedDict -from collections.abc import Callable -from typing import Literal - import matplotlib as mpl import matplotlib.pyplot as plt -import numpy as np -import torch +from collections import OrderedDict from pyrtools.tools.display import make_figure as pt_make_figure -from torch import Tensor -from tqdm.auto import tqdm - -from ..tools import data, display, optim -from ..tools.convergence import loss_convergence from ..tools.validate import validate_input, validate_metric -from .synthesis import OptimizedSynthesis +from ..tools.convergence import loss_convergence class MADCompetition(OptimizedSynthesis): @@ -99,32 +97,20 @@ class MADCompetition(OptimizedSynthesis): http://dx.doi.org/10.1167/8.12.8 """ - - def __init__( - self, - image: Tensor, - optimized_metric: torch.nn.Module | Callable[[Tensor, Tensor], Tensor], - reference_metric: torch.nn.Module | Callable[[Tensor, Tensor], Tensor], - minmax: Literal["min", "max"], - initial_noise: float = 0.1, - metric_tradeoff_lambda: float | None = None, - range_penalty_lambda: float = 0.1, - allowed_range: tuple[float, float] = (0, 1), - ): + def __init__(self, image: Tensor, + optimized_metric: Union[torch.nn.Module, Callable[[Tensor, Tensor], Tensor]], + reference_metric: Union[torch.nn.Module, Callable[[Tensor, Tensor], Tensor]], + minmax: Literal['min', 'max'], + initial_noise: float = .1, + metric_tradeoff_lambda: Optional[float] = None, + range_penalty_lambda: float = .1, + allowed_range: Tuple[float, float] = (0, 1)): super().__init__(range_penalty_lambda, allowed_range) validate_input(image, allowed_range=allowed_range) - validate_metric( - optimized_metric, - image_shape=image.shape, - image_dtype=image.dtype, - device=image.device, - ) - validate_metric( - reference_metric, - image_shape=image.shape, - image_dtype=image.dtype, - device=image.device, - ) + validate_metric(optimized_metric, image_shape=image.shape, image_dtype=image.dtype, + device=image.device) + validate_metric(reference_metric, image_shape=image.shape, image_dtype=image.dtype, + device=image.device) self._optimized_metric = optimized_metric self._reference_metric = reference_metric self._image = image.detach() @@ -132,33 +118,25 @@ def __init__( self.scheduler = None self._optimized_metric_loss = [] self._reference_metric_loss = [] - if minmax not in ["min", "max"]: - raise ValueError( - "synthesis_target must be one of {'min', 'max'}, but got " - f"value {minmax} instead!" - ) + if minmax not in ['min', 'max']: + raise ValueError("synthesis_target must be one of {'min', 'max'}, but got " + f"value {minmax} instead!") self._minmax = minmax self._initialize(initial_noise) # If no metric_tradeoff_lambda is specified, pick one that gets them to # approximately the same magnitude if metric_tradeoff_lambda is None: - loss_ratio = torch.as_tensor( - self.optimized_metric_loss[-1] - / self.reference_metric_loss[-1], - dtype=torch.float32, - ) - metric_tradeoff_lambda = torch.pow( - torch.as_tensor(10), torch.round(torch.log10(loss_ratio)) - ).item() - warnings.warn( - "Since metric_tradeoff_lamda was None, automatically set" - f" to {metric_tradeoff_lambda} to roughly balance metrics." - ) + loss_ratio = torch.as_tensor(self.optimized_metric_loss[-1] / self.reference_metric_loss[-1], + dtype=torch.float32) + metric_tradeoff_lambda = torch.pow(torch.as_tensor(10), + torch.round(torch.log10(loss_ratio))).item() + warnings.warn("Since metric_tradeoff_lamda was None, automatically set" + f" to {metric_tradeoff_lambda} to roughly balance metrics.") self._metric_tradeoff_lambda = metric_tradeoff_lambda self._store_progress = None self._saved_mad_image = [] - def _initialize(self, initial_noise: float = 0.1): + def _initialize(self, initial_noise: float = .1): """Initialize the synthesized image. Initialize ``self.mad_image`` attribute to be ``image`` plus @@ -171,28 +149,24 @@ def _initialize(self, initial_noise: float = 0.1): ``mad_image`` from ``image``. """ - mad_image = self.image + initial_noise * torch.randn_like(self.image) + mad_image = (self.image + initial_noise * + torch.randn_like(self.image)) mad_image = mad_image.clamp(*self.allowed_range) self._initial_image = mad_image.clone() mad_image.requires_grad_() self._mad_image = mad_image - self._reference_metric_target = self.reference_metric( - self.image, self.mad_image - ).item() + self._reference_metric_target = self.reference_metric(self.image, + self.mad_image).item() self._reference_metric_loss.append(self._reference_metric_target) - self._optimized_metric_loss.append( - self.optimized_metric(self.image, self.mad_image).item() - ) - - def synthesize( - self, - max_iter: int = 100, - optimizer: torch.optim.Optimizer | None = None, - scheduler: torch.optim.lr_scheduler._LRScheduler | None = None, - store_progress: bool | int = False, - stop_criterion: float = 1e-4, - stop_iters_to_check: int = 50, - ): + self._optimized_metric_loss.append(self.optimized_metric(self.image, + self.mad_image).item()) + + def synthesize(self, max_iter: int = 100, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + store_progress: Union[bool, int] = False, + stop_criterion: float = 1e-4, stop_iters_to_check: int = 50 + ): r"""Synthesize a MAD image. Update the pixels of ``initial_image`` to maximize or minimize @@ -254,9 +228,9 @@ def synthesize( pbar.close() - def objective_function( - self, mad_image: Tensor | None = None, image: Tensor | None = None - ) -> Tensor: + def objective_function(self, + mad_image: Optional[Tensor] = None, + image: Optional[Tensor] = None) -> Tensor: r"""Compute the MADCompetition synthesis loss. This computes: @@ -294,18 +268,15 @@ def objective_function( image = self.image if mad_image is None: mad_image = self.mad_image - synth_target = {"min": 1, "max": -1}[self.minmax] + synth_target = {'min': 1, 'max': -1}[self.minmax] synthesis_loss = self.optimized_metric(image, mad_image) - fixed_loss = ( - self._reference_metric_target - - self.reference_metric(image, mad_image) - ).pow(2) - range_penalty = optim.penalize_range(mad_image, self.allowed_range) - return ( - synth_target * synthesis_loss - + self.metric_tradeoff_lambda * fixed_loss - + self.range_penalty_lambda * range_penalty - ) + fixed_loss = (self._reference_metric_target - + self.reference_metric(image, mad_image)).pow(2) + range_penalty = optim.penalize_range(mad_image, + self.allowed_range) + return (synth_target * synthesis_loss + + self.metric_tradeoff_lambda * fixed_loss + + self.range_penalty_lambda * range_penalty) def _optimizer_step(self, pbar: tqdm) -> Tensor: r"""Compute and propagate gradients, then step the optimizer to update mad_image. @@ -327,9 +298,8 @@ def _optimizer_step(self, pbar: tqdm) -> Tensor: last_iter_mad_image = self.mad_image.clone() loss = self.optimizer.step(self._closure) self._losses.append(loss.item()) - grad_norm = torch.linalg.vector_norm( - self.mad_image.grad.data, ord=2, dim=None - ) + grad_norm = torch.linalg.vector_norm(self.mad_image.grad.data, + ord=2, dim=None) self._gradient_norm.append(grad_norm.item()) fm = self.reference_metric(self.image, self.mad_image) @@ -341,22 +311,18 @@ def _optimizer_step(self, pbar: tqdm) -> Tensor: if self.scheduler is not None: self.scheduler.step(loss.item()) - pixel_change_norm = torch.linalg.vector_norm( - self.mad_image - last_iter_mad_image, ord=2, dim=None - ) + pixel_change_norm = torch.linalg.vector_norm(self.mad_image - last_iter_mad_image, + ord=2, dim=None) self._pixel_change_norm.append(pixel_change_norm.item()) # add extra info here if you want it to show up in progress bar pbar.set_postfix( - OrderedDict( - loss=f"{loss.item():.04e}", - learning_rate=self.optimizer.param_groups[0]["lr"], - gradient_norm=f"{grad_norm.item():.04e}", - pixel_change_norm=f"{pixel_change_norm.item():.04e}", - reference_metric=f"{fm.item():.04e}", - optimized_metric=f"{sm.item():.04e}", - ) - ) + OrderedDict(loss=f"{loss.item():.04e}", + learning_rate=self.optimizer.param_groups[0]['lr'], + gradient_norm=f"{grad_norm.item():.04e}", + pixel_change_norm=f"{pixel_change_norm.item():.04e}", + reference_metric=f'{fm.item():.04e}', + optimized_metric=f'{sm.item():.04e}')) return loss def _check_convergence(self, stop_criterion, stop_iters_to_check): @@ -392,7 +358,7 @@ def _check_convergence(self, stop_criterion, stop_iters_to_check): def _initialize_optimizer(self, optimizer, scheduler): """Initialize optimizer and scheduler.""" - super()._initialize_optimizer(optimizer, "mad_image") + super()._initialize_optimizer(optimizer, 'mad_image') self.scheduler = scheduler def _store(self, i: int) -> bool: @@ -413,7 +379,7 @@ def _store(self, i: int) -> bool: """ if self.store_progress and (i % self.store_progress == 0): # want these to always be on cpu, to reduce memory use for GPUs - self._saved_mad_image.append(self.mad_image.clone().to("cpu")) + self._saved_mad_image.append(self.mad_image.clone().to('cpu')) stored = True else: stored = False @@ -439,9 +405,9 @@ def save(self, file_path: str): # if the metrics are Modules, then we don't want to save them. If # they're functions then saving them is fine. if isinstance(self.optimized_metric, torch.nn.Module): - attrs.pop("_optimized_metric") + attrs.pop('_optimized_metric') if isinstance(self.reference_metric, torch.nn.Module): - attrs.pop("_reference_metric") + attrs.pop('_reference_metric') super().save(file_path, attrs=attrs) def to(self, *args, **kwargs): @@ -478,7 +444,8 @@ def to(self, *args, **kwargs): dtype and device for all parameters and buffers in this module """ - attrs = ["_initial_image", "_image", "_mad_image", "_saved_mad_image"] + attrs = ['_initial_image', '_image', '_mad_image', + '_saved_mad_image'] super().to(*args, attrs=attrs, **kwargs) # if the metrics are Modules, then we should pass them as well. If # they're functions then nothing needs to be done. @@ -491,12 +458,9 @@ def to(self, *args, **kwargs): except AttributeError: pass - def load( - self, - file_path: str, - map_location: None | None = None, - **pickle_load_args, - ): + def load(self, file_path: str, + map_location: Optional[None] = None, + **pickle_load_args): r"""Load all relevant stuff from a .pt file. This should be called by an initialized ``MADCompetition`` object -- we @@ -533,33 +497,21 @@ def load( *then* load. """ - check_attributes = [ - "_image", - "_metric_tradeoff_lambda", - "_range_penalty_lambda", - "_allowed_range", - "_minmax", - ] - check_loss_functions = ["_reference_metric", "_optimized_metric"] - super().load( - file_path, - map_location=map_location, - check_attributes=check_attributes, - check_loss_functions=check_loss_functions, - **pickle_load_args, - ) + check_attributes = ['_image', '_metric_tradeoff_lambda', + '_range_penalty_lambda', '_allowed_range', + '_minmax'] + check_loss_functions = ['_reference_metric', '_optimized_metric'] + super().load(file_path, map_location=map_location, + check_attributes=check_attributes, + check_loss_functions=check_loss_functions, + **pickle_load_args) # make this require a grad again self.mad_image.requires_grad_() # these are always supposed to be on cpu, but may get copied over to # gpu on load (which can cause problems when resuming synthesis), so # fix that. - if ( - len(self._saved_mad_image) - and self._saved_mad_image[0].device.type != "cpu" - ): - self._saved_mad_image = [ - mad.to("cpu") for mad in self._saved_mad_image - ] + if len(self._saved_mad_image) and self._saved_mad_image[0].device.type != 'cpu': + self._saved_mad_image = [mad.to('cpu') for mad in self._saved_mad_image] @property def mad_image(self): @@ -602,12 +554,10 @@ def saved_mad_image(self): return torch.stack(self._saved_mad_image) -def plot_loss( - mad: MADCompetition, - iteration: int | None = None, - axes: list[mpl.axes.Axes] | mpl.axes.Axes | None = None, - **kwargs, -) -> mpl.axes.Axes: +def plot_loss(mad: MADCompetition, + iteration: Optional[int] = None, + axes: Union[List[mpl.axes.Axes], mpl.axes.Axes, None] = None, + **kwargs) -> mpl.axes.Axes: """Plot metric losses. Plots ``mad.optimized_metric_loss`` and ``mad.reference_metric_loss`` on two @@ -652,32 +602,30 @@ def plot_loss( loss_idx = iteration if axes is None: axes = plt.gca() - if not hasattr(axes, "__iter__"): - axes = display.clean_up_axes( - axes, False, ["top", "right", "bottom", "left"], ["x", "y"] - ) + if not hasattr(axes, '__iter__'): + axes = display.clean_up_axes(axes, False, + ['top', 'right', 'bottom', 'left'], + ['x', 'y']) gs = axes.get_subplotspec().subgridspec(1, 2) fig = axes.figure axes = [fig.add_subplot(gs[0, 0]), fig.add_subplot(gs[0, 1])] losses = [mad.reference_metric_loss, mad.optimized_metric_loss] - names = ["Reference metric loss", "Optimized metric loss"] - for ax, loss, name in zip(axes, losses, names, strict=False): + names = ['Reference metric loss', 'Optimized metric loss'] + for ax, loss, name in zip(axes, losses, names): ax.plot(loss, **kwargs) - ax.scatter(loss_idx, loss[loss_idx], c="r") - ax.set(xlabel="Synthesis iteration", ylabel=name) + ax.scatter(loss_idx, loss[loss_idx], c='r') + ax.set(xlabel='Synthesis iteration', ylabel=name) return ax -def display_mad_image( - mad: MADCompetition, - batch_idx: int = 0, - channel_idx: int | None = None, - zoom: float | None = None, - iteration: int | None = None, - ax: mpl.axes.Axes | None = None, - title: str = "MADCompetition", - **kwargs, -) -> mpl.axes.Axes: +def display_mad_image(mad: MADCompetition, + batch_idx: int = 0, + channel_idx: Optional[int] = None, + zoom: Optional[float] = None, + iteration: Optional[int] = None, + ax: Optional[mpl.axes.Axes] = None, + title: str = 'MADCompetition', + **kwargs) -> mpl.axes.Axes: """Display MAD image. You can specify what iteration to view by using the ``iteration`` arg. @@ -732,30 +680,21 @@ def display_mad_image( as_rgb = False if ax is None: ax = plt.gca() - display.imshow( - image, - ax=ax, - title=title, - zoom=zoom, - batch_idx=batch_idx, - channel_idx=channel_idx, - as_rgb=as_rgb, - **kwargs, - ) + display.imshow(image, ax=ax, title=title, zoom=zoom, + batch_idx=batch_idx, channel_idx=channel_idx, + as_rgb=as_rgb, **kwargs) ax.xaxis.set_visible(False) ax.yaxis.set_visible(False) return ax -def plot_pixel_values( - mad: MADCompetition, - batch_idx: int = 0, - channel_idx: int | None = None, - iteration: int | None = None, - ylim: tuple[float] | Literal[False] = False, - ax: mpl.axes.Axes | None = None, - **kwargs, -) -> mpl.axes.Axes: +def plot_pixel_values(mad: MADCompetition, + batch_idx: int = 0, + channel_idx: Optional[int] = None, + iteration: Optional[int] = None, + ylim: Union[Tuple[float], Literal[False]] = False, + ax: Optional[mpl.axes.Axes] = None, + **kwargs) -> mpl.axes.Axes: r"""Plot histogram of pixel values of reference and MAD images. As a way to check the distributions of pixel intensities and see @@ -787,12 +726,11 @@ def plot_pixel_values( Creates axes. """ - def _freedman_diaconis_bins(a): """Calculate number of hist bins using Freedman-Diaconis rule. copied from seaborn.""" # From https://stats.stackexchange.com/questions/798/ a = np.asarray(a) - iqr = np.diff(np.percentile(a, [0.25, 0.75]))[0] + iqr = np.diff(np.percentile(a, [.25, .75]))[0] if len(a) < 2: return 1 h = 2 * iqr / (len(a) ** (1 / 3)) @@ -802,7 +740,7 @@ def _freedman_diaconis_bins(a): else: return int(np.ceil((a.max() - a.min()) / h)) - kwargs.setdefault("alpha", 0.4) + kwargs.setdefault('alpha', .4) if iteration is None: mad_image = mad.mad_image[batch_idx] else: @@ -815,18 +753,10 @@ def _freedman_diaconis_bins(a): ax = plt.gca() image = data.to_numpy(image).flatten() mad_image = data.to_numpy(mad_image).flatten() - ax.hist( - image, - bins=min(_freedman_diaconis_bins(image), 50), - label="Reference image", - **kwargs, - ) - ax.hist( - mad_image, - bins=min(_freedman_diaconis_bins(image), 50), - label="MAD image", - **kwargs, - ) + ax.hist(image, bins=min(_freedman_diaconis_bins(image), 50), + label='Reference image', **kwargs) + ax.hist(mad_image, bins=min(_freedman_diaconis_bins(image), 50), + label='MAD image', **kwargs) ax.legend() if ylim: ax.set_ylim(ylim) @@ -834,9 +764,8 @@ def _freedman_diaconis_bins(a): return ax -def _check_included_plots( - to_check: list[str] | dict[str, int], to_check_name: str -): +def _check_included_plots(to_check: Union[List[str], Dict[str, int]], + to_check_name: str): """Check whether the user wanted us to create plots that we can't. Helper function for plot_synthesis_status and animate. @@ -853,37 +782,26 @@ def _check_included_plots( Name of the `to_check` variable, used in the error message. """ - allowed_vals = [ - "display_mad_image", - "plot_loss", - "plot_pixel_values", - "misc", - ] + allowed_vals = ['display_mad_image', 'plot_loss', 'plot_pixel_values', 'misc'] try: vals = to_check.keys() except AttributeError: vals = to_check not_allowed = [v for v in vals if v not in allowed_vals] if not_allowed: - raise ValueError( - f"{to_check_name} contained value(s) {not_allowed}! " - f"Only {allowed_vals} are permissible!" - ) - - -def _setup_synthesis_fig( - fig: mpl.figure.Figure | None = None, - axes_idx: dict[str, int] = {}, - figsize: tuple[float] | None = None, - included_plots: list[str] = [ - "display_mad_image", - "plot_loss", - "plot_pixel_values", - ], - display_mad_image_width: float = 1, - plot_loss_width: float = 2, - plot_pixel_values_width: float = 1, -) -> tuple[mpl.figure.Figure, list[mpl.axes.Axes], dict[str, int]]: + raise ValueError(f'{to_check_name} contained value(s) {not_allowed}! ' + f'Only {allowed_vals} are permissible!') + + +def _setup_synthesis_fig(fig: Optional[mpl.figure.Figure] = None, + axes_idx: Dict[str, int] = {}, + figsize: Optional[Tuple[float]] = None, + included_plots: List[str] = ['display_mad_image', + 'plot_loss', + 'plot_pixel_values'], + display_mad_image_width: float = 1, + plot_loss_width: float = 2, + plot_pixel_values_width: float = 1) -> Tuple[mpl.figure.Figure, List[mpl.axes.Axes], Dict[str, int]]: """Set up figure for plot_synthesis_status. Creates figure with enough axes for the all the plots you want. Will @@ -934,75 +852,64 @@ def _setup_synthesis_fig( n_subplots = 0 axes_idx = axes_idx.copy() width_ratios = [] - if "display_mad_image" in included_plots: + if 'display_mad_image' in included_plots: n_subplots += 1 width_ratios.append(display_mad_image_width) - if "display_mad_image" not in axes_idx.keys(): - axes_idx["display_mad_image"] = data._find_min_int( - axes_idx.values() - ) - if "plot_loss" in included_plots: + if 'display_mad_image' not in axes_idx.keys(): + axes_idx['display_mad_image'] = data._find_min_int(axes_idx.values()) + if 'plot_loss' in included_plots: n_subplots += 1 width_ratios.append(plot_loss_width) - if "plot_loss" not in axes_idx.keys(): - axes_idx["plot_loss"] = data._find_min_int(axes_idx.values()) - if "plot_pixel_values" in included_plots: + if 'plot_loss' not in axes_idx.keys(): + axes_idx['plot_loss'] = data._find_min_int(axes_idx.values()) + if 'plot_pixel_values' in included_plots: n_subplots += 1 width_ratios.append(plot_pixel_values_width) - if "plot_pixel_values" not in axes_idx.keys(): - axes_idx["plot_pixel_values"] = data._find_min_int( - axes_idx.values() - ) + if 'plot_pixel_values' not in axes_idx.keys(): + axes_idx['plot_pixel_values'] = data._find_min_int(axes_idx.values()) if fig is None: width_ratios = np.array(width_ratios) if figsize is None: # we want (5, 5) for each subplot, with a bit of room between # each subplot - figsize = ((width_ratios * 5).sum() + width_ratios.sum() - 1, 5) + figsize = ((width_ratios*5).sum() + width_ratios.sum()-1, 5) width_ratios = width_ratios / width_ratios.sum() - fig, axes = plt.subplots( - 1, - n_subplots, - figsize=figsize, - gridspec_kw={"width_ratios": width_ratios}, - ) + fig, axes = plt.subplots(1, n_subplots, figsize=figsize, + gridspec_kw={'width_ratios': width_ratios}) if n_subplots == 1: axes = [axes] else: axes = fig.axes # make sure misc contains all the empty axes - misc_axes = axes_idx.get("misc", []) - if not hasattr(misc_axes, "__iter__"): + misc_axes = axes_idx.get('misc', []) + if not hasattr(misc_axes, '__iter__'): misc_axes = [misc_axes] all_axes = [] for i in axes_idx.values(): # so if it's a list of ints - if hasattr(i, "__iter__"): + if hasattr(i, '__iter__'): all_axes.extend(i) else: all_axes.append(i) misc_axes += [i for i, _ in enumerate(fig.axes) if i not in all_axes] - axes_idx["misc"] = misc_axes + axes_idx['misc'] = misc_axes return fig, axes, axes_idx -def plot_synthesis_status( - mad: MADCompetition, - batch_idx: int = 0, - channel_idx: int | None = None, - iteration: int | None = None, - vrange: tuple[float] | str = "indep1", - zoom: float | None = None, - fig: mpl.figure.Figure | None = None, - axes_idx: dict[str, int] = {}, - figsize: tuple[float] | None = None, - included_plots: list[str] = [ - "display_mad_image", - "plot_loss", - "plot_pixel_values", - ], - width_ratios: dict[str, float] = {}, -) -> tuple[mpl.figure.Figure, dict[str, int]]: +def plot_synthesis_status(mad: MADCompetition, + batch_idx: int = 0, + channel_idx: Optional[int] = None, + iteration: Optional[int] = None, + vrange: Union[Tuple[float], str] = 'indep1', + zoom: Optional[float] = None, + fig: Optional[mpl.figure.Figure] = None, + axes_idx: Dict[str, int] = {}, + figsize: Optional[Tuple[float]] = None, + included_plots: List[str] = ['display_mad_image', + 'plot_loss', + 'plot_pixel_values'], + width_ratios: Dict[str, float] = {}, + ) -> Tuple[mpl.figure.Figure, Dict[str, int]]: r"""Make a plot showing synthesis status. We create several subplots to analyze this. By default, we create two @@ -1070,75 +977,62 @@ def plot_synthesis_status( """ if iteration is not None and not mad.store_progress: - raise ValueError( - "synthesis() was run with store_progress=False, " - "cannot specify which iteration to plot (only" - " last one, with iteration=None)" - ) + raise ValueError("synthesis() was run with store_progress=False, " + "cannot specify which iteration to plot (only" + " last one, with iteration=None)") if mad.mad_image.ndim not in [3, 4]: - raise ValueError( - "plot_synthesis_status() expects 3 or 4d data;" - "unexpected behavior will result otherwise!" - ) - _check_included_plots(included_plots, "included_plots") - _check_included_plots(width_ratios, "width_ratios") - _check_included_plots(axes_idx, "axes_idx") - width_ratios = {f"{k}_width": v for k, v in width_ratios.items()} - fig, axes, axes_idx = _setup_synthesis_fig( - fig, axes_idx, figsize, included_plots, **width_ratios - ) - - if "display_mad_image" in included_plots: - display_mad_image( - mad, - batch_idx=batch_idx, - channel_idx=channel_idx, - iteration=iteration, - ax=axes[axes_idx["display_mad_image"]], - zoom=zoom, - vrange=vrange, - ) - if "plot_loss" in included_plots: - plot_loss(mad, iteration=iteration, axes=axes[axes_idx["plot_loss"]]) + raise ValueError("plot_synthesis_status() expects 3 or 4d data;" + "unexpected behavior will result otherwise!") + _check_included_plots(included_plots, 'included_plots') + _check_included_plots(width_ratios, 'width_ratios') + _check_included_plots(axes_idx, 'axes_idx') + width_ratios = {f'{k}_width': v for k, v in width_ratios.items()} + fig, axes, axes_idx = _setup_synthesis_fig(fig, axes_idx, figsize, + included_plots, + **width_ratios) + + if 'display_mad_image' in included_plots: + display_mad_image(mad, batch_idx=batch_idx, + channel_idx=channel_idx, + iteration=iteration, + ax=axes[axes_idx['display_mad_image']], + zoom=zoom, vrange=vrange) + if 'plot_loss' in included_plots: + plot_loss(mad, iteration=iteration, axes=axes[axes_idx['plot_loss']]) # this function creates a single axis for loss, which plot_loss then # split into two. this makes sure the right two axes are present in the # dict all_axes = [] for i in axes_idx.values(): # so if it's a list of ints - if hasattr(i, "__iter__"): + if hasattr(i, '__iter__'): all_axes.extend(i) else: all_axes.append(i) - new_axes = [i for i, _ in enumerate(fig.axes) if i not in all_axes] - axes_idx["plot_loss"] = new_axes - if "plot_pixel_values" in included_plots: - plot_pixel_values( - mad, - batch_idx=batch_idx, - channel_idx=channel_idx, - iteration=iteration, - ax=axes[axes_idx["plot_pixel_values"]], - ) + new_axes = [i for i, _ in enumerate(fig.axes) + if i not in all_axes] + axes_idx['plot_loss'] = new_axes + if 'plot_pixel_values' in included_plots: + plot_pixel_values(mad, batch_idx=batch_idx, + channel_idx=channel_idx, + iteration=iteration, + ax=axes[axes_idx['plot_pixel_values']]) return fig, axes_idx -def animate( - mad: MADCompetition, - framerate: int = 10, - batch_idx: int = 0, - channel_idx: int | None = None, - zoom: float | None = None, - fig: mpl.figure.Figure | None = None, - axes_idx: dict[str, int] = {}, - figsize: tuple[float] | None = None, - included_plots: list[str] = [ - "display_mad_image", - "plot_loss", - "plot_pixel_values", - ], - width_ratios: dict[str, float] = {}, -) -> mpl.animation.FuncAnimation: +def animate(mad: MADCompetition, + framerate: int = 10, + batch_idx: int = 0, + channel_idx: Optional[int] = None, + zoom: Optional[float] = None, + fig: Optional[mpl.figure.Figure] = None, + axes_idx: Dict[str, int] = {}, + figsize: Optional[Tuple[float]] = None, + included_plots: List[str] = ['display_mad_image', + 'plot_loss', + 'plot_pixel_values'], + width_ratios: Dict[str, float] = {}, + ) -> mpl.animation.FuncAnimation: r"""Animate synthesis progress. This is essentially the figure produced by @@ -1211,67 +1105,51 @@ def animate( """ if not mad.store_progress: - raise ValueError( - "synthesize() was run with store_progress=False," - " cannot animate!" - ) + raise ValueError("synthesize() was run with store_progress=False," + " cannot animate!") if mad.mad_image.ndim not in [3, 4]: - raise ValueError( - "animate() expects 3 or 4d data; unexpected" - " behavior will result otherwise!" - ) - _check_included_plots(included_plots, "included_plots") - _check_included_plots(width_ratios, "width_ratios") - _check_included_plots(axes_idx, "axes_idx") + raise ValueError("animate() expects 3 or 4d data; unexpected" + " behavior will result otherwise!") + _check_included_plots(included_plots, 'included_plots') + _check_included_plots(width_ratios, 'width_ratios') + _check_included_plots(axes_idx, 'axes_idx') # we run plot_synthesis_status to initialize the figure if either fig is # None or if there are no titles on any axes, which we assume means that # it's an empty figure if fig is None or not any([ax.get_title() for ax in fig.axes]): - fig, axes_idx = plot_synthesis_status( - mad=mad, - batch_idx=batch_idx, - channel_idx=channel_idx, - iteration=0, - figsize=figsize, - zoom=zoom, - fig=fig, - included_plots=included_plots, - axes_idx=axes_idx, - width_ratios=width_ratios, - ) + fig, axes_idx = plot_synthesis_status(mad=mad, + batch_idx=batch_idx, + channel_idx=channel_idx, + iteration=0, figsize=figsize, + zoom=zoom, fig=fig, + included_plots=included_plots, + axes_idx=axes_idx, + width_ratios=width_ratios) # grab the artist for the second plot (we don't need to do this for the # MAD image plot, because we use the update_plot function for that) - if "plot_loss" in included_plots: - scat = [fig.axes[i].collections[0] for i in axes_idx["plot_loss"]] + if 'plot_loss' in included_plots: + scat = [fig.axes[i].collections[0] for i in axes_idx['plot_loss']] # can also have multiple plots def movie_plot(i): artists = [] - if "display_mad_image" in included_plots: - artists.extend( - display.update_plot( - fig.axes[axes_idx["display_mad_image"]], - data=mad.saved_mad_image[i], - batch_idx=batch_idx, - ) - ) - if "plot_pixel_values" in included_plots: + if 'display_mad_image' in included_plots: + artists.extend(display.update_plot(fig.axes[axes_idx['display_mad_image']], + data=mad.saved_mad_image[i], + batch_idx=batch_idx)) + if 'plot_pixel_values' in included_plots: # this is the dumbest way to do this, but it's simple -- # clearing the axes can cause problems if the user has, for # example, changed the tick locator or formatter. not sure how # to handle this best right now - fig.axes[axes_idx["plot_pixel_values"]].clear() - plot_pixel_values( - mad, - batch_idx=batch_idx, - channel_idx=channel_idx, - iteration=i, - ax=fig.axes[axes_idx["plot_pixel_values"]], - ) - if "plot_loss" in included_plots: + fig.axes[axes_idx['plot_pixel_values']].clear() + plot_pixel_values(mad, batch_idx=batch_idx, + channel_idx=channel_idx, iteration=i, + ax=fig.axes[axes_idx['plot_pixel_values']]) + if 'plot_loss' in included_plots: # loss always contains values from every iteration, but everything # else will be subsampled. - x_val = i * mad.store_progress + x_val = i*mad.store_progress scat[0].set_offsets((x_val, mad.reference_metric_loss[x_val])) scat[1].set_offsets((x_val, mad.optimized_metric_loss[x_val])) artists.extend(scat) @@ -1279,28 +1157,22 @@ def movie_plot(i): return artists # don't need an init_func, since we handle initialization ourselves - anim = mpl.animation.FuncAnimation( - fig, - movie_plot, - frames=len(mad.saved_mad_image), - blit=True, - interval=1000.0 / framerate, - repeat=False, - ) + anim = mpl.animation.FuncAnimation(fig, movie_plot, + frames=len(mad.saved_mad_image), + blit=True, interval=1000./framerate, + repeat=False) plt.close(fig) return anim -def display_mad_image_all( - mad_metric1_min: MADCompetition, - mad_metric2_min: MADCompetition, - mad_metric1_max: MADCompetition, - mad_metric2_max: MADCompetition, - metric1_name: str | None = None, - metric2_name: str | None = None, - zoom: int | float = 1, - **kwargs, -) -> mpl.figure.Figure: +def display_mad_image_all(mad_metric1_min: MADCompetition, + mad_metric2_min: MADCompetition, + mad_metric1_max: MADCompetition, + mad_metric2_max: MADCompetition, + metric1_name: Optional[str] = None, + metric2_name: Optional[str] = None, + zoom: Union[int, float] = 1, + **kwargs) -> mpl.figure.Figure: """Display all MAD Competition images. To generate a full set of MAD Competition images, you need four instances: @@ -1344,74 +1216,49 @@ def display_mad_image_all( # this is a bit of a hack right now, because they don't all have same # initial image if not torch.allclose(mad_metric1_min.image, mad_metric2_min.image): - raise ValueError( - "All four instances of MADCompetition must have same image!" - ) + raise ValueError("All four instances of MADCompetition must have same image!") if not torch.allclose(mad_metric1_min.image, mad_metric1_max.image): - raise ValueError( - "All four instances of MADCompetition must have same image!" - ) + raise ValueError("All four instances of MADCompetition must have same image!") if not torch.allclose(mad_metric1_min.image, mad_metric2_max.image): - raise ValueError( - "All four instances of MADCompetition must have same image!" - ) + raise ValueError("All four instances of MADCompetition must have same image!") if metric1_name is None: metric1_name = mad_metric1_min.optimized_metric.__name__ if metric2_name is None: metric2_name = mad_metric2_min.optimized_metric.__name__ - fig = pt_make_figure( - 3, 2, [zoom * i for i in mad_metric1_min.image.shape[-2:]] - ) + fig = pt_make_figure(3, 2, [zoom * i for i in + mad_metric1_min.image.shape[-2:]]) mads = [mad_metric1_min, mad_metric1_max, mad_metric2_min, mad_metric2_max] - titles = [ - f"Minimize {metric1_name}", - f"Maximize {metric1_name}", - f"Minimize {metric2_name}", - f"Maximize {metric2_name}", - ] + titles = [f'Minimize {metric1_name}', f'Maximize {metric1_name}', + f'Minimize {metric2_name}', f'Maximize {metric2_name}'] # we're only plotting one image here, so if the user wants multiple # channels, they must be RGB - if ( - kwargs.get("channel_idx", None) is None - and mad_metric1_min.initial_image.shape[1] > 1 - ): + if kwargs.get('channel_idx', None) is None and mad_metric1_min.initial_image.shape[1] > 1: as_rgb = True else: as_rgb = False - display.imshow( - mad_metric1_min.image, - ax=fig.axes[0], - title="Reference image", - zoom=zoom, - as_rgb=as_rgb, - **kwargs, - ) - display.imshow( - mad_metric1_min.initial_image, - ax=fig.axes[1], - title="Initial (noisy) image", - zoom=zoom, - as_rgb=as_rgb, - **kwargs, - ) - for ax, mad, title in zip(fig.axes[2:], mads, titles, strict=False): - display_mad_image(mad, zoom=zoom, ax=ax, title=title, **kwargs) + display.imshow(mad_metric1_min.image, ax=fig.axes[0], + title='Reference image', zoom=zoom, as_rgb=as_rgb, + **kwargs) + display.imshow(mad_metric1_min.initial_image, ax=fig.axes[1], + title='Initial (noisy) image', zoom=zoom, as_rgb=as_rgb, + **kwargs) + for ax, mad, title in zip(fig.axes[2:], mads, titles): + display_mad_image(mad, zoom=zoom, ax=ax, title=title, + **kwargs) return fig -def plot_loss_all( - mad_metric1_min: MADCompetition, - mad_metric2_min: MADCompetition, - mad_metric1_max: MADCompetition, - mad_metric2_max: MADCompetition, - metric1_name: str | None = None, - metric2_name: str | None = None, - metric1_kwargs: dict = {"c": "C0"}, - metric2_kwargs: dict = {"c": "C1"}, - min_kwargs: dict = {"linestyle": "--"}, - max_kwargs: dict = {"linestyle": "-"}, - figsize=(10, 5), -) -> mpl.figure.Figure: +def plot_loss_all(mad_metric1_min: MADCompetition, + mad_metric2_min: MADCompetition, + mad_metric1_max: MADCompetition, + mad_metric2_max: MADCompetition, + metric1_name: Optional[str] = None, + metric2_name: Optional[str] = None, + metric1_kwargs: Dict = {'c': 'C0'}, + metric2_kwargs: Dict = {'c': 'C1'}, + min_kwargs: Dict = {'linestyle': '--'}, + max_kwargs: Dict = {'linestyle': '-'}, + figsize=(10, 5)) -> mpl.figure.Figure: """Plot loss for full set of MAD Competiton instances. To generate a full set of MAD Competition images, you need four instances: @@ -1459,52 +1306,26 @@ def plot_loss_all( """ if not torch.allclose(mad_metric1_min.image, mad_metric2_min.image): - raise ValueError( - "All four instances of MADCompetition must have same image!" - ) + raise ValueError("All four instances of MADCompetition must have same image!") if not torch.allclose(mad_metric1_min.image, mad_metric1_max.image): - raise ValueError( - "All four instances of MADCompetition must have same image!" - ) + raise ValueError("All four instances of MADCompetition must have same image!") if not torch.allclose(mad_metric1_min.image, mad_metric2_max.image): - raise ValueError( - "All four instances of MADCompetition must have same image!" - ) + raise ValueError("All four instances of MADCompetition must have same image!") if metric1_name is None: metric1_name = mad_metric1_min.optimized_metric.__name__ if metric2_name is None: metric2_name = mad_metric2_min.optimized_metric.__name__ fig, axes = plt.subplots(1, 2, figsize=figsize) - plot_loss( - mad_metric1_min, - axes=axes, - label=f"Minimize {metric1_name}", - **metric1_kwargs, - **min_kwargs, - ) - plot_loss( - mad_metric1_max, - axes=axes, - label=f"Maximize {metric1_name}", - **metric1_kwargs, - **max_kwargs, - ) + plot_loss(mad_metric1_min, axes=axes, label=f'Minimize {metric1_name}', + **metric1_kwargs, **min_kwargs) + plot_loss(mad_metric1_max, axes=axes, label=f'Maximize {metric1_name}', + **metric1_kwargs, **max_kwargs) # we pass the axes backwards here because the fixed and synthesis metrics are the opposite as they are in the instances above. - plot_loss( - mad_metric2_min, - axes=axes[::-1], - label=f"Minimize {metric2_name}", - **metric2_kwargs, - **min_kwargs, - ) - plot_loss( - mad_metric2_max, - axes=axes[::-1], - label=f"Maximize {metric2_name}", - **metric2_kwargs, - **max_kwargs, - ) - axes[0].set(ylabel="Loss", title=metric2_name) - axes[1].set(ylabel="Loss", title=metric1_name) - axes[1].legend(loc="center left", bbox_to_anchor=(1.1, 0.5)) + plot_loss(mad_metric2_min, axes=axes[::-1], label=f'Minimize {metric2_name}', + **metric2_kwargs, **min_kwargs) + plot_loss(mad_metric2_max, axes=axes[::-1], label=f'Maximize {metric2_name}', + **metric2_kwargs, **max_kwargs) + axes[0].set(ylabel='Loss', title=metric2_name) + axes[1].set(ylabel='Loss', title=metric1_name) + axes[1].legend(loc='center left', bbox_to_anchor=(1.1, .5)) return fig diff --git a/src/plenoptic/synthesize/metamer.py b/src/plenoptic/synthesize/metamer.py index d2027ea7..616bdb20 100644 --- a/src/plenoptic/synthesize/metamer.py +++ b/src/plenoptic/synthesize/metamer.py @@ -1,25 +1,20 @@ """Synthesize model metamers.""" +import torch import re -import warnings -from collections import OrderedDict -from collections.abc import Callable -from typing import Literal - -import matplotlib as mpl -import matplotlib.pyplot as plt import numpy as np -import torch from torch import Tensor from tqdm.auto import tqdm -from ..tools import data, display, optim, signal +from ..tools import optim, display, signal, data +from ..tools.validate import validate_input, validate_model, validate_coarse_to_fine from ..tools.convergence import coarse_to_fine_enough, loss_convergence -from ..tools.validate import ( - validate_coarse_to_fine, - validate_input, - validate_model, -) +from typing import Union, Tuple, Callable, List, Dict, Optional +from typing_extensions import Literal from .synthesis import OptimizedSynthesis +import warnings +import matplotlib as mpl +import matplotlib.pyplot as plt +from collections import OrderedDict class Metamer(OptimizedSynthesis): @@ -87,24 +82,15 @@ class Metamer(OptimizedSynthesis): http://www.cns.nyu.edu/~lcv/texture/ """ - - def __init__( - self, - image: Tensor, - model: torch.nn.Module, - loss_function: Callable[[Tensor, Tensor], Tensor] = optim.mse, - range_penalty_lambda: float = 0.1, - allowed_range: tuple[float, float] = (0, 1), - initial_image: Tensor | None = None, - ): + def __init__(self, image: Tensor, model: torch.nn.Module, + loss_function: Callable[[Tensor, Tensor], Tensor] = optim.mse, + range_penalty_lambda: float = .1, + allowed_range: Tuple[float, float] = (0, 1), + initial_image: Optional[Tensor] = None): super().__init__(range_penalty_lambda, allowed_range) validate_input(image, allowed_range=allowed_range) - validate_model( - model, - image_shape=image.shape, - image_dtype=image.dtype, - device=image.device, - ) + validate_model(model, image_shape=image.shape, image_dtype=image.dtype, + device=image.device) self._model = model self._image = image self._image_shape = image.shape @@ -115,7 +101,7 @@ def __init__( self._saved_metamer = [] self._store_progress = None - def _initialize(self, initial_image: Tensor | None = None): + def _initialize(self, initial_image: Optional[Tensor] = None): """Initialize the metamer. Set the ``self.metamer`` attribute to be an attribute with the @@ -137,29 +123,22 @@ def _initialize(self, initial_image: Tensor | None = None): metamer.requires_grad_() else: if initial_image.ndimension() < 4: - raise ValueError( - "initial_image must be torch.Size([n_batch" - ", n_channels, im_height, im_width]) but got " - f"{initial_image.size()}" - ) + raise ValueError("initial_image must be torch.Size([n_batch" + ", n_channels, im_height, im_width]) but got " + f"{initial_image.size()}") if initial_image.size() != self.image.size(): raise ValueError("initial_image and image must be same size!") metamer = initial_image.clone().detach() - metamer = metamer.to( - dtype=self.image.dtype, device=self.image.device - ) + metamer = metamer.to(dtype=self.image.dtype, device=self.image.device) metamer.requires_grad_() self._metamer = metamer - def synthesize( - self, - max_iter: int = 100, - optimizer: torch.optim.Optimizer | None = None, - scheduler: torch.optim.lr_scheduler._LRScheduler | None = None, - store_progress: bool | int = False, - stop_criterion: float = 1e-4, - stop_iters_to_check: int = 50, - ): + def synthesize(self, max_iter: int = 100, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + store_progress: Union[bool, int] = False, + stop_criterion: float = 1e-4, stop_iters_to_check: int = 50, + ): r"""Synthesize a metamer. Update the pixels of ``initial_image`` until its representation matches @@ -218,11 +197,8 @@ def synthesize( pbar.close() - def objective_function( - self, - metamer_representation: Tensor | None = None, - target_representation: Tensor | None = None, - ) -> Tensor: + def objective_function(self, metamer_representation: Optional[Tensor] = None, + target_representation: Optional[Tensor] = None) -> Tensor: """Compute the metamer synthesis loss. This calls self.loss_function on ``metamer_representation`` and @@ -246,10 +222,10 @@ def objective_function( metamer_representation = self.model(self.metamer) if target_representation is None: target_representation = self.target_representation - loss = self.loss_function( - metamer_representation, target_representation - ) - range_penalty = optim.penalize_range(self.metamer, self.allowed_range) + loss = self.loss_function(metamer_representation, + target_representation) + range_penalty = optim.penalize_range(self.metamer, + self.allowed_range) return loss + self.range_penalty_lambda * range_penalty def _optimizer_step(self, pbar: tqdm) -> Tensor: @@ -273,28 +249,23 @@ def _optimizer_step(self, pbar: tqdm) -> Tensor: loss = self.optimizer.step(self._closure) self._losses.append(loss.item()) - grad_norm = torch.linalg.vector_norm( - self.metamer.grad.data, ord=2, dim=None - ) + grad_norm = torch.linalg.vector_norm(self.metamer.grad.data, ord=2, + dim=None) self._gradient_norm.append(grad_norm.item()) # optionally step the scheduler if self.scheduler is not None: self.scheduler.step(loss.item()) - pixel_change_norm = torch.linalg.vector_norm( - self.metamer - last_iter_metamer, ord=2, dim=None - ) + pixel_change_norm = torch.linalg.vector_norm(self.metamer - last_iter_metamer, + ord=2, dim=None) self._pixel_change_norm.append(pixel_change_norm.item()) # add extra info here if you want it to show up in progress bar pbar.set_postfix( - OrderedDict( - loss=f"{loss.item():.04e}", - learning_rate=self.optimizer.param_groups[0]["lr"], - gradient_norm=f"{grad_norm.item():.04e}", - pixel_change_norm=f"{pixel_change_norm.item():.04e}", - ) - ) + OrderedDict(loss=f"{loss.item():.04e}", + learning_rate=self.optimizer.param_groups[0]['lr'], + gradient_norm=f"{grad_norm.item():.04e}", + pixel_change_norm=f"{pixel_change_norm.item():.04e}")) return loss def _check_convergence(self, stop_criterion, stop_iters_to_check): @@ -328,20 +299,18 @@ def _check_convergence(self, stop_criterion, stop_iters_to_check): """ return loss_convergence(self, stop_criterion, stop_iters_to_check) - def _initialize_optimizer( - self, - optimizer: torch.optim.Optimizer | None, - scheduler: torch.optim.lr_scheduler._LRScheduler | None, - ): + def _initialize_optimizer(self, + optimizer: Optional[torch.optim.Optimizer], + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler]): """Initialize optimizer and scheduler.""" # this uses the OptimizedSynthesis setter - super()._initialize_optimizer(optimizer, "metamer") + super()._initialize_optimizer(optimizer, 'metamer') self.scheduler = scheduler for pg in self.optimizer.param_groups: # initialize initial_lr if it's not here. Scheduler should add it # if it's not None. - if "initial_lr" not in pg: - pg["initial_lr"] = pg["lr"] + if 'initial_lr' not in pg: + pg['initial_lr'] = pg['lr'] def _store(self, i: int) -> bool: """Store metamer, if appropriate. @@ -361,7 +330,7 @@ def _store(self, i: int) -> bool: """ if self.store_progress and (i % self.store_progress == 0): # want these to always be on cpu, to reduce memory use for GPUs - self._saved_metamer.append(self.metamer.clone().to("cpu")) + self._saved_metamer.append(self.metamer.clone().to('cpu')) stored = True else: stored = False @@ -417,21 +386,13 @@ def to(self, *args, **kwargs): dtype and device for all parameters and buffers in this module """ - attrs = [ - "_image", - "_target_representation", - "_metamer", - "_model", - "_saved_metamer", - ] + attrs = ['_image', '_target_representation', + '_metamer', '_model', '_saved_metamer'] super().to(*args, attrs=attrs, **kwargs) - def load( - self, - file_path: str, - map_location: str | None = None, - **pickle_load_args, - ): + def load(self, file_path: str, + map_location: Optional[str] = None, + **pickle_load_args): r"""Load all relevant stuff from a .pt file. This should be called by an initialized ``Metamer`` object -- we will @@ -468,48 +429,33 @@ def load( """ self._load(file_path, map_location, **pickle_load_args) - def _load( - self, - file_path: str, - map_location: str | None = None, - additional_check_attributes: list[str] = [], - additional_check_loss_functions: list[str] = [], - **pickle_load_args, - ): + def _load(self, file_path: str, + map_location: Optional[str] = None, + additional_check_attributes: List[str] = [], + additional_check_loss_functions: List[str] = [], + **pickle_load_args): r"""Helper function for loading. Users interact with ``load`` (without the underscore), this is to allow subclasses to specify additional attributes or loss functions to check. """ - check_attributes = [ - "_image", - "_target_representation", - "_range_penalty_lambda", - "_allowed_range", - ] + check_attributes = ['_image', '_target_representation', + '_range_penalty_lambda', '_allowed_range'] check_attributes += additional_check_attributes - check_loss_functions = ["loss_function"] + check_loss_functions = ['loss_function'] check_loss_functions += additional_check_loss_functions - super().load( - file_path, - map_location=map_location, - check_attributes=check_attributes, - check_loss_functions=check_loss_functions, - **pickle_load_args, - ) + super().load(file_path, map_location=map_location, + check_attributes=check_attributes, + check_loss_functions=check_loss_functions, + **pickle_load_args) # make this require a grad again self.metamer.requires_grad_() # these are always supposed to be on cpu, but may get copied over to # gpu on load (which can cause problems when resuming synthesis), so # fix that. - if ( - len(self._saved_metamer) - and self._saved_metamer[0].device.type != "cpu" - ): - self._saved_metamer = [ - met.to("cpu") for met in self._saved_metamer - ] + if len(self._saved_metamer) and self._saved_metamer[0].device.type != 'cpu': + self._saved_metamer = [met.to('cpu') for met in self._saved_metamer] @property def model(self): @@ -573,7 +519,7 @@ class MetamerCTF(Metamer): scale separately (ignoring the others), then with respect to all of them at the end. (see ``Metamer`` tutorial for more details). - + Attributes ---------- target_representation : torch.Tensor @@ -603,63 +549,46 @@ class MetamerCTF(Metamer): scales_finished : list or None List of scales that we've finished optimizing. """ - - def __init__( - self, - image: Tensor, - model: torch.nn.Module, - loss_function: Callable[[Tensor, Tensor], Tensor] = optim.mse, - range_penalty_lambda: float = 0.1, - allowed_range: tuple[float, float] = (0, 1), - initial_image: Tensor | None = None, - coarse_to_fine: Literal["together", "separate"] = "together", - ): - super().__init__( - image, - model, - loss_function, - range_penalty_lambda, - allowed_range, - initial_image, - ) + def __init__(self, image: Tensor, model: torch.nn.Module, + loss_function: Callable[[Tensor, Tensor], Tensor] = optim.mse, + range_penalty_lambda: float = .1, + allowed_range: Tuple[float, float] = (0, 1), + initial_image: Optional[Tensor] = None, + coarse_to_fine: Literal['together', 'separate'] = 'together'): + super().__init__(image, model, loss_function, range_penalty_lambda, + allowed_range, initial_image) self._init_ctf(coarse_to_fine) - def _init_ctf(self, coarse_to_fine: Literal["together", "separate"]): + def _init_ctf(self, coarse_to_fine: Literal['together', 'separate']): """Initialize stuff related to coarse-to-fine.""" # this will hold the reduced representation of the target image. - if coarse_to_fine not in ["separate", "together"]: - raise ValueError( - f"Don't know how to handle value {coarse_to_fine}!" - " Must be one of: 'separate', 'together'" - ) + if coarse_to_fine not in ['separate', 'together']: + raise ValueError(f"Don't know how to handle value {coarse_to_fine}!" + " Must be one of: 'separate', 'together'") self._ctf_target_representation = None - validate_coarse_to_fine( - self.model, image_shape=self.image.shape, device=self.image.device - ) + validate_coarse_to_fine(self.model, image_shape=self.image.shape, + device=self.image.device) # if self.scales is not None, we're continuing a previous version # and want to continue. this list comprehension creates a new # object, so we don't modify model.scales self._scales = [i for i in self.model.scales[:-1]] - if coarse_to_fine == "separate": + if coarse_to_fine == 'separate': self._scales += [self.model.scales[-1]] - self._scales += ["all"] + self._scales += ['all'] self._scales_timing = dict((k, []) for k in self.scales) self._scales_timing[self.scales[0]].append(0) self._scales_loss = [] self._scales_finished = [] self._coarse_to_fine = coarse_to_fine - def synthesize( - self, - max_iter: int = 100, - optimizer: torch.optim.Optimizer | None = None, - scheduler: torch.optim.lr_scheduler._LRScheduler | None = None, - store_progress: bool | int = False, - stop_criterion: float = 1e-4, - stop_iters_to_check: int = 50, - change_scale_criterion: float | None = 1e-2, - ctf_iters_to_check: int = 50, - ): + def synthesize(self, max_iter: int = 100, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + store_progress: Union[bool, int] = False, + stop_criterion: float = 1e-4, stop_iters_to_check: int = 50, + change_scale_criterion: Optional[float] = 1e-2, + ctf_iters_to_check: int = 50, + ): r"""Synthesize a metamer. Update the pixels of ``initial_image`` until its representation matches @@ -704,13 +633,9 @@ def synthesize( switch scales. """ - if (change_scale_criterion is not None) and ( - stop_criterion >= change_scale_criterion - ): - raise ValueError( - "stop_criterion must be strictly less than " - "change_scale_criterion, or things get weird!" - ) + if (change_scale_criterion is not None) and (stop_criterion >= change_scale_criterion): + raise ValueError("stop_criterion must be strictly less than " + "change_scale_criterion, or things get weird!") # initialize the optimizer and scheduler self._initialize_optimizer(optimizer, scheduler) @@ -718,6 +643,7 @@ def synthesize( # get ready to store progress self.store_progress = store_progress + pbar = tqdm(range(max_iter)) for i in pbar: @@ -725,27 +651,22 @@ def synthesize( # iterations and will be correct across calls to `synthesize` self._store(len(self.losses)) - loss = self._optimizer_step( - pbar, change_scale_criterion, ctf_iters_to_check - ) + loss = self._optimizer_step(pbar, change_scale_criterion, ctf_iters_to_check) if not torch.isfinite(loss): raise ValueError("Found a NaN in loss during optimization.") - if self._check_convergence( - i, stop_criterion, stop_iters_to_check, ctf_iters_to_check - ): + if self._check_convergence(i, stop_criterion, stop_iters_to_check, + ctf_iters_to_check): warnings.warn("Loss has converged, stopping synthesis") break pbar.close() - def _optimizer_step( - self, - pbar: tqdm, - change_scale_criterion: float, - ctf_iters_to_check: int, - ) -> Tensor: + def _optimizer_step(self, pbar: tqdm, + change_scale_criterion: float, + ctf_iters_to_check: int + ) -> Tensor: r"""Compute and propagate gradients, then step the optimizer to update metamer. Parameters @@ -774,31 +695,19 @@ def _optimizer_step( # has stopped declining and, if so, switch to the next scale. Then # we're checking if self.scales_loss is long enough to check # ctf_iters_to_check back. - if ( - len(self.scales) > 1 - and len(self.scales_loss) >= ctf_iters_to_check - ): + if len(self.scales) > 1 and len(self.scales_loss) >= ctf_iters_to_check: # Now we check whether loss has decreased less than # change_scale_criterion - if (change_scale_criterion is None) or abs( - self.scales_loss[-1] - self.scales_loss[-ctf_iters_to_check] - ) < change_scale_criterion: + if ((change_scale_criterion is None) or abs(self.scales_loss[-1] - self.scales_loss[-ctf_iters_to_check]) < change_scale_criterion): # and finally we check whether we've been optimizing this # scale for ctf_iters_to_check - if ( - len(self.losses) - self.scales_timing[self.scales[0]][0] - >= ctf_iters_to_check - ): - self._scales_timing[self.scales[0]].append( - len(self.losses) - 1 - ) + if len(self.losses) - self.scales_timing[self.scales[0]][0] >= ctf_iters_to_check: + self._scales_timing[self.scales[0]].append(len(self.losses)-1) self._scales_finished.append(self._scales.pop(0)) - self._scales_timing[self.scales[0]].append( - len(self.losses) - ) + self._scales_timing[self.scales[0]].append(len(self.losses)) # reset optimizer's lr. for pg in self.optimizer.param_groups: - pg["lr"] = pg["initial_lr"] + pg['lr'] = pg['initial_lr'] # reset ctf target representation, so we update it on # next pass self._ctf_target_representation = None @@ -806,33 +715,28 @@ def _optimizer_step( self._scales_loss.append(loss.item()) self._losses.append(overall_loss.item()) - grad_norm = torch.linalg.vector_norm( - self.metamer.grad.data, ord=2, dim=None - ) + grad_norm = torch.linalg.vector_norm(self.metamer.grad.data, ord=2, + dim=None) self._gradient_norm.append(grad_norm.item()) # optionally step the scheduler if self.scheduler is not None: self.scheduler.step(loss.item()) - pixel_change_norm = torch.linalg.vector_norm( - self.metamer - last_iter_metamer, ord=2, dim=None - ) + pixel_change_norm = torch.linalg.vector_norm(self.metamer - last_iter_metamer, + ord=2, dim=None) self._pixel_change_norm.append(pixel_change_norm.item()) # add extra info here if you want it to show up in progress bar pbar.set_postfix( - OrderedDict( - loss=f"{overall_loss.item():.04e}", - learning_rate=self.optimizer.param_groups[0]["lr"], - gradient_norm=f"{grad_norm.item():.04e}", - pixel_change_norm=f"{pixel_change_norm.item():.04e}", - current_scale=self.scales[0], - current_scale_loss=f"{loss.item():.04e}", - ) - ) + OrderedDict(loss=f"{overall_loss.item():.04e}", + learning_rate=self.optimizer.param_groups[0]['lr'], + gradient_norm=f"{grad_norm.item():.04e}", + pixel_change_norm=f"{pixel_change_norm.item():.04e}", + current_scale=self.scales[0], + current_scale_loss=f'{loss.item():.04e}')) return overall_loss - def _closure(self) -> tuple[Tensor, Tensor]: + def _closure(self) -> Tuple[Tensor, Tensor]: r"""An abstraction of the gradient calculation, before the optimization step. This enables optimization algorithms that perform several evaluations @@ -859,12 +763,12 @@ def _closure(self) -> tuple[Tensor, Tensor]: self.optimizer.zero_grad() analyze_kwargs = {} # if we've reached 'all', we use the full model - if self.scales[0] != "all": - analyze_kwargs["scales"] = [self.scales[0]] + if self.scales[0] != 'all': + analyze_kwargs['scales'] = [self.scales[0]] # if 'together', then we also want all the coarser # scales - if self.coarse_to_fine == "together": - analyze_kwargs["scales"] += self.scales_finished + if self.coarse_to_fine == 'together': + analyze_kwargs['scales'] += self.scales_finished metamer_representation = self.model(self.metamer, **analyze_kwargs) # if analyze_kwargs is empty, we can just compare # metamer_representation against our cached target_representation @@ -888,13 +792,9 @@ def _closure(self) -> tuple[Tensor, Tensor]: return loss, overall_loss - def _check_convergence( - self, - i: int, - stop_criterion: float, - stop_iters_to_check: int, - ctf_iters_to_check: int, - ) -> bool: + def _check_convergence(self, i: int, stop_criterion: float, + stop_iters_to_check: int, + ctf_iters_to_check: int) -> bool: r"""Check whether the loss has stabilized and whether we've synthesized all scales. Have we been synthesizing for ``stop_iters_to_check`` iterations? @@ -937,12 +837,9 @@ def _check_convergence( loss_conv = loss_convergence(self, stop_criterion, stop_iters_to_check) return loss_conv and coarse_to_fine_enough(self, i, ctf_iters_to_check) - def load( - self, - file_path: str, - map_location: str | None = None, - **pickle_load_args, - ): + def load(self, file_path: str, + map_location: Optional[str] = None, + **pickle_load_args): r"""Load all relevant stuff from a .pt file. This should be called by an initialized ``Metamer`` object -- we will @@ -977,9 +874,8 @@ def load( *then* load. """ - super()._load( - file_path, map_location, ["_coarse_to_fine"], **pickle_load_args - ) + super()._load(file_path, map_location, ['_coarse_to_fine'], + **pickle_load_args) @property def coarse_to_fine(self): @@ -1002,12 +898,10 @@ def scales_finished(self): return tuple(self._scales_finished) -def plot_loss( - metamer: Metamer, - iteration: int | None = None, - ax: mpl.axes.Axes | None = None, - **kwargs, -) -> mpl.axes.Axes: +def plot_loss(metamer: Metamer, + iteration: Optional[int] = None, + ax: Optional[mpl.axes.Axes] = None, + **kwargs) -> mpl.axes.Axes: """Plot synthesis loss with log-scaled y axis. Plots ``metamer.losses`` over all iterations. Also plots a red dot at @@ -1045,23 +939,21 @@ def plot_loss( ax = plt.gca() ax.semilogy(metamer.losses, **kwargs) try: - ax.scatter(loss_idx, metamer.losses[loss_idx], c="r") + ax.scatter(loss_idx, metamer.losses[loss_idx], c='r') except IndexError: # then there's no loss here pass - ax.set(xlabel="Synthesis iteration", ylabel="Loss") + ax.set(xlabel='Synthesis iteration', ylabel='Loss') return ax -def display_metamer( - metamer: Metamer, - batch_idx: int = 0, - channel_idx: int | None = None, - zoom: float | None = None, - iteration: int | None = None, - ax: mpl.axes.Axes | None = None, - **kwargs, -) -> mpl.axes.Axes: +def display_metamer(metamer: Metamer, + batch_idx: int = 0, + channel_idx: Optional[int] = None, + zoom: Optional[float] = None, + iteration: Optional[int] = None, + ax: Optional[mpl.axes.Axes] = None, + **kwargs) -> mpl.axes.Axes: """Display metamer. You can specify what iteration to view by using the ``iteration`` arg. @@ -1114,24 +1006,17 @@ def display_metamer( as_rgb = False if ax is None: ax = plt.gca() - display.imshow( - image, - ax=ax, - title="Metamer", - zoom=zoom, - batch_idx=batch_idx, - channel_idx=channel_idx, - as_rgb=as_rgb, - **kwargs, - ) + display.imshow(image, ax=ax, title='Metamer', zoom=zoom, + batch_idx=batch_idx, channel_idx=channel_idx, + as_rgb=as_rgb, **kwargs) ax.xaxis.set_visible(False) ax.yaxis.set_visible(False) return ax -def _representation_error( - metamer: Metamer, iteration: int | None = None, **kwargs -) -> Tensor: +def _representation_error(metamer: Metamer, + iteration: Optional[int] = None, + **kwargs) -> Tensor: r"""Get the representation error. This is ``metamer.model(metamer) - target_representation)``. If @@ -1154,25 +1039,19 @@ def _representation_error( """ if iteration is not None: - metamer_rep = metamer.model( - metamer.saved_metamer[iteration].to( - metamer.target_representation.device - ) - ) + metamer_rep = metamer.model(metamer.saved_metamer[iteration].to(metamer.target_representation.device)) else: metamer_rep = metamer.model(metamer.metamer, **kwargs) return metamer_rep - metamer.target_representation -def plot_representation_error( - metamer: Metamer, - batch_idx: int = 0, - iteration: int | None = None, - ylim: tuple[float, float] | None | Literal[False] = None, - ax: mpl.axes.Axes | None = None, - as_rgb: bool = False, - **kwargs, -) -> list[mpl.axes.Axes]: +def plot_representation_error(metamer: Metamer, + batch_idx: int = 0, + iteration: Optional[int] = None, + ylim: Union[Tuple[float, float], None, Literal[False]] = None, + ax: Optional[mpl.axes.Axes] = None, + as_rgb: bool = False, + **kwargs) -> List[mpl.axes.Axes]: r"""Plot distance ratio showing how close we are to convergence. We plot ``_representation_error(metamer, iteration)``. For more details, see @@ -1209,31 +1088,22 @@ def plot_representation_error( List of created axes """ - representation_error = _representation_error( - metamer=metamer, iteration=iteration, **kwargs - ) + representation_error = _representation_error(metamer=metamer, + iteration=iteration, **kwargs) if ax is None: ax = plt.gca() - return display.plot_representation( - metamer.model, - representation_error, - ax, - title="Representation error", - ylim=ylim, - batch_idx=batch_idx, - as_rgb=as_rgb, - ) - - -def plot_pixel_values( - metamer: Metamer, - batch_idx: int = 0, - channel_idx: int | None = None, - iteration: int | None = None, - ylim: tuple[float, float] | Literal[False] = False, - ax: mpl.axes.Axes | None = None, - **kwargs, -) -> mpl.axes.Axes: + return display.plot_representation(metamer.model, representation_error, ax, + title="Representation error", ylim=ylim, + batch_idx=batch_idx, as_rgb=as_rgb) + + +def plot_pixel_values(metamer: Metamer, + batch_idx: int = 0, + channel_idx: Optional[int] = None, + iteration: Optional[int] = None, + ylim: Union[Tuple[float, float], Literal[False]] = False, + ax: Optional[mpl.axes.Axes] = None, + **kwargs) -> mpl.axes.Axes: r"""Plot histogram of pixel values of target image and its metamer. As a way to check the distributions of pixel intensities and see @@ -1265,12 +1135,11 @@ def plot_pixel_values( Created axes. """ - def _freedman_diaconis_bins(a): """Calculate number of hist bins using Freedman-Diaconis rule. copied from seaborn.""" # From https://stats.stackexchange.com/questions/798/ a = np.asarray(a) - iqr = np.diff(np.percentile(a, [0.25, 0.75]))[0] + iqr = np.diff(np.percentile(a, [.25, .75]))[0] if len(a) < 2: return 1 h = 2 * iqr / (len(a) ** (1 / 3)) @@ -1280,7 +1149,7 @@ def _freedman_diaconis_bins(a): else: return int(np.ceil((a.max() - a.min()) / h)) - kwargs.setdefault("alpha", 0.4) + kwargs.setdefault('alpha', .4) if iteration is None: met = metamer.metamer[batch_idx] else: @@ -1293,18 +1162,10 @@ def _freedman_diaconis_bins(a): ax = plt.gca() image = data.to_numpy(image).flatten() met = data.to_numpy(met).flatten() - ax.hist( - met, - bins=min(_freedman_diaconis_bins(image), 50), - label="metamer", - **kwargs, - ) - ax.hist( - image, - bins=min(_freedman_diaconis_bins(image), 50), - label="target image", - **kwargs, - ) + ax.hist(met, bins=min(_freedman_diaconis_bins(image), 50), + label='metamer', **kwargs) + ax.hist(image, bins=min(_freedman_diaconis_bins(image), 50), + label='target image', **kwargs) ax.legend() if ylim: ax.set_ylim(ylim) @@ -1312,9 +1173,8 @@ def _freedman_diaconis_bins(a): return ax -def _check_included_plots( - to_check: list[str] | dict[str, float], to_check_name: str -): +def _check_included_plots(to_check: Union[List[str], Dict[str, float]], + to_check_name: str): """Check whether the user wanted us to create plots that we can't. Helper function for plot_synthesis_status and animate. @@ -1331,39 +1191,28 @@ def _check_included_plots( Name of the `to_check` variable, used in the error message. """ - allowed_vals = [ - "display_metamer", - "plot_loss", - "plot_representation_error", - "plot_pixel_values", - "misc", - ] + allowed_vals = ['display_metamer', 'plot_loss', 'plot_representation_error', + 'plot_pixel_values', 'misc'] try: vals = to_check.keys() except AttributeError: vals = to_check not_allowed = [v for v in vals if v not in allowed_vals] if not_allowed: - raise ValueError( - f"{to_check_name} contained value(s) {not_allowed}! " - f"Only {allowed_vals} are permissible!" - ) - - -def _setup_synthesis_fig( - fig: mpl.figure.Figure | None = None, - axes_idx: dict[str, int] = {}, - figsize: tuple[float, float] | None = None, - included_plots: list[str] = [ - "display_metamer", - "plot_loss", - "plot_representation_error", - ], - display_metamer_width: float = 1, - plot_loss_width: float = 1, - plot_representation_error_width: float = 1, - plot_pixel_values_width: float = 1, -) -> tuple[mpl.figure.Figure, list[mpl.axes.Axes], dict[str, int]]: + raise ValueError(f'{to_check_name} contained value(s) {not_allowed}! ' + f'Only {allowed_vals} are permissible!') + + +def _setup_synthesis_fig(fig: Optional[mpl.figure.Figure] = None, + axes_idx: Dict[str, int] = {}, + figsize: Optional[Tuple[float, float]] = None, + included_plots: List[str] = ['display_metamer', + 'plot_loss', + 'plot_representation_error'], + display_metamer_width: float = 1, + plot_loss_width: float = 1, + plot_representation_error_width: float = 1, + plot_pixel_values_width: float = 1) -> Tuple[mpl.figure.Figure, List[mpl.axes.Axes], Dict[str, int]]: """Set up figure for plot_synthesis_status. Creates figure with enough axes for the all the plots you want. Will @@ -1420,79 +1269,68 @@ def _setup_synthesis_fig( if "display_metamer" in included_plots: n_subplots += 1 width_ratios.append(display_metamer_width) - if "display_metamer" not in axes_idx.keys(): - axes_idx["display_metamer"] = data._find_min_int(axes_idx.values()) + if 'display_metamer' not in axes_idx.keys(): + axes_idx['display_metamer'] = data._find_min_int(axes_idx.values()) if "plot_loss" in included_plots: n_subplots += 1 width_ratios.append(plot_loss_width) - if "plot_loss" not in axes_idx.keys(): - axes_idx["plot_loss"] = data._find_min_int(axes_idx.values()) + if 'plot_loss' not in axes_idx.keys(): + axes_idx['plot_loss'] = data._find_min_int(axes_idx.values()) if "plot_representation_error" in included_plots: n_subplots += 1 width_ratios.append(plot_representation_error_width) - if "plot_representation_error" not in axes_idx.keys(): - axes_idx["plot_representation_error"] = data._find_min_int( - axes_idx.values() - ) + if 'plot_representation_error' not in axes_idx.keys(): + axes_idx['plot_representation_error'] = data._find_min_int(axes_idx.values()) if "plot_pixel_values" in included_plots: n_subplots += 1 width_ratios.append(plot_pixel_values_width) - if "plot_pixel_values" not in axes_idx.keys(): - axes_idx["plot_pixel_values"] = data._find_min_int( - axes_idx.values() - ) + if 'plot_pixel_values' not in axes_idx.keys(): + axes_idx['plot_pixel_values'] = data._find_min_int(axes_idx.values()) if fig is None: width_ratios = np.array(width_ratios) if figsize is None: # we want (5, 5) for each subplot, with a bit of room between # each subplot - figsize = ((width_ratios * 5).sum() + width_ratios.sum() - 1, 5) + figsize = ((width_ratios*5).sum() + width_ratios.sum()-1, 5) width_ratios = width_ratios / width_ratios.sum() - fig, axes = plt.subplots( - 1, - n_subplots, - figsize=figsize, - gridspec_kw={"width_ratios": width_ratios}, - ) + fig, axes = plt.subplots(1, n_subplots, figsize=figsize, + gridspec_kw={'width_ratios': width_ratios}) if n_subplots == 1: axes = [axes] else: axes = fig.axes # make sure misc contains all the empty axes - misc_axes = axes_idx.get("misc", []) - if not hasattr(misc_axes, "__iter__"): + misc_axes = axes_idx.get('misc', []) + if not hasattr(misc_axes, '__iter__'): misc_axes = [misc_axes] all_axes = [] for i in axes_idx.values(): # so if it's a list of ints - if hasattr(i, "__iter__"): + if hasattr(i, '__iter__'): all_axes.extend(i) else: all_axes.append(i) misc_axes += [i for i, _ in enumerate(fig.axes) if i not in all_axes] - axes_idx["misc"] = misc_axes + axes_idx['misc'] = misc_axes return fig, axes, axes_idx -def plot_synthesis_status( - metamer: Metamer, - batch_idx: int = 0, - channel_idx: int | None = None, - iteration: int | None = None, - ylim: tuple[float, float] | None | Literal[False] = None, - vrange: tuple[float, float] | str = "indep1", - zoom: float | None = None, - plot_representation_error_as_rgb: bool = False, - fig: mpl.figure.Figure | None = None, - axes_idx: dict[str, int] = {}, - figsize: tuple[float, float] | None = None, - included_plots: list[str] = [ - "display_metamer", - "plot_loss", - "plot_representation_error", - ], - width_ratios: dict[str, float] = {}, -) -> tuple[mpl.figure.Figure, dict[str, int]]: +def plot_synthesis_status(metamer: Metamer, + batch_idx: int = 0, + channel_idx: Optional[int] = None, + iteration: Optional[int] = None, + ylim: Union[Tuple[float, float], None, Literal[False]] = None, + vrange: Union[Tuple[float, float], str] = 'indep1', + zoom: Optional[float] = None, + plot_representation_error_as_rgb: bool = False, + fig: Optional[mpl.figure.Figure] = None, + axes_idx: Dict[str, int] = {}, + figsize: Optional[Tuple[float, float]] = None, + included_plots: List[str] = ['display_metamer', + 'plot_loss', + 'plot_representation_error'], + width_ratios: Dict[str, float] = {}, + ) -> Tuple[mpl.figure.Figure, Dict[str, int]]: r"""Make a plot showing synthesis status. We create several subplots to analyze this. By default, we create three @@ -1572,23 +1410,19 @@ def plot_synthesis_status( """ if iteration is not None and not metamer.store_progress: - raise ValueError( - "synthesis() was run with store_progress=False, " - "cannot specify which iteration to plot (only" - " last one, with iteration=None)" - ) + raise ValueError("synthesis() was run with store_progress=False, " + "cannot specify which iteration to plot (only" + " last one, with iteration=None)") if metamer.metamer.ndim not in [3, 4]: - raise ValueError( - "plot_synthesis_status() expects 3 or 4d data;" - "unexpected behavior will result otherwise!" - ) - _check_included_plots(included_plots, "included_plots") - _check_included_plots(width_ratios, "width_ratios") - _check_included_plots(axes_idx, "axes_idx") - width_ratios = {f"{k}_width": v for k, v in width_ratios.items()} - fig, axes, axes_idx = _setup_synthesis_fig( - fig, axes_idx, figsize, included_plots, **width_ratios - ) + raise ValueError("plot_synthesis_status() expects 3 or 4d data;" + "unexpected behavior will result otherwise!") + _check_included_plots(included_plots, 'included_plots') + _check_included_plots(width_ratios, 'width_ratios') + _check_included_plots(axes_idx, 'axes_idx') + width_ratios = {f'{k}_width': v for k, v in width_ratios.items()} + fig, axes, axes_idx = _setup_synthesis_fig(fig, axes_idx, figsize, + included_plots, + **width_ratios) def check_iterables(i, vals): for j in vals: @@ -1602,64 +1436,48 @@ def check_iterables(i, vals): return True if "display_metamer" in included_plots: - display_metamer( - metamer, - batch_idx=batch_idx, - channel_idx=channel_idx, - iteration=iteration, - ax=axes[axes_idx["display_metamer"]], - zoom=zoom, - vrange=vrange, - ) + display_metamer(metamer, batch_idx=batch_idx, + channel_idx=channel_idx, + iteration=iteration, + ax=axes[axes_idx['display_metamer']], + zoom=zoom, vrange=vrange) if "plot_loss" in included_plots: - plot_loss(metamer, iteration=iteration, ax=axes[axes_idx["plot_loss"]]) + plot_loss(metamer, iteration=iteration, ax=axes[axes_idx['plot_loss']]) if "plot_representation_error" in included_plots: - plot_representation_error( - metamer, - batch_idx=batch_idx, - iteration=iteration, - ax=axes[axes_idx["plot_representation_error"]], - ylim=ylim, - as_rgb=plot_representation_error_as_rgb, - ) + plot_representation_error(metamer, batch_idx=batch_idx, + iteration=iteration, + ax=axes[axes_idx['plot_representation_error']], + ylim=ylim, + as_rgb=plot_representation_error_as_rgb) # this can add a bunch of axes, so this will try and figure # them out - new_axes = [ - i - for i, _ in enumerate(fig.axes) - if not check_iterables(i, axes_idx.values()) - ] + [axes_idx["plot_representation_error"]] - axes_idx["plot_representation_error"] = new_axes + new_axes = [i for i, _ in enumerate(fig.axes) if not + check_iterables(i, axes_idx.values())] + [axes_idx['plot_representation_error']] + axes_idx['plot_representation_error'] = new_axes if "plot_pixel_values" in included_plots: - plot_pixel_values( - metamer, - batch_idx=batch_idx, - channel_idx=channel_idx, - iteration=iteration, - ax=axes[axes_idx["plot_pixel_values"]], - ) + plot_pixel_values(metamer, batch_idx=batch_idx, + channel_idx=channel_idx, + iteration=iteration, + ax=axes[axes_idx['plot_pixel_values']]) return fig, axes_idx -def animate( - metamer: Metamer, - framerate: int = 10, - batch_idx: int = 0, - channel_idx: int | None = None, - ylim: str | None | tuple[float, float] | Literal[False] = None, - vrange: tuple[float, float] | str = (0, 1), - zoom: float | None = None, - plot_representation_error_as_rgb: bool = False, - fig: mpl.figure.Figure | None = None, - axes_idx: dict[str, int] = {}, - figsize: tuple[float, float] | None = None, - included_plots: list[str] = [ - "display_metamer", - "plot_loss", - "plot_representation_error", - ], - width_ratios: dict[str, float] = {}, -) -> mpl.animation.FuncAnimation: +def animate(metamer: Metamer, + framerate: int = 10, + batch_idx: int = 0, + channel_idx: Optional[int] = None, + ylim: Union[str, None, Tuple[float, float], Literal[False]] = None, + vrange: Union[Tuple[float, float], str] = (0, 1), + zoom: Optional[float] = None, + plot_representation_error_as_rgb: bool = False, + fig: Optional[mpl.figure.Figure] = None, + axes_idx: Dict[str, int] = {}, + figsize: Optional[Tuple[float, float]] = None, + included_plots: List[str] = ['display_metamer', + 'plot_loss', + 'plot_representation_error'], + width_ratios: Dict[str, float] = {}, + ) -> mpl.animation.FuncAnimation: r"""Animate synthesis progress. This is essentially the figure produced by @@ -1765,150 +1583,119 @@ def animate( """ if not metamer.store_progress: - raise ValueError( - "synthesize() was run with store_progress=False," - " cannot animate!" - ) + raise ValueError("synthesize() was run with store_progress=False," + " cannot animate!") if metamer.metamer.ndim not in [3, 4]: - raise ValueError( - "animate() expects 3 or 4d data; unexpected" - " behavior will result otherwise!" - ) - _check_included_plots(included_plots, "included_plots") - _check_included_plots(width_ratios, "width_ratios") - _check_included_plots(axes_idx, "axes_idx") + raise ValueError("animate() expects 3 or 4d data; unexpected" + " behavior will result otherwise!") + _check_included_plots(included_plots, 'included_plots') + _check_included_plots(width_ratios, 'width_ratios') + _check_included_plots(axes_idx, 'axes_idx') if metamer.target_representation.ndimension() == 4: # we have to do this here so that we set the # ylim_rescale_interval such that we never rescale ylim # (rescaling ylim messes up an image axis) ylim = False try: - if ylim.startswith("rescale"): + if ylim.startswith('rescale'): try: - ylim_rescale_interval = int(ylim.replace("rescale", "")) + ylim_rescale_interval = int(ylim.replace('rescale', '')) except ValueError: # then there's nothing we can convert to an int there - ylim_rescale_interval = int( - (metamer.saved_metamer.shape[0] - 1) // 10 - ) + ylim_rescale_interval = int((metamer.saved_metamer.shape[0] - 1) // 10) if ylim_rescale_interval == 0: - ylim_rescale_interval = int( - metamer.saved_metamer.shape[0] - 1 - ) + ylim_rescale_interval = int(metamer.saved_metamer.shape[0] - 1) ylim = None else: raise ValueError("Don't know how to handle ylim %s!" % ylim) except AttributeError: # this way we'll never rescale - ylim_rescale_interval = len(metamer.saved_metamer) + 1 + ylim_rescale_interval = len(metamer.saved_metamer)+1 # we run plot_synthesis_status to initialize the figure if either fig is # None or if there are no titles on any axes, which we assume means that # it's an empty figure if fig is None or not any([ax.get_title() for ax in fig.axes]): - fig, axes_idx = plot_synthesis_status( - metamer=metamer, - batch_idx=batch_idx, - channel_idx=channel_idx, - iteration=0, - figsize=figsize, - ylim=ylim, - vrange=vrange, - zoom=zoom, - fig=fig, - axes_idx=axes_idx, - included_plots=included_plots, - plot_representation_error_as_rgb=plot_representation_error_as_rgb, - width_ratios=width_ratios, - ) + fig, axes_idx = plot_synthesis_status(metamer=metamer, + batch_idx=batch_idx, + channel_idx=channel_idx, + iteration=0, figsize=figsize, + ylim=ylim, vrange=vrange, + zoom=zoom, fig=fig, + axes_idx=axes_idx, + included_plots=included_plots, + plot_representation_error_as_rgb=plot_representation_error_as_rgb, + width_ratios=width_ratios) # grab the artist for the second plot (we don't need to do this for the # metamer or representation plot, because we use the update_plot # function for that) - if "plot_loss" in included_plots: - scat = fig.axes[axes_idx["plot_loss"]].collections[0] + if 'plot_loss' in included_plots: + scat = fig.axes[axes_idx['plot_loss']].collections[0] # can have multiple plots - if "plot_representation_error" in included_plots: + if 'plot_representation_error' in included_plots: try: - rep_error_axes = [ - fig.axes[i] for i in axes_idx["plot_representation_error"] - ] + rep_error_axes = [fig.axes[i] for i in axes_idx['plot_representation_error']] except TypeError: # in this case, axes_idx['plot_representation_error'] is not iterable and so is # a single value - rep_error_axes = [fig.axes[axes_idx["plot_representation_error"]]] + rep_error_axes = [fig.axes[axes_idx['plot_representation_error']]] else: rep_error_axes = [] # can also have multiple plots if metamer.target_representation.ndimension() == 4: - if "plot_representation_error" in included_plots: - warnings.warn( - "Looks like representation is image-like, haven't fully thought out how" - " to best handle rescaling color ranges yet!" - ) + if 'plot_representation_error' in included_plots: + warnings.warn("Looks like representation is image-like, haven't fully thought out how" + " to best handle rescaling color ranges yet!") # replace the bit of the title that specifies the range, # since we don't make any promises about that. we have to do # this here because we need the figure to have been created for ax in rep_error_axes: - ax.set_title(re.sub(r"\n range: .* \n", "\n\n", ax.get_title())) + ax.set_title(re.sub(r'\n range: .* \n', '\n\n', ax.get_title())) def movie_plot(i): artists = [] - if "display_metamer" in included_plots: - artists.extend( - display.update_plot( - fig.axes[axes_idx["display_metamer"]], - data=metamer.saved_metamer[i], - batch_idx=batch_idx, - ) - ) - if "plot_representation_error" in included_plots: - rep_error = _representation_error(metamer, iteration=i) + if 'display_metamer' in included_plots: + artists.extend(display.update_plot(fig.axes[axes_idx['display_metamer']], + data=metamer.saved_metamer[i], + batch_idx=batch_idx)) + if 'plot_representation_error' in included_plots: + rep_error = _representation_error(metamer, + iteration=i) # we pass rep_error_axes to update, and we've grabbed # the right things above - artists.extend( - display.update_plot( - rep_error_axes, - batch_idx=batch_idx, - model=metamer.model, - data=rep_error, - ) - ) + artists.extend(display.update_plot(rep_error_axes, + batch_idx=batch_idx, + model=metamer.model, + data=rep_error)) # again, we know that rep_error_axes contains all the axes # with the representation ratio info - if ((i + 1) % ylim_rescale_interval) == 0: + if ((i+1) % ylim_rescale_interval) == 0: if metamer.target_representation.ndimension() == 3: - display.rescale_ylim(rep_error_axes, rep_error) - if "plot_pixel_values" in included_plots: + display.rescale_ylim(rep_error_axes, + rep_error) + if 'plot_pixel_values' in included_plots: # this is the dumbest way to do this, but it's simple -- # clearing the axes can cause problems if the user has, for # example, changed the tick locator or formatter. not sure how # to handle this best right now - fig.axes[axes_idx["plot_pixel_values"]].clear() - plot_pixel_values( - metamer, - batch_idx=batch_idx, - channel_idx=channel_idx, - iteration=i, - ax=fig.axes[axes_idx["plot_pixel_values"]], - ) - if "plot_loss" in included_plots: + fig.axes[axes_idx['plot_pixel_values']].clear() + plot_pixel_values(metamer, batch_idx=batch_idx, + channel_idx=channel_idx, iteration=i, + ax=fig.axes[axes_idx['plot_pixel_values']]) + if 'plot_loss'in included_plots: # loss always contains values from every iteration, but everything # else will be subsampled. - x_val = i * metamer.store_progress + x_val = i*metamer.store_progress scat.set_offsets((x_val, metamer.losses[x_val])) artists.append(scat) # as long as blitting is True, need to return a sequence of artists return artists # don't need an init_func, since we handle initialization ourselves - anim = mpl.animation.FuncAnimation( - fig, - movie_plot, - frames=len(metamer.saved_metamer), - blit=True, - interval=1000.0 / framerate, - repeat=False, - ) + anim = mpl.animation.FuncAnimation(fig, movie_plot, + frames=len(metamer.saved_metamer), + blit=True, interval=1000./framerate, + repeat=False) plt.close(fig) return anim diff --git a/src/plenoptic/synthesize/simple_metamer.py b/src/plenoptic/synthesize/simple_metamer.py index db857b3a..fd6b8f8a 100644 --- a/src/plenoptic/synthesize/simple_metamer.py +++ b/src/plenoptic/synthesize/simple_metamer.py @@ -1,12 +1,11 @@ """Simple Metamer Class """ - import torch from tqdm.auto import tqdm - -from ..tools import optim -from ..tools.validate import validate_input, validate_model from .synthesis import Synthesis +from ..tools.validate import validate_input, validate_model +from ..tools import optim +from typing import Union class SimpleMetamer(Synthesis): @@ -30,12 +29,8 @@ class SimpleMetamer(Synthesis): """ def __init__(self, image: torch.Tensor, model: torch.nn.Module): - validate_model( - model, - image_shape=image.shape, - image_dtype=image.dtype, - device=image.device, - ) + validate_model(model, image_shape=image.shape, image_dtype=image.dtype, + device=image.device) self.model = model validate_input(image) self.image = image @@ -44,11 +39,8 @@ def __init__(self, image: torch.Tensor, model: torch.nn.Module): self.optimizer = None self.losses = [] - def synthesize( - self, - max_iter: int = 100, - optimizer: None | torch.optim.Optimizer = None, - ) -> torch.Tensor: + def synthesize(self, max_iter: int = 100, + optimizer: Union[None, torch.optim.Optimizer] = None) -> torch.Tensor: """Synthesize a simple metamer. If called multiple times, will continue where we left off. @@ -70,9 +62,8 @@ def synthesize( """ if optimizer is None: if self.optimizer is None: - self.optimizer = torch.optim.Adam( - [self.metamer], lr=0.01, amsgrad=True - ) + self.optimizer = torch.optim.Adam([self.metamer], + lr=.01, amsgrad=True) else: self.optimizer = optimizer @@ -87,10 +78,10 @@ def closure(): # function. You could theoretically also just clamp metamer on # each step of the iteration, but the penalty in the loss seems # to work better in practice - loss = optim.mse( - metamer_representation, self.target_representation - ) - loss = loss + 0.1 * optim.penalize_range(self.metamer, (0, 1)) + loss = optim.mse(metamer_representation, + self.target_representation) + loss = loss + .1 * optim.penalize_range(self.metamer, + (0, 1)) self.losses.append(loss.item()) loss.backward(retain_graph=False) pbar.set_postfix(loss=loss.item()) @@ -109,7 +100,8 @@ def save(self, file_path: str): """ super().save(file_path, attrs=None) - def load(self, file_path: str, map_location: str | None = None): + def load(self, file_path: str, + map_location: Union[str, None] = None): r"""Load all relevant attributes from a .pt file. Note this operates in place and so doesn't return anything. @@ -119,12 +111,9 @@ def load(self, file_path: str, map_location: str | None = None): file_path The path to load the synthesis object from """ - check_attributes = ["target_representation", "image"] - super().load( - file_path, - check_attributes=check_attributes, - map_location=map_location, - ) + check_attributes = ['target_representation', 'image'] + super().load(file_path, check_attributes=check_attributes, + map_location=map_location) def to(self, *args, **kwargs): r"""Move and/or cast the parameters and buffers. @@ -157,6 +146,7 @@ def to(self, *args, **kwargs): Returns: Module: self """ - attrs = ["model", "image", "target_representation", "metamer"] + attrs = ['model', 'image', 'target_representation', + 'metamer'] super().to(*args, attrs=attrs, **kwargs) return self diff --git a/src/plenoptic/synthesize/synthesis.py b/src/plenoptic/synthesize/synthesis.py index cc18555c..8c52dd8c 100644 --- a/src/plenoptic/synthesize/synthesis.py +++ b/src/plenoptic/synthesize/synthesis.py @@ -1,8 +1,8 @@ """abstract synthesis super-class.""" import abc import warnings - import torch +from typing import Optional, List, Tuple, Union class Synthesis(abc.ABC): @@ -20,7 +20,7 @@ def synthesize(self): r"""Synthesize something.""" pass - def save(self, file_path: str, attrs: list[str] | None = None): + def save(self, file_path: str, attrs: Optional[List[str]] = None): r"""Save all relevant (non-model) variables in .pt file. If you leave attrs as None, we grab vars(self) and exclude 'model'. @@ -40,16 +40,14 @@ def save(self, file_path: str, attrs: list[str] | None = None): # this copies the attributes dict so we don't actually remove the # model attribute in the next line attrs = {k: v for k, v in vars(self).items()} - attrs.pop("_model", None) + attrs.pop('_model', None) save_dict = {} for k in attrs: - if k == "_model": - warnings.warn( - "Models can be quite large and they don't change" - " over synthesis. Please be sure that you " - "actually want to save the model." - ) + if k == '_model': + warnings.warn("Models can be quite large and they don't change" + " over synthesis. Please be sure that you " + "actually want to save the model.") attr = getattr(self, k) # detaching the tensors avoids some headaches like the # tensors having extra hooks or the like @@ -58,14 +56,11 @@ def save(self, file_path: str, attrs: list[str] | None = None): save_dict[k] = attr torch.save(save_dict, file_path) - def load( - self, - file_path: str, - map_location: str | None = None, - check_attributes: list[str] = [], - check_loss_functions: list[str] = [], - **pickle_load_args, - ): + def load(self, file_path: str, + map_location: Optional[str] = None, + check_attributes: List[str] = [], + check_loss_functions: List[str] = [], + **pickle_load_args): r"""Load all relevant attributes from a .pt file. This should be called by an initialized ``Synthesis`` object -- we will @@ -103,9 +98,9 @@ def load( ``torch.load``, see that function's docstring for details. """ - tmp_dict = torch.load( - file_path, map_location=map_location, **pickle_load_args - ) + tmp_dict = torch.load(file_path, + map_location=map_location, + **pickle_load_args) if map_location is not None: device = map_location else: @@ -121,60 +116,47 @@ def load( # the initial underscore. This is because this function # needs to be able to set the attribute, which can only be # done with the hidden version. - if k.startswith("_"): + if k.startswith('_'): display_k = k[1:] else: display_k = k if not hasattr(self, k): - raise AttributeError( - "All values of `check_attributes` should be " - "attributes set at initialization, but got " - f"attr {display_k}!" - ) + raise AttributeError("All values of `check_attributes` should be " + "attributes set at initialization, but got " + f"attr {display_k}!") if isinstance(getattr(self, k), torch.Tensor): # there are two ways this can fail -- the first is if they're # the same shape but different values and the second (in the # except block) are if they're different shapes. try: - if not torch.allclose( - getattr(self, k).to(tmp_dict[k].device), - tmp_dict[k], - rtol=5e-2, - ): - raise ValueError( - f"Saved and initialized {display_k} are " - f"different! Initialized: {getattr(self, k)}" - f", Saved: {tmp_dict[k]}, difference: " - f"{getattr(self, k) - tmp_dict[k]}" - ) + if not torch.allclose(getattr(self, k).to(tmp_dict[k].device), + tmp_dict[k], rtol=5e-2): + raise ValueError(f"Saved and initialized {display_k} are " + f"different! Initialized: {getattr(self, k)}" + f", Saved: {tmp_dict[k]}, difference: " + f"{getattr(self, k) - tmp_dict[k]}") except RuntimeError as e: # we end up here if dtype or shape don't match - if "The size of tensor a" in e.args[0]: - raise RuntimeError( - f"Attribute {display_k} have different shapes in" - " saved and initialized versions! Initialized" - f": {getattr(self, k).shape}, Saved: " - f"{tmp_dict[k].shape}" - ) - elif "did not match" in e.args[0]: - raise RuntimeError( - f"Attribute {display_k} has different dtype in " - "saved and initialized versions! Initialized" - f": {getattr(self, k).dtype}, Saved: " - f"{tmp_dict[k].dtype}" - ) + if 'The size of tensor a' in e.args[0]: + raise RuntimeError(f"Attribute {display_k} have different shapes in" + " saved and initialized versions! Initialized" + f": {getattr(self, k).shape}, Saved: " + f"{tmp_dict[k].shape}") + elif 'did not match' in e.args[0]: + raise RuntimeError(f"Attribute {display_k} has different dtype in " + "saved and initialized versions! Initialized" + f": {getattr(self, k).dtype}, Saved: " + f"{tmp_dict[k].dtype}") else: raise e else: if getattr(self, k) != tmp_dict[k]: - raise ValueError( - f"Saved and initialized {display_k} are different!" - f" Self: {getattr(self, k)}, " - f"Saved: {tmp_dict[k]}" - ) + raise ValueError(f"Saved and initialized {display_k} are different!" + f" Self: {getattr(self, k)}, " + f"Saved: {tmp_dict[k]}") for k in check_loss_functions: # same as above - if k.startswith("_"): + if k.startswith('_'): display_k = k[1:] else: display_k = k @@ -183,22 +165,20 @@ def load( saved_loss = tmp_dict[k](tensor_a, tensor_b) init_loss = getattr(self, k)(tensor_a, tensor_b) if not torch.allclose(saved_loss, init_loss, rtol=1e-2): - raise ValueError( - f"Saved and initialized {display_k} are " - "different! On two random tensors: " - f"Initialized: {init_loss}, Saved: " - f"{saved_loss}, difference: " - f"{init_loss-saved_loss}" - ) + raise ValueError(f"Saved and initialized {display_k} are " + "different! On two random tensors: " + f"Initialized: {init_loss}, Saved: " + f"{saved_loss}, difference: " + f"{init_loss-saved_loss}") for k, v in tmp_dict.items(): setattr(self, k, v) @abc.abstractmethod - def to(self, *args, attrs: list[str] = [], **kwargs): + def to(self, *args, attrs: List[str] = [], **kwargs): r"""Moves and/or casts the parameters and buffers. Similar to ``save``, this is an abstract method only because you need to define the attributes to call to on. - + This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) .. function:: to(dtype, non_blocking=False) @@ -230,19 +210,13 @@ def to(self, *args, attrs: list[str] = [], **kwargs): except AttributeError: warnings.warn("model has no `to` method, so we leave it as is...") - device, dtype, non_blocking, memory_format = torch._C._nn._parse_to( - *args, **kwargs - ) + device, dtype, non_blocking, memory_format = torch._C._nn._parse_to(*args, **kwargs) def move(a, k): move_device = None if k.startswith("saved_") else device if memory_format is not None and a.dim() == 4: - return a.to( - move_device, - dtype, - non_blocking, - memory_format=memory_format, - ) + return a.to(move_device, dtype, non_blocking, + memory_format=memory_format) else: return a.to(move_device, dtype, non_blocking) @@ -265,12 +239,10 @@ class OptimizedSynthesis(Synthesis): these will use an optimizer object to iteratively update their output. """ - - def __init__( - self, - range_penalty_lambda: float = 0.1, - allowed_range: tuple[float, float] = (0, 1), - ): + def __init__(self, + range_penalty_lambda: float = .1, + allowed_range: Tuple[float, float] = (0, 1), + ): """Initialize the properties of OptimizedSynthesis.""" self._losses = [] self._gradient_norm = [] @@ -324,12 +296,10 @@ def _closure(self) -> torch.Tensor: loss.backward(retain_graph=False) return loss - def _initialize_optimizer( - self, - optimizer: torch.optim.Optimizer | None, - synth_name: str, - learning_rate: float = 0.01, - ): + def _initialize_optimizer(self, + optimizer: Optional[torch.optim.Optimizer], + synth_name: str, + learning_rate: float = .01): """Initialize optimizer. First time this is called, optimizer can be: @@ -349,20 +319,15 @@ def _initialize_optimizer( synth_attr = getattr(self, synth_name) if optimizer is None: if self.optimizer is None: - self._optimizer = torch.optim.Adam( - [synth_attr], lr=learning_rate, amsgrad=True - ) + self._optimizer = torch.optim.Adam([synth_attr], + lr=learning_rate, amsgrad=True) else: if self.optimizer is not None: - raise TypeError( - "When resuming synthesis, optimizer arg must be None!" - ) - params = optimizer.param_groups[0]["params"] + raise TypeError("When resuming synthesis, optimizer arg must be None!") + params = optimizer.param_groups[0]['params'] if len(params) != 1 or not torch.equal(params[0], synth_attr): - raise ValueError( - f"For {synth_name} synthesis, optimizer must have one " - f"parameter, the {synth_name} we're synthesizing." - ) + raise ValueError(f"For {synth_name} synthesis, optimizer must have one " + f"parameter, the {synth_name} we're synthesizing.") self._optimizer = optimizer @property @@ -393,7 +358,7 @@ def store_progress(self): return self._store_progress @store_progress.setter - def store_progress(self, store_progress: bool | int): + def store_progress(self, store_progress: Union[bool, int]): """Initialize store_progress. Sets the ``self.store_progress`` attribute, as well as changing the @@ -413,23 +378,19 @@ def store_progress(self, store_progress: bool | int): if store_progress: if store_progress is True: store_progress = 1 - if ( - self.store_progress is not None - and store_progress != self.store_progress - ): + if self.store_progress is not None and store_progress != self.store_progress: # we require store_progress to be the same because otherwise the # subsampling relationship between attrs that are stored every # iteration (loss, gradient, etc) and those that are stored every # store_progress iteration (e.g., saved_metamer) changes partway # through and that's annoying - raise Exception( - "If you've already run synthesize() before, must " - "re-run it with same store_progress arg. You " - f"passed {store_progress} instead of " - f"{self.store_progress} (True is equivalent to 1)" - ) + raise Exception("If you've already run synthesize() before, must " + "re-run it with same store_progress arg. You " + f"passed {store_progress} instead of " + f"{self.store_progress} (True is equivalent to 1)") self._store_progress = store_progress @property def optimizer(self): return self._optimizer + diff --git a/src/plenoptic/tools/__init__.py b/src/plenoptic/tools/__init__.py index e02d1c9c..2c815b31 100644 --- a/src/plenoptic/tools/__init__.py +++ b/src/plenoptic/tools/__init__.py @@ -1,10 +1,12 @@ -from . import validate -from .conv import * from .data import * -from .display import * -from .external import * -from .optim import * +from .conv import * from .signal import * from .stats import * +from .display import * from .straightness import * + +from .optim import * +from .external import * from .validate import remove_grad + +from . import validate diff --git a/src/plenoptic/tools/conv.py b/src/plenoptic/tools/conv.py index cc4ae6eb..70832efd 100644 --- a/src/plenoptic/tools/conv.py +++ b/src/plenoptic/tools/conv.py @@ -1,10 +1,10 @@ -import math - import numpy as np -import pyrtools as pt import torch -import torch.nn.functional as F from torch import Tensor +import torch.nn.functional as F +import pyrtools as pt +from typing import Union, Tuple +import math def correlate_downsample(image, filt, padding_mode="reflect"): @@ -24,15 +24,8 @@ def correlate_downsample(image, filt, padding_mode="reflect"): assert isinstance(image, torch.Tensor) and isinstance(filt, torch.Tensor) assert image.ndim == 4 and filt.ndim == 2 n_channels = image.shape[1] - image_padded = same_padding( - image, kernel_size=filt.shape, pad_mode=padding_mode - ) - return F.conv2d( - image_padded, - filt.repeat(n_channels, 1, 1, 1), - stride=2, - groups=n_channels, - ) + image_padded = same_padding(image, kernel_size=filt.shape, pad_mode=padding_mode) + return F.conv2d(image_padded, filt.repeat(n_channels, 1, 1, 1), stride=2, groups=n_channels) def upsample_convolve(image, odd, filt, padding_mode="reflect"): @@ -61,18 +54,10 @@ def upsample_convolve(image, odd, filt, padding_mode="reflect"): pad_end = np.array(filt.shape) - np.array(odd) - pad_start pad = np.array([pad_start[1], pad_end[1], pad_start[0], pad_end[0]]) image_prepad = F.pad(image, tuple(pad // 2), mode=padding_mode) - image_upsample = F.conv_transpose2d( - image_prepad, - weight=torch.ones( - (n_channels, 1, 1, 1), device=image.device, dtype=image.dtype - ), - stride=2, - groups=n_channels, - ) + image_upsample = F.conv_transpose2d(image_prepad, + weight=torch.ones((n_channels, 1, 1, 1), device=image.device, dtype=image.dtype), stride=2, groups=n_channels) image_postpad = F.pad(image_upsample, tuple(pad % 2)) - return F.conv2d( - image_postpad, filt.repeat(n_channels, 1, 1, 1), groups=n_channels - ) + return F.conv2d(image_postpad, filt.repeat(n_channels, 1, 1, 1), groups=n_channels) def blur_downsample(x, n_scales=1, filtname="binom5", scale_filter=True): @@ -92,9 +77,7 @@ def blur_downsample(x, n_scales=1, filtname="binom5", scale_filter=True): """ f = pt.named_filter(filtname) - filt = torch.as_tensor( - np.outer(f, f), dtype=torch.float32, device=x.device - ) + filt = torch.as_tensor(np.outer(f, f), dtype=torch.float32, device=x.device) if scale_filter: filt = filt / 2 for _ in range(n_scales): @@ -120,46 +103,38 @@ def upsample_blur(x, odd, filtname="binom5", scale_filter=True): """ f = pt.named_filter(filtname) - filt = torch.as_tensor( - np.outer(f, f), dtype=torch.float32, device=x.device - ) + filt = torch.as_tensor(np.outer(f, f), dtype=torch.float32, device=x.device) if scale_filter: filt = filt * 2 return upsample_convolve(x, odd, filt) def _get_same_padding( - x: int, kernel_size: int, stride: int, dilation: int + x: int, + kernel_size: int, + stride: int, + dilation: int ) -> int: """Helper function to determine integer padding for F.pad() given img and kernel""" - pad = ( - (math.ceil(x / stride) - 1) * stride - + (kernel_size - 1) * dilation - + 1 - - x - ) + pad = (math.ceil(x / stride) - 1) * stride + (kernel_size - 1) * dilation + 1 - x pad = max(pad, 0) return pad def same_padding( - x: Tensor, - kernel_size: int | tuple[int, int], - stride: int | tuple[int, int] = (1, 1), - dilation: int | tuple[int, int] = (1, 1), - pad_mode: str = "circular", + x: Tensor, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = (1, 1), + dilation: Union[int, Tuple[int, int]] = (1, 1), + pad_mode: str = "circular", ) -> Tensor: """Pad a tensor so that 2D convolution will result in output with same dims.""" - assert ( - len(x.shape) > 2 - ), "Input must be tensor whose last dims are height x width" + assert len(x.shape) > 2, "Input must be tensor whose last dims are height x width" ih, iw = x.shape[-2:] pad_h = _get_same_padding(ih, kernel_size[0], stride[0], dilation[0]) pad_w = _get_same_padding(iw, kernel_size[1], stride[1], dilation[1]) if pad_h > 0 or pad_w > 0: - x = F.pad( - x, - [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], - mode=pad_mode, - ) + x = F.pad(x, + [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], + mode=pad_mode) return x diff --git a/src/plenoptic/tools/convergence.py b/src/plenoptic/tools/convergence.py index bba4b2d1..8a658ea1 100644 --- a/src/plenoptic/tools/convergence.py +++ b/src/plenoptic/tools/convergence.py @@ -20,17 +20,14 @@ # to avoid circular import error: # https://adamj.eu/tech/2021/05/13/python-type-hints-how-to-fix-circular-imports/ from typing import TYPE_CHECKING - if TYPE_CHECKING: - from ..synthesize.metamer import Metamer from ..synthesize.synthesis import OptimizedSynthesis + from ..synthesize.metamer import Metamer -def loss_convergence( - synth: "OptimizedSynthesis", - stop_criterion: float, - stop_iters_to_check: int, -) -> bool: +def loss_convergence(synth: "OptimizedSynthesis", + stop_criterion: float, + stop_iters_to_check: int) -> bool: r"""Check whether the loss has stabilized and, if so, return True. Have we been synthesizing for ``stop_iters_to_check`` iterations? @@ -62,17 +59,13 @@ def loss_convergence( """ if len(synth.losses) > stop_iters_to_check: - if ( - abs(synth.losses[-stop_iters_to_check] - synth.losses[-1]) - < stop_criterion - ): + if abs(synth.losses[-stop_iters_to_check] - synth.losses[-1]) < stop_criterion: return True return False -def coarse_to_fine_enough( - synth: "Metamer", i: int, ctf_iters_to_check: int -) -> bool: +def coarse_to_fine_enough(synth: "Metamer", i: int, + ctf_iters_to_check: int) -> bool: r"""Check whether we've synthesized all scales and done so for at least ctf_iters_to_check iterations This is meant to be paired with another convergence check, such as ``loss_convergence``. @@ -93,20 +86,18 @@ def coarse_to_fine_enough( Whether we've been doing coarse to fine synthesis for long enough. """ - all_scales = synth.scales[0] == "all" + all_scales = synth.scales[0] == 'all' # synth.scales_timing['all'] will only be a non-empty list if all_scales is # True, so we only check it then. This is equivalent to checking if both conditions are trued if all_scales: - return (i - synth.scales_timing["all"][0]) > ctf_iters_to_check + return (i - synth.scales_timing['all'][0]) > ctf_iters_to_check else: return False -def pixel_change_convergence( - synth: "OptimizedSynthesis", - stop_criterion: float, - stop_iters_to_check: int, -) -> bool: +def pixel_change_convergence(synth: "OptimizedSynthesis", + stop_criterion: float, + stop_iters_to_check: int) -> bool: """Check whether the pixel change norm has stabilized and, if so, return True. Have we been synthesizing for ``stop_iters_to_check`` iterations? @@ -138,8 +129,6 @@ def pixel_change_convergence( """ if len(synth.pixel_change_norm) > stop_iters_to_check: - if ( - synth.pixel_change_norm[-stop_iters_to_check:] < stop_criterion - ).all(): + if (synth.pixel_change_norm[-stop_iters_to_check:] < stop_criterion).all(): return True return False diff --git a/src/plenoptic/tools/data.py b/src/plenoptic/tools/data.py index 5f462842..415defa5 100644 --- a/src/plenoptic/tools/data.py +++ b/src/plenoptic/tools/data.py @@ -1,12 +1,13 @@ -import os.path as op import pathlib +from typing import List, Optional, Union, Tuple import warnings import imageio import numpy as np -import torch +import os.path as op from pyrtools import synthetic_images from skimage import color +import torch from torch import Tensor from .signal import rescale @@ -27,12 +28,10 @@ np.complex128: torch.complex128, } -TORCH_TO_NUMPY_TYPES = { - value: key for (key, value) in NUMPY_TO_TORCH_TYPES.items() -} +TORCH_TO_NUMPY_TYPES = {value: key for (key, value) in NUMPY_TO_TORCH_TYPES.items()} -def to_numpy(x: Tensor | np.ndarray, squeeze: bool = False) -> np.ndarray: +def to_numpy(x: Union[Tensor, np.ndarray], squeeze: bool = False) -> np.ndarray: r"""cast tensor to numpy in the most conservative way possible Parameters @@ -58,7 +57,7 @@ def to_numpy(x: Tensor | np.ndarray, squeeze: bool = False) -> np.ndarray: return x -def load_images(paths: str | list[str], as_gray: bool = True) -> Tensor: +def load_images(paths: Union[str, List[str]], as_gray: bool = True) -> Tensor: r"""Correctly load in images Our models and synthesis methods expect their inputs to be 4d @@ -139,10 +138,8 @@ def load_images(paths: str | list[str], as_gray: bool = True) -> Tensor: im = np.expand_dims(im, 0).repeat(3, 0) images.append(im) if len(set([i.shape for i in images])) > 1: - raise ValueError( - "All images must be the same shape but got the following: " - f"{[i.shape for i in images]}" - ) + raise ValueError("All images must be the same shape but got the following: " + f"{[i.shape for i in images]}") images = torch.as_tensor(np.array(images), dtype=torch.float32) if as_gray: if images.ndimension() != 3: @@ -197,9 +194,7 @@ def convert_float_to_int(im: np.ndarray, dtype=np.uint8) -> np.ndarray: return (im * np.iinfo(dtype).max).astype(dtype) -def make_synthetic_stimuli( - size: int = 256, requires_grad: bool = True -) -> Tensor: +def make_synthetic_stimuli(size: int = 256, requires_grad: bool = True) -> Tensor: r"""Make a set of basic stimuli, useful for developping and debugging models Parameters @@ -228,13 +223,10 @@ def make_synthetic_stimuli( bar = np.zeros((size, size)) bar[ - size // 2 - size // 10 : size // 2 + size // 10, - size // 2 - 1 : size // 2 + 1, + size // 2 - size // 10 : size // 2 + size // 10, size // 2 - 1 : size // 2 + 1 ] = 1 - curv_edge = synthetic_images.disk( - size=size, radius=size / 1.2, origin=(size, size) - ) + curv_edge = synthetic_images.disk(size=size, radius=size / 1.2, origin=(size, size)) sine_grating = synthetic_images.sine(size) * synthetic_images.gaussian( size, covariance=size @@ -283,10 +275,10 @@ def make_synthetic_stimuli( def polar_radius( - size: int | tuple[int, int], + size: Union[int, Tuple[int, int]], exponent: float = 1.0, - origin: int | tuple[int, int] | None = None, - device: str | torch.device | None = None, + origin: Optional[Union[int, Tuple[int, int]]] = None, + device: Optional[Union[str, torch.device]] = None, ) -> Tensor: """Make distance-from-origin (r) matrix @@ -344,10 +336,10 @@ def polar_radius( def polar_angle( - size: int | tuple[int, int], + size: Union[int, Tuple[int, int]], phase: float = 0.0, - origin: int | tuple[float, float] | None = None, - device: torch.device | None = None, + origin: Optional[Union[int, Tuple[float, float]]] = None, + device: Optional[torch.device] = None, ) -> Tensor: """Make polar angle matrix (in radians). diff --git a/src/plenoptic/tools/display.py b/src/plenoptic/tools/display.py index d903e22f..97350074 100644 --- a/src/plenoptic/tools/display.py +++ b/src/plenoptic/tools/display.py @@ -1,34 +1,20 @@ """various helpful utilities for plotting or displaying information """ import warnings - -import matplotlib.pyplot as plt +import torch import numpy as np import pyrtools as pt -import torch - +import matplotlib.pyplot as plt from .data import to_numpy - try: from IPython.display import HTML except ImportError: warnings.warn("Unable to import IPython.display.HTML") -def imshow( - image, - vrange="indep1", - zoom=None, - title="", - col_wrap=None, - ax=None, - cmap=None, - plot_complex="rectangular", - batch_idx=None, - channel_idx=None, - as_rgb=False, - **kwargs, -): +def imshow(image, vrange='indep1', zoom=None, title='', col_wrap=None, ax=None, + cmap=None, plot_complex='rectangular', batch_idx=None, + channel_idx=None, as_rgb=False, **kwargs): """Show image(s) correctly. This function shows images correctly, making sure that each element in the @@ -132,26 +118,22 @@ def imshow( im = to_numpy(im) if im.shape[0] > 1 and batch_idx is not None: # this preserves the number of dimensions - im = im[batch_idx : batch_idx + 1] + im = im[batch_idx:batch_idx+1] if channel_idx is not None: # this preserves the number of dimensions - im = im[:, channel_idx : channel_idx + 1] + im = im[:, channel_idx:channel_idx+1] # allow RGB and RGBA if as_rgb: if im.shape[1] not in [3, 4]: - raise Exception( - "If as_rgb is True, then channel must have 3 " - "or 4 elements!" - ) + raise Exception("If as_rgb is True, then channel must have 3 " + "or 4 elements!") im = im.transpose(0, 2, 3, 1) # want to insert a fake "channel" dimension here, so our putting it # into a list below works as expected im = im.reshape((im.shape[0], 1, *im.shape[1:])) elif im.shape[1] > 1 and im.shape[0] > 1: - raise Exception( - "Don't know how to plot images with more than one channel and batch!" - " Use batch_idx / channel_idx to choose a subset for plotting" - ) + raise Exception("Don't know how to plot images with more than one channel and batch!" + " Use batch_idx / channel_idx to choose a subset for plotting") # by iterating through it twice, we make sure to peel apart the batch # and channel dimensions so that they each show up as a separate image. # because of how we've handled everything above, we know that im will @@ -170,8 +152,7 @@ def find_zoom(x, limit): divisors = [i for i in range(2, x) if not x % i] # find the largest zoom (equivalently, smallest divisor) such that the # zoomed in image is smaller than the limit - return 1 / min([i for i in divisors if x / i <= limit]) - + return 1 / min([i for i in divisors if x/i <= limit]) if ax is not None and zoom is None: if ax.bbox.height > max(heights): zoom = ax.bbox.height // max(heights) @@ -183,35 +164,15 @@ def find_zoom(x, limit): zoom = find_zoom(max(widths), ax.bbox.width) elif zoom is None: zoom = 1 - return pt.imshow( - images_to_plot, - vrange=vrange, - zoom=zoom, - title=title, - col_wrap=col_wrap, - ax=ax, - cmap=cmap, - plot_complex=plot_complex, - **kwargs, - ) - - -def animshow( - video, - framerate=2.0, - repeat=False, - vrange="indep1", - zoom=1, - title="", - col_wrap=None, - ax=None, - cmap=None, - plot_complex="rectangular", - batch_idx=None, - channel_idx=None, - as_rgb=False, - **kwargs, -): + return pt.imshow(images_to_plot, vrange=vrange, zoom=zoom, title=title, + col_wrap=col_wrap, ax=ax, cmap=cmap, plot_complex=plot_complex, + **kwargs) + + +def animshow(video, framerate=2., repeat=False, vrange='indep1', zoom=1, + title='', col_wrap=None, ax=None, cmap=None, + plot_complex='rectangular', batch_idx=None, channel_idx=None, + as_rgb=False, **kwargs): """Animate video(s) correctly. This function animates videos correctly, making sure that each element in @@ -340,59 +301,37 @@ def animshow( vid = to_numpy(vid) if vid.shape[0] > 1 and batch_idx is not None: # this preserves the number of dimensions - vid = vid[batch_idx : batch_idx + 1] + vid = vid[batch_idx:batch_idx+1] if channel_idx is not None: # this preserves the number of dimensions - vid = vid[:, channel_idx : channel_idx + 1] + vid = vid[:, channel_idx:channel_idx+1] # allow RGB and RGBA if as_rgb: if vid.shape[1] not in [3, 4]: - raise Exception( - "If as_rgb is True, then channel must have 3 " - "or 4 elements!" - ) + raise Exception("If as_rgb is True, then channel must have 3 " + "or 4 elements!") vid = vid.transpose(0, 2, 3, 4, 1) # want to insert a fake "channel" dimension here, so our putting it # into a list below works as expected vid = vid.reshape((vid.shape[0], 1, *vid.shape[1:])) elif vid.shape[1] > 1 and vid.shape[0] > 1: - raise Exception( - "Don't know how to plot images with more than one channel and batch!" - " Use batch_idx / channel_idx to choose a subset for plotting" - ) + raise Exception("Don't know how to plot images with more than one channel and batch!" + " Use batch_idx / channel_idx to choose a subset for plotting") # by iterating through it twice, we make sure to peel apart the batch # and channel dimensions so that they each show up as a separate video. # because of how we've handled everything above, we know that vid will # be (b,c,t,h,w) or (b,c,t,h,w,r) where r is the RGB(A) values for v in vid: videos_to_show.extend([v_.squeeze() for v_ in v]) - return pt.animshow( - videos_to_show, - framerate=framerate, - as_html5=False, - repeat=repeat, - vrange=vrange, - zoom=zoom, - title=title, - col_wrap=col_wrap, - ax=ax, - cmap=cmap, - plot_complex=plot_complex, - **kwargs, - ) - - -def pyrshow( - pyr_coeffs, - vrange="indep1", - zoom=1, - show_residuals=True, - cmap=None, - plot_complex="rectangular", - batch_idx=0, - channel_idx=0, - **kwargs, -): + return pt.animshow(videos_to_show, framerate=framerate, as_html5=False, + repeat=repeat, vrange=vrange, zoom=zoom, title=title, + col_wrap=col_wrap, ax=ax, cmap=cmap, + plot_complex=plot_complex, **kwargs) + + +def pyrshow(pyr_coeffs, vrange='indep1', zoom=1, show_residuals=True, + cmap=None, plot_complex='rectangular', batch_idx=0, channel_idx=0, + **kwargs): r"""Display steerable pyramid coefficients in orderly fashion. This function uses ``imshow`` to show the coefficients of the steeable @@ -469,31 +408,20 @@ def pyrshow( if np.iscomplex(im).any(): is_complex = True # this removes only the first (batch) dimension - im = im[batch_idx : batch_idx + 1].squeeze(0) + im = im[batch_idx:batch_idx+1].squeeze(0) # this removes only the first (now channel) dimension - im = im[channel_idx : channel_idx + 1].squeeze(0) + im = im[channel_idx:channel_idx+1].squeeze(0) # because of how we've handled everything above, we know that im will # be (h,w). pyr_coeffvis[k] = im - return pt.pyrshow( - pyr_coeffvis, - is_complex=is_complex, - vrange=vrange, - zoom=zoom, - cmap=cmap, - plot_complex=plot_complex, - show_residuals=show_residuals, - **kwargs, - ) - - -def clean_up_axes( - ax, - ylim=None, - spines_to_remove=["top", "right", "bottom"], - axes_to_remove=["x"], -): + return pt.pyrshow(pyr_coeffvis, is_complex=is_complex, vrange=vrange, + zoom=zoom, cmap=cmap, plot_complex=plot_complex, + show_residuals=show_residuals, **kwargs) + + +def clean_up_axes(ax, ylim=None, spines_to_remove=['top', 'right', 'bottom'], + axes_to_remove=['x']): r"""Clean up an axis, as desired when making a stem plot of the representation Parameters @@ -517,18 +445,18 @@ def clean_up_axes( """ if spines_to_remove is None: - spines_to_remove = ["top", "right", "bottom"] + spines_to_remove = ['top', 'right', 'bottom'] if axes_to_remove is None: - axes_to_remove = ["x"] + axes_to_remove = ['x'] if ylim is not None: if ylim: ax.set_ylim(ylim) else: ax.set_ylim((0, ax.get_ylim()[1])) - if "x" in axes_to_remove: + if 'x' in axes_to_remove: ax.xaxis.set_visible(False) - if "y" in axes_to_remove: + if 'y' in axes_to_remove: ax.yaxis.set_visible(False) for s in spines_to_remove: ax.spines[s].set_visible(False) @@ -563,7 +491,7 @@ def update_stem(stem_container, ydata): """ stem_container.markerline.set_ydata(ydata) segments = stem_container.stemlines.get_segments().copy() - for s, y in zip(segments, ydata, strict=False): + for s, y in zip(segments, ydata): try: s[1, 1] = y except IndexError: @@ -589,7 +517,6 @@ def rescale_ylim(axes, data): values) """ data = data.cpu() - def find_ymax(data): try: return np.abs(data).max() @@ -597,7 +524,6 @@ def find_ymax(data): # then we need to call to_numpy on it because it needs to be # detached and converted to an array return np.abs(to_numpy(data)).max() - try: y_max = find_ymax(data) except TypeError: @@ -607,7 +533,7 @@ def find_ymax(data): ax.set_ylim((-y_max, y_max)) -def clean_stem_plot(data, ax=None, title="", ylim=None, xvals=None, **kwargs): +def clean_stem_plot(data, ax=None, title='', ylim=None, xvals=None, **kwargs): r"""convenience wrapper for plotting stem plots This plots the data, baseline, cleans up the axis, and sets the @@ -691,15 +617,14 @@ def clean_stem_plot(data, ax=None, title="", ylim=None, xvals=None, **kwargs): if ax is None: ax = plt.gca() if xvals is not None: - basefmt = " " - ax.hlines( - len(xvals[0]) * [0], xvals[0], xvals[1], colors="C3", zorder=10 - ) + basefmt = ' ' + ax.hlines(len(xvals[0])*[0], xvals[0], xvals[1], colors='C3', + zorder=10) else: # this is the default basefmt value basefmt = None ax.stem(data, basefmt=basefmt, **kwargs) - ax = clean_up_axes(ax, ylim, ["top", "right", "bottom"]) + ax = clean_up_axes(ax, ylim, ['top', 'right', 'bottom']) if title is not None: ax.set_title(title) return ax @@ -727,7 +652,7 @@ def _get_artists_from_axes(axes, data): use, keys are the corresponding keys for data """ - if not hasattr(axes, "__iter__"): + if not hasattr(axes, '__iter__'): # then we only have one axis, so we may be able to update more than one # data element. if len(axes.containers) > 0: @@ -747,25 +672,17 @@ def _get_artists_from_axes(axes, data): artists = {ax.get_label(): ax for ax in artists} else: if data_check == 1 and data.shape[1] != len(artists): - raise Exception( - f"data has {data.shape[1]} things to plot, but " - f"your axis contains {len(artists)} plotting artists, " - "so unsure how to continue! Pass data as a dictionary" - " with keys corresponding to the labels of the artists" - " to update to resolve this." - ) - elif ( - data_check == 2 - and data.ndim > 2 - and data.shape[-3] != len(artists) - ): - raise Exception( - f"data has {data.shape[-3]} things to plot, but " - f"your axis contains {len(artists)} plotting artists, " - "so unsure how to continue! Pass data as a dictionary" - " with keys corresponding to the labels of the artists" - " to update to resolve this." - ) + raise Exception(f"data has {data.shape[1]} things to plot, but " + f"your axis contains {len(artists)} plotting artists, " + "so unsure how to continue! Pass data as a dictionary" + " with keys corresponding to the labels of the artists" + " to update to resolve this.") + elif data_check == 2 and data.ndim > 2 and data.shape[-3] != len(artists): + raise Exception(f"data has {data.shape[-3]} things to plot, but " + f"your axis contains {len(artists)} plotting artists, " + "so unsure how to continue! Pass data as a dictionary" + " with keys corresponding to the labels of the artists" + " to update to resolve this.") else: # then we have multiple axes, so we are only updating one data element # per plot @@ -786,31 +703,19 @@ def _get_artists_from_axes(axes, data): data_check = 2 if isinstance(data, dict): if len(data.keys()) != len(artists): - raise Exception( - f"data has {len(data.keys())} things to plot, but " - f"you passed {len(axes)} axes , so unsure how " - "to continue!" - ) - artists = { - k: a for k, a in zip(data.keys(), artists, strict=False) - } + raise Exception(f"data has {len(data.keys())} things to plot, but " + f"you passed {len(axes)} axes , so unsure how " + "to continue!") + artists = {k: a for k, a in zip(data.keys(), artists)} else: if data_check == 1 and data.shape[1] != len(artists): - raise Exception( - f"data has {data.shape[1]} things to plot, but " - f"you passed {len(axes)} axes , so unsure how " - "to continue!" - ) - if ( - data_check == 2 - and data.ndim > 2 - and data.shape[-3] != len(artists) - ): - raise Exception( - f"data has {data.shape[-3]} things to plot, but " - f"you passed {len(axes)} axes , so unsure how " - "to continue!" - ) + raise Exception(f"data has {data.shape[1]} things to plot, but " + f"you passed {len(axes)} axes , so unsure how " + "to continue!") + if data_check == 2 and data.ndim > 2 and data.shape[-3] != len(artists): + raise Exception(f"data has {data.shape[-3]} things to plot, but " + f"you passed {len(axes)} axes , so unsure how " + "to continue!") if not isinstance(artists, dict): artists = {f"{i:02d}": a for i, a in enumerate(artists)} return artists @@ -882,18 +787,14 @@ def update_plot(axes, data, model=None, batch_idx=0): if isinstance(data, dict): for v in data.values(): if v.ndim not in [3, 4]: - raise ValueError( - "update_plot expects 3 or 4 dimensional data" - "; unexpected behavior will result otherwise!" - f" Got data of shape {v.shape}" - ) + raise ValueError("update_plot expects 3 or 4 dimensional data" + "; unexpected behavior will result otherwise!" + f" Got data of shape {v.shape}") else: if data.ndim not in [3, 4]: - raise ValueError( - "update_plot expects 3 or 4 dimensional data" - "; unexpected behavior will result otherwise!" - f" Got data of shape {data.shape}" - ) + raise ValueError("update_plot expects 3 or 4 dimensional data" + "; unexpected behavior will result otherwise!" + f" Got data of shape {data.shape}") try: artists = model.update_plot(axes=axes, batch_idx=batch_idx, data=data) except AttributeError: @@ -907,24 +808,19 @@ def update_plot(axes, data, model=None, batch_idx=0): # instead, as suggested # https://stackoverflow.com/questions/43629270/how-to-get-single-value-from-dict-with-single-entry try: - if ( - next(iter(ax_artists.values())).get_array().data.ndim - > 1 - ): + if next(iter(ax_artists.values())).get_array().data.ndim > 1: # then this is an RGBA image - data_dict = {"00": data} + data_dict = {'00': data} except Exception as e: - raise Exception( - "Thought this was an RGB(A) image based on the number of " - "artists and data shape, but something is off! " - f"Original exception: {e}" - ) + raise Exception("Thought this was an RGB(A) image based on the number of " + "artists and data shape, but something is off! " + f"Original exception: {e}") else: for i, d in enumerate(data.unbind(1)): # need to keep the shape the same because of how we # check for shape below (unbinding removes a dimension, # so we add it back) - data_dict[f"{i:02d}"] = d.unsqueeze(1) + data_dict[f'{i:02d}'] = d.unsqueeze(1) data = data_dict for k, d in data.items(): try: @@ -965,16 +861,8 @@ def update_plot(axes, data, model=None, batch_idx=0): return artists -def plot_representation( - model=None, - data=None, - ax=None, - figsize=(5, 5), - ylim=False, - batch_idx=0, - title="", - as_rgb=False, -): +def plot_representation(model=None, data=None, ax=None, figsize=(5, 5), + ylim=False, batch_idx=0, title='', as_rgb=False): r"""Helper function for plotting model representation We are trying to plot ``data`` on ``ax``, using @@ -1045,15 +933,15 @@ def plot_representation( try: # no point in passing figsize, because we've already created # and are passing an axis or are passing the user-specified one - fig, axes = model.plot_representation( - ylim=ylim, ax=ax, title=title, batch_idx=batch_idx, data=data - ) + fig, axes = model.plot_representation(ylim=ylim, ax=ax, title=title, + batch_idx=batch_idx, + data=data) except AttributeError: if data is None: data = model.representation if not isinstance(data, dict): if title is None: - title = "Representation" + title = 'Representation' data_dict = {} if not as_rgb: # then we peel apart the channels @@ -1061,22 +949,20 @@ def plot_representation( # need to keep the shape the same because of how we # check for shape below (unbinding removes a dimension, # so we add it back) - data_dict[title + "_%02d" % i] = d.unsqueeze(1) + data_dict[title+'_%02d' % i] = d.unsqueeze(1) else: data_dict[title] = data data = data_dict else: warnings.warn("data has keys, so we're ignoring title!") # want to make sure the axis we're taking over is basically invisible. - ax = clean_up_axes( - ax, False, ["top", "right", "bottom", "left"], ["x", "y"] - ) + ax = clean_up_axes(ax, False, + ['top', 'right', 'bottom', 'left'], ['x', 'y']) axes = [] if len(list(data.values())[0].shape) == 3: # then this is 'vector-like' - gs = ax.get_subplotspec().subgridspec( - min(4, len(data)), int(np.ceil(len(data) / 4)) - ) + gs = ax.get_subplotspec().subgridspec(min(4, len(data)), + int(np.ceil(len(data) / 4))) for i, (k, v) in enumerate(data.items()): ax = fig.add_subplot(gs[i % 4, i // 4]) # only plot the specified batch, but plot each channel @@ -1088,31 +974,23 @@ def plot_representation( axes.append(ax) elif len(list(data.values())[0].shape) == 4: # then this is 'image-like' - gs = ax.get_subplotspec().subgridspec( - int(np.ceil(len(data) / 4)), min(4, len(data)) - ) + gs = ax.get_subplotspec().subgridspec(int(np.ceil(len(data) / 4)), + min(4, len(data))) for i, (k, v) in enumerate(data.items()): ax = fig.add_subplot(gs[i // 4, i % 4]) - ax = clean_up_axes( - ax, False, ["top", "right", "bottom", "left"], ["x", "y"] - ) + ax = clean_up_axes(ax, + False, ['top', 'right', 'bottom', 'left'], + ['x', 'y']) # only plot the specified batch - imshow( - v, - batch_idx=batch_idx, - title=k, - ax=ax, - vrange="indep0", - as_rgb=as_rgb, - ) + imshow(v, batch_idx=batch_idx, title=k, ax=ax, + vrange='indep0', as_rgb=as_rgb) axes.append(ax) # because we're plotting image data, don't want to change # ylim at all ylim = False else: - raise Exception( - "Don't know what to do with data of shape" f" {data.shape}" - ) + raise Exception("Don't know what to do with data of shape" + f" {data.shape}") if ylim is None: if isinstance(data, dict): data = torch.cat(list(data.values()), dim=2) diff --git a/src/plenoptic/tools/external.py b/src/plenoptic/tools/external.py index c6ddefba..310f684d 100644 --- a/src/plenoptic/tools/external.py +++ b/src/plenoptic/tools/external.py @@ -10,19 +10,13 @@ import numpy as np import pyrtools as pt import scipy.io as sio - from ..data import fetch_data -def plot_MAD_results( - original_image, - noise_levels=None, - results_dir=None, - ssim_images_dir=None, - zoom=3, - vrange="indep1", - **kwargs, -): +def plot_MAD_results(original_image, noise_levels=None, + results_dir=None, + ssim_images_dir=None, + zoom=3, vrange='indep1', **kwargs): r"""plot original MAD results, provided by Zhou Wang Plot the results of original MAD Competition, as provided in .mat @@ -77,9 +71,9 @@ def plot_MAD_results( """ if results_dir is None: - results_dir = str(fetch_data("MAD_results.tar.gz")) + results_dir = str(fetch_data('MAD_results.tar.gz')) if ssim_images_dir is None: - ssim_images_dir = str(fetch_data("ssim_images.tar.gz")) + ssim_images_dir = str(fetch_data('ssim_images.tar.gz')) img_path = op.join(op.expanduser(ssim_images_dir), f"{original_image}.tif") orig_img = imageio.imread(img_path) blanks = np.ones((*orig_img.shape, 4)) @@ -87,107 +81,63 @@ def plot_MAD_results( noise_levels = [2**i for i in range(1, 11)] results = {} images = np.dstack([orig_img, blanks]) - titles = ["Original image"] + 4 * [None] - super_titles = 5 * [None] - keys = [ - "im_init", - "im_fixmse_maxssim", - "im_fixmse_minssim", - "im_fixssim_minmse", - "im_fixssim_maxmse", - ] + titles = ['Original image'] + 4*[None] + super_titles = 5*[None] + keys = ['im_init', 'im_fixmse_maxssim', 'im_fixmse_minssim', 'im_fixssim_minmse', + 'im_fixssim_maxmse'] for l in noise_levels: - mat = sio.loadmat( - op.join( - op.expanduser(results_dir), - f"{original_image}_L{l}_results.mat", - ), - squeeze_me=True, - ) + mat = sio.loadmat(op.join(op.expanduser(results_dir), + f"{original_image}_L{l}_results.mat"), squeeze_me=True) # remove these metadata keys - [mat.pop(k) for k in ["__header__", "__version__", "__globals__"]] - key_titles = [ - f"Noise level: {l}", - f"Best SSIM: {mat['maxssim']:.05f}", - f"Worst SSIM: {mat['minssim']:.05f}", - f"Best MSE: {mat['minmse']:.05f}", - f"Worst MSE: {mat['maxmse']:.05f}", - ] - key_super_titles = [ - None, - f"Fix MSE: {mat['FIX_MSE']:.0f}", - None, - f"Fix SSIM: {mat['FIX_SSIM']:.05f}", - None, - ] - for k, t, s in zip(keys, key_titles, key_super_titles, strict=False): + [mat.pop(k) for k in ['__header__', '__version__', '__globals__']] + key_titles = [f'Noise level: {l}', f"Best SSIM: {mat['maxssim']:.05f}", + f"Worst SSIM: {mat['minssim']:.05f}", + f"Best MSE: {mat['minmse']:.05f}", + f"Worst MSE: {mat['maxmse']:.05f}"] + key_super_titles = [None, f"Fix MSE: {mat['FIX_MSE']:.0f}", None, + f"Fix SSIM: {mat['FIX_SSIM']:.05f}", None] + for k, t, s in zip(keys, key_titles, key_super_titles): images = np.dstack([images, mat.pop(k)]) titles.append(t) super_titles.append(s) # this then just contains the loss information - mat.update({"noise_level": l, "original_image": original_image}) - results[f"L{l}"] = mat + mat.update({'noise_level': l, 'original_image': original_image}) + results[f'L{l}'] = mat images = images.transpose((2, 0, 1)) - if vrange.startswith("row"): + if vrange.startswith('row'): vrange_list = [] - for i in range(len(images) // 5): - vr, cmap = pt.tools.display.colormap_range( - images[5 * i : 5 * (i + 1)], vrange.replace("row", "auto") - ) + for i in range(len(images)//5): + vr, cmap = pt.tools.display.colormap_range(images[5*i:5*(i+1)], + vrange.replace('row', 'auto')) vrange_list.extend(vr) else: vrange_list, cmap = pt.tools.display.colormap_range(images, vrange) # this is a bit of hack to do the same thing imshow does, but with # slightly more space dedicated to the title - fig = pt.tools.display.make_figure( - len(images) // 5, - 5, - [zoom * i + 1 for i in images.shape[-2:]], - vert_pct=0.75, - ) - for img, ax, t, vr, s in zip( - images, fig.axes, titles, vrange_list, super_titles, strict=False - ): + fig = pt.tools.display.make_figure(len(images)//5, 5, [zoom*i+1 for i in images.shape[-2:]], + vert_pct=.75) + for img, ax, t, vr, s in zip(images, fig.axes, titles, vrange_list, super_titles): # these are the blanks if (img == 1).all(): continue - pt.imshow( - img, ax=ax, title=t, zoom=zoom, vrange=vr, cmap=cmap, **kwargs - ) + pt.imshow(img, ax=ax, title=t, zoom=zoom, vrange=vr, cmap=cmap, **kwargs) if s is not None: - font = { - k.replace("_", ""): v - for k, v in ax.title.get_font_properties().__dict__.items() - } + font = {k.replace('_', ''): v for k, v in + ax.title.get_font_properties().__dict__.items()} # these are the acceptable keys for the fontdict below - font = { - k: v - for k, v in font.items() - if k in ["family", "color", "weight", "size", "style"] - } + font = {k: v for k, v in font.items() if k in ['family', 'color', 'weight', 'size', + 'style']} # for some reason, this (with passing the transform) is # different (and looks better) than using ax.text. We also # slightly adjust the placement of the text to account for # different zoom levels (we also have 10 pixels between the # rows and columns, which correspond to a different) img_size = ax.bbox.size - fig.text( - 1 + (5 / img_size[0]), - (1 / 0.75), - s, - fontdict=font, - transform=ax.transAxes, - ha="center", - va="top", - ) + fig.text(1+(5/img_size[0]), (1/.75), s, fontdict=font, + transform=ax.transAxes, ha='center', va='top') # linewidth of 1.5 looks good with bbox of 192, 192 - linewidth = np.max([1.5 * np.mean(img_size / 192), 1]) - line = lines.Line2D( - 2 * [0 - ((5 + linewidth / 2) / img_size[0])], - [0, (1 / 0.75)], - transform=ax.transAxes, - figure=fig, - linewidth=linewidth, - ) + linewidth = np.max([1.5 * np.mean(img_size/192), 1]) + line = lines.Line2D(2*[0-((5+linewidth/2)/img_size[0])], [0, (1/.75)], + transform=ax.transAxes, figure=fig, linewidth=linewidth) fig.lines.append(line) return fig, results diff --git a/src/plenoptic/tools/optim.py b/src/plenoptic/tools/optim.py index 4dcf339e..439cc8c3 100644 --- a/src/plenoptic/tools/optim.py +++ b/src/plenoptic/tools/optim.py @@ -1,12 +1,12 @@ """Tools related to optimization such as more objective functions. """ - -import numpy as np import torch from torch import Tensor +from typing import Optional, Tuple +import numpy as np -def set_seed(seed: int | None = None) -> None: +def set_seed(seed: Optional[int] = None) -> None: """Set the seed. We call both ``torch.manual_seed()`` and ``np.random.seed()``. @@ -99,16 +99,11 @@ def relative_MSE(synth_rep: Tensor, ref_rep: Tensor, **kwargs) -> Tensor: Ratio of the squared l2-norm of the difference between ``ref_rep`` and ``synth_rep`` to the squared l2-norm of ``ref_rep`` """ - return ( - torch.linalg.vector_norm(ref_rep - synth_rep, ord=2) ** 2 - / torch.linalg.vector_norm(ref_rep, ord=2) ** 2 - ) + return torch.linalg.vector_norm(ref_rep - synth_rep, ord=2) ** 2 / torch.linalg.vector_norm(ref_rep, ord=2) ** 2 def penalize_range( - synth_img: Tensor, - allowed_range: tuple[float, float] = (0.0, 1.0), - **kwargs, + synth_img: Tensor, allowed_range: Tuple[float, float] = (0.0, 1.0), **kwargs ) -> Tensor: r"""penalize values outside of allowed_range diff --git a/src/plenoptic/tools/signal.py b/src/plenoptic/tools/signal.py index 90f4e939..33841d7c 100644 --- a/src/plenoptic/tools/signal.py +++ b/src/plenoptic/tools/signal.py @@ -1,11 +1,14 @@ +from typing import List, Optional, Tuple, Union + import numpy as np import torch -from pyrtools.pyramids.steer import steer_to_harmonics_mtx from torch import Tensor +import torch.fft as fft +from pyrtools.pyramids.steer import steer_to_harmonics_mtx def minimum( - x: Tensor, dim: list[int] | None = None, keepdim: bool = False + x: Tensor, dim: Optional[List[int]] = None, keepdim: bool = False ) -> Tensor: r"""Compute minimum in torch over any axis or combination of axes in tensor. @@ -13,14 +16,14 @@ def minimum( ---------- x Input tensor. - dim + dim Dimensions over which you would like to compute the minimum. - keepdim + keepdim Keep original dimensions of tensor when returning result. Returns ------- - min_x + min_x Minimum value of x. """ if dim is None: @@ -33,7 +36,7 @@ def minimum( def maximum( - x: Tensor, dim: list[int] | None = None, keepdim: bool = False + x: Tensor, dim: Optional[List[int]] = None, keepdim: bool = False ) -> Tensor: r"""Compute maximum in torch over any dim or combination of axes in tensor. @@ -70,8 +73,8 @@ def rescale(x: Tensor, a: float = 0.0, b: float = 1.0) -> Tensor: def raised_cosine( - width: float = 1, position: float = 0, values: tuple[float, float] = (0, 1) -) -> tuple[np.ndarray, np.ndarray]: + width: float = 1, position: float = 0, values: Tuple[float, float] = (0, 1) +) -> Tuple[np.ndarray, np.ndarray]: """Return a lookup table containing a "raised cosine" soft threshold function. Y = VALUES(1) @@ -113,7 +116,7 @@ def raised_cosine( def interpolate1d( - x_new: Tensor, Y: Tensor | np.ndarray, X: Tensor | np.ndarray + x_new: Tensor, Y: Union[Tensor, np.ndarray], X: Union[Tensor, np.ndarray] ) -> Tensor: r"""One-dimensional linear interpolation. @@ -142,7 +145,7 @@ def interpolate1d( return np.reshape(out, x_new.shape) -def rectangular_to_polar(x: Tensor) -> tuple[Tensor, Tensor]: +def rectangular_to_polar(x: Tensor) -> Tuple[Tensor, Tensor]: r"""Rectangular to polar coordinate transform Parameters @@ -187,9 +190,9 @@ def polar_to_rectangular(amplitude: Tensor, phase: Tensor) -> Tensor: def steer( basis: Tensor, - angle: np.ndarray | Tensor | float, - harmonics: list[int] | None = None, - steermtx: Tensor | np.ndarray | None = None, + angle: Union[np.ndarray, Tensor, float], + harmonics: Optional[List[int]] = None, + steermtx: Optional[Union[Tensor, np.ndarray]] = None, return_weights: bool = False, even_phase: bool = True, ): @@ -283,9 +286,9 @@ def steer( def make_disk( - img_size: int | tuple[int, int] | torch.Size, - outer_radius: float | None = None, - inner_radius: float | None = None, + img_size: Union[int, Tuple[int, int], torch.Size], + outer_radius: Optional[float] = None, + inner_radius: Optional[float] = None, ) -> Tensor: r"""Create a circular mask with softened edges to an image. @@ -324,6 +327,7 @@ def make_disk( for i in range(img_size[0]): # height for j in range(img_size[1]): # width + r = np.sqrt((i - i0) ** 2 + (j - j0) ** 2) if r > outer_radius: @@ -331,15 +335,13 @@ def make_disk( elif r < inner_radius: mask[i][j] = 1 else: - radial_decay = (r - inner_radius) / ( - outer_radius - inner_radius - ) + radial_decay = (r - inner_radius) / (outer_radius - inner_radius) mask[i][j] = (1 + np.cos(np.pi * radial_decay)) / 2 return mask -def add_noise(img: Tensor, noise_mse: float | list[float]) -> Tensor: +def add_noise(img: Tensor, noise_mse: Union[float, List[float]]) -> Tensor: """Add normally distributed noise to an image This adds normally-distributed noise to an image so that the resulting @@ -366,9 +368,7 @@ def add_noise(img: Tensor, noise_mse: float | list[float]) -> Tensor: ).unsqueeze(0) noise_mse = noise_mse.view(noise_mse.nelement(), 1, 1, 1) noise = 200 * torch.randn( - max(noise_mse.shape[0], img.shape[0]), - *img.shape[1:], - device=img.device, + max(noise_mse.shape[0], img.shape[0]), *img.shape[1:], device=img.device ) noise = noise - noise.mean() noise = noise * torch.sqrt( @@ -377,7 +377,7 @@ def add_noise(img: Tensor, noise_mse: float | list[float]) -> Tensor: return img + noise -def modulate_phase(x: Tensor, phase_factor: float = 2.0) -> Tensor: +def modulate_phase(x: Tensor, phase_factor: float = 2.) -> Tensor: """Modulate the phase of a complex signal. Doubling the phase of a complex signal allows you to, for example, take the @@ -471,11 +471,8 @@ def center_crop(x: Tensor, output_size: int) -> Tensor: """ h, w = x.shape[-2:] - return x[ - ..., - (h // 2 - output_size // 2) : (h // 2 + (output_size + 1) // 2), - (w // 2 - output_size // 2) : (w // 2 + (output_size + 1) // 2), - ] + return x[..., (h//2 - output_size//2) : (h//2 + (output_size+1)//2), + (w//2 - output_size//2) : (w//2 + (output_size+1)//2)] def expand(x: Tensor, factor: float) -> Tensor: @@ -510,13 +507,9 @@ def expand(x: Tensor, factor: float) -> Tensor: mx = factor * im_x my = factor * im_y if int(mx) != mx: - raise ValueError( - f"factor * x.shape[-1] must be an integer but got {mx} instead!" - ) + raise ValueError(f"factor * x.shape[-1] must be an integer but got {mx} instead!") if int(my) != my: - raise ValueError( - f"factor * x.shape[-2] must be an integer but got {my} instead!" - ) + raise ValueError(f"factor * x.shape[-2] must be an integer but got {my} instead!") mx = int(mx) my = int(my) @@ -595,20 +588,14 @@ def shrink(x: Tensor, factor: int) -> Tensor: my = im_y / factor if int(mx) != mx: - raise ValueError( - f"x.shape[-1]/factor must be an integer but got {mx} instead!" - ) + raise ValueError(f"x.shape[-1]/factor must be an integer but got {mx} instead!") if int(my) != my: - raise ValueError( - f"x.shape[-2]/factor must be an integer but got {my} instead!" - ) + raise ValueError(f"x.shape[-2]/factor must be an integer but got {my} instead!") mx = int(mx) my = int(my) - fourier = ( - 1 / factor**2 * torch.fft.fftshift(torch.fft.fft2(x), dim=(-2, -1)) - ) + fourier = 1/factor**2 * torch.fft.fftshift(torch.fft.fft2(x), dim=(-2, -1)) fourier_small = torch.zeros( *x.shape[:-2], my, @@ -630,18 +617,9 @@ def shrink(x: Tensor, factor: int) -> Tensor: # This line is equivalent to fourier_small[..., 1:, 1:] = fourier[..., y1:y2, x1:x2] - fourier_small[..., 0, 1:] = ( - fourier[..., y1 - 1, x1:x2] + fourier[..., y2, x1:x2] - ) / 2 - fourier_small[..., 1:, 0] = ( - fourier[..., y1:y2, x1 - 1] + fourier[..., y1:y2, x2] - ) / 2 - fourier_small[..., 0, 0] = ( - fourier[..., y1 - 1, x1 - 1] - + fourier[..., y1 - 1, x2] - + fourier[..., y2, x1 - 1] - + fourier[..., y2, x2] - ) / 4 + fourier_small[..., 0, 1:] = (fourier[..., y1-1, x1:x2] + fourier[..., y2, x1:x2])/ 2 + fourier_small[..., 1:, 0] = (fourier[..., y1:y2, x1-1] + fourier[..., y1:y2, x2])/ 2 + fourier_small[..., 0, 0] = (fourier[..., y1-1, x1-1] + fourier[..., y1-1, x2] + fourier[..., y2, x1-1] + fourier[..., y2, x2]) / 4 fourier_small = torch.fft.ifftshift(fourier_small, dim=(-2, -1)) im_small = torch.fft.ifft2(fourier_small) diff --git a/src/plenoptic/tools/stats.py b/src/plenoptic/tools/stats.py index f862ea0d..ecabf1c8 100644 --- a/src/plenoptic/tools/stats.py +++ b/src/plenoptic/tools/stats.py @@ -1,11 +1,13 @@ +from typing import List, Optional, Union + import torch from torch import Tensor def variance( x: Tensor, - mean: float | Tensor | None = None, - dim: int | list[int] | None = None, + mean: Optional[Union[float, Tensor]] = None, + dim: Optional[Union[int, List[int]]] = None, keepdim: bool = False, ) -> Tensor: r"""Calculate sample variance. @@ -39,9 +41,9 @@ def variance( def skew( x: Tensor, - mean: float | Tensor | None = None, - var: float | Tensor | None = None, - dim: int | list[int] | None = None, + mean: Optional[Union[float, Tensor]] = None, + var: Optional[Union[float, Tensor]] = None, + dim: Optional[Union[int, List[int]]] = None, keepdim: bool = False, ) -> Tensor: r"""Sample estimate of `x` *asymmetry* about its mean @@ -70,16 +72,14 @@ def skew( mean = torch.mean(x, dim=dim, keepdim=True) if var is None: var = variance(x, mean=mean, dim=dim, keepdim=keepdim) - return torch.mean((x - mean).pow(3), dim=dim, keepdim=keepdim) / var.pow( - 1.5 - ) + return torch.mean((x - mean).pow(3), dim=dim, keepdim=keepdim) / var.pow(1.5) def kurtosis( x: Tensor, - mean: float | Tensor | None = None, - var: float | Tensor | None = None, - dim: int | list[int] | None = None, + mean: Optional[Union[float, Tensor]] = None, + var: Optional[Union[float, Tensor]] = None, + dim: Optional[Union[int, List[int]]] = None, keepdim: bool = False, ) -> Tensor: r"""sample estimate of `x` *tailedness* (presence of outliers) @@ -114,6 +114,4 @@ def kurtosis( mean = torch.mean(x, dim=dim, keepdim=True) if var is None: var = variance(x, mean=mean, dim=dim, keepdim=keepdim) - return torch.mean( - torch.abs(x - mean).pow(4), dim=dim, keepdim=keepdim - ) / var.pow(2) + return torch.mean(torch.abs(x - mean).pow(4), dim=dim, keepdim=keepdim) / var.pow(2) diff --git a/src/plenoptic/tools/straightness.py b/src/plenoptic/tools/straightness.py index 4ee0301b..e90e651a 100644 --- a/src/plenoptic/tools/straightness.py +++ b/src/plenoptic/tools/straightness.py @@ -1,6 +1,6 @@ import torch from torch import Tensor - +from typing import Tuple from .validate import validate_input @@ -26,9 +26,7 @@ def make_straight_line(start: Tensor, stop: Tensor, n_steps: int) -> Tensor: validate_input(start, no_batch=True) validate_input(stop, no_batch=True) if start.shape != stop.shape: - raise ValueError( - f"start and stop must be same shape, but got {start.shape} and {stop.shape}!" - ) + raise ValueError(f"start and stop must be same shape, but got {start.shape} and {stop.shape}!") if n_steps <= 0: raise ValueError(f"n_steps must be positive, but got {n_steps}") shape = start.shape[1:] @@ -36,17 +34,15 @@ def make_straight_line(start: Tensor, stop: Tensor, n_steps: int) -> Tensor: device = start.device start = start.reshape(1, -1) stop = stop.reshape(1, -1) - tt = torch.linspace(0, 1, steps=n_steps + 1, device=device).view( - n_steps + 1, 1 - ) + tt = torch.linspace(0, 1, steps=n_steps+1, device=device + ).view(n_steps+1, 1) straight = (1 - tt) * start + tt * stop - return straight.reshape((n_steps + 1, *shape)) + return straight.reshape((n_steps+1, *shape)) -def sample_brownian_bridge( - start: Tensor, stop: Tensor, n_steps: int, max_norm: float = 1 -) -> Tensor: +def sample_brownian_bridge(start: Tensor, stop: Tensor, + n_steps: int, max_norm: float = 1) -> Tensor: """Sample a brownian bridge between `start` and `stop` made up of `n_steps` Parameters @@ -74,9 +70,7 @@ def sample_brownian_bridge( validate_input(start, no_batch=True) validate_input(stop, no_batch=True) if start.shape != stop.shape: - raise ValueError( - f"start and stop must be same shape, but got {start.shape} and {stop.shape}!" - ) + raise ValueError(f"start and stop must be same shape, but got {start.shape} and {stop.shape}!") if n_steps <= 0: raise ValueError(f"n_steps must be positive, but got {n_steps}") if max_norm < 0: @@ -87,22 +81,21 @@ def sample_brownian_bridge( start = start.reshape(1, -1) stop = stop.reshape(1, -1) D = start.shape[1] - dt = torch.as_tensor(1 / n_steps) - tt = torch.linspace(0, 1, steps=n_steps + 1, device=device)[:, None] + dt = torch.as_tensor(1/n_steps) + tt = torch.linspace(0, 1, steps=n_steps+1, device=device)[:, None] - sigma = torch.sqrt(dt / D) * 2.0 * max_norm - dW = sigma * torch.randn(n_steps + 1, D, device=device) + sigma = torch.sqrt(dt / D) * 2. * max_norm + dW = sigma * torch.randn(n_steps+1, D, device=device) dW[0] = start.flatten() W = torch.cumsum(dW, dim=0) bridge = W - tt * (W[-1:] - stop) - return bridge.reshape((n_steps + 1, *shape)) + return bridge.reshape((n_steps+1, *shape)) -def deviation_from_line( - sequence: Tensor, normalize: bool = True -) -> tuple[Tensor, Tensor]: +def deviation_from_line(sequence: Tensor, + normalize: bool = True) -> Tuple[Tensor, Tensor]: """Compute the deviation of `sequence` to the straight line between its endpoints. Project each point of the path `sequence` onto the line defined by @@ -133,15 +126,14 @@ def deviation_from_line( y0 = y[0].view(1, D) y1 = y[-1].view(1, D) - line = y1 - y0 + line = (y1 - y0) line_length = torch.linalg.vector_norm(line, ord=2) line = line / line_length y_centered = y - y0 dist_along_line = y_centered @ line[0] projection = dist_along_line.view(T, 1) * line - dist_from_line = torch.linalg.vector_norm( - y_centered - projection, dim=1, ord=2 - ) + dist_from_line = torch.linalg.vector_norm(y_centered - projection, dim=1, + ord=2) if normalize: dist_along_line /= line_length @@ -170,9 +162,9 @@ def translation_sequence(image: Tensor, n_steps: int = 10) -> Tensor: validate_input(image, no_batch=True) if n_steps <= 0: raise ValueError(f"n_steps must be positive, but got {n_steps}") - sequence = torch.empty(n_steps + 1, *image.shape[1:]).to(image.device) + sequence = torch.empty(n_steps+1, *image.shape[1:]).to(image.device) - for shift in range(n_steps + 1): + for shift in range(n_steps+1): sequence[shift] = torch.roll(image, shift, [-1]) return sequence diff --git a/src/plenoptic/tools/validate.py b/src/plenoptic/tools/validate.py index c1a5028d..c062c70f 100644 --- a/src/plenoptic/tools/validate.py +++ b/src/plenoptic/tools/validate.py @@ -1,16 +1,16 @@ """Functions to validate synthesis inputs. """ -import itertools -import warnings -from collections.abc import Callable - import torch +import warnings +import itertools +from typing import Tuple, Optional, Callable, Union from torch import Tensor +import warnings def validate_input( input_tensor: Tensor, no_batch: bool = False, - allowed_range: tuple[float, float] | None = None, + allowed_range: Optional[Tuple[float, float]] = None, ): """Determine whether input_tensor tensor can be used for synthesis. @@ -39,17 +39,10 @@ def validate_input( """ # validate dtype - if input_tensor.dtype not in [ - torch.float16, - torch.complex32, - torch.float32, - torch.complex64, - torch.float64, - torch.complex128, - ]: - raise TypeError( - f"Only float or complex dtypes are allowed but got type {input_tensor.dtype}" - ) + if input_tensor.dtype not in [torch.float16, torch.complex32, + torch.float32, torch.complex64, + torch.float64, torch.complex128]: + raise TypeError(f"Only float or complex dtypes are allowed but got type {input_tensor.dtype}") if input_tensor.ndimension() != 4: if no_batch: n_batch = 1 @@ -64,29 +57,24 @@ def validate_input( if no_batch and input_tensor.shape[0] != 1: # numpy raises ValueError when operands cannot be broadcast together, # so it seems reasonable here - raise ValueError("input_tensor batch dimension must be 1.") + raise ValueError(f"input_tensor batch dimension must be 1.") if allowed_range is not None: if allowed_range[0] >= allowed_range[1]: raise ValueError( "allowed_range[0] must be strictly less than" f" allowed_range[1], but got {allowed_range}" ) - if ( - input_tensor.min() < allowed_range[0] - or input_tensor.max() > allowed_range[1] - ): + if input_tensor.min() < allowed_range[0] or input_tensor.max() > allowed_range[1]: raise ValueError( f"input_tensor range must lie within {allowed_range}, but got" f" {(input_tensor.min().item(), input_tensor.max().item())}" ) -def validate_model( - model: torch.nn.Module, - image_shape: tuple[int, int, int, int] | None = None, - image_dtype: torch.dtype = torch.float32, - device: str | torch.device = "cpu", -): +def validate_model(model: torch.nn.Module, + image_shape: Optional[Tuple[int, int, int, int]] = None, + image_dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = 'cpu'): """Determine whether model can be used for sythesis. In particular, this function checks the following (with their associated @@ -138,9 +126,8 @@ def validate_model( """ if image_shape is None: image_shape = (1, 1, 16, 16) - test_img = torch.rand( - image_shape, dtype=image_dtype, requires_grad=False, device=device - ) + test_img = torch.rand(image_shape, dtype=image_dtype, requires_grad=False, + device=device) try: if model(test_img).requires_grad: raise ValueError( @@ -176,14 +163,12 @@ def validate_model( elif image_dtype in [torch.float64, torch.complex128]: allowed_dtypes = [torch.float64, torch.complex128] else: - raise TypeError( - f"Only float or complex dtypes are allowed but got type {image_dtype}" - ) + raise TypeError(f"Only float or complex dtypes are allowed but got type {image_dtype}") if model(test_img).dtype not in allowed_dtypes: raise TypeError("model changes precision of input, don't do that!") if model(test_img).ndimension() not in [3, 4]: raise ValueError( - "When given a 4d input, model output must be three- or four-" + f"When given a 4d input, model output must be three- or four-" "dimensional but had {model(test_img).ndimension()} dimensions instead!" ) if model(test_img).device != test_img.device: @@ -196,11 +181,9 @@ def validate_model( ) -def validate_coarse_to_fine( - model: torch.nn.Module, - image_shape: tuple[int, int, int, int] | None = None, - device: str | torch.device = "cpu", -): +def validate_coarse_to_fine(model: torch.nn.Module, + image_shape: Optional[Tuple[int, int, int, int]] = None, + device: Union[str, torch.device] = 'cpu'): """Determine whether a model can be used for coarse-to-fine synthesis. In particular, this function checks the following (with associated errors): @@ -225,9 +208,7 @@ def validate_coarse_to_fine( Which device to place the test image on. """ - warnings.warn( - "Validating whether model can work with coarse-to-fine synthesis -- this can take a while!" - ) + warnings.warn("Validating whether model can work with coarse-to-fine synthesis -- this can take a while!") msg = "and therefore we cannot do coarse-to-fine synthesis" if not hasattr(model, "scales"): raise AttributeError(f"model has no scales attribute {msg}") @@ -240,7 +221,7 @@ def validate_coarse_to_fine( try: if model_output_shape == model(test_img, scales=sc).shape: raise ValueError( - "Output of model forward method doesn't change" + f"Output of model forward method doesn't change" " shape when scales keyword arg is set to {sc} {msg}" ) except TypeError: @@ -249,12 +230,10 @@ def validate_coarse_to_fine( ) -def validate_metric( - metric: torch.nn.Module | Callable[[Tensor, Tensor], Tensor], - image_shape: tuple[int, int, int, int] | None = None, - image_dtype: torch.dtype = torch.float32, - device: str | torch.device = "cpu", -): +def validate_metric(metric: Union[torch.nn.Module, Callable[[Tensor, Tensor], Tensor]], + image_shape: Optional[Tuple[int, int, int, int]] = None, + image_dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = 'cpu'): """Determines whether a metric can be used for MADCompetition synthesis. In particular, this functions checks the following (with associated @@ -291,9 +270,7 @@ def validate_metric( try: same_val = metric(test_img, test_img).item() except TypeError: - raise TypeError( - "metric should be callable and accept two 4d tensors as input" - ) + raise TypeError("metric should be callable and accept two 4d tensors as input") # as of torch 2.0.0, this is a RuntimeError (a Tensor with X elements # cannot be converted to Scalar); previously it was a ValueError (only one # element tensors can be converted to Python scalars)