Skip to content

Commit

Permalink
tests for set,frozenset,list,dict,typed_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Oct 30, 2024
1 parent 2bf9d9b commit 554ec40
Show file tree
Hide file tree
Showing 9 changed files with 281 additions and 68 deletions.
4 changes: 2 additions & 2 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2854,7 +2854,7 @@ def typed_dict_field(
Args:
schema: The schema to use for the field
required: Whether the field is required
required: Whether the field is required, otherwise uses the value from `total` on the typed dict
validation_alias: The alias(es) to use to find the field in the validation data
serialization_alias: The alias to use as a key when serializing
serialization_exclude: Whether to exclude the field when serializing
Expand Down Expand Up @@ -2930,7 +2930,7 @@ class MyTypedDict(TypedDict):
ref: optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
extra_behavior: The extra behavior to use for the typed dict
total: Whether the typed dict is total
total: Whether the typed dict is total, otherwise uses `typed_dict_total` from config
populate_by_name: Whether the typed dict should populate by name
serialization: Custom serialization schema
"""
Expand Down
6 changes: 3 additions & 3 deletions src/errors/line_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,11 @@ impl ValLineError {
self
}

pub fn last_loc_item(&self) -> Option<&LocItem> {
pub fn first_loc_item(&self) -> Option<&LocItem> {
match &self.location {
Location::Empty => None,
// first because order is reversed
Location::List(loc_items) => loc_items.first(),
// last because order is reversed
Location::List(loc_items) => loc_items.last(),
}
}
}
Expand Down
38 changes: 0 additions & 38 deletions src/errors/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use crate::validators::ValidationState;
use pyo3::prelude::*;

mod line_error;
Expand Down Expand Up @@ -31,40 +30,3 @@ pub fn py_err_string(py: Python, err: PyErr) -> String {
Err(_) => "Unknown Error".to_string(),
}
}

/// If we're in `allow_partial` mode, whether all errors occurred in the last element of the input.
pub fn sequence_valid_as_partial(state: &ValidationState, input_length: usize, errors: &[ValLineError]) -> bool {
if !state.extra().allow_partial {
return false;
}
// for the error to be in the last element, the index of all errors must be `input_length - 1`
let last_index = (input_length - 1) as i64;
errors.iter().all(|error| {
if let Some(LocItem::I(loc_index)) = error.last_loc_item() {
*loc_index == last_index
} else {
false
}
})
}

/// If we're in `allow_partial` mode, whether all errors occurred in the last value of the input.
pub fn mapping_valid_as_partial(
state: &ValidationState,
opt_last_key: Option<impl Into<LocItem>>,
errors: &[ValLineError],
) -> bool {
if !state.extra().allow_partial {
return false;
}
let Some(last_key) = opt_last_key.map(Into::into) else {
return false;
};
errors.iter().all(|error| {
if let Some(loc_item) = error.last_loc_item() {
loc_item == &last_key
} else {
false
}
})
}
24 changes: 20 additions & 4 deletions src/input/return_enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ use pyo3::types::{PyBytes, PyComplex, PyFloat, PyFrozenSet, PyIterator, PyMappin
use serde::{ser::Error, Serialize, Serializer};

use crate::errors::{
py_err_string, sequence_valid_as_partial, ErrorType, ErrorTypeDefaults, InputValue, ToErrorValue, ValError,
ValLineError, ValResult,
py_err_string, ErrorType, ErrorTypeDefaults, InputValue, LocItem, ToErrorValue, ValError, ValLineError, ValResult,
};
use crate::py_gc::PyGcTraverse;
use crate::tools::{extract_i64, extract_int, new_py_string, py_err};
Expand Down Expand Up @@ -131,7 +130,6 @@ pub(crate) fn validate_iter_to_vec<'py>(
let mut errors: Vec<ValLineError> = Vec::new();
let mut index = 0;
for item_result in iter {
index += 1;
let item = item_result.map_err(|e| any_next_error!(py, e, max_length_check.input, index))?;
match validator.validate(py, item.borrow_input(), state) {
Ok(item) => {
Expand All @@ -148,6 +146,7 @@ pub(crate) fn validate_iter_to_vec<'py>(
Err(ValError::Omit) => (),
Err(err) => return Err(err),
}
index += 1;
}

if errors.is_empty() || sequence_valid_as_partial(state, index, &errors) {
Expand All @@ -157,6 +156,23 @@ pub(crate) fn validate_iter_to_vec<'py>(
}
}

/// If we're in `allow_partial` mode, whether all errors occurred in the last element of the input.
pub fn sequence_valid_as_partial(state: &ValidationState, input_length: usize, errors: &[ValLineError]) -> bool {
if !state.extra().allow_partial {
false
} else {
// for the error to be in the last element, the index of all errors must be `input_length - 1`
let last_index = (input_length - 1) as i64;
errors.iter().all(|error| {
if let Some(LocItem::I(loc_index)) = error.first_loc_item() {
*loc_index == last_index
} else {
false
}
})
}
}

pub trait BuildSet {
fn build_add(&self, item: PyObject) -> PyResult<()>;

Expand Down Expand Up @@ -202,7 +218,6 @@ pub(crate) fn validate_iter_to_set<'py>(
let mut errors: Vec<ValLineError> = Vec::new();
let mut index = 0;
for item_result in iter {
index += 1;
let item = item_result.map_err(|e| any_next_error!(py, e, input, index))?;
match validator.validate(py, item.borrow_input(), state) {
Ok(item) => {
Expand Down Expand Up @@ -233,6 +248,7 @@ pub(crate) fn validate_iter_to_set<'py>(
if fail_fast && !errors.is_empty() {
return Err(ValError::LineErrors(errors));
}
index += 1;
}

if errors.is_empty() || sequence_valid_as_partial(state, index, &errors) {
Expand Down
20 changes: 12 additions & 8 deletions src/validators/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use pyo3::prelude::*;
use pyo3::types::PyDict;

use crate::build_tools::is_strict;
use crate::errors::{sequence_valid_as_partial, LocItem, ValError, ValLineError, ValResult};
use crate::errors::{LocItem, ValError, ValLineError, ValResult};
use crate::input::BorrowInput;
use crate::input::ConsumeIterator;
use crate::input::{Input, ValidatedDict};
Expand Down Expand Up @@ -109,10 +109,13 @@ where
fn consume_iterator(self, iterator: impl Iterator<Item = ValResult<(Key, Value)>>) -> ValResult<PyObject> {
let output = PyDict::new_bound(self.py);
let mut errors: Vec<ValLineError> = Vec::new();
let mut input_length = 0;
// this should only be set to if:
// we get errors in a value, there are no previous errors, and no items come after that
// e.g. if we get errors just in the last value
let mut errors_in_last = false;

for item_result in iterator {
input_length += 1;
errors_in_last = false;
let (key, value) = item_result?;
let output_key = match self.key_validator.validate(self.py, key.borrow_input(), self.state) {
Ok(value) => Some(value),
Expand All @@ -127,22 +130,23 @@ where
Err(err) => return Err(err),
};
let output_value = match self.value_validator.validate(self.py, value.borrow_input(), self.state) {
Ok(value) => Some(value),
Ok(value) => value,
Err(ValError::LineErrors(line_errors)) => {
errors_in_last = errors.is_empty();
for err in line_errors {
errors.push(err.with_outer_location(key.clone()));
}
None
continue;
}
Err(ValError::Omit) => continue,
Err(err) => return Err(err),
};
if let (Some(key), Some(value)) = (output_key, output_value) {
output.set_item(key, value)?;
if let Some(key) = output_key {
output.set_item(key, output_value)?;
}
}

if errors.is_empty() || sequence_valid_as_partial(self.state, input_length, &errors) {
if errors.is_empty() || (self.state.extra().allow_partial && errors_in_last) {
let input = self.input;
length_check!(input, "Dictionary", self.min_length, self.max_length, output);
Ok(output.into())
Expand Down
32 changes: 29 additions & 3 deletions src/validators/typed_dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use ahash::AHashSet;

use crate::build_tools::py_schema_err;
use crate::build_tools::{is_strict, schema_or_config, schema_or_config_same, ExtraBehavior};
use crate::errors::{mapping_valid_as_partial, LocItem};
use crate::errors::LocItem;
use crate::errors::{ErrorTypeDefaults, ValError, ValLineError, ValResult};
use crate::input::BorrowInput;
use crate::input::ConsumeIterator;
Expand Down Expand Up @@ -35,6 +35,7 @@ pub struct TypedDictValidator {
extras_validator: Option<Box<CombinedValidator>>,
strict: bool,
loc_by_alias: bool,
allow_partial: bool,
}

impl BuildValidator for TypedDictValidator {
Expand Down Expand Up @@ -124,13 +125,14 @@ impl BuildValidator for TypedDictValidator {
required,
});
}

let allow_partial = fields.iter().all(|f| !f.required);
Ok(Self {
fields,
extra_behavior,
extras_validator,
strict,
loc_by_alias: config.get_as(intern!(py, "loc_by_alias"))?.unwrap_or(true),
allow_partial,
}
.into())
}
Expand Down Expand Up @@ -322,7 +324,7 @@ impl Validator for TypedDictValidator {
})??;
}

if errors.is_empty() || mapping_valid_as_partial(state, dict.last_key(), &errors) {
if errors.is_empty() || self.valid_as_partial(state, dict.last_key(), &errors) {
Ok(output_dict.to_object(py))
} else {
Err(ValError::LineErrors(errors))
Expand All @@ -333,3 +335,27 @@ impl Validator for TypedDictValidator {
Self::EXPECTED_TYPE
}
}

impl TypedDictValidator {
/// If we're in `allow_partial` mode, whether all errors occurred in the last value of the dict.
fn valid_as_partial(
&self,
state: &ValidationState,
opt_last_key: Option<impl Into<LocItem>>,
errors: &[ValLineError],
) -> bool {
if !state.extra().allow_partial || !self.allow_partial {
false
} else if let Some(last_key) = opt_last_key.map(Into::into) {
errors.iter().all(|error| {
if let Some(loc_item) = error.first_loc_item() {
loc_item == &last_key
} else {
false
}
})
} else {
false
}
}
}
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
backports.zoneinfo==0.2.1;python_version<"3.9"
coverage==7.6.1
dirty-equals==0.8.0
inline-snapshot==0.13.3
hypothesis==6.111.2
# pandas doesn't offer prebuilt wheels for all versions and platforms we test in CI e.g. aarch64 musllinux
pandas==2.1.3; python_version >= "3.9" and python_version < "3.13" and implementation_name == "cpython" and platform_machine == 'x86_64'
Expand Down
Loading

0 comments on commit 554ec40

Please sign in to comment.