Skip to content

Commit

Permalink
all models can be downloaded with one shot
Browse files Browse the repository at this point in the history
  • Loading branch information
serengil committed Oct 6, 2024
1 parent 5415407 commit 234d0db
Show file tree
Hide file tree
Showing 19 changed files with 156 additions and 93 deletions.
127 changes: 84 additions & 43 deletions deepface/commons/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,33 @@
from deepface.commons import folder_utils, package_utils
from deepface.commons.logger import Logger

# weight urls as variables
from deepface.models.facial_recognition.VGGFace import WEIGHTS_URL as VGGFACE_WEIGHTS
from deepface.models.facial_recognition.Facenet import FACENET128_WEIGHTS, FACENET512_WEIGHTS
from deepface.models.facial_recognition.OpenFace import WEIGHTS_URL as OPENFACE_WEIGHTS
from deepface.models.facial_recognition.FbDeepFace import WEIGHTS_URL as FBDEEPFACE_WEIGHTS
from deepface.models.facial_recognition.ArcFace import WEIGHTS_URL as ARCFACE_WEIGHTS
from deepface.models.facial_recognition.DeepID import WEIGHTS_URL as DEEPID_WEIGHTS
from deepface.models.facial_recognition.SFace import WEIGHTS_URL as SFACE_WEIGHTS
from deepface.models.facial_recognition.GhostFaceNet import WEIGHTS_URL as GHOSTFACENET_WEIGHTS
from deepface.models.facial_recognition.Dlib import WEIGHT_URL as DLIB_FR_WEIGHTS
from deepface.models.demography.Age import WEIGHTS_URL as AGE_WEIGHTS
from deepface.models.demography.Gender import WEIGHTS_URL as GENDER_WEIGHTS
from deepface.models.demography.Race import WEIGHTS_URL as RACE_WEIGHTS
from deepface.models.demography.Emotion import WEIGHTS_URL as EMOTION_WEIGHTS
from deepface.models.spoofing.FasNet import (
FIRST_WEIGHTS_URL as FASNET_1ST_WEIGHTS,
SECOND_WEIGHTS_URL as FASNET_2ND_WEIGHTS,
)
from deepface.models.face_detection.Ssd import MODEL_URL as SSD_MODEL, WEIGHTS_URL as SSD_WEIGHTS
from deepface.models.face_detection.Yolo import (
WEIGHT_URL as YOLOV8_WEIGHTS,
WEIGHT_NAME as YOLOV8_WEIGHT_NAME,
)
from deepface.models.face_detection.YuNet import WEIGHTS_URL as YUNET_WEIGHTS
from deepface.models.face_detection.Dlib import WEIGHTS_URL as DLIB_FD_WEIGHTS
from deepface.models.face_detection.CenterFace import WEIGHTS_URL as CENTERFACE_WEIGHTS

tf_version = package_utils.get_tf_major_version()
if tf_version == 1:
from keras.models import Sequential
Expand All @@ -20,38 +47,40 @@
logger = Logger()

