Skip to content

Commit

Permalink
switch to using register_buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
billbrod committed Jan 2, 2024
1 parent d918a42 commit e8ce40c
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 124 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,9 @@ def __init__(self, image_shape, height='auto', order=3, twidth=1, is_complex=Fal
# create low and high masks
lo0mask = interpolate1d(self.log_rad, self.YIrcos, self.Xrcos)
hi0mask = interpolate1d(self.log_rad, self.Yrcos, self.Xrcos)
self.lo0mask = torch.tensor(lo0mask).unsqueeze(0)
self.hi0mask = torch.tensor(hi0mask).unsqueeze(0)
self.register_buffer('lo0mask', torch.tensor(lo0mask).unsqueeze(0))
self.register_buffer('hi0mask', torch.tensor(hi0mask).unsqueeze(0))

# pre-generate the angle, hi and lo masks, as well as the
# indices used for down-sampling
self._anglemasks = []
self._anglemasks_recon = []
self._himasks = []
self._lomasks = []
self._loindices = []

# need a mock image to down-sample so that we correctly
# construct the differently-sized masks
Expand All @@ -171,6 +164,9 @@ def __init__(self, image_shape, height='auto', order=3, twidth=1, is_complex=Fal
Xrcos = self.Xrcos.copy()
angle = self.angle.copy()
log_rad = self.log_rad.copy()
# pre-generate the angle, hi and lo masks, as well as the
# indices used for down-sampling.
self._loindices = []
for i in range(self.num_scales):
Xrcos -= np.log2(2)
const = ((2 ** (2*self.order)) * (factorial(self.order, exact=True)**2) /
Expand All @@ -187,7 +183,7 @@ def __init__(self, image_shape, height='auto', order=3, twidth=1, is_complex=Fal
Ycosn_recon = Ycosn_forward

himask = interpolate1d(log_rad, self.Yrcos, Xrcos)
self._himasks.append(torch.tensor(himask).unsqueeze(0))
self.register_buffer(f'_himasks_scale_{i}', torch.tensor(himask).unsqueeze(0))

anglemasks = []
anglemasks_recon = []
Expand All @@ -199,11 +195,11 @@ def __init__(self, image_shape, height='auto', order=3, twidth=1, is_complex=Fal
anglemasks.append(torch.tensor(anglemask).unsqueeze(0))
anglemasks_recon.append(torch.tensor(anglemask_recon).unsqueeze(0))

self._anglemasks.append(anglemasks)
self._anglemasks_recon.append(anglemasks_recon)
self.register_buffer(f'_anglemasks_scale_{i}', torch.cat(anglemasks))
self.register_buffer(f'_anglemasks_recon_scale_{i}', torch.cat(anglemasks_recon))
if not self.downsample:
lomask = interpolate1d(log_rad, self.YIrcos, Xrcos)
self._lomasks.append(torch.tensor(lomask).unsqueeze(0))
self.register_buffer(f'_lomasks_scale_{i}', torch.tensor(lomask).unsqueeze(0))
self._loindices.append([np.array([0, 0]), dims])
lodft = lodft * lomask

Expand All @@ -222,62 +218,14 @@ def __init__(self, image_shape, height='auto', order=3, twidth=1, is_complex=Fal
angle = angle[lostart[0]:loend[0], lostart[1]:loend[1]]

lomask = interpolate1d(log_rad, self.YIrcos, Xrcos)
self._lomasks.append(torch.tensor(lomask).unsqueeze(0))
self.register_buffer(f'_lomasks_scale_{i}', torch.tensor(lomask).unsqueeze(0))
# subsampling
lodft = lodft[lostart[0]:loend[0], lostart[1]:loend[1]]
# convolution in spatial domain
lodft = lodft * lomask

# reasonable default dtype
self = self.to(torch.float32)

def to(self, *args, **kwargs):
r"""Moves and/or casts the parameters and buffers.
This can be called as
.. function:: to(device=None, dtype=None, non_blocking=False)
.. function:: to(dtype, non_blocking=False)
.. function:: to(tensor, non_blocking=False)
Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
floating point desired :attr:`dtype` s. In addition, this method will
only cast the floating point parameters and buffers to :attr:`dtype`
(if given). The integral parameters and buffers will be moved
:attr:`device`, if that is given, but with dtypes unchanged. When
:attr:`non_blocking` is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
See below for examples.
.. note::
This method modifies the module in-place.
Args:
device (:class:`torch.device`): the desired device of the parameters
and buffers in this module
dtype (:class:`torch.dtype`): the desired floating point type of
the floating point parameters and buffers in this module
tensor (torch.Tensor): Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
Returns:
Module: self
"""
self.lo0mask = self.lo0mask.to(*args, **kwargs)
self.hi0mask = self.hi0mask.to(*args, **kwargs)
self._himasks = [m.to(*args, **kwargs) for m in self._himasks]
self._lomasks = [m.to(*args, **kwargs) for m in self._lomasks]
angles = []
angles_recon = []
for a, ar in zip(self._anglemasks, self._anglemasks_recon):
angles.append([m.to(*args, **kwargs) for m in a])
angles_recon.append([m.to(*args, **kwargs) for m in ar])
self._anglemasks = angles
self._anglemasks_recon = angles_recon
return self
self.to(torch.float32)

def forward(self, x, scales=None):
r"""Generate the steerable pyramid coefficients for an image
Expand Down Expand Up @@ -342,7 +290,7 @@ def forward(self, x, scales=None):

if i in scales:
#high-pass mask is selected based on the current scale
himask = self._himasks[i]
himask = getattr(self, f'_himasks_scale_{i}')
#compute filter output at each orientation
for b in range(self.num_orientations):

Expand All @@ -352,8 +300,8 @@ def forward(self, x, scales=None):
# the complex_const variable comes from the Fourier transform of a gaussian derivative.
# Based on the order of the gaussian, this constant changes.


anglemask = self._anglemasks[i][b]
anglemask = getattr(self, f'_anglemasks_scale_{i}')[b]

complex_const = np.power(complex(0, -1), self.order)
banddft = complex_const * lodft * anglemask * himask
# fft output is then shifted to center frequencies
Expand All @@ -378,7 +326,7 @@ def forward(self, x, scales=None):
if not self.downsample:
# no subsampling of angle and rad
# just use lo0mask
lomask = self._lomasks[i]
lomask = getattr(self, f'_lomasks_scale_{i}')
lodft = lodft * lomask

# because we don't subsample here, if we are not using orthonormalization that
Expand All @@ -397,7 +345,7 @@ def forward(self, x, scales=None):
# subsampling of the dft for next scale
lodft = lodft[:, :, lostart[0]:loend[0], lostart[1]:loend[1]]
# low-pass filter mask is selected
lomask = self._lomasks[i]
lomask = getattr(self, f'_lomasks_scale_{i}')
# again multiply dft by subsampled mask (convolution in spatial domain)

lodft = lodft * lomask
Expand Down Expand Up @@ -787,7 +735,7 @@ def _recon_levels(self, pyr_coeffs, recon_keys, scale):

# Reconstruct from orientation bands
# update himask
himask = self._himasks[scale]
himask = getattr(self, f'_himasks_scale_{scale}')
if self.is_complex:
tensor_type = torch.complex64
else:
Expand All @@ -796,7 +744,7 @@ def _recon_levels(self, pyr_coeffs, recon_keys, scale):

for b in range(self.num_orientations):
if (scale, b) in recon_keys:
anglemask = self._anglemasks_recon[scale][b]
anglemask = getattr(self, f'_anglemasks_recon_scale_{scale}')[b]
coeffs = pyr_coeffs[(scale,b)]
if self.tight_frame and self.is_complex:
coeffs = coeffs*np.sqrt(2)
Expand All @@ -812,7 +760,8 @@ def _recon_levels(self, pyr_coeffs, recon_keys, scale):
lostart, loend = self._loindices[scale]

# create lowpass mask
lomask = self._lomasks[scale]
lomask = getattr(self, f'_lomasks_scale_{scale}')

# Recursively reconstruct by going to the next scale
reslevdft = self._recon_levels(pyr_coeffs, recon_keys, scale+1)
#in not downsampled case, rescale the magnitudes of the reconstructed dft at each level by factor of 2 to account for the scaling in the forward
Expand Down
8 changes: 4 additions & 4 deletions src/plenoptic/simulate/models/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(
self.out_channels = out_channels

self.cache_filt = cache_filt
self._filt = None
self.register_buffer('_filt', None)

@property
def filt(self):
Expand All @@ -139,7 +139,7 @@ def filt(self):
filt = circular_gaussian2d(self.kernel_size, self.std, self.out_channels)

if self.cache_filt:
self._filt = filt
self.register_buffer('_filt', filt)
return filt

def forward(self, x: Tensor, **conv2d_kwargs) -> Tensor:
Expand Down Expand Up @@ -235,7 +235,7 @@ def __init__(
self.pad_mode = pad_mode

self.cache_filt = cache_filt
self._filt = None
self.register_buffer('_filt', None)

@property
def filt(self) -> Tensor:
Expand All @@ -256,7 +256,7 @@ def filt(self) -> Tensor:
filt = on_amp * (sign * (filt_center - filt_surround))

if self.cache_filt:
self._filt = filt
self.register_buffer('_filt', filt)
return filt

def _clamp_surround_std(self):
Expand Down
56 changes: 7 additions & 49 deletions src/plenoptic/simulate/models/portilla_simoncelli.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,19 @@ def __init__(
tight_frame=False,
)
self.filterPyr = SteerablePyramidFreq(
self.pyr._lomasks[-1].shape[-2:], height=0, order=1,
getattr(self.pyr, f'_lomasks_scale_{n_scales-1}').shape[-2:],
height=0, order=1,
tight_frame=False
)
self.unoriented_band_pyrs = [
SteerablePyramidFreq(
himask.shape[-2:],
for i in range(n_scales):
pyr = SteerablePyramidFreq(
getattr(self.pyr, f'_himasks_scale_{i}').shape[-2:],
height=1,
order=self.n_orientations - 1,
is_complex=False,
tight_frame=False,
)
# want to go through these masks backwards
for himask in self.pyr._himasks
]
setattr(self, f'unoriented_band_pyrs_scale_{i}', pyr)

self.use_true_correlations = use_true_correlations
self.scales = (
Expand Down Expand Up @@ -667,7 +666,7 @@ def _calculate_autocorrelation_skew_kurtosis(self):
reconstructed_image = reconstructed_image.unsqueeze(0).unsqueeze(0)

# reconstruct the unoriented band for this scale
unoriented_band_pyr = self.unoriented_band_pyrs[this_scale]
unoriented_band_pyr = getattr(self, f'unoriented_band_pyrs_scale_{this_scale}')
unoriented_pyr_coeffs = unoriented_band_pyr.forward(reconstructed_image)
for ii in range(self.n_orientations):
unoriented_pyr_coeffs[(0, ii)] = (
Expand Down Expand Up @@ -1177,44 +1176,3 @@ def update_plot(self, axes, batch_idx=0, data=None):
sc = update_stem(ax.containers[0], vals)
stem_artists.extend([sc.markerline, sc.stemlines])
return stem_artists

def to(self, *args, **kwargs):
r"""Moves and/or casts the parameters and buffers.
This can be called as
.. function:: to(device=None, dtype=None, non_blocking=False)
.. function:: to(dtype, non_blocking=False)
.. function:: to(tensor, non_blocking=False)
Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
floating point desired :attr:`dtype` s. In addition, this method will
only cast the floating point parameters and buffers to :attr:`dtype`
(if given). The integral parameters and buffers will be moved
:attr:`device`, if that is given, but with dtypes unchanged. When
:attr:`non_blocking` is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
See below for examples.
.. note::
This method modifies the module in-place.
Args:
device (:class:`torch.device`): the desired device of the parameters
and buffers in this module
dtype (:class:`torch.dtype`): the desired floating point type of
the floating point parameters and buffers in this module
tensor (torch.Tensor): Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module
Returns:
Module: self
"""
self.pyr = self.pyr.to(*args, **kwargs)
self.filterPyr = self.filterPyr.to(*args, **kwargs)
self.unoriented_band_pyrs = [pyr.to(*args, **kwargs) for pyr in
self.unoriented_band_pyrs]
return self
Loading

0 comments on commit e8ce40c

Please sign in to comment.