forked from langgenius/dify
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: added dataset recall testing API (langgenius#9300)
- Loading branch information
Showing
6 changed files
with
401 additions
and
73 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,88 +1,24 @@ | ||
import logging | ||
from flask_restful import Resource | ||
|
||
from flask_login import current_user | ||
from flask_restful import Resource, marshal, reqparse | ||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound | ||
|
||
import services | ||
from controllers.console import api | ||
from controllers.console.app.error import ( | ||
CompletionRequestError, | ||
ProviderModelCurrentlyNotSupportError, | ||
ProviderNotInitializeError, | ||
ProviderQuotaExceededError, | ||
) | ||
from controllers.console.datasets.error import DatasetNotInitializedError | ||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase | ||
from controllers.console.setup import setup_required | ||
from controllers.console.wraps import account_initialization_required | ||
from core.errors.error import ( | ||
LLMBadRequestError, | ||
ModelCurrentlyNotSupportError, | ||
ProviderTokenNotInitError, | ||
QuotaExceededError, | ||
) | ||
from core.model_runtime.errors.invoke import InvokeError | ||
from fields.hit_testing_fields import hit_testing_record_fields | ||
from libs.login import login_required | ||
from services.dataset_service import DatasetService | ||
from services.hit_testing_service import HitTestingService | ||
|
||
|
||
class HitTestingApi(Resource): | ||
class HitTestingApi(Resource, DatasetsHitTestingBase): | ||
@setup_required | ||
@login_required | ||
@account_initialization_required | ||
def post(self, dataset_id): | ||
dataset_id_str = str(dataset_id) | ||
|
||
dataset = DatasetService.get_dataset(dataset_id_str) | ||
if dataset is None: | ||
raise NotFound("Dataset not found.") | ||
|
||
try: | ||
DatasetService.check_dataset_permission(dataset, current_user) | ||
except services.errors.account.NoPermissionError as e: | ||
raise Forbidden(str(e)) | ||
|
||
parser = reqparse.RequestParser() | ||
parser.add_argument("query", type=str, location="json") | ||
parser.add_argument("retrieval_model", type=dict, required=False, location="json") | ||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") | ||
args = parser.parse_args() | ||
|
||
HitTestingService.hit_testing_args_check(args) | ||
|
||
try: | ||
response = HitTestingService.retrieve( | ||
dataset=dataset, | ||
query=args["query"], | ||
account=current_user, | ||
retrieval_model=args["retrieval_model"], | ||
external_retrieval_model=args["external_retrieval_model"], | ||
limit=10, | ||
) | ||
dataset = self.get_and_validate_dataset(dataset_id_str) | ||
args = self.parse_args() | ||
self.hit_testing_args_check(args) | ||
|
||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} | ||
except services.errors.index.IndexNotInitializedError: | ||
raise DatasetNotInitializedError() | ||
except ProviderTokenNotInitError as ex: | ||
raise ProviderNotInitializeError(ex.description) | ||
except QuotaExceededError: | ||
raise ProviderQuotaExceededError() | ||
except ModelCurrentlyNotSupportError: | ||
raise ProviderModelCurrentlyNotSupportError() | ||
except LLMBadRequestError: | ||
raise ProviderNotInitializeError( | ||
"No Embedding Model or Reranking Model available. Please configure a valid provider " | ||
"in the Settings -> Model Provider." | ||
) | ||
except InvokeError as e: | ||
raise CompletionRequestError(e.description) | ||
except ValueError as e: | ||
raise ValueError(str(e)) | ||
except Exception as e: | ||
logging.exception("Hit testing failed.") | ||
raise InternalServerError(str(e)) | ||
return self.perform_hit_testing(dataset, args) | ||
|
||
|
||
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import logging | ||
|
||
from flask_login import current_user | ||
from flask_restful import marshal, reqparse | ||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound | ||
|
||
import services.dataset_service | ||
from controllers.console.app.error import ( | ||
CompletionRequestError, | ||
ProviderModelCurrentlyNotSupportError, | ||
ProviderNotInitializeError, | ||
ProviderQuotaExceededError, | ||
) | ||
from controllers.console.datasets.error import DatasetNotInitializedError | ||
from core.errors.error import ( | ||
LLMBadRequestError, | ||
ModelCurrentlyNotSupportError, | ||
ProviderTokenNotInitError, | ||
QuotaExceededError, | ||
) | ||
from core.model_runtime.errors.invoke import InvokeError | ||
from fields.hit_testing_fields import hit_testing_record_fields | ||
from services.dataset_service import DatasetService | ||
from services.hit_testing_service import HitTestingService | ||
|
||
|
||
class DatasetsHitTestingBase: | ||
@staticmethod | ||
def get_and_validate_dataset(dataset_id: str): | ||
dataset = DatasetService.get_dataset(dataset_id) | ||
if dataset is None: | ||
raise NotFound("Dataset not found.") | ||
|
||
try: | ||
DatasetService.check_dataset_permission(dataset, current_user) | ||
except services.errors.account.NoPermissionError as e: | ||
raise Forbidden(str(e)) | ||
|
||
return dataset | ||
|
||
@staticmethod | ||
def hit_testing_args_check(args): | ||
HitTestingService.hit_testing_args_check(args) | ||
|
||
@staticmethod | ||
def parse_args(): | ||
parser = reqparse.RequestParser() | ||
|
||
parser.add_argument("query", type=str, location="json") | ||
parser.add_argument("retrieval_model", type=dict, required=False, location="json") | ||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") | ||
return parser.parse_args() | ||
|
||
@staticmethod | ||
def perform_hit_testing(dataset, args): | ||
try: | ||
response = HitTestingService.retrieve( | ||
dataset=dataset, | ||
query=args["query"], | ||
account=current_user, | ||
retrieval_model=args["retrieval_model"], | ||
external_retrieval_model=args["external_retrieval_model"], | ||
limit=10, | ||
) | ||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} | ||
except services.errors.index.IndexNotInitializedError: | ||
raise DatasetNotInitializedError() | ||
except ProviderTokenNotInitError as ex: | ||
raise ProviderNotInitializeError(ex.description) | ||
except QuotaExceededError: | ||
raise ProviderQuotaExceededError() | ||
except ModelCurrentlyNotSupportError: | ||
raise ProviderModelCurrentlyNotSupportError() | ||
except LLMBadRequestError: | ||
raise ProviderNotInitializeError( | ||
"No Embedding Model or Reranking Model available. Please configure a valid provider " | ||
"in the Settings -> Model Provider." | ||
) | ||
except InvokeError as e: | ||
raise CompletionRequestError(e.description) | ||
except ValueError as e: | ||
raise ValueError(str(e)) | ||
except Exception as e: | ||
logging.exception("Hit testing failed.") | ||
raise InternalServerError(str(e)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase | ||
from controllers.service_api import api | ||
from controllers.service_api.wraps import DatasetApiResource | ||
|
||
|
||
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): | ||
def post(self, tenant_id, dataset_id): | ||
dataset_id_str = str(dataset_id) | ||
|
||
dataset = self.get_and_validate_dataset(dataset_id_str) | ||
args = self.parse_args() | ||
self.hit_testing_args_check(args) | ||
|
||
return self.perform_hit_testing(dataset, args) | ||
|
||
|
||
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.