Skip to content

Commit

Permalink
Try to reuse existing chunks. Close infiniflow#3793
Browse files Browse the repository at this point in the history
  • Loading branch information
yuzhichang committed Dec 11, 2024
1 parent 409acf0 commit 5aeef10
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 28 deletions.
10 changes: 5 additions & 5 deletions api/apps/document_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from flask import request
from flask_login import login_required, current_user

from api.db.db_models import Task, File
from api.db.db_models import File
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.task_service import queue_tasks
from api.db.services.user_service import UserTenantService
from deepdoc.parser.html_parser import RAGFlowHtmlParser
from rag.nlp import search
Expand Down Expand Up @@ -361,11 +361,11 @@ def run():
e, doc = DocumentService.get_by_id(id)
if not e:
return get_data_error_result(message="Document not found!")
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if req.get("delete", False):
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)

if str(req["run"]) == TaskStatus.RUNNING.value:
TaskService.filter_delete([Task.doc_id == id])
e, doc = DocumentService.get_by_id(id)
doc = doc.to_dict()
doc["tenant_id"] = tenant_id
Expand Down
2 changes: 2 additions & 0 deletions api/db/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,8 @@ class Task(DataBaseModel):
help_text="process message",
default="")
retry_count = IntegerField(default=0)
digest = TextField(null=True, help_text="task digest", default="")
chunk_ids = TextField(null=True, help_text="chunk ids", default="")


class Dialog(DataBaseModel):
Expand Down
52 changes: 46 additions & 6 deletions api/db/services/task_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@
from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.redis_conn import REDIS_CONN

def trim_header_by_lines(text: str, max_length) -> str:
if len(text) <= max_length:
return text
lines = text.split("\n")
total = 0
idx = len(lines) - 1
for i in range(len(lines)-1, -1, -1):
if total + len(lines[i]) > max_length:
break
idx = i
text2 = "\n".join(lines[idx:])
return text2

class TaskService(CommonService):
model = Task
Expand Down Expand Up @@ -87,6 +99,34 @@ def get_task(cls, task_id):

return docs[0]

@classmethod
@DB.connection_context()
def get_task2(cls, doc_id: str, from_page: int, to_page: int):
fields = [
cls.model.id,
cls.model.progress,
cls.model.digest,
cls.model.chunk_ids,
]
tasks = (
cls.model.select(*fields)
.where(cls.model.doc_id == doc_id, cls.model.from_page == from_page, cls.model.to_page == to_page)
)
tasks = list(tasks.dicts())
if not tasks:
return None
return tasks[0]

@classmethod
@DB.connection_context()
def update_digest(cls, id: str, digest: str):
cls.model.update(digest=digest).where(cls.model.id == id).execute()

@classmethod
@DB.connection_context()
def update_chunk_ids(cls, id: str, chunk_ids: str):
cls.model.update(chunk_ids=chunk_ids).where(cls.model.id == id).execute()

