|
16 | 16 | import os |
17 | 17 |
|
18 | 18 | import pytest |
| 19 | +from pydantic import ValidationError |
19 | 20 |
|
20 | 21 | from nemoguardrails.eval.models import ( |
21 | 22 | ComplianceCheckLog, |
@@ -140,3 +141,130 @@ def test_eval_output(): |
140 | 141 | assert len(output.results[0].compliance_checks) == 1 |
141 | 142 | assert output.results[0].compliance_checks[0].id == "check_id" |
142 | 143 | assert output.results[0].compliance_checks[0].interaction_id == "test_id" |
| 144 | + |
| 145 | + |
| 146 | +def test_eval_config_policy_validation_empty_lists(): |
| 147 | + """Test that empty policies and interactions lists are handled correctly.""" |
| 148 | + config = EvalConfig.model_validate( |
| 149 | + { |
| 150 | + "policies": [], |
| 151 | + "interactions": [], |
| 152 | + } |
| 153 | + ) |
| 154 | + assert len(config.policies) == 0 |
| 155 | + assert len(config.interactions) == 0 |
| 156 | + |
| 157 | + |
| 158 | +def test_eval_config_policy_validation_invalid_polocy_format_missing_description(): |
| 159 | + """Test that invalid policy formats are rejected.""" |
| 160 | + with pytest.raises(ValueError): |
| 161 | + EvalConfig.model_validate( |
| 162 | + { |
| 163 | + "policies": [{"id": "policy1"}], |
| 164 | + "interactions": [], |
| 165 | + } |
| 166 | + ) |
| 167 | + |
| 168 | + |
| 169 | +def test_eval_config_policy_validation_invalid_interaction_format_missing_inputs(): |
| 170 | + """Test that invalid interaction formats are rejected.""" |
| 171 | + with pytest.raises(ValueError): |
| 172 | + EvalConfig.model_validate( |
| 173 | + { |
| 174 | + "policies": [{"id": "policy1", "description": "Test policy"}], |
| 175 | + "interactions": [ |
| 176 | + { |
| 177 | + "id": "test_id", |
| 178 | + "expected_output": [{"type": "string", "policy": "policy1"}], |
| 179 | + } |
| 180 | + ], |
| 181 | + } |
| 182 | + ) |
| 183 | + |
| 184 | + |
| 185 | +def test_interaction_set_empty_expected_output(): |
| 186 | + """Test that empty expected_output list is handled correctly.""" |
| 187 | + interaction_set = InteractionSet.model_validate( |
| 188 | + {"id": "test_id", "inputs": ["test input"], "expected_output": []} |
| 189 | + ) |
| 190 | + assert len(interaction_set.expected_output) == 0 |
| 191 | + |
| 192 | + |
| 193 | +def test_interaction_set_invalid_format(): |
| 194 | + """Test that invalid expected_output format is rejected.""" |
| 195 | + with pytest.raises(ValueError): |
| 196 | + InteractionSet.model_validate( |
| 197 | + { |
| 198 | + "id": "test_id", |
| 199 | + "inputs": ["test input"], |
| 200 | + "expected_output": [{"type": "string"}], |
| 201 | + } |
| 202 | + ) |
| 203 | + |
| 204 | + # TODO: The model currently doesn't validate the type field values. |
| 205 | + # This test should pass once type validation is implemented. |
| 206 | + # with pytest.raises(ValueError): |
| 207 | + # InteractionSet.model_validate( |
| 208 | + # { |
| 209 | + # "id": "test_id", |
| 210 | + # "inputs": ["test input"], |
| 211 | + # "expected_output": [{"type": "invalid_type", "policy": "test_policy"}], |
| 212 | + # } |
| 213 | + # ) |
| 214 | + |
| 215 | + |
| 216 | +def test_compliance_check_log_invalid_format(): |
| 217 | + """Test that invalid ComplianceCheckLog format is rejected.""" |
| 218 | + with pytest.raises(ValueError): |
| 219 | + ComplianceCheckLog.model_validate({}) |
| 220 | + |
| 221 | + # invalid llm_calls format |
| 222 | + with pytest.raises(ValueError): |
| 223 | + ComplianceCheckLog.model_validate({"id": "test_id", "llm_calls": "invalid"}) |
| 224 | + |
| 225 | + |
| 226 | +def test_policy_creation(): |
| 227 | + policy = Policy( |
| 228 | + id="policy_1", |
| 229 | + description="Test policy description", |
| 230 | + weight=50, |
| 231 | + apply_to_all=False, |
| 232 | + ) |
| 233 | + assert policy.id == "policy_1" |
| 234 | + assert policy.description == "Test policy description" |
| 235 | + assert policy.weight == 50 |
| 236 | + assert not policy.apply_to_all |
| 237 | + |
| 238 | + |
| 239 | +def test_policy_default_values(): |
| 240 | + policy = Policy( |
| 241 | + id="policy_2", |
| 242 | + description="Another test policy", |
| 243 | + ) |
| 244 | + assert policy.weight == 100 |
| 245 | + assert policy.apply_to_all |
| 246 | + |
| 247 | + |
| 248 | +def test_policy_invalid_weight(): |
| 249 | + with pytest.raises(ValidationError): |
| 250 | + Policy( |
| 251 | + id="policy_3", |
| 252 | + description="Invalid weight test", |
| 253 | + weight="invalid_weight", |
| 254 | + ) |
| 255 | + |
| 256 | + |
| 257 | +def test_expected_output_creation(): |
| 258 | + output = ExpectedOutput( |
| 259 | + type="refusal", |
| 260 | + policy="policy_1", |
| 261 | + ) |
| 262 | + assert output.type == "refusal" |
| 263 | + assert output.policy == "policy_1" |
| 264 | + |
| 265 | + |
| 266 | +def test_expected_output_missing_field(): |
| 267 | + with pytest.raises(ValidationError): |
| 268 | + ExpectedOutput( |
| 269 | + type="refusal", |
| 270 | + ) |
0 commit comments