forked from autogenhub/autogen
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add types agnostic to role (autogenhub#11)
- Loading branch information
1 parent
8d1f4ae
commit 52f6f79
Showing
2 changed files
with
110 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from typing import List, Union | ||
|
||
from agnext.agent_components.image import Image | ||
from agnext.agent_components.types import FunctionCall | ||
|
||
|
||
@dataclass(kw_only=True) | ||
class BaseMessage: | ||
# Name of the agent that sent this message | ||
source: str | ||
|
||
|
||
@dataclass | ||
class TextMessage(BaseMessage): | ||
content: str | ||
|
||
|
||
@dataclass | ||
class MultiModalMessage(BaseMessage): | ||
content: List[Union[str, Image]] | ||
|
||
|
||
@dataclass | ||
class FunctionCallMessage(BaseMessage): | ||
content: List[FunctionCall] | ||
|
||
|
||
@dataclass | ||
class FunctionExecutionResult: | ||
content: str | ||
call_id: str | ||
|
||
|
||
@dataclass | ||
class FunctionExecutionResultMessage(BaseMessage): | ||
content: List[FunctionExecutionResult] | ||
|
||
|
||
Message = Union[TextMessage, MultiModalMessage, FunctionCallMessage, FunctionExecutionResultMessage] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from typing import List, Optional, Union | ||
|
||
from agnext.agent_components.types import AssistantMessage, LLMMessage, UserMessage | ||
from agnext.chat.types import FunctionCallMessage, Message, MultiModalMessage, TextMessage | ||
from typing_extensions import Literal | ||
|
||
|
||
def convert_content_message_to_assistant_message( | ||
message: Union[TextMessage, MultiModalMessage, FunctionCallMessage], | ||
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error", | ||
) -> Optional[AssistantMessage]: | ||
match message: | ||
case TextMessage() | FunctionCallMessage(): | ||
return AssistantMessage(content=message.content, source=message.source) | ||
case MultiModalMessage(): | ||
if handle_unrepresentable == "error": | ||
raise ValueError("Cannot represent multimodal message as AssistantMessage") | ||
elif handle_unrepresentable == "ignore": | ||
return None | ||
elif handle_unrepresentable == "try_slice": | ||
return AssistantMessage( | ||
content="".join([x for x in message.content if isinstance(x, str)]), source=message.source | ||
) | ||
|
||
|
||
def convert_content_message_to_user_message( | ||
message: Union[TextMessage, MultiModalMessage, FunctionCallMessage], | ||
handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error", | ||
) -> Optional[UserMessage]: | ||
match message: | ||
case TextMessage() | MultiModalMessage(): | ||
return UserMessage(content=message.content, source=message.source) | ||
case FunctionCallMessage(): | ||
if handle_unrepresentable == "error": | ||
raise ValueError("Cannot represent multimodal message as UserMessage") | ||
elif handle_unrepresentable == "ignore": | ||
return None | ||
elif handle_unrepresentable == "try_slice": | ||
# TODO: what is a sliced function call? | ||
raise NotImplementedError("Sliced function calls not yet implemented") | ||
|
||
|
||
def convert_messages_to_llm_messages( | ||
messages: List[Message], self_name: str, handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error" | ||
) -> List[LLMMessage]: | ||
result: List[LLMMessage] = [] | ||
for message in messages: | ||
match message: | ||
case ( | ||
TextMessage(_, source=source) | ||
| MultiModalMessage(_, source=source) | ||
| FunctionCallMessage(_, source=source) | ||
) if source == self_name: | ||
converted_message_1 = convert_content_message_to_assistant_message(message, handle_unrepresentable) | ||
if converted_message_1 is not None: | ||
result.append(converted_message_1) | ||
case ( | ||
TextMessage(_, source=source) | ||
| MultiModalMessage(_, source=source) | ||
| FunctionCallMessage(_, source=source) | ||
) if source != self_name: | ||
converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable) | ||
if converted_message_2 is not None: | ||
result.append(converted_message_2) | ||
case _: | ||
raise AssertionError("unreachable") | ||
|
||
return result |