Skip to content

Commit

Permalink
Use more explicit warning regarding serialization warning for missing…
Browse files Browse the repository at this point in the history
… fields (#1415)
  • Loading branch information
sydney-runkle authored Aug 22, 2024
1 parent 4113638 commit f4a0675
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 85 deletions.
46 changes: 0 additions & 46 deletions src/errors/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
use core::fmt;
use std::borrow::Cow;

use pyo3::prelude::*;

mod line_error;
Expand Down Expand Up @@ -33,46 +30,3 @@ pub fn py_err_string(py: Python, err: PyErr) -> String {
Err(_) => "Unknown Error".to_string(),
}
}

// TODO: is_utf8_char_boundary, floor_char_boundary and ceil_char_boundary
// with builtin methods once https://github.com/rust-lang/rust/issues/93743 is resolved
// These are just copy pasted from the current implementation
const fn is_utf8_char_boundary(value: u8) -> bool {
// This is bit magic equivalent to: b < 128 || b >= 192
(value as i8) >= -0x40
}

pub fn floor_char_boundary(value: &str, index: usize) -> usize {
if index >= value.len() {
value.len()
} else {
let lower_bound = index.saturating_sub(3);
let new_index = value.as_bytes()[lower_bound..=index]
.iter()
.rposition(|b| is_utf8_char_boundary(*b));

// SAFETY: we know that the character boundary will be within four bytes
unsafe { lower_bound + new_index.unwrap_unchecked() }
}
}

pub fn ceil_char_boundary(value: &str, index: usize) -> usize {
let upper_bound = Ord::min(index + 4, value.len());
value.as_bytes()[index..upper_bound]
.iter()
.position(|b| is_utf8_char_boundary(*b))
.map_or(upper_bound, |pos| pos + index)
}

pub fn write_truncated_to_50_bytes<F: fmt::Write>(f: &mut F, val: Cow<'_, str>) -> std::fmt::Result {
if val.len() > 50 {
write!(
f,
"{}...{}",
&val[0..floor_char_boundary(&val, 25)],
&val[ceil_char_boundary(&val, val.len() - 24)..]
)
} else {
write!(f, "{val}")
}
}
4 changes: 2 additions & 2 deletions src/errors/validation_exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::errors::LocItem;
use crate::get_pydantic_version;
use crate::input::InputType;
use crate::serializers::{DuckTypingSerMode, Extra, SerMode, SerializationState};
use crate::tools::{safe_repr, SchemaDict};
use crate::tools::{safe_repr, write_truncated_to_limited_bytes, SchemaDict};

