Skip to content

Commit

Permalink
First Commit.
Browse files Browse the repository at this point in the history
  • Loading branch information
Piyushi-0 committed Jul 10, 2023
1 parent 5a2e607 commit b589b05
Show file tree
Hide file tree
Showing 60 changed files with 4,968 additions and 2 deletions.
Binary file added .DS_Store
Binary file not shown.
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
# MMD-reg-OT
MMD regularized OT.
# MMD-OT:
#### Algorithms
- [Code for solving the MMD-OT plan using Accelerated PGD](./ot_mmd/mmdot.py)
- [Code for computing a batch of MMD-OT problems parallely](./ot_mmd/b_mmdot.py)
- [Code for solving the MMD-OT barycenter problem using Accelerated PGD](./ot_mmd/barycenter.py)
#### [Examples](./examples)
- [OT plan between Gaussians](./examples/synthetic/OTplan.ipynb)
- [Barycenter between Gaussians](./examples/synthetic/barycenter_with_imq.ipynb)
2 changes: 2 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from mmdot import *
from utils import *
Binary file added examples/.DS_Store
Binary file not shown.
Binary file added examples/barycenter_ScRNA/data/EB_t0.pickle
Binary file not shown.
Binary file added examples/barycenter_ScRNA/data/EB_t1.pickle
Binary file not shown.
Binary file added examples/barycenter_ScRNA/data/EB_t2.pickle
Binary file not shown.
Binary file added examples/barycenter_ScRNA/data/EB_t3.pickle
Binary file not shown.
Binary file added examples/barycenter_ScRNA/data/EB_t4.pickle
Binary file not shown.
89 changes: 89 additions & 0 deletions examples/barycenter_ScRNA/kluot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from ot_mmd.utils import createLogHandler, get_t, get_dist, get_G
import os
import argparse
import joblib
import torch
from kluot_bary import solve_md

parser = argparse.ArgumentParser(description="_")
parser.add_argument("--t_pred", required=True, type=int)
parser.add_argument("--best_lda", type=float, default=None)
parser.add_argument("--best_hp", type=float, default=None)
parser.add_argument("--save_as", default="")
args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available else "cpu")
dtype = torch.float64
t_predict = args.t_pred
max_itr = 1000

logger = createLogHandler(f"{args.save_as}.csv", str(os.getpid()))

if args.best_lda is None:
valt_predict = list(set([1, 2, 3]).symmetric_difference(set([t_predict])))
best_score = torch.inf
val = {}
for lda in [10, 1e-1, 1]:
val[lda] = {}
for hp in [1e-1, 1e-2, 1e-3, 1e-4, 1e-5]:
val[lda][hp] = []
for t in valt_predict:
init_tstep = t-1
final_tstep = t+1

data_tpredict = get_t(joblib.load(f"data/EB_t{t}.pickle"), device=device)

data_init = get_t(joblib.load(f"data/EB_t{init_tstep}.pickle"), device=device)
data_final = get_t(joblib.load(f"data/EB_t{final_tstep}.pickle"), device=device)

data_all = torch.vstack([data_init, data_final])
C = {1: get_dist(x=data_init, y=data_all, p=1),
2: get_dist(x=data_final, y=data_all, p=1)}

a = (torch.ones(data_init.shape[0])/data_init.shape[0]).to(dtype).to(device)
b = (torch.ones(data_final.shape[0])/data_final.shape[0]).to(dtype).to(device)

bary, _ = solve_md({1: a, 2: b}, C, {1: lda, 2: lda}, max_itr, coeff_entr=hp)

gt = (torch.ones(data_tpredict.shape[0])/data_tpredict.shape[0]).to(dtype).to(device)
data_cat = torch.vstack([data_tpredict, data_all])
G = get_G(ktype="rbf", x=data_cat, y=data_cat)
vec = torch.cat([gt, -bary])
val[lda][hp].append(torch.mv(G, vec).dot(vec).item())

logger.info(f", {lda}, {hp}, {sum(val[lda][hp])}")
if sum(val[lda][hp]) < best_score:
best_score = sum(val[lda][hp])
best_config = {"lda": lda, "hp": hp}

lda = best_config["lda"]
hp = best_config["hp"]
else:
lda = args.best_lda
hp = args.best_hp

t = t_predict

init_tstep = t-1
final_tstep = t+1

data_tpredict = get_t(joblib.load(f"data/EB_t{t}.pickle"), device=device)

data_init = get_t(joblib.load(f"data/EB_t{init_tstep}.pickle"), device=device)
data_final = get_t(joblib.load(f"data/EB_t{final_tstep}.pickle"), device=device)

