Skip to content

Commit

Permalink
[Hotfix] Address URL Formatting Issue Did Not Get Fixed During 0.3.0 (#…
Browse files Browse the repository at this point in the history
…312)

### Description
- A previous PR didn't test thoroughly and had a bug so that it did not
address the problem.
- Previous PR: #306 
- Goal is to have different system prompt for different slack vs webapp
client
- The previous PR was modifying the user's question prompt not system
prompt

### Related Issue
closes #301 
This issue is supposed to be fixed in 0.3.0, but due to inadequate
testing during that PR merge, it wasn't caught that the PR didn't
address the problem. The bug is re-discovered after 0.3.0 release.

### Tests
1. Sent a request on front end
2. Got URL embedded into the keywords
<img width="1456" alt="image"
src="https://github.com/astronomer/ask-astro/assets/26350341/1f93cecf-6ff4-40db-90bd-6e055330e3ce">
3. Check Langsmith,  metadata has client = webapp
<img width="904" alt="image"
src="https://github.com/astronomer/ask-astro/assets/26350341/df4c2e26-ddeb-4b33-b1f7-f32f7009bafa">
4. Check langsmith, correct system prompt is used
<img width="1044" alt="image"
src="https://github.com/astronomer/ask-astro/assets/26350341/3245201d-73ed-4ebc-b4fe-7deff770cc4a">

