Skip to content

Commit e7866ef

Browse files
committed
test: add tests for output mapping functions
1 parent 8405d89 commit e7866ef

File tree

2 files changed

+183
-1
lines changed

2 files changed

+183
-1
lines changed
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
import pytest
18+
19+
from nemoguardrails.actions import action
20+
from nemoguardrails.actions.output_mapping import (
21+
default_output_mapping,
22+
should_block_output,
23+
)
24+
25+
# Tests for default_output_mapping
26+
27+
28+
def test_default_output_mapping_boolean_true():
29+
# For booleans, the mapping returns the negation (block if result is False).
30+
# If result is True, not True == False, so output is not blocked.
31+
assert default_output_mapping(True) is False
32+
33+
34+
def test_default_output_mapping_boolean_false():
35+
# If result is False, then not False == True, so it is blocked.
36+
assert default_output_mapping(False) is True
37+
38+
39+
def test_default_output_mapping_numeric_below_threshold():
40+
# For numeric values, block if the value is less than 0.5.
41+
assert default_output_mapping(0.4) is True
42+
43+
44+
def test_default_output_mapping_numeric_above_threshold():
45+
# For numeric values greater than or equal to 0.5, do not block.
46+
assert default_output_mapping(0.5) is False
47+
assert default_output_mapping(0.6) is False
48+
49+
50+
def test_default_output_mapping_non_numeric_non_boolean():
51+
# For other types (e.g., strings), default mapping returns False (allowed).
52+
assert default_output_mapping("anything") is False
53+
54+
55+
# Tests for should_block_output
56+
57+
58+
# Create a dummy action function with an attached mapping in its metadata.
59+
def dummy_action_output_mapping(val):
60+
# For testing, block if the value equals "block", otherwise do not block.
61+
return val == "block"
62+
63+
64+
@action(output_mapping=dummy_action_output_mapping)
65+
def dummy_action(result):
66+
return result
67+
68+
69+
def test_should_block_output_with_tuple_result_and_mapping():
70+
# Test should_block_output when the result is a tuple and the dummy mapping is used.
71+
# When the first element equals "block", we expect True.
72+
result = ("block",)
73+
assert should_block_output(result, dummy_action) is True
74+
75+
# When the result is not "block", we expect False.
76+
result = ("allow",)
77+
assert should_block_output(result, dummy_action) is False
78+
79+
80+
def test_should_block_output_with_non_tuple_result_and_mapping():
81+
# Test should_block_output when the result is not a tuple.
82+
# The function should wrap it into a tuple.
83+
result = "block"
84+
assert should_block_output(result, dummy_action) is True
85+
86+
result = "allow"
87+
assert should_block_output(result, dummy_action) is False
88+
89+
90+
def test_should_block_output_without_action_meta():
91+
# Test should_block_output when the action function does not have an "action_meta" attribute.
92+
# In this case, default_output_mapping should be used.
93+
def action_without_meta(res):
94+
return res
95+
96+
# Ensure there is no action_meta attribute.
97+
if hasattr(action_without_meta, "action_meta"):
98+
del action_without_meta.action_meta
99+
100+
# Test with a boolean: default_output_mapping for True is False and for False is True.
101+
assert should_block_output(True, action_without_meta) is False
102+
assert should_block_output(False, action_without_meta) is True
103+
104+
# Test with a numeric value: block if < 0.5.
105+
assert should_block_output(0.4, action_without_meta) is True
106+
assert should_block_output(0.6, action_without_meta) is False

tests/test_llmrails.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import Optional
16+
from typing import Any, Dict, List, Optional, Union
1717

1818
import pytest
1919

2020
from nemoguardrails import LLMRails, RailsConfig
21+
from nemoguardrails.rails.llm.llmrails import _get_action_details_from_flow_id
2122
from tests.utils import FakeLLM, clean_events, event_sequence_conforms
2223

2324

@@ -621,3 +622,78 @@ async def compute(what: Optional[str] = "2 + 3"):
621622
"role": "assistant",
622623
"content": "The answer is 5\nAre you happy with the result?",
623624
}
625+
626+
627+
# get_action_details_from_flow_id used in llmrails.py
628+
629+
630+
@pytest.fixture
631+
def dummy_flows() -> List[Union[Dict, Any]]:
632+
return [
633+
{
634+
"id": "test_flow",
635+
"elements": [
636+
{
637+
"_type": "run_action",
638+
"_source_mapping": {
639+
"filename": "flows.v1.co",
640+
"line_text": "execute something",
641+
},
642+
"action_name": "test_action",
643+
"action_params": {"param1": "value1"},
644+
}
645+
],
646+
},
647+
# Additional flow that should match on a prefix
648+
{
649+
"id": "other_flow is prefix",
650+
"elements": [
651+
{
652+
"_type": "run_action",
653+
"_source_mapping": {
654+
"filename": "flows.v1.co",
655+
"line_text": "execute something else",
656+
},
657+
"action_name": "other_action",
658+
"action_params": {"param2": "value2"},
659+
}
660+
],
661+
},
662+
]
663+
664+
665+
def test_get_action_details_exact_match(dummy_flows):
666+
action_name, action_params = _get_action_details_from_flow_id(
667+
"test_flow", dummy_flows
668+
)
669+
assert action_name == "test_action"
670+
assert action_params == {"param1": "value1"}
671+
672+
673+
def test_get_action_details_prefix_match(dummy_flows):
674+
# For a flow_id that starts with the prefix "other_flow",
675+
# we expect to retrieve the action details from the flow whose id starts with that prefix.
676+
# we expect a result since we are passing the prefixes argument.
677+
action_name, action_params = _get_action_details_from_flow_id(
678+
"other_flow", dummy_flows, prefixes=["other_flow"]
679+
)
680+
assert action_name == "other_action"
681+
assert action_params == {"param2": "value2"}
682+
683+
684+
def test_get_action_details_prefix_match_unsupported_prefix(dummy_flows):
685+
# For a flow_id that starts with the prefix "other_flow",
686+
# we expect to retrieve the action details from the flow whose id starts with that prefix.
687+
# but as the prefix is not supported, we expect a ValueError.
688+
689+
with pytest.raises(ValueError) as exc_info:
690+
_get_action_details_from_flow_id("other_flow", dummy_flows)
691+
692+
assert "No action found for flow_id" in str(exc_info.value)
693+
694+
695+
def test_get_action_details_no_match(dummy_flows):
696+
# Tests that a non matching flow_id raises a ValueError
697+
with pytest.raises(ValueError) as exc_info:
698+
_get_action_details_from_flow_id("non_existing_flow", dummy_flows)
699+
assert "No action found for flow_id" in str(exc_info.value)

0 commit comments

Comments
 (0)