-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_cld.py
55 lines (43 loc) · 1.6 KB
/
run_cld.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from torchpgm.model import *
from torchpgm.layers import *
from cld.postprocessing import *
from cld.criterion import *
from cld.walker import *
from .utils import *
from .config import *
device = "cuda"
folder = f"{DATA}/vink"
Nh, Npam = 200, 5
best_epoch = 90
q_pi, N_pi = 21, 736
model_full_name = f"rbmssl2_pid_h{Nh}_npam{Npam}_gamma8.819763977946266"
def lit_to_pam(s):
pam = []
s += "N" * max(0, (Npam - len(s)))
for x in s:
pam += NAd_idx[x]
return torch.tensor(pam).float()[None].to(device)
pi = OneHotLayer(None, N=N_pi, q=q_pi, name="pi")
h = GaussianLayer(N=Nh, name="hidden")
classifier = PAM_classifier(Nh, Npam * 4)
E = [(pi.name, h.name)]
E.sort()
model_rbm = PI_RBM_SSL(classifier, layers={pi.name: pi, h.name: h}, edges=E, name=model_full_name)
model_rbm = model_rbm.to(device)
model_rbm.load(f"{folder}/weights/{model_full_name}_{best_epoch}.h5")
model_rbm.eval()
model_rbm = model_rbm.to("cpu")
model_rbm.ais()
model_rbm = model_rbm.to("cpu")
x_cas9 = torch.load(f"{DATA}/x_cas9.pt")
zero_idx = torch.load(f"{DATA}/zero_idx.pt")
kept_idx = torch.load(f"{DATA}/kept_idx.pt")
target = lit_to_pam("NGG")
objective = RbmCriterion(model_rbm, postprocessing=ConstantPostprocessor(0))
constraints = [
SimCriterion(x_cas9, nnz_idx, postprocessing=ConstantPostprocessor(None, 80)),
RbmCriterion(model_rbm, postprocessing=ConstantPostprocessor(-0.5, None)),
]
T = 0.3 * torch.ones(1, 1, len(kept_idx))
walker = Walker(x_cas9.view(21, -1).clone(), model_rbm, objective, constraints, zero_idx, gamma=1, n=1, a=1,
c=1e-2, eps=1, target=target.cpu(), T=T, kept_idx=kept_idx)