From ed0c917051e1d41acbb8bae0ddf6ecc63b0f6c80 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Thu, 28 Mar 2024 11:44:44 +0800 Subject: [PATCH] refine citation --- api/apps/conversation_app.py | 7 ++++--- rag/app/paper.py | 2 +- rag/nlp/search.py | 29 ++++++++++++++++------------- 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 5c55d5dd7a9..5521cca8ba6 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -194,7 +194,8 @@ def chat(dialog, messages, **kwargs): # try to use sql if field mapping is good to go if field_map: chat_logger.info("Use SQL to retrieval:{}".format(questions[-1])) - return use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl) + ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl) + if ans: return ans prompt_config = dialog.prompt_config for p in prompt_config["parameters"]: @@ -305,7 +306,7 @@ def get_table(): tbl, sql = get_table() if tbl is None: - return None, None + return None if tbl.get("error") and tried_times <= 2: user_promt = """ 表名:{}; @@ -333,7 +334,7 @@ def get_table(): chat_logger.info("GET table: {}".format(tbl)) print(tbl) if tbl.get("error") or len(tbl["rows"]) == 0: - return None, None + return None docid_idx = set([ii for ii, c in enumerate( tbl["columns"]) if c["name"] == "doc_id"]) diff --git a/rag/app/paper.py b/rag/app/paper.py index 87250545fbb..9a75bec7881 100644 --- a/rag/app/paper.py +++ b/rag/app/paper.py @@ -120,7 +120,7 @@ def _begin(txt): print(tbls) return { - "title": title if title else filename, + "title": title, "authors": " ".join(authors), "abstract": abstr, "sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes[i:] if diff --git a/rag/nlp/search.py b/rag/nlp/search.py index cc9f533efb1..ac92853dd7c 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -246,19 +246,22 @@ def insert_citations(self, answer, chunks, chunk_v, chunks_tks = [huqie.qie(self.qryr.rmWWW(ck)).split(" ") for ck in chunks] cites = {} - for i, a in enumerate(pieces_): - sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i], - chunk_v, - huqie.qie( - self.qryr.rmWWW(pieces_[i])).split(" "), - chunks_tks, - tkweight, vtweight) - mx = np.max(sim) * 0.99 - es_logger.info("{} SIM: {}".format(pieces_[i], mx)) - if mx < 0.63: - continue - cites[idx[i]] = list( - set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4] + thr = 0.63 + while len(cites.keys()) == 0 and pieces_ and chunks_tks: + for i, a in enumerate(pieces_): + sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i], + chunk_v, + huqie.qie( + self.qryr.rmWWW(pieces_[i])).split(" "), + chunks_tks, + tkweight, vtweight) + mx = np.max(sim) * 0.99 + es_logger.info("{} SIM: {}".format(pieces_[i], mx)) + if mx < thr: + continue + cites[idx[i]] = list( + set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4] + thr *= 0.8 res = "" seted = set([])