Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for the HyDE method in quey analysis for RAG plates #1413

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions config/config2.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,22 @@ llm:
pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's


rag:
# RAG Embedding.
# For backward compatibility, if the embedding is not set and the llm's api_type is either openai or azure, the llm's config will be used.
embedding:
api_type: "" # openai / azure / gemini / ollama etc. Check EmbeddingType for more options.
base_url: ""
api_key: ""
model: ""
api_version: ""
embed_batch_size: 100
dimensions: # output dimension of embedding model
embedding:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't change this embedding one.

api_type: "" # openai / azure / gemini / ollama etc. Check EmbeddingType for more options.
base_url: ""
api_key: ""
model: ""
api_version: ""
embed_batch_size: 100
dimensions: # output dimension of embedding model
# RAG Query Analysis
query_analysis:
hyde:
include_original: true # In the query rewrite, determines whether to include the original


repair_llm_output: true # when the output is not a valid json, try to repair it

Expand Down
3 changes: 2 additions & 1 deletion config/config2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ llm:
api_type: "openai" # or azure / ollama / groq etc.
model: "gpt-4-turbo" # or gpt-3.5-turbo
base_url: "https://api.openai.com/v1" # or forward url / other llm url
api_key: "YOUR_API_KEY"
api_key: "YOUR_API_KEY"

