-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun.py
30 lines (24 loc) · 984 Bytes
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# -*- coding: utf-8 -*-
# @Time : 18-8-16 下午4:47
# @Author : zhangmr
# @File : run.py
from generator import *
from SAE import SAE
from utils import *
def main():
# train
feature_extractor = FeatureExtractor("/home/zhangmr/project/my/zsl/teamwork/resource/simple_feature_extractor.pb")
#train_generator = ImgGenerator("/data/mydata/zeroshot/DatasetA_train_20180813", feature_extractor)
#params = dict()
#params["lambdas"] = [0.9, 0.95, 1, 1.05, 1.1]
#model = SAE(train_generator, params, load_weights=False)
#model.train()
# test
test_generator = ImgGenerator("/data/mydata/zeroshot/DatasetA_test_20180813", feature_extractor, is_train=False)
imgs = test_generator.x
features = test_generator.xf
model = SAE(test_generator, None, load_weights=True)
predictions, _ = model.predict(features)
submit(imgs, predictions,"/home/zhangmr/project/my/zsl/teamwork/resource/submit.txt")
if __name__=='__main__':
main()