Skip to content

Commit

Permalink
Add types agnostic to role (autogenhub#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits authored May 23, 2024
1 parent 8d1f4ae commit 52f6f79
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 0 deletions.
42 changes: 42 additions & 0 deletions src/agnext/chat/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Union

from agnext.agent_components.image import Image
from agnext.agent_components.types import FunctionCall


@dataclass(kw_only=True)
class BaseMessage:
# Name of the agent that sent this message
source: str


@dataclass
class TextMessage(BaseMessage):
content: str


@dataclass
class MultiModalMessage(BaseMessage):
content: List[Union[str, Image]]


@dataclass
class FunctionCallMessage(BaseMessage):
content: List[FunctionCall]


@dataclass
class FunctionExecutionResult:
content: str
call_id: str


@dataclass
class FunctionExecutionResultMessage(BaseMessage):
content: List[FunctionExecutionResult]


Message = Union[TextMessage, MultiModalMessage, FunctionCallMessage, FunctionExecutionResultMessage]
68 changes: 68 additions & 0 deletions src/agnext/chat/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import List, Optional, Union

from agnext.agent_components.types import AssistantMessage, LLMMessage, UserMessage
from agnext.chat.types import FunctionCallMessage, Message, MultiModalMessage, TextMessage
from typing_extensions import Literal


def convert_content_message_to_assistant_message(
message: Union[TextMessage, MultiModalMessage, FunctionCallMessage],
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
) -> Optional[AssistantMessage]:
match message:
case TextMessage() | FunctionCallMessage():
return AssistantMessage(content=message.content, source=message.source)
case MultiModalMessage():
if handle_unrepresentable == "error":
raise ValueError("Cannot represent multimodal message as AssistantMessage")
elif handle_unrepresentable == "ignore":
return None
elif handle_unrepresentable == "try_slice":
return AssistantMessage(
content="".join([x for x in message.content if isinstance(x, str)]), source=message.source
)


def convert_content_message_to_user_message(
message: Union[TextMessage, MultiModalMessage, FunctionCallMessage],
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error",
) -> Optional[UserMessage]:
match message:
case TextMessage() | MultiModalMessage():
return UserMessage(content=message.content, source=message.source)
case FunctionCallMessage():
if handle_unrepresentable == "error":
raise ValueError("Cannot represent multimodal message as UserMessage")
elif handle_unrepresentable == "ignore":
return None
elif handle_unrepresentable == "try_slice":
# TODO: what is a sliced function call?
raise NotImplementedError("Sliced function calls not yet implemented")


def convert_messages_to_llm_messages(
messages: List[Message], self_name: str, handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error"
) -> List[LLMMessage]:
result: List[LLMMessage] = []
for message in messages:
match message:
case (
TextMessage(_, source=source)
| MultiModalMessage(_, source=source)
| FunctionCallMessage(_, source=source)
) if source == self_name:
converted_message_1 = convert_content_message_to_assistant_message(message, handle_unrepresentable)
if converted_message_1 is not None:
result.append(converted_message_1)
case (
TextMessage(_, source=source)
| MultiModalMessage(_, source=source)
| FunctionCallMessage(_, source=source)
) if source != self_name:
converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable)
if converted_message_2 is not None:
result.append(converted_message_2)
case _:
raise AssertionError("unreachable")

return result

0 comments on commit 52f6f79

Please sign in to comment.