Skip to content

Commit

Permalink
add "allow_partial" support
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Oct 30, 2024
1 parent c5a4261 commit 3ba83fb
Show file tree
Hide file tree
Showing 19 changed files with 197 additions and 44 deletions.
17 changes: 16 additions & 1 deletion python/pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class SchemaValidator:
from_attributes: bool | None = None,
context: Any | None = None,
self_instance: Any | None = None,
allow_partial: bool = False,
) -> Any:
"""
Validate a Python object against the schema and return the validated object.
Expand All @@ -105,6 +106,8 @@ class SchemaValidator:
[`info.context`][pydantic_core.core_schema.ValidationInfo.context].
self_instance: An instance of a model set attributes on from validation, this is used when running
validation from the `__init__` method of a model.
allow_partial: Whether to allow partial validation, if `True` errors in the last element of sequences
and mappings are ignored.
Raises:
ValidationError: If validation fails.
Expand Down Expand Up @@ -138,6 +141,7 @@ class SchemaValidator:
strict: bool | None = None,
context: Any | None = None,
self_instance: Any | None = None,
allow_partial: bool = False,
) -> Any:
"""
Validate JSON data directly against the schema and return the validated Python object.
Expand All @@ -155,6 +159,8 @@ class SchemaValidator:
context: The context to use for validation, this is passed to functional validators as
[`info.context`][pydantic_core.core_schema.ValidationInfo.context].
self_instance: An instance of a model set attributes on from validation.
allow_partial: Whether to allow partial validation, if `True` errors in the last element of sequences
and mappings are ignored.
Raises:
ValidationError: If validation fails or if the JSON data is invalid.
Expand All @@ -163,7 +169,14 @@ class SchemaValidator:
Returns:
The validated Python object.
"""
def validate_strings(self, input: _StringInput, *, strict: bool | None = None, context: Any | None = None) -> Any:
def validate_strings(
self,
input: _StringInput,
*,
strict: bool | None = None,
context: Any | None = None,
allow_partial: bool = False,
) -> Any:
"""
Validate a string against the schema and return the validated Python object.
Expand All @@ -176,6 +189,8 @@ class SchemaValidator:
If `None`, the value of [`CoreConfig.strict`][pydantic_core.core_schema.CoreConfig] is used.
context: The context to use for validation, this is passed to functional validators as
[`info.context`][pydantic_core.core_schema.ValidationInfo.context].
allow_partial: Whether to allow partial validation, if `True` errors in the last element of sequences
and mappings are ignored.
Raises:
ValidationError: If validation fails or if the JSON data is invalid.
Expand Down
42 changes: 28 additions & 14 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,45 +122,58 @@ class CoreConfig(TypedDict, total=False):

class SerializationInfo(Protocol):
@property
def include(self) -> IncExCall: ...
def include(self) -> IncExCall:
...

@property
def exclude(self) -> IncExCall: ...
def exclude(self) -> IncExCall:
...

@property
def context(self) -> Any | None:
"""Current serialization context."""

@property
def mode(self) -> str: ...
def mode(self) -> str:
...

@property
def by_alias(self) -> bool: ...
def by_alias(self) -> bool:
...

@property
def exclude_unset(self) -> bool: ...
def exclude_unset(self) -> bool:
...

@property
def exclude_defaults(self) -> bool: ...
def exclude_defaults(self) -> bool:
...

@property
def exclude_none(self) -> bool: ...
def exclude_none(self) -> bool:
...

@property
def serialize_as_any(self) -> bool: ...
def serialize_as_any(self) -> bool:
...

def round_trip(self) -> bool: ...
def round_trip(self) -> bool:
...

def mode_is_json(self) -> bool: ...
def mode_is_json(self) -> bool:
...

def __str__(self) -> str: ...
def __str__(self) -> str:
...

def __repr__(self) -> str: ...
def __repr__(self) -> str:
...


class FieldSerializationInfo(SerializationInfo, Protocol):
@property
def field_name(self) -> str: ...
def field_name(self) -> str:
...


class ValidationInfo(Protocol):
Expand Down Expand Up @@ -304,7 +317,8 @@ def plain_serializer_function_ser_schema(


class SerializerFunctionWrapHandler(Protocol): # pragma: no cover
def __call__(self, input_value: Any, index_key: int | str | None = None, /) -> Any: ...
def __call__(self, input_value: Any, index_key: int | str | None = None, /) -> Any:
...


# (input_value: Any, serializer: SerializerFunctionWrapHandler, /) -> Any
Expand Down
8 changes: 8 additions & 0 deletions src/errors/line_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ impl ValLineError {
self.error_type = error_type;
self
}

pub fn last_loc_item(&self) -> Option<&LocItem> {
match &self.location {
Location::Empty => None,
// first because order is reversed
Location::List(loc_items) => loc_items.first(),
}
}
}

