Skip to content

Commit 4344daf

Browse files
authored
Feat: Support models with reasoning traces (#996)
* Fix jailbreak models test for cases when JAILBREAK_SETUP_PRESENT env variable is not defined * Add support from removing reasoning traces for reasoning LLMs (e.g. Deepseek-R1) * Add default prompts for Deepseek models. * Changed how output parsers are used by the self check rails. * Added test for a config with LLM reasoning for LLM rails. * Added DeepSeek-R1 example config and updated documentation on reasoning models. * Added support for deepseek and google_genai as providers * Refactored model name config for langchain providers
1 parent 255d579 commit 4344daf

File tree

15 files changed

+598
-56
lines changed

15 files changed

+598
-56
lines changed

docs/user-guides/configuration-guide.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,29 @@ To use any of the providers, you must install additional packages; when you firs
9494
Although you can instantiate any of the previously mentioned LLM providers, depending on the capabilities of the model, the NeMo Guardrails toolkit works better with some providers than others. The toolkit includes prompts that have been optimized for certain types of models, such as `openai` and `nemollm`. For others, you can optimize the prompts yourself following the information in the [LLM Prompts](#llm-prompts) section.
9595
```
9696

97+
#### Using LLMs with Reasoning Traces
98+
99+
To use an LLM that outputs the reasoning traces as part of the response (e.g. [DeepSeek-R1](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d)), the following model config should be used:
100+
101+
```yaml
102+
models:
103+
- type: main
104+
engine: deepseek
105+
model: deepseek-reasoner
106+
reasoning_config:
107+
remove_thinking_traces: True
108+
start_token: "<think>"
109+
end_token: "</think>"
110+
```
111+
112+
The `reasoning_config` attribute for a model contains all the required configuration for a reasoning model that outputs reasoning traces.
113+
In most of the cases, the reasoning traces need to be removed and the guardrails runtime will only process the actual responses from the LLM.
114+
115+
The attributes that can be configured for a reasoning model are:
116+
- `remove_thinking_traces`: if the reasoning traces should be ignored (defaults to `True`.
117+
- `start_token`: the start token for the reasoning process (e.g. `<think>` for DeepSeek-R1).
118+
- `end_token`: the end token for the reasoning process (e.g. `</think>` for DeepSeek-R1).
119+
97120
#### NIM for LLMs
98121

99122
[NVIDIA NIM](https://docs.nvidia.com/nim/index.html) is a set of easy-to-use microservices designed to accelerate the deployment of generative AI models across the cloud, data center, and workstations.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
models:
2+
- type: main
3+
engine: deepseek
4+
model: deepseek-reasoner
5+
reasoning_config:
6+
remove_thinking_traces: True
7+
start_token: "<think>"
8+
end_token: "</think>"

nemoguardrails/library/self_check/facts/actions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ async def self_check_facts(
7878
if llm_task_manager.has_output_parser(task):
7979
result = llm_task_manager.parse_task_output(task, output=response)
8080
else:
81-
result = llm_task_manager.output_parsers["is_content_safe"](response)
81+
result = llm_task_manager.parse_task_output(
82+
task, output=response, forced_output_parser="is_content_safe"
83+
)
8284

8385
is_not_safe, _ = result
8486

nemoguardrails/library/self_check/input_check/actions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ async def self_check_input(
7979
result = llm_task_manager.parse_task_output(task, output=response)
8080

8181
else:
82-
result = llm_task_manager.output_parsers["is_content_safe"](response)
82+
result = llm_task_manager.parse_task_output(
83+
task, output=response, forced_output_parser="is_content_safe"
84+
)
8385

8486
is_safe, _ = result
8587

nemoguardrails/library/self_check/output_check/actions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ async def self_check_output(
8383
if llm_task_manager.has_output_parser(task):
8484
result = llm_task_manager.parse_task_output(task, output=response)
8585
else:
86-
result = llm_task_manager.output_parsers["is_content_safe"](response)
86+
result = llm_task_manager.parse_task_output(
87+
task, output=response, forced_output_parser="is_content_safe"
88+
)
8789

8890
is_safe, _ = result
8991

nemoguardrails/llm/filters.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,3 +482,26 @@ def conversation_to_events(conversation: List) -> List[dict]:
482482
)
483483

484484
return events
485+
486+
487+
def remove_reasoning_traces(response: str, start_token: str, end_token: str) -> str:
488+
"""Removes the text between the first occurrence of the start token and the
489+
last occurrence of the last token, if these tokens exist in the response.
490+
491+
This utility function is useful to strip reasoning traces from reasoning LLMs
492+
that encode the reasoning traces between specific tokens.
493+
"""
494+
if start_token and end_token:
495+
start_index = response.find(start_token)
496+
# If the start index is missing, this is probably a continuation of a bot message
497+
# started in the prompt.
498+
if start_index == -1:
499+
start_index = 0
500+
end_index = response.rfind(end_token)
501+
if end_index == -1:
502+
return response
503+
504+
if start_index != -1 and end_index != -1 and start_index < end_index:
505+
return response[:start_index] + response[end_index + len(end_token) :]
506+
507+
return response

nemoguardrails/llm/prompts.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import yaml
2121

2222
from nemoguardrails.llm.types import Task
23-
from nemoguardrails.rails.llm.config import RailsConfig, TaskPrompt
23+
from nemoguardrails.rails.llm.config import Model, RailsConfig, TaskPrompt
2424

2525
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
2626

@@ -116,21 +116,33 @@ def _get_prompt(
116116
raise ValueError(f"Could not find prompt for task {task_name} and model {model}")
117117

118118

119-
def get_prompt(config: RailsConfig, task: Union[str, Task]) -> TaskPrompt:
120-
"""Return the prompt for the given task."""
121-
119+
def get_task_model(config: RailsConfig, task: Union[str, Task]) -> Model:
120+
"""Return the model for the given task in the current config."""
122121
# Fetch current task parameters like name, models to use, and the prompting mode
123122
task_name = str(task.value) if isinstance(task, Task) else task
124123

125-
task_model = "unknown"
126124
if config.models:
127125
_models = [model for model in config.models if model.type == task_name]
128126
if not _models:
129127
_models = [model for model in config.models if model.type == "main"]
130128

131-
task_model = _models[0].engine
132-
if _models[0].model:
133-
task_model += "/" + _models[0].model
129+
return _models[0]
130+
131+
return None
132+
133+
134+
def get_prompt(config: RailsConfig, task: Union[str, Task]) -> TaskPrompt:
135+
"""Return the prompt for the given task."""
136+
137+
# Fetch current task parameters like name, models to use, and the prompting mode
138+
task_name = str(task.value) if isinstance(task, Task) else task
139+
140+
task_model = "unknown"
141+
_model = get_task_model(config, task)
142+
if _model:
143+
task_model = _model.engine
144+
if _model.model:
145+
task_model += "/" + _model.model
134146

135147
task_prompting_mode = "standard"
136148
if config.prompting_mode:

0 commit comments

Comments
 (0)