forked from evolutionaryscale/esm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
local_generate.py
171 lines (153 loc) · 6.17 KB
/
local_generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from esm.models.esm3 import ESM3
from esm.sdk.api import (
ESM3InferenceClient,
ESMProtein,
ESMProteinError,
ESMProteinTensor,
GenerationConfig,
LogitsConfig,
LogitsOutput,
SamplingConfig,
SamplingTrackConfig,
)
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.types import FunctionAnnotation
def get_sample_protein() -> ESMProtein:
protein = ProteinChain.from_rcsb("1utn")
protein = ESMProtein.from_protein_chain(protein)
protein.function_annotations = [
# Peptidase S1A, chymotrypsin family: https://www.ebi.ac.uk/interpro/structure/PDB/1utn/
FunctionAnnotation(label="peptidase", start=100, end=114),
FunctionAnnotation(label="chymotrypsin", start=190, end=202),
]
return protein
def main(client: ESM3InferenceClient):
# Single step decoding
protein = get_sample_protein()
protein.function_annotations = None
protein = client.encode(protein)
single_step_protein = client.forward_and_sample(
protein,
SamplingConfig(structure=SamplingTrackConfig(topk_logprobs=2)),
)
single_step_protein.protein_tensor.sequence = protein.sequence
single_step_protein = client.decode(single_step_protein.protein_tensor)
# Generate with partial sequence.
prompt = (
"___________________________________________________DQATSLRILNNGHAFNVEFDDSQDKAVLK"
"GGPLDGTYRLIQFHFHWGSLDGQGSEHTVDKKKYAAELHLVHWNTKYGDFGKAVQQPDGLAVLGIFLKVGSAKPGLQKVVDVLDSIK"
"TKGKSADFTNFDPRGLLPESLDYWTYPGSLTTPP___________________________________________________________"
)
protein = ESMProtein(sequence=prompt)
protein = client.generate(
protein,
GenerationConfig(track="sequence", num_steps=8, temperature=0.7),
)
assert isinstance(protein, ESMProtein), f"ESMProtein was expected but got {protein}"
# Folding
protein = get_sample_protein()
sequence_length = len(protein.sequence) # type: ignore
num_steps = int(sequence_length / 16)
protein.coordinates = None
protein.function_annotations = None
protein.sasa = None
folded_protein = client.generate(
protein,
GenerationConfig(track="structure", schedule="cosine", num_steps=num_steps),
)
assert isinstance(
folded_protein, ESMProtein
), f"ESMProtein was expected but got {protein}"
folded_protein.to_pdb("./sample_folded.pdb")
# Inverse folding
protein = get_sample_protein()
protein.sequence = None
protein.sasa = None
protein.function_annotations = None
inv_folded_protein = client.generate(
protein,
GenerationConfig(track="sequence", schedule="cosine", num_steps=num_steps),
)
assert isinstance(inv_folded_protein, ESMProtein)
inv_folded_protein.to_pdb("./sample_inv_folded.pdb")
# Function prediction
protein = get_sample_protein()
protein.function_annotations = None
protein_with_function = client.generate(
protein,
GenerationConfig(track="function", schedule="cosine", num_steps=num_steps),
)
assert isinstance(
protein_with_function, ESMProtein
), f"{protein_with_function} is not an ESMProtein"
# Logits
protein = get_sample_protein()
protein.coordinates = None
protein.function_annotations = None
protein.sasa = None
protein_tensor = client.encode(protein)
logits_output = client.logits(protein_tensor, LogitsConfig(sequence=True))
assert isinstance(
logits_output, LogitsOutput
), f"LogitsOutput was expected but got {logits_output}"
assert (
logits_output.logits is not None and logits_output.logits.sequence is not None
)
# Chain of Thought (Function -> Secondary Structure -> Structure -> Sequence)
cot_protein = get_sample_protein()
cot_protein.sequence = "_" * len(cot_protein.sequence) # type: ignore
cot_protein.coordinates = None
cot_protein.sasa = None
cot_protein_tensor = client.encode(cot_protein)
for cot_track in ["secondary_structure", "structure", "sequence"]:
cot_protein_tensor = client.generate(
cot_protein_tensor,
GenerationConfig(track=cot_track, schedule="cosine", num_steps=10),
)
assert isinstance(
cot_protein_tensor, ESMProteinTensor
), f"ESMProteinTensor was expected but got {cot_protein_tensor}"
cot_protein = client.decode(cot_protein_tensor)
assert isinstance(
cot_protein, ESMProtein
), f"ESMProtein was expected but got {cot_protein}"
cot_protein.to_pdb("./sample_cot.pdb")
# Batch examples.
# Batch generation.
prompts = [ESMProtein(sequence=("_" * (10 + 2 * i))) for i in range(5)]
configs = [
GenerationConfig(track="sequence", schedule="cosine", num_steps=(i + 1))
for i in range(5)
]
proteins = client.batch_generate(prompts, configs)
# Batch folding.
# Take the list of proteins batch generated from last step.
configs = [
GenerationConfig(track="structure", schedule="cosine", num_steps=(i + 1))
for i in range(5)
]
# Generate again for the structure track.
proteins = client.batch_generate(proteins, configs)
# Now write sequence and structure to PDB files.
for i, p in enumerate(proteins):
assert isinstance(p, ESMProtein), f"ESMProtein was expected but got {p}"
p.to_pdb(f"./batch_gen_{i}.pdb")
# Batch generation returns ESMProteinError for specific prompts that have issues.
prompts = [ESMProtein(sequence=("_" * (10 + 2 * i))) for i in range(5)]
# Mock error situation. The third prompt has no masks to be sampled.
prompts[2].sequence = "ANTVPYQ"
configs = [
GenerationConfig(track="sequence", schedule="cosine", num_steps=(i + 1))
for i in range(5)
]
proteins = client.batch_generate(prompts, configs)
# Should still get results. But third result is a ESMProteinError.
for i, p in enumerate(proteins):
if i == 2:
assert isinstance(
p, ESMProteinError
), f"ESMProteinError was expected but got {p}"
else:
assert isinstance(p, ESMProtein), f"ESMProtein was expected but got {p}"
if __name__ == "__main__":
main(ESM3.from_pretrained("esm3_sm_open_v1"))