Skip to content

Commit 0a7ce08

Browse files
committed
update barycenter demo
1 parent e3c0d3e commit 0a7ce08

File tree

4 files changed

+366
-161
lines changed

4 files changed

+366
-161
lines changed

examples/Demo_1D_barycenter.ipynb

Lines changed: 283 additions & 154 deletions
Large diffs are not rendered by default.

examples/demo_OTDA_classes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858

5959
# True Group lasso regularization
6060
reg=1e-1
61-
eta=1e0
61+
eta=2e0
6262
da_l1l2=ot.da.OTDA_l1l2()
6363
da_l1l2.fit(xs,ys,xt,reg=reg,eta=eta,numItermax=20,verbose=True)
6464
xstgl=da_l1l2.interp()

examples/demo_barycenter_1D.py

Lines changed: 81 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import matplotlib.pylab as pl
1010
import ot
1111

12-
12+
from mpl_toolkits.mplot3d import Axes3D
13+
from matplotlib.collections import PolyCollection
14+
from matplotlib.colors import colorConverter
1315

1416
#%% parameters
1517

@@ -19,8 +21,8 @@
1921
x=np.arange(n,dtype=np.float64)
2022

2123
# Gaussian distributions
22-
a1=ot.datasets.get_1D_gauss(n,m=20,s=20) # m= mean, s= std
23-
a2=ot.datasets.get_1D_gauss(n,m=60,s=60)
24+
a1=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std
25+
a2=ot.datasets.get_1D_gauss(n,m=60,s=8)
2426

2527
# creating matrix A containing all distributions
2628
A=np.vstack((a1,a2)).T
@@ -39,12 +41,15 @@
3941

4042
#%% barycenter computation
4143

44+
alpha=0.2 # 0<=alpha<=1
45+
weights=np.array([1-alpha,alpha])
46+
4247
# l2bary
43-
bary_l2=A.mean(1)
48+
bary_l2=A.dot(weights)
4449

4550
# wasserstein
4651
reg=1e-3
47-
bary_wass=ot.bregman.barycenter(A,M,reg)
52+
bary_wass=ot.bregman.barycenter(A,M,reg,weights)
4853

4954
pl.figure(2)
5055
pl.clf()
@@ -58,3 +63,74 @@
5863
pl.plot(x,bary_wass,'g',label='Wasserstein')
5964
pl.legend()
6065
pl.title('Barycenters')
66+
67+
68+
#%% barycenter interpolation
69+
70+
nbalpha=11
71+
alphalist=np.linspace(0,1,nbalpha)
72+
73+
74+
B_l2=np.zeros((n,nbalpha))
75+
76+
B_wass=np.copy(B_l2)
77+
78+
for i in range(0,nbalpha):
79+
alpha=alphalist[i]
80+
weights=np.array([1-alpha,alpha])
81+
B_l2[:,i]=A.dot(weights)
82+
B_wass[:,i]=ot.bregman.barycenter(A,M,reg,weights)
83+
84+
#%% plot interpolation
85+
86+
pl.figure(3,(10,5))
87+
88+
#pl.subplot(1,2,1)
89+
cmap=pl.cm.get_cmap('viridis')
90+
verts = []
91+
zs = alphalist
92+
for i,z in enumerate(zs):
93+
ys = B_l2[:,i]
94+
verts.append(list(zip(x, ys)))
95+
96+
ax = pl.gcf().gca(projection='3d')
97+
98+
poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist])
99+
poly.set_alpha(0.7)
100+
ax.add_collection3d(poly, zs=zs, zdir='y')
101+
102+
ax.set_xlabel('x')
103+
ax.set_xlim3d(0, n)
104+
ax.set_ylabel('$\\alpha$')
105+
ax.set_ylim3d(0,1)
106+
ax.set_zlabel('')
107+
ax.set_zlim3d(0, B_l2.max()*1.01)
108+
pl.title('Barycenter interpolation with l2')
109+
110+
pl.show()
111+
112+
pl.figure(4,(10,5))
113+
114+
#pl.subplot(1,2,1)
115+
cmap=pl.cm.get_cmap('viridis')
116+
verts = []
117+
zs = alphalist
118+
for i,z in enumerate(zs):
119+
ys = B_wass[:,i]
120+
verts.append(list(zip(x, ys)))
121+
122+
ax = pl.gcf().gca(projection='3d')
123+
124+
poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist])
125+
poly.set_alpha(0.7)
126+
ax.add_collection3d(poly, zs=zs, zdir='y')
127+
128+
ax.set_xlabel('x')
129+
ax.set_xlim3d(0, n)
130+
ax.set_ylabel('$\\alpha$')
131+
ax.set_ylim3d(0,1)
132+
ax.set_zlabel('')
133+
ax.set_zlim3d(0, B_l2.max()*1.01)
134+
pl.title('Barycenter interpolation with Wasserstein')
135+
136+
pl.show()

ot/datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,4 +128,4 @@ def get_data_classif(dataset,n,nz=.5,theta=0,**kwargs):
128128
y=0
129129
print("unknown dataset")
130130

131-
return x,y.astype(int)
131+
return x,y

0 commit comments

Comments
 (0)