9
9
import matplotlib .pylab as pl
10
10
import ot
11
11
12
-
12
+ from mpl_toolkits .mplot3d import Axes3D
13
+ from matplotlib .collections import PolyCollection
14
+ from matplotlib .colors import colorConverter
13
15
14
16
#%% parameters
15
17
19
21
x = np .arange (n ,dtype = np .float64 )
20
22
21
23
# Gaussian distributions
22
- a1 = ot .datasets .get_1D_gauss (n ,m = 20 ,s = 20 ) # m= mean, s= std
23
- a2 = ot .datasets .get_1D_gauss (n ,m = 60 ,s = 60 )
24
+ a1 = ot .datasets .get_1D_gauss (n ,m = 20 ,s = 5 ) # m= mean, s= std
25
+ a2 = ot .datasets .get_1D_gauss (n ,m = 60 ,s = 8 )
24
26
25
27
# creating matrix A containing all distributions
26
28
A = np .vstack ((a1 ,a2 )).T
39
41
40
42
#%% barycenter computation
41
43
44
+ alpha = 0.2 # 0<=alpha<=1
45
+ weights = np .array ([1 - alpha ,alpha ])
46
+
42
47
# l2bary
43
- bary_l2 = A .mean ( 1 )
48
+ bary_l2 = A .dot ( weights )
44
49
45
50
# wasserstein
46
51
reg = 1e-3
47
- bary_wass = ot .bregman .barycenter (A ,M ,reg )
52
+ bary_wass = ot .bregman .barycenter (A ,M ,reg , weights )
48
53
49
54
pl .figure (2 )
50
55
pl .clf ()
58
63
pl .plot (x ,bary_wass ,'g' ,label = 'Wasserstein' )
59
64
pl .legend ()
60
65
pl .title ('Barycenters' )
66
+
67
+
68
+ #%% barycenter interpolation
69
+
70
+ nbalpha = 11
71
+ alphalist = np .linspace (0 ,1 ,nbalpha )
72
+
73
+
74
+ B_l2 = np .zeros ((n ,nbalpha ))
75
+
76
+ B_wass = np .copy (B_l2 )
77
+
78
+ for i in range (0 ,nbalpha ):
79
+ alpha = alphalist [i ]
80
+ weights = np .array ([1 - alpha ,alpha ])
81
+ B_l2 [:,i ]= A .dot (weights )
82
+ B_wass [:,i ]= ot .bregman .barycenter (A ,M ,reg ,weights )
83
+
84
+ #%% plot interpolation
85
+
86
+ pl .figure (3 ,(10 ,5 ))
87
+
88
+ #pl.subplot(1,2,1)
89
+ cmap = pl .cm .get_cmap ('viridis' )
90
+ verts = []
91
+ zs = alphalist
92
+ for i ,z in enumerate (zs ):
93
+ ys = B_l2 [:,i ]
94
+ verts .append (list (zip (x , ys )))
95
+
96
+ ax = pl .gcf ().gca (projection = '3d' )
97
+
98
+ poly = PolyCollection (verts ,facecolors = [cmap (a ) for a in alphalist ])
99
+ poly .set_alpha (0.7 )
100
+ ax .add_collection3d (poly , zs = zs , zdir = 'y' )
101
+
102
+ ax .set_xlabel ('x' )
103
+ ax .set_xlim3d (0 , n )
104
+ ax .set_ylabel ('$\\ alpha$' )
105
+ ax .set_ylim3d (0 ,1 )
106
+ ax .set_zlabel ('' )
107
+ ax .set_zlim3d (0 , B_l2 .max ()* 1.01 )
108
+ pl .title ('Barycenter interpolation with l2' )
109
+
110
+ pl .show ()
111
+
112
+ pl .figure (4 ,(10 ,5 ))
113
+
114
+ #pl.subplot(1,2,1)
115
+ cmap = pl .cm .get_cmap ('viridis' )
116
+ verts = []
117
+ zs = alphalist
118
+ for i ,z in enumerate (zs ):
119
+ ys = B_wass [:,i ]
120
+ verts .append (list (zip (x , ys )))
121
+
122
+ ax = pl .gcf ().gca (projection = '3d' )
123
+
124
+ poly = PolyCollection (verts ,facecolors = [cmap (a ) for a in alphalist ])
125
+ poly .set_alpha (0.7 )
126
+ ax .add_collection3d (poly , zs = zs , zdir = 'y' )
127
+
128
+ ax .set_xlabel ('x' )
129
+ ax .set_xlim3d (0 , n )
130
+ ax .set_ylabel ('$\\ alpha$' )
131
+ ax .set_ylim3d (0 ,1 )
132
+ ax .set_zlabel ('' )
133
+ ax .set_zlim3d (0 , B_l2 .max ()* 1.01 )
134
+ pl .title ('Barycenter interpolation with Wasserstein' )
135
+
136
+ pl .show ()
0 commit comments