Skip to content

Commit

Permalink
fix Auto focal point crop for opencv >= 4.8.x
Browse files Browse the repository at this point in the history
  • Loading branch information
w-e-w committed Nov 27, 2023
1 parent f0f100e commit d4e254d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
27 changes: 15 additions & 12 deletions modules/textual_inversion/autocrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
import numpy as np
from PIL import ImageDraw
from modules import paths_internal
from pkg_resources import parse_version

GREEN = "#0F0"
BLUE = "#00F"
Expand Down Expand Up @@ -294,22 +296,23 @@ def is_square(w, h):
return w == h


def download_and_cache_models(dirname):
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
model_file_name = 'face_detection_yunet.onnx'
if parse_version(cv2.__version__) >= parse_version('4.8'):
model_file_path = os.path.join(paths_internal.models_path, 'opencv', 'face_detection_yunet_2023mar.onnx')
model_url = 'https://github.com/opencv/opencv_zoo/blob/b6e370b10f641879a87890d44e42173077154a05/models/face_detection_yunet/face_detection_yunet_2023mar.onnx?raw=true'
else:
model_file_path = os.path.join(paths_internal.models_path, 'opencv', 'face_detection_yunet.onnx')
model_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'

os.makedirs(dirname, exist_ok=True)

cache_file = os.path.join(dirname, model_file_name)
if not os.path.exists(cache_file):
print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
response = requests.get(download_url)
with open(cache_file, "wb") as f:
def download_and_cache_models():
if not os.path.exists(model_file_path):
print(f"downloading face detection model from '{model_url}' to '{model_file_path}'")
response = requests.get(model_url)
with open(model_file_path, "wb") as f:
f.write(response.content)

if os.path.exists(cache_file):
return cache_file
return None
if os.path.exists(model_file_path):
return model_file_path


class PointOfInterest:
Expand Down
4 changes: 2 additions & 2 deletions modules/textual_inversion/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math
import tqdm

from modules import paths, shared, images, deepbooru
from modules import shared, images, deepbooru
from modules.textual_inversion import autocrop


Expand Down Expand Up @@ -196,7 +196,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre

dnn_model_path = None
try:
dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv"))
dnn_model_path = autocrop.download_and_cache_models()
except Exception as e:
print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)

Expand Down

0 comments on commit d4e254d

Please sign in to comment.