use super::line_error::ValLineError;
use super::location::Location;
Expand Down Expand Up @@ -526,7 +526,7 @@ impl PyLineError {
let input_value = self.input_value.bind(py);
let input_str = safe_repr(input_value);
write!(output, ", input_value=")?;
super::write_truncated_to_50_bytes(&mut output, input_str.to_cow())?;
write_truncated_to_limited_bytes(&mut output, &input_str.to_string(), 50)?;

if let Ok(type_) = input_value.get_type().qualname() {
write!(output, ", input_type={type_}")?;
Expand Down
11 changes: 3 additions & 8 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::recursion_guard::ContainsRecursionState;
use crate::recursion_guard::RecursionError;
use crate::recursion_guard::RecursionGuard;
use crate::recursion_guard::RecursionState;
use crate::tools::safe_repr;
use crate::tools::truncate_safe_repr;
use crate::PydanticSerializationError;

/// this is ugly, would be much better if extra could be stored in `SerializationState`
Expand Down Expand Up @@ -426,15 +426,10 @@ impl CollectWarnings {
.qualname()
.unwrap_or_else(|_| PyString::new_bound(value.py(), "<unknown python object>"));

let input_str = safe_repr(value);
let mut value_str = String::with_capacity(100);
value_str.push_str("with value `");
crate::errors::write_truncated_to_50_bytes(&mut value_str, input_str.to_cow())
.expect("Writing to a `String` failed");
value_str.push('`');
let value_str = truncate_safe_repr(value, None);

self.add_warning(format!(
"Expected `{field_type}` but got `{type_name}` {value_str} - serialized value may not be as expected"
"Expected `{field_type}` but got `{type_name}` with value `{value_str}` - serialized value may not be as expected"
));
}
}
Expand Down
20 changes: 19 additions & 1 deletion src/serializers/fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use smallvec::SmallVec;

use crate::serializers::extra::SerCheck;
use crate::serializers::DuckTypingSerMode;
use crate::tools::truncate_safe_repr;
use crate::PydanticSerializationUnexpectedValue;

use super::computed_fields::ComputedFields;
Expand Down Expand Up @@ -210,7 +211,24 @@ impl GeneralFieldsSerializer {
// Check for missing fields, we can't have extra fields here
&& self.required_fields > used_req_fields
{
Err(PydanticSerializationUnexpectedValue::new_err(None))
let required_fields = self.required_fields;
let type_name = match extra.model {
Some(model) => model
.get_type()
.qualname()
.ok()
.unwrap_or_else(|| PyString::new_bound(py, "<unknown python object>"))
.to_string(),
None => "<unknown python object>".to_string(),
};
let field_value = match extra.model {
Some(model) => truncate_safe_repr(model, Some(100)),
None => "<unknown python object>".to_string(),
};

Err(PydanticSerializationUnexpectedValue::new_err(Some(format!(
"Expected {required_fields} fields but got {used_req_fields} for type `{type_name}` with value `{field_value}` - serialized value may not be as expected."
))))
} else {
Ok(output_dict)
}
Expand Down
3 changes: 1 addition & 2 deletions src/serializers/type_serializers/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ impl TypeSerializer for ModelSerializer {
) -> PyResult<PyObject> {
let model = Some(value);
let duck_typing_ser_mode = extra.duck_typing_ser_mode.next_mode();

let model_extra = Extra {
model,
field_name: None,
duck_typing_ser_mode,
..*extra
};
Expand Down Expand Up @@ -221,7 +221,6 @@ impl TypeSerializer for ModelSerializer {
let duck_typing_ser_mode = extra.duck_typing_ser_mode.next_mode();
let model_extra = Extra {
model,
field_name: None,
duck_typing_ser_mode,
..*extra
};
Expand Down
12 changes: 3 additions & 9 deletions src/serializers/type_serializers/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ use std::borrow::Cow;
use crate::build_tools::py_schema_err;
use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD};
use crate::definitions::DefinitionsBuilder;
use crate::errors::write_truncated_to_50_bytes;
use crate::lookup_key::LookupKey;
use crate::serializers::type_serializers::py_err_se_err;
use crate::tools::{safe_repr, SchemaDict};
use crate::tools::{truncate_safe_repr, SchemaDict};
use crate::PydanticSerializationUnexpectedValue;

use super::{
Expand Down Expand Up @@ -446,15 +445,10 @@ impl TaggedUnionSerializer {
Discriminator::Function(func) => func.call1(py, (value,)).ok(),
};
if discriminator_value.is_none() {
let input_str = safe_repr(value);
let mut value_str = String::with_capacity(100);
value_str.push_str("with value `");
write_truncated_to_50_bytes(&mut value_str, input_str.to_cow()).expect("Writing to a `String` failed");
value_str.push('`');

let value_str = truncate_safe_repr(value, None);
extra.warnings.custom_warning(
format!(
"Failed to get discriminator value for tagged union serialization {value_str} - defaulting to left to right union serialization."
"Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization."
)
);
}
Expand Down
64 changes: 54 additions & 10 deletions src/tools.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::borrow::Cow;
use core::fmt;

use pyo3::exceptions::PyKeyError;
use pyo3::prelude::*;
Expand Down Expand Up @@ -96,15 +96,6 @@ pub enum ReprOutput<'py> {
Fallback(String),
}

impl ReprOutput<'_> {
pub fn to_cow(&self) -> Cow<'_, str> {
match self {
ReprOutput::Python(s) => s.to_string_lossy(),
ReprOutput::Fallback(s) => s.into(),
}
}
}

impl std::fmt::Display for ReprOutput<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Expand All @@ -124,6 +115,15 @@ pub fn safe_repr<'py>(v: &Bound<'py, PyAny>) -> ReprOutput<'py> {
}
}

pub fn truncate_safe_repr(v: &Bound<'_, PyAny>, max_len: Option<usize>) -> String {
let max_len = max_len.unwrap_or(50); // default to 100 bytes
let input_str = safe_repr(v);
let mut limited_str = String::with_capacity(max_len);
write_truncated_to_limited_bytes(&mut limited_str, &input_str.to_string(), max_len)
.expect("Writing to a `String` failed");
limited_str
}

