10
10
"""
11
11
from attr import attrs , attrib
12
12
import numpy as np
13
+ from six import add_metaclass
14
+ from abc import ABCMeta , abstractmethod
13
15
14
16
15
- @attrs
16
- class Triangle (object ):
17
+ @add_metaclass (ABCMeta )
18
+ class ConvFuncBase (object ):
19
+ """
20
+ Implements truncation (via __call__), numpy array reshaping.
21
+
22
+ Always returns 0 outside truncation radius, i.e.::
23
+
24
+ if np.fabs(x) > trunc:
25
+ conv_func(x)==0 # True
26
+
27
+ Args:
28
+ trunc: truncation radius.
29
+ """
30
+
31
+ def __init__ (self , trunc ):
32
+ self .trunc = trunc
33
+
34
+ @abstractmethod
35
+ def f (self , radius ):
36
+ """The convolution function to be evaluated and truncated"""
37
+ pass
38
+
39
+ def __call__ (self , radius_in_pix ):
40
+ radius_in_pix = np .atleast_1d (radius_in_pix )
41
+ output = np .zeros_like (radius_in_pix )
42
+ inside_trunc_radius = np .fabs (radius_in_pix ) < self .trunc
43
+ output [inside_trunc_radius ] = self .f (radius_in_pix [inside_trunc_radius ])
44
+ return output
45
+
46
+
47
+ class Triangle (ConvFuncBase ):
17
48
"""
18
49
Linearly declines from 1.0 at origin to 0.0 at **half_base_width**, zero thereafter.
19
50
"
@@ -28,17 +59,19 @@ class Triangle(object):
28
59
half_base_width (float): Half-base width of the triangle.
29
60
30
61
"""
31
- half_base_width = attrib ()
32
62
33
- def __call__ (self , radius_in_pix ):
63
+ def __init__ (self , half_base_width ):
64
+ self .half_base_width = half_base_width
65
+ super (Triangle , self ).__init__ (half_base_width )
66
+
67
+ def f (self , radius_in_pix ):
34
68
return np .maximum (
35
69
1.0 - np .fabs (radius_in_pix ) / self .half_base_width ,
36
70
np .zeros_like (radius_in_pix )
37
71
)
38
72
39
73
40
- @attrs
41
- class Pillbox (object ):
74
+ class Pillbox (ConvFuncBase ):
42
75
"""
43
76
Valued 1.0 from origin to **half_base_width**, zero thereafter.
44
77
@@ -53,7 +86,44 @@ class Pillbox(object):
53
86
Attributes:
54
87
half_base_width (float): Half-base width pillbox.
55
88
"""
56
- half_base_width = attrib ()
57
89
58
- def __call__ (self , radius_in_pix ):
90
+ def __init__ (self , half_base_width ):
91
+ self .half_base_width = half_base_width
92
+ super (Pillbox , self ).__init__ (half_base_width )
93
+
94
+ def f (self , radius_in_pix ):
59
95
return np .where (np .fabs (radius_in_pix ) < self .half_base_width , 1.0 , 0.0 )
96
+
97
+
98
+ class Sinc (ConvFuncBase ):
99
+ """
100
+ Sinc function, truncated beyond **trunc** pixels from centre.
101
+
102
+
103
+ Attributes:
104
+ trunc (float): Truncation radius
105
+ """
106
+ trunc = attrib (default = 3.0 )
107
+
108
+ def __init__ (self , trunc ):
109
+ super (Sinc , self ).__init__ (trunc )
110
+
111
+ def f (self , radius_in_pix ):
112
+ return np .sinc (radius_in_pix )
113
+
114
+
115
+ class Sinc (ConvFuncBase ):
116
+ """
117
+ Sinc function, truncated beyond **trunc** pixels from centre.
118
+
119
+
120
+ Attributes:
121
+ trunc (float): Truncation radius
122
+ """
123
+ trunc = attrib (default = 3.0 )
124
+
125
+ def __init__ (self , trunc ):
126
+ super (Sinc , self ).__init__ (trunc )
127
+
128
+ def f (self , radius_in_pix ):
129
+ return np .sinc (radius_in_pix )
0 commit comments