Skip to content

Commit

Permalink
Update predict.py
Browse files Browse the repository at this point in the history
1, switched the location of width and height in line 35 because the first dimension in img.shape should be row(height). Also changed their order in dlib.resize_image(), line 41. This do not affect the result, just for clarity.
2, removed redundant line:  new_width, new_height = 628, int( 628 * old_height / old_width)
  • Loading branch information
Bernardo1998 authored Sep 20, 2020
1 parent e8db9ab commit 64d012b
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,23 @@ def rect_to_bb(rect):
# return a tuple of (x, y, w, h)
return (x, y, w, h)


def detect_face(image_paths, SAVE_DETECTED_AT, default_max_size=800, size = 300, padding = 0.25):
def detect_face(image_paths, SAVE_DETECTED_AT, default_max_size=800,size = 300, padding = 0.25):
cnn_face_detector = dlib.cnn_face_detection_model_v1('dlib_models/mmod_human_face_detector.dat')
sp = dlib.shape_predictor('dlib_models/shape_predictor_5_face_landmarks.dat')
base = 2000 # largest width and height
for index, image_path in enumerate(image_paths):
if index % 1000 == 0:
print('---%d/%d---' %(index, len(image_paths)))
img = dlib.load_rgb_image(image_path)
old_width, old_height, _ = img.shape

old_height, old_width, _ = img.shape

if old_width > old_height:
new_width, new_height = default_max_size, int(default_max_size * old_height / old_width)
else:
new_width, new_height = int(default_max_size * old_height / old_width), default_max_size
new_width, new_height = 628, int( 628 * old_height / old_width)
img = dlib.resize_image(img, new_width, new_height)
new_width, new_height = int(default_max_size * old_width / old_height), default_max_size
img = dlib.resize_image(img, rows=new_height, cols=new_width)

dets = cnn_face_detector(img, 1)
num_faces = len(dets)
if num_faces == 0:
Expand All @@ -62,7 +63,7 @@ def predidct_age_gender_race(save_prediction_at, imgs_path = 'cropped_faces/'):

model_fair_7 = torchvision.models.resnet34(pretrained=True)
model_fair_7.fc = nn.Linear(model_fair_7.fc.in_features, 18)
model_fair_7.load_state_dict(torch.load('fair_face_models/res34_fair_align_multi_7_20190809.pt'))
model_fair_7.load_state_dict(torch.load('fair_face_models/fairface_alldata_20191111.pt'))
model_fair_7 = model_fair_7.to(device)
model_fair_7.eval()

Expand Down Expand Up @@ -200,14 +201,18 @@ def ensure_dir(directory):


if __name__ == "__main__":
#Please create a csv with one column 'img_path', contains the full paths of all images to be analyzed.
#Also please change working directory to this file.
parser = argparse.ArgumentParser()
parser.add_argument('--csv', dest='input_csv', action='store',
help='csv file of image path where col name for image path is "img_path')
dlib.DLIB_USE_CUDA = True
print("using CUDA?: %s" % dlib.DLIB_USE_CUDA)
args = parser.parse_args()
SAVE_DETECTED_AT = "detected_faces"
ensure_dir(SAVE_DETECTED_AT)
imgs = pd.read_csv(args.input_csv)['img_path']
detect_face(imgs, SAVE_DETECTED_AT)
print("detected faces are saved at ", SAVE_DETECTED_AT)
predidct_age_gender_race("test_outputs.csv", SAVE_DETECTED_AT)
#Please change test_outputs.csv to actual name of output csv.
predidct_age_gender_race("test_outputs.csv", SAVE_DETECTED_AT)

0 comments on commit 64d012b

Please sign in to comment.