Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Separate out streaming route #2111

Merged
merged 11 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/swarm/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def run(self, agent_name: str, message: str):
# print(self.client.get_agent(agent_id).tools)
# TODO: implement with sending multiple messages
if len(history) == 0:
response = self.client.send_message(agent_id=agent_id, message=message, role="user", include_full_message=True)
response = self.client.send_message(agent_id=agent_id, message=message, role="user")
else:
response = self.client.send_messages(agent_id=agent_id, messages=history, include_full_message=True)
response = self.client.send_messages(agent_id=agent_id, messages=history)

# update history
history += response.messages
Expand Down
47 changes: 13 additions & 34 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,10 @@ def send_message(
stream: Optional[bool] = False,
stream_steps: bool = False,
stream_tokens: bool = False,
include_full_message: Optional[bool] = False,
) -> LettaResponse:
raise NotImplementedError

def user_message(self, agent_id: str, message: str, include_full_message: Optional[bool] = False) -> LettaResponse:
def user_message(self, agent_id: str, message: str) -> LettaResponse:
raise NotImplementedError

def create_human(self, name: str, text: str) -> Human:
Expand Down Expand Up @@ -839,7 +838,7 @@ def get_in_context_messages(self, agent_id: str) -> List[Message]:

# agent interactions

def user_message(self, agent_id: str, message: str, include_full_message: Optional[bool] = False) -> LettaResponse:
def user_message(self, agent_id: str, message: str) -> LettaResponse:
"""
Send a message to an agent as a user

Expand All @@ -850,7 +849,7 @@ def user_message(self, agent_id: str, message: str, include_full_message: Option
Returns:
response (LettaResponse): Response from the agent
"""
return self.send_message(agent_id, message, role="user", include_full_message=include_full_message)
return self.send_message(agent_id=agent_id, message=message, role="user")

