From 5f6099fc20ddc905b6300717af9d432295898678 Mon Sep 17 00:00:00 2001 From: Summer-Gu <37869445+gubinjie@users.noreply.github.com> Date: Mon, 14 Oct 2024 17:26:04 +0800 Subject: [PATCH] feat: added dataset recall testing API (#9300) --- .../console/datasets/hit_testing.py | 78 +--------- .../console/datasets/hit_testing_base.py | 85 ++++++++++ api/controllers/service_api/__init__.py | 3 +- .../service_api/dataset/hit_testing.py | 17 ++ .../datasets/template/template.en.mdx | 145 +++++++++++++++++ .../datasets/template/template.zh.mdx | 146 ++++++++++++++++++ 6 files changed, 401 insertions(+), 73 deletions(-) create mode 100644 api/controllers/console/datasets/hit_testing_base.py create mode 100644 api/controllers/service_api/dataset/hit_testing.py diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 6e6d8c0bd76f23..5c9bcef84c4f83 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,88 +1,24 @@ -import logging +from flask_restful import Resource -from flask_login import current_user -from flask_restful import Resource, marshal, reqparse -from werkzeug.exceptions import Forbidden, InternalServerError, NotFound - -import services from controllers.console import api -from controllers.console.app.error import ( - CompletionRequestError, - ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, - ProviderQuotaExceededError, -) -from controllers.console.datasets.error import DatasetNotInitializedError +from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.errors.error import ( - LLMBadRequestError, - ModelCurrentlyNotSupportError, - ProviderTokenNotInitError, - QuotaExceededError, -) -from core.model_runtime.errors.invoke import InvokeError -from fields.hit_testing_fields import hit_testing_record_fields from libs.login import login_required -from services.dataset_service import DatasetService -from services.hit_testing_service import HitTestingService -class HitTestingApi(Resource): +class HitTestingApi(Resource, DatasetsHitTestingBase): @setup_required @login_required @account_initialization_required def post(self, dataset_id): dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) - if dataset is None: - raise NotFound("Dataset not found.") - - try: - DatasetService.check_dataset_permission(dataset, current_user) - except services.errors.account.NoPermissionError as e: - raise Forbidden(str(e)) - - parser = reqparse.RequestParser() - parser.add_argument("query", type=str, location="json") - parser.add_argument("retrieval_model", type=dict, required=False, location="json") - parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") - args = parser.parse_args() - - HitTestingService.hit_testing_args_check(args) - - try: - response = HitTestingService.retrieve( - dataset=dataset, - query=args["query"], - account=current_user, - retrieval_model=args["retrieval_model"], - external_retrieval_model=args["external_retrieval_model"], - limit=10, - ) + dataset = self.get_and_validate_dataset(dataset_id_str) + args = self.parse_args() + self.hit_testing_args_check(args) - return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} - except services.errors.index.IndexNotInitializedError: - raise DatasetNotInitializedError() - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - except QuotaExceededError: - raise ProviderQuotaExceededError() - except ModelCurrentlyNotSupportError: - raise ProviderModelCurrentlyNotSupportError() - except LLMBadRequestError: - raise ProviderNotInitializeError( - "No Embedding Model or Reranking Model available. Please configure a valid provider " - "in the Settings -> Model Provider." - ) - except InvokeError as e: - raise CompletionRequestError(e.description) - except ValueError as e: - raise ValueError(str(e)) - except Exception as e: - logging.exception("Hit testing failed.") - raise InternalServerError(str(e)) + return self.perform_hit_testing(dataset, args) api.add_resource(HitTestingApi, "/datasets//hit-testing") diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py new file mode 100644 index 00000000000000..3b4c07686361d0 --- /dev/null +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -0,0 +1,85 @@ +import logging + +from flask_login import current_user +from flask_restful import marshal, reqparse +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound + +import services.dataset_service +from controllers.console.app.error import ( + CompletionRequestError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from controllers.console.datasets.error import DatasetNotInitializedError +from core.errors.error import ( + LLMBadRequestError, + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) +from core.model_runtime.errors.invoke import InvokeError +from fields.hit_testing_fields import hit_testing_record_fields +from services.dataset_service import DatasetService +from services.hit_testing_service import HitTestingService + + +class DatasetsHitTestingBase: + @staticmethod + def get_and_validate_dataset(dataset_id: str): + dataset = DatasetService.get_dataset(dataset_id) + if dataset is None: + raise NotFound("Dataset not found.") + + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + return dataset + + @staticmethod + def hit_testing_args_check(args): + HitTestingService.hit_testing_args_check(args) + + @staticmethod + def parse_args(): + parser = reqparse.RequestParser() + + parser.add_argument("query", type=str, location="json") + parser.add_argument("retrieval_model", type=dict, required=False, location="json") + parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") + return parser.parse_args() + + @staticmethod + def perform_hit_testing(dataset, args): + try: + response = HitTestingService.retrieve( + dataset=dataset, + query=args["query"], + account=current_user, + retrieval_model=args["retrieval_model"], + external_retrieval_model=args["external_retrieval_model"], + limit=10, + ) + return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} + except services.errors.index.IndexNotInitializedError: + raise DatasetNotInitializedError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except LLMBadRequestError: + raise ProviderNotInitializeError( + "No Embedding Model or Reranking Model available. Please configure a valid provider " + "in the Settings -> Model Provider." + ) + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError as e: + raise ValueError(str(e)) + except Exception as e: + logging.exception("Hit testing failed.") + raise InternalServerError(str(e)) diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index ad39c160ac2669..d6ab96c329f335 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -5,7 +5,6 @@ bp = Blueprint("service_api", __name__, url_prefix="/v1") api = ExternalApi(bp) - from . import index from .app import app, audio, completion, conversation, file, message, workflow -from .dataset import dataset, document, segment +from .dataset import dataset, document, hit_testing, segment diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py new file mode 100644 index 00000000000000..9c9a4302c99a65 --- /dev/null +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -0,0 +1,17 @@ +from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase +from controllers.service_api import api +from controllers.service_api.wraps import DatasetApiResource + + +class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): + def post(self, tenant_id, dataset_id): + dataset_id_str = str(dataset_id) + + dataset = self.get_and_validate_dataset(dataset_id_str) + args = self.parse_args() + self.hit_testing_args_check(args) + + return self.perform_hit_testing(dataset, args) + + +api.add_resource(HitTestingApi, "/datasets//hit-testing") diff --git a/web/app/(commonLayout)/datasets/template/template.en.mdx b/web/app/(commonLayout)/datasets/template/template.en.mdx index 99fab1ed8b77ef..b846f6d9fbf500 100644 --- a/web/app/(commonLayout)/datasets/template/template.en.mdx +++ b/web/app/(commonLayout)/datasets/template/template.en.mdx @@ -1050,6 +1050,151 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from --- + + + + ### Path + + + Dataset ID + + + + ### Request Body + + + retrieval keywordc + + + retrieval keyword(Optional, if not filled, it will be recalled according to the default method) + - search_method (text) Search method: One of the following four keywords is required + - keyword_search Keyword search + - semantic_search Semantic search + - full_text_search Full-text search + - hybrid_search Hybrid search + - reranking_enable (bool) Whether to enable reranking, optional, required if the search mode is semantic_search or hybrid_search + - reranking_mode (object) Rerank model configuration, optional, required if reranking is enabled + - reranking_provider_name (string) Rerank model provider + - reranking_model_name (string) Rerank model name + - weights (double) Semantic search weight setting in hybrid search mode + - top_k (integer) Number of results to return, optional + - score_threshold_enabled (bool) Whether to enable score threshold + - score_threshold (double) Score threshold + + + Unused field + + + + + + ```bash {{ title: 'cURL' }} + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \ + --header 'Authorization: Bearer {api_key}' \ + --header 'Content-Type: application/json' \ + --data-raw '{ + "query": "test", + "retrieval_model": { + "search_method": "keyword_search", + "reranking_enable": false, + "reranking_mode": null, + "reranking_model": { + "reranking_provider_name": "", + "reranking_model_name": "" + }, + "weights": null, + "top_k": 2, + "score_threshold_enabled": false, + "score_threshold": null + } + }' + ``` + + + ```json {{ title: 'Response' }} + { + "query": { + "content": "test" + }, + "records": [ + { + "segment": { + "id": "7fa6f24f-8679-48b3-bc9d-bdf28d73f218", + "position": 1, + "document_id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2", + "content": "Operation guide", + "answer": null, + "word_count": 847, + "tokens": 280, + "keywords": [ + "install", + "java", + "base", + "scripts", + "jdk", + "manual", + "internal", + "opens", + "add", + "vmoptions" + ], + "index_node_id": "39dd8443-d960-45a8-bb46-7275ad7fbc8e", + "index_node_hash": "0189157697b3c6a418ccf8264a09699f25858975578f3467c76d6bfc94df1d73", + "hit_count": 0, + "enabled": true, + "disabled_at": null, + "disabled_by": null, + "status": "completed", + "created_by": "dbcb1ab5-90c8-41a7-8b78-73b235eb6f6f", + "created_at": 1728734540, + "indexing_at": 1728734552, + "completed_at": 1728734584, + "error": null, + "stopped_at": null, + "document": { + "id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2", + "data_source_type": "upload_file", + "name": "readme.txt", + "doc_type": null + } + }, + "score": 3.730463140527718e-05, + "tsne_position": null + } + ] + } + ``` + + + + +--- + ### Error message diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index 2f8c7e99dc8540..ece4d3b7717efd 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -1049,6 +1049,152 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from +--- + + + + + ### Path + + + 知识库 ID + + + + ### Request Body + + + 召回关键词 + + + 召回参数(选填,如不填,按照默认方式召回) + - search_method (text) 检索方法:以下三个关键字之一,必填 + - keyword_search 关键字检索 + - semantic_search 语义检索 + - full_text_search 全文检索 + - hybrid_search 混合检索 + - reranking_enable (bool) 是否启用 Reranking,非必填,如果检索模式为semantic_search模式或者hybrid_search则传值 + - reranking_mode (object) Rerank模型配置,非必填,如果启用了 reranking 则传值 + - reranking_provider_name (string) Rerank 模型提供商 + - reranking_model_name (string) Rerank 模型名称 + - weights (double) 混合检索模式下语意检索的权重设置 + - top_k (integer) 返回结果数量,非必填 + - score_threshold_enabled (bool) 是否开启Score阈值 + - score_threshold (double) Score阈值 + + + 未启用字段 + + + + + + ```bash {{ title: 'cURL' }} + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \ + --header 'Authorization: Bearer {api_key}' \ + --header 'Content-Type: application/json' \ + --data-raw '{ + "query": "test", + "retrieval_model": { + "search_method": "keyword_search", + "reranking_enable": false, + "reranking_mode": null, + "reranking_model": { + "reranking_provider_name": "", + "reranking_model_name": "" + }, + "weights": null, + "top_k": 2, + "score_threshold_enabled": false, + "score_threshold": null + } + }' + ``` + + + ```json {{ title: 'Response' }} + { + "query": { + "content": "test" + }, + "records": [ + { + "segment": { + "id": "7fa6f24f-8679-48b3-bc9d-bdf28d73f218", + "position": 1, + "document_id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2", + "content": "Operation guide", + "answer": null, + "word_count": 847, + "tokens": 280, + "keywords": [ + "install", + "java", + "base", + "scripts", + "jdk", + "manual", + "internal", + "opens", + "add", + "vmoptions" + ], + "index_node_id": "39dd8443-d960-45a8-bb46-7275ad7fbc8e", + "index_node_hash": "0189157697b3c6a418ccf8264a09699f25858975578f3467c76d6bfc94df1d73", + "hit_count": 0, + "enabled": true, + "disabled_at": null, + "disabled_by": null, + "status": "completed", + "created_by": "dbcb1ab5-90c8-41a7-8b78-73b235eb6f6f", + "created_at": 1728734540, + "indexing_at": 1728734552, + "completed_at": 1728734584, + "error": null, + "stopped_at": null, + "document": { + "id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2", + "data_source_type": "upload_file", + "name": "readme.txt", + "doc_type": null + } + }, + "score": 3.730463140527718e-05, + "tsne_position": null + } + ] + } + ``` + + + + + ---