From 8714754afc2406fb6118762e3fc8b59014005811 Mon Sep 17 00:00:00 2001 From: liuhua <10215101452@stu.ecnu.edu.cn> Date: Wed, 23 Oct 2024 12:02:18 +0800 Subject: [PATCH] Fix some issues in API (#2982) ### What problem does this PR solve? Fix some issues in API ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn> --- api/apps/sdk/chat.py | 59 ++++++++++-------- api/apps/sdk/dataset.py | 42 ++++++++----- api/apps/sdk/doc.py | 79 ++++++++++++++++++++---- api/apps/sdk/session.py | 2 +- api/db/services/knowledgebase_service.py | 12 +++- api/utils/api_utils.py | 21 ++++++- sdk/python/ragflow/modules/chat.py | 2 +- sdk/python/ragflow/modules/chunk.py | 6 +- sdk/python/ragflow/modules/dataset.py | 15 ++--- sdk/python/ragflow/modules/document.py | 20 +++--- sdk/python/ragflow/modules/session.py | 4 +- sdk/python/ragflow/ragflow.py | 33 ++++------ sdk/python/test/t_chat.py | 7 ++- sdk/python/test/t_session.py | 10 +-- 14 files changed, 206 insertions(+), 106 deletions(-) diff --git a/api/apps/sdk/chat.py b/api/apps/sdk/chat.py index 8357156a9b0..99def010f57 100644 --- a/api/apps/sdk/chat.py +++ b/api/apps/sdk/chat.py @@ -18,20 +18,21 @@ from api.db import StatusEnum from api.db.services.dialog_service import DialogService from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.llm_service import TenantLLMService +from api.db.services.llm_service import TenantLLMService from api.db.services.user_service import TenantService from api.utils import get_uuid from api.utils.api_utils import get_error_data_result, token_required from api.utils.api_utils import get_result + @manager.route('/chat', methods=['POST']) @token_required def create(tenant_id): req=request.json - ids= req.get("knowledgebases") + ids= req.get("datasets") if not ids: - return get_error_data_result(retmsg="`knowledgebases` is required") + return get_error_data_result(retmsg="`datasets` is required") for kb_id in ids: kbs = KnowledgebaseService.query(id=kb_id,tenant_id=tenant_id) if not kbs: @@ -45,6 +46,8 @@ def create(tenant_id): if llm: if "model_name" in llm: req["llm_id"] = llm.pop("model_name") + if not TenantLLMService.query(tenant_id=tenant_id,llm_name=req["llm_id"],model_type="chat"): + return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist") req["llm_setting"] = req.pop("llm") e, tenant = TenantService.get_by_id(tenant_id) if not e: @@ -73,10 +76,10 @@ def create(tenant_id): req["top_n"] = req.get("top_n", 6) req["top_k"] = req.get("top_k", 1024) req["rerank_id"] = req.get("rerank_id", "") - if req.get("llm_id"): - if not TenantLLMService.query(llm_name=req["llm_id"]): - return get_error_data_result(retmsg="the model_name does not exist.") - else: + if req.get("rerank_id"): + if not TenantLLMService.query(tenant_id=tenant_id,llm_name=req.get("rerank_id"),model_type="rerank"): + return get_error_data_result(f"`rerank_model` {req.get('rerank_id')} doesn't exist") + if not req.get("llm_id"): req["llm_id"] = tenant.llm_id if not req.get("name"): return get_error_data_result(retmsg="`name` is required.") @@ -135,7 +138,7 @@ def create(tenant_id): res["llm"] = res.pop("llm_setting") res["llm"]["model_name"] = res.pop("llm_id") del res["kb_ids"] - res["knowledgebases"] = req["knowledgebases"] + res["datasets"] = req["datasets"] res["avatar"] = res.pop("icon") return get_result(data=res) @@ -145,27 +148,32 @@ def update(tenant_id,chat_id): if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): return get_error_data_result(retmsg='You do not own the chat') req =request.json - if "knowledgebases" in req: - if not req.get("knowledgebases"): - return get_error_data_result(retmsg="`knowledgebases` can't be empty value") - kb_list = [] - for kb in req.get("knowledgebases"): - if not kb["id"]: - return get_error_data_result(retmsg="knowledgebase needs id") - if not KnowledgebaseService.query(id=kb["id"], tenant_id=tenant_id): - return get_error_data_result(retmsg="you do not own the knowledgebase") - # if not DocumentService.query(kb_id=kb["id"]): - # return get_error_data_result(retmsg="There is a invalid knowledgebase") - kb_list.append(kb["id"]) - req["kb_ids"] = kb_list + ids = req.get("datasets") + if "datasets" in req: + if not ids: + return get_error_data_result("`datasets` can't be empty") + if ids: + for kb_id in ids: + kbs = KnowledgebaseService.query(id=kb_id, tenant_id=tenant_id) + if not kbs: + return get_error_data_result(f"You don't own the dataset {kb_id}") + kb = kbs[0] + if kb.chunk_num == 0: + return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") + req["kb_ids"] = ids llm = req.get("llm") if llm: if "model_name" in llm: req["llm_id"] = llm.pop("model_name") + if not TenantLLMService.query(tenant_id=tenant_id,llm_name=req["llm_id"],model_type="chat"): + return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist") req["llm_setting"] = req.pop("llm") e, tenant = TenantService.get_by_id(tenant_id) if not e: return get_error_data_result(retmsg="Tenant not found!") + if req.get("rerank_model"): + if not TenantLLMService.query(tenant_id=tenant_id,llm_name=req.get("rerank_model"),model_type="rerank"): + return get_error_data_result(f"`rerank_model` {req.get('rerank_model')} doesn't exist") # prompt prompt = req.get("prompt") key_mapping = {"parameters": "variables", @@ -185,9 +193,6 @@ def update(tenant_id,chat_id): req["prompt_config"] = req.pop("prompt") e, res = DialogService.get_by_id(chat_id) res = res.to_json() - if "llm_id" in req: - if not TenantLLMService.query(llm_name=req["llm_id"]): - return get_error_data_result(retmsg="The `model_name` does not exist.") if "name" in req: if not req.get("name"): return get_error_data_result(retmsg="`name` is not empty.") @@ -209,8 +214,8 @@ def update(tenant_id,chat_id): # avatar if "avatar" in req: req["icon"] = req.pop("avatar") - if "knowledgebases" in req: - req.pop("knowledgebases") + if "datasets" in req: + req.pop("datasets") if not DialogService.update_by_id(chat_id, req): return get_error_data_result(retmsg="Chat not found!") return get_result() @@ -279,7 +284,7 @@ def list_chat(tenant_id): return get_error_data_result(retmsg=f"Don't exist the kb {kb_id}") kb_list.append(kb[0].to_json()) del res["kb_ids"] - res["knowledgebases"] = kb_list + res["datasets"] = kb_list res["avatar"] = res.pop("icon") list_assts.append(res) return get_result(data=list_assts) diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index eaca4e877e9..70d3e86276c 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -15,17 +15,17 @@ # from flask import request - from api.db import StatusEnum, FileSource from api.db.db_models import File from api.db.services.document_service import DocumentService from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.llm_service import TenantLLMService from api.db.services.user_service import TenantService from api.settings import RetCode from api.utils import get_uuid -from api.utils.api_utils import get_result, token_required, get_error_data_result, valid +from api.utils.api_utils import get_result, token_required, get_error_data_result, valid,get_parser_config @manager.route('/dataset', methods=['POST']) @@ -36,15 +36,17 @@ def create(tenant_id): permission = req.get("permission") language = req.get("language") chunk_method = req.get("chunk_method") - valid_permission = ("me", "team") - valid_language =("Chinese", "English") - valid_chunk_method = ("naive","manual","qa","table","paper","book","laws","presentation","picture","one","knowledge_graph","email") + parser_config = req.get("parser_config") + valid_permission = {"me", "team"} + valid_language ={"Chinese", "English"} + valid_chunk_method = {"naive","manual","qa","table","paper","book","laws","presentation","picture","one","knowledge_graph","email"} check_validation=valid(permission,valid_permission,language,valid_language,chunk_method,valid_chunk_method) if check_validation: return check_validation - if "tenant_id" in req or "embedding_model" in req: + req["parser_config"]=get_parser_config(chunk_method,parser_config) + if "tenant_id" in req: return get_error_data_result( - retmsg="`tenant_id` or `embedding_model` must not be provided") + retmsg="`tenant_id` must not be provided") chunk_count=req.get("chunk_count") document_count=req.get("document_count") if chunk_count or document_count: @@ -59,9 +61,13 @@ def create(tenant_id): retmsg="`name` is not empty string!") if KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): return get_error_data_result( - retmsg="Duplicated knowledgebase name in creating dataset.") + retmsg="Duplicated dataset name in creating dataset.") req["tenant_id"] = req['created_by'] = tenant_id - req['embedding_model'] = t.embd_id + if not req.get("embedding_model"): + req['embedding_model'] = t.embd_id + else: + if not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model")): + return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist") key_mapping = { "chunk_num": "chunk_count", "doc_num": "document_count", @@ -116,10 +122,12 @@ def update(tenant_id,dataset_id): permission = req.get("permission") language = req.get("language") chunk_method = req.get("chunk_method") - valid_permission = ("me", "team") - valid_language =("Chinese", "English") - valid_chunk_method = ("naive","manual","qa","table","paper","book","laws","presentation","picture","one","knowledge_graph","email") - check_validation=valid(permission,valid_permission,language,valid_language,chunk_method,valid_chunk_method) + parser_config = req.get("parser_config") + valid_permission = {"me", "team"} + valid_language = {"Chinese", "English"} + valid_chunk_method = {"naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", + "knowledge_graph", "email"} + check_validation = valid(permission, valid_permission, language, valid_language, chunk_method, valid_chunk_method) if check_validation: return check_validation if "tenant_id" in req: @@ -142,10 +150,16 @@ def update(tenant_id,dataset_id): return get_error_data_result( retmsg="If `chunk_count` is not 0, `chunk_method` is not changeable.") req['parser_id'] = req.pop('chunk_method') + if req['parser_id'] != kb.parser_id: + req["parser_config"] = get_parser_config(chunk_method, parser_config) if "embedding_model" in req: if kb.chunk_num != 0 and req['embedding_model'] != kb.embd_id: return get_error_data_result( retmsg="If `chunk_count` is not 0, `embedding_method` is not changeable.") + if not req.get("embedding_model"): + return get_error_data_result("`embedding_model` can't be empty") + if not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model")): + return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist") req['embd_id'] = req.pop('embedding_model') if "name" in req: req["name"] = req["name"].strip() @@ -153,7 +167,7 @@ def update(tenant_id,dataset_id): and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0: return get_error_data_result( - retmsg="Duplicated knowledgebase name in updating dataset.") + retmsg="Duplicated dataset name in updating dataset.") if not KnowledgebaseService.update_by_id(kb.id, req): return get_error_data_result(retmsg="Update dataset error.(Database error)") return get_result(retcode=RetCode.SUCCESS) diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index b32eeafeb63..59fae924d72 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -39,7 +39,7 @@ from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService from api.settings import RetCode, retrievaler -from api.utils.api_utils import construct_json_result +from api.utils.api_utils import construct_json_result,get_parser_config from rag.nlp import search from rag.utils import rmSpace from rag.utils.es_conn import ELASTICSEARCH @@ -49,6 +49,10 @@ MAXIMUM_OF_UPLOADING_FILES = 256 +MAXIMUM_OF_UPLOADING_FILES = 256 + +MAXIMUM_OF_UPLOADING_FILES = 256 + @manager.route('/dataset//document', methods=['POST']) @token_required @@ -61,14 +65,41 @@ def upload(dataset_id, tenant_id): if file_obj.filename == '': return get_result( retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR) + # total size + total_size = 0 + for file_obj in file_objs: + file_obj.seek(0, os.SEEK_END) + total_size += file_obj.tell() + file_obj.seek(0) + MAX_TOTAL_FILE_SIZE=10*1024*1024 + if total_size > MAX_TOTAL_FILE_SIZE: + return get_result( + retmsg=f'Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)', + retcode=RetCode.ARGUMENT_ERROR) e, kb = KnowledgebaseService.get_by_id(dataset_id) if not e: - raise LookupError(f"Can't find the knowledgebase with ID {dataset_id}!") - err, _ = FileService.upload_document(kb, file_objs, tenant_id) + raise LookupError(f"Can't find the dataset with ID {dataset_id}!") + err, files= FileService.upload_document(kb, file_objs, tenant_id) if err: return get_result( retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR) - return get_result() + # rename key's name + renamed_doc_list = [] + for file in files: + doc = file[0] + key_mapping = { + "chunk_num": "chunk_count", + "kb_id": "dataset_id", + "token_num": "token_count", + "parser_id": "chunk_method" + } + renamed_doc = {} + for key, value in doc.items(): + new_key = key_mapping.get(key, key) + renamed_doc[new_key] = value + renamed_doc["run"] = "UNSTART" + renamed_doc_list.append(renamed_doc) + return get_result(data=renamed_doc_list) @manager.route('/dataset//info/', methods=['PUT']) @@ -97,7 +128,7 @@ def update_doc(tenant_id, dataset_id, document_id): for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): if d.name == req["name"]: return get_error_data_result( - retmsg="Duplicated document name in the same knowledgebase.") + retmsg="Duplicated document name in the same dataset.") if not DocumentService.update_by_id( document_id, {"name": req["name"]}): return get_error_data_result( @@ -110,6 +141,9 @@ def update_doc(tenant_id, dataset_id, document_id): if "parser_config" in req: DocumentService.update_parser_config(doc.id, req["parser_config"]) if "chunk_method" in req: + valid_chunk_method = {"naive","manual","qa","table","paper","book","laws","presentation","picture","one","knowledge_graph","email"} + if req.get("chunk_method") not in valid_chunk_method: + return get_error_data_result(f"`chunk_method` {req['chunk_method']} doesn't exist") if doc.parser_id.lower() == req["chunk_method"].lower(): return get_result() @@ -122,6 +156,7 @@ def update_doc(tenant_id, dataset_id, document_id): "run": TaskStatus.UNSTART.value}) if not e: return get_error_data_result(retmsg="Document not found!") + req["parser_config"] = get_parser_config(req["chunk_method"], req.get("parser_config")) if doc.token_num > 0: e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, doc.process_duation * -1) @@ -182,12 +217,21 @@ def list_docs(dataset_id, tenant_id): for doc in docs: key_mapping = { "chunk_num": "chunk_count", - "kb_id": "knowledgebase_id", + "kb_id": "dataset_id", "token_num": "token_count", "parser_id": "chunk_method" } + run_mapping = { + "0" :"UNSTART", + "1":"RUNNING", + "2":"CANCEL", + "3":"DONE", + "4":"FAIL" + } renamed_doc = {} for key, value in doc.items(): + if key =="run": + renamed_doc["run"]=run_mapping.get(str(value)) new_key = key_mapping.get(key, key) renamed_doc[new_key] = value renamed_doc_list.append(renamed_doc) @@ -353,9 +397,10 @@ def list_chunks(tenant_id,dataset_id,document_id): return get_result(data=res) + @manager.route('/dataset//document//chunk', methods=['POST']) @token_required -def create(tenant_id,dataset_id,document_id): +def add_chunk(tenant_id,dataset_id,document_id): if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}.") doc = DocumentService.query(id=document_id, kb_id=dataset_id) @@ -441,6 +486,7 @@ def rm_chunk(tenant_id,dataset_id,document_id): return get_result() + @manager.route('/dataset//document//chunk/', methods=['PUT']) @token_required def update_chunk(tenant_id,dataset_id,document_id,chunk_id): @@ -470,12 +516,12 @@ def update_chunk(tenant_id,dataset_id,document_id,chunk_id): d["content_ltks"] = rag_tokenizer.tokenize(d["content_with_weight"]) d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) if "important_keywords" in req: - if type(req["important_keywords"]) != list: - return get_error_data_result("`important_keywords` is required to be a list") + if not isinstance(req["important_keywords"],list): + return get_error_data_result("`important_keywords` should be a list") d["important_kwd"] = req.get("important_keywords") d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_keywords"])) if "available" in req: - d["available_int"] = req["available"] + d["available_int"] = int(req["available"]) embd_id = DocumentService.get_embd_id(document_id) embd_mdl = TenantLLMService.model_instance( tenant_id, LLMType.EMBEDDING.value, embd_id) @@ -498,6 +544,7 @@ def update_chunk(tenant_id,dataset_id,document_id,chunk_id): return get_result() + @manager.route('/retrieval', methods=['POST']) @token_required def retrieval_test(tenant_id): @@ -505,6 +552,8 @@ def retrieval_test(tenant_id): if not req.get("datasets"): return get_error_data_result("`datasets` is required.") kb_ids = req["datasets"] + if not isinstance(kb_ids,list): + return get_error_data_result("`datasets` should be a list") kbs = KnowledgebaseService.get_by_ids(kb_ids) embd_nms = list(set([kb.embd_id for kb in kbs])) if len(embd_nms) != 1: @@ -518,9 +567,15 @@ def retrieval_test(tenant_id): if "question" not in req: return get_error_data_result("`question` is required.") page = int(req.get("offset", 1)) - size = int(req.get("limit", 30)) + size = int(req.get("limit", 1024)) question = req["question"] doc_ids = req.get("documents", []) + if not isinstance(req.get("documents"),list): + return get_error_data_result("`documents` should be a list") + doc_ids_list=KnowledgebaseService.list_documents_by_ids(kb_ids) + for doc_id in doc_ids: + if doc_id not in doc_ids_list: + return get_error_data_result(f"You don't own the document {doc_id}") similarity_threshold = float(req.get("similarity_threshold", 0.2)) vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) top = int(req.get("top_k", 1024)) @@ -531,7 +586,7 @@ def retrieval_test(tenant_id): try: e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) if not e: - return get_error_data_result(retmsg="Knowledgebase not found!") + return get_error_data_result(retmsg="Dataset not found!") embd_mdl = TenantLLMService.model_instance( kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 814fc215732..fe0297a8c7d 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -199,7 +199,7 @@ def list(chat_id,tenant_id): "content": chunk["content_with_weight"], "document_id": chunk["doc_id"], "document_name": chunk["docnm_kwd"], - "knowledgebase_id": chunk["kb_id"], + "dataset_id": chunk["kb_id"], "image_id": chunk["img_id"], "similarity": chunk["similarity"], "vector_similarity": chunk["vector_similarity"], diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index 602d6097a57..2baba4eaf4e 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -14,13 +14,23 @@ # limitations under the License. # from api.db import StatusEnum, TenantPermission -from api.db.db_models import Knowledgebase, DB, Tenant, User, UserTenant +from api.db.db_models import Knowledgebase, DB, Tenant, User, UserTenant,Document from api.db.services.common_service import CommonService class KnowledgebaseService(CommonService): model = Knowledgebase + @classmethod + @DB.connection_context() + def list_documents_by_ids(cls,kb_ids): + doc_ids=cls.model.select(Document.id.alias("document_id")).join(Document,on=(cls.model.id == Document.kb_id)).where( + cls.model.id.in_(kb_ids) + ) + doc_ids =list(doc_ids.dicts()) + doc_ids = [doc["document_id"] for doc in doc_ids] + return doc_ids + @classmethod @DB.connection_context() def get_by_tenant_ids(cls, joined_tenant_ids, user_id, diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 95792d72e84..1279662b027 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -337,4 +337,23 @@ def valid(permission,valid_permission,language,valid_language,chunk_method,valid def valid_parameter(parameter,valid_values): if parameter and parameter not in valid_values: - return get_error_data_result(f"{parameter} not in {valid_values}") \ No newline at end of file + return get_error_data_result(f"{parameter} not in {valid_values}") + +def get_parser_config(chunk_method,parser_config): + if parser_config: + return parser_config + if not chunk_method: + chunk_method = "naive" + key_mapping={"naive":{"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False,"layout_recognize": True, "raptor": {"user_raptor": False}}, + "qa":{"raptor":{"use_raptor":False}}, + "resume":None, + "manual":{"raptor":{"use_raptor":False}}, + "table":None, + "paper":{"raptor":{"use_raptor":False}}, + "book":{"raptor":{"use_raptor":False}}, + "laws":{"raptor":{"use_raptor":False}}, + "presentation":{"raptor":{"use_raptor":False}}, + "one":None, + "knowledge_graph":{"chunk_token_num":8192,"delimiter":"\\n!?;。;!?","entity_types":["organization","person","location","event","time"]}} + parser_config=key_mapping[chunk_method] + return parser_config \ No newline at end of file diff --git a/sdk/python/ragflow/modules/chat.py b/sdk/python/ragflow/modules/chat.py index 04cb7ebc26c..596ba7b5cf9 100644 --- a/sdk/python/ragflow/modules/chat.py +++ b/sdk/python/ragflow/modules/chat.py @@ -9,7 +9,7 @@ def __init__(self, rag, res_dict): self.id = "" self.name = "assistant" self.avatar = "path/to/avatar" - self.knowledgebases = ["kb1"] + self.datasets = ["kb1"] self.llm = Chat.LLM(rag, {}) self.prompt = Chat.Prompt(rag, {}) super().__init__(rag, res_dict) diff --git a/sdk/python/ragflow/modules/chunk.py b/sdk/python/ragflow/modules/chunk.py index 49132af91f0..dd4763a302d 100644 --- a/sdk/python/ragflow/modules/chunk.py +++ b/sdk/python/ragflow/modules/chunk.py @@ -8,10 +8,10 @@ def __init__(self, rag, res_dict): self.important_keywords = [] self.create_time = "" self.create_timestamp = 0.0 - self.knowledgebase_id = None + self.dataset_id = None self.document_name = "" self.document_id = "" - self.available = 1 + self.available = True for k in list(res_dict.keys()): if k not in self.__dict__: res_dict.pop(k) @@ -19,7 +19,7 @@ def __init__(self, rag, res_dict): def update(self,update_message:dict): - res = self.put(f"/dataset/{self.knowledgebase_id}/document/{self.document_id}/chunk/{self.id}",update_message) + res = self.put(f"/dataset/{self.dataset_id}/document/{self.document_id}/chunk/{self.id}",update_message) res = res.json() if res.get("code") != 0 : raise Exception(res["message"]) diff --git a/sdk/python/ragflow/modules/dataset.py b/sdk/python/ragflow/modules/dataset.py index 44028813b43..1f72c6cc98a 100644 --- a/sdk/python/ragflow/modules/dataset.py +++ b/sdk/python/ragflow/modules/dataset.py @@ -10,10 +10,6 @@ class DataSet(Base): class ParserConfig(Base): def __init__(self, rag, res_dict): - self.chunk_token_count = 128 - self.layout_recognize = True - self.delimiter = '\n!?。;!?' - self.task_page_size = 12 super().__init__(rag, res_dict) def __init__(self, rag, res_dict): @@ -43,11 +39,16 @@ def update(self, update_message: dict): def upload_documents(self,document_list: List[dict]): url = f"/dataset/{self.id}/document" - files = [("file",(ele["name"],ele["blob"])) for ele in document_list] + files = [("file",(ele["displayed_name"],ele["blob"])) for ele in document_list] res = self.post(path=url,json=None,files=files) res = res.json() - if res.get("code") != 0: - raise Exception(res.get("message")) + if res.get("code") == 0: + doc_list=[] + for doc in res["data"]: + document = Document(self.rag,doc) + doc_list.append(document) + return doc_list + raise Exception(res.get("message")) def list_documents(self, id: str = None, keywords: str = None, offset: int =1, limit: int = 1024, orderby: str = "create_time", desc: bool = True): res = self.get(f"/dataset/{self.id}/info",params={"id": id,"keywords": keywords,"offset": offset,"limit": limit,"orderby": orderby,"desc": desc}) diff --git a/sdk/python/ragflow/modules/document.py b/sdk/python/ragflow/modules/document.py index fcf02115d4b..64ba8f9209f 100644 --- a/sdk/python/ragflow/modules/document.py +++ b/sdk/python/ragflow/modules/document.py @@ -5,12 +5,16 @@ class Document(Base): + class ParserConfig(Base): + def __init__(self, rag, res_dict): + super().__init__(rag, res_dict) + def __init__(self, rag, res_dict): self.id = "" self.name = "" self.thumbnail = None - self.knowledgebase_id = None - self.chunk_method = "" + self.dataset_id = None + self.chunk_method = "naive" self.parser_config = {"pages": [[1, 1000000]]} self.source_type = "local" self.type = "" @@ -31,14 +35,14 @@ def __init__(self, rag, res_dict): def update(self, update_message: dict): - res = self.put(f'/dataset/{self.knowledgebase_id}/info/{self.id}', + res = self.put(f'/dataset/{self.dataset_id}/info/{self.id}', update_message) res = res.json() if res.get("code") != 0: raise Exception(res["message"]) def download(self): - res = self.get(f"/dataset/{self.knowledgebase_id}/document/{self.id}") + res = self.get(f"/dataset/{self.dataset_id}/document/{self.id}") try: res = res.json() raise Exception(res.get("message")) @@ -48,7 +52,7 @@ def download(self): def list_chunks(self,offset=0, limit=30, keywords="", id:str=None): data={"document_id": self.id,"keywords": keywords,"offset":offset,"limit":limit,"id":id} - res = self.get(f'/dataset/{self.knowledgebase_id}/document/{self.id}/chunk', data) + res = self.get(f'/dataset/{self.dataset_id}/document/{self.id}/chunk', data) res = res.json() if res.get("code") == 0: chunks=[] @@ -59,15 +63,15 @@ def list_chunks(self,offset=0, limit=30, keywords="", id:str=None): raise Exception(res.get("message")) - def add_chunk(self, content: str): - res = self.post(f'/dataset/{self.knowledgebase_id}/document/{self.id}/chunk', {"content":content}) + def add_chunk(self, content: str,important_keywords:List[str]=[]): + res = self.post(f'/dataset/{self.dataset_id}/document/{self.id}/chunk', {"content":content,"important_keywords":important_keywords}) res = res.json() if res.get("code") == 0: return Chunk(self.rag,res["data"].get("chunk")) raise Exception(res.get("message")) def delete_chunks(self,ids:List[str]): - res = self.rm(f"dataset/{self.knowledgebase_id}/document/{self.id}/chunk",{"ids":ids}) + res = self.rm(f"dataset/{self.dataset_id}/document/{self.id}/chunk",{"ids":ids}) res = res.json() if res.get("code")!=0: raise Exception(res.get("message")) \ No newline at end of file diff --git a/sdk/python/ragflow/modules/session.py b/sdk/python/ragflow/modules/session.py index e9805520d4e..cbd6faf1f92 100644 --- a/sdk/python/ragflow/modules/session.py +++ b/sdk/python/ragflow/modules/session.py @@ -40,7 +40,7 @@ def ask(self, question: str, stream: bool = False): "content": chunk["content_with_weight"], "document_id": chunk["doc_id"], "document_name": chunk["docnm_kwd"], - "knowledgebase_id": chunk["kb_id"], + "dataset_id": chunk["kb_id"], "image_id": chunk["img_id"], "similarity": chunk["similarity"], "vector_similarity": chunk["vector_similarity"], @@ -75,7 +75,7 @@ def __init__(self, rag, res_dict): self.content = None self.document_id = "" self.document_name = "" - self.knowledgebase_id = "" + self.dataset_id = "" self.image_id = "" self.similarity = None self.vector_similarity = None diff --git a/sdk/python/ragflow/ragflow.py b/sdk/python/ragflow/ragflow.py index 78994ccbe7e..616473f6a52 100644 --- a/sdk/python/ragflow/ragflow.py +++ b/sdk/python/ragflow/ragflow.py @@ -49,17 +49,11 @@ def put(self, path, json): return res def create_dataset(self, name: str, avatar: str = "", description: str = "", language: str = "English", - permission: str = "me", - document_count: int = 0, chunk_count: int = 0, chunk_method: str = "naive", + permission: str = "me",chunk_method: str = "naive", parser_config: DataSet.ParserConfig = None) -> DataSet: - if parser_config is None: - parser_config = DataSet.ParserConfig(self, {"chunk_token_count": 128, "layout_recognize": True, - "delimiter": "\n!?。;!?", "task_page_size": 12}) - parser_config = parser_config.to_json() res = self.post("/dataset", {"name": name, "avatar": avatar, "description": description, "language": language, - "permission": permission, - "document_count": document_count, "chunk_count": chunk_count, "chunk_method": chunk_method, + "permission": permission, "chunk_method": chunk_method, "parser_config": parser_config } ) @@ -93,11 +87,11 @@ def list_datasets(self, page: int = 1, page_size: int = 1024, orderby: str = "cr return result_list raise Exception(res["message"]) - def create_chat(self, name: str, avatar: str = "", knowledgebases: List[DataSet] = [], + def create_chat(self, name: str, avatar: str = "", datasets: List[DataSet] = [], llm: Chat.LLM = None, prompt: Chat.Prompt = None) -> Chat: - datasets = [] - for dataset in knowledgebases: - datasets.append(dataset.to_json()) + dataset_list = [] + for dataset in datasets: + dataset_list.append(dataset.to_json()) if llm is None: llm = Chat.LLM(self, {"model_name": None, @@ -130,7 +124,7 @@ def create_chat(self, name: str, avatar: str = "", knowledgebases: List[DataSet] temp_dict = {"name": name, "avatar": avatar, - "knowledgebases": datasets, + "datasets": dataset_list, "llm": llm.to_json(), "prompt": prompt.to_json()} res = self.post("/chat", temp_dict) @@ -158,25 +152,22 @@ def list_chats(self, page: int = 1, page_size: int = 1024, orderby: str = "creat raise Exception(res["message"]) - def retrieve(self, question="",datasets=None,documents=None, offset=1, limit=30, similarity_threshold=0.2,vector_similarity_weight=0.3,top_k=1024,rerank_id:str=None,keyword:bool=False,): - data_params = { + def retrieve(self, datasets,documents,question="", offset=1, limit=1024, similarity_threshold=0.2,vector_similarity_weight=0.3,top_k=1024,rerank_id:str=None,keyword:bool=False,): + data_json ={ "offset": offset, "limit": limit, "similarity_threshold": similarity_threshold, "vector_similarity_weight": vector_similarity_weight, "top_k": top_k, - "knowledgebase_id": datasets, - "rerank_id":rerank_id, - "keyword":keyword - } - data_json ={ + "rerank_id": rerank_id, + "keyword": keyword, "question": question, "datasets": datasets, "documents": documents } # Send a POST request to the backend service (using requests library as an example, actual implementation may vary) - res = self.get(f'/retrieval', data_params,data_json) + res = self.post(f'/retrieval',json=data_json) res = res.json() if res.get("code") ==0: chunks=[] diff --git a/sdk/python/test/t_chat.py b/sdk/python/test/t_chat.py index 93a55fa454a..994cbe86e51 100644 --- a/sdk/python/test/t_chat.py +++ b/sdk/python/test/t_chat.py @@ -1,4 +1,5 @@ from ragflow import RAGFlow, Chat +from xgboost.testing import datasets from common import API_KEY, HOST_ADDRESS from test_sdkbase import TestSdk @@ -11,7 +12,7 @@ def test_create_chat_with_success(self): """ rag = RAGFlow(API_KEY, HOST_ADDRESS) kb = rag.create_dataset(name="test_create_chat") - chat = rag.create_chat("test_create", knowledgebases=[kb]) + chat = rag.create_chat("test_create", datasets=[kb]) if isinstance(chat, Chat): assert chat.name == "test_create", "Name does not match." else: @@ -23,7 +24,7 @@ def test_update_chat_with_success(self): """ rag = RAGFlow(API_KEY, HOST_ADDRESS) kb = rag.create_dataset(name="test_update_chat") - chat = rag.create_chat("test_update", knowledgebases=[kb]) + chat = rag.create_chat("test_update", datasets=[kb]) if isinstance(chat, Chat): assert chat.name == "test_update", "Name does not match." res=chat.update({"name":"new_chat"}) @@ -37,7 +38,7 @@ def test_delete_chats_with_success(self): """ rag = RAGFlow(API_KEY, HOST_ADDRESS) kb = rag.create_dataset(name="test_delete_chat") - chat = rag.create_chat("test_delete", knowledgebases=[kb]) + chat = rag.create_chat("test_delete", datasets=[kb]) if isinstance(chat, Chat): assert chat.name == "test_delete", "Name does not match." res = rag.delete_chats(ids=[chat.id]) diff --git a/sdk/python/test/t_session.py b/sdk/python/test/t_session.py index d00647943ce..938a186d52c 100644 --- a/sdk/python/test/t_session.py +++ b/sdk/python/test/t_session.py @@ -7,14 +7,14 @@ class TestSession: def test_create_session(self): rag = RAGFlow(API_KEY, HOST_ADDRESS) kb = rag.create_dataset(name="test_create_session") - assistant = rag.create_chat(name="test_create_session", knowledgebases=[kb]) + assistant = rag.create_chat(name="test_create_session", datasets=[kb]) session = assistant.create_session() assert isinstance(session,Session), "Failed to create a session." def test_create_chat_with_success(self): rag = RAGFlow(API_KEY, HOST_ADDRESS) kb = rag.create_dataset(name="test_create_chat") - assistant = rag.create_chat(name="test_create_chat", knowledgebases=[kb]) + assistant = rag.create_chat(name="test_create_chat", datasets=[kb]) session = assistant.create_session() question = "What is AI" for ans in session.ask(question, stream=True): @@ -24,7 +24,7 @@ def test_create_chat_with_success(self): def test_delete_sessions_with_success(self): rag = RAGFlow(API_KEY, HOST_ADDRESS) kb = rag.create_dataset(name="test_delete_session") - assistant = rag.create_chat(name="test_delete_session",knowledgebases=[kb]) + assistant = rag.create_chat(name="test_delete_session",datasets=[kb]) session=assistant.create_session() res=assistant.delete_sessions(ids=[session.id]) assert res is None, "Failed to delete the dataset." @@ -32,7 +32,7 @@ def test_delete_sessions_with_success(self): def test_update_session_with_success(self): rag=RAGFlow(API_KEY,HOST_ADDRESS) kb=rag.create_dataset(name="test_update_session") - assistant = rag.create_chat(name="test_update_session",knowledgebases=[kb]) + assistant = rag.create_chat(name="test_update_session",datasets=[kb]) session=assistant.create_session(name="old session") res=session.update({"name":"new session"}) assert res is None,"Failed to update the session" @@ -41,7 +41,7 @@ def test_update_session_with_success(self): def test_list_sessions_with_success(self): rag=RAGFlow(API_KEY,HOST_ADDRESS) kb=rag.create_dataset(name="test_list_session") - assistant=rag.create_chat(name="test_list_session",knowledgebases=[kb]) + assistant=rag.create_chat(name="test_list_session",datasets=[kb]) assistant.create_session("test_1") assistant.create_session("test_2") sessions=assistant.list_sessions()