From 5f3c4bbe1743606718d036a7e34865a5e6c36c88 Mon Sep 17 00:00:00 2001 From: kngwyu Date: Thu, 8 Oct 2020 13:54:17 +0900 Subject: [PATCH] Change PyCFunction to take &'static str as a function name --- pyo3-derive-backend/src/module.rs | 7 ++++- src/class/methods.rs | 1 - src/types/function.rs | 47 +++++++------------------------ tests/test_pyfunction.rs | 2 +- 4 files changed, 17 insertions(+), 40 deletions(-) diff --git a/pyo3-derive-backend/src/module.rs b/pyo3-derive-backend/src/module.rs index 657d1134a8c..1357079bd41 100644 --- a/pyo3-derive-backend/src/module.rs +++ b/pyo3-derive-backend/src/module.rs @@ -212,7 +212,12 @@ pub fn add_fn_to_module( fn #function_wrapper_ident<'a>( args: impl Into> ) -> pyo3::PyResult<&'a pyo3::types::PyCFunction> { - pyo3::types::PyCFunction::new_with_keywords(#wrapper_ident, stringify!(#python_name), #doc, args.into()) + pyo3::types::PyCFunction::new_with_keywords( + #wrapper_ident, + concat!(stringify!(#python_name), "\0"), + #doc, + args.into(), + ) } }) } diff --git a/src/class/methods.rs b/src/class/methods.rs index 4713d2acfaf..67269705706 100644 --- a/src/class/methods.rs +++ b/src/class/methods.rs @@ -35,7 +35,6 @@ pub enum PyMethodType { PyInitFunc(ffi::initproc), } -// TODO(kngwyu): We should also use &'static CStr for this? I'm not sure. #[derive(Clone, Debug)] pub struct PyMethodDef { ml_name: &'static CStr, diff --git a/src/types/function.rs b/src/types/function.rs index e185c83365c..aaf06a3acfc 100644 --- a/src/types/function.rs +++ b/src/types/function.rs @@ -1,9 +1,6 @@ -use std::ffi::{CStr, CString}; - use crate::derive_utils::PyFunctionArguments; -use crate::exceptions::PyValueError; use crate::prelude::*; -use crate::{class, ffi, AsPyPointer, PyMethodType}; +use crate::{ffi, AsPyPointer, PyMethodDef}; /// Represents a builtin Python function object. #[repr(transparent)] @@ -17,56 +14,32 @@ impl PyCFunction { /// See [raw_pycfunction] for documentation on how to get the `fun` argument. pub fn new_with_keywords<'a>( fun: ffi::PyCFunctionWithKeywords, - name: &str, + name: &'static str, doc: &'static str, py_or_module: PyFunctionArguments<'a>, ) -> PyResult<&'a PyCFunction> { - let fun = PyMethodType::PyCFunctionWithKeywords(fun); - Self::new_(fun, name, doc, py_or_module) + Self::new_( + PyMethodDef::cfunction_with_keywords(name, fun, 0, doc), + py_or_module, + ) } /// Create a new built-in function without keywords. pub fn new<'a>( fun: ffi::PyCFunction, - name: &str, + name: &'static str, doc: &'static str, py_or_module: PyFunctionArguments<'a>, ) -> PyResult<&'a PyCFunction> { - let fun = PyMethodType::PyCFunction(fun); - Self::new_(fun, name, doc, py_or_module) + Self::new_(PyMethodDef::cfunction(name, fun, doc), py_or_module) } fn new_<'a>( - fun: class::PyMethodType, - name: &str, - doc: &'static str, + def: PyMethodDef, py_or_module: PyFunctionArguments<'a>, ) -> PyResult<&'a PyCFunction> { let (py, module) = py_or_module.into_py_and_maybe_module(); - let doc: &'static CStr = CStr::from_bytes_with_nul(doc.as_bytes()) - .map_err(|_| PyValueError::new_err("docstring must end with NULL byte."))?; - let name = CString::new(name.as_bytes()).map_err(|_| { - PyValueError::new_err("Function name cannot contain contain NULL byte.") - })?; - let def = match fun { - PyMethodType::PyCFunction(fun) => ffi::PyMethodDef { - ml_name: name.into_raw() as _, - ml_meth: Some(fun), - ml_flags: ffi::METH_VARARGS, - ml_doc: doc.as_ptr() as _, - }, - PyMethodType::PyCFunctionWithKeywords(fun) => ffi::PyMethodDef { - ml_name: name.into_raw() as _, - ml_meth: Some(unsafe { std::mem::transmute(fun) }), - ml_flags: ffi::METH_VARARGS | ffi::METH_KEYWORDS, - ml_doc: doc.as_ptr() as _, - }, - _ => { - return Err(PyValueError::new_err( - "Only PyCFunction and PyCFunctionWithKeywords are valid.", - )) - } - }; + let def = def.as_method_def(); let (mod_ptr, module_name) = if let Some(m) = module { let mod_ptr = m.as_ptr(); let name = m.name()?.into_py(py); diff --git a/tests/test_pyfunction.rs b/tests/test_pyfunction.rs index 89b4dbd01e3..a1563e31404 100644 --- a/tests/test_pyfunction.rs +++ b/tests/test_pyfunction.rs @@ -100,7 +100,7 @@ fn test_raw_function() { let gil = Python::acquire_gil(); let py = gil.python(); let raw_func = raw_pycfunction!(optional_bool); - let fun = PyCFunction::new_with_keywords(raw_func, "fun", "\0", py.into()).unwrap(); + let fun = PyCFunction::new_with_keywords(raw_func, "fun\0", "\0", py.into()).unwrap(); let res = fun.call((), None).unwrap().extract::<&str>().unwrap(); assert_eq!(res, "Some(true)"); let res = fun.call((false,), None).unwrap().extract::<&str>().unwrap();