Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added dataset recall testing API #9300

Merged
merged 6 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions api/controllers/service_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
bp = Blueprint("service_api", __name__, url_prefix="/v1")
api = ExternalApi(bp)


from . import index
from .app import app, audio, completion, conversation, file, message, workflow
from .dataset import dataset, document, segment
from .dataset import dataset, document, hit_testing, segment
85 changes: 85 additions & 0 deletions api/controllers/service_api/dataset/hit_testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import logging

from flask_login import current_user
from flask_restful import marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound

import services.dataset_service
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.datasets.error import DatasetNotInitializedError
from controllers.service_api import api
from controllers.service_api.wraps import DatasetApiResource
from core.errors.error import (
LLMBadRequestError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from fields.hit_testing_fields import hit_testing_record_fields
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService


class HitTestingApi(DatasetApiResource):
def post(self, tenant_id, dataset_id):
dataset_id_str = str(dataset_id)
logging.error(f"{dataset_id_str}")
gubinjie marked this conversation as resolved.
Show resolved Hide resolved

dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")

try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))

parser = reqparse.RequestParser()

parser.add_argument("query", type=str, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
args = parser.parse_args()

HitTestingService.hit_testing_args_check(args)

try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args["query"],
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
limit=10,
)

return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model or Reranking Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise ValueError(str(e))
except Exception as e:
logging.exception("Hit testing failed.")
raise InternalServerError(str(e))


api.add_resource(HitTestingApi, "/datasets2/<uuid:dataset_id>/hit-testing")
145 changes: 145 additions & 0 deletions web/app/(commonLayout)/datasets/template/template.en.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,151 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from

---

<Heading
url='/datasets/{dataset_id}/hit_testing'
method='POST'
title='Dataset hit testing'
name='#dataset_hit_testing'
/>
<Row>
<Col>
### Path
<Properties>
<Property name='dataset_id' type='string' key='dataset_id'>
Dataset ID
</Property>
</Properties>

### Request Body
<Properties>
<Property name='query' type='string' key='query'>
retrieval keywordc
</Property>
<Property name='retrieval_model' type='object' key='retrieval_model'>
retrieval keyword(Optional, if not filled, it will be recalled according to the default method)
- <code>search_method</code> (text) Search method: One of the following four keywords is required
- <code>keyword_search</code> Keyword search
- <code>semantic_search</code> Semantic search
- <code>full_text_search</code> Full-text search
- <code>hybrid_search</code> Hybrid search
- <code>reranking_enable</code> (bool) Whether to enable reranking, optional, required if the search mode is semantic_search or hybrid_search
- <code>reranking_mode</code> (object) Rerank model configuration, optional, required if reranking is enabled
- <code>reranking_provider_name</code> (string) Rerank model provider
- <code>reranking_model_name</code> (string) Rerank model name
- <code>weights</code> (double) Semantic search weight setting in hybrid search mode
- <code>top_k</code> (integer) Number of results to return, optional
- <code>score_threshold_enabled</code> (bool) Whether to enable score threshold
- <code>score_threshold</code> (double) Score threshold
</Property>
<Property name='external_retrieval_model' type='object' key='external_retrieval_model'>
Unused field
</Property>
</Properties>
</Col>
<Col sticky>
<CodeGroup
title="Request"
tag="GET"
label="/datasets/{dataset_id}/hit_testing"
targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{
gubinjie marked this conversation as resolved.
Show resolved Hide resolved
"query": "test",
"retrieval_model": {
"search_method": "keyword_search",
"reranking_enable": false,
"reranking_mode": null,
"reranking_model": {
"reranking_provider_name": "",
"reranking_model_name": ""
},
"weights": null,
"top_k": 1,
"score_threshold_enabled": false,
"score_threshold": null
}
}'`}
>
```bash {{ title: 'cURL' }}
curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \
--header 'Authorization: Bearer {api_key}' \
--header 'Content-Type: application/json' \
--data-raw '{
"query": "test",
"retrieval_model": {
"search_method": "keyword_search",
"reranking_enable": false,
"reranking_mode": null,
"reranking_model": {
"reranking_provider_name": "",
"reranking_model_name": ""
},
"weights": null,
"top_k": 2,
"score_threshold_enabled": false,
"score_threshold": null
}
}'
```
</CodeGroup>
<CodeGroup title="Response">
```json {{ title: 'Response' }}
{
"query": {
"content": "test"
},
"records": [
{
"segment": {
"id": "7fa6f24f-8679-48b3-bc9d-bdf28d73f218",
"position": 1,
"document_id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
"content": "Operation guide",
"answer": null,
"word_count": 847,
"tokens": 280,
"keywords": [
"install",
"java",
"base",
"scripts",
"jdk",
"manual",
"internal",
"opens",
"add",
"vmoptions"
],
"index_node_id": "39dd8443-d960-45a8-bb46-7275ad7fbc8e",
"index_node_hash": "0189157697b3c6a418ccf8264a09699f25858975578f3467c76d6bfc94df1d73",
"hit_count": 0,
"enabled": true,
"disabled_at": null,
"disabled_by": null,
"status": "completed",
"created_by": "dbcb1ab5-90c8-41a7-8b78-73b235eb6f6f",
"created_at": 1728734540,
"indexing_at": 1728734552,
"completed_at": 1728734584,
"error": null,
"stopped_at": null,
"document": {
"id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
"data_source_type": "upload_file",
"name": "readme.txt",
"doc_type": null
}
},
"score": 3.730463140527718e-05,
"tsne_position": null
}
]
}
```
</CodeGroup>
</Col>
</Row>

