-
Notifications
You must be signed in to change notification settings - Fork 0
/
server.py
153 lines (114 loc) · 5.53 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import io
import json
import base64
import logging
import numpy as np
from PIL import Image
from flask import Flask, request, jsonify, abort, make_response
import argparse
from UGATIT import UGATIT
from utils import *
app = Flask(__name__)
app.logger.setLevel(logging.DEBUG)
def parse_args():
desc = "Tensorflow implementation of U-GAT-IT"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--phase', type=str, default='runner', help='[train / test / web / runner]')
parser.add_argument('--light', type=str2bool, default=False,
help='[U-GAT-IT full version / U-GAT-IT light version]')
parser.add_argument('--dataset', type=str, default='selfie2anime', help='dataset_name')
parser.add_argument('--epoch', type=int, default=100, help='The number of epochs to run')
parser.add_argument('--iteration', type=int, default=10000, help='The number of training iterations')
parser.add_argument('--batch_size', type=int, default=1, help='The size of batch size')
parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freq')
parser.add_argument('--save_freq', type=int, default=1000, help='The number of ckpt_save_freq')
parser.add_argument('--decay_flag', type=str2bool, default=True, help='The decay_flag')
parser.add_argument('--decay_epoch', type=int, default=50, help='decay epoch')
parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate')
parser.add_argument('--GP_ld', type=int, default=10, help='The gradient penalty lambda')
parser.add_argument('--adv_weight', type=int, default=1, help='Weight about GAN')
parser.add_argument('--cycle_weight', type=int, default=10, help='Weight about Cycle')
parser.add_argument('--identity_weight', type=int, default=10, help='Weight about Identity')
parser.add_argument('--cam_weight', type=int, default=1000, help='Weight about CAM')
parser.add_argument('--gan_type', type=str, default='lsgan',
help='[gan / lsgan / wgan-gp / wgan-lp / dragan / hinge]')
parser.add_argument('--smoothing', type=str2bool, default=True, help='AdaLIN smoothing effect')
parser.add_argument('--ch', type=int, default=64, help='base channel number per layer')
parser.add_argument('--n_res', type=int, default=4, help='The number of resblock')
parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer')
parser.add_argument('--n_critic', type=int, default=1, help='The number of critic')
parser.add_argument('--sn', type=str2bool, default=True, help='using spectral norm')
parser.add_argument('--img_size', type=int, default=256, help='The size of image')
parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')
parser.add_argument('--augment_flag', type=str2bool, default=True, help='Image augmentation use or not')
parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
help='Directory name to save the checkpoints')
parser.add_argument('--result_dir', type=str, default='results',
help='Directory name to save the generated images')
parser.add_argument('--log_dir', type=str, default='logs',
help='Directory name to save training logs')
parser.add_argument('--sample_dir', type=str, default='samples',
help='Directory name to save the samples on training')
return check_args(parser.parse_args())
"""checking arguments"""
def check_args(args):
# --checkpoint_dir
check_folder(args.checkpoint_dir)
# --result_dir
check_folder(args.result_dir)
# --result_dir
check_folder(args.log_dir)
# --sample_dir
check_folder(args.sample_dir)
# --epoch
try:
assert args.epoch >= 1
except:
print('number of epochs must be larger than or equal to one')
# --batch_size
try:
assert args.batch_size >= 1
except:
print('batch size must be larger than or equal to one')
return args
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, device_count = {'GPU': 0}))
args = parse_args()
gan = UGATIT(sess, args)
# build graph
gan.build_model()
# show network architecture
show_all_variables()
gan.test_endpoint_init()
@app.route("/selfie2anime", methods=['POST'])
def selfie2anime():
file = request.files['file']
# convert string of image data to uint8
nparr = np.fromfile(file, np.uint8)
# decode image
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# parse arguments
args = parse_args()
if args is None:
exit()
# open session
# with tf.Session(config=tf.ConfigProto(allow_soft_placement=True), device_count = {'GPU': 0}) as sess:
# gan = UGATIT(sess, args)
# # build graph
# gan.build_model()
# # show network architecture
# show_all_variables()
# do some fancy processing here....
fake_img = gan.test_endpoint(img)
# save the file with to our photos folder
# filename = str(uuid.uuid1()) + '.png'
# cv2.imwrite('uploads/' + filename, fake_img)
# # append image urls
# file_urls.append(photos.url(filename))
retval, buffer = cv2.imencode('.png', fake_img)
response = make_response(buffer.tobytes())
response.headers['Content-Type'] = 'image/png'
return response
def run_server_api():
app.run(host='0.0.0.0', port=8080)
if __name__ == "__main__":
run_server_api()