data_all = torch.vstack([data_init, data_final])
C = {1: get_dist(x=data_init, y=data_all, p=1),
2: get_dist(x=data_final, y=data_all, p=1)}

a = (torch.ones(data_init.shape[0])/data_init.shape[0]).to(dtype).to(device)
b = (torch.ones(data_final.shape[0])/data_final.shape[0]).to(dtype).to(device)

bary, _ = solve_md({1: a, 2: b}, C, {1: lda, 2: lda}, max_itr, coeff_entr=hp)

gt = (torch.ones(data_tpredict.shape[0])/data_tpredict.shape[0]).to(dtype).to(device)
data_cat = torch.vstack([data_tpredict, data_all])
G = get_G(ktype="rbf", x=data_cat, y=data_cat)
vec = torch.cat([gt, -bary])
val_chosen = torch.sqrt(torch.mv(G, vec).dot(vec)).item()
logger.info(f"KL-UOT, {t}, {val_chosen}")
109 changes: 109 additions & 0 deletions examples/barycenter_ScRNA/kluot_bary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import torch
from ot_mmd.utils import get_marginals


def get_kl(v1, v2, case, eps=1e-10):
v1 = v1 + eps
v2 = v2 + eps
kl = torch.sum(torch.where(v1 != 0, v1*torch.log(v1/v2), 0))
if case == "unb":
kl = kl-v1.sum() + v2.sum()
return kl

def get_entropy(alpha, case, eps=1e-10):
alpha = alpha + eps
entropy = torch.sum(torch.where(alpha != 0, alpha * torch.log(alpha), 0))
if case == "unb":
entropy = entropy - alpha.sum()
return entropy

def get_obj(alpha, bary, v, C, lda, coeff_entr, rho={1: 0.5, 2: 0.5}, case="bal"):
cost_part = rho[1]*torch.tensordot(alpha[1], C[1]) + rho[2]*torch.tensordot(alpha[2], C[2])

alpha1_1, alpha1_T1 = get_marginals(alpha[1])
alpha2_1, alpha2_T1 = get_marginals(alpha[2])

lda1_part = rho[1]*get_kl(alpha1_1, v[1], case) + rho[2]*get_kl(alpha2_1, v[2], case)
lda2_part = rho[1]*get_kl(alpha1_T1, bary, case) + rho[2]*get_kl(alpha2_T1, bary, case)

obj = cost_part + lda[1]*lda1_part + lda[2]*lda2_part
obj += coeff_entr*(rho[1]*get_entropy(alpha[1], case)+rho[2]*get_entropy(alpha[2], case))
return obj

def get_grd(alpha, bary, v, C, lda, coeff_entr, rho={1: 0.5, 2: 0.5}, case="bal"):
eps = 1e-10

alpha[1] = alpha[1] + eps
alpha[2] = alpha[2] + eps
bary = bary + eps

alpha1_1, alpha1_T1 = get_marginals(alpha[1])
alpha2_1, alpha2_T1 = get_marginals(alpha[2])

grd_bary = -lda[2]*(rho[1]*alpha1_T1 + rho[2]*alpha2_T1)/bary
grd_1 = grd_2 = 0
if rho[1]>0:
term1 = torch.log(alpha1_1)-torch.log(v[1])
term2 = torch.log(alpha1_T1)-torch.log(bary)
if case == "bal":
term1 += 1
term2 += 1
grd_1 = rho[1]*(C[1] + lda[1]*term1[:, None] + lda[2]*term2)

if rho[2]>0:
term1 = torch.log(alpha2_1)-torch.log(v[2])
term2 = torch.log(alpha2_T1)-torch.log(bary)
if case == "bal":
term1 += 1
term2 += 1
grd_2 = rho[2]*(C[2] + lda[1]*term1[:, None] + lda[2]*term2)

grd_1 += rho[1]*coeff_entr*(1+torch.log(alpha[1])) if case == "bal" else rho[1]*coeff_entr*torch.log(alpha[1])
grd_2 += rho[2]*coeff_entr*(1+torch.log(alpha[2])) if case == "bal" else rho[2]*coeff_entr*torch.log(alpha[2])

return grd_1, grd_2, grd_bary


def solve_md(v, C, lda, max_itr, coeff_entr, rho={1: 0.5, 2: 0.5}, case="bal"):

def update_vars(var, grd, case):
s = 1/torch.norm(grd, torch.inf)
var = var*torch.exp(-grd*s)
if case == "bal":
var = var/var.sum()
return var

