-
Notifications
You must be signed in to change notification settings - Fork 0
/
sql_db.py
288 lines (243 loc) · 12.4 KB
/
sql_db.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
import os
from datetime import datetime
import pandas as pd
from typing import Dict
from sqlalchemy.exc import IntegrityError
from sqlalchemy import select, or_
from sqlalchemy.sql import func
from sqlmodel import SQLModel, create_engine, Session, select
import constants
from db.sql_models import User, Category, Product, SearchHistory, Recommendation, RecommendationFeedback
from db.qdrant_db import QdrantDatabase
from db.embedder import TextEmbedder
from db.tokenizer import TextTokenizer
class DB:
def __init__(self, database_location: str = None):
self.database_location = constants.SQL_DB_PATH if database_location is None else database_location
self.engine = None
self.text_embedder = TextEmbedder()
self.qdrant_product_database = QdrantDatabase(collection_name=constants.QDRANT_PRODUCT_COLLECTION_NAME)
self.text_tokenizer = TextTokenizer()
self.initialize_database()
def initialize_database(self):
if self.engine is None:
os.makedirs(os.path.dirname(self.database_location), exist_ok=True)
sqlite_url = f"sqlite:///{self.database_location}"
self.engine = create_engine(sqlite_url, echo=True)
if os.path.exists(self.database_location):
pass
else:
print("Database does not exist. Initializing...")
SQLModel.metadata.create_all(self.engine)
def record_user(self, email: str, profile_description: str):
user: User = User(email=email, profile_description=profile_description)
with Session(self.engine) as session:
session.add(user)
session.commit()
session.refresh(user)
def update_user(self, user_id: int, new_email: str = None, new_profile_description: str = None):
with Session(self.engine) as session:
statement = select(User).where(User.user_id == user_id)
user = session.exec(statement).first()
if user:
if new_email is not None:
user.email = new_email
if new_profile_description is not None:
user.profile_description = new_profile_description
user.created_at = user.created_at
user.updated_at = datetime.now()
session.commit()
session.refresh(user)
def record_category(self, name: str, description: str):
category: Category = Category(name=name, description=description)
with Session(self.engine) as session:
session.add(category)
session.commit()
session.refresh(category)
def update_category(self, category_id: int, new_name: str = None, new_description: str = None):
with Session(self.engine) as session:
statement = select(Category).where(Category.category_id == category_id)
category = session.exec(statement).first()
if category:
if new_name is not None:
category.name = new_name
if new_description is not None:
category.description = new_description
category.created_at = category.created_at
category.updated_at = datetime.now()
session.commit()
session.refresh(category)
def record_product(self, category_id: int, name: str, description: str, price: float, stock: int):
product: Product = Product(category_id=category_id, name=name, description=description, price=price, stock=stock)
with Session(self.engine) as session:
session.add(product)
session.commit()
session.refresh(product)
product_category_str = self.get_data(table='Category', id=category_id)['name']
product_text = f'Product name is {name} and category is {product_category_str}. Description: {description}'
product_embedding = self.text_embedder.generate_embedding(product_text)
self.qdrant_product_database.record_embedding_into_collection(
id=product.product_id,
embedding=product_embedding
)
def update_product(self, product_id: int = None, new_category_id: int = None, new_name: str = None, new_description: str = None, new_price: float = None, new_stock: int = None):
with Session(self.engine) as session:
statement = select(Product).where(Product.product_id == product_id)
product: Product = session.exec(statement).first()
if product:
if new_category_id is not None:
product.category_id = new_category_id
if new_name is not None:
product.name = new_name
if new_description is not None:
product.description = new_description
if new_price is not None:
product.price = new_price
if new_stock is not None:
product.stock = new_stock
product.created_at = product.created_at
product.updated_at = datetime.now()
session.commit()
session.refresh(product)
product_category_str = self.get_data(table='Category', id=product.category_id)['name']
product_text = f'Product name is {new_name} and category is {product_category_str}. Description: {new_description}'
product_embedding = self.text_embedder.generate_embedding(product_text)
self.qdrant_product_database.record_embedding_into_collection(
id=product.product_id,
embedding=product_embedding
)
def record_search_history(self, user_id: int, query: str) -> int:
query: SearchHistory = SearchHistory(user_id=user_id, query=query)
with Session(self.engine) as session:
session.add(query)
session.commit()
session.refresh(query)
return query.search_id
def record_recommendation(self, user_id: int, product_id: int, score: float = 1) -> int:
recommendation: Recommendation = Recommendation(user_id=user_id, product_id=product_id, score=score)
with Session(self.engine) as session:
session.add(recommendation)
session.commit()
session.refresh(recommendation)
return recommendation.recommendation_id
def record_recommendation_feedback(self, recommendation_id: int, user_id: int, rating: int) -> int:
recommendation_feedback: RecommendationFeedback = RecommendationFeedback(recommendation_id=recommendation_id, user_id=user_id, rating=rating)
with Session(self.engine) as session:
try:
session.add(recommendation_feedback)
session.commit()
session.refresh(recommendation_feedback)
return recommendation_feedback.recommendation_feedback_id
except IntegrityError as e:
if "UNIQUE constraint failed" in str(e.orig): #Feedback already exists for recommendation
session.rollback() # Rollback the transaction to keep the session clean
return None
else:
# For any other IntegrityError, re-raise the exception
raise
def get_data(self, table: str, id: int, return_as_dict: bool = True) -> dict:
'''
Parameters:
table (str) options:
- "User"
- "Category"
- "Product"
- "SearchHistory"
- "Recommendation"
- "RecommendationFeedback"
id (str): represents primary id of the choosen table
Returns:
dict if data found else None
'''
mapping = {
"User": select(User).where(User.user_id == id),
"Category": select(Category).where(Category.category_id == id),
"Product": select(Product).where(Product.product_id == id),
"SearchHistory": select(SearchHistory).where(SearchHistory.search_id == id),
"Recommendation": select(Recommendation).where(Recommendation.recommendation_id == id),
"RecommendationFeedback": select(RecommendationFeedback).where(RecommendationFeedback.recommendation_feedback_id == id)
}
if table not in mapping: # Not given correct table name
return None
statement = mapping[table]
with Session(self.engine) as session:
data = session.exec(statement).first()
if data:
if return_as_dict:
return data.dict()
else:
return data
return None
def search_product(self, user_id: int = None, search_keyword: str = '', min_price: float = None, max_price: float = None, min_stock: int = None, return_as_dict: bool = True, use_vectore_search: bool = True, similarity_threshold: float=0.3):
if user_id is not None:
self.record_search_history(user_id, search_keyword)
search_results = []
### Using keyword search ###
with Session(self.engine) as session:
tokenized_search_keyword: list = self.text_tokenizer.clean_text(search_keyword)
conditions = []
keyword_conditions = []
for keyword in tokenized_search_keyword:
keyword_conditions.extend(
[
Product.name.ilike(f"%{keyword}%"),
Product.description.ilike(f"%{keyword}%")
]
)
if keyword_conditions:
conditions.append(or_(*keyword_conditions))
if min_price is not None:
conditions.append(
Product.price >= min_price
)
if max_price is not None:
conditions.append(
Product.price <= max_price
)
if min_stock is None:
conditions.append(
Product.stock > 0 # Ensure that the product has stock greater than 0
)
else:
conditions.append(
Product.stock > min_stock
)
statement = select(Product).where(
*conditions, # Unpack the list of conditions
)
results = session.exec(statement).all()
search_results.extend(results)
### Using semantic vectore search ###
if use_vectore_search:
search_query_embedding = self.text_embedder.generate_embedding(search_keyword)
products_search_result_ids: list = self.qdrant_product_database.search_similar_products(query_embedding=search_query_embedding, similarity_threshold=similarity_threshold)
with Session(self.engine) as session:
statement = select(Product).filter(Product.product_id.in_(products_search_result_ids))
products = session.exec(statement).all()
for product in products:
stock_requirement_met = False
if min_stock: # when stock given
if (product.stock >= min_stock) and (product.stock > 0):
stock_requirement_met = True
else:
if product.stock > 0: # Ensure that the product has stock greater than 0
stock_requirement_met = True
if stock_requirement_met:
if (min_price is None or product.price >= min_price) and (max_price is None or product.price <= max_price):
# filtered_products.append(product)
if product not in search_results:
search_results.append(product)
### Search Result Return ###
if return_as_dict:
return [product.dict() for product in search_results]
else:
return [product for product in search_results]
def summarize_recommendation_feedback_rating(self) -> dict:
with Session(self.engine) as session:
statement = (
select(RecommendationFeedback.rating, func.count(RecommendationFeedback.rating))
.group_by(RecommendationFeedback.rating)
)
results = session.exec(statement).all()
results_dict = {result[0]:result[1] for result in results}
return results_dict