Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
d335a2f
feat(filters): enhance reasoning trace handling and extraction
Pouyanpi Apr 24, 2025
80f9a5f
feat(config): add guardrail reasoning traces option
Pouyanpi Apr 25, 2025
c59e5dd
feat(taskmanager): enhance task output parsing logic
Pouyanpi Apr 25, 2025
166b209
test: update reasoning trace handling in tests
Pouyanpi Apr 25, 2025
26e3f7d
refactor(llm): standardize access to parsed task output text
Pouyanpi Apr 25, 2025
52edbcf
feat(taskmanager): enhance reasoning trace handling
Pouyanpi Apr 29, 2025
753af44
feat(llmrails): prepend reasoning trace to response content
Pouyanpi Apr 30, 2025
3ba82dc
refactor: rename guardrail_reasoning_traces field
Pouyanpi May 1, 2025
72c1974
test: add tests for LLM Rails reasoning output
Pouyanpi May 1, 2025
45153ab
refactor: rename reasoning trace function
Pouyanpi May 1, 2025
ade7776
add edge cases for token removal logic
Pouyanpi May 2, 2025
8439e0d
add async reasoning trace tests
Pouyanpi May 2, 2025
2838164
fix(llmrails): handle reasoning trace with and without prompt
Pouyanpi May 2, 2025
1e3ba77
fix: set contextvar to None
Pouyanpi May 2, 2025
5719b04
add case when there is no reasoning tag
Pouyanpi May 2, 2025
ccc2482
remove unsused code
Pouyanpi May 2, 2025
43ea1d9
enhance reasoning trace extraction test cases
Pouyanpi May 2, 2025
12afd11
enhance trace extraction and stripping
Pouyanpi May 2, 2025
73a32fa
revert style changes
Pouyanpi May 2, 2025
b266c48
rename `apply_to_traces` to `apply_to_reasoning_traces` across test c…
Pouyanpi May 2, 2025
95a4fc6
review: fix docstring
Pouyanpi May 2, 2025
e1ee5e3
review: make start/end tokens optional in helpers
Pouyanpi May 2, 2025
bbe80c2
review: use None instead of empty strings
Pouyanpi May 2, 2025
64ba92a
review: add docstring for parse_task_output
Pouyanpi May 2, 2025
c5c85e8
review: update apply_to_reasoning_traces description
Pouyanpi May 2, 2025
5f63018
review: improve docstring
Pouyanpi May 2, 2025
759eb63
review: update reasoning traces contextvar typehint
Pouyanpi May 2, 2025
0e34af6
resolve conflict after rebase
Pouyanpi May 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion nemoguardrails/actions/llm/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,14 @@
generation_options_var,
llm_call_info_var,
raw_llm_request,
reasoning_trace_var,
streaming_handler_var,
)
from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem
from nemoguardrails.kb.kb import KnowledgeBase
from nemoguardrails.llm.params import llm_params
from nemoguardrails.llm.prompts import get_prompt
from nemoguardrails.llm.taskmanager import LLMTaskManager
from nemoguardrails.llm.taskmanager import LLMTaskManager, ParsedTaskOutput
from nemoguardrails.llm.types import Task
from nemoguardrails.logging.explain import LLMCallInfo
from nemoguardrails.patch_asyncio import check_sync_call_from_async_loop
Expand Down Expand Up @@ -442,6 +443,7 @@ async def generate_user_intent(
result = self.llm_task_manager.parse_task_output(
Task.GENERATE_USER_INTENT, output=result
)
result = result.text

user_intent = get_first_nonempty_line(result)
if user_intent is None:
Expand Down Expand Up @@ -530,6 +532,11 @@ async def generate_user_intent(
text = self.llm_task_manager.parse_task_output(
Task.GENERAL, output=text
)

text = _process_parsed_output(
text, self._include_reasoning_traces()
)

else:
# Initialize the LLMCallInfo object
llm_call_info_var.set(LLMCallInfo(task=Task.GENERAL.value))
Expand Down Expand Up @@ -565,6 +572,8 @@ async def generate_user_intent(
text = self.llm_task_manager.parse_task_output(
Task.GENERAL, output=result
)

text = _process_parsed_output(text, self._include_reasoning_traces())
text = text.strip()
if text.startswith('"'):
text = text[1:-1]
Expand Down Expand Up @@ -646,6 +655,7 @@ async def generate_next_step(
result = self.llm_task_manager.parse_task_output(
Task.GENERATE_NEXT_STEPS, output=result
)
result = result.text

# If we don't have multi-step generation enabled, we only look at the first line.
if not self.config.enable_multi_step_generation:
Expand Down Expand Up @@ -900,6 +910,10 @@ async def generate_bot_message(
Task.GENERAL, output=result
)

result = _process_parsed_output(
result, self._include_reasoning_traces()
)

log.info(
"--- :: LLM Bot Message Generation passthrough call took %.2f seconds",
time() - t0,
Expand Down Expand Up @@ -963,6 +977,10 @@ async def generate_bot_message(
Task.GENERATE_BOT_MESSAGE, output=result
)

result = _process_parsed_output(
result, self._include_reasoning_traces()
)

# TODO: catch openai.error.InvalidRequestError from exceeding max token length

result = get_multiline_response(result)
Expand Down Expand Up @@ -1055,6 +1073,7 @@ async def generate_value(
result = self.llm_task_manager.parse_task_output(
Task.GENERATE_VALUE, output=result
)
result = result.text

# We only use the first line for now
# TODO: support multi-line values?
Expand Down Expand Up @@ -1266,6 +1285,7 @@ async def generate_intent_steps_message(
result = self.llm_task_manager.parse_task_output(
Task.GENERATE_INTENT_STEPS_MESSAGE, output=result
)
result = result.text

# TODO: Implement logic for generating more complex Colang next steps (multi-step),
# not just a single bot intent.
Expand Down Expand Up @@ -1348,6 +1368,7 @@ async def generate_intent_steps_message(
result = self.llm_task_manager.parse_task_output(
Task.GENERAL, output=result
)
result = _process_parsed_output(result, self._include_reasoning_traces())
text = result.strip()
if text.startswith('"'):
text = text[1:-1]
Expand All @@ -1360,6 +1381,10 @@ async def generate_intent_steps_message(
events=[new_event_dict("BotMessage", text=text)],
)

def _include_reasoning_traces(self) -> bool:
"""Get the configuration value for whether to include reasoning traces in output."""
return _get_apply_to_reasoning_traces(self.config)


def clean_utterance_content(utterance: str) -> str:
"""
Expand All @@ -1377,3 +1402,27 @@ def clean_utterance_content(utterance: str) -> str:
# It should be translated to an actual \n character.
utterance = utterance.replace("\\n", "\n")
return utterance


def _record_reasoning_trace(trace: str) -> None:
"""Store the reasoning trace in context for later retrieval."""
reasoning_trace_var.set(trace)


def _assemble_response(text: str, trace: Optional[str], include_reasoning: bool) -> str:
"""Combine trace and text if requested, otherwise just return text."""
return (trace + text) if (trace and include_reasoning) else text


def _process_parsed_output(
output: ParsedTaskOutput, include_reasoning_trace: bool
) -> str:
"""Record trace, then assemble the final LLM response."""
if reasoning_trace := output.reasoning_trace:
_record_reasoning_trace(reasoning_trace)
return _assemble_response(output.text, reasoning_trace, include_reasoning_trace)


def _get_apply_to_reasoning_traces(config: RailsConfig) -> bool:
"""Get the configuration value for whether to include reasoning traces in output."""
return config.rails.output.apply_to_reasoning_traces
16 changes: 14 additions & 2 deletions nemoguardrails/actions/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from nemoguardrails.colang.v2_x.lang.colang_ast import Flow
from nemoguardrails.colang.v2_x.runtime.flows import InternalEvent, InternalEvents
from nemoguardrails.context import llm_call_info_var
from nemoguardrails.context import llm_call_info_var, reasoning_trace_var
from nemoguardrails.logging.callbacks import logging_callbacks
from nemoguardrails.logging.explain import LLMCallInfo

Expand Down Expand Up @@ -192,7 +192,7 @@ def get_colang_history(
and event["action_name"] == "retrieve_relevant_chunks"
):
continue
history += f'execute {event["action_name"]}\n'
history += f"execute {event['action_name']}\n"
elif event["type"] == "InternalSystemActionFinished" and not event.get(
"is_system_action"
):
Expand Down Expand Up @@ -577,3 +577,15 @@ def escape_flow_name(name: str) -> str:
# removes non-word chars and leading digits in a word
result = re.sub(r"\b\d+|[^\w\s]", "", result)
return result


def get_and_clear_reasoning_trace_contextvar() -> Optional[str]:
"""Get the current reasoning trace and clear it from the context.

Returns:
Optional[str]: The reasoning trace if one exists, None otherwise.
"""
if reasoning_trace := reasoning_trace_var.get():
reasoning_trace_var.set(None)
return reasoning_trace
return None
18 changes: 17 additions & 1 deletion nemoguardrails/actions/v2_x/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ async def _collect_user_intent_and_examples(

# We add these in reverse order so the most relevant is towards the end.
for result in reversed(results):
examples += f"user action: user said \"{result.text}\"\nuser intent: {result.meta['intent']}\n\n"
examples += f'user action: user said "{result.text}"\nuser intent: {result.meta["intent"]}\n\n'
if result.meta["intent"] not in potential_user_intents:
potential_user_intents.append(result.meta["intent"])

Expand Down Expand Up @@ -302,6 +302,8 @@ async def generate_user_intent(
Task.GENERATE_USER_INTENT_FROM_USER_ACTION, output=result
)

result = result.text

user_intent = get_first_nonempty_line(result)
# GTP-4o often adds 'user intent: ' in front
if user_intent and ":" in user_intent:
Expand Down Expand Up @@ -378,6 +380,8 @@ async def generate_user_intent_and_bot_action(
Task.GENERATE_USER_INTENT_AND_BOT_ACTION_FROM_USER_ACTION, output=result
)

result = result.text

user_intent = get_first_nonempty_line(result)

if user_intent and ":" in user_intent:
Expand Down Expand Up @@ -458,6 +462,8 @@ async def passthrough_llm_action(

text = self.llm_task_manager.parse_task_output(Task.GENERAL, output=text)

text = result.text

return text

@action(name="CheckValidFlowExistsAction", is_system_action=True)
Expand Down Expand Up @@ -541,6 +547,8 @@ async def generate_flow_from_instructions(
task=Task.GENERATE_FLOW_FROM_INSTRUCTIONS, output=result
)

result = result.text

# TODO: why this is not part of a filter or output_parser?
#
lines = _remove_leading_empty_lines(result).split("\n")
Expand Down Expand Up @@ -613,6 +621,8 @@ async def generate_flow_from_name(
task=Task.GENERATE_FLOW_FROM_NAME, output=result
)

result = result.text

lines = _remove_leading_empty_lines(result).split("\n")

if lines[0].startswith("flow"):
Expand Down Expand Up @@ -680,6 +690,8 @@ async def generate_flow_continuation(
task=Task.GENERATE_FLOW_CONTINUATION, output=result
)

result = result.text

lines = _remove_leading_empty_lines(result).split("\n")

if len(lines) == 0 or (len(lines) == 1 and lines[0] == ""):
Expand Down Expand Up @@ -806,6 +818,8 @@ async def generate_value(
Task.GENERATE_VALUE_FROM_INSTRUCTION, output=result
)

result = result.text

# We only use the first line for now
# TODO: support multi-line values?
value = result.strip().split("\n")[0]
Expand Down Expand Up @@ -913,6 +927,8 @@ async def generate_flow(
Task.GENERATE_FLOW_CONTINUATION_FROM_NLD, output=result
)

result = result.text

result = _remove_leading_empty_lines(result)
lines = result.split("\n")
if "codeblock" in lines[0]:
Expand Down
5 changes: 5 additions & 0 deletions nemoguardrails/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import contextvars
from typing import Optional

streaming_handler_var = contextvars.ContextVar("streaming_handler", default=None)

Expand All @@ -32,3 +33,7 @@
# The raw LLM request that comes from the user.
# This is used in passthrough mode.
raw_llm_request = contextvars.ContextVar("raw_llm_request", default=None)

reasoning_trace_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
"reasoning_trace", default=None
)
3 changes: 3 additions & 0 deletions nemoguardrails/library/content_safety/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ async def content_safety_check_input(
result = await llm_call(llm, check_input_prompt, stop=stop)

result = llm_task_manager.parse_task_output(task, output=result)
result = result.text

try:
is_safe, violated_policies = result
Expand Down Expand Up @@ -162,6 +163,8 @@ async def content_safety_check_output(

result = llm_task_manager.parse_task_output(task, output=result)

result = result.text

try:
is_safe, violated_policies = result
except TypeError:
Expand Down
1 change: 1 addition & 0 deletions nemoguardrails/library/self_check/facts/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ async def self_check_facts(
task, output=response, forced_output_parser="is_content_safe"
)

result = result.text
is_not_safe, _ = result

result = float(not is_not_safe)
Expand Down
1 change: 1 addition & 0 deletions nemoguardrails/library/self_check/input_check/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ async def self_check_input(
task, output=response, forced_output_parser="is_content_safe"
)

result = result.text
is_safe, _ = result

if not is_safe:
Expand Down
1 change: 1 addition & 0 deletions nemoguardrails/library/self_check/output_check/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ async def self_check_output(
task, output=response, forced_output_parser="is_content_safe"
)

result = result.text
is_safe, _ = result

return is_safe
Loading