-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
60 changed files
with
4,968 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,8 @@ | ||
# MMD-reg-OT | ||
MMD regularized OT. | ||
# MMD-OT: | ||
#### Algorithms | ||
- [Code for solving the MMD-OT plan using Accelerated PGD](./ot_mmd/mmdot.py) | ||
- [Code for computing a batch of MMD-OT problems parallely](./ot_mmd/b_mmdot.py) | ||
- [Code for solving the MMD-OT barycenter problem using Accelerated PGD](./ot_mmd/barycenter.py) | ||
#### [Examples](./examples) | ||
- [OT plan between Gaussians](./examples/synthetic/OTplan.ipynb) | ||
- [Barycenter between Gaussians](./examples/synthetic/barycenter_with_imq.ipynb) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from mmdot import * | ||
from utils import * |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
from ot_mmd.utils import createLogHandler, get_t, get_dist, get_G | ||
import os | ||
import argparse | ||
import joblib | ||
import torch | ||
from kluot_bary import solve_md | ||
|
||
parser = argparse.ArgumentParser(description="_") | ||
parser.add_argument("--t_pred", required=True, type=int) | ||
parser.add_argument("--best_lda", type=float, default=None) | ||
parser.add_argument("--best_hp", type=float, default=None) | ||
parser.add_argument("--save_as", default="") | ||
args = parser.parse_args() | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available else "cpu") | ||
dtype = torch.float64 | ||
t_predict = args.t_pred | ||
max_itr = 1000 | ||
|
||
logger = createLogHandler(f"{args.save_as}.csv", str(os.getpid())) | ||
|
||
if args.best_lda is None: | ||
valt_predict = list(set([1, 2, 3]).symmetric_difference(set([t_predict]))) | ||
best_score = torch.inf | ||
val = {} | ||
for lda in [10, 1e-1, 1]: | ||
val[lda] = {} | ||
for hp in [1e-1, 1e-2, 1e-3, 1e-4, 1e-5]: | ||
val[lda][hp] = [] | ||
for t in valt_predict: | ||
init_tstep = t-1 | ||
final_tstep = t+1 | ||
|
||
data_tpredict = get_t(joblib.load(f"data/EB_t{t}.pickle"), device=device) | ||
|
||
data_init = get_t(joblib.load(f"data/EB_t{init_tstep}.pickle"), device=device) | ||
data_final = get_t(joblib.load(f"data/EB_t{final_tstep}.pickle"), device=device) | ||
|
||
data_all = torch.vstack([data_init, data_final]) | ||
C = {1: get_dist(x=data_init, y=data_all, p=1), | ||
2: get_dist(x=data_final, y=data_all, p=1)} | ||
|
||
a = (torch.ones(data_init.shape[0])/data_init.shape[0]).to(dtype).to(device) | ||
b = (torch.ones(data_final.shape[0])/data_final.shape[0]).to(dtype).to(device) | ||
|
||
bary, _ = solve_md({1: a, 2: b}, C, {1: lda, 2: lda}, max_itr, coeff_entr=hp) | ||
|
||
gt = (torch.ones(data_tpredict.shape[0])/data_tpredict.shape[0]).to(dtype).to(device) | ||
data_cat = torch.vstack([data_tpredict, data_all]) | ||
G = get_G(ktype="rbf", x=data_cat, y=data_cat) | ||
vec = torch.cat([gt, -bary]) | ||
val[lda][hp].append(torch.mv(G, vec).dot(vec).item()) | ||
|
||
logger.info(f", {lda}, {hp}, {sum(val[lda][hp])}") | ||
if sum(val[lda][hp]) < best_score: | ||
best_score = sum(val[lda][hp]) | ||
best_config = {"lda": lda, "hp": hp} | ||
|
||
lda = best_config["lda"] | ||
hp = best_config["hp"] | ||
else: | ||
lda = args.best_lda | ||
hp = args.best_hp | ||
|
||
t = t_predict | ||
|
||
init_tstep = t-1 | ||
final_tstep = t+1 | ||
|
||
data_tpredict = get_t(joblib.load(f"data/EB_t{t}.pickle"), device=device) | ||
|
||
data_init = get_t(joblib.load(f"data/EB_t{init_tstep}.pickle"), device=device) | ||
data_final = get_t(joblib.load(f"data/EB_t{final_tstep}.pickle"), device=device) | ||
|
||
data_all = torch.vstack([data_init, data_final]) | ||
C = {1: get_dist(x=data_init, y=data_all, p=1), | ||
2: get_dist(x=data_final, y=data_all, p=1)} | ||
|
||
a = (torch.ones(data_init.shape[0])/data_init.shape[0]).to(dtype).to(device) | ||
b = (torch.ones(data_final.shape[0])/data_final.shape[0]).to(dtype).to(device) | ||
|
||
bary, _ = solve_md({1: a, 2: b}, C, {1: lda, 2: lda}, max_itr, coeff_entr=hp) | ||
|
||
gt = (torch.ones(data_tpredict.shape[0])/data_tpredict.shape[0]).to(dtype).to(device) | ||
data_cat = torch.vstack([data_tpredict, data_all]) | ||
G = get_G(ktype="rbf", x=data_cat, y=data_cat) | ||
vec = torch.cat([gt, -bary]) | ||
val_chosen = torch.sqrt(torch.mv(G, vec).dot(vec)).item() | ||
logger.info(f"KL-UOT, {t}, {val_chosen}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import torch | ||
from ot_mmd.utils import get_marginals | ||
|
||
|
||
def get_kl(v1, v2, case, eps=1e-10): | ||
v1 = v1 + eps | ||
v2 = v2 + eps | ||
kl = torch.sum(torch.where(v1 != 0, v1*torch.log(v1/v2), 0)) | ||
if case == "unb": | ||
kl = kl-v1.sum() + v2.sum() | ||
return kl | ||
|
||
def get_entropy(alpha, case, eps=1e-10): | ||
alpha = alpha + eps | ||
entropy = torch.sum(torch.where(alpha != 0, alpha * torch.log(alpha), 0)) | ||
if case == "unb": | ||
entropy = entropy - alpha.sum() | ||
return entropy | ||
|
||
def get_obj(alpha, bary, v, C, lda, coeff_entr, rho={1: 0.5, 2: 0.5}, case="bal"): | ||
cost_part = rho[1]*torch.tensordot(alpha[1], C[1]) + rho[2]*torch.tensordot(alpha[2], C[2]) | ||
|
||
alpha1_1, alpha1_T1 = get_marginals(alpha[1]) | ||
alpha2_1, alpha2_T1 = get_marginals(alpha[2]) | ||
|
||
lda1_part = rho[1]*get_kl(alpha1_1, v[1], case) + rho[2]*get_kl(alpha2_1, v[2], case) | ||
lda2_part = rho[1]*get_kl(alpha1_T1, bary, case) + rho[2]*get_kl(alpha2_T1, bary, case) | ||
|
||
obj = cost_part + lda[1]*lda1_part + lda[2]*lda2_part | ||
obj += coeff_entr*(rho[1]*get_entropy(alpha[1], case)+rho[2]*get_entropy(alpha[2], case)) | ||
return obj | ||
|
||
def get_grd(alpha, bary, v, C, lda, coeff_entr, rho={1: 0.5, 2: 0.5}, case="bal"): | ||
eps = 1e-10 | ||
|
||
alpha[1] = alpha[1] + eps | ||
alpha[2] = alpha[2] + eps | ||
bary = bary + eps | ||
|
||
alpha1_1, alpha1_T1 = get_marginals(alpha[1]) | ||
alpha2_1, alpha2_T1 = get_marginals(alpha[2]) | ||
|
||
grd_bary = -lda[2]*(rho[1]*alpha1_T1 + rho[2]*alpha2_T1)/bary | ||
grd_1 = grd_2 = 0 | ||
if rho[1]>0: | ||
term1 = torch.log(alpha1_1)-torch.log(v[1]) | ||
term2 = torch.log(alpha1_T1)-torch.log(bary) | ||
if case == "bal": | ||
term1 += 1 | ||
term2 += 1 | ||
grd_1 = rho[1]*(C[1] + lda[1]*term1[:, None] + lda[2]*term2) | ||
|
||
if rho[2]>0: | ||
term1 = torch.log(alpha2_1)-torch.log(v[2]) | ||
term2 = torch.log(alpha2_T1)-torch.log(bary) | ||
if case == "bal": | ||
term1 += 1 | ||
term2 += 1 | ||
grd_2 = rho[2]*(C[2] + lda[1]*term1[:, None] + lda[2]*term2) | ||
|
||
grd_1 += rho[1]*coeff_entr*(1+torch.log(alpha[1])) if case == "bal" else rho[1]*coeff_entr*torch.log(alpha[1]) | ||
grd_2 += rho[2]*coeff_entr*(1+torch.log(alpha[2])) if case == "bal" else rho[2]*coeff_entr*torch.log(alpha[2]) | ||
|
||
return grd_1, grd_2, grd_bary | ||
|
||
|
||
def solve_md(v, C, lda, max_itr, coeff_entr, rho={1: 0.5, 2: 0.5}, case="bal"): | ||
|
||
def update_vars(var, grd, case): | ||
s = 1/torch.norm(grd, torch.inf) | ||
var = var*torch.exp(-grd*s) | ||
if case == "bal": | ||
var = var/var.sum() | ||
return var | ||
|
||
alpha = {1: torch.ones_like(C[1])/C[1].numel(), | ||
2: torch.ones_like(C[2])/C[2].numel()} | ||
bary = (torch.ones(C[1].shape[1])/C[1].shape[1]).to(C[1].dtype).to(C[1].device) | ||
obj_itr = [] | ||
bary_best = None | ||
best_itr = None | ||
|
||
for itr in range(max_itr): | ||
obj_itr.append(get_obj(alpha, bary, v, C, lda, coeff_entr, rho)) | ||
if best_itr is None or obj_itr[best_itr] > obj_itr[-1]: | ||
best_itr = itr | ||
bary_best = bary.clone() | ||
grd_1, grd_2, grd_bary = get_grd(alpha, bary, v, C, lda, coeff_entr, rho) | ||
if rho[1] > 0: | ||
try: # error triggered when optimality has been reached | ||
alpha[1] = update_vars(alpha[1], grd_1, case) | ||
except Exception as e: | ||
print(e) | ||
pass | ||
|
||
if rho[2] > 0: | ||
try: # error triggered when optimality has been reached | ||
alpha[2] = update_vars(alpha[2], grd_2, case) | ||
except Exception as e: | ||
print(e) | ||
pass | ||
|
||
try: # error triggered when optimality has been reached | ||
bary = update_vars(bary, grd_bary, case) | ||
except Exception as e: | ||
print(e) | ||
pass | ||
|
||
return bary_best, obj_itr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from ot_mmd.utils import createLogHandler, get_t, get_G | ||
import os | ||
import argparse | ||
import joblib | ||
import torch | ||
|
||
parser = argparse.ArgumentParser(description="_") | ||
parser.add_argument("--t_pred", required=True, type=int) | ||
parser.add_argument("--save_as", default="") | ||
args = parser.parse_args() | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available else "cpu") | ||
dtype = torch.float64 | ||
t_predict = args.t_pred | ||
|
||
logger = createLogHandler(f"{args.save_as}.csv", str(os.getpid())) | ||
|
||
t = t_predict | ||
|
||
init_tstep = t-1 | ||
final_tstep = t+1 | ||
|
||
data_tpredict = get_t(joblib.load(f"data/EB_t{t}.pickle"), device=device) | ||
|
||
data_init = get_t(joblib.load(f"data/EB_t{init_tstep}.pickle"), device=device) | ||
data_final = get_t(joblib.load(f"data/EB_t{final_tstep}.pickle"), device=device) | ||
|
||
data_all = torch.vstack([data_init, data_final]) | ||
|
||
a = (torch.ones(data_init.shape[0])/data_init.shape[0]).to(dtype).to(device) | ||
b = (torch.ones(data_final.shape[0])/data_final.shape[0]).to(dtype).to(device) | ||
|
||
bary = torch.cat([a, b])/2 | ||
|
||
gt = (torch.ones(data_tpredict.shape[0])/data_tpredict.shape[0]).to(dtype).to(device) | ||
data_cat = torch.vstack([data_tpredict, data_all]) | ||
G = get_G(ktype="rbf", x=data_cat, y=data_cat) | ||
vec = torch.cat([gt, -bary]) | ||
val_chosen = torch.sqrt(torch.mv(G, vec).dot(vec)).item() | ||
logger.info(f"Method, tstep, MMD (lower is better)") | ||
logger.info(f"MMD, {t}, {val_chosen}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from ot_mmd.barycenter import solve_apgd | ||
from ot_mmd.utils import createLogHandler, get_t, get_dist, get_G | ||
import os | ||
import argparse | ||
import joblib | ||
import torch | ||
|
||
parser = argparse.ArgumentParser(description="_") | ||
parser.add_argument("--t_pred", required=True, type=int) | ||
parser.add_argument("--best_lda", type=float, default=None) | ||
parser.add_argument("--best_hp", type=float, default=None) | ||
parser.add_argument("--save_as", default="") | ||
args = parser.parse_args() | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available else "cpu") | ||
dtype = torch.float64 | ||
max_itr = 1000 | ||
ktype = "imq_v2" | ||
t_predict = args.t_pred | ||
|
||
logger = createLogHandler(f"{args.save_as}.csv", str(os.getpid())) | ||
|
||
if args.best_lda is None: | ||
valt_predict = list(set([1, 2, 3]).symmetric_difference(set([t_predict]))) | ||
best_score = torch.inf | ||
val = {} | ||
for lda in [10, 1e-1, 1]: | ||
val[lda] = {} | ||
for khp in [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, None]: | ||
val[lda][khp] = [] | ||
for t in valt_predict: | ||
init_tstep = t-1 | ||
final_tstep = t+1 | ||
|
||
data_tpredict = get_t(joblib.load(f"data/EB_t{t}.pickle"), device=device) | ||
|
||
data_init = get_t(joblib.load(f"data/EB_t{init_tstep}.pickle"), device=device) | ||
data_final = get_t(joblib.load(f"data/EB_t{final_tstep}.pickle"), device=device) | ||
|
||
data_all = torch.vstack([data_init, data_final]) | ||
C = {1: get_dist(x=data_init, y=data_all, p=1), | ||
2: get_dist(x=data_final, y=data_all, p=1)} | ||
|
||
G_all = get_G(ktype=ktype, khp=khp, x=data_all, y=data_all) | ||
m1 = data_init.shape[0] | ||
G = {1: G_all[:m1, :m1], 2: G_all[m1:, m1:], 'all': G_all} | ||
|
||
a = (torch.ones(data_init.shape[0])/data_init.shape[0]).to(dtype).to(device) | ||
b = (torch.ones(data_final.shape[0])/data_final.shape[0]).to(dtype).to(device) | ||
|
||
bary, _ = solve_apgd(C, G, {1: a, 2: b}, max_itr, {1: lda, 2: lda}, case="bal") | ||
|
||
gt = (torch.ones(data_tpredict.shape[0])/data_tpredict.shape[0]).to(dtype).to(device) | ||
data_cat = torch.vstack([data_tpredict, data_all]) | ||
G = get_G(ktype="rbf", x=data_cat, y=data_cat) | ||
vec = torch.cat([gt, -bary]) | ||
val[lda][khp].append(torch.mv(G, vec).dot(vec).item()) | ||
|
||
logger.info(f", {lda}, {khp}, {sum(val[lda][khp])}") | ||
if sum(val[lda][khp]) < best_score: | ||
best_score = sum(val[lda][khp]) | ||
best_config = {"lda": lda, "khp": khp} | ||
|
||
lda = best_config["lda"] | ||
khp = best_config["khp"] | ||
else: | ||
lda = args.best_lda | ||
khp = args.best_hp | ||
|
||
t = t_predict | ||
|
||
init_tstep = t-1 | ||
final_tstep = t+1 | ||
|
||
data_tpredict = get_t(joblib.load(f"data/EB_t{t}.pickle"), device=device) | ||
|
||
data_init = get_t(joblib.load(f"data/EB_t{init_tstep}.pickle"), device=device) | ||
data_final = get_t(joblib.load(f"data/EB_t{final_tstep}.pickle"), device=device) | ||
|
||
data_all = torch.vstack([data_init, data_final]) | ||
C = {1: get_dist(x=data_init, y=data_all, p=1), | ||
2: get_dist(x=data_final, y=data_all, p=1)} | ||
|
||
G_all = get_G(ktype=ktype, khp=khp, x=data_all, y=data_all) | ||
m1 = data_init.shape[0] | ||
G = {1: G_all[:m1, :m1], 2: G_all[m1:, m1:], 'all': G_all} | ||
|
||
a = (torch.ones(data_init.shape[0])/data_init.shape[0]).to(dtype).to(device) | ||
b = (torch.ones(data_final.shape[0])/data_final.shape[0]).to(dtype).to(device) | ||
|
||
bary, _ = solve_apgd(C, G, {1: a, 2: b}, max_itr, {1: lda, 2: lda}, case="bal") | ||
|
||
gt = (torch.ones(data_tpredict.shape[0])/data_tpredict.shape[0]).to(dtype).to(device) | ||
data_cat = torch.vstack([data_tpredict, data_all]) | ||
G = get_G(ktype="rbf", x=data_cat, y=data_cat) | ||
vec = torch.cat([gt, -bary]) | ||
val_chosen = torch.sqrt(torch.mv(G, vec).dot(vec)).item() | ||
logger.info(f"UOT-MMD, {t}, {val_chosen}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
2023-06-27 22:20:17; , Method, tstep, MMD (lower is better) | ||
2023-06-27 22:20:17; , MMD, 1, 0.37495047828375055 | ||
2023-06-27 22:20:53; , KL-UOT, 1, 0.3906461023303901 | ||
2023-06-27 22:21:47; , UOT-MMD, 1, 0.33371626328939424 | ||
2023-06-27 22:21:53; , Method, tstep, MMD (lower is better) | ||
2023-06-27 22:21:53; , MMD, 2, 0.19001395396118076 | ||
2023-06-27 22:22:57; , KL-UOT, 2, 0.18436960477258085 | ||
2023-06-27 22:24:35; , UOT-MMD, 2, 0.17922341489550467 | ||
2023-06-27 22:24:42; , Method, tstep, MMD (lower is better) | ||
2023-06-27 22:24:42; , MMD, 3, 0.12121269628920549 | ||
2023-06-27 22:25:30; , KL-UOT, 3, 0.13796316453091137 | ||
2023-06-27 22:26:41; , UOT-MMD, 3, 0.1164323020279007 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
python mmd.py --t_pred 1 --save_as results | ||
python kluot.py --t_pred 1 --best_lda 10 --best_hp 0.01 --save_as results | ||
python proposed.py --t_pred 1 --best_lda 1 --best_hp 0.1 --save_as results | ||
|
||
python mmd.py --t_pred 2 --save_as results | ||
python kluot.py --t_pred 2 --best_lda 1 --best_hp 0.1 --save_as results | ||
python proposed.py --t_pred 2 --best_lda 1 --save_as results | ||
|
||
python mmd.py --t_pred 3 --save_as results | ||
python kluot.py --t_pred 3 --best_lda 1 --best_hp 0.1 --save_as results | ||
python proposed.py --t_pred 3 --best_lda 1 --save_as results |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Oops, something went wrong.