From b94bd8d372d158921b82f9955e889d993fb5afe4 Mon Sep 17 00:00:00 2001 From: Zhichang Yu Date: Wed, 11 Dec 2024 11:53:10 +0800 Subject: [PATCH] Try to reuse existing chunks. Close #3793 --- api/apps/document_app.py | 28 +++++--- api/db/db_models.py | 14 ++++ api/db/services/document_service.py | 25 +++++++ api/db/services/task_service.py | 100 ++++++++++++++++++++++---- poetry.lock | 50 ++++++------- pyproject.toml | 1 + rag/svr/task_executor.py | 105 ++++++++++++++++++++-------- 7 files changed, 240 insertions(+), 83 deletions(-) diff --git a/api/apps/document_app.py b/api/apps/document_app.py index fdab39ad442..4b1c3e69995 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -21,19 +21,25 @@ from flask import request from flask_login import login_required, current_user -from api.db.db_models import Task, File +from deepdoc.parser.html_parser import RAGFlowHtmlParser +from rag.nlp import search + +from api.db import FileType, TaskStatus, ParserType, FileSource +from api.db.db_models import File, Task from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService -from api.db.services.task_service import TaskService, queue_tasks +from api.db.services.task_service import queue_tasks from api.db.services.user_service import UserTenantService -from deepdoc.parser.html_parser import RAGFlowHtmlParser -from rag.nlp import search from api.db.services import duplicate_name from api.db.services.knowledgebase_service import KnowledgebaseService -from api.utils.api_utils import server_error_response, get_data_error_result, validate_request -from api.utils import get_uuid -from api.db import FileType, TaskStatus, ParserType, FileSource +from api.db.services.task_service import TaskService from api.db.services.document_service import DocumentService, doc_upload_and_parse +from api.utils.api_utils import ( + server_error_response, + get_data_error_result, + validate_request, +) +from api.utils import get_uuid from api import settings from api.utils.api_utils import get_json_result from rag.utils.storage_factory import STORAGE_IMPL @@ -316,6 +322,7 @@ def rm(): b, n = File2DocumentService.get_storage_address(doc_id=doc_id) + TaskService.filter_delete([Task.doc_id == doc_id]) if not DocumentService.remove_document(doc, tenant_id): return get_data_error_result( message="Database error (Document removal)!") @@ -361,11 +368,12 @@ def run(): e, doc = DocumentService.get_by_id(id) if not e: return get_data_error_result(message="Document not found!") - if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): - settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) + if req.get("delete", False): + TaskService.filter_delete([Task.doc_id == id]) + if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): + settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) if str(req["run"]) == TaskStatus.RUNNING.value: - TaskService.filter_delete([Task.doc_id == id]) e, doc = DocumentService.get_by_id(id) doc = doc.to_dict() doc["tenant_id"] = tenant_id diff --git a/api/db/db_models.py b/api/db/db_models.py index 0c4d12c034c..24ad7f010a7 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -855,6 +855,8 @@ class Task(DataBaseModel): help_text="process message", default="") retry_count = IntegerField(default=0) + digest = TextField(null=True, help_text="task digest", default="") + chunk_ids = LongTextField(null=True, help_text="chunk ids", default="") class Dialog(DataBaseModel): @@ -1090,4 +1092,16 @@ def migrate_db(): ) except Exception: pass + try: + migrate( + migrator.add_column("task", "digest", TextField(null=True, help_text="task digest", default="")) + ) + except Exception: + pass + try: + migrate( + migrator.add_column("task", "chunk_ids", LongTextField(null=True, help_text="chunk ids", default="")) + ) + except Exception: + pass diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index f4b6f6874b8..cff5db583d6 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -282,6 +282,31 @@ def get_embd_id(cls, doc_id): return return docs[0]["embd_id"] + @classmethod + @DB.connection_context() + def get_chunking_config(cls, doc_id): + configs = ( + cls.model.select( + cls.model.id, + cls.model.kb_id, + cls.model.parser_id, + cls.model.parser_config, + Knowledgebase.language, + Knowledgebase.embd_id, + Tenant.id.alias("tenant_id"), + Tenant.img2txt_id, + Tenant.asr_id, + Tenant.llm_id, + ) + .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) + .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) + .where(cls.model.id == doc_id) + ) + configs = configs.dicts() + if not configs: + return None + return configs[0] + @classmethod @DB.connection_context() def get_doc_id_by_doc_name(cls, doc_name): diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 424a571ee57..5c53fe241bb 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -15,6 +15,8 @@ # import os import random +import xxhash +import bisect from api.db.db_utils import bulk_insert_into_db from deepdoc.parser import PdfParser @@ -29,7 +31,21 @@ from rag.settings import SVR_QUEUE_NAME from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.redis_conn import REDIS_CONN +from api import settings +from rag.nlp import search +def trim_header_by_lines(text: str, max_length) -> str: + if len(text) <= max_length: + return text + lines = text.split("\n") + total = 0 + idx = len(lines) - 1 + for i in range(len(lines)-1, -1, -1): + if total + len(lines[i]) > max_length: + break + idx = i + text2 = "\n".join(lines[idx:]) + return text2 class TaskService(CommonService): model = Task @@ -87,6 +103,30 @@ def get_task(cls, task_id): return docs[0] + @classmethod + @DB.connection_context() + def get_tasks(cls, doc_id: str): + fields = [ + cls.model.id, + cls.model.from_page, + cls.model.progress, + cls.model.digest, + cls.model.chunk_ids, + ] + tasks = ( + cls.model.select(*fields).order_by(cls.model.from_page.asc(), cls.model.create_time.desc()) + .where(cls.model.doc_id == doc_id) + ) + tasks = list(tasks.dicts()) + if not tasks: + return None + return tasks + + @classmethod + @DB.connection_context() + def update_chunk_ids(cls, id: str, chunk_ids: str): + cls.model.update(chunk_ids=chunk_ids).where(cls.model.id == id).execute() + @classmethod @DB.connection_context() def get_ongoing_doc_name(cls): @@ -133,22 +173,18 @@ def get_ongoing_doc_name(cls): @classmethod @DB.connection_context() def do_cancel(cls, id): - try: - task = cls.model.get_by_id(id) - _, doc = DocumentService.get_by_id(task.doc_id) - return doc.run == TaskStatus.CANCEL.value or doc.progress < 0 - except Exception: - pass - return False + task = cls.model.get_by_id(id) + _, doc = DocumentService.get_by_id(task.doc_id) + return doc.run == TaskStatus.CANCEL.value or doc.progress < 0 @classmethod @DB.connection_context() def update_progress(cls, id, info): if os.environ.get("MACOS"): if info["progress_msg"]: - cls.model.update( - progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"] - ).where(cls.model.id == id).execute() + task = cls.model.get_by_id(id) + progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 10000) + cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute() if "progress" in info: cls.model.update(progress=info["progress"]).where( cls.model.id == id @@ -157,9 +193,9 @@ def update_progress(cls, id, info): with DB.lock("update_progress", -1): if info["progress_msg"]: - cls.model.update( - progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"] - ).where(cls.model.id == id).execute() + task = cls.model.get_by_id(id) + progress_msg = trim_header_by_lines(task.progress_msg + "\n" + info["progress_msg"], 10000) + cls.model.update(progress_msg=progress_msg).where(cls.model.id == id).execute() if "progress" in info: cls.model.update(progress=info["progress"]).where( cls.model.id == id @@ -168,7 +204,7 @@ def update_progress(cls, id, info): def queue_tasks(doc: dict, bucket: str, name: str): def new_task(): - return {"id": get_uuid(), "doc_id": doc["id"]} + return {"id": get_uuid(), "doc_id": doc["id"], "progress": 0.0} tsks = [] @@ -203,10 +239,46 @@ def new_task(): else: tsks.append(new_task()) + chunking_config = DocumentService.get_chunking_config(doc["id"]) + for task in tsks: + hasher = xxhash.xxh64() + for field in sorted(chunking_config.keys()): + hasher.update(str(chunking_config[field]).encode("utf-8")) + for field in ["doc_id", "from_page", "to_page"]: + hasher.update(str(task.get(field, "")).encode("utf-8")) + task_digest = hasher.hexdigest() + task["digest"] = task_digest + task["progress"] = 0.0 + + prev_tasks = TaskService.get_tasks(doc["id"]) + if prev_tasks: + for task in tsks: + reuse_prev_task_chunks(task, prev_tasks, chunking_config) + TaskService.filter_delete([Task.doc_id == doc["id"]]) + chunk_ids = [] + for task in prev_tasks: + if task["chunk_ids"]: + chunk_ids.extend(task["chunk_ids"].split()) + if chunk_ids: + settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(chunking_config["tenant_id"]), chunking_config["kb_id"]) + bulk_insert_into_db(Task, tsks, True) DocumentService.begin2parse(doc["id"]) + tsks = [task for task in tsks if task["progress"] < 1.0] for t in tsks: assert REDIS_CONN.queue_product( SVR_QUEUE_NAME, message=t ), "Can't access Redis. Please check the Redis' status." + +def reuse_prev_task_chunks(task: dict, prev_tasks: list[dict], chunking_config: dict): + idx = bisect.bisect_left(prev_tasks, task["from_page"], key=lambda x: x["from_page"]) + if idx >= len(prev_tasks): + return + prev_task = prev_tasks[idx] + if prev_task["progress"] < 1.0 or prev_task["digest"] != task["digest"] or not prev_task["chunk_ids"]: + return + task["chunk_ids"] = prev_task["chunk_ids"] + task["progress"] = 1.0 + task["progress_msg"] = f"Page({task['from_page']}~{task['to_page']}): reused previous task's chunks" + prev_task["chunk_ids"] = "" diff --git a/poetry.lock b/poetry.lock index 5dacfd02546..4f0a3068b12 100644 --- a/poetry.lock +++ b/poetry.lock @@ -145,7 +145,7 @@ name = "aiolimiter" version = "1.2.0" description = "asyncio rate limiter, a leaky bucket implementation" optional = false -python-versions = "<4.0,>=3.8" +python-versions = ">=3.8,<4.0" files = [ {file = "aiolimiter-1.2.0-py3-none-any.whl", hash = "sha256:e3fc486a4506248cfdd1f3976920459945944518bbb1d1e6b2be1060232829e2"}, {file = "aiolimiter-1.2.0.tar.gz", hash = "sha256:761455d26df0d7a393f78bd39b022579e02ca5a65beb303a67bed2ded2f740ac"}, @@ -416,7 +416,7 @@ name = "aspose-slides" version = "24.12.0" description = "Aspose.Slides for Python via .NET is a presentation file formats processing library for working with Microsoft PowerPoint files without using Microsoft PowerPoint." optional = false -python-versions = "<3.14,>=3.5" +python-versions = ">=3.5,<3.14" files = [ {file = "Aspose.Slides-24.12.0-py3-none-macosx_10_14_x86_64.whl", hash = "sha256:ccfaa61a863ed28cd37b221e31a0edf4a83802599d76fb50861c25149ac5e5e3"}, {file = "Aspose.Slides-24.12.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:b050659129c5ca92e52fbcd7d5091caa244db731adb68fbea1fd0a8b9fd62a5a"}, @@ -568,7 +568,7 @@ name = "bce-python-sdk" version = "0.9.23" description = "BCE SDK for python" optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,<4,>=2.7" +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, <4" files = [ {file = "bce_python_sdk-0.9.23-py3-none-any.whl", hash = "sha256:8debe21a040e00060f6044877d594765ed7b18bc765c6bf16b878bca864140a3"}, {file = "bce_python_sdk-0.9.23.tar.gz", hash = "sha256:19739fed5cd0725356fc5ffa2acbdd8fb23f2a81edb91db21a03174551d0cf41"}, @@ -1502,7 +1502,7 @@ name = "cryptography" version = "44.0.0" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." optional = false -python-versions = "!=3.9.0,!=3.9.1,>=3.7" +python-versions = ">=3.7, !=3.9.0, !=3.9.1" files = [ {file = "cryptography-44.0.0-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:84111ad4ff3f6253820e6d3e58be2cc2a00adb29335d4cacb5ab4d4d34f2a123"}, {file = "cryptography-44.0.0-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b15492a11f9e1b62ba9d73c210e2416724633167de94607ec6069ef724fad092"}, @@ -1711,7 +1711,7 @@ name = "deprecated" version = "1.2.15" description = "Python @deprecated decorator to deprecate old python classes, functions or methods." optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ {file = "Deprecated-1.2.15-py2.py3-none-any.whl", hash = "sha256:353bc4a8ac4bfc96800ddab349d89c25dec1079f65fd53acdcc1e0b975b21320"}, {file = "deprecated-1.2.15.tar.gz", hash = "sha256:683e561a90de76239796e6b6feac66b99030d2dd3fcf61ef996330f14bbb9b0d"}, @@ -2042,7 +2042,7 @@ name = "fastembed" version = "0.3.6" description = "Fast, light, accurate library built for retrieval embedding generation" optional = false -python-versions = "<3.13,>=3.8.0" +python-versions = ">=3.8.0,<3.13" files = [ {file = "fastembed-0.3.6-py3-none-any.whl", hash = "sha256:2bf70edae28bb4ccd9e01617098c2075b0ba35b88025a3d22b0e1e85b2c488ce"}, {file = "fastembed-0.3.6.tar.gz", hash = "sha256:c93c8ec99b8c008c2d192d6297866b8d70ec7ac8f5696b34eb5ea91f85efd15f"}, @@ -2624,12 +2624,12 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ - {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" @@ -2965,7 +2965,7 @@ name = "graspologic" version = "3.4.1" description = "A set of Python modules for graph statistics" optional = false -python-versions = "<3.13,>=3.9" +python-versions = ">=3.9,<3.13" files = [ {file = "graspologic-3.4.1-py3-none-any.whl", hash = "sha256:c6563e087eda599bad1de831d4b7321c0daa7a82f4e85a7d7737ff67e07cdda2"}, {file = "graspologic-3.4.1.tar.gz", hash = "sha256:7561f0b852a2bccd351bff77e8db07d9892f9dfa35a420fdec01690e4fdc8075"}, @@ -3650,7 +3650,7 @@ name = "infinity-emb" version = "0.0.66" description = "Infinity is a high-throughput, low-latency REST API for serving text-embeddings, reranking models and clip." optional = false -python-versions = "<4,>=3.9" +python-versions = ">=3.9,<4" files = [ {file = "infinity_emb-0.0.66-py3-none-any.whl", hash = "sha256:1dc6ed9fa48e6cbe83650a7583dbbb4bc393900c39c326bb0aff2ddc090ac018"}, {file = "infinity_emb-0.0.66.tar.gz", hash = "sha256:9c9a361ccebf8e8f626c1f685286518d03d0c35e7d14179ae7c2500b4fc68b98"}, @@ -4098,7 +4098,7 @@ name = "litellm" version = "1.48.0" description = "Library to easily interface with LLM API providers" optional = false -python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" +python-versions = ">=3.8, !=2.7.*, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*, !=3.7.*" files = [ {file = "litellm-1.48.0-py3-none-any.whl", hash = "sha256:7765e8a92069778f5fc66aacfabd0e2f8ec8d74fb117f5e475567d89b0d376b9"}, {file = "litellm-1.48.0.tar.gz", hash = "sha256:31a9b8a25a9daf44c24ddc08bf74298da920f2c5cea44135e5061278d0aa6fc9"}, @@ -5416,9 +5416,9 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] @@ -5440,9 +5440,9 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] @@ -5657,8 +5657,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" @@ -6276,7 +6276,7 @@ name = "psutil" version = "6.1.0" description = "Cross-platform lib for process and system monitoring in Python." optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ {file = "psutil-6.1.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:ff34df86226c0227c52f38b919213157588a678d049688eded74c76c8ba4a5d0"}, {file = "psutil-6.1.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:c0e0c00aa18ca2d3b2b991643b799a15fc8f0563d2ebb6040f64ce8dc027b942"}, @@ -6298,8 +6298,8 @@ files = [ ] [package.extras] -dev = ["black", "check-manifest", "coverage", "packaging", "pylint", "pyperf", "pypinfo", "pytest-cov", "requests", "rstcheck", "ruff", "sphinx", "sphinx_rtd_theme", "toml-sort", "twine", "virtualenv", "wheel"] -test = ["pytest", "pytest-xdist", "setuptools"] +dev = ["black", "check-manifest", "coverage", "packaging", "pylint", "pyperf", "pypinfo", "pytest-cov", "requests", "rstcheck", "ruff", "sphinx", "sphinx-rtd-theme", "toml-sort", "twine", "virtualenv", "wheel"] +test = ["enum34", "futures", "ipaddress", "mock (==1.0.1)", "pytest (==4.6.11)", "pytest-xdist", "setuptools", "unittest2"] [[package]] name = "psycopg2-binary" @@ -7803,40 +7803,30 @@ files = [ {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:a606ef75a60ecf3d924613892cc603b154178ee25abb3055db5062da811fd969"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd5415dded15c3822597455bc02bcd66e81ef8b7a48cb71a33628fc9fdde39df"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f66efbc1caa63c088dead1c4170d148eabc9b80d95fb75b6c92ac0aad2437d76"}, - {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:22353049ba4181685023b25b5b51a574bce33e7f51c759371a7422dcae5402a6"}, - {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:932205970b9f9991b34f55136be327501903f7c66830e9760a8ffb15b07f05cd"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-win32.whl", hash = "sha256:3eac5a91891ceb88138c113f9db04f3cebdae277f5d44eaa3651a4f573e6a5da"}, {file = "ruamel.yaml.clib-0.2.12-cp310-cp310-win_amd64.whl", hash = "sha256:ab007f2f5a87bd08ab1499bdf96f3d5c6ad4dcfa364884cb4549aa0154b13a28"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:4a6679521a58256a90b0d89e03992c15144c5f3858f40d7c18886023d7943db6"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:d84318609196d6bd6da0edfa25cedfbabd8dbde5140a0a23af29ad4b8f91fb1e"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb43a269eb827806502c7c8efb7ae7e9e9d0573257a46e8e952f4d4caba4f31e"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:811ea1594b8a0fb466172c384267a4e5e367298af6b228931f273b111f17ef52"}, - {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cf12567a7b565cbf65d438dec6cfbe2917d3c1bdddfce84a9930b7d35ea59642"}, - {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7dd5adc8b930b12c8fc5b99e2d535a09889941aa0d0bd06f4749e9a9397c71d2"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-win32.whl", hash = "sha256:bd0a08f0bab19093c54e18a14a10b4322e1eacc5217056f3c063bd2f59853ce4"}, {file = "ruamel.yaml.clib-0.2.12-cp311-cp311-win_amd64.whl", hash = "sha256:a274fb2cb086c7a3dea4322ec27f4cb5cc4b6298adb583ab0e211a4682f241eb"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:20b0f8dc160ba83b6dcc0e256846e1a02d044e13f7ea74a3d1d56ede4e48c632"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:943f32bc9dedb3abff9879edc134901df92cfce2c3d5c9348f172f62eb2d771d"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95c3829bb364fdb8e0332c9931ecf57d9be3519241323c5274bd82f709cebc0c"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:749c16fcc4a2b09f28843cda5a193e0283e47454b63ec4b81eaa2242f50e4ccd"}, - {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bf165fef1f223beae7333275156ab2022cffe255dcc51c27f066b4370da81e31"}, - {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:32621c177bbf782ca5a18ba4d7af0f1082a3f6e517ac2a18b3974d4edf349680"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-win32.whl", hash = "sha256:e8c4ebfcfd57177b572e2040777b8abc537cdef58a2120e830124946aa9b42c5"}, {file = "ruamel.yaml.clib-0.2.12-cp312-cp312-win_amd64.whl", hash = "sha256:0467c5965282c62203273b838ae77c0d29d7638c8a4e3a1c8bdd3602c10904e4"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:4c8c5d82f50bb53986a5e02d1b3092b03622c02c2eb78e29bec33fd9593bae1a"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:e7e3736715fbf53e9be2a79eb4db68e4ed857017344d697e8b9749444ae57475"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b7e75b4965e1d4690e93021adfcecccbca7d61c7bddd8e22406ef2ff20d74ef"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:96777d473c05ee3e5e3c3e999f5d23c6f4ec5b0c38c098b3a5229085f74236c6"}, - {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:3bc2a80e6420ca8b7d3590791e2dfc709c88ab9152c00eeb511c9875ce5778bf"}, - {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:e188d2699864c11c36cdfdada94d781fd5d6b0071cd9c427bceb08ad3d7c70e1"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-win32.whl", hash = "sha256:6442cb36270b3afb1b4951f060eccca1ce49f3d087ca1ca4563a6eb479cb3de6"}, {file = "ruamel.yaml.clib-0.2.12-cp313-cp313-win_amd64.whl", hash = "sha256:e5b8daf27af0b90da7bb903a876477a9e6d7270be6146906b276605997c7e9a3"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:fc4b630cd3fa2cf7fce38afa91d7cfe844a9f75d7f0f36393fa98815e911d987"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:bc5f1e1c28e966d61d2519f2a3d451ba989f9ea0f2307de7bc45baa526de9e45"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a0e060aace4c24dcaf71023bbd7d42674e3b230f7e7b97317baf1e953e5b519"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2f1c3765db32be59d18ab3953f43ab62a761327aafc1594a2a1fbe038b8b8a7"}, - {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d85252669dc32f98ebcd5d36768f5d4faeaeaa2d655ac0473be490ecdae3c285"}, - {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e143ada795c341b56de9418c58d028989093ee611aa27ffb9b7f609c00d813ed"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-win32.whl", hash = "sha256:beffaed67936fbbeffd10966a4eb53c402fafd3d6833770516bf7314bc6ffa12"}, {file = "ruamel.yaml.clib-0.2.12-cp39-cp39-win_amd64.whl", hash = "sha256:040ae85536960525ea62868b642bdb0c2cc6021c9f9d507810c0c604e66f5a7b"}, {file = "ruamel.yaml.clib-0.2.12.tar.gz", hash = "sha256:6c8fbb13ec503f99a91901ab46e0b07ae7941cd527393187039aec586fdfd36f"}, @@ -7847,7 +7837,7 @@ name = "s3transfer" version = "0.10.4" description = "An Amazon S3 Transfer Manager" optional = false -python-versions = ">=3.8" +python-versions = ">= 3.8" files = [ {file = "s3transfer-0.10.4-py3-none-any.whl", hash = "sha256:244a76a24355363a68164241438de1b72f8781664920260c48465896b712a41e"}, {file = "s3transfer-0.10.4.tar.gz", hash = "sha256:29edc09801743c21eb5ecbc617a152df41d3c287f67b615f73e5f750583666a7"}, @@ -8309,7 +8299,7 @@ name = "smart-open" version = "7.0.5" description = "Utils for streaming large files (S3, HDFS, GCS, Azure Blob Storage, gzip, bz2...)" optional = false -python-versions = "<4.0,>=3.7" +python-versions = ">=3.7,<4.0" files = [ {file = "smart_open-7.0.5-py3-none-any.whl", hash = "sha256:8523ed805c12dff3eaa50e9c903a6cb0ae78800626631c5fe7ea073439847b89"}, {file = "smart_open-7.0.5.tar.gz", hash = "sha256:d3672003b1dbc85e2013e4983b88eb9a5ccfd389b0d4e5015f39a9ee5620ec18"}, @@ -10119,4 +10109,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "6e401212ce8c6499bc8c687355a6cfe7d0fc2beecf21e7895546f82a2b4332cd" +content-hash = "69fbe11a30c649544196546b9384ca5972bfd17a923b7dc8ff340f790984b5df" diff --git a/pyproject.toml b/pyproject.toml index 01dcc7de995..3aea0db40be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,6 +121,7 @@ pyicu = "^2.13.1" flasgger = "^0.9.7.1" polars = { version = "^1.9.0", markers = "platform_machine == 'x86_64'" } polars-lts-cpu = { version = "^1.9.0", markers = "platform_machine == 'arm64'" } +xxhash = "^3.5.0" [tool.poetry.group.full] diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 6d971a5d794..10a2082c6a9 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -39,8 +39,9 @@ import tracemalloc import numpy as np +from peewee import DoesNotExist -from api.db import LLMType, ParserType +from api.db import LLMType, ParserType, TaskStatus from api.db.services.dialog_service import keyword_extraction, question_proposal from api.db.services.document_service import DocumentService from api.db.services.llm_service import LLMBundle @@ -89,12 +90,23 @@ FAILED_TASKS = 0 CURRENT_TASK = None +class TaskCanceledException(Exception): + def __init__(self, msg): + self.msg = msg def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."): global PAYLOAD if prog is not None and prog < 0: msg = "[ERROR]" + msg - cancel = TaskService.do_cancel(task_id) + try: + cancel = TaskService.do_cancel(task_id) + except DoesNotExist: + logging.warning(f"set_progress task {task_id} is unknown") + if PAYLOAD: + PAYLOAD.ack() + PAYLOAD = None + return + if cancel: msg += " [Canceled]" prog = -1 @@ -105,18 +117,22 @@ def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing... d = {"progress_msg": msg} if prog is not None: d["progress"] = prog + + logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}") try: - logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}") TaskService.update_progress(task_id, d) - except Exception: - logging.exception(f"set_progress({task_id}) got exception") - - close_connection() - if cancel: + except DoesNotExist: + logging.warning(f"set_progress task {task_id} is unknown") if PAYLOAD: PAYLOAD.ack() PAYLOAD = None - os._exit(0) + return + + close_connection() + if cancel and PAYLOAD: + PAYLOAD.ack() + PAYLOAD = None + raise TaskCanceledException(msg) def collect(): @@ -136,16 +152,22 @@ def collect(): if not msg: return None - if TaskService.do_cancel(msg["id"]): - with mt_lock: - DONE_TASKS += 1 - logging.info("Task {} has been canceled.".format(msg["id"])) - return None - task = TaskService.get_task(msg["id"]) - if not task: + task = None + canceled = False + try: + task = TaskService.get_task(msg["id"]) + if task: + _, doc = DocumentService.get_by_id(task["doc_id"]) + canceled = doc.run == TaskStatus.CANCEL.value or doc.progress < 0 + except DoesNotExist: + pass + except Exception: + logging.exception("collect get_task exception") + if not task or canceled: + state = "is unknown" if not task else "has been cancelled" with mt_lock: DONE_TASKS += 1 - logging.warning("{} empty task!".format(msg["id"])) + logging.info(f"collect task {msg['id']} {state}") return None if msg.get("type", "") == "raptor": @@ -186,6 +208,8 @@ def build_chunks(task, progress_callback): to_page=task["to_page"], lang=task["language"], callback=progress_callback, kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]) logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"])) + except TaskCanceledException: + raise except Exception as e: progress_callback(-1, "Internal server error while chunking: %s" % str(e).replace("'", "")) logging.exception("Chunking {}/{} got exception".format(task["location"], task["name"])) @@ -358,6 +382,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None): return res, tk_count, vector_size + + def do_handle_task(task): task_id = task["id"] task_from_page = task["from_page"] @@ -373,6 +399,16 @@ def do_handle_task(task): # prepare the progress callback function progress_callback = partial(set_progress, task_id, task_from_page, task_to_page) + + try: + task_canceled = TaskService.do_cancel(task_id) + except DoesNotExist: + logging.warning(f"task {task_id} is unknown") + return + if task_canceled: + progress_callback(-1, msg="Task has been canceled.") + return + try: # bind embedding model embedding_model = LLMBundle(task_tenant_id, LLMType.EMBEDDING, llm_name=task_embedding_id, lang=task_language) @@ -390,6 +426,8 @@ def do_handle_task(task): # run RAPTOR chunks, token_count, vector_size = run_raptor(task, chat_model, embedding_model, progress_callback) + except TaskCanceledException: + raise except Exception as e: error_message = f'Fail to bind LLM used by RAPTOR: {str(e)}' progress_callback(-1, msg=error_message) @@ -420,6 +458,7 @@ def do_handle_task(task): progress_message = "Embedding chunks ({:.2f}s)".format(timer() - start_ts) logging.info(progress_message) progress_callback(msg=progress_message) + # logging.info(f"task_executor init_kb index {search.index_name(task_tenant_id)} embedding_model {embedding_model.llm_name} vector length {vector_size}") init_kb(task, vector_size) chunk_count = len(set([chunk["id"] for chunk in chunks])) @@ -430,23 +469,25 @@ def do_handle_task(task): doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id) if b % 128 == 0: progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="") - logging.info("Indexing {} elapsed: {:.2f}".format(task_document_name, timer() - start_ts)) - if doc_store_result: - error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" - progress_callback(-1, msg=error_message) - settings.docStoreConn.delete({"doc_id": task_doc_id}, search.index_name(task_tenant_id), task_dataset_id) - logging.error(error_message) - raise Exception(error_message) - - if TaskService.do_cancel(task_id): - settings.docStoreConn.delete({"doc_id": task_doc_id}, search.index_name(task_tenant_id), task_dataset_id) - return + if doc_store_result: + error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" + progress_callback(-1, msg=error_message) + raise Exception(error_message) + chunk_ids = [chunk["id"] for chunk in chunks[:b + es_bulk_size]] + chunk_ids_str = " ".join(chunk_ids) + try: + TaskService.update_chunk_ids(task["id"], chunk_ids_str) + except DoesNotExist: + logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.") + doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id) + return + logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts)) DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0) time_cost = timer() - start_ts progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost)) - logging.info("Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(task_id, token_count, len(chunks), time_cost)) + logging.info("Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), token_count, time_cost)) def handle_task(): @@ -462,6 +503,12 @@ def handle_task(): DONE_TASKS += 1 CURRENT_TASK = None logging.info(f"handle_task done for task {json.dumps(task)}") + except TaskCanceledException: + with mt_lock: + DONE_TASKS += 1 + CURRENT_TASK = None + logging.info(f"handle_task got TaskCanceledException for task {json.dumps(task)}") + logging.debug("handle_task got TaskCanceledException", exc_info=True) except Exception: with mt_lock: FAILED_TASKS += 1