Skip to content

Commit

Permalink
feat: custom prompt support for slash commands!
Browse files Browse the repository at this point in the history
resolves #13
  • Loading branch information
meetbryce committed Oct 11, 2024
1 parent c123133 commit ee36898
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 42 deletions.
65 changes: 42 additions & 23 deletions ossai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
get_since_timeframe_presets,
)

_custom_prompt_cache = {}

def handler_feedback(body):
"""
Expand Down Expand Up @@ -102,22 +103,19 @@ async def handler_tldr_extended_slash_command(
client: WebClient, ack, payload, say, user_id: str
):
await ack()
text = payload.get("text", None)
channel_name = payload["channel_name"]
channel_id = payload["channel_id"]
dm_channel_id = None

dm_channel_id = await get_direct_message_channel_id(client, user_id)
await say(channel=dm_channel_id, text="...")

if text:
return await say("ERROR: custom prompt support coming soon!")

history = await get_channel_history(client, channel_id)
history.reverse()
user = await get_user_context(client, user_id)
title = f"*Summary of #{channel_name}* (last {len(history)} messages)\n"
summarizer = Summarizer()
custom_prompt = payload.get("text", None)
summarizer = Summarizer(custom_prompt=custom_prompt)
summary, run_id = summarizer.summarize_slack_messages(
client,
history,
Expand All @@ -126,7 +124,7 @@ async def handler_tldr_extended_slash_command(
user=user,
)
text, blocks = get_text_and_blocks_for_say(
title=title, run_id=run_id, messages=summary
title=title, run_id=run_id, messages=summary, custom_prompt=custom_prompt
)
return await say(channel=dm_channel_id, text=text, blocks=blocks)

Expand All @@ -136,9 +134,6 @@ async def handler_topics_slash_command(
client: WebClient, ack, payload, say, user_id: str
):
await ack()
text = payload.get("text", None)
if text:
return await say("ERROR: custom prompt support coming soon!")
channel_id = payload["channel_id"]
dm_channel_id = await get_direct_message_channel_id(client, user_id)
await say(channel=dm_channel_id, text="...")
Expand All @@ -149,6 +144,11 @@ async def handler_topics_slash_command(
messages = get_parsed_messages(client, history, with_names=False)
user = await get_user_context(client, user_id)
is_private, channel_name = get_is_private_and_channel_name(client, channel_id)
custom_prompt = payload.get("text", None)
if custom_prompt:
# todo: add support for custom prompts to /tldr
await say(channel=dm_channel_id, text="Sorry, this command doesn't support custom prompts yet so I'm processing your request without it.")

topic_overview, run_id = await analyze_topics_of_history(
channel_name, messages, user=user, is_private=is_private
)
Expand All @@ -164,11 +164,10 @@ async def handler_tldr_since_slash_command(client: WebClient, ack, payload, say)
await ack()
title = "Choose your summary timeframe."
dm_channel_id = await get_direct_message_channel_id(client, payload["user_id"])
text = payload.get("text", None)
if text:
return await say("ERROR: custom prompt support coming soon!")

custom_prompt = payload.get("text", None)

client.chat_postEphemeral(
result = client.chat_postEphemeral(
channel=payload["channel_id"],
user=payload["user_id"],
text=title,
Expand All @@ -184,17 +183,23 @@ async def handler_tldr_since_slash_command(client: WebClient, ack, payload, say)
"text": "Select a date",
"emoji": True,
},
"action_id": "summarize_since",
"action_id": f"summarize_since",
},
],
}
],
)

# get `custom_prompt` into handler_action_summarize_since_date()
key = f"{result['message_ts']}__{payload['user_id']}"
_custom_prompt_cache[key] = custom_prompt
logger.debug(f"Storing `custom_prompt` at {key}: {custom_prompt}")

await say(
channel=dm_channel_id,
text=f'In #{payload["channel_name"]}, choose a date or timeframe to get your summary',
)
return


@catch_errors_dm_user
Expand Down Expand Up @@ -227,14 +232,19 @@ async def handler_action_summarize_since_date(client: WebClient, ack, body):
history = await get_channel_history(client, channel_id, since=since_datetime)
history.reverse()
user = await get_user_context(client, user_id)
summarizer = Summarizer()
custom_prompt = None
if 'container' in body and 'message_ts' in body['container']:
key = f"{body['container']['message_ts']}__{user_id}"
custom_prompt = _custom_prompt_cache.get(key, None)
summarizer = Summarizer(custom_prompt=custom_prompt)
summary, run_id = summarizer.summarize_slack_messages(
client, history, channel_id, feature_name=feature_name, user=user
)
text, blocks = get_text_and_blocks_for_say(
title=f'*Summary of #{channel_name}* since {since_datetime.strftime("%A %b %-d, %Y")} ({len(history)} messages)\n',
run_id=run_id,
messages=summary,
custom_prompt=custom_prompt,
)
# todo: somehow add date/preset choice to langsmith metadata
# feature_name: str -> feature: str || Tuple[str, List(Tuple[str, str])]
Expand All @@ -245,17 +255,26 @@ async def handler_action_summarize_since_date(client: WebClient, ack, body):
async def handler_sandbox_slash_command(
client: WebClient, ack, payload, say, user_id: str
):
text = payload.get("text", None)
if text:
return await say("ERROR: custom prompt support coming soon!")
logger.debug(f"Handling /sandbox command")
await ack()
run_id = str(uuid.uuid4())
run_id = None
text = """-- Better error handling coming soon! Useful summary of content goes here -- (no run id)"""
lines = text.strip().split("\n")
channel_id = payload["channel_id"]
custom_prompt = payload.get("text", None)
summarizer = Summarizer(custom_prompt=custom_prompt)
summary, run_id = summarizer.summarize_slack_messages(
client,
[
{"text": "bacon", "user": user_id},
{"text": "eggs", "user": user_id},
{"text": "spam", "user": user_id},
{"text": "orange juice", "user": user_id},
{"text": "coffee", "user": user_id},
],
channel_id=channel_id,
feature_name="sandbox",
user=user_id,
)
title = "This is a test of the /sandbox command."
text, blocks = get_text_and_blocks_for_say(
title=title, run_id=run_id, messages=lines
title=title, run_id=run_id, messages=summary, custom_prompt=custom_prompt
)
return await say(text=text, blocks=blocks)
18 changes: 14 additions & 4 deletions ossai/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@


class Summarizer:
def __init__(self):
def __init__(self, custom_prompt: str | None = None):
# todo: apply pydantic model
self.config = get_llm_config()
self.model = ChatOpenAI(model=self.config["chat_model"], temperature=self.config["temperature"])
self.parser = StrOutputParser()
self.custom_prompt = custom_prompt

def summarize(
self,
Expand Down Expand Up @@ -62,18 +64,21 @@ def summarize(
So, The assistant needs to speak in {language}.
"""

