From e8ce40c98ea448abb86d96306fcbba10e68ed352 Mon Sep 17 00:00:00 2001 From: "William F. Broderick" Date: Tue, 2 Jan 2024 17:32:14 -0500 Subject: [PATCH] switch to using register_buffers --- .../steerable_pyramid_freq.py | 91 +++---------- src/plenoptic/simulate/models/naive.py | 8 +- .../simulate/models/portilla_simoncelli.py | 56 +------- tests/test_models.py | 123 ++++++++++++++++++ tests/test_steerable_pyr.py | 40 ++++++ 5 files changed, 194 insertions(+), 124 deletions(-) diff --git a/src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py b/src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py index 524c9314..924b24f3 100644 --- a/src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py +++ b/src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py @@ -144,16 +144,9 @@ def __init__(self, image_shape, height='auto', order=3, twidth=1, is_complex=Fal # create low and high masks lo0mask = interpolate1d(self.log_rad, self.YIrcos, self.Xrcos) hi0mask = interpolate1d(self.log_rad, self.Yrcos, self.Xrcos) - self.lo0mask = torch.tensor(lo0mask).unsqueeze(0) - self.hi0mask = torch.tensor(hi0mask).unsqueeze(0) + self.register_buffer('lo0mask', torch.tensor(lo0mask).unsqueeze(0)) + self.register_buffer('hi0mask', torch.tensor(hi0mask).unsqueeze(0)) - # pre-generate the angle, hi and lo masks, as well as the - # indices used for down-sampling - self._anglemasks = [] - self._anglemasks_recon = [] - self._himasks = [] - self._lomasks = [] - self._loindices = [] # need a mock image to down-sample so that we correctly # construct the differently-sized masks @@ -171,6 +164,9 @@ def __init__(self, image_shape, height='auto', order=3, twidth=1, is_complex=Fal Xrcos = self.Xrcos.copy() angle = self.angle.copy() log_rad = self.log_rad.copy() + # pre-generate the angle, hi and lo masks, as well as the + # indices used for down-sampling. + self._loindices = [] for i in range(self.num_scales): Xrcos -= np.log2(2) const = ((2 ** (2*self.order)) * (factorial(self.order, exact=True)**2) / @@ -187,7 +183,7 @@ def __init__(self, image_shape, height='auto', order=3, twidth=1, is_complex=Fal Ycosn_recon = Ycosn_forward himask = interpolate1d(log_rad, self.Yrcos, Xrcos) - self._himasks.append(torch.tensor(himask).unsqueeze(0)) + self.register_buffer(f'_himasks_scale_{i}', torch.tensor(himask).unsqueeze(0)) anglemasks = [] anglemasks_recon = [] @@ -199,11 +195,11 @@ def __init__(self, image_shape, height='auto', order=3, twidth=1, is_complex=Fal anglemasks.append(torch.tensor(anglemask).unsqueeze(0)) anglemasks_recon.append(torch.tensor(anglemask_recon).unsqueeze(0)) - self._anglemasks.append(anglemasks) - self._anglemasks_recon.append(anglemasks_recon) + self.register_buffer(f'_anglemasks_scale_{i}', torch.cat(anglemasks)) + self.register_buffer(f'_anglemasks_recon_scale_{i}', torch.cat(anglemasks_recon)) if not self.downsample: lomask = interpolate1d(log_rad, self.YIrcos, Xrcos) - self._lomasks.append(torch.tensor(lomask).unsqueeze(0)) + self.register_buffer(f'_lomasks_scale_{i}', torch.tensor(lomask).unsqueeze(0)) self._loindices.append([np.array([0, 0]), dims]) lodft = lodft * lomask @@ -222,62 +218,14 @@ def __init__(self, image_shape, height='auto', order=3, twidth=1, is_complex=Fal angle = angle[lostart[0]:loend[0], lostart[1]:loend[1]] lomask = interpolate1d(log_rad, self.YIrcos, Xrcos) - self._lomasks.append(torch.tensor(lomask).unsqueeze(0)) + self.register_buffer(f'_lomasks_scale_{i}', torch.tensor(lomask).unsqueeze(0)) # subsampling lodft = lodft[lostart[0]:loend[0], lostart[1]:loend[1]] # convolution in spatial domain lodft = lodft * lomask # reasonable default dtype - self = self.to(torch.float32) - - def to(self, *args, **kwargs): - r"""Moves and/or casts the parameters and buffers. - - This can be called as - - .. function:: to(device=None, dtype=None, non_blocking=False) - - .. function:: to(dtype, non_blocking=False) - - .. function:: to(tensor, non_blocking=False) - - Its signature is similar to :meth:`torch.Tensor.to`, but only accepts - floating point desired :attr:`dtype` s. In addition, this method will - only cast the floating point parameters and buffers to :attr:`dtype` - (if given). The integral parameters and buffers will be moved - :attr:`device`, if that is given, but with dtypes unchanged. When - :attr:`non_blocking` is set, it tries to convert/move asynchronously - with respect to the host if possible, e.g., moving CPU Tensors with - pinned memory to CUDA devices. - - See below for examples. - - .. note:: - This method modifies the module in-place. - Args: - device (:class:`torch.device`): the desired device of the parameters - and buffers in this module - dtype (:class:`torch.dtype`): the desired floating point type of - the floating point parameters and buffers in this module - tensor (torch.Tensor): Tensor whose dtype and device are the desired - dtype and device for all parameters and buffers in this module - - Returns: - Module: self - """ - self.lo0mask = self.lo0mask.to(*args, **kwargs) - self.hi0mask = self.hi0mask.to(*args, **kwargs) - self._himasks = [m.to(*args, **kwargs) for m in self._himasks] - self._lomasks = [m.to(*args, **kwargs) for m in self._lomasks] - angles = [] - angles_recon = [] - for a, ar in zip(self._anglemasks, self._anglemasks_recon): - angles.append([m.to(*args, **kwargs) for m in a]) - angles_recon.append([m.to(*args, **kwargs) for m in ar]) - self._anglemasks = angles - self._anglemasks_recon = angles_recon - return self + self.to(torch.float32) def forward(self, x, scales=None): r"""Generate the steerable pyramid coefficients for an image @@ -342,7 +290,7 @@ def forward(self, x, scales=None): if i in scales: #high-pass mask is selected based on the current scale - himask = self._himasks[i] + himask = getattr(self, f'_himasks_scale_{i}') #compute filter output at each orientation for b in range(self.num_orientations): @@ -352,8 +300,8 @@ def forward(self, x, scales=None): # the complex_const variable comes from the Fourier transform of a gaussian derivative. # Based on the order of the gaussian, this constant changes. - - anglemask = self._anglemasks[i][b] + anglemask = getattr(self, f'_anglemasks_scale_{i}')[b] + complex_const = np.power(complex(0, -1), self.order) banddft = complex_const * lodft * anglemask * himask # fft output is then shifted to center frequencies @@ -378,7 +326,7 @@ def forward(self, x, scales=None): if not self.downsample: # no subsampling of angle and rad # just use lo0mask - lomask = self._lomasks[i] + lomask = getattr(self, f'_lomasks_scale_{i}') lodft = lodft * lomask # because we don't subsample here, if we are not using orthonormalization that @@ -397,7 +345,7 @@ def forward(self, x, scales=None): # subsampling of the dft for next scale lodft = lodft[:, :, lostart[0]:loend[0], lostart[1]:loend[1]] # low-pass filter mask is selected - lomask = self._lomasks[i] + lomask = getattr(self, f'_lomasks_scale_{i}') # again multiply dft by subsampled mask (convolution in spatial domain) lodft = lodft * lomask @@ -787,7 +735,7 @@ def _recon_levels(self, pyr_coeffs, recon_keys, scale): # Reconstruct from orientation bands # update himask - himask = self._himasks[scale] + himask = getattr(self, f'_himasks_scale_{scale}') if self.is_complex: tensor_type = torch.complex64 else: @@ -796,7 +744,7 @@ def _recon_levels(self, pyr_coeffs, recon_keys, scale): for b in range(self.num_orientations): if (scale, b) in recon_keys: - anglemask = self._anglemasks_recon[scale][b] + anglemask = getattr(self, f'_anglemasks_recon_scale_{scale}')[b] coeffs = pyr_coeffs[(scale,b)] if self.tight_frame and self.is_complex: coeffs = coeffs*np.sqrt(2) @@ -812,7 +760,8 @@ def _recon_levels(self, pyr_coeffs, recon_keys, scale): lostart, loend = self._loindices[scale] # create lowpass mask - lomask = self._lomasks[scale] + lomask = getattr(self, f'_lomasks_scale_{scale}') + # Recursively reconstruct by going to the next scale reslevdft = self._recon_levels(pyr_coeffs, recon_keys, scale+1) #in not downsampled case, rescale the magnitudes of the reconstructed dft at each level by factor of 2 to account for the scaling in the forward diff --git a/src/plenoptic/simulate/models/naive.py b/src/plenoptic/simulate/models/naive.py index fa1bcac2..af14d64d 100644 --- a/src/plenoptic/simulate/models/naive.py +++ b/src/plenoptic/simulate/models/naive.py @@ -129,7 +129,7 @@ def __init__( self.out_channels = out_channels self.cache_filt = cache_filt - self._filt = None + self.register_buffer('_filt', None) @property def filt(self): @@ -139,7 +139,7 @@ def filt(self): filt = circular_gaussian2d(self.kernel_size, self.std, self.out_channels) if self.cache_filt: - self._filt = filt + self.register_buffer('_filt', filt) return filt def forward(self, x: Tensor, **conv2d_kwargs) -> Tensor: @@ -235,7 +235,7 @@ def __init__( self.pad_mode = pad_mode self.cache_filt = cache_filt - self._filt = None + self.register_buffer('_filt', None) @property def filt(self) -> Tensor: @@ -256,7 +256,7 @@ def filt(self) -> Tensor: filt = on_amp * (sign * (filt_center - filt_surround)) if self.cache_filt: - self._filt = filt + self.register_buffer('_filt', filt) return filt def _clamp_surround_std(self): diff --git a/src/plenoptic/simulate/models/portilla_simoncelli.py b/src/plenoptic/simulate/models/portilla_simoncelli.py index 657d8ee0..297293d7 100644 --- a/src/plenoptic/simulate/models/portilla_simoncelli.py +++ b/src/plenoptic/simulate/models/portilla_simoncelli.py @@ -84,20 +84,19 @@ def __init__( tight_frame=False, ) self.filterPyr = SteerablePyramidFreq( - self.pyr._lomasks[-1].shape[-2:], height=0, order=1, + getattr(self.pyr, f'_lomasks_scale_{n_scales-1}').shape[-2:], + height=0, order=1, tight_frame=False ) - self.unoriented_band_pyrs = [ - SteerablePyramidFreq( - himask.shape[-2:], + for i in range(n_scales): + pyr = SteerablePyramidFreq( + getattr(self.pyr, f'_himasks_scale_{i}').shape[-2:], height=1, order=self.n_orientations - 1, is_complex=False, tight_frame=False, ) - # want to go through these masks backwards - for himask in self.pyr._himasks - ] + setattr(self, f'unoriented_band_pyrs_scale_{i}', pyr) self.use_true_correlations = use_true_correlations self.scales = ( @@ -667,7 +666,7 @@ def _calculate_autocorrelation_skew_kurtosis(self): reconstructed_image = reconstructed_image.unsqueeze(0).unsqueeze(0) # reconstruct the unoriented band for this scale - unoriented_band_pyr = self.unoriented_band_pyrs[this_scale] + unoriented_band_pyr = getattr(self, f'unoriented_band_pyrs_scale_{this_scale}') unoriented_pyr_coeffs = unoriented_band_pyr.forward(reconstructed_image) for ii in range(self.n_orientations): unoriented_pyr_coeffs[(0, ii)] = ( @@ -1177,44 +1176,3 @@ def update_plot(self, axes, batch_idx=0, data=None): sc = update_stem(ax.containers[0], vals) stem_artists.extend([sc.markerline, sc.stemlines]) return stem_artists - - def to(self, *args, **kwargs): - r"""Moves and/or casts the parameters and buffers. - - This can be called as - - .. function:: to(device=None, dtype=None, non_blocking=False) - - .. function:: to(dtype, non_blocking=False) - - .. function:: to(tensor, non_blocking=False) - - Its signature is similar to :meth:`torch.Tensor.to`, but only accepts - floating point desired :attr:`dtype` s. In addition, this method will - only cast the floating point parameters and buffers to :attr:`dtype` - (if given). The integral parameters and buffers will be moved - :attr:`device`, if that is given, but with dtypes unchanged. When - :attr:`non_blocking` is set, it tries to convert/move asynchronously - with respect to the host if possible, e.g., moving CPU Tensors with - pinned memory to CUDA devices. - - See below for examples. - - .. note:: - This method modifies the module in-place. - Args: - device (:class:`torch.device`): the desired device of the parameters - and buffers in this module - dtype (:class:`torch.dtype`): the desired floating point type of - the floating point parameters and buffers in this module - tensor (torch.Tensor): Tensor whose dtype and device are the desired - dtype and device for all parameters and buffers in this module - - Returns: - Module: self - """ - self.pyr = self.pyr.to(*args, **kwargs) - self.filterPyr = self.filterPyr.to(*args, **kwargs) - self.unoriented_band_pyrs = [pyr.to(*args, **kwargs) for pyr in - self.unoriented_band_pyrs] - return self diff --git a/tests/test_models.py b/tests/test_models.py index e8e20f6f..14890f2d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -130,6 +130,34 @@ def test_match_pyrtools(self, curie_img, n_scales): # after upsampling up to one row/column. This causes inconsistency on the right and # bottom edges, so they are exluded in the comparison. + @pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda") + def test_cuda(self, einstein_img): + lpyr = po.simul.LaplacianPyramid() + lpyr.cuda() + lpyr(einstein_img) + + @pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda") + def test_cpu_and_back(self, einstein_img): + lpyr = po.simul.LaplacianPyramid() + lpyr.cpu() + lpyr.cuda() + lpyr(einstein_img) + + @pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda") + def test_cuda_and_back(self, einstein_img): + lpyr = po.simul.LaplacianPyramid() + lpyr.cuda() + lpyr.cpu() + lpyr(einstein_img.cpu()) + # make sure it ends on same device it started, since it's a fixture + einstein_img.to(DEVICE) + + def test_cpu(self, einstein_img): + lpyr = po.simul.LaplacianPyramid() + lpyr.cpu() + lpyr(einstein_img.cpu()) + # make sure it ends on same device it started, since it's a fixture + einstein_img.to(DEVICE) class TestFrontEnd: @@ -168,6 +196,40 @@ def test_frontend_display_filters(self, model): fig = model.display_filters() plt.close(fig) + @pytest.mark.parametrize("model", all_models, indirect=True) + @pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda") + def test_cuda(self, model, einstein_img): + model.cuda() + model(einstein_img) + # make sure it ends on same device it started, since it's a fixture + model.to(DEVICE) + + @pytest.mark.parametrize("model", all_models, indirect=True) + @pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda") + def test_cpu_and_back(self, model, einstein_img): + model.cpu() + model.cuda() + model(einstein_img) + # make sure it ends on same device it started, since it's a fixture + model.to(DEVICE) + + @pytest.mark.parametrize("model", all_models, indirect=True) + @pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda") + def test_cuda_and_back(self, model, einstein_img): + model.cuda() + model.cpu() + model(einstein_img.cpu()) + # make sure it ends on same device it started, since it's a fixture + model.to(DEVICE) + einstein_img.to(DEVICE) + + @pytest.mark.parametrize("model", all_models, indirect=True) + def test_cpu(self, model, einstein_img): + model.cpu() + model(einstein_img.cpu()) + # make sure it ends on same device it started, since it's a fixture + model.to(DEVICE) + einstein_img.to(DEVICE) class TestNaive(object): @@ -214,6 +276,40 @@ def test_linear(self, basic_stim): model = po.simul.Linear().to(DEVICE) assert model(basic_stim).requires_grad + @pytest.mark.parametrize("model", all_models, indirect=True) + @pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda") + def test_cuda(self, model, einstein_img): + model.cuda() + model(einstein_img) + # make sure it ends on same device it started, since it's a fixture + model.to(DEVICE) + + @pytest.mark.parametrize("model", all_models, indirect=True) + @pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda") + def test_cpu_and_back(self, model, einstein_img): + model.cpu() + model.cuda() + model(einstein_img) + # make sure it ends on same device it started, since it's a fixture + model.to(DEVICE) + + @pytest.mark.parametrize("model", all_models, indirect=True) + @pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda") + def test_cuda_and_back(self, model, einstein_img): + model.cuda() + model.cpu() + model(einstein_img.cpu()) + # make sure it ends on same device it started, since it's a fixture + model.to(DEVICE) + einstein_img.to(DEVICE) + + @pytest.mark.parametrize("model", all_models, indirect=True) + def test_cpu(self, model, einstein_img): + model.cpu() + model(einstein_img.cpu()) + # make sure it ends on same device it started, since it's a fixture + model.to(DEVICE) + einstein_img.to(DEVICE) class TestPortillaSimoncelli(object): @pytest.mark.parametrize("n_scales", [1, 2, 3, 4]) @@ -423,6 +519,33 @@ def test_ps_expand(self, im_shape): assert out_im.shape == (im_shape[0] * mult, im_shape[1] * mult) + @pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda") + def test_cuda(self, einstein_img): + ps = po.simul.PortillaSimoncelli(einstein_img.shape[-2:]) + ps.cuda() + ps(einstein_img) + + @pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda") + def test_cpu_and_back(self, einstein_img): + ps = po.simul.PortillaSimoncelli(einstein_img.shape[-2:]) + ps.cpu() + ps.cuda() + ps(einstein_img) + + @pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda") + def test_cuda_and_back(self, einstein_img): + ps = po.simul.PortillaSimoncelli(einstein_img.shape[-2:]) + ps.cuda() + ps.cpu() + ps(einstein_img.cpu()) + einstein_img.to(DEVICE) + + def test_cpu(self, einstein_img): + ps = po.simul.PortillaSimoncelli(einstein_img.shape[-2:]) + ps.cpu() + ps(einstein_img.cpu()) + einstein_img.to(DEVICE) + class TestFilters: @pytest.mark.parametrize("std", [5., torch.tensor(1., device=DEVICE), -1., 0.]) diff --git a/tests/test_steerable_pyr.py b/tests/test_steerable_pyr.py index d5d12315..b8f2b5f3 100644 --- a/tests/test_steerable_pyr.py +++ b/tests/test_steerable_pyr.py @@ -290,3 +290,43 @@ def test_order_values(self, img, order): with expectation: pyr = po.simul.SteerablePyramidFreq(img.shape[-2:], order=order).to(DEVICE) pyr(img) + + @pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda") + def test_cuda(self, img): + pyr = po.simul.SteerablePyramidFreq(img.shape[-2:]) + pyr.cuda() + pyr(img) + + @pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda") + def test_cpu_and_back(self, img): + pyr = po.simul.SteerablePyramidFreq(img.shape[-2:]) + pyr.cpu() + pyr.cuda() + pyr(img) + + @pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda") + def test_cuda_and_back(self, img): + pyr = po.simul.SteerablePyramidFreq(img.shape[-2:]) + pyr.cuda() + pyr.cpu() + pyr(img.cpu()) + # make sure it ends on same device it started, since it's a fixture + img.to(DEVICE) + + def test_cpu(self, img): + pyr = po.simul.SteerablePyramidFreq(img.shape[-2:]) + pyr.cpu() + pyr(img.cpu()) + # make sure it ends on same device it started, since it's a fixture + img.to(DEVICE) + + @pytest.mark.parametrize('order', range(1, 16)) + def test_buffers(self, order): + pyr = po.simul.SteerablePyramidFreq((256, 256), order=order) + buffers = [k for k, _ in pyr.named_buffers()] + names = ['lo0mask', 'hi0mask'] + for s in range(pyr.num_scales): + names.extend([f'_himasks_scale_{s}', f'_lomasks_scale_{s}', + f'_anglemasks_scale_{s}', f'_anglemasks_recon_scale_{s}']) + assert len(buffers) == len(names), "pyramid doesn't have the right number of buffers!" + assert set(buffers) == set(names), "pyramid doesn't have the right buffers!"