From a3f5af066fcb16bc2e90e28c006597aef1904eab Mon Sep 17 00:00:00 2001 From: Protobuf Team Bot Date: Tue, 20 Feb 2024 09:35:45 -0800 Subject: [PATCH] Add GetStringView() and GetRepeatedStringView() with scratch. PiperOrigin-RevId: 608636288 --- .../protobuf/generated_message_reflection.cc | 44 +++++++++ .../generated_message_reflection_unittest.cc | 93 +++++++++++++++++++ src/google/protobuf/message.h | 38 ++++++++ src/google/protobuf/string_view_test.cc | 12 +++ 4 files changed, 187 insertions(+) diff --git a/src/google/protobuf/generated_message_reflection.cc b/src/google/protobuf/generated_message_reflection.cc index 340bcaad706b9..3f328dd642550 100644 --- a/src/google/protobuf/generated_message_reflection.cc +++ b/src/google/protobuf/generated_message_reflection.cc @@ -1868,6 +1868,32 @@ absl::Cord Reflection::GetCord(const Message& message, } } +absl::string_view Reflection::GetStringView(const Message& message, + const FieldDescriptor* field, + ScratchSpace& scratch) const { + USAGE_CHECK_ALL(GetStringView, SINGULAR, STRING); + + if (field->is_extension()) { + return GetExtensionSet(message).GetString(field->number(), + field->default_value_string()); + } + if (schema_.InRealOneof(field) && !HasOneofField(message, field)) { + return field->default_value_string(); + } + + switch (internal::cpp::EffectiveStringCType(field)) { + case FieldOptions::CORD: { + const auto& cord = schema_.InRealOneof(field) + ? *GetField(message, field) + : GetField(message, field); + return scratch.CopyFromCord(cord); + } + default: + auto str = GetField(message, field); + return str.IsDefault() ? field->default_value_string() : str.Get(); + } +} + void Reflection::SetString(Message* message, const FieldDescriptor* field, std::string value) const { @@ -2002,6 +2028,24 @@ const std::string& Reflection::GetRepeatedStringReference( } } +// See GetStringView(), above. +absl::string_view Reflection::GetRepeatedStringView( + const Message& message, const FieldDescriptor* field, int index, + ScratchSpace& scratch) const { + (void)scratch; + USAGE_CHECK_ALL(GetRepeatedStringView, REPEATED, STRING); + + if (field->is_extension()) { + return GetExtensionSet(message).GetRepeatedString(field->number(), index); + } + + switch (internal::cpp::EffectiveStringCType(field)) { + case FieldOptions::STRING: + default: + return GetRepeatedPtrField(message, field, index); + } +} + void Reflection::SetRepeatedString(Message* message, const FieldDescriptor* field, int index, diff --git a/src/google/protobuf/generated_message_reflection_unittest.cc b/src/google/protobuf/generated_message_reflection_unittest.cc index 4925f53ada434..7879dafeabe60 100644 --- a/src/google/protobuf/generated_message_reflection_unittest.cc +++ b/src/google/protobuf/generated_message_reflection_unittest.cc @@ -162,6 +162,99 @@ TEST(GeneratedMessageReflectionTest, GetStringReferenceCopy) { &cord_scratch)); } +TEST(GeneratedMessageReflectionTest, GetStringView) { + unittest::TestAllTypes message; + TestUtil::SetAllFields(&message); + + const Reflection* reflection = message.GetReflection(); + Reflection::ScratchSpace scratch; + + EXPECT_EQ("115", + reflection->GetStringView(message, F("optional_string"), scratch)); + EXPECT_EQ("124", reflection->GetStringView( + message, F("optional_string_piece"), scratch)); + EXPECT_EQ("125", + reflection->GetStringView(message, F("optional_cord"), scratch)); +} + +TEST(GeneratedMessageReflectionTest, GetStringViewWithExtensions) { + unittest::TestAllExtensions message; + google::protobuf::FileDescriptor const* descriptor_file = + message.GetDescriptor()->file(); + google::protobuf::FieldDescriptor const* string_ext = + descriptor_file->FindExtensionByName("optional_string_extension"); + google::protobuf::FieldDescriptor const* string_piece_ext = + descriptor_file->FindExtensionByName("optional_string_piece_extension"); + google::protobuf::FieldDescriptor const* cord_ext = + descriptor_file->FindExtensionByName("optional_cord_extension"); + message.SetExtension(protobuf_unittest::optional_string_extension, "foo"); + message.SetExtension(protobuf_unittest::optional_string_piece_extension, "bar"); + message.SetExtension(protobuf_unittest::optional_cord_extension, "baz"); + const Reflection* reflection = message.GetReflection(); + Reflection::ScratchSpace scratch; + + EXPECT_EQ("foo", reflection->GetStringView(message, string_ext, scratch)); + EXPECT_EQ("bar", + reflection->GetStringView(message, string_piece_ext, scratch)); + EXPECT_EQ("baz", reflection->GetStringView(message, cord_ext, scratch)); +} + +TEST(GeneratedMessageReflectionTest, GetStringViewWithOneof) { + unittest::TestOneof2 message; + const Reflection* reflection = message.GetReflection(); + const FieldDescriptor* string_field = + message.GetDescriptor()->FindFieldByName("foo_string"); + const FieldDescriptor* string_piece_field = + message.GetDescriptor()->FindFieldByName("foo_string_piece"); + Reflection::ScratchSpace scratch; + + message.set_foo_string("foo"); + EXPECT_EQ("foo", reflection->GetStringView(message, string_field, scratch)); + EXPECT_EQ("", + reflection->GetStringView(message, string_piece_field, scratch)); + +} + +TEST(GeneratedMessageReflectionTest, GetRepeatedStringView) { + unittest::TestAllTypes message; + TestUtil::AddRepeatedFields1(&message); + TestUtil::AddRepeatedFields2(&message); + + const Reflection* reflection = message.GetReflection(); + Reflection::ScratchSpace scratch; + + EXPECT_EQ("215", reflection->GetRepeatedStringView( + message, F("repeated_string"), 0, scratch)); + EXPECT_EQ("224", reflection->GetRepeatedStringView( + message, F("repeated_string_piece"), 0, scratch)); + EXPECT_EQ("225", reflection->GetRepeatedStringView( + message, F("repeated_cord"), 0, scratch)); +} + +TEST(GeneratedMessageReflectionTest, GetRepeatedStringViewWithExtensions) { + unittest::TestAllExtensions message; + google::protobuf::FileDescriptor const* descriptor_file = + message.GetDescriptor()->file(); + google::protobuf::FieldDescriptor const* string_ext = + descriptor_file->FindExtensionByName("repeated_string_extension"); + google::protobuf::FieldDescriptor const* string_piece_ext = + descriptor_file->FindExtensionByName("repeated_string_piece_extension"); + google::protobuf::FieldDescriptor const* cord_ext = + descriptor_file->FindExtensionByName("repeated_cord_extension"); + message.AddExtension(protobuf_unittest::repeated_string_extension, "foo"); + message.AddExtension(protobuf_unittest::repeated_string_piece_extension, "bar"); + message.AddExtension(protobuf_unittest::repeated_cord_extension, "baz"); + const Reflection* reflection = message.GetReflection(); + Reflection::ScratchSpace scratch; + + EXPECT_EQ("foo", + reflection->GetRepeatedStringView(message, string_ext, 0, scratch)); + EXPECT_EQ("bar", reflection->GetRepeatedStringView(message, string_piece_ext, + 0, scratch)); + EXPECT_EQ("baz", + reflection->GetRepeatedStringView(message, cord_ext, 0, scratch)); +} + class GeneratedMessageReflectionSwapTest : public testing::TestWithParam { protected: diff --git a/src/google/protobuf/message.h b/src/google/protobuf/message.h index 2d08e9a8b1d76..e2a8d418055bc 100644 --- a/src/google/protobuf/message.h +++ b/src/google/protobuf/message.h @@ -89,6 +89,7 @@ #include #include +#include #include #include #include @@ -97,6 +98,7 @@ #include "absl/base/call_once.h" #include "google/protobuf/stubs/common.h" #include "absl/log/absl_check.h" +#include "absl/memory/memory.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "google/protobuf/arena.h" @@ -585,6 +587,37 @@ class PROTOBUF_EXPORT Reflection final { absl::Cord GetCord(const Message& message, const FieldDescriptor* field) const; + // Enables GetStringView() and GetRepeatedStringView() APIs to return + // absl::string_view even though the underlying implementation doesn't have + // contiguous bytes; e.g. absl::Cord. + class ScratchSpace { + public: + ScratchSpace() = default; + + ScratchSpace(const ScratchSpace&) = delete; + ScratchSpace& operator=(const ScratchSpace&) = delete; + + private: + friend class Reflection; + + absl::string_view CopyFromCord(const absl::Cord& cord) { + if (!buffer_) { + buffer_ = absl::make_unique(); + } + absl::CopyCordToString(cord, buffer_.get()); + return *buffer_; + } + + std::unique_ptr buffer_; + }; + + // Returns a view into the contents of a string field. "scratch" is used to + // flatten bytes if it is non-contiguous. The lifetime of absl::string_view is + // either tied to "message" (contiguous) or "scratch" (otherwise). + absl::string_view GetStringView( + const Message& message, const FieldDescriptor* field, + ScratchSpace& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) const; + // Singular field mutators ----------------------------------------- // These mutate the value of a non-repeated field. @@ -706,6 +739,11 @@ class PROTOBUF_EXPORT Reflection final { int index, std::string* scratch) const; + // See GetStringView(), above. + absl::string_view GetRepeatedStringView( + const Message& message, const FieldDescriptor* field, int index, + ScratchSpace& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) const; + // Repeated field mutators ----------------------------------------- // These mutate the value of one element of a repeated field. diff --git a/src/google/protobuf/string_view_test.cc b/src/google/protobuf/string_view_test.cc index 4415bd5ee4288..1cdb860d94416 100644 --- a/src/google/protobuf/string_view_test.cc +++ b/src/google/protobuf/string_view_test.cc @@ -109,6 +109,9 @@ TEST(StringViewFieldTest, SingularSetAndGetByReflection) { reflection->SetString(&message, field, std::string{STRING_PAYLOAD}); EXPECT_THAT(reflection->GetString(message, field), StrEq(STRING_PAYLOAD)); + Reflection::ScratchSpace scratch; + EXPECT_THAT(reflection->GetStringView(message, field, scratch), + StrEq(STRING_PAYLOAD)); EXPECT_THAT(message.singular_string(), StrEq(STRING_PAYLOAD)); } @@ -292,6 +295,15 @@ TEST(StringViewFieldTest, RepeatedSetAndGetByReflection) { EXPECT_THAT(rep_str, ElementsAre("000000000000", "111111111111", "222222222222")); } + + // GetRepeatedStringView. + Reflection::ScratchSpace scratch; + EXPECT_THAT(reflection->GetRepeatedStringView(message, field, 0, scratch), + StrEq("000000000000")); + EXPECT_THAT(reflection->GetRepeatedStringView(message, field, 1, scratch), + StrEq("111111111111")); + EXPECT_THAT(reflection->GetRepeatedStringView(message, field, 2, scratch), + StrEq("222222222222")); } } // namespace