human_msg = """\
base_human_msg = """\
Please summarize the following chat log to a flat markdown formatted bullet list.
Do not write a line by line summary. Instead, summarize the overall conversation.
Do not include greeting/salutation/polite expressions in summary.
Make the summary easy to read while maintaining a conversational tone and retaining meaning.
Write in conversational English.
{custom_instructions}
{text}
"""

# todo: guard against prompt injection

prompt_template = ChatPromptTemplate.from_messages(
[("system", system_msg), ("user", human_msg)]
[("system", system_msg), ("user", base_human_msg)]
)

chain = prompt_template | self.model | self.parser
Expand All @@ -87,7 +92,12 @@ def summarize(
)
logger.info(f"{langsmith_config=}")
result = chain.invoke(
{"text": text, "language": self.config["language"]}, config=langsmith_config
{
"text": text,
"language": self.config["language"],
"custom_instructions": f"\n\nAdditionally, please follow these specific instructions for this summary:\n{self.custom_prompt}" if self.custom_prompt else "",
},
config=langsmith_config
)
return result, langsmith_config["run_id"]

Expand Down
14 changes: 13 additions & 1 deletion ossai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def parse_message(msg):


def get_text_and_blocks_for_say(
title: str, run_id: Union[uuid.UUID, None], messages: list
title: str, run_id: Union[uuid.UUID, None], messages: list, custom_prompt: str = None
) -> tuple[str, list]:
CHAR_LIMIT = 3000
text = "\n".join(messages)
Expand Down Expand Up @@ -260,6 +260,18 @@ def get_text_and_blocks_for_say(
}
)

if custom_prompt:
blocks.append({
"type": "context",
"elements": [
{
"type": "plain_text",
"text": f"Custom Prompt: {custom_prompt}",
"emoji": True
}
]
})

return text.split("\n")[0], blocks


Expand Down
11 changes: 4 additions & 7 deletions tests/test_handlers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import uuid
import pytest
from slack_sdk import WebClient
Expand Down Expand Up @@ -380,7 +380,7 @@ async def test_handler_tldr_extended_slash_command_non_public(
@patch("ossai.handlers.Summarizer")
@patch("ossai.handlers.get_text_and_blocks_for_say")
@patch("ossai.handlers.get_direct_message_channel_id")
@patch("aiohttp.ClientSession.post", new_callable=AsyncMock) # Change this line
@patch("aiohttp.ClientSession.post", new_callable=AsyncMock)
async def test_handler_action_summarize_since_date(
mock_post,
get_direct_message_channel_id_mock,
Expand Down Expand Up @@ -440,6 +440,7 @@ async def test_handler_action_summarize_since_date(
title="*Summary of #general* since Tuesday Feb 21, 2023 (2 messages)\n",
run_id="run_id",
messages="summary",
custom_prompt=None,
)
client.chat_postMessage.assert_called_with(
channel="DM123", text="text", blocks="blocks"
Expand All @@ -456,7 +457,7 @@ async def test_handler_tldr_since_slash_command_happy_path(
):
# Setup
client = AsyncMock(spec=WebClient)
client.chat_postEphemeral = AsyncMock()
client.chat_postEphemeral = MagicMock()
say = AsyncMock()
payload = {"user_id": "U123", "channel_id": "C123", "channel_name": "general"}
get_since_timeframe_presets_mock.return_value = {"foo": "bar"}
Expand Down Expand Up @@ -646,10 +647,6 @@ async def test_handler_sandbox_slash_command_happy_path():

await handler_sandbox_slash_command(client, ack, payload, say, user_id="foo123")
say.assert_called_once()
assert any(
"Useful summary of content goes here" in str(block)
for block in say.call_args[1]["blocks"]
)
assert any(
"This is a test of the /sandbox command." in str(block)
for block in say.call_args[1]["blocks"]
Expand Down
34 changes: 27 additions & 7 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@

from ossai import utils


@pytest.fixture
def mock_client():
with patch("ossai.utils.WebClient") as mock_client:

def users_info_side_effect(user):
def users_info_side_effect(*args, **kwargs):
user = kwargs.get("user")
users = {
"U123": {
"ok": True,
Expand All @@ -29,10 +31,23 @@ def users_info_side_effect(user):
"profile": {"real_name": "Taylor Garcia", "title": "CTO"},
},
},
"B123": { # Mock bot user
"ok": True,
"bot": {"name": "Bender Bending Rodríguez"},
},
}
return users.get(user, {"ok": False})

mock_client.users_info.side_effect = users_info_side_effect
mock_client.bots_info.side_effect = users_info_side_effect # Handle bots_info similarly
mock_client.auth_test.return_value = {"bot_id": "B123"}
mock_client.conversations_info.return_value = {
"ok": True,
"channel": {"name": "general", "is_private": False},
}
mock_client.conversations_open.return_value = {"channel": {"id": "C123"}}
mock_client.conversations_history.return_value = {"messages": [{"bot_id": "B123"}]}
mock_client.team_info.return_value = {"ok": True, "team": {"name": "Workspace"}}
yield mock_client


Expand Down Expand Up @@ -281,27 +296,32 @@ def test_get_since_timeframe_presets_values(mock_gmtime):
def test_get_text_and_blocks_for_say_block_size():
title = "Test Title"
run_id = uuid.uuid4()

# Create a message that's longer than 3000 characters
long_message = "A" * 4000
messages = [long_message]

_, blocks = utils.get_text_and_blocks_for_say(title, run_id, messages)

# Check that the title is in the first block
assert blocks[0]['text']['text'] == title
assert blocks[0]["text"]["text"] == title

# Check that each block's text is no longer than 3000 characters
for block in blocks[1:-1]: # Exclude the first (title) and last (buttons) blocks
assert len(block['text']['text']) <= 3000, f"Block text exceeds 3000 characters: {len(block['text']['text'])}"
assert (
len(block["text"]["text"]) <= 3000
), f"Block text exceeds 3000 characters: {len(block['text']['text'])}"

# Check that all of the original message is included
combined_text = ''.join(block['text']['text'] for block in blocks[1:-1])
combined_text = "".join(block["text"]["text"] for block in blocks[1:-1])
assert combined_text == long_message

# Check that the last block contains the buttons
assert blocks[-1]['type'] == 'actions'
assert len(blocks[-1]['elements']) == 3 # Three buttons
assert blocks[-1]["type"] == "actions"
assert len(blocks[-1]["elements"]) == 3 # Three buttons


# todo: test get_text_and_blocks_for_say with custom prompt


def test_main_as_script():
Expand Down

0 comments on commit ee36898

Please sign in to comment.