From 554ec405ab064154f56bc11bf0163a6a56e98ef1 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 30 Oct 2024 18:57:40 +0000 Subject: [PATCH] tests for set,frozenset,list,dict,typed_dict --- python/pydantic_core/core_schema.py | 4 +- src/errors/line_error.rs | 6 +- src/errors/mod.rs | 38 ----- src/input/return_enums.rs | 24 ++- src/validators/dict.rs | 20 ++- src/validators/typed_dict.rs | 32 +++- tests/requirements.txt | 1 + tests/validators/test_allow_partial.py | 214 +++++++++++++++++++++++++ tests/validators/test_list.py | 10 -- 9 files changed, 281 insertions(+), 68 deletions(-) create mode 100644 tests/validators/test_allow_partial.py diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 17e66f625..1f19c4a84 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -2854,7 +2854,7 @@ def typed_dict_field( Args: schema: The schema to use for the field - required: Whether the field is required + required: Whether the field is required, otherwise uses the value from `total` on the typed dict validation_alias: The alias(es) to use to find the field in the validation data serialization_alias: The alias to use as a key when serializing serialization_exclude: Whether to exclude the field when serializing @@ -2930,7 +2930,7 @@ class MyTypedDict(TypedDict): ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core extra_behavior: The extra behavior to use for the typed dict - total: Whether the typed dict is total + total: Whether the typed dict is total, otherwise uses `typed_dict_total` from config populate_by_name: Whether the typed dict should populate by name serialization: Custom serialization schema """ diff --git a/src/errors/line_error.rs b/src/errors/line_error.rs index 417f5e790..03e770cc3 100644 --- a/src/errors/line_error.rs +++ b/src/errors/line_error.rs @@ -146,11 +146,11 @@ impl ValLineError { self } - pub fn last_loc_item(&self) -> Option<&LocItem> { + pub fn first_loc_item(&self) -> Option<&LocItem> { match &self.location { Location::Empty => None, - // first because order is reversed - Location::List(loc_items) => loc_items.first(), + // last because order is reversed + Location::List(loc_items) => loc_items.last(), } } } diff --git a/src/errors/mod.rs b/src/errors/mod.rs index 2f0176695..ffdda90e3 100644 --- a/src/errors/mod.rs +++ b/src/errors/mod.rs @@ -1,4 +1,3 @@ -use crate::validators::ValidationState; use pyo3::prelude::*; mod line_error; @@ -31,40 +30,3 @@ pub fn py_err_string(py: Python, err: PyErr) -> String { Err(_) => "Unknown Error".to_string(), } } - -/// If we're in `allow_partial` mode, whether all errors occurred in the last element of the input. -pub fn sequence_valid_as_partial(state: &ValidationState, input_length: usize, errors: &[ValLineError]) -> bool { - if !state.extra().allow_partial { - return false; - } - // for the error to be in the last element, the index of all errors must be `input_length - 1` - let last_index = (input_length - 1) as i64; - errors.iter().all(|error| { - if let Some(LocItem::I(loc_index)) = error.last_loc_item() { - *loc_index == last_index - } else { - false - } - }) -} - -/// If we're in `allow_partial` mode, whether all errors occurred in the last value of the input. -pub fn mapping_valid_as_partial( - state: &ValidationState, - opt_last_key: Option>, - errors: &[ValLineError], -) -> bool { - if !state.extra().allow_partial { - return false; - } - let Some(last_key) = opt_last_key.map(Into::into) else { - return false; - }; - errors.iter().all(|error| { - if let Some(loc_item) = error.last_loc_item() { - loc_item == &last_key - } else { - false - } - }) -} diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index 6980917a3..5f5e4c1ab 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -17,8 +17,7 @@ use pyo3::types::{PyBytes, PyComplex, PyFloat, PyFrozenSet, PyIterator, PyMappin use serde::{ser::Error, Serialize, Serializer}; use crate::errors::{ - py_err_string, sequence_valid_as_partial, ErrorType, ErrorTypeDefaults, InputValue, ToErrorValue, ValError, - ValLineError, ValResult, + py_err_string, ErrorType, ErrorTypeDefaults, InputValue, LocItem, ToErrorValue, ValError, ValLineError, ValResult, }; use crate::py_gc::PyGcTraverse; use crate::tools::{extract_i64, extract_int, new_py_string, py_err}; @@ -131,7 +130,6 @@ pub(crate) fn validate_iter_to_vec<'py>( let mut errors: Vec = Vec::new(); let mut index = 0; for item_result in iter { - index += 1; let item = item_result.map_err(|e| any_next_error!(py, e, max_length_check.input, index))?; match validator.validate(py, item.borrow_input(), state) { Ok(item) => { @@ -148,6 +146,7 @@ pub(crate) fn validate_iter_to_vec<'py>( Err(ValError::Omit) => (), Err(err) => return Err(err), } + index += 1; } if errors.is_empty() || sequence_valid_as_partial(state, index, &errors) { @@ -157,6 +156,23 @@ pub(crate) fn validate_iter_to_vec<'py>( } } +/// If we're in `allow_partial` mode, whether all errors occurred in the last element of the input. +pub fn sequence_valid_as_partial(state: &ValidationState, input_length: usize, errors: &[ValLineError]) -> bool { + if !state.extra().allow_partial { + false + } else { + // for the error to be in the last element, the index of all errors must be `input_length - 1` + let last_index = (input_length - 1) as i64; + errors.iter().all(|error| { + if let Some(LocItem::I(loc_index)) = error.first_loc_item() { + *loc_index == last_index + } else { + false + } + }) + } +} + pub trait BuildSet { fn build_add(&self, item: PyObject) -> PyResult<()>; @@ -202,7 +218,6 @@ pub(crate) fn validate_iter_to_set<'py>( let mut errors: Vec = Vec::new(); let mut index = 0; for item_result in iter { - index += 1; let item = item_result.map_err(|e| any_next_error!(py, e, input, index))?; match validator.validate(py, item.borrow_input(), state) { Ok(item) => { @@ -233,6 +248,7 @@ pub(crate) fn validate_iter_to_set<'py>( if fail_fast && !errors.is_empty() { return Err(ValError::LineErrors(errors)); } + index += 1; } if errors.is_empty() || sequence_valid_as_partial(state, index, &errors) { diff --git a/src/validators/dict.rs b/src/validators/dict.rs index edd537ba7..02e4f379f 100644 --- a/src/validators/dict.rs +++ b/src/validators/dict.rs @@ -3,7 +3,7 @@ use pyo3::prelude::*; use pyo3::types::PyDict; use crate::build_tools::is_strict; -use crate::errors::{sequence_valid_as_partial, LocItem, ValError, ValLineError, ValResult}; +use crate::errors::{LocItem, ValError, ValLineError, ValResult}; use crate::input::BorrowInput; use crate::input::ConsumeIterator; use crate::input::{Input, ValidatedDict}; @@ -109,10 +109,13 @@ where fn consume_iterator(self, iterator: impl Iterator>) -> ValResult { let output = PyDict::new_bound(self.py); let mut errors: Vec = Vec::new(); - let mut input_length = 0; + // this should only be set to if: + // we get errors in a value, there are no previous errors, and no items come after that + // e.g. if we get errors just in the last value + let mut errors_in_last = false; for item_result in iterator { - input_length += 1; + errors_in_last = false; let (key, value) = item_result?; let output_key = match self.key_validator.validate(self.py, key.borrow_input(), self.state) { Ok(value) => Some(value), @@ -127,22 +130,23 @@ where Err(err) => return Err(err), }; let output_value = match self.value_validator.validate(self.py, value.borrow_input(), self.state) { - Ok(value) => Some(value), + Ok(value) => value, Err(ValError::LineErrors(line_errors)) => { + errors_in_last = errors.is_empty(); for err in line_errors { errors.push(err.with_outer_location(key.clone())); } - None + continue; } Err(ValError::Omit) => continue, Err(err) => return Err(err), }; - if let (Some(key), Some(value)) = (output_key, output_value) { - output.set_item(key, value)?; + if let Some(key) = output_key { + output.set_item(key, output_value)?; } } - if errors.is_empty() || sequence_valid_as_partial(self.state, input_length, &errors) { + if errors.is_empty() || (self.state.extra().allow_partial && errors_in_last) { let input = self.input; length_check!(input, "Dictionary", self.min_length, self.max_length, output); Ok(output.into()) diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index cfa42d52d..e75275231 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -6,7 +6,7 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config, schema_or_config_same, ExtraBehavior}; -use crate::errors::{mapping_valid_as_partial, LocItem}; +use crate::errors::LocItem; use crate::errors::{ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::BorrowInput; use crate::input::ConsumeIterator; @@ -35,6 +35,7 @@ pub struct TypedDictValidator { extras_validator: Option>, strict: bool, loc_by_alias: bool, + allow_partial: bool, } impl BuildValidator for TypedDictValidator { @@ -124,13 +125,14 @@ impl BuildValidator for TypedDictValidator { required, }); } - + let allow_partial = fields.iter().all(|f| !f.required); Ok(Self { fields, extra_behavior, extras_validator, strict, loc_by_alias: config.get_as(intern!(py, "loc_by_alias"))?.unwrap_or(true), + allow_partial, } .into()) } @@ -322,7 +324,7 @@ impl Validator for TypedDictValidator { })??; } - if errors.is_empty() || mapping_valid_as_partial(state, dict.last_key(), &errors) { + if errors.is_empty() || self.valid_as_partial(state, dict.last_key(), &errors) { Ok(output_dict.to_object(py)) } else { Err(ValError::LineErrors(errors)) @@ -333,3 +335,27 @@ impl Validator for TypedDictValidator { Self::EXPECTED_TYPE } } + +impl TypedDictValidator { + /// If we're in `allow_partial` mode, whether all errors occurred in the last value of the dict. + fn valid_as_partial( + &self, + state: &ValidationState, + opt_last_key: Option>, + errors: &[ValLineError], + ) -> bool { + if !state.extra().allow_partial || !self.allow_partial { + false + } else if let Some(last_key) = opt_last_key.map(Into::into) { + errors.iter().all(|error| { + if let Some(loc_item) = error.first_loc_item() { + loc_item == &last_key + } else { + false + } + }) + } else { + false + } + } +} diff --git a/tests/requirements.txt b/tests/requirements.txt index 5ee5ebfda..3fab8cc74 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,6 +1,7 @@ backports.zoneinfo==0.2.1;python_version<"3.9" coverage==7.6.1 dirty-equals==0.8.0 +inline-snapshot==0.13.3 hypothesis==6.111.2 # pandas doesn't offer prebuilt wheels for all versions and platforms we test in CI e.g. aarch64 musllinux pandas==2.1.3; python_version >= "3.9" and python_version < "3.13" and implementation_name == "cpython" and platform_machine == 'x86_64' diff --git a/tests/validators/test_allow_partial.py b/tests/validators/test_allow_partial.py new file mode 100644 index 000000000..559ca12c8 --- /dev/null +++ b/tests/validators/test_allow_partial.py @@ -0,0 +1,214 @@ +import pytest +from dirty_equals import IsStrictDict +from inline_snapshot import snapshot + +from generate_self_schema import core_schema +from pydantic_core import SchemaValidator, ValidationError + + +def test_list(): + v = SchemaValidator( + core_schema.list_schema( + core_schema.tuple_positional_schema([core_schema.int_schema(), core_schema.int_schema()]), + ) + ) + assert v.validate_python([[1, 2], [3, 4]]) == [(1, 2), (3, 4)] + assert v.validate_python([[1, 2], [3, 4]], allow_partial=True) == [(1, 2), (3, 4)] + with pytest.raises(ValidationError) as exc_info: + v.validate_python([[1, 2], 'wrong']) + assert exc_info.value.errors(include_url=False) == snapshot( + [ + { + 'type': 'tuple_type', + 'loc': (1,), + 'msg': 'Input should be a valid tuple', + 'input': 'wrong', + } + ] + ) + assert v.validate_python([[1, 2], 'wrong'], allow_partial=True) == [(1, 2)] + assert v.validate_python([[1, 2], []], allow_partial=True) == [(1, 2)] + assert v.validate_python([[1, 2], [3]], allow_partial=True) == [(1, 2)] + assert v.validate_python([[1, 2], [3, 'x']], allow_partial=True) == [(1, 2)] + with pytest.raises(ValidationError, match='Input should be a valid tuple'): + v.validate_python([[1, 2], 'wrong', [3, 4]]) + with pytest.raises(ValidationError, match='Input should be a valid tuple'): + v.validate_python([[1, 2], 'wrong', 'wrong']) + + +def test_list_partial_nested(): + v = SchemaValidator( + core_schema.tuple_positional_schema( + [core_schema.int_schema(), core_schema.list_schema(core_schema.int_schema())] + ), + ) + assert v.validate_python([1, [2, 3]]) == (1, [2, 3]) + assert v.validate_python([1, [2, 3]], allow_partial=True) == (1, [2, 3]) + with pytest.raises(ValidationError) as exc_info: + v.validate_python((1, [2, 3, 'x'])) + assert exc_info.value.errors(include_url=False) == snapshot( + [ + { + 'type': 'int_parsing', + 'loc': (1, 2), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'input': 'x', + } + ] + ) + assert v.validate_python((1, [2, 3, 'x']), allow_partial=True) == (1, [2, 3]) + + +@pytest.mark.parametrize('collection_type', [core_schema.set_schema, core_schema.frozenset_schema]) +def test_set_frozenset(collection_type): + v = SchemaValidator( + collection_type( + core_schema.tuple_positional_schema([core_schema.int_schema(), core_schema.int_schema()]), + ) + ) + assert v.validate_python([[1, 2], [3, 4]]) == snapshot({(1, 2), (3, 4)}) + assert v.validate_python([[1, 2], [3, 4]], allow_partial=True) == snapshot({(1, 2), (3, 4)}) + with pytest.raises(ValidationError) as exc_info: + v.validate_python([[1, 2], 'wrong']) + assert exc_info.value.errors(include_url=False) == snapshot( + [ + { + 'type': 'tuple_type', + 'loc': (1,), + 'msg': 'Input should be a valid tuple', + 'input': 'wrong', + } + ] + ) + assert v.validate_python([[1, 2], 'wrong'], allow_partial=True) == snapshot({(1, 2)}) + assert v.validate_python([[1, 2], [3, 4], 'wrong'], allow_partial=True) == snapshot({(1, 2), (3, 4)}) + assert v.validate_python([[1, 2], []], allow_partial=True) == snapshot({(1, 2)}) + assert v.validate_python([[1, 2], [3]], allow_partial=True) == snapshot({(1, 2)}) + assert v.validate_python([[1, 2], [3, 'x']], allow_partial=True) == snapshot({(1, 2)}) + with pytest.raises(ValidationError, match='Input should be a valid tuple'): + v.validate_python([[1, 2], 'wrong', [3, 4]]) + with pytest.raises(ValidationError, match='Input should be a valid tuple'): + v.validate_python([[1, 2], 'wrong', 'wrong']) + + +def test_dict(): + v = SchemaValidator(core_schema.dict_schema(core_schema.int_schema(), core_schema.int_schema())) + assert v.validate_python({'1': 2, 3: '4'}) == snapshot({1: 2, 3: 4}) + assert v.validate_python({'1': 2, 3: '4'}, allow_partial=True) == snapshot({1: 2, 3: 4}) + with pytest.raises(ValidationError) as exc_info: + v.validate_python({'1': 2, 3: 'wrong'}) + assert exc_info.value.errors(include_url=False) == snapshot( + [ + { + 'type': 'int_parsing', + 'loc': (3,), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'input': 'wrong', + } + ] + ) + assert v.validate_python({'1': 2, 3: 'wrong'}, allow_partial=True) == snapshot({1: 2}) + assert v.validate_python({'1': 2, 3: 4, 5: '6', 7: 'x'}, allow_partial=True) == snapshot({1: 2, 3: 4, 5: 6}) + with pytest.raises(ValidationError, match='Input should be a valid integer'): + v.validate_python({'1': 2, 3: 4, 5: 'x', 7: '8'}) + with pytest.raises(ValidationError, match='Input should be a valid integer'): + v.validate_python({'1': 2, 3: 4, 5: 'x', 7: 'x'}) + with pytest.raises(ValidationError, match='Input should be a valid integer'): + v.validate_python({'1': 2, 3: 4, 'x': 6}) + + +def test_partial_typed_dict(): + v = SchemaValidator( + core_schema.typed_dict_schema( + { + 'a': core_schema.typed_dict_field(core_schema.int_schema(gt=10)), + 'b': core_schema.typed_dict_field(core_schema.int_schema(gt=10)), + 'c': core_schema.typed_dict_field(core_schema.int_schema(gt=10)), + }, + total=False, + ) + ) + + assert v.validate_python({'a': 11, 'b': '12', 'c': 13}) == snapshot(IsStrictDict(a=11, b=12, c=13)) + assert v.validate_python({'a': 11, 'c': 13, 'b': '12'}) == snapshot(IsStrictDict(a=11, b=12, c=13)) + + assert v.validate_python({'a': 11, 'b': '12', 'c': 13}, allow_partial=True) == snapshot({'a': 11, 'b': 12, 'c': 13}) + with pytest.raises(ValidationError) as exc_info: + v.validate_python({'a': 11, 'b': '12', 'c': 1}) + assert exc_info.value.errors(include_url=False) == snapshot( + [ + { + 'type': 'greater_than', + 'loc': ('c',), + 'msg': 'Input should be greater than 10', + 'input': 1, + 'ctx': {'gt': 10}, + } + ] + ) + assert v.validate_python({'a': 11, 'b': '12', 'c': 1}, allow_partial=True) == snapshot(IsStrictDict(a=11, b=12)) + assert v.validate_python({'a': 11, 'c': 13, 'b': 1}, allow_partial=True) == snapshot(IsStrictDict(a=11, c=13)) + with pytest.raises(ValidationError) as exc_info: + v.validate_python({'a': 11, 'c': 1, 'b': 12}, allow_partial=True) + assert exc_info.value.errors(include_url=False) == snapshot( + [ + { + 'type': 'greater_than', + 'loc': ('c',), + 'msg': 'Input should be greater than 10', + 'input': 1, + 'ctx': {'gt': 10}, + } + ] + ) + + +def test_non_partial_typed_dict(): + v = SchemaValidator( + core_schema.typed_dict_schema( + { + 'a': core_schema.typed_dict_field(core_schema.int_schema(gt=10)), + 'b': core_schema.typed_dict_field(core_schema.int_schema(gt=10), required=True), + 'c': core_schema.typed_dict_field(core_schema.int_schema(gt=10)), + }, + total=False, + ) + ) + + assert v.validate_python({'a': 11, 'b': '12', 'c': 13}) == snapshot({'a': 11, 'b': 12, 'c': 13}) + with pytest.raises(ValidationError, match='Input should be greater than 10'): + v.validate_python({'a': 11, 'b': '12', 'c': 1}) + with pytest.raises(ValidationError, match='Input should be greater than 10'): + v.validate_python({'a': 11, 'b': '12', 'c': 1}, allow_partial=False) + + +def test_double_nested(): + v = SchemaValidator( + core_schema.typed_dict_schema( + { + 'a': core_schema.typed_dict_field(core_schema.int_schema(gt=10)), + 'b': core_schema.typed_dict_field( + core_schema.list_schema( + core_schema.dict_schema(core_schema.str_schema(), core_schema.int_schema(ge=10)) + ) + ), + }, + total=False, + ) + ) + assert v.validate_python({'a': 11, 'b': [{'a': 10, 'b': 20}, {'a': 30, 'b': 40}]}) == snapshot( + {'a': 11, 'b': [{'a': 10, 'b': 20}, {'a': 30, 'b': 40}]} + ) + assert v.validate_python({'a': 11, 'b': [{'a': 10, 'b': 20}, {'a': 30, 'b': 4}]}, allow_partial=True) == snapshot( + {'a': 11, 'b': [{'a': 10, 'b': 20}, {'a': 30}]} + ) + assert v.validate_python({'a': 11, 'b': [{'a': 10, 'b': 20}, {'a': 30, 123: 4}]}, allow_partial=True) == snapshot( + {'a': 11, 'b': [{'a': 10, 'b': 20}]} + ) + # this is not the intended behaviour, but it's okay + assert v.validate_python({'a': 11, 'b': [{'a': 10, 'b': 2}, {'a': 30}]}, allow_partial=True) == snapshot( + {'a': 11, 'b': [{'a': 10}, {'a': 30}]} + ) + assert v.validate_python({'a': 11, 'b': [{'a': 1, 'b': 20}, {'a': 3, 'b': 40}]}, allow_partial=True) == snapshot( + {'a': 11} + ) diff --git a/tests/validators/test_list.py b/tests/validators/test_list.py index 92f3aefc3..b8bcee8ec 100644 --- a/tests/validators/test_list.py +++ b/tests/validators/test_list.py @@ -547,13 +547,3 @@ def test_list_dict_items_input(testcase: ListInputTestCase) -> None: output = v.validate_python(testcase.input) assert output == testcase.output assert output is not testcase.input - - -def test_validate_partial(): - v = SchemaValidator( - core_schema.list_schema( - core_schema.tuple_positional_schema([core_schema.int_schema(), core_schema.int_schema()]), - ) - ) - assert v.validate_python([[1, 2], [3, 4]]) == [(1, 2), (3, 4)] - assert v.validate_python([[1, 2], [3, 4]], allow_partial=True) == [(1, 2), (3, 4)]