Skip to content

Commit

Permalink
Merge pull request #291 from plenoptic-org/frontend_params
Browse files Browse the repository at this point in the history
Improvements to FrontEnd models
  • Loading branch information
billbrod authored Sep 5, 2024
2 parents 1e0377b + d193bf4 commit a51cd75
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 90 deletions.
14 changes: 10 additions & 4 deletions src/plenoptic/simulate/canonical_computations/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
__all__ = ["gaussian1d", "circular_gaussian2d"]


def gaussian1d(kernel_size: int = 11, std: Union[float, Tensor] = 1.5) -> Tensor:
def gaussian1d(kernel_size: int = 11, std: Union[int, float, Tensor] = 1.5) -> Tensor:
"""Normalized 1D Gaussian.
1d Gaussian of size `kernel_size`, centered half-way, with variable std
Expand All @@ -27,9 +27,15 @@ def gaussian1d(kernel_size: int = 11, std: Union[float, Tensor] = 1.5) -> Tensor
filt:
1d Gaussian with `Size([kernel_size])`.
"""
assert std > 0.0, "std must be positive"
if isinstance(std, float):
std = torch.as_tensor(std)
try:
dtype = std.dtype
except AttributeError:
dtype = torch.float32
std = torch.as_tensor(std, dtype=dtype)
if std.numel() != 1:
raise ValueError("std must have only one element!")
if std <= 0:
raise ValueError("std must be positive!")
device = std.device

x = torch.arange(kernel_size).to(device)
Expand Down
Loading

0 comments on commit a51cd75

Please sign in to comment.