Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple retriever systems. #155

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
200 changes: 200 additions & 0 deletions examples/run_storm_wiki_gpt_with_mult_sources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
"""
This STORM Wiki pipeline powered by GPT-3.5/4 and local retrieval model that uses Qdrant.
You need to set up the following environment variables to run this script:
- OPENAI_API_KEY: OpenAI API key
- OPENAI_API_TYPE: OpenAI API type (e.g., 'openai' or 'azure')
- QDRANT_API_KEY: Qdrant API key (needed ONLY if online vector store was used)
- YDC_API_KEY: You.com API key; or, BING_SEARCH_API_KEY: Bing Search API key

You will also need an existing Qdrant vector store either saved in a folder locally offline or in a server online.
If not, then you would need a CSV file with documents, and the script is going to create the vector store for you.
The CSV should be in the following format:
content | title | url | description
I am a document. | Document 1 | docu-n-112 | A self-explanatory document.
I am another document. | Document 2 | docu-l-13 | Another self-explanatory document.

Notice that the URL will be a unique identifier for the document so ensure different documents have different urls.

Output will be structured as below
args.output_dir/
topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash
conversation_log.json # Log of information-seeking conversation
raw_search_results.json # Raw search results from search engine
direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge
storm_gen_outline.txt # Outline refined with collected information
url_to_info.json # Sources that are used in the final article
storm_gen_article.txt # Final article generated
storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True)
"""

import os
import sys
from argparse import ArgumentParser

sys.path.append('./')
from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.rm import VectorRM, BingSearch, YouRM
from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel
from knowledge_storm.utils import load_api_key, QdrantVectorStoreManager


def main(args):
# Load API key from the specified toml file path
load_api_key(toml_file_path='secrets.toml')

# Initialize the language model configurations
engine_lm_configs = STORMWikiLMConfigs()
openai_kwargs = {
'api_key': os.getenv("OPENAI_API_KEY"),
'api_provider': os.getenv('OPENAI_API_TYPE'),
'temperature': 1.0,
'top_p': 0.9,
}

# STORM is a LM system so different components can be powered by different models.
# For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm
# which is used to split queries, synthesize answers in the conversation. We recommend using stronger models
# for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm
# which is responsible for generating sections with citations.
conv_simulator_lm = OpenAIModel(model='gpt-4o-mini', max_tokens=500, **openai_kwargs)
question_asker_lm = OpenAIModel(model='gpt-4o', max_tokens=500, **openai_kwargs)
outline_gen_lm = OpenAIModel(model='gpt-4o', max_tokens=400, **openai_kwargs)
article_gen_lm = OpenAIModel(model='gpt-4o', max_tokens=700, **openai_kwargs)
article_polish_lm = OpenAIModel(model='gpt-4o', max_tokens=4000, **openai_kwargs)

engine_lm_configs.set_conv_simulator_lm(conv_simulator_lm)
engine_lm_configs.set_question_asker_lm(question_asker_lm)
engine_lm_configs.set_outline_gen_lm(outline_gen_lm)
engine_lm_configs.set_article_gen_lm(article_gen_lm)
engine_lm_configs.set_article_polish_lm(article_polish_lm)

# Initialize the engine arguments
engine_args = STORMWikiRunnerArguments(
output_dir=args.output_dir,
max_conv_turn=args.max_conv_turn,
max_perspective=args.max_perspective,
search_top_k=args.search_top_k,
max_thread_num=args.max_thread_num,
)

# Create / update the vector store with the documents in the csv file
if args.csv_file_path:
kwargs = {
'file_path': args.csv_file_path,
'content_column': 'content',
'title_column': 'title',
'url_column': 'url',
'desc_column': 'description',
'batch_size': args.embed_batch_size,
'vector_db_mode': args.vector_db_mode,
'collection_name': args.collection_name,
'embedding_model': args.embedding_model,
'device': args.device,
}
if args.vector_db_mode == 'offline':
QdrantVectorStoreManager.create_or_update_vector_store(
vector_store_path=args.offline_vector_db_dir,
**kwargs
)
elif args.vector_db_mode == 'online':
QdrantVectorStoreManager.create_or_update_vector_store(
url=args.online_vector_db_url,
api_key=os.getenv('QDRANT_API_KEY'),
**kwargs
)
# Setup VectorRM to retrieve information from your own data
rm = VectorRM(collection_name=args.collection_name, embedding_model=args.embedding_model,
device=args.device, k=engine_args.search_top_k,
nickname='vector_db',
description=args.vector_db_desc)

# initialize the vector store, either online (store the db on Qdrant server) or offline (store the db locally):
if args.vector_db_mode == 'offline':
rm.init_offline_vector_db(vector_store_path=args.offline_vector_db_dir)
elif args.vector_db_mode == 'online':
rm.init_online_vector_db(url=args.online_vector_db_url, api_key=os.getenv('QDRANT_API_KEY'))