# pylint: disable=line-too-long
WEIGHTS = {
"facial_recognition": {
"VGG-Face": "https://github.com/serengil/deepface_models/releases/download/v1.0/vgg_face_weights.h5",
"Facenet": "https://github.com/serengil/deepface_models/releases/download/v1.0/facenet_weights.h5",
"Facenet512": "https://github.com/serengil/deepface_models/releases/download/v1.0/facenet512_weights.h5",
"OpenFace": "https://github.com/serengil/deepface_models/releases/download/v1.0/openface_weights.h5",
"FbDeepFace": "https://github.com/swghosh/DeepFace/releases/download/weights-vggface2-2d-aligned/VGGFace2_DeepFace_weights_val-0.9034.h5.zip",
"ArcFace": "https://github.com/serengil/deepface_models/releases/download/v1.0/arcface_weights.h5",
"DeepID": "https://github.com/serengil/deepface_models/releases/download/v1.0/deepid_keras_weights.h5",
"SFace": "https://github.com/opencv/opencv_zoo/raw/main/models/face_recognition_sface/face_recognition_sface_2021dec.onnx",
"GhostFaceNet": "https://github.com/HamadYA/GhostFaceNets/releases/download/v1.2/GhostFaceNet_W1.3_S1_ArcFace.h5",
"Dlib": "http://dlib.net/files/dlib_face_recognition_resnet_model_v1.dat.bz2",
},
"demography": {
"Age": "https://github.com/serengil/deepface_models/releases/download/v1.0/age_model_weights.h5",
"Gender": "https://github.com/serengil/deepface_models/releases/download/v1.0/gender_model_weights.h5",
"Emotion": "https://github.com/serengil/deepface_models/releases/download/v1.0/facial_expression_model_weights.h5",
"Race": "https://github.com/serengil/deepface_models/releases/download/v1.0/race_model_single_batch.h5",
},
"detection": {
"ssd_model": "https://github.com/opencv/opencv/raw/3.4.0/samples/dnn/face_detector/deploy.prototxt",
"ssd_weights": "https://github.com/opencv/opencv_3rdparty/raw/dnn_samples_face_detector_20170830/res10_300x300_ssd_iter_140000.caffemodel",
"yolo": "https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb",
"yunet": "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx",
"dlib": "http://dlib.net/files/shape_predictor_5_face_landmarks.dat.bz2",
"centerface": "https://github.com/Star-Clouds/CenterFace/raw/master/models/onnx/centerface.onnx",
WEIGHTS = [
# facial recognition
VGGFACE_WEIGHTS,
FACENET128_WEIGHTS,
FACENET512_WEIGHTS,
OPENFACE_WEIGHTS,
FBDEEPFACE_WEIGHTS,
ARCFACE_WEIGHTS,
DEEPID_WEIGHTS,
SFACE_WEIGHTS,
{
"filename": "ghostfacenet_v1.h5",
"url": GHOSTFACENET_WEIGHTS,
},
"spoofing": {
"MiniFASNetV2": "https://github.com/minivision-ai/Silent-Face-Anti-Spoofing/raw/master/resources/anti_spoof_models/2.7_80x80_MiniFASNetV2.pth",
"MiniFASNetV1SE": "https://github.com/minivision-ai/Silent-Face-Anti-Spoofing/raw/master/resources/anti_spoof_models/4_0_0_80x80_MiniFASNetV1SE.pth",
DLIB_FR_WEIGHTS,
# demography
AGE_WEIGHTS,
GENDER_WEIGHTS,
RACE_WEIGHTS,
EMOTION_WEIGHTS,
# spoofing
FASNET_1ST_WEIGHTS,
FASNET_2ND_WEIGHTS,
# face detection
SSD_MODEL,
SSD_WEIGHTS,
{
"filename": YOLOV8_WEIGHT_NAME,
"url": YOLOV8_WEIGHTS,
},
}
YUNET_WEIGHTS,
DLIB_FD_WEIGHTS,
CENTERFACE_WEIGHTS,
]

ALLOWED_COMPRESS_TYPES = ["zip", "bz2"]

Expand Down Expand Up @@ -131,18 +160,30 @@ def load_model_weights(model: Sequential, weight_file: str) -> Sequential:
return model


def retrieve_model_source(model_name: str, task: str) -> str:
def download_all_models_in_one_shot() -> None:
"""
Find the source url of a given model name
Args:
model_name (str): given model name
Returns:
weight_url (str): source url of the given model
Download all model weights in one shot
"""
if task not in ["facial_recognition", "detection", "demography", "spoofing"]:
raise ValueError(f"unimplemented task - {task}")

source_url = WEIGHTS.get(task, {}).get(model_name)
if source_url is None:
raise ValueError(f"Source url cannot be found for given model {task}-{model_name}")
return source_url
for i in WEIGHTS:
if isinstance(i, str):
url = i
filename = i.split("/")[-1]
compress_type = None
# if compressed file will be downloaded, get rid of its extension
if filename.endswith(tuple(ALLOWED_COMPRESS_TYPES)):
for ext in ALLOWED_COMPRESS_TYPES:
compress_type = ext
if filename.endswith(f".{ext}"):
filename = filename[: -(len(ext) + 1)]
break
elif isinstance(i, dict):
filename = i["filename"]
url = i["url"]
else:
raise ValueError("unimplemented scenario")
logger.info(
f"Downloading {url} to ~/.deepface/weights/{filename} with {compress_type} compression"
)
download_weights_if_necessary(
file_name=filename, source_url=url, compress_type=compress_type
)
11 changes: 7 additions & 4 deletions deepface/models/demography/Age.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@

# ----------------------------------------

WEIGHTS_URL = (
"https://github.com/serengil/deepface_models/releases/download/v1.0/age_model_weights.h5"
)

# pylint: disable=too-few-public-methods
class ApparentAgeClient(Demography):
"""
Expand All @@ -41,7 +45,7 @@ def predict(self, img: np.ndarray) -> np.float64:


def load_model(
url="https://github.com/serengil/deepface_models/releases/download/v1.0/age_model_weights.h5",
url=WEIGHTS_URL,
) -> Model:
"""
Construct age model, download its weights and load
Expand Down Expand Up @@ -70,12 +74,11 @@ def load_model(
file_name="age_model_weights.h5", source_url=url
)

age_model = weight_utils.load_model_weights(
model=age_model, weight_file=weight_file
)
age_model = weight_utils.load_model_weights(model=age_model, weight_file=weight_file)

return age_model


def find_apparent_age(age_predictions: np.ndarray) -> np.float64:
"""
Find apparent age prediction from a given probas of ages
Expand Down
20 changes: 9 additions & 11 deletions deepface/models/demography/Emotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@
from deepface.models.Demography import Demography
from deepface.commons.logger import Logger

logger = Logger()

# -------------------------------------------
# pylint: disable=line-too-long
# -------------------------------------------
# dependency configuration
tf_version = package_utils.get_tf_major_version()

Expand All @@ -28,12 +23,17 @@
Dense,
Dropout,
)
# -------------------------------------------

# Labels for the emotions that can be detected by the model.
labels = ["angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"]

# pylint: disable=too-few-public-methods
logger = Logger()

# pylint: disable=line-too-long, disable=too-few-public-methods

WEIGHTS_URL = "https://github.com/serengil/deepface_models/releases/download/v1.0/facial_expression_model_weights.h5"


class EmotionClient(Demography):
"""
Emotion model class
Expand All @@ -56,7 +56,7 @@ def predict(self, img: np.ndarray) -> np.ndarray:


def load_model(
url="https://github.com/serengil/deepface_models/releases/download/v1.0/facial_expression_model_weights.h5",
url=WEIGHTS_URL,
) -> Sequential:
"""
Consruct emotion model, download and load weights
Expand Down Expand Up @@ -96,8 +96,6 @@ def load_model(
file_name="facial_expression_model_weights.h5", source_url=url
)

model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)
model = weight_utils.load_model_weights(model=model, weight_file=weight_file)

return model
5 changes: 3 additions & 2 deletions deepface/models/demography/Gender.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
else:
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Convolution2D, Flatten, Activation
# -------------------------------------

WEIGHTS_URL="https://github.com/serengil/deepface_models/releases/download/v1.0/gender_model_weights.h5"

# Labels for the genders that can be detected by the model.
labels = ["Woman", "Man"]
Expand All @@ -43,7 +44,7 @@ def predict(self, img: np.ndarray) -> np.ndarray:


def load_model(
url="https://github.com/serengil/deepface_models/releases/download/v1.0/gender_model_weights.h5",
url=WEIGHTS_URL,
) -> Model:
"""
Construct gender model, download its weights and load
Expand Down
18 changes: 9 additions & 9 deletions deepface/models/demography/Race.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@
from deepface.models.Demography import Demography
from deepface.commons.logger import Logger

logger = Logger()

# --------------------------
# pylint: disable=line-too-long
# --------------------------

# dependency configurations
tf_version = package_utils.get_tf_major_version()

Expand All @@ -21,10 +18,15 @@
else:
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Convolution2D, Flatten, Activation
# --------------------------

WEIGHTS_URL = (
"https://github.com/serengil/deepface_models/releases/download/v1.0/race_model_single_batch.h5"
)
# Labels for the ethnic phenotypes that can be detected by the model.
labels = ["asian", "indian", "black", "white", "middle eastern", "latino hispanic"]

logger = Logger()

# pylint: disable=too-few-public-methods
class RaceClient(Demography):
"""
Expand All @@ -42,7 +44,7 @@ def predict(self, img: np.ndarray) -> np.ndarray:


def load_model(
url="https://github.com/serengil/deepface_models/releases/download/v1.0/race_model_single_batch.h5",
url=WEIGHTS_URL,
) -> Model:
"""
Construct race model, download its weights and load
Expand All @@ -69,8 +71,6 @@ def load_model(
file_name="race_model_single_batch.h5", source_url=url
)

race_model = weight_utils.load_model_weights(
model=race_model, weight_file=weight_file
)
race_model = weight_utils.load_model_weights(model=race_model, weight_file=weight_file)

return race_model
3 changes: 2 additions & 1 deletion deepface/models/face_detection/Dlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

logger = Logger()

WEIGHTS_URL="http://dlib.net/files/shape_predictor_5_face_landmarks.dat.bz2"

class DlibClient(Detector):
def __init__(self):
Expand All @@ -34,7 +35,7 @@ def build_model(self) -> dict:
# check required file exists in the home/.deepface/weights folder
weight_file = weight_utils.download_weights_if_necessary(
file_name="shape_predictor_5_face_landmarks.dat",
source_url="http://dlib.net/files/shape_predictor_5_face_landmarks.dat.bz2",
source_url=WEIGHTS_URL,
compress_type="bz2",
)

Expand Down
7 changes: 5 additions & 2 deletions deepface/models/face_detection/Ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

# pylint: disable=line-too-long, c-extension-no-member

MODEL_URL = "https://github.com/opencv/opencv/raw/3.4.0/samples/dnn/face_detector/deploy.prototxt"
WEIGHTS_URL = "https://github.com/opencv/opencv_3rdparty/raw/dnn_samples_face_detector_20170830/res10_300x300_ssd_iter_140000.caffemodel"


class SsdClient(Detector):
def __init__(self):
Expand All @@ -31,13 +34,13 @@ def build_model(self) -> dict:
# model structure
output_model = weight_utils.download_weights_if_necessary(
file_name="deploy.prototxt",
source_url="https://github.com/opencv/opencv/raw/3.4.0/samples/dnn/face_detector/deploy.prototxt",
source_url=MODEL_URL,
)

# pre-trained weights
output_weights = weight_utils.download_weights_if_necessary(
file_name="res10_300x300_ssd_iter_140000.caffemodel",
source_url="https://github.com/opencv/opencv_3rdparty/raw/dnn_samples_face_detector_20170830/res10_300x300_ssd_iter_140000.caffemodel",
source_url=WEIGHTS_URL,
)

try:
Expand Down
4 changes: 2 additions & 2 deletions deepface/models/face_detection/Yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
logger = Logger()

# Model's weights paths
PATH = ".deepface/weights/yolov8n-face.pt"
WEIGHT_NAME = "yolov8n-face.pt"

# Google Drive URL from repo (https://github.com/derronqi/yolov8-face) ~6MB
WEIGHT_URL = "https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb"
Expand All @@ -39,7 +39,7 @@ def build_model(self) -> Any:
) from e

weight_file = weight_utils.download_weights_if_necessary(
file_name="yolov8n-face.pt", source_url=WEIGHT_URL
file_name=WEIGHT_NAME, source_url=WEIGHT_URL
)

# Return face_detector
Expand Down
4 changes: 3 additions & 1 deletion deepface/models/face_detection/YuNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

logger = Logger()

WEIGHTS_URL = "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx"


class YuNetClient(Detector):
def __init__(self):
Expand Down Expand Up @@ -41,7 +43,7 @@ def build_model(self) -> Any:
# pylint: disable=C0301
weight_file = weight_utils.download_weights_if_necessary(
file_name="face_detection_yunet_2023mar.onnx",
source_url="https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx",
source_url=WEIGHTS_URL,
)

try:
Expand Down
4 changes: 3 additions & 1 deletion deepface/models/facial_recognition/ArcFace.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
Dense,
)

WEIGHTS_URL="https://github.com/serengil/deepface_models/releases/download/v1.0/arcface_weights.h5"

# pylint: disable=too-few-public-methods
class ArcFaceClient(FacialRecognition):
"""
Expand All @@ -56,7 +58,7 @@ def __init__(self):


def load_model(
url="https://github.com/serengil/deepface_models/releases/download/v1.0/arcface_weights.h5",
url=WEIGHTS_URL,
) -> Model:
"""
Construct ArcFace model, download its weights and load
Expand Down
Loading

0 comments on commit 234d0db

Please sign in to comment.