Skip to content

Commit 8215095

Browse files
authored
feat!: add support for preserving and optionally applying guardrails to reasoning traces (#1145)
1 parent 5b5fea0 commit 8215095

File tree

19 files changed

+1187
-114
lines changed

19 files changed

+1187
-114
lines changed

nemoguardrails/actions/llm/generation.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,14 @@
5050
generation_options_var,
5151
llm_call_info_var,
5252
raw_llm_request,
53+
reasoning_trace_var,
5354
streaming_handler_var,
5455
)
5556
from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem
5657
from nemoguardrails.kb.kb import KnowledgeBase
5758
from nemoguardrails.llm.params import llm_params
5859
from nemoguardrails.llm.prompts import get_prompt
59-
from nemoguardrails.llm.taskmanager import LLMTaskManager
60+
from nemoguardrails.llm.taskmanager import LLMTaskManager, ParsedTaskOutput
6061
from nemoguardrails.llm.types import Task
6162
from nemoguardrails.logging.explain import LLMCallInfo
6263
from nemoguardrails.patch_asyncio import check_sync_call_from_async_loop
@@ -442,6 +443,7 @@ async def generate_user_intent(
442443
result = self.llm_task_manager.parse_task_output(
443444
Task.GENERATE_USER_INTENT, output=result
444445
)
446+
result = result.text
445447

446448
user_intent = get_first_nonempty_line(result)
447449
if user_intent is None:
@@ -530,6 +532,11 @@ async def generate_user_intent(
530532
text = self.llm_task_manager.parse_task_output(
531533
Task.GENERAL, output=text
532534
)
535+
536+
text = _process_parsed_output(
537+
text, self._include_reasoning_traces()
538+
)
539+
533540
else:
534541
# Initialize the LLMCallInfo object
535542
llm_call_info_var.set(LLMCallInfo(task=Task.GENERAL.value))
@@ -565,6 +572,8 @@ async def generate_user_intent(
565572
text = self.llm_task_manager.parse_task_output(
566573
Task.GENERAL, output=result
567574
)
575+
576+
text = _process_parsed_output(text, self._include_reasoning_traces())
568577
text = text.strip()
569578
if text.startswith('"'):
570579
text = text[1:-1]
@@ -646,6 +655,7 @@ async def generate_next_step(
646655
result = self.llm_task_manager.parse_task_output(
647656
Task.GENERATE_NEXT_STEPS, output=result
648657
)
658+
result = result.text
649659

650660
# If we don't have multi-step generation enabled, we only look at the first line.
651661
if not self.config.enable_multi_step_generation:
@@ -900,6 +910,10 @@ async def generate_bot_message(
900910
Task.GENERAL, output=result
901911
)
902912

913+
result = _process_parsed_output(
914+
result, self._include_reasoning_traces()
915+
)
916+
903917
log.info(
904918
"--- :: LLM Bot Message Generation passthrough call took %.2f seconds",
905919
time() - t0,
@@ -963,6 +977,10 @@ async def generate_bot_message(
963977
Task.GENERATE_BOT_MESSAGE, output=result
964978
)
965979

980+
result = _process_parsed_output(
981+
result, self._include_reasoning_traces()
982+
)
983+
966984
# TODO: catch openai.error.InvalidRequestError from exceeding max token length
967985

968986
result = get_multiline_response(result)
@@ -1055,6 +1073,7 @@ async def generate_value(
10551073
result = self.llm_task_manager.parse_task_output(
10561074
Task.GENERATE_VALUE, output=result
10571075
)
1076+
result = result.text
10581077

10591078
# We only use the first line for now
10601079
# TODO: support multi-line values?
@@ -1266,6 +1285,7 @@ async def generate_intent_steps_message(
12661285
result = self.llm_task_manager.parse_task_output(
12671286
Task.GENERATE_INTENT_STEPS_MESSAGE, output=result
12681287
)
1288+
result = result.text
12691289

12701290
# TODO: Implement logic for generating more complex Colang next steps (multi-step),
12711291
# not just a single bot intent.
@@ -1348,6 +1368,7 @@ async def generate_intent_steps_message(
13481368
result = self.llm_task_manager.parse_task_output(
13491369
Task.GENERAL, output=result
13501370
)
1371+
result = _process_parsed_output(result, self._include_reasoning_traces())
13511372
text = result.strip()
13521373
if text.startswith('"'):
13531374
text = text[1:-1]
@@ -1360,6 +1381,10 @@ async def generate_intent_steps_message(
13601381
events=[new_event_dict("BotMessage", text=text)],
13611382
)
13621383

1384+
def _include_reasoning_traces(self) -> bool:
1385+
"""Get the configuration value for whether to include reasoning traces in output."""
1386+
return _get_apply_to_reasoning_traces(self.config)
1387+
13631388

13641389
def clean_utterance_content(utterance: str) -> str:
13651390
"""
@@ -1377,3 +1402,27 @@ def clean_utterance_content(utterance: str) -> str:
13771402
# It should be translated to an actual \n character.
13781403
utterance = utterance.replace("\\n", "\n")
13791404
return utterance
1405+
1406+
1407+
def _record_reasoning_trace(trace: str) -> None:
1408+
"""Store the reasoning trace in context for later retrieval."""
1409+
reasoning_trace_var.set(trace)
1410+
1411+
1412+
def _assemble_response(text: str, trace: Optional[str], include_reasoning: bool) -> str:
1413+
"""Combine trace and text if requested, otherwise just return text."""
1414+
return (trace + text) if (trace and include_reasoning) else text
1415+
1416+
1417+
def _process_parsed_output(
1418+
output: ParsedTaskOutput, include_reasoning_trace: bool
1419+
) -> str:
1420+
"""Record trace, then assemble the final LLM response."""
1421+
if reasoning_trace := output.reasoning_trace:
1422+
_record_reasoning_trace(reasoning_trace)
1423+
return _assemble_response(output.text, reasoning_trace, include_reasoning_trace)
1424+
1425+
1426+
def _get_apply_to_reasoning_traces(config: RailsConfig) -> bool:
1427+
"""Get the configuration value for whether to include reasoning traces in output."""
1428+
return config.rails.output.apply_to_reasoning_traces

nemoguardrails/actions/llm/utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from nemoguardrails.colang.v2_x.lang.colang_ast import Flow
2626
from nemoguardrails.colang.v2_x.runtime.flows import InternalEvent, InternalEvents
27-
from nemoguardrails.context import llm_call_info_var
27+
from nemoguardrails.context import llm_call_info_var, reasoning_trace_var
2828
from nemoguardrails.logging.callbacks import logging_callbacks
2929
from nemoguardrails.logging.explain import LLMCallInfo
3030

@@ -192,7 +192,7 @@ def get_colang_history(
192192
and event["action_name"] == "retrieve_relevant_chunks"
193193
):
194194
continue
195-
history += f'execute {event["action_name"]}\n'
195+
history += f"execute {event['action_name']}\n"
196196
elif event["type"] == "InternalSystemActionFinished" and not event.get(
197197
"is_system_action"
198198
):
@@ -577,3 +577,15 @@ def escape_flow_name(name: str) -> str:
577577
# removes non-word chars and leading digits in a word
578578
result = re.sub(r"\b\d+|[^\w\s]", "", result)
579579
return result
580+
581+
582+
def get_and_clear_reasoning_trace_contextvar() -> Optional[str]:
583+
"""Get the current reasoning trace and clear it from the context.
584+
585+
Returns:
586+
Optional[str]: The reasoning trace if one exists, None otherwise.
587+
"""
588+
if reasoning_trace := reasoning_trace_var.get():
589+
reasoning_trace_var.set(None)
590+
return reasoning_trace
591+
return None

nemoguardrails/actions/v2_x/generation.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ async def _collect_user_intent_and_examples(
197197

198198
# We add these in reverse order so the most relevant is towards the end.
199199
for result in reversed(results):
200-
examples += f"user action: user said \"{result.text}\"\nuser intent: {result.meta['intent']}\n\n"
200+
examples += f'user action: user said "{result.text}"\nuser intent: {result.meta["intent"]}\n\n'
201201
if result.meta["intent"] not in potential_user_intents:
202202
potential_user_intents.append(result.meta["intent"])
203203

@@ -302,6 +302,8 @@ async def generate_user_intent(
302302
Task.GENERATE_USER_INTENT_FROM_USER_ACTION, output=result
303303
)
304304

305+
result = result.text
306+
305307
user_intent = get_first_nonempty_line(result)
306308
# GTP-4o often adds 'user intent: ' in front
307309
if user_intent and ":" in user_intent:
@@ -378,6 +380,8 @@ async def generate_user_intent_and_bot_action(
378380
Task.GENERATE_USER_INTENT_AND_BOT_ACTION_FROM_USER_ACTION, output=result
379381
)
380382

383+
result = result.text
384+
381385
user_intent = get_first_nonempty_line(result)
382386

383387
if user_intent and ":" in user_intent:
@@ -458,6 +462,8 @@ async def passthrough_llm_action(
458462

459463
text = self.llm_task_manager.parse_task_output(Task.GENERAL, output=text)
460464

465+
text = result.text
466+
461467
return text
462468

463469
@action(name="CheckValidFlowExistsAction", is_system_action=True)
@@ -541,6 +547,8 @@ async def generate_flow_from_instructions(
541547
task=Task.GENERATE_FLOW_FROM_INSTRUCTIONS, output=result
542548
)
543549

550+
result = result.text
551+
544552
# TODO: why this is not part of a filter or output_parser?
545553
#
546554
lines = _remove_leading_empty_lines(result).split("\n")
@@ -613,6 +621,8 @@ async def generate_flow_from_name(
613621
task=Task.GENERATE_FLOW_FROM_NAME, output=result
614622
)
615623

624+
result = result.text
625+
616626
lines = _remove_leading_empty_lines(result).split("\n")
617627

618628
if lines[0].startswith("flow"):
@@ -680,6 +690,8 @@ async def generate_flow_continuation(
680690
task=Task.GENERATE_FLOW_CONTINUATION, output=result
681691
)
682692

693+
result = result.text
694+
683695
lines = _remove_leading_empty_lines(result).split("\n")
684696

685697
if len(lines) == 0 or (len(lines) == 1 and lines[0] == ""):
@@ -806,6 +818,8 @@ async def generate_value(
806818
Task.GENERATE_VALUE_FROM_INSTRUCTION, output=result
807819
)
808820

821+
result = result.text
822+
809823
# We only use the first line for now
810824
# TODO: support multi-line values?
811825
value = result.strip().split("\n")[0]
@@ -913,6 +927,8 @@ async def generate_flow(
913927
Task.GENERATE_FLOW_CONTINUATION_FROM_NLD, output=result
914928
)
915929

930+
result = result.text
931+
916932
result = _remove_leading_empty_lines(result)
917933
lines = result.split("\n")
918934
if "codeblock" in lines[0]:

nemoguardrails/context.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import contextvars
17+
from typing import Optional
1718

1819
streaming_handler_var = contextvars.ContextVar("streaming_handler", default=None)
1920

@@ -32,3 +33,7 @@
3233
# The raw LLM request that comes from the user.
3334
# This is used in passthrough mode.
3435
raw_llm_request = contextvars.ContextVar("raw_llm_request", default=None)
36+
37+
reasoning_trace_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
38+
"reasoning_trace", default=None
39+
)

nemoguardrails/library/content_safety/actions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ async def content_safety_check_input(
8080
result = await llm_call(llm, check_input_prompt, stop=stop)
8181

8282
result = llm_task_manager.parse_task_output(task, output=result)
83+
result = result.text
8384

8485
try:
8586
is_safe, violated_policies = result
@@ -162,6 +163,8 @@ async def content_safety_check_output(
162163

163164
result = llm_task_manager.parse_task_output(task, output=result)
164165

166+
result = result.text
167+
165168
try:
166169
is_safe, violated_policies = result
167170
except TypeError:

nemoguardrails/library/self_check/facts/actions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ async def self_check_facts(
8282
task, output=response, forced_output_parser="is_content_safe"
8383
)
8484

85+
result = result.text
8586
is_not_safe, _ = result
8687

8788
result = float(not is_not_safe)

nemoguardrails/library/self_check/input_check/actions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ async def self_check_input(
8383
task, output=response, forced_output_parser="is_content_safe"
8484
)
8585

86+
result = result.text
8687
is_safe, _ = result
8788

8889
if not is_safe:

nemoguardrails/library/self_check/output_check/actions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ async def self_check_output(
8787
task, output=response, forced_output_parser="is_content_safe"
8888
)
8989

90+
result = result.text
9091
is_safe, _ = result
9192

9293
return is_safe

0 commit comments

Comments
 (0)