Skip to content

Commit

Permalink
Update file parsing progress info (infiniflow#3780)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Refine the file parsing progress info

### Type of change

- [x] Refactoring

Signed-off-by: jinhai <haijin.chn@gmail.com>
  • Loading branch information
JinHai-CN authored Dec 1, 2024
1 parent b5f6436 commit ea84cc2
Showing 1 changed file with 57 additions and 43 deletions.
100 changes: 57 additions & 43 deletions rag/svr/task_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,72 +370,86 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
return res, tk_count, vector_size


def do_handle_task(r):
callback = partial(set_progress, r["id"], r["from_page"], r["to_page"])
def do_handle_task(task):
task_id = task["id"]
task_from_page = task["from_page"]
task_to_page = task["to_page"]
task_tenant_id = task["tenant_id"]
task_embedding_id = task["embd_id"]
task_language = task["language"]
task_llm_id = task["llm_id"]
task_dataset_id = task["kb_id"]
task_doc_id = task["doc_id"]
task_document_name = task["name"]
task_parser_config = task["parser_config"]

# prepare the progress callback function
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
try:
embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING, llm_name=r["embd_id"], lang=r["language"])
# bind embedding model
embedding_model = LLMBundle(task_tenant_id, LLMType.EMBEDDING, llm_name=task_embedding_id, lang=task_language)
except Exception as e:
callback(-1, msg=str(e))
progress_callback(-1, msg=f'Fail to bind embedding model: {str(e)}')
raise
if r.get("task_type", "") == "raptor":

# Either using RAPTOR or Standard chunking methods
if task.get("task_type", "") == "raptor":
try:
chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"])
cks, tk_count, vector_size = run_raptor(r, chat_mdl, embd_mdl, callback)
# bind LLM for raptor
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)

# run RAPTOR
chunks, tk_count, vector_size = run_raptor(task, chat_model, embedding_model, progress_callback)
except Exception as e:
callback(-1, msg=str(e))
progress_callback(-1, msg=f'Fail to bind LLM used by RAPTOR: {str(e)}')
raise
else:
st = timer()
cks = build(r)
logging.info("Build chunks({}): {}".format(r["name"], timer() - st))
if cks is None:
# Standard chunking methods
start_ts = timer()
chunks = build(task)
logging.info("Build document {}: {:.2f}s".format(task_document_name, timer() - start_ts))
if chunks is None:
return
if not cks:
callback(1., "No chunk! Done!")
if not chunks:
progress_callback(1., msg=f"No chunk built from {task_document_name}")
return
# TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ")
callback(
msg="Generate {} chunks ({:.2f}s). Embedding chunks.".format(len(cks), timer() - st)
)
st = timer()
## set_progress(task["did"], -1, "ERROR: ")
progress_callback(msg="Generate {} chunks".format(len(chunks)))
start_ts = timer()
try:
tk_count, vector_size = embedding(cks, embd_mdl, r["parser_config"], callback)
tk_count, vector_size = embedding(chunks, embedding_model, task_parser_config, progress_callback)
except Exception as e:
callback(-1, "Embedding error:{}".format(str(e)))
logging.exception("run_rembedding got exception")
progress_callback(-1, "Generate embedding error:{}".format(str(e)))
logging.exception("run_embedding got exception")
tk_count = 0
raise
logging.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
callback(msg="Finished embedding ({:.2f}s)!".format(timer() - st))
# logging.info(f"task_executor init_kb index {search.index_name(r["tenant_id"])} embd_mdl {embd_mdl.llm_name} vector length {vector_size}")
init_kb(r, vector_size)
chunk_count = len(set([c["id"] for c in cks]))
st = timer()
logging.info("Embedding {} elapsed: {:.2f}".format(task_document_name, timer() - start_ts))
progress_callback(msg="Embedding chunks ({:.2f}s)".format(timer() - start_ts))
# logging.info(f"task_executor init_kb index {search.index_name(task_tenant_id)} embedding_model {embedding_model.llm_name} vector length {vector_size}")
init_kb(task, vector_size)
chunk_count = len(set([chunk["id"] for chunk in chunks]))
start_ts = timer()
es_r = ""
es_bulk_size = 4
for b in range(0, len(cks), es_bulk_size):
es_r = settings.docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
for b in range(0, len(chunks), es_bulk_size):
es_r = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id)
if b % 128 == 0:
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
logging.info("Indexing {} elapsed: {:.2f}".format(task_document_name, timer() - start_ts))
if es_r:
callback(-1,
"Insert chunk error, detail info please check log file. Please also check Elasticsearch/Infinity status!")
settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
progress_callback(-1, "Insert chunk error, detail info please check log file. Please also check Elasticsearch/Infinity status!")
settings.docStoreConn.delete({"doc_id": task_doc_id}, search.index_name(task_tenant_id), task_dataset_id)
logging.error('Insert chunk error: ' + str(es_r))
raise Exception('Insert chunk error: ' + str(es_r))

if TaskService.do_cancel(r["id"]):
settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
if TaskService.do_cancel(task_id):
settings.docStoreConn.delete({"doc_id": task_doc_id}, search.index_name(task_tenant_id), task_dataset_id)
return

callback(1., msg="Index cost {:.2f}s.".format(timer() - st))
DocumentService.increment_chunk_num(
r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
logging.info(
"Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(
r["id"], tk_count, len(cks), timer() - st))
progress_callback(1., msg="Finish Index ({:.2f}s)".format(timer() - start_ts))
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, tk_count, chunk_count, 0)
logging.info("Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(task_id, tk_count, len(chunks), timer() - start_ts))


def handle_task():
Expand Down

0 comments on commit ea84cc2

Please sign in to comment.