diff --git a/setup.py b/setup.py index df546c3..613e10f 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ # Package metadata name = "suql" -version = "1.1.7a5" +version = "1.1.7a6" description = "Structured and Unstructured Query Language (SUQL) Python API" author = "Shicheng Liu" author_email = "shicheng@cs.stanford.edu" diff --git a/src/suql/faiss_embedding.py b/src/suql/faiss_embedding.py index d0a11fe..291b2c2 100644 --- a/src/suql/faiss_embedding.py +++ b/src/suql/faiss_embedding.py @@ -285,11 +285,12 @@ def dot_product(self, id_list, query, top, individual_id_list=[]): for sublist in map(lambda x: self.id2document[x], individual_id_list) for item in sublist ] - embedding_indices = [ + # remove potential duplicates here + embedding_indices = list(dict.fromkeys([ item for sublist in map(lambda x: self.document2embedding[x], document_indices) for item in sublist - ] + ])) query_embedding = embed_query(query) @@ -301,8 +302,8 @@ def dot_product(self, id_list, query, top, individual_id_list=[]): params=faiss.SearchParametersIVF(sel=sel), ) else: - if top > self.embeddings.ntotal: - top = self.embeddings.ntotal + if top > min(self.embeddings.ntotal, len(embedding_indices)): + top = min(self.embeddings.ntotal, len(embedding_indices)) D, I = self.embeddings.search( query_embedding, top, params=faiss.SearchParametersIVF(sel=sel) ) diff --git a/src/suql/sql_free_text_support/execute_free_text_sql.py b/src/suql/sql_free_text_support/execute_free_text_sql.py index 8d34408..82e474f 100644 --- a/src/suql/sql_free_text_support/execute_free_text_sql.py +++ b/src/suql/sql_free_text_support/execute_free_text_sql.py @@ -723,10 +723,13 @@ def _retrieve_and_verify( enforce_ordering=True if node.sortClause is not None else False, ) else: - id_res = [] + id_res = set() for each_res in parsed_result: if _verify_single_res(each_res, field_query_list, llm_model_name): - id_res.append(each_res[0]) + if isinstance(each_res[0], list): + id_res.update(each_res[0]) + else: + id_res.add(each_res[0]) end_time = time.time() logging.info("retrieve + verification time {}s".format(end_time - start_time))