From 8bfa80b5422cc22d64bfba5aa3049b9d5143c99b Mon Sep 17 00:00:00 2001 From: Protobuf Team Bot Date: Tue, 17 Sep 2024 12:03:36 -0700 Subject: [PATCH] Add recursion check when parsing unknown fields in Java. PiperOrigin-RevId: 675657198 --- java/core/BUILD.bazel | 2 + .../com/google/protobuf/ArrayDecoders.java | 28 +++ .../com/google/protobuf/CodedInputStream.java | 6 + .../com/google/protobuf/MessageSchema.java | 12 +- .../com/google/protobuf/MessageSetSchema.java | 3 +- .../google/protobuf/UnknownFieldSchema.java | 29 ++- .../google/protobuf/CodedInputStreamTest.java | 158 ++++++++++++ .../java/com/google/protobuf/LiteTest.java | 232 ++++++++++++++++++ 8 files changed, 458 insertions(+), 12 deletions(-) diff --git a/java/core/BUILD.bazel b/java/core/BUILD.bazel index 6fe4ccff6af4..1f7b0cb5b8bb 100644 --- a/java/core/BUILD.bazel +++ b/java/core/BUILD.bazel @@ -608,6 +608,7 @@ junit_tests( "src/test/java/com/google/protobuf/DescriptorsTest.java", "src/test/java/com/google/protobuf/DebugFormatTest.java", "src/test/java/com/google/protobuf/CodedOutputStreamTest.java", + "src/test/java/com/google/protobuf/CodedInputStreamTest.java", "src/test/java/com/google/protobuf/ProtobufToStringOutputTest.java", # Excluded in core_tests "src/test/java/com/google/protobuf/DecodeUtf8Test.java", @@ -656,6 +657,7 @@ junit_tests( "src/test/java/com/google/protobuf/DescriptorsTest.java", "src/test/java/com/google/protobuf/DebugFormatTest.java", "src/test/java/com/google/protobuf/CodedOutputStreamTest.java", + "src/test/java/com/google/protobuf/CodedInputStreamTest.java", "src/test/java/com/google/protobuf/ProtobufToStringOutputTest.java", # Excluded in core_tests "src/test/java/com/google/protobuf/DecodeUtf8Test.java", diff --git a/java/core/src/main/java/com/google/protobuf/ArrayDecoders.java b/java/core/src/main/java/com/google/protobuf/ArrayDecoders.java index 9bf14396263a..bf5f922b073c 100644 --- a/java/core/src/main/java/com/google/protobuf/ArrayDecoders.java +++ b/java/core/src/main/java/com/google/protobuf/ArrayDecoders.java @@ -23,6 +23,10 @@ */ @CheckReturnValue final class ArrayDecoders { + static final int DEFAULT_RECURSION_LIMIT = 100; + + @SuppressWarnings("NonFinalStaticField") + private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT; private ArrayDecoders() {} @@ -37,6 +41,7 @@ static final class Registers { public long long1; public Object object1; public final ExtensionRegistryLite extensionRegistry; + public int recursionDepth; Registers() { this.extensionRegistry = ExtensionRegistryLite.getEmptyRegistry(); @@ -244,7 +249,10 @@ static int mergeMessageField( if (length < 0 || length > limit - position) { throw InvalidProtocolBufferException.truncatedMessage(); } + registers.recursionDepth++; + checkRecursionLimit(registers.recursionDepth); schema.mergeFrom(msg, data, position, position + length, registers); + registers.recursionDepth--; registers.object1 = msg; return position + length; } @@ -262,8 +270,11 @@ static int mergeGroupField( // A group field must has a MessageSchema (the only other subclass of Schema is MessageSetSchema // and it can't be used in group fields). final MessageSchema messageSchema = (MessageSchema) schema; + registers.recursionDepth++; + checkRecursionLimit(registers.recursionDepth); final int endPosition = messageSchema.parseMessage(msg, data, position, limit, endGroup, registers); + registers.recursionDepth--; registers.object1 = msg; return endPosition; } @@ -1024,6 +1035,8 @@ static int decodeUnknownField( final UnknownFieldSetLite child = UnknownFieldSetLite.newInstance(); final int endGroup = (tag & ~0x7) | WireFormat.WIRETYPE_END_GROUP; int lastTag = 0; + registers.recursionDepth++; + checkRecursionLimit(registers.recursionDepth); while (position < limit) { position = decodeVarint32(data, position, registers); lastTag = registers.int1; @@ -1032,6 +1045,7 @@ static int decodeUnknownField( } position = decodeUnknownField(lastTag, data, position, limit, child, registers); } + registers.recursionDepth--; if (position > limit || lastTag != endGroup) { throw InvalidProtocolBufferException.parseFailure(); } @@ -1078,4 +1092,18 @@ static int skipField(int tag, byte[] data, int position, int limit, Registers re throw InvalidProtocolBufferException.invalidTag(); } } + + /** + * Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if + * the depth of the message exceeds this limit. + */ + public static void setRecursionLimit(int limit) { + recursionLimit = limit; + } + + private static void checkRecursionLimit(int depth) throws InvalidProtocolBufferException { + if (depth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + } } diff --git a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java index a0f3a78c5227..6b3573bf92f1 100644 --- a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java +++ b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java @@ -230,7 +230,10 @@ public void skipMessage() throws IOException { if (tag == 0) { return; } + checkRecursionLimit(); + ++recursionDepth; boolean fieldSkipped = skipField(tag); + --recursionDepth; if (!fieldSkipped) { return; } @@ -247,7 +250,10 @@ public void skipMessage(CodedOutputStream output) throws IOException { if (tag == 0) { return; } + checkRecursionLimit(); + ++recursionDepth; boolean fieldSkipped = skipField(tag, output); + --recursionDepth; if (!fieldSkipped) { return; } diff --git a/java/core/src/main/java/com/google/protobuf/MessageSchema.java b/java/core/src/main/java/com/google/protobuf/MessageSchema.java index f8f79fcdf8b4..5ad6762b0dc4 100644 --- a/java/core/src/main/java/com/google/protobuf/MessageSchema.java +++ b/java/core/src/main/java/com/google/protobuf/MessageSchema.java @@ -3006,8 +3006,8 @@ private > void mergeFromHelper( unknownFields = unknownFieldSchema.getBuilderFromMessage(message); } // Unknown field. - - if (unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) { + if (unknownFieldSchema.mergeOneFieldFrom( + unknownFields, reader, /* currentDepth= */ 0)) { continue; } } @@ -3382,8 +3382,8 @@ private > void mergeFromHelper( if (unknownFields == null) { unknownFields = unknownFieldSchema.getBuilderFromMessage(message); } - - if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) { + if (!unknownFieldSchema.mergeOneFieldFrom( + unknownFields, reader, /* currentDepth= */ 0)) { return; } break; @@ -3399,8 +3399,8 @@ private > void mergeFromHelper( if (unknownFields == null) { unknownFields = unknownFieldSchema.getBuilderFromMessage(message); } - - if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) { + if (!unknownFieldSchema.mergeOneFieldFrom( + unknownFields, reader, /* currentDepth= */ 0)) { return; } } diff --git a/java/core/src/main/java/com/google/protobuf/MessageSetSchema.java b/java/core/src/main/java/com/google/protobuf/MessageSetSchema.java index a17037e8efd4..ec37d41f98c5 100644 --- a/java/core/src/main/java/com/google/protobuf/MessageSetSchema.java +++ b/java/core/src/main/java/com/google/protobuf/MessageSetSchema.java @@ -278,8 +278,7 @@ boolean parseMessageSetItemOrUnknownField( reader, extension, extensionRegistry, extensions); return true; } else { - - return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader); + return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader, /* currentDepth= */ 0); } } else { return reader.skipField(); diff --git a/java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java b/java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java index a43bc2a9472d..80602b16359a 100644 --- a/java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java +++ b/java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java @@ -13,6 +13,11 @@ @CheckReturnValue abstract class UnknownFieldSchema { + static final int DEFAULT_RECURSION_LIMIT = 100; + + @SuppressWarnings("NonFinalStaticField") + private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT; + /** Whether unknown fields should be dropped. */ abstract boolean shouldDiscardUnknownFields(Reader reader); @@ -55,7 +60,9 @@ abstract class UnknownFieldSchema { /** Marks unknown fields as immutable. */ abstract void makeImmutable(Object message); - final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOException { + /** Merges one field into the unknown fields. */ + final boolean mergeOneFieldFrom(B unknownFields, Reader reader, int currentDepth) + throws IOException { int tag = reader.getTag(); int fieldNumber = WireFormat.getTagFieldNumber(tag); switch (WireFormat.getTagWireType(tag)) { @@ -74,7 +81,12 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti case WireFormat.WIRETYPE_START_GROUP: final B subFields = newBuilder(); int endGroupTag = WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP); - mergeFrom(subFields, reader); + currentDepth++; + if (currentDepth >= recursionLimit) { + throw InvalidProtocolBufferException.recursionLimitExceeded(); + } + mergeFrom(subFields, reader, currentDepth); + currentDepth--; if (endGroupTag != reader.getTag()) { throw InvalidProtocolBufferException.invalidEndTag(); } @@ -87,10 +99,11 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti } } - private final void mergeFrom(B unknownFields, Reader reader) throws IOException { + private final void mergeFrom(B unknownFields, Reader reader, int currentDepth) + throws IOException { while (true) { if (reader.getFieldNumber() == Reader.READ_DONE - || !mergeOneFieldFrom(unknownFields, reader)) { + || !mergeOneFieldFrom(unknownFields, reader, currentDepth)) { break; } } @@ -107,4 +120,12 @@ private final void mergeFrom(B unknownFields, Reader reader) throws IOException abstract int getSerializedSizeAsMessageSet(T message); abstract int getSerializedSize(T unknowns); + + /** + * Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if + * the depth of the message exceeds this limit. + */ + public void setRecursionLimit(int limit) { + recursionLimit = limit; + } } diff --git a/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java b/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java index ff700587a160..f73cb3b0eecf 100644 --- a/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java +++ b/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java @@ -11,6 +11,9 @@ import static com.google.common.truth.Truth.assertWithMessage; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertThrows; + +import com.google.common.primitives.Bytes; +import map_test.MapTestProto.MapContainer; import protobuf_unittest.UnittestProto.BoolMessage; import protobuf_unittest.UnittestProto.Int32Message; import protobuf_unittest.UnittestProto.Int64Message; @@ -35,6 +38,13 @@ public class CodedInputStreamTest { private static final int DEFAULT_BLOCK_SIZE = 4096; + private static final int GROUP_TAP = WireFormat.makeTag(3, WireFormat.WIRETYPE_START_GROUP); + + private static final byte[] NESTING_SGROUP = generateSGroupTags(); + + private static final byte[] NESTING_SGROUP_WITH_INITIAL_BYTES = generateSGroupTagsForMapField(); + + private enum InputType { ARRAY { @Override @@ -117,6 +127,17 @@ private byte[] bytes(int... bytesAsInts) { return bytes; } + private static byte[] generateSGroupTags() { + byte[] bytes = new byte[100000]; + Arrays.fill(bytes, (byte) GROUP_TAP); + return bytes; + } + + private static byte[] generateSGroupTagsForMapField() { + byte[] initialBytes = {18, 1, 75, 26, (byte) 198, (byte) 154, 12}; + return Bytes.concat(initialBytes, NESTING_SGROUP); + } + /** * An InputStream which limits the number of bytes it reads at a time. We use this to make sure * that CodedInputStream doesn't screw up when reading in small blocks. @@ -740,6 +761,143 @@ public void testMaliciousRecursion() throws Exception { } } + @Test + public void testMaliciousRecursion_unknownFields() throws Exception { + Throwable thrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> TestRecursiveMessage.parseFrom(NESTING_SGROUP)); + + assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); + } + + @Test + public void testMaliciousRecursion_skippingUnknownField() throws Exception { + Throwable thrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> + DiscardUnknownFieldsParser.wrap(TestRecursiveMessage.parser()) + .parseFrom(NESTING_SGROUP)); + + assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); + } + + @Test + public void testMaliciousSGroupTagsWithMapField_fromInputStream() throws Exception { + Throwable parseFromThrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> + MapContainer.parseFrom( + new ByteArrayInputStream(NESTING_SGROUP_WITH_INITIAL_BYTES))); + Throwable mergeFromThrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> + MapContainer.newBuilder() + .mergeFrom(new ByteArrayInputStream(NESTING_SGROUP_WITH_INITIAL_BYTES))); + + assertThat(parseFromThrown) + .hasMessageThat() + .contains("Protocol message had too many levels of nesting"); + assertThat(mergeFromThrown) + .hasMessageThat() + .contains("Protocol message had too many levels of nesting"); + } + + @Test + public void testMaliciousSGroupTags_inputStream_skipMessage() throws Exception { + ByteArrayInputStream inputSteam = new ByteArrayInputStream(NESTING_SGROUP); + CodedInputStream input = CodedInputStream.newInstance(inputSteam); + CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]); + + Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage); + Throwable thrown2 = + assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output)); + + assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); + assertThat(thrown2) + .hasMessageThat() + .contains("Protocol message had too many levels of nesting"); + } + + @Test + public void testMaliciousSGroupTagsWithMapField_fromByteArray() throws Exception { + Throwable parseFromThrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> MapContainer.parseFrom(NESTING_SGROUP_WITH_INITIAL_BYTES)); + Throwable mergeFromThrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> MapContainer.newBuilder().mergeFrom(NESTING_SGROUP_WITH_INITIAL_BYTES)); + + assertThat(parseFromThrown) + .hasMessageThat() + .contains("the input ended unexpectedly in the middle of a field"); + assertThat(mergeFromThrown) + .hasMessageThat() + .contains("the input ended unexpectedly in the middle of a field"); + } + + @Test + public void testMaliciousSGroupTags_arrayDecoder_skipMessage() throws Exception { + CodedInputStream input = CodedInputStream.newInstance(NESTING_SGROUP); + CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]); + + Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage); + Throwable thrown2 = + assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output)); + + assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); + assertThat(thrown2) + .hasMessageThat() + .contains("Protocol message had too many levels of nesting"); + } + + @Test + public void testMaliciousSGroupTagsWithMapField_fromByteBuffer() throws Exception { + Throwable thrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> MapContainer.parseFrom(ByteBuffer.wrap(NESTING_SGROUP_WITH_INITIAL_BYTES))); + + assertThat(thrown) + .hasMessageThat() + .contains("the input ended unexpectedly in the middle of a field"); + } + + @Test + public void testMaliciousSGroupTags_byteBuffer_skipMessage() throws Exception { + CodedInputStream input = InputType.NIO_DIRECT.newDecoder(NESTING_SGROUP); + CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]); + + Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage); + Throwable thrown2 = + assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output)); + + assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); + assertThat(thrown2) + .hasMessageThat() + .contains("Protocol message had too many levels of nesting"); + } + + @Test + public void testMaliciousSGroupTags_iterableByteBuffer() throws Exception { + CodedInputStream input = InputType.ITER_DIRECT.newDecoder(NESTING_SGROUP); + CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]); + + Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage); + Throwable thrown2 = + assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output)); + + assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); + assertThat(thrown2) + .hasMessageThat() + .contains("Protocol message had too many levels of nesting"); + } + private void checkSizeLimitExceeded(InvalidProtocolBufferException e) { assertThat(e) .hasMessageThat() diff --git a/java/lite/src/test/java/com/google/protobuf/LiteTest.java b/java/lite/src/test/java/com/google/protobuf/LiteTest.java index 61009df109d4..1fc9c5607160 100644 --- a/java/lite/src/test/java/com/google/protobuf/LiteTest.java +++ b/java/lite/src/test/java/com/google/protobuf/LiteTest.java @@ -2463,6 +2463,211 @@ public void testParseFromByteBufferThrows() { } } + @Test + public void testParseFromInputStream_concurrent_nestingUnknownGroups() throws Exception { + int numThreads = 200; + ArrayList threads = new ArrayList<>(); + + ByteString byteString = generateNestingGroups(99); + AtomicBoolean thrown = new AtomicBoolean(false); + + for (int i = 0; i < numThreads; i++) { + Thread thread = + new Thread( + () -> { + try { + TestAllTypesLite unused = TestAllTypesLite.parseFrom(byteString); + } catch (IOException e) { + if (e.getMessage().contains("Protocol message had too many levels of nesting")) { + thrown.set(true); + } + } + }); + thread.start(); + threads.add(thread); + } + + for (Thread thread : threads) { + thread.join(); + } + + assertThat(thrown.get()).isFalse(); + } + + @Test + public void testParseFromInputStream_nestingUnknownGroups() throws IOException { + ByteString byteString = generateNestingGroups(99); + + Throwable thrown = + assertThrows( + InvalidProtocolBufferException.class, () -> TestAllTypesLite.parseFrom(byteString)); + assertThat(thrown) + .hasMessageThat() + .doesNotContain("Protocol message had too many levels of nesting"); + } + + @Test + public void testParseFromInputStream_nestingUnknownGroups_exception() throws IOException { + ByteString byteString = generateNestingGroups(100); + + Throwable thrown = + assertThrows( + InvalidProtocolBufferException.class, () -> TestAllTypesLite.parseFrom(byteString)); + assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); + } + + @Test + public void testParseFromInputStream_setRecursionLimit_exception() throws IOException { + ByteString byteString = generateNestingGroups(199); + UnknownFieldSchema schema = SchemaUtil.unknownFieldSetLiteSchema(); + schema.setRecursionLimit(200); + + Throwable thrown = + assertThrows( + InvalidProtocolBufferException.class, () -> TestAllTypesLite.parseFrom(byteString)); + assertThat(thrown) + .hasMessageThat() + .doesNotContain("Protocol message had too many levels of nesting"); + schema.setRecursionLimit(UnknownFieldSchema.DEFAULT_RECURSION_LIMIT); + } + + @Test + public void testParseFromBytes_concurrent_nestingUnknownGroups() throws Exception { + int numThreads = 200; + ArrayList threads = new ArrayList<>(); + + ByteString byteString = generateNestingGroups(99); + AtomicBoolean thrown = new AtomicBoolean(false); + + for (int i = 0; i < numThreads; i++) { + Thread thread = + new Thread( + () -> { + try { + // Should pass in byte[] instead of ByteString to go into ArrayDecoders. + TestAllTypesLite unused = TestAllTypesLite.parseFrom(byteString.toByteArray()); + } catch (InvalidProtocolBufferException e) { + if (e.getMessage().contains("Protocol message had too many levels of nesting")) { + thrown.set(true); + } + } + }); + thread.start(); + threads.add(thread); + } + + for (Thread thread : threads) { + thread.join(); + } + + assertThat(thrown.get()).isFalse(); + } + + @Test + public void testParseFromBytes_nestingUnknownGroups() throws IOException { + ByteString byteString = generateNestingGroups(99); + + Throwable thrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> TestAllTypesLite.parseFrom(byteString.toByteArray())); + assertThat(thrown) + .hasMessageThat() + .doesNotContain("Protocol message had too many levels of nesting"); + } + + @Test + public void testParseFromBytes_nestingUnknownGroups_exception() throws IOException { + ByteString byteString = generateNestingGroups(100); + + Throwable thrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> TestAllTypesLite.parseFrom(byteString.toByteArray())); + assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); + } + + @Test + public void testParseFromBytes_setRecursionLimit_exception() throws IOException { + ByteString byteString = generateNestingGroups(199); + ArrayDecoders.setRecursionLimit(200); + + Throwable thrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> TestAllTypesLite.parseFrom(byteString.toByteArray())); + assertThat(thrown) + .hasMessageThat() + .doesNotContain("Protocol message had too many levels of nesting"); + ArrayDecoders.setRecursionLimit(ArrayDecoders.DEFAULT_RECURSION_LIMIT); + } + + @Test + public void testParseFromBytes_recursiveMessages() throws Exception { + byte[] data99 = makeRecursiveMessage(99).toByteArray(); + byte[] data100 = makeRecursiveMessage(100).toByteArray(); + + RecursiveMessage unused = RecursiveMessage.parseFrom(data99); + Throwable thrown = + assertThrows( + InvalidProtocolBufferException.class, () -> RecursiveMessage.parseFrom(data100)); + assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); + } + + @Test + public void testParseFromBytes_recursiveKnownGroups() throws Exception { + byte[] data99 = makeRecursiveGroup(99).toByteArray(); + byte[] data100 = makeRecursiveGroup(100).toByteArray(); + + RecursiveGroup unused = RecursiveGroup.parseFrom(data99); + Throwable thrown = + assertThrows(InvalidProtocolBufferException.class, () -> RecursiveGroup.parseFrom(data100)); + assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting"); + } + + @Test + @SuppressWarnings("ProtoParseFromByteString") + public void testMaliciousSGroupTagsWithMapField_fromByteArray() throws Exception { + ByteString byteString = generateNestingGroups(102); + + Throwable parseFromThrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> MapContainer.parseFrom(byteString.toByteArray())); + Throwable mergeFromThrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> MapContainer.newBuilder().mergeFrom(byteString.toByteArray())); + + assertThat(parseFromThrown) + .hasMessageThat() + .contains("Protocol message had too many levels of nesting"); + assertThat(mergeFromThrown) + .hasMessageThat() + .contains("Protocol message had too many levels of nesting"); + } + + @Test + public void testMaliciousSGroupTagsWithMapField_fromInputStream() throws Exception { + byte[] bytes = generateNestingGroups(101).toByteArray(); + + Throwable parseFromThrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> MapContainer.parseFrom(new ByteArrayInputStream(bytes))); + Throwable mergeFromThrown = + assertThrows( + InvalidProtocolBufferException.class, + () -> MapContainer.newBuilder().mergeFrom(new ByteArrayInputStream(bytes))); + + assertThat(parseFromThrown) + .hasMessageThat() + .contains("Protocol message had too many levels of nesting"); + assertThat(mergeFromThrown) + .hasMessageThat() + .contains("Protocol message had too many levels of nesting"); + } + @Test public void testParseFromByteBuffer_extensions() throws Exception { TestAllExtensionsLite message = @@ -2819,4 +3024,31 @@ private static boolean contains(ByteString a, ByteString b) { } return false; } + + private static ByteString generateNestingGroups(int num) throws IOException { + int groupTap = WireFormat.makeTag(3, WireFormat.WIRETYPE_START_GROUP); + ByteString.Output byteStringOutput = ByteString.newOutput(); + CodedOutputStream codedOutput = CodedOutputStream.newInstance(byteStringOutput); + for (int i = 0; i < num; i++) { + codedOutput.writeInt32NoTag(groupTap); + } + codedOutput.flush(); + return byteStringOutput.toByteString(); + } + + private static RecursiveMessage makeRecursiveMessage(int num) { + if (num == 0) { + return RecursiveMessage.getDefaultInstance(); + } else { + return RecursiveMessage.newBuilder().setRecurse(makeRecursiveMessage(num - 1)).build(); + } + } + + private static RecursiveGroup makeRecursiveGroup(int num) { + if (num == 0) { + return RecursiveGroup.getDefaultInstance(); + } else { + return RecursiveGroup.newBuilder().setRecurse(makeRecursiveGroup(num - 1)).build(); + } + } }