diff --git a/src/py_gc.rs b/src/py_gc.rs index c4f83f5b0..c9752e732 100644 --- a/src/py_gc.rs +++ b/src/py_gc.rs @@ -64,7 +64,7 @@ macro_rules! impl_py_gc_traverse { } } }; - ($name:ty { $($fields:ident),* }) => { + ($name:ty { $($fields:ident),* $(,)? }) => { impl crate::py_gc::PyGcTraverse for $name { fn py_gc_traverse(&self, visit: &pyo3::PyVisit<'_>) -> Result<(), pyo3::PyTraverseError> { $(self.$fields.py_gc_traverse(visit)?;)* diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs index 6d76db48f..5121224cd 100644 --- a/src/serializers/extra.rs +++ b/src/serializers/extra.rs @@ -261,6 +261,14 @@ pub(crate) struct ExtraOwned { exclude: Option>, } +impl_py_gc_traverse!(ExtraOwned { + model, + fallback, + context, + include, + exclude, +}); + #[derive(Clone)] enum FieldNameOwned { Root, diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index c44380bc4..63d288e2d 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -48,6 +48,13 @@ pub struct SchemaSerializer { py_config: Option>, } +impl_py_gc_traverse!(SchemaSerializer { + serializer, + definitions, + py_schema, + py_config, +}); + #[pymethods] impl SchemaSerializer { #[new] @@ -186,13 +193,7 @@ impl SchemaSerializer { } fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { - visit.call(&self.py_schema)?; - if let Some(ref py_config) = self.py_config { - visit.call(py_config)?; - } - self.serializer.py_gc_traverse(&visit)?; - self.definitions.py_gc_traverse(&visit)?; - Ok(()) + self.py_gc_traverse(&visit) } } diff --git a/src/serializers/type_serializers/function.rs b/src/serializers/type_serializers/function.rs index 6bcb2fe6f..f08396bb4 100644 --- a/src/serializers/type_serializers/function.rs +++ b/src/serializers/type_serializers/function.rs @@ -11,6 +11,7 @@ use pyo3::PyTraverseError; use pyo3::types::PyString; use crate::definitions::DefinitionsBuilder; +use crate::py_gc::PyGcTraverse; use crate::serializers::SerializationState; use crate::tools::SchemaDict; use crate::tools::{function_name, py_err, py_error_type}; @@ -437,6 +438,11 @@ pub(crate) struct SerializationCallable { filter: AnyFilter, } +impl_py_gc_traverse!(SerializationCallable { + serializer, + extra_owned +}); + impl SerializationCallable { pub fn new(serializer: &Arc, state: &SerializationState<'_, '_>) -> Self { Self { @@ -447,16 +453,7 @@ impl SerializationCallable { } fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { - if let Some(model) = &self.extra_owned.model { - visit.call(model)?; - } - if let Some(fallback) = &self.extra_owned.fallback { - visit.call(fallback)?; - } - if let Some(context) = &self.extra_owned.context { - visit.call(context)?; - } - Ok(()) + self.py_gc_traverse(&visit) } fn __clear__(&mut self) { @@ -542,6 +539,12 @@ struct SerializationInfo { serialize_as_any: bool, } +impl_py_gc_traverse!(SerializationInfo { + include, + exclude, + context +}); + impl SerializationInfo { fn new(state: &SerializationState<'_, '_>, is_field_serializer: bool) -> PyResult { let extra = &state.extra; @@ -584,16 +587,7 @@ impl SerializationInfo { } fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { - if let Some(include) = &self.include { - visit.call(include)?; - } - if let Some(exclude) = &self.exclude { - visit.call(exclude)?; - } - if let Some(context) = &self.context { - visit.call(context)?; - } - Ok(()) + self.py_gc_traverse(&visit) } fn __clear__(&mut self) { diff --git a/src/serializers/type_serializers/generator.rs b/src/serializers/type_serializers/generator.rs index ad6d3a984..56e41f924 100644 --- a/src/serializers/type_serializers/generator.rs +++ b/src/serializers/type_serializers/generator.rs @@ -11,6 +11,7 @@ use pyo3::PyTraverseError; use serde::ser::SerializeSeq; use crate::definitions::DefinitionsBuilder; +use crate::py_gc::PyGcTraverse; use crate::serializers::SerializationState; use crate::tools::SchemaDict; @@ -144,6 +145,12 @@ pub(crate) struct SerializationIterator { filter: SchemaFilter, } +impl_py_gc_traverse!(SerializationIterator { + iterator, + item_serializer, + extra_owned, +}); + impl SerializationIterator { pub fn new( py_iter: &Bound<'_, PyIterator>, @@ -161,16 +168,7 @@ impl SerializationIterator { } fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { - if let Some(model) = &self.extra_owned.model { - visit.call(model)?; - } - if let Some(fallback) = &self.extra_owned.fallback { - visit.call(fallback)?; - } - if let Some(context) = &self.extra_owned.context { - visit.call(context)?; - } - Ok(()) + self.py_gc_traverse(&visit) } fn __clear__(&mut self) { diff --git a/src/validators/function.rs b/src/validators/function.rs index 8436a181e..69390184b 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -536,6 +536,13 @@ pub struct ValidationInfo { mode: InputType, } +impl_py_gc_traverse!(ValidationInfo { + config, + context, + data, + field_name +}); + impl ValidationInfo { fn new(py: Python, extra: &Extra<'_, '_>, config: &Py, field_name: Option>) -> Self { Self { @@ -548,11 +555,7 @@ impl ValidationInfo { } fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { - visit.call(&self.config)?; - if let Some(context) = &self.context { - visit.call(context)?; - } - Ok(()) + self.py_gc_traverse(&visit) } fn __clear__(&mut self) { diff --git a/src/validators/mod.rs b/src/validators/mod.rs index adcf1ba55..6a4ef1348 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -121,6 +121,13 @@ pub struct SchemaValidator { cache_str: StringCacheMode, } +impl_py_gc_traverse!(SchemaValidator { + validator, + definitions, + py_schema, + py_config, +}); + #[pymethods] impl SchemaValidator { #[new] @@ -403,12 +410,7 @@ impl SchemaValidator { } fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { - self.validator.py_gc_traverse(&visit)?; - visit.call(&self.py_schema)?; - if let Some(ref py_config) = self.py_config { - visit.call(py_config)?; - } - Ok(()) + self.py_gc_traverse(&visit) } }