Skip to content

Commit

Permalink
prune python generic iterable types
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Mar 20, 2024
1 parent 7addbfd commit e5c5ded
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 244 deletions.
250 changes: 113 additions & 137 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ use pyo3::prelude::*;

use pyo3::types::{
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator, PyList,
PyMapping, PySequence, PySet, PyString, PyTime, PyTuple, PyType,
PyMapping, PySet, PyString, PyTime, PyTuple, PyType,
};
#[cfg(not(PyPy))]
use pyo3::types::{PyDictItems, PyDictKeys, PyDictValues};

use speedate::MicrosecondsPrecisionOverflowBehavior;

Expand All @@ -34,71 +32,14 @@ use super::ConsumeIterator;
use super::KeywordArgs;
use super::PositionalArgs;
use super::ValidatedDict;
use super::ValidatedList;
use super::ValidatedSet;
use super::ValidatedTuple;
use super::{
py_string_str, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericIterable,
GenericIterator, Input,
py_string_str, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericIterator,
Input,
};

#[cfg(not(PyPy))]
macro_rules! extract_dict_keys {
($py:expr, $obj:expr) => {
$obj.downcast::<PyDictKeys>()
.ok()
.map(|v| PyIterator::from_bound_object(v).unwrap())
};
}

#[cfg(PyPy)]
macro_rules! extract_dict_keys {
($py:expr, $obj:expr) => {
if is_dict_keys_type($obj) {
Some(PyIterator::from_bound_object($obj).unwrap())
} else {
None
}
};
}

#[cfg(not(PyPy))]
macro_rules! extract_dict_values {
($py:expr, $obj:expr) => {
$obj.downcast::<PyDictValues>()
.ok()
.map(|v| PyIterator::from_bound_object(v).unwrap())
};
}

#[cfg(PyPy)]
macro_rules! extract_dict_values {
($py:expr, $obj:expr) => {
if is_dict_values_type($obj) {
Some(PyIterator::from_bound_object($obj).unwrap())
} else {
None
}
};
}

#[cfg(not(PyPy))]
macro_rules! extract_dict_items {
($py:expr, $obj:expr) => {
$obj.downcast::<PyDictItems>()
.ok()
.map(|v| PyIterator::from_bound_object(v).unwrap())
};
}

#[cfg(PyPy)]
macro_rules! extract_dict_items {
($py:expr, $obj:expr) => {
if is_dict_items_type($obj) {
Some(PyIterator::from_bound_object($obj).unwrap())
} else {
None
}
};
}

impl From<&Bound<'_, PyAny>> for LocItem {
fn from(py_any: &Bound<'_, PyAny>) -> Self {
if let Ok(py_str) = py_any.downcast::<PyString>() {
Expand Down Expand Up @@ -476,82 +417,54 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
}
}

type List<'a> = GenericIterable<'a, 'py> where Self: 'a;
type List<'a> = PySequenceIterable<'a, 'py> where Self: 'a;

