diff --git a/src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py b/src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py index 9293f1cc..a50ce586 100644 --- a/src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py +++ b/src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py @@ -47,9 +47,10 @@ class SteerablePyramidFreq(nn.Module): image_shape : `list or tuple` shape of input image height : 'auto' or `int` - The height of the pyramid. If 'auto', will automatically determine - based on the size of `image`. If an int, must be positive and less than - log2(min(image_shape[1], image_shape[1]))-2. + The height of the pyramid. If 'auto', will automatically determine based on the + size of `image`. If an int, must be non-negative and less than + log2(min(image_shape[1], image_shape[1]))-2. If height=0, this only returns the + residuals. order : `int`. The Gaussian derivative order used for the steerable filters, in [1, 15]. Note that to achieve steerability the minimum number of @@ -135,8 +136,8 @@ def __init__( self.num_scales = int(max_ht) elif height > max_ht: raise ValueError(f"Cannot build pyramid higher than {max_ht:.0f} levels.") - elif height < 1: - raise ValueError("Height must be a positive int.") + elif height < 0: + raise ValueError("Height must be a non-negative integer.") else: self.num_scales = int(height) diff --git a/tests/test_steerable_pyr.py b/tests/test_steerable_pyr.py index 4eb4e11a..0a51566f 100644 --- a/tests/test_steerable_pyr.py +++ b/tests/test_steerable_pyr.py @@ -407,9 +407,9 @@ def test_scales_arg(self, img, spyr, scales): @pytest.mark.parametrize("height", range(-1, 8)) def test_height_values(self, img, height): - if height < 1: + if height < 0: expectation = pytest.raises( - ValueError, match="Height must be a positive int" + ValueError, match="Height must be a non-negative int" ) elif height > 6: expectation = pytest.raises(