Skip to content

Commit 4152e21

Browse files
committed
refactor: extract output mapping logic to separate module
1 parent 6d61c20 commit 4152e21

File tree

2 files changed

+55
-40
lines changed

2 files changed

+55
-40
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import Any, Tuple
17+
18+
19+
def default_output_mapping(result: Any) -> bool:
20+
"""A fallback output mapping if an action does not provide one.
21+
22+
- For a boolean result: assume True means allowed (so block if False).
23+
- For a numeric result: use 0.5 as a threshold (block if the value is less).
24+
- Otherwise, assume the result is allowed.
25+
"""
26+
if isinstance(result, bool):
27+
return not result
28+
elif isinstance(result, (int, float)):
29+
return result < 0.5
30+
else:
31+
return False
32+
33+
34+
def should_block_output(result: Any, action_func: Any) -> bool:
35+
"""Determines if an action result should be blocked using its attached mapping.
36+
37+
Args:
38+
result: The value returned by the action.
39+
action_func: The action function (whose metadata contains the mapping).
40+
41+
Returns:
42+
True if the mapping indicates that the output should be blocked, False otherwise.
43+
"""
44+
mapping = getattr(action_func, "action_meta", {}).get("output_mapping")
45+
if mapping is None:
46+
mapping = default_output_mapping
47+
48+
if not isinstance(result, Tuple):
49+
result = (result,)
50+
51+
return mapping(result[0])

nemoguardrails/rails/llm/llmrails.py

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
from nemoguardrails.actions.llm.generation import LLMGenerationActions
3333
from nemoguardrails.actions.llm.utils import get_colang_history
34+
from nemoguardrails.actions.output_mapping import should_block_output
3435
from nemoguardrails.actions.v2_x.generation import LLMGenerationActionsV2dotx
3536
from nemoguardrails.colang import parse_colang_file
3637
from nemoguardrails.colang.v1_0.runtime.flows import compute_context
@@ -1266,7 +1267,7 @@ def _update_explain_info():
12661267
output_rails_flows_id = self.config.rails.output.flows
12671268
stream_first = stream_first or output_rails_streaming_config.stream_first
12681269
get_action_details = partial(
1269-
get_action_details_from_flow_id, flows=self.config.flows
1270+
_get_action_details_from_flow_id, flows=self.config.flows
12701271
)
12711272

12721273
async for chunk_list, chunk_str_rep in buffer_strategy(streaming_handler):
@@ -1298,7 +1299,7 @@ def _update_explain_info():
12981299
action_func = self.runtime.action_dispatcher.get_action(action_name)
12991300

13001301
# Use the mapping to decide if the result indicates blocked content.
1301-
if is_blocked(result, action_func):
1302+
if should_block_output(result, action_func):
13021303
# TODO: while whitespace issue is fixed, remove the space from below
13031304
yield " {DATA: STOP}"
13041305
return
@@ -1307,7 +1308,7 @@ def _update_explain_info():
13071308
yield chunk_str_rep
13081309

13091310

1310-
def get_action_details_from_flow_id(
1311+
def _get_action_details_from_flow_id(
13111312
flow_id: str, flows: List[Union[Dict, Any]]
13121313
) -> Tuple[str, Any]:
13131314
"""Get the action name and parameters from the flow id."""
@@ -1327,40 +1328,3 @@ def get_action_details_from_flow_id(
13271328
return element["action_name"], element["action_params"]
13281329

13291330
raise ValueError(f"No action found for flow_id: {flow_id}")
1330-
1331-
1332-
def default_mapping(result):
1333-
"""
1334-
A fallback mapping if an action does not provide one.
1335-
1336-
- For a boolean result: assume True means allowed (so block if False).
1337-
- For a numeric result: use 0.5 as a threshold (block if the value is less).
1338-
- Otherwise, assume the result is allowed.
1339-
"""
1340-
if isinstance(result, bool):
1341-
return not result # block if result is False
1342-
elif isinstance(result, (int, float)):
1343-
return result < 0.5
1344-
else:
1345-
return False
1346-
1347-
1348-
def is_blocked(result, action_func):
1349-
"""
1350-
Determines if an action result should be blocked using its attached mapping.
1351-
1352-
Args:
1353-
result: The value returned by the action.
1354-
action_func: The action function (whose metadata contains the mapping).
1355-
1356-
Returns:
1357-
True if the mapping indicates that the output should be blocked, False otherwise.
1358-
"""
1359-
mapping = getattr(action_func, "action_meta", {}).get("output_mapping")
1360-
if mapping is None:
1361-
mapping = default_mapping
1362-
1363-
if not isinstance(result, Tuple):
1364-
result = (result,)
1365-
1366-
return mapping(result[0])

0 commit comments

Comments
 (0)