-
Notifications
You must be signed in to change notification settings - Fork 28
/
kernels.py
1828 lines (1595 loc) · 69 KB
/
kernels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import objax
from jax import vmap
import jax.numpy as np
from jax.scipy.linalg import cho_factor, cho_solve, block_diag, expm
from .utils import scaled_squared_euclid_dist, softplus, softplus_inv, rotation_matrix
from warnings import warn
from tensorflow_probability.substrates.jax.math import bessel_ive
FACTORIALS = np.array([1, 1, 2, 6, 24, 120, 720, 5040, 40320, 362880, 3628800])
def factorial(i):
return FACTORIALS[i]
def coeff(j, lengthscale, order):
"""
Can be used to generate co-efficients for the quasi-periodic kernels that guarantee a valid covariance function:
q2 = np.array([coeff(j, lengthscale, order) for j in range(order + 1)])
See eq (26) of [1].
Not currently used (we use Bessel functions instead for clarity).
[1] Solin & Sarkka (2014) "Explicit Link Between Periodic Covariance Functions and State Space Models".
"""
s = sum([
# (2 * lengthscale ** 2) ** -(j + 2 * i) / (factorial(j+i) * factorial(i)) for i in range(int((order-j) / 2))
(2 * lengthscale ** 2) ** -(j + 2 * i) / (factorial(j+i) * factorial(i)) for i in range(int((order+2-j) / 2))
])
return 1 / np.exp(lengthscale ** -2) * s
class Kernel(objax.Module):
"""
"""
def __call__(self, X, X2):
return self.K(X, X2)
def K(self, X, X2):
raise NotImplementedError('kernel function not implemented')
def measurement_model(self):
raise NotImplementedError
def inducing_precision(self):
return None, None
def kernel_to_state_space(self, R=None):
raise NotImplementedError
def spatial_conditional(self, R=None, predict=False):
"""
"""
return None, None
def get_meanfield_block_index(self):
raise Exception('Either the mean-field method is not applicable to this kernel, '
'or this kernel\'s get_meanfield_block_index() method has not been implemented')
def feedback_matrix(self):
raise NotImplementedError
def state_transition(self, dt):
# TODO(32): fix prediction when using expm to compute the state transition.
F = self.feedback_matrix()
A = expm(F * dt)
return A
class StationaryKernel(Kernel):
"""
"""
def __init__(self,
variance=1.0,
lengthscale=1.0,
fix_variance=False,
fix_lengthscale=False):
# check whether the parameters are to be optimised
if fix_lengthscale:
self.transformed_lengthscale = objax.StateVar(softplus_inv(np.array(lengthscale)))
else:
self.transformed_lengthscale = objax.TrainVar(softplus_inv(np.array(lengthscale)))
if fix_variance:
self.transformed_variance = objax.StateVar(softplus_inv(np.array(variance)))
else:
self.transformed_variance = objax.TrainVar(softplus_inv(np.array(variance)))
@property
def variance(self):
return softplus(self.transformed_variance.value)
@property
def lengthscale(self):
return softplus(self.transformed_lengthscale.value)
def K(self, X, X2):
r2 = scaled_squared_euclid_dist(X, X2, self.lengthscale)
return self.K_r2(r2)
def K_r2(self, r2):
# Clipping around the (single) float precision which is ~1e-45.
r = np.sqrt(np.maximum(r2, 1e-36))
return self.K_r(r)
@staticmethod
def K_r(r):
raise NotImplementedError('kernel not implemented')
def kernel_to_state_space(self, R=None):
raise NotImplementedError
def measurement_model(self):
raise NotImplementedError
def stationary_covariance(self):
raise NotImplementedError
def feedback_matrix(self):
raise NotImplementedError
class Matern12(StationaryKernel):
"""
The Matern 1/2 kernel. Functions drawn from a GP with this kernel are not
differentiable anywhere. The kernel equation is
k(r) = σ² exp{-r}
where:
r is the Euclidean distance between the input points, scaled by the lengthscales parameter ℓ.
σ² is the variance parameter
"""
@property
def state_dim(self):
return 1
def K_r(self, r):
return self.variance * np.exp(-r)
def kernel_to_state_space(self, R=None):
F = np.array([[-1.0 / self.lengthscale]])
L = np.array([[1.0]])
Qc = np.array([[2.0 * self.variance / self.lengthscale]])
H = np.array([[1.0]])
Pinf = np.array([[self.variance]])
return F, L, Qc, H, Pinf
def stationary_covariance(self):
Pinf = np.array([[self.variance]])
return Pinf
def measurement_model(self):
H = np.array([[1.0]])
return H
def state_transition(self, dt):
"""
Calculation of the discrete-time state transition matrix A = expm(FΔt) for the exponential prior.
:param dt: step size(s), Δtₙ = tₙ - tₙ₋₁ [scalar]
:return: state transition matrix A [1, 1]
"""
A = np.broadcast_to(np.exp(-dt / self.lengthscale), [1, 1])
return A
def feedback_matrix(self):
F = np.array([[-1.0 / self.lengthscale]])
return F
Exponential = Matern12
class Matern32(StationaryKernel):
"""
The Matern 3/2 kernel. Functions drawn from a GP with this kernel are once
differentiable. The kernel equation is
k(r) = σ² (1 + √3r) exp{-√3 r}
where:
r is the Euclidean distance between the input points, scaled by the lengthscales parameter ℓ,
σ² is the variance parameter.
"""
@property
def state_dim(self):
return 2
def K_r(self, r):
sqrt3 = np.sqrt(3.0)
return self.variance * (1.0 + sqrt3 * r) * np.exp(-sqrt3 * r)
def kernel_to_state_space(self, R=None):
lam = 3.0 ** 0.5 / self.lengthscale
F = np.array([[0.0, 1.0],
[-lam ** 2, -2 * lam]])
L = np.array([[0],
[1]])
Qc = np.array([[12.0 * 3.0 ** 0.5 / self.lengthscale ** 3.0 * self.variance]])
H = np.array([[1.0, 0.0]])
Pinf = np.array([[self.variance, 0.0],
[0.0, 3.0 * self.variance / self.lengthscale ** 2.0]])
return F, L, Qc, H, Pinf
def stationary_covariance(self):
Pinf = np.array([[self.variance, 0.0],
[0.0, 3.0 * self.variance / self.lengthscale ** 2.0]])
return Pinf
def measurement_model(self):
H = np.array([[1.0, 0.0]])
return H
def state_transition(self, dt):
"""
Calculation of the discrete-time state transition matrix A = expm(FΔt) for the Matern-3/2 prior.
:param dt: step size(s), Δtₙ = tₙ - tₙ₋₁ [scalar]
:return: state transition matrix A [2, 2]
"""
lam = np.sqrt(3.0) / self.lengthscale
A = np.exp(-dt * lam) * (dt * np.array([[lam, 1.0], [-lam**2.0, -lam]]) + np.eye(2))
return A
def feedback_matrix(self):
lam = 3.0 ** 0.5 / self.lengthscale
F = np.array([[0.0, 1.0],
[-lam ** 2, -2 * lam]])
return F
class Matern52(StationaryKernel):
"""
The Matern 5/2 kernel. Functions drawn from a GP with this kernel are twice
differentiable. The kernel equation is
k(r) = σ² (1 + √5r + 5/3r²) exp{-√5 r}
where:
r is the Euclidean distance between the input points, scaled by the lengthscales parameter ℓ,
σ² is the variance parameter.
"""
@property
def state_dim(self):
return 3
def K_r(self, r):
sqrt5 = np.sqrt(5.0)
return self.variance * (1.0 + sqrt5 * r + 5.0 / 3.0 * np.square(r)) * np.exp(-sqrt5 * r)
def kernel_to_state_space(self, R=None):
lam = 5.0**0.5 / self.lengthscale
F = np.array([[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[-lam**3.0, -3.0*lam**2.0, -3.0*lam]])
L = np.array([[0.0],
[0.0],
[1.0]])
Qc = np.array([[self.variance * 400.0 * 5.0 ** 0.5 / 3.0 / self.lengthscale ** 5.0]])
H = np.array([[1.0, 0.0, 0.0]])
kappa = 5.0 / 3.0 * self.variance / self.lengthscale**2.0
Pinf = np.array([[self.variance, 0.0, -kappa],
[0.0, kappa, 0.0],
[-kappa, 0.0, 25.0*self.variance / self.lengthscale**4.0]])
return F, L, Qc, H, Pinf
def measurement_model(self):
H = np.array([[1.0, 0.0, 0.0]])
return H
def state_transition(self, dt):
"""
Calculation of the discrete-time state transition matrix A = expm(FΔt) for the Matern-5/2 prior.
:param dt: step size(s), Δtₙ = tₙ - tₙ₋₁ [scalar]
:return: state transition matrix A [3, 3]
"""
lam = np.sqrt(5.0) / self.lengthscale
dtlam = dt * lam
A = np.exp(-dtlam) \
* (dt * np.array([[lam * (0.5 * dtlam + 1.0), dtlam + 1.0, 0.5 * dt],
[-0.5 * dtlam * lam ** 2, lam * (1.0 - dtlam), 1.0 - 0.5 * dtlam],
[lam ** 3 * (0.5 * dtlam - 1.0), lam ** 2 * (dtlam - 3), lam * (0.5 * dtlam - 2.0)]])
+ np.eye(3))
return A
def stationary_covariance(self):
kappa = 5.0 / 3.0 * self.variance / self.lengthscale**2.0
Pinf = np.array([[self.variance, 0.0, -kappa],
[0.0, kappa, 0.0],
[-kappa, 0.0, 25.0*self.variance / self.lengthscale**4.0]])
return Pinf
def feedback_matrix(self):
lam = 5.0**0.5 / self.lengthscale
F = np.array([[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[-lam**3.0, -3.0*lam**2.0, -3.0*lam]])
return F
class Matern72(StationaryKernel):
"""
The Matern 7/2 kernel. Functions drawn from a GP with this kernel are three times differentiable.
where:
r is the Euclidean distance between the input points, scaled by the lengthscales parameter ℓ,
σ² is the variance parameter.
"""
@property
def state_dim(self):
return 4
def K_r(self, r):
sqrt7 = np.sqrt(7.0)
return self.variance * (1. + sqrt7 * r + 14. / 5. * np.square(r) + 7. * sqrt7 / 15. * r**3) * np.exp(-sqrt7 * r)
def kernel_to_state_space(self, R=None):
lam = 7.0**0.5 / self.lengthscale
F = np.array([[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
[-lam**4.0, -4.0*lam**3.0, -6.0*lam**2.0, -4.0*lam]])
L = np.array([[0.0],
[0.0],
[0.0],
[1.0]])
Qc = np.array([[self.variance * 10976.0 * 7.0 ** 0.5 / 5.0 / self.lengthscale ** 7.0]])
H = np.array([[1, 0, 0, 0]])
kappa = 7.0 / 5.0 * self.variance / self.lengthscale**2.0
kappa2 = 9.8 * self.variance / self.lengthscale**4.0
Pinf = np.array([[self.variance, 0.0, -kappa, 0.0],
[0.0, kappa, 0.0, -kappa2],
[-kappa, 0.0, kappa2, 0.0],
[0.0, -kappa2, 0.0, 343.0*self.variance / self.lengthscale**6.0]])
return F, L, Qc, H, Pinf
def measurement_model(self):
H = np.array([[1.0, 0.0, 0.0, 0.0]])
return H
def state_transition(self, dt):
"""
Calculation of the discrete-time state transition matrix A = expm(FΔt) for the Matern-7/2 prior.
:param dt: step size(s), Δtₙ = tₙ - tₙ₋₁ [scalar]
:return: state transition matrix A [4, 4]
"""
lam = np.sqrt(7.0) / self.lengthscale
lam2 = lam * lam
lam3 = lam2 * lam
dtlam = dt * lam
dtlam2 = dtlam ** 2
A = np.exp(-dtlam) \
* (dt * np.array([[lam * (1.0 + 0.5 * dtlam + dtlam2 / 6.0), 1.0 + dtlam + 0.5 * dtlam2,
0.5 * dt * (1.0 + dtlam), dt ** 2 / 6],
[-dtlam2 * lam ** 2.0 / 6.0, lam * (1.0 + 0.5 * dtlam - 0.5 * dtlam2),
1.0 + dtlam - 0.5 * dtlam2, dt * (0.5 - dtlam / 6.0)],
[lam3 * dtlam * (dtlam / 6.0 - 0.5), dtlam * lam2 * (0.5 * dtlam - 2.0),
lam * (1.0 - 2.5 * dtlam + 0.5 * dtlam2), 1.0 - dtlam + dtlam2 / 6.0],
[lam2 ** 2 * (dtlam - 1.0 - dtlam2 / 6.0), lam3 * (3.5 * dtlam - 4.0 - 0.5 * dtlam2),
lam2 * (4.0 * dtlam - 6.0 - 0.5 * dtlam2), lam * (1.5 * dtlam - 3.0 - dtlam2 / 6.0)]])
+ np.eye(4))
return A
def stationary_covariance(self):
kappa = 7.0 / 5.0 * self.variance / self.lengthscale ** 2.0
kappa2 = 9.8 * self.variance / self.lengthscale ** 4.0
Pinf = np.array([[self.variance, 0.0, -kappa, 0.0],
[0.0, kappa, 0.0, -kappa2],
[-kappa, 0.0, kappa2, 0.0],
[0.0, -kappa2, 0.0, 343.0 * self.variance / self.lengthscale ** 6.0]])
return Pinf
def feedback_matrix(self):
lam = 7.0 ** 0.5 / self.lengthscale
F = np.array([[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
[-lam ** 4.0, -4.0 * lam ** 3.0, -6.0 * lam ** 2.0, -4.0 * lam]])
return F
class SpatioTemporalKernel(Kernel):
"""
The Spatio-Temporal GP class
:param temporal_kernel: the temporal prior, must be a member of the Prior class
:param spatial_kernel: the kernel used for the spatial dimensions
:param z: the initial spatial locations
:param conditional: specifies which method to use for computing the covariance of the spatial conditional;
must be one of ['DTC', 'FIC', 'Full']
:param sparse: boolean specifying whether the model is sparse in space
:param opt_z: boolean specifying whether to optimise the spatial input locations z
"""
def __init__(self,
temporal_kernel,
spatial_kernel,
z=None,
conditional=None,
sparse=True,
opt_z=False,
spatial_dims=None):
self.temporal_kernel = temporal_kernel
self.spatial_kernel = spatial_kernel
if conditional is None:
if sparse:
conditional = 'Full'
else:
conditional = 'DTC'
if opt_z and (not sparse): # z should not be optimised if the model is not sparse
warn("spatial inducing inputs z will not be optimised because sparse=False")
opt_z = False
self.sparse = sparse
if z is None: # initialise z
# TODO: smart initialisation
if spatial_dims == 1:
z = np.linspace(-3., 3., num=15)
elif spatial_dims == 2:
z1 = np.linspace(-3., 3., num=5)
zA, zB = np.meshgrid(z1, z1) # Adding additional dimension to inducing points grid
z = np.hstack((zA.reshape(-1, 1), zB.reshape(-1, 1))) # Flattening grid for use in kernel functions
else:
raise NotImplementedError('please provide an initialisation for inducing inputs z')
if z.ndim < 2:
z = z[:, np.newaxis]
if spatial_dims is None:
spatial_dims = z.ndim - 1
assert spatial_dims == z.ndim - 1
self.M = z.shape[0]
if opt_z:
self.z = objax.TrainVar(np.array(z)) # .reshape(-1, 1)
else:
self.z = objax.StateVar(np.array(z))
if conditional in ['DTC', 'dtc']:
self.conditional_covariance = self.deterministic_training_conditional
elif conditional in ['FIC', 'FITC', 'fic', 'fitc']:
self.conditional_covariance = self.fully_independent_conditional
elif conditional in ['Full', 'full']:
self.conditional_covariance = self.full_conditional
else:
raise NotImplementedError('conditional method not recognised')
if (not sparse) and (conditional != 'DTC'):
warn("You chose a non-deterministic conditional, but \'DTC\' will be used because the model is not sparse")
@property
def variance(self):
return self.temporal_kernel.variance
@property
def temporal_lengthscale(self):
return self.temporal_kernel.lengthscale
@property
def spatial_lengthscale(self):
return self.spatial_kernel.lengthscale
@property
def state_dim(self):
return self.temporal_kernel.state_dim
def K(self, X, X2):
T = X[:, :1]
T2 = X2[:, :1]
R = X[:, 1:]
R2 = X2[:, 1:]
return self.temporal_kernel(T, T2) * self.spatial_kernel(R, R2)
@staticmethod
def deterministic_training_conditional(X, R, Krz, K):
cov = np.array([[0.0]])
return cov
def fully_independent_conditional(self, X, R, Krz, K):
Krr = self.spatial_kernel(R, R)
X = X.reshape(-1, 1)
cov = self.temporal_kernel.K(X, X) * (np.diag(np.diag(Krr - K @ Krz.T)))
return cov
def full_conditional(self, X, R, Krz, K):
Krr = self.spatial_kernel(R, R)
X = X.reshape(-1, 1)
cov = self.temporal_kernel.K(X, X) * (Krr - K @ Krz.T)
return cov
def spatial_conditional(self, X=None, R=None, predict=False):
"""
Compute the spatial conditional, i.e. the measurement model projecting the latent function u(t) to f(X,R)
f(X,R) | u(t) ~ N(f(X,R) | B u(t), C)
"""
Qzz, Lzz = self.inducing_precision() # pre-calculate inducing precision and its Cholesky factor
if self.sparse or predict:
# TODO: save compute if R is constant:
# gridded_data = np.all(np.abs(np.diff(R, axis=0)) < 1e-10)
# if gridded_data:
# R = R[:1]
R = R.reshape((R.shape[0],) + (-1,) + self.z.value.shape[1:])
Krz = vmap(self.spatial_kernel, [0, None])(R, self.z.value)
K = Krz @ Qzz # Krz / Kzz
B = K @ Lzz
C = vmap(self.conditional_covariance)(X, R, Krz, K) # conditional covariance
else:
B = Lzz
# conditional covariance (deterministic mapping is exact in non-sparse case)
C = np.zeros([B.shape[0], B.shape[0]])
return B, C
def inducing_precision(self):
"""
Compute the covariance and precision of the inducing spatial points to be used during filtering
"""
Kzz = self.spatial_kernel(self.z.value, self.z.value)
Lzz, low = cho_factor(Kzz, lower=True) # K_zz^(1/2)
Qzz = cho_solve((Lzz, low), np.eye(self.M)) # K_zz^(-1)
return Qzz, Lzz
def stationary_covariance(self):
"""
Compute the covariance of the stationary state distribution. Since the latent components are independent
under the prior, this is a block-diagonal matrix
"""
Pinf_time = self.temporal_kernel.stationary_covariance()
Pinf = np.kron(np.eye(self.M), Pinf_time)
return Pinf
def stationary_covariance_meanfield(self):
"""
Stationary covariance as a tensor of blocks, as required when using a mean-field assumption
"""
Pinf_time = self.temporal_kernel.stationary_covariance()
Pinf = np.tile(Pinf_time, [self.M, 1, 1])
return Pinf
def measurement_model(self):
"""
Compute the spatial conditional, i.e. the measurement model projecting the state x(t) to function space
f(t, R) = H x(t)
"""
H_time = self.temporal_kernel.measurement_model()
H = np.kron(np.eye(self.M), H_time)
return H
def state_transition(self, dt):
"""
Calculation of the discrete-time state transition matrix A = expm(FΔt) for the spatio-temporal prior.
:param dt: step size(s), Δtₙ = tₙ - tₙ₋₁ [scalar]
:return: state transition matrix A
"""
A_time = self.temporal_kernel.state_transition(dt)
A = np.kron(np.eye(self.M), A_time)
return A
def state_transition_meanfield(self, dt):
"""
State transition matrix in the form required for mean-field inference.
:param dt: step size(s), Δtₙ = tₙ - tₙ₋₁ [scalar]
:return: state transition matrix A
"""
A_time = self.temporal_kernel.state_transition(dt)
A = np.tile(A_time, [self.M, 1, 1])
return A
def kernel_to_state_space(self, R=None):
F_t, L_t, Qc_t, H_t, Pinf_t = self.temporal_kernel.kernel_to_state_space()
Kzz = self.spatial_kernel(self.z.value, self.z.value)
F = np.kron(np.eye(self.M), F_t)
Qc = None
L = None
H = self.measurement_model()
Pinf = np.kron(Kzz, Pinf_t)
return F, L, Qc, H, Pinf
def get_meanfield_block_index(self):
Pinf = self.stationary_covariance_meanfield()
num_latents = Pinf.shape[0]
sub_state_dim = Pinf.shape[1]
state = np.ones([sub_state_dim, sub_state_dim])
for i in range(1, num_latents):
state = block_diag(state, np.ones([sub_state_dim, sub_state_dim]))
block_index = np.where(np.array(state, dtype=bool))
return block_index
def feedback_matrix(self):
F_t = self.temporal_kernel.feedback_matrix()
F = np.kron(np.eye(self.M), F_t)
return F
class SpatioTemporalMatern12(SpatioTemporalKernel):
"""
Spatio-Temporal Matern-1/2 kernel in SDE form.
Hyperparameters:
variance, σ²
temporal lengthscale, lt
spatial lengthscale, ls
"""
def __init__(self,
variance=1.0,
lengthscale_time=1.0,
lengthscale_space=1.0,
z=None,
sparse=True,
opt_z=False,
conditional=None):
super().__init__(temporal_kernel=Matern12(variance=variance, lengthscale=lengthscale_time),
spatial_kernel=Matern12(variance=1., lengthscale=lengthscale_space, fix_variance=True),
z=z,
conditional=conditional,
sparse=sparse,
opt_z=opt_z)
self.name = 'Spatio-Temporal Matern-1/2'
class SpatioTemporalMatern32(SpatioTemporalKernel):
"""
Spatio-Temporal Matern-3/2 kernel in SDE form.
Hyperparameters:
variance, σ²
temporal lengthscale, lt
spatial lengthscale, ls
"""
def __init__(self,
variance=1.0,
lengthscale_time=1.0,
lengthscale_space=1.0,
z=None,
sparse=True,
opt_z=False,
conditional=None):
super().__init__(temporal_kernel=Matern32(variance=variance, lengthscale=lengthscale_time),
spatial_kernel=Matern32(variance=1., lengthscale=lengthscale_space, fix_variance=True),
z=z,
conditional=conditional,
sparse=sparse,
opt_z=opt_z)
self.name = 'Spatio-Temporal Matern-3/2'
class SpatioTemporalMatern52(SpatioTemporalKernel):
"""
Spatio-Temporal Matern-5/2 kernel in SDE form.
Hyperparameters:
variance, σ²
temporal lengthscale, lt
spatial lengthscale, ls
"""
def __init__(self,
variance=1.0,
lengthscale_time=1.0,
lengthscale_space=1.0,
z=None,
sparse=True,
opt_z=False,
conditional=None):
super().__init__(temporal_kernel=Matern52(variance=variance, lengthscale=lengthscale_time),
spatial_kernel=Matern52(variance=1., lengthscale=lengthscale_space, fix_variance=True),
z=z,
conditional=conditional,
sparse=sparse,
opt_z=opt_z)
self.name = 'Spatio-Temporal Matern-5/2'
class SpatialMatern12(SpatioTemporalKernel):
"""
Spatial Matern-1/2 kernel in SDE form. Similar to the spatio-temporal kernel but the
lengthscale is shared across dimensions.
Hyperparameters:
variance, σ²
lengthscale, l
"""
def __init__(self,
variance=1.0,
lengthscale=1.0,
z=None,
sparse=True,
opt_z=False,
conditional=None):
super().__init__(temporal_kernel=Matern12(variance=variance, lengthscale=lengthscale),
spatial_kernel=Matern12(variance=1., lengthscale=lengthscale, fix_variance=True),
z=z,
conditional=conditional,
sparse=sparse,
opt_z=opt_z)
# --- couple the lengthscales ---
self.spatial_kernel.transformed_lengthscale = self.temporal_kernel.transformed_lengthscale
# -------------------------------
self.name = 'Spatial Matern-1/2'
class SpatialMatern32(SpatioTemporalKernel):
"""
Spatial Matern-3/2 kernel in SDE form. Similar to the spatio-temporal kernel but the
lengthscale is shared across dimensions.
Hyperparameters:
variance, σ²
lengthscale, l
"""
def __init__(self,
variance=1.0,
lengthscale=1.0,
z=None,
sparse=True,
opt_z=False,
conditional=None):
super().__init__(temporal_kernel=Matern32(variance=variance, lengthscale=lengthscale),
spatial_kernel=Matern32(variance=1., lengthscale=lengthscale, fix_variance=True),
z=z,
conditional=conditional,
sparse=sparse,
opt_z=opt_z)
# --- couple the lengthscales ---
self.spatial_kernel.transformed_lengthscale = self.temporal_kernel.transformed_lengthscale
# -------------------------------
self.name = 'Spatial Matern-3/2'
class SpatialMatern52(SpatioTemporalKernel):
"""
Spatial Matern-5/2 kernel in SDE form. Similar to the spatio-temporal kernel but the
lengthscale is shared across dimensions.
Hyperparameters:
variance, σ²
lengthscale, l
"""
def __init__(self,
variance=1.0,
lengthscale=1.0,
z=None,
sparse=True,
opt_z=False,
conditional=None):
super().__init__(temporal_kernel=Matern52(variance=variance, lengthscale=lengthscale),
spatial_kernel=Matern52(variance=1., lengthscale=lengthscale, fix_variance=True),
z=z,
conditional=conditional,
sparse=sparse,
opt_z=opt_z)
# --- couple the lengthscales ---
self.spatial_kernel.transformed_lengthscale = self.temporal_kernel.transformed_lengthscale
# -------------------------------
self.name = 'Spatial Matern-5/2'
class Cosine(Kernel):
"""
Cosine kernel in SDE form.
Hyperparameters:
radial frequency, ω
The associated continuous-time state space model matrices are:
F = ( 0 -ω
ω 0 )
L = N/A
Qc = N/A
H = ( 1 0 )
Pinf = ( 1 0
0 1 )
and the discrete-time transition matrix is (for step size Δt),
A = ( cos(ωΔt) -sin(ωΔt)
sin(ωΔt) cos(ωΔt) )
"""
def __init__(self, frequency=1.0):
self.transformed_frequency = objax.TrainVar(np.array(softplus_inv(frequency)))
super().__init__()
self.name = 'Cosine'
@property
def frequency(self):
return softplus(self.transformed_frequency.value)
def kernel_to_state_space(self, R=None):
F = np.array([[0.0, -self.frequency],
[self.frequency, 0.0]])
H = np.array([[1.0, 0.0]])
L = []
Qc = []
Pinf = np.eye(2)
return F, L, Qc, H, Pinf
def stationary_covariance(self):
Pinf = np.eye(2)
return Pinf
def measurement_model(self):
H = np.array([[1.0, 0.0]])
return H
def state_transition(self, dt):
"""
Calculation of the closed form discrete-time state
transition matrix A = expm(FΔt) for the Cosine prior
:param dt: step size(s), Δt = tₙ - tₙ₋₁ [M+1, 1]
:return: state transition matrix A [M+1, D, D]
"""
state_transitions = rotation_matrix(dt, self.frequency) # [2, 2]
return state_transitions
def feedback_matrix(self):
F = np.array([[0.0, -self.frequency],
[self.frequency, 0.0]])
return F
class Periodic(Kernel):
"""
Periodic kernel in SDE form.
Hyperparameters:
variance, σ²
lengthscale, l
period, p
The associated continuous-time state space model matrices are constructed via
a sum of cosines.
"""
def __init__(self, variance=1.0, lengthscale=1.0, period=1.0, order=6, fix_variance=False):
self.transformed_lengthscale = objax.TrainVar(np.array(softplus_inv(lengthscale)))
if fix_variance:
self.transformed_variance = objax.StateVar(np.array(softplus_inv(variance)))
else:
self.transformed_variance = objax.TrainVar(np.array(softplus_inv(variance)))
self.transformed_period = objax.TrainVar(np.array(softplus_inv(period)))
super().__init__()
self.name = 'Periodic'
self.order = order
@property
def variance(self):
return softplus(self.transformed_variance.value)
@property
def lengthscale(self):
return softplus(self.transformed_lengthscale.value)
@property
def period(self):
return softplus(self.transformed_period.value)
def kernel_to_state_space(self, R=None):
q2 = np.array([1, *[2]*self.order]) * self.variance * bessel_ive([*range(self.order+1)], self.lengthscale**(-2))
# The angular frequency
omega = 2 * np.pi / self.period
# The model
F = np.kron(np.diag(np.arange(self.order + 1)), np.array([[0., -omega], [omega, 0.]]))
L = np.eye(2 * (self.order + 1))
Qc = np.zeros(2 * (self.order + 1))
Pinf = np.kron(np.diag(q2), np.eye(2))
H = np.kron(np.ones([1, self.order + 1]), np.array([1., 0.]))
return F, L, Qc, H, Pinf
def stationary_covariance(self):
q2 = np.array([1, *[2]*self.order]) * self.variance * bessel_ive([*range(self.order+1)], self.lengthscale**(-2))
Pinf = np.kron(np.diag(q2), np.eye(2))
return Pinf
def measurement_model(self):
H = np.kron(np.ones([1, self.order + 1]), np.array([1., 0.]))
return H
def state_transition(self, dt):
"""
Calculation of the closed form discrete-time state
transition matrix A = expm(FΔt) for the Periodic prior
:param dt: step size(s), Δt = tₙ - tₙ₋₁ [1]
:return: state transition matrix A [2(N+1), 2(N+1)]
"""
omega = 2 * np.pi / self.period # The angular frequency
harmonics = np.arange(self.order + 1) * omega
A = block_diag(*vmap(rotation_matrix, [None, 0])(dt, harmonics))
return A
def feedback_matrix(self):
# The angular frequency
omega = 2 * np.pi / self.period
# The model
F = np.kron(np.diag(np.arange(self.order + 1)), np.array([[0., -omega], [omega, 0.]]))
return F
class QuasiPeriodicMatern12(Kernel):
"""
TODO: implement a general 'Product' class to reduce code duplication
Quasi-periodic kernel in SDE form (product of Periodic and Matern-1/2).
Hyperparameters:
variance, σ²
lengthscale of Periodic, l_p
period, p
lengthscale of Matern, l_m
The associated continuous-time state space model matrices are constructed via
a sum of cosines times a Matern-1/2.
"""
def __init__(self, variance=1.0, lengthscale_periodic=1.0, period=1.0, lengthscale_matern=1.0, order=6):
self.transformed_lengthscale_periodic = objax.TrainVar(np.array(softplus_inv(lengthscale_periodic)))
self.transformed_variance = objax.TrainVar(np.array(softplus_inv(variance)))
self.transformed_period = objax.TrainVar(np.array(softplus_inv(period)))
self.transformed_lengthscale_matern = objax.TrainVar(np.array(softplus_inv(lengthscale_matern)))
super().__init__()
self.name = 'Quasi-periodic Matern-1/2'
self.order = order
@property
def variance(self):
return softplus(self.transformed_variance.value)
@property
def lengthscale_periodic(self):
return softplus(self.transformed_lengthscale_periodic.value)
@property
def lengthscale_matern(self):
return softplus(self.transformed_lengthscale_matern.value)
@property
def period(self):
return softplus(self.transformed_period.value)
def K(self, X, X2):
raise NotImplementedError
def kernel_to_state_space(self, R=None):
q2 = np.array([1, *[2]*self.order]) * bessel_ive([*range(self.order+1)], self.lengthscale_periodic**(-2))
# The angular frequency
omega = 2 * np.pi / self.period
# The model
F_p = np.kron(np.diag(np.arange(self.order + 1)), np.array([[0., -omega], [omega, 0.]]))
L_p = np.eye(2 * (self.order + 1))
# Qc_p = np.zeros(2 * (self.N + 1))
Pinf_p = np.kron(np.diag(q2), np.eye(2))
H_p = np.kron(np.ones([1, self.order + 1]), np.array([1., 0.]))
F_m = np.array([[-1.0 / self.lengthscale_matern]])
L_m = np.array([[1.0]])
Qc_m = np.array([[2.0 * self.variance / self.lengthscale_matern]])
H_m = np.array([[1.0]])
Pinf_m = np.array([[self.variance]])
F = np.kron(F_m, np.eye(2 * (self.order + 1))) + np.kron(np.eye(1), F_p)
L = np.kron(L_m, L_p)
Qc = np.kron(Qc_m, Pinf_p)
H = np.kron(H_m, H_p)
Pinf = np.kron(Pinf_m, Pinf_p)
return F, L, Qc, H, Pinf
def stationary_covariance(self):
q2 = np.array([1, *[2]*self.order]) * bessel_ive([*range(self.order+1)], self.lengthscale_periodic**(-2))
Pinf_m = np.array([[self.variance]])
Pinf_p = np.kron(np.diag(q2), np.eye(2))
Pinf = np.kron(Pinf_m, Pinf_p)
return Pinf
def measurement_model(self):
H_p = np.kron(np.ones([1, self.order + 1]), np.array([1., 0.]))
H_m = np.array([[1.0]])
H = np.kron(H_m, H_p)
return H
def state_transition(self, dt):
"""
Calculation of the closed form discrete-time state
transition matrix A = expm(FΔt) for the Quasi-Periodic Matern-3/2 prior
:param dt: step size(s), Δt = tₙ - tₙ₋₁ [M+1, 1]
:return: state transition matrix A [M+1, D, D]
"""
# The angular frequency
omega = 2 * np.pi / self.period
harmonics = np.arange(self.order + 1) * omega
A = (
np.exp(-dt / self.lengthscale_matern)
* block_diag(*vmap(rotation_matrix, [None, 0])(dt, harmonics))
)
return A
def feedback_matrix(self):
# The angular frequency
omega = 2 * np.pi / self.period
# The model
F_p = np.kron(np.diag(np.arange(self.order + 1)), np.array([[0., -omega], [omega, 0.]]))
F_m = np.array([[-1.0 / self.lengthscale_matern]])
F = np.kron(F_m, np.eye(2 * (self.order + 1))) + np.kron(np.eye(1), F_p)
return F
class QuasiPeriodicMatern32(Kernel):
"""
Quasi-periodic kernel in SDE form (product of Periodic and Matern-3/2).
Hyperparameters:
variance, σ²
lengthscale of Periodic, l_p
period, p
lengthscale of Matern, l_m
The associated continuous-time state space model matrices are constructed via
a sum of cosines times a Matern-3/2.
"""
def __init__(self, variance=1.0, lengthscale_periodic=1.0, period=1.0, lengthscale_matern=1.0, order=6):
self.transformed_lengthscale_periodic = objax.TrainVar(np.array(softplus_inv(lengthscale_periodic)))
self.transformed_variance = objax.TrainVar(np.array(softplus_inv(variance)))
self.transformed_period = objax.TrainVar(np.array(softplus_inv(period)))
self.transformed_lengthscale_matern = objax.TrainVar(np.array(softplus_inv(lengthscale_matern)))
super().__init__()
self.name = 'Quasi-periodic Matern-3/2'
self.order = order
@property
def variance(self):