|
31 | 31 |
|
32 | 32 | n_source = 7
|
33 | 33 | n_target = 4
|
34 |
| -eps = 1 |
35 |
| -nb_iter = 10000 |
| 34 | +reg = 1 |
| 35 | +numItermax = 10000 |
36 | 36 | lr = 0.1
|
37 | 37 |
|
38 |
| -a = (1./n_source) * np.ones(n_source) |
39 |
| -b = (1./n_target) * np.ones(n_target) |
40 |
| -X_source = np.arange(n_source) |
41 |
| -Y_target = np.arange(0, 2 * n_target, 2) |
42 |
| -M = np.abs(X_source[:, None] - Y_target[None, :]) |
| 38 | +a = ot.utils.unif(n_source) |
| 39 | +b = ot.utils.unif(n_target) |
| 40 | + |
| 41 | +rng = np.random.RandomState(0) |
| 42 | +X_source = rng.randn(n_source, 2) |
| 43 | +Y_target = rng.randn(n_target, 2) |
| 44 | +M = ot.dist(X_source, Y_target) |
43 | 45 |
|
44 | 46 | #############################################################################
|
45 | 47 | #
|
|
50 | 52 | # results.
|
51 | 53 |
|
52 | 54 | method = "SAG"
|
53 |
| -sag_pi = ot.stochastic.transportation_matrix_entropic(method, eps, a, b, M, |
54 |
| - n_source, n_target, |
55 |
| - nb_iter, lr) |
| 55 | +sag_pi = ot.stochastic.transportation_matrix_entropic(a, b, M, reg, method, |
| 56 | + numItermax, lr) |
56 | 57 | print(sag_pi)
|
57 | 58 |
|
58 | 59 | #############################################################################
|
|
66 | 67 |
|
67 | 68 | n_source = 7
|
68 | 69 | n_target = 4
|
69 |
| -eps = 1 |
70 |
| -nb_iter = 10000 |
71 |
| -lr = 0.1 |
| 70 | +reg = 1 |
| 71 | +numItermax = 300000 |
| 72 | +lr = 1 |
| 73 | + |
| 74 | +a = ot.utils.unif(n_source) |
| 75 | +b = ot.utils.unif(n_target) |
72 | 76 |
|
73 |
| -a = (1./n_source) * np.ones(n_source) |
74 |
| -b = (1./n_target) * np.ones(n_target) |
75 |
| -X_source = np.arange(n_source) |
76 |
| -Y_target = np.arange(0, 2 * n_target, 2) |
77 |
| -M = np.abs(X_source[:, None] - Y_target[None, :]) |
| 77 | +rng = np.random.RandomState(0) |
| 78 | +X_source = rng.randn(n_source, 2) |
| 79 | +Y_target = rng.randn(n_target, 2) |
| 80 | +M = ot.dist(X_source, Y_target) |
78 | 81 |
|
79 | 82 | #############################################################################
|
80 | 83 | #
|
|
86 | 89 | # results.
|
87 | 90 |
|
88 | 91 | method = "ASGD"
|
89 |
| -asgd_pi = ot.stochastic.transportation_matrix_entropic(method, eps, a, b, M, |
90 |
| - n_source, n_target, |
91 |
| - nb_iter, lr) |
| 92 | +asgd_pi = ot.stochastic.transportation_matrix_entropic(a, b, M, reg, method, |
| 93 | + numItermax, lr) |
92 | 94 | print(asgd_pi)
|
93 | 95 |
|
94 | 96 | #############################################################################
|
|
0 commit comments