Skip to content

Commit 748e9cd

Browse files
authored
fix: error in LLMRails with tracing enabled (#1103)
Signed-off-by: Giovanni Liva <giovanni.liva@dynatrace.com>
1 parent eaaa58c commit 748e9cd

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

nemoguardrails/rails/llm/llmrails.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,6 @@ async def generate_async(
938938
input=messages, response=res, adapters=self._log_adapters
939939
)
940940
await tracer.export_async()
941-
res = res.response[0]
942941
return res
943942
else:
944943
# If a prompt is used, we only return the content of the message.

tests/test_tracing.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@
1414
# limitations under the License.
1515

1616
import asyncio
17+
import os
18+
19+
import pytest
1720
import unittest
18-
from unittest.mock import AsyncMock, MagicMock
21+
from unittest.mock import AsyncMock, MagicMock, patch
1922

23+
from nemoguardrails import LLMRails
2024
from nemoguardrails.logging.explain import LLMCallInfo
21-
from nemoguardrails.rails.llm.config import TracingConfig
25+
from nemoguardrails.rails.llm.config import TracingConfig, RailsConfig
2226
from nemoguardrails.rails.llm.options import (
2327
ActivatedRail,
2428
ExecutedAction,
@@ -201,5 +205,39 @@ def test_export_async(self):
201205
adapter_non_empty.transform_async.assert_called_once()
202206

203207

208+
@patch.object(Tracer, "export_async", return_value="")
209+
@pytest.mark.asyncio
210+
async def test_tracing_enable_no_crash_issue_1093(mockTracer):
211+
config = RailsConfig.from_content(
212+
colang_content="""
213+
define user express greeting
214+
"hello"
215+
216+
define flow
217+
user express greeting
218+
bot express greeting
219+
220+
define bot express greeting
221+
"Hello World!\\n NewLine World!"
222+
""",
223+
config={
224+
"models": [],
225+
"rails": {"dialog": {"user_messages": {"embeddings_only": True}}},
226+
},
227+
)
228+
# Force Tracing to be enabled
229+
config.tracing.enabled = True
230+
rails = LLMRails(config)
231+
res = await rails.generate_async(
232+
messages=[
233+
{"role": "user", "content": "hi!"},
234+
{"role": "assistant", "content": "hi!"},
235+
{"role": "user", "content": "hi!"},
236+
]
237+
)
238+
assert mockTracer.called == True
239+
assert res.response != None
240+
241+
204242
if __name__ == "__main__":
205243
unittest.main()

0 commit comments

Comments
 (0)