-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_compressor.py
85 lines (72 loc) · 3.45 KB
/
train_compressor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from typing import Any, Dict, Mapping, Tuple
from torch import nn
from torchdistill.common.constant import def_logger
from torchdistill.datasets import util
from misc.eval import evaluate_accuracy, get_eval_metric
from misc.train_util import get_eval_metrics
from misc.util import calc_compression_module_sizes, load_model
from model.franken_net import CQHybridFrankenNet
from train import train, train_main
logger = def_logger.getChild(__name__)
def _load_models(models_config: Dict[str, Any], device: str, skip_ckpt: bool) -> Tuple[CQHybridFrankenNet, nn.Module]:
assert "student_model" in models_config and "teacher_model" in models_config, "Invalid models config"
student_model = load_model(models_config["student_model"],
device,
skip_ckpt=skip_ckpt)
teacher_model = load_model(models_config['teacher_model'],
device,
skip_ckpt=False)
return student_model, teacher_model
def _train_compressor(config: Mapping[str, Any], args: Any):
models_config = config["models"]
device = args.device
student, teacher = _load_models(models_config, device, args.skip_ckpt)
ckpt_file_path = models_config["student_model"]["ckpt"]
summary_str, _, _ = calc_compression_module_sizes(
bnet_injected_model=student,
device=device,
input_size=(1, 3, 224, 224))
logger.info(summary_str)
datasets_config = config['datasets']
dataset_dict = util.get_all_datasets(datasets_config)
if not args.test_only:
eval_metrics = get_eval_metrics(config["train"])
train(teacher_model=teacher,
student_model=student,
dataset_dict=dataset_dict,
ckpt_file_path=ckpt_file_path,
device=device,
train_config=config["train"],
eval_metrics=eval_metrics,
args=args,
apply_aux_loss=True)
test_config = config['test']
test_data_loader_config = test_config['test_data_loader']
test_data_loader = util.build_data_loader(dataset_dict[test_data_loader_config['dataset_id']],
test_data_loader_config,
distributed=False)
log_freq = test_config.get('log_freq', 1000)
# check if test has multiple datasets
metrics = test_config.get("eval_metrics")
# if args.eval_teacher:
# evaluate_accuracy(teacher,
# data_loader=test_data_loader,
# device=device,
# device_ids=None,
# distributed=False,
# log_freq=log_freq,
# title="[Teacher: ]")
#
# for metric in metrics:
# get_eval_metric(metric).eval_func(student,
# data_loader=test_data_loader,
# device=device,
# device_ids=None,
# distributed=False,
# log_freq=log_freq,
# title="[Student: ]",
# test_mode=True,
# use_hnetwork=True)
#
if __name__ == "__main__":
train_main(description="Train Compressor", task="train_compressor", train_func=_train_compressor)