Skip to content

Commit cbff2af

Browse files
committed
feat(tests): add validation tests for models
1 parent 7fe3276 commit cbff2af

File tree

1 file changed

+128
-0
lines changed

1 file changed

+128
-0
lines changed

tests/eval/test_models.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717

1818
import pytest
19+
from pydantic import ValidationError
1920

2021
from nemoguardrails.eval.models import (
2122
ComplianceCheckLog,
@@ -140,3 +141,130 @@ def test_eval_output():
140141
assert len(output.results[0].compliance_checks) == 1
141142
assert output.results[0].compliance_checks[0].id == "check_id"
142143
assert output.results[0].compliance_checks[0].interaction_id == "test_id"
144+
145+
146+
def test_eval_config_policy_validation_empty_lists():
147+
"""Test that empty policies and interactions lists are handled correctly."""
148+
config = EvalConfig.model_validate(
149+
{
150+
"policies": [],
151+
"interactions": [],
152+
}
153+
)
154+
assert len(config.policies) == 0
155+
assert len(config.interactions) == 0
156+
157+
158+
def test_eval_config_policy_validation_invalid_polocy_format_missing_description():
159+
"""Test that invalid policy formats are rejected."""
160+
with pytest.raises(ValueError):
161+
EvalConfig.model_validate(
162+
{
163+
"policies": [{"id": "policy1"}],
164+
"interactions": [],
165+
}
166+
)
167+
168+
169+
def test_eval_config_policy_validation_invalid_interaction_format_missing_inputs():
170+
"""Test that invalid interaction formats are rejected."""
171+
with pytest.raises(ValueError):
172+
EvalConfig.model_validate(
173+
{
174+
"policies": [{"id": "policy1", "description": "Test policy"}],
175+
"interactions": [
176+
{
177+
"id": "test_id",
178+
"expected_output": [{"type": "string", "policy": "policy1"}],
179+
}
180+
],
181+
}
182+
)
183+
184+
185+
def test_interaction_set_empty_expected_output():
186+
"""Test that empty expected_output list is handled correctly."""
187+
interaction_set = InteractionSet.model_validate(
188+
{"id": "test_id", "inputs": ["test input"], "expected_output": []}
189+
)
190+
assert len(interaction_set.expected_output) == 0
191+
192+
193+
def test_interaction_set_invalid_format():
194+
"""Test that invalid expected_output format is rejected."""
195+
with pytest.raises(ValueError):
196+
InteractionSet.model_validate(
197+
{
198+
"id": "test_id",
199+
"inputs": ["test input"],
200+
"expected_output": [{"type": "string"}],
201+
}
202+
)
203+
204+
# TODO: The model currently doesn't validate the type field values.
205+
# This test should pass once type validation is implemented.
206+
# with pytest.raises(ValueError):
207+
# InteractionSet.model_validate(
208+
# {
209+
# "id": "test_id",
210+
# "inputs": ["test input"],
211+
# "expected_output": [{"type": "invalid_type", "policy": "test_policy"}],
212+
# }
213+
# )
214+
215+
216+
def test_compliance_check_log_invalid_format():
217+
"""Test that invalid ComplianceCheckLog format is rejected."""
218+
with pytest.raises(ValueError):
219+
ComplianceCheckLog.model_validate({})
220+
221+
# invalid llm_calls format
222+
with pytest.raises(ValueError):
223+
ComplianceCheckLog.model_validate({"id": "test_id", "llm_calls": "invalid"})
224+
225+
226+
def test_policy_creation():
227+
policy = Policy(
228+
id="policy_1",
229+
description="Test policy description",
230+
weight=50,
231+
apply_to_all=False,
232+
)
233+
assert policy.id == "policy_1"
234+
assert policy.description == "Test policy description"
235+
assert policy.weight == 50
236+
assert not policy.apply_to_all
237+
238+
239+
def test_policy_default_values():
240+
policy = Policy(
241+
id="policy_2",
242+
description="Another test policy",
243+
)
244+
assert policy.weight == 100
245+
assert policy.apply_to_all
246+
247+
248+
def test_policy_invalid_weight():
249+
with pytest.raises(ValidationError):
250+
Policy(
251+
id="policy_3",
252+
description="Invalid weight test",
253+
weight="invalid_weight",
254+
)
255+
256+
257+
def test_expected_output_creation():
258+
output = ExpectedOutput(
259+
type="refusal",
260+
policy="policy_1",
261+
)
262+
assert output.type == "refusal"
263+
assert output.policy == "policy_1"
264+
265+
266+
def test_expected_output_missing_field():
267+
with pytest.raises(ValidationError):
268+
ExpectedOutput(
269+
type="refusal",
270+
)

0 commit comments

Comments
 (0)