Skip to content

Commit

Permalink
feat: added dataset recall testing API (langgenius#9300)
Browse files Browse the repository at this point in the history
  • Loading branch information
gubinjie authored and JunXu01 committed Nov 9, 2024
1 parent 3707c67 commit 5f6099f
Show file tree
Hide file tree
Showing 6 changed files with 401 additions and 73 deletions.
78 changes: 7 additions & 71 deletions api/controllers/console/datasets/hit_testing.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,24 @@
import logging
from flask_restful import Resource

from flask_login import current_user
from flask_restful import Resource, marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound

import services
from controllers.console import api
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.datasets.error import DatasetNotInitializedError
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.errors.error import (
LLMBadRequestError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from fields.hit_testing_fields import hit_testing_record_fields
from libs.login import login_required
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService


class HitTestingApi(Resource):
class HitTestingApi(Resource, DatasetsHitTestingBase):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id):
dataset_id_str = str(dataset_id)

dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")

try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))

parser = reqparse.RequestParser()
parser.add_argument("query", type=str, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
args = parser.parse_args()

HitTestingService.hit_testing_args_check(args)

try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args["query"],
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
limit=10,
)
dataset = self.get_and_validate_dataset(dataset_id_str)
args = self.parse_args()
self.hit_testing_args_check(args)

return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model or Reranking Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise ValueError(str(e))
except Exception as e:
logging.exception("Hit testing failed.")
raise InternalServerError(str(e))
return self.perform_hit_testing(dataset, args)


api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
85 changes: 85 additions & 0 deletions api/controllers/console/datasets/hit_testing_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import logging

from flask_login import current_user
from flask_restful import marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound

import services.dataset_service
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.datasets.error import DatasetNotInitializedError
from core.errors.error import (
LLMBadRequestError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from fields.hit_testing_fields import hit_testing_record_fields
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService


class DatasetsHitTestingBase:
@staticmethod
def get_and_validate_dataset(dataset_id: str):
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")

try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))

return dataset

@staticmethod
def hit_testing_args_check(args):
HitTestingService.hit_testing_args_check(args)

@staticmethod
def parse_args():
parser = reqparse.RequestParser()

parser.add_argument("query", type=str, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
return parser.parse_args()

@staticmethod
def perform_hit_testing(dataset, args):
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args["query"],
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
limit=10,
)
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model or Reranking Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise ValueError(str(e))
except Exception as e:
logging.exception("Hit testing failed.")
raise InternalServerError(str(e))
3 changes: 1 addition & 2 deletions api/controllers/service_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
bp = Blueprint("service_api", __name__, url_prefix="/v1")
api = ExternalApi(bp)


from . import index
from .app import app, audio, completion, conversation, file, message, workflow
from .dataset import dataset, document, segment
from .dataset import dataset, document, hit_testing, segment
17 changes: 17 additions & 0 deletions api/controllers/service_api/dataset/hit_testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.service_api import api
from controllers.service_api.wraps import DatasetApiResource


class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
def post(self, tenant_id, dataset_id):
dataset_id_str = str(dataset_id)

dataset = self.get_and_validate_dataset(dataset_id_str)
args = self.parse_args()
self.hit_testing_args_check(args)

return self.perform_hit_testing(dataset, args)


api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
145 changes: 145 additions & 0 deletions web/app/(commonLayout)/datasets/template/template.en.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,151 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from

---

<Heading
url='/datasets/{dataset_id}/hit_testing'
method='POST'
title='Dataset hit testing'
name='#dataset_hit_testing'
/>
<Row>
<Col>
### Path
<Properties>
<Property name='dataset_id' type='string' key='dataset_id'>
Dataset ID
</Property>
</Properties>

### Request Body
<Properties>
<Property name='query' type='string' key='query'>
retrieval keywordc
</Property>
<Property name='retrieval_model' type='object' key='retrieval_model'>
retrieval keyword(Optional, if not filled, it will be recalled according to the default method)
- <code>search_method</code> (text) Search method: One of the following four keywords is required
- <code>keyword_search</code> Keyword search
- <code>semantic_search</code> Semantic search
- <code>full_text_search</code> Full-text search
- <code>hybrid_search</code> Hybrid search
- <code>reranking_enable</code> (bool) Whether to enable reranking, optional, required if the search mode is semantic_search or hybrid_search
- <code>reranking_mode</code> (object) Rerank model configuration, optional, required if reranking is enabled
- <code>reranking_provider_name</code> (string) Rerank model provider
- <code>reranking_model_name</code> (string) Rerank model name
- <code>weights</code> (double) Semantic search weight setting in hybrid search mode
- <code>top_k</code> (integer) Number of results to return, optional
- <code>score_threshold_enabled</code> (bool) Whether to enable score threshold
- <code>score_threshold</code> (double) Score threshold
</Property>
<Property name='external_retrieval_model' type='object' key='external_retrieval_model'>
Unused field
</Property>
</Properties>
</Col>
<Col sticky>
<CodeGroup
title="Request"
tag="POST"
label="/datasets/{dataset_id}/hit_testing"
targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{
"query": "test",
"retrieval_model": {
"search_method": "keyword_search",
"reranking_enable": false,
"reranking_mode": null,
"reranking_model": {
"reranking_provider_name": "",
"reranking_model_name": ""
},
"weights": null,
"top_k": 1,
"score_threshold_enabled": false,
"score_threshold": null
}
}'`}
>
```bash {{ title: 'cURL' }}
curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \
--header 'Authorization: Bearer {api_key}' \
--header 'Content-Type: application/json' \
--data-raw '{
"query": "test",
"retrieval_model": {
"search_method": "keyword_search",
"reranking_enable": false,
"reranking_mode": null,
"reranking_model": {
"reranking_provider_name": "",
"reranking_model_name": ""
},
"weights": null,
"top_k": 2,
"score_threshold_enabled": false,
"score_threshold": null
}
}'
```
</CodeGroup>
<CodeGroup title="Response">
```json {{ title: 'Response' }}
{
"query": {
"content": "test"
},
"records": [
{
"segment": {
"id": "7fa6f24f-8679-48b3-bc9d-bdf28d73f218",
"position": 1,
"document_id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
"content": "Operation guide",
"answer": null,
"word_count": 847,
"tokens": 280,
"keywords": [
"install",
"java",
"base",
"scripts",
"jdk",
"manual",
"internal",
"opens",
"add",
"vmoptions"
],
"index_node_id": "39dd8443-d960-45a8-bb46-7275ad7fbc8e",
"index_node_hash": "0189157697b3c6a418ccf8264a09699f25858975578f3467c76d6bfc94df1d73",
"hit_count": 0,
"enabled": true,
"disabled_at": null,
"disabled_by": null,
"status": "completed",
"created_by": "dbcb1ab5-90c8-41a7-8b78-73b235eb6f6f",
"created_at": 1728734540,
"indexing_at": 1728734552,
"completed_at": 1728734584,
"error": null,
"stopped_at": null,
"document": {
"id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
"data_source_type": "upload_file",
"name": "readme.txt",
"doc_type": null
}
},
"score": 3.730463140527718e-05,
"tsne_position": null
}
]
}
```
</CodeGroup>
</Col>
</Row>

---

<Row>
<Col>
### Error message
Expand Down
Loading

0 comments on commit 5f6099f

Please sign in to comment.