diff --git a/src/google/protobuf/dynamic_message.cc b/src/google/protobuf/dynamic_message.cc index 561047f2f00f..9de81bd0c012 100644 --- a/src/google/protobuf/dynamic_message.cc +++ b/src/google/protobuf/dynamic_message.cc @@ -214,8 +214,9 @@ class DynamicMapField final static void ClearMapNoSyncImpl(MapFieldBase& base); static bool DeleteMapValueImpl(MapFieldBase& map, const MapKey& map_key); static void SetMapIteratorValueImpl(MapIterator* map_iter); - static bool LookupMapValueImpl(const MapFieldBase& self, - const MapKey& map_key, MapValueConstRef* val); + static bool LookupMapValueNoSyncImpl(const MapFieldBase& self, + const MapKey& map_key, + MapValueConstRef* val); static void UnsafeShallowSwapImpl(MapFieldBase& lhs, MapFieldBase& rhs) { static_cast(lhs).Swap( @@ -278,10 +279,10 @@ void DynamicMapField::SetMapIteratorValueImpl(MapIterator* map_iter) { map_iter->value_.CopyFrom(iter->second); } -bool DynamicMapField::LookupMapValueImpl(const MapFieldBase& self, - const MapKey& map_key, - MapValueConstRef* val) { - const auto& map = static_cast(self).GetMap(); +bool DynamicMapField::LookupMapValueNoSyncImpl(const MapFieldBase& self, + const MapKey& map_key, + MapValueConstRef* val) { + const auto& map = static_cast(self).map_; auto iter = map.find(map_key); if (map.end() == iter) { return false; diff --git a/src/google/protobuf/map.h b/src/google/protobuf/map.h index 739d17283e53..8e3b42cc1be4 100644 --- a/src/google/protobuf/map.h +++ b/src/google/protobuf/map.h @@ -680,6 +680,10 @@ class KeyMapBase : public UntypedMapBase { friend class RustMapHelper; friend class v2::TableDriven; + Key* GetKey(NodeBase* node) const { + return UntypedMapBase::GetKey(node); + } + PROTOBUF_NOINLINE size_type EraseImpl(map_index_t b, KeyNode* node, bool do_destroy) { // Force bucket_index to be in range. diff --git a/src/google/protobuf/map_field.cc b/src/google/protobuf/map_field.cc index 00cd7bc5ab05..f2af91e7e082 100644 --- a/src/google/protobuf/map_field.cc +++ b/src/google/protobuf/map_field.cc @@ -11,6 +11,7 @@ #include #include #include +#include #include "absl/functional/overload.h" #include "absl/log/absl_check.h" @@ -35,6 +36,58 @@ MapFieldBase::~MapFieldBase() { delete maybe_payload(); } +template +auto VisitMapKey(const MapKey& map_key, Map& map, F f) { + switch (map_key.type()) { +#define HANDLE_TYPE(CPPTYPE, Type, KeyBaseType) \ + case FieldDescriptor::CPPTYPE_##CPPTYPE: { \ + using KMB = KeyMapBase; \ + return f( \ + static_cast< \ + std::conditional_t, const KMB&, KMB&>>(map), \ + TransparentSupport::ToView(map_key.Get##Type##Value())); \ + } + HANDLE_TYPE(INT32, Int32, uint32_t); + HANDLE_TYPE(UINT32, UInt32, uint32_t); + HANDLE_TYPE(INT64, Int64, uint64_t); + HANDLE_TYPE(UINT64, UInt64, uint64_t); + HANDLE_TYPE(BOOL, Bool, bool); + HANDLE_TYPE(STRING, String, std::string); +#undef HANDLE_TYPE + default: + Unreachable(); + } +} + +bool MapFieldBase::InsertOrLookupMapValueNoSyncImpl(MapFieldBase& self, + const MapKey& map_key, + MapValueRef* val) { + if (LookupMapValueNoSyncImpl(self, map_key, + static_cast(val))) { + return false; + } + + auto& map = self.GetMapRaw(); + + NodeBase* node = map.AllocNode(); + map.VisitValue(node, [&](auto* v) { self.InitializeKeyValue(v); }); + val->SetValue(map.GetVoidValue(node)); + + return VisitMapKey(map_key, map, [&](auto& map, const auto& key) { + self.InitializeKeyValue(map.GetKey(node), key); + map.InsertOrReplaceNode( + static_cast::KeyNode*>(node)); + return true; + }); +} + +bool MapFieldBase::DeleteMapValueImpl(MapFieldBase& self, + const MapKey& map_key) { + return VisitMapKey( + map_key, *self.MutableMap(), + [](auto& map, const auto& key) { return map.EraseImpl(key); }); +} + void MapFieldBase::ClearMapNoSyncImpl(MapFieldBase& self) { self.GetMapRaw().ClearTable(true, nullptr); } @@ -56,35 +109,22 @@ void MapFieldBase::SetMapIteratorValueImpl(MapIterator* map_iter) { map_iter->value_.SetValue(map.GetVoidValue(node)); } -bool MapFieldBase::LookupMapValueImpl(const MapFieldBase& self, - const MapKey& map_key, - MapValueConstRef* val) { - auto& map = self.GetMap(); +bool MapFieldBase::LookupMapValueNoSyncImpl(const MapFieldBase& self, + const MapKey& map_key, + MapValueConstRef* val) { + auto& map = self.GetMapRaw(); if (map.empty()) return false; - switch (map_key.type()) { -#define HANDLE_TYPE(CPPTYPE, Type, KeyBaseType) \ - case FieldDescriptor::CPPTYPE_##CPPTYPE: { \ - auto& key_map = static_cast&>(map); \ - auto res = key_map.FindHelper(map_key.Get##Type##Value()); \ - if (res.node == nullptr) { \ - return false; \ - } \ - if (val != nullptr) { \ - val->SetValue(map.GetVoidValue(res.node)); \ - } \ - return true; \ - } - HANDLE_TYPE(INT32, Int32, uint32_t); - HANDLE_TYPE(UINT32, UInt32, uint32_t); - HANDLE_TYPE(INT64, Int64, uint64_t); - HANDLE_TYPE(UINT64, UInt64, uint64_t); - HANDLE_TYPE(BOOL, Bool, bool); - HANDLE_TYPE(STRING, String, std::string); -#undef HANDLE_TYPE - default: - Unreachable(); - } + return VisitMapKey(map_key, map, [&](auto& map, const auto& key) { + auto res = map.FindHelper(key); + if (res.node == nullptr) { + return false; + } + if (val != nullptr) { + val->SetValue(map.GetVoidValue(res.node)); + } + return true; + }); } size_t MapFieldBase::SpaceUsedExcludingSelfNoLockImpl(const MapFieldBase& map) { diff --git a/src/google/protobuf/map_field.h b/src/google/protobuf/map_field.h index 5b1874ee4d4a..29fd6718ca06 100644 --- a/src/google/protobuf/map_field.h +++ b/src/google/protobuf/map_field.h @@ -285,6 +285,14 @@ template struct MapDynamicFieldInfo; struct MapFieldTestPeer; +// Return the prototype message for a Map entry. +// REQUIRES: `default_entry` is a map entry message. +// REQUIRES: mapped_type is of type message. +inline const Message& GetMapEntryValuePrototype(const Message& default_entry) { + return default_entry.GetReflection()->GetMessage( + default_entry, default_entry.GetDescriptor()->map_value()); +} + // This class provides access to map field using reflection, which is the same // as those provided for RepeatedPtrField. It is used for internal // reflection implementation only. Users should never use this directly. @@ -302,8 +310,9 @@ class PROTOBUF_EXPORT MapFieldBase : public MapFieldBaseForParse { ~MapFieldBase(); struct VTable : MapFieldBaseForParse::VTable { - bool (*lookup_map_value)(const MapFieldBase& map, const MapKey& map_key, - MapValueConstRef* val); + bool (*lookup_map_value_no_sync)(const MapFieldBase& map, + const MapKey& map_key, + MapValueConstRef* val); bool (*delete_map_value)(MapFieldBase& map, const MapKey& map_key); void (*set_map_iterator_value)(MapIterator* map_iter); bool (*insert_or_lookup_no_sync)(MapFieldBase& map, const MapKey& map_key, @@ -321,7 +330,7 @@ class PROTOBUF_EXPORT MapFieldBase : public MapFieldBaseForParse { static constexpr VTable MakeVTable() { VTable out{}; out.get_map = &T::GetMapImpl; - out.lookup_map_value = &T::LookupMapValueImpl; + out.lookup_map_value_no_sync = &T::LookupMapValueNoSyncImpl; out.delete_map_value = &T::DeleteMapValueImpl; out.set_map_iterator_value = &T::SetMapIteratorValueImpl; out.insert_or_lookup_no_sync = &T::InsertOrLookupMapValueNoSyncImpl; @@ -349,7 +358,8 @@ class PROTOBUF_EXPORT MapFieldBase : public MapFieldBaseForParse { return LookupMapValue(map_key, static_cast(nullptr)); } bool LookupMapValue(const MapKey& map_key, MapValueConstRef* val) const { - return vtable()->lookup_map_value(*this, map_key, val); + SyncMapWithRepeatedField(); + return vtable()->lookup_map_value_no_sync(*this, map_key, val); } bool LookupMapValue(const MapKey&, MapValueRef*) const = delete; @@ -503,8 +513,13 @@ class PROTOBUF_EXPORT MapFieldBase : public MapFieldBaseForParse { bool is_mutable); static void ClearMapNoSyncImpl(MapFieldBase& self); static void SetMapIteratorValueImpl(MapIterator* map_iter); - static bool LookupMapValueImpl(const MapFieldBase& self, - const MapKey& map_key, MapValueConstRef* val); + static bool LookupMapValueNoSyncImpl(const MapFieldBase& self, + const MapKey& map_key, + MapValueConstRef* val); + static bool InsertOrLookupMapValueNoSyncImpl(MapFieldBase& self, + const MapKey& map_key, + MapValueRef* val); + static bool DeleteMapValueImpl(MapFieldBase& self, const MapKey& map_key); private: friend class ContendedMapCleanTest; @@ -513,6 +528,21 @@ class PROTOBUF_EXPORT MapFieldBase : public MapFieldBaseForParse { friend class google::protobuf::Reflection; friend class google::protobuf::DynamicMessage; + template + void InitializeKeyValue(T* v, const U&... init) { + ::new (static_cast(v)) T(init...); + if constexpr (std::is_same_v) { + if (arena() != nullptr) { + arena()->OwnDestructor(v); + } + } + } + + void InitializeKeyValue(MessageLite* msg) { + GetClassData(GetMapEntryValuePrototype(*GetPrototype())) + ->PlacementNew(msg, arena()); + } + // See assertion in TypeDefinedMapFieldBase::TypeDefinedMapFieldBase() const UntypedMapBase& GetMapRaw() const { return *reinterpret_cast(this + 1); @@ -614,11 +644,6 @@ class TypeDefinedMapFieldBase : public MapFieldBase { using Iter = typename Map::const_iterator; - static bool DeleteMapValueImpl(MapFieldBase& map, const MapKey& map_key); - static bool InsertOrLookupMapValueNoSyncImpl(MapFieldBase& map, - const MapKey& map_key, - MapValueRef* val); - static void MergeFromImpl(MapFieldBase& base, const MapFieldBase& other); static void SwapImpl(MapFieldBase& lhs, MapFieldBase& rhs); static void UnsafeShallowSwapImpl(MapFieldBase& lhs, MapFieldBase& rhs); diff --git a/src/google/protobuf/map_field_inl.h b/src/google/protobuf/map_field_inl.h index 45c782bb1ae9..7e6f3b81bb0b 100644 --- a/src/google/protobuf/map_field_inl.h +++ b/src/google/protobuf/map_field_inl.h @@ -33,51 +33,7 @@ namespace google { namespace protobuf { namespace internal { -// UnwrapMapKey template. We're using overloading rather than template -// specialization so that we can return a value or reference type depending on -// `T`. -inline int32_t UnwrapMapKeyImpl(const MapKey& map_key, const int32_t*) { - return map_key.GetInt32Value(); -} -inline uint32_t UnwrapMapKeyImpl(const MapKey& map_key, const uint32_t*) { - return map_key.GetUInt32Value(); -} -inline int64_t UnwrapMapKeyImpl(const MapKey& map_key, const int64_t*) { - return map_key.GetInt64Value(); -} -inline uint64_t UnwrapMapKeyImpl(const MapKey& map_key, const uint64_t*) { - return map_key.GetUInt64Value(); -} -inline bool UnwrapMapKeyImpl(const MapKey& map_key, const bool*) { - return map_key.GetBoolValue(); -} -inline absl::string_view UnwrapMapKeyImpl(const MapKey& map_key, - const std::string*) { - return map_key.GetStringValue(); -} - -template -decltype(auto) UnwrapMapKey(const MapKey& map_key) { - return UnwrapMapKeyImpl(map_key, static_cast(nullptr)); -} - // ------------------------TypeDefinedMapFieldBase--------------- -template -bool TypeDefinedMapFieldBase::InsertOrLookupMapValueNoSyncImpl( - MapFieldBase& map, const MapKey& map_key, MapValueRef* val) { - auto res = static_cast(map).map_.try_emplace( - UnwrapMapKey(map_key)); - val->SetValue(&res.first->second); - return res.second; -} - -template -bool TypeDefinedMapFieldBase::DeleteMapValueImpl( - MapFieldBase& map, const MapKey& map_key) { - return static_cast(map).MutableMap()->erase( - UnwrapMapKey(map_key)); -} - template void TypeDefinedMapFieldBase::SwapImpl(MapFieldBase& lhs, MapFieldBase& rhs) {