Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 1 addition & 283 deletions ot/da.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .bregman import sinkhorn
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization
from .utils import check_params, deprecated, BaseEstimator
from .utils import check_params, BaseEstimator
from .optim import cg
from .optim import gcg

Expand Down Expand Up @@ -740,288 +740,6 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
return A, b


@deprecated("The class OTDA is deprecated in 0.3.1 and will be "
"removed in 0.5"
"\n\tfor standard transport use class EMDTransport instead.")
class OTDA(object):

"""Class for domain adaptation with optimal transport as proposed in [5]


References
----------

.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
"Optimal Transport for Domain Adaptation," in IEEE Transactions on
Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1

"""

def __init__(self, metric='sqeuclidean', norm=None):
""" Class initialization"""
self.xs = 0
self.xt = 0
self.G = 0
self.metric = metric
self.norm = norm
self.computed = False

def fit(self, xs, xt, ws=None, wt=None, max_iter=100000):
"""Fit domain adaptation between samples is xs and xt
(with optional weights)"""
self.xs = xs
self.xt = xt

if wt is None:
wt = unif(xt.shape[0])
if ws is None:
ws = unif(xs.shape[0])

self.ws = ws
self.wt = wt

self.M = dist(xs, xt, metric=self.metric)
self.M = cost_normalization(self.M, self.norm)
self.G = emd(ws, wt, self.M, max_iter)
self.computed = True

def interp(self, direction=1):
"""Barycentric interpolation for the source (1) or target (-1) samples

This Barycentric interpolation solves for each source (resp target)
sample xs (resp xt) the following optimization problem:

.. math::
arg\min_x \sum_i \gamma_{k,i} c(x,x_i^t)

where k is the index of the sample in xs

For the moment only squared euclidean distance is provided but more
metric could be used in the future.

"""
if direction > 0: # >0 then source to target
G = self.G
w = self.ws.reshape((self.xs.shape[0], 1))
x = self.xt
else:
G = self.G.T
w = self.wt.reshape((self.xt.shape[0], 1))
x = self.xs

if self.computed:
if self.metric == 'sqeuclidean':
return np.dot(G / w, x) # weighted mean
else:
print(
"Warning, metric not handled yet, using weighted average")
return np.dot(G / w, x) # weighted mean
return None
else:
print("Warning, model not fitted yet, returning None")
return None

def predict(self, x, direction=1):
""" Out of sample mapping using the formulation from [6]

For each sample x to map, it finds the nearest source sample xs and
map the samle x to the position xst+(x-xs) wher xst is the barycentric
interpolation of source sample xs.

References
----------

.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
Regularized discrete optimal transport. SIAM Journal on Imaging
Sciences, 7(3), 1853-1882.

"""
if direction > 0: # >0 then source to target
xf = self.xt
x0 = self.xs
else:
xf = self.xs
x0 = self.xt

D0 = dist(x, x0) # dist netween new samples an source
idx = np.argmin(D0, 1) # closest one
xf = self.interp(direction) # interp the source samples
# aply the delta to the interpolation
return xf[idx, :] + x - x0[idx, :]


@deprecated("The class OTDA_sinkhorn is deprecated in 0.3.1 and will be"
" removed in 0.5 \nUse class SinkhornTransport instead.")
class OTDA_sinkhorn(OTDA):

"""Class for domain adaptation with optimal transport with entropic
regularization


"""

def fit(self, xs, xt, reg=1, ws=None, wt=None, **kwargs):
"""Fit regularized domain adaptation between samples is xs and xt
(with optional weights)"""
self.xs = xs
self.xt = xt

if wt is None:
wt = unif(xt.shape[0])
if ws is None:
ws = unif(xs.shape[0])

self.ws = ws
self.wt = wt

self.M = dist(xs, xt, metric=self.metric)
self.M = cost_normalization(self.M, self.norm)
self.G = sinkhorn(ws, wt, self.M, reg, **kwargs)
self.computed = True


@deprecated("The class OTDA_lpl1 is deprecated in 0.3.1 and will be"
" removed in 0.5 \nUse class SinkhornLpl1Transport instead.")
class OTDA_lpl1(OTDA):

"""Class for domain adaptation with optimal transport with entropic and
group regularization"""

def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs):
"""Fit regularized domain adaptation between samples is xs and xt
(with optional weights), See ot.da.sinkhorn_lpl1_mm for fit
parameters"""
self.xs = xs
self.xt = xt

if wt is None:
wt = unif(xt.shape[0])
if ws is None:
ws = unif(xs.shape[0])

self.ws = ws
self.wt = wt

self.M = dist(xs, xt, metric=self.metric)
self.M = cost_normalization(self.M, self.norm)
self.G = sinkhorn_lpl1_mm(ws, ys, wt, self.M, reg, eta, **kwargs)
self.computed = True


@deprecated("The class OTDA_l1L2 is deprecated in 0.3.1 and will be"
" removed in 0.5 \nUse class SinkhornL1l2Transport instead.")
class OTDA_l1l2(OTDA):

"""Class for domain adaptation with optimal transport with entropic
and group lasso regularization"""

def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs):
"""Fit regularized domain adaptation between samples is xs and xt
(with optional weights), See ot.da.sinkhorn_lpl1_gl for fit
parameters"""
self.xs = xs
self.xt = xt

if wt is None:
wt = unif(xt.shape[0])
if ws is None:
ws = unif(xs.shape[0])

self.ws = ws
self.wt = wt

