Skip to content

Commit

Permalink
fix memory leak in create_class (pydantic#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt authored May 11, 2022
1 parent 6d0da78 commit 6339ef5
Showing 1 changed file with 29 additions and 32 deletions.
61 changes: 29 additions & 32 deletions src/validators/model_class.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::os::raw::c_int;
use std::ptr::null_mut;

use pyo3::conversion::{AsPyPointer, FromPyPointer};
use pyo3::conversion::AsPyPointer;
use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyTuple, PyType};
use pyo3::{ffi, intern, ToBorrowedObject};
use pyo3::{ffi, intern};

use crate::build_tools::{py_error, SchemaDict};
use crate::errors::{as_internal, context, err_val_error, ErrorKind, InputValue, ValError, ValResult};
Expand Down Expand Up @@ -67,7 +67,7 @@ impl Validator for ModelClassValidator {
)
} else {
let output = self.validator.validate(py, input, extra, slots)?;
unsafe { self.create_class(py, output).map_err(as_internal) }
self.create_class(py, output).map_err(as_internal)
}
}

Expand Down Expand Up @@ -99,51 +99,48 @@ impl Validator for ModelClassValidator {
}

impl ModelClassValidator {
unsafe fn create_class(&self, py: Python, output: PyObject) -> PyResult<PyObject> {
let t: &PyTuple = output.extract(py)?;
let model_dict = t.get_item(0)?;
let fields_set = t.get_item(1)?;
fn create_class(&self, py: Python, output: PyObject) -> PyResult<PyObject> {
let (model_dict, fields_set): (&PyAny, &PyAny) = output.extract(py)?;

// based on the following but with the second argument of new_func set to an empty tuple as required
// https://github.com/PyO3/pyo3/blob/d2caa056e9aacc46374139ef491d112cb8af1a25/src/pyclass_init.rs#L35-L77
let args = PyTuple::empty(py);
let raw_type = self.class.as_ref(py).as_type_ptr();
let instance_ptr = match (*raw_type).tp_new {
Some(new_func) => {
let obj = new_func(raw_type, args.as_ptr(), null_mut());
if obj.is_null() {
return Err(PyErr::fetch(py));
} else {
obj
}
let instance = unsafe {
// Safety: raw_type is known to be a non-null type object pointer
match (*raw_type).tp_new {
// Safety: the result of new_func is guaranteed to be either an owned pointer or null on error returns.
Some(new_func) => PyObject::from_owned_ptr_or_err(
py,
// Safety: the non-null pointers are known to be valid, and it's allowed to call tp_new with a
// null kwargs dict.
new_func(raw_type, args.as_ptr(), null_mut()),
)?,
None => return Err(PyTypeError::new_err("base type without tp_new")),
}
None => return Err(PyTypeError::new_err("base type without tp_new")),
};

force_setattr(instance_ptr, py, intern!(py, "__dict__"), model_dict)?;
force_setattr(instance_ptr, py, intern!(py, "__fields_set__"), fields_set)?;
let instance_ref = instance.as_ref(py);
force_setattr(py, instance_ref, intern!(py, "__dict__"), model_dict)?;
force_setattr(py, instance_ref, intern!(py, "__fields_set__"), fields_set)?;

match PyAny::from_borrowed_ptr_or_opt(py, instance_ptr) {
Some(instance) => Ok(instance.into()),
None => Err(PyTypeError::new_err("failed to create instance of class")),
}
Ok(instance)
}
}

/// copied and modified from
/// https://github.com/PyO3/pyo3/blob/d2caa056e9aacc46374139ef491d112cb8af1a25/src/instance.rs#L587-L597
/// to use `PyObject_GenericSetAttr` thereby bypassing `__setattr__` methods on the instance,
/// see https://github.com/PyO3/pyo3/discussions/2321 for discussion
pub fn force_setattr<N, V>(obj: *mut ffi::PyObject, py: Python<'_>, attr_name: N, value: V) -> PyResult<()>
pub fn force_setattr<N, V>(py: Python<'_>, obj: &PyAny, attr_name: N, value: V) -> PyResult<()>
where
N: ToPyObject,
V: ToPyObject,
{
attr_name.with_borrowed_ptr(py, move |attr_name| {
value.with_borrowed_ptr(py, |value| unsafe {
error_on_minusone(py, ffi::PyObject_GenericSetAttr(obj, attr_name, value))
})
})
let attr_name = attr_name.to_object(py);
let value = value.to_object(py);
unsafe {
error_on_minusone(
py,
ffi::PyObject_GenericSetAttr(obj.as_ptr(), attr_name.as_ptr(), value.as_ptr()),
)
}
}

// Defined here as it's not exported by pyo3
Expand Down

0 comments on commit 6339ef5

Please sign in to comment.