diff --git a/rag/benchmark.py b/rag/benchmark.py index 490c031f97c..aea4ef99c5f 100644 --- a/rag/benchmark.py +++ b/rag/benchmark.py @@ -16,11 +16,15 @@ import json import os from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from copy import deepcopy + from api.db import LLMType from api.db.services.llm_service import LLMBundle from api.db.services.knowledgebase_service import KnowledgebaseService from api.settings import retrievaler from api.utils import get_uuid +from api.utils.file_utils import get_project_base_directory from rag.nlp import tokenize, search from rag.utils.es_conn import ELASTICSEARCH from ranx import evaluate @@ -63,14 +67,34 @@ def embedding(self, docs, batch_size=16): d["q_%d_vec" % len(v)] = v return docs + @staticmethod + def init_kb(index_name): + idxnm = search.index_name(index_name) + if ELASTICSEARCH.indexExist(idxnm): + ELASTICSEARCH.deleteIdx(search.index_name(index_name)) + + return ELASTICSEARCH.createIdx(idxnm, json.load( + open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r"))) + def ms_marco_index(self, file_path, index_name): qrels = defaultdict(dict) texts = defaultdict(dict) docs = [] filelist = os.listdir(file_path) + self.init_kb(index_name) + + max_workers = int(os.environ.get('MAX_WORKERS', 3)) + exe = ThreadPoolExecutor(max_workers=max_workers) + threads = [] + + def slow_actions(es_docs, idx_nm): + es_docs = self.embedding(es_docs) + ELASTICSEARCH.bulk(es_docs, idx_nm) + return True + for dir in filelist: data = pd.read_parquet(os.path.join(file_path, dir)) - for i in tqdm(range(len(data)), colour="green", desc="Indexing:" + dir): + for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + dir): query = data.iloc[i]['query'] for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']): @@ -82,12 +106,17 @@ def ms_marco_index(self, file_path, index_name): texts[d["id"]] = text qrels[query][d["id"]] = int(rel) if len(docs) >= 32: - docs = self.embedding(docs) - ELASTICSEARCH.bulk(docs, search.index_name(index_name)) + threads.append( + exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name))) docs = [] - docs = self.embedding(docs) - ELASTICSEARCH.bulk(docs, search.index_name(index_name)) + threads.append( + exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name))) + + for i in tqdm(range(len(threads)), colour="red", desc="Indexing:" + dir): + if not threads[i].result().output: + print("Indexing error...") + return qrels, texts def trivia_qa_index(self, file_path, index_name): diff --git a/rag/nlp/term_weight.py b/rag/nlp/term_weight.py index 1b22dc9e176..f0be0527ffb 100644 --- a/rag/nlp/term_weight.py +++ b/rag/nlp/term_weight.py @@ -227,7 +227,7 @@ def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5))) idf2 = np.array([idf(df(t), 1000000000) for t in tks]) wts = (0.3 * idf1 + 0.7 * idf2) * \ np.array([ner(t) * postag(t) for t in tks]) - tw = zip(tks, wts) + tw = list(zip(tks, wts)) else: for tk in tks: tt = self.tokenMerge(self.pretoken(tk, True))