alpha = {1: torch.ones_like(C[1])/C[1].numel(),
2: torch.ones_like(C[2])/C[2].numel()}
bary = (torch.ones(C[1].shape[1])/C[1].shape[1]).to(C[1].dtype).to(C[1].device)
obj_itr = []
bary_best = None
best_itr = None

for itr in range(max_itr):
obj_itr.append(get_obj(alpha, bary, v, C, lda, coeff_entr, rho))
if best_itr is None or obj_itr[best_itr] > obj_itr[-1]:
best_itr = itr
bary_best = bary.clone()
grd_1, grd_2, grd_bary = get_grd(alpha, bary, v, C, lda, coeff_entr, rho)
if rho[1] > 0:
try: # error triggered when optimality has been reached
alpha[1] = update_vars(alpha[1], grd_1, case)
except Exception as e:
print(e)
pass

if rho[2] > 0:
try: # error triggered when optimality has been reached
alpha[2] = update_vars(alpha[2], grd_2, case)
except Exception as e:
print(e)
pass

try: # error triggered when optimality has been reached
bary = update_vars(bary, grd_bary, case)
except Exception as e:
print(e)
pass

return bary_best, obj_itr
41 changes: 41 additions & 0 deletions examples/barycenter_ScRNA/mmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from ot_mmd.utils import createLogHandler, get_t, get_G
import os
import argparse
import joblib
import torch

parser = argparse.ArgumentParser(description="_")
parser.add_argument("--t_pred", required=True, type=int)
parser.add_argument("--save_as", default="")
args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available else "cpu")
dtype = torch.float64
t_predict = args.t_pred

logger = createLogHandler(f"{args.save_as}.csv", str(os.getpid()))

t = t_predict

init_tstep = t-1
final_tstep = t+1

data_tpredict = get_t(joblib.load(f"data/EB_t{t}.pickle"), device=device)

data_init = get_t(joblib.load(f"data/EB_t{init_tstep}.pickle"), device=device)
data_final = get_t(joblib.load(f"data/EB_t{final_tstep}.pickle"), device=device)

data_all = torch.vstack([data_init, data_final])

a = (torch.ones(data_init.shape[0])/data_init.shape[0]).to(dtype).to(device)
b = (torch.ones(data_final.shape[0])/data_final.shape[0]).to(dtype).to(device)

bary = torch.cat([a, b])/2

gt = (torch.ones(data_tpredict.shape[0])/data_tpredict.shape[0]).to(dtype).to(device)
data_cat = torch.vstack([data_tpredict, data_all])
G = get_G(ktype="rbf", x=data_cat, y=data_cat)
vec = torch.cat([gt, -bary])
val_chosen = torch.sqrt(torch.mv(G, vec).dot(vec)).item()
logger.info(f"Method, tstep, MMD (lower is better)")
logger.info(f"MMD, {t}, {val_chosen}")
98 changes: 98 additions & 0 deletions examples/barycenter_ScRNA/proposed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from ot_mmd.barycenter import solve_apgd
from ot_mmd.utils import createLogHandler, get_t, get_dist, get_G
import os
import argparse
import joblib
import torch

parser = argparse.ArgumentParser(description="_")
parser.add_argument("--t_pred", required=True, type=int)
parser.add_argument("--best_lda", type=float, default=None)
parser.add_argument("--best_hp", type=float, default=None)
parser.add_argument("--save_as", default="")
args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available else "cpu")
dtype = torch.float64
max_itr = 1000
ktype = "imq_v2"
t_predict = args.t_pred

logger = createLogHandler(f"{args.save_as}.csv", str(os.getpid()))

if args.best_lda is None:
valt_predict = list(set([1, 2, 3]).symmetric_difference(set([t_predict])))
best_score = torch.inf
val = {}
for lda in [10, 1e-1, 1]:
val[lda] = {}
for khp in [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, None]:
val[lda][khp] = []
for t in valt_predict:
init_tstep = t-1
final_tstep = t+1

data_tpredict = get_t(joblib.load(f"data/EB_t{t}.pickle"), device=device)

data_init = get_t(joblib.load(f"data/EB_t{init_tstep}.pickle"), device=device)
data_final = get_t(joblib.load(f"data/EB_t{final_tstep}.pickle"), device=device)

data_all = torch.vstack([data_init, data_final])
C = {1: get_dist(x=data_init, y=data_all, p=1),
2: get_dist(x=data_final, y=data_all, p=1)}

G_all = get_G(ktype=ktype, khp=khp, x=data_all, y=data_all)
m1 = data_init.shape[0]
G = {1: G_all[:m1, :m1], 2: G_all[m1:, m1:], 'all': G_all}

a = (torch.ones(data_init.shape[0])/data_init.shape[0]).to(dtype).to(device)
b = (torch.ones(data_final.shape[0])/data_final.shape[0]).to(dtype).to(device)

