diff --git a/jenkins/Jenkinsfile b/jenkins/Jenkinsfile index 0cdd6300..18a5c667 100644 --- a/jenkins/Jenkinsfile +++ b/jenkins/Jenkinsfile @@ -31,7 +31,7 @@ pipeline { pip install -U pip pip install -U .[dev] --verbose python3.11 -c "import torch; print(torch.cuda.current_device())" - python3.11 -m pytest + python3.11 -m pytest -n 8 ''' } } diff --git a/src/plenoptic/data/__init__.py b/src/plenoptic/data/__init__.py index b6527ec8..a6a67048 100644 --- a/src/plenoptic/data/__init__.py +++ b/src/plenoptic/data/__init__.py @@ -1,28 +1,31 @@ -from . import data_utils from .fetch import fetch_data, DOWNLOADABLE_FILES import torch +from importlib import resources +from ..tools.data import load_images __all__ = ['einstein', 'curie', 'parrot', 'reptile_skin', 'color_wheel', 'fetch_data', 'DOWNLOADABLE_FILES'] def __dir__(): return __all__ +FILES = resources.files(__name__) -def einstein() -> torch.Tensor: - return data_utils.get('einstein') +def einstein(as_gray: bool = True) -> torch.Tensor: + return load_images(FILES / 'einstein.pgm', as_gray=as_gray) -def curie() -> torch.Tensor: - return data_utils.get('curie') + +def curie(as_gray: bool = True) -> torch.Tensor: + return load_images(FILES / 'curie.pgm', as_gray=as_gray) def parrot(as_gray: bool = False) -> torch.Tensor: - return data_utils.get('parrot', as_gray=as_gray) + return load_images(FILES / 'parrot.png', as_gray=as_gray) -def reptile_skin() -> torch.Tensor: - return data_utils.get('reptile_skin') +def reptile_skin(as_gray: bool = True) -> torch.Tensor: + return load_images(FILES / 'reptile_skin.pgm', as_gray=as_gray) def color_wheel(as_gray: bool = False) -> torch.Tensor: - return data_utils.get('color_wheel', as_gray=as_gray) + return load_images(FILES / 'color_wheel.jpg', as_gray=as_gray) diff --git a/src/plenoptic/data/data_utils.py b/src/plenoptic/data/data_utils.py deleted file mode 100644 index 037baffa..00000000 --- a/src/plenoptic/data/data_utils.py +++ /dev/null @@ -1,63 +0,0 @@ -from importlib import resources -from importlib.abc import Traversable -from typing import Union - - -from ..tools.data import load_images - - -def get_path(item_name: str) -> Traversable: - """ - Retrieve the filename that matches the given item name with any extension. - - Parameters - ---------- - item_name - The name of the item to find the file for, without specifying the file extension. - - Returns - ------- - : - The filename matching the `item_name` with its extension. - - Raises - ------ - AssertionError - If no files or more than one file match the `item_name`. - - Notes - ----- - This function uses glob to search for files in the current directory matching the `item_name`. - It is assumed that there is only one file matching the name regardless of its extension. - """ - fhs = [file for file in resources.files("plenoptic.data").iterdir() if file.stem == item_name] - assert len(fhs) == 1, f"Expected exactly one file for {item_name}, but found {len(fhs)}." - return fhs[0] - - -def get(*item_names: str, as_gray: Union[None, bool] = None): - """Load an image based on the item name from the package's data resources. - - Parameters - ---------- - item_names : - The names of the items to load, without specifying the file extension. - as_gray : - Whether to load in the image(s) as grayscale or not. If None, will make - best guess based on file extension. - - Returns - ------- - The loaded image object. The exact return type depends on the `load_images` function implementation. - - Notes - ----- - This function first retrieves the full filename using `get_filename` and then loads the image - using `load_images` from the `tools.data` module. It supports loading images as grayscale if - they have a `.pgm` extension. - - """ - paths = [get_path(name) for name in item_names] - if as_gray is None: - as_gray = all(path.suffix == ".pgm" for path in paths) - return load_images(paths, as_gray=as_gray) diff --git a/src/plenoptic/metric/perceptual_distance.py b/src/plenoptic/metric/perceptual_distance.py index c3ec3de2..374f5a88 100644 --- a/src/plenoptic/metric/perceptual_distance.py +++ b/src/plenoptic/metric/perceptual_distance.py @@ -1,15 +1,14 @@ import numpy as np import torch import torch.nn.functional as F +from importlib import resources import warnings from ..simulate.canonical_computations import LaplacianPyramid from ..simulate.canonical_computations.filters import circular_gaussian2d from ..tools.conv import same_padding -import os - -DIRNAME = os.path.dirname(__file__) +DIR = resources.files('plenoptic.metric') def _ssim_parts(img1, img2, pad=False): @@ -365,8 +364,8 @@ def normalized_laplacian_pyramid(img): (_, channel, height, width) = img.size() N_scales = 6 - spatialpooling_filters = np.load(os.path.join(DIRNAME, 'DN_filts.npy')) - sigmas = np.load(os.path.join(DIRNAME, 'DN_sigmas.npy')) + spatialpooling_filters = np.load(DIR / 'DN_filts.npy') + sigmas = np.load(DIR / 'DN_sigmas.npy') L = LaplacianPyramid(n_scales=N_scales, scale_filter=True) laplacian_activations = L.forward(img) diff --git a/src/plenoptic/tools/optim.py b/src/plenoptic/tools/optim.py index 439cc8c3..ae98cd2a 100644 --- a/src/plenoptic/tools/optim.py +++ b/src/plenoptic/tools/optim.py @@ -126,9 +126,8 @@ def penalize_range( penalty Penalty for values outside range """ - # the indexing should flatten it - below_min = synth_img[synth_img < allowed_range[0]] - below_min = torch.pow(below_min - allowed_range[0], 2) - above_max = synth_img[synth_img > allowed_range[1]] - above_max = torch.pow(above_max - allowed_range[1], 2) - return torch.sum(torch.cat([below_min, above_max])) + # Using clip like this is equivalent to using boolean indexing (e.g., + # synth_img[synth_img < allowed_range[0]]) but much faster + below_min = torch.clip(synth_img - allowed_range[0], max=0).pow(2).sum() + above_max = torch.clip(synth_img - allowed_range[1], min=0).pow(2).sum() + return below_min + above_max diff --git a/tests/test_data_get.py b/tests/test_data_get.py index 1a6b5ff3..2575e54b 100644 --- a/tests/test_data_get.py +++ b/tests/test_data_get.py @@ -1,55 +1,14 @@ -from contextlib import nullcontext as does_not_raise - -from importlib import resources -from importlib.abc import Traversable import pytest from torch import Tensor import plenoptic as po -@pytest.mark.parametrize( - "item_name, expectation", - [ - ("color_wheel", does_not_raise()), - ("xyz", pytest.raises(AssertionError, match="Expected exactly one file for xyz, but found 2")), - ("xyzw", pytest.raises(AssertionError, match=f"Expected exactly one file for xyzw, but found 0")) - ] -) -def test_data_get_path(item_name, expectation): - """Test the retrieval of file paths with varying expectations.""" - fh = resources.files("plenoptic.data") - # ensures that the files gets deleted, even if an exception gets hit in - # the try block. - try: - with expectation: - # Create files with specific extensions - for ext in ["abc", "abcd"]: - file_path = fh / f"xyz.{ext}" - file_path.touch(exist_ok=True) - po.data.data_utils.get_path(item_name) - - finally: - # Delete the files created for the test - for ext in ["abc", "abcd"]: - path = fh / f"xyz.{ext}" - if path.exists(): - path.unlink() # This deletes the file - - -@pytest.mark.parametrize("item_name", ["color_wheel", "parrot", "curie"]) -def test_data_get_path_type(item_name): - """Test that the returned path object is an instance of Traversable.""" - assert isinstance(po.data.data_utils.get_path(item_name), Traversable) - - -@pytest.mark.parametrize( - "item_name", ["color_wheel", "parrot", "curie"] -) -def test_data_get_type(item_name): - """Test that the retrieved data is of type Tensor.""" - img = po.data.data_utils.get(item_name) - assert isinstance(img, Tensor) +@pytest.mark.parametrize("item_name", [img for img in dir(po.data) + if img not in ['fetch_data', 'DOWNLOADABLE_FILES']]) +def test_data_module(item_name): + """Test that data module works.""" + assert isinstance(eval(f"po.data.{item_name}()"), Tensor) @pytest.mark.parametrize( @@ -57,17 +16,12 @@ def test_data_get_type(item_name): [ ("color_wheel", (1, 3, 600, 600)), ("parrot", (1, 3, 254, 266)), - ("curie", (1, 1, 256, 256)) + ("curie", (1, 1, 256, 256)), + ("einstein", (1, 1, 256, 256)), + ("reptile_skin", (1, 1, 256, 256)), ] ) def test_data_get_shape(item_name, img_shape): """Check if the shape of the retrieved image matches the expected dimensions.""" - img = po.data.data_utils.get(item_name) + img = eval(f"po.data.{item_name}()") assert all(shp == img_shape[i] for i, shp in enumerate(img.shape)) - - -@pytest.mark.parametrize("item_name", ["color_wheel", "parrot", "curie", - 'einstein', 'reptile_skin']) -def test_data_module(item_name): - """Test that data module works.""" - assert (eval(f"po.data.{item_name}()") == po.data.data_utils.get(item_name)).all() diff --git a/tests/test_tools.py b/tests/test_tools.py index 541bbf66..8147e6a1 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -458,3 +458,16 @@ def test_validate_metric_identical(self): @pytest.mark.parametrize('model', ['frontend.OnOff.nograd'], indirect=True) def test_remove_grad(self, model): po.tools.validate.validate_model(model, device=DEVICE) + + +class TestOptim(object): + + def test_penalize_range_above(self): + img = .5 * torch.ones((1, 1, 4, 4)) + img[..., 0, :] = 2 + assert po.tools.optim.penalize_range(img).item() == 4 + + def test_penalize_range_below(self): + img = .5 * torch.ones((1, 1, 4, 4)) + img[..., 0, :] = -1 + assert po.tools.optim.penalize_range(img).item() == 4