-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathevaluate_sweep.py
162 lines (142 loc) · 4.54 KB
/
evaluate_sweep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import argparse
import os
from pathlib import Path
import torch
import wandb
from _util import get_datamodule
from diffusion_hopping.analysis.evaluate import Evaluator
from diffusion_hopping.model import DiffusionHoppingModel
from diffusion_hopping.util import disable_obabel_and_rdkit_logging
from evaluate_model import evaluate_molecules, generate_molecules
def setup_model_and_data_module(
checkpoint, dataset_name, device="cpu", output_path=None
):
checkpoint_folder = Path("artifacts") / checkpoint.name
if not checkpoint_folder.exists():
checkpoint_folder = checkpoint.download()
else:
print("Checkpoint already downloaded")
checkpoint_path = Path(checkpoint_folder) / "model.ckpt"
model = DiffusionHoppingModel.load_from_checkpoint(
checkpoint_path, map_location=device
).to(device)
data_module = get_datamodule(dataset_name, batch_size=32)
return model, data_module
def main():
parser = argparse.ArgumentParser(
prog="evaluate_sweep.py",
description="Evaluate sweep",
epilog="Example: python evaluate_sweep.py bwgidbfw pdbbind_filtered",
)
parser.add_argument(
"sweep_id",
type=str,
help="Sweep id of the sweep to evaluate",
)
parser.add_argument(
"dataset",
type=str,
help="Dataset to evaluate on",
# choices=get_data_module_choices(),
)
parser.add_argument(
"--mode",
type=str,
help="Mode to evaluate",
choices=["ground_truth", "ligand_generation", "inpaint_generation", "all"],
default="all",
)
parser.add_argument(
"--only_generation",
action="store_true",
help="Only generate molecules, do not evaluate them",
)
parser.add_argument(
"--only_evaluation",
action="store_true",
help="Only evaluate molecules, do not generate them",
)
parser.add_argument(
"--r",
type=int,
help="Number of resampling steps when using inpainting",
default=10,
)
parser.add_argument(
"--j",
type=int,
help="Jump length when using inpainting",
default=10,
)
parser.add_argument(
"--limit_samples",
type=int,
help="Limit the number of samples to evaluate",
default=None,
)
parser.add_argument(
"--molecules_per_pocket",
type=int,
help="Number of molecules to generate per pocket",
default=3,
)
parser.add_argument(
"--batch_size",
type=int,
help="Batch size for generation",
default=32,
)
args = parser.parse_args()
mode = args.mode
do_generation = not args.only_evaluation
do_evaluation = not args.only_generation
r = args.r
j = args.j
limit_samples = args.limit_samples
molecules_per_pocket = args.molecules_per_pocket
batch_size = args.batch_size
device = "cuda" if torch.cuda.is_available() else "cpu"
dataset_name = args.dataset
sweep_id = args.sweep_id
api = wandb.Api()
sweep = api.sweep(f"{os.environ['WANDB_PROJECT']}/{sweep_id}")
output_path = Path("evaluation") / sweep.name / dataset_name
output_path.mkdir(parents=True, exist_ok=True)
best_run = sweep.best_run()
artifacts = best_run.logged_artifacts()
models = [artifact for artifact in artifacts if artifact.type == "model"]
loss_models = [
model
for model in models
if True or model.metadata["ModelCheckpoint"]["monitor"] == "loss/val"
]
loss_models = sorted(loss_models, key=lambda x: x.metadata["score"])
artifact = loss_models[0]
print(f"Using best model: {artifact.name} with score {artifact.metadata['score']}")
disable_obabel_and_rdkit_logging()
print("Running on artifact:", artifact.name)
best_config = best_run.config
print("Best config:")
for key, value in best_config.items():
print(f"> {key}: {value}")
model, data_module = setup_model_and_data_module(
artifact, dataset_name, device=device
)
evaluator = Evaluator(output_path)
evaluator.load_data_module(data_module)
evaluator.load_model(model)
if do_generation:
generate_molecules(
evaluator,
output_path,
mode=mode,
r=r,
j=j,
limit_samples=limit_samples,
molecules_per_pocket=molecules_per_pocket,
batch_size=batch_size,
)
if do_evaluation:
evaluate_molecules(evaluator, output_path, mode=mode)
if __name__ == "__main__":
main()