Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 20, 2024
1 parent c6df6b5 commit 15d0032
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 76 deletions.
51 changes: 30 additions & 21 deletions src/plenoptic/simulate/canonical_computations/weighted_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@


class SimpleAverage(torch.nn.Module):
"""Module to average over the last two dimensions of input
"""
"""Module to average over the last two dimensions of input"""

def __init__(self):
super().__init__()

def forward(self, x: Tensor):
"""Average over the last two dimensions of input.
"""
"""Average over the last two dimensions of input."""
return torch.mean(x, dim=(-2, -1))


Expand Down Expand Up @@ -47,6 +46,7 @@ class WeightedAverage(torch.nn.Module):
initialization) containing the normalized weights.
"""

def __init__(self, weights_1: Tensor, weights_2: Tensor | None = None):
super().__init__()
self._validate_weights(weights_1, "_1")
Expand All @@ -57,7 +57,9 @@ def __init__(self, weights_1: Tensor, weights_2: Tensor | None = None):
if weights_2 is not None:
self._validate_weights(weights_2, "_2")
if weights_1.shape[-2:] != weights_2.shape[-2:]:
raise ValueError("weights_1 and weights_2 must have same height and width!")
raise ValueError(
"weights_1 and weights_2 must have same height and width!"
)
self._n_weights += 1
input_einsum += ", w2 h w"
output_einsum += " w2"
Expand Down Expand Up @@ -94,11 +96,15 @@ def _normalize_weights(self):
"""
weight_sums = self.sum_weights()
if torch.isclose(weight_sums, torch.zeros_like(weight_sums)).any():
raise ValueError("Some of the weights sum to zero! This will not work out well.")
raise ValueError(
"Some of the weights sum to zero! This will not work out well."
)
var = weight_sums.var()
if not torch.isclose(var, torch.zeros_like(var)):
warnings.warn("Looks like there's some variation in the sums across your weights."
" That might be fine, but just wanted to make sure you knew...")
warnings.warn(
"Looks like there's some variation in the sums across your weights."
" That might be fine, but just wanted to make sure you knew..."
)
mode = torch.mode(weight_sums.flatten()).values
if not torch.isclose(mode, torch.ones_like(mode)):
warnings.warn("Weights don't sum to 1, normalizing...")
Expand All @@ -109,7 +115,7 @@ def _normalize_weights(self):
self._weights_2 = self._weights_2 / mode.sqrt()

@staticmethod
def _validate_weights(weights: Tensor, idx: 'str' = '_1'):
def _validate_weights(weights: Tensor, idx: "str" = "_1"):
if weights.ndim != 3:
raise ValueError(f"weights{idx} must be 3d!")
if weights.min() < 0:
Expand All @@ -133,15 +139,11 @@ def forward(self, image: Tensor) -> Tensor:
"""
if image.ndim < 4:
raise ValueError(
"image must be a tensor of 4 to 6 dimensions!"
)
raise ValueError("image must be a tensor of 4 to 6 dimensions!")
try:
extra_dims = self._extra_dims[image.ndim - 4]
except IndexError:
raise ValueError(
"image must be a tensor of 4 to 6 dimensions!"
)
raise ValueError("image must be a tensor of 4 to 6 dimensions!")
einsum_str = self._forward_einsum.format(extra_dims=extra_dims)
return einops.einsum(*self.weights, image, einsum_str).flatten(2, 3)

Expand Down Expand Up @@ -213,19 +215,24 @@ class WeightedAveragePyramid(torch.nn.Module):
ModuleList of ``WeightedAverage`` at each scale
"""
def __init__(self, weights_1: Tensor,
weights_2: Tensor | None = None,
n_scales: int = 4):

def __init__(
self, weights_1: Tensor, weights_2: Tensor | None = None, n_scales: int = 4
):
super().__init__()
self._n_weights = 1 if weights_2 is None else 2
self.n_scales = n_scales
weights = []
for i in range(n_scales):
if i != 0:
weights_1 = blur_downsample(weights_1.unsqueeze(0), 1, scale_filter=True)
weights_1 = blur_downsample(
weights_1.unsqueeze(0), 1, scale_filter=True
)
weights_1 = weights_1.squeeze(0).clip(min=0)
if weights_2 is not None:
weights_2 = blur_downsample(weights_2.unsqueeze(0), 1, scale_filter=True)
weights_2 = blur_downsample(
weights_2.unsqueeze(0), 1, scale_filter=True
)
weights_2 = weights_2.squeeze(0).clip(min=0)
# it's possible negative values will get introduced by the downsampling
# above, in which case we remove them
Expand Down Expand Up @@ -277,7 +284,9 @@ def einsum(self, einsum_str: str, *tensors: list[Tensor]) -> list[Tensor]:
The result of this einsum. Scales are stacked along last dimension.
"""
return torch.stack([w.einsum(einsum_str, *x) for *x, w in zip(*tensors, self.weights)], dim=-1)
return torch.stack(
[w.einsum(einsum_str, *x) for *x, w in zip(*tensors, self.weights)], dim=-1
)

def sum_weights(self) -> Tensor:
"""Sum weights, largely used for diagnostic purposes.
Expand Down
Loading

0 comments on commit 15d0032

Please sign in to comment.