def save(self):
raise NotImplementedError
Expand Down Expand Up @@ -937,13 +936,13 @@ def get_messages(

def send_message(
self,
agent_id: str,
message: str,
role: str,
agent_id: Optional[str] = None,
name: Optional[str] = None,
stream: Optional[bool] = False,
stream_steps: bool = False,
stream_tokens: bool = False,
include_full_message: bool = False,
) -> Union[LettaResponse, Generator[LettaStreamingResponse, None, None]]:
"""
Send a message to an agent
Expand All @@ -964,17 +963,11 @@ def send_message(
# TODO: figure out how to handle stream_steps and stream_tokens

# When streaming steps is True, stream_tokens must be False
request = LettaRequest(
messages=messages,
stream_steps=stream_steps,
stream_tokens=stream_tokens,
return_message_object=include_full_message,
)
request = LettaRequest(messages=messages)
if stream_tokens or stream_steps:
from letta.client.streaming import _sse_post

request.return_message_object = False
return _sse_post(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages", request.model_dump(), self.headers)
return _sse_post(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages/stream", request.model_dump(), self.headers)
else:
response = requests.post(
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages", json=request.model_dump(), headers=self.headers
Expand Down Expand Up @@ -2250,7 +2243,6 @@ def send_messages(
self,
agent_id: str,
messages: List[Union[Message | MessageCreate]],
include_full_message: Optional[bool] = False,
):
"""
Send pre-packed messages to an agent.
Expand All @@ -2270,15 +2262,7 @@ def send_messages(
self.save()

# format messages
messages = self.interface.to_list()
if include_full_message:
letta_messages = messages
else:
letta_messages = []
for m in messages:
letta_messages += m.to_letta_message()

return LettaResponse(messages=letta_messages, usage=usage)
return LettaResponse(messages=messages, usage=usage)

def send_message(
self,
Expand All @@ -2289,7 +2273,6 @@ def send_message(
agent_name: Optional[str] = None,
stream_steps: bool = False,
stream_tokens: bool = False,
include_full_message: Optional[bool] = False,
) -> LettaResponse:
"""
Send a message to an agent
Expand Down Expand Up @@ -2338,16 +2321,13 @@ def send_message(

# format messages
messages = self.interface.to_list()
if include_full_message:
letta_messages = messages
else:
letta_messages = []
for m in messages:
letta_messages += m.to_letta_message()
letta_messages = []
for m in messages:
letta_messages += m.to_letta_message()
sarahwooders marked this conversation as resolved.
Show resolved Hide resolved

return LettaResponse(messages=letta_messages, usage=usage)

def user_message(self, agent_id: str, message: str, include_full_message: Optional[bool] = False) -> LettaResponse:
def user_message(self, agent_id: str, message: str) -> LettaResponse:
"""
Send a message to an agent as a user

Expand All @@ -2359,7 +2339,7 @@ def user_message(self, agent_id: str, message: str, include_full_message: Option
response (LettaResponse): Response from the agent
"""
self.interface.clear()
return self.send_message(role="user", agent_id=agent_id, message=message, include_full_message=include_full_message)
return self.send_message(role="user", agent_id=agent_id, message=message)

def run_command(self, agent_id: str, command: str) -> LettaResponse:
"""
Expand Down Expand Up @@ -2951,7 +2931,6 @@ def get_messages(
after=after,
limit=limit,
reverse=True,
return_message_object=True,
)

def list_blocks(self, label: Optional[str] = None, templates_only: Optional[bool] = True) -> List[Block]:
Expand Down
2 changes: 1 addition & 1 deletion letta/functions/function_sets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# If the function fails, throw an exception


def send_message(self: Agent, message: str) -> Optional[str]:
def send_message(self: "Agent", message: str) -> Optional[str]:
"""
Sends a message to the human user.

Expand Down
34 changes: 11 additions & 23 deletions letta/schemas/letta_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,21 @@

class LettaRequest(BaseModel):
messages: Union[List[MessageCreate], List[Message]] = Field(..., description="The messages to be sent to the agent.")
run_async: bool = Field(default=False, description="Whether to asynchronously send the messages to the agent.") # TODO: implement

stream_steps: bool = Field(
default=False, description="Flag to determine if the response should be streamed. Set to True for streaming agent steps."
)
stream_tokens: bool = Field(
default=False,
description="Flag to determine if individual tokens should be streamed. Set to True for token streaming (requires stream_steps = True).",
)

return_message_object: bool = Field(
default=False,
description="Set True to return the raw Message object. Set False to return the Message in the format of the Letta API.",
)

# Flags to support the use of AssistantMessage message types

use_assistant_message: bool = Field(
default=False,
description="[Only applicable if return_message_object is False] If true, returns AssistantMessage objects when the agent calls a designated message tool. If false, return FunctionCallMessage objects for all tool calls.",
)

assistant_message_function_name: str = Field(
assistant_message_tool_name: str = Field(
default=DEFAULT_MESSAGE_TOOL,
description="[Only applicable if use_assistant_message is True] The name of the designated message tool.",
description="The name of the designated message tool.",
)
assistant_message_function_kwarg: str = Field(
assistant_message_tool_kwarg: str = Field(
default=DEFAULT_MESSAGE_TOOL_KWARG,
description="[Only applicable if use_assistant_message is True] The name of the message argument in the designated message tool.",
description="The name of the message argument in the designated message tool.",
)


class LettaStreamingRequest(LettaRequest):
stream_tokens: bool = Field(
default=False,
description="Flag to determine if individual tokens should be streamed. Set to True for token streaming (requires stream_steps = True).",
)
3 changes: 1 addition & 2 deletions letta/schemas/letta_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import LettaMessage, LettaMessageUnion
from letta.schemas.message import Message
from letta.schemas.usage import LettaUsageStatistics
from letta.utils import json_dumps

Expand All @@ -24,7 +23,7 @@ class LettaResponse(BaseModel):
usage (LettaUsageStatistics): The usage statistics
"""

messages: Union[List[Message], List[LettaMessageUnion]] = Field(..., description="The messages returned by the agent.")
messages: List[LettaMessageUnion] = Field(..., description="The messages returned by the agent.")
usage: LettaUsageStatistics = Field(..., description="The usage statistics of the agent.")

def __str__(self):
Expand Down
6 changes: 3 additions & 3 deletions letta/schemas/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def to_json(self):
def to_letta_message(
self,
assistant_message: bool = False,
assistant_message_function_name: str = DEFAULT_MESSAGE_TOOL,
assistant_message_function_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
) -> List[LettaMessage]:
"""Convert message object (in DB format) to the style used by the original Letta API"""

Expand All @@ -156,7 +156,7 @@ def to_letta_message(
for tool_call in self.tool_calls:
# If we're supporting using assistant message,
# then we want to treat certain function calls as a special case
if assistant_message and tool_call.function.name == assistant_message_function_name:
if assistant_message and tool_call.function.name == assistant_message_tool_name:
# We need to unpack the actual message contents from the function call
try:
func_args = json.loads(tool_call.function.arguments)
Expand Down
31 changes: 12 additions & 19 deletions letta/server/rest_api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,8 @@ def __init__(
self,
multi_step=True,
# Related to if we want to try and pass back the AssistantMessage as a special case function
use_assistant_message=False,
assistant_message_function_name=DEFAULT_MESSAGE_TOOL,
assistant_message_function_kwarg=DEFAULT_MESSAGE_TOOL_KWARG,
assistant_message_tool_name=DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg=DEFAULT_MESSAGE_TOOL_KWARG,
# Related to if we expect inner_thoughts to be in the kwargs
inner_thoughts_in_kwargs=True,
inner_thoughts_kwarg=INNER_THOUGHTS_KWARG,
Expand All @@ -287,7 +286,7 @@ def __init__(
self.streaming_chat_completion_mode_function_name = None # NOTE: sadly need to track state during stream
# If chat completion mode, we need a special stream reader to
# turn function argument to send_message into a normal text stream
self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler(json_key=assistant_message_function_kwarg)
self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler(json_key=assistant_message_tool_kwarg)

self._chunks = deque()
self._event = asyncio.Event() # Use an event to notify when chunks are available
Expand All @@ -300,9 +299,9 @@ def __init__(
self.multi_step_gen_indicator = MessageStreamStatus.done_generation

# Support for AssistantMessage
self.use_assistant_message = use_assistant_message
self.assistant_message_function_name = assistant_message_function_name
self.assistant_message_function_kwarg = assistant_message_function_kwarg
self.use_assistant_message = False # TODO: Remove this
self.assistant_message_tool_name = assistant_message_tool_name
self.assistant_message_tool_kwarg = assistant_message_tool_kwarg

# Support for inner_thoughts_in_kwargs
self.inner_thoughts_in_kwargs = inner_thoughts_in_kwargs
Expand Down Expand Up @@ -455,17 +454,14 @@ def _process_chunk_to_letta_style(

# If we get a "hit" on the special keyword we're looking for, we want to skip to the next chunk
# TODO I don't think this handles the function name in multi-pieces problem. Instead, we should probably reset the streaming_chat_completion_mode_function_name when we make this hit?
# if self.streaming_chat_completion_mode_function_name == self.assistant_message_function_name:
if tool_call.function.name == self.assistant_message_function_name:
# if self.streaming_chat_completion_mode_function_name == self.assistant_message_tool_name:
if tool_call.function.name == self.assistant_message_tool_name:
self.streaming_chat_completion_json_reader.reset()
# early exit to turn into content mode
return None

# if we're in the middle of parsing a send_message, we'll keep processing the JSON chunks
if (
tool_call.function.arguments
and self.streaming_chat_completion_mode_function_name == self.assistant_message_function_name
):
if tool_call.function.arguments and self.streaming_chat_completion_mode_function_name == self.assistant_message_tool_name:
# Strip out any extras tokens
cleaned_func_args = self.streaming_chat_completion_json_reader.process_json_chunk(tool_call.function.arguments)
# In the case that we just have the prefix of something, no message yet, then we should early exit to move to the next chunk
Expand Down Expand Up @@ -500,9 +496,6 @@ def _process_chunk_to_letta_style(
)

elif self.inner_thoughts_in_kwargs and tool_call.function:
if self.use_assistant_message:
mattzh72 marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError("inner_thoughts_in_kwargs with use_assistant_message not yet supported")

processed_chunk = None

if tool_call.function.name:
Expand Down Expand Up @@ -909,13 +902,13 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None):

if (
self.use_assistant_message
and function_call.function.name == self.assistant_message_function_name
and self.assistant_message_function_kwarg in func_args
and function_call.function.name == self.assistant_message_tool_name
and self.assistant_message_tool_kwarg in func_args
):
processed_chunk = AssistantMessage(
id=msg_obj.id,
date=msg_obj.created_at,
assistant_message=func_args[self.assistant_message_function_kwarg],
assistant_message=func_args[self.assistant_message_tool_kwarg],
)
else:
processed_chunk = FunctionCallMessage(
Expand Down
1 change: 0 additions & 1 deletion letta/server/rest_api/routers/openai/assistants/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ def list_messages(
before=before_uuid,
order_by="created_at",
reverse=reverse,
return_message_object=True,
)
assert isinstance(json_messages, List)
assert all([isinstance(message, Message) for message in json_messages])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ async def create_chat_completion(
stream_tokens=True,
# Turn on ChatCompletion mode (eg remaps send_message to content)
chat_completion_mode=True,
return_message_object=False,
)

else:
Expand All @@ -86,7 +85,6 @@ async def create_chat_completion(
# Turn streaming OFF
stream_steps=False,
stream_tokens=False,
return_message_object=False,
)
# print(response_messages)

Expand Down
Loading
Loading