From b90e60c09befaff836a2fc2ee4d678451b2ec75d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Quenaudon?= Date: Mon, 23 Sep 2024 16:30:34 +0100 Subject: [PATCH] Enforce recursion limit on nested groups --- .../squareup/wire/ByteArrayProtoReader32.kt | 13 ++++++++++++- .../kotlin/com/squareup/wire/ProtoReader.kt | 13 ++++++++++++- .../com/squareup/wire/ProtoReader32Test.kt | 18 ++++++++++++++++++ .../com/squareup/wire/ProtoReaderTest.kt | 18 ++++++++++++++++++ 4 files changed, 60 insertions(+), 2 deletions(-) diff --git a/wire-runtime/src/commonMain/kotlin/com/squareup/wire/ByteArrayProtoReader32.kt b/wire-runtime/src/commonMain/kotlin/com/squareup/wire/ByteArrayProtoReader32.kt index 97b312e0ee..61cf09136c 100644 --- a/wire-runtime/src/commonMain/kotlin/com/squareup/wire/ByteArrayProtoReader32.kt +++ b/wire-runtime/src/commonMain/kotlin/com/squareup/wire/ByteArrayProtoReader32.kt @@ -210,7 +210,18 @@ internal class ByteArrayProtoReader32( if (tagAndFieldEncoding == 0) throw ProtocolException("Unexpected tag 0") val tag = tagAndFieldEncoding shr TAG_FIELD_ENCODING_BITS when (val groupOrFieldEncoding = tagAndFieldEncoding and FIELD_ENCODING_MASK) { - STATE_START_GROUP -> skipGroup(tag) // Nested group. + STATE_START_GROUP -> { + recursionDepth++ + try { + if (recursionDepth > RECURSION_LIMIT) { + throw IOException("Wire recursion limit exceeded") + } + // Nested group. + skipGroup(tag) + } finally { + recursionDepth-- + } + } STATE_END_GROUP -> { if (tag == expectedEndTag) return // Success! throw ProtocolException("Unexpected end group") diff --git a/wire-runtime/src/commonMain/kotlin/com/squareup/wire/ProtoReader.kt b/wire-runtime/src/commonMain/kotlin/com/squareup/wire/ProtoReader.kt index a615f40cbe..94280330c7 100644 --- a/wire-runtime/src/commonMain/kotlin/com/squareup/wire/ProtoReader.kt +++ b/wire-runtime/src/commonMain/kotlin/com/squareup/wire/ProtoReader.kt @@ -249,7 +249,18 @@ open class ProtoReader(private val source: BufferedSource) { if (tagAndFieldEncoding == 0) throw ProtocolException("Unexpected tag 0") val tag = tagAndFieldEncoding shr TAG_FIELD_ENCODING_BITS when (val groupOrFieldEncoding = tagAndFieldEncoding and FIELD_ENCODING_MASK) { - STATE_START_GROUP -> skipGroup(tag) // Nested group. + STATE_START_GROUP -> { + recursionDepth++ + try { + if (recursionDepth > RECURSION_LIMIT) { + throw IOException("Wire recursion limit exceeded") + } + // Nested group. + skipGroup(tag) + } finally { + recursionDepth-- + } + } STATE_END_GROUP -> { if (tag == expectedEndTag) return // Success! throw ProtocolException("Unexpected end group") diff --git a/wire-runtime/src/commonTest/kotlin/com/squareup/wire/ProtoReader32Test.kt b/wire-runtime/src/commonTest/kotlin/com/squareup/wire/ProtoReader32Test.kt index 02f80cf11e..8b1f74035c 100644 --- a/wire-runtime/src/commonTest/kotlin/com/squareup/wire/ProtoReader32Test.kt +++ b/wire-runtime/src/commonTest/kotlin/com/squareup/wire/ProtoReader32Test.kt @@ -15,9 +15,12 @@ */ package com.squareup.wire +import com.squareup.wire.ReverseProtoWriterTest.Person import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertFailsWith import okio.ByteString.Companion.decodeHex +import okio.IOException class ProtoReader32Test { @Test fun packedExposedAsRepeated() { @@ -58,5 +61,20 @@ class ProtoReader32Test { reader.endMessageAndGetUnknownFields(secondToken) } + /** We had a bug where we weren't enforcing recursion limits for groups. */ + @Test fun testSkipGroupNested() { + val data = ByteArray(50000) { + when { + it % 2 == 0 -> 0xa3.toByte() + else -> 0x01.toByte() + } + } + + val failed = assertFailsWith { + Person.ADAPTER.decode(data) + } + assertEquals("Wire recursion limit exceeded", failed.message) + } + // Consider pasting new tests into ProtoReaderTest.kt also. } diff --git a/wire-runtime/src/commonTest/kotlin/com/squareup/wire/ProtoReaderTest.kt b/wire-runtime/src/commonTest/kotlin/com/squareup/wire/ProtoReaderTest.kt index def84f293a..18e54f4067 100644 --- a/wire-runtime/src/commonTest/kotlin/com/squareup/wire/ProtoReaderTest.kt +++ b/wire-runtime/src/commonTest/kotlin/com/squareup/wire/ProtoReaderTest.kt @@ -15,10 +15,13 @@ */ package com.squareup.wire +import com.squareup.wire.ReverseProtoWriterTest.Person import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertFailsWith import okio.Buffer import okio.ByteString.Companion.decodeHex +import okio.IOException class ProtoReaderTest { @Test fun packedExposedAsRepeated() { @@ -59,5 +62,20 @@ class ProtoReaderTest { reader.endMessageAndGetUnknownFields(secondToken) } + /** We had a bug where we weren't enforcing recursion limits for groups. */ + @Test fun testSkipGroupNested() { + val data = ByteArray(50000) { + when { + it % 2 == 0 -> 0xa3.toByte() + else -> 0x01.toByte() + } + } + + val failed = assertFailsWith { + Person.ADAPTER.decode(Buffer().write(data)) + } + assertEquals("Wire recursion limit exceeded", failed.message) + } + // Consider pasting new tests into ProtoReader32Test.kt also. }