From 42b02b3a5f6c833ac5e09fd9667920a34ebed518 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Fri, 11 Oct 2024 19:21:03 +0800 Subject: [PATCH] Fix/agent external knowledge retrieval (#9241) --- api/configs/middleware/__init__.py | 17 ++ api/controllers/console/datasets/external.py | 24 ++ api/core/rag/retrieval/dataset_retrieval.py | 2 +- .../dataset_retriever_tool.py | 218 +++++++++++------- .../knowledge_retrieval_node.py | 14 +- api/services/dataset_service.py | 1 + api/services/knowledge_service.py | 45 ++++ 7 files changed, 227 insertions(+), 94 deletions(-) create mode 100644 api/services/knowledge_service.py diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 8626236856ddbf..5fec991d6e2d97 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -191,6 +191,22 @@ def BROKER_USE_SSL(self) -> bool: return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False +class InternalTestConfig(BaseSettings): + """ + Configuration settings for Internal Test + """ + + AWS_SECRET_ACCESS_KEY: Optional[str] = Field( + description="Internal test AWS secret access key", + default=None, + ) + + AWS_ACCESS_KEY_ID: Optional[str] = Field( + description="Internal test AWS access key ID", + default=None, + ) + + class MiddlewareConfig( # place the configs in alphabet order CeleryConfig, @@ -224,5 +240,6 @@ class MiddlewareConfig( TiDBVectorConfig, WeaviateConfig, ElasticsearchConfig, + InternalTestConfig, ): pass diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index dcef979360ba62..2dc054cfbdf4af 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -13,6 +13,7 @@ from services.dataset_service import DatasetService from services.external_knowledge_service import ExternalDatasetService from services.hit_testing_service import HitTestingService +from services.knowledge_service import ExternalDatasetTestService def _validate_name(name): @@ -232,8 +233,31 @@ def post(self, dataset_id): raise InternalServerError(str(e)) +class BedrockRetrievalApi(Resource): + # this api is only for internal testing + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json") + parser.add_argument( + "query", + nullable=False, + required=True, + type=str, + ) + parser.add_argument("knowledge_id", nullable=False, required=True, type=str) + args = parser.parse_args() + + # Call the knowledge retrieval service + result = ExternalDatasetTestService.knowledge_retrieval( + args["retrieval_setting"], args["query"], args["knowledge_id"] + ) + return result, 200 + + api.add_resource(ExternalKnowledgeHitTestingApi, "/datasets//external-hit-testing") api.add_resource(ExternalDatasetCreateApi, "/datasets/external") api.add_resource(ExternalApiTemplateListApi, "/datasets/external-knowledge-api") api.add_resource(ExternalApiTemplateApi, "/datasets/external-knowledge-api/") api.add_resource(ExternalApiUseCheckApi, "/datasets/external-knowledge-api//use-check") +# this api is only for internal test +api.add_resource(BedrockRetrievalApi, "/test/retrieval") diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index ae79b44158c848..633e41d5cf1ed6 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -539,7 +539,7 @@ def to_dataset_retriever_tool( continue # pass if dataset is not available - if dataset and dataset.available_document_count == 0: + if dataset and dataset.provider != "external" and dataset.available_document_count == 0: continue available_datasets.append(dataset) diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index 8dc60408c93b41..987f94a35046e9 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -1,10 +1,12 @@ from pydantic import BaseModel, Field from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.models.document import Document as RetrievalDocument from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment +from services.external_knowledge_service import ExternalDatasetService default_retrieval_model = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, @@ -53,97 +55,137 @@ def _run(self, query: str) -> str: for hit_callback in self.hit_callbacks: hit_callback.on_query(query, dataset.id) - - # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model or default_retrieval_model - if dataset.indexing_technique == "economy": - # use keyword table query - documents = RetrievalService.retrieve( - retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k + if dataset.provider == "external": + results = [] + external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + query=query, + external_retrieval_parameters=dataset.retrieval_model, ) - return str("\n".join([document.page_content for document in documents])) + for external_document in external_documents: + document = RetrievalDocument( + page_content=external_document.get("content"), + metadata=external_document.get("metadata"), + provider="external", + ) + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset.id + document.metadata["dataset_name"] = dataset.name + results.append(document) + # deal with external documents + context_list = [] + for position, item in enumerate(results, start=1): + source = { + "position": position, + "dataset_id": item.metadata.get("dataset_id"), + "dataset_name": item.metadata.get("dataset_name"), + "document_name": item.metadata.get("title"), + "data_source_type": "external", + "retriever_from": self.retriever_from, + "score": item.metadata.get("score"), + "title": item.metadata.get("title"), + "content": item.page_content, + } + context_list.append(source) + for hit_callback in self.hit_callbacks: + hit_callback.return_retriever_resource_info(context_list) + + return str("\n".join([item.page_content for item in results])) else: - if self.top_k > 0: - # retrieval source + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model or default_retrieval_model + if dataset.indexing_technique == "economy": + # use keyword table query documents = RetrievalService.retrieve( - retrieval_method=retrieval_model.get("search_method", "semantic_search"), - dataset_id=dataset.id, - query=query, - top_k=self.top_k, - score_threshold=retrieval_model.get("score_threshold", 0.0) - if retrieval_model["score_threshold_enabled"] - else 0.0, - reranking_model=retrieval_model.get("reranking_model", None) - if retrieval_model["reranking_enable"] - else None, - reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", - weights=retrieval_model.get("weights", None), + retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k ) + return str("\n".join([document.page_content for document in documents])) else: - documents = [] - - for hit_callback in self.hit_callbacks: - hit_callback.on_tool_end(documents) - document_score_list = {} - if dataset.indexing_technique != "economy": - for item in documents: - if item.metadata.get("score"): - document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - document_context_list = [] - index_node_ids = [document.metadata["doc_id"] for document in documents] - segments = DocumentSegment.query.filter( - DocumentSegment.dataset_id == self.dataset_id, - DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids), - ).all() - - if segments: - index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted( - segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) - ) - for segment in sorted_segments: - if segment.answer: - document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") - else: - document_context_list.append(segment.get_sign_content()) - if self.return_resource: - context_list = [] - resource_number = 1 + if self.top_k > 0: + # retrieval source + documents = RetrievalService.retrieve( + retrieval_method=retrieval_model.get("search_method", "semantic_search"), + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else 0.0, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", + weights=retrieval_model.get("weights", None), + ) + else: + documents = [] + + for hit_callback in self.hit_callbacks: + hit_callback.on_tool_end(documents) + document_score_list = {} + if dataset.indexing_technique != "economy": + for item in documents: + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] + document_context_list = [] + index_node_ids = [document.metadata["doc_id"] for document in documents] + segments = DocumentSegment.query.filter( + DocumentSegment.dataset_id == self.dataset_id, + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids), + ).all() + + if segments: + index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: - context = {} - document = Document.query.filter( - Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() - if dataset and document: - source = { - "position": resource_number, - "dataset_id": dataset.id, - "dataset_name": dataset.name, - "document_id": document.id, - "document_name": document.name, - "data_source_type": document.data_source_type, - "segment_id": segment.id, - "retriever_from": self.retriever_from, - "score": document_score_list.get(segment.index_node_id, None), - } - if self.retriever_from == "dev": - source["hit_count"] = segment.hit_count - source["word_count"] = segment.word_count - source["segment_position"] = segment.position - source["index_node_hash"] = segment.index_node_hash - if segment.answer: - source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" - else: - source["content"] = segment.content - context_list.append(source) - resource_number += 1 - - for hit_callback in self.hit_callbacks: - hit_callback.return_retriever_resource_info(context_list) - - return str("\n".join(document_context_list)) + if segment.answer: + document_context_list.append( + f"question:{segment.get_sign_content()} answer:{segment.answer}" + ) + else: + document_context_list.append(segment.get_sign_content()) + if self.return_resource: + context_list = [] + resource_number = 1 + for segment in sorted_segments: + context = {} + document = Document.query.filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() + if dataset and document: + source = { + "position": resource_number, + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": self.retriever_from, + "score": document_score_list.get(segment.index_node_id, None), + } + if self.retriever_from == "dev": + source["hit_count"] = segment.hit_count + source["word_count"] = segment.word_count + source["segment_position"] = segment.position + source["index_node_hash"] = segment.index_node_hash + if segment.answer: + source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" + else: + source["content"] = segment.content + context_list.append(source) + resource_number += 1 + + for hit_callback in self.hit_callbacks: + hit_callback.return_retriever_resource_info(context_list) + + return str("\n".join(document_context_list)) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 0caf99a963822c..0b3e9bd6a888ec 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -79,8 +79,9 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: results = ( db.session.query(Dataset) - .join(subquery, Dataset.id == subquery.c.dataset_id) + .outerjoin(subquery, Dataset.id == subquery.c.dataset_id) .filter(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids)) + .filter((subquery.c.available_document_count > 0) | (Dataset.provider == "external")) .all() ) @@ -121,10 +122,13 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: ) elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": - reranking_model = { - "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider, - "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model, - } + if node_data.multiple_retrieval_config.reranking_model: + reranking_model = { + "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider, + "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model, + } + else: + reranking_model = None weights = None elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score": reranking_model = None diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 4f77fdc18cb432..ede87640862199 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -234,6 +234,7 @@ def update_dataset(dataset_id, data, user): dataset.name = data.get("name", dataset.name) dataset.description = data.get("description", "") external_knowledge_id = data.get("external_knowledge_id", None) + dataset.permission = data.get("permission") db.session.add(dataset) if not external_knowledge_id: raise ValueError("External knowledge id is required.") diff --git a/api/services/knowledge_service.py b/api/services/knowledge_service.py new file mode 100644 index 00000000000000..02fe1d19bc42be --- /dev/null +++ b/api/services/knowledge_service.py @@ -0,0 +1,45 @@ +import boto3 + +from configs import dify_config + + +class ExternalDatasetTestService: + # this service is only for internal testing + @staticmethod + def knowledge_retrieval(retrieval_setting: dict, query: str, knowledge_id: str): + # get bedrock client + client = boto3.client( + "bedrock-agent-runtime", + aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY, + aws_access_key_id=dify_config.AWS_ACCESS_KEY_ID, + # example: us-east-1 + region_name="us-east-1", + ) + # fetch external knowledge retrieval + response = client.retrieve( + knowledgeBaseId=knowledge_id, + retrievalConfiguration={ + "vectorSearchConfiguration": { + "numberOfResults": retrieval_setting.get("top_k"), + "overrideSearchType": "HYBRID", + } + }, + retrievalQuery={"text": query}, + ) + # parse response + results = [] + if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200: + if response.get("retrievalResults"): + retrieval_results = response.get("retrievalResults") + for retrieval_result in retrieval_results: + # filter out results with score less than threshold + if retrieval_result.get("score") < retrieval_setting.get("score_threshold", 0.0): + continue + result = { + "metadata": retrieval_result.get("metadata"), + "score": retrieval_result.get("score"), + "title": retrieval_result.get("metadata").get("x-amz-bedrock-kb-source-uri"), + "content": retrieval_result.get("content").get("text"), + } + results.append(result) + return {"records": results}