forked from tae898/age-gender
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
92 lines (69 loc) · 2.45 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from flask import Flask, request
import jsonpickle
import logging
import numpy as np
import io
from model.model import ResMLP
import torch
from utils import read_json, forward_mc, enable_dropout
from tqdm import tqdm
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
# a light-weight flask app
app = Flask(__name__)
models = {'age': None, 'gender': None}
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
for model_ in ['age', 'gender']:
model = ResMLP(**read_json(f'./models/{model_}.json')['arch']['args'])
checkpoint = f"models/{model_}.pth"
checkpoint = torch.load(checkpoint, map_location=torch.device(device))
state_dict = checkpoint['state_dict']
model.load_state_dict(state_dict)
model.to(device)
model.eval()
enable_dropout(model)
models[model_] = model
# One endpoint is enough.
@app.route("/", methods=["POST"])
def predict_age_gender():
"""Receive everything in json!!!
"""
app.logger.debug(f"Receiving data ...")
data = request.json
data = jsonpickle.decode(data)
app.logger.debug(f"loading embeddings ...")
embeddings = data['embeddings']
embeddings = io.BytesIO(embeddings)
# This assumes that the client has serialized the embeddings with pickle.
# before sending it to the server.
embeddings = np.load(embeddings, allow_pickle=True)
# -1 accounts for the batch size.
embeddings = embeddings.reshape(-1, 512).astype(np.float32)
app.logger.debug(
f"extracting gender and age from {embeddings.shape[0]} faces ...")
genders = []
ages = []
for embedding in tqdm(embeddings):
embedding = embedding.reshape(1, 512)
gender_mean, gender_entropy = forward_mc(models['gender'], embedding)
age_mean, age_entropy = forward_mc(models['age'], embedding)
gender = {'m': 1 - gender_mean,
'f': gender_mean,
'entropy': gender_entropy}
age = {'mean': age_mean,
'entropy': age_entropy}
genders.append(gender)
ages.append(age)
app.logger.debug(f"gender and age extracted!")
response = {'ages': ages,
'genders': genders}
response_pickled = jsonpickle.encode(response)
app.logger.info("json-pickle is done.")
return response_pickled
if __name__ == '__main__':
app.run(host='0.0.0.0', port=10003)