Skip to content

Commit

Permalink
Add prefetching of subsequent extensions in ExtensionSet::ForEach.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 671457336
  • Loading branch information
ezbr authored and copybara-github committed Sep 5, 2024
1 parent f72e5ce commit 9b019ee
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 131 deletions.
63 changes: 46 additions & 17 deletions src/google/protobuf/extension_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ void ExtensionSet::RegisterMessageExtension(const MessageLite* extendee,
ExtensionSet::~ExtensionSet() {
// Deletes all allocated extensions.
if (arena_ == nullptr) {
ForEach([](int /* number */, Extension& ext) { ext.Free(); });
ForEach([](int /* number */, Extension& ext) { ext.Free(); },
PrefetchNta{});
if (PROTOBUF_PREDICT_FALSE(is_large())) {
delete map_.large;
} else {
Expand Down Expand Up @@ -225,7 +226,7 @@ bool ExtensionSet::HasLazy(int number) const {

int ExtensionSet::NumExtensions() const {
int result = 0;
ForEach([&result](int /* number */, const Extension& ext) {
ForEachNoPrefetch([&result](int /* number */, const Extension& ext) {
if (!ext.is_cleared) {
++result;
}
Expand Down Expand Up @@ -308,6 +309,7 @@ enum { REPEATED_FIELD, OPTIONAL_FIELD };
ABSL_DCHECK_EQ(cpp_type(extension->type), \
WireFormatLite::CPPTYPE_##UPPERCASE); \
extension->is_repeated = false; \
extension->is_pointer = false; \
} else { \
ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, UPPERCASE); \
} \
Expand Down Expand Up @@ -351,6 +353,7 @@ enum { REPEATED_FIELD, OPTIONAL_FIELD };
ABSL_DCHECK_EQ(cpp_type(extension->type), \
WireFormatLite::CPPTYPE_##UPPERCASE); \
extension->is_repeated = true; \
extension->is_pointer = true; \
extension->is_packed = packed; \
extension->ptr.repeated_##LOWERCASE##_value = \
Arena::Create<RepeatedField<LOWERCASE>>(arena_); \
Expand Down Expand Up @@ -391,6 +394,7 @@ void* ExtensionSet::MutableRawRepeatedField(int number, FieldType field_type,
// extension.
if (MaybeNewExtension(number, desc, &extension)) {
extension->is_repeated = true;
extension->is_pointer = true;
extension->type = field_type;
extension->is_packed = packed;

Expand Down Expand Up @@ -487,6 +491,7 @@ void ExtensionSet::SetEnum(int number, FieldType type, int value,
extension->type = type;
ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_ENUM);
extension->is_repeated = false;
extension->is_pointer = false;
} else {
ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, ENUM);
}
Expand Down Expand Up @@ -522,6 +527,7 @@ void ExtensionSet::AddEnum(int number, FieldType type, bool packed, int value,
extension->type = type;
ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_ENUM);
extension->is_repeated = true;
extension->is_pointer = true;
extension->is_packed = packed;
extension->ptr.repeated_enum_value =
Arena::Create<RepeatedField<int>>(arena_);
Expand Down Expand Up @@ -554,6 +560,7 @@ std::string* ExtensionSet::MutableString(int number, FieldType type,
extension->type = type;
ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_STRING);
extension->is_repeated = false;
extension->is_pointer = true;
extension->ptr.string_value = Arena::Create<std::string>(arena_);
} else {
ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, STRING);
Expand Down Expand Up @@ -584,6 +591,7 @@ std::string* ExtensionSet::AddString(int number, FieldType type,
extension->type = type;
ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_STRING);
extension->is_repeated = true;
extension->is_pointer = true;
extension->is_packed = false;
extension->ptr.repeated_string_value =
Arena::Create<RepeatedPtrField<std::string>>(arena_);
Expand Down Expand Up @@ -626,6 +634,7 @@ MessageLite* ExtensionSet::MutableMessage(int number, FieldType type,
extension->type = type;
ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_MESSAGE);
extension->is_repeated = false;
extension->is_pointer = true;
extension->is_lazy = false;
extension->ptr.message_value = prototype.New(arena_);
extension->is_cleared = false;
Expand Down Expand Up @@ -663,6 +672,7 @@ void ExtensionSet::SetAllocatedMessage(int number, FieldType type,
extension->type = type;
ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_MESSAGE);
extension->is_repeated = false;
extension->is_pointer = true;
extension->is_lazy = false;
if (message_arena == arena) {
extension->ptr.message_value = message;
Expand Down Expand Up @@ -707,6 +717,7 @@ void ExtensionSet::UnsafeArenaSetAllocatedMessage(
extension->type = type;
ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_MESSAGE);
extension->is_repeated = false;
extension->is_pointer = true;
extension->is_lazy = false;
extension->ptr.message_value = message;
} else {
Expand Down Expand Up @@ -805,6 +816,7 @@ MessageLite* ExtensionSet::AddMessage(int number, FieldType type,
extension->type = type;
ABSL_DCHECK_EQ(cpp_type(extension->type), WireFormatLite::CPPTYPE_MESSAGE);
extension->is_repeated = true;
extension->is_pointer = true;
extension->ptr.repeated_message_value =
Arena::Create<RepeatedPtrField<MessageLite>>(arena_);
} else {
Expand Down Expand Up @@ -920,7 +932,7 @@ void ExtensionSet::SwapElements(int number, int index1, int index2) {
// ===================================================================

void ExtensionSet::Clear() {
ForEach([](int /* number */, Extension& ext) { ext.Clear(); });
ForEach([](int /* number */, Extension& ext) { ext.Clear(); }, Prefetch{});
}

namespace {
Expand Down Expand Up @@ -969,9 +981,11 @@ void ExtensionSet::MergeFrom(const MessageLite* extendee,
other.map_.large->end()));
}
}
other.ForEach([extendee, this, &other](int number, const Extension& ext) {
this->InternalExtensionMergeFrom(extendee, number, ext, other.arena_);
});
other.ForEach(
[extendee, this, &other](int number, const Extension& ext) {
this->InternalExtensionMergeFrom(extendee, number, ext, other.arena_);
},
Prefetch{});
}

void ExtensionSet::InternalExtensionMergeFrom(const MessageLite* extendee,
Expand All @@ -987,6 +1001,7 @@ void ExtensionSet::InternalExtensionMergeFrom(const MessageLite* extendee,
extension->type = other_extension.type;
extension->is_packed = other_extension.is_packed;
extension->is_repeated = true;
extension->is_pointer = true;
} else {
ABSL_DCHECK_EQ(extension->type, other_extension.type);
ABSL_DCHECK_EQ(extension->is_packed, other_extension.is_packed);
Expand Down Expand Up @@ -1049,6 +1064,7 @@ void ExtensionSet::InternalExtensionMergeFrom(const MessageLite* extendee,
extension->type = other_extension.type;
extension->is_packed = other_extension.is_packed;
extension->is_repeated = false;
extension->is_pointer = true;
if (other_extension.is_lazy) {
extension->is_lazy = true;
extension->ptr.lazymessage_value =
Expand Down Expand Up @@ -1226,6 +1242,13 @@ const char* ExtensionSet::ParseMessageSetItem(
metadata, ctx);
}

bool ExtensionSet::FieldTypeIsPointer(FieldType type) {
return type == WireFormatLite::TYPE_STRING ||
type == WireFormatLite::TYPE_BYTES ||
type == WireFormatLite::TYPE_GROUP ||
type == WireFormatLite::TYPE_MESSAGE;
}

uint8_t* ExtensionSet::_InternalSerializeImpl(
const MessageLite* extendee, int start_field_number, int end_field_number,
uint8_t* target, io::EpsCopyOutputStream* stream) const {
Expand All @@ -1252,19 +1275,23 @@ uint8_t* ExtensionSet::InternalSerializeMessageSetWithCachedSizesToArray(
const MessageLite* extendee, uint8_t* target,
io::EpsCopyOutputStream* stream) const {
const ExtensionSet* extension_set = this;
ForEach([&target, extendee, stream, extension_set](int number,
const Extension& ext) {
target = ext.InternalSerializeMessageSetItemWithCachedSizesToArray(
extendee, extension_set, number, target, stream);
});
ForEach(
[&target, extendee, stream, extension_set](int number,
const Extension& ext) {
target = ext.InternalSerializeMessageSetItemWithCachedSizesToArray(
extendee, extension_set, number, target, stream);
},
Prefetch{});
return target;
}

size_t ExtensionSet::ByteSize() const {
size_t total_size = 0;
ForEach([&total_size](int number, const Extension& ext) {
total_size += ext.ByteSize(number);
});
ForEach(
[&total_size](int number, const Extension& ext) {
total_size += ext.ByteSize(number);
},
Prefetch{});
return total_size;
}

Expand Down Expand Up @@ -1932,9 +1959,11 @@ size_t ExtensionSet::Extension::MessageSetItemByteSize(int number) const {

size_t ExtensionSet::MessageSetByteSize() const {
size_t total_size = 0;
ForEach([&total_size](int number, const Extension& ext) {
total_size += ext.MessageSetItemByteSize(number);
});
ForEach(
[&total_size](int number, const Extension& ext) {
total_size += ext.MessageSetItemByteSize(number);
},
Prefetch{});
return total_size;
}

Expand Down
96 changes: 85 additions & 11 deletions src/google/protobuf/extension_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

#include "google/protobuf/stubs/common.h"
#include "absl/base/call_once.h"
#include "absl/base/casts.h"
#include "absl/base/prefetch.h"
#include "absl/container/btree_map.h"
#include "absl/log/absl_check.h"
#include "google/protobuf/internal_visibility.h"
Expand Down Expand Up @@ -555,6 +557,8 @@ class PROTOBUF_EXPORT ExtensionSet {

friend void internal::InitializeLazyExtensionSet();

static bool FieldTypeIsPointer(FieldType type);

const int32_t& GetRefInt32(int number, const int32_t& default_value) const;
const int64_t& GetRefInt64(int number, const int64_t& default_value) const;
const uint32_t& GetRefUInt32(int number, const uint32_t& default_value) const;
Expand Down Expand Up @@ -670,6 +674,12 @@ class PROTOBUF_EXPORT ExtensionSet {
size_t SpaceUsedExcludingSelfLong() const;
bool IsInitialized(const ExtensionSet* ext_set, const MessageLite* extendee,
int number, Arena* arena) const;
const void* PrefetchPtr() const {
ABSL_DCHECK_EQ(is_pointer, is_repeated || FieldTypeIsPointer(type));
// We don't want to prefetch invalid/null pointers so if there isn't a
// pointer to prefetch, then return `this`.
return is_pointer ? absl::bit_cast<const void*>(ptr) : this;
}

// The order of these fields packs Extension into 24 bytes when using 8
// byte alignment. Consider this when adding or removing fields here.
Expand Down Expand Up @@ -708,20 +718,23 @@ class PROTOBUF_EXPORT ExtensionSet {
FieldType type;
bool is_repeated;

// Whether the extension is a pointer. This is used for prefetching.
bool is_pointer : 1;

// For singular types, indicates if the extension is "cleared". This
// happens when an extension is set and then later cleared by the caller.
// We want to keep the Extension object around for reuse, so instead of
// removing it from the map, we just set is_cleared = true. This has no
// meaning for repeated types; for those, the size of the RepeatedField
// simply becomes zero when cleared.
bool is_cleared : 4;
bool is_cleared : 1;

// For singular message types, indicates whether lazy parsing is enabled
// for this extension. This field is only valid when type == TYPE_MESSAGE
// and !is_repeated because we only support lazy parsing for singular
// message types currently. If is_lazy = true, the extension is stored in
// lazymessage_value. Otherwise, the extension will be message_value.
bool is_lazy : 4;
bool is_lazy : 1;

// For repeated types, this indicates if the [packed=true] option is set.
bool is_packed;
Expand Down Expand Up @@ -779,32 +792,93 @@ class PROTOBUF_EXPORT ExtensionSet {
return PROTOBUF_PREDICT_FALSE(is_large()) ? map_.large->size() : flat_size_;
}

// For use as `PrefetchFunctor`s in `ForEach`.
struct Prefetch {
void operator()(const void* ptr) const { absl::PrefetchToLocalCache(ptr); }
};
struct PrefetchNta {
void operator()(const void* ptr) const {
absl::PrefetchToLocalCacheNta(ptr);
}
};

template <typename Iterator, typename KeyValueFunctor,
typename PrefetchFunctor>
static KeyValueFunctor ForEachPrefetchImpl(Iterator it, Iterator end,
KeyValueFunctor func,
PrefetchFunctor prefetch_func) {
// Note: based on arena's ChunkList::Cleanup().
// Prefetch distance 16 performs better than 8 in load tests.
constexpr int kPrefetchDistance = 16;
Iterator prefetch = it;
// Prefetch the first kPrefetchDistance extensions.
for (int i = 0; prefetch != end && i < kPrefetchDistance; ++prefetch, ++i) {
prefetch_func(prefetch->second.PrefetchPtr());
}
// For the middle extensions, call func and then prefetch the extension
// kPrefetchDistance after the current one.
for (; prefetch != end; ++it, ++prefetch) {
func(it->first, it->second);
prefetch_func(prefetch->second.PrefetchPtr());
}
// Call func on the rest without prefetching.
for (; it != end; ++it) func(it->first, it->second);
return std::move(func);
}

// Similar to std::for_each.
// Each Iterator is decomposed into ->first and ->second fields, so
// that the KeyValueFunctor can be agnostic vis-a-vis KeyValue-vs-std::pair.
// Applies a functor to the <int, Extension&> pairs in sorted order and
// prefetches ahead.
template <typename KeyValueFunctor, typename PrefetchFunctor>
KeyValueFunctor ForEach(KeyValueFunctor func, PrefetchFunctor prefetch_func) {
if (PROTOBUF_PREDICT_FALSE(is_large())) {
return ForEachPrefetchImpl(map_.large->begin(), map_.large->end(),
std::move(func), std::move(prefetch_func));
}
return ForEachPrefetchImpl(flat_begin(), flat_end(), std::move(func),
std::move(prefetch_func));
}
// As above, but const.
template <typename KeyValueFunctor, typename PrefetchFunctor>
KeyValueFunctor ForEach(KeyValueFunctor func,
PrefetchFunctor prefetch_func) const {
if (PROTOBUF_PREDICT_FALSE(is_large())) {
return ForEachPrefetchImpl(map_.large->begin(), map_.large->end(),
std::move(func), std::move(prefetch_func));
}
return ForEachPrefetchImpl(flat_begin(), flat_end(), std::move(func),
std::move(prefetch_func));
}

// As above, but without prefetching. This is for use in cases where we never
// use the pointed-to extension values in `func`.
template <typename Iterator, typename KeyValueFunctor>
static KeyValueFunctor ForEach(Iterator begin, Iterator end,
KeyValueFunctor func) {
static KeyValueFunctor ForEachNoPrefetch(Iterator begin, Iterator end,
KeyValueFunctor func) {
for (Iterator it = begin; it != end; ++it) func(it->first, it->second);
return std::move(func);
}

// Applies a functor to the <int, Extension&> pairs in sorted order.
template <typename KeyValueFunctor>
KeyValueFunctor ForEach(KeyValueFunctor func) {
KeyValueFunctor ForEachNoPrefetch(KeyValueFunctor func) {
if (PROTOBUF_PREDICT_FALSE(is_large())) {
return ForEach(map_.large->begin(), map_.large->end(), std::move(func));
return ForEachNoPrefetch(map_.large->begin(), map_.large->end(),
std::move(func));
}
return ForEach(flat_begin(), flat_end(), std::move(func));
return ForEachNoPrefetch(flat_begin(), flat_end(), std::move(func));
}

// Applies a functor to the <int, const Extension&> pairs in sorted order.
// As above, but const.
template <typename KeyValueFunctor>
KeyValueFunctor ForEach(KeyValueFunctor func) const {
KeyValueFunctor ForEachNoPrefetch(KeyValueFunctor func) const {
if (PROTOBUF_PREDICT_FALSE(is_large())) {
return ForEach(map_.large->begin(), map_.large->end(), std::move(func));
return ForEachNoPrefetch(map_.large->begin(), map_.large->end(),
std::move(func));
}
return ForEach(flat_begin(), flat_end(), std::move(func));
return ForEachNoPrefetch(flat_begin(), flat_end(), std::move(func));
}

// Merges existing Extension from other_extension
Expand Down
Loading

0 comments on commit 9b019ee

Please sign in to comment.