-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #358 from artificialfintelligence/blogger
Blog-Writer Feature
- Loading branch information
Showing
8 changed files
with
475 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Enter your OpenAI API key and Serper API key immediately after the equals sign without spaces or enclosing quotation marks. | ||
OPENAI_API_KEY= | ||
SERPER_API_KEY= | ||
|
||
# Uncomment the line below to enable debugging logs. | ||
# LOG_LEVEL=DEBUG |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Blog Writer Project | ||
This community project aims to utilize Sherpa as a library (`sherpa_ai`) as well as direct LLM calls (to the OpenAI API) with LangChain to construct a blog post from the raw | ||
transcript of a lecture or presenttaion. | ||
|
||
Please refer to [this How-to Guide](https://github.com/Aggregate-Intellect/sherpa/tree/main/docs/How_To/Tutorials/blog_writer.rst) for further details. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from langchain.document_loaders import PDFMinerLoader | ||
from langchain.text_splitter import SentenceTransformersTokenTextSplitter | ||
from langchain.vectorstores.chroma import Chroma | ||
from loguru import logger | ||
|
||
from sherpa_ai.actions.base import BaseAction | ||
|
||
|
||
class DocumentSearch(BaseAction): | ||
def __init__(self, filename, embedding_function, k=4): | ||
# file name of the pdf | ||
self.filename = filename | ||
# the embedding function to use | ||
self.embedding_function = embedding_function | ||
# number of results to return in search | ||
self.k = k | ||
|
||
# load the pdf and create the vector store | ||
self.chroma = Chroma(embedding_function = embedding_function) | ||
documents = PDFMinerLoader(self.filename).load() | ||
documents = SentenceTransformersTokenTextSplitter(chunk_overlap=0).split_documents(documents) | ||
|
||
logger.info(f"Adding {len(documents)} documents to the vector store") | ||
self.chroma.add_documents(documents) | ||
logger.info("Finished adding documents to the vector store") | ||
|
||
def execute(self, query): | ||
""" | ||
Execute the action by searching the document store for the query | ||
Args: | ||
query (str): The query to search for | ||
Returns: | ||
str: The search results combined into a single string | ||
""" | ||
|
||
results = self.chroma.search(query, search_type="mmr", k=self.k) | ||
return "\n\n".join([result.page_content for result in results]) | ||
|
||
@property | ||
def name(self) -> str: | ||
""" | ||
The name of the action, used to describe the action to the agent. | ||
""" | ||
return "DocumentSearch" | ||
|
||
@property | ||
def args(self) -> dict: | ||
""" | ||
The arguments that the action takes, used to describe the action to the agent. | ||
""" | ||
return { | ||
"query": "string" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
shared_memory: | ||
_target_: sherpa_ai.memory.shared_memory.SharedMemory # The absolute path to the share memory class in the library | ||
objective: Answer the question # Objective for the agent, since this is a question answering agent, the objective is to answer questions | ||
|
||
agent_config: # For the demo, default configuration is used. You can change the configuration as per your requirement | ||
_target_: sherpa_ai.config.task_config.AgentConfig | ||
|
||
|
||
llm: # Configuration for the llm, here we are using the OpenAI GPT-3.5-turbo model | ||
_target_: langchain.chat_models.ChatOpenAI | ||
model_name: gpt-3.5-turbo | ||
temperature: 0 | ||
|
||
embedding_func: | ||
_target_: langchain.embeddings.SentenceTransformerEmbeddings | ||
model_name: sentence-transformers/all-mpnet-base-v2 | ||
|
||
doc_search: | ||
_target_: actions.DocumentSearch | ||
filename: transcript.pdf | ||
embedding_function: ${embedding_func} | ||
k: 4 | ||
|
||
google_search: | ||
_target_: sherpa_ai.actions.GoogleSearch | ||
role_description: Act as a question answering agent | ||
task: Question answering | ||
llm: ${llm} | ||
include_metadata: true | ||
config: ${agent_config} | ||
|
||
citation_validation: # The tool used to validate and add citation to the answer | ||
_target_: sherpa_ai.output_parsers.citation_validation.CitationValidation | ||
sequence_threshold: 0.6 | ||
jaccard_threshold: 0.6 | ||
token_overlap: 0.6 | ||
|
||
qa_agent: | ||
_target_: sherpa_ai.agents.qa_agent.QAAgent | ||
llm: ${llm} | ||
shared_memory: ${shared_memory} | ||
name: QA Sherpa | ||
description: You are a technical writing assistant that helps users write articles. For each prompt, use Google Search to find detailed information that supports and expands on the prompt. | ||
agent_config: ${agent_config} | ||
num_runs: 1 | ||
validation_steps: 1 | ||
actions: | ||
- ${google_search} | ||
# - ${doc_search} | ||
validations: | ||
- ${citation_validation} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import json | ||
from argparse import ArgumentParser | ||
|
||
from hydra.utils import instantiate | ||
from omegaconf import OmegaConf | ||
from sherpa_ai.agents import QAAgent | ||
from sherpa_ai.events import EventType | ||
|
||
from outliner import Outliner | ||
|
||
# from sherpa_ai.memory import Belief | ||
|
||
|
||
def get_qa_agent_from_config_file( | ||
config_path: str, | ||
) -> QAAgent: | ||
""" | ||
Create a QAAgent from a config file. | ||
Args: | ||
config_path: Path to the config file | ||
Returns: | ||
QAAgent: A QAAgent instance | ||
""" | ||
|
||
config = OmegaConf.load(config_path) | ||
|
||
agent_config = instantiate(config.agent_config) | ||
qa_agent: QAAgent = instantiate(config.qa_agent, agent_config=agent_config) | ||
|
||
return qa_agent | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser() | ||
parser.add_argument("--config", type=str, default="agent_config.yaml") | ||
parser.add_argument("--transcript", type=str, default="transcript.txt") | ||
args = parser.parse_args() | ||
|
||
writer_agent = get_qa_agent_from_config_file(args.config) | ||
|
||
outliner = Outliner(args.transcript) | ||
blueprint = outliner.full_transcript2outline_json(verbose=True) | ||
if blueprint.startswith("```"): | ||
# The first and last lines are code block delimiters; remove them | ||
lines = blueprint.split("\n")[1:-1] | ||
pure_json_str = "\n".join(lines) | ||
else: | ||
pure_json_str = blueprint | ||
|
||
with open("blueprint.json", "w") as f: | ||
f.write(pure_json_str) | ||
|
||
# with open("blueprint_manual.json", "r") as f: | ||
# pure_json_str = f.read() | ||
|
||
parsed_json = json.loads(pure_json_str) | ||
|
||
blog = "" | ||
thesis = parsed_json.get("Thesis Statement", "") | ||
blog += f"# Introduction\n{thesis}\n" | ||
arguments = parsed_json.get("Supporting Arguments", []) | ||
for argument in arguments: | ||
blog += f"## {argument['Argument']}\n" | ||
evidences = argument.get("Evidence", []) | ||
for evidence in evidences: | ||
writer_agent.shared_memory.add(EventType.task, "human", evidence) | ||
result = writer_agent.run() | ||
# writer_agent.belief = Belief() | ||
blog += f"{result}\n" | ||
|
||
with open("blog.md", "w") as f: | ||
f.write(blog) | ||
|
||
print("\nBlog generated successfully!\n") | ||
|
||
# save_format = None | ||
# while save_format is None: | ||
# save_format = input( | ||
# "Select format to save the blog in: 1. Markdown (Default) 2. ReStructured Text\n" | ||
# ) | ||
|
||
# if save_format == "2": | ||
# output = pypandoc.convert("blog.md", "rst") | ||
# if os.path.exists("blog.md"): | ||
# os.remove("blog.md") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import os | ||
import time | ||
|
||
import tiktoken | ||
from langchain.chat_models import ChatOpenAI | ||
from langchain.prompts.chat import ( | ||
ChatPromptTemplate, | ||
HumanMessagePromptTemplate, | ||
SystemMessagePromptTemplate, | ||
) | ||
from langchain.text_splitter import MarkdownTextSplitter | ||
|
||
|
||
class Outliner: | ||
def __init__(self, transcript_file) -> None: | ||
with open(transcript_file, "r") as f: | ||
self.raw_transcript = f.read() | ||
# instantiate chat model | ||
self.chat = ChatOpenAI( | ||
openai_api_key=os.environ.get("OPENAI_API_KEY"), | ||
temperature=0, | ||
model="gpt-3.5-turbo", | ||
) | ||
|
||
def num_tokens_from_string( | ||
self, string: str, encoding_name="cl100k_base" | ||
) -> int: | ||
"""Returns the number of tokens in a text string.""" | ||
encoding = tiktoken.get_encoding(encoding_name) | ||
num_tokens = len(encoding.encode(string)) | ||
return num_tokens | ||
|
||
def transcript_splitter(self, chunk_size=3000, chunk_overlap=200): | ||
markdown_splitter = MarkdownTextSplitter( | ||
chunk_size=chunk_size, chunk_overlap=chunk_overlap | ||
) | ||
transcript_chunks = markdown_splitter.create_documents( | ||
[self.raw_transcript] | ||
) | ||
return transcript_chunks | ||
|
||
def transcript2insights(self, transcript): | ||
system_template = "You are a helpful assistant that summarizes transcripts of podcasts or lectures." | ||
system_prompt = SystemMessagePromptTemplate.from_template( | ||
system_template | ||
) | ||
human_template = """From this chunk of a presentation transcript, extract a short list of key insights. \ | ||
Skip explaining what you're doing, labeling the insights and writing conclusion paragraphs. \ | ||
The insights have to be phrased as statements of facts with no references to the presentation or the transcript. \ | ||
Statements have to be full sentences and in terms of words and phrases as close as possible to those used in the transcript. \ | ||
Keep as much detail as possible. The transcript of the presentation is delimited in triple backticks. | ||
Desired output format: | ||
- [Key insight #1] | ||
- [Key insight #2] | ||
- [...] | ||
Transcript: | ||
```{transcript}```""" | ||
human_prompt = HumanMessagePromptTemplate.from_template(human_template) | ||
chat_prompt = ChatPromptTemplate.from_messages( | ||
[system_prompt, human_prompt] | ||
) | ||
|
||
result = self.chat( | ||
chat_prompt.format_prompt(transcript=transcript).to_messages() | ||
) | ||
|
||
return result.content | ||
|
||
def create_essay_insights(self, transcript_chunks, verbose=True): | ||
response = "" | ||
for i, text in enumerate(transcript_chunks): | ||
insights = self.transcript2insights(text.page_content) | ||
response = "\n".join([response, insights]) | ||
if verbose: | ||
print( | ||
f"\nInsights extracted from chunk {i+1}/{len(transcript_chunks)}:\n{insights}" | ||
) | ||
return response | ||
|
||
def create_blueprint(self, statements, verbose=True): | ||
system_template = """You are a helpful AI blogger who writes essays on technical topics.""" | ||
system_prompt = SystemMessagePromptTemplate.from_template( | ||
system_template | ||
) | ||
|
||
human_template = """Organize the following list of statements (delimited in triple backticks) to create the outline \ | ||
for a blog post in JSON format. The highest level is the most plausible statement as the overarching thesis \ | ||
statement of the post, the next layers are statements providing supporting arguments for the thesis statement. \ | ||
The last layer are pieces of evidence for each of the supporting arguments, directly quoted from the provided \ | ||
list of statements. Use as many of the provided statements as possible. Keep their wording as is without paraphrasing them. \ | ||
Retain as many technical details as possible. The thesis statement, supporting arguments, and evidences must be \ | ||
full sentences containing claims. Label each layer with the appropriate level title and create the desired JSON output format below. \ | ||
Only output the JSON and skip explaining what you're doing: | ||
Desired output format: | ||
{{ | ||
"Thesis Statement": "...", | ||
"Supporting Arguments": [ | ||
{{ | ||
"Argument": "...", | ||
"Evidence": [ | ||
"...", "...", "...", ... | ||
] | ||
}}, | ||
{{ | ||
"Argument": "...", | ||
"Evidence": [ | ||
"...", "...", "...", ... | ||
] | ||
}}, | ||
... | ||
] | ||
}} | ||
Statements: | ||
```{statements}```""" | ||
human_prompt = HumanMessagePromptTemplate.from_template(human_template) | ||
chat_prompt = ChatPromptTemplate.from_messages( | ||
[system_prompt, human_prompt] | ||
) | ||
|
||
outline = self.chat( | ||
chat_prompt.format_prompt(statements=statements).to_messages() | ||
) | ||
|
||
if verbose: | ||
print(f"\nEssay outline: {outline.content}\n") | ||
return outline.content | ||
|
||
# @timer_decorator | ||
def full_transcript2outline_json(self, verbose=True): | ||
print("\nChunking transcript...") | ||
transcript_docs = self.transcript_splitter() | ||
t1 = time.time() | ||
print("\nExtracting key insights...") | ||
essay_insights = self.create_essay_insights(transcript_docs, verbose) | ||
t2 = time.time() - t1 | ||
print("\nCreating essay outline...") | ||
t1 = time.time() | ||
blueprint = self.create_blueprint(essay_insights, verbose) | ||
t3 = time.time() - t1 | ||
if verbose: | ||
print() | ||
print(f"Extracted essay insights in {t2:.2f} seconds.") | ||
print(f"Created essay blueprint in {t3:.2f} seconds.") | ||
return blueprint |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
pdfminer.six | ||
sentence-transformers | ||
langchain==0.0.332 | ||
python-dotenv>=1.0.0 | ||
openai>=0.28.0 | ||
tiktoken>=0.4.0 | ||
sherpa-ai >= 0.2.1 |
Oops, something went wrong.