Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ability to pass context to serialization (pydantic#7143) #1215

Merged
merged 3 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions python/pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class SchemaValidator:
*,
strict: bool | None = None,
from_attributes: bool | None = None,
context: 'dict[str, Any] | None' = None,
context: dict[str, Any] | None = None,
self_instance: Any | None = None,
) -> Any:
"""
Expand Down Expand Up @@ -131,7 +131,7 @@ class SchemaValidator:
*,
strict: bool | None = None,
from_attributes: bool | None = None,
context: 'dict[str, Any] | None' = None,
context: dict[str, Any] | None = None,
self_instance: Any | None = None,
) -> bool:
"""
Expand All @@ -148,7 +148,7 @@ class SchemaValidator:
input: str | bytes | bytearray,
*,
strict: bool | None = None,
context: 'dict[str, Any] | None' = None,
context: dict[str, Any] | None = None,
self_instance: Any | None = None,
) -> Any:
"""
Expand Down Expand Up @@ -176,7 +176,7 @@ class SchemaValidator:
The validated Python object.
"""
def validate_strings(
self, input: _StringInput, *, strict: bool | None = None, context: 'dict[str, Any] | None' = None
self, input: _StringInput, *, strict: bool | None = None, context: dict[str, Any] | None = None
) -> Any:
"""
Validate a string against the schema and return the validated Python object.
Expand Down Expand Up @@ -206,7 +206,7 @@ class SchemaValidator:
*,
strict: bool | None = None,
from_attributes: bool | None = None,
context: 'dict[str, Any] | None' = None,
context: dict[str, Any] | None = None,
) -> dict[str, Any] | tuple[dict[str, Any], dict[str, Any] | None, set[str]]:
"""
Validate an assignment to a field on a model.
Expand Down Expand Up @@ -278,6 +278,7 @@ class SchemaSerializer:
round_trip: bool = False,
warnings: bool = True,
fallback: Callable[[Any], Any] | None = None,
context: dict[str, Any] | None = None,
) -> Any:
"""
Serialize/marshal a Python object to a Python object including transforming and filtering data.
Expand All @@ -297,6 +298,8 @@ class SchemaSerializer:
warnings: Whether to log warnings when invalid fields are encountered.
fallback: A function to call when an unknown value is encountered,
if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised.
context: The context to use for serialization, this is passed to functional serializers as
[`info.context`][pydantic_core.core_schema.SerializationInfo.context].

Raises:
PydanticSerializationError: If serialization fails and no `fallback` function is provided.
Expand All @@ -318,6 +321,7 @@ class SchemaSerializer:
round_trip: bool = False,
warnings: bool = True,
fallback: Callable[[Any], Any] | None = None,
context: dict[str, Any] | None = None,
) -> bytes:
"""
Serialize a Python object to JSON including transforming and filtering data.
Expand All @@ -336,6 +340,8 @@ class SchemaSerializer:
warnings: Whether to log warnings when invalid fields are encountered.
fallback: A function to call when an unknown value is encountered,
if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised.
context: The context to use for serialization, this is passed to functional serializers as
[`info.context`][pydantic_core.core_schema.SerializationInfo.context].

