Skip to content

Commit 4d4975a

Browse files
authored
refactor out common str serialization logic (#1867)
1 parent 1660728 commit 4d4975a

File tree

12 files changed

+106
-108
lines changed

12 files changed

+106
-108
lines changed

Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ float_cmp = "allow"
9292
fn_params_excessive_bools = "allow"
9393
if_not_else = "allow"
9494
match_bool = "allow"
95-
match_same_arms = "allow"
9695
missing_errors_doc = "allow"
9796
missing_panics_doc = "allow"
9897
module_name_repetitions = "allow"

src/errors/types.rs

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ impl ErrorType {
480480
}
481481

482482
pub fn message_template_python(&self) -> &'static str {
483+
#[allow(clippy::match_same_arms)] // much nicer to have the messages explicitly listed
483484
match self {
484485
Self::NoSuchAttribute {..} => "Object has no attribute '{attribute}'",
485486
Self::JsonInvalid {..} => "Invalid JSON: {error}",
@@ -636,12 +637,24 @@ impl ErrorType {
636637
};
637638
match self {
638639
Self::NoSuchAttribute { attribute, .. } => render!(tmpl, attribute),
639-
Self::JsonInvalid { error, .. } => render!(tmpl, error),
640+
Self::JsonInvalid { error, .. }
641+
| Self::GetAttributeError { error, .. }
642+
| Self::IterationError { error, .. }
643+
| Self::DatetimeObjectInvalid { error, .. }
644+
| Self::UrlParsing { error, .. }
645+
| Self::UuidParsing { error, .. } => render!(tmpl, error),
646+
Self::MappingType { error, .. }
647+
| Self::DateParsing { error, .. }
648+
| Self::DateFromDatetimeParsing { error, .. }
649+
| Self::TimeParsing { error, .. }
650+
| Self::DatetimeParsing { error, .. }
651+
| Self::DatetimeFromDateParsing { error, .. }
652+
| Self::TimeDeltaParsing { error, .. }
653+
| Self::UrlSyntaxViolation { error, .. } => render!(tmpl, error),
640654
Self::NeedsPythonObject { method_name, .. } => render!(tmpl, method_name),
641-
Self::GetAttributeError { error, .. } => render!(tmpl, error),
642-
Self::ModelType { class_name, .. } => render!(tmpl, class_name),
643-
Self::DataclassType { class_name, .. } => render!(tmpl, class_name),
644-
Self::DataclassExactType { class_name, .. } => render!(tmpl, class_name),
655+
Self::ModelType { class_name, .. }
656+
| Self::DataclassType { class_name, .. }
657+
| Self::DataclassExactType { class_name, .. } => render!(tmpl, class_name),
645658
Self::GreaterThan { gt, .. } => to_string_render!(tmpl, gt),
646659
Self::GreaterThanEqual { ge, .. } => to_string_render!(tmpl, ge),
647660
Self::LessThan { lt, .. } => to_string_render!(tmpl, lt),
@@ -666,26 +679,18 @@ impl ErrorType {
666679
let actual_length = actual_length.map_or(Cow::Borrowed("more"), |v| Cow::Owned(v.to_string()));
667680
to_string_render!(tmpl, field_type, max_length, actual_length, expected_plural,)
668681
}
669-
Self::IterationError { error, .. } => render!(tmpl, error),
670-
Self::StringTooShort { min_length, .. } => {
682+
Self::StringTooShort { min_length, .. } | Self::BytesTooShort { min_length, .. } => {
671683
let expected_plural = plural_s(*min_length);
672684
to_string_render!(tmpl, min_length, expected_plural)
673685
}
674-
Self::StringTooLong { max_length, .. } => {
686+
Self::StringTooLong { max_length, .. }
687+
| Self::BytesTooLong { max_length, .. }
688+
| Self::UrlTooLong { max_length, .. } => {
675689
let expected_plural = plural_s(*max_length);
676690
to_string_render!(tmpl, max_length, expected_plural)
677691
}
678692
Self::StringPatternMismatch { pattern, .. } => render!(tmpl, pattern),
679693
Self::Enum { expected, .. } => to_string_render!(tmpl, expected),
680-
Self::MappingType { error, .. } => render!(tmpl, error),
681-
Self::BytesTooShort { min_length, .. } => {
682-
let expected_plural = plural_s(*min_length);
683-
to_string_render!(tmpl, min_length, expected_plural)
684-
}
685-
Self::BytesTooLong { max_length, .. } => {
686-
let expected_plural = plural_s(*max_length);
687-
to_string_render!(tmpl, max_length, expected_plural)
688-
}
689694
Self::BytesInvalidEncoding {
690695
encoding,
691696
encoding_error,
@@ -709,33 +714,18 @@ impl ErrorType {
709714
..
710715
} => PydanticCustomError::format_message(message_template, context.as_ref().map(|c| c.bind(py))),
711716
Self::LiteralError { expected, .. } => render!(tmpl, expected),
712-
Self::DateParsing { error, .. } => render!(tmpl, error),
713-
Self::DateFromDatetimeParsing { error, .. } => render!(tmpl, error),
714-
Self::TimeParsing { error, .. } => render!(tmpl, error),
715-
Self::DatetimeParsing { error, .. } => render!(tmpl, error),
716-
Self::DatetimeFromDateParsing { error, .. } => render!(tmpl, error),
717-
Self::DatetimeObjectInvalid { error, .. } => render!(tmpl, error),
718717
Self::TimezoneOffset {
719718
tz_expected, tz_actual, ..
720719
} => to_string_render!(tmpl, tz_expected, tz_actual),
721-
Self::TimeDeltaParsing { error, .. } => render!(tmpl, error),
722-
Self::IsInstanceOf { class, .. } => render!(tmpl, class),
723-
Self::IsSubclassOf { class, .. } => render!(tmpl, class),
720+
Self::IsInstanceOf { class, .. } | Self::IsSubclassOf { class, .. } => render!(tmpl, class),
724721
Self::UnionTagInvalid {
725722
discriminator,
726723
tag,
727724
expected_tags,
728725
..
729726
} => render!(tmpl, discriminator, tag, expected_tags),
730727
Self::UnionTagNotFound { discriminator, .. } => render!(tmpl, discriminator),
731-
Self::UrlParsing { error, .. } => render!(tmpl, error),
732-
Self::UrlSyntaxViolation { error, .. } => render!(tmpl, error),
733-
Self::UrlTooLong { max_length, .. } => {
734-
let expected_plural = plural_s(*max_length);
735-
to_string_render!(tmpl, max_length, expected_plural)
736-
}
737728
Self::UrlScheme { expected_schemes, .. } => render!(tmpl, expected_schemes),
738-
Self::UuidParsing { error, .. } => render!(tmpl, error),
739729
Self::UuidVersion { expected_version, .. } => to_string_render!(tmpl, expected_version),
740730
Self::DecimalMaxDigits { max_digits, .. } => {
741731
let expected_plural = plural_s(*max_digits);

src/input/datetime.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,7 @@ impl<'py> IntoPyObject<'py> for EitherTimedelta<'py> {
113113
fn into_pyobject(self, py: Python<'py>) -> PyResult<Self::Output> {
114114
match self {
115115
Self::Raw(duration) => duration_as_pytimedelta(py, &duration),
116-
Self::PyExact(py_timedelta) => Ok(py_timedelta),
117-
Self::PySubclass(py_timedelta) => Ok(py_timedelta),
116+
Self::PyExact(py_timedelta) | Self::PySubclass(py_timedelta) => Ok(py_timedelta),
118117
}
119118
}
120119
}

src/serializers/errors.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ pub(super) fn py_err_se_err<T: ser::Error, E: fmt::Display>(py_error: E) -> T {
2121
/// Wrapper type which allows convenient conversion between `PyErr` and `ser::Error` in `?` expressions.
2222
pub(super) struct WrappedSerError<T: ser::Error>(pub T);
2323

24+
pub fn unwrap_ser_error<T: ser::Error>(wrapped: WrappedSerError<T>) -> T {
25+
wrapped.0
26+
}
27+
2428
impl<T: ser::Error> From<PyErr> for WrappedSerError<T> {
2529
fn from(py_err: PyErr) -> Self {
2630
WrappedSerError(T::custom(py_err.to_string()))

src/serializers/extra.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,9 +372,8 @@ impl From<Option<&str>> for SerMode {
372372
fn from(s: Option<&str>) -> Self {
373373
match s {
374374
Some("json") => SerMode::Json,
375-
Some("python") => SerMode::Python,
375+
Some("python") | None => SerMode::Python,
376376
Some(other) => SerMode::Other(other.to_string()),
377-
None => SerMode::Python,
378377
}
379378
}
380379
}

src/serializers/infer.rs

Lines changed: 19 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@ use std::cell::RefCell;
44
use pyo3::exceptions::PyTypeError;
55
use pyo3::intern;
66
use pyo3::prelude::*;
7-
use pyo3::pybacked::PyBackedStr;
87
use pyo3::types::PyComplex;
98
use pyo3::types::{PyByteArray, PyBytes, PyDict, PyFrozenSet, PyIterator, PyList, PySet, PyString, PyTuple};
109

1110
use pyo3::IntoPyObjectExt;
1211
use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer};
1312

1413
use crate::input::{EitherTimedelta, Int};
14+
use crate::serializers::errors::unwrap_ser_error;
1515
use crate::serializers::shared::serialize_to_json;
1616
use crate::serializers::shared::serialize_to_python;
1717
use crate::serializers::shared::DoSerialize;
1818
use crate::serializers::type_serializers;
19+
use crate::serializers::type_serializers::format::serialize_via_str;
1920
use crate::serializers::SerializationState;
2021
use crate::tools::{extract_int, py_err, safe_repr};
2122
use crate::url::{PyMultiHostUrl, PyUrl};
@@ -167,14 +168,7 @@ pub(crate) fn infer_to_python_known<'py>(
167168
let either_delta = EitherTimedelta::try_from(value)?;
168169
state.config.temporal_mode.timedelta_to_json(value.py(), either_delta)?
169170
}
170-
ObType::Url => {
171-
let py_url: PyUrl = value.extract()?;
172-
py_url.__str__(py).into_py_any(py)?
173-
}
174-
ObType::MultiHostUrl => {
175-
let py_url: PyMultiHostUrl = value.extract()?;
176-
py_url.__str__(py).into_py_any(py)?
177-
}
171+
ObType::Url | ObType::MultiHostUrl | ObType::Path => serialize_via_str(value, serialize_to_python())?,
178172
ObType::Uuid => {
179173
let uuid = super::type_serializers::uuid::uuid_to_string(value)?;
180174
uuid.into_py_any(py)?
@@ -207,8 +201,7 @@ pub(crate) fn infer_to_python_known<'py>(
207201
let complex_str = type_serializers::complex::complex_to_str(v);
208202
complex_str.into_py_any(py)?
209203
}
210-
ObType::Path => value.str()?.into_py_any(py)?,
211-
ObType::Pattern => value.getattr(intern!(py, "pattern"))?.unbind(),
204+
ObType::Pattern => serialize_pattern(value, serialize_to_python())?,
212205
ObType::Unknown => {
213206
if let Some(fallback) = state.extra.fallback {
214207
let next_value = fallback.call1((value,))?;
@@ -377,7 +370,9 @@ pub(crate) fn infer_serialize_known<'py, S: Serializer>(
377370
ObType::Decimal => value.to_string().serialize(serializer),
378371
ObType::Str | ObType::StrSubclass => {
379372
let py_str = value.downcast::<PyString>().map_err(py_err_se_err)?;
380-
super::type_serializers::string::serialize_py_str(py_str, serializer)
373+
serialize_to_json(serializer)
374+
.serialize_str(py_str)
375+
.map_err(unwrap_ser_error)
381376
}
382377
ObType::Bytes => {
383378
let py_bytes = value.downcast::<PyBytes>().map_err(py_err_se_err)?;
@@ -418,16 +413,11 @@ pub(crate) fn infer_serialize_known<'py, S: Serializer>(
418413
let either_delta = EitherTimedelta::try_from(value).map_err(py_err_se_err)?;
419414
state.config.temporal_mode.timedelta_serialize(either_delta, serializer)
420415
}
421-
ObType::Url => {
422-
let py_url: PyUrl = value.extract().map_err(py_err_se_err)?;
423-
serializer.serialize_str(py_url.__str__(value.py()))
424-
}
425-
ObType::MultiHostUrl => {
426-
let py_url: PyMultiHostUrl = value.extract().map_err(py_err_se_err)?;
427-
serializer.serialize_str(&py_url.__str__(value.py()))
416+
ObType::Url | ObType::MultiHostUrl | ObType::Path => {
417+
serialize_via_str(value, serialize_to_json(serializer)).map_err(unwrap_ser_error)
428418
}
429419
ObType::PydanticSerializable => {
430-
call_pydantic_serializer(value, state, serialize_to_json(serializer)).map_err(|e| e.0)
420+
call_pydantic_serializer(value, state, serialize_to_json(serializer)).map_err(unwrap_ser_error)
431421
}
432422
ObType::Dataclass => {
433423
let (pairs_iter, fields_dict) = any_dataclass_iter(value).map_err(py_err_se_err)?;
@@ -456,20 +446,7 @@ pub(crate) fn infer_serialize_known<'py, S: Serializer>(
456446
}
457447
seq.end()
458448
}
459-
ObType::Path => {
460-
let s: PyBackedStr = value
461-
.str()
462-
.and_then(|value_str| value_str.extract())
463-
.map_err(py_err_se_err)?;
464-
serializer.serialize_str(&s)
465-
}
466-
ObType::Pattern => {
467-
let s: PyBackedStr = value
468-
.getattr(intern!(value.py(), "pattern"))
469-
.and_then(|pattern| pattern.str()?.extract())
470-
.map_err(py_err_se_err)?;
471-
serializer.serialize_str(&s)
472-
}
449+
ObType::Pattern => serialize_pattern(value, serialize_to_json(serializer)).map_err(unwrap_ser_error),
473450
ObType::Unknown => {
474451
if let Some(fallback) = state.extra.fallback {
475452
let next_value = fallback.call1((value,)).map_err(py_err_se_err)?;
@@ -497,6 +474,14 @@ fn unknown_type_error(value: &Bound<'_, PyAny>) -> PyErr {
497474
))
498475
}
499476

477+
fn serialize_pattern<'py, T, E: From<PyErr>>(
478+
value: &Bound<'py, PyAny>,
479+
do_serialize: impl DoSerialize<'py, T, E>,
480+
) -> Result<T, E> {
481+
let pattern = value.getattr(intern!(value.py(), "pattern"))?;
482+
serialize_via_str(&pattern, do_serialize)
483+
}
484+
500485
fn serialize_unknown<'py>(value: &Bound<'py, PyAny>) -> Cow<'py, str> {
501486
if let Ok(s) = value.str() {
502487
s.to_string_lossy().into_owned().into()

src/serializers/ob_type.rs

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ pub struct ObTypeLookup {
4848
pattern_object: Py<PyAny>,
4949
// uuid type
5050
uuid_object: Py<PyAny>,
51+
// `complex` builtin
5152
complex: usize,
5253
}
5354

@@ -155,10 +156,10 @@ impl ObTypeLookup {
155156
ObType::Enum => self.enum_object.as_ptr() as usize == ob_type,
156157
ObType::Generator => self.generator_object.as_ptr() as usize == ob_type,
157158
ObType::Path => self.path_object.as_ptr() as usize == ob_type,
158-
ObType::Pattern => self.path_object.as_ptr() as usize == ob_type,
159+
ObType::Pattern => self.pattern_object.as_ptr() as usize == ob_type,
159160
ObType::Uuid => self.uuid_object.as_ptr() as usize == ob_type,
160-
ObType::Unknown => false,
161161
ObType::Complex => self.complex == ob_type,
162+
ObType::Unknown => false,
162163
};
163164

164165
if ans {
@@ -414,9 +415,10 @@ pub enum ObType {
414415
Pattern,
415416
// Uuid
416417
Uuid,
418+
// complex builtin
419+
Complex,
417420
// unknown type
418421
Unknown,
419-
Complex,
420422
}
421423

422424
impl PartialEq for ObType {
@@ -427,12 +429,12 @@ impl PartialEq for ObType {
427429
} else {
428430
match (self, other) {
429431
// special cases for subclasses
430-
(Self::IntSubclass, Self::Int) => true,
431-
(Self::Int, Self::IntSubclass) => true,
432-
(Self::FloatSubclass, Self::Float) => true,
433-
(Self::Float, Self::FloatSubclass) => true,
434-
(Self::StrSubclass, Self::Str) => true,
435-
(Self::Str, Self::StrSubclass) => true,
432+
(Self::IntSubclass, Self::Int)
433+
| (Self::Int, Self::IntSubclass)
434+
| (Self::FloatSubclass, Self::Float)
435+
| (Self::Float, Self::FloatSubclass)
436+
| (Self::StrSubclass, Self::Str)
437+
| (Self::Str, Self::StrSubclass) => true,
436438
_ => false,
437439
}
438440
}

src/serializers/shared.rs

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ use std::io::{self, Write};
55
use std::sync::Arc;
66

77
use pyo3::exceptions::PyTypeError;
8-
use pyo3::prelude::*;
98
use pyo3::sync::PyOnceLock;
109
use pyo3::types::{PyDict, PyString};
1110
use pyo3::{intern, PyTraverseError, PyVisit};
11+
use pyo3::{prelude::*, IntoPyObjectExt};
1212

1313
use enum_dispatch::enum_dispatch;
1414
use serde::{Serialize, Serializer};
@@ -199,17 +199,19 @@ impl CombinedSerializer {
199199
)
200200
.map_err(|err| py_schema_error_type!("Error building `function-wrap` serializer:\n {}", err));
201201
}
202-
// applies to lists tuples and dicts, does not override the main schema `type`
203-
Some("include-exclude-sequence" | "include-exclude-dict") => (),
204-
// applies specifically to bytes, does not override the main schema `type`
205-
Some("base64") => (),
202+
Some(
203+
// applies to lists tuples and dicts, does not override the main schema `type`
204+
"include-exclude-sequence" | "include-exclude-dict"
205+
// applies specifically to bytes, does not override the main schema `type`
206+
| "base64"
207+
)
208+
// if `schema.serialization.type` is None, fall back to `schema.type`
209+
| None => (),
206210
Some(ser_type) => {
207211
// otherwise if `schema.serialization.type` is defined, use that with `find_serializer`
208212
// instead of `schema.type`. In this case it's an error if a serializer isn't found.
209213
return Self::find_serializer(ser_type, &ser_schema, config, definitions);
210214
}
211-
// if `schema.serialization.type` is None, fall back to `schema.type`
212-
None => (),
213215
};
214216
}
215217

@@ -625,6 +627,8 @@ pub trait DoSerialize<'py, OutputT, ErrorT> {
625627
value: &Bound<'py, PyAny>,
626628
state: &mut SerializationState<'_, 'py>,
627629
) -> Result<OutputT, ErrorT>;
630+
631+
fn serialize_str(self, value: &Bound<'py, PyString>) -> Result<OutputT, ErrorT>;
628632
}
629633

630634
/// Helper to create a `SerializeToPython` instance
@@ -660,6 +664,10 @@ impl<'py> DoSerialize<'py, Py<PyAny>, PyErr> for SerializeToPython {
660664
state.warn_fallback_py(name, value)?;
661665
infer_to_python(value, state)
662666
}
667+
668+
fn serialize_str(self, value: &Bound<'py, PyString>) -> Result<Py<PyAny>, PyErr> {
669+
value.into_py_any(value.py())
670+
}
663671
}
664672

665673
pub struct SerializeToJson<S> {
@@ -687,4 +695,9 @@ impl<'py, S: Serializer> DoSerialize<'py, S::Ok, WrappedSerError<S::Error>> for
687695
state.warn_fallback_ser::<S>(name, value).map_err(WrappedSerError)?;
688696
infer_serialize(value, self.serializer, state).map_err(WrappedSerError)
689697
}
698+
699+
fn serialize_str(self, value: &Bound<'py, PyString>) -> Result<S::Ok, WrappedSerError<S::Error>> {
700+
let s = value.to_str()?;
701+
self.serializer.serialize_str(s).map_err(WrappedSerError)
702+
}
690703
}

0 commit comments

Comments
 (0)