-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbaseline.py
41 lines (31 loc) · 1.22 KB
/
baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from fastai.vision import *
import models.cs_v2 as cs
import models.laterals
import pose
import utils
from utils import DataTime
lateral_types = {
'add': models.laterals.conv_add_lateral,
'mul': models.laterals.conv_mul_lateral,
'attention': models.laterals.attention_lateral,
}
def main(args):
print(args)
arch = pose.nets[args.resnet]
n = args.niter
instructor = cs.RecurrentInstructor(n)
pckh = partial(pose.Pckh, niter=n)
loss = pose.RecurrentLoss(n)
root = Path(__file__).resolve().parent.parent / 'LIP'
db = pose.get_data(root, args.size, bs=args.bs)
lateral = lateral_types[args.lateral]
learn = cs.cs_learner(db, arch, instructor, td_c=16, pretrained=False, embedding=None, lateral=lateral,
add_td_out=args.add_td_out, loss_func=loss, callback_fns=[pckh, DataTime])
monitor = f'Total_{n - 1}' if n > 1 else 'Total'
utils.fit_and_log(learn, args, monitor)
if __name__ == '__main__':
parser = utils.basic_train_parser()
parser.add_argument('-n', '--niter', default=1, type=int)
parser.add_argument('--lateral', choices=lateral_types, default='add')
parser.add_argument('--add-td-out', action='store_true')
main(parser.parse_args())