Skip to content

Commit d6bf10d

Browse files
eloitanguyrflamary
andauthored
[WIP] Graphical tweaks for GWB + fixed seed method for the partial gromov test (#376)
* GWB first solver version * tests + example for gwb (untested) + free_bar doc fix * improved doc, fixed minor bugs, better example visu * minor doc + visu fixes * plot GWB pep8 fix * fixed partial gromov test reproductibility * added an animation for the GWB visu * added PR num * minor doc fixes + better gwb logo * GWB graphical tweaks + better seed method for partial gromov test * fixed PR number * refixed seed issue * seed fix fix fix Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent c1ccfc4 commit d6bf10d

File tree

3 files changed

+13
-11
lines changed

3 files changed

+13
-11
lines changed

RELEASES.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44

55
#### New features
66

7-
- Added Generalized Wasserstein Barycenter solver + example (PR #372)
7+
- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376)
88

99
#### Closed issues
1010

1111
- Fixed an issue where we could not ask TorchBackend to place a random tensor on GPU
1212
(Issue #371, PR #373)
1313

14-
1514
## 0.8.2
1615

1716
This releases introduces several new notable features. The less important

examples/barycenters/plot_generalized_free_support_barycenter.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333
# Input measures
3434
sub_sample_factor = 8
3535
I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2]
36-
I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2]
37-
I3 = pl.imread('../../data/heart.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2]
36+
I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::-sub_sample_factor, ::sub_sample_factor, 2]
37+
I3 = pl.imread('../../data/heart.png').astype(np.float64)[::-sub_sample_factor, ::sub_sample_factor, 2]
3838

3939
sz = I1.shape[0]
4040
UU, VV = np.meshgrid(np.arange(sz), np.arange(sz))
@@ -145,8 +145,11 @@ def _init():
145145

146146

147147
def _update_plot(i):
148-
ax.view_init(elev=i, azim=4 * i)
148+
if i < 45:
149+
ax.view_init(elev=0, azim=4 * i)
150+
else:
151+
ax.view_init(elev=i - 45, azim=4 * i)
149152
return fig,
150153

151154

152-
ani = animation.FuncAnimation(fig, _update_plot, init_func=_init, frames=90, interval=50, blit=True, repeat_delay=2000)
155+
ani = animation.FuncAnimation(fig, _update_plot, init_func=_init, frames=136, interval=50, blit=True, repeat_delay=2000)

test/test_partial.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def test_partial_wasserstein():
137137

138138

139139
def test_partial_gromov_wasserstein():
140-
np.random.seed(42)
140+
rng = np.random.RandomState(seed=42)
141141
n_samples = 20 # nb samples
142142
n_noise = 10 # nb of samples (noise)
143143

@@ -150,11 +150,11 @@ def test_partial_gromov_wasserstein():
150150
mu_t = np.array([0, 0, 0])
151151
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
152152

153-
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
154-
xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) + 1) * 4)), axis=0)
153+
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, rng)
154+
xs = np.concatenate((xs, ((rng.rand(n_noise, 2) + 1) * 4)), axis=0)
155155
P = sp.linalg.sqrtm(cov_t)
156-
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
157-
xt = np.concatenate((xt, ((np.random.rand(n_noise, 3) + 1) * 10)), axis=0)
156+
xt = rng.randn(n_samples, 3).dot(P) + mu_t
157+
xt = np.concatenate((xt, ((rng.rand(n_noise, 3) + 1) * 10)), axis=0)
158158
xt2 = xs[::-1].copy()
159159

160160
C1 = ot.dist(xs, xs)

0 commit comments

Comments
 (0)