-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Modifying schema to support multi modal inputs. #1673
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -156,6 +156,7 @@ class MessageModel(Base): | |
# openai info | ||
role = Column(String, nullable=False) | ||
text = Column(String) # optional: can be null if function call | ||
mm_content = Column(JSON) # optional: multi-modal input | ||
model = Column(String) # optional: can be null if LLM backend doesn't require specifying | ||
name = Column(String) # optional: multi-agent only | ||
|
||
|
@@ -192,6 +193,7 @@ def to_record(self): | |
role=self.role, | ||
name=self.name, | ||
text=self.text, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly, should delete/deprecate |
||
mm_content=self.mm_content, | ||
model=self.model, | ||
# tool_calls=[ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls] if self.tool_calls else None, | ||
tool_calls=self.tool_calls, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,6 +62,8 @@ class Message(BaseMessage): | |
id: str = BaseMessage.generate_id_field() | ||
role: MessageRole = Field(..., description="The role of the participant.") | ||
text: str = Field(..., description="The text of the message.") | ||
# Field mm_content is only used when role is 'user'. It needs to be mapped to MultiModalMessage | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment here - deprecating |
||
mm_content: List[dict] = Field(None, description="Multi modal content entered by the user.") | ||
user_id: str = Field(None, description="The unique identifier of the user.") | ||
agent_id: str = Field(None, description="The unique identifier of the agent.") | ||
model: Optional[str] = Field(None, description="The model used to make the function call.") | ||
|
@@ -223,8 +225,9 @@ def to_openai_dict( | |
|
||
elif self.role == "user": | ||
assert all([v is not None for v in [self.text, self.role]]), vars(self) | ||
content = self.mm_content if self.mm_content is not None else self.text | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Once we do the above (replace |
||
openai_message = { | ||
"content": self.text, | ||
"content": content, | ||
"role": self.role, | ||
} | ||
# Optional field, do not include if null | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,13 +8,15 @@ class SystemMessage(BaseModel): | |
role: str = "system" | ||
name: Optional[str] = None | ||
|
||
class MultiModalMessage(BaseModel): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
type: str | ||
image_url: str | ||
|
||
class UserMessage(BaseModel): | ||
content: Union[str, List[str]] | ||
content: Union[str, List[MultiModalMessage]] | ||
role: str = "user" | ||
name: Optional[str] = None | ||
|
||
|
||
class ToolCallFunction(BaseModel): | ||
name: str | ||
arguments: str | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we should probably delete
text = Column(String) # optional: can be null if function call
and replace with
content = Column(JSON) # optional: multi-modal input
which in the pydantic model is
Optional[Union[str, List[MultiModalMessagePart]]]
but in the database itself is stored as an optional JSON field