|
16 | 16 | from typing import List, Optional |
17 | 17 |
|
18 | 18 | import pytest |
19 | | -from langchain_core.messages import AIMessage |
| 19 | +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage |
| 20 | +from langchain_core.prompt_values import ChatPromptValue |
20 | 21 | from langchain_core.prompts import ChatPromptTemplate, PromptTemplate |
21 | 22 | from langchain_core.runnables import ( |
22 | 23 | Runnable, |
@@ -153,6 +154,63 @@ def test_dict_messages_in_dict_messages_out(): |
153 | 154 | assert result["output"] == {"role": "assistant", "content": "Paris."} |
154 | 155 |
|
155 | 156 |
|
| 157 | +def test_dict_system_message_in_dict_messages_out(): |
| 158 | + """Tests that SystemMessage is correctly handled.""" |
| 159 | + llm = FakeLLM( |
| 160 | + responses=[ |
| 161 | + "Okay.", |
| 162 | + ] |
| 163 | + ) |
| 164 | + config = RailsConfig.from_content(config={"models": []}) |
| 165 | + model_with_rails = RunnableRails(config, llm=llm) |
| 166 | + |
| 167 | + original_generate_async = model_with_rails.rails.generate_async |
| 168 | + messages_passed = None |
| 169 | + |
| 170 | + async def mock_generate_async(*args, **kwargs): |
| 171 | + nonlocal messages_passed |
| 172 | + messages_passed = kwargs.get("messages") |
| 173 | + return await original_generate_async(*args, **kwargs) |
| 174 | + |
| 175 | + model_with_rails.rails.generate_async = mock_generate_async |
| 176 | + |
| 177 | + result = model_with_rails.invoke( |
| 178 | + input={ |
| 179 | + "input": [ |
| 180 | + {"role": "system", "content": "You are a helpful assistant."}, |
| 181 | + {"role": "user", "content": "Got it?"}, |
| 182 | + ] |
| 183 | + } |
| 184 | + ) |
| 185 | + |
| 186 | + assert isinstance(result, dict) |
| 187 | + assert result["output"] == {"role": "assistant", "content": "Okay."} |
| 188 | + assert messages_passed == [ |
| 189 | + {"role": "system", "content": "You are a helpful assistant."}, |
| 190 | + {"role": "user", "content": "Got it?"}, |
| 191 | + ] |
| 192 | + |
| 193 | + |
| 194 | +def test_list_system_message_in_list_messages_out(): |
| 195 | + """Tests that SystemMessage is correctly handled when input is ChatPromptValue.""" |
| 196 | + llm_response = "Intent: user asks question" |
| 197 | + llm = FakeLLM(responses=[llm_response]) |
| 198 | + |
| 199 | + config = RailsConfig.from_content(config={"models": []}) |
| 200 | + model_with_rails = RunnableRails(config) |
| 201 | + |
| 202 | + chain = model_with_rails | llm |
| 203 | + |
| 204 | + input_messages = [ |
| 205 | + SystemMessage(content="You are a helpful assistant."), |
| 206 | + HumanMessage(content="Got it?"), |
| 207 | + ] |
| 208 | + result = chain.invoke(input=ChatPromptValue(messages=input_messages)) |
| 209 | + |
| 210 | + assert isinstance(result, AIMessage) |
| 211 | + assert result.content == llm_response |
| 212 | + |
| 213 | + |
156 | 214 | def test_context_passing(): |
157 | 215 | llm = FakeLLM( |
158 | 216 | responses=[ |
|
0 commit comments