From f32b2359baa6e0c8fac42ba53da3c3d45bc4aabf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Br=C3=BCns?= Date: Thu, 13 Jun 2024 12:24:56 +0200 Subject: [PATCH] Make protobuf python_api fully optional In case `PYBIND11_PROTOBUF_ENABLE_PYPROTO_API` is not defined, the py_proto_api_ member of the GlobalState singleton is never changed from its default nullptr value. Any code protected by a `GlobalState::instance()->py_proto_api()` check can thus also be made dependent on the `PYPROTO_API` define. This allows to remove the dependency on the proto_api.h header file. As the call to check_unknown_fields::CheckRecursively is also protected by the `py_proto_api()` it can be stubbed out. See #127. --- pybind11_protobuf/check_unknown_fields.cc | 4 ++++ pybind11_protobuf/check_unknown_fields.h | 7 ++++++- pybind11_protobuf/proto_cast_util.cc | 18 +++++++++++++++++- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/pybind11_protobuf/check_unknown_fields.cc b/pybind11_protobuf/check_unknown_fields.cc index 0639d09..fa6e30c 100644 --- a/pybind11_protobuf/check_unknown_fields.cc +++ b/pybind11_protobuf/check_unknown_fields.cc @@ -34,6 +34,7 @@ std::string MakeAllowListKey( return absl::StrCat(top_message_descriptor_full_name, ":", unknown_field_parent_message_fqn); } +#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API) /// Recurses through the message Descriptor class looking for valid extensions. /// Stores the result to `memoized`. @@ -173,6 +174,7 @@ std::string HasUnknownFields::BuildErrorMessage() const { return emsg; } +#endif } // namespace void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name, @@ -181,6 +183,7 @@ void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name, unknown_field_parent_message_fqn)); } +#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API) std::optional CheckRecursively( const ::google::protobuf::python::PyProto_API* py_proto_api, const ::google::protobuf::Message* message) { @@ -195,5 +198,6 @@ std::optional CheckRecursively( } return search.BuildErrorMessage(); } +#endif } // namespace pybind11_protobuf::check_unknown_fields diff --git a/pybind11_protobuf/check_unknown_fields.h b/pybind11_protobuf/check_unknown_fields.h index 79ac001..4713620 100644 --- a/pybind11_protobuf/check_unknown_fields.h +++ b/pybind11_protobuf/check_unknown_fields.h @@ -3,9 +3,12 @@ #include +#include "absl/strings/string_view.h" #include "google/protobuf/message.h" + +#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API) #include "python/google/protobuf/proto_api.h" -#include "absl/strings/string_view.h" +#endif // PYBIND11_PROTOBUF_ENABLE_PYPROTO_API namespace pybind11_protobuf::check_unknown_fields { @@ -45,9 +48,11 @@ class ExtensionsWithUnknownFieldsPolicy { void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name, absl::string_view unknown_field_parent_message_fqn); +#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API) std::optional CheckRecursively( const ::google::protobuf::python::PyProto_API* py_proto_api, const ::google::protobuf::Message* top_message); +#endif // PYBIND11_PROTOBUF_ENABLE_PYPROTO_API } // namespace pybind11_protobuf::check_unknown_fields diff --git a/pybind11_protobuf/proto_cast_util.cc b/pybind11_protobuf/proto_cast_util.cc index 223b011..14ae6e7 100644 --- a/pybind11_protobuf/proto_cast_util.cc +++ b/pybind11_protobuf/proto_cast_util.cc @@ -27,7 +27,13 @@ #include "google/protobuf/descriptor.h" #include "google/protobuf/descriptor_database.h" #include "google/protobuf/dynamic_message.h" +#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API) #include "python/google/protobuf/proto_api.h" +#else +namespace google::protobuf::python { +struct PyProto_API; +} +#endif #include "pybind11_protobuf/check_unknown_fields.h" #if defined(GOOGLE_PROTOBUF_VERSION) @@ -46,7 +52,6 @@ using ::google::protobuf::FileDescriptorProto; using ::google::protobuf::Message; using ::google::protobuf::MessageFactory; using ::google::protobuf::python::PyProto_API; -using ::google::protobuf::python::PyProtoAPICapsuleName; namespace pybind11_protobuf { @@ -266,6 +271,7 @@ GlobalState::GlobalState() { // // By default (3) is used, however if the define is set *and* the version // matches, then pybind11_protobuf will assume that this will work. + using ::google::protobuf::python::PyProtoAPICapsuleName; py_proto_api_ = static_cast(PyCapsule_Import(PyProtoAPICapsuleName(), 0)); if (py_proto_api_ == nullptr) { @@ -355,6 +361,7 @@ py::object GlobalState::PyMessageInstance(const Descriptor* descriptor) { module_name + "?"); } +#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API) std::pair GlobalState::PyFastCppProtoMessageInstance( const Descriptor* descriptor) { assert(descriptor != nullptr); @@ -395,6 +402,7 @@ std::pair GlobalState::PyFastCppProtoMessageInstance( } return {std::move(result), message}; } +#endif // Create C++ DescriptorPools based on Python DescriptorPools. // The Python pool will provide message definitions when they are needed. @@ -534,6 +542,7 @@ class PythonDescriptorPoolWrapper { private: bool CopyToFileDescriptorProto(py::handle py_file_descriptor, FileDescriptorProto* output) { +#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API) if (GlobalState::instance()->py_proto_api()) { try { py::object c_proto = py::reinterpret_steal( @@ -552,6 +561,7 @@ class PythonDescriptorPoolWrapper { PyErr_Print(); } } +#endif return output->ParsePartialFromString( PyBytesAsStringView(py_file_descriptor.attr("serialized_pb"))); @@ -750,6 +760,7 @@ py::handle GenericPyProtoCast(Message* src, py::return_value_policy policy, return py_proto.release(); } +#if defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API) py::handle GenericFastCppProtoCast(Message* src, py::return_value_policy policy, py::handle parent, bool is_const) { assert(policy != pybind11::return_value_policy::automatic); @@ -823,6 +834,7 @@ py::handle GenericFastCppProtoCast(Message* src, py::return_value_policy policy, throw py::cast_error(message + ReturnValuePolicyName(policy)); } } +#endif py::handle GenericProtoCast(Message* src, py::return_value_policy policy, py::handle parent, bool is_const) { @@ -833,6 +845,9 @@ py::handle GenericProtoCast(Message* src, py::return_value_policy policy, // 1. The binary does not have a py_proto_api instance, or // 2. a) the proto is from the default pool and // b) the binary is not using fast_cpp_protos. +#if ! defined(PYBIND11_PROTOBUF_ENABLE_PYPROTO_API) + return GenericPyProtoCast(src, policy, parent, is_const); +#else if (GlobalState::instance()->py_proto_api() == nullptr || (src->GetDescriptor()->file()->pool() == DescriptorPool::generated_pool() && @@ -861,6 +876,7 @@ py::handle GenericProtoCast(Message* src, py::return_value_policy policy, // construct a mapping between C++ pool() and python pool(), and then // use the PyProto_API to make it work. return GenericFastCppProtoCast(src, policy, parent, is_const); +#endif } } // namespace pybind11_protobuf