fn validate_list<'a>(&'a self, strict: bool) -> ValMatch<GenericIterable<'a, 'py>> {
fn validate_list<'a>(&'a self, strict: bool) -> ValMatch<PySequenceIterable<'a, 'py>> {
if let Ok(list) = self.downcast::<PyList>() {
return Ok(ValidationMatch::exact(GenericIterable::List(list)));
return Ok(ValidationMatch::exact(PySequenceIterable::List(list)));
} else if !strict {
match extract_generic_iterable(self) {
Ok(
GenericIterable::PyString(_)
| GenericIterable::Bytes(_)
| GenericIterable::Dict(_)
| GenericIterable::Mapping(_),
)
| Err(_) => {}
Ok(other) => return Ok(ValidationMatch::lax(other)),
if let Ok(other) = extract_sequence_iterable(self) {
return Ok(ValidationMatch::lax(other));
}
}

Err(ValError::new(ErrorTypeDefaults::ListType, self))
}

type Tuple<'a> = GenericIterable<'a, 'py> where Self: 'a;
type Tuple<'a> = PySequenceIterable<'a, 'py> where Self: 'a;

fn validate_tuple<'a>(&'a self, strict: bool) -> ValMatch<GenericIterable<'a, 'py>> {
fn validate_tuple<'a>(&'a self, strict: bool) -> ValMatch<PySequenceIterable<'a, 'py>> {
if let Ok(tup) = self.downcast::<PyTuple>() {
return Ok(ValidationMatch::exact(GenericIterable::Tuple(tup)));
return Ok(ValidationMatch::exact(PySequenceIterable::Tuple(tup)));
} else if !strict {
match extract_generic_iterable(self) {
Ok(
GenericIterable::PyString(_)
| GenericIterable::Bytes(_)
| GenericIterable::Dict(_)
| GenericIterable::Mapping(_),
)
| Err(_) => {}
Ok(other) => return Ok(ValidationMatch::lax(other)),
if let Ok(other) = extract_sequence_iterable(self) {
return Ok(ValidationMatch::lax(other));
}
}

Err(ValError::new(ErrorTypeDefaults::TupleType, self))
}

type Set<'a> = GenericIterable<'a, 'py> where Self: 'a;
type Set<'a> = PySequenceIterable<'a, 'py> where Self: 'a;

fn validate_set<'a>(&'a self, strict: bool) -> ValMatch<GenericIterable<'a, 'py>> {
fn validate_set<'a>(&'a self, strict: bool) -> ValMatch<PySequenceIterable<'a, 'py>> {
if let Ok(set) = self.downcast::<PySet>() {
return Ok(ValidationMatch::exact(GenericIterable::Set(set)));
return Ok(ValidationMatch::exact(PySequenceIterable::Set(set)));
} else if !strict {
match extract_generic_iterable(self) {
Ok(
GenericIterable::PyString(_)
| GenericIterable::Bytes(_)
| GenericIterable::Dict(_)
| GenericIterable::Mapping(_),
)
| Err(_) => {}
Ok(other) => return Ok(ValidationMatch::lax(other)),
if let Ok(other) = extract_sequence_iterable(self) {
return Ok(ValidationMatch::lax(other));
}
}

Err(ValError::new(ErrorTypeDefaults::SetType, self))
}

fn validate_frozenset<'a>(&'a self, strict: bool) -> ValMatch<GenericIterable<'a, 'py>> {
fn validate_frozenset<'a>(&'a self, strict: bool) -> ValMatch<PySequenceIterable<'a, 'py>> {
if let Ok(frozenset) = self.downcast::<PyFrozenSet>() {
return Ok(ValidationMatch::exact(GenericIterable::FrozenSet(frozenset)));
return Ok(ValidationMatch::exact(PySequenceIterable::FrozenSet(frozenset)));
} else if !strict {
match extract_generic_iterable(self) {
Ok(
GenericIterable::PyString(_)
| GenericIterable::Bytes(_)
| GenericIterable::Dict(_)
| GenericIterable::Mapping(_),
)
| Err(_) => {}
Ok(other) => return Ok(ValidationMatch::lax(other)),
if let Ok(other) = extract_sequence_iterable(self) {
return Ok(ValidationMatch::lax(other));
}
}

Expand Down Expand Up @@ -932,37 +845,100 @@ impl<'py> ValidatedDict<'py> for GenericPyMapping<'_, 'py> {
}
}

fn extract_generic_iterable<'a, 'py>(obj: &'a Bound<'py, PyAny>) -> ValResult<GenericIterable<'a, 'py>> {
/// Container for all the collections (sized iterable containers) types, which
/// can mostly be converted to each other in lax mode.
/// This mostly matches python's definition of `Collection`.
pub enum PySequenceIterable<'a, 'py> {
List(&'a Bound<'py, PyList>),
Tuple(&'a Bound<'py, PyTuple>),
Set(&'a Bound<'py, PySet>),
FrozenSet(&'a Bound<'py, PyFrozenSet>),
Iterator(Bound<'py, PyIterator>),
}

/// Extract types which can be iterated to produce a sequence-like container like a list, tuple, set
/// or frozenset
fn extract_sequence_iterable<'a, 'py>(obj: &'a Bound<'py, PyAny>) -> ValResult<PySequenceIterable<'a, 'py>> {
// Handle concrete non-overlapping types first, then abstract types
if let Ok(iterable) = obj.downcast::<PyList>() {
Ok(GenericIterable::List(iterable))
Ok(PySequenceIterable::List(iterable))
} else if let Ok(iterable) = obj.downcast::<PyTuple>() {
Ok(GenericIterable::Tuple(iterable))
Ok(PySequenceIterable::Tuple(iterable))
} else if let Ok(iterable) = obj.downcast::<PySet>() {
Ok(GenericIterable::Set(iterable))
Ok(PySequenceIterable::Set(iterable))
} else if let Ok(iterable) = obj.downcast::<PyFrozenSet>() {
Ok(GenericIterable::FrozenSet(iterable))
} else if let Ok(iterable) = obj.downcast::<PyDict>() {
Ok(GenericIterable::Dict(iterable))
} else if let Some(iterable) = extract_dict_keys!(obj.py(), obj) {
Ok(GenericIterable::DictKeys(iterable))
} else if let Some(iterable) = extract_dict_values!(obj.py(), obj) {
Ok(GenericIterable::DictValues(iterable))
} else if let Some(iterable) = extract_dict_items!(obj.py(), obj) {
Ok(GenericIterable::DictItems(iterable))
} else if let Ok(iterable) = obj.downcast::<PyMapping>() {
Ok(GenericIterable::Mapping(iterable))
} else if let Ok(iterable) = obj.downcast::<PyString>() {
Ok(GenericIterable::PyString(iterable))
} else if let Ok(iterable) = obj.downcast::<PyBytes>() {
Ok(GenericIterable::Bytes(iterable))
} else if let Ok(iterable) = obj.downcast::<PyByteArray>() {
Ok(GenericIterable::PyByteArray(iterable))
} else if let Ok(iterable) = obj.downcast::<PySequence>() {
Ok(GenericIterable::Sequence(iterable))
} else if let Ok(iterable) = obj.iter() {
Ok(GenericIterable::Iterator(iterable))
Ok(PySequenceIterable::FrozenSet(iterable))
} else {
// Try to get this as a generable iterable thing, but exclude string and mapping types
if !(obj.is_instance_of::<PyString>()
|| obj.is_instance_of::<PyBytes>()
|| obj.is_instance_of::<PyByteArray>()
|| obj.is_instance_of::<PyDict>()
|| obj.downcast::<PyMapping>().is_ok())
{
if let Ok(iter) = obj.iter() {
return Ok(PySequenceIterable::Iterator(iter));
}
}

Err(ValError::new(ErrorTypeDefaults::IterableType, obj))
}
}

impl<'py> PySequenceIterable<'_, 'py> {
pub fn generic_len(&self) -> Option<usize> {
match &self {
PySequenceIterable::List(iter) => Some(iter.len()),
PySequenceIterable::Tuple(iter) => Some(iter.len()),
PySequenceIterable::Set(iter) => Some(iter.len()),
PySequenceIterable::FrozenSet(iter) => Some(iter.len()),
PySequenceIterable::Iterator(iter) => iter.len().ok(),
}
}

fn generic_iterate<R>(
self,
consumer: impl ConsumeIterator<PyResult<Bound<'py, PyAny>>, Output = R>,
) -> ValResult<R> {
match self {
PySequenceIterable::List(iter) => Ok(consumer.consume_iterator(iter.iter().map(Ok))),
PySequenceIterable::Tuple(iter) => Ok(consumer.consume_iterator(iter.iter().map(Ok))),
PySequenceIterable::Set(iter) => Ok(consumer.consume_iterator(iter.iter().map(Ok))),
PySequenceIterable::FrozenSet(iter) => Ok(consumer.consume_iterator(iter.iter().map(Ok))),
PySequenceIterable::Iterator(iter) => Ok(consumer.consume_iterator(iter.iter()?)),
}
}
}

impl<'py> ValidatedList<'py> for PySequenceIterable<'_, 'py> {
type Item = Bound<'py, PyAny>;
fn len(&self) -> Option<usize> {
self.generic_len()
}
fn iterate<R>(self, consumer: impl ConsumeIterator<PyResult<Self::Item>, Output = R>) -> ValResult<R> {
self.generic_iterate(consumer)
}
fn as_py_list(&self) -> Option<&Bound<'py, PyList>> {
match self {
PySequenceIterable::List(iter) => Some(iter),
_ => None,
}
}
}

impl<'py> ValidatedTuple<'py> for PySequenceIterable<'_, 'py> {
type Item = Bound<'py, PyAny>;
fn len(&self) -> Option<usize> {
self.generic_len()
}
fn iterate<R>(self, consumer: impl ConsumeIterator<PyResult<Self::Item>, Output = R>) -> ValResult<R> {
self.generic_iterate(consumer)
}
}

impl<'py> ValidatedSet<'py> for PySequenceIterable<'_, 'py> {
type Item = Bound<'py, PyAny>;
fn iterate<R>(self, consumer: impl ConsumeIterator<PyResult<Self::Item>, Output = R>) -> ValResult<R> {
self.generic_iterate(consumer)
}
}
2 changes: 1 addition & 1 deletion src/input/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub(crate) use input_abstract::{
pub(crate) use input_string::StringMapping;
pub(crate) use return_enums::{
no_validator_iter_to_vec, py_string_str, validate_iter_to_set, validate_iter_to_vec, EitherBytes, EitherFloat,
EitherInt, EitherString, GenericIterable, GenericIterator, Int, MaxLengthCheck, ValidationMatch,
EitherInt, EitherString, GenericIterator, Int, MaxLengthCheck, ValidationMatch,
};

// Defined here as it's not exported by pyo3
Expand Down
Loading

0 comments on commit e5c5ded

Please sign in to comment.