-
Notifications
You must be signed in to change notification settings - Fork 17
/
office_run.py
78 lines (69 loc) · 2.31 KB
/
office_run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
torch.multiprocessing.set_sharing_strategy('file_system')
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from dataset import * # init_dataset
from model import *
from init_config import *
from easydict import EasyDict as edict
import sys
import trainer
import time, datetime
import copy
import numpy as np
import random
import importlib
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))
domain_list = {}
domain_list['office'] = ['amazon', 'webcam', 'dslr']
def main():
cudnn.enabled = True
cudnn.benchmark = True
config, writer = init_config("config/office.yml", sys.argv)
Param = importlib.import_module('trainer.{}{}_trainer'.format(config.trainer, config.version))
if config.setting=='uda':
config.cls_share = 10
config.cls_src = 10
config.cls_total = 31
elif config.setting=='osda':
config.cls_share = 10
config.cls_src = 10
config.cls_total = 31
elif config.setting=='pda':
config.cls_share = 10
config.cls_src = 21
config.cls_total = 31
config.num_classes = config.cls_share + config.cls_src
config.uk_index=config.cls_share + config.cls_src
a,b,c = config.cls_share, config.cls_src, config.cls_total
c = c-a-b
share_classes = [i for i in range(a)]
source_classes = [a+i for i in range(b)]
target_classes = [a+b+i for i in range(c)]
if config.setting=='osda':
source_classes = []
config.share_classes = share_classes
config.source_classes = share_classes + source_classes
config.target_classes = share_classes + target_classes
if not config.transfer_all:
trainer = Param.Trainer(config, writer)
trainer.train()
else:
transfer_list = []
domains = domain_list[config.task]
for src in domains:
for tgt in domains:
if src != tgt:
transfer_list.append((src, tgt))
print(transfer_list)
for src, tgt in transfer_list:
print('{}-->{}'.format(src, tgt))
config.source = src
config.target = tgt
trainer = Param.Trainer(config, writer)
trainer.train()
if __name__ == "__main__":
main()