Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updates steerpyr behavior #39

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions src/pyrtools/pyramids/SteerablePyramidFreq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class SteerablePyramidFreq(SteerablePyramidBase):
The squared radial functions tile the Fourier plane with a raised-cosine
falloff. Angular functions are cos(theta- k*pi/order+1)^(order).

Note that reconstruction will not be exact if the image has an odd shape (due to
boundary-handling issues) or if the pyramid is complex with order=0.

Notes
-----
Transform described in [1]_, filter kernel design described in [2]_.
Expand All @@ -30,7 +33,7 @@ class SteerablePyramidFreq(SteerablePyramidBase):
2d image upon which to construct to the pyramid.
height : 'auto' or `int`.
The height of the pyramid. If 'auto', will automatically determine based on the size of
`image`.
`image`. If an int, must be non-negative. When height=0, only returns the residuals.
order : `int`.
The Gaussian derivative order used for the steerable filters. Default value is 3.
Note that to achieve steerability the minimum number of orientation is `order` + 1,
Expand Down Expand Up @@ -66,6 +69,7 @@ class SteerablePyramidFreq(SteerablePyramidBase):
Oct 1995.
.. [2] A Karasaridis and E P Simoncelli, "A Filter Design Technique for Steerable Pyramid
Image Transforms", ICASSP, Atlanta, GA, May 1996.

"""
def __init__(self, image, height='auto', order=3, twidth=1, is_complex=False):
# in the Fourier domain, there's only one choice for how do edge-handling: circular. to
Expand All @@ -78,24 +82,31 @@ def __init__(self, image, height='auto', order=3, twidth=1, is_complex=False):
self.filters = {}
self.order = int(order)

if (image.shape[0] % 2 != 0) or (image.shape[1] % 2 != 0):
warnings.warn("Reconstruction will not be perfect with odd-sized images")

if self.order == 0 and self.is_complex:
warnings.warn("Reconstruction will not be perfect for a complex pyramid with order=0")

# we can't use the base class's _set_num_scales method because the max height is calculated
# slightly differently
max_ht = np.floor(np.log2(min(self.image.shape))) - 2
if height == 'auto' or height is None:
self.num_scales = int(max_ht)
elif height > max_ht:
raise Exception("Cannot build pyramid higher than %d levels." % (max_ht))
raise ValueError("Cannot build pyramid higher than %d levels." % (max_ht))
elif height < 0:
raise ValueError("Height must be a non-negative int.")
else:
self.num_scales = int(height)

if self.order > 15 or self.order < 0:
raise Exception("order must be an integer in the range [0,15]. Truncating.")
raise ValueError("order must be an integer in the range [0,15].")

self.num_orientations = int(order + 1)

if twidth <= 0:
warnings.warn("twidth must be positive. Setting to 1.")
twidth = 1
raise ValueError("twidth must be positive.")
twidth = int(twidth)

dims = np.array(self.image.shape)
Expand Down Expand Up @@ -220,8 +231,7 @@ def recon_pyr(self, levels='all', bands='all', twidth=1):

"""
if twidth <= 0:
warnings.warn("twidth must be positive. Setting to 1.")
twidth = 1
raise ValueError("twidth must be positive.")

recon_keys = self._recon_keys(levels, bands)

Expand Down
2 changes: 1 addition & 1 deletion src/pyrtools/pyramids/pyramid.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _set_num_scales(self, filter_name, height, extra_height=0):
if height == 'auto':
self.num_scales = max_ht
elif height > max_ht:
raise Exception("Cannot build pyramid higher than %d levels." % (max_ht))
raise ValueError("Cannot build pyramid higher than %d levels." % (max_ht))
else:
self.num_scales = int(height)

Expand Down
Loading