Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ float_cmp = "allow"
fn_params_excessive_bools = "allow"
if_not_else = "allow"
match_bool = "allow"
match_same_arms = "allow"
missing_errors_doc = "allow"
missing_panics_doc = "allow"
module_name_repetitions = "allow"
Expand Down
56 changes: 23 additions & 33 deletions src/errors/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ impl ErrorType {
}

pub fn message_template_python(&self) -> &'static str {
#[allow(clippy::match_same_arms)] // much nicer to have the messages explicitly listed
match self {
Self::NoSuchAttribute {..} => "Object has no attribute '{attribute}'",
Self::JsonInvalid {..} => "Invalid JSON: {error}",
Expand Down Expand Up @@ -636,12 +637,24 @@ impl ErrorType {
};
match self {
Self::NoSuchAttribute { attribute, .. } => render!(tmpl, attribute),
Self::JsonInvalid { error, .. } => render!(tmpl, error),
Self::JsonInvalid { error, .. }
| Self::GetAttributeError { error, .. }
| Self::IterationError { error, .. }
| Self::DatetimeObjectInvalid { error, .. }
| Self::UrlParsing { error, .. }
| Self::UuidParsing { error, .. } => render!(tmpl, error),
Self::MappingType { error, .. }
| Self::DateParsing { error, .. }
| Self::DateFromDatetimeParsing { error, .. }
| Self::TimeParsing { error, .. }
| Self::DatetimeParsing { error, .. }
| Self::DatetimeFromDateParsing { error, .. }
| Self::TimeDeltaParsing { error, .. }
| Self::UrlSyntaxViolation { error, .. } => render!(tmpl, error),
Self::NeedsPythonObject { method_name, .. } => render!(tmpl, method_name),
Self::GetAttributeError { error, .. } => render!(tmpl, error),
Self::ModelType { class_name, .. } => render!(tmpl, class_name),
Self::DataclassType { class_name, .. } => render!(tmpl, class_name),
Self::DataclassExactType { class_name, .. } => render!(tmpl, class_name),
Self::ModelType { class_name, .. }
| Self::DataclassType { class_name, .. }
| Self::DataclassExactType { class_name, .. } => render!(tmpl, class_name),
Self::GreaterThan { gt, .. } => to_string_render!(tmpl, gt),
Self::GreaterThanEqual { ge, .. } => to_string_render!(tmpl, ge),
Self::LessThan { lt, .. } => to_string_render!(tmpl, lt),
Expand All @@ -666,26 +679,18 @@ impl ErrorType {
let actual_length = actual_length.map_or(Cow::Borrowed("more"), |v| Cow::Owned(v.to_string()));
to_string_render!(tmpl, field_type, max_length, actual_length, expected_plural,)
}
Self::IterationError { error, .. } => render!(tmpl, error),
Self::StringTooShort { min_length, .. } => {
Self::StringTooShort { min_length, .. } | Self::BytesTooShort { min_length, .. } => {
let expected_plural = plural_s(*min_length);
to_string_render!(tmpl, min_length, expected_plural)
}
Self::StringTooLong { max_length, .. } => {
Self::StringTooLong { max_length, .. }
| Self::BytesTooLong { max_length, .. }
| Self::UrlTooLong { max_length, .. } => {
let expected_plural = plural_s(*max_length);
to_string_render!(tmpl, max_length, expected_plural)
}
Self::StringPatternMismatch { pattern, .. } => render!(tmpl, pattern),
Self::Enum { expected, .. } => to_string_render!(tmpl, expected),
Self::MappingType { error, .. } => render!(tmpl, error),
Self::BytesTooShort { min_length, .. } => {
let expected_plural = plural_s(*min_length);
to_string_render!(tmpl, min_length, expected_plural)
}
Self::BytesTooLong { max_length, .. } => {
let expected_plural = plural_s(*max_length);
to_string_render!(tmpl, max_length, expected_plural)
}
Self::BytesInvalidEncoding {
encoding,
encoding_error,
Expand All @@ -709,33 +714,18 @@ impl ErrorType {
..
} => PydanticCustomError::format_message(message_template, context.as_ref().map(|c| c.bind(py))),
Self::LiteralError { expected, .. } => render!(tmpl, expected),
Self::DateParsing { error, .. } => render!(tmpl, error),
Self::DateFromDatetimeParsing { error, .. } => render!(tmpl, error),
Self::TimeParsing { error, .. } => render!(tmpl, error),
Self::DatetimeParsing { error, .. } => render!(tmpl, error),
Self::DatetimeFromDateParsing { error, .. } => render!(tmpl, error),
Self::DatetimeObjectInvalid { error, .. } => render!(tmpl, error),
Self::TimezoneOffset {
tz_expected, tz_actual, ..
} => to_string_render!(tmpl, tz_expected, tz_actual),
Self::TimeDeltaParsing { error, .. } => render!(tmpl, error),
Self::IsInstanceOf { class, .. } => render!(tmpl, class),
Self::IsSubclassOf { class, .. } => render!(tmpl, class),
Self::IsInstanceOf { class, .. } | Self::IsSubclassOf { class, .. } => render!(tmpl, class),
Self::UnionTagInvalid {
discriminator,
tag,
expected_tags,
..
} => render!(tmpl, discriminator, tag, expected_tags),
Self::UnionTagNotFound { discriminator, .. } => render!(tmpl, discriminator),
Self::UrlParsing { error, .. } => render!(tmpl, error),
Self::UrlSyntaxViolation { error, .. } => render!(tmpl, error),
Self::UrlTooLong { max_length, .. } => {
let expected_plural = plural_s(*max_length);
to_string_render!(tmpl, max_length, expected_plural)
}
Self::UrlScheme { expected_schemes, .. } => render!(tmpl, expected_schemes),
Self::UuidParsing { error, .. } => render!(tmpl, error),
Self::UuidVersion { expected_version, .. } => to_string_render!(tmpl, expected_version),
Self::DecimalMaxDigits { max_digits, .. } => {
let expected_plural = plural_s(*max_digits);
Expand Down
3 changes: 1 addition & 2 deletions src/input/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@ impl<'py> IntoPyObject<'py> for EitherTimedelta<'py> {
fn into_pyobject(self, py: Python<'py>) -> PyResult<Self::Output> {
match self {
Self::Raw(duration) => duration_as_pytimedelta(py, &duration),
Self::PyExact(py_timedelta) => Ok(py_timedelta),
Self::PySubclass(py_timedelta) => Ok(py_timedelta),
Self::PyExact(py_timedelta) | Self::PySubclass(py_timedelta) => Ok(py_timedelta),
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/serializers/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ pub(super) fn py_err_se_err<T: ser::Error, E: fmt::Display>(py_error: E) -> T {
/// Wrapper type which allows convenient conversion between `PyErr` and `ser::Error` in `?` expressions.
pub(super) struct WrappedSerError<T: ser::Error>(pub T);

pub fn unwrap_ser_error<T: ser::Error>(wrapped: WrappedSerError<T>) -> T {
wrapped.0
}

impl<T: ser::Error> From<PyErr> for WrappedSerError<T> {
fn from(py_err: PyErr) -> Self {
WrappedSerError(T::custom(py_err.to_string()))
Expand Down
3 changes: 1 addition & 2 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,8 @@ impl From<Option<&str>> for SerMode {
fn from(s: Option<&str>) -> Self {
match s {
Some("json") => SerMode::Json,
Some("python") => SerMode::Python,
Some("python") | None => SerMode::Python,
Some(other) => SerMode::Other(other.to_string()),
None => SerMode::Python,
}
}
}
Expand Down
53 changes: 19 additions & 34 deletions src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@ use std::cell::RefCell;
use pyo3::exceptions::PyTypeError;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::PyComplex;
use pyo3::types::{PyByteArray, PyBytes, PyDict, PyFrozenSet, PyIterator, PyList, PySet, PyString, PyTuple};

use pyo3::IntoPyObjectExt;
use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer};

use crate::input::{EitherTimedelta, Int};
use crate::serializers::errors::unwrap_ser_error;
use crate::serializers::shared::serialize_to_json;
use crate::serializers::shared::serialize_to_python;
use crate::serializers::shared::DoSerialize;
use crate::serializers::type_serializers;
use crate::serializers::type_serializers::format::serialize_via_str;
use crate::serializers::SerializationState;
use crate::tools::{extract_int, py_err, safe_repr};
use crate::url::{PyMultiHostUrl, PyUrl};
Expand Down Expand Up @@ -167,14 +168,7 @@ pub(crate) fn infer_to_python_known<'py>(
let either_delta = EitherTimedelta::try_from(value)?;
state.config.temporal_mode.timedelta_to_json(value.py(), either_delta)?
}
ObType::Url => {
let py_url: PyUrl = value.extract()?;
py_url.__str__(py).into_py_any(py)?
}
ObType::MultiHostUrl => {
let py_url: PyMultiHostUrl = value.extract()?;
py_url.__str__(py).into_py_any(py)?
}
ObType::Url | ObType::MultiHostUrl | ObType::Path => serialize_via_str(value, serialize_to_python())?,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g. IPAddress pathway will be added here

ObType::Uuid => {
let uuid = super::type_serializers::uuid::uuid_to_string(value)?;
uuid.into_py_any(py)?
Expand Down Expand Up @@ -207,8 +201,7 @@ pub(crate) fn infer_to_python_known<'py>(
let complex_str = type_serializers::complex::complex_to_str(v);
complex_str.into_py_any(py)?
}
ObType::Path => value.str()?.into_py_any(py)?,
ObType::Pattern => value.getattr(intern!(py, "pattern"))?.unbind(),
ObType::Pattern => serialize_pattern(value, serialize_to_python())?,
ObType::Unknown => {
if let Some(fallback) = state.extra.fallback {
let next_value = fallback.call1((value,))?;
Expand Down Expand Up @@ -377,7 +370,9 @@ pub(crate) fn infer_serialize_known<'py, S: Serializer>(
ObType::Decimal => value.to_string().serialize(serializer),
ObType::Str | ObType::StrSubclass => {
let py_str = value.downcast::<PyString>().map_err(py_err_se_err)?;
super::type_serializers::string::serialize_py_str(py_str, serializer)
serialize_to_json(serializer)
.serialize_str(py_str)
.map_err(unwrap_ser_error)
}
ObType::Bytes => {
let py_bytes = value.downcast::<PyBytes>().map_err(py_err_se_err)?;
Expand Down Expand Up @@ -418,16 +413,11 @@ pub(crate) fn infer_serialize_known<'py, S: Serializer>(
let either_delta = EitherTimedelta::try_from(value).map_err(py_err_se_err)?;
state.config.temporal_mode.timedelta_serialize(either_delta, serializer)
}
ObType::Url => {
let py_url: PyUrl = value.extract().map_err(py_err_se_err)?;
serializer.serialize_str(py_url.__str__(value.py()))
}
ObType::MultiHostUrl => {
let py_url: PyMultiHostUrl = value.extract().map_err(py_err_se_err)?;
serializer.serialize_str(&py_url.__str__(value.py()))
ObType::Url | ObType::MultiHostUrl | ObType::Path => {
serialize_via_str(value, serialize_to_json(serializer)).map_err(unwrap_ser_error)
}
ObType::PydanticSerializable => {
call_pydantic_serializer(value, state, serialize_to_json(serializer)).map_err(|e| e.0)
call_pydantic_serializer(value, state, serialize_to_json(serializer)).map_err(unwrap_ser_error)
}
ObType::Dataclass => {
let (pairs_iter, fields_dict) = any_dataclass_iter(value).map_err(py_err_se_err)?;
Expand Down Expand Up @@ -456,20 +446,7 @@ pub(crate) fn infer_serialize_known<'py, S: Serializer>(
}
seq.end()
}
ObType::Path => {
let s: PyBackedStr = value
.str()
.and_then(|value_str| value_str.extract())
.map_err(py_err_se_err)?;
serializer.serialize_str(&s)
}
ObType::Pattern => {
let s: PyBackedStr = value
.getattr(intern!(value.py(), "pattern"))
.and_then(|pattern| pattern.str()?.extract())
.map_err(py_err_se_err)?;
serializer.serialize_str(&s)
}
ObType::Pattern => serialize_pattern(value, serialize_to_json(serializer)).map_err(unwrap_ser_error),
ObType::Unknown => {
if let Some(fallback) = state.extra.fallback {
let next_value = fallback.call1((value,)).map_err(py_err_se_err)?;
Expand Down Expand Up @@ -497,6 +474,14 @@ fn unknown_type_error(value: &Bound<'_, PyAny>) -> PyErr {
))
}

fn serialize_pattern<'py, T, E: From<PyErr>>(
value: &Bound<'py, PyAny>,
do_serialize: impl DoSerialize<'py, T, E>,
) -> Result<T, E> {
let pattern = value.getattr(intern!(value.py(), "pattern"))?;
serialize_via_str(&pattern, do_serialize)
}

fn serialize_unknown<'py>(value: &Bound<'py, PyAny>) -> Cow<'py, str> {
if let Ok(s) = value.str() {
s.to_string_lossy().into_owned().into()
Expand Down
20 changes: 11 additions & 9 deletions src/serializers/ob_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub struct ObTypeLookup {
pattern_object: Py<PyAny>,
// uuid type
uuid_object: Py<PyAny>,
// `complex` builtin
complex: usize,
}

Expand Down Expand Up @@ -155,10 +156,10 @@ impl ObTypeLookup {
ObType::Enum => self.enum_object.as_ptr() as usize == ob_type,
ObType::Generator => self.generator_object.as_ptr() as usize == ob_type,
ObType::Path => self.path_object.as_ptr() as usize == ob_type,
ObType::Pattern => self.path_object.as_ptr() as usize == ob_type,
ObType::Pattern => self.pattern_object.as_ptr() as usize == ob_type,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clippy caught a legitimate typo here, I'm not sure if it's actually observable in practice though.

ObType::Uuid => self.uuid_object.as_ptr() as usize == ob_type,
ObType::Unknown => false,
ObType::Complex => self.complex == ob_type,
ObType::Unknown => false,
};

if ans {
Expand Down Expand Up @@ -414,9 +415,10 @@ pub enum ObType {
Pattern,
// Uuid
Uuid,
// complex builtin
Complex,
// unknown type
Unknown,
Complex,
}

impl PartialEq for ObType {
Expand All @@ -427,12 +429,12 @@ impl PartialEq for ObType {
} else {
match (self, other) {
// special cases for subclasses
(Self::IntSubclass, Self::Int) => true,
(Self::Int, Self::IntSubclass) => true,
(Self::FloatSubclass, Self::Float) => true,
(Self::Float, Self::FloatSubclass) => true,
(Self::StrSubclass, Self::Str) => true,
(Self::Str, Self::StrSubclass) => true,
(Self::IntSubclass, Self::Int)
| (Self::Int, Self::IntSubclass)
| (Self::FloatSubclass, Self::Float)
| (Self::Float, Self::FloatSubclass)
| (Self::StrSubclass, Self::Str)
| (Self::Str, Self::StrSubclass) => true,
_ => false,
}
}
Expand Down
27 changes: 20 additions & 7 deletions src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use std::io::{self, Write};
use std::sync::Arc;

use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;
use pyo3::sync::PyOnceLock;
use pyo3::types::{PyDict, PyString};
use pyo3::{intern, PyTraverseError, PyVisit};
use pyo3::{prelude::*, IntoPyObjectExt};

use enum_dispatch::enum_dispatch;
use serde::{Serialize, Serializer};
Expand Down Expand Up @@ -199,17 +199,19 @@ impl CombinedSerializer {
)
.map_err(|err| py_schema_error_type!("Error building `function-wrap` serializer:\n {}", err));
}
// applies to lists tuples and dicts, does not override the main schema `type`
Some("include-exclude-sequence" | "include-exclude-dict") => (),
// applies specifically to bytes, does not override the main schema `type`
Some("base64") => (),
Some(
// applies to lists tuples and dicts, does not override the main schema `type`
"include-exclude-sequence" | "include-exclude-dict"
// applies specifically to bytes, does not override the main schema `type`
| "base64"
)
// if `schema.serialization.type` is None, fall back to `schema.type`
| None => (),
Some(ser_type) => {
// otherwise if `schema.serialization.type` is defined, use that with `find_serializer`
// instead of `schema.type`. In this case it's an error if a serializer isn't found.
return Self::find_serializer(ser_type, &ser_schema, config, definitions);
}
// if `schema.serialization.type` is None, fall back to `schema.type`
None => (),
};
}

Expand Down Expand Up @@ -625,6 +627,8 @@ pub trait DoSerialize<'py, OutputT, ErrorT> {
value: &Bound<'py, PyAny>,
state: &mut SerializationState<'_, 'py>,
) -> Result<OutputT, ErrorT>;

fn serialize_str(self, value: &Bound<'py, PyString>) -> Result<OutputT, ErrorT>;
}

/// Helper to create a `SerializeToPython` instance
Expand Down Expand Up @@ -660,6 +664,10 @@ impl<'py> DoSerialize<'py, Py<PyAny>, PyErr> for SerializeToPython {
state.warn_fallback_py(name, value)?;
infer_to_python(value, state)
}

fn serialize_str(self, value: &Bound<'py, PyString>) -> Result<Py<PyAny>, PyErr> {
value.into_py_any(value.py())
}
}

pub struct SerializeToJson<S> {
Expand Down Expand Up @@ -687,4 +695,9 @@ impl<'py, S: Serializer> DoSerialize<'py, S::Ok, WrappedSerError<S::Error>> for
state.warn_fallback_ser::<S>(name, value).map_err(WrappedSerError)?;
infer_serialize(value, self.serializer, state).map_err(WrappedSerError)
}

fn serialize_str(self, value: &Bound<'py, PyString>) -> Result<S::Ok, WrappedSerError<S::Error>> {
let s = value.to_str()?;
self.serializer.serialize_str(s).map_err(WrappedSerError)
}
}
Loading
Loading