# setup the internet rm
if args.internet_retriever == 'bing':
rm_internet = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'),
k=engine_args.search_top_k,
nickname='bing_api',
description="Bing is a search engine that provides information from the internet. This is where you lookup any information" \
" you think the other retrieval systems will not have. This is searching the live internet right now in 2024.")
elif args.internet_retriever == 'you':
rm_internet = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'),
k=engine_args.search_top_k,
nickname='you_api',
description="You is a search engine that provides information from the internet. This is where you lookup any information" \
" you think the other retrieval systems will not have. This is searching the live internet right now in 2024.")

# Initialize the STORM Wiki Runner
runner = STORMWikiRunner(engine_args, engine_lm_configs, [rm, rm_internet])

# run the pipeline
topic = input('Topic: ')
runner.run(
topic=topic,
do_research=args.do_research,
do_generate_outline=args.do_generate_outline,
do_generate_article=args.do_generate_article,
do_polish_article=args.do_polish_article,
)
runner.post_run()
runner.summary()


if __name__ == "__main__":
parser = ArgumentParser()
# global arguments
parser.add_argument('--output-dir', type=str, default='./results/gpt_retrieval',
help='Directory to store the outputs.')
parser.add_argument('--max-thread-num', type=int, default=3,
help='Maximum number of threads to use. The information seeking part and the article generation'
'part can speed up by using multiple threads. Consider reducing it if keep getting '
'"Exceed rate limit" error when calling LM API.')
# provide local corpus and set up vector db
parser.add_argument('--collection-name', type=str, default="my_documents",
help='The collection name for vector store.')
parser.add_argument('--embedding_model', type=str, default="BAAI/bge-m3",
help='The collection name for vector store.')
parser.add_argument('--device', type=str, default="mps",
help='The device used to run the retrieval model (mps, cuda, cpu, etc).')
parser.add_argument('--vector-db-mode', type=str, choices=['offline', 'online'],
help='The mode of the Qdrant vector store (offline or online).')
parser.add_argument('--vector-db-desc', type=str, default="A custom collection of documents stored with Qdrant.",
help='The description of the vector store.')
parser.add_argument('--offline-vector-db-dir', type=str, default='./vector_store',
help='If use offline mode, please provide the directory to store the vector store.')
parser.add_argument('--online-vector-db-url', type=str,
help='If use online mode, please provide the url of the Qdrant server.')
parser.add_argument('--csv-file-path', type=str, default=None,
help='The path of the custom document corpus in CSV format. The CSV file should include '
'content, title, url, and description columns.')
parser.add_argument('--embed-batch-size', type=int, default=64,
help='Batch size for embedding the documents in the csv file.')
parser.add_argument('--internet-retriever', type=str, choices=['bing', 'you'], default='you',
help='The search engine API to use for retrieving information.')
# stage of the pipeline
parser.add_argument('--do-research', action='store_true',
help='If True, simulate conversation to research the topic; otherwise, load the results.')
parser.add_argument('--do-generate-outline', action='store_true',
help='If True, generate an outline for the topic; otherwise, load the results.')
parser.add_argument('--do-generate-article', action='store_true',
help='If True, generate an article for the topic; otherwise, load the results.')
parser.add_argument('--do-polish-article', action='store_true',
help='If True, polish the article by adding a summarization section and (optionally) removing '
'duplicate content.')
# hyperparameters for the pre-writing stage
parser.add_argument('--max-conv-turn', type=int, default=3,
help='Maximum number of questions in conversational question asking.')
parser.add_argument('--max-perspective', type=int, default=3,
help='Maximum number of perspectives to consider in perspective-guided question asking.')
parser.add_argument('--search-top-k', type=int, default=3,
help='Top k search results to consider for each search query.')
# hyperparameters for the writing stage
parser.add_argument('--retrieve-top-k', type=int, default=3,
help='Top k collected references for each section title.')
parser.add_argument('--remove-duplicate', action='store_true',
help='If True, remove duplicate content from the article.')
main(parser.parse_args())
4 changes: 2 additions & 2 deletions knowledge_storm/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,11 @@ class Retriever(ABC):
The retrieval model/search engine used for each part should be declared with a suffix '_rm' in the attribute name.
"""

def __init__(self, search_top_k):
def __init__(self, search_top_k): # this search_top_k is not used anywhere later on
self.search_top_k = search_top_k

def update_search_top_k(self, k):
self.search_top_k = k
self.search_top_k = k # same thing for this as well.

def collect_and_reset_rm_usage(self):
combined_usage = []
Expand Down
Loading