Skip to content

Commit

Permalink
Minor formatting/include cleanups.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 678453559
  • Loading branch information
laramiel authored and copybara-github committed Sep 25, 2024
1 parent 55916e1 commit bf3189c
Show file tree
Hide file tree
Showing 21 changed files with 300 additions and 302 deletions.
4 changes: 2 additions & 2 deletions pybind11_protobuf/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pybind_library(
"//visibility:public",
],
deps = [
"@com_google_protobuf//:protobuf",
"//third_party/protobuf:protobuf_lite",
],
)

Expand Down Expand Up @@ -39,6 +39,7 @@ pybind_library(
],
deps = [
":check_unknown_fields",
"//third_party/protobuf:descriptor_pb",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
Expand Down Expand Up @@ -71,7 +72,6 @@ cc_library(
deps = [
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/meta:type_traits",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_protobuf//:protobuf",
Expand Down
44 changes: 22 additions & 22 deletions pybind11_protobuf/check_unknown_fields.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,23 @@
#include <string>
#include <vector>

#include "google/protobuf/descriptor.h"
#include "google/protobuf/message.h"
#include "google/protobuf/unknown_field_set.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/message.h"
#include "google/protobuf/unknown_field_set.h"
#include "python/google/protobuf/proto_api.h"

namespace pybind11_protobuf::check_unknown_fields {
namespace {

using AllowListSet = absl::flat_hash_set<std::string>;
using MayContainExtensionsMap =
absl::flat_hash_map<const ::google::protobuf::Descriptor*, bool>;
absl::flat_hash_map<const google::protobuf::Descriptor*, bool>;

AllowListSet* GetAllowList() {
static auto* allow_list = new AllowListSet();
Expand All @@ -37,7 +38,7 @@ std::string MakeAllowListKey(

/// Recurses through the message Descriptor class looking for valid extensions.
/// Stores the result to `memoized`.
bool MessageMayContainExtensionsRecursive(const ::google::protobuf::Descriptor* descriptor,
bool MessageMayContainExtensionsRecursive(const google::protobuf::Descriptor* descriptor,
MayContainExtensionsMap* memoized) {
if (descriptor->extension_range_count() > 0) return true;

Expand All @@ -48,7 +49,7 @@ bool MessageMayContainExtensionsRecursive(const ::google::protobuf::Descriptor*

for (int i = 0; i < descriptor->field_count(); i++) {
auto* fd = descriptor->field(i);
if (fd->cpp_type() != ::google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) continue;
if (fd->cpp_type() != google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) continue;
if (MessageMayContainExtensionsRecursive(fd->message_type(), memoized)) {
(*memoized)[descriptor] = true;
return true;
Expand All @@ -58,16 +59,16 @@ bool MessageMayContainExtensionsRecursive(const ::google::protobuf::Descriptor*
return false;
}

bool MessageMayContainExtensionsMemoized(const ::google::protobuf::Descriptor* descriptor) {
bool MessageMayContainExtensionsMemoized(const google::protobuf::Descriptor* descriptor) {
static auto* memoized = new MayContainExtensionsMap();
static absl::Mutex lock;
absl::MutexLock l(&lock);
return MessageMayContainExtensionsRecursive(descriptor, memoized);
}

struct HasUnknownFields {
HasUnknownFields(const ::google::protobuf::python::PyProto_API* py_proto_api,
const ::google::protobuf::Descriptor* root_descriptor)
HasUnknownFields(const google::protobuf::python::PyProto_API* py_proto_api,
const google::protobuf::Descriptor* root_descriptor)
: py_proto_api(py_proto_api), root_descriptor(root_descriptor) {}

std::string FieldFQN() const { return absl::StrJoin(field_fqn_parts, "."); }
Expand All @@ -77,34 +78,33 @@ struct HasUnknownFields {
: absl::StrCat(FieldFQN(), ".", unknown_field_number);
}

bool FindUnknownFieldsRecursive(const ::google::protobuf::Message* sub_message,
bool FindUnknownFieldsRecursive(const google::protobuf::Message* sub_message,
uint32_t depth);

std::string BuildErrorMessage() const;

const ::google::protobuf::python::PyProto_API* py_proto_api;
const ::google::protobuf::Descriptor* root_descriptor = nullptr;
const ::google::protobuf::Descriptor* unknown_field_parent_descriptor = nullptr;
const google::protobuf::python::PyProto_API* py_proto_api;
const google::protobuf::Descriptor* root_descriptor = nullptr;
const google::protobuf::Descriptor* unknown_field_parent_descriptor = nullptr;
std::vector<std::string> field_fqn_parts;
int unknown_field_number;
};

/// Recurses through the message fields class looking for UnknownFields.
bool HasUnknownFields::FindUnknownFieldsRecursive(
const ::google::protobuf::Message* sub_message, uint32_t depth) {
const ::google::protobuf::Reflection& reflection = *sub_message->GetReflection();
const google::protobuf::Message* sub_message, uint32_t depth) {
const google::protobuf::Reflection& reflection = *sub_message->GetReflection();

// If there are unknown fields, stop searching.
const ::google::protobuf::UnknownFieldSet& unknown_field_set =
const google::protobuf::UnknownFieldSet& unknown_field_set =
reflection.GetUnknownFields(*sub_message);
if (!unknown_field_set.empty()) {
unknown_field_parent_descriptor = sub_message->GetDescriptor();
unknown_field_number = unknown_field_set.field(0).number();

// Stop only if the extension is known by Python.
if (py_proto_api->GetDefaultDescriptorPool()->FindExtensionByNumber(
unknown_field_parent_descriptor,
unknown_field_number)) {
unknown_field_parent_descriptor, unknown_field_number)) {
field_fqn_parts.resize(depth);
return true;
}
Expand All @@ -118,11 +118,11 @@ bool HasUnknownFields::FindUnknownFieldsRecursive(

// Otherwise the method has to check all present fields, including
// extensions to determine if they include unknown fields.
std::vector<const ::google::protobuf::FieldDescriptor*> present_fields;
std::vector<const google::protobuf::FieldDescriptor*> present_fields;
reflection.ListFields(*sub_message, &present_fields);

for (const auto* field : present_fields) {
if (field->cpp_type() != ::google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
if (field->cpp_type() != google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
continue;
}
if (field->is_repeated()) {
Expand Down Expand Up @@ -182,8 +182,8 @@ void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,
}

std::optional<std::string> CheckRecursively(
const ::google::protobuf::python::PyProto_API* py_proto_api,
const ::google::protobuf::Message* message) {
const google::protobuf::python::PyProto_API* py_proto_api,
const google::protobuf::Message* message) {
const auto* root_descriptor = message->GetDescriptor();
HasUnknownFields search{py_proto_api, root_descriptor};
if (!search.FindUnknownFieldsRecursive(message, 0u)) {
Expand Down
7 changes: 4 additions & 3 deletions pybind11_protobuf/check_unknown_fields.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
#define PYBIND11_PROTOBUF_CHECK_UNKNOWN_FIELDS_H_

#include <optional>
#include <string>

#include "absl/strings/string_view.h"
#include "google/protobuf/message.h"
#include "python/google/protobuf/proto_api.h"
#include "absl/strings/string_view.h"

namespace pybind11_protobuf::check_unknown_fields {

Expand Down Expand Up @@ -46,8 +47,8 @@ void AllowUnknownFieldsFor(absl::string_view top_message_descriptor_full_name,
absl::string_view unknown_field_parent_message_fqn);

std::optional<std::string> CheckRecursively(
const ::google::protobuf::python::PyProto_API* py_proto_api,
const ::google::protobuf::Message* top_message);
const google::protobuf::python::PyProto_API* py_proto_api,
const google::protobuf::Message* top_message);

} // namespace pybind11_protobuf::check_unknown_fields

Expand Down
16 changes: 7 additions & 9 deletions pybind11_protobuf/enum_type_caster.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,19 @@
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>

#include <string>
#include <type_traits>

#include "google/protobuf/descriptor.h"
#include "google/protobuf/generated_enum_reflection.h"
#include "google/protobuf/generated_enum_util.h"

// pybind11 type_caster specialization which translates Proto::Enum types
// to/from ints. This will have ODR conflicts when users specify wrappers for
// enums using py::enum_<T>.
//
// ::google::protobuf::is_proto_enum and ::google::protobuf::GetEnumDescriptor are require
// google::protobuf::is_proto_enum and google::protobuf::GetEnumDescriptor
// are require
//
// NOTE: The protobuf compiler does not generate ::google::protobuf::is_proto_enum traits
// for enumerations of oneof fields.
// NOTE: The protobuf compiler does not generate
// google::protobuf::is_proto_enum traits for enumerations of oneof fields.
//
// Example:
// #include <pybind11/pybind11.h>
Expand Down Expand Up @@ -100,17 +98,17 @@ constexpr bool pybind11_protobuf_enable_enum_type_caster(...) { return true; }
#if defined(PYBIND11_HAS_NATIVE_ENUM)
template <typename EnumType>
struct type_caster_enum_type_enabled<
EnumType, std::enable_if_t<(::google::protobuf::is_proto_enum<EnumType>::value &&
EnumType, std::enable_if_t<(google::protobuf::is_proto_enum<EnumType>::value &&
pybind11_protobuf_enable_enum_type_caster(
static_cast<EnumType*>(nullptr)))>>
: std::false_type {};
#endif

// Specialization of pybind11::detail::type_caster<T> for types satisfying
// ::google::protobuf::is_proto_enum.
// google::protobuf::is_proto_enum.
template <typename EnumType>
struct type_caster<EnumType,
std::enable_if_t<(::google::protobuf::is_proto_enum<EnumType>::value &&
std::enable_if_t<(google::protobuf::is_proto_enum<EnumType>::value &&
pybind11_protobuf_enable_enum_type_caster(
static_cast<EnumType*>(nullptr)))>>
: public pybind11_protobuf::enum_type_caster<EnumType> {};
Expand Down
33 changes: 14 additions & 19 deletions pybind11_protobuf/native_proto_caster.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,17 @@
// IWYU
#include <Python.h>

#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <type_traits>
#include <utility>

#include "google/protobuf/message.h"
#include "absl/strings/string_view.h"
#include "google/protobuf/message.h"
#include "pybind11_protobuf/enum_type_caster.h"
#include "pybind11_protobuf/proto_caster_impl.h"

// pybind11::type_caster<> specialization for ::google::protobuf::Message types that
// that converts protocol buffer objects between C++ and python representations.
// This binder supports binaries linked with both native python protos
// and fast cpp python protos.
// pybind11::type_caster<> specialization for google::protobuf::Message types
// that that converts protocol buffer objects between C++ and python
// representations. This binder supports binaries linked with both native python
// protos and fast cpp python protos.
//
// When passing protos between python and C++, if possible, an underlying C++
// object may have ownership transferred, or may be copied if both instances
Expand Down Expand Up @@ -85,7 +80,7 @@ constexpr bool pybind11_protobuf_enable_type_caster(...) { return true; }
template <typename ProtoType>
struct type_caster<
ProtoType,
std::enable_if_t<(std::is_base_of<::google::protobuf::Message, ProtoType>::value &&
std::enable_if_t<(std::is_base_of<google::protobuf::Message, ProtoType>::value &&
pybind11_protobuf_enable_type_caster(
static_cast<ProtoType *>(nullptr)))>>
: public pybind11_protobuf::proto_caster<
Expand All @@ -95,12 +90,12 @@ struct type_caster<

template <typename ProtoType>
struct copyable_holder_caster_shared_ptr_with_smart_holder_support_enabled<
ProtoType, enable_if_t<std::is_base_of<::google::protobuf::Message, ProtoType>::value>>
ProtoType, enable_if_t<std::is_base_of<google::protobuf::Message, ProtoType>::value>>
: std::false_type {};

template <typename ProtoType>
struct move_only_holder_caster_unique_ptr_with_smart_holder_support_enabled<
ProtoType, enable_if_t<std::is_base_of<::google::protobuf::Message, ProtoType>::value>>
ProtoType, enable_if_t<std::is_base_of<google::protobuf::Message, ProtoType>::value>>
: std::false_type {};

#endif // PYBIND11_HAS_INTERNALS_WITH_SMART_HOLDER_SUPPORT
Expand All @@ -118,7 +113,7 @@ struct move_only_holder_caster_unique_ptr_with_smart_holder_support_enabled<
template <typename ProtoType, typename HolderType>
struct move_only_holder_caster<
ProtoType, HolderType,
std::enable_if_t<(std::is_base_of<::google::protobuf::Message, ProtoType>::value &&
std::enable_if_t<(std::is_base_of<google::protobuf::Message, ProtoType>::value &&
pybind11_protobuf_enable_type_caster(
static_cast<ProtoType *>(nullptr)))>>
: public pybind11_protobuf::move_only_holder_caster_impl<ProtoType,
Expand All @@ -136,18 +131,18 @@ struct move_only_holder_caster<
template <typename ProtoType, typename HolderType>
struct copyable_holder_caster<
ProtoType, HolderType,
std::enable_if_t<(std::is_base_of<::google::protobuf::Message, ProtoType>::value &&
std::enable_if_t<(std::is_base_of<google::protobuf::Message, ProtoType>::value &&
pybind11_protobuf_enable_type_caster(
static_cast<ProtoType *>(nullptr)))>>
: public pybind11_protobuf::copyable_holder_caster_impl<ProtoType,
HolderType> {};

// NOTE: We also need to add support and/or test classes:
//
// ::google::protobuf::Descriptor
// ::google::protobuf::EnumDescriptor
// ::google::protobuf::EnumValueDescriptor
// ::google::protobuf::FieldDescriptor
// google::protobuf::Descriptor
// google::protobuf::EnumDescriptor
// google::protobuf::EnumValueDescriptor
// google::protobuf::FieldDescriptor
//

} // namespace detail
Expand Down
13 changes: 5 additions & 8 deletions pybind11_protobuf/proto_cast_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>

#include "google/protobuf/descriptor.pb.h"
#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
Expand All @@ -25,6 +23,7 @@
#include "absl/strings/strip.h"
#include "absl/types/optional.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/descriptor.pb.h"
#include "google/protobuf/descriptor_database.h"
#include "google/protobuf/dynamic_message.h"

Expand All @@ -34,7 +33,6 @@ using ::google::protobuf::Descriptor;
using ::google::protobuf::DescriptorDatabase;
using ::google::protobuf::DescriptorPool;
using ::google::protobuf::DynamicMessageFactory;
using ::google::protobuf::FileDescriptor;
using ::google::protobuf::FileDescriptorProto;
using ::google::protobuf::Message;
using ::google::protobuf::MessageFactory;
Expand Down Expand Up @@ -183,11 +181,9 @@ GlobalState::GlobalState() {

// pybind11_protobuf casting needs a dependency on proto internals to work.
try {
ImportCached("google.protobuf.descriptor");
auto descriptor_pool =
ImportCached("google.protobuf.descriptor_pool");
auto message_factory =
ImportCached("google.protobuf.message_factory");
ImportCached("google.protobuf");
auto descriptor_pool = ImportCached("google.protobuf.descriptor_pool");
auto message_factory = ImportCached("google.protobuf.message_factory");
global_pool_ = descriptor_pool.attr("Default")();
find_message_type_by_name_ = global_pool_.attr("FindMessageTypeByName");
if (hasattr(message_factory, "GetMessageClass")) {
Expand Down Expand Up @@ -221,6 +217,7 @@ py::module_ GlobalState::ImportCached(const std::string& module_name) {
if (cached != import_cache_.end()) {
return cached->second;
}
LOG(INFO) << "ImportCached " << module_name;
auto module = py::module_::import(module_name.c_str());
import_cache_[module_name] = module;
return module;
Expand Down
Loading

0 comments on commit bf3189c

Please sign in to comment.