pub fn extract_i64(v: &Bound<'_, PyAny>) -> Option<i64> {
#[cfg(PyPy)]
if !v.is_instance_of::<pyo3::types::PyInt>() {
Expand All @@ -146,3 +146,47 @@ pub(crate) fn new_py_string<'py>(py: Python<'py>, s: &str, cache_str: StringCach
pystring_fast_new(py, s, ascii_only)
}
}

// TODO: is_utf8_char_boundary, floor_char_boundary and ceil_char_boundary
// with builtin methods once https://github.com/rust-lang/rust/issues/93743 is resolved
// These are just copy pasted from the current implementation
const fn is_utf8_char_boundary(value: u8) -> bool {
// This is bit magic equivalent to: b < 128 || b >= 192
(value as i8) >= -0x40
}

pub fn floor_char_boundary(value: &str, index: usize) -> usize {
if index >= value.len() {
value.len()
} else {
let lower_bound = index.saturating_sub(3);
let new_index = value.as_bytes()[lower_bound..=index]
.iter()
.rposition(|b| is_utf8_char_boundary(*b));

// SAFETY: we know that the character boundary will be within four bytes
unsafe { lower_bound + new_index.unwrap_unchecked() }
}
}

pub fn ceil_char_boundary(value: &str, index: usize) -> usize {
let upper_bound = Ord::min(index + 4, value.len());
value.as_bytes()[index..upper_bound]
.iter()
.position(|b| is_utf8_char_boundary(*b))
.map_or(upper_bound, |pos| pos + index)
}

pub fn write_truncated_to_limited_bytes<F: fmt::Write>(f: &mut F, val: &str, max_len: usize) -> std::fmt::Result {
if val.len() > max_len {
let mid_point = max_len.div_ceil(2);
write!(
f,
"{}...{}",
&val[0..floor_char_boundary(val, mid_point)],
&val[ceil_char_boundary(val, val.len() - (mid_point - 1))..]
)
} else {
write!(f, "{val}")
}
}
68 changes: 61 additions & 7 deletions tests/serializers/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
import pytest
from dirty_equals import IsJson

from pydantic_core import PydanticSerializationError, SchemaSerializer, SchemaValidator, core_schema
from pydantic_core import (
PydanticSerializationError,
PydanticSerializationUnexpectedValue,
SchemaSerializer,
SchemaValidator,
core_schema,
)

from ..conftest import plain_repr

Expand Down Expand Up @@ -1084,20 +1090,68 @@ class Model:


def test_no_warn_on_exclude() -> None:
warnings.simplefilter('error')
with warnings.catch_warnings():
warnings.simplefilter('error')

s = SchemaSerializer(
core_schema.model_schema(
BasicModel,
core_schema.model_fields_schema(
{
'a': core_schema.model_field(core_schema.int_schema()),
'b': core_schema.model_field(core_schema.int_schema()),
}
),
)
)

value = BasicModel(a=0, b=1)
assert s.to_python(value, exclude={'b'}) == {'a': 0}
assert s.to_python(value, mode='json', exclude={'b'}) == {'a': 0}


def test_warn_on_missing_field() -> None:
class AModel(BasicModel): ...

class BModel(BasicModel): ...

s = SchemaSerializer(
core_schema.model_schema(
BasicModel,
core_schema.model_fields_schema(
{
'a': core_schema.model_field(core_schema.int_schema()),
'b': core_schema.model_field(core_schema.int_schema()),
'root': core_schema.model_field(
core_schema.tagged_union_schema(
choices={
'a': core_schema.model_schema(
AModel,
core_schema.model_fields_schema(
{
'type': core_schema.model_field(core_schema.literal_schema(['a'])),
'a': core_schema.model_field(core_schema.int_schema()),
}
),
),
'b': core_schema.model_schema(
BModel,
core_schema.model_fields_schema(
{
'type': core_schema.model_field(core_schema.literal_schema(['b'])),
'b': core_schema.model_field(core_schema.int_schema()),
}
),
),
},
discriminator='type',
)
),
}
),
)
)

value = BasicModel(a=0, b=1)
assert s.to_python(value, exclude={'b'}) == {'a': 0}
assert s.to_python(value, mode='json', exclude={'b'}) == {'a': 0}
with pytest.raises(
PydanticSerializationUnexpectedValue, match='Expected 2 fields but got 1 for type `.*AModel` with value `.*`.+'
):
value = BasicModel(root=AModel(type='a'))
s.to_python(value)

0 comments on commit f4a0675

Please sign in to comment.