Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add files via upload #20

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
396 changes: 198 additions & 198 deletions test_background-matting_image.py
Original file line number Diff line number Diff line change
@@ -1,198 +1,198 @@
from __future__ import print_function


import os, glob, time, argparse, pdb, cv2
#import matplotlib.pyplot as plt
import numpy as np
from skimage.measure import label


import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.backends.cudnn as cudnn

from functions import *
from networks import ResnetConditionHR

torch.set_num_threads(1)
#os.environ["CUDA_VISIBLE_DEVICES"]="4"
print('CUDA Device: ' + os.environ["CUDA_VISIBLE_DEVICES"])


"""Parses arguments."""
parser = argparse.ArgumentParser(description='Background Matting.')
parser.add_argument('-m', '--trained_model', type=str, default='real-fixed-cam',choices=['real-fixed-cam', 'real-hand-held', 'syn-comp-adobe'],help='Trained background matting model')
parser.add_argument('-o', '--output_dir', type=str, required=True,help='Directory to save the output results. (required)')
parser.add_argument('-i', '--input_dir', type=str, required=True,help='Directory to load input images. (required)')
parser.add_argument('-tb', '--target_back', type=str,help='Directory to load the target background.')
parser.add_argument('-b', '--back', type=str,default=None,help='Captured background image. (only use for inference on videos with fixed camera')


args=parser.parse_args()

#input model
model_main_dir='Models/' + args.trained_model + '/';
#input data path
data_path=args.input_dir

if os.path.isdir(args.target_back):
args.video=True
print('Using video mode')
else:
args.video=False
print('Using image mode')
#target background path
back_img10=cv2.imread(args.target_back); back_img10=cv2.cvtColor(back_img10,cv2.COLOR_BGR2RGB);
#Green-screen background
back_img20=np.zeros(back_img10.shape); back_img20[...,0]=120; back_img20[...,1]=255; back_img20[...,2]=155;



#initialize network
fo=glob.glob(model_main_dir + 'netG_epoch_*')
model_name1=fo[0]
netM=ResnetConditionHR(input_nc=(3,3,1,4),output_nc=4,n_blocks1=7,n_blocks2=3)
netM=nn.DataParallel(netM)
netM.load_state_dict(torch.load(model_name1))
netM.cuda(); netM.eval()
cudnn.benchmark=True
reso=(512,512) #input reoslution to the network

#load captured background for video mode, fixed camera
if args.back is not None:
bg_im0=cv2.imread(args.back); bg_im0=cv2.cvtColor(bg_im0,cv2.COLOR_BGR2RGB);


#Create a list of test images
test_imgs = [f for f in os.listdir(data_path) if
os.path.isfile(os.path.join(data_path, f)) and f.endswith('_img.png')]
test_imgs.sort()

#output directory
result_path=args.output_dir

if not os.path.exists(result_path):
os.makedirs(result_path)

for i in range(0,len(test_imgs)):
filename = test_imgs[i]
#original image
bgr_img = cv2.imread(os.path.join(data_path, filename)); bgr_img=cv2.cvtColor(bgr_img,cv2.COLOR_BGR2RGB);

if args.back is None:
#captured background image
bg_im0=cv2.imread(os.path.join(data_path, filename.replace('_img','_back'))); bg_im0=cv2.cvtColor(bg_im0,cv2.COLOR_BGR2RGB);

#segmentation mask
rcnn = cv2.imread(os.path.join(data_path, filename.replace('_img','_masksDL')),0);

if args.video: #if video mode, load target background frames
#target background path
back_img10=cv2.imread(os.path.join(args.target_back,filename.replace('_img.png','.png'))); back_img10=cv2.cvtColor(back_img10,cv2.COLOR_BGR2RGB);
#Green-screen background
back_img20=np.zeros(back_img10.shape); back_img20[...,0]=120; back_img20[...,1]=255; back_img20[...,2]=155;

#create multiple frames with adjoining frames
gap=20
multi_fr_w=np.zeros((bgr_img.shape[0],bgr_img.shape[1],4))
idx=[i-2*gap,i-gap,i+gap,i+2*gap]
for t in range(0,4):
if idx[t]<0:
idx[t]=len(test_imgs)+idx[t]
elif idx[t]>=len(test_imgs):
idx[t]=idx[t]-len(test_imgs)

file_tmp=test_imgs[idx[t]]
bgr_img_mul = cv2.imread(os.path.join(data_path, file_tmp));
multi_fr_w[...,t]=cv2.cvtColor(bgr_img_mul,cv2.COLOR_BGR2GRAY);

