-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathpretrain.py
56 lines (42 loc) · 1.48 KB
/
pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import sys
import math
import pprint
import torch
import torch_geometric.data
from torchvision import datasets
from torchdrug import core, models
from torchdrug.utils import comm
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import util
from s3f import dataset, model, task, gvp
def train_and_validate(cfg, solver):
if cfg.train.num_epoch == 0:
return
step = math.ceil(cfg.train.num_epoch / 50)
best_result = float("-inf")
best_epoch = -1
for i in range(0, cfg.train.num_epoch, step):
kwargs = cfg.train.copy()
kwargs["num_epoch"] = min(step, cfg.train.num_epoch - i)
solver.train(**kwargs)
solver.save("model_epoch_%d.pth" % solver.epoch)
metric = solver.evaluate("valid")
result = metric[cfg.metric]
if result > best_result:
best_result = result
best_epoch = solver.epoch
solver.load("model_epoch_%d.pth" % best_epoch)
return solver
if __name__ == "__main__":
args, vars = util.parse_args()
cfg = util.load_config(args.config, context=vars)
working_dir = util.create_working_directory(cfg)
torch.manual_seed(args.seed + comm.get_rank())
logger = util.get_root_logger()
if comm.get_rank() == 0:
logger.warning("Config file: %s" % args.config)
logger.warning(pprint.pformat(cfg))
dataset = core.Configurable.load_config_dict(cfg.dataset)
solver = util.build_solver(cfg, dataset)
train_and_validate(cfg, solver)