29 changes: 27 additions & 2 deletions examples/rag/rag_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from metagpt.const import DATA_PATH, EXAMPLE_DATA_PATH
from metagpt.logs import logger
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.query_analysis.hyde import HyDEQuery
from metagpt.rag.query_analysis.simple_query_transformer import SimpleQueryTransformer
from metagpt.rag.schema import (
ChromaIndexConfig,
ChromaRetrieverConfig,
Expand All @@ -22,7 +24,7 @@

DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt"
QUESTION = f"What are key qualities to be a good writer? {LLM_TIP}"

QUESTION2 = "What are the key factors in maintaining high productivity?"
TRAVEL_DOC_PATH = EXAMPLE_DATA_PATH / "rag/travel.txt"
TRAVEL_QUESTION = f"What does Bob like? {LLM_TIP}"

Expand Down Expand Up @@ -212,9 +214,31 @@ async def init_and_query_es(self):
answer = await engine.aquery(TRAVEL_QUESTION)
self._print_query_result(answer)

async def use_hyde(self):
"""This example show how to use HyDE: HyDE enhances search results by generating Hypothetical doc(virtual
article), for more details please refer to the paper: http://arxiv.org/abs/2212.10496
Query Result:
Bob likes traveling.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment correct?

"""

self._print_title("Use HyDE to analysis query")
# 1. save docs
engine = SimpleEngine.from_docs(input_files=[DOC_PATH])
# 2. Initialize HyDE query analysis method
hyde_query = HyDEQuery()
# 3. Add HyDE to the engine
hyde_query_engine = SimpleQueryTransformer(engine, hyde_query)
answer = await hyde_query_engine.aquery(QUESTION2)
self._print_query_result(answer)

self._print_title("No use HyDE to analysis query")
answer = await engine.aquery(QUESTION2)

self._print_query_result(answer)

@staticmethod
def _print_title(title):
logger.info(f"{'#'*30} {title} {'#'*30}")
logger.info(f"{'#' * 30} {title} {'#' * 30}")

@staticmethod
def _print_retrieve_result(result):
Expand Down Expand Up @@ -254,6 +278,7 @@ async def main():
await e.init_objects()
await e.init_and_query_chromadb()
await e.init_and_query_es()
await e.use_hyde()


if __name__ == "__main__":
Expand Down
10 changes: 4 additions & 6 deletions metagpt/config2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
from pydantic import BaseModel, model_validator

from metagpt.configs.browser_config import BrowserConfig
from metagpt.configs.embedding_config import EmbeddingConfig
from metagpt.configs.file_parser_config import OmniParseConfig
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.configs.mermaid_config import MermaidConfig
from metagpt.configs.rag_config import RAGConfig
from metagpt.configs.redis_config import RedisConfig
from metagpt.configs.s3_config import S3Config
from metagpt.configs.search_config import SearchConfig
from metagpt.configs.workspace_config import WorkspaceConfig
from metagpt.const import CONFIG_ROOT, METAGPT_ROOT
from MetaGPT.metagpt.configs.rag_config import RAGConfig
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs to be deleted

from metagpt.utils.yaml_model import YamlModel


Expand Down Expand Up @@ -49,12 +50,10 @@ class Config(CLIParams, YamlModel):
# Key Parameters
llm: LLMConfig

# RAG Embedding
embedding: EmbeddingConfig = EmbeddingConfig()

# RAG
rag: RAGConfig = RAGConfig()
# omniparse
omniparse: OmniParseConfig = OmniParseConfig()

# Global Proxy. Will be used if llm.proxy is not set
proxy: str = ""

Expand All @@ -73,7 +72,6 @@ class Config(CLIParams, YamlModel):
workspace: WorkspaceConfig = WorkspaceConfig()
enable_longterm_memory: bool = False
code_review_k_times: int = 2
agentops_api_key: str = ""

# Will be removed in the future
metagpt_tti_url: str = ""
Expand Down
5 changes: 5 additions & 0 deletions metagpt/configs/hyde_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from metagpt.utils.yaml_model import YamlModel


class HyDEConfig(YamlModel):
include_original: bool = True
6 changes: 6 additions & 0 deletions metagpt/configs/query_analysis_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from metagpt.configs.hyde_config import HyDEConfig
from metagpt.utils.yaml_model import YamlModel

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use rag_config.py to support independent rag configuration


class QueryAnalysisConfig(YamlModel):
hyde: HyDEConfig = HyDEConfig()
8 changes: 8 additions & 0 deletions metagpt/configs/rag_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from metagpt.configs.embedding_config import EmbeddingConfig
from metagpt.configs.query_analysis_config import QueryAnalysisConfig
from metagpt.utils.yaml_model import YamlModel


class RAGConfig(YamlModel):
embedding: EmbeddingConfig = EmbeddingConfig()
query_analysis: QueryAnalysisConfig = QueryAnalysisConfig()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is recommended to add QueryAnalysisConfig and EmbeddingConfig in rag_config.py without hyde_config.py and query_analysis_config.py files.

127 changes: 127 additions & 0 deletions metagpt/rag/benchmark/hotpotqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import json
from typing import Any, List, Optional

from llama_index.core.base.base_query_engine import BaseQueryEngine
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.evaluation.benchmarks import HotpotQAEvaluator
from llama_index.core.evaluation.benchmarks.hotpotqa import exact_match_score, f1_score
from llama_index.core.query_engine.retriever_query_engine import RetrieverQueryEngine
from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode

from metagpt.const import EXAMPLE_DATA_PATH
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.query_analysis.hyde import HyDEQuery
from metagpt.rag.query_analysis.simple_query_transformer import SimpleQueryTransformer
from metagpt.rag.schema import FAISSRetrieverConfig


class HotpotQA(HotpotQAEvaluator):
def run(
self,
query_engine: BaseQueryEngine,
queries: int = 10,
queries_fraction: Optional[float] = None,
show_result: bool = False,
hyde: bool = False,
) -> None:
dataset_paths = self._download_datasets()
dataset = "hotpot_dev_distractor"
dataset_path = dataset_paths[dataset]
print("Evaluating on dataset:", dataset)
print("-------------------------------------")

f = open(dataset_path)
query_objects = json.loads(f.read())
if queries_fraction:
queries_to_load = int(len(query_objects) * queries_fraction)
else:
queries_to_load = queries
queries_fraction = round(queries / len(query_objects), 5)

print(
f"Loading {queries_to_load} queries out of \
{len(query_objects)} (fraction: {queries_fraction})"
)
query_objects = query_objects[:queries_to_load]

assert isinstance(
query_engine, RetrieverQueryEngine
), "query_engine must be a RetrieverQueryEngine for this evaluation"
retriever = HotpotQARetriever(query_objects, hyde)
# Mock the query engine with a retriever
query_engine = query_engine.with_retriever(retriever=retriever)
if hyde:
hyde_query = HyDEQuery()
query_engine = SimpleQueryTransformer(query_engine, hyde_query)
scores = {"exact_match": 0.0, "f1": 0.0}

for query in query_objects:
if hyde:
response = query_engine.query(
query["question"] + " Give a short factoid answer (as few words as possible)."
)
else:
query_bundle = QueryBundle(
query_str=query["question"] + " Give a short factoid answer (as few words as possible).",
custom_embedding_strs=[query["question"]],
)
response = query_engine.query(query_bundle)
em = int(exact_match_score(prediction=str(response), ground_truth=query["answer"]))
f1, _, _ = f1_score(prediction=str(response), ground_truth=query["answer"])
scores["exact_match"] += em
scores["f1"] += f1
if show_result:
print("Question: ", query["question"])
print("Response:", response)
print("Correct answer: ", query["answer"])
print("EM:", em, "F1:", f1)
print("-------------------------------------")

for score in scores:
scores[score] /= len(query_objects)

print("Scores: ", scores)


class HotpotQARetriever(BaseRetriever):
"""
This is a mocked retriever for HotpotQA dataset. It is only meant to be used
with the hotpotqa dev dataset in the distractor setting. This is the setting that
does not require retrieval but requires identifying the supporting facts from
a list of 10 sources.
"""

def __init__(self, query_objects: Any, hyde: bool) -> None:
self.hyde = hyde
assert isinstance(
query_objects,
list,
), f"query_objects must be a list, got: {type(query_objects)}"
self._queries = {}
for object in query_objects:
self._queries[object["question"]] = object

def _retrieve(self, query: QueryBundle) -> List[NodeWithScore]:
if query.custom_embedding_strs and self.hyde is False:
query_str = query.custom_embedding_strs[0]
else:
query_str = query.query_str.replace(" Give a short factoid answer (as few words as possible).", "")
contexts = self._queries[query_str]["context"]
node_with_scores = []
for ctx in contexts:
text_list = ctx[1]
text = "\n".join(text_list)
node = TextNode(text=text, metadata={"title": ctx[0]})
node_with_scores.append(NodeWithScore(node=node, score=1.0))

return node_with_scores

def __str__(self) -> str:
return "HotpotQARetriever"


if __name__ == "__main__":
DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt"
engine = SimpleEngine.from_docs(input_files=[DOC_PATH], retriever_configs=[FAISSRetrieverConfig()])
HotpotQA().run(engine, queries=100, show_result=True, hyde=True)
HotpotQA().run(engine, queries=100, show_result=True, hyde=False)
1 change: 0 additions & 1 deletion metagpt/rag/engines/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def from_docs(

transformations = transformations or cls._default_transformations()
nodes = run_transformations(documents, transformations=transformations)

return cls._from_nodes(
nodes=nodes,
transformations=transformations,
Expand Down
28 changes: 14 additions & 14 deletions metagpt/rag/factories/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def _resolve_embedding_type(self) -> EmbeddingType | LLMType:
If the embedding type is not specified, for backward compatibility, it checks if the LLM API type is either OPENAI or AZURE.
Raise TypeError if embedding type not found.
"""
if config.embedding.api_type:
return config.embedding.api_type
if config.rag.embedding.api_type:
return config.rag.embedding.api_type

if config.llm.api_type in [LLMType.OPENAI, LLMType.AZURE]:
return config.llm.api_type
Expand All @@ -50,8 +50,8 @@ def _resolve_embedding_type(self) -> EmbeddingType | LLMType:

def _create_openai(self) -> OpenAIEmbedding:
params = dict(
api_key=config.embedding.api_key or config.llm.api_key,
api_base=config.embedding.base_url or config.llm.base_url,
api_key=config.rag.embedding.api_key or config.llm.api_key,
api_base=config.rag.embedding.base_url or config.llm.base_url,
)

self._try_set_model_and_batch_size(params)
Expand All @@ -60,9 +60,9 @@ def _create_openai(self) -> OpenAIEmbedding:

def _create_azure(self) -> AzureOpenAIEmbedding:
params = dict(
api_key=config.embedding.api_key or config.llm.api_key,
azure_endpoint=config.embedding.base_url or config.llm.base_url,
api_version=config.embedding.api_version or config.llm.api_version,
api_key=config.rag.embedding.api_key or config.llm.api_key,
azure_endpoint=config.rag.embedding.base_url or config.llm.base_url,
api_version=config.rag.embedding.api_version or config.llm.api_version,
)

self._try_set_model_and_batch_size(params)
Expand All @@ -71,8 +71,8 @@ def _create_azure(self) -> AzureOpenAIEmbedding:

def _create_gemini(self) -> GeminiEmbedding:
params = dict(
api_key=config.embedding.api_key,
api_base=config.embedding.base_url,
api_key=config.rag.embedding.api_key,
api_base=config.rag.embedding.base_url,
)

self._try_set_model_and_batch_size(params)
Expand All @@ -81,7 +81,7 @@ def _create_gemini(self) -> GeminiEmbedding:

def _create_ollama(self) -> OllamaEmbedding:
params = dict(
base_url=config.embedding.base_url,
base_url=config.rag.embedding.base_url,
)

self._try_set_model_and_batch_size(params)
Expand All @@ -90,11 +90,11 @@ def _create_ollama(self) -> OllamaEmbedding:

def _try_set_model_and_batch_size(self, params: dict):
"""Set the model_name and embed_batch_size only when they are specified."""
if config.embedding.model:
params["model_name"] = config.embedding.model
if config.rag.embedding.model:
params["model_name"] = config.rag.embedding.model

if config.embedding.embed_batch_size:
params["embed_batch_size"] = config.embedding.embed_batch_size
if config.rag.embedding.embed_batch_size:
params["embed_batch_size"] = config.rag.embedding.embed_batch_size

def _raise_for_key(self, key: Any):
raise ValueError(f"The embedding type is currently not supported: `{type(key)}`, {key}")
Expand Down
Empty file.
Loading
Loading