Skip to content

Commit

Permalink
Merge pull request #167 from biocypher/multimodal
Browse files Browse the repository at this point in the history
Multimodal model support
  • Loading branch information
slobentanzer committed Jun 21, 2024
2 parents 113c874 + b394b38 commit b86a4df
Show file tree
Hide file tree
Showing 4 changed files with 265 additions and 13 deletions.
130 changes: 117 additions & 13 deletions biochatter/llm_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
st = None

from abc import ABC, abstractmethod
import base64
from typing import Optional
import json
import logging
Expand All @@ -20,6 +21,7 @@
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
import nltk
import openai
import urllib.parse

from ._stats import get_stats
from .rag_agent import RagAgent
Expand Down Expand Up @@ -56,6 +58,12 @@
}


# Function to encode the image
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")


class Conversation(ABC):
"""
Expand Down Expand Up @@ -121,34 +129,88 @@ def get_prompts(self):
def set_prompts(self, prompts: dict):
self.prompts = prompts

def append_ai_message(self, message: str):
def append_ai_message(self, message: str) -> None:
"""
Add a message from the AI to the conversation.
Args:
message (str): The message from the AI.
"""
self.messages.append(
AIMessage(
content=message,
),
)

def append_system_message(self, message: str):
def append_system_message(self, message: str) -> None:
"""
Add a system message to the conversation.
Args:
message (str): The system message.
"""
self.messages.append(
SystemMessage(
content=message,
),
)

def append_ca_message(self, message: str):
def append_ca_message(self, message: str) -> None:
"""
Add a message to the correcting agent conversation.
Args:
message (str): The message to the correcting agent.
"""
self.ca_messages.append(
SystemMessage(
content=message,
),
)

def append_user_message(self, message: str):
def append_user_message(self, message: str) -> None:
"""
Add a message from the user to the conversation.
Args:
message (str): The message from the user.
"""
self.messages.append(
HumanMessage(
content=message,
),
)

def append_image_message(
self, message: str, image_url: str, local: bool = False
) -> None:
"""
Add a user message with an image to the conversation. Also checks, in
addition to the `local` flag, if the image URL is a local file path.
If it is local, the image will be encoded as a base64 string to be
passed to the LLM.
Args:
message (str): The message from the user.
image_url (str): The URL of the image.
local (bool): Whether the image is local or not. If local, it will
be encoded as a base64 string to be passed to the LLM.
"""
parsed_url = urllib.parse.urlparse(image_url)
if local or not parsed_url.netloc:
image_url = f"data:image/jpeg;base64,{encode_image(image_url)}"

self.messages.append(
HumanMessage(
content=[
{"type": "text", "text": message},
{"type": "image_url", "image_url": {"url": image_url}},
],
),
)

def setup(self, context: str):
"""
Set up the conversation with general prompts and a context.
Expand Down Expand Up @@ -178,8 +240,27 @@ def setup_data_input_tool(self, df, input_file_name: str):
msg = self.prompts["tool_prompts"][tool_name].format(df=df)
self.append_system_message(msg)

def query(self, text: str):
self.append_user_message(text)
def query(self, text: str, image_url: str = None) -> tuple[str, dict, str]:
"""
The main workflow for querying the LLM API. Appends the most recent
query to the conversation, optionally injects context from the RAG
agent, and runs the primary query method of the child class.
Args:
text (str): The user query.
image_url (str): The URL of an image to include in the conversation.
Optional and only supported for models with vision capabilities.
Returns:
tuple: A tuple containing the response from the API, the token usage
information, and the correction if necessary/desired.
"""

if not image_url:
self.append_user_message(text)
else:
self.append_image_message(text, image_url)

self._inject_context(text)

Expand Down Expand Up @@ -587,6 +668,19 @@ def _primary_query(self):

def _create_history(self):
history = []
# extract text components from message contents
msg_texts = [
m.content[0]["text"] if isinstance(m.content, list) else m.content
for m in self.messages
]

# check if last message is an image message
is_image_message = False
if isinstance(self.messages[-1].content, list):
is_image_message = (
self.messages[-1].content[1]["type"] == "image_url"
)

# find location of last AI message (if any)
last_ai_message = None
for i, m in enumerate(self.messages):
Expand All @@ -599,15 +693,15 @@ def _create_history(self):
{
"role": "user",
"content": "\n".join(
[m.content for m in self.messages[:last_ai_message]]
[m for m in msg_texts[:last_ai_message]]
),
}
)
# then append the last AI message
history.append(
{
"role": "assistant",
"content": self.messages[last_ai_message].content,
"content": msg_texts[last_ai_message],
}
)

Expand All @@ -617,10 +711,7 @@ def _create_history(self):
{
"role": "user",
"content": "\n".join(
[
m.content
for m in self.messages[last_ai_message + 1 :]
]
[m for m in msg_texts[last_ai_message + 1 :]]
),
}
)
Expand All @@ -631,10 +722,21 @@ def _create_history(self):
history.append(
{
"role": "user",
"content": "\n".join([m.content for m in self.messages]),
"content": "\n".join([m for m in msg_texts[:]]),
}
)

# if the last message is an image message, add the image to the history
if is_image_message:
history[-1]["content"] = [
{"type": "text", "text": history[-1]["content"]},
{
"type": "image_url",
"image_url": {
"url": self.messages[-1].content[1]["image_url"]["url"]
},
},
]
return history

