-
Notifications
You must be signed in to change notification settings - Fork 5
/
person_remover.py
67 lines (52 loc) · 2.59 KB
/
person_remover.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from detector.model_deeplab import Detector, AVAILABLE_MODELS
from inpainter.model import Inpainter
from libs.data_retriever import IMG_EXTENSIONS
import torch
from libs.utils import save_batch, crop_center, read_image
import os
import argparse
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--image-input-path',
default='Datasets/remove_people/',
type=str,
help='The path to the directory where images are saved')
parser.add_argument('-o', '--image-output-path',
type=str,
default='./output/',
help='The path of the output photos')
parser.add_argument('-dm', '--detector-model',
type=str,
default='deeplab',
help=F'Detector model name. It has to been one of {", ".join(AVAILABLE_MODELS)}')
parser.add_argument('-e', '--encoder',
type=str,
default='resnet50dilated',
help='Encoder name. Only valid when detector model is MITCSAIL (default)')
parser.add_argument('-d', '--decoder',
type=str,
default='ppm_deepsup',
help='Decoder name. Only valid when detector model is MITCSAIL (default)')
parser.add_argument('-ob', '--objects', nargs='+', type=str, default=['person'])
def main(FLAGS):
if FLAGS.image_input_path == FLAGS.image_output_path:
raise Exception('Input and output directories cannot be the same')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Prepare models
detector = Detector(FLAGS.detector_model, encoder=FLAGS.encoder, decoder=FLAGS.decoder, object_names=FLAGS.objects)
inpainter = Inpainter(mode='try', checkpoint_dir='inpainter/weights/')
for file in tqdm(os.listdir(FLAGS.image_input_path)):
if file.lower().endswith(IMG_EXTENSIONS):
input_file = FLAGS.image_input_path + file
image = read_image(input_file, device)
mask = detector(image)
torch.cuda.empty_cache()
image_inpaint = image * mask
output_inpaint = inpainter(image_inpaint, mask)
output_inpaint = crop_center(output_inpaint, image.shape[-1], image.shape[-2])
final_image = image_inpaint + (1 - mask) * output_inpaint
save_batch(final_image.detach().cpu().numpy(), [file], FLAGS.image_output_path)
del image_inpaint, image, mask, final_image
if __name__ == '__main__':
FLAGS, unparsed = parser.parse_known_args()
main(FLAGS)