From 0e813479f03df8c08112eb54a9e336538c85b6c5 Mon Sep 17 00:00:00 2001 From: liaojianxing Date: Thu, 25 Jul 2024 11:15:42 +0800 Subject: [PATCH 01/12] Added support for HyDE --- config/config2.example.yaml | 4 +++ config/config2.yaml | 12 ++++--- metagpt/config2.py | 4 +++ metagpt/configs/query_analysis_config.py | 5 +++ metagpt/rag/engines/simple.py | 1 - metagpt/rag/query_analysis/HyDE.py | 43 ++++++++++++++++++++++++ metagpt/rag/query_analysis/__init__.py | 0 metagpt/rag/schema.py | 4 +++ 8 files changed, 68 insertions(+), 5 deletions(-) create mode 100644 metagpt/configs/query_analysis_config.py create mode 100644 metagpt/rag/query_analysis/HyDE.py create mode 100644 metagpt/rag/query_analysis/__init__.py diff --git a/config/config2.example.yaml b/config/config2.example.yaml index 0fe11df4e..3b277f573 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -20,6 +20,10 @@ embedding: embed_batch_size: 100 dimensions: # output dimension of embedding model + +hyde: + include_original: true + repair_llm_output: true # when the output is not a valid json, try to repair it proxy: "YOUR_PROXY" # for tools like requests, playwright, selenium, etc. diff --git a/config/config2.yaml b/config/config2.yaml index b3f24539c..0895da664 100644 --- a/config/config2.yaml +++ b/config/config2.yaml @@ -2,7 +2,11 @@ # Reflected Code: https://github.com/geekan/MetaGPT/blob/main/metagpt/config2.py # Config Docs: https://docs.deepwisdom.ai/main/en/guide/get_started/configuration.html llm: - api_type: "openai" # or azure / ollama / groq etc. - model: "gpt-4-turbo" # or gpt-3.5-turbo - base_url: "https://api.openai.com/v1" # or forward url / other llm url - api_key: "YOUR_API_KEY" \ No newline at end of file + api_type: 'azure' + base_url: 'https://deepwisdomai03.openai.azure.com/' + api_key: '006a9c24a03e46609f49c84e85e707dd' + api_version: '2024-05-01-preview' + model: 'gpt-4o' + +hyde: + include_original: true \ No newline at end of file diff --git a/metagpt/config2.py b/metagpt/config2.py index 58a99c920..6ddb4b207 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -15,6 +15,7 @@ from metagpt.configs.embedding_config import EmbeddingConfig from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.configs.mermaid_config import MermaidConfig +from metagpt.configs.query_analysis_config import HydeConfig from metagpt.configs.redis_config import RedisConfig from metagpt.configs.s3_config import S3Config from metagpt.configs.search_config import SearchConfig @@ -51,6 +52,9 @@ class Config(CLIParams, YamlModel): # RAG Embedding embedding: EmbeddingConfig = EmbeddingConfig() + # RAG Analysis + hyde: HydeConfig = HydeConfig() + # Global Proxy. Will be used if llm.proxy is not set proxy: str = "" diff --git a/metagpt/configs/query_analysis_config.py b/metagpt/configs/query_analysis_config.py new file mode 100644 index 000000000..e8f9dc562 --- /dev/null +++ b/metagpt/configs/query_analysis_config.py @@ -0,0 +1,5 @@ +from metagpt.utils.yaml_model import YamlModel + + +class HydeConfig(YamlModel): + include_original: bool = True diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index c237dcf69..f4bdef1be 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -105,7 +105,6 @@ def from_docs( transformations = transformations or cls._default_transformations() nodes = run_transformations(documents, transformations=transformations) - return cls._from_nodes( nodes=nodes, transformations=transformations, diff --git a/metagpt/rag/query_analysis/HyDE.py b/metagpt/rag/query_analysis/HyDE.py new file mode 100644 index 000000000..89ee2ec8f --- /dev/null +++ b/metagpt/rag/query_analysis/HyDE.py @@ -0,0 +1,43 @@ +from typing import Dict + +from llama_index.core.indices.query.query_transform import HyDEQueryTransform +from llama_index.core.llms import LLM +from llama_index.core.schema import QueryBundle + +from metagpt.config2 import config +from metagpt.logs import logger +from metagpt.rag.factories import get_rag_llm +from metagpt.rag.factories.base import ConfigBasedFactory + + +class HyDEQuery(HyDEQueryTransform): + def _run(self, query_bundle: QueryBundle, metadata: Dict) -> QueryBundle: + logger.info(f"{'#' * 5} running HyDEQuery... {'#' * 5}") + query_str = query_bundle.query_str + + hypothetical_doc = self._llm.predict(self._hyde_prompt, context_str=query_str) + + embedding_strs = [hypothetical_doc] + + if self._include_original: + embedding_strs.extend(query_bundle.embedding_strs) + logger.info(f" Hypothetical doc:{embedding_strs} ") + + return QueryBundle( + query_str=query_str, + custom_embedding_strs=embedding_strs, + ) + + +class HyDEQueryTransformFactory(ConfigBasedFactory): + """Factory for creating HyDEQueryTransform instances with LLM configuration.""" + + llm: LLM = None + hyde_config: dict = None + + def __init__(self): + self.hyde_config = config.hyde + self.llm = get_rag_llm() + + def create_hyde_query_transform(self) -> HyDEQuery: + return HyDEQuery(include_original=self.hyde_config.include_original, llm=self.llm) diff --git a/metagpt/rag/query_analysis/__init__.py b/metagpt/rag/query_analysis/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 618880a22..7a4b670e4 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -189,6 +189,10 @@ class ElasticsearchKeywordIndexConfig(ElasticsearchIndexConfig): _no_embedding: bool = PrivateAttr(default=True) +class HydeConfig(BaseRetrieverConfig): + """Config for HyDe""" + + class ObjectNodeMetadata(BaseModel): """Metadata of ObjectNode.""" From 2cc6d60db0eb899d51ad57643109220b73bd06f7 Mon Sep 17 00:00:00 2001 From: liaojianxing Date: Thu, 25 Jul 2024 11:35:55 +0800 Subject: [PATCH 02/12] fix config2.yaml --- config/config2.yaml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/config/config2.yaml b/config/config2.yaml index 0895da664..1ab8e5181 100644 --- a/config/config2.yaml +++ b/config/config2.yaml @@ -1,12 +1,6 @@ # Full Example: https://github.com/geekan/MetaGPT/blob/main/config/config2.example.yaml # Reflected Code: https://github.com/geekan/MetaGPT/blob/main/metagpt/config2.py # Config Docs: https://docs.deepwisdom.ai/main/en/guide/get_started/configuration.html -llm: - api_type: 'azure' - base_url: 'https://deepwisdomai03.openai.azure.com/' - api_key: '006a9c24a03e46609f49c84e85e707dd' - api_version: '2024-05-01-preview' - model: 'gpt-4o' hyde: include_original: true \ No newline at end of file From b4fa4685d6cfea7a9e10e911568615b511b01929 Mon Sep 17 00:00:00 2001 From: liaojianxing Date: Thu, 25 Jul 2024 19:30:35 +0800 Subject: [PATCH 03/12] Add the HyDE example to the rag_pipeline --- config/config2.example.yaml | 4 ++-- examples/rag_pipeline.py | 19 +++++++++++++++++++ metagpt/rag/schema.py | 4 ---- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/config/config2.example.yaml b/config/config2.example.yaml index 3b277f573..69f902aa2 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -20,9 +20,9 @@ embedding: embed_batch_size: 100 dimensions: # output dimension of embedding model - +# RAG Analysis hyde: - include_original: true + include_original: true # In the query rewrite the content whether to add the original question repair_llm_output: true # when the output is not a valid json, try to repair it diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index 5b716ce03..be56c32bb 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -2,11 +2,13 @@ import asyncio +from llama_index.core.query_engine import TransformQueryEngine from pydantic import BaseModel from metagpt.const import DATA_PATH, EXAMPLE_DATA_PATH from metagpt.logs import logger from metagpt.rag.engines import SimpleEngine +from metagpt.rag.query_analysis.HyDE import HyDEQueryTransformFactory from metagpt.rag.schema import ( ChromaIndexConfig, ChromaRetrieverConfig, @@ -212,6 +214,22 @@ async def init_and_query_es(self): answer = await engine.aquery(TRAVEL_QUESTION) self._print_query_result(answer) + async def use_HyDe(self): + """This example show how to use HyDE: HyDE enhances search results by generating Hypothetical doc(virtual + article), for more details please refer to the paper: http://arxiv.org/abs/2212.10496 + Query Result: + Bob likes traveling. + """ + self._print_title("Use HyDE to analysis query") + # 1. save docs + engine = SimpleEngine.from_docs(input_files=[TRAVEL_DOC_PATH]) + # create HyDE query engine + hyde_query_transformr = HyDEQueryTransformFactory().create_hyde_query_transform() + hyde_query_engine = TransformQueryEngine(engine, hyde_query_transformr) + # 3. query + answer = await hyde_query_engine.aquery(TRAVEL_QUESTION) + self._print_query_result(answer) + @staticmethod def _print_title(title): logger.info(f"{'#'*30} {title} {'#'*30}") @@ -254,6 +272,7 @@ async def main(): await e.init_objects() await e.init_and_query_chromadb() await e.init_and_query_es() + await e.use_HyDe() if __name__ == "__main__": diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 7a4b670e4..618880a22 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -189,10 +189,6 @@ class ElasticsearchKeywordIndexConfig(ElasticsearchIndexConfig): _no_embedding: bool = PrivateAttr(default=True) -class HydeConfig(BaseRetrieverConfig): - """Config for HyDe""" - - class ObjectNodeMetadata(BaseModel): """Metadata of ObjectNode.""" From 2819b2e91c9b967bd4610548f0d9e05f49d66282 Mon Sep 17 00:00:00 2001 From: liaojianxing Date: Fri, 26 Jul 2024 11:39:51 +0800 Subject: [PATCH 04/12] HyDEQueryTransformSubmission Instructions: Simulation functions (mock_openai_embedding, mock_azure_embedding, mock_gemini_embedding, and mock_ollama_embedding) have been added. Reason for adding: Fix the issue that static methods are not callable: The previous code parameterized the static method as a parameterized test, but the static method was not a callable object, resulting in a TypeError error.Factory.py --- config/config2.yaml | 7 ++-- examples/rag_pipeline.py | 2 +- .../factories/HyDEQueryTransformFactory.py | 20 +++++++++++ metagpt/rag/query_analysis/HyDE.py | 18 ---------- tests/metagpt/rag/factories/test_embedding.py | 36 +++++++++---------- 5 files changed, 44 insertions(+), 39 deletions(-) create mode 100644 metagpt/rag/factories/HyDEQueryTransformFactory.py diff --git a/config/config2.yaml b/config/config2.yaml index 1ab8e5181..5f875a7bb 100644 --- a/config/config2.yaml +++ b/config/config2.yaml @@ -1,6 +1,9 @@ # Full Example: https://github.com/geekan/MetaGPT/blob/main/config/config2.example.yaml # Reflected Code: https://github.com/geekan/MetaGPT/blob/main/metagpt/config2.py # Config Docs: https://docs.deepwisdom.ai/main/en/guide/get_started/configuration.html +llm: + api_type: "openai" # or azure / ollama / groq etc. + model: "gpt-4-turbo" # or gpt-3.5-turbo + base_url: "https://api.openai.com/v1" # or forward url / other llm url + api_key: "YOUR_API_KEY" -hyde: - include_original: true \ No newline at end of file diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index be56c32bb..4fba52a66 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -8,7 +8,7 @@ from metagpt.const import DATA_PATH, EXAMPLE_DATA_PATH from metagpt.logs import logger from metagpt.rag.engines import SimpleEngine -from metagpt.rag.query_analysis.HyDE import HyDEQueryTransformFactory +from metagpt.rag.factories.HyDEQueryTransformFactory import HyDEQueryTransformFactory from metagpt.rag.schema import ( ChromaIndexConfig, ChromaRetrieverConfig, diff --git a/metagpt/rag/factories/HyDEQueryTransformFactory.py b/metagpt/rag/factories/HyDEQueryTransformFactory.py new file mode 100644 index 000000000..c0946f8ee --- /dev/null +++ b/metagpt/rag/factories/HyDEQueryTransformFactory.py @@ -0,0 +1,20 @@ +from llama_index.core.llms import LLM + +from metagpt.config2 import config +from metagpt.rag.factories import get_rag_llm +from metagpt.rag.factories.base import ConfigBasedFactory +from metagpt.rag.query_analysis.HyDE import HyDEQuery + + +class HyDEQueryTransformFactory(ConfigBasedFactory): + """Factory for creating HyDEQueryTransform instances with LLM configuration.""" + + llm: LLM = None + hyde_config: dict = None + + def __init__(self): + self.hyde_config = config.hyde + self.llm = get_rag_llm() + + def create_hyde_query_transform(self) -> HyDEQuery: + return HyDEQuery(include_original=self.hyde_config.include_original, llm=self.llm) diff --git a/metagpt/rag/query_analysis/HyDE.py b/metagpt/rag/query_analysis/HyDE.py index 89ee2ec8f..c1bcedcae 100644 --- a/metagpt/rag/query_analysis/HyDE.py +++ b/metagpt/rag/query_analysis/HyDE.py @@ -1,13 +1,9 @@ from typing import Dict from llama_index.core.indices.query.query_transform import HyDEQueryTransform -from llama_index.core.llms import LLM from llama_index.core.schema import QueryBundle -from metagpt.config2 import config from metagpt.logs import logger -from metagpt.rag.factories import get_rag_llm -from metagpt.rag.factories.base import ConfigBasedFactory class HyDEQuery(HyDEQueryTransform): @@ -27,17 +23,3 @@ def _run(self, query_bundle: QueryBundle, metadata: Dict) -> QueryBundle: query_str=query_str, custom_embedding_strs=embedding_strs, ) - - -class HyDEQueryTransformFactory(ConfigBasedFactory): - """Factory for creating HyDEQueryTransform instances with LLM configuration.""" - - llm: LLM = None - hyde_config: dict = None - - def __init__(self): - self.hyde_config = config.hyde - self.llm = get_rag_llm() - - def create_hyde_query_transform(self) -> HyDEQuery: - return HyDEQuery(include_original=self.hyde_config.include_original, llm=self.llm) diff --git a/tests/metagpt/rag/factories/test_embedding.py b/tests/metagpt/rag/factories/test_embedding.py index 1a9e9b2c9..3b84a8f86 100644 --- a/tests/metagpt/rag/factories/test_embedding.py +++ b/tests/metagpt/rag/factories/test_embedding.py @@ -5,6 +5,22 @@ from metagpt.rag.factories.embedding import RAGEmbeddingFactory +def mock_azure_embedding(mocker): + return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding") + + +def mock_gemini_embedding(mocker): + return mocker.patch("metagpt.rag.factories.embedding.GeminiEmbedding") + + +def mock_ollama_embedding(mocker): + return mocker.patch("metagpt.rag.factories.embedding.OllamaEmbedding") + + +def mock_openai_embedding(mocker): + return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding") + + class TestRAGEmbeddingFactory: @pytest.fixture(autouse=True) def mock_embedding_factory(self): @@ -14,22 +30,6 @@ def mock_embedding_factory(self): def mock_config(self, mocker): return mocker.patch("metagpt.rag.factories.embedding.config") - @staticmethod - def mock_openai_embedding(mocker): - return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding") - - @staticmethod - def mock_azure_embedding(mocker): - return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding") - - @staticmethod - def mock_gemini_embedding(mocker): - return mocker.patch("metagpt.rag.factories.embedding.GeminiEmbedding") - - @staticmethod - def mock_ollama_embedding(mocker): - return mocker.patch("metagpt.rag.factories.embedding.OllamaEmbedding") - @pytest.mark.parametrize( ("mock_func", "embedding_type"), [ @@ -53,7 +53,7 @@ def test_get_rag_embedding(self, mock_func, embedding_type, mocker): def test_get_rag_embedding_default(self, mocker, mock_config): # Mock - mock_openai_embedding = self.mock_openai_embedding(mocker) + mock_openai_emb = mock_openai_embedding(mocker) mock_config.embedding.api_type = None mock_config.llm.api_type = LLMType.OPENAI @@ -62,7 +62,7 @@ def test_get_rag_embedding_default(self, mocker, mock_config): self.embedding_factory.get_rag_embedding() # Assert - mock_openai_embedding.assert_called_once() + mock_openai_emb.assert_called_once() @pytest.mark.parametrize( "model, embed_batch_size, expected_params", From ec400435394013d4258530ff2555d92ab00b8591 Mon Sep 17 00:00:00 2001 From: liaojianxing Date: Mon, 19 Aug 2024 15:32:37 +0800 Subject: [PATCH 05/12] Update hyde function --- config/config2.example.yaml | 22 +++-- examples/rag_pipeline.py | 32 +++++++- metagpt/configs/hyde_config.py | 7 ++ metagpt/configs/query_analysis_config.py | 6 ++ metagpt/configs/rag_config.py | 8 ++ metagpt/rag/engines/simple.py | 82 ++++++++++--------- metagpt/rag/factories/embedding.py | 28 +++---- .../simple_query_transformer.py | 34 ++++++++ metagpt/rag/schema.py | 8 +- tests/metagpt/rag/factories/test_embedding.py | 46 +++++------ 10 files changed, 183 insertions(+), 90 deletions(-) create mode 100644 metagpt/configs/hyde_config.py create mode 100644 metagpt/configs/query_analysis_config.py create mode 100644 metagpt/configs/rag_config.py create mode 100644 metagpt/rag/query_analysis/simple_query_transformer.py diff --git a/config/config2.example.yaml b/config/config2.example.yaml index 0fe11df4e..34b4f2b75 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -9,16 +9,22 @@ llm: pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's +rag: # RAG Embedding. # For backward compatibility, if the embedding is not set and the llm's api_type is either openai or azure, the llm's config will be used. -embedding: - api_type: "" # openai / azure / gemini / ollama etc. Check EmbeddingType for more options. - base_url: "" - api_key: "" - model: "" - api_version: "" - embed_batch_size: 100 - dimensions: # output dimension of embedding model + embedding: + api_type: "" # openai / azure / gemini / ollama etc. Check EmbeddingType for more options. + base_url: "" + api_key: "" + model: "" + api_version: "" + embed_batch_size: 100 + dimensions: # output dimension of embedding model + # RAG Query Analysis + query_analysis: + hyde: + include_original: true # In the query rewrite, determines whether to include the original + repair_llm_output: true # when the output is not a valid json, try to repair it diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index 5b716ce03..81e75c2ea 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -2,6 +2,7 @@ import asyncio +from llama_index.core.query_engine import TransformQueryEngine from pydantic import BaseModel from metagpt.const import DATA_PATH, EXAMPLE_DATA_PATH @@ -17,12 +18,16 @@ LLMRankerConfig, ) from metagpt.utils.exceptions import handle_exception +from metagpt.rag.query_analysis.simple_query_transformer import SimpleQueryTransformer +from metagpt.rag.query_analysis.HyDE import HyDEQuery + + LLM_TIP = "If you not sure, just answer I don't know." DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt" QUESTION = f"What are key qualities to be a good writer? {LLM_TIP}" - +QUESTION2 = "What are the key factors in maintaining high productivity?" TRAVEL_DOC_PATH = EXAMPLE_DATA_PATH / "rag/travel.txt" TRAVEL_QUESTION = f"What does Bob like? {LLM_TIP}" @@ -212,9 +217,31 @@ async def init_and_query_es(self): answer = await engine.aquery(TRAVEL_QUESTION) self._print_query_result(answer) + async def use_hyde(self): + """This example show how to use HyDE: HyDE enhances search results by generating Hypothetical doc(virtual + article), for more details please refer to the paper: http://arxiv.org/abs/2212.10496 + Query Result: + Bob likes traveling. + """ + + self._print_title("Use HyDE to analysis query") + # 1. save docs + engine = SimpleEngine.from_docs(input_files=[DOC_PATH]) + # 2. Initialize HyDE query analysis method + hyde_query = HyDEQuery() + # 3. Add HyDE to the engine + hyde_query_engine = SimpleQueryTransformer(engine, hyde_query) + answer = await hyde_query_engine.aquery(QUESTION2) + self._print_query_result(answer) + + self._print_title("No use HyDE to analysis query") + engine = SimpleEngine.from_docs(input_files=[DOC_PATH]) + answer = await engine.aquery(QUESTION2) + self._print_query_result(answer) + @staticmethod def _print_title(title): - logger.info(f"{'#'*30} {title} {'#'*30}") + logger.info(f"{'#' * 30} {title} {'#' * 30}") @staticmethod def _print_retrieve_result(result): @@ -254,6 +281,7 @@ async def main(): await e.init_objects() await e.init_and_query_chromadb() await e.init_and_query_es() + await e.use_hyde() if __name__ == "__main__": diff --git a/metagpt/configs/hyde_config.py b/metagpt/configs/hyde_config.py new file mode 100644 index 000000000..6cf547630 --- /dev/null +++ b/metagpt/configs/hyde_config.py @@ -0,0 +1,7 @@ +from metagpt.utils.yaml_model import YamlModel + + +class HyDEConfig(YamlModel): + include_original: bool = True + + diff --git a/metagpt/configs/query_analysis_config.py b/metagpt/configs/query_analysis_config.py new file mode 100644 index 000000000..292c6b8e8 --- /dev/null +++ b/metagpt/configs/query_analysis_config.py @@ -0,0 +1,6 @@ +from metagpt.utils.yaml_model import YamlModel +from metagpt.configs.hyde_config import HyDEConfig + + +class QueryAnalysisConfig(YamlModel): + hyde: HyDEConfig = HyDEConfig() diff --git a/metagpt/configs/rag_config.py b/metagpt/configs/rag_config.py new file mode 100644 index 000000000..f469ff919 --- /dev/null +++ b/metagpt/configs/rag_config.py @@ -0,0 +1,8 @@ +from metagpt.utils.yaml_model import YamlModel +from metagpt.configs.embedding_config import EmbeddingConfig +from metagpt.configs.query_analysis_config import QueryAnalysisConfig + + +class RAGConfig(YamlModel): + embedding: EmbeddingConfig = EmbeddingConfig() + query_analysis: QueryAnalysisConfig = QueryAnalysisConfig() diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index c237dcf69..10b8e1efa 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -27,6 +27,8 @@ QueryType, TransformComponent, ) +from llama_index.core.query_engine import TransformQueryEngine +from llama_index.legacy.indices.query.query_transform.base import BaseQueryTransform from metagpt.rag.factories import ( get_index, @@ -58,12 +60,12 @@ class SimpleEngine(RetrieverQueryEngine): """ def __init__( - self, - retriever: BaseRetriever, - response_synthesizer: Optional[BaseSynthesizer] = None, - node_postprocessors: Optional[list[BaseNodePostprocessor]] = None, - callback_manager: Optional[CallbackManager] = None, - transformations: Optional[list[TransformComponent]] = None, + self, + retriever: BaseRetriever, + response_synthesizer: Optional[BaseSynthesizer] = None, + node_postprocessors: Optional[list[BaseNodePostprocessor]] = None, + callback_manager: Optional[CallbackManager] = None, + transformations: Optional[list[TransformComponent]] = None, ) -> None: super().__init__( retriever=retriever, @@ -75,14 +77,14 @@ def __init__( @classmethod def from_docs( - cls, - input_dir: str = None, - input_files: list[str] = None, - transformations: Optional[list[TransformComponent]] = None, - embed_model: BaseEmbedding = None, - llm: LLM = None, - retriever_configs: list[BaseRetrieverConfig] = None, - ranker_configs: list[BaseRankerConfig] = None, + cls, + input_dir: str = None, + input_files: list[str] = None, + transformations: Optional[list[TransformComponent]] = None, + embed_model: BaseEmbedding = None, + llm: LLM = None, + retriever_configs: list[BaseRetrieverConfig] = None, + ranker_configs: list[BaseRankerConfig] = None, ) -> "SimpleEngine": """From docs. @@ -117,13 +119,13 @@ def from_docs( @classmethod def from_objs( - cls, - objs: Optional[list[RAGObject]] = None, - transformations: Optional[list[TransformComponent]] = None, - embed_model: BaseEmbedding = None, - llm: LLM = None, - retriever_configs: list[BaseRetrieverConfig] = None, - ranker_configs: list[BaseRankerConfig] = None, + cls, + objs: Optional[list[RAGObject]] = None, + transformations: Optional[list[TransformComponent]] = None, + embed_model: BaseEmbedding = None, + llm: LLM = None, + retriever_configs: list[BaseRetrieverConfig] = None, + ranker_configs: list[BaseRankerConfig] = None, ) -> "SimpleEngine": """From objs. @@ -154,12 +156,12 @@ def from_objs( @classmethod def from_index( - cls, - index_config: BaseIndexConfig, - embed_model: BaseEmbedding = None, - llm: LLM = None, - retriever_configs: list[BaseRetrieverConfig] = None, - ranker_configs: list[BaseRankerConfig] = None, + cls, + index_config: BaseIndexConfig, + embed_model: BaseEmbedding = None, + llm: LLM = None, + retriever_configs: list[BaseRetrieverConfig] = None, + ranker_configs: list[BaseRankerConfig] = None, ) -> "SimpleEngine": """Load from previously maintained index by self.persist(), index_config contains persis_path.""" index = get_index(index_config, embed_model=cls._resolve_embed_model(embed_model, [index_config])) @@ -209,13 +211,13 @@ def persist(self, persist_dir: Union[str, os.PathLike], **kwargs): @classmethod def _from_nodes( - cls, - nodes: list[BaseNode], - transformations: Optional[list[TransformComponent]] = None, - embed_model: BaseEmbedding = None, - llm: LLM = None, - retriever_configs: list[BaseRetrieverConfig] = None, - ranker_configs: list[BaseRankerConfig] = None, + cls, + nodes: list[BaseNode], + transformations: Optional[list[TransformComponent]] = None, + embed_model: BaseEmbedding = None, + llm: LLM = None, + retriever_configs: list[BaseRetrieverConfig] = None, + ranker_configs: list[BaseRankerConfig] = None, ) -> "SimpleEngine": embed_model = cls._resolve_embed_model(embed_model, retriever_configs) llm = llm or get_rag_llm() @@ -232,11 +234,11 @@ def _from_nodes( @classmethod def _from_index( - cls, - index: BaseIndex, - llm: LLM = None, - retriever_configs: list[BaseRetrieverConfig] = None, - ranker_configs: list[BaseRankerConfig] = None, + cls, + index: BaseIndex, + llm: LLM = None, + retriever_configs: list[BaseRetrieverConfig] = None, + ranker_configs: list[BaseRankerConfig] = None, ) -> "SimpleEngine": llm = llm or get_rag_llm() @@ -301,3 +303,5 @@ def _resolve_embed_model(embed_model: BaseEmbedding = None, configs: list[Any] = @staticmethod def _default_transformations(): return [SentenceSplitter()] + + diff --git a/metagpt/rag/factories/embedding.py b/metagpt/rag/factories/embedding.py index 3613fd228..45a300cbc 100644 --- a/metagpt/rag/factories/embedding.py +++ b/metagpt/rag/factories/embedding.py @@ -40,8 +40,8 @@ def _resolve_embedding_type(self) -> EmbeddingType | LLMType: If the embedding type is not specified, for backward compatibility, it checks if the LLM API type is either OPENAI or AZURE. Raise TypeError if embedding type not found. """ - if config.embedding.api_type: - return config.embedding.api_type + if config.rag.embedding.api_type: + return config.rag.embedding.api_type if config.llm.api_type in [LLMType.OPENAI, LLMType.AZURE]: return config.llm.api_type @@ -50,8 +50,8 @@ def _resolve_embedding_type(self) -> EmbeddingType | LLMType: def _create_openai(self) -> OpenAIEmbedding: params = dict( - api_key=config.embedding.api_key or config.llm.api_key, - api_base=config.embedding.base_url or config.llm.base_url, + api_key=config.rag.embedding.api_key or config.llm.api_key, + api_base=config.rag.embedding.base_url or config.llm.base_url, ) self._try_set_model_and_batch_size(params) @@ -60,9 +60,9 @@ def _create_openai(self) -> OpenAIEmbedding: def _create_azure(self) -> AzureOpenAIEmbedding: params = dict( - api_key=config.embedding.api_key or config.llm.api_key, - azure_endpoint=config.embedding.base_url or config.llm.base_url, - api_version=config.embedding.api_version or config.llm.api_version, + api_key=config.rag.embedding.api_key or config.llm.api_key, + azure_endpoint=config.rag.embedding.base_url or config.llm.base_url, + api_version=config.rag.embedding.api_version or config.llm.api_version, ) self._try_set_model_and_batch_size(params) @@ -71,8 +71,8 @@ def _create_azure(self) -> AzureOpenAIEmbedding: def _create_gemini(self) -> GeminiEmbedding: params = dict( - api_key=config.embedding.api_key, - api_base=config.embedding.base_url, + api_key=config.rag.embedding.api_key, + api_base=config.rag.embedding.base_url, ) self._try_set_model_and_batch_size(params) @@ -81,7 +81,7 @@ def _create_gemini(self) -> GeminiEmbedding: def _create_ollama(self) -> OllamaEmbedding: params = dict( - base_url=config.embedding.base_url, + base_url=config.rag.embedding.base_url, ) self._try_set_model_and_batch_size(params) @@ -90,11 +90,11 @@ def _create_ollama(self) -> OllamaEmbedding: def _try_set_model_and_batch_size(self, params: dict): """Set the model_name and embed_batch_size only when they are specified.""" - if config.embedding.model: - params["model_name"] = config.embedding.model + if config.rag.embedding.model: + params["model_name"] = config.rag.embedding.model - if config.embedding.embed_batch_size: - params["embed_batch_size"] = config.embedding.embed_batch_size + if config.rag.embedding.embed_batch_size: + params["embed_batch_size"] = config.rag.embedding.embed_batch_size def _raise_for_key(self, key: Any): raise ValueError(f"The embedding type is currently not supported: `{type(key)}`, {key}") diff --git a/metagpt/rag/query_analysis/simple_query_transformer.py b/metagpt/rag/query_analysis/simple_query_transformer.py new file mode 100644 index 000000000..4eaacdb46 --- /dev/null +++ b/metagpt/rag/query_analysis/simple_query_transformer.py @@ -0,0 +1,34 @@ +from llama_index.core.base.response.schema import RESPONSE_TYPE +from llama_index.core.callbacks import CallbackManager +from llama_index.core.indices.query.query_transform.base import BaseQueryTransform +from llama_index.core.query_engine import TransformQueryEngine, BaseQueryEngine +from typing import List, Optional, Sequence + +from metagpt.rag.engines.simple import SimpleEngine + + +class SimpleQueryTransformer(TransformQueryEngine): + """Simple query engine + + Extends the TransformQueryEngine to handle simpler queries using a basic query engine. + + Args: + query_engine (BaseQueryEngine): A simple query engine object. + query_transform (BaseQueryTransform): A query transform object. + transform_metadata (Optional[dict]): metadata to pass to the query transform. + callback_manager (Optional[CallbackManager]): A callback manager. + """ + + def __init__( + self, + query_engine: SimpleEngine, + query_transform: BaseQueryTransform, + transform_metadata: Optional[dict] = None, + callback_manager: Optional[CallbackManager] = None, + ) -> None: + super().__init__( + query_engine=query_engine, + query_transform=query_transform, + transform_metadata=transform_metadata, + callback_manager=callback_manager + ) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 618880a22..ba1d67941 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -45,12 +45,12 @@ class FAISSRetrieverConfig(IndexRetrieverConfig): @model_validator(mode="after") def check_dimensions(self): if self.dimensions == 0: - self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get( - config.embedding.api_type, 1536 + self.dimensions = config.rag.embedding.dimensions or self._embedding_type_to_dimensions.get( + config.rag.embedding.api_type, 1536 ) - if not config.embedding.dimensions and config.embedding.api_type not in self._embedding_type_to_dimensions: + if not config.rag.embedding.dimensions and config.rag.embedding.api_type not in self._embedding_type_to_dimensions: logger.warning( - f"You didn't set dimensions in config when using {config.embedding.api_type}, default to 1536" + f"You didn't set dimensions in config when using {config.rag.embedding.api_type}, default to 1536" ) return self diff --git a/tests/metagpt/rag/factories/test_embedding.py b/tests/metagpt/rag/factories/test_embedding.py index 1a9e9b2c9..332092a9b 100644 --- a/tests/metagpt/rag/factories/test_embedding.py +++ b/tests/metagpt/rag/factories/test_embedding.py @@ -5,6 +5,22 @@ from metagpt.rag.factories.embedding import RAGEmbeddingFactory +def mock_azure_embedding(mocker): + return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding") + + +def mock_gemini_embedding(mocker): + return mocker.patch("metagpt.rag.factories.embedding.GeminiEmbedding") + + +def mock_ollama_embedding(mocker): + return mocker.patch("metagpt.rag.factories.embedding.OllamaEmbedding") + + +def mock_openai_embedding(mocker): + return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding") + + class TestRAGEmbeddingFactory: @pytest.fixture(autouse=True) def mock_embedding_factory(self): @@ -14,22 +30,6 @@ def mock_embedding_factory(self): def mock_config(self, mocker): return mocker.patch("metagpt.rag.factories.embedding.config") - @staticmethod - def mock_openai_embedding(mocker): - return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding") - - @staticmethod - def mock_azure_embedding(mocker): - return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding") - - @staticmethod - def mock_gemini_embedding(mocker): - return mocker.patch("metagpt.rag.factories.embedding.GeminiEmbedding") - - @staticmethod - def mock_ollama_embedding(mocker): - return mocker.patch("metagpt.rag.factories.embedding.OllamaEmbedding") - @pytest.mark.parametrize( ("mock_func", "embedding_type"), [ @@ -53,16 +53,16 @@ def test_get_rag_embedding(self, mock_func, embedding_type, mocker): def test_get_rag_embedding_default(self, mocker, mock_config): # Mock - mock_openai_embedding = self.mock_openai_embedding(mocker) + mock_openai_emb = mock_openai_embedding(mocker) - mock_config.embedding.api_type = None + mock_config.rag.embedding.api_type = None mock_config.llm.api_type = LLMType.OPENAI # Exec self.embedding_factory.get_rag_embedding() # Assert - mock_openai_embedding.assert_called_once() + mock_openai_emb.assert_called_once() @pytest.mark.parametrize( "model, embed_batch_size, expected_params", @@ -70,8 +70,8 @@ def test_get_rag_embedding_default(self, mocker, mock_config): ) def test_try_set_model_and_batch_size(self, mock_config, model, embed_batch_size, expected_params): # Mock - mock_config.embedding.model = model - mock_config.embedding.embed_batch_size = embed_batch_size + mock_config.rag.embedding.model = model + mock_config.rag.embedding.embed_batch_size = embed_batch_size # Setup test_params = {} @@ -84,7 +84,7 @@ def test_try_set_model_and_batch_size(self, mock_config, model, embed_batch_size def test_resolve_embedding_type(self, mock_config): # Mock - mock_config.embedding.api_type = EmbeddingType.OPENAI + mock_config.rag.embedding.api_type = EmbeddingType.OPENAI # Exec embedding_type = self.embedding_factory._resolve_embedding_type() @@ -94,7 +94,7 @@ def test_resolve_embedding_type(self, mock_config): def test_resolve_embedding_type_exception(self, mock_config): # Mock - mock_config.embedding.api_type = None + mock_config.rag.embedding.api_type = None mock_config.llm.api_type = LLMType.GEMINI # Assert From b2458d8ed5a0e65de85d5dd24d100f8cb76c397e Mon Sep 17 00:00:00 2001 From: liaojianxing Date: Mon, 19 Aug 2024 16:24:09 +0800 Subject: [PATCH 06/12] add rag config --- examples/rag_pipeline.py | 4 +- metagpt/config2.py | 6 +- metagpt/rag/query_analysis/__init__.py | 9 +++ metagpt/rag/query_analysis/hyde.py | 63 +++++++++++++++++++ .../simple_query_transformer.py | 1 - 5 files changed, 76 insertions(+), 7 deletions(-) create mode 100644 metagpt/rag/query_analysis/__init__.py create mode 100644 metagpt/rag/query_analysis/hyde.py diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index 81e75c2ea..5a1942eee 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -19,9 +19,7 @@ ) from metagpt.utils.exceptions import handle_exception from metagpt.rag.query_analysis.simple_query_transformer import SimpleQueryTransformer -from metagpt.rag.query_analysis.HyDE import HyDEQuery - - +from metagpt.rag.query_analysis.hyde import HyDEQuery LLM_TIP = "If you not sure, just answer I don't know." diff --git a/metagpt/config2.py b/metagpt/config2.py index 58a99c920..9c6900dd9 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -12,7 +12,6 @@ from pydantic import BaseModel, model_validator from metagpt.configs.browser_config import BrowserConfig -from metagpt.configs.embedding_config import EmbeddingConfig from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.configs.mermaid_config import MermaidConfig from metagpt.configs.redis_config import RedisConfig @@ -21,6 +20,7 @@ from metagpt.configs.workspace_config import WorkspaceConfig from metagpt.const import CONFIG_ROOT, METAGPT_ROOT from metagpt.utils.yaml_model import YamlModel +from MetaGPT.metagpt.configs.rag_config import RAGConfig class CLIParams(BaseModel): @@ -48,8 +48,8 @@ class Config(CLIParams, YamlModel): # Key Parameters llm: LLMConfig - # RAG Embedding - embedding: EmbeddingConfig = EmbeddingConfig() + # RAG + rag: RAGConfig = RAGConfig() # Global Proxy. Will be used if llm.proxy is not set proxy: str = "" diff --git a/metagpt/rag/query_analysis/__init__.py b/metagpt/rag/query_analysis/__init__.py new file mode 100644 index 000000000..caa35405f --- /dev/null +++ b/metagpt/rag/query_analysis/__init__.py @@ -0,0 +1,9 @@ +"""RAG factories""" + +from metagpt.rag.factories.retriever import get_retriever +from metagpt.rag.factories.ranker import get_rankers +from metagpt.rag.factories.embedding import get_rag_embedding +from metagpt.rag.factories.index import get_index +from metagpt.rag.factories.llm import get_rag_llm + +__all__ = ["get_retriever", "get_rankers", "get_rag_embedding", "get_index", "get_rag_llm"] diff --git a/metagpt/rag/query_analysis/hyde.py b/metagpt/rag/query_analysis/hyde.py new file mode 100644 index 000000000..359cf0439 --- /dev/null +++ b/metagpt/rag/query_analysis/hyde.py @@ -0,0 +1,63 @@ +from typing import Any, Dict, Optional +from llama_index.core.llms import LLM +from llama_index.core.indices.query.query_transform import HyDEQueryTransform +from llama_index.core.prompts.default_prompts import DEFAULT_HYDE_PROMPT +from llama_index.core.schema import QueryBundle +from llama_index.core.prompts import BasePromptTemplate +from llama_index.core.service_context_elements.llm_predictor import LLMPredictorType +from metagpt.logs import logger +from metagpt.rag.factories import get_rag_llm +from metagpt.config2 import config + + +class HyDEQuery(HyDEQueryTransform): + def __init__( + self, + llm: Optional[LLMPredictorType] = None, + hyde_prompt: Optional[BasePromptTemplate] = None, + include_original: Optional[bool] = None, + ) -> None: + """Initialize the HyDEQueryTransform class with optional parameters. + + Args: + llm (Optional[LLMPredictorType]): An LLM (Language Learning Model) used for generating + hypothetical documents. If not provided, defaults to rag_llm. + hyde_prompt (Optional[BasePromptTemplate]): Custom prompt template for HyDE. + If not provided, the default prompt is used. + include_original (Optional[bool]): Flag to include the original query string in the output. + If not provided, the setting is fetched from config or defaults to True. + """ + # Set LLM, using a default if not provided + self._llm = llm or get_rag_llm() + # Set the prompt template, using a default if not provided + self._hyde_prompt = hyde_prompt or DEFAULT_HYDE_PROMPT + # Set the flag to include the original query, fetching from config if not provided + if include_original is not None: + self._include_original = include_original + else: + try: + self._include_original = config.rag.query_analysis.hyde + except AttributeError: + self._include_original = True + + def _run(self, query_bundle: QueryBundle, metadata: Dict[str, Any]) -> QueryBundle: + """Process the query bundle to include hypothetical document embeddings. + + Args: + query_bundle (QueryBundle): The original query bundle containing query information. + metadata (Dict[str, Any]): Additional metadata for processing. + + Returns: + QueryBundle: Updated query bundle with additional hypothetical document embeddings. + """ + # Log the operation + logger.info(f"{'#' * 5} Running HyDEQuery... {'#' * 5}") + # Generate the hypothetical document using the LLM and prompt + query_str = query_bundle.query_str + hypothetical_doc = self._llm.predict(self._hyde_prompt, context_str=query_str) + embedding_strs = [hypothetical_doc] + # Include the original query strings if specified + if self._include_original: + embedding_strs.extend(query_bundle.embedding_strs) + + return QueryBundle(query_str=query_str, custom_embedding_strs=embedding_strs) \ No newline at end of file diff --git a/metagpt/rag/query_analysis/simple_query_transformer.py b/metagpt/rag/query_analysis/simple_query_transformer.py index 4eaacdb46..322f9fee2 100644 --- a/metagpt/rag/query_analysis/simple_query_transformer.py +++ b/metagpt/rag/query_analysis/simple_query_transformer.py @@ -1,4 +1,3 @@ -from llama_index.core.base.response.schema import RESPONSE_TYPE from llama_index.core.callbacks import CallbackManager from llama_index.core.indices.query.query_transform.base import BaseQueryTransform from llama_index.core.query_engine import TransformQueryEngine, BaseQueryEngine From 008fe374eaf349e08ac93240c7fdf1b997b4290f Mon Sep 17 00:00:00 2001 From: liaojianxing Date: Mon, 19 Aug 2024 16:41:33 +0800 Subject: [PATCH 07/12] Formatting Files --- examples/rag_pipeline.py | 5 +- metagpt/config2.py | 2 +- metagpt/configs/hyde_config.py | 2 - metagpt/configs/query_analysis_config.py | 2 +- metagpt/configs/rag_config.py | 2 +- metagpt/rag/engines/simple.py | 82 +++++++++---------- metagpt/rag/query_analysis/__init__.py | 9 -- metagpt/rag/query_analysis/hyde.py | 9 +- .../simple_query_transformer.py | 17 ++-- metagpt/rag/schema.py | 5 +- 10 files changed, 62 insertions(+), 73 deletions(-) delete mode 100644 metagpt/rag/query_analysis/__init__.py diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index 5a1942eee..0703f1656 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -2,12 +2,13 @@ import asyncio -from llama_index.core.query_engine import TransformQueryEngine from pydantic import BaseModel from metagpt.const import DATA_PATH, EXAMPLE_DATA_PATH from metagpt.logs import logger from metagpt.rag.engines import SimpleEngine +from metagpt.rag.query_analysis.hyde import HyDEQuery +from metagpt.rag.query_analysis.simple_query_transformer import SimpleQueryTransformer from metagpt.rag.schema import ( ChromaIndexConfig, ChromaRetrieverConfig, @@ -18,8 +19,6 @@ LLMRankerConfig, ) from metagpt.utils.exceptions import handle_exception -from metagpt.rag.query_analysis.simple_query_transformer import SimpleQueryTransformer -from metagpt.rag.query_analysis.hyde import HyDEQuery LLM_TIP = "If you not sure, just answer I don't know." diff --git a/metagpt/config2.py b/metagpt/config2.py index 9c6900dd9..b43bc52a8 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -19,8 +19,8 @@ from metagpt.configs.search_config import SearchConfig from metagpt.configs.workspace_config import WorkspaceConfig from metagpt.const import CONFIG_ROOT, METAGPT_ROOT -from metagpt.utils.yaml_model import YamlModel from MetaGPT.metagpt.configs.rag_config import RAGConfig +from metagpt.utils.yaml_model import YamlModel class CLIParams(BaseModel): diff --git a/metagpt/configs/hyde_config.py b/metagpt/configs/hyde_config.py index 6cf547630..f951d3a71 100644 --- a/metagpt/configs/hyde_config.py +++ b/metagpt/configs/hyde_config.py @@ -3,5 +3,3 @@ class HyDEConfig(YamlModel): include_original: bool = True - - diff --git a/metagpt/configs/query_analysis_config.py b/metagpt/configs/query_analysis_config.py index 292c6b8e8..d0e5db738 100644 --- a/metagpt/configs/query_analysis_config.py +++ b/metagpt/configs/query_analysis_config.py @@ -1,5 +1,5 @@ -from metagpt.utils.yaml_model import YamlModel from metagpt.configs.hyde_config import HyDEConfig +from metagpt.utils.yaml_model import YamlModel class QueryAnalysisConfig(YamlModel): diff --git a/metagpt/configs/rag_config.py b/metagpt/configs/rag_config.py index f469ff919..af0294a45 100644 --- a/metagpt/configs/rag_config.py +++ b/metagpt/configs/rag_config.py @@ -1,6 +1,6 @@ -from metagpt.utils.yaml_model import YamlModel from metagpt.configs.embedding_config import EmbeddingConfig from metagpt.configs.query_analysis_config import QueryAnalysisConfig +from metagpt.utils.yaml_model import YamlModel class RAGConfig(YamlModel): diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index 10b8e1efa..c237dcf69 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -27,8 +27,6 @@ QueryType, TransformComponent, ) -from llama_index.core.query_engine import TransformQueryEngine -from llama_index.legacy.indices.query.query_transform.base import BaseQueryTransform from metagpt.rag.factories import ( get_index, @@ -60,12 +58,12 @@ class SimpleEngine(RetrieverQueryEngine): """ def __init__( - self, - retriever: BaseRetriever, - response_synthesizer: Optional[BaseSynthesizer] = None, - node_postprocessors: Optional[list[BaseNodePostprocessor]] = None, - callback_manager: Optional[CallbackManager] = None, - transformations: Optional[list[TransformComponent]] = None, + self, + retriever: BaseRetriever, + response_synthesizer: Optional[BaseSynthesizer] = None, + node_postprocessors: Optional[list[BaseNodePostprocessor]] = None, + callback_manager: Optional[CallbackManager] = None, + transformations: Optional[list[TransformComponent]] = None, ) -> None: super().__init__( retriever=retriever, @@ -77,14 +75,14 @@ def __init__( @classmethod def from_docs( - cls, - input_dir: str = None, - input_files: list[str] = None, - transformations: Optional[list[TransformComponent]] = None, - embed_model: BaseEmbedding = None, - llm: LLM = None, - retriever_configs: list[BaseRetrieverConfig] = None, - ranker_configs: list[BaseRankerConfig] = None, + cls, + input_dir: str = None, + input_files: list[str] = None, + transformations: Optional[list[TransformComponent]] = None, + embed_model: BaseEmbedding = None, + llm: LLM = None, + retriever_configs: list[BaseRetrieverConfig] = None, + ranker_configs: list[BaseRankerConfig] = None, ) -> "SimpleEngine": """From docs. @@ -119,13 +117,13 @@ def from_docs( @classmethod def from_objs( - cls, - objs: Optional[list[RAGObject]] = None, - transformations: Optional[list[TransformComponent]] = None, - embed_model: BaseEmbedding = None, - llm: LLM = None, - retriever_configs: list[BaseRetrieverConfig] = None, - ranker_configs: list[BaseRankerConfig] = None, + cls, + objs: Optional[list[RAGObject]] = None, + transformations: Optional[list[TransformComponent]] = None, + embed_model: BaseEmbedding = None, + llm: LLM = None, + retriever_configs: list[BaseRetrieverConfig] = None, + ranker_configs: list[BaseRankerConfig] = None, ) -> "SimpleEngine": """From objs. @@ -156,12 +154,12 @@ def from_objs( @classmethod def from_index( - cls, - index_config: BaseIndexConfig, - embed_model: BaseEmbedding = None, - llm: LLM = None, - retriever_configs: list[BaseRetrieverConfig] = None, - ranker_configs: list[BaseRankerConfig] = None, + cls, + index_config: BaseIndexConfig, + embed_model: BaseEmbedding = None, + llm: LLM = None, + retriever_configs: list[BaseRetrieverConfig] = None, + ranker_configs: list[BaseRankerConfig] = None, ) -> "SimpleEngine": """Load from previously maintained index by self.persist(), index_config contains persis_path.""" index = get_index(index_config, embed_model=cls._resolve_embed_model(embed_model, [index_config])) @@ -211,13 +209,13 @@ def persist(self, persist_dir: Union[str, os.PathLike], **kwargs): @classmethod def _from_nodes( - cls, - nodes: list[BaseNode], - transformations: Optional[list[TransformComponent]] = None, - embed_model: BaseEmbedding = None, - llm: LLM = None, - retriever_configs: list[BaseRetrieverConfig] = None, - ranker_configs: list[BaseRankerConfig] = None, + cls, + nodes: list[BaseNode], + transformations: Optional[list[TransformComponent]] = None, + embed_model: BaseEmbedding = None, + llm: LLM = None, + retriever_configs: list[BaseRetrieverConfig] = None, + ranker_configs: list[BaseRankerConfig] = None, ) -> "SimpleEngine": embed_model = cls._resolve_embed_model(embed_model, retriever_configs) llm = llm or get_rag_llm() @@ -234,11 +232,11 @@ def _from_nodes( @classmethod def _from_index( - cls, - index: BaseIndex, - llm: LLM = None, - retriever_configs: list[BaseRetrieverConfig] = None, - ranker_configs: list[BaseRankerConfig] = None, + cls, + index: BaseIndex, + llm: LLM = None, + retriever_configs: list[BaseRetrieverConfig] = None, + ranker_configs: list[BaseRankerConfig] = None, ) -> "SimpleEngine": llm = llm or get_rag_llm() @@ -303,5 +301,3 @@ def _resolve_embed_model(embed_model: BaseEmbedding = None, configs: list[Any] = @staticmethod def _default_transformations(): return [SentenceSplitter()] - - diff --git a/metagpt/rag/query_analysis/__init__.py b/metagpt/rag/query_analysis/__init__.py deleted file mode 100644 index caa35405f..000000000 --- a/metagpt/rag/query_analysis/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""RAG factories""" - -from metagpt.rag.factories.retriever import get_retriever -from metagpt.rag.factories.ranker import get_rankers -from metagpt.rag.factories.embedding import get_rag_embedding -from metagpt.rag.factories.index import get_index -from metagpt.rag.factories.llm import get_rag_llm - -__all__ = ["get_retriever", "get_rankers", "get_rag_embedding", "get_index", "get_rag_llm"] diff --git a/metagpt/rag/query_analysis/hyde.py b/metagpt/rag/query_analysis/hyde.py index 359cf0439..1700e3a8b 100644 --- a/metagpt/rag/query_analysis/hyde.py +++ b/metagpt/rag/query_analysis/hyde.py @@ -1,13 +1,14 @@ from typing import Any, Dict, Optional -from llama_index.core.llms import LLM + from llama_index.core.indices.query.query_transform import HyDEQueryTransform +from llama_index.core.prompts import BasePromptTemplate from llama_index.core.prompts.default_prompts import DEFAULT_HYDE_PROMPT from llama_index.core.schema import QueryBundle -from llama_index.core.prompts import BasePromptTemplate from llama_index.core.service_context_elements.llm_predictor import LLMPredictorType + +from metagpt.config2 import config from metagpt.logs import logger from metagpt.rag.factories import get_rag_llm -from metagpt.config2 import config class HyDEQuery(HyDEQueryTransform): @@ -60,4 +61,4 @@ def _run(self, query_bundle: QueryBundle, metadata: Dict[str, Any]) -> QueryBund if self._include_original: embedding_strs.extend(query_bundle.embedding_strs) - return QueryBundle(query_str=query_str, custom_embedding_strs=embedding_strs) \ No newline at end of file + return QueryBundle(query_str=query_str, custom_embedding_strs=embedding_strs) diff --git a/metagpt/rag/query_analysis/simple_query_transformer.py b/metagpt/rag/query_analysis/simple_query_transformer.py index 322f9fee2..b44ac998c 100644 --- a/metagpt/rag/query_analysis/simple_query_transformer.py +++ b/metagpt/rag/query_analysis/simple_query_transformer.py @@ -1,7 +1,8 @@ +from typing import Optional + from llama_index.core.callbacks import CallbackManager from llama_index.core.indices.query.query_transform.base import BaseQueryTransform -from llama_index.core.query_engine import TransformQueryEngine, BaseQueryEngine -from typing import List, Optional, Sequence +from llama_index.core.query_engine import TransformQueryEngine from metagpt.rag.engines.simple import SimpleEngine @@ -19,15 +20,15 @@ class SimpleQueryTransformer(TransformQueryEngine): """ def __init__( - self, - query_engine: SimpleEngine, - query_transform: BaseQueryTransform, - transform_metadata: Optional[dict] = None, - callback_manager: Optional[CallbackManager] = None, + self, + query_engine: SimpleEngine, + query_transform: BaseQueryTransform, + transform_metadata: Optional[dict] = None, + callback_manager: Optional[CallbackManager] = None, ) -> None: super().__init__( query_engine=query_engine, query_transform=query_transform, transform_metadata=transform_metadata, - callback_manager=callback_manager + callback_manager=callback_manager, ) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index ba1d67941..2850aad3d 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -48,7 +48,10 @@ def check_dimensions(self): self.dimensions = config.rag.embedding.dimensions or self._embedding_type_to_dimensions.get( config.rag.embedding.api_type, 1536 ) - if not config.rag.embedding.dimensions and config.rag.embedding.api_type not in self._embedding_type_to_dimensions: + if ( + not config.rag.embedding.dimensions + and config.rag.embedding.api_type not in self._embedding_type_to_dimensions + ): logger.warning( f"You didn't set dimensions in config when using {config.rag.embedding.api_type}, default to 1536" ) From 3e53b335d5548a1da2e17e0b1a6bf882c9277737 Mon Sep 17 00:00:00 2001 From: liaojianxing Date: Tue, 20 Aug 2024 14:54:16 +0800 Subject: [PATCH 08/12] add hotpotqa.py --- config/config2.example.yaml | 2 +- examples/rag_pipeline.py | 3 - metagpt/config2.py | 4 - metagpt/rag/benchmark/hotpotqa.py | 127 ++++++++++++++++++ .../factories/HyDEQueryTransformFactory.py | 2 +- metagpt/rag/query_analysis/HyDE.py | 25 ---- 6 files changed, 129 insertions(+), 34 deletions(-) create mode 100644 metagpt/rag/benchmark/hotpotqa.py delete mode 100644 metagpt/rag/query_analysis/HyDE.py diff --git a/config/config2.example.yaml b/config/config2.example.yaml index 15f3d028f..0d7dd91af 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -23,7 +23,7 @@ rag: # RAG Query Analysis query_analysis: hyde: - include_original: true # In the query rewrite, determines whether to include the original + include_original: True # In the query rewrite, determines whether to include the original # RAG Analysis diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index 4f3338522..558e0570e 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -2,7 +2,6 @@ import asyncio -from llama_index.core.query_engine import TransformQueryEngine from pydantic import BaseModel from metagpt.const import DATA_PATH, EXAMPLE_DATA_PATH @@ -233,7 +232,6 @@ async def use_hyde(self): self._print_query_result(answer) self._print_title("No use HyDE to analysis query") - engine = SimpleEngine.from_docs(input_files=[DOC_PATH]) answer = await engine.aquery(QUESTION2) self._print_query_result(answer) @@ -283,6 +281,5 @@ async def main(): await e.use_hyde() - if __name__ == "__main__": asyncio.run(main()) diff --git a/metagpt/config2.py b/metagpt/config2.py index b8d9e3113..b43bc52a8 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -14,7 +14,6 @@ from metagpt.configs.browser_config import BrowserConfig from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.configs.mermaid_config import MermaidConfig -from metagpt.configs.query_analysis_config import HydeConfig from metagpt.configs.redis_config import RedisConfig from metagpt.configs.s3_config import S3Config from metagpt.configs.search_config import SearchConfig @@ -52,9 +51,6 @@ class Config(CLIParams, YamlModel): # RAG rag: RAGConfig = RAGConfig() - # RAG Analysis - hyde: HydeConfig = HydeConfig() - # Global Proxy. Will be used if llm.proxy is not set proxy: str = "" diff --git a/metagpt/rag/benchmark/hotpotqa.py b/metagpt/rag/benchmark/hotpotqa.py new file mode 100644 index 000000000..c88810f3b --- /dev/null +++ b/metagpt/rag/benchmark/hotpotqa.py @@ -0,0 +1,127 @@ +import json +from typing import Any, List, Optional + +from llama_index.core.base.base_query_engine import BaseQueryEngine +from llama_index.core.base.base_retriever import BaseRetriever +from llama_index.core.evaluation.benchmarks import HotpotQAEvaluator +from llama_index.core.evaluation.benchmarks.hotpotqa import exact_match_score, f1_score +from llama_index.core.query_engine.retriever_query_engine import RetrieverQueryEngine +from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode + +from metagpt.const import EXAMPLE_DATA_PATH +from metagpt.rag.engines import SimpleEngine +from metagpt.rag.query_analysis.hyde import HyDEQuery +from metagpt.rag.query_analysis.simple_query_transformer import SimpleQueryTransformer +from metagpt.rag.schema import FAISSRetrieverConfig + + +class HotpotQA(HotpotQAEvaluator): + def run( + self, + query_engine: BaseQueryEngine, + queries: int = 10, + queries_fraction: Optional[float] = None, + show_result: bool = False, + hyde: bool = False, + ) -> None: + dataset_paths = self._download_datasets() + dataset = "hotpot_dev_distractor" + dataset_path = dataset_paths[dataset] + print("Evaluating on dataset:", dataset) + print("-------------------------------------") + + f = open(dataset_path) + query_objects = json.loads(f.read()) + if queries_fraction: + queries_to_load = int(len(query_objects) * queries_fraction) + else: + queries_to_load = queries + queries_fraction = round(queries / len(query_objects), 5) + + print( + f"Loading {queries_to_load} queries out of \ + {len(query_objects)} (fraction: {queries_fraction})" + ) + query_objects = query_objects[:queries_to_load] + + assert isinstance( + query_engine, RetrieverQueryEngine + ), "query_engine must be a RetrieverQueryEngine for this evaluation" + retriever = HotpotQARetriever(query_objects, hyde) + # Mock the query engine with a retriever + query_engine = query_engine.with_retriever(retriever=retriever) + if hyde: + hyde_query = HyDEQuery() + query_engine = SimpleQueryTransformer(query_engine, hyde_query) + scores = {"exact_match": 0.0, "f1": 0.0} + + for query in query_objects: + if hyde: + response = query_engine.query( + query["question"] + " Give a short factoid answer (as few words as possible)." + ) + else: + query_bundle = QueryBundle( + query_str=query["question"] + " Give a short factoid answer (as few words as possible).", + custom_embedding_strs=[query["question"]], + ) + response = query_engine.query(query_bundle) + em = int(exact_match_score(prediction=str(response), ground_truth=query["answer"])) + f1, _, _ = f1_score(prediction=str(response), ground_truth=query["answer"]) + scores["exact_match"] += em + scores["f1"] += f1 + if show_result: + print("Question: ", query["question"]) + print("Response:", response) + print("Correct answer: ", query["answer"]) + print("EM:", em, "F1:", f1) + print("-------------------------------------") + + for score in scores: + scores[score] /= len(query_objects) + + print("Scores: ", scores) + + +class HotpotQARetriever(BaseRetriever): + """ + This is a mocked retriever for HotpotQA dataset. It is only meant to be used + with the hotpotqa dev dataset in the distractor setting. This is the setting that + does not require retrieval but requires identifying the supporting facts from + a list of 10 sources. + """ + + def __init__(self, query_objects: Any, hyde: bool) -> None: + self.hyde = hyde + assert isinstance( + query_objects, + list, + ), f"query_objects must be a list, got: {type(query_objects)}" + self._queries = {} + for object in query_objects: + self._queries[object["question"]] = object + + def _retrieve(self, query: QueryBundle) -> List[NodeWithScore]: + if query.custom_embedding_strs and self.hyde is False: + query_str = query.custom_embedding_strs[0] + else: + query_str = query.query_str.replace(" Give a short factoid answer (as few words as possible).", "") + contexts = self._queries[query_str]["context"] + node_with_scores = [] + for ctx in contexts: + text_list = ctx[1] + text = "\n".join(text_list) + node = TextNode(text=text, metadata={"title": ctx[0]}) + node_with_scores.append(NodeWithScore(node=node, score=1.0)) + + return node_with_scores + + def __str__(self) -> str: + return "HotpotQARetriever" + + +if __name__ == "__main__": + DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt" + engine = SimpleEngine.from_docs(input_files=[DOC_PATH], retriever_configs=[FAISSRetrieverConfig()]) + HotpotQA().run(engine, queries=100, show_result=True, hyde=True) + HotpotQA().run(engine, queries=100, show_result=True, hyde=False) diff --git a/metagpt/rag/factories/HyDEQueryTransformFactory.py b/metagpt/rag/factories/HyDEQueryTransformFactory.py index c0946f8ee..d08f97181 100644 --- a/metagpt/rag/factories/HyDEQueryTransformFactory.py +++ b/metagpt/rag/factories/HyDEQueryTransformFactory.py @@ -3,7 +3,7 @@ from metagpt.config2 import config from metagpt.rag.factories import get_rag_llm from metagpt.rag.factories.base import ConfigBasedFactory -from metagpt.rag.query_analysis.HyDE import HyDEQuery +from metagpt.rag.query_analysis.hyde import HyDEQuery class HyDEQueryTransformFactory(ConfigBasedFactory): diff --git a/metagpt/rag/query_analysis/HyDE.py b/metagpt/rag/query_analysis/HyDE.py deleted file mode 100644 index c1bcedcae..000000000 --- a/metagpt/rag/query_analysis/HyDE.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import Dict - -from llama_index.core.indices.query.query_transform import HyDEQueryTransform -from llama_index.core.schema import QueryBundle - -from metagpt.logs import logger - - -class HyDEQuery(HyDEQueryTransform): - def _run(self, query_bundle: QueryBundle, metadata: Dict) -> QueryBundle: - logger.info(f"{'#' * 5} running HyDEQuery... {'#' * 5}") - query_str = query_bundle.query_str - - hypothetical_doc = self._llm.predict(self._hyde_prompt, context_str=query_str) - - embedding_strs = [hypothetical_doc] - - if self._include_original: - embedding_strs.extend(query_bundle.embedding_strs) - logger.info(f" Hypothetical doc:{embedding_strs} ") - - return QueryBundle( - query_str=query_str, - custom_embedding_strs=embedding_strs, - ) From d5c3c20ac36ed41e6194c2d300c899e3d00ebf7e Mon Sep 17 00:00:00 2001 From: liaojianxing Date: Tue, 20 Aug 2024 15:34:46 +0800 Subject: [PATCH 09/12] fix config --- metagpt/config2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metagpt/config2.py b/metagpt/config2.py index f17f7fe87..0e7c266de 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -15,12 +15,12 @@ from metagpt.configs.file_parser_config import OmniParseConfig from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.configs.mermaid_config import MermaidConfig +from metagpt.configs.rag_config import RAGConfig from metagpt.configs.redis_config import RedisConfig from metagpt.configs.s3_config import S3Config from metagpt.configs.search_config import SearchConfig from metagpt.configs.workspace_config import WorkspaceConfig from metagpt.const import CONFIG_ROOT, METAGPT_ROOT -from MetaGPT.metagpt.configs.rag_config import RAGConfig from metagpt.utils.yaml_model import YamlModel From 621cb22113cdca709f650e3bce5bcfaa83a404cc Mon Sep 17 00:00:00 2001 From: jasonliao <57182856+lanlanguai@users.noreply.github.com> Date: Tue, 20 Aug 2024 16:26:52 +0800 Subject: [PATCH 10/12] Update config2.example.yaml --- config/config2.example.yaml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/config/config2.example.yaml b/config/config2.example.yaml index df6ec5b31..8e0d3af17 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -26,10 +26,6 @@ rag: include_original: True # In the query rewrite, determines whether to include the original -# RAG Analysis -hyde: - include_original: true # In the query rewrite the content whether to add the original question - repair_llm_output: true # when the output is not a valid json, try to repair it proxy: "YOUR_PROXY" # for tools like requests, playwright, selenium, etc. From 83041d29106144bf4cd3281d239b5b14b3f39695 Mon Sep 17 00:00:00 2001 From: jasonliao <57182856+lanlanguai@users.noreply.github.com> Date: Tue, 20 Aug 2024 16:27:42 +0800 Subject: [PATCH 11/12] Update config2.example.yaml --- config/config2.example.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/config2.example.yaml b/config/config2.example.yaml index 8e0d3af17..58078fd3d 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -23,7 +23,7 @@ rag: # RAG Query Analysis query_analysis: hyde: - include_original: True # In the query rewrite, determines whether to include the original + include_original: true # In the query rewrite, determines whether to include the original repair_llm_output: true # when the output is not a valid json, try to repair it From 26b028504bb481e8a4fa784285924dcb7c3141d1 Mon Sep 17 00:00:00 2001 From: jasonliao <57182856+lanlanguai@users.noreply.github.com> Date: Tue, 20 Aug 2024 16:28:57 +0800 Subject: [PATCH 12/12] Delete metagpt/rag/factories/HyDEQueryTransformFactory.py --- .../factories/HyDEQueryTransformFactory.py | 20 ------------------- 1 file changed, 20 deletions(-) delete mode 100644 metagpt/rag/factories/HyDEQueryTransformFactory.py diff --git a/metagpt/rag/factories/HyDEQueryTransformFactory.py b/metagpt/rag/factories/HyDEQueryTransformFactory.py deleted file mode 100644 index d08f97181..000000000 --- a/metagpt/rag/factories/HyDEQueryTransformFactory.py +++ /dev/null @@ -1,20 +0,0 @@ -from llama_index.core.llms import LLM - -from metagpt.config2 import config -from metagpt.rag.factories import get_rag_llm -from metagpt.rag.factories.base import ConfigBasedFactory -from metagpt.rag.query_analysis.hyde import HyDEQuery - - -class HyDEQueryTransformFactory(ConfigBasedFactory): - """Factory for creating HyDEQueryTransform instances with LLM configuration.""" - - llm: LLM = None - hyde_config: dict = None - - def __init__(self): - self.hyde_config = config.hyde - self.llm = get_rag_llm() - - def create_hyde_query_transform(self) -> HyDEQuery: - return HyDEQuery(include_original=self.hyde_config.include_original, llm=self.llm)