forked from Zeyi-Lin/HivisionIDPhotos
-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
149 lines (130 loc) · 4.37 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import cv2
import argparse
import numpy as np
from hivision.error import FaceError
from hivision.utils import hex_to_rgb, resize_image_to_kb, add_background
from hivision import IDCreator
from hivision.creator.layout_calculator import (
generate_layout_photo,
generate_layout_image,
)
from hivision.creator.choose_handler import choose_handler
INFERENCE_TYPE = [
"idphoto",
"human_matting",
"add_background",
"generate_layout_photos",
]
MATTING_MODEL = [
"hivision_modnet",
"modnet_photographic_portrait_matting",
"mnn_hivision_modnet",
"rmbg-1.4",
"birefnet-v1-lite",
]
FACE_DETECT_MODEL = [
"mtcnn",
"face_plusplus",
"retinaface-resnet50",
]
RENDER = [0, 1, 2]
parser = argparse.ArgumentParser(description="HivisionIDPhotos 证件照制作推理程序。")
parser.add_argument(
"-t",
"--type",
help="请求 API 的种类",
choices=INFERENCE_TYPE,
default="idphoto",
)
parser.add_argument("-i", "--input_image_dir", help="输入图像路径", required=True)
parser.add_argument("-o", "--output_image_dir", help="保存图像路径", required=True)
parser.add_argument("--height", help="证件照尺寸-高", default=413)
parser.add_argument("--width", help="证件照尺寸-宽", default=295)
parser.add_argument("-c", "--color", help="证件照背景色", default="638cce")
parser.add_argument(
"-k", "--kb", help="输出照片的 KB 值,仅对换底和制作排版照生效", default=None
)
parser.add_argument(
"--matting_model",
help="抠图模型权重",
default="hivision_modnet",
choices=MATTING_MODEL,
)
parser.add_argument(
"-r",
"--render",
type=int,
help="底色合成的模式,有 0:纯色、1:上下渐变、2:中心渐变 可选",
choices=RENDER,
default=0,
)
parser.add_argument(
"--face_detect_model",
help="人脸检测模型",
default="mtcnn",
choices=FACE_DETECT_MODEL,
)
args = parser.parse_args()
# ------------------- 选择抠图与人脸检测模型 -------------------
creator = IDCreator()
choose_handler(creator, args.matting_model, args.face_detect_model)
root_dir = os.path.dirname(os.path.abspath(__file__))
input_image = cv2.imread(args.input_image_dir, cv2.IMREAD_UNCHANGED)
# 如果模式是生成证件照
if args.type == "idphoto":
# 将字符串转为元组
size = (int(args.height), int(args.width))
try:
result = creator(input_image, size=size)
except FaceError:
print("人脸数量不等于 1,请上传单张人脸的图像。")
else:
# 保存标准照
cv2.imwrite(args.output_image_dir, result.standard)
# 保存高清照
file_name, file_extension = os.path.splitext(args.output_image_dir)
new_file_name = file_name + "_hd" + file_extension
cv2.imwrite(new_file_name, result.hd)
# 如果模式是人像抠图
elif args.type == "human_matting":
result = creator(input_image, change_bg_only=True)
cv2.imwrite(args.output_image_dir, result.hd)
# 如果模式是添加背景
elif args.type == "add_background":
render_choice = ["pure_color", "updown_gradient", "center_gradient"]
# 将字符串转为元组
color = hex_to_rgb(args.color)
# 将元祖的 0 和 2 号数字交换
color = (color[2], color[1], color[0])
result_image = add_background(
input_image, bgr=color, mode=render_choice[args.render]
)
result_image = result_image.astype(np.uint8)
if args.kb:
result_image = cv2.cvtColor(result_image, cv2.COLOR_RGB2BGR)
result_image = resize_image_to_kb(
result_image, args.output_image_dir, int(args.kb)
)
else:
cv2.imwrite(args.output_image_dir, result_image)
# 如果模式是生成排版照
elif args.type == "generate_layout_photos":
size = (int(args.height), int(args.width))
typography_arr, typography_rotate = generate_layout_photo(
input_height=size[0], input_width=size[1]
)
result_layout_image = generate_layout_image(
input_image,
typography_arr,
typography_rotate,
height=size[0],
width=size[1],
)
if args.kb:
result_layout_image = cv2.cvtColor(result_layout_image, cv2.COLOR_RGB2BGR)
result_layout_image = resize_image_to_kb(
result_layout_image, args.output_image_dir, int(args.kb)
)
else:
cv2.imwrite(args.output_image_dir, result_layout_image)