def _correct_response(self, msg: str):
Expand Down Expand Up @@ -930,6 +1032,8 @@ def __init__(
prompts (dict): A dictionary of prompts to use for the conversation.
correct (bool): Whether to correct the model output.
split_correction (bool): Whether to correct the model output by
splitting the output into sentences and correcting each
sentence individually.
Expand Down
46 changes: 46 additions & 0 deletions docs/chat.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,49 @@ but requires the provision of an API key to the OpenAI API. To do this, you can
designate the `OPENAI_API_KEY` variable in your environment directly (`export
OPENAI_API_KEY=sk-...`) by adding it to your shell configuration (e.g., the
`zshrc`).

## Multimodal models - Text and image

We support multimodal queries in models that offer these capabilities after the
blueprint of the OpenAI API. We can either add an image-containing message to
the conversation using the `append_image_message` method, or we can pass an
image URL directly to the `query` method:

```python
# Either: Append image message
conversation.append_image_message(
message="Here is an attached image",
image_url="https://example.com/image.jpg"
)

# Or: Query with image included
msg, token_usage, correction = conversation.query(
"What's in this image?",
image_url="https://example.com/image.jpg"
)
```

### Using local images

Following the recommendations by OpenAI, we can pass local images as
base64-encoded strings. We allow this by setting the `local` flag to `True` in
the `append_image_message` method:

```python
conversation.append_image_message(
message="Here is an attached image",
image_url="my/local/image.jpg",
local=True
)
```

We also support the use of local images in the `query` method by detecting the
netloc of the image URL. If the netloc is empty, we assume that the image is
local and read it as a base64-encoded string:

```python
msg, token_usage, correction = conversation.query(
"What's in this image?",
image_url="my/local/image.jpg"
)
```
Binary file added test/figure_panel.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
102 changes: 102 additions & 0 deletions test/test_llm_connect.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
from unittest.mock import Mock, patch
import os

Expand Down Expand Up @@ -324,3 +325,104 @@ def test_multiple_cycles_of_ai_and_human(xinference_conversation):
"role": "user",
"content": "System message\nHuman message",
}


@pytest.mark.skip(reason="Live test for development purposes")
def test_append_local_image_gpt():
convo = GptConversation(
model_name="gpt-4o",
prompts={},
correct=False,
split_correction=False,
)
convo.set_api_key(api_key=os.getenv("OPENAI_API_KEY"), user="test_user")

convo.append_system_message(
"You are an editorial assistant to a journal in biomedical science."
)

convo.append_image_message(
message=(
"This text describes the attached image: "
"Live confocal imaging of liver stage P. berghei expressing UIS4-mCherry and cytoplasmic GFP reveals different morphologies of the LS-TVN: elongated membrane clusters (left), vesicles in the host cell cytoplasm (center), and a thin tubule protruding from the PVM (right). Live imaging was performed 20?h after infection of hepatoma cells. Features are marked with white arrowheads."
),
image_url="test/figure_panel.jpg",
local=True,
)

result, _, _ = convo.query("Is the description accurate?")
assert "yes" in result.lower()


@pytest.mark.skip(reason="Live test for development purposes")
def test_local_image_query_gpt():
convo = GptConversation(
model_name="gpt-4o",
prompts={},
correct=False,
split_correction=False,
)
convo.set_api_key(api_key=os.getenv("OPENAI_API_KEY"), user="test_user")

convo.append_system_message(
"You are an editorial assistant to a journal in biomedical science."
)

result, _, _ = convo.query(
"Does this text describe the attached image: Live confocal imaging of liver stage P. berghei expressing UIS4-mCherry and cytoplasmic GFP reveals different morphologies of the LS-TVN: elongated membrane clusters (left), vesicles in the host cell cytoplasm (center), and a thin tubule protruding from the PVM (right). Live imaging was performed 20?h after infection of hepatoma cells. Features are marked with white arrowheads.",
image_url="test/figure_panel.jpg",
)
assert "yes" in result.lower()


@pytest.mark.skip(reason="Live test for development purposes")
def test_append_online_image_gpt():
convo = GptConversation(
model_name="gpt-4o",
prompts={},
correct=False,
split_correction=False,
)
convo.set_api_key(api_key=os.getenv("OPENAI_API_KEY"), user="test_user")

convo.append_image_message(
"This is a picture from the internet.",
image_url="https://upload.wikimedia.org/wikipedia/commons/8/8f/The-Transformer-model-architecture.png",
)

result, _, _ = convo.query("What does this picture show?")
assert "transformer" in result.lower()


@pytest.mark.skip(reason="Live test for development purposes")
def test_online_image_query_gpt():
convo = GptConversation(
model_name="gpt-4o",
prompts={},
correct=False,
split_correction=False,
)
convo.set_api_key(api_key=os.getenv("OPENAI_API_KEY"), user="test_user")

result, _, _ = convo.query(
"What does this picture show?",
image_url="https://upload.wikimedia.org/wikipedia/commons/8/8f/The-Transformer-model-architecture.png",
)
assert "transformer" in result.lower()


@pytest.mark.skip(reason="Live test for development purposes")
def test_local_image_query_xinference():
url = "http://localhost:9997"
convo = XinferenceConversation(
base_url=url,
prompts={},
correct=False,
)
assert convo.set_api_key()

result, _, _ = convo.query(
"Does this text describe the attached image: Live confocal imaging of liver stage P. berghei expressing UIS4-mCherry and cytoplasmic GFP reveals different morphologies of the LS-TVN: elongated membrane clusters (left), vesicles in the host cell cytoplasm (center), and a thin tubule protruding from the PVM (right). Live imaging was performed 20?h after infection of hepatoma cells. Features are marked with white arrowheads.",
image_url="test/figure_panel.jpg",
)
assert isinstance(result, str)

0 comments on commit b86a4df

Please sign in to comment.