bary, _ = solve_apgd(C, G, {1: a, 2: b}, max_itr, {1: lda, 2: lda}, case="bal")

gt = (torch.ones(data_tpredict.shape[0])/data_tpredict.shape[0]).to(dtype).to(device)
data_cat = torch.vstack([data_tpredict, data_all])
G = get_G(ktype="rbf", x=data_cat, y=data_cat)
vec = torch.cat([gt, -bary])
val[lda][khp].append(torch.mv(G, vec).dot(vec).item())

logger.info(f", {lda}, {khp}, {sum(val[lda][khp])}")
if sum(val[lda][khp]) < best_score:
best_score = sum(val[lda][khp])
best_config = {"lda": lda, "khp": khp}

lda = best_config["lda"]
khp = best_config["khp"]
else:
lda = args.best_lda
khp = args.best_hp

t = t_predict

init_tstep = t-1
final_tstep = t+1

data_tpredict = get_t(joblib.load(f"data/EB_t{t}.pickle"), device=device)

data_init = get_t(joblib.load(f"data/EB_t{init_tstep}.pickle"), device=device)
data_final = get_t(joblib.load(f"data/EB_t{final_tstep}.pickle"), device=device)

data_all = torch.vstack([data_init, data_final])
C = {1: get_dist(x=data_init, y=data_all, p=1),
2: get_dist(x=data_final, y=data_all, p=1)}

G_all = get_G(ktype=ktype, khp=khp, x=data_all, y=data_all)
m1 = data_init.shape[0]
G = {1: G_all[:m1, :m1], 2: G_all[m1:, m1:], 'all': G_all}

a = (torch.ones(data_init.shape[0])/data_init.shape[0]).to(dtype).to(device)
b = (torch.ones(data_final.shape[0])/data_final.shape[0]).to(dtype).to(device)

bary, _ = solve_apgd(C, G, {1: a, 2: b}, max_itr, {1: lda, 2: lda}, case="bal")

gt = (torch.ones(data_tpredict.shape[0])/data_tpredict.shape[0]).to(dtype).to(device)
data_cat = torch.vstack([data_tpredict, data_all])
G = get_G(ktype="rbf", x=data_cat, y=data_cat)
vec = torch.cat([gt, -bary])
val_chosen = torch.sqrt(torch.mv(G, vec).dot(vec)).item()
logger.info(f"UOT-MMD, {t}, {val_chosen}")
12 changes: 12 additions & 0 deletions examples/barycenter_ScRNA/results.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
2023-06-27 22:20:17; , Method, tstep, MMD (lower is better)
2023-06-27 22:20:17; , MMD, 1, 0.37495047828375055
2023-06-27 22:20:53; , KL-UOT, 1, 0.3906461023303901
2023-06-27 22:21:47; , UOT-MMD, 1, 0.33371626328939424
2023-06-27 22:21:53; , Method, tstep, MMD (lower is better)
2023-06-27 22:21:53; , MMD, 2, 0.19001395396118076
2023-06-27 22:22:57; , KL-UOT, 2, 0.18436960477258085
2023-06-27 22:24:35; , UOT-MMD, 2, 0.17922341489550467
2023-06-27 22:24:42; , Method, tstep, MMD (lower is better)
2023-06-27 22:24:42; , MMD, 3, 0.12121269628920549
2023-06-27 22:25:30; , KL-UOT, 3, 0.13796316453091137
2023-06-27 22:26:41; , UOT-MMD, 3, 0.1164323020279007
11 changes: 11 additions & 0 deletions examples/barycenter_ScRNA/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
python mmd.py --t_pred 1 --save_as results
python kluot.py --t_pred 1 --best_lda 10 --best_hp 0.01 --save_as results
python proposed.py --t_pred 1 --best_lda 1 --best_hp 0.1 --save_as results

python mmd.py --t_pred 2 --save_as results
python kluot.py --t_pred 2 --best_lda 1 --best_hp 0.1 --save_as results
python proposed.py --t_pred 2 --best_lda 1 --save_as results

python mmd.py --t_pred 3 --save_as results
python kluot.py --t_pred 3 --best_lda 1 --best_hp 0.1 --save_as results
python proposed.py --t_pred 3 --best_lda 1 --save_as results
Binary file added examples/jumbot/.DS_Store
Binary file not shown.
Binary file added examples/jumbot/digits/.DS_Store
Binary file not shown.
Binary file added examples/jumbot/digits/kluot/.DS_Store
Binary file not shown.
Loading

0 comments on commit b589b05

Please sign in to comment.