@classmethod
@DB.connection_context()
def get_ongoing_doc_name(cls):
Expand Down Expand Up @@ -146,9 +186,9 @@ def do_cancel(cls, id):
def update_progress(cls, id, info):
if os.environ.get("MACOS"):
if info["progress_msg"]:
cls.model.update(
progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]
).where(cls.model.id == id).execute()
task = cls.model.get_by_id(id)
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 10000)
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
if "progress" in info:
cls.model.update(progress=info["progress"]).where(
cls.model.id == id
Expand All @@ -157,9 +197,9 @@ def update_progress(cls, id, info):

with DB.lock("update_progress", -1):
if info["progress_msg"]:
cls.model.update(
progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]
).where(cls.model.id == id).execute()
task = cls.model.get_by_id(id)
progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 10000)
cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute()
if "progress" in info:
cls.model.update(progress=info["progress"]).where(
cls.model.id == id
Expand Down
73 changes: 56 additions & 17 deletions rag/svr/task_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from rag.utils import rmSpace, num_tokens_from_string
from rag.utils.redis_conn import REDIS_CONN, Payload
from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.doc_store_conn import OrderByExpr

BATCH_SIZE = 64

Expand Down Expand Up @@ -89,6 +90,9 @@
FAILED_TASKS = 0
CURRENT_TASK = None

class TaskCanceledException(Exception):
def __init__(self, msg):
self.msg = msg

def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
global PAYLOAD
Expand All @@ -112,11 +116,10 @@ def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing...
logging.exception(f"set_progress({task_id}) got exception")

close_connection()
if cancel:
if PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
os._exit(0)
if cancel and PAYLOAD:
PAYLOAD.ack()
PAYLOAD = None
raise TaskCanceledException(msg)


def collect():
Expand Down Expand Up @@ -358,6 +361,37 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
return res, tk_count, vector_size


def reuse_prev_task_chunks(task: dict) -> bool:
md5 = hashlib.md5()
for field in ["task_type", "tenant_id", "kb_id", "doc_id", "name", "from_page", "to_page", "parser_config", "embd_id",
"language", "llm_id"]:
md5.update(str(task.get(field, "")).encode("utf-8"))
task_digest = md5.hexdigest()
TaskService.update_digest(task["id"], task_digest)

prev_task = TaskService.get_task2(task["doc_id"], task["from_page"], task["to_page"])
if prev_task is None:
return False
chunk_ids = prev_task["chunk_ids"]
chunk_ids = [x for x in chunk_ids.split() if x]
reusable = False
if prev_task["progress"] == 1.0 and prev_task["digest"] == task_digest and chunk_ids:
tenant_id = task["tenant_id"]
kb_ids = [task["kb_id"]]
res = settings.docStoreConn.search(["id"], [], {"id": chunk_ids}, [], OrderByExpr(), 0, len(chunk_ids), search.index_name(tenant_id), kb_ids)
dict_chunks = settings.docStoreConn.getFields(res, ["id"])
if len(chunk_ids) == len(dict_chunks):
reusable = True
if reusable:
TaskService.update_chunk_ids(task["id"], " ".join(chunk_ids))
return True

if chunk_ids:
settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task["tenant_id"]), [task["kb_id"]])
return False



def do_handle_task(task):
task_id = task["id"]
task_from_page = task["from_page"]
Expand All @@ -373,6 +407,16 @@ def do_handle_task(task):

# prepare the progress callback function
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)

task_canceled = TaskService.do_cancel(task_id)
if task_canceled:
progress_callback(1.0, msg="Task has been canceled.")
return
reused = reuse_prev_task_chunks(task)
if reused:
progress_callback(1.0, msg="Chunks of task already exist, skip.")
return

try:
# bind embedding model
embedding_model = LLMBundle(task_tenant_id, LLMType.EMBEDDING, llm_name=task_embedding_id, lang=task_language)
Expand Down Expand Up @@ -420,6 +464,7 @@ def do_handle_task(task):
progress_message = "Embedding chunks ({:.2f}s)".format(timer() - start_ts)
logging.info(progress_message)
progress_callback(msg=progress_message)

# logging.info(f"task_executor init_kb index {search.index_name(task_tenant_id)} embedding_model {embedding_model.llm_name} vector length {vector_size}")
init_kb(task, vector_size)
chunk_count = len(set([chunk["id"] for chunk in chunks]))
Expand All @@ -430,23 +475,17 @@ def do_handle_task(task):
doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id)
if b % 128 == 0:
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
logging.info("Indexing {} elapsed: {:.2f}".format(task_document_name, timer() - start_ts))
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
progress_callback(-1, msg=error_message)
settings.docStoreConn.delete({"doc_id": task_doc_id}, search.index_name(task_tenant_id), task_dataset_id)
logging.error(error_message)
raise Exception(error_message)

if TaskService.do_cancel(task_id):
settings.docStoreConn.delete({"doc_id": task_doc_id}, search.index_name(task_tenant_id), task_dataset_id)
return
if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
progress_callback(-1, msg=error_message)
raise Exception(error_message)
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts))

DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)

time_cost = timer() - start_ts
progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost))
logging.info("Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(task_id, token_count, len(chunks), time_cost))
logging.info("Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), token_count, time_cost))


def handle_task():
Expand Down

0 comments on commit 5aeef10

Please sign in to comment.