---

<Row>
<Col>
### Error message
Expand Down
146 changes: 146 additions & 0 deletions web/app/(commonLayout)/datasets/template/template.zh.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,152 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
</Col>
</Row>

---

<Heading
url='/datasets/{dataset_id}/hit_testing'
method='POST'
title='知识库召回测试'
name='#dataset_hit_testing'
/>
<Row>
<Col>
### Path
<Properties>
<Property name='dataset_id' type='string' key='dataset_id'>
知识库 ID
</Property>
</Properties>

### Request Body
<Properties>
<Property name='query' type='string' key='query'>
召回关键词
</Property>
<Property name='retrieval_model' type='object' key='retrieval_model'>
召回参数(选填,如不填,按照默认方式召回)
- <code>search_method</code> (text) 检索方法:以下三个关键字之一,必填
- <code>keyword_search</code> 关键字检索
- <code>semantic_search</code> 语义检索
- <code>full_text_search</code> 全文检索
- <code>hybrid_search</code> 混合检索
- <code>reranking_enable</code> (bool) 是否启用 Reranking,非必填,如果检索模式为semantic_search模式或者hybrid_search则传值
- <code>reranking_mode</code> (object) Rerank模型配置,非必填,如果启用了 reranking 则传值
- <code>reranking_provider_name</code> (string) Rerank 模型提供商
- <code>reranking_model_name</code> (string) Rerank 模型名称
- <code>weights</code> (double) 混合检索模式下语意检索的权重设置
- <code>top_k</code> (integer) 返回结果数量,非必填
- <code>score_threshold_enabled</code> (bool) 是否开启Score阈值
- <code>score_threshold</code> (double) Score阈值
</Property>
<Property name='external_retrieval_model' type='object' key='external_retrieval_model'>
未启用字段
</Property>
</Properties>
</Col>
<Col sticky>
<CodeGroup
title="Request"
tag="GET"
label="/datasets/{dataset_id}/hit_testing"
targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{
"query": "test",
"retrieval_model": {
"search_method": "keyword_search",
"reranking_enable": false,
"reranking_mode": null,
"reranking_model": {
"reranking_provider_name": "",
"reranking_model_name": ""
},
"weights": null,
"top_k": 1,
"score_threshold_enabled": false,
"score_threshold": null
}
}'`}
>
```bash {{ title: 'cURL' }}
curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \
--header 'Authorization: Bearer {api_key}' \
--header 'Content-Type: application/json' \
--data-raw '{
"query": "test",
"retrieval_model": {
"search_method": "keyword_search",
"reranking_enable": false,
"reranking_mode": null,
"reranking_model": {
"reranking_provider_name": "",
"reranking_model_name": ""
},
"weights": null,
"top_k": 2,
"score_threshold_enabled": false,
"score_threshold": null
}
}'
```
</CodeGroup>
<CodeGroup title="Response">
```json {{ title: 'Response' }}
{
"query": {
"content": "test"
},
"records": [
{
"segment": {
"id": "7fa6f24f-8679-48b3-bc9d-bdf28d73f218",
"position": 1,
"document_id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
"content": "Operation guide",
"answer": null,
"word_count": 847,
"tokens": 280,
"keywords": [
"install",
"java",
"base",
"scripts",
"jdk",
"manual",
"internal",
"opens",
"add",
"vmoptions"
],
"index_node_id": "39dd8443-d960-45a8-bb46-7275ad7fbc8e",
"index_node_hash": "0189157697b3c6a418ccf8264a09699f25858975578f3467c76d6bfc94df1d73",
"hit_count": 0,
"enabled": true,
"disabled_at": null,
"disabled_by": null,
"status": "completed",
"created_by": "dbcb1ab5-90c8-41a7-8b78-73b235eb6f6f",
"created_at": 1728734540,
"indexing_at": 1728734552,
"completed_at": 1728734584,
"error": null,
"stopped_at": null,
"document": {
"id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
"data_source_type": "upload_file",
"name": "readme.txt",
"doc_type": null
}
},
"score": 3.730463140527718e-05,
"tsne_position": null
}
]
}
```
</CodeGroup>
</Col>
</Row>


---

<Row>
Expand Down