self.M = dist(xs, xt, metric=self.metric)
self.M = cost_normalization(self.M, self.norm)
self.G = sinkhorn_l1l2_gl(ws, ys, wt, self.M, reg, eta, **kwargs)
self.computed = True


@deprecated("The class OTDA_mapping_linear is deprecated in 0.3.1 and will be"
" removed in 0.5 \nUse class MappingTransport instead.")
class OTDA_mapping_linear(OTDA):

"""Class for optimal transport with joint linear mapping estimation as in
[8]
"""

def __init__(self):
""" Class initialization"""

self.xs = 0
self.xt = 0
self.G = 0
self.L = 0
self.bias = False
self.computed = False
self.metric = 'sqeuclidean'

def fit(self, xs, xt, mu=1, eta=1, bias=False, **kwargs):
""" Fit domain adaptation between samples is xs and xt (with optional
weights)"""
self.xs = xs
self.xt = xt
self.bias = bias

self.ws = unif(xs.shape[0])
self.wt = unif(xt.shape[0])

self.G, self.L = joint_OT_mapping_linear(
xs, xt, mu=mu, eta=eta, bias=bias, **kwargs)
self.computed = True

def mapping(self):
return lambda x: self.predict(x)

def predict(self, x):
""" Out of sample mapping estimated during the call to fit"""
if self.computed:
if self.bias:
x = np.hstack((x, np.ones((x.shape[0], 1))))
return x.dot(self.L) # aply the delta to the interpolation
else:
print("Warning, model not fitted yet, returning None")
return None


@deprecated("The class OTDA_mapping_kernel is deprecated in 0.3.1 and will be"
" removed in 0.5 \nUse class MappingTransport instead.")
class OTDA_mapping_kernel(OTDA_mapping_linear):

"""Class for optimal transport with joint nonlinear mapping
estimation as in [8]"""

def fit(self, xs, xt, mu=1, eta=1, bias=False, kerneltype='gaussian',
sigma=1, **kwargs):
""" Fit domain adaptation between samples is xs and xt """
self.xs = xs
self.xt = xt
self.bias = bias

self.ws = unif(xs.shape[0])
self.wt = unif(xt.shape[0])
self.kernel = kerneltype
self.sigma = sigma
self.kwargs = kwargs

self.G, self.L = joint_OT_mapping_kernel(
xs, xt, mu=mu, eta=eta, bias=bias, **kwargs)
self.computed = True

def predict(self, x):
""" Out of sample mapping estimated during the call to fit"""

if self.computed:
K = kernel(
x, self.xs, method=self.kernel, sigma=self.sigma,
**self.kwargs)
if self.bias:
K = np.hstack((K, np.ones((x.shape[0], 1))))
return K.dot(self.L)
else:
print("Warning, model not fitted yet, returning None")
return None


def distribution_estimation_uniform(X):
"""estimates a uniform distribution from an array of samples X

Expand Down
63 changes: 0 additions & 63 deletions test/test_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,66 +484,3 @@ def test_linear_mapping_class():
Cst = np.cov(Xst.T)

np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)


def test_otda():

n_samples = 150 # nb samples
np.random.seed(0)

xs, ys = ot.datasets.make_data_classif('3gauss', n_samples)
xt, yt = ot.datasets.make_data_classif('3gauss2', n_samples)

a, b = ot.unif(n_samples), ot.unif(n_samples)

# LP problem
da_emd = ot.da.OTDA() # init class
da_emd.fit(xs, xt) # fit distributions
da_emd.interp() # interpolation of source samples
da_emd.predict(xs) # interpolation of source samples

np.testing.assert_allclose(a, np.sum(da_emd.G, 1))
np.testing.assert_allclose(b, np.sum(da_emd.G, 0))

# sinkhorn regularization
lambd = 1e-1
da_entrop = ot.da.OTDA_sinkhorn()
da_entrop.fit(xs, xt, reg=lambd)
da_entrop.interp()
da_entrop.predict(xs)

np.testing.assert_allclose(
a, np.sum(da_entrop.G, 1), rtol=1e-3, atol=1e-3)
np.testing.assert_allclose(b, np.sum(da_entrop.G, 0), rtol=1e-3, atol=1e-3)

# non-convex Group lasso regularization
reg = 1e-1
eta = 1e0
da_lpl1 = ot.da.OTDA_lpl1()
da_lpl1.fit(xs, ys, xt, reg=reg, eta=eta)
da_lpl1.interp()
da_lpl1.predict(xs)

np.testing.assert_allclose(a, np.sum(da_lpl1.G, 1), rtol=1e-3, atol=1e-3)
np.testing.assert_allclose(b, np.sum(da_lpl1.G, 0), rtol=1e-3, atol=1e-3)

# True Group lasso regularization
reg = 1e-1
eta = 2e0
da_l1l2 = ot.da.OTDA_l1l2()
da_l1l2.fit(xs, ys, xt, reg=reg, eta=eta, numItermax=20, verbose=True)
da_l1l2.interp()
da_l1l2.predict(xs)

np.testing.assert_allclose(a, np.sum(da_l1l2.G, 1), rtol=1e-3, atol=1e-3)
np.testing.assert_allclose(b, np.sum(da_l1l2.G, 0), rtol=1e-3, atol=1e-3)

# linear mapping
da_emd = ot.da.OTDA_mapping_linear() # init class
da_emd.fit(xs, xt, numItermax=10) # fit distributions
da_emd.predict(xs) # interpolation of source samples

# nonlinear mapping
da_emd = ot.da.OTDA_mapping_kernel() # init class
da_emd.fit(xs, xt, numItermax=10) # fit distributions
da_emd.predict(xs) # interpolation of source samples