[Langsmith
link](https://smith.langchain.com/public/37f44ef4-af11-45a5-8246-e7f2b1997b58/r)
  • Loading branch information
davidgxue authored Mar 7, 2024
1 parent 09b0df1 commit cc76b19
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 18 deletions.
2 changes: 1 addition & 1 deletion airflow/include/streamlit/streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def write_response(text: str):
question to see the prompt."""
)

prompt_path = Path(__file__).parent.joinpath("combine_docs_chat_prompt.txt")
prompt_path = Path(__file__).parent.joinpath("combine_docs_sys_prompt_webapp.txt")
prompt = prompt_path.read_text()

st.write(f"The prompt text can be found in the file __`{prompt_path}`__")
Expand Down
42 changes: 36 additions & 6 deletions api/ask_astro/chains/answer_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,18 @@
MULTI_QUERY_RETRIEVER_TEMPERATURE,
)

with open("ask_astro/templates/combine_docs_chat_prompt.txt") as system_prompt_fd:
"""Load system prompt template from a file and structure it."""
messages = [
SystemMessagePromptTemplate.from_template(system_prompt_fd.read()),
with open("ask_astro/templates/combine_docs_sys_prompt_webapp.txt") as webapp_system_prompt_fd:
"""Load system prompt template for webapp messages"""
webapp_messages = [
SystemMessagePromptTemplate.from_template(webapp_system_prompt_fd.read()),
MessagesPlaceholder(variable_name="messages"),
HumanMessagePromptTemplate.from_template("{question}"),
]

with open("ask_astro/templates/combine_docs_sys_prompt_slack.txt") as slack_system_prompt_fd:
"""Load system prompt template for slack messages"""
slack_messages = [
SystemMessagePromptTemplate.from_template(slack_system_prompt_fd.read()),
MessagesPlaceholder(variable_name="messages"),
HumanMessagePromptTemplate.from_template("{question}"),
]
Expand Down Expand Up @@ -92,7 +100,29 @@
)

# Set up a ConversationalRetrievalChain to generate answers using the retriever.
answer_question_chain = ConversationalRetrievalChain(
webapp_answer_question_chain = ConversationalRetrievalChain(
retriever=llm_chain_filter_compression_retriever,
return_source_documents=True,
question_generator=LLMChain(
llm=AzureChatOpenAI(
**AzureOpenAIParams.us_east2,
deployment_name=CONVERSATIONAL_RETRIEVAL_LLM_CHAIN_DEPLOYMENT_NAME,
temperature=CONVERSATIONAL_RETRIEVAL_LLM_CHAIN_TEMPERATURE,
),
prompt=CONDENSE_QUESTION_PROMPT,
),
combine_docs_chain=load_qa_chain(
AzureChatOpenAI(
**AzureOpenAIParams.us_east2,
deployment_name=CONVERSATIONAL_RETRIEVAL_LOAD_QA_CHAIN_DEPLOYMENT_NAME,
temperature=CONVERSATIONAL_RETRIEVAL_LOAD_QA_CHAIN_TEMPERATURE,
),
chain_type="stuff",
prompt=ChatPromptTemplate.from_messages(webapp_messages),
),
)

slack_answer_question_chain = ConversationalRetrievalChain(
retriever=llm_chain_filter_compression_retriever,
return_source_documents=True,
question_generator=LLMChain(
Expand All @@ -110,6 +140,6 @@
temperature=CONVERSATIONAL_RETRIEVAL_LOAD_QA_CHAIN_TEMPERATURE,
),
chain_type="stuff",
prompt=ChatPromptTemplate.from_messages(messages),
prompt=ChatPromptTemplate.from_messages(slack_messages),
),
)
31 changes: 21 additions & 10 deletions api/ask_astro/services/questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def answer_question(request: AskAstroRequest) -> None:
try:
from langchain import callbacks

from ask_astro.chains.answer_question import answer_question_chain
from ask_astro.chains.answer_question import slack_answer_question_chain, webapp_answer_question_chain

# First, mark the request as in_progress and add it to the database
request.status = "in_progress"
Expand All @@ -93,16 +93,27 @@ async def answer_question(request: AskAstroRequest) -> None:

# Run the question answering chain
with callbacks.collect_runs() as cb:
result = await asyncio.to_thread(
lambda: answer_question_chain(
{
"question": request.prompt,
"chat_history": [],
"messages": request.messages,
},
metadata={"request_id": str(request.uuid)},
if request.client == "slack":
result = await asyncio.to_thread(
lambda: slack_answer_question_chain(
{
"question": request.prompt,
"chat_history": request.messages,
},
metadata={"request_id": str(request.uuid), "client": str(request.client)},
)
)
else:
result = await asyncio.to_thread(
lambda: webapp_answer_question_chain(
{
"question": request.prompt,
"chat_history": [],
"messages": request.messages,
},
metadata={"request_id": str(request.uuid), "client": str(request.client)},
)
)
)
request.langchain_run_id = cb.traced_runs[0].id

logger.info("Question answering chain finished with result %s", result)
Expand Down
17 changes: 17 additions & 0 deletions api/ask_astro/templates/combine_docs_sys_prompt_webapp.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
You are Ask Astro, a friendy and helpful bot.
Only answer questions related to Astronomer, the Astro platform and Apache Airflow.
If the question relates to pricing, licensing, or commercial usage, ask the user to contact support at www.astronomer.io/contact.
If you don't know the answer, just say that you don't know and ask the user to contact support, don't try to make up an answer.
If the supplied context below does not have sufficient information to help answer the question, make a note when answering to let the user know that the answer may contain false information and the user should contact support to verify.
Be concise and precise in your answers and do not apologize.
Format your response using Markdown syntax.
Surround text with SINGLE * to format it in bold or provide emphasis. Examples: GOOD: *This is bold!*. BAD: **This is bold!**.
Support text with _ to format it in italic. Example: _This is italic._
Use the • character for unnumbered lists.
Use the ` character to surround inline code. Example: This is a sentence with some `inline *code*` in it.
Use ``` to surround multi-line code blocks. Do not specify a language in code blocks. Examples: GOOD: ```This is a code block\nAnd it is multi-line``` BAD: ```python print("Hello world!")```.
Format links using this format: [Text to display](URL). Examples: GOOD: [This message **is** a link](https://www.example.com). BAD: <https://www.example.com|This message **is** a link>.
12 character words that start with "<@U" and end with ">" are usernames. Example: <@U024BE7LH>.
Use the following pieces of context to answer the users question.
----------------
{context}
2 changes: 1 addition & 1 deletion tests/api/ask_astro/chains/test_answer_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

def test_system_prompt_loading():
"""Test if the system prompt is loaded correctly"""
with open("ask_astro/templates/combine_docs_chat_prompt.txt") as fd:
with open("ask_astro/templates/combine_docs_sys_prompt_webapp.txt") as fd:
expected_template = fd.read()
template_instance = SystemMessagePromptTemplate.from_template(expected_template)
assert template_instance.prompt.template == expected_template

0 comments on commit cc76b19

Please sign in to comment.