This repository has been archived by the owner on Nov 25, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrecommend.py
108 lines (92 loc) · 4.63 KB
/
recommend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import argparse
import os.path
import joblib
import numpy as np
import pandas as pd
import tensorflow as tf
from PIL import Image
from matplotlib import pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity
from classify import classify
def recommend(
ref_path: str, num_recommendations: int,
data_path: str, clf_path: str, fe_path: str, clu_path: str,
) -> list:
"""
Recommends similar images based on a reference image.
:param ref_path: Path to the reference image.
:param num_recommendations: Number of recommended images to return.
:param data_path: Path to the .csv data file containing recommender database image feature vectors. This file must be generated using the same feature extractor specified in fe_path.
:param clf_path: Path to the classifier model file.
:param fe_path: Path to the feature extraction model file.
:param clu_path: Path to the clustering model file.
:return: List of paths to the recommended images.
"""
if num_recommendations < 1:
raise ValueError('Number of recommendations cannot be smaller than 1.')
df_rec = pd.read_csv(data_path)
fe = tf.keras.models.load_model(fe_path)
clu = joblib.load(clu_path)
clu.set_params(n_clusters=int(np.sqrt(len(df_rec) / num_recommendations)))
ref_processed, ref_class = classify(ref_path, classifier_path=clf_path, return_original=False, verbose=False)
recommendations = df_rec[df_rec['Class'] == ref_class]
# Extract reference image feature vector
ref_processed = np.squeeze(ref_processed)
ref_feature_vector = fe.predict(
tf.expand_dims(ref_processed, axis=0),
verbose=0
)
ref_feature_vector = ref_feature_vector.astype(float)
ref_feature_vector = ref_feature_vector.reshape(1, -1)
# Cluster reference image
clu.fit(recommendations.drop(['ImgPath', 'Class'], axis='columns').values)
ref_cluster = clu.predict(ref_feature_vector)
ref_cluster_indices = np.where(clu.labels_ == ref_cluster)[0]
recommendations = recommendations.iloc[ref_cluster_indices]
# Rank cluster and produce top cosine similarity recommendations
cosine_similarities = cosine_similarity(
ref_feature_vector,
recommendations.drop(['ImgPath', 'Class'], axis='columns')
)
sorted_ref_cluster_indices = np.argsort(-cosine_similarities.flatten())
if num_recommendations > len(sorted_ref_cluster_indices):
raise ValueError('Number of recommendations too large. Insufficient database size.')
top_ref_cluster_indices = sorted_ref_cluster_indices[:num_recommendations]
recommendations = recommendations.iloc[top_ref_cluster_indices]
return list(recommendations['ImgPath'].values)
if __name__ == '__main__':
ap = argparse.ArgumentParser()
ap.add_argument('-f', '--file', required=True, help='reference image')
ap.add_argument('-d', '--database', default='data/recommender-database', help='the database containing the images to be recommended, default: data/recommender-database')
ap.add_argument('-c', '--classifier', default='models/clf-cnn', help='the machine learning model used for image classification, default: models/clf-cnn')
ap.add_argument('-e', '--feature-extractor', default='models/fe-cnn', help='the machine learning model used for image feature extraction, default: models/fe-cnn')
ap.add_argument('-k', '--clustering-model', default='models/clu-kmeans.model', help='the machine learning model used for image clustering, default: models/clu-kmeans.model')
ap.add_argument('-n', '--num', required=False, default='10', help="number of recommendations, default: 10")
args = vars(ap.parse_args())
num = int(args['num'])
fig, axes = plt.subplots(max([1, num // 5]) + 1, 5, figsize=(16, 16), num='Flower Image Recommender')
axes = axes.ravel()
ref = Image.open(args['file'])
_, ref_class = classify(args['file'], classifier_path=args['classifier'], return_original=False, verbose=False)
axes[2].imshow(ref)
axes[2].set_title(
f'Reference Image - "{ref_class}"',
fontsize=10,
weight='bold'
)
axes[2].text(
0.5, -0.08, f'{os.path.relpath(args["file"])}',
horizontalalignment='center',
verticalalignment='center_baseline',
transform=axes[2].transAxes,
fontsize=8,
)
for i, rec_path in enumerate(recommend(
args['file'], int(args['num']),
args['database'] + '.csv', args['classifier'], args['feature_extractor'], args['clustering_model']
), start=5):
with Image.open(f'{args["database"]}/{rec_path}') as rec:
axes[i].imshow(rec)
for ax in axes:
ax.axis('off')
plt.show()