-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added support for the HyDE method in quey analysis for RAG plates #1413
base: main
Are you sure you want to change the base?
Changes from all commits
0e81347
2cc6d60
b4fa468
2819b2e
ec40043
b2458d8
008fe37
775130b
3e53b33
2b85048
d5c3c20
2619eff
c053214
5fd3670
621cb22
83041d2
26b0285
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,8 @@ | |
from metagpt.const import DATA_PATH, EXAMPLE_DATA_PATH | ||
from metagpt.logs import logger | ||
from metagpt.rag.engines import SimpleEngine | ||
from metagpt.rag.query_analysis.hyde import HyDEQuery | ||
from metagpt.rag.query_analysis.simple_query_transformer import SimpleQueryTransformer | ||
from metagpt.rag.schema import ( | ||
ChromaIndexConfig, | ||
ChromaRetrieverConfig, | ||
|
@@ -22,7 +24,7 @@ | |
|
||
DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt" | ||
QUESTION = f"What are key qualities to be a good writer? {LLM_TIP}" | ||
|
||
QUESTION2 = "What are the key factors in maintaining high productivity?" | ||
TRAVEL_DOC_PATH = EXAMPLE_DATA_PATH / "rag/travel.txt" | ||
TRAVEL_QUESTION = f"What does Bob like? {LLM_TIP}" | ||
|
||
|
@@ -212,9 +214,31 @@ async def init_and_query_es(self): | |
answer = await engine.aquery(TRAVEL_QUESTION) | ||
self._print_query_result(answer) | ||
|
||
async def use_hyde(self): | ||
"""This example show how to use HyDE: HyDE enhances search results by generating Hypothetical doc(virtual | ||
article), for more details please refer to the paper: http://arxiv.org/abs/2212.10496 | ||
Query Result: | ||
Bob likes traveling. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment correct? |
||
""" | ||
|
||
self._print_title("Use HyDE to analysis query") | ||
# 1. save docs | ||
engine = SimpleEngine.from_docs(input_files=[DOC_PATH]) | ||
# 2. Initialize HyDE query analysis method | ||
hyde_query = HyDEQuery() | ||
# 3. Add HyDE to the engine | ||
hyde_query_engine = SimpleQueryTransformer(engine, hyde_query) | ||
answer = await hyde_query_engine.aquery(QUESTION2) | ||
self._print_query_result(answer) | ||
|
||
self._print_title("No use HyDE to analysis query") | ||
answer = await engine.aquery(QUESTION2) | ||
|
||
self._print_query_result(answer) | ||
|
||
@staticmethod | ||
def _print_title(title): | ||
logger.info(f"{'#'*30} {title} {'#'*30}") | ||
logger.info(f"{'#' * 30} {title} {'#' * 30}") | ||
|
||
@staticmethod | ||
def _print_retrieve_result(result): | ||
|
@@ -254,6 +278,7 @@ async def main(): | |
await e.init_objects() | ||
await e.init_and_query_chromadb() | ||
await e.init_and_query_es() | ||
await e.use_hyde() | ||
|
||
|
||
if __name__ == "__main__": | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,15 +12,16 @@ | |
from pydantic import BaseModel, model_validator | ||
|
||
from metagpt.configs.browser_config import BrowserConfig | ||
from metagpt.configs.embedding_config import EmbeddingConfig | ||
from metagpt.configs.file_parser_config import OmniParseConfig | ||
from metagpt.configs.llm_config import LLMConfig, LLMType | ||
from metagpt.configs.mermaid_config import MermaidConfig | ||
from metagpt.configs.rag_config import RAGConfig | ||
from metagpt.configs.redis_config import RedisConfig | ||
from metagpt.configs.s3_config import S3Config | ||
from metagpt.configs.search_config import SearchConfig | ||
from metagpt.configs.workspace_config import WorkspaceConfig | ||
from metagpt.const import CONFIG_ROOT, METAGPT_ROOT | ||
from MetaGPT.metagpt.configs.rag_config import RAGConfig | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needs to be deleted |
||
from metagpt.utils.yaml_model import YamlModel | ||
|
||
|
||
|
@@ -49,12 +50,10 @@ class Config(CLIParams, YamlModel): | |
# Key Parameters | ||
llm: LLMConfig | ||
|
||
# RAG Embedding | ||
embedding: EmbeddingConfig = EmbeddingConfig() | ||
|
||
# RAG | ||
rag: RAGConfig = RAGConfig() | ||
# omniparse | ||
omniparse: OmniParseConfig = OmniParseConfig() | ||
|
||
# Global Proxy. Will be used if llm.proxy is not set | ||
proxy: str = "" | ||
|
||
|
@@ -73,7 +72,6 @@ class Config(CLIParams, YamlModel): | |
workspace: WorkspaceConfig = WorkspaceConfig() | ||
enable_longterm_memory: bool = False | ||
code_review_k_times: int = 2 | ||
agentops_api_key: str = "" | ||
|
||
# Will be removed in the future | ||
metagpt_tti_url: str = "" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from metagpt.utils.yaml_model import YamlModel | ||
|
||
|
||
class HyDEConfig(YamlModel): | ||
include_original: bool = True |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from metagpt.configs.hyde_config import HyDEConfig | ||
from metagpt.utils.yaml_model import YamlModel | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use rag_config.py to support independent rag configuration |
||
|
||
class QueryAnalysisConfig(YamlModel): | ||
hyde: HyDEConfig = HyDEConfig() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from metagpt.configs.embedding_config import EmbeddingConfig | ||
from metagpt.configs.query_analysis_config import QueryAnalysisConfig | ||
from metagpt.utils.yaml_model import YamlModel | ||
|
||
|
||
class RAGConfig(YamlModel): | ||
embedding: EmbeddingConfig = EmbeddingConfig() | ||
query_analysis: QueryAnalysisConfig = QueryAnalysisConfig() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is recommended to add QueryAnalysisConfig and EmbeddingConfig in rag_config.py without hyde_config.py and query_analysis_config.py files. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import json | ||
from typing import Any, List, Optional | ||
|
||
from llama_index.core.base.base_query_engine import BaseQueryEngine | ||
from llama_index.core.base.base_retriever import BaseRetriever | ||
from llama_index.core.evaluation.benchmarks import HotpotQAEvaluator | ||
from llama_index.core.evaluation.benchmarks.hotpotqa import exact_match_score, f1_score | ||
from llama_index.core.query_engine.retriever_query_engine import RetrieverQueryEngine | ||
from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode | ||
|
||
from metagpt.const import EXAMPLE_DATA_PATH | ||
from metagpt.rag.engines import SimpleEngine | ||
from metagpt.rag.query_analysis.hyde import HyDEQuery | ||
from metagpt.rag.query_analysis.simple_query_transformer import SimpleQueryTransformer | ||
from metagpt.rag.schema import FAISSRetrieverConfig | ||
|
||
|
||
class HotpotQA(HotpotQAEvaluator): | ||
def run( | ||
self, | ||
query_engine: BaseQueryEngine, | ||
queries: int = 10, | ||
queries_fraction: Optional[float] = None, | ||
show_result: bool = False, | ||
hyde: bool = False, | ||
) -> None: | ||
dataset_paths = self._download_datasets() | ||
dataset = "hotpot_dev_distractor" | ||
dataset_path = dataset_paths[dataset] | ||
print("Evaluating on dataset:", dataset) | ||
print("-------------------------------------") | ||
|
||
f = open(dataset_path) | ||
query_objects = json.loads(f.read()) | ||
if queries_fraction: | ||
queries_to_load = int(len(query_objects) * queries_fraction) | ||
else: | ||
queries_to_load = queries | ||
queries_fraction = round(queries / len(query_objects), 5) | ||
|
||
print( | ||
f"Loading {queries_to_load} queries out of \ | ||
{len(query_objects)} (fraction: {queries_fraction})" | ||
) | ||
query_objects = query_objects[:queries_to_load] | ||
|
||
assert isinstance( | ||
query_engine, RetrieverQueryEngine | ||
), "query_engine must be a RetrieverQueryEngine for this evaluation" | ||
retriever = HotpotQARetriever(query_objects, hyde) | ||
# Mock the query engine with a retriever | ||
query_engine = query_engine.with_retriever(retriever=retriever) | ||
if hyde: | ||
hyde_query = HyDEQuery() | ||
query_engine = SimpleQueryTransformer(query_engine, hyde_query) | ||
scores = {"exact_match": 0.0, "f1": 0.0} | ||
|
||
for query in query_objects: | ||
if hyde: | ||
response = query_engine.query( | ||
query["question"] + " Give a short factoid answer (as few words as possible)." | ||
) | ||
else: | ||
query_bundle = QueryBundle( | ||
query_str=query["question"] + " Give a short factoid answer (as few words as possible).", | ||
custom_embedding_strs=[query["question"]], | ||
) | ||
response = query_engine.query(query_bundle) | ||
em = int(exact_match_score(prediction=str(response), ground_truth=query["answer"])) | ||
f1, _, _ = f1_score(prediction=str(response), ground_truth=query["answer"]) | ||
scores["exact_match"] += em | ||
scores["f1"] += f1 | ||
if show_result: | ||
print("Question: ", query["question"]) | ||
print("Response:", response) | ||
print("Correct answer: ", query["answer"]) | ||
print("EM:", em, "F1:", f1) | ||
print("-------------------------------------") | ||
|
||
for score in scores: | ||
scores[score] /= len(query_objects) | ||
|
||
print("Scores: ", scores) | ||
|
||
|
||
class HotpotQARetriever(BaseRetriever): | ||
""" | ||
This is a mocked retriever for HotpotQA dataset. It is only meant to be used | ||
with the hotpotqa dev dataset in the distractor setting. This is the setting that | ||
does not require retrieval but requires identifying the supporting facts from | ||
a list of 10 sources. | ||
""" | ||
|
||
def __init__(self, query_objects: Any, hyde: bool) -> None: | ||
self.hyde = hyde | ||
assert isinstance( | ||
query_objects, | ||
list, | ||
), f"query_objects must be a list, got: {type(query_objects)}" | ||
self._queries = {} | ||
for object in query_objects: | ||
self._queries[object["question"]] = object | ||
|
||
def _retrieve(self, query: QueryBundle) -> List[NodeWithScore]: | ||
if query.custom_embedding_strs and self.hyde is False: | ||
query_str = query.custom_embedding_strs[0] | ||
else: | ||
query_str = query.query_str.replace(" Give a short factoid answer (as few words as possible).", "") | ||
contexts = self._queries[query_str]["context"] | ||
node_with_scores = [] | ||
for ctx in contexts: | ||
text_list = ctx[1] | ||
text = "\n".join(text_list) | ||
node = TextNode(text=text, metadata={"title": ctx[0]}) | ||
node_with_scores.append(NodeWithScore(node=node, score=1.0)) | ||
|
||
return node_with_scores | ||
|
||
def __str__(self) -> str: | ||
return "HotpotQARetriever" | ||
|
||
|
||
if __name__ == "__main__": | ||
DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt" | ||
engine = SimpleEngine.from_docs(input_files=[DOC_PATH], retriever_configs=[FAISSRetrieverConfig()]) | ||
HotpotQA().run(engine, queries=100, show_result=True, hyde=True) | ||
HotpotQA().run(engine, queries=100, show_result=True, hyde=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't change this embedding one.