From 260710e9968e5f298bff435f7589e85f80a1e675 Mon Sep 17 00:00:00 2001 From: Jesse Wilson Date: Wed, 7 Feb 2024 08:07:59 -0500 Subject: [PATCH] Make Inflater and Deflater symmetric (#1426) I expect this to simplify adding DeflaterSink and InflaterSource. --- .../nativeMain/kotlin/okio/DataProcessor.kt | 64 +++++++++++++++++++ okio/src/nativeMain/kotlin/okio/Deflater.kt | 59 +++-------------- okio/src/nativeMain/kotlin/okio/Inflater.kt | 37 +++++------ .../nativeTest/kotlin/okio/DeflaterTest.kt | 36 ++++++----- .../nativeTest/kotlin/okio/InflaterTest.kt | 29 ++++++--- 5 files changed, 128 insertions(+), 97 deletions(-) create mode 100644 okio/src/nativeMain/kotlin/okio/DataProcessor.kt diff --git a/okio/src/nativeMain/kotlin/okio/DataProcessor.kt b/okio/src/nativeMain/kotlin/okio/DataProcessor.kt new file mode 100644 index 0000000000..27549f1eb1 --- /dev/null +++ b/okio/src/nativeMain/kotlin/okio/DataProcessor.kt @@ -0,0 +1,64 @@ +/* + * Copyright (C) 2024 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +private val emptyByteArray = byteArrayOf() + +/** + * Transform a stream of source bytes into a stream of target bytes, one segment at a time. The + * relationship between input byte count and output byte count is arbitrary: a sequence of input + * bytes may produce zero output bytes, or many segments of output bytes. + * + * To use: + * + * 1. Create an instance. + * + * 2. Populate [source] with input data. Set [sourcePos] and [sourceLimit] to a readable slice of + * this array. + * + * 3. Populate [target] with a destination for output data. Set [targetPos] and [targetLimit] to a + * writable slice of this array. + * + * 4. Call [process] to read input data from [source] and write output to [target]. This function + * advances [sourcePos] if input data was read and [targetPos] if compressed output was written. + * If the input array is exhausted (`sourcePos == sourceLimit`) or the output array is full + * (`targetPos == targetLimit`), make an adjustment and call [process] again. + * + * 5. Repeat steps 2 through 4 until the input data is completely exhausted. + * + * 6. Close the processor. + * + * See also, the [zlib manual](https://www.zlib.net/manual.html). + */ +internal abstract class DataProcessor : Closeable { + var source: ByteArray = emptyByteArray + var sourcePos: Int = 0 + var sourceLimit: Int = 0 + + var target: ByteArray = emptyByteArray + var targetPos: Int = 0 + var targetLimit: Int = 0 + + var closed: Boolean = false + protected set + + /** + * Returns true if no further calls to [process] are required to complete the operation. + * Otherwise, make space available in [target] and call [process] again. + */ + @Throws(ProtocolException::class) + abstract fun process(): Boolean +} diff --git a/okio/src/nativeMain/kotlin/okio/Deflater.kt b/okio/src/nativeMain/kotlin/okio/Deflater.kt index fe8d4672fb..a1e30bc3c2 100644 --- a/okio/src/nativeMain/kotlin/okio/Deflater.kt +++ b/okio/src/nativeMain/kotlin/okio/Deflater.kt @@ -37,36 +37,16 @@ import platform.zlib.deflateEnd import platform.zlib.deflateInit2 import platform.zlib.z_stream_s -internal val emptyByteArray = byteArrayOf() - /** * Deflate using Kotlin/Native's built-in zlib bindings. This uses the raw deflate format and omits * the zlib header and trailer, and does not compute a check value. * - * To use: - * - * 1. Create an instance. - * - * 2. Populate [source] with uncompressed data. Set [sourcePos] and [sourceLimit] to a readable - * slice of this array. - * - * 3. Populate [target] with a destination for compressed data. Set [targetPos] and [targetLimit] to - * a writable slice of this array. - * - * 4. Call [deflate] to read input data from [source] and write compressed output to [target]. This - * function advances [sourcePos] if input data was read and [targetPos] if compressed output was - * written. If the input array is exhausted (`sourcePos == sourceLimit`) or the output array is - * full (`targetPos == targetLimit`), make an adjustment and call [deflate] again. - * - * 5. Repeat steps 2 through 4 until the input data is completely exhausted. Set [sourceFinished] - * to true before the last call to [deflate]. (It is okay to call deflate() when the source is - * exhausted.) - * - * 6. Close the Deflater. + * Note that you must set [flush] to [Z_FINISH] before the last call to [process]. (It is okay to + * call process() when the source is exhausted.) * * See also, the [zlib manual](https://www.zlib.net/manual.html). */ -internal class Deflater : Closeable { +internal class Deflater : DataProcessor() { private val zStream: z_stream_s = nativeHeap.alloc { zalloc = null zfree = null @@ -83,22 +63,10 @@ internal class Deflater : Closeable { ) } - var source: ByteArray = emptyByteArray - var sourcePos: Int = 0 - var sourceLimit: Int = 0 - var sourceFinished = false - - var target: ByteArray = emptyByteArray - var targetPos: Int = 0 - var targetLimit: Int = 0 - - private var closed = false + /** Probably [Z_NO_FLUSH], [Z_FINISH], or [Z_SYNC_FLUSH]. */ + var flush: Int = Z_NO_FLUSH - /** - * Returns true if no further calls to [deflate] are required to complete the operation. - * Otherwise, make space available in [target] and call [deflate] again with the same arguments. - */ - fun deflate(flush: Boolean = false): Boolean { + override fun process(): Boolean { check(!closed) { "closed" } require(0 <= sourcePos && sourcePos <= sourceLimit && sourceLimit <= source.size) require(0 <= targetPos && targetPos <= targetLimit && targetLimit <= target.size) @@ -119,23 +87,16 @@ internal class Deflater : Closeable { } zStream.avail_out = targetByteCount.toUInt() - val deflateFlush = when { - sourceFinished -> Z_FINISH - flush -> Z_SYNC_FLUSH - else -> Z_NO_FLUSH - } - // One of Z_OK, Z_STREAM_END, Z_STREAM_ERROR, or Z_BUF_ERROR. - val deflateResult = deflate(zStream.ptr, deflateFlush) + val deflateResult = deflate(zStream.ptr, flush) check(deflateResult != Z_STREAM_ERROR) sourcePos += sourceByteCount - zStream.avail_in.toInt() targetPos += targetByteCount - zStream.avail_out.toInt() - return when { - sourceFinished -> deflateResult == Z_STREAM_END - flush -> targetPos < targetLimit - else -> true + return when (deflateResult) { + Z_STREAM_END -> true + else -> targetPos < targetLimit } } } diff --git a/okio/src/nativeMain/kotlin/okio/Inflater.kt b/okio/src/nativeMain/kotlin/okio/Inflater.kt index ddac5da2a6..22c7a6c264 100644 --- a/okio/src/nativeMain/kotlin/okio/Inflater.kt +++ b/okio/src/nativeMain/kotlin/okio/Inflater.kt @@ -34,10 +34,8 @@ import platform.zlib.z_stream_s /** * Inflate using Kotlin/Native's built-in zlib bindings. - * - * The API is symmetric with [Deflater]. */ -internal class Inflater : Closeable { +internal class Inflater : DataProcessor() { private val zStream: z_stream_s = nativeHeap.alloc { zalloc = null zfree = null @@ -50,22 +48,11 @@ internal class Inflater : Closeable { ) } - var source: ByteArray = emptyByteArray - var sourcePos: Int = 0 - var sourceLimit: Int = 0 - - var target: ByteArray = emptyByteArray - var targetPos: Int = 0 - var targetLimit: Int = 0 - - private var closed = false + var sourceFinished: Boolean = false + private set - /** - * Returns true if no further calls to [inflate] are required because the source stream is - * finished. Otherwise, ensure there's input data in [source] and output space in [target] and - * call this again. - */ - fun inflate(): Boolean { + @Throws(ProtocolException::class) + override fun process(): Boolean { check(!closed) { "closed" } require(0 <= sourcePos && sourcePos <= sourceLimit && sourceLimit <= source.size) require(0 <= targetPos && targetPos <= targetLimit && targetLimit <= target.size) @@ -91,10 +78,16 @@ internal class Inflater : Closeable { sourcePos += sourceByteCount - zStream.avail_in.toInt() targetPos += targetByteCount - zStream.avail_out.toInt() - return when (inflateResult) { - Z_OK -> false - Z_BUF_ERROR -> false // Non-fatal but the caller needs to update source and/or target. - Z_STREAM_END -> true + when (inflateResult) { + Z_OK, Z_BUF_ERROR -> { + return targetPos < targetLimit + } + + Z_STREAM_END -> { + sourceFinished = true + return true + } + Z_DATA_ERROR -> throw ProtocolException("Z_DATA_ERROR") // One of Z_NEED_DICT, Z_STREAM_ERROR, Z_MEM_ERROR. diff --git a/okio/src/nativeTest/kotlin/okio/DeflaterTest.kt b/okio/src/nativeTest/kotlin/okio/DeflaterTest.kt index f0374fda85..505e7131b5 100644 --- a/okio/src/nativeTest/kotlin/okio/DeflaterTest.kt +++ b/okio/src/nativeTest/kotlin/okio/DeflaterTest.kt @@ -23,6 +23,9 @@ import kotlin.test.assertTrue import okio.ByteString.Companion.decodeBase64 import okio.ByteString.Companion.encodeUtf8 import okio.ByteString.Companion.toByteString +import platform.zlib.Z_FINISH +import platform.zlib.Z_NO_FLUSH +import platform.zlib.Z_SYNC_FLUSH class DeflaterTest { @Test @@ -31,14 +34,14 @@ class DeflaterTest { source = "God help us, we're in the hands of engineers.".encodeUtf8().toByteArray() sourcePos = 0 sourceLimit = source.size - sourceFinished = true + flush = Z_FINISH target = ByteArray(256) targetPos = 0 targetLimit = target.size } - assertTrue(deflater.deflate()) + assertTrue(deflater.process()) assertEquals(deflater.sourceLimit, deflater.sourcePos) val deflated = deflater.target.toByteString(0, deflater.targetPos) @@ -62,15 +65,14 @@ class DeflaterTest { deflater.source = "God help us, we're in the hands".encodeUtf8().toByteArray() deflater.sourcePos = 0 deflater.sourceLimit = deflater.source.size - deflater.sourceFinished = false - assertTrue(deflater.deflate()) + assertTrue(deflater.process()) assertEquals(deflater.sourceLimit, deflater.sourcePos) deflater.source = " of engineers.".encodeUtf8().toByteArray() deflater.sourcePos = 0 deflater.sourceLimit = deflater.source.size - deflater.sourceFinished = true - assertTrue(deflater.deflate()) + deflater.flush = Z_FINISH + assertTrue(deflater.process()) assertEquals(deflater.sourceLimit, deflater.sourcePos) val deflated = deflater.target.toByteString(0, deflater.targetPos) @@ -97,19 +99,21 @@ class DeflaterTest { deflater.target = ByteArray(10) deflater.targetPos = 0 deflater.targetLimit = deflater.target.size - assertFalse(deflater.deflate(flush = true)) + deflater.flush = Z_SYNC_FLUSH + assertFalse(deflater.process()) assertEquals(deflater.targetLimit, deflater.targetPos) targetBuffer.write(deflater.target) deflater.target = ByteArray(256) deflater.targetPos = 0 deflater.targetLimit = deflater.target.size - assertTrue(deflater.deflate()) + deflater.flush = Z_NO_FLUSH + assertTrue(deflater.process()) assertEquals(deflater.sourcePos, deflater.sourceLimit) targetBuffer.write(deflater.target, 0, deflater.targetPos) - deflater.sourceFinished = true - assertTrue(deflater.deflate()) + deflater.flush = Z_FINISH + assertTrue(deflater.process()) // Golden compressed output. assertEquals( @@ -128,20 +132,20 @@ class DeflaterTest { source = "God help us, we're in the hands of engineers.".encodeUtf8().toByteArray() sourcePos = 0 sourceLimit = source.size - sourceFinished = true + flush = Z_FINISH } deflater.target = ByteArray(10) deflater.targetPos = 0 deflater.targetLimit = deflater.target.size - assertFalse(deflater.deflate()) + assertFalse(deflater.process()) assertEquals(deflater.targetLimit, deflater.targetPos) targetBuffer.write(deflater.target) deflater.target = ByteArray(256) deflater.targetPos = 0 deflater.targetLimit = deflater.target.size - assertTrue(deflater.deflate()) + assertTrue(deflater.process()) assertEquals(deflater.sourcePos, deflater.sourceLimit) targetBuffer.write(deflater.target, 0, deflater.targetPos) @@ -157,14 +161,14 @@ class DeflaterTest { @Test fun deflateEmptySource() { val deflater = Deflater().apply { - sourceFinished = true + flush = Z_FINISH target = ByteArray(256) targetPos = 0 targetLimit = target.size } - assertTrue(deflater.deflate()) + assertTrue(deflater.process()) val deflated = deflater.target.toByteString(0, deflater.targetPos) // Golden compressed output. @@ -182,7 +186,7 @@ class DeflaterTest { deflater.close() assertFailsWith { - deflater.deflate() + deflater.process() } } diff --git a/okio/src/nativeTest/kotlin/okio/InflaterTest.kt b/okio/src/nativeTest/kotlin/okio/InflaterTest.kt index 2114ac1748..a7293baa52 100644 --- a/okio/src/nativeTest/kotlin/okio/InflaterTest.kt +++ b/okio/src/nativeTest/kotlin/okio/InflaterTest.kt @@ -38,7 +38,8 @@ class InflaterTest { targetLimit = target.size } - assertTrue(inflater.inflate()) + assertTrue(inflater.process()) + assertTrue(inflater.sourceFinished) assertEquals(inflater.sourceLimit, inflater.sourcePos) val inflated = inflater.target.toByteString(0, inflater.targetPos) @@ -61,13 +62,15 @@ class InflaterTest { inflater.source = "c89PUchIzSlQKC3WUShPVS9KVcjMUyjJ".decodeBase64()!!.toByteArray() inflater.sourcePos = 0 inflater.sourceLimit = inflater.source.size - assertFalse(inflater.inflate()) + assertTrue(inflater.process()) + assertFalse(inflater.sourceFinished) assertEquals(inflater.sourceLimit, inflater.sourcePos) inflater.source = "SFXISMxLKVbIT1NIzUvPzEtNLSrWAwA=".decodeBase64()!!.toByteArray() inflater.sourcePos = 0 inflater.sourceLimit = inflater.source.size - assertTrue(inflater.inflate()) + assertTrue(inflater.process()) + assertTrue(inflater.sourceFinished) assertEquals(inflater.sourceLimit, inflater.sourcePos) val inflated = inflater.target.toByteString(0, inflater.targetPos) @@ -93,14 +96,16 @@ class InflaterTest { inflater.target = ByteArray(31) inflater.targetPos = 0 inflater.targetLimit = inflater.target.size - assertFalse(inflater.inflate()) + assertFalse(inflater.process()) + assertFalse(inflater.sourceFinished) assertEquals(inflater.targetLimit, inflater.targetPos) targetBuffer.write(inflater.target) inflater.target = ByteArray(256) inflater.targetPos = 0 inflater.targetLimit = inflater.target.size - assertTrue(inflater.inflate()) + assertTrue(inflater.process()) + assertTrue(inflater.sourceFinished) assertEquals(inflater.sourcePos, inflater.sourceLimit) targetBuffer.write(inflater.target, 0, inflater.targetPos) @@ -124,7 +129,8 @@ class InflaterTest { targetLimit = target.size } - assertTrue(inflater.inflate()) + assertTrue(inflater.process()) + assertTrue(inflater.sourceFinished) val inflated = inflater.target.toByteString(0, inflater.targetPos) assertEquals( @@ -146,13 +152,15 @@ class InflaterTest { inflater.source = ByteArray(256) inflater.sourcePos = 0 inflater.sourceLimit = 0 - assertFalse(inflater.inflate()) + assertTrue(inflater.process()) + assertFalse(inflater.sourceFinished) inflater.source = "c89PUchIzSlQKC3WUShPVS9KVcjMUyjJSFXISMxLKVbIT1NIzUvPzEtNLSrWAwA=" .decodeBase64()!!.toByteArray() inflater.sourcePos = 0 inflater.sourceLimit = inflater.source.size - assertTrue(inflater.inflate()) + assertTrue(inflater.process()) + assertTrue(inflater.sourceFinished) val inflated = inflater.target.toByteString(0, inflater.targetPos) assertEquals( @@ -175,8 +183,9 @@ class InflaterTest { inflater.sourcePos = 0 inflater.sourceLimit = inflater.source.size val exception = assertFailsWith { - inflater.inflate() + inflater.process() } + assertFalse(inflater.sourceFinished) assertEquals("Z_DATA_ERROR", exception.message) inflater.close() @@ -188,7 +197,7 @@ class InflaterTest { inflater.close() assertFailsWith { - inflater.inflate() + inflater.process() } }