Skip to content

Commit a44566e

Browse files
smruthi33Smruthi Raj MohanPouyanpi
authored
feat: add support for system messages to RunnableRails (#1106)
* Adding support for system message * test: add test for system message handling in runnable rails add test --------- Co-authored-by: Smruthi Raj Mohan <smruthi.rajmohan@ibm.com> Co-authored-by: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com>
1 parent 6d55b82 commit a44566e

File tree

2 files changed

+62
-2
lines changed

2 files changed

+62
-2
lines changed

nemoguardrails/integrations/langchain/runnable_rails.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import Any, List, Optional
1919

2020
from langchain_core.language_models import BaseLanguageModel
21-
from langchain_core.messages import AIMessage, HumanMessage
21+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
2222
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
2323
from langchain_core.runnables import Runnable
2424
from langchain_core.runnables.config import RunnableConfig
@@ -139,6 +139,8 @@ def _transform_input_to_rails_format(self, _input):
139139
messages.append({"role": "assistant", "content": msg.content})
140140
elif isinstance(msg, HumanMessage):
141141
messages.append({"role": "user", "content": msg.content})
142+
elif isinstance(msg, SystemMessage):
143+
messages.append({"role": "system", "content": msg.content})
142144
elif isinstance(_input, StringPromptValue):
143145
messages.append({"role": "user", "content": _input.text})
144146
elif isinstance(_input, dict):

tests/test_runnable_rails.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from typing import List, Optional
1717

1818
import pytest
19-
from langchain_core.messages import AIMessage
19+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
20+
from langchain_core.prompt_values import ChatPromptValue
2021
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
2122
from langchain_core.runnables import (
2223
Runnable,
@@ -153,6 +154,63 @@ def test_dict_messages_in_dict_messages_out():
153154
assert result["output"] == {"role": "assistant", "content": "Paris."}
154155

155156

157+
def test_dict_system_message_in_dict_messages_out():
158+
"""Tests that SystemMessage is correctly handled."""
159+
llm = FakeLLM(
160+
responses=[
161+
"Okay.",
162+
]
163+
)
164+
config = RailsConfig.from_content(config={"models": []})
165+
model_with_rails = RunnableRails(config, llm=llm)
166+
167+
original_generate_async = model_with_rails.rails.generate_async
168+
messages_passed = None
169+
170+
async def mock_generate_async(*args, **kwargs):
171+
nonlocal messages_passed
172+
messages_passed = kwargs.get("messages")
173+
return await original_generate_async(*args, **kwargs)
174+
175+
model_with_rails.rails.generate_async = mock_generate_async
176+
177+
result = model_with_rails.invoke(
178+
input={
179+
"input": [
180+
{"role": "system", "content": "You are a helpful assistant."},
181+
{"role": "user", "content": "Got it?"},
182+
]
183+
}
184+
)
185+
186+
assert isinstance(result, dict)
187+
assert result["output"] == {"role": "assistant", "content": "Okay."}
188+
assert messages_passed == [
189+
{"role": "system", "content": "You are a helpful assistant."},
190+
{"role": "user", "content": "Got it?"},
191+
]
192+
193+
194+
def test_list_system_message_in_list_messages_out():
195+
"""Tests that SystemMessage is correctly handled when input is ChatPromptValue."""
196+
llm_response = "Intent: user asks question"
197+
llm = FakeLLM(responses=[llm_response])
198+
199+
config = RailsConfig.from_content(config={"models": []})
200+
model_with_rails = RunnableRails(config)
201+
202+
chain = model_with_rails | llm
203+
204+
input_messages = [
205+
SystemMessage(content="You are a helpful assistant."),
206+
HumanMessage(content="Got it?"),
207+
]
208+
result = chain.invoke(input=ChatPromptValue(messages=input_messages))
209+
210+
assert isinstance(result, AIMessage)
211+
assert result.content == llm_response
212+
213+
156214
def test_context_passing():
157215
llm = FakeLLM(
158216
responses=[

0 commit comments

Comments
 (0)