Raises:
PydanticSerializationError: If serialization fails and no `fallback` function is provided.
Expand All @@ -358,6 +364,7 @@ def to_json(
inf_nan_mode: Literal['null', 'constants'] = 'constants',
serialize_unknown: bool = False,
fallback: Callable[[Any], Any] | None = None,
context: dict[str, Any] | None = None,
) -> bytes:
"""
Serialize a Python object to JSON including transforming and filtering data.
Expand All @@ -379,6 +386,8 @@ def to_json(
`"<Unserializable {value_type} object>"` will be used.
fallback: A function to call when an unknown value is encountered,
if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised.
context: The context to use for serialization, this is passed to functional serializers as
[`info.context`][pydantic_core.core_schema.SerializationInfo.context].

Raises:
PydanticSerializationError: If serialization fails and no `fallback` function is provided.
Expand Down Expand Up @@ -419,6 +428,7 @@ def to_jsonable_python(
inf_nan_mode: Literal['null', 'constants'] = 'constants',
serialize_unknown: bool = False,
fallback: Callable[[Any], Any] | None = None,
context: dict[str, Any] | None = None,
) -> Any:
"""
Serialize/marshal a Python object to a JSON-serializable Python object including transforming and filtering data.
Expand All @@ -440,6 +450,8 @@ def to_jsonable_python(
`"<Unserializable {value_type} object>"` will be used.
fallback: A function to call when an unknown value is encountered,
if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised.
context: The context to use for serialization, this is passed to functional serializers as
[`info.context`][pydantic_core.core_schema.SerializationInfo.context].

Raises:
PydanticSerializationError: If serialization fails and no `fallback` function is provided.
Expand Down
4 changes: 4 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ def include(self) -> IncExCall: ...
@property
def exclude(self) -> IncExCall: ...

@property
def context(self) -> Any | None:
"""Current serialization context."""

@property
def mode(self) -> str: ...

Expand Down
2 changes: 1 addition & 1 deletion src/errors/validation_exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ impl ValidationError {
include_input: bool,
) -> PyResult<&'py PyString> {
let state = SerializationState::new("iso8601", "utf8", "constants")?;
let extra = state.extra(py, &SerMode::Json, true, false, false, true, None);
let extra = state.extra(py, &SerMode::Json, true, false, false, true, None, None);
let serializer = ValidationErrorSerializer {
py,
line_errors: &self.line_errors,
Expand Down
12 changes: 10 additions & 2 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ impl SerializationState {
round_trip: bool,
serialize_unknown: bool,
fallback: Option<&'py PyAny>,
context: Option<&'py PyAny>,
) -> Extra<'py> {
Extra::new(
py,
Expand All @@ -59,6 +60,7 @@ impl SerializationState {
&self.rec_guard,
serialize_unknown,
fallback,
context,
)
}

Expand Down Expand Up @@ -90,6 +92,7 @@ pub(crate) struct Extra<'a> {
pub field_name: Option<&'a str>,
pub serialize_unknown: bool,
pub fallback: Option<&'a PyAny>,
pub context: Option<&'a PyAny>,
}

impl<'a> Extra<'a> {
Expand All @@ -107,6 +110,7 @@ impl<'a> Extra<'a> {
rec_guard: &'a SerRecursionState,
serialize_unknown: bool,
fallback: Option<&'a PyAny>,
context: Option<&'a PyAny>,
) -> Self {
Self {
mode,
Expand All @@ -124,6 +128,7 @@ impl<'a> Extra<'a> {
field_name: None,
serialize_unknown,
fallback,
context,
}
}

Expand Down Expand Up @@ -178,10 +183,11 @@ pub(crate) struct ExtraOwned {
config: SerializationConfig,
rec_guard: SerRecursionState,
check: SerCheck,
model: Option<PyObject>,
pub model: Option<PyObject>,
field_name: Option<String>,
serialize_unknown: bool,
fallback: Option<PyObject>,
pub fallback: Option<PyObject>,
pub context: Option<PyObject>,
}

impl ExtraOwned {
Expand All @@ -201,6 +207,7 @@ impl ExtraOwned {
field_name: extra.field_name.map(ToString::to_string),
serialize_unknown: extra.serialize_unknown,
fallback: extra.fallback.map(Into::into),
context: extra.context.map(Into::into),
}
}

Expand All @@ -221,6 +228,7 @@ impl ExtraOwned {
field_name: self.field_name.as_deref(),
serialize_unknown: self.serialize_unknown,
fallback: self.fallback.as_ref().map(|m| m.as_ref(py)),
context: self.context.as_ref().map(|m| m.as_ref(py)),
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ pub(crate) fn infer_to_python_known(
extra.rec_guard,
extra.serialize_unknown,
extra.fallback,
extra.context,
);
serializer.serializer.to_python(value, include, exclude, &extra)
};
Expand Down Expand Up @@ -468,6 +469,7 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
extra.rec_guard,
extra.serialize_unknown,
extra.fallback,
extra.context,
);
let pydantic_serializer =
PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, &extra);
Expand Down
18 changes: 14 additions & 4 deletions src/serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ impl SchemaSerializer {
rec_guard: &'a SerRecursionState,
serialize_unknown: bool,
fallback: Option<&'a PyAny>,
context: Option<&'a PyAny>,
) -> Extra<'b> {
Extra::new(
py,
Expand All @@ -69,6 +70,7 @@ impl SchemaSerializer {
rec_guard,
serialize_unknown,
fallback,
context,
)
}
}
Expand All @@ -95,7 +97,7 @@ impl SchemaSerializer {
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (value, *, mode = None, include = None, exclude = None, by_alias = true,
exclude_unset = false, exclude_defaults = false, exclude_none = false, round_trip = false, warnings = true,
fallback = None))]
fallback = None, context = None))]
pub fn to_python(
&self,
py: Python,
Expand All @@ -110,6 +112,7 @@ impl SchemaSerializer {
round_trip: bool,
warnings: bool,
fallback: Option<&PyAny>,
context: Option<&PyAny>,
) -> PyResult<PyObject> {
let mode: SerMode = mode.into();
let warnings = CollectWarnings::new(warnings);
Expand All @@ -126,6 +129,7 @@ impl SchemaSerializer {
&rec_guard,
false,
fallback,
context,
);
let v = self.serializer.to_python(value, include, exclude, &extra)?;
warnings.final_check(py)?;
Expand All @@ -135,7 +139,7 @@ impl SchemaSerializer {
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (value, *, indent = None, include = None, exclude = None, by_alias = true,
exclude_unset = false, exclude_defaults = false, exclude_none = false, round_trip = false, warnings = true,
fallback = None))]
fallback = None, context = None))]
pub fn to_json(
&self,
py: Python,
Expand All @@ -150,6 +154,7 @@ impl SchemaSerializer {
round_trip: bool,
warnings: bool,
fallback: Option<&PyAny>,
context: Option<&PyAny>,
) -> PyResult<PyObject> {
let warnings = CollectWarnings::new(warnings);
let rec_guard = SerRecursionState::default();
Expand All @@ -165,6 +170,7 @@ impl SchemaSerializer {
&rec_guard,
false,
fallback,
context,
);
let bytes = to_json_bytes(
value,
Expand Down Expand Up @@ -213,7 +219,7 @@ impl SchemaSerializer {
#[pyfunction]
#[pyo3(signature = (value, *, indent = None, include = None, exclude = None, by_alias = true,
exclude_none = false, round_trip = false, timedelta_mode = "iso8601", bytes_mode = "utf8",
inf_nan_mode = "constants", serialize_unknown = false, fallback = None))]
inf_nan_mode = "constants", serialize_unknown = false, fallback = None, context = None))]
pub fn to_json(
py: Python,
value: &PyAny,
Expand All @@ -228,6 +234,7 @@ pub fn to_json(
inf_nan_mode: &str,
serialize_unknown: bool,
fallback: Option<&PyAny>,
context: Option<&PyAny>,
) -> PyResult<PyObject> {
let state = SerializationState::new(timedelta_mode, bytes_mode, inf_nan_mode)?;
let extra = state.extra(
Expand All @@ -238,6 +245,7 @@ pub fn to_json(
round_trip,
serialize_unknown,
fallback,
context,
);
let serializer = type_serializers::any::AnySerializer.into();
let bytes = to_json_bytes(value, &serializer, include, exclude, &extra, indent, 1024)?;
Expand All @@ -249,7 +257,7 @@ pub fn to_json(
#[allow(clippy::too_many_arguments)]
#[pyfunction]
#[pyo3(signature = (value, *, include = None, exclude = None, by_alias = true, exclude_none = false, round_trip = false,
timedelta_mode = "iso8601", bytes_mode = "utf8", inf_nan_mode = "constants", serialize_unknown = false, fallback = None))]
timedelta_mode = "iso8601", bytes_mode = "utf8", inf_nan_mode = "constants", serialize_unknown = false, fallback = None, context = None))]
pub fn to_jsonable_python(
py: Python,
value: &PyAny,
Expand All @@ -263,6 +271,7 @@ pub fn to_jsonable_python(
inf_nan_mode: &str,
serialize_unknown: bool,
fallback: Option<&PyAny>,
context: Option<&PyAny>,
) -> PyResult<PyObject> {
let state = SerializationState::new(timedelta_mode, bytes_mode, inf_nan_mode)?;
let extra = state.extra(
Expand All @@ -273,6 +282,7 @@ pub fn to_jsonable_python(
round_trip,
serialize_unknown,
fallback,
context,
);
let v = infer::infer_to_python(value, include, exclude, &extra)?;
state.final_check(py)?;
Expand Down
Loading