Skip to content

Commit 4298d05

Browse files
committed
Clarify amplitude and length_scale (#32)
1 parent 46b0911 commit 4298d05

File tree

2 files changed

+38
-14
lines changed

2 files changed

+38
-14
lines changed

tests/test_kernels.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,22 @@ def test_custom_kernel_reloading(tmp_path: Path):
4141
@pytest.mark.parametrize("kernel_type", [RBFKernelFn, MaternOneHalfFn])
4242
def test_reload(tmp_path: Path, kernel_type: Type[AmpAndLengthScaleFn]):
4343
"""Test saving and reloading a builtin kernel to/from disk."""
44-
# Example layer weights: [_amplitude, _length_scale]
44+
# Example layer weights: [_amplitude_basis, _length_scale_basis]
4545
example_weights = [np.array(2.0), np.array(3.0)]
4646
save_dir = tmp_path / "kernel"
4747
orig_kernel = kernel_type()
4848
orig_kernel.set_weights(example_weights)
4949
orig_kernel.save(save_dir)
5050

51-
loaded_kernel = load_kernel(save_dir)
51+
orig_amplitude = orig_kernel.amplitude
52+
orig_length_scale = orig_kernel.length_scale
53+
54+
loaded_kernel: AmpAndLengthScaleFn = load_kernel(save_dir)
5255
assert isinstance(loaded_kernel, kernel_type)
5356
loaded_weights = loaded_kernel.get_weights()
5457
assert loaded_weights == example_weights
58+
59+
loaded_amp = loaded_kernel.amplitude
60+
loaded_ls = loaded_kernel.length_scale
61+
assert loaded_amp == pytest.approx(orig_amplitude)
62+
assert loaded_ls == pytest.approx(orig_length_scale)

unlocknn/kernel_layers.py

+28-12
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,10 @@ class AmpAndLengthScaleFn(KernelLayer, ABC):
8989
"""An ABC for kernels with amplitude and length scale parameters.
9090
9191
Attributes:
92-
_amplitude (tf.Tensor): The amplitude of the kernel.
93-
_length_scale (tf.Tensor): The length scale of the kernel.
92+
_amplitude_basis (tf.Tensor): The basis for the kernel amplitude,
93+
which is passed through a softplus to calculate the actual amplitude.
94+
_length_scale_basis (tf.Tensor): The basis for the length scale of the kernel.
95+
which is passed through a softplus to calculate the actual amplitude.
9496
9597
"""
9698

@@ -99,30 +101,42 @@ def __init__(self, **kwargs):
99101
super().__init__(**kwargs)
100102
dtype = kwargs.get("dtype", tf.float64)
101103

102-
self._amplitude = self.add_weight(
104+
self._amplitude_basis = self.add_weight(
103105
initializer=tf.constant_initializer(0), dtype=dtype, name="amplitude"
104106
)
105107

106-
self._length_scale = self.add_weight(
108+
self._length_scale_basis = self.add_weight(
107109
initializer=tf.constant_initializer(0), dtype=dtype, name="length_scale"
108110
)
111+
112+
@property
113+
def amplitude(self) -> float:
114+
"""Get the current kernel amplitude."""
115+
return tf.nn.softplus(0.1 * self._amplitude_basis).numpy().item()
116+
117+
@property
118+
def length_scale(self) -> float:
119+
"""Get the current kernel length scale."""
120+
return tf.nn.softplus(5.0 * self._length_scale_basis).numpy().item()
109121

110122

111123
class RBFKernelFn(AmpAndLengthScaleFn):
112124
"""A radial basis function implementation that works with keras.
113125
114126
Attributes:
115-
_amplitude (tf.Tensor): The amplitude of the kernel.
116-
_length_scale (tf.Tensor): The length scale of the kernel.
127+
_amplitude_basis (tf.Tensor): The basis for the kernel amplitude,
128+
which is passed through a softplus to calculate the actual amplitude.
129+
_length_scale_basis (tf.Tensor): The basis for the length scale of the kernel.
130+
which is passed through a softplus to calculate the actual amplitude.
117131
118132
"""
119133

120134
@property
121135
def kernel(self) -> tfp.math.psd_kernels.PositiveSemidefiniteKernel:
122136
"""Get a callable kernel."""
123137
return tfp.math.psd_kernels.ExponentiatedQuadratic(
124-
amplitude=tf.nn.softplus(0.1 * self._amplitude),
125-
length_scale=tf.nn.softplus(5.0 * self._length_scale),
138+
amplitude=self.amplitude,
139+
length_scale=self.length_scale,
126140
)
127141

128142
@property
@@ -134,17 +148,19 @@ class MaternOneHalfFn(AmpAndLengthScaleFn):
134148
"""A Matern kernel with parameter 1/2 implementation that works with keras.
135149
136150
Attributes:
137-
_amplitude (tf.Tensor): The amplitude of the kernel.
138-
_length_scale (tf.Tensor): The length scale of the kernel.
151+
_amplitude_basis (tf.Tensor): The basis for the kernel amplitude,
152+
which is passed through a softplus to calculate the actual amplitude.
153+
_length_scale_basis (tf.Tensor): The basis for the length scale of the kernel.
154+
which is passed through a softplus to calculate the actual amplitude.
139155
140156
"""
141157

142158
@property
143159
def kernel(self) -> tfp.math.psd_kernels.PositiveSemidefiniteKernel:
144160
"""Get a callable kernel."""
145161
return tfp.math.psd_kernels.MaternOneHalf(
146-
amplitude=tf.nn.softplus(0.1 * self._amplitude),
147-
length_scale=tf.nn.softplus(5.0 * self._length_scale),
162+
amplitude=self.amplitude,
163+
length_scale=self.length_scale,
148164
)
149165

150166
@property

0 commit comments

Comments
 (0)