diff --git a/benches/main.rs b/benches/main.rs index ae238f2fa..a89fb1241 100644 --- a/benches/main.rs +++ b/benches/main.rs @@ -5,7 +5,7 @@ extern crate test; use test::{black_box, Bencher}; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyString}; +use pyo3::types::{PyDict, PySet, PyString}; use _pydantic_core::SchemaValidator; @@ -265,6 +265,7 @@ fn dict_python(bench: &mut Bencher) { .collect::>() .join(", ") ); + dbg!(code.clone()); let input = py.eval(&code, None, None).unwrap(); let input = black_box(input); bench.iter(|| { @@ -696,3 +697,58 @@ class Foo(Enum): } }) } + +const COLLECTION_SIZE: usize = 100_000; + +#[bench] +fn constructing_pyset_from_vec_without_capacity(bench: &mut Bencher) { + Python::with_gil(|py| { + let input: Vec = (0..COLLECTION_SIZE).map(|v| v.to_object(py)).collect(); + + bench.iter(|| { + black_box({ + let mut output = Vec::new(); + for x in &input { + output.push(x); + } + let set = PySet::new(py, output.iter()).unwrap(); + set + }) + }) + }) +} + +#[bench] +fn constructing_pyset_from_vec_with_capacity(bench: &mut Bencher) { + Python::with_gil(|py| { + let input: Vec = (0..COLLECTION_SIZE).map(|v| v.to_object(py)).collect(); + + bench.iter(|| { + black_box({ + let mut output = Vec::with_capacity(COLLECTION_SIZE); + for x in &input { + output.push(x); + } + let set = PySet::new(py, output.iter()).unwrap(); + set + }) + }) + }) +} + +#[bench] +fn constructing_pyset_from_vec_directly(bench: &mut Bencher) { + Python::with_gil(|py| { + let input: Vec = (0..COLLECTION_SIZE).map(|v| v.to_object(py)).collect(); + + bench.iter(|| { + black_box({ + let output = PySet::new(py, &Vec::::new()).unwrap(); + for x in &input { + output.add(x).unwrap(); + } + output + }) + }) + }) +} diff --git a/pydantic_core/core_schema.py b/pydantic_core/core_schema.py index 681a3b8a5..78c6a4572 100644 --- a/pydantic_core/core_schema.py +++ b/pydantic_core/core_schema.py @@ -1188,7 +1188,6 @@ class ListSchema(TypedDict, total=False): min_length: int max_length: int strict: bool - allow_any_iter: bool ref: str metadata: Any serialization: IncExSeqOrElseSerSchema @@ -1200,7 +1199,6 @@ def list_schema( min_length: int | None = None, max_length: int | None = None, strict: bool | None = None, - allow_any_iter: bool | None = None, ref: str | None = None, metadata: Any = None, serialization: IncExSeqOrElseSerSchema | None = None, @@ -1221,7 +1219,6 @@ def list_schema( min_length: The value must be a list with at least this many items max_length: The value must be a list with at most this many items strict: The value must be a list with exactly this many items - allow_any_iter: Whether the value can be any iterable ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema @@ -1232,7 +1229,6 @@ def list_schema( min_length=min_length, max_length=max_length, strict=strict, - allow_any_iter=allow_any_iter, ref=ref, metadata=metadata, serialization=serialization, @@ -1353,7 +1349,6 @@ class SetSchema(TypedDict, total=False): items_schema: CoreSchema min_length: int max_length: int - generator_max_length: int strict: bool ref: str metadata: Any @@ -1365,7 +1360,6 @@ def set_schema( *, min_length: int | None = None, max_length: int | None = None, - generator_max_length: int | None = None, strict: bool | None = None, ref: str | None = None, metadata: Any = None, @@ -1388,9 +1382,6 @@ def set_schema( items_schema: The value must be a set with items that match this schema min_length: The value must be a set with at least this many items max_length: The value must be a set with at most this many items - generator_max_length: At most this many items will be read from a generator before failing validation - This is important because generators can be infinite, and even with a `max_length` on the set, - an infinite generator could run forever without producing more than `max_length` distinct items. strict: The value must be a set with exactly this many items ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core @@ -1401,7 +1392,6 @@ def set_schema( items_schema=items_schema, min_length=min_length, max_length=max_length, - generator_max_length=generator_max_length, strict=strict, ref=ref, metadata=metadata, @@ -1414,7 +1404,6 @@ class FrozenSetSchema(TypedDict, total=False): items_schema: CoreSchema min_length: int max_length: int - generator_max_length: int strict: bool ref: str metadata: Any @@ -1426,7 +1415,6 @@ def frozenset_schema( *, min_length: int | None = None, max_length: int | None = None, - generator_max_length: int | None = None, strict: bool | None = None, ref: str | None = None, metadata: Any = None, @@ -1449,7 +1437,6 @@ def frozenset_schema( items_schema: The value must be a frozenset with items that match this schema min_length: The value must be a frozenset with at least this many items max_length: The value must be a frozenset with at most this many items - generator_max_length: The value must generate a frozenset with at most this many items strict: The value must be a frozenset with exactly this many items ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core @@ -1460,7 +1447,6 @@ def frozenset_schema( items_schema=items_schema, min_length=min_length, max_length=max_length, - generator_max_length=generator_max_length, strict=strict, ref=ref, metadata=metadata, diff --git a/src/input/generic_iterable.rs b/src/input/generic_iterable.rs new file mode 100644 index 000000000..8631518a0 --- /dev/null +++ b/src/input/generic_iterable.rs @@ -0,0 +1,210 @@ +use crate::errors::{py_err_string, ErrorType, ValError, ValResult}; + +use super::parse_json::{JsonInput, JsonObject}; +use pyo3::{ + exceptions::PyTypeError, + types::{ + PyByteArray, PyBytes, PyDict, PyFrozenSet, PyIterator, PyList, PyMapping, PySequence, PySet, PyString, PyTuple, + }, + PyAny, PyErr, PyResult, Python, ToPyObject, +}; + +#[derive(Debug)] +pub enum GenericIterable<'a> { + List(&'a PyList), + Tuple(&'a PyTuple), + Set(&'a PySet), + FrozenSet(&'a PyFrozenSet), + Dict(&'a PyDict), + // Treat dict values / keys / items as generic iterators + // since PyPy doesn't export the concrete types + DictKeys(&'a PyIterator), + DictValues(&'a PyIterator), + DictItems(&'a PyIterator), + Mapping(&'a PyMapping), + String(&'a PyString), + Bytes(&'a PyBytes), + PyByteArray(&'a PyByteArray), + Sequence(&'a PySequence), + Iterator(&'a PyIterator), + JsonArray(&'a [JsonInput]), + JsonObject(&'a JsonObject), +} + +type PyMappingItems<'a> = (&'a PyAny, &'a PyAny); + +#[inline(always)] +fn extract_items(item: PyResult<&PyAny>) -> PyResult> { + match item { + Ok(v) => v.extract::(), + Err(e) => Err(e), + } +} + +#[inline(always)] +fn map_err<'data>(py: Python<'data>, err: PyErr, input: &'data PyAny) -> ValError<'data> { + ValError::new( + ErrorType::IterationError { + error: py_err_string(py, err), + }, + input, + ) +} + +impl<'a, 'py: 'a> GenericIterable<'a> { + pub fn len(&self) -> Option { + match &self { + GenericIterable::List(iter) => Some(iter.len()), + GenericIterable::Tuple(iter) => Some(iter.len()), + GenericIterable::Set(iter) => Some(iter.len()), + GenericIterable::FrozenSet(iter) => Some(iter.len()), + GenericIterable::Dict(iter) => Some(iter.len()), + GenericIterable::DictKeys(iter) => iter.len().ok(), + GenericIterable::DictValues(iter) => iter.len().ok(), + GenericIterable::DictItems(iter) => iter.len().ok(), + GenericIterable::Mapping(iter) => iter.len().ok(), + GenericIterable::String(iter) => iter.len().ok(), + GenericIterable::Bytes(iter) => iter.len().ok(), + GenericIterable::PyByteArray(iter) => Some(iter.len()), + GenericIterable::Sequence(iter) => iter.len().ok(), + GenericIterable::Iterator(iter) => iter.len().ok(), + GenericIterable::JsonArray(iter) => Some(iter.len()), + GenericIterable::JsonObject(iter) => Some(iter.len()), + } + } + pub fn into_sequence_iterator( + self, + py: Python<'py>, + ) -> PyResult> + 'a>> { + match self { + GenericIterable::List(iter) => Ok(Box::new(iter.iter().map(Ok))), + GenericIterable::Tuple(iter) => Ok(Box::new(iter.iter().map(Ok))), + GenericIterable::Set(iter) => Ok(Box::new(iter.iter().map(Ok))), + GenericIterable::FrozenSet(iter) => Ok(Box::new(iter.iter().map(Ok))), + // Note that this iterates over only the keys, just like doing iter({}) in Python + GenericIterable::Dict(iter) => Ok(Box::new(iter.iter().map(|(k, _)| Ok(k)))), + GenericIterable::DictKeys(iter) => Ok(Box::new( + iter.map(move |r| r.map_err(|e| map_err(py, e, iter.as_ref()))), + )), + GenericIterable::DictValues(iter) => Ok(Box::new( + iter.map(move |r| r.map_err(|e| map_err(py, e, iter.as_ref()))), + )), + GenericIterable::DictItems(iter) => Ok(Box::new( + iter.map(move |r| r.map_err(|e| map_err(py, e, iter.as_ref()))), + )), + // Note that this iterates over only the keys, just like doing iter({}) in Python + GenericIterable::Mapping(iter) => Ok(Box::new( + iter.keys()? + .iter()? + .map(move |r| r.map_err(|e| map_err(py, e, iter.as_ref()))), + )), + GenericIterable::String(iter) => Ok(Box::new( + iter.iter()?.map(move |r| r.map_err(|e| map_err(py, e, iter.as_ref()))), + )), + GenericIterable::Bytes(iter) => Ok(Box::new( + iter.iter()?.map(move |r| r.map_err(|e| map_err(py, e, iter.as_ref()))), + )), + GenericIterable::PyByteArray(iter) => Ok(Box::new( + iter.iter()?.map(move |r| r.map_err(|e| map_err(py, e, iter.as_ref()))), + )), + GenericIterable::Sequence(iter) => Ok(Box::new( + iter.iter()?.map(move |r| r.map_err(|e| map_err(py, e, iter.as_ref()))), + )), + GenericIterable::Iterator(iter) => Ok(Box::new( + iter.map(move |r| r.map_err(|e| map_err(py, e, iter.as_ref()))), + )), + GenericIterable::JsonArray(iter) => Ok(Box::new(iter.iter().map(move |v| { + let v = v.to_object(py); + Ok(v.into_ref(py)) + }))), + // Note that this iterates over only the keys, just like doing iter({}) in Python, just for consistency + GenericIterable::JsonObject(iter) => Ok(Box::new( + iter.iter().map(move |(k, _)| Ok(k.to_object(py).into_ref(py))), + )), + } + } + + pub fn into_mapping_items_iterator( + self, + py: Python<'a>, + ) -> PyResult>> + 'a>> { + match self { + GenericIterable::List(iter) => { + Ok(Box::new(iter.iter().map(move |v| { + extract_items(Ok(v)).map_err(|e| map_err(py, e, iter.as_ref())) + }))) + } + GenericIterable::Tuple(iter) => { + Ok(Box::new(iter.iter().map(move |v| { + extract_items(Ok(v)).map_err(|e| map_err(py, e, iter.as_ref())) + }))) + } + GenericIterable::Set(iter) => { + Ok(Box::new(iter.iter().map(move |v| { + extract_items(Ok(v)).map_err(|e| map_err(py, e, iter.as_ref())) + }))) + } + GenericIterable::FrozenSet(iter) => { + Ok(Box::new(iter.iter().map(move |v| { + extract_items(Ok(v)).map_err(|e| map_err(py, e, iter.as_ref())) + }))) + } + // Note that we iterate over (key, value), unlike doing iter({}) in Python + GenericIterable::Dict(iter) => Ok(Box::new(iter.iter().map(Ok))), + // Keys or values can be tuples + GenericIterable::DictKeys(iter) => Ok(Box::new( + iter.map(extract_items) + .map(move |r| r.map_err(|e| map_err(py, e, iter.as_ref()))), + )), + GenericIterable::DictValues(iter) => Ok(Box::new( + iter.map(extract_items) + .map(move |r| r.map_err(|e| map_err(py, e, iter.as_ref()))), + )), + GenericIterable::DictItems(iter) => Ok(Box::new( + iter.map(extract_items) + .map(move |r| r.map_err(|e| map_err(py, e, iter.as_ref()))), + )), + // Note that we iterate over (key, value), unlike doing iter({}) in Python + GenericIterable::Mapping(iter) => Ok(Box::new( + iter.items()? + .iter()? + .map(extract_items) + .map(move |r| r.map_err(|e| map_err(py, e, iter.as_ref()))), + )), + // In Python if you do dict("foobar") you get "dictionary update sequence element #0 has length 1; 2 is required" + // This is similar but arguably a better error message + GenericIterable::String(_) => Err(PyTypeError::new_err( + "Expected an iterable of (key, value) pairs, got a string", + )), + GenericIterable::Bytes(_) => Err(PyTypeError::new_err( + "Expected an iterable of (key, value) pairs, got a bytes", + )), + GenericIterable::PyByteArray(_) => Err(PyTypeError::new_err( + "Expected an iterable of (key, value) pairs, got a bytearray", + )), + // Obviously these may be things that are not convertible to a tuple of (Hashable, Any) + // Python fails with a similar error message to above, ours will be slightly different (PyO3 will fail to extract) but similar enough + GenericIterable::Sequence(iter) => Ok(Box::new( + iter.iter()? + .map(extract_items) + .map(move |r| r.map_err(|e| map_err(py, e, iter.as_ref()))), + )), + GenericIterable::Iterator(iter) => Ok(Box::new( + iter.iter()? + .map(extract_items) + .map(move |r| r.map_err(|e| map_err(py, e, iter.as_ref()))), + )), + GenericIterable::JsonArray(iter) => Ok(Box::new( + iter.iter() + .map(move |v| extract_items(Ok(v.to_object(py).into_ref(py)))) + .map(move |r| r.map_err(|e| map_err(py, e, iter.to_object(py).into_ref(py)))), + )), + // Note that we iterate over (key, value), unlike doing iter({}) in Python + GenericIterable::JsonObject(iter) => Ok(Box::new(iter.iter().map(move |(k, v)| { + let k = PyString::new(py, k).as_ref(); + let v = v.to_object(py).into_ref(py); + Ok((k, v)) + }))), + } + } +} diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 79ab52308..779feab1d 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -7,8 +7,9 @@ use crate::errors::{InputValue, LocItem, ValResult}; use crate::{PyMultiHostUrl, PyUrl}; use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta}; +use super::generic_iterable::GenericIterable; use super::return_enums::{EitherBytes, EitherString}; -use super::{GenericArguments, GenericCollection, GenericIterator, GenericMapping, JsonInput}; +use super::{GenericArguments, GenericIterator, GenericMapping, JsonInput}; #[derive(Debug, Clone, Copy)] pub enum InputType { @@ -166,57 +167,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { self.validate_dict(strict) } - fn validate_list(&'a self, strict: bool, allow_any_iter: bool) -> ValResult> { - if strict && !allow_any_iter { - self.strict_list() - } else { - self.lax_list(allow_any_iter) - } - } - fn strict_list(&'a self) -> ValResult>; - #[cfg_attr(has_no_coverage, no_coverage)] - fn lax_list(&'a self, _allow_any_iter: bool) -> ValResult> { - self.strict_list() - } - - fn validate_tuple(&'a self, strict: bool) -> ValResult> { - if strict { - self.strict_tuple() - } else { - self.lax_tuple() - } - } - fn strict_tuple(&'a self) -> ValResult>; - #[cfg_attr(has_no_coverage, no_coverage)] - fn lax_tuple(&'a self) -> ValResult> { - self.strict_tuple() - } - - fn validate_set(&'a self, strict: bool) -> ValResult> { - if strict { - self.strict_set() - } else { - self.lax_set() - } - } - fn strict_set(&'a self) -> ValResult>; - #[cfg_attr(has_no_coverage, no_coverage)] - fn lax_set(&'a self) -> ValResult> { - self.strict_set() - } - - fn validate_frozenset(&'a self, strict: bool) -> ValResult> { - if strict { - self.strict_frozenset() - } else { - self.lax_frozenset() - } - } - fn strict_frozenset(&'a self) -> ValResult>; - #[cfg_attr(has_no_coverage, no_coverage)] - fn lax_frozenset(&'a self) -> ValResult> { - self.strict_frozenset() - } + fn extract_iterable(&'a self) -> ValResult>; fn validate_iter(&self) -> ValResult; diff --git a/src/input/input_json.rs b/src/input/input_json.rs index 9afddb088..4311767a9 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -10,8 +10,8 @@ use super::datetime::{ use super::parse_json::JsonArray; use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_int}; use super::{ - EitherBytes, EitherString, EitherTimedelta, GenericArguments, GenericCollection, GenericIterator, GenericMapping, - Input, JsonArgs, JsonInput, + EitherBytes, EitherString, EitherTimedelta, GenericArguments, GenericIterator, GenericMapping, Input, JsonArgs, + JsonInput, }; impl<'a> Input<'a> for JsonInput { @@ -187,52 +187,13 @@ impl<'a> Input<'a> for JsonInput { self.validate_dict(false) } - fn validate_list(&'a self, _strict: bool, _allow_any_iter: bool) -> ValResult> { + fn extract_iterable(&'a self) -> ValResult> { match self { - JsonInput::Array(a) => Ok(a.into()), - _ => Err(ValError::new(ErrorType::ListType, self)), - } - } - #[cfg_attr(has_no_coverage, no_coverage)] - fn strict_list(&'a self) -> ValResult> { - self.validate_list(false, false) - } - - fn validate_tuple(&'a self, _strict: bool) -> ValResult> { - // just as in set's case, List has to be allowed - match self { - JsonInput::Array(a) => Ok(a.into()), - _ => Err(ValError::new(ErrorType::TupleType, self)), - } - } - #[cfg_attr(has_no_coverage, no_coverage)] - fn strict_tuple(&'a self) -> ValResult> { - self.validate_tuple(false) - } - - fn validate_set(&'a self, _strict: bool) -> ValResult> { - // we allow a list here since otherwise it would be impossible to create a set from JSON - match self { - JsonInput::Array(a) => Ok(a.into()), - _ => Err(ValError::new(ErrorType::SetType, self)), + JsonInput::Array(a) => Ok(super::generic_iterable::GenericIterable::JsonArray(a)), + JsonInput::Object(o) => Ok(super::generic_iterable::GenericIterable::JsonObject(o)), + _ => Err(ValError::new(ErrorType::IterableType, self)), } } - #[cfg_attr(has_no_coverage, no_coverage)] - fn strict_set(&'a self) -> ValResult> { - self.validate_set(false) - } - - fn validate_frozenset(&'a self, _strict: bool) -> ValResult> { - // we allow a list here since otherwise it would be impossible to create a frozenset from JSON - match self { - JsonInput::Array(a) => Ok(a.into()), - _ => Err(ValError::new(ErrorType::FrozenSetType, self)), - } - } - #[cfg_attr(has_no_coverage, no_coverage)] - fn strict_frozenset(&'a self) -> ValResult> { - self.validate_frozenset(false) - } fn validate_iter(&self) -> ValResult { match self { @@ -404,40 +365,8 @@ impl<'a> Input<'a> for String { self.validate_dict(false) } - #[cfg_attr(has_no_coverage, no_coverage)] - fn validate_list(&'a self, _strict: bool, _allow_any_iter: bool) -> ValResult> { - Err(ValError::new(ErrorType::ListType, self)) - } - #[cfg_attr(has_no_coverage, no_coverage)] - fn strict_list(&'a self) -> ValResult> { - self.validate_list(false, false) - } - - #[cfg_attr(has_no_coverage, no_coverage)] - fn validate_tuple(&'a self, _strict: bool) -> ValResult> { - Err(ValError::new(ErrorType::TupleType, self)) - } - #[cfg_attr(has_no_coverage, no_coverage)] - fn strict_tuple(&'a self) -> ValResult> { - self.validate_tuple(false) - } - - #[cfg_attr(has_no_coverage, no_coverage)] - fn validate_set(&'a self, _strict: bool) -> ValResult> { - Err(ValError::new(ErrorType::SetType, self)) - } - #[cfg_attr(has_no_coverage, no_coverage)] - fn strict_set(&'a self) -> ValResult> { - self.validate_set(false) - } - - #[cfg_attr(has_no_coverage, no_coverage)] - fn validate_frozenset(&'a self, _strict: bool) -> ValResult> { - Err(ValError::new(ErrorType::FrozenSetType, self)) - } - #[cfg_attr(has_no_coverage, no_coverage)] - fn strict_frozenset(&'a self) -> ValResult> { - self.validate_frozenset(false) + fn extract_iterable(&'a self) -> ValResult> { + Err(ValError::new(ErrorType::IterableType, self)) } fn validate_iter(&self) -> ValResult { diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 021da881d..ff7356d02 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -5,12 +5,13 @@ use pyo3::once_cell::GILOnceCell; use pyo3::prelude::*; use pyo3::types::{ PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFrozenSet, PyInt, PyIterator, PyList, - PyMapping, PySet, PyString, PyTime, PyTuple, PyType, + PyMapping, PySequence, PySet, PyString, PyTime, PyTuple, PyType, }; #[cfg(not(PyPy))] use pyo3::types::{PyDictItems, PyDictKeys, PyDictValues}; use pyo3::{ffi, intern, AsPyPointer, PyTypeInfo}; +use super::generic_iterable::GenericIterable; use crate::build_tools::safe_repr; use crate::errors::{ErrorType, InputValue, LocItem, ValError, ValResult}; use crate::{ArgsKwargs, PyMultiHostUrl, PyUrl}; @@ -22,16 +23,18 @@ use super::datetime::{ }; use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_int}; use super::{ - py_string_str, EitherBytes, EitherString, EitherTimedelta, GenericArguments, GenericCollection, GenericIterator, - GenericMapping, Input, JsonInput, PyArgs, + py_string_str, EitherBytes, EitherString, EitherTimedelta, GenericArguments, GenericIterator, GenericMapping, + Input, JsonInput, PyArgs, }; -/// Extract generators and deques into a `GenericCollection` -macro_rules! extract_shared_iter { - ($type:ty, $obj:ident) => { - if $obj.downcast::().is_ok() { +#[cfg(PyPy)] +macro_rules! extract_dict_iter { + ($obj:ident) => { + if is_dict_keys_type($obj) { Some($obj.into()) - } else if is_deque($obj) { + } else if is_dict_values_type($obj) { + Some($obj.into()) + } else if is_dict_items_type($obj) { Some($obj.into()) } else { None @@ -39,31 +42,60 @@ macro_rules! extract_shared_iter { }; } -/// Extract dict keys, values and items into a `GenericCollection` #[cfg(not(PyPy))] -macro_rules! extract_dict_iter { - ($obj:ident) => { - if $obj.is_instance_of::().unwrap_or(false) { - Some($obj.into()) - } else if $obj.is_instance_of::().unwrap_or(false) { - Some($obj.into()) - } else if $obj.is_instance_of::().unwrap_or(false) { - Some($obj.into()) +macro_rules! extract_dict_keys { + ($py:expr, $obj:ident) => { + $obj.downcast::() + .ok() + .map(|v| PyIterator::from_object($py, v).unwrap()) + }; +} + +#[cfg(PyPy)] +macro_rules! extract_dict_keys { + ($py:expr, $obj:ident) => { + if is_dict_keys_type($obj) { + Some(PyIterator::from_object($py, $obj).unwrap()) + } else { + None + } + }; +} + +#[cfg(not(PyPy))] +macro_rules! extract_dict_values { + ($py:expr, $obj:ident) => { + $obj.downcast::() + .ok() + .map(|v| PyIterator::from_object($py, v).unwrap()) + }; +} + +#[cfg(PyPy)] +macro_rules! extract_dict_values { + ($py:expr, $obj:ident) => { + if is_dict_values_type($obj) { + Some(PyIterator::from_object($py, $obj).unwrap()) } else { None } }; } +#[cfg(not(PyPy))] +macro_rules! extract_dict_items { + ($py:expr, $obj:ident) => { + $obj.downcast::() + .ok() + .map(|v| PyIterator::from_object($py, v).unwrap()) + }; +} + #[cfg(PyPy)] -macro_rules! extract_dict_iter { - ($obj:ident) => { - if is_dict_keys_type($obj) { - Some($obj.into()) - } else if is_dict_values_type($obj) { - Some($obj.into()) - } else if is_dict_items_type($obj) { - Some($obj.into()) +macro_rules! extract_dict_items { + ($py:expr, $obj:ident) => { + if is_dict_items_type($obj) { + Some(PyIterator::from_object($py, $obj).unwrap()) } else { None } @@ -375,101 +407,38 @@ impl<'a> Input<'a> for PyAny { } } - fn strict_list(&'a self) -> ValResult> { - if let Ok(list) = self.downcast::() { - Ok(list.into()) + fn extract_iterable(&'a self) -> ValResult> { + // Handle concrete non-overlapping types first, then abstract types + if let Ok(iterable) = self.downcast::() { + Ok(GenericIterable::List(iterable)) + } else if let Ok(iterable) = self.downcast::() { + Ok(GenericIterable::Tuple(iterable)) + } else if let Ok(iterable) = self.downcast::() { + Ok(GenericIterable::Set(iterable)) + } else if let Ok(iterable) = self.downcast::() { + Ok(GenericIterable::FrozenSet(iterable)) + } else if let Ok(iterable) = self.downcast::() { + Ok(GenericIterable::Dict(iterable)) + } else if let Some(iterable) = extract_dict_keys!(self.py(), self) { + Ok(GenericIterable::DictKeys(iterable)) + } else if let Some(iterable) = extract_dict_values!(self.py(), self) { + Ok(GenericIterable::DictValues(iterable)) + } else if let Some(iterable) = extract_dict_items!(self.py(), self) { + Ok(GenericIterable::DictItems(iterable)) + } else if let Ok(iterable) = self.downcast::() { + Ok(GenericIterable::Mapping(iterable)) + } else if let Ok(iterable) = self.downcast::() { + Ok(GenericIterable::String(iterable)) + } else if let Ok(iterable) = self.downcast::() { + Ok(GenericIterable::Bytes(iterable)) + } else if let Ok(iterable) = self.downcast::() { + Ok(GenericIterable::PyByteArray(iterable)) + } else if let Ok(iterable) = self.downcast::() { + Ok(GenericIterable::Sequence(iterable)) + } else if let Ok(iterable) = self.iter() { + Ok(GenericIterable::Iterator(iterable)) } else { - Err(ValError::new(ErrorType::ListType, self)) - } - } - - fn lax_list(&'a self, allow_any_iter: bool) -> ValResult> { - if let Ok(list) = self.downcast::() { - Ok(list.into()) - } else if let Ok(tuple) = self.downcast::() { - Ok(tuple.into()) - } else if let Some(collection) = extract_dict_iter!(self) { - Ok(collection) - } else if allow_any_iter && self.iter().is_ok() { - Ok(self.into()) - } else if let Some(collection) = extract_shared_iter!(PyList, self) { - Ok(collection) - } else { - Err(ValError::new(ErrorType::ListType, self)) - } - } - - fn strict_tuple(&'a self) -> ValResult> { - if let Ok(tuple) = self.downcast::() { - Ok(tuple.into()) - } else { - Err(ValError::new(ErrorType::TupleType, self)) - } - } - - fn lax_tuple(&'a self) -> ValResult> { - if let Ok(tuple) = self.downcast::() { - Ok(tuple.into()) - } else if let Ok(list) = self.downcast::() { - Ok(list.into()) - } else if let Some(collection) = extract_dict_iter!(self) { - Ok(collection) - } else if let Some(collection) = extract_shared_iter!(PyTuple, self) { - Ok(collection) - } else { - Err(ValError::new(ErrorType::TupleType, self)) - } - } - - fn strict_set(&'a self) -> ValResult> { - if let Ok(set) = self.downcast::() { - Ok(set.into()) - } else { - Err(ValError::new(ErrorType::SetType, self)) - } - } - - fn lax_set(&'a self) -> ValResult> { - if let Ok(set) = self.downcast::() { - Ok(set.into()) - } else if let Ok(list) = self.downcast::() { - Ok(list.into()) - } else if let Ok(tuple) = self.downcast::() { - Ok(tuple.into()) - } else if let Ok(frozen_set) = self.downcast::() { - Ok(frozen_set.into()) - } else if let Some(collection) = extract_dict_iter!(self) { - Ok(collection) - } else if let Some(collection) = extract_shared_iter!(PyTuple, self) { - Ok(collection) - } else { - Err(ValError::new(ErrorType::SetType, self)) - } - } - - fn strict_frozenset(&'a self) -> ValResult> { - if let Ok(set) = self.downcast::() { - Ok(set.into()) - } else { - Err(ValError::new(ErrorType::FrozenSetType, self)) - } - } - - fn lax_frozenset(&'a self) -> ValResult> { - if let Ok(frozen_set) = self.downcast::() { - Ok(frozen_set.into()) - } else if let Ok(set) = self.downcast::() { - Ok(set.into()) - } else if let Ok(list) = self.downcast::() { - Ok(list.into()) - } else if let Ok(tuple) = self.downcast::() { - Ok(tuple.into()) - } else if let Some(collection) = extract_dict_iter!(self) { - Ok(collection) - } else if let Some(collection) = extract_shared_iter!(PyTuple, self) { - Ok(collection) - } else { - Err(ValError::new(ErrorType::FrozenSetType, self)) + Err(ValError::new(ErrorType::IterableType, self)) } } @@ -623,21 +592,6 @@ fn maybe_as_string(v: &PyAny, unicode_error: ErrorType) -> ValResult> = GILOnceCell::new(); - -fn is_deque(v: &PyAny) -> bool { - let py = v.py(); - let deque_type = DEQUE_TYPE - .get_or_init(py, || import_type(py, "collections", "deque").unwrap()) - .as_ref(py); - v.is_instance(deque_type).unwrap_or(false) -} - -fn import_type(py: Python, module: &str, attr: &str) -> PyResult> { - let obj = py.import(module)?.getattr(attr)?; - Ok(obj.downcast::()?.into()) -} - fn is_builtin_str(py_str: &PyString) -> bool { py_str.get_type().is(PyString::type_object(py_str.py())) } diff --git a/src/input/iterator.rs b/src/input/iterator.rs new file mode 100644 index 000000000..a84880d82 --- /dev/null +++ b/src/input/iterator.rs @@ -0,0 +1,144 @@ +use std::marker::PhantomData; + +use pyo3::{PyObject, PyResult, Python}; + +use super::Input; + +use crate::validators::Validator; +use crate::{ + definitions::Definitions, + errors::{ErrorType, ValError, ValLineError, ValResult}, + recursion_guard::RecursionGuard, + validators::CombinedValidator, + validators::Extra, +}; + +pub fn calculate_output_init_capacity(iterator_size: Option, max_length: Option) -> usize { + // The smaller number of either the input size or the max output length + match (iterator_size, max_length) { + (None, _) => 0, + (Some(l), None) => l, + (Some(l), Some(r)) => std::cmp::min(l, r), + } +} + +#[derive(Debug, Clone)] +pub struct LengthConstraints { + pub min_length: usize, + pub max_length: Option, +} + +pub struct IterableValidationChecks<'data, I> { + output_length: usize, + min_length: usize, + max_length: Option, + field_type: &'static str, + errors: Vec>, + p: PhantomData, +} + +impl<'data, I: Input<'data> + 'data> IterableValidationChecks<'data, I> { + pub fn new(length_constraints: LengthConstraints, field_type: &'static str) -> Self { + Self { + output_length: 0, + min_length: length_constraints.min_length, + max_length: length_constraints.max_length, + field_type, + errors: vec![], + p: PhantomData, + } + } + pub fn add_error(&mut self, error: ValLineError<'data>) { + self.errors.push(error) + } + pub fn filter_validation_result( + &mut self, + result: ValResult<'data, R>, + input: &'data I, + ) -> ValResult<'data, Option> { + match result { + Ok(v) => Ok(Some(v)), + Err(ValError::LineErrors(line_errors)) => { + self.errors.extend(line_errors); + if let Some(max_length) = self.max_length { + self.check_max_length(self.output_length + self.errors.len(), max_length, input)?; + } + Ok(None) + } + Err(ValError::Omit) => Ok(None), + Err(e) => Err(e), + } + } + pub fn check_output_length(&mut self, output_length: usize, input: &'data I) -> ValResult<'data, ()> { + self.output_length = output_length; + if let Some(max_length) = self.max_length { + self.check_max_length(output_length + self.errors.len(), max_length, input)?; + } + Ok(()) + } + pub fn finish(&mut self, input: &'data I) -> ValResult<'data, ()> { + if self.min_length > self.output_length { + let err = ValLineError::new( + ErrorType::TooShort { + field_type: self.field_type.to_string(), + min_length: self.min_length, + actual_length: self.output_length, + }, + input, + ); + self.errors.push(err); + } + if self.errors.is_empty() { + Ok(()) + } else { + Err(ValError::LineErrors(std::mem::take(&mut self.errors))) + } + } + fn check_max_length(&self, current_length: usize, max_length: usize, input: &'data I) -> ValResult<'data, ()> { + if max_length < current_length { + return Err(ValError::new( + ErrorType::TooLong { + field_type: self.field_type.to_string(), + max_length, + actual_length: current_length, + }, + input, + )); + } + Ok(()) + } +} + +#[allow(clippy::too_many_arguments)] +pub fn validate_iterator<'s, 'data, V, O, W, L, I>( + py: Python<'data>, + input: &'data I, + extra: &'s Extra<'s>, + definitions: &'data Definitions, + recursion_guard: &'s mut RecursionGuard, + checks: &mut IterableValidationChecks<'data, I>, + iter: impl Iterator>, + items_validator: &'s CombinedValidator, + output: &mut O, + write: &mut W, + len: &L, +) -> ValResult<'data, ()> +where + I: Input<'data> + 'data, + V: Input<'data> + 'data, + W: FnMut(&mut O, PyObject) -> PyResult<()>, + L: Fn(&O) -> usize, +{ + for (index, result) in iter.enumerate() { + let value = result?; + let result = items_validator + .validate(py, value, extra, definitions, recursion_guard) + .map_err(|e| e.with_outer_location(index.into())); + if let Some(value) = checks.filter_validation_result(result, input)? { + write(output, value)?; + checks.check_output_length(len(output), input)?; + } + } + checks.finish(input)?; + Ok(()) +} diff --git a/src/input/mod.rs b/src/input/mod.rs index 19144b2f6..48ad06632 100644 --- a/src/input/mod.rs +++ b/src/input/mod.rs @@ -3,9 +3,11 @@ use std::os::raw::c_int; use pyo3::prelude::*; mod datetime; +mod generic_iterable; mod input_abstract; mod input_json; mod input_python; +pub mod iterator; mod parse_json; mod return_enums; mod shared; @@ -14,12 +16,12 @@ pub(crate) use datetime::{ pydate_as_date, pydatetime_as_datetime, pytime_as_time, pytimedelta_as_duration, EitherDate, EitherDateTime, EitherTime, EitherTimedelta, }; +pub(crate) use generic_iterable::GenericIterable; pub(crate) use input_abstract::{Input, InputType}; pub(crate) use parse_json::{JsonInput, JsonObject}; pub(crate) use return_enums::{ py_string_str, AttributesGenericIterator, DictGenericIterator, EitherBytes, EitherString, GenericArguments, - GenericCollection, GenericIterator, GenericMapping, JsonArgs, JsonObjectGenericIterator, MappingGenericIterator, - PyArgs, + GenericIterator, GenericMapping, JsonArgs, JsonObjectGenericIterator, MappingGenericIterator, PyArgs, }; // Defined here as it's not exported by pyo3 diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index f8741a7d4..3e12f159d 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -4,33 +4,18 @@ use std::slice::Iter as SliceIter; use pyo3::intern; use pyo3::prelude::*; use pyo3::types::iter::PyDictIterator; -use pyo3::types::{PyBytes, PyDict, PyFrozenSet, PyIterator, PyList, PyMapping, PySet, PyString, PyTuple}; +use pyo3::types::{PyBytes, PyDict, PyIterator, PyList, PyMapping, PyString, PyTuple}; #[cfg(not(PyPy))] use pyo3::types::PyFunction; #[cfg(not(PyPy))] use pyo3::PyTypeInfo; -use crate::errors::{py_err_string, ErrorType, InputValue, ValError, ValLineError, ValResult}; -use crate::recursion_guard::RecursionGuard; -use crate::validators::{CombinedValidator, Extra, Validator}; +use crate::errors::{py_err_string, ErrorType, InputValue, ValError, ValResult}; use super::parse_json::{JsonArray, JsonInput, JsonObject}; use super::Input; -/// Container for all the collections (sized iterable containers) types, which -/// can mostly be converted to each other in lax mode. -/// This mostly matches python's definition of `Collection`. -#[cfg_attr(debug_assertions, derive(Debug))] -pub enum GenericCollection<'a> { - List(&'a PyList), - Tuple(&'a PyTuple), - Set(&'a PySet), - FrozenSet(&'a PyFrozenSet), - PyAny(&'a PyAny), - JsonArray(&'a [JsonInput]), -} - macro_rules! derive_from { ($enum:ident, $key:ident, $type:ty $(, $extra_types:ident )*) => { impl<'a> From<&'a $type> for $enum<'a> { @@ -40,203 +25,6 @@ macro_rules! derive_from { } }; } -derive_from!(GenericCollection, List, PyList); -derive_from!(GenericCollection, Tuple, PyTuple); -derive_from!(GenericCollection, Set, PySet); -derive_from!(GenericCollection, FrozenSet, PyFrozenSet); -derive_from!(GenericCollection, PyAny, PyAny); -derive_from!(GenericCollection, JsonArray, JsonArray); -derive_from!(GenericCollection, JsonArray, [JsonInput]); - -fn validate_iter_to_vec<'a, 's>( - py: Python<'a>, - iter: impl Iterator + 'a)>, - capacity: usize, - validator: &'s CombinedValidator, - extra: &Extra, - definitions: &'a [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, -) -> ValResult<'a, Vec> { - let mut output: Vec = Vec::with_capacity(capacity); - let mut errors: Vec = Vec::new(); - for (index, item) in iter.enumerate() { - match validator.validate(py, item, extra, definitions, recursion_guard) { - Ok(item) => output.push(item), - Err(ValError::LineErrors(line_errors)) => { - errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index.into()))); - } - Err(ValError::Omit) => (), - Err(err) => return Err(err), - } - } - - if errors.is_empty() { - Ok(output) - } else { - Err(ValError::LineErrors(errors)) - } -} - -macro_rules! any_next_error { - ($py:expr, $err:ident, $input:ident, $index:ident) => { - ValError::new_with_loc( - ErrorType::IterationError { - error: py_err_string($py, $err), - }, - $input, - $index, - ) - }; -} - -macro_rules! generator_too_long { - ($input:ident, $index:ident, $max_length:expr, $field_type:ident) => { - if let Some(max_length) = $max_length { - if $index > max_length { - return Err(ValError::new( - ErrorType::TooLong { - field_type: $field_type.to_string(), - max_length, - actual_length: $index, - }, - $input, - )); - } - } - }; -} - -// pretty arbitrary default capacity when creating vecs from iteration -static DEFAULT_CAPACITY: usize = 10; - -impl<'a> GenericCollection<'a> { - pub fn generic_len(&self) -> PyResult { - match self { - Self::List(v) => Ok(v.len()), - Self::Tuple(v) => Ok(v.len()), - Self::Set(v) => Ok(v.len()), - Self::FrozenSet(v) => Ok(v.len()), - Self::PyAny(v) => v.len(), - Self::JsonArray(v) => Ok(v.len()), - } - } - - #[allow(clippy::too_many_arguments)] - pub fn validate_to_vec<'s>( - &'s self, - py: Python<'a>, - input: &'a impl Input<'a>, - max_length: Option, - field_type: &'static str, - generator_max_length: Option, - validator: &'s CombinedValidator, - extra: &Extra, - definitions: &'a [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, - ) -> ValResult<'a, Vec> { - let capacity = self - .generic_len() - .unwrap_or_else(|_| max_length.unwrap_or(DEFAULT_CAPACITY)); - match self { - Self::List(collection) => validate_iter_to_vec( - py, - collection.iter(), - capacity, - validator, - extra, - definitions, - recursion_guard, - ), - Self::Tuple(collection) => validate_iter_to_vec( - py, - collection.iter(), - capacity, - validator, - extra, - definitions, - recursion_guard, - ), - Self::Set(collection) => validate_iter_to_vec( - py, - collection.iter(), - capacity, - validator, - extra, - definitions, - recursion_guard, - ), - Self::FrozenSet(collection) => validate_iter_to_vec( - py, - collection.iter(), - capacity, - validator, - extra, - definitions, - recursion_guard, - ), - Self::PyAny(collection) => { - let iter = collection.iter()?; - let mut output: Vec = Vec::with_capacity(capacity); - let mut errors: Vec = Vec::new(); - for (index, item_result) in iter.enumerate() { - let item = item_result.map_err(|e| any_next_error!(collection.py(), e, input, index))?; - match validator.validate(py, item, extra, definitions, recursion_guard) { - Ok(item) => { - generator_too_long!(input, index, generator_max_length, field_type); - output.push(item); - } - Err(ValError::LineErrors(line_errors)) => { - errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index.into()))); - } - Err(ValError::Omit) => (), - Err(err) => return Err(err), - } - } - // TODO do too small check here - - if errors.is_empty() { - Ok(output) - } else { - Err(ValError::LineErrors(errors)) - } - } - Self::JsonArray(collection) => validate_iter_to_vec( - py, - collection.iter(), - capacity, - validator, - extra, - definitions, - recursion_guard, - ), - } - } - - pub fn to_vec<'s>( - &'s self, - py: Python<'a>, - input: &'a impl Input<'a>, - field_type: &'static str, - generator_max_length: Option, - ) -> ValResult<'a, Vec> { - match self { - Self::List(collection) => Ok(collection.iter().map(|i| i.to_object(py)).collect()), - Self::Tuple(collection) => Ok(collection.iter().map(|i| i.to_object(py)).collect()), - Self::Set(collection) => Ok(collection.iter().map(|i| i.to_object(py)).collect()), - Self::FrozenSet(collection) => Ok(collection.iter().map(|i| i.to_object(py)).collect()), - Self::PyAny(collection) => collection - .iter()? - .enumerate() - .map(|(index, item_result)| { - generator_too_long!(input, index, generator_max_length, field_type); - let item = item_result.map_err(|e| any_next_error!(collection.py(), e, input, index))?; - Ok(item.to_object(py)) - }) - .collect(), - Self::JsonArray(collection) => Ok(collection.iter().map(|i| i.to_object(py)).collect()), - } - } -} #[cfg_attr(debug_assertions, derive(Debug))] pub enum GenericMapping<'a> { diff --git a/src/validators/dict.rs b/src/validators/dict.rs index 23d196005..14a37657d 100644 --- a/src/validators/dict.rs +++ b/src/validators/dict.rs @@ -1,16 +1,16 @@ use pyo3::intern; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyMapping}; +use pyo3::types::PyDict; use crate::build_tools::{is_strict, SchemaDict}; -use crate::errors::{ValError, ValLineError, ValResult}; -use crate::input::{ - DictGenericIterator, GenericMapping, Input, JsonObject, JsonObjectGenericIterator, MappingGenericIterator, -}; + +use crate::errors::{ErrorType, ValError, ValResult}; +use crate::input::iterator::IterableValidationChecks; +use crate::input::Input; +use crate::input::{iterator, GenericIterable}; use crate::recursion_guard::RecursionGuard; use super::any::AnyValidator; -use super::list::length_check; use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; #[derive(Debug, Clone)] @@ -58,28 +58,143 @@ impl BuildValidator for DictValidator { } } +const FIELD_TYPE: &str = "Dictionary"; + +#[allow(clippy::too_many_arguments)] +fn validation_function<'s, 'data, K, V>( + py: Python<'data>, + extra: &'s Extra<'s>, + definitions: &'data Definitions, + recursion_guard: &'s mut RecursionGuard, + key_validator: &'s CombinedValidator, + value_validator: &'s CombinedValidator, + key: &'data K, + value: &'data V, +) -> ValResult<'data, (PyObject, PyObject)> +where + K: Input<'data>, + V: Input<'data>, +{ + let v_key = key_validator + .validate(py, key, extra, definitions, recursion_guard) + .map_err(|e| { + e.with_outer_location("[key]".into()) + .with_outer_location(key.as_loc_item()) + })?; + let v_value = value_validator + .validate(py, value, extra, definitions, recursion_guard) + .map_err(|e| e.with_outer_location(key.as_loc_item()))?; + Ok((v_key, v_value)) +} + +#[allow(clippy::too_many_arguments)] +fn validate_mapping<'s, 'data, K, V, I>( + py: Python<'data>, + input: &'data I, + extra: &'s Extra<'s>, + definitions: &'data Definitions, + recursion_guard: &'s mut RecursionGuard, + checks: &mut IterableValidationChecks<'data, I>, + iter: impl Iterator>, + key_validator: &'s CombinedValidator, + value_validator: &'s CombinedValidator, + output: &'data PyDict, +) -> ValResult<'data, ()> +where + I: Input<'data> + 'data, + K: Input<'data> + 'data, + V: Input<'data> + 'data, +{ + for result in iter { + let (key, value) = result?; + let result = validation_function( + py, + extra, + definitions, + recursion_guard, + key_validator, + value_validator, + key, + value, + ); + if let Some((key, value)) = checks.filter_validation_result(result, input)? { + output.set_item(key, value)?; + checks.check_output_length(output.len(), input)?; + } + } + checks.finish(input)?; + Ok(()) +} + impl Validator for DictValidator { fn validate<'s, 'data>( &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, + extra: &'s Extra<'s>, definitions: &'data Definitions, recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - let dict = input.validate_dict(extra.strict.unwrap_or(self.strict))?; - match dict { - GenericMapping::PyDict(py_dict) => { - self.validate_dict(py, input, py_dict, extra, definitions, recursion_guard) - } - GenericMapping::PyMapping(mapping) => { - self.validate_mapping(py, input, mapping, extra, definitions, recursion_guard) - } - GenericMapping::PyGetAttr(_, _) => unreachable!(), - GenericMapping::JsonObject(json_object) => { - self.validate_json_object(py, input, json_object, extra, definitions, recursion_guard) - } - } + let strict = extra.strict.unwrap_or(self.strict); + + let length_constraints = iterator::LengthConstraints { + min_length: self.min_length.unwrap_or_default(), + max_length: self.max_length, + }; + + let mut checks = IterableValidationChecks::new(length_constraints, FIELD_TYPE); + + let output = PyDict::new(py); + + let generic_iterable = input + .extract_iterable() + .map_err(|_| ValError::new(ErrorType::DictType, input))?; + match (generic_iterable, strict) { + // Always allow actual dicts or JSON objects + (GenericIterable::Dict(iter), _) => validate_mapping( + py, + input, + extra, + definitions, + recursion_guard, + &mut checks, + iter.iter().map(Ok), + &self.key_validator, + &self.value_validator, + output, + )?, + (GenericIterable::JsonObject(iter), _) => validate_mapping( + py, + input, + extra, + definitions, + recursion_guard, + &mut checks, + iter.iter().map(|(k, v)| (k, v)).map(Ok), + &self.key_validator, + &self.value_validator, + output, + )?, + // If we're not in strict mode, accept other iterables, equivalent to calling dict(thing) + (generic_iterable, false) => match generic_iterable.into_mapping_items_iterator(py) { + Ok(iter) => validate_mapping( + py, + input, + extra, + definitions, + recursion_guard, + &mut checks, + iter, + &self.key_validator, + &self.value_validator, + output, + )?, + Err(_) => return Err(ValError::new(ErrorType::DictType, input)), + }, + _ => return Err(ValError::new(ErrorType::DictType, input)), + }; + + Ok(output.into_py(py)) } fn different_strict_behavior( @@ -104,68 +219,3 @@ impl Validator for DictValidator { self.value_validator.complete(definitions) } } - -macro_rules! build_validate { - ($name:ident, $dict_type:ty, $iter:ty) => { - fn $name<'s, 'data>( - &'s self, - py: Python<'data>, - input: &'data impl Input<'data>, - dict: &'data $dict_type, - extra: &Extra, - definitions: &'data Definitions, - recursion_guard: &'s mut RecursionGuard, - ) -> ValResult<'data, PyObject> { - let output = PyDict::new(py); - let mut errors: Vec = Vec::new(); - - let key_validator = self.key_validator.as_ref(); - let value_validator = self.value_validator.as_ref(); - for item_result in <$iter>::new(dict)? { - let (key, value) = item_result?; - let output_key = match key_validator.validate(py, key, extra, definitions, recursion_guard) { - Ok(value) => Some(value), - Err(ValError::LineErrors(line_errors)) => { - for err in line_errors { - // these are added in reverse order so [key] is shunted along by the second call - errors.push( - err.with_outer_location("[key]".into()) - .with_outer_location(key.as_loc_item()), - ); - } - None - } - Err(ValError::Omit) => continue, - Err(err) => return Err(err), - }; - let output_value = match value_validator.validate(py, value, extra, definitions, recursion_guard) { - Ok(value) => Some(value), - Err(ValError::LineErrors(line_errors)) => { - for err in line_errors { - errors.push(err.with_outer_location(key.as_loc_item())); - } - None - } - Err(ValError::Omit) => continue, - Err(err) => return Err(err), - }; - if let (Some(key), Some(value)) = (output_key, output_value) { - output.set_item(key, value)?; - } - } - - if errors.is_empty() { - length_check!(input, "Dictionary", self.min_length, self.max_length, output); - Ok(output.into()) - } else { - Err(ValError::LineErrors(errors)) - } - } - }; -} - -impl DictValidator { - build_validate!(validate_dict, PyDict, DictGenericIterator); - build_validate!(validate_mapping, PyMapping, MappingGenericIterator); - build_validate!(validate_json_object, JsonObject, JsonObjectGenericIterator); -} diff --git a/src/validators/frozenset.rs b/src/validators/frozenset.rs index 368b9d15f..05e1de4fc 100644 --- a/src/validators/frozenset.rs +++ b/src/validators/frozenset.rs @@ -1,22 +1,25 @@ -use pyo3::prelude::*; use pyo3::types::{PyDict, PyFrozenSet}; +use pyo3::{ffi, prelude::*, AsPyPointer}; +use super::any::AnyValidator; +use super::set::set_build; +use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; use crate::build_tools::SchemaDict; +use crate::errors::ErrorType; use crate::errors::ValResult; -use crate::input::{GenericCollection, Input}; +use crate::input::iterator::{validate_iterator, IterableValidationChecks, LengthConstraints}; +use crate::input::Input; +use crate::input::{py_error_on_minusone, GenericIterable}; use crate::recursion_guard::RecursionGuard; -use super::list::{get_items_schema, length_check}; -use super::set::set_build; -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use crate::errors::ValError; #[derive(Debug, Clone)] pub struct FrozenSetValidator { strict: bool, - item_validator: Option>, - min_length: Option, + item_validator: Box, + min_length: usize, max_length: Option, - generator_max_length: Option, name: String, } @@ -25,6 +28,14 @@ impl BuildValidator for FrozenSetValidator { set_build!(); } +// See https://github.com/PyO3/pyo3/pull/3156 +fn frozen_set_add(set: &PyFrozenSet, key: K) -> PyResult<()> +where + K: ToPyObject, +{ + unsafe { py_error_on_minusone(set.py(), ffi::PySet_Add(set.as_ptr(), key.to_object(set.py()).as_ptr())) } +} + impl Validator for FrozenSetValidator { fn validate<'s, 'data>( &'s self, @@ -34,30 +45,85 @@ impl Validator for FrozenSetValidator { definitions: &'data Definitions, recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - let seq = input.validate_frozenset(extra.strict.unwrap_or(self.strict))?; + let create_err = |input| ValError::new(ErrorType::FrozenSetType, input); + + let field_type = "Frozenset"; + + let generic_iterable = input.extract_iterable().map_err(|_| create_err(input))?; + + let strict = extra.strict.unwrap_or(self.strict); + + let length_constraints = LengthConstraints { + min_length: self.min_length, + max_length: self.max_length, + }; - let f_set = match self.item_validator { - Some(ref v) => PyFrozenSet::new( + let mut checks = IterableValidationChecks::new(length_constraints, field_type); + + let mut output = PyFrozenSet::empty(py)?; + + let len = |output: &&'data PyFrozenSet| output.len(); + let mut write = |output: &mut &'data PyFrozenSet, ob: PyObject| frozen_set_add(output, ob); + + match (generic_iterable, strict) { + // Always allow actual lists or JSON arrays + (GenericIterable::JsonArray(iter), _) => validate_iterator( + py, + input, + extra, + definitions, + recursion_guard, + &mut checks, + iter.iter().map(Ok), + &self.item_validator, + &mut output, + &mut write, + &len, + )?, + (GenericIterable::FrozenSet(iter), _) => validate_iterator( py, - &seq.validate_to_vec( + input, + extra, + definitions, + recursion_guard, + &mut checks, + iter.iter().map(Ok), + &self.item_validator, + &mut output, + &mut write, + &len, + )?, + // If not in strict mode we also accept any iterable except str, bytes or mappings + // This may seem counterintuitive since a Mapping is a less generic type than an arbitrary + // iterable (which we do accept) but doing something like `x: list[int] = {1: 'a'}` is commonly + // a mistake, so we don't parse it by default + ( + GenericIterable::String(_) + | GenericIterable::Bytes(_) + | GenericIterable::Dict(_) + | GenericIterable::Mapping(_), + false, + ) => return Err(create_err(input)), + (generic_iterable, false) => match generic_iterable.into_sequence_iterator(py) { + Ok(iter) => validate_iterator( py, input, - self.max_length, - "Frozenset", - self.generator_max_length, - v, extra, definitions, recursion_guard, + &mut checks, + iter, + &self.item_validator, + &mut output, + &mut write, + &len, )?, - )?, - None => match seq { - GenericCollection::FrozenSet(f_set) => f_set, - _ => PyFrozenSet::new(py, &seq.to_vec(py, input, "Frozenset", self.generator_max_length)?)?, + Err(_) => return Err(create_err(input)), }, + _ => return Err(create_err(input)), }; - length_check!(input, "Frozenset", self.min_length, self.max_length, f_set); - Ok(f_set.into_py(py)) + + Ok(output.into_py(py)) } fn different_strict_behavior( @@ -66,10 +132,7 @@ impl Validator for FrozenSetValidator { ultra_strict: bool, ) -> bool { if ultra_strict { - match self.item_validator { - Some(ref v) => v.different_strict_behavior(definitions, true), - None => false, - } + self.item_validator.different_strict_behavior(definitions, true) } else { true } @@ -80,9 +143,6 @@ impl Validator for FrozenSetValidator { } fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - match self.item_validator { - Some(ref mut v) => v.complete(definitions), - None => Ok(()), - } + self.item_validator.complete(definitions) } } diff --git a/src/validators/list.rs b/src/validators/list.rs index c12de8b55..7527805d3 100644 --- a/src/validators/list.rs +++ b/src/validators/list.rs @@ -2,18 +2,20 @@ use pyo3::prelude::*; use pyo3::types::PyDict; use crate::build_tools::SchemaDict; -use crate::errors::ValResult; -use crate::input::{GenericCollection, Input}; +use crate::errors::{ErrorType, ValError, ValResult}; +use crate::input::iterator::{calculate_output_init_capacity, validate_iterator, IterableValidationChecks}; +use crate::input::Input; +use crate::input::{iterator::LengthConstraints, GenericIterable}; use crate::recursion_guard::RecursionGuard; +use super::any::AnyValidator; use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; #[derive(Debug, Clone)] pub struct ListValidator { strict: bool, - allow_any_iter: bool, - item_validator: Option>, - min_length: Option, + item_validator: Box, + min_length: usize, max_length: Option, name: String, } @@ -35,40 +37,6 @@ pub fn get_items_schema( } } -macro_rules! length_check { - ($input:ident, $field_type:literal, $min_length:expr, $max_length:expr, $obj:ident) => {{ - let mut op_actual_length: Option = None; - if let Some(min_length) = $min_length { - let actual_length = $obj.len(); - if actual_length < min_length { - return Err(crate::errors::ValError::new( - crate::errors::ErrorType::TooShort { - field_type: $field_type.to_string(), - min_length, - actual_length, - }, - $input, - )); - } - op_actual_length = Some(actual_length); - } - if let Some(max_length) = $max_length { - let actual_length = op_actual_length.unwrap_or_else(|| $obj.len()); - if actual_length > max_length { - return Err(crate::errors::ValError::new( - crate::errors::ErrorType::TooLong { - field_type: $field_type.to_string(), - max_length, - actual_length, - }, - $input, - )); - } - } - }}; -} -pub(crate) use length_check; - impl BuildValidator for ListValidator { const EXPECTED_TYPE: &'static str = "list"; @@ -78,14 +46,16 @@ impl BuildValidator for ListValidator { definitions: &mut DefinitionsBuilder, ) -> PyResult { let py = schema.py(); - let item_validator = get_items_schema(schema, config, definitions)?; - let inner_name = item_validator.as_ref().map(|v| v.get_name()).unwrap_or("any"); + let item_validator = match schema.get_item(pyo3::intern!(schema.py(), "items_schema")) { + Some(d) => build_validator(d, config, definitions)?, + None => CombinedValidator::Any(AnyValidator), + }; + let inner_name = item_validator.get_name(); let name = format!("{}[{inner_name}]", Self::EXPECTED_TYPE); Ok(Self { strict: crate::build_tools::is_strict(schema, config)?, - allow_any_iter: schema.get_as(pyo3::intern!(py, "allow_any_iter"))?.unwrap_or(false), - item_validator, - min_length: schema.get_as(pyo3::intern!(py, "min_length"))?, + item_validator: Box::new(item_validator), + min_length: schema.get_as(pyo3::intern!(py, "min_length"))?.unwrap_or_default(), max_length: schema.get_as(pyo3::intern!(py, "max_length"))?, name, } @@ -93,38 +63,96 @@ impl BuildValidator for ListValidator { } } +const FIELD_TYPE: &str = "List"; + impl Validator for ListValidator { fn validate<'s, 'data>( &'s self, py: Python<'data>, input: &'data impl Input<'data>, - extra: &Extra, + extra: &'s Extra<'s>, definitions: &'data Definitions, recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - let seq = input.validate_list(extra.strict.unwrap_or(self.strict), self.allow_any_iter)?; + let generic_iterable = input + .extract_iterable() + .map_err(|_| ValError::new(ErrorType::ListType, input))?; + + let strict = extra.strict.unwrap_or(self.strict); + + let length_constraints = LengthConstraints { + min_length: self.min_length, + max_length: self.max_length, + }; + + let mut checks = IterableValidationChecks::new(length_constraints, FIELD_TYPE); - let output = match self.item_validator { - Some(ref v) => seq.validate_to_vec( + let mut output: Vec = + Vec::with_capacity(calculate_output_init_capacity(generic_iterable.len(), self.max_length)); + + let len = |output: &Vec| output.len(); + let mut write = |output: &mut Vec, ob: PyObject| { + output.push(ob); + Ok(()) + }; + + match (generic_iterable, strict) { + // Always allow actual lists or JSON arrays + (GenericIterable::JsonArray(iter), _) => validate_iterator( py, input, - self.max_length, - "List", - self.max_length, - v, extra, definitions, recursion_guard, + &mut checks, + iter.iter().map(Ok), + &self.item_validator, + &mut output, + &mut write, + &len, )?, - None => match seq { - GenericCollection::List(list) => { - length_check!(input, "List", self.min_length, self.max_length, list); - return Ok(list.into_py(py)); - } - _ => seq.to_vec(py, input, "List", self.max_length)?, + (GenericIterable::List(iter), _) => validate_iterator( + py, + input, + extra, + definitions, + recursion_guard, + &mut checks, + iter.iter().map(Ok), + &self.item_validator, + &mut output, + &mut write, + &len, + )?, + // If not in strict mode we also accept any iterable except str, bytes or mappings + // This may seem counterintuitive since a Mapping is a less generic type than an arbitrary + // iterable (which we do accept) but doing something like `x: list[int] = {1: 'a'}` is commonly + // a mistake, so we don't parse it by default + ( + GenericIterable::String(_) + | GenericIterable::Bytes(_) + | GenericIterable::Dict(_) + | GenericIterable::Mapping(_), + _, + ) => return Err(ValError::new(ErrorType::ListType, input)), + (generic_iterable, false) => match generic_iterable.into_sequence_iterator(py) { + Ok(iter) => validate_iterator( + py, + input, + extra, + definitions, + recursion_guard, + &mut checks, + iter, + &self.item_validator, + &mut output, + &mut write, + &len, + )?, + Err(_) => return Err(ValError::new(ErrorType::ListType, input)), }, + _ => return Err(ValError::new(ErrorType::ListType, input)), }; - length_check!(input, "List", self.min_length, self.max_length, output); Ok(output.into_py(py)) } @@ -134,10 +162,7 @@ impl Validator for ListValidator { ultra_strict: bool, ) -> bool { if ultra_strict { - match self.item_validator { - Some(ref v) => v.different_strict_behavior(definitions, true), - None => false, - } + self.item_validator.different_strict_behavior(definitions, true) } else { true } @@ -148,11 +173,9 @@ impl Validator for ListValidator { } fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - if let Some(ref mut v) = self.item_validator { - v.complete(definitions)?; - let inner_name = v.get_name(); - self.name = format!("{}[{inner_name}]", Self::EXPECTED_TYPE); - } + self.item_validator.complete(definitions)?; + let inner_name = self.item_validator.get_name(); + self.name = format!("{}[{inner_name}]", Self::EXPECTED_TYPE); Ok(()) } } diff --git a/src/validators/set.rs b/src/validators/set.rs index f8383b496..49cae841b 100644 --- a/src/validators/set.rs +++ b/src/validators/set.rs @@ -1,24 +1,23 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PySet}; +use super::{BuildValidator, Definitions, DefinitionsBuilder, Extra, Validator}; use crate::build_tools::SchemaDict; -use crate::errors::ValResult; -use crate::input::{GenericCollection, Input}; +use crate::errors::{ErrorType, ValError, ValResult}; +use crate::input::iterator::{validate_iterator, IterableValidationChecks, LengthConstraints}; +use crate::input::{GenericIterable, Input}; use crate::recursion_guard::RecursionGuard; - -use super::list::{get_items_schema, length_check}; -use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; +use crate::validators::any::AnyValidator; +use crate::validators::{build_validator, CombinedValidator}; #[derive(Debug, Clone)] pub struct SetValidator { strict: bool, - item_validator: Option>, - min_length: Option, + item_validator: Box, + min_length: usize, max_length: Option, - generator_max_length: Option, name: String, } -pub static MAX_LENGTH_GEN_MULTIPLE: usize = 10; macro_rules! set_build { () => { @@ -28,20 +27,20 @@ macro_rules! set_build { definitions: &mut DefinitionsBuilder, ) -> PyResult { let py = schema.py(); - let item_validator = get_items_schema(schema, config, definitions)?; - let inner_name = item_validator.as_ref().map(|v| v.get_name()).unwrap_or("any"); - let max_length = schema.get_as(pyo3::intern!(py, "max_length"))?; - let generator_max_length = match schema.get_as(pyo3::intern!(py, "generator_max_length"))? { - Some(v) => Some(v), - None => max_length.map(|v| v * super::set::MAX_LENGTH_GEN_MULTIPLE), + let item_validator = match schema.get_item(pyo3::intern!(schema.py(), "items_schema")) { + Some(d) => build_validator(d, config, definitions)?, + None => CombinedValidator::Any(AnyValidator), }; + let inner_name = item_validator.get_name(); + let max_length = schema.get_as(pyo3::intern!(py, "max_length"))?; let name = format!("{}[{}]", Self::EXPECTED_TYPE, inner_name); Ok(Self { strict: crate::build_tools::is_strict(schema, config)?, - item_validator, - min_length: schema.get_as(pyo3::intern!(py, "min_length"))?, + item_validator: Box::new(item_validator), + min_length: schema + .get_as(pyo3::intern!(py, "min_length"))? + .unwrap_or_default(), max_length, - generator_max_length, name, } .into()) @@ -64,30 +63,85 @@ impl Validator for SetValidator { definitions: &'data Definitions, recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - let seq = input.validate_set(extra.strict.unwrap_or(self.strict))?; + let create_err = |input| ValError::new(ErrorType::SetType, input); + + let field_type = "Set"; + + let generic_iterable = input.extract_iterable().map_err(|_| create_err(input))?; + + let strict = extra.strict.unwrap_or(self.strict); + + let length_constraints = LengthConstraints { + min_length: self.min_length, + max_length: self.max_length, + }; + + let mut checks = IterableValidationChecks::new(length_constraints, field_type); - let set = match self.item_validator { - Some(ref v) => PySet::new( + let mut output = PySet::empty(py)?; + + let len = |output: &&'data PySet| output.len(); + let mut write = |output: &mut &'data PySet, ob: PyObject| output.add(ob); + + match (generic_iterable, strict) { + // Always allow actual lists or JSON arrays + (GenericIterable::JsonArray(iter), _) => validate_iterator( py, - &seq.validate_to_vec( + input, + extra, + definitions, + recursion_guard, + &mut checks, + iter.iter().map(Ok), + &self.item_validator, + &mut output, + &mut write, + &len, + )?, + (GenericIterable::Set(iter), _) => validate_iterator( + py, + input, + extra, + definitions, + recursion_guard, + &mut checks, + iter.iter().map(Ok), + &self.item_validator, + &mut output, + &mut write, + &len, + )?, + // If not in strict mode we also accept any iterable except str, bytes or mappings + // This may seem counterintuitive since a Mapping is a less generic type than an arbitrary + // iterable (which we do accept) but doing something like `x: list[int] = {1: 'a'}` is commonly + // a mistake, so we don't parse it by default + ( + GenericIterable::String(_) + | GenericIterable::Bytes(_) + | GenericIterable::Dict(_) + | GenericIterable::Mapping(_), + false, + ) => return Err(create_err(input)), + (generic_iterable, false) => match generic_iterable.into_sequence_iterator(py) { + Ok(iter) => validate_iterator( py, input, - self.max_length, - "Set", - self.generator_max_length, - v, extra, definitions, recursion_guard, + &mut checks, + iter, + &self.item_validator, + &mut output, + &mut write, + &len, )?, - )?, - None => match seq { - GenericCollection::Set(set) => set, - _ => PySet::new(py, &seq.to_vec(py, input, "Set", self.generator_max_length)?)?, + Err(_) => return Err(create_err(input)), }, + _ => return Err(create_err(input)), }; - length_check!(input, "Set", self.min_length, self.max_length, set); - Ok(set.into_py(py)) + + Ok(output.into_py(py)) } fn different_strict_behavior( @@ -96,10 +150,7 @@ impl Validator for SetValidator { ultra_strict: bool, ) -> bool { if ultra_strict { - match self.item_validator { - Some(ref v) => v.different_strict_behavior(definitions, true), - None => false, - } + self.item_validator.different_strict_behavior(definitions, true) } else { true } @@ -110,9 +161,6 @@ impl Validator for SetValidator { } fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - match self.item_validator { - Some(ref mut v) => v.complete(definitions), - None => Ok(()), - } + self.item_validator.complete(definitions) } } diff --git a/src/validators/tuple.rs b/src/validators/tuple.rs index 57f3f49d7..8d88ecbcd 100644 --- a/src/validators/tuple.rs +++ b/src/validators/tuple.rs @@ -3,18 +3,23 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyTuple}; use crate::build_tools::{is_strict, SchemaDict}; -use crate::errors::{ErrorType, ValError, ValLineError, ValResult}; -use crate::input::{GenericCollection, Input}; +use crate::errors::ValLineError; +use crate::errors::{ErrorType, ValError, ValResult}; +use crate::input::iterator::calculate_output_init_capacity; +use crate::input::iterator::validate_iterator; +use crate::input::iterator::IterableValidationChecks; +use crate::input::iterator::LengthConstraints; +use crate::input::{GenericIterable, Input}; use crate::recursion_guard::RecursionGuard; -use super::list::{get_items_schema, length_check}; +use super::any::AnyValidator; use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator}; #[derive(Debug, Clone)] pub struct TupleVariableValidator { strict: bool, - item_validator: Option>, - min_length: Option, + item_validator: Box, + min_length: usize, max_length: Option, name: String, } @@ -27,13 +32,16 @@ impl BuildValidator for TupleVariableValidator { definitions: &mut DefinitionsBuilder, ) -> PyResult { let py = schema.py(); - let item_validator = get_items_schema(schema, config, definitions)?; - let inner_name = item_validator.as_ref().map(|v| v.get_name()).unwrap_or("any"); + let item_validator = match schema.get_item(pyo3::intern!(schema.py(), "items_schema")) { + Some(d) => build_validator(d, config, definitions)?, + None => CombinedValidator::Any(AnyValidator), + }; + let inner_name = item_validator.get_name(); let name = format!("tuple[{inner_name}, ...]"); Ok(Self { strict: crate::build_tools::is_strict(schema, config)?, - item_validator, - min_length: schema.get_as(pyo3::intern!(py, "min_length"))?, + item_validator: Box::new(item_validator), + min_length: schema.get_as(pyo3::intern!(py, "min_length"))?.unwrap_or_default(), max_length: schema.get_as(pyo3::intern!(py, "max_length"))?, name, } @@ -41,6 +49,8 @@ impl BuildValidator for TupleVariableValidator { } } +const FIELD_TYPE: &str = "Tuple"; + impl Validator for TupleVariableValidator { fn validate<'s, 'data>( &'s self, @@ -50,30 +60,84 @@ impl Validator for TupleVariableValidator { definitions: &'data Definitions, recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - let seq = input.validate_tuple(extra.strict.unwrap_or(self.strict))?; + let generic_iterable = input + .extract_iterable() + .map_err(|_| ValError::new(ErrorType::TupleType, input))?; + + let strict = extra.strict.unwrap_or(self.strict); + + let length_constraints = LengthConstraints { + min_length: self.min_length, + max_length: self.max_length, + }; - let output = match self.item_validator { - Some(ref v) => seq.validate_to_vec( + let mut checks = IterableValidationChecks::new(length_constraints, FIELD_TYPE); + + let mut output = Vec::with_capacity(calculate_output_init_capacity(generic_iterable.len(), self.max_length)); + let len = |output: &Vec| output.len(); + let mut write = |output: &mut Vec, ob: PyObject| { + output.push(ob); + Ok(()) + }; + + match (generic_iterable, strict) { + // Always allow actual lists or JSON arrays + (GenericIterable::JsonArray(iter), _) => validate_iterator( py, input, - self.max_length, - "Tuple", - self.max_length, - v, extra, definitions, recursion_guard, + &mut checks, + iter.iter().map(Ok), + &self.item_validator, + &mut output, + &mut write, + &len, )?, - None => match seq { - GenericCollection::Tuple(tuple) => { - length_check!(input, "Tuple", self.min_length, self.max_length, tuple); - return Ok(tuple.into_py(py)); - } - _ => seq.to_vec(py, input, "Tuple", self.max_length)?, + (GenericIterable::Tuple(iter), _) => validate_iterator( + py, + input, + extra, + definitions, + recursion_guard, + &mut checks, + iter.iter().map(Ok), + &self.item_validator, + &mut output, + &mut write, + &len, + )?, + // If not in strict mode we also accept any iterable except str, bytes or mappings + // This may seem counterintuitive since a Mapping is a less generic type than an arbitrary + // iterable (which we do accept) but doing something like `x: list[int] = {1: 'a'}` is commonly + // a mistake, so we don't parse it by default + ( + GenericIterable::String(_) + | GenericIterable::Bytes(_) + | GenericIterable::Dict(_) + | GenericIterable::Mapping(_), + _, + ) => return Err(ValError::new(ErrorType::TupleType, input)), + (generic_iterable, false) => match generic_iterable.into_sequence_iterator(py) { + Ok(iter) => validate_iterator( + py, + input, + extra, + definitions, + recursion_guard, + &mut checks, + iter, + &self.item_validator, + &mut output, + &mut write, + &len, + )?, + Err(_) => return Err(ValError::new(ErrorType::TupleType, input)), }, + _ => return Err(ValError::new(ErrorType::TupleType, input)), }; - length_check!(input, "Tuple", self.min_length, self.max_length, output); - Ok(PyTuple::new(py, &output).into_py(py)) + Ok(PyTuple::new(py, output.into_iter()).into_py(py)) } fn different_strict_behavior( @@ -82,10 +146,7 @@ impl Validator for TupleVariableValidator { ultra_strict: bool, ) -> bool { if ultra_strict { - match self.item_validator { - Some(ref v) => v.different_strict_behavior(definitions, true), - None => false, - } + self.item_validator.different_strict_behavior(definitions, true) } else { true } @@ -96,10 +157,7 @@ impl Validator for TupleVariableValidator { } fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - match self.item_validator { - Some(ref mut v) => v.complete(definitions), - None => Ok(()), - } + self.item_validator.complete(definitions) } } @@ -139,6 +197,75 @@ impl BuildValidator for TuplePositionalValidator { } } +#[allow(clippy::too_many_arguments)] +fn validate_iterator_tuple_positional<'s, 'data, V, I>( + py: Python<'data>, + input: &'data I, + extra: &'s Extra<'s>, + definitions: &'data Definitions, + recursion_guard: &'s mut RecursionGuard, + checks: &mut IterableValidationChecks<'data, I>, + iter: impl Iterator>, + items_validators: &[CombinedValidator], + extra_validator: &Option>, + output: &mut Vec, +) -> ValResult<'data, ()> +where + I: Input<'data> + 'data, + V: Input<'data> + 'data, +{ + for (index, result) in iter.enumerate() { + let value = result?; + match items_validators.get(output.len()) { + Some(item_validator) => { + let result = item_validator + .validate(py, value, extra, definitions, recursion_guard) + .map_err(|e| e.with_outer_location(index.into())); + if let Some(value) = checks.filter_validation_result(result, input)? { + output.push(value); + checks.check_output_length(output.len(), input)?; + } + } + None => { + // Extra item + match extra_validator { + Some(ref validator) => { + let result = validator + .validate(py, value, extra, definitions, recursion_guard) + .map_err(|e| e.with_outer_location(index.into())); + if let Some(value) = checks.filter_validation_result(result, input)? { + output.push(value); + checks.check_output_length(output.len(), input)?; + } + } + None => { + return Err(ValError::new( + ErrorType::TooLong { + field_type: "Tuple".to_string(), + max_length: items_validators.len(), + actual_length: output.len() + 1, + }, + input, + )) + } + } + } + } + } + for (idx, validator) in items_validators.iter().enumerate().skip(output.len()) { + let default = validator.default_value(py, Some(output.len()), extra, definitions, recursion_guard)?; + match default { + Some(v) => { + output.push(v); + checks.check_output_length(output.len(), input)?; + } + None => checks.add_error(ValLineError::new_with_loc(ErrorType::Missing, input, idx)), + } + } + checks.finish(input)?; + Ok(()) +} + impl Validator for TuplePositionalValidator { fn validate<'s, 'data>( &'s self, @@ -148,94 +275,76 @@ impl Validator for TuplePositionalValidator { definitions: &'data Definitions, recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - let collection = input.validate_tuple(extra.strict.unwrap_or(self.strict))?; - let expected_length = self.items_validators.len(); + let generic_iterable = input + .extract_iterable() + .map_err(|_| ValError::new(ErrorType::TupleType, input))?; - let mut output: Vec = Vec::with_capacity(expected_length); - let mut errors: Vec = Vec::new(); - macro_rules! iter { - ($collection_iter:expr) => {{ - for (index, validator) in self.items_validators.iter().enumerate() { - match $collection_iter.next() { - Some(item) => match validator.validate(py, item, extra, definitions, recursion_guard) { - Ok(item) => output.push(item), - Err(ValError::LineErrors(line_errors)) => { - errors.extend( - line_errors - .into_iter() - .map(|err| err.with_outer_location(index.into())), - ); - } - Err(err) => return Err(err), - }, - None => { - if let Some(value) = - validator.default_value(py, Some(index), extra, definitions, recursion_guard)? - { - output.push(value); - } else { - errors.push(ValLineError::new_with_loc(ErrorType::Missing, input, index)); - } - } - } - } - for (index, item) in $collection_iter.enumerate() { - match self.extra_validator { - Some(ref extra_validator) => { - match extra_validator.validate(py, item, extra, definitions, recursion_guard) { - Ok(item) => output.push(item), - Err(ValError::LineErrors(line_errors)) => { - errors.extend( - line_errors - .into_iter() - .map(|err| err.with_outer_location((index + expected_length).into())), - ); - } - Err(ValError::Omit) => (), - Err(err) => return Err(err), - } - } - None => { - errors.push(ValLineError::new( - ErrorType::TooLong { - field_type: "Tuple".to_string(), - max_length: expected_length, - actual_length: collection.generic_len()?, - }, - input, - )); - // no need to continue through further items - break; - } - } - } - }}; - } - match collection { - GenericCollection::List(collection) => { - let mut iter = collection.iter(); - iter!(iter) - } - GenericCollection::Tuple(collection) => { - let mut iter = collection.iter(); - iter!(iter) - } - GenericCollection::PyAny(collection) => { - let vec: Vec<&PyAny> = collection.iter()?.collect::>()?; - let mut iter = vec.into_iter(); - iter!(iter) - } - GenericCollection::JsonArray(collection) => { - let mut iter = collection.iter(); - iter!(iter) - } - _ => unreachable!(), - } - if errors.is_empty() { - Ok(PyTuple::new(py, &output).into_py(py)) - } else { - Err(ValError::LineErrors(errors)) - } + let strict = extra.strict.unwrap_or(self.strict); + + let length_constraints = LengthConstraints { + min_length: 0, + max_length: None, + }; + + let mut checks = IterableValidationChecks::new(length_constraints, FIELD_TYPE); + + let mut output = Vec::with_capacity(calculate_output_init_capacity(generic_iterable.len(), None)); + + match (generic_iterable, strict) { + // Always allow actual lists or JSON arrays + (GenericIterable::JsonArray(iter), _) => validate_iterator_tuple_positional( + py, + input, + extra, + definitions, + recursion_guard, + &mut checks, + iter.iter().map(Ok), + &self.items_validators, + &self.extra_validator, + &mut output, + )?, + (GenericIterable::Tuple(iter), _) => validate_iterator_tuple_positional( + py, + input, + extra, + definitions, + recursion_guard, + &mut checks, + iter.iter().map(Ok), + &self.items_validators, + &self.extra_validator, + &mut output, + )?, + // If not in strict mode we also accept any iterable except str, bytes or mappings + // This may seem counterintuitive since a Mapping is a less generic type than an arbitrary + // iterable (which we do accept) but doing something like `x: list[int] = {1: 'a'}` is commonly + // a mistake, so we don't parse it by default + ( + GenericIterable::String(_) + | GenericIterable::Bytes(_) + | GenericIterable::Dict(_) + | GenericIterable::Mapping(_), + _, + ) => return Err(ValError::new(ErrorType::TupleType, input)), + (generic_iterable, false) => match generic_iterable.into_sequence_iterator(py) { + Ok(iter) => validate_iterator_tuple_positional( + py, + input, + extra, + definitions, + recursion_guard, + &mut checks, + iter, + &self.items_validators, + &self.extra_validator, + &mut output, + )?, + Err(_) => return Err(ValError::new(ErrorType::TupleType, input)), + }, + _ => return Err(ValError::new(ErrorType::TupleType, input)), + }; + Ok(PyTuple::new(py, output.into_iter()).into_py(py)) } fn different_strict_behavior( diff --git a/tests/benchmarks/test_complete_benchmark.py b/tests/benchmarks/test_complete_benchmark.py index dc8767af0..f4b60b367 100644 --- a/tests/benchmarks/test_complete_benchmark.py +++ b/tests/benchmarks/test_complete_benchmark.py @@ -91,7 +91,7 @@ def test_complete_invalid(): lax_validator = SchemaValidator(lax_schema) with pytest.raises(ValidationError) as exc_info: lax_validator.validate_python(input_data_wrong()) - assert len(exc_info.value.errors(include_url=False)) == 738 + assert len(exc_info.value.errors(include_url=False)) == 638 model = pydantic_model() if model is None: diff --git a/tests/conftest.py b/tests/conftest.py index d86b6069b..45817a427 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ import re from dataclasses import dataclass from pathlib import Path -from typing import Any, Type +from typing import Any, List, Type, TypeVar, Union import hypothesis import pytest @@ -42,6 +42,16 @@ def __repr__(self): return f'Err({self.message!r})' +@dataclass +class AnyOf: + expected: List[Any] + + +T = TypeVar('T') + +Result = Union[T, Err] + + def json_default(obj): if isinstance(obj, ArgsKwargs): raise pytest.skip('JSON skipping ArgsKwargs') diff --git a/tests/serializers/test_list_tuple.py b/tests/serializers/test_list_tuple.py index 6d1a96516..686f029db 100644 --- a/tests/serializers/test_list_tuple.py +++ b/tests/serializers/test_list_tuple.py @@ -159,9 +159,6 @@ class RemovedContains(ImplicitContains): ({'a': 'dict'}, 'Input should be a valid set'), ({4.2}, 'Input should be a valid integer, got a number with a fractional part'), ({'a'}, 'Input should be a valid integer, unable to parse string as an integer'), - (ImplicitContains(), 'Input should be a valid set'), - (ExplicitContains(), re.compile('.*Invalid Schema:.*Input should be a valid set.*', re.DOTALL)), - (RemovedContains(), re.compile('.*Invalid Schema:.*Input should be a valid set.*', re.DOTALL)), ], ) @pytest.mark.parametrize('schema_func', [core_schema.list_schema, core_schema.tuple_variable_schema]) diff --git a/tests/test_errors.py b/tests/test_errors.py index 2815ae5af..e84f28684 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -2,7 +2,7 @@ from decimal import Decimal import pytest -from dirty_equals import HasRepr, IsInstance, IsJson, IsStr +from dirty_equals import Contains, HasRepr, IsInstance, IsJson, IsStr from pydantic_core import ( PydanticCustomError, @@ -615,21 +615,25 @@ def test_loc_with_dots(): with pytest.raises(ValidationError) as exc_info: v.validate_python({'foo.bar': ('x', 42)}) # insert_assert(exc_info.value.errors(include_url=False)) - assert exc_info.value.errors(include_url=False) == [ + assert exc_info.value.errors(include_url=False) == Contains( { 'type': 'int_parsing', 'loc': ('foo.bar', 0), 'msg': 'Input should be a valid integer, unable to parse string as an integer', 'input': 'x', } - ] + ) # insert_assert(str(exc_info.value)) - assert str(exc_info.value) == ( - "1 validation error for typed-dict\n" - "`foo.bar`.0\n" - " Input should be a valid integer, unable to parse string as an integer " - "[type=int_parsing, input_value='x', input_type=str]\n" - f' For further information visit https://errors.pydantic.dev/{__version__}/v/int_parsing' + assert ( + str(exc_info.value) + == f"""\ +2 validation errors for typed-dict +`foo.bar`.0 + Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='x', input_type=str] + For further information visit https://errors.pydantic.dev/{__version__}/v/int_parsing +`foo.bar`.1 + Field required [type=missing, input_value=('x', 42), input_type=tuple] + For further information visit https://errors.pydantic.dev/{__version__}/v/missing""" # noqa: E501 ) diff --git a/tests/validators/test_definitions_recursive.py b/tests/validators/test_definitions_recursive.py index 8d09ee5a6..2e4e5470f 100644 --- a/tests/validators/test_definitions_recursive.py +++ b/tests/validators/test_definitions_recursive.py @@ -2,7 +2,7 @@ from typing import List, Optional import pytest -from dirty_equals import AnyThing, HasAttributes, IsList, IsPartialDict, IsStr, IsTuple +from dirty_equals import AnyThing, Contains, HasAttributes, IsList, IsPartialDict, IsStr, IsTuple from pydantic_core import SchemaError, SchemaValidator, ValidationError, __version__, core_schema @@ -509,7 +509,7 @@ def test_multiple_tuple_recursion(multiple_tuple_schema: SchemaValidator): with pytest.raises(ValidationError) as exc_info: multiple_tuple_schema.validate_python({'f1': data, 'f2': data}) - assert exc_info.value.errors(include_url=False) == [ + assert exc_info.value.errors(include_url=False) == Contains( { 'type': 'recursion_loop', 'loc': ('f1', 1), @@ -522,7 +522,7 @@ def test_multiple_tuple_recursion(multiple_tuple_schema: SchemaValidator): 'msg': 'Recursion error - cyclic reference detected', 'input': [1, IsList(length=2)], }, - ] + ) def test_multiple_tuple_recursion_once(multiple_tuple_schema: SchemaValidator): @@ -531,7 +531,7 @@ def test_multiple_tuple_recursion_once(multiple_tuple_schema: SchemaValidator): with pytest.raises(ValidationError) as exc_info: multiple_tuple_schema.validate_python({'f1': data, 'f2': data}) - assert exc_info.value.errors(include_url=False) == [ + assert exc_info.value.errors(include_url=False) == Contains( { 'type': 'recursion_loop', 'loc': ('f1', 1), @@ -544,7 +544,7 @@ def test_multiple_tuple_recursion_once(multiple_tuple_schema: SchemaValidator): 'msg': 'Recursion error - cyclic reference detected', 'input': [1, IsList(length=2)], }, - ] + ) def test_definition_wrap(): @@ -571,14 +571,14 @@ def wrap_func(input_value, validator, info): t.append(t) with pytest.raises(ValidationError) as exc_info: v.validate_python(t) - assert exc_info.value.errors(include_url=False) == [ + assert exc_info.value.errors(include_url=False) == Contains( { 'type': 'recursion_loop', 'loc': (1,), 'msg': 'Recursion error - cyclic reference detected', 'input': IsList(positions={0: 1}, length=2), } - ] + ) def test_union_ref_strictness(): diff --git a/tests/validators/test_dict.py b/tests/validators/test_dict.py index d804f4ba5..69c85aa05 100644 --- a/tests/validators/test_dict.py +++ b/tests/validators/test_dict.py @@ -1,7 +1,7 @@ import re from collections import OrderedDict from collections.abc import Mapping -from typing import Any, Dict +from typing import Any, Dict, List import pytest from dirty_equals import HasRepr, IsStr @@ -24,21 +24,23 @@ def test_dict(py_and_json: PyAndJson): @pytest.mark.parametrize( 'input_value,expected', [ - ({'1': b'1', '2': b'2'}, {'1': '1', '2': '2'}), - (OrderedDict(a=b'1', b='2'), {'a': '1', 'b': '2'}), + ({'1': b'1', '2': b'2'}, {'1': 1, '2': 2}), + (OrderedDict(a=b'1', b='2'), {'a': 1, 'b': 2}), ({}, {}), ('foobar', Err("Input should be a valid dictionary [type=dict_type, input_value='foobar', input_type=str]")), - ([], Err('Input should be a valid dictionary [type=dict_type,')), - ([('x', 'y')], Err('Input should be a valid dictionary [type=dict_type,')), - ([('x', 'y'), ('z', 'z')], Err('Input should be a valid dictionary [type=dict_type,')), - ((), Err('Input should be a valid dictionary [type=dict_type,')), - ((('x', 'y'),), Err('Input should be a valid dictionary [type=dict_type,')), - ((type('Foobar', (), {'x': 1})()), Err('Input should be a valid dictionary [type=dict_type,')), + ([], {}), + ([('x', '1')], {'x': 1}), + ([('x', '1'), ('z', b'2')], {'x': 1, 'z': 2}), + ((), {}), + ((('x', '1'),), {'x': 1}), + pytest.param( + (type('Foobar', (), {'x': 1})()), Err('Input should be a valid dictionary [type=dict_type,'), id='Foobar' + ), ], ids=repr, ) def test_dict_cases(input_value, expected): - v = SchemaValidator({'type': 'dict', 'keys_schema': {'type': 'str'}, 'values_schema': {'type': 'str'}}) + v = SchemaValidator({'type': 'dict', 'keys_schema': {'type': 'str'}, 'values_schema': {'type': 'int'}}) if isinstance(expected, Err): with pytest.raises(ValidationError, match=re.escape(expected.message)): v.validate_python(input_value) @@ -157,17 +159,60 @@ def __len__(self): assert exc_info.value.errors(include_url=False) == [ { - 'type': 'mapping_type', + 'type': 'dict_type', 'loc': (), - 'msg': 'Input should be a valid mapping, error: RuntimeError: intentional error', + 'msg': 'Input should be a valid dictionary', 'input': HasRepr(IsStr(regex='.+BadMapping object at.+')), - 'ctx': {'error': 'RuntimeError: intentional error'}, } ] -@pytest.mark.parametrize('mapping_items', [[(1,)], ['foobar'], [(1, 2, 3)], 'not list']) -def test_mapping_error_yield_1(mapping_items): +@pytest.mark.parametrize( + 'mapping_items,error', + [ + ( + [(1,)], + { + 'type': 'iteration_error', + 'loc': (), + 'msg': 'Error iterating over object, error: ValueError: expected tuple of length 2, but got tuple of length 1', # noqa: E501 + 'input': HasRepr(IsStr(regex='.+BadMapping object at.+')), + 'ctx': {'error': 'ValueError: expected tuple of length 2, but got tuple of length 1'}, + }, + ), + ( + ['foobar'], + { + 'type': 'iteration_error', + 'loc': (), + 'msg': "Error iterating over object, error: TypeError: 'str' object cannot be converted to 'PyTuple'", # noqa: E501 + 'input': HasRepr(IsStr(regex='.+BadMapping object at.+')), + 'ctx': {'error': "TypeError: 'str' object cannot be converted to 'PyTuple'"}, + }, + ), + ( + [(1, 2, 3)], + { + 'type': 'iteration_error', + 'loc': (), + 'msg': 'Error iterating over object, error: ValueError: expected tuple of length 2, but got tuple of length 3', # noqa: E501 + 'input': HasRepr(IsStr(regex='.+BadMapping object at.+')), + 'ctx': {'error': 'ValueError: expected tuple of length 2, but got tuple of length 3'}, + }, + ), + ( + 'not list', + { + 'type': 'iteration_error', + 'loc': (), + 'msg': "Error iterating over object, error: TypeError: 'str' object cannot be converted to 'PyTuple'", # noqa: E501 + 'input': HasRepr(IsStr(regex='.+BadMapping object at.+')), + 'ctx': {'error': "TypeError: 'str' object cannot be converted to 'PyTuple'"}, + }, + ), + ], +) +def test_mapping_error_yield_1(mapping_items: List[Any], error: Any): class BadMapping(Mapping): def items(self): return mapping_items @@ -185,15 +230,7 @@ def __len__(self): with pytest.raises(ValidationError) as exc_info: v.validate_python(BadMapping()) - assert exc_info.value.errors(include_url=False) == [ - { - 'type': 'mapping_type', - 'loc': (), - 'msg': 'Input should be a valid mapping, error: Mapping items must be tuples of (key, value) pairs', - 'input': HasRepr(IsStr(regex='.+BadMapping object at.+')), - 'ctx': {'error': 'Mapping items must be tuples of (key, value) pairs'}, - } - ] + assert exc_info.value.errors(include_url=False) == [error] @pytest.mark.parametrize( diff --git a/tests/validators/test_frozenset.py b/tests/validators/test_frozenset.py index 1d932e291..3d5866915 100644 --- a/tests/validators/test_frozenset.py +++ b/tests/validators/test_frozenset.py @@ -36,8 +36,8 @@ def test_no_copy(): input_value = frozenset([1, 2, 3]) output = v.validate_python(input_value) assert output == input_value - assert output is input_value - assert id(output) == id(input_value) + assert output is not input_value + assert id(output) != id(input_value) @pytest.mark.parametrize( @@ -157,11 +157,6 @@ def generate_repeats(): ( {'max_length': 3}, infinite_generator(), - Err('Frozenset should have at most 30 items after validation, not 31 [type=too_long,'), - ), - ( - {'max_length': 3, 'generator_max_length': 3}, - infinite_generator(), Err('Frozenset should have at most 3 items after validation, not 4 [type=too_long,'), ), ], @@ -248,7 +243,7 @@ def test_repr(): 'SchemaValidator(' 'title="frozenset[any]",' 'validator=FrozenSet(FrozenSetValidator{' - 'strict:true,item_validator:None,min_length:Some(42),max_length:None,generator_max_length:None,' + 'strict:true,item_validator:Any(AnyValidator),min_length:42,max_length:None,' 'name:"frozenset[any]"' '}),definitions=[])' ) diff --git a/tests/validators/test_list.py b/tests/validators/test_list.py index 3b019df3f..d8694e1fe 100644 --- a/tests/validators/test_list.py +++ b/tests/validators/test_list.py @@ -1,14 +1,15 @@ +import collections.abc import re from collections import deque -from collections.abc import Sequence -from typing import Any, Dict +from dataclasses import dataclass +from typing import Any, Callable, Dict, Iterator, List, Union import pytest -from dirty_equals import HasRepr, IsInstance, IsStr +from dirty_equals import Contains, HasRepr, IsInstance, IsStr -from pydantic_core import SchemaValidator, ValidationError +from pydantic_core import PydanticOmit, SchemaValidator, ValidationError, core_schema -from ..conftest import Err, PyAndJson, infinite_generator +from ..conftest import Err, PyAndJson, Result @pytest.mark.parametrize( @@ -40,31 +41,125 @@ def test_list_strict(): ] -def gen_ints(): - yield 1 - yield 2 - yield '3' +class MySequence(collections.abc.Sequence): + def __init__(self, data: List[Any]): + self._data = data + + def __getitem__(self, index: int) -> Any: + return self._data[index] + + def __len__(self): + return len(self._data) + + def __repr__(self) -> str: + return f'MySequence({repr(self._data)})' + + +class MyMapping(collections.abc.Mapping): + def __init__(self, data: Dict[Any, Any]) -> None: + self._data = data + + def __getitem__(self, key: Any) -> Any: + return self._data[key] + + def __iter__(self) -> Iterator[Any]: + return iter(self._data) + + def __len__(self) -> int: + return len(self._data) + + def __repr__(self) -> str: + return f'MyMapping({repr(self._data)})' + + +@dataclass +class ListInputTestCase: + input: Any + output: Result[Any] + strict: Union[bool, None] = None + + +LAX_MODE_INPUTS: List[Any] = [ + (1, 2, 3), + frozenset((1, 2, 3)), + set((1, 2, 3)), + deque([1, 2, 3]), + {1: 'a', 2: 'b', 3: 'c'}.keys(), + {'a': 1, 'b': 2, 'c': 3}.values(), + MySequence([1, 2, 3]), + MyMapping({1: 'a', 2: 'b', 3: 'c'}).keys(), + MyMapping({'a': 1, 'b': 2, 'c': 3}).values(), + (x for x in [1, 2, 3]), +] + + +@pytest.mark.parametrize( + 'testcase', + [ + *[ListInputTestCase([1, 2, 3], [1, 2, 3], strict) for strict in (True, False, None)], + *[ + ListInputTestCase(inp, Err('Input should be a valid list [type=list_type,'), True) + for inp in [*LAX_MODE_INPUTS, '123', b'123'] + ], + *[ListInputTestCase(inp, [1, 2, 3], False) for inp in LAX_MODE_INPUTS], + *[ + ListInputTestCase(inp, Err('Input should be a valid list [type=list_type,'), False) + for inp in ['123', b'123', MyMapping({1: 'a', 2: 'b', 3: 'c'}), {1: 'a', 2: 'b', 3: 'c'}] + ], + ], + ids=repr, +) +def test_list_allowed_inputs_python(testcase: ListInputTestCase): + v = SchemaValidator(core_schema.list_schema(core_schema.int_schema(), strict=testcase.strict)) + if isinstance(testcase.output, Err): + with pytest.raises(ValidationError, match=re.escape(testcase.output.message)): + v.validate_python(testcase.input) + else: + output = v.validate_python(testcase.input) + assert output == testcase.output + assert output is not testcase.input + + +@pytest.mark.parametrize( + 'testcase', + [ + ListInputTestCase({1: 1, 2: 2, 3: 3}.items(), Err('Input should be a valid list [type=list_type,'), True), + ListInputTestCase( + MyMapping({1: 1, 2: 2, 3: 3}).items(), Err('Input should be a valid list [type=list_type,'), True + ), + ListInputTestCase({1: 1, 2: 2, 3: 3}.items(), [(1, 1), (2, 2), (3, 3)], False), + ListInputTestCase(MyMapping({1: 1, 2: 2, 3: 3}).items(), [(1, 1), (2, 2), (3, 3)], False), + ], + ids=repr, +) +def test_list_dict_items_input(testcase: ListInputTestCase) -> None: + v = SchemaValidator( + core_schema.list_schema( + core_schema.tuple_positional_schema([core_schema.int_schema(), core_schema.int_schema()]), + strict=testcase.strict, + ) + ) + if isinstance(testcase.output, Err): + with pytest.raises(ValidationError, match=re.escape(testcase.output.message)): + v.validate_python(testcase.input) + else: + output = v.validate_python(testcase.input) + assert output == testcase.output + assert output is not testcase.input @pytest.mark.parametrize( 'input_value,expected', [ - ([1, 2, '3'], [1, 2, 3]), - ((1, 2, '3'), [1, 2, 3]), - (deque((1, 2, '3')), [1, 2, 3]), - ({1, 2, '3'}, Err('Input should be a valid list [type=list_type,')), - (gen_ints(), [1, 2, 3]), - (frozenset({1, 2, '3'}), Err('Input should be a valid list [type=list_type,')), - ({1: 10, 2: 20, '3': '30'}.keys(), [1, 2, 3]), - ({1: 10, 2: 20, '3': '30'}.values(), [10, 20, 30]), - ({1: 10, 2: 20, '3': '30'}, Err('Input should be a valid list [type=list_type,')), - ((x for x in [1, 2, '3']), [1, 2, 3]), + ([1, b'2', '3'], [1, 2, 3]), + ((1, b'2', '3'), [1, 2, 3]), + ((x for x in (1, b'2', '3')), [1, 2, 3]), ('456', Err("Input should be a valid list [type=list_type, input_value='456', input_type=str]")), (b'789', Err("Input should be a valid list [type=list_type, input_value=b'789', input_type=bytes]")), ], ids=repr, ) -def test_list_int(input_value, expected): +def test_list_int(input_value: Any, expected: Any): v = SchemaValidator({'type': 'list', 'items_schema': {'type': 'int'}}) if isinstance(expected, Err): with pytest.raises(ValidationError, match=re.escape(expected.message)): @@ -86,25 +181,10 @@ def test_list_json(): ] -@pytest.mark.parametrize( - 'input_value,expected', - [ - ([], []), - ([1, '2', b'3'], [1, '2', b'3']), - (frozenset([1, '2', b'3']), Err('Input should be a valid list [type=list_type,')), - ((), []), - ((1, '2', b'3'), [1, '2', b'3']), - (deque([1, '2', b'3']), [1, '2', b'3']), - ({1, '2', b'3'}, Err('Input should be a valid list [type=list_type,')), - ], -) +@pytest.mark.parametrize('input_value,expected', [([], []), ([1, '2', b'3'], [1, '2', b'3'])]) def test_list_any(input_value, expected): v = SchemaValidator({'type': 'list'}) - if isinstance(expected, Err): - with pytest.raises(ValidationError, match=re.escape(expected.message)): - v.validate_python(input_value) - else: - assert v.validate_python(input_value) == expected + assert v.validate_python(input_value) == expected @pytest.mark.parametrize( @@ -148,13 +228,24 @@ def test_list_error(input_value, index): ({'max_length': 1}, [1, 2], Err('List should have at most 1 item after validation, not 2 [type=too_long,')), ( {'max_length': 44}, - infinite_generator(), + [1] * 100, Err('List should have at most 44 items after validation, not 45 [type=too_long,'), ), + ( + {'max_length': 3}, + ['a', 'b', 'c', 'd'], + Err('List should have at most 3 items after validation, not 4 [type=too_long,'), + ), + ( + {'min_length': 2}, + ['a', 'b'], + Err('List should have at least 2 items after validation, not 0 [type=too_short,'), + ), + ({'min_length': 1}, [], Err('List should have at least 1 item after validation, not 0 [type=too_short,')), ], ) def test_list_length_constraints(kwargs: Dict[str, Any], input_value, expected): - v = SchemaValidator({'type': 'list', **kwargs}) + v = SchemaValidator({'type': 'list', 'items_schema': core_schema.int_schema(), **kwargs}) if isinstance(expected, Err): with pytest.raises(ValidationError, match=re.escape(expected.message)): v.validate_python(input_value) @@ -270,7 +361,7 @@ def gen(error: bool): assert exc_info.value.errors(include_url=False) == [ { 'type': 'iteration_error', - 'loc': (2,), + 'loc': (), 'msg': 'Error iterating over object, error: RuntimeError: error', 'input': HasRepr(IsStr(regex='.gen at 0x[0-9a-fA-F]+>')), 'ctx': {'error': 'RuntimeError: error'}, @@ -303,68 +394,6 @@ def test_list_from_dict_items(input_value, items_schema, expected): assert output == expected -@pytest.fixture(scope='session', name='MySequence') -def my_sequence(): - class MySequence(Sequence): - def __init__(self): - self._data = [1, 2, 3] - - def __getitem__(self, index): - return self._data[index] - - def __len__(self): - return len(self._data) - - def count(self, value): - return self._data.count(value) - - assert isinstance(MySequence(), Sequence) - return MySequence - - -def test_sequence(MySequence): - v = SchemaValidator({'type': 'list', 'items_schema': {'type': 'int'}}) - with pytest.raises(ValidationError) as exc_info: - v.validate_python(MySequence()) - # insert_assert(exc_info.value.errors(include_url=False)) - assert exc_info.value.errors(include_url=False) == [ - {'type': 'list_type', 'loc': (), 'msg': 'Input should be a valid list', 'input': IsInstance(MySequence)} - ] - - -@pytest.mark.parametrize( - 'input_value,expected', - [ - ([1, 2, 3], [1, 2, 3]), - ((1, 2, 3), [1, 2, 3]), - (range(3), [0, 1, 2]), - (gen_ints(), [1, 2, 3]), - ({1: 2, 3: 4}, [1, 3]), - ('123', [1, 2, 3]), - ( - 123, - Err( - '1 validation error for list[int]', - [{'type': 'list_type', 'loc': (), 'msg': 'Input should be a valid list', 'input': 123}], - ), - ), - ], -) -def test_allow_any_iter(input_value, expected): - v = SchemaValidator({'type': 'list', 'items_schema': {'type': 'int'}, 'allow_any_iter': True}) - if isinstance(expected, Err): - with pytest.raises(ValidationError, match=re.escape(expected.message)) as exc_info: - v.validate_python(input_value) - assert exc_info.value.errors(include_url=False) == expected.errors - else: - assert v.validate_python(input_value) == expected - - -def test_sequence_allow_any_iter(MySequence): - v = SchemaValidator({'type': 'list', 'items_schema': {'type': 'int'}, 'allow_any_iter': True}) - assert v.validate_python(MySequence()) == [1, 2, 3] - - @pytest.mark.parametrize('items_schema', ['int', 'any']) def test_bad_iter(items_schema): class BadIter: @@ -387,7 +416,7 @@ def __next__(self): else: raise RuntimeError('broken') - v = SchemaValidator({'type': 'list', 'items_schema': {'type': items_schema}, 'allow_any_iter': True}) + v = SchemaValidator({'type': 'list', 'items_schema': {'type': items_schema}}) assert v.validate_python(BadIter(True)) == [1] with pytest.raises(ValidationError) as exc_info: v.validate_python(BadIter(False)) @@ -395,9 +424,108 @@ def __next__(self): assert exc_info.value.errors(include_url=False) == [ { 'type': 'iteration_error', - 'loc': (1,), + 'loc': (), 'msg': 'Error iterating over object, error: RuntimeError: broken', 'input': IsInstance(BadIter), 'ctx': {'error': 'RuntimeError: broken'}, } ] + + +def infinite_int_gen() -> Iterator[int]: + num = 0 + while True: + yield num + num += 1 + + +def infinite_str_gen() -> Iterator[str]: + num = 0 + while True: + yield f'a_{num}' + num += 1 + + +# consumed the first item into an error, when we found a second item we errored so 3rd item is next +@pytest.mark.parametrize('gen_factory,nxt', [(infinite_int_gen, 2), (infinite_str_gen, 'a_2')]) +def test_stop_iterating_on_error(gen_factory: Callable[[], Iterator[Any]], nxt: Any) -> None: + v = SchemaValidator(core_schema.list_schema(core_schema.int_schema(), max_length=1)) + + gen = gen_factory() + + with pytest.raises(ValidationError) as exc_info: + v.validate_python(gen) + + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'too_long', + 'loc': (), + 'msg': 'List should have at most 1 item after validation, not 2', + 'input': gen, + 'ctx': {'field_type': 'List', 'max_length': 1, 'actual_length': 2}, + } + ] + + assert next(gen) == nxt + + +def test_stop_iterating_func_raises_omit() -> None: + def f(x: int) -> int: + if x < 100: + raise PydanticOmit + return x + + v = SchemaValidator( + core_schema.list_schema(core_schema.no_info_after_validator_function(f, core_schema.int_schema()), max_length=1) + ) + + gen = infinite_int_gen() + + with pytest.raises(ValidationError) as exc_info: + v.validate_python(gen) + + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'too_long', + 'loc': (), + 'msg': 'List should have at most 1 item after validation, not 2', + 'input': gen, + 'ctx': {'field_type': 'List', 'max_length': 1, 'actual_length': 2}, + } + ] + + assert next(gen) == 102 + + +@pytest.mark.parametrize('error_in_func', [True, False]) +def test_max_length_fail_fast(error_in_func: bool) -> None: + calls: list[int] = [] + + def f(v: int) -> int: + calls.append(v) + if error_in_func: + assert v < 10 + return v + + s = core_schema.list_schema( + core_schema.no_info_after_validator_function(f, core_schema.int_schema()), max_length=10 + ) + + v = SchemaValidator(s) + + data = list(range(15)) + + with pytest.raises(ValidationError) as exc_info: + v.validate_python(data) + + assert len(calls) <= 11, len(calls) # we still run validation on the "extra" item + + assert exc_info.value.errors(include_url=False) == Contains( + { + 'type': 'too_long', + 'loc': (), + 'msg': 'List should have at most 10 items after validation, not 11', + 'input': data, + 'ctx': {'field_type': 'List', 'max_length': 10, 'actual_length': 11}, + } + ) diff --git a/tests/validators/test_set.py b/tests/validators/test_set.py index d94ccb359..9f9af8ca7 100644 --- a/tests/validators/test_set.py +++ b/tests/validators/test_set.py @@ -142,11 +142,6 @@ def generate_repeats(): ( {'max_length': 3}, infinite_generator(), - Err('Set should have at most 30 items after validation, not 31 [type=too_long,'), - ), - ( - {'max_length': 3, 'generator_max_length': 3}, - infinite_generator(), Err('Set should have at most 3 items after validation, not 4 [type=too_long,'), ), ], diff --git a/tests/validators/test_tuple.py b/tests/validators/test_tuple.py index baceb3e9b..9d2a5716c 100644 --- a/tests/validators/test_tuple.py +++ b/tests/validators/test_tuple.py @@ -1,13 +1,14 @@ +import itertools import re from collections import deque from typing import Any, Dict, Type import pytest -from dirty_equals import IsNonNegative +from dirty_equals import Contains, IsNonNegative from pydantic_core import SchemaValidator, ValidationError -from ..conftest import Err, PyAndJson, infinite_generator +from ..conftest import AnyOf, Err, PyAndJson, infinite_generator @pytest.mark.parametrize( @@ -39,13 +40,13 @@ def test_tuple_json(py_and_json: PyAndJson, mode, items, input_value, expected): assert v.validate_test(input_value) == expected -def test_any_no_copy(): +def test_any_copied(): v = SchemaValidator({'type': 'tuple-variable'}) input_value = (1, '2', b'3') output = v.validate_python(input_value) assert output == input_value - assert output is input_value - assert id(output) == id(input_value) + assert output is not input_value + assert id(output) != id(input_value) @pytest.mark.parametrize( @@ -144,8 +145,8 @@ def test_tuple_var_len_kwargs(kwargs: Dict[str, Any], input_value, expected): ({1: 10, 2: 20, '3': '30'}.keys(), (1, 2, 3)), ({1: 10, 2: 20, '3': '30'}.values(), (10, 20, 30)), ({1: 10, 2: 20, '3': '30'}, Err('Input should be a valid tuple [type=tuple_type,')), - ({1, 2, '3'}, Err('Input should be a valid tuple [type=tuple_type,')), - (frozenset([1, 2, '3']), Err('Input should be a valid tuple [type=tuple_type,')), + ({1, 2, '3'}, AnyOf([tuple(o) for o in sorted(itertools.permutations([1, 2, 3]))])), + (frozenset([1, 2, '3']), AnyOf([tuple(o) for o in sorted(itertools.permutations([1, 2, 3]))])), ], ids=repr, ) @@ -154,6 +155,8 @@ def test_tuple_validate(input_value, expected, mode, items): if isinstance(expected, Err): with pytest.raises(ValidationError, match=re.escape(expected.message)): v.validate_python(input_value) + elif isinstance(expected, AnyOf): + assert v.validate_python(input_value) in expected.expected else: assert v.validate_python(input_value) == expected @@ -211,14 +214,12 @@ def test_tuple_fix_len_errors(input_value, items, index): v = SchemaValidator({'type': 'tuple-positional', 'items_schema': items}) with pytest.raises(ValidationError) as exc_info: assert v.validate_python(input_value) - assert exc_info.value.errors(include_url=False) == [ - { - 'type': 'int_parsing', - 'loc': (index,), - 'msg': 'Input should be a valid integer, unable to parse string as an integer', - 'input': 'wrong', - } - ] + assert { + 'type': 'int_parsing', + 'loc': (index,), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'input': 'wrong', + } in exc_info.value.errors(include_url=False) def test_multiple_missing(py_and_json: PyAndJson): @@ -231,11 +232,12 @@ def test_multiple_missing(py_and_json: PyAndJson): assert v.validate_test([1, 2, 3, 4]) == (1, 2, 3, 4) with pytest.raises(ValidationError) as exc_info: v.validate_test([1]) - assert exc_info.value.errors(include_url=False) == [ + + assert exc_info.value.errors(include_url=False) == Contains( {'type': 'missing', 'loc': (1,), 'msg': 'Field required', 'input': [1]}, {'type': 'missing', 'loc': (2,), 'msg': 'Field required', 'input': [1]}, {'type': 'missing', 'loc': (3,), 'msg': 'Field required', 'input': [1]}, - ] + ) with pytest.raises(ValidationError) as exc_info: v.validate_test([1, 2, 3]) assert exc_info.value.errors(include_url=False) == [ @@ -253,9 +255,9 @@ def test_extra_arguments(py_and_json: PyAndJson): { 'type': 'too_long', 'loc': (), - 'msg': 'Tuple should have at most 2 items after validation, not 4', + 'msg': 'Tuple should have at most 2 items after validation, not 3', 'input': [1, 2, 3, 4], - 'ctx': {'field_type': 'Tuple', 'max_length': 2, 'actual_length': 4}, + 'ctx': {'field_type': 'Tuple', 'max_length': 2, 'actual_length': 3}, } ]