Skip to content

Commit

Permalink
Merge pull request #948 from serengil/feat-task-0701-vgg-descriptor
Browse files Browse the repository at this point in the history
VGG-Face descriptor with new structure
  • Loading branch information
serengil authored Jan 8, 2024
2 parents 20edf2c + d35833e commit 1b40870
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 11 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ Pipfile.lock
.idea/
deepface.egg-info/
tests/dataset/*.pkl
tests/sandbox.ipynb
tests/*.ipynb
tests/*.csv
*.pyc
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,15 @@ Face recognition models basically represent facial images as multi-dimensional v
embedding_objs = DeepFace.represent(img_path = "img.jpg")
```

This function returns an array as embedding. The size of the embedding array would be different based on the model name. For instance, VGG-Face is the default model and it represents facial images as 2622 dimensional vectors.
This function returns an array as embedding. The size of the embedding array would be different based on the model name. For instance, VGG-Face is the default model and it represents facial images as 4096 dimensional vectors.

```python
embedding = embedding_objs[0]["embedding"]
assert isinstance(embedding, list)
assert model_name = "VGG-Face" and len(embedding) == 2622
assert model_name = "VGG-Face" and len(embedding) == 4096
```

Here, embedding is also [plotted](https://sefiks.com/2020/05/01/a-gentle-introduction-to-face-recognition-in-deep-learning/) with 2622 slots horizontally. Each slot is corresponding to a dimension value in the embedding vector and dimension value is explained in the colorbar on the right. Similar to 2D barcodes, vertical dimension stores no information in the illustration.
Here, embedding is also [plotted](https://sefiks.com/2020/05/01/a-gentle-introduction-to-face-recognition-in-deep-learning/) with 4096 slots horizontally. Each slot is corresponding to a dimension value in the embedding vector and dimension value is explained in the colorbar on the right. Similar to 2D barcodes, vertical dimension stores no information in the illustration.

<p align="center"><img src="https://raw.githubusercontent.com/serengil/deepface/master/icon/embedding.jpg" width="95%" height="95%"></p>

Expand Down
10 changes: 10 additions & 0 deletions deepface/DeepFace.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,15 @@ def find(
for index, instance in df.iterrows():
source_representation = instance[f"{model_name}_representation"]

target_dims = len(list(target_representation))
source_dims = len(list(source_representation))
if target_dims != source_dims:
raise ValueError(
"Source and target embeddings must have same dimensions but "
+ f"{target_dims}:{source_dims}. Model structure may change"
+ " after pickle created. Delete the {file_name} and re-run."
)

if distance_metric == "cosine":
distance = dst.findCosineDistance(source_representation, target_representation)
elif distance_metric == "euclidean":
Expand All @@ -636,6 +645,7 @@ def find(

threshold = dst.findThreshold(model_name, distance_metric)
result_df = result_df.drop(columns=[f"{model_name}_representation"])
# pylint: disable=unsubscriptable-object
result_df = result_df[result_df[f"{model_name}_{distance_metric}"] <= threshold]
result_df = result_df.sort_values(
by=[f"{model_name}_{distance_metric}"], ascending=True
Expand Down
18 changes: 17 additions & 1 deletion deepface/basemodels/VGGFace.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
Flatten,
Dropout,
Activation,
Lambda,
)
from keras import backend as K
else:
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import (
Expand All @@ -29,7 +31,9 @@
Flatten,
Dropout,
Activation,
Lambda,
)
from tensorflow.keras import backend as K

# ---------------------------------------

Expand Down Expand Up @@ -98,6 +102,18 @@ def loadModel(

model.load_weights(output)

vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output)
# 2622d dimensional model
# vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output)

# 4096 dimensional model offers 6% to 14% increasement on accuracy!
# - softmax causes underfitting
# - added normalization layer to avoid underfitting with euclidean
# as described here: https://github.com/serengil/deepface/issues/944
base_model_output = Sequential()
base_model_output = Flatten()(model.layers[-5].output)
base_model_output = Lambda(lambda x: K.l2_normalize(x, axis=1), name="norm_layer")(
base_model_output
)
vgg_face_descriptor = Model(inputs=model.input, outputs=base_model_output)

return vgg_face_descriptor
7 changes: 6 additions & 1 deletion deepface/commons/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ def findThreshold(model_name: str, distance_metric: str) -> float:
base_threshold = {"cosine": 0.40, "euclidean": 0.55, "euclidean_l2": 0.75}

thresholds = {
"VGG-Face": {"cosine": 0.40, "euclidean": 0.60, "euclidean_l2": 0.86},
# "VGG-Face": {"cosine": 0.40, "euclidean": 0.60, "euclidean_l2": 0.86}, # 2622d
"VGG-Face": {
"cosine": 0.68,
"euclidean": 1.17,
"euclidean_l2": 1.17,
}, # 4096d - tuned with LFW
"Facenet": {"cosine": 0.40, "euclidean": 10, "euclidean_l2": 0.80},
"Facenet512": {"cosine": 0.30, "euclidean": 23.56, "euclidean_l2": 1.04},
"ArcFace": {"cosine": 0.68, "euclidean": 4.15, "euclidean_l2": 1.13},
Expand Down
2 changes: 1 addition & 1 deletion tests/test_enforce_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_disabled_enforce_detection_for_non_facial_input_on_represent():
assert "w" in objs[0]["facial_area"].keys()
assert "h" in objs[0]["facial_area"].keys()
assert isinstance(objs[0]["embedding"], list)
assert len(objs[0]["embedding"]) == 2622 # embedding of VGG-Face
assert len(objs[0]["embedding"]) == 4096 # embedding of VGG-Face

logger.info("✅ disabled enforce detection with non facial input test for represent tests done")

Expand Down
29 changes: 26 additions & 3 deletions tests/test_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,42 @@


def test_find_with_exact_path():
dfs = DeepFace.find(img_path="dataset/img1.jpg", db_path="dataset", silent=True)
img_path = "dataset/img1.jpg"
dfs = DeepFace.find(img_path=img_path, db_path="dataset", silent=True)
assert len(dfs) > 0
for df in dfs:
assert isinstance(df, pd.DataFrame)

# one is img1.jpg itself
identity_df = df[df["identity"] == img_path]
assert identity_df.shape[0] > 0

# validate reproducability
assert identity_df["VGG-Face_cosine"].values[0] == 0

df = df[df["identity"] != img_path]
logger.debug(df.head())
assert df.shape[0] > 0
logger.info("✅ test find for exact path done")


def test_find_with_array_input():
img1 = cv2.imread("dataset/img1.jpg")
img_path = "dataset/img1.jpg"
img1 = cv2.imread(img_path)
dfs = DeepFace.find(img1, db_path="dataset", silent=True)

assert len(dfs) > 0
for df in dfs:
assert isinstance(df, pd.DataFrame)

# one is img1.jpg itself
identity_df = df[df["identity"] == img_path]
assert identity_df.shape[0] > 0

# validate reproducability
assert identity_df["VGG-Face_cosine"].values[0] == 0


df = df[df["identity"] != img_path]
logger.debug(df.head())
assert df.shape[0] > 0

Expand Down
2 changes: 1 addition & 1 deletion tests/test_represent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_standard_represent():
for embedding_obj in embedding_objs:
embedding = embedding_obj["embedding"]
logger.debug(f"Function returned {len(embedding)} dimensional vector")
assert len(embedding) == 2622
assert len(embedding) == 4096
logger.info("✅ test standard represent function done")


Expand Down

0 comments on commit 1b40870

Please sign in to comment.