diff --git a/okio/src/nativeMain/kotlin/okio/Deflater.kt b/okio/src/nativeMain/kotlin/okio/Deflater.kt index 0baa6e5f3a..fe8d4672fb 100644 --- a/okio/src/nativeMain/kotlin/okio/Deflater.kt +++ b/okio/src/nativeMain/kotlin/okio/Deflater.kt @@ -37,7 +37,7 @@ import platform.zlib.deflateEnd import platform.zlib.deflateInit2 import platform.zlib.z_stream_s -private val emptyByteArray = byteArrayOf() +internal val emptyByteArray = byteArrayOf() /** * Deflate using Kotlin/Native's built-in zlib bindings. This uses the raw deflate format and omits @@ -145,8 +145,7 @@ internal class Deflater : Closeable { if (closed) return closed = true - val deflateEndResult = deflateEnd(zStream.ptr) - check(deflateEndResult == Z_OK) + deflateEnd(zStream.ptr) nativeHeap.free(zStream) } } diff --git a/okio/src/nativeMain/kotlin/okio/Inflater.kt b/okio/src/nativeMain/kotlin/okio/Inflater.kt new file mode 100644 index 0000000000..ddac5da2a6 --- /dev/null +++ b/okio/src/nativeMain/kotlin/okio/Inflater.kt @@ -0,0 +1,114 @@ +/* + * Copyright (C) 2024 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import kotlinx.cinterop.CPointer +import kotlinx.cinterop.UByteVar +import kotlinx.cinterop.addressOf +import kotlinx.cinterop.alloc +import kotlinx.cinterop.free +import kotlinx.cinterop.nativeHeap +import kotlinx.cinterop.ptr +import kotlinx.cinterop.usePinned +import platform.zlib.Z_BUF_ERROR +import platform.zlib.Z_DATA_ERROR +import platform.zlib.Z_NO_FLUSH +import platform.zlib.Z_OK +import platform.zlib.Z_STREAM_END +import platform.zlib.inflateEnd +import platform.zlib.inflateInit2 +import platform.zlib.z_stream_s + +/** + * Inflate using Kotlin/Native's built-in zlib bindings. + * + * The API is symmetric with [Deflater]. + */ +internal class Inflater : Closeable { + private val zStream: z_stream_s = nativeHeap.alloc { + zalloc = null + zfree = null + opaque = null + check( + inflateInit2( + strm = ptr, + windowBits = -15, // Default value for raw deflate. + ) == Z_OK, + ) + } + + var source: ByteArray = emptyByteArray + var sourcePos: Int = 0 + var sourceLimit: Int = 0 + + var target: ByteArray = emptyByteArray + var targetPos: Int = 0 + var targetLimit: Int = 0 + + private var closed = false + + /** + * Returns true if no further calls to [inflate] are required because the source stream is + * finished. Otherwise, ensure there's input data in [source] and output space in [target] and + * call this again. + */ + fun inflate(): Boolean { + check(!closed) { "closed" } + require(0 <= sourcePos && sourcePos <= sourceLimit && sourceLimit <= source.size) + require(0 <= targetPos && targetPos <= targetLimit && targetLimit <= target.size) + + source.usePinned { pinnedSource -> + target.usePinned { pinnedTarget -> + val sourceByteCount = sourceLimit - sourcePos + zStream.next_in = when { + sourceByteCount > 0 -> pinnedSource.addressOf(sourcePos) as CPointer + else -> null + } + zStream.avail_in = sourceByteCount.toUInt() + + val targetByteCount = targetLimit - targetPos + zStream.next_out = when { + targetByteCount > 0 -> pinnedTarget.addressOf(targetPos) as CPointer + else -> null + } + zStream.avail_out = targetByteCount.toUInt() + + val inflateResult = platform.zlib.inflate(zStream.ptr, Z_NO_FLUSH) + + sourcePos += sourceByteCount - zStream.avail_in.toInt() + targetPos += targetByteCount - zStream.avail_out.toInt() + + return when (inflateResult) { + Z_OK -> false + Z_BUF_ERROR -> false // Non-fatal but the caller needs to update source and/or target. + Z_STREAM_END -> true + Z_DATA_ERROR -> throw ProtocolException("Z_DATA_ERROR") + + // One of Z_NEED_DICT, Z_STREAM_ERROR, Z_MEM_ERROR. + else -> throw ProtocolException("unexpected inflate result: $inflateResult") + } + } + } + } + + override fun close() { + if (closed) return + closed = true + + inflateEnd(zStream.ptr) + nativeHeap.free(zStream) + } +} diff --git a/okio/src/nativeTest/kotlin/okio/DeflaterTest.kt b/okio/src/nativeTest/kotlin/okio/DeflaterTest.kt index 4ab91a0410..f0374fda85 100644 --- a/okio/src/nativeTest/kotlin/okio/DeflaterTest.kt +++ b/okio/src/nativeTest/kotlin/okio/DeflaterTest.kt @@ -17,6 +17,7 @@ package okio import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertFailsWith import kotlin.test.assertFalse import kotlin.test.assertTrue import okio.ByteString.Companion.decodeBase64 @@ -174,4 +175,21 @@ class DeflaterTest { deflater.close() } + + @Test + fun cannotDeflateAfterClose() { + val deflater = Deflater() + deflater.close() + + assertFailsWith { + deflater.deflate() + } + } + + @Test + fun closeIsIdemptent() { + val deflater = Deflater() + deflater.close() + deflater.close() + } } diff --git a/okio/src/nativeTest/kotlin/okio/InflaterTest.kt b/okio/src/nativeTest/kotlin/okio/InflaterTest.kt new file mode 100644 index 0000000000..2114ac1748 --- /dev/null +++ b/okio/src/nativeTest/kotlin/okio/InflaterTest.kt @@ -0,0 +1,201 @@ +/* + * Copyright (C) 2024 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertTrue +import okio.ByteString.Companion.decodeBase64 +import okio.ByteString.Companion.decodeHex +import okio.ByteString.Companion.toByteString + +class InflaterTest { + @Test + fun happyPath() { + val inflater = Inflater().apply { + source = "c89PUchIzSlQKC3WUShPVS9KVcjMUyjJSFXISMxLKVbIT1NIzUvPzEtNLSrWAwA=" + .decodeBase64()!!.toByteArray() + sourcePos = 0 + sourceLimit = source.size + + target = ByteArray(256) + targetPos = 0 + targetLimit = target.size + } + + assertTrue(inflater.inflate()) + assertEquals(inflater.sourceLimit, inflater.sourcePos) + + val inflated = inflater.target.toByteString(0, inflater.targetPos) + assertEquals( + "God help us, we're in the hands of engineers.", + inflated.utf8(), + ) + + inflater.close() + } + + @Test + fun inflateInParts() { + val inflater = Inflater().apply { + target = ByteArray(256) + targetPos = 0 + targetLimit = target.size + } + + inflater.source = "c89PUchIzSlQKC3WUShPVS9KVcjMUyjJ".decodeBase64()!!.toByteArray() + inflater.sourcePos = 0 + inflater.sourceLimit = inflater.source.size + assertFalse(inflater.inflate()) + assertEquals(inflater.sourceLimit, inflater.sourcePos) + + inflater.source = "SFXISMxLKVbIT1NIzUvPzEtNLSrWAwA=".decodeBase64()!!.toByteArray() + inflater.sourcePos = 0 + inflater.sourceLimit = inflater.source.size + assertTrue(inflater.inflate()) + assertEquals(inflater.sourceLimit, inflater.sourcePos) + + val inflated = inflater.target.toByteString(0, inflater.targetPos) + assertEquals( + "God help us, we're in the hands of engineers.", + inflated.utf8(), + ) + + inflater.close() + } + + @Test + fun inflateInsufficientSpaceInTarget() { + val targetBuffer = Buffer() + + val inflater = Inflater().apply { + source = "c89PUchIzSlQKC3WUShPVS9KVcjMUyjJSFXISMxLKVbIT1NIzUvPzEtNLSrWAwA=" + .decodeBase64()!!.toByteArray() + sourcePos = 0 + sourceLimit = source.size + } + + inflater.target = ByteArray(31) + inflater.targetPos = 0 + inflater.targetLimit = inflater.target.size + assertFalse(inflater.inflate()) + assertEquals(inflater.targetLimit, inflater.targetPos) + targetBuffer.write(inflater.target) + + inflater.target = ByteArray(256) + inflater.targetPos = 0 + inflater.targetLimit = inflater.target.size + assertTrue(inflater.inflate()) + assertEquals(inflater.sourcePos, inflater.sourceLimit) + targetBuffer.write(inflater.target, 0, inflater.targetPos) + + assertEquals( + "God help us, we're in the hands of engineers.", + targetBuffer.readUtf8(), + ) + + inflater.close() + } + + @Test + fun inflateEmptyContent() { + val inflater = Inflater().apply { + source = "AwA=".decodeBase64()!!.toByteArray() + sourcePos = 0 + sourceLimit = source.size + + target = ByteArray(256) + targetPos = 0 + targetLimit = target.size + } + + assertTrue(inflater.inflate()) + + val inflated = inflater.target.toByteString(0, inflater.targetPos) + assertEquals( + "", + inflated.utf8(), + ) + + inflater.close() + } + + @Test + fun inflateInPartsStartingWithEmptySource() { + val inflater = Inflater().apply { + target = ByteArray(256) + targetPos = 0 + targetLimit = target.size + } + + inflater.source = ByteArray(256) + inflater.sourcePos = 0 + inflater.sourceLimit = 0 + assertFalse(inflater.inflate()) + + inflater.source = "c89PUchIzSlQKC3WUShPVS9KVcjMUyjJSFXISMxLKVbIT1NIzUvPzEtNLSrWAwA=" + .decodeBase64()!!.toByteArray() + inflater.sourcePos = 0 + inflater.sourceLimit = inflater.source.size + assertTrue(inflater.inflate()) + + val inflated = inflater.target.toByteString(0, inflater.targetPos) + assertEquals( + "God help us, we're in the hands of engineers.", + inflated.utf8(), + ) + + inflater.close() + } + + @Test + fun inflateInvalidData() { + val inflater = Inflater().apply { + target = ByteArray(256) + targetPos = 0 + targetLimit = target.size + } + + inflater.source = "ffffffffffffffff".decodeHex().toByteArray() + inflater.sourcePos = 0 + inflater.sourceLimit = inflater.source.size + val exception = assertFailsWith { + inflater.inflate() + } + assertEquals("Z_DATA_ERROR", exception.message) + + inflater.close() + } + + @Test + fun cannotInflateAfterClose() { + val inflater = Inflater() + inflater.close() + + assertFailsWith { + inflater.inflate() + } + } + + @Test + fun closeIsIdemptent() { + val inflater = Inflater() + inflater.close() + inflater.close() + } +}