forked from BioinfoMachineLearning/EnQA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
78 lines (74 loc) · 3.73 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import torch
import argparse
import numpy as np
from feature import create_feature
from network.resEGNN import resEGNN, resEGNN_with_mask, resEGNN_with_ne
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict.')
parser.add_argument('--input', type=str, required=True)
parser.add_argument('--output', type=str, required=True)
parser.add_argument('--model_path', type=str, required=True)
parser.add_argument('--disto_type', type=str, required=True)
parser.add_argument('--model_type', type=str, required=False, default='egnn')
parser.add_argument('--alphafold_prediction', type=str, required=False, default='')
parser.add_argument('--alphafold_feature_cache', type=str, required=False, default='')
parser.add_argument('--af2_pdb', type=str, required=False, default='',
help='Optional. PDBs from AlphaFold2 predcition for index correction with input pdb')
args = parser.parse_args()
if args.alphafold_feature_cache == '':
args.alphafold_feature_cache = None
device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'
if args.disto_type == 'base':
f1d, f2d, pos, el = create_feature(input_model_path=args.input, output_feature_path=args.output,
disto_type=args.disto_type)
model = resEGNN(dim2d=41, dim1d=23)
state = torch.load(args.model_path, map_location=torch.device('cpu'))
model.load_state_dict(state['model'])
model.to(device)
model.eval()
with torch.no_grad():
f1d = torch.tensor(f1d).unsqueeze(0).to(device)
f2d = torch.tensor(f2d).unsqueeze(0).to(device)
pos = torch.tensor(pos).to(device)
el = [torch.tensor(i).to(device) for i in el]
pred_bin, pos_new, pred_lddt = model(f1d, f2d, pos, el)
elif args.disto_type in ['disto', 'cov25', 'cov64', 'esto9']:
f1d, f2d, pos, el, cmap = create_feature(input_model_path=args.input, output_feature_path=args.output,
disto_type=args.disto_type,
alphafold_prediction_path=args.alphafold_prediction,
alphafold_prediction_cache=args.alphafold_feature_cache,
af2_pdb=args.af2_pdb)
if args.disto_type == 'disto':
dim2d = 25 + 64 * 5
elif args.disto_type == 'cov25':
dim2d = 25 + 25
elif args.disto_type == 'esto9':
dim2d = 25 + 9 * 5
else:
dim2d = 25 + 64
if args.model_type == 'egnn':
model = resEGNN_with_mask(dim2d=dim2d, dim1d=33)
elif args.model_type == 'egnn_ne':
model = resEGNN_with_ne(dim2d=dim2d, dim1d=33)
elif args.model_type == 'se3':
from network.se3_model import se3_model
model = se3_model(dim2d=dim2d, dim1d=33)
else:
raise NotImplementedError
state = torch.load(args.model_path, map_location=torch.device('cpu'))
model.load_state_dict(state['model'])
model.to(device)
model.eval()
with torch.no_grad():
f1d = torch.tensor(f1d).unsqueeze(0).to(device)
f2d = torch.tensor(f2d).unsqueeze(0).to(device)
pos = torch.tensor(pos).to(device)
el = [torch.tensor(i).to(device) for i in el]
cmap = torch.tensor(cmap).to(device)
_, _, pred_lddt = model(f1d, f2d, pos, el, cmap)
else:
raise NotImplementedError
out = pred_lddt.cpu().detach().numpy().astype(np.float16)
out[out > 1] = 1
np.save(os.path.join(args.output, os.path.basename(args.input)), out)