else:
## create the multi-frame
multi_fr_w=np.zeros((bgr_img.shape[0],bgr_img.shape[1],4))
multi_fr_w[...,0] = cv2.cvtColor(bgr_img,cv2.COLOR_BGR2GRAY);
multi_fr_w[...,1] = multi_fr_w[...,0]
multi_fr_w[...,2] = multi_fr_w[...,0]
multi_fr_w[...,3] = multi_fr_w[...,0]


#crop tightly
bgr_img0=bgr_img;
bbox=get_bbox(rcnn,R=bgr_img0.shape[0],C=bgr_img0.shape[1])

crop_list=[bgr_img,bg_im0,rcnn,back_img10,back_img20,multi_fr_w]
crop_list=crop_images(crop_list,reso,bbox)
bgr_img=crop_list[0]; bg_im=crop_list[1]; rcnn=crop_list[2]; back_img1=crop_list[3]; back_img2=crop_list[4]; multi_fr=crop_list[5]

#process segmentation mask
kernel_er = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
kernel_dil = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
rcnn=rcnn.astype(np.float32)/255; rcnn[rcnn>0.2]=1;
K=25

zero_id=np.nonzero(np.sum(rcnn,axis=1)==0)
del_id=zero_id[0][zero_id[0]>250]
if len(del_id)>0:
del_id=[del_id[0]-2,del_id[0]-1,*del_id]
rcnn=np.delete(rcnn,del_id,0)
rcnn = cv2.copyMakeBorder( rcnn, 0, K + len(del_id), 0, 0, cv2.BORDER_REPLICATE)


rcnn = cv2.erode(rcnn, kernel_er, iterations=10)
rcnn = cv2.dilate(rcnn, kernel_dil, iterations=5)
rcnn=cv2.GaussianBlur(rcnn.astype(np.float32),(31,31),0)
rcnn=(255*rcnn).astype(np.uint8)
rcnn=np.delete(rcnn, range(reso[0],reso[0]+K), 0)


#convert to torch
img=torch.from_numpy(bgr_img.transpose((2, 0, 1))).unsqueeze(0); img=2*img.float().div(255)-1
bg=torch.from_numpy(bg_im.transpose((2, 0, 1))).unsqueeze(0); bg=2*bg.float().div(255)-1
rcnn_al=torch.from_numpy(rcnn).unsqueeze(0).unsqueeze(0); rcnn_al=2*rcnn_al.float().div(255)-1
multi_fr=torch.from_numpy(multi_fr.transpose((2, 0, 1))).unsqueeze(0); multi_fr=2*multi_fr.float().div(255)-1


with torch.no_grad():
img,bg,rcnn_al, multi_fr =Variable(img.cuda()), Variable(bg.cuda()), Variable(rcnn_al.cuda()), Variable(multi_fr.cuda())
input_im=torch.cat([img,bg,rcnn_al,multi_fr],dim=1)

alpha_pred,fg_pred_tmp=netM(img,bg,rcnn_al,multi_fr)

al_mask=(alpha_pred>0.95).type(torch.cuda.FloatTensor)

# for regions with alpha>0.95, simply use the image as fg
fg_pred=img*al_mask + fg_pred_tmp*(1-al_mask)

alpha_out=to_image(alpha_pred[0,...]);

#refine alpha with connected component
labels=label((alpha_out>0.05).astype(int))
try:
assert( labels.max() != 0 )
except:
continue
largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
alpha_out=alpha_out*largestCC

alpha_out=(255*alpha_out[...,0]).astype(np.uint8)

fg_out=to_image(fg_pred[0,...]); fg_out=fg_out*np.expand_dims((alpha_out.astype(float)/255>0.01).astype(float),axis=2); fg_out=(255*fg_out).astype(np.uint8)

#Uncrop
R0=bgr_img0.shape[0];C0=bgr_img0.shape[1]
alpha_out0=uncrop(alpha_out,bbox,R0,C0)
fg_out0=uncrop(fg_out,bbox,R0,C0)

#compose
back_img10=cv2.resize(back_img10,(C0,R0)); back_img20=cv2.resize(back_img20,(C0,R0))
comp_im_tr1=composite4(fg_out0,back_img10,alpha_out0)
comp_im_tr2=composite4(fg_out0,back_img20,alpha_out0)

cv2.imwrite(result_path+'/'+filename.replace('_img','_out'), alpha_out0)
cv2.imwrite(result_path+'/'+filename.replace('_img','_fg'), cv2.cvtColor(fg_out0,cv2.COLOR_BGR2RGB))
cv2.imwrite(result_path+'/'+filename.replace('_img','_compose'), cv2.cvtColor(comp_im_tr1,cv2.COLOR_BGR2RGB))
cv2.imwrite(result_path+'/'+filename.replace('_img','_matte').format(i), cv2.cvtColor(comp_im_tr2,cv2.COLOR_BGR2RGB))


