-
Notifications
You must be signed in to change notification settings - Fork 90
/
Copy pathsample_selfie_segmentation.py
110 lines (82 loc) · 3.48 KB
/
sample_selfie_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import copy
import argparse
import cv2 as cv
import numpy as np
import mediapipe as mp
from utils import CvFpsCalc
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--width", help='cap width', type=int, default=960)
parser.add_argument("--height", help='cap height', type=int, default=540)
parser.add_argument("--model_selection",
help='model_selection',
type=int,
default=0)
parser.add_argument("--score_th",
help='score threshold',
type=float,
default=0.1)
parser.add_argument("--bg_path",
help='back ground image path',
type=str,
default=None)
args = parser.parse_args()
return args
def main():
# 引数解析 #################################################################
args = get_args()
cap_device = args.device
cap_width = args.width
cap_height = args.height
model_selection = args.model_selection
score_th = args.score_th
if args.bg_path is not None:
bg_image = cv.imread(args.bg_path)
else:
bg_image = None
# カメラ準備 ###############################################################
cap = cv.VideoCapture(cap_device)
cap.set(cv.CAP_PROP_FRAME_WIDTH, cap_width)
cap.set(cv.CAP_PROP_FRAME_HEIGHT, cap_height)
# モデルロード #############################################################
mp_selfie_segmentation = mp.solutions.selfie_segmentation
selfie_segmentation = mp_selfie_segmentation.SelfieSegmentation(
model_selection=model_selection)
# FPS計測モジュール ########################################################
cvFpsCalc = CvFpsCalc(buffer_len=10)
while True:
display_fps = cvFpsCalc.get()
# カメラキャプチャ #####################################################
ret, image = cap.read()
if not ret:
break
image = cv.flip(image, 1) # ミラー表示
debug_image = copy.deepcopy(image)
# 検出実施 #############################################################
image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
results = selfie_segmentation.process(image)
# 描画 ################################################################
mask = np.stack((results.segmentation_mask, ) * 3, axis=-1) >= score_th
if bg_image is None:
bg_resize_image = np.zeros(image.shape, dtype=np.uint8)
bg_resize_image[:] = (0, 255, 0)
else:
bg_resize_image = cv.resize(bg_image,
(image.shape[1], image.shape[0]))
debug_image = np.where(mask, debug_image, bg_resize_image)
cv.putText(debug_image, "FPS:" + str(display_fps), (10, 30),
cv.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2,
cv.LINE_AA)
# キー処理(ESC:終了) #################################################
key = cv.waitKey(1)
if key == 27: # ESC
break
# 画面反映 #############################################################
cv.imshow('MediaPipe Selfie Segmentation Demo', debug_image)
cap.release()
cv.destroyAllWindows()
if __name__ == '__main__':
main()