@@ -89,8 +89,10 @@ class AmpAndLengthScaleFn(KernelLayer, ABC):
89
89
"""An ABC for kernels with amplitude and length scale parameters.
90
90
91
91
Attributes:
92
- _amplitude (tf.Tensor): The amplitude of the kernel.
93
- _length_scale (tf.Tensor): The length scale of the kernel.
92
+ _amplitude_basis (tf.Tensor): The basis for the kernel amplitude,
93
+ which is passed through a softplus to calculate the actual amplitude.
94
+ _length_scale_basis (tf.Tensor): The basis for the length scale of the kernel.
95
+ which is passed through a softplus to calculate the actual amplitude.
94
96
95
97
"""
96
98
@@ -99,30 +101,42 @@ def __init__(self, **kwargs):
99
101
super ().__init__ (** kwargs )
100
102
dtype = kwargs .get ("dtype" , tf .float64 )
101
103
102
- self ._amplitude = self .add_weight (
104
+ self ._amplitude_basis = self .add_weight (
103
105
initializer = tf .constant_initializer (0 ), dtype = dtype , name = "amplitude"
104
106
)
105
107
106
- self ._length_scale = self .add_weight (
108
+ self ._length_scale_basis = self .add_weight (
107
109
initializer = tf .constant_initializer (0 ), dtype = dtype , name = "length_scale"
108
110
)
111
+
112
+ @property
113
+ def amplitude (self ) -> float :
114
+ """Get the current kernel amplitude."""
115
+ return tf .nn .softplus (0.1 * self ._amplitude_basis ).numpy ().item ()
116
+
117
+ @property
118
+ def length_scale (self ) -> float :
119
+ """Get the current kernel length scale."""
120
+ return tf .nn .softplus (5.0 * self ._length_scale_basis ).numpy ().item ()
109
121
110
122
111
123
class RBFKernelFn (AmpAndLengthScaleFn ):
112
124
"""A radial basis function implementation that works with keras.
113
125
114
126
Attributes:
115
- _amplitude (tf.Tensor): The amplitude of the kernel.
116
- _length_scale (tf.Tensor): The length scale of the kernel.
127
+ _amplitude_basis (tf.Tensor): The basis for the kernel amplitude,
128
+ which is passed through a softplus to calculate the actual amplitude.
129
+ _length_scale_basis (tf.Tensor): The basis for the length scale of the kernel.
130
+ which is passed through a softplus to calculate the actual amplitude.
117
131
118
132
"""
119
133
120
134
@property
121
135
def kernel (self ) -> tfp .math .psd_kernels .PositiveSemidefiniteKernel :
122
136
"""Get a callable kernel."""
123
137
return tfp .math .psd_kernels .ExponentiatedQuadratic (
124
- amplitude = tf . nn . softplus ( 0.1 * self ._amplitude ) ,
125
- length_scale = tf . nn . softplus ( 5.0 * self ._length_scale ) ,
138
+ amplitude = self .amplitude ,
139
+ length_scale = self .length_scale ,
126
140
)
127
141
128
142
@property
@@ -134,17 +148,19 @@ class MaternOneHalfFn(AmpAndLengthScaleFn):
134
148
"""A Matern kernel with parameter 1/2 implementation that works with keras.
135
149
136
150
Attributes:
137
- _amplitude (tf.Tensor): The amplitude of the kernel.
138
- _length_scale (tf.Tensor): The length scale of the kernel.
151
+ _amplitude_basis (tf.Tensor): The basis for the kernel amplitude,
152
+ which is passed through a softplus to calculate the actual amplitude.
153
+ _length_scale_basis (tf.Tensor): The basis for the length scale of the kernel.
154
+ which is passed through a softplus to calculate the actual amplitude.
139
155
140
156
"""
141
157
142
158
@property
143
159
def kernel (self ) -> tfp .math .psd_kernels .PositiveSemidefiniteKernel :
144
160
"""Get a callable kernel."""
145
161
return tfp .math .psd_kernels .MaternOneHalf (
146
- amplitude = tf . nn . softplus ( 0.1 * self ._amplitude ) ,
147
- length_scale = tf . nn . softplus ( 5.0 * self ._length_scale ) ,
162
+ amplitude = self .amplitude ,
163
+ length_scale = self .length_scale ,
148
164
)
149
165
150
166
@property
0 commit comments