-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathslice_profiles.py
152 lines (119 loc) · 4.48 KB
/
slice_profiles.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""Slice Profiles."""
import abc
from collections.abc import Sequence
from math import log
import numpy as np
import torch
from torch import Tensor
__all__ = ['SliceProfileBase', 'SliceGaussian', 'SliceSmoothedRectangular', 'SliceInterpolate']
class SliceProfileBase(abc.ABC, torch.nn.Module):
"""Base class for slice profiles."""
@abc.abstractmethod
def forward(self, x: Tensor) -> Tensor:
"""Evaluate the slice profile at a position x."""
raise NotImplementedError
def random_sample(self, size: Sequence[int]) -> Tensor:
"""Sample n random positions from the profile.
Use the profile as a probability density function to sample positions.
Parameters
----------
size
Number of positions to sample
Returns
-------
Sampled positions, shape will be size.
"""
raise NotImplementedError
class SliceGaussian(SliceProfileBase):
"""Gaussian Slice Profile."""
def __init__(self, fwhm: float | Tensor):
"""Initialize the Gaussian Slice Profile.
Parameters
----------
fwhm
Full width at half maximum of the Gaussian
"""
super().__init__()
self.register_buffer('fwhm', torch.as_tensor(fwhm))
def forward(self, x: Tensor) -> Tensor:
"""Evaluate the Gaussian Slice Profile at a position.
Parameters
----------
x
Position at which to evaluate the profile
Returns
-------
Value of the profile / intensity at the given position
"""
return torch.exp(-(x**2) / (0.36 * self.fwhm**2))
class SliceSmoothedRectangular(SliceProfileBase):
"""Rectangular Slice Profile with smoothed flanks.
Implemented as a convolution of a rectangular profile
with a Gaussian.
"""
def __init__(self, fwhm_rect: float | Tensor, fwhm_gauss: float | Tensor):
"""Initialize the Rectangular Slice Profile.
Parameters
----------
fwhm_rect
Full width at half maximum of the rectangular profile
fwhm_gauss
Full width at half maximum of the Gaussian profile.
Set to zero to disable smoothing.
Returns
-------
Value of the profile / intensity at the given position
"""
super().__init__()
self.register_buffer('fwhm_rect', torch.as_tensor(fwhm_rect))
self.register_buffer('fwhm_gauss', torch.as_tensor(fwhm_gauss))
def forward(self, x: Tensor) -> Tensor:
"""Evaluate the Gaussian Slice Profile at a position.
Parameters
----------
x
Position at which to evaluate the profile
Returns
-------
Value of the profile / intensity at the given position
"""
scaled = x * 2 / self.fwhm_rect
if self.fwhm_gauss > 0 and self.fwhm_rect > 0:
n = (log(2) ** 0.5) * self.fwhm_rect / self.fwhm_gauss
norm = 1 / (2 * torch.erf(n))
return (torch.erf(n * (1 - scaled)) + torch.erf(n * (1 + scaled))) * norm
elif self.fwhm_rect > 0:
return (scaled.abs() <= 1).float()
elif self.fwhm_gauss > 0:
return torch.exp(-4 * log(2) * (x / self.fwhm_gauss) ** 2)
else:
raise ValueError('At least one of the widths has to be greater zero.')
class SliceInterpolate(SliceProfileBase):
"""Slice Profile based on Interpolation of Measured Profile."""
def __init__(self, positions: Tensor, values: Tensor):
"""Initialize the Interpolated Slice Profile.
Parameters
----------
positions
Positions of the measured profile
values
Intensities of the measured profile
"""
super().__init__()
self._xs = positions.detach().cpu().float().numpy()
self._weights = values.detach().cpu().float().numpy()
def forward(self, x: Tensor) -> Tensor:
"""Evaluate the Interpolated Slice Profile at a position.
Parameters
----------
x
Position at which to evaluate the profile
Returns
-------
Value of the profile / intensity at the given position
"""
if x.requires_grad:
raise NotImplementedError('Interpolated profile does not support gradients.')
x_np = x.detach().cpu().numpy()
y_np = torch.as_tensor(np.interp(x_np, self._xs, self._weights, 0, 0))
return y_np.to(x.device)