Skip to content

Commit

Permalink
Revert "updating some deprecated imports, isinstance for union of typ…
Browse files Browse the repository at this point in the history
…es, unsorted imports, f-strings, replaced single quote with double quotes and deleted trailing whitespace"

This reverts commit c1fd8bc.
  • Loading branch information
hmd101 committed Aug 8, 2024
1 parent 7786cad commit 79e93d3
Show file tree
Hide file tree
Showing 49 changed files with 1,636 additions and 2,749 deletions.
12 changes: 4 additions & 8 deletions examples/00_quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import torch\n",
"\n",
"import plenoptic as po\n",
"\n",
"import torch\n",
"import pyrtools as pt\n",
"import matplotlib.pyplot as plt\n",
"# so that relative sizes of axes created by po.imshow and others look right\n",
"plt.rcParams['figure.dpi'] = 72\n",
"\n",
Expand Down Expand Up @@ -84,10 +83,7 @@
],
"source": [
"# this is a convenience function for creating a simple Gaussian kernel\n",
"from plenoptic.simulate.canonical_computations.filters import (\n",
" circular_gaussian2d,\n",
")\n",
"\n",
"from plenoptic.simulate.canonical_computations.filters import circular_gaussian2d\n",
"\n",
"# Simple rectified Gaussian convolutional model\n",
"class SimpleModel(torch.nn.Module):\n",
Expand Down
14 changes: 6 additions & 8 deletions examples/02_Eigendistortions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,11 @@
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"# so that relative sizes of axes created by po.imshow and others look right\n",
"plt.rcParams['figure.dpi'] = 72\n",
"import torch\n",
"from torch import nn\n",
"\n",
"from plenoptic.synthesize.eigendistortion import Eigendistortion\n",
"\n",
"from torch import nn\n",
"# this notebook uses torchvision, which is an optional dependency.\n",
"# if this fails, install torchvision in your plenoptic environment \n",
"# and restart the notebook kernel.\n",
Expand All @@ -62,6 +59,7 @@
" raise ModuleNotFoundError(\"optional dependency torchvision not found!\"\n",
" \" please install it in your plenoptic environment \"\n",
" \"and restart the notebook kernel\")\n",
"import os.path as op\n",
"import plenoptic as po"
]
},
Expand Down Expand Up @@ -824,7 +822,7 @@
}
],
"source": [
"po.synth.eigendistortion.display_eigendistortion(ed_resneta, 0, as_rgb=True, zoom=3)\n",
"po.synth.eigendistortion.display_eigendistortion(ed_resneta, 0, as_rgb=True, zoom=3);\n",
"po.synth.eigendistortion.display_eigendistortion(ed_resneta, -1, as_rgb=True, zoom=3);"
]
},
Expand Down Expand Up @@ -1027,10 +1025,10 @@
}
],
"source": [
"po.synth.eigendistortion.display_eigendistortion(ed_resneta, 0, as_rgb=True, zoom=2, title=\"top eigendist\")\n",
"po.synth.eigendistortion.display_eigendistortion(ed_resneta, -1, as_rgb=True, zoom=2, title=\"bottom eigendist\")\n",
"po.synth.eigendistortion.display_eigendistortion(ed_resneta, 0, as_rgb=True, zoom=2, title=\"top eigendist\");\n",
"po.synth.eigendistortion.display_eigendistortion(ed_resneta, -1, as_rgb=True, zoom=2, title=\"bottom eigendist\");\n",
"\n",
"po.synth.eigendistortion.display_eigendistortion(ed_resnetb, 0, as_rgb=True, zoom=2, title=\"top eigendist\")\n",
"po.synth.eigendistortion.display_eigendistortion(ed_resnetb, 0, as_rgb=True, zoom=2, title=\"top eigendist\");\n",
"po.synth.eigendistortion.display_eigendistortion(ed_resnetb, -1, as_rgb=True, zoom=2, title=\"bottom eigendist\");"
]
},
Expand Down
18 changes: 9 additions & 9 deletions examples/03_Steerable_Pyramid.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
"source": [
"import numpy as np\n",
"import torch\n",
"\n",
"# this notebook uses torchvision, which is an optional dependency.\n",
"# if this fails, install torchvision in your plenoptic environment \n",
"# and restart the notebook kernel.\n",
Expand All @@ -31,19 +30,20 @@
" raise ModuleNotFoundError(\"optional dependency torchvision not found!\"\n",
" \" please install it in your plenoptic environment \"\n",
" \"and restart the notebook kernel\")\n",
"import matplotlib.pyplot as plt\n",
"import torch.nn.functional as F\n",
"import torchvision.transforms as transforms\n",
"import torch.nn.functional as F\n",
"from torch import nn\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import pyrtools as pt\n",
"import plenoptic as po\n",
"from plenoptic.simulate import SteerablePyramidFreq\n",
"from plenoptic.synthesize import Eigendistortion\n",
"from plenoptic.tools.data import to_numpy\n",
"\n",
"dtype = torch.float32\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"import os\n",
"from tqdm.auto import tqdm\n",
"\n",
"%load_ext autoreload\n",
"\n",
"%autoreload 2\n",
Expand Down Expand Up @@ -218,7 +218,7 @@
],
"source": [
"print(pyr_coeffs.keys())\n",
"po.pyrshow(pyr_coeffs, zoom=0.5, batch_idx=0)\n",
"po.pyrshow(pyr_coeffs, zoom=0.5, batch_idx=0);\n",
"po.pyrshow(pyr_coeffs, zoom=0.5, batch_idx=1);"
]
},
Expand Down Expand Up @@ -267,7 +267,7 @@
"#get the 3rd scale\n",
"print(pyr.scales)\n",
"pyr_coeffs_scale0 = pyr(im_batch, scales=[2])\n",
"po.pyrshow(pyr_coeffs_scale0, zoom=2, batch_idx=0)\n",
"po.pyrshow(pyr_coeffs_scale0, zoom=2, batch_idx=0);\n",
"po.pyrshow(pyr_coeffs_scale0, zoom=2, batch_idx=1);"
]
},
Expand Down Expand Up @@ -323,7 +323,7 @@
],
"source": [
"# the same visualization machinery works for complex pyramids; what is shown is the magnitude of the coefficients\n",
"po.pyrshow(pyr_coeffs_complex, zoom=0.5, batch_idx=0)\n",
"po.pyrshow(pyr_coeffs_complex, zoom=0.5, batch_idx=0);\n",
"po.pyrshow(pyr_coeffs_complex, zoom=0.5, batch_idx=1);"
]
},
Expand Down Expand Up @@ -2310,7 +2310,7 @@
}
],
"source": [
"po.pyrshow(pyr_coeffs_complex, zoom=0.5)\n",
"po.pyrshow(pyr_coeffs_complex, zoom=0.5);\n",
"po.pyrshow(pyr_coeffs_fixed_1, zoom=0.5);"
]
},
Expand Down
15 changes: 6 additions & 9 deletions examples/04_Perceptual_distance.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,14 @@
"outputs": [],
"source": [
"import os\n",
"\n",
"import io\n",
"import imageio\n",
"import matplotlib.pyplot as plt\n",
"import plenoptic as po\n",
"import numpy as np\n",
"import torch\n",
"from PIL import Image\n",
"from scipy.stats import pearsonr, spearmanr\n",
"\n",
"import plenoptic as po"
"import matplotlib.pyplot as plt\n",
"import torch\n",
"from PIL import Image"
]
},
{
Expand Down Expand Up @@ -81,8 +80,6 @@
"outputs": [],
"source": [
"import tempfile\n",
"\n",
"\n",
"def add_jpeg_artifact(img, quality):\n",
" # need to convert this back to 2d 8-bit int for writing out as jpg\n",
" img = po.to_numpy(img.squeeze() * 255).astype(np.uint8)\n",
Expand Down Expand Up @@ -396,7 +393,7 @@
" folder / \"distorted_images\" / distorted_filename).convert(\"L\"))) / 255\n",
" distorted_images = distorted_images[:, [0] + list(range(2, 17)) + list(range(18, 24))] # Remove color distortions\n",
"\n",
" with open(folder/ \"mos.txt\", encoding=\"utf-8\") as g:\n",
" with open(folder/ \"mos.txt\", \"r\", encoding=\"utf-8\") as g:\n",
" mos_values = list(map(float, g.readlines()))\n",
" mos_values = np.array(mos_values).reshape([25, 24, 5])\n",
" mos_values = mos_values[:, [0] + list(range(2, 17)) + list(range(18, 24))] # Remove color distortions\n",
Expand Down
49 changes: 22 additions & 27 deletions examples/05_Geodesics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,20 @@
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"import matplotlib.pyplot as plt\n",
"# so that relative sizes of axes created by po.imshow and others look right\n",
"plt.rcParams['figure.dpi'] = 72\n",
"%matplotlib inline\n",
"\n",
"import pyrtools as pt\n",
"\n",
"import plenoptic as po\n",
"from plenoptic.tools import to_numpy\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"\n",
"# this notebook uses torchvision, which is an optional dependency.\n",
"# if this fails, install torchvision in your plenoptic environment \n",
"# and restart the notebook kernel.\n",
Expand Down Expand Up @@ -146,8 +142,6 @@
"outputs": [],
"source": [
"import torch.fft\n",
"\n",
"\n",
"class Fourier(nn.Module):\n",
" def __init__(self, representation = 'amp'):\n",
" super().__init__()\n",
Expand Down Expand Up @@ -228,7 +222,7 @@
],
"source": [
"fig, axes = plt.subplots(2, 1, figsize=(5, 8))\n",
"po.synth.geodesic.plot_loss(moog, ax=axes[0])\n",
"po.synth.geodesic.plot_loss(moog, ax=axes[0]);\n",
"po.synth.geodesic.plot_deviation_from_line(moog, vid, ax=axes[1]);"
]
},
Expand All @@ -249,7 +243,7 @@
}
],
"source": [
"plt.plot(po.to_numpy(moog.step_energy), alpha=.2)\n",
"plt.plot(po.to_numpy(moog.step_energy), alpha=.2);\n",
"plt.plot(moog.step_energy.mean(1), 'r-', label='path energy')\n",
"plt.axhline(torch.linalg.vector_norm(moog.model(moog.image_a) - moog.model(moog.image_b), ord=2) ** 2 / moog.n_steps ** 2)\n",
"plt.legend()\n",
Expand Down Expand Up @@ -308,7 +302,7 @@
}
],
"source": [
"plt.plot(po.to_numpy(moog.dev_from_line[..., 1]))\n",
"plt.plot(po.to_numpy(moog.dev_from_line[..., 1]));\n",
"\n",
"plt.title('evolution of distance from representation line')\n",
"plt.ylabel('distance from representation line')\n",
Expand Down Expand Up @@ -367,7 +361,7 @@
"geodesic = to_numpy(moog.geodesic.squeeze())\n",
"fig = pt.imshow([video[5], pixelfade[5], geodesic[5]],\n",
" title=['video', 'pixelfade', 'geodesic'],\n",
" col_wrap=3, zoom=4)\n",
" col_wrap=3, zoom=4);\n",
"\n",
"size = geodesic.shape[-1]\n",
"h, m , l = (size//2 + size//4, size//2, size//2 - size//4)\n",
Expand All @@ -378,9 +372,9 @@
" a.axhline(line, lw=2)\n",
"\n",
"pt.imshow([video[:,l], pixelfade[:,l], geodesic[:,l]],\n",
" title=None, col_wrap=3, zoom=4)\n",
" title=None, col_wrap=3, zoom=4);\n",
"pt.imshow([video[:,m], pixelfade[:,m], geodesic[:,m]],\n",
" title=None, col_wrap=3, zoom=4)\n",
" title=None, col_wrap=3, zoom=4);\n",
"pt.imshow([video[:,h], pixelfade[:,h], geodesic[:,h]],\n",
" title=None, col_wrap=3, zoom=4);"
]
Expand Down Expand Up @@ -477,7 +471,7 @@
],
"source": [
"fig, axes = plt.subplots(2, 1, figsize=(5, 8))\n",
"po.synth.geodesic.plot_loss(moog, ax=axes[0])\n",
"po.synth.geodesic.plot_loss(moog, ax=axes[0]);\n",
"po.synth.geodesic.plot_deviation_from_line(moog, ax=axes[1]);"
]
},
Expand Down Expand Up @@ -524,7 +518,7 @@
}
],
"source": [
"plt.plot(po.to_numpy(moog.step_energy), alpha=.2)\n",
"plt.plot(po.to_numpy(moog.step_energy), alpha=.2);\n",
"plt.plot(moog.step_energy.mean(1), 'r-', label='path energy')\n",
"plt.axhline(torch.linalg.vector_norm(moog.model(moog.image_a) - moog.model(moog.image_b), ord=2) ** 2 / moog.n_steps ** 2)\n",
"plt.legend()\n",
Expand Down Expand Up @@ -636,9 +630,9 @@
],
"source": [
"print('geodesic')\n",
"pt.imshow(list(geodesic), vrange='auto1', title=None, zoom=4)\n",
"pt.imshow(list(geodesic), vrange='auto1', title=None, zoom=4);\n",
"print('diff')\n",
"pt.imshow(list(geodesic - pixelfade), vrange='auto1', title=None, zoom=4)\n",
"pt.imshow(list(geodesic - pixelfade), vrange='auto1', title=None, zoom=4);\n",
"print('pixelfade')\n",
"pt.imshow(list(pixelfade), vrange='auto1', title=None, zoom=4);"
]
Expand All @@ -663,7 +657,7 @@
"# checking that the range constraint is met\n",
"plt.hist(video.flatten(), histtype='step', density=True, label='video')\n",
"plt.hist(pixelfade.flatten(), histtype='step', density=True, label='pixelfade')\n",
"plt.hist(geodesic.flatten(), histtype='step', density=True, label='geodesic')\n",
"plt.hist(geodesic.flatten(), histtype='step', density=True, label='geodesic');\n",
"plt.title('signal value histogram')\n",
"plt.legend(loc=1)\n",
"plt.show()"
Expand Down Expand Up @@ -722,9 +716,9 @@
"l = 90\n",
"imgA = imgA[..., u:u+224, l:l+224]\n",
"imgB = imgB[..., u:u+224, l:l+224]\n",
"po.imshow([imgA, imgB], as_rgb=True)\n",
"po.imshow([imgA, imgB], as_rgb=True);\n",
"diff = imgA - imgB\n",
"po.imshow(diff)\n",
"po.imshow(diff);\n",
"pt.image_compare(po.to_numpy(imgA, True), po.to_numpy(imgB, True));"
]
},
Expand All @@ -745,6 +739,7 @@
}
],
"source": [
"from torchvision import models\n",
"# Create a class that takes the nth layer output of a given model\n",
"class NthLayer(torch.nn.Module):\n",
" \"\"\"Wrap any model to get the response of an intermediate layer\n",
Expand Down Expand Up @@ -825,7 +820,7 @@
"predA = po.to_numpy(models.vgg16(pretrained=True)(imgA))[0]\n",
"predB = po.to_numpy(models.vgg16(pretrained=True)(imgB))[0]\n",
"\n",
"plt.plot(predA)\n",
"plt.plot(predA);\n",
"plt.plot(predB);"
]
},
Expand Down Expand Up @@ -940,7 +935,7 @@
],
"source": [
"fig, axes = plt.subplots(2, 1, figsize=(5, 8))\n",
"po.synth.geodesic.plot_loss(moog, ax=axes[0])\n",
"po.synth.geodesic.plot_loss(moog, ax=axes[0]);\n",
"po.synth.geodesic.plot_deviation_from_line(moog, ax=axes[1]);"
]
},
Expand Down Expand Up @@ -1057,12 +1052,12 @@
}
],
"source": [
"po.imshow(moog.geodesic, as_rgb=True, zoom=2, title=None, vrange='auto0')\n",
"po.imshow(moog.pixelfade, as_rgb=True, zoom=2, title=None, vrange='auto0')\n",
"po.imshow(moog.geodesic, as_rgb=True, zoom=2, title=None, vrange='auto0');\n",
"po.imshow(moog.pixelfade, as_rgb=True, zoom=2, title=None, vrange='auto0');\n",
"# per channel difference\n",
"po.imshow([(moog.geodesic - moog.pixelfade)[1:-1, 0:1]], zoom=2, title=None, vrange='auto1')\n",
"po.imshow([(moog.geodesic - moog.pixelfade)[1:-1, 1:2]], zoom=2, title=None, vrange='auto1')\n",
"po.imshow([(moog.geodesic - moog.pixelfade)[1:-1, 2:]], zoom=2, title=None, vrange='auto1')\n",
"po.imshow([(moog.geodesic - moog.pixelfade)[1:-1, 0:1]], zoom=2, title=None, vrange='auto1');\n",
"po.imshow([(moog.geodesic - moog.pixelfade)[1:-1, 1:2]], zoom=2, title=None, vrange='auto1');\n",
"po.imshow([(moog.geodesic - moog.pixelfade)[1:-1, 2:]], zoom=2, title=None, vrange='auto1');\n",
"# exaggerated color difference\n",
"po.imshow([po.tools.rescale((moog.geodesic - moog.pixelfade)[1:-1])], as_rgb=True, zoom=2, title=None);"
]
Expand Down
8 changes: 4 additions & 4 deletions examples/06_Metamer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
"metadata": {},
"outputs": [],
"source": [
"import plenoptic as po\n",
"from plenoptic.tools import to_numpy\n",
"import imageio\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"\n",
"import plenoptic as po\n",
"\n",
"import pyrtools as pt\n",
"import matplotlib.pyplot as plt\n",
"# so that relative sizes of axes created by po.imshow and others look right\n",
"plt.rcParams['figure.dpi'] = 72\n",
"# Animation-related settings\n",
Expand Down
Loading

0 comments on commit 79e93d3

Please sign in to comment.