diff --git a/src/google/protobuf/extension_set_inl.h b/src/google/protobuf/extension_set_inl.h index 95936cc2439b..e4e711721d4d 100644 --- a/src/google/protobuf/extension_set_inl.h +++ b/src/google/protobuf/extension_set_inl.h @@ -206,16 +206,21 @@ const char* ExtensionSet::ParseMessageSetItemTmpl( const char* ptr, const Msg* extendee, internal::InternalMetadata* metadata, internal::ParseContext* ctx) { std::string payload; - uint32_t type_id = 0; - bool payload_read = false; + uint32_t type_id; + enum class State { kNoTag, kHasType, kHasPayload, kDone }; + State state = State::kNoTag; + while (!ctx->Done(&ptr)) { uint32_t tag = static_cast(*ptr++); if (tag == WireFormatLite::kMessageSetTypeIdTag) { uint64_t tmp; ptr = ParseBigVarint(ptr, &tmp); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); - type_id = tmp; - if (payload_read) { + if (state == State::kNoTag) { + type_id = tmp; + state = State::kHasType; + } else if (state == State::kHasPayload) { + type_id = tmp; ExtensionInfo extension; bool was_packed_on_wire; if (!FindExtension(2, type_id, extendee, ctx, &extension, @@ -241,20 +246,24 @@ const char* ExtensionSet::ParseMessageSetItemTmpl( GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) && tmp_ctx.EndedAtLimit()); } - type_id = 0; + state = State::kDone; } } else if (tag == WireFormatLite::kMessageSetMessageTag) { - if (type_id != 0) { + if (state == State::kHasType) { ptr = ParseFieldMaybeLazily(static_cast(type_id) * 8 + 2, ptr, extendee, metadata, ctx); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr); - type_id = 0; + state = State::kDone; } else { + std::string tmp; int32_t size = ReadSize(&ptr); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); - ptr = ctx->ReadString(ptr, size, &payload); + ptr = ctx->ReadString(ptr, size, &tmp); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); - payload_read = true; + if (state == State::kNoTag) { + payload = std::move(tmp); + state = State::kHasPayload; + } } } else { ptr = ReadTag(ptr - 1, &tag); diff --git a/src/google/protobuf/wire_format.cc b/src/google/protobuf/wire_format.cc index e44c6ebeee7f..6fe63c86cebb 100644 --- a/src/google/protobuf/wire_format.cc +++ b/src/google/protobuf/wire_format.cc @@ -657,9 +657,11 @@ struct WireFormat::MessageSetParser { const char* _InternalParse(const char* ptr, internal::ParseContext* ctx) { // Parse a MessageSetItem auto metadata = reflection->MutableInternalMetadata(msg); + enum class State { kNoTag, kHasType, kHasPayload, kDone }; + State state = State::kNoTag; + std::string payload; uint32_t type_id = 0; - bool payload_read = false; while (!ctx->Done(&ptr)) { // We use 64 bit tags in order to allow typeid's that span the whole // range of 32 bit numbers. @@ -668,8 +670,11 @@ struct WireFormat::MessageSetParser { uint64_t tmp; ptr = ParseBigVarint(ptr, &tmp); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); - type_id = tmp; - if (payload_read) { + if (state == State::kNoTag) { + type_id = tmp; + state = State::kHasType; + } else if (state == State::kHasPayload) { + type_id = tmp; const FieldDescriptor* field; if (ctx->data().pool == nullptr) { field = reflection->FindKnownExtensionByNumber(type_id); @@ -696,17 +701,17 @@ struct WireFormat::MessageSetParser { GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) && tmp_ctx.EndedAtLimit()); } - type_id = 0; + state = State::kDone; } continue; } else if (tag == WireFormatLite::kMessageSetMessageTag) { - if (type_id == 0) { + if (state == State::kNoTag) { int32_t size = ReadSize(&ptr); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); ptr = ctx->ReadString(ptr, size, &payload); GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); - payload_read = true; - } else { + state = State::kHasPayload; + } else if (state == State::kHasType) { // We're now parsing the payload const FieldDescriptor* field = nullptr; if (descriptor->IsExtensionNumber(type_id)) { @@ -720,7 +725,12 @@ struct WireFormat::MessageSetParser { ptr = WireFormat::_InternalParseAndMergeField( msg, ptr, ctx, static_cast(type_id) * 8 + 2, reflection, field); - type_id = 0; + state = State::kDone; + } else { + int32_t size = ReadSize(&ptr); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + ptr = ctx->Skip(ptr, size); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); } } else { // An unknown field in MessageSetItem. diff --git a/src/google/protobuf/wire_format_lite.h b/src/google/protobuf/wire_format_lite.h index 32fe0c7711bc..ef83aeb1d422 100644 --- a/src/google/protobuf/wire_format_lite.h +++ b/src/google/protobuf/wire_format_lite.h @@ -1830,6 +1830,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) { // we can parse it later. std::string message_data; + enum class State { kNoTag, kHasType, kHasPayload, kDone }; + State state = State::kNoTag; + while (true) { const uint32_t tag = input->ReadTagNoLastTag(); if (tag == 0) return false; @@ -1838,26 +1841,34 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) { case WireFormatLite::kMessageSetTypeIdTag: { uint32_t type_id; if (!input->ReadVarint32(&type_id)) return false; - last_type_id = type_id; - - if (!message_data.empty()) { + if (state == State::kNoTag) { + last_type_id = type_id; + state = State::kHasType; + } else if (state == State::kHasPayload) { // We saw some message data before the type_id. Have to parse it // now. io::CodedInputStream sub_input( reinterpret_cast(message_data.data()), static_cast(message_data.size())); sub_input.SetRecursionLimit(input->RecursionBudget()); - if (!ms.ParseField(last_type_id, &sub_input)) { + if (!ms.ParseField(type_id, &sub_input)) { return false; } message_data.clear(); + state = State::kDone; } break; } case WireFormatLite::kMessageSetMessageTag: { - if (last_type_id == 0) { + if (state == State::kHasType) { + // Already saw type_id, so we can parse this directly. + if (!ms.ParseField(last_type_id, input)) { + return false; + } + state = State::kDone; + } else if (state == State::kNoTag) { // We haven't seen a type_id yet. Append this data to message_data. uint32_t length; if (!input->ReadVarint32(&length)) return false; @@ -1868,11 +1879,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) { auto ptr = reinterpret_cast(&message_data[0]); ptr = io::CodedOutputStream::WriteVarint32ToArray(length, ptr); if (!input->ReadRaw(ptr, length)) return false; + state = State::kHasPayload; } else { - // Already saw type_id, so we can parse this directly. - if (!ms.ParseField(last_type_id, input)) { - return false; - } + if (!ms.SkipField(tag, input)) return false; } break; diff --git a/src/google/protobuf/wire_format_unittest.inc b/src/google/protobuf/wire_format_unittest.inc index 4b7862ce24fa..7b5aa5671f88 100644 --- a/src/google/protobuf/wire_format_unittest.inc +++ b/src/google/protobuf/wire_format_unittest.inc @@ -581,28 +581,54 @@ TEST(WireFormatTest, ParseMessageSet) { EXPECT_EQ(message_set.DebugString(), dynamic_message_set.DebugString()); } -TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) { +namespace { +std::string BuildMessageSetItemStart() { std::string data; { - UNITTEST::TestMessageSetExtension1 message; - message.set_i(123); - // Build a MessageSet manually with its message content put before its - // type_id. io::StringOutputStream output_stream(&data); io::CodedOutputStream coded_output(&output_stream); coded_output.WriteTag(WireFormatLite::kMessageSetItemStartTag); + } + return data; +} +std::string BuildMessageSetItemEnd() { + std::string data; + { + io::StringOutputStream output_stream(&data); + io::CodedOutputStream coded_output(&output_stream); + coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag); + } + return data; +} +std::string BuildMessageSetTestExtension1(int value = 123) { + std::string data; + { + UNITTEST::TestMessageSetExtension1 message; + message.set_i(value); + io::StringOutputStream output_stream(&data); + io::CodedOutputStream coded_output(&output_stream); // Write the message content first. WireFormatLite::WriteTag(WireFormatLite::kMessageSetMessageNumber, WireFormatLite::WIRETYPE_LENGTH_DELIMITED, &coded_output); coded_output.WriteVarint32(message.ByteSizeLong()); message.SerializeWithCachedSizes(&coded_output); - // Write the type id. - uint32_t type_id = message.GetDescriptor()->extension(0)->number(); + } + return data; +} +std::string BuildMessageSetItemTypeId(int extension_number) { + std::string data; + { + io::StringOutputStream output_stream(&data); + io::CodedOutputStream coded_output(&output_stream); WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber, - type_id, &coded_output); - coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag); + extension_number, &coded_output); } + return data; +} +void ValidateTestMessageSet(const std::string& test_case, + const std::string& data) { + SCOPED_TRACE(test_case); { PROTO2_WIREFORMAT_UNITTEST::TestMessageSet message_set; ASSERT_TRUE(message_set.ParseFromString(data)); @@ -612,6 +638,11 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) { .GetExtension( UNITTEST::TestMessageSetExtension1::message_set_extension) .i()); + + // Make sure it does not contain anything else. + message_set.ClearExtension( + UNITTEST::TestMessageSetExtension1::message_set_extension); + EXPECT_EQ(message_set.SerializeAsString(), ""); } { // Test parse the message via Reflection. @@ -627,6 +658,61 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) { UNITTEST::TestMessageSetExtension1::message_set_extension) .i()); } + { + // Test parse the message via DynamicMessage. + DynamicMessageFactory factory; + std::unique_ptr msg( + factory + .GetPrototype( + PROTO2_WIREFORMAT_UNITTEST::TestMessageSet::descriptor()) + ->New()); + msg->ParseFromString(data); + auto* reflection = msg->GetReflection(); + std::vector fields; + reflection->ListFields(*msg, &fields); + ASSERT_EQ(fields.size(), 1); + const auto& sub = reflection->GetMessage(*msg, fields[0]); + reflection = sub.GetReflection(); + EXPECT_EQ(123, reflection->GetInt32( + sub, sub.GetDescriptor()->FindFieldByName("i"))); + } +} +} // namespace + +TEST(WireFormatTest, ParseMessageSetWithAnyTagOrder) { + std::string start = BuildMessageSetItemStart(); + std::string end = BuildMessageSetItemEnd(); + std::string id = BuildMessageSetItemTypeId( + UNITTEST::TestMessageSetExtension1::descriptor()->extension(0)->number()); + std::string message = BuildMessageSetTestExtension1(); + + ValidateTestMessageSet("id + message", start + id + message + end); + ValidateTestMessageSet("message + id", start + message + id + end); +} + +TEST(WireFormatTest, ParseMessageSetWithDuplicateTags) { + std::string start = BuildMessageSetItemStart(); + std::string end = BuildMessageSetItemEnd(); + std::string id = BuildMessageSetItemTypeId( + UNITTEST::TestMessageSetExtension1::descriptor()->extension(0)->number()); + std::string other_id = BuildMessageSetItemTypeId(123456); + std::string message = BuildMessageSetTestExtension1(); + std::string other_message = BuildMessageSetTestExtension1(321); + + // Double id + ValidateTestMessageSet("id + other_id + message", + start + id + other_id + message + end); + ValidateTestMessageSet("id + message + other_id", + start + id + message + other_id + end); + ValidateTestMessageSet("message + id + other_id", + start + message + id + other_id + end); + // Double message + ValidateTestMessageSet("id + message + other_message", + start + id + message + other_message + end); + ValidateTestMessageSet("message + id + other_message", + start + message + id + other_message + end); + ValidateTestMessageSet("message + other_message + id", + start + message + other_message + id + end); } void SerializeReverseOrder(