-
Notifications
You must be signed in to change notification settings - Fork 84
/
Copy pathrun_sweep.py
126 lines (114 loc) · 4.4 KB
/
run_sweep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
import argparse
import os
import sys
import subprocess
def add_poprun_arguments(parser):
# shared application arguments
parser.add_argument("--config", default=None, type=str, required=True)
parser.add_argument("--num_epochs", default=None, type=int, required=True)
parser.add_argument("--init_lr", default=None, type=float, required=True)
parser.add_argument("--end_lr_ratio", default=None, type=float, required=True)
parser.add_argument("--warmup_epochs", default=None, type=int, required=True)
parser.add_argument("--bn_momentum", default=None, type=float, required=True)
parser.add_argument("--label_smoothing", default=None, type=float, required=True)
parser.add_argument("--opt_momentum", default=None, type=float, required=True)
# SGD specific
parser.add_argument("--l2", default=None, type=float, required=False)
# LARS specific
parser.add_argument("--lars_weight_decay", default=None, type=float, required=False)
parser.add_argument("--lars_eeta", default=None, type=float, required=False)
return parser
if __name__ == "__main__":
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = add_poprun_arguments(parser)
args = parser.parse_args()
hosts = os.environ.get("POPRUN_HOSTS")
vipu_host = os.environ.get("IPUOF_VIPU_API_HOST")
vipu_partition = os.environ.get("IPUOF_VIPU_API_PARTITION_ID")
if None in [hosts, vipu_host, vipu_partition]:
raise ValueError(
f"The following environment variables must be defined: "
f"POPRUN_HOSTS={hosts}\n"
f"IPUOF_VIPU_API_HOST={vipu_host}\n"
f"IPUOF_VIPU_API_PARTITION_ID={vipu_partition}"
)
user = os.environ["USER"]
exec_cache = os.environ.get("TF_POPLAR_EXEC_CACHE") or os.path.join("/home", user, "exec_cache")
poprun_command = [
"poprun",
"-vv",
"--host",
hosts,
"--only-output-from-instance",
"0",
"--mpi-global-args",
"--mca oob_tcp_if_include eno1 --mca btl_tcp_if_include eno1",
"--update-partition",
"yes",
"--reset-partition",
"no",
"--vipu-server-timeout",
"600",
"--vipu-server-host",
vipu_host,
"--vipu-partition",
vipu_partition,
"--executable-cache-path",
exec_cache,
"--num-instances",
64,
"--num-replicas",
64,
"--ipus-per-replica",
1,
]
training_command = [
*poprun_command,
"python3",
"train.py",
"--config",
args.config,
"--num-epochs",
args.num_epochs,
"--wandb",
"True",
"--target-accuracy",
0.759,
"--ckpts-per-epoch",
1,
"--first-ckpt-epoch",
0,
"--sweep",
"True",
]
# label smoothing
training_command += ["--label-smoothing", args.label_smoothing]
# l2 regularization
training_command += ["--l2-regularization", args.l2] if args.l2 is not None else []
# norm layer
training_command += ["--norm-layer", "{" + f'"name": "custom_batch_norm", "momentum": {args.bn_momentum}' + "}"]
# lr schedule params
lr_schedule_params = [
f'"initial_learning_rate": {args.init_lr}',
f'"end_learning_rate_ratio": {args.end_lr_ratio}',
f'"epochs_to_total_decay": {args.num_epochs - args.warmup_epochs}',
f'"power": 2',
]
training_command += ["--lr-schedule-params", "{" + ",".join(lr_schedule_params) + "}"]
# warmup params
lr_warmup_params = [
f'"warmup_mode": "shift"',
f'"warmup_epochs": {args.warmup_epochs}',
]
training_command += ["--lr-warmup-params", "{" + ",".join(lr_warmup_params) + "}"]
# optimizer params
optimizer_params = [f'"momentum": {args.opt_momentum}']
optimizer_params += [f'"weight_decay": {args.lars_weight_decay}'] if args.lars_weight_decay is not None else []
optimizer_params += [f'"eeta": {args.lars_eeta}', f'"epsilon": 0'] if args.lars_eeta is not None else []
training_command += ["--optimizer-params", "{" + ",".join(optimizer_params) + "}"]
training_command = [str(command) for command in training_command]
print(" ".join(training_command))
# run training
p = subprocess.Popen(training_command, stderr=sys.stderr, stdout=sys.stdout)
p.wait()