Skip to content

Commit a49170e

Browse files
committed
Add notebook tools
1 parent 10872d0 commit a49170e

File tree

2 files changed

+71
-3
lines changed

2 files changed

+71
-3
lines changed

bootstrap/run.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def init_logs_options_files(exp_dir, resume=None):
5353
Logger(exp_dir, name=logs_name)
5454

5555

56-
def run(path_opts=None):
56+
def run(path_opts=None, train_engine=True, eval_engine=True):
5757
# first call to Options() load the options yaml file from --path_opts command line argument if path_opts=None
5858
Options(path_opts)
5959

@@ -106,19 +106,20 @@ def run(path_opts=None):
106106

107107
# if no training split, evaluate the model on the evaluation split
108108
# (example: $ python main.py --dataset.train_split --dataset.eval_split test)
109-
if not Options()['dataset']['train_split']:
109+
if eval_engine and not Options()['dataset']['train_split']:
110110
engine.eval()
111111

112112
# optimize the model on the training split for several epochs
113113
# (example: $ python main.py --dataset.train_split train)
114114
# if evaluation split, evaluate the model after each epochs
115115
# (example: $ python main.py --dataset.train_split train --dataset.eval_split val)
116-
if Options()['dataset']['train_split']:
116+
if train_engine and Options()['dataset']['train_split']:
117117
engine.train()
118118

119119
finally:
120120
# write profiling results, if enabled
121121
process_profiler(profiler)
122+
return engine
122123

123124

124125
def activate_debugger():

bootstrap/tools.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import os
2+
import sys
3+
import torch
4+
from bootstrap.lib.logger import Logger
5+
from bootstrap.lib.options import Options
6+
from bootstrap.run import run
7+
8+
9+
def reset_instance():
10+
Options._Options__instance = None
11+
Options.__instance = None
12+
Logger._Loger_instance = None
13+
Logger.perf_memory = {}
14+
sys.argv = [sys.argv[0]] # reset command line args
15+
16+
17+
def get_engine(
18+
path_experiment, weights="best_eval_epoch.accuracy_top1", logs_name="tools",
19+
):
20+
reset_instance()
21+
path_yaml = os.path.join(path_experiment, "options.yaml")
22+
opt = Options(path_yaml)
23+
if weights is not None:
24+
opt["exp.resume"] = weights
25+
opt["exp.dir"] = path_experiment
26+
opt["misc.logs_name"] = logs_name
27+
engine = run(train_engine=False, eval_engine=False)
28+
return engine
29+
30+
31+
def item_to_batch(engine, split, item, prepare_batch=True):
32+
batch = engine.dataset[split].collate_fn([item])
33+
if prepare_batch:
34+
batch = engine.model.prepare_batch(batch)
35+
return batch
36+
37+
38+
def apply_item(engine, item, split="eval"):
39+
# item = engine.dataset[split][idx]
40+
engine.model.eval()
41+
batch = item_to_batch(engine, split, item)
42+
with torch.no_grad():
43+
out = engine.model.network(batch)
44+
return out
45+
46+
47+
def load_model_state(engine, path):
48+
"""
49+
engine: bootstran Engine
50+
path: path to model weights
51+
"""
52+
model_state = torch.load(path)
53+
engine.model.load_state_dict(model_state)
54+
55+
56+
def load_epoch(
57+
engine, epoch, exp_dir,
58+
):
59+
path = os.path.join(exp_dir, f"ckpt_epoch_{epoch}_model.pth.tar")
60+
print(path)
61+
load_model_state(engine, path)
62+
63+
64+
def load_last(engine, exp_dir):
65+
path = os.path.join(exp_dir, "ckpt_last_model.pth.tar")
66+
load_model_state(engine, path)
67+

0 commit comments

Comments
 (0)