-
Notifications
You must be signed in to change notification settings - Fork 19
/
cache.py
48 lines (36 loc) · 1.02 KB
/
cache.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from tqdm import tqdm
import sidechainnet as scn
from ddpm_proteins.utils import get_msa_attention_embeddings, get_msa_transformer
# sidechainnet data
data = scn.load(
casp_version = 12,
thinning = 30,
with_pytorch = 'dataloaders',
batch_size = 1,
dynamic_batching = False
)
# constants
LENGTH_THRES = 256
# function for fetching MSAs, fill-in depending on your setup
def fetch_msas_fn(aa_str):
"""
given a protein as amino acid string
fill in a function that returns the MSAs, as a list of strings
(by default, it will return nothing, and only the primary sequence is fed into MSA Transformer)
"""
return []
# caching loop
model, batch_converter = get_msa_transformer()
for batch in tqdm(data['train']):
if batch.seqs.shape[1] > LENGTH_THRES:
continue
pids = batch.pids
seqs = batch.seqs.argmax(dim = -1)
_ = get_msa_attention_embeddings(
model,
batch_converter,
seqs,
batch.pids,
fetch_msas_fn
)
print('caching complete')