print('Done: ' + str(i+1) + '/' + str(len(test_imgs)))

from __future__ import print_function
import os, glob, time, argparse, pdb, cv2
#import matplotlib.pyplot as plt
import numpy as np
from skimage.measure import label
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
from functions import *
from networks import ResnetConditionHR
torch.set_num_threads(1)
print('CUDA Device: ' + os.environ["CUDA_VISIBLE_DEVICES"])
"""Parses arguments."""
parser = argparse.ArgumentParser(description='Background Matting.')
parser.add_argument('-m', '--trained_model', type=str, default='real-fixed-cam',choices=['real-fixed-cam', 'real-hand-held', 'syn-comp-adobe'],help='Trained background matting model')
parser.add_argument('-o', '--output_dir', type=str, required=True,help='Directory to save the output results. (required)')
parser.add_argument('-i', '--input_dir', type=str, required=True,help='Directory to load input images. (required)')
parser.add_argument('-tb', '--target_back', type=str,help='Directory to load the target background.')
parser.add_argument('-b', '--back', type=str,default=None,help='Captured background image. (only use for inference on videos with fixed camera')
args=parser.parse_args()
#input model
model_main_dir='Models1/' + args.trained_model + '/';
#input data path
data_path=args.input_dir
if os.path.isdir(args.target_back):
args.video=True
print('Using video mode')
else:
args.video=False
print('Using image mode')
#target background path
back_img10=cv2.imread(args.target_back); back_img10=cv2.cvtColor(back_img10,cv2.COLOR_BGR2RGB);
#Green-screen background
back_img20=np.zeros(back_img10.shape); back_img20[...,0]=120; back_img20[...,1]=255; back_img20[...,2]=155;
#initialize network
fo=glob.glob(model_main_dir + 'netG_epoch_*')
model_name1=fo[0]
netM=ResnetConditionHR(input_nc=(3,3,1,4),output_nc=4,n_blocks1=7,n_blocks2=3)
netM=nn.DataParallel(netM)
netM.load_state_dict(torch.load(model_name1))
netM.cuda(); netM.eval()
cudnn.benchmark=True
reso=(512,512) #input reoslution to the network
#load captured background for video mode, fixed camera
if args.back is not None:
bg_im0=cv2.imread(args.back); bg_im0=cv2.cvtColor(bg_im0,cv2.COLOR_BGR2RGB);
#Create a list of test images
test_imgs = [f for f in os.listdir(data_path) if
os.path.isfile(os.path.join(data_path, f)) and f.endswith('_img.png')]
test_imgs.sort()
#output directory
result_path=args.output_dir
if not os.path.exists(result_path):
os.makedirs(result_path)
for i in range(0,len(test_imgs)):
filename = test_imgs[i]
#original image
bgr_img = cv2.imread(os.path.join(data_path, filename)); bgr_img=cv2.cvtColor(bgr_img,cv2.COLOR_BGR2RGB);
if args.back is None:
#captured background image
bg_im0=cv2.imread(os.path.join(data_path, filename.replace('_img','_back'))); bg_im0=cv2.cvtColor(bg_im0,cv2.COLOR_BGR2RGB);
#segmentation mask
rcnn = cv2.imread(os.path.join(data_path, filename.replace('_img','_masksDL')),0);
if args.video: #if video mode, load target background frames
#target background path
back_img10=cv2.imread(os.path.join(args.target_back,filename.replace('_img.png','.png'))); back_img10=cv2.cvtColor(back_img10,cv2.COLOR_BGR2RGB);
#Green-screen background
back_img20=np.zeros(back_img10.shape); back_img20[...,0]=120; back_img20[...,1]=255; back_img20[...,2]=155;
#create multiple frames with adjoining frames
gap=20
multi_fr_w=np.zeros((bgr_img.shape[0],bgr_img.shape[1],4))
idx=[i-2*gap,i-gap,i+gap,i+2*gap]
for t in range(0,4):
if idx[t]<0:
idx[t]=len(test_imgs)+idx[t]
elif idx[t]>=len(test_imgs):
idx[t]=idx[t]-len(test_imgs)
file_tmp=test_imgs[idx[t]]
bgr_img_mul = cv2.imread(os.path.join(data_path, file_tmp));
multi_fr_w[...,t]=cv2.cvtColor(bgr_img_mul,cv2.COLOR_BGR2GRAY);
else:
## create the multi-frame
multi_fr_w=np.zeros((bgr_img.shape[0],bgr_img.shape[1],4))
multi_fr_w[...,0] = cv2.cvtColor(bgr_img,cv2.COLOR_BGR2GRAY);
multi_fr_w[...,1] = multi_fr_w[...,0]
multi_fr_w[...,2] = multi_fr_w[...,0]
multi_fr_w[...,3] = multi_fr_w[...,0]
#crop tightly
bgr_img0=bgr_img;
bbox=get_bbox(rcnn,R=bgr_img0.shape[0],C=bgr_img0.shape[1])
crop_list=[bgr_img,bg_im0,rcnn,back_img10,back_img20,multi_fr_w]
crop_list=crop_images(crop_list,reso,bbox)
bgr_img=crop_list[0]; bg_im=crop_list[1]; rcnn=crop_list[2]; back_img1=crop_list[3]; back_img2=crop_list[4]; multi_fr=crop_list[5]
#process segmentation mask
kernel_er = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
kernel_dil = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
rcnn=rcnn.astype(np.float32)/255; rcnn[rcnn>0.2]=1;
K=25
zero_id=np.nonzero(np.sum(rcnn,axis=1)==0)
del_id=zero_id[0][zero_id[0]>250]
if len(del_id)>0:
del_id=[del_id[0]-2,del_id[0]-1,*del_id]
rcnn=np.delete(rcnn,del_id,0)
rcnn = cv2.copyMakeBorder( rcnn, 0, K + len(del_id), 0, 0, cv2.BORDER_REPLICATE)
rcnn = cv2.erode(rcnn, kernel_er, iterations=10)
rcnn = cv2.dilate(rcnn, kernel_dil, iterations=5)
rcnn=cv2.GaussianBlur(rcnn.astype(np.float32),(31,31),0)
rcnn=(255*rcnn).astype(np.uint8)
rcnn=np.delete(rcnn, range(reso[0],reso[0]+K), 0)
#convert to torch
img=torch.from_numpy(bgr_img.transpose((2, 0, 1))).unsqueeze(0); img=2*img.float().div(255)-1
bg=torch.from_numpy(bg_im.transpose((2, 0, 1))).unsqueeze(0); bg=2*bg.float().div(255)-1
rcnn_al=torch.from_numpy(rcnn).unsqueeze(0).unsqueeze(0); rcnn_al=2*rcnn_al.float().div(255)-1
multi_fr=torch.from_numpy(multi_fr.transpose((2, 0, 1))).unsqueeze(0); multi_fr=2*multi_fr.float().div(255)-1
with torch.no_grad():
img,bg,rcnn_al, multi_fr =Variable(img.cuda()), Variable(bg.cuda()), Variable(rcnn_al.cuda()), Variable(multi_fr.cuda())
input_im=torch.cat([img,bg,rcnn_al,multi_fr],dim=1)
alpha_pred,fg_pred_tmp=netM(img,bg,rcnn_al,multi_fr)
al_mask=(alpha_pred>0.95).type(torch.cuda.FloatTensor)
# for regions with alpha>0.95, simply use the image as fg
fg_pred=img*al_mask + fg_pred_tmp*(1-al_mask)
alpha_out=to_image(alpha_pred[0,...]);
#refine alpha with connected component
labels=label((alpha_out>0.05).astype(int))
try:
assert( labels.max() != 0 )
except:
continue
largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
alpha_out=alpha_out*largestCC
alpha_out=(255*alpha_out[...,0]).astype(np.uint8)
fg_out=to_image(fg_pred[0,...]); fg_out=fg_out*np.expand_dims((alpha_out.astype(float)/255>0.01).astype(float),axis=2); fg_out=(255*fg_out).astype(np.uint8)
#Uncrop
R0=bgr_img0.shape[0];C0=bgr_img0.shape[1]
alpha_out0=uncrop(alpha_out,bbox,R0,C0)
fg_out0=uncrop(fg_out,bbox,R0,C0)
#compose
back_img10=cv2.resize(back_img10,(C0,R0)); back_img20=cv2.resize(back_img20,(C0,R0))
comp_im_tr1=composite4(fg_out0,back_img10,alpha_out0)
comp_im_tr2=composite4(fg_out0,back_img20,alpha_out0)
cv2.imwrite(result_path+'/'+filename.replace('_img','_out'), alpha_out0)
cv2.imwrite(result_path+'/'+filename.replace('_img','_fg'), cv2.cvtColor(fg_out0,cv2.COLOR_BGR2RGB))
cv2.imwrite(result_path+'/'+filename.replace('_img','_compose'), cv2.cvtColor(comp_im_tr1,cv2.COLOR_BGR2RGB))
cv2.imwrite(result_path+'/'+filename.replace('_img','_matte').format(i), cv2.cvtColor(comp_im_tr2,cv2.COLOR_BGR2RGB))
print('Done: ' + str(i+1) + '/' + str(len(test_imgs)))