diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 210ca8988..299a0f1e8 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -6,10 +6,8 @@ use pyo3::prelude::*; use pyo3::types::{ PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator, PyList, - PyMapping, PySequence, PySet, PyString, PyTime, PyTuple, PyType, + PyMapping, PySet, PyString, PyTime, PyTuple, PyType, }; -#[cfg(not(PyPy))] -use pyo3::types::{PyDictItems, PyDictKeys, PyDictValues}; use speedate::MicrosecondsPrecisionOverflowBehavior; @@ -34,71 +32,14 @@ use super::ConsumeIterator; use super::KeywordArgs; use super::PositionalArgs; use super::ValidatedDict; +use super::ValidatedList; +use super::ValidatedSet; +use super::ValidatedTuple; use super::{ - py_string_str, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericIterable, - GenericIterator, Input, + py_string_str, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericIterator, + Input, }; -#[cfg(not(PyPy))] -macro_rules! extract_dict_keys { - ($py:expr, $obj:expr) => { - $obj.downcast::() - .ok() - .map(|v| PyIterator::from_bound_object(v).unwrap()) - }; -} - -#[cfg(PyPy)] -macro_rules! extract_dict_keys { - ($py:expr, $obj:expr) => { - if is_dict_keys_type($obj) { - Some(PyIterator::from_bound_object($obj).unwrap()) - } else { - None - } - }; -} - -#[cfg(not(PyPy))] -macro_rules! extract_dict_values { - ($py:expr, $obj:expr) => { - $obj.downcast::() - .ok() - .map(|v| PyIterator::from_bound_object(v).unwrap()) - }; -} - -#[cfg(PyPy)] -macro_rules! extract_dict_values { - ($py:expr, $obj:expr) => { - if is_dict_values_type($obj) { - Some(PyIterator::from_bound_object($obj).unwrap()) - } else { - None - } - }; -} - -#[cfg(not(PyPy))] -macro_rules! extract_dict_items { - ($py:expr, $obj:expr) => { - $obj.downcast::() - .ok() - .map(|v| PyIterator::from_bound_object(v).unwrap()) - }; -} - -#[cfg(PyPy)] -macro_rules! extract_dict_items { - ($py:expr, $obj:expr) => { - if is_dict_items_type($obj) { - Some(PyIterator::from_bound_object($obj).unwrap()) - } else { - None - } - }; -} - impl From<&Bound<'_, PyAny>> for LocItem { fn from(py_any: &Bound<'_, PyAny>) -> Self { if let Ok(py_str) = py_any.downcast::() { @@ -476,82 +417,54 @@ impl<'py> Input<'py> for Bound<'py, PyAny> { } } - type List<'a> = GenericIterable<'a, 'py> where Self: 'a; + type List<'a> = PySequenceIterable<'a, 'py> where Self: 'a; - fn validate_list<'a>(&'a self, strict: bool) -> ValMatch> { + fn validate_list<'a>(&'a self, strict: bool) -> ValMatch> { if let Ok(list) = self.downcast::() { - return Ok(ValidationMatch::exact(GenericIterable::List(list))); + return Ok(ValidationMatch::exact(PySequenceIterable::List(list))); } else if !strict { - match extract_generic_iterable(self) { - Ok( - GenericIterable::PyString(_) - | GenericIterable::Bytes(_) - | GenericIterable::Dict(_) - | GenericIterable::Mapping(_), - ) - | Err(_) => {} - Ok(other) => return Ok(ValidationMatch::lax(other)), + if let Ok(other) = extract_sequence_iterable(self) { + return Ok(ValidationMatch::lax(other)); } } Err(ValError::new(ErrorTypeDefaults::ListType, self)) } - type Tuple<'a> = GenericIterable<'a, 'py> where Self: 'a; + type Tuple<'a> = PySequenceIterable<'a, 'py> where Self: 'a; - fn validate_tuple<'a>(&'a self, strict: bool) -> ValMatch> { + fn validate_tuple<'a>(&'a self, strict: bool) -> ValMatch> { if let Ok(tup) = self.downcast::() { - return Ok(ValidationMatch::exact(GenericIterable::Tuple(tup))); + return Ok(ValidationMatch::exact(PySequenceIterable::Tuple(tup))); } else if !strict { - match extract_generic_iterable(self) { - Ok( - GenericIterable::PyString(_) - | GenericIterable::Bytes(_) - | GenericIterable::Dict(_) - | GenericIterable::Mapping(_), - ) - | Err(_) => {} - Ok(other) => return Ok(ValidationMatch::lax(other)), + if let Ok(other) = extract_sequence_iterable(self) { + return Ok(ValidationMatch::lax(other)); } } Err(ValError::new(ErrorTypeDefaults::TupleType, self)) } - type Set<'a> = GenericIterable<'a, 'py> where Self: 'a; + type Set<'a> = PySequenceIterable<'a, 'py> where Self: 'a; - fn validate_set<'a>(&'a self, strict: bool) -> ValMatch> { + fn validate_set<'a>(&'a self, strict: bool) -> ValMatch> { if let Ok(set) = self.downcast::() { - return Ok(ValidationMatch::exact(GenericIterable::Set(set))); + return Ok(ValidationMatch::exact(PySequenceIterable::Set(set))); } else if !strict { - match extract_generic_iterable(self) { - Ok( - GenericIterable::PyString(_) - | GenericIterable::Bytes(_) - | GenericIterable::Dict(_) - | GenericIterable::Mapping(_), - ) - | Err(_) => {} - Ok(other) => return Ok(ValidationMatch::lax(other)), + if let Ok(other) = extract_sequence_iterable(self) { + return Ok(ValidationMatch::lax(other)); } } Err(ValError::new(ErrorTypeDefaults::SetType, self)) } - fn validate_frozenset<'a>(&'a self, strict: bool) -> ValMatch> { + fn validate_frozenset<'a>(&'a self, strict: bool) -> ValMatch> { if let Ok(frozenset) = self.downcast::() { - return Ok(ValidationMatch::exact(GenericIterable::FrozenSet(frozenset))); + return Ok(ValidationMatch::exact(PySequenceIterable::FrozenSet(frozenset))); } else if !strict { - match extract_generic_iterable(self) { - Ok( - GenericIterable::PyString(_) - | GenericIterable::Bytes(_) - | GenericIterable::Dict(_) - | GenericIterable::Mapping(_), - ) - | Err(_) => {} - Ok(other) => return Ok(ValidationMatch::lax(other)), + if let Ok(other) = extract_sequence_iterable(self) { + return Ok(ValidationMatch::lax(other)); } } @@ -932,37 +845,100 @@ impl<'py> ValidatedDict<'py> for GenericPyMapping<'_, 'py> { } } -fn extract_generic_iterable<'a, 'py>(obj: &'a Bound<'py, PyAny>) -> ValResult> { +/// Container for all the collections (sized iterable containers) types, which +/// can mostly be converted to each other in lax mode. +/// This mostly matches python's definition of `Collection`. +pub enum PySequenceIterable<'a, 'py> { + List(&'a Bound<'py, PyList>), + Tuple(&'a Bound<'py, PyTuple>), + Set(&'a Bound<'py, PySet>), + FrozenSet(&'a Bound<'py, PyFrozenSet>), + Iterator(Bound<'py, PyIterator>), +} + +/// Extract types which can be iterated to produce a sequence-like container like a list, tuple, set +/// or frozenset +fn extract_sequence_iterable<'a, 'py>(obj: &'a Bound<'py, PyAny>) -> ValResult> { // Handle concrete non-overlapping types first, then abstract types if let Ok(iterable) = obj.downcast::() { - Ok(GenericIterable::List(iterable)) + Ok(PySequenceIterable::List(iterable)) } else if let Ok(iterable) = obj.downcast::() { - Ok(GenericIterable::Tuple(iterable)) + Ok(PySequenceIterable::Tuple(iterable)) } else if let Ok(iterable) = obj.downcast::() { - Ok(GenericIterable::Set(iterable)) + Ok(PySequenceIterable::Set(iterable)) } else if let Ok(iterable) = obj.downcast::() { - Ok(GenericIterable::FrozenSet(iterable)) - } else if let Ok(iterable) = obj.downcast::() { - Ok(GenericIterable::Dict(iterable)) - } else if let Some(iterable) = extract_dict_keys!(obj.py(), obj) { - Ok(GenericIterable::DictKeys(iterable)) - } else if let Some(iterable) = extract_dict_values!(obj.py(), obj) { - Ok(GenericIterable::DictValues(iterable)) - } else if let Some(iterable) = extract_dict_items!(obj.py(), obj) { - Ok(GenericIterable::DictItems(iterable)) - } else if let Ok(iterable) = obj.downcast::() { - Ok(GenericIterable::Mapping(iterable)) - } else if let Ok(iterable) = obj.downcast::() { - Ok(GenericIterable::PyString(iterable)) - } else if let Ok(iterable) = obj.downcast::() { - Ok(GenericIterable::Bytes(iterable)) - } else if let Ok(iterable) = obj.downcast::() { - Ok(GenericIterable::PyByteArray(iterable)) - } else if let Ok(iterable) = obj.downcast::() { - Ok(GenericIterable::Sequence(iterable)) - } else if let Ok(iterable) = obj.iter() { - Ok(GenericIterable::Iterator(iterable)) + Ok(PySequenceIterable::FrozenSet(iterable)) } else { + // Try to get this as a generable iterable thing, but exclude string and mapping types + if !(obj.is_instance_of::() + || obj.is_instance_of::() + || obj.is_instance_of::() + || obj.is_instance_of::() + || obj.downcast::().is_ok()) + { + if let Ok(iter) = obj.iter() { + return Ok(PySequenceIterable::Iterator(iter)); + } + } + Err(ValError::new(ErrorTypeDefaults::IterableType, obj)) } } + +impl<'py> PySequenceIterable<'_, 'py> { + pub fn generic_len(&self) -> Option { + match &self { + PySequenceIterable::List(iter) => Some(iter.len()), + PySequenceIterable::Tuple(iter) => Some(iter.len()), + PySequenceIterable::Set(iter) => Some(iter.len()), + PySequenceIterable::FrozenSet(iter) => Some(iter.len()), + PySequenceIterable::Iterator(iter) => iter.len().ok(), + } + } + + fn generic_iterate( + self, + consumer: impl ConsumeIterator>, Output = R>, + ) -> ValResult { + match self { + PySequenceIterable::List(iter) => Ok(consumer.consume_iterator(iter.iter().map(Ok))), + PySequenceIterable::Tuple(iter) => Ok(consumer.consume_iterator(iter.iter().map(Ok))), + PySequenceIterable::Set(iter) => Ok(consumer.consume_iterator(iter.iter().map(Ok))), + PySequenceIterable::FrozenSet(iter) => Ok(consumer.consume_iterator(iter.iter().map(Ok))), + PySequenceIterable::Iterator(iter) => Ok(consumer.consume_iterator(iter.iter()?)), + } + } +} + +impl<'py> ValidatedList<'py> for PySequenceIterable<'_, 'py> { + type Item = Bound<'py, PyAny>; + fn len(&self) -> Option { + self.generic_len() + } + fn iterate(self, consumer: impl ConsumeIterator, Output = R>) -> ValResult { + self.generic_iterate(consumer) + } + fn as_py_list(&self) -> Option<&Bound<'py, PyList>> { + match self { + PySequenceIterable::List(iter) => Some(iter), + _ => None, + } + } +} + +impl<'py> ValidatedTuple<'py> for PySequenceIterable<'_, 'py> { + type Item = Bound<'py, PyAny>; + fn len(&self) -> Option { + self.generic_len() + } + fn iterate(self, consumer: impl ConsumeIterator, Output = R>) -> ValResult { + self.generic_iterate(consumer) + } +} + +impl<'py> ValidatedSet<'py> for PySequenceIterable<'_, 'py> { + type Item = Bound<'py, PyAny>; + fn iterate(self, consumer: impl ConsumeIterator, Output = R>) -> ValResult { + self.generic_iterate(consumer) + } +} diff --git a/src/input/mod.rs b/src/input/mod.rs index a742f6b43..5a2fdeb69 100644 --- a/src/input/mod.rs +++ b/src/input/mod.rs @@ -22,7 +22,7 @@ pub(crate) use input_abstract::{ pub(crate) use input_string::StringMapping; pub(crate) use return_enums::{ no_validator_iter_to_vec, py_string_str, validate_iter_to_set, validate_iter_to_vec, EitherBytes, EitherFloat, - EitherInt, EitherString, GenericIterable, GenericIterator, Int, MaxLengthCheck, ValidationMatch, + EitherInt, EitherString, GenericIterator, Int, MaxLengthCheck, ValidationMatch, }; // Defined here as it's not exported by pyo3 diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index 1ff1d51ca..b4a699735 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -12,10 +12,7 @@ use pyo3::intern; use pyo3::prelude::*; #[cfg(not(PyPy))] use pyo3::types::PyFunction; -use pyo3::types::{ - PyByteArray, PyBytes, PyDict, PyFloat, PyFrozenSet, PyIterator, PyList, PyMapping, PySequence, PySet, PyString, - PyTuple, -}; +use pyo3::types::{PyBytes, PyFloat, PyFrozenSet, PyIterator, PyMapping, PySet, PyString}; use serde::{ser::Error, Serialize, Serializer}; @@ -25,7 +22,7 @@ use crate::errors::{ use crate::tools::{extract_i64, py_err}; use crate::validators::{CombinedValidator, Exactness, ValidationState, Validator}; -use super::{py_error_on_minusone, BorrowInput, ConsumeIterator, Input, ValidatedList, ValidatedSet, ValidatedTuple}; +use super::{py_error_on_minusone, BorrowInput, Input}; pub struct ValidationMatch(T, Exactness); @@ -60,29 +57,6 @@ impl ValidationMatch { } } -/// Container for all the collections (sized iterable containers) types, which -/// can mostly be converted to each other in lax mode. -/// This mostly matches python's definition of `Collection`. -#[cfg_attr(debug_assertions, derive(Debug))] -pub enum GenericIterable<'a, 'py> { - List(&'a Bound<'py, PyList>), - Tuple(&'a Bound<'py, PyTuple>), - Set(&'a Bound<'py, PySet>), - FrozenSet(&'a Bound<'py, PyFrozenSet>), - Dict(&'a Bound<'py, PyDict>), - // Treat dict values / keys / items as generic iterators - // since PyPy doesn't export the concrete types - DictKeys(Bound<'py, PyIterator>), - DictValues(Bound<'py, PyIterator>), - DictItems(Bound<'py, PyIterator>), - Mapping(&'a Bound<'py, PyMapping>), - PyString(&'a Bound<'py, PyString>), - Bytes(&'a Bound<'py, PyBytes>), - PyByteArray(&'a Bound<'py, PyByteArray>), - Sequence(&'a Bound<'py, PySequence>), - Iterator(Bound<'py, PyIterator>), -} - pub struct MaxLengthCheck<'a, INPUT: ?Sized> { current_length: usize, max_length: Option, @@ -269,84 +243,6 @@ pub(crate) fn no_validator_iter_to_vec<'py>( .collect() } -impl<'py> GenericIterable<'_, 'py> { - pub fn generic_len(&self) -> Option { - match &self { - GenericIterable::List(iter) => Some(iter.len()), - GenericIterable::Tuple(iter) => Some(iter.len()), - GenericIterable::Set(iter) => Some(iter.len()), - GenericIterable::FrozenSet(iter) => Some(iter.len()), - GenericIterable::Dict(iter) => Some(iter.len()), - GenericIterable::DictKeys(iter) => iter.len().ok(), - GenericIterable::DictValues(iter) => iter.len().ok(), - GenericIterable::DictItems(iter) => iter.len().ok(), - GenericIterable::Mapping(iter) => iter.len().ok(), - GenericIterable::PyString(iter) => iter.len().ok(), - GenericIterable::Bytes(iter) => iter.len().ok(), - GenericIterable::PyByteArray(iter) => Some(iter.len()), - GenericIterable::Sequence(iter) => iter.len().ok(), - GenericIterable::Iterator(iter) => iter.len().ok(), - } - } - - fn generic_iterate( - self, - consumer: impl ConsumeIterator>, Output = R>, - ) -> ValResult { - match self { - GenericIterable::List(iter) => Ok(consumer.consume_iterator(iter.iter().map(Ok))), - GenericIterable::Tuple(iter) => Ok(consumer.consume_iterator(iter.iter().map(Ok))), - GenericIterable::Set(iter) => Ok(consumer.consume_iterator(iter.iter().map(Ok))), - GenericIterable::FrozenSet(iter) => Ok(consumer.consume_iterator(iter.iter().map(Ok))), - // Note that this iterates over only the keys, just like doing iter({}) in Python - GenericIterable::Dict(iter) => Ok(consumer.consume_iterator(iter.iter().map(|(k, _)| Ok(k)))), - GenericIterable::DictKeys(iter) => Ok(consumer.consume_iterator(iter.iter()?)), - GenericIterable::DictValues(iter) => Ok(consumer.consume_iterator(iter.iter()?)), - GenericIterable::DictItems(iter) => Ok(consumer.consume_iterator(iter.iter()?)), - // Note that this iterates over only the keys, just like doing iter({}) in Python - GenericIterable::Mapping(iter) => Ok(consumer.consume_iterator(iter.keys()?.iter()?)), - GenericIterable::PyString(iter) => Ok(consumer.consume_iterator(iter.iter()?)), - GenericIterable::Bytes(iter) => Ok(consumer.consume_iterator(iter.iter()?)), - GenericIterable::PyByteArray(iter) => Ok(consumer.consume_iterator(iter.iter()?)), - GenericIterable::Sequence(iter) => Ok(consumer.consume_iterator(iter.iter()?)), - GenericIterable::Iterator(iter) => Ok(consumer.consume_iterator(iter.iter()?)), - } - } -} - -impl<'py> ValidatedList<'py> for GenericIterable<'_, 'py> { - type Item = Bound<'py, PyAny>; - fn len(&self) -> Option { - self.generic_len() - } - fn iterate(self, consumer: impl ConsumeIterator, Output = R>) -> ValResult { - self.generic_iterate(consumer) - } - fn as_py_list(&self) -> Option<&Bound<'py, PyList>> { - match self { - GenericIterable::List(iter) => Some(iter), - _ => None, - } - } -} - -impl<'py> ValidatedTuple<'py> for GenericIterable<'_, 'py> { - type Item = Bound<'py, PyAny>; - fn len(&self) -> Option { - self.generic_len() - } - fn iterate(self, consumer: impl ConsumeIterator, Output = R>) -> ValResult { - self.generic_iterate(consumer) - } -} - -impl<'py> ValidatedSet<'py> for GenericIterable<'_, 'py> { - type Item = Bound<'py, PyAny>; - fn iterate(self, consumer: impl ConsumeIterator, Output = R>) -> ValResult { - self.generic_iterate(consumer) - } -} - pub(crate) fn iterate_mapping_items<'a, 'py>( mapping: &'a Bound<'py, PyMapping>, ) -> ValResult, Bound<'py, PyAny>)>> + 'a> {