-
Notifications
You must be signed in to change notification settings - Fork 9
/
train.py
63 lines (45 loc) · 2.58 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
""" Minimal Training Script for Associative Domain Adaptation
"""
import os
import os.path as osp
import sys
import argparse
import torch
import solver, models, data
def build_parser():
parser = argparse.ArgumentParser(description='Associative Domain Adaptation')
# General setup
parser.add_argument('--gpu', default=0, help='Specify GPU', type=int)
parser.add_argument('--cpu', action='store_true', help="Use CPU Training")
parser.add_argument('--log', default="./log/log2", help="Log directory. Will be created if non-existing")
parser.add_argument('--epochs', default="1000", help="Number of Epochs (Full passes through the unsupervised training set)", type=int)
parser.add_argument('--checkpoint', default="", help="Checkpoint path")
parser.add_argument('--learningrate', default=3e-4, type=float, help="Learning rate for Adam. Defaults to Karpathy's constant ;-)")
# Domain Adaptation Args
parser.add_argument('--source', default="svhn", choices=['mnist', 'svhn'], help="Source Dataset. Choose mnist or svhn")
parser.add_argument('--target', default="mnist", choices=['mnist', 'svhn'], help="Target Dataset. Choose mnist or svhn")
parser.add_argument('--sourcebatch', default=100, type=int, help="Batch size of Source")
parser.add_argument('--targetbatch', default=1000, type=int, help="Batch size of Target")
# Associative DA Hyperparams
parser.add_argument('--visit', default=0.1, type=float, help="Visit weight")
parser.add_argument('--walker', default=1.0, type=float, help="Walker weight")
return parser
if __name__ == '__main__':
parser = build_parser()
args = parser.parse_args()
# Network
if osp.exists(args.checkpoint):
print("Resume from checkpoint file at {}".format(args.checkpoint))
model = torch.load(args.checkpoint)
else:
model = models.FrenchModel()
# Adam optimizer, with amsgrad enabled
optim = torch.optim.Adam(model.parameters(), lr=args.learningrate, betas=(0.5, 0.999), amsgrad=True)
# Dataset
datasets = data.load_dataset(path="data", train=True)
train_loader = torch.utils.data.DataLoader(datasets[args.source], batch_size=args.sourcebatch, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(datasets[args.target], batch_size=args.targetbatch, shuffle=True, num_workers=4)
os.makedirs(args.log, exist_ok=True)
solver.fit(model, optim, (train_loader, val_loader), n_epochs=args.epochs,
savedir=args.log, visit_weight=args.visit, walker_weight=args.walker,
cuda=None if args.cpu else args.gpu)