Skip to content

Commit 0344763

Browse files
davidhewittViicos
andauthored
fix various RootModel serialization issues (#1836)
Co-authored-by: Victorien <65306057+Viicos@users.noreply.github.com>
1 parent d90a93b commit 0344763

File tree

5 files changed

+236
-156
lines changed

5 files changed

+236
-156
lines changed

src/serializers/errors.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@ pub(super) fn py_err_se_err<T: ser::Error, E: fmt::Display>(py_error: E) -> T {
1818
T::custom(py_error.to_string())
1919
}
2020

21+
/// Wrapper type which allows convenient conversion between `PyErr` and `ser::Error` in `?` expressions.
22+
pub(super) struct WrappedSerError<T: ser::Error>(pub T);
23+
24+
impl<T: ser::Error> From<PyErr> for WrappedSerError<T> {
25+
fn from(py_err: PyErr) -> Self {
26+
WrappedSerError(T::custom(py_err.to_string()))
27+
}
28+
}
29+
2130
#[pyclass(extends=PyValueError, module="pydantic_core._pydantic_core")]
2231
#[derive(Debug, Clone)]
2332
pub struct PythonSerializerError {

src/serializers/type_serializers/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ pub mod with_default;
3232

3333
use super::computed_fields::ComputedFields;
3434
use super::config::utf8_py_error;
35-
use super::errors::{py_err_se_err, PydanticSerializationError};
35+
use super::errors::{py_err_se_err, PydanticSerializationError, WrappedSerError};
3636
use super::extra::{Extra, ExtraOwned, SerCheck, SerMode};
3737
use super::fields::{FieldsMode, GeneralFieldsSerializer, SerField};
3838
use super::filter::{AnyFilter, SchemaFilter};

src/serializers/type_serializers/model.rs

Lines changed: 97 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@ use ahash::AHashMap;
99
use pyo3::IntoPyObjectExt;
1010

1111
use super::{
12-
infer_json_key, infer_json_key_known, infer_serialize, infer_to_python, py_err_se_err, BuildSerializer,
13-
CombinedSerializer, ComputedFields, Extra, FieldsMode, GeneralFieldsSerializer, ObType, SerCheck, SerField,
14-
TypeSerializer,
12+
infer_json_key, infer_json_key_known, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer,
13+
ComputedFields, Extra, FieldsMode, GeneralFieldsSerializer, ObType, SerCheck, SerField, TypeSerializer,
14+
WrappedSerError,
1515
};
1616
use crate::build_tools::py_schema_err;
1717
use crate::build_tools::{py_schema_error_type, ExtraBehavior};
1818
use crate::definitions::DefinitionsBuilder;
19-
use crate::serializers::errors::PydanticSerializationUnexpectedValue;
19+
use crate::serializers::type_serializers::any::AnySerializer;
20+
use crate::serializers::type_serializers::function::FunctionPlainSerializer;
21+
use crate::serializers::type_serializers::function::FunctionWrapSerializer;
2022
use crate::tools::SchemaDict;
2123

2224
const ROOT_FIELD: &str = "root";
@@ -138,17 +140,80 @@ fn has_extra(schema: &Bound<'_, PyDict>, config: Option<&Bound<'_, PyDict>>) ->
138140
}
139141

140142
impl ModelSerializer {
141-
fn allow_value(&self, value: &Bound<'_, PyAny>, extra: &Extra) -> PyResult<bool> {
142-
let class = self.class.bind(value.py());
143-
match extra.check {
144-
SerCheck::Strict => Ok(value.get_type().is(class)),
145-
SerCheck::Lax => value.is_instance(class),
143+
fn allow_value(&self, value: &Bound<'_, PyAny>, check: SerCheck) -> PyResult<bool> {
144+
match check {
145+
SerCheck::Strict => Ok(value.get_type().is(&self.class)),
146+
SerCheck::Lax => value.is_instance(self.class.bind(value.py())),
146147
SerCheck::None => value.hasattr(intern!(value.py(), "__dict__")),
147148
}
148149
}
149150

151+
fn allow_value_root_model(&self, value: &Bound<'_, PyAny>, check: SerCheck) -> PyResult<bool> {
152+
match check {
153+
SerCheck::Strict => Ok(value.get_type().is(&self.class)),
154+
SerCheck::Lax | SerCheck::None => value.is_instance(self.class.bind(value.py())),
155+
}
156+
}
157+
158+
/// Performs serialization for the model. This handles
159+
/// - compatibility checks
160+
/// - extracting the inner value for root models
161+
/// - applying `serialize_as_any` where needed
162+
///
163+
/// `do_serialize` should be a function which performs the actual serialization, and should not
164+
/// apply any type inference. (`Model` serialization is strongly coupled with its child
165+
/// serializer, and in the few cases where `serialize_as_any` applies, it is handled here.)
166+
///
167+
/// If the value is not applicable, `do_serialize` will be called with `None` to indicate fallback
168+
/// behaviour should be used.
169+
fn serialize<T, E: From<PyErr>>(
170+
&self,
171+
value: &Bound<'_, PyAny>,
172+
extra: &Extra,
173+
do_serialize: impl FnOnce(Option<(&Arc<CombinedSerializer>, &Bound<'_, PyAny>, &Extra)>) -> Result<T, E>,
174+
) -> Result<T, E> {
175+
match self.root_model {
176+
true if self.allow_value_root_model(value, extra.check)? => {
177+
let root_extra = Extra {
178+
field_name: Some(ROOT_FIELD),
179+
model: Some(value),
180+
..extra.clone()
181+
};
182+
let root = value.getattr(intern!(value.py(), ROOT_FIELD))?;
183+
184+
// for root models, `serialize_as_any` may apply unless a `field_serializer` is used
185+
let serializer = if root_extra.serialize_as_any
186+
&& !matches!(
187+
self.serializer.as_ref(),
188+
CombinedSerializer::Function(FunctionPlainSerializer {
189+
is_field_serializer: true,
190+
..
191+
}) | CombinedSerializer::FunctionWrap(FunctionWrapSerializer {
192+
is_field_serializer: true,
193+
..
194+
}),
195+
) {
196+
AnySerializer::get()
197+
} else {
198+
&self.serializer
199+
};
200+
201+
do_serialize(Some((serializer, &root, &root_extra)))
202+
}
203+
false if self.allow_value(value, extra.check)? => {
204+
let model_extra = Extra {
205+
model: Some(value),
206+
..extra.clone()
207+
};
208+
let inner_value = self.get_inner_value(value, &model_extra)?;
209+
do_serialize(Some((&self.serializer, &inner_value, &model_extra)))
210+
}
211+
_ => do_serialize(None),
212+
}
213+
}
214+
150215
fn get_inner_value<'py>(&self, model: &Bound<'py, PyAny>, extra: &Extra) -> PyResult<Bound<'py, PyAny>> {
151-
let py = model.py();
216+
let py: Python<'_> = model.py();
152217
let mut attrs = model.getattr(intern!(py, "__dict__"))?.downcast_into::<PyDict>()?;
153218

154219
if extra.exclude_unset {
@@ -184,38 +249,18 @@ impl TypeSerializer for ModelSerializer {
184249
exclude: Option<&Bound<'_, PyAny>>,
185250
extra: &Extra,
186251
) -> PyResult<Py<PyAny>> {
187-
let model = Some(value);
188-
189-
let model_extra = Extra { model, ..*extra };
190-
if self.root_model {
191-
let field_name = Some(ROOT_FIELD);
192-
let root_extra = Extra {
193-
field_name,
194-
..model_extra
195-
};
196-
let py = value.py();
197-
let root = value.getattr(intern!(py, ROOT_FIELD)).map_err(|original_err| {
198-
if root_extra.check.enabled() {
199-
PydanticSerializationUnexpectedValue::new_from_msg(None).to_py_err()
200-
} else {
201-
original_err
202-
}
203-
})?;
204-
self.serializer.to_python(&root, include, exclude, &root_extra)
205-
} else if self.allow_value(value, &model_extra)? {
206-
let inner_value = self.get_inner_value(value, &model_extra)?;
207-
// There is strong coupling between a model serializer and its child, we should
208-
// not fall back to type inference in the middle.
209-
self.serializer
210-
.to_python_no_infer(&inner_value, include, exclude, &model_extra)
211-
} else {
212-
extra.warnings.on_fallback_py(self.get_name(), value, &model_extra)?;
213-
infer_to_python(value, include, exclude, &model_extra)
214-
}
252+
self.serialize(value, extra, |resolved| match resolved {
253+
Some((serializer, value, extra)) => serializer.to_python_no_infer(value, include, exclude, extra),
254+
None => {
255+
extra.warnings.on_fallback_py(self.get_name(), value, extra)?;
256+
infer_to_python(value, include, exclude, extra)
257+
}
258+
})
215259
}
216260

217261
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
218-
if self.allow_value(key, extra)? {
262+
// FIXME: root model in json key position should serialize as inner value?
263+
if self.allow_value(key, extra.check)? {
219264
infer_json_key_known(ObType::PydanticSerializable, key, extra)
220265
} else {
221266
extra.warnings.on_fallback_py(&self.name, key, extra)?;
@@ -231,30 +276,19 @@ impl TypeSerializer for ModelSerializer {
231276
exclude: Option<&Bound<'_, PyAny>>,
232277
extra: &Extra,
233278
) -> Result<S::Ok, S::Error> {
234-
let model = Some(value);
235-
let model_extra = Extra { model, ..*extra };
236-
if self.root_model {
237-
let field_name = Some(ROOT_FIELD);
238-
let root_extra = Extra {
239-
field_name,
240-
..model_extra
241-
};
242-
let py = value.py();
243-
let root = value.getattr(intern!(py, ROOT_FIELD)).map_err(py_err_se_err)?;
244-
self.serializer
245-
.serde_serialize(&root, serializer, include, exclude, &root_extra)
246-
} else if self.allow_value(value, &model_extra).map_err(py_err_se_err)? {
247-
let inner_value = self.get_inner_value(value, &model_extra).map_err(py_err_se_err)?;
248-
// There is strong coupling between a model serializer and its child, we should
249-
// not fall back to type inference in the midddle.
250-
self.serializer
251-
.serde_serialize_no_infer(&inner_value, serializer, include, exclude, &model_extra)
252-
} else {
253-
extra
254-
.warnings
255-
.on_fallback_ser::<S>(self.get_name(), value, &model_extra)?;
256-
infer_serialize(value, serializer, include, exclude, &model_extra)
257-
}
279+
self.serialize(value, extra, |resolved| match resolved {
280+
Some((cs, value, extra)) => cs
281+
.serde_serialize_no_infer(value, serializer, include, exclude, extra)
282+
.map_err(WrappedSerError),
283+
None => {
284+
extra
285+
.warnings
286+
.on_fallback_ser::<S>(self.get_name(), value, extra)
287+
.map_err(WrappedSerError)?;
288+
infer_serialize(value, serializer, include, exclude, extra).map_err(WrappedSerError)
289+
}
290+
})
291+
.map_err(|e| e.0)
258292
}
259293

260294
fn get_name(&self) -> &str {

tests/serializers/test_model_root.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
import json
2+
import os
23
import platform
4+
from pathlib import Path
35
from typing import Any, Union
46

57
import pytest
68

7-
try:
8-
from functools import cached_property
9-
except ImportError:
10-
cached_property = None
11-
12-
139
from pydantic_core import SchemaSerializer, core_schema
1410

1511
from ..conftest import plain_repr
@@ -181,27 +177,41 @@ class Model(RootModel):
181177
assert s.to_json(m) == b'[1,2,{"value":"abc"}]'
182178

183179

184-
def test_construct_nested():
185-
class RModel(RootModel):
180+
def test_not_root_model():
181+
# https://github.com/pydantic/pydantic/issues/8963
182+
183+
class RootModel:
186184
root: int
187185

188-
class BModel(BaseModel):
189-
value: RModel
186+
v = RootModel()
187+
v.root = '123'
190188

191189
s = SchemaSerializer(
192190
core_schema.model_schema(
193-
BModel,
194-
core_schema.model_fields_schema(
195-
{
196-
'value': core_schema.model_field(
197-
core_schema.model_schema(RModel, core_schema.int_schema(), root_model=True)
198-
)
199-
}
200-
),
201-
)
191+
RootModel,
192+
core_schema.str_schema(),
193+
root_model=True,
194+
),
202195
)
203196

204-
m = BModel(value=42)
197+
assert s.to_python(v) == '123'
198+
assert s.to_json(v) == b'"123"'
199+
200+
# Path is chosen because it has a .root property
201+
# which could look like a root model in bad implementations
202+
203+
if os.name == 'nt':
204+
path_value = Path('C:\\a\\b')
205+
path_bytes = b'"C:\\\\a\\\\b"' # fixme double escaping?
206+
else:
207+
path_value = Path('/a/b')
208+
path_bytes = b'"/a/b"'
209+
210+
with pytest.warns(UserWarning, match=r'PydanticSerializationUnexpectedValue\(Expected `RootModel`'):
211+
assert s.to_python(path_value) == path_value
212+
213+
with pytest.warns(UserWarning, match=r'PydanticSerializationUnexpectedValue\(Expected `RootModel`'):
214+
assert s.to_json(path_value) == path_bytes
205215

206-
with pytest.raises(AttributeError, match="'int' object has no attribute 'root'"):
207-
s.to_python(m)
216+
assert s.to_python(path_value, warnings=False) == path_value
217+
assert s.to_json(path_value, warnings=False) == path_bytes

0 commit comments

Comments
 (0)