Skip to content

Commit d772fc9

Browse files
committed
feat: add multimodal io rails support
feat: handle list type in event text processing feat: add support for Together engine fix(logging): ensure msg.content is a string in callbacks fix(llm): ensure all responses and key items are strings feat(prompts): add general task for OpenAI models fix(prompts): nsure model name is lowercase fix(llm): improve multimodal content handling fix(takmanager): handle multimodal content in message processing feat: add tests for multimodality feat: add example configs for vision content safety
1 parent 4114202 commit d772fc9

File tree

13 files changed

+433
-16
lines changed

13 files changed

+433
-16
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
models:
2+
- type: main
3+
engine: openai
4+
model: gpt-4o-mini
5+
6+
- type: vision_rails
7+
engine: openai
8+
model: gpt-4o-mini
9+
rails:
10+
input:
11+
flows:
12+
- content safety check input $model=vision_rails
13+
dialog:
14+
single_call:
15+
enabled: False
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
prompts:
2+
- task: content_safety_check_input $model=vision_rails
3+
messages:
4+
- type: system
5+
content: |
6+
You are an AI assistant that evaluates if a user's message contains unsafe content.
7+
Carefully examine any text and images in the user's message.
8+
9+
{%- if categories is not defined -%}
10+
{%- set categories = {
11+
"S1": "Violent Crimes.",
12+
"S2": "Non-Violent Crimes.",
13+
"S3": "Sex Crimes.",
14+
"S4": "Child Exploitation.",
15+
"S5": "Defamation.",
16+
"S6": "Specialized Advice.",
17+
"S7": "Privacy.",
18+
"S8": "Intellectual Property.",
19+
"S10": "Hate.",
20+
"S11": "Self-Harm.",
21+
"S12": "Sexual Content.",
22+
"S13": "Elections."
23+
} -%}
24+
{%- endif -%}
25+
26+
Task: Check if there is unsafe content in "User" messages in conversations according to our safety policy with the below categories.
27+
28+
<BEGIN UNSAFE CONTENT CATEGORIES>
29+
{%- for key, desc in categories.items() -%}
30+
{%- if key not in excluded_category_keys -%}
31+
{{ key }}: {{ desc }}
32+
{%- endif -%}
33+
{%- endfor -%}
34+
<END UNSAFE CONTENT CATEGORIES>
35+
36+
IMPORTANT:
37+
Provide your safety assessment for the user message:
38+
- First line must read 'safe' or 'unsafe' and nothing more
39+
- If unsafe, a second line must include a comma-separated list of violated categories.
40+
- type: user
41+
content: "{{ user_input }}"
42+
43+
stop: ["<|eot_id|>", "<|eom_id|>"]
44+
output_parser: is_content_safe
45+
max_tokens: 200

nemoguardrails/actions/llm/generation.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
import re
2222
import sys
2323
import threading
24-
import uuid
25-
from ast import literal_eval
2624
from functools import lru_cache
2725
from time import time
2826
from typing import Callable, List, Optional, cast
@@ -374,6 +372,12 @@ async def generate_user_intent(
374372
# We search for the most relevant similar user utterance
375373
examples = ""
376374
potential_user_intents = []
375+
if isinstance(event["text"], list):
376+
text = " ".join(
377+
[item["text"] for item in event["text"] if item["type"] == "text"]
378+
)
379+
else:
380+
text = event["text"]
377381

378382
if self.user_message_index is not None:
379383
threshold = None
@@ -384,7 +388,7 @@ async def generate_user_intent(
384388
)
385389

386390
results = await self.user_message_index.search(
387-
text=event["text"], max_results=5, threshold=threshold
391+
text=text, max_results=5, threshold=threshold
388392
)
389393

390394
# If the option to use only the embeddings is activated, we take the first
@@ -409,7 +413,7 @@ async def generate_user_intent(
409413
)
410414
else:
411415
results = await self.user_message_index.search(
412-
text=event["text"], max_results=5
416+
text=text, max_results=5
413417
)
414418
# We add these in reverse order so the most relevant is towards the end.
415419
for result in reversed(results):

nemoguardrails/llm/filters.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def co_v2(
100100
history += f' bot say "{event["script"]}"\n'
101101

102102
elif event["type"] == "StartTool":
103-
s = f' await {event["flow_name"]}'
103+
s = f" await {event['flow_name']}"
104104
for k, v in event.items():
105105
if k in [
106106
"type",
@@ -275,13 +275,19 @@ def verbose_v1(colang_history: str) -> str:
275275

276276

277277
def to_chat_messages(events: List[dict]) -> str:
278-
"""Filter that turns an array of events into a sequence of user/assistant messages."""
278+
"""Filter that turns an array of events into a sequence of user/assistant messages.
279+
280+
Properly handles multimodal content by preserving the structure when the content
281+
is in the format of a Message object with potential image_url content.
282+
"""
279283
messages = []
280284
for event in events:
281285
if event["type"] == "UserMessage":
282-
messages.append({"type": "user", "content": event["text"]})
286+
# Preserve the original structure when possible to support multimodal content
287+
content = event["text"]
288+
messages.append({"role": "user", "content": content})
283289
elif event["type"] == "StartUtteranceBotAction":
284-
messages.append({"type": "assistant", "content": event["script"]})
290+
messages.append({"role": "assistant", "content": event["script"]})
285291

286292
return messages
287293

@@ -296,11 +302,30 @@ def user_assistant_sequence(events: List[dict]) -> str:
296302
User: What can you do?
297303
Assistant: I can help with many things.
298304
```
305+
306+
For multimodal content, it extracts text content and indicates if there were images.
299307
"""
300308
history_items = []
301309
for event in events:
302310
if event["type"] == "UserMessage":
303-
history_items.append("User: " + event["text"])
311+
content = event["text"]
312+
# Handle multimodal content by extracting text
313+
if isinstance(content, list):
314+
text_parts = []
315+
has_images = False
316+
for item in content:
317+
if isinstance(item, dict):
318+
if item.get("type") == "text":
319+
text_parts.append(item.get("text", ""))
320+
elif item.get("type") == "image_url":
321+
has_images = True
322+
text_content = " ".join(text_parts)
323+
if has_images:
324+
text_content += " [+ image]"
325+
history_items.append("User: " + text_content)
326+
else:
327+
# Regular text content
328+
history_items.append("User: " + str(content))
304329
elif event["type"] == "StartUtteranceBotAction":
305330
history_items.append("Assistant: " + event["script"])
306331

@@ -375,7 +400,8 @@ def user_assistant_sequence_nemollm(events: List[dict]) -> str:
375400
history_items = []
376401
for event in events:
377402
if event["type"] == "UserMessage":
378-
history_items.append("<extra_id_1>User\n" + event["text"])
403+
# Convert text to string regardless of type (handles both text and multimodal)
404+
history_items.append("<extra_id_1>User\n" + str(event["text"]))
379405
elif event["type"] == "StartUtteranceBotAction":
380406
history_items.append("<extra_id_1>Assistant\n" + event["script"])
381407

nemoguardrails/llm/prompts.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
"""Prompts for the various steps in the interaction."""
17+
1718
import os
1819
from typing import List, Union
1920

@@ -64,6 +65,7 @@ def _get_prompt(
6465
matching_prompt = None
6566
matching_score = 0
6667

68+
model = model.lower()
6769
for prompt in prompts:
6870
if prompt.task != task_name:
6971
continue

nemoguardrails/llm/prompts/openai-chatgpt.yml

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
11
# Prompts for OpenAI ChatGPT.
22
prompts:
3+
- task: general
4+
models:
5+
- openai/gpt-3.5-turbo
6+
- openai/gpt-4
7+
messages:
8+
- type: system
9+
content: |
10+
{{ general_instructions }}{% if relevant_chunks != None and relevant_chunks != '' %}
11+
This is some relevant context:
12+
```markdown
13+
{{ relevant_chunks }}
14+
```{% endif %}
15+
- "{{ history | to_chat_messages }}"
16+
317
- task: generate_user_intent
418
models:
519
- openai/gpt-3.5-turbo
@@ -305,4 +319,4 @@ prompts:
305319
messages:
306320
- type: system
307321
content: |-
308-
{{ flow_nld }}
322+
{{ flow_nld }}

nemoguardrails/llm/providers/providers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,15 @@ def get_llm_provider(model_config: Model) -> Type[BaseLLM]:
304304
"Could not import langchain_google_vertexai, please install it with "
305305
"`pip install langchain-google-vertexai`."
306306
)
307+
elif model_config.engine == "together":
308+
try:
309+
from langchain_together.chat_models import ChatTogether
310+
311+
return ChatTogether
312+
except ImportError:
313+
raise ImportError(
314+
"Could not import langchain_together, please install it with "
315+
)
307316

308317
else:
309318
return _providers[model_config.engine]

nemoguardrails/llm/taskmanager.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,16 @@ def _get_messages_text_length(self, messages: List[dict]) -> int:
207207
"""Return the length of the text in the messages."""
208208
text = ""
209209
for message in messages:
210-
text += message["content"] + "\n"
210+
content = message["content"]
211+
# Handle multimodal content (when content is a list)
212+
if isinstance(content, list):
213+
# Extract text from multimodal content
214+
for item in content:
215+
if isinstance(item, dict) and item.get("type") == "text":
216+
text += item.get("text", "") + "\n"
217+
else:
218+
# Regular string content
219+
text += content + "\n"
211220
return len(text)
212221

213222
def render_task_prompt(

nemoguardrails/logging/callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ async def on_chat_model_start(
113113
)
114114
+ "[/]"
115115
+ "\n"
116-
+ msg.content
116+
+ (msg.content if isinstance(msg.content, str) else "")
117117
for msg in messages[0]
118118
]
119119
)

nemoguardrails/rails/llm/llmrails.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,13 @@ async def generate_async(
758758

759759
if exception:
760760
new_message = {"role": "exception", "content": exception}
761+
761762
else:
763+
# Ensure all items in responses are strings
764+
responses = [
765+
str(response) if not isinstance(response, str) else response
766+
for response in responses
767+
]
762768
new_message = {"role": "assistant", "content": "\n".join(responses)}
763769
if response_tool_calls:
764770
new_message["tool_calls"] = response_tool_calls

0 commit comments

Comments
 (0)