Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added dataset recall testing API #9300

Merged
merged 6 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 7 additions & 71 deletions api/controllers/console/datasets/hit_testing.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,24 @@
import logging
from flask_restful import Resource

from flask_login import current_user
from flask_restful import Resource, marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound

import services
from controllers.console import api
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.datasets.error import DatasetNotInitializedError
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.errors.error import (
LLMBadRequestError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from fields.hit_testing_fields import hit_testing_record_fields
from libs.login import login_required
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService


class HitTestingApi(Resource):
class HitTestingApi(Resource, DatasetsHitTestingBase):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id):
dataset_id_str = str(dataset_id)

dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")

try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))

parser = reqparse.RequestParser()
parser.add_argument("query", type=str, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
args = parser.parse_args()

HitTestingService.hit_testing_args_check(args)

try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args["query"],
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
limit=10,
)
dataset = self.get_and_validate_dataset(dataset_id_str)
args = self.parse_args()
self.hit_testing_args_check(args)

return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model or Reranking Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise ValueError(str(e))
except Exception as e:
logging.exception("Hit testing failed.")
raise InternalServerError(str(e))
return self.perform_hit_testing(dataset, args)


api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
85 changes: 85 additions & 0 deletions api/controllers/console/datasets/hit_testing_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import logging

from flask_login import current_user
from flask_restful import marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound

import services.dataset_service
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.datasets.error import DatasetNotInitializedError
from core.errors.error import (
LLMBadRequestError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from fields.hit_testing_fields import hit_testing_record_fields
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService


class DatasetsHitTestingBase:
@staticmethod
def get_and_validate_dataset(dataset_id: str):
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")

try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))

return dataset

@staticmethod
def hit_testing_args_check(args):
HitTestingService.hit_testing_args_check(args)

@staticmethod
def parse_args():
parser = reqparse.RequestParser()

parser.add_argument("query", type=str, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
return parser.parse_args()

@staticmethod
def perform_hit_testing(dataset, args):
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args["query"],
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
limit=10,
)
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model or Reranking Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise ValueError(str(e))
except Exception as e:
logging.exception("Hit testing failed.")
raise InternalServerError(str(e))
3 changes: 1 addition & 2 deletions api/controllers/service_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
bp = Blueprint("service_api", __name__, url_prefix="/v1")
api = ExternalApi(bp)


from . import index
from .app import app, audio, completion, conversation, file, message, workflow
from .dataset import dataset, document, segment
from .dataset import dataset, document, hit_testing, segment
17 changes: 17 additions & 0 deletions api/controllers/service_api/dataset/hit_testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.service_api import api
from controllers.service_api.wraps import DatasetApiResource


class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
def post(self, tenant_id, dataset_id):
dataset_id_str = str(dataset_id)

dataset = self.get_and_validate_dataset(dataset_id_str)
args = self.parse_args()
self.hit_testing_args_check(args)

return self.perform_hit_testing(dataset, args)


api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
145 changes: 145 additions & 0 deletions web/app/(commonLayout)/datasets/template/template.en.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,151 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from

---

<Heading
url='/datasets/{dataset_id}/hit_testing'
method='POST'
title='Dataset hit testing'
name='#dataset_hit_testing'
/>
<Row>
<Col>
### Path
<Properties>
<Property name='dataset_id' type='string' key='dataset_id'>
Dataset ID
</Property>
</Properties>

### Request Body
<Properties>
<Property name='query' type='string' key='query'>
retrieval keywordc
</Property>
<Property name='retrieval_model' type='object' key='retrieval_model'>
retrieval keyword(Optional, if not filled, it will be recalled according to the default method)
- <code>search_method</code> (text) Search method: One of the following four keywords is required
- <code>keyword_search</code> Keyword search
- <code>semantic_search</code> Semantic search
- <code>full_text_search</code> Full-text search
- <code>hybrid_search</code> Hybrid search
- <code>reranking_enable</code> (bool) Whether to enable reranking, optional, required if the search mode is semantic_search or hybrid_search
- <code>reranking_mode</code> (object) Rerank model configuration, optional, required if reranking is enabled
- <code>reranking_provider_name</code> (string) Rerank model provider
- <code>reranking_model_name</code> (string) Rerank model name
- <code>weights</code> (double) Semantic search weight setting in hybrid search mode
- <code>top_k</code> (integer) Number of results to return, optional
- <code>score_threshold_enabled</code> (bool) Whether to enable score threshold
- <code>score_threshold</code> (double) Score threshold
</Property>
<Property name='external_retrieval_model' type='object' key='external_retrieval_model'>
Unused field
</Property>
</Properties>
</Col>
<Col sticky>
<CodeGroup
title="Request"
tag="POST"
label="/datasets/{dataset_id}/hit_testing"
targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{
gubinjie marked this conversation as resolved.
Show resolved Hide resolved
"query": "test",
"retrieval_model": {
"search_method": "keyword_search",
"reranking_enable": false,
"reranking_mode": null,
"reranking_model": {
"reranking_provider_name": "",
"reranking_model_name": ""
},
"weights": null,
"top_k": 1,
"score_threshold_enabled": false,
"score_threshold": null
}
}'`}
>
```bash {{ title: 'cURL' }}
curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \
--header 'Authorization: Bearer {api_key}' \
--header 'Content-Type: application/json' \
--data-raw '{
"query": "test",
"retrieval_model": {
"search_method": "keyword_search",
"reranking_enable": false,
"reranking_mode": null,
"reranking_model": {
"reranking_provider_name": "",
"reranking_model_name": ""
},
"weights": null,
"top_k": 2,
"score_threshold_enabled": false,
"score_threshold": null
}
}'
```
</CodeGroup>
<CodeGroup title="Response">
```json {{ title: 'Response' }}
{
"query": {
"content": "test"
},
"records": [
{
"segment": {
"id": "7fa6f24f-8679-48b3-bc9d-bdf28d73f218",
"position": 1,
"document_id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
"content": "Operation guide",
"answer": null,
"word_count": 847,
"tokens": 280,
"keywords": [
"install",
"java",
"base",
"scripts",
"jdk",
"manual",
"internal",
"opens",
"add",
"vmoptions"
],
"index_node_id": "39dd8443-d960-45a8-bb46-7275ad7fbc8e",
"index_node_hash": "0189157697b3c6a418ccf8264a09699f25858975578f3467c76d6bfc94df1d73",
"hit_count": 0,
"enabled": true,
"disabled_at": null,
"disabled_by": null,
"status": "completed",
"created_by": "dbcb1ab5-90c8-41a7-8b78-73b235eb6f6f",
"created_at": 1728734540,
"indexing_at": 1728734552,
"completed_at": 1728734584,
"error": null,
"stopped_at": null,
"document": {
"id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
"data_source_type": "upload_file",
"name": "readme.txt",
"doc_type": null
}
},
"score": 3.730463140527718e-05,
"tsne_position": null
}
]
}
```
</CodeGroup>
</Col>
</Row>

---

<Row>
<Col>
### Error message
Expand Down
Loading