#[cfg_attr(debug_assertions, derive(Debug))]
Expand Down
2 changes: 1 addition & 1 deletion src/errors/location.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::lookup_key::{LookupPath, PathItem};

/// Used to store individual items of the error location, e.g. a string for key/field names
/// or a number for array indices.
#[derive(Clone)]
#[derive(Clone, Eq, PartialEq)]
#[cfg_attr(debug_assertions, derive(Debug))]
pub enum LocItem {
/// string type key, used to identify items from a dict or anything that implements `__getitem__`
Expand Down
38 changes: 38 additions & 0 deletions src/errors/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::validators::ValidationState;
use pyo3::prelude::*;

mod line_error;
Expand Down Expand Up @@ -30,3 +31,40 @@ pub fn py_err_string(py: Python, err: PyErr) -> String {
Err(_) => "Unknown Error".to_string(),
}
}

/// If we're in `allow_partial` mode, whether all errors occurred in the last element of the input.
pub fn sequence_valid_as_partial(state: &ValidationState, input_length: usize, errors: &[ValLineError]) -> bool {
if !state.extra().allow_partial {
return false;
}
// for the error to be in the last element, the index of all errors must be `input_length - 1`
let last_index = (input_length - 1) as i64;
errors.iter().all(|error| {
if let Some(LocItem::I(loc_index)) = error.last_loc_item() {
*loc_index == last_index
} else {
false
}
})
}

/// If we're in `allow_partial` mode, whether all errors occurred in the last value of the input.
pub fn mapping_valid_as_partial(
state: &ValidationState,
opt_last_key: Option<impl Into<LocItem>>,
errors: &[ValLineError],
) -> bool {
if !state.extra().allow_partial {
return false;
}
let Some(last_key) = opt_last_key.map(Into::into) else {
return false;
};
errors.iter().all(|error| {
if let Some(loc_item) = error.last_loc_item() {
loc_item == &last_key
} else {
false
}
})
}
5 changes: 5 additions & 0 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ pub trait ValidatedDict<'py> {
&'a self,
consumer: impl ConsumeIterator<ValResult<(Self::Key<'a>, Self::Item<'a>)>, Output = R>,
) -> ValResult<R>;
// used in partial mode to check all errors occurred in the last key value pair
fn last_key(&self) -> Option<Self::Key<'_>>;
}

/// For validations from a list
Expand Down Expand Up @@ -289,6 +291,9 @@ impl<'py> ValidatedDict<'py> for Never {
) -> ValResult<R> {
unreachable!()
}
fn last_key(&self) -> Option<Self::Key<'_>> {
unreachable!()
}
}

impl<'py> ValidatedList<'py> for Never {
Expand Down
4 changes: 4 additions & 0 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,10 @@ impl<'py, 'data> ValidatedDict<'py> for &'_ JsonObject<'data> {
) -> ValResult<R> {
Ok(consumer.consume_iterator(LazyIndexMap::iter(self).map(|(k, v)| Ok((k.as_ref(), v)))))
}

fn last_key(&self) -> Option<Self::Key<'_>> {
self.keys().last().map(AsRef::as_ref)
}
}

impl<'a, 'py, 'data> ValidatedList<'py> for &'a JsonArray<'data> {
Expand Down
8 changes: 8 additions & 0 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,14 @@ impl<'py> ValidatedDict<'py> for GenericPyMapping<'_, 'py> {
Self::GetAttr(obj, _) => Ok(consumer.consume_iterator(iterate_attributes(obj)?)),
}
}

fn last_key(&self) -> Option<Self::Key<'_>> {
match self {
Self::Dict(dict) => dict.keys().iter().last(),
Self::Mapping(mapping) => mapping.keys().ok()?.iter().ok()?.last()?.ok(),
Self::GetAttr(_, _) => None,
}
}
}

/// Container for all the collections (sized iterable containers) types, which
Expand Down
8 changes: 8 additions & 0 deletions src/input/input_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,4 +297,12 @@ impl<'py> ValidatedDict<'py> for StringMappingDict<'py> {
.map(|(key, val)| Ok((StringMapping::new_key(key)?, StringMapping::new_value(val)?))),
))
}

fn last_key(&self) -> Option<Self::Key<'_>> {
self.0
.keys()
.iter()
.last()
.and_then(|key| StringMapping::new_key(key).ok())
}
}
19 changes: 12 additions & 7 deletions src/input/return_enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ use pyo3::types::{PyBytes, PyFloat, PyFrozenSet, PyIterator, PyMapping, PySet, P
use serde::{ser::Error, Serialize, Serializer};

use crate::errors::{
py_err_string, ErrorType, ErrorTypeDefaults, InputValue, ToErrorValue, ValError, ValLineError, ValResult,
py_err_string, sequence_valid_as_partial, ErrorType, ErrorTypeDefaults, InputValue, ToErrorValue, ValError,
ValLineError, ValResult,
};
use crate::py_gc::PyGcTraverse;
use crate::tools::{extract_i64, new_py_string, py_err};
Expand Down Expand Up @@ -128,7 +129,9 @@ pub(crate) fn validate_iter_to_vec<'py>(
) -> ValResult<Vec<PyObject>> {
let mut output: Vec<PyObject> = Vec::with_capacity(capacity);
let mut errors: Vec<ValLineError> = Vec::new();
for (index, item_result) in iter.enumerate() {
let mut index = 0;
for item_result in iter {
index += 1;
let item = item_result.map_err(|e| any_next_error!(py, e, max_length_check.input, index))?;
match validator.validate(py, item.borrow_input(), state) {
Ok(item) => {
Expand All @@ -139,15 +142,15 @@ pub(crate) fn validate_iter_to_vec<'py>(
max_length_check.incr()?;
errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index)));
if fail_fast {
break;
return Err(ValError::LineErrors(errors));
}
}
Err(ValError::Omit) => (),
Err(err) => return Err(err),
}
}

if errors.is_empty() {
if errors.is_empty() || sequence_valid_as_partial(state, index, &errors) {
Ok(output)
} else {
Err(ValError::LineErrors(errors))
Expand Down Expand Up @@ -197,7 +200,9 @@ pub(crate) fn validate_iter_to_set<'py>(
fail_fast: bool,
) -> ValResult<()> {
let mut errors: Vec<ValLineError> = Vec::new();
for (index, item_result) in iter.enumerate() {
let mut index = 0;
for item_result in iter {
index += 1;
let item = item_result.map_err(|e| any_next_error!(py, e, input, index))?;
match validator.validate(py, item.borrow_input(), state) {
Ok(item) => {
Expand Down Expand Up @@ -226,11 +231,11 @@ pub(crate) fn validate_iter_to_set<'py>(
Err(err) => return Err(err),
}
if fail_fast && !errors.is_empty() {
break;
return Err(ValError::LineErrors(errors));
}
}

if errors.is_empty() {
if errors.is_empty() || sequence_valid_as_partial(state, index, &errors) {
Ok(())
} else {
Err(ValError::LineErrors(errors))
Expand Down
4 changes: 2 additions & 2 deletions src/url.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl PyUrl {
pub fn py_new(py: Python, url: &Bound<'_, PyAny>) -> PyResult<Self> {
let schema_obj = SCHEMA_DEFINITION_URL
.get_or_init(py, || build_schema_validator(py, "url"))
.validate_python(py, url, None, None, None, None)?;
.validate_python(py, url, None, None, None, None, false)?;
schema_obj.extract(py)
}

Expand Down Expand Up @@ -225,7 +225,7 @@ impl PyMultiHostUrl {
pub fn py_new(py: Python, url: &Bound<'_, PyAny>) -> PyResult<Self> {
let schema_obj = SCHEMA_DEFINITION_MULTI_HOST_URL
.get_or_init(py, || build_schema_validator(py, "multi-host-url"))
.validate_python(py, url, None, None, None, None)?;
.validate_python(py, url, None, None, None, None, false)?;
schema_obj.extract(py)
}

Expand Down
6 changes: 4 additions & 2 deletions src/validators/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use pyo3::prelude::*;
use pyo3::types::PyDict;

use crate::build_tools::is_strict;
use crate::errors::{LocItem, ValError, ValLineError, ValResult};
use crate::errors::{sequence_valid_as_partial, LocItem, ValError, ValLineError, ValResult};
use crate::input::BorrowInput;
use crate::input::ConsumeIterator;
use crate::input::{Input, ValidatedDict};
Expand Down Expand Up @@ -109,8 +109,10 @@ where
fn consume_iterator(self, iterator: impl Iterator<Item = ValResult<(Key, Value)>>) -> ValResult<PyObject> {
let output = PyDict::new_bound(self.py);
let mut errors: Vec<ValLineError> = Vec::new();
let mut input_length = 0;

for item_result in iterator {
input_length += 1;
let (key, value) = item_result?;
let output_key = match self.key_validator.validate(self.py, key.borrow_input(), self.state) {
Ok(value) => Some(value),
Expand Down Expand Up @@ -140,7 +142,7 @@ where
}
}

if errors.is_empty() {
if errors.is_empty() || sequence_valid_as_partial(self.state, input_length, &errors) {
let input = self.input;
length_check!(input, "Dictionary", self.min_length, self.max_length, output);
Ok(output.into())
Expand Down
Loading

0 comments on commit 3ba83fb

Please sign in to comment.