Skip to content

Commit

Permalink
changed message base and human message, and groq
Browse files Browse the repository at this point in the history
  • Loading branch information
Ishaan Gupta committed Oct 1, 2024
1 parent 7398ad0 commit 2a6d773
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 13 deletions.
23 changes: 14 additions & 9 deletions pkgs/swarmauri/swarmauri/llms/concrete/GroqModel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import List, Optional, Dict, Literal
from typing import List, Optional, Dict, Literal, Any
from groq import Groq
from swarmauri_core.typing import SubclassUnion

Expand Down Expand Up @@ -32,14 +32,19 @@ class GroqModel(LLMBase):
name: str = "gemma-7b-it"
type: Literal["GroqModel"] = "GroqModel"

def _format_messages(
self, messages: List[SubclassUnion[MessageBase]]
) -> List[Dict[str, str]]:
message_properties = ["content", "role", "name"]
formatted_messages = [
message.model_dump(include=message_properties, exclude_none=True)
for message in messages
]
def _format_messages(messages: List[SubclassUnion[MessageBase]]) -> List[Dict[str, Any]]:
formatted_messages = []
for message in messages:
formatted_message = message.model_dump(
include=["content", "role", "name"], exclude_none=True
)

if isinstance(formatted_message["content"], list):
formatted_message["content"] = [
{"type": item["type"], **item} for item in formatted_message["content"]
]

formatted_messages.append(formatted_message)
return formatted_messages

def predict(
Expand Down
4 changes: 2 additions & 2 deletions pkgs/swarmauri/swarmauri/messages/base/MessageBase.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Optional, Tuple, Literal
from typing import Optional, Tuple, Literal, List, Dict, Union
from pydantic import PrivateAttr, ConfigDict, Field
from swarmauri_core.ComponentBase import ComponentBase, ResourceTypes
from swarmauri_core.messages.IMessage import IMessage

class MessageBase(IMessage, ComponentBase):
content: str
content: Union[str, Dict , List[Dict]]
role: str
model_config = ConfigDict(extra='forbid', arbitrary_types_allowed=True)
resource: Optional[str] = Field(default=ResourceTypes.MESSAGE.value, frozen=True)
Expand Down
16 changes: 14 additions & 2 deletions pkgs/swarmauri/swarmauri/messages/concrete/HumanMessage.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
from typing import Optional, Any, Literal
from typing import Optional, Any, Literal, List
from pydantic import Field
from swarmauri.messages.base.MessageBase import MessageBase
from typing import Union, Dict
from typing_extensions import TypedDict
# Define specific content types
class TextContent(TypedDict):
type: str
text: str

class ImageUrlContent(TypedDict):
type: str
image_url: Union[str, Dict]

contentItem = Union[TextContent, ImageUrlContent]

class HumanMessage(MessageBase):
content: str
content: Optional[Union[str, List[contentItem]]]
role: str = Field(default='user')
name: Optional[str] = None
type: Literal['HumanMessage'] = 'HumanMessage'

0 comments on commit 2a6d773

Please sign in to comment.