Skip to content

Commit

Permalink
Merge pull request #234 from Chuckame/fix/remove-bytes
Browse files Browse the repository at this point in the history
fix: Only handle ByteArrays as bytes or fixed, and collection of Byte as arrays of int
  • Loading branch information
Chuckame authored Jul 10, 2024
2 parents ad17bb3 + 83934fc commit 489b45a
Show file tree
Hide file tree
Showing 11 changed files with 84 additions and 195 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ import kotlinx.serialization.SerializationException
import kotlinx.serialization.SerializationStrategy
import kotlinx.serialization.builtins.ByteArraySerializer
import kotlinx.serialization.builtins.serializer
import kotlinx.serialization.descriptors.PrimitiveKind
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.StructureKind
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.internal.AbstractCollectionSerializer
Expand Down Expand Up @@ -56,14 +54,12 @@ internal object SerializerLocatorMiddleware {

fun apply(descriptor: SerialDescriptor): SerialDescriptor {
return when {
descriptor.isCollectionOfBytes() -> SerialDescriptorWithAvroSchemaDelegate(descriptor, AvroByteArraySerializer)
descriptor == String.serializer().descriptor -> AvroStringSerialDescriptor
descriptor == Duration.serializer().descriptor -> KotlinDurationSerializer.descriptor
descriptor === ByteArraySerializer().descriptor -> AvroByteArraySerializer.descriptor
descriptor === String.serializer().descriptor -> AvroStringSerialDescriptor
descriptor === Duration.serializer().descriptor -> KotlinDurationSerializer.descriptor
else -> descriptor
}
}

private fun SerialDescriptor.isCollectionOfBytes() = kind === StructureKind.LIST && elementsCount == 1 && getElementDescriptor(0).kind === PrimitiveKind.BYTE
}

private val AvroStringSerialDescriptor: SerialDescriptor =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,38 +61,18 @@ internal abstract class AbstractAvroDirectDecoder(
override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder {
return when (descriptor.kind) {
StructureKind.LIST ->
decodeResolvingAny({
UnexpectedDecodeSchemaError(
descriptor.nonNullSerialName,
Schema.Type.ARRAY,
Schema.Type.BYTES,
Schema.Type.FIXED
)
}) {
decodeResolvingAny({ UnexpectedDecodeSchemaError(descriptor.nonNullSerialName, Schema.Type.ARRAY) }) {
when (it.type) {
Schema.Type.ARRAY -> {
AnyValueDecoder { ArrayBlockDirectDecoder(it, decodeFirstBlock = decodedCollectionSize == -1, { decodedCollectionSize = it }, avro, binaryDecoder) }
}

Schema.Type.BYTES -> {
AnyValueDecoder { BytesDirectDecoder(avro, binaryDecoder) }
}

Schema.Type.FIXED -> {
AnyValueDecoder { FixedDirectDecoder(avro, it.fixedSize, binaryDecoder) }
}

else -> null
}
}

StructureKind.MAP ->
decodeResolvingAny({
UnexpectedDecodeSchemaError(
descriptor.nonNullSerialName,
Schema.Type.MAP
)
}) {
decodeResolvingAny({ UnexpectedDecodeSchemaError(descriptor.nonNullSerialName, Schema.Type.MAP) }) {
when (it.type) {
Schema.Type.MAP -> {
AnyValueDecoder { MapBlockDirectDecoder(it, decodeFirstBlock = decodedCollectionSize == -1, { decodedCollectionSize = it }, avro, binaryDecoder) }
Expand All @@ -103,12 +83,7 @@ internal abstract class AbstractAvroDirectDecoder(
}

StructureKind.CLASS, StructureKind.OBJECT ->
decodeResolvingAny({
UnexpectedDecodeSchemaError(
descriptor.nonNullSerialName,
Schema.Type.RECORD
)
}) {
decodeResolvingAny({ UnexpectedDecodeSchemaError(descriptor.nonNullSerialName, Schema.Type.RECORD) }) {
when (it.type) {
Schema.Type.RECORD -> {
AnyValueDecoder { RecordDirectDecoder(it, descriptor, avro, binaryDecoder) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,60 +3,8 @@ package com.github.avrokotlin.avro4k.internal.decoder.direct
import com.github.avrokotlin.avro4k.Avro
import com.github.avrokotlin.avro4k.internal.IllegalIndexedAccessError
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.encoding.AbstractDecoder
import kotlinx.serialization.modules.SerializersModule
import org.apache.avro.Schema

internal class BytesDirectDecoder(
private val avro: Avro,
binaryDecoder: org.apache.avro.io.Decoder,
) : AbstractDecoder() {
override val serializersModule: SerializersModule
get() = avro.serializersModule

private val bytes = binaryDecoder.readBytes(null)

override fun decodeByte(): Byte {
return bytes.get()
}

override fun decodeCollectionSize(descriptor: SerialDescriptor): Int {
return bytes.remaining()
}

override fun decodeSequentially() = true

override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
throw IllegalIndexedAccessError()
}
}

internal class FixedDirectDecoder(
private val avro: Avro,
fixedSize: Int,
binaryDecoder: org.apache.avro.io.Decoder,
) : AbstractDecoder() {
override val serializersModule: SerializersModule
get() = avro.serializersModule

private val bytes = ByteArray(fixedSize).also { binaryDecoder.readFixed(it) }
private var nextPosition = 0

override fun decodeByte(): Byte {
return bytes[nextPosition++]
}

override fun decodeCollectionSize(descriptor: SerialDescriptor): Int {
return bytes.size
}

override fun decodeSequentially() = true

override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
throw IllegalIndexedAccessError()
}
}

internal class ArrayBlockDirectDecoder(
private val arraySchema: Schema,
private val decodeFirstBlock: Boolean,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import com.github.avrokotlin.avro4k.Avro
import com.github.avrokotlin.avro4k.internal.DecodedNullError
import com.github.avrokotlin.avro4k.internal.IllegalIndexedAccessError
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.encoding.AbstractDecoder
import kotlinx.serialization.modules.SerializersModule
import org.apache.avro.Schema

internal class MapGenericDecoder(
Expand Down Expand Up @@ -90,26 +88,4 @@ internal class ArrayGenericDecoder(
override fun decodeCollectionSize(descriptor: SerialDescriptor) = collection.size

override fun decodeSequentially() = true
}

internal class ByteArrayGenericDecoder(
private val avro: Avro,
private val bytes: ByteArray,
) : AbstractDecoder() {
override val serializersModule: SerializersModule
get() = avro.serializersModule

private val iterator = bytes.iterator()

override fun decodeByte() = iterator.nextByte()

override fun decodeCollectionSize(descriptor: SerialDescriptor): Int {
return bytes.size
}

override fun decodeSequentially() = true

override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
throw IllegalIndexedAccessError()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,30 +74,18 @@ internal sealed class AbstractAvroDirectEncoder(
): CompositeEncoder {
return when (descriptor.kind) {
StructureKind.LIST ->
encodeResolving(
{ BadEncodedValueError(emptyList<Any?>(), currentWriterSchema, Schema.Type.ARRAY, Schema.Type.BYTES, Schema.Type.FIXED) }
) { schema ->
encodeResolving({ BadEncodedValueError(emptyList<Any?>(), currentWriterSchema, Schema.Type.ARRAY) }) { schema ->
when (schema.type) {
Schema.Type.ARRAY -> {
{ ArrayDirectEncoder(schema, collectionSize, avro, binaryEncoder) }
}

Schema.Type.BYTES -> {
{ BytesDirectEncoder(avro, binaryEncoder, collectionSize) }
}

Schema.Type.FIXED -> {
{ FixedDirectEncoder(schema, collectionSize, avro, binaryEncoder) }
}

else -> null
}
}

StructureKind.MAP ->
encodeResolving(
{ BadEncodedValueError(emptyMap<String, Any?>(), currentWriterSchema, Schema.Type.MAP) }
) { schema ->
encodeResolving({ BadEncodedValueError(emptyMap<String, Any?>(), currentWriterSchema, Schema.Type.MAP) }) { schema ->
when (schema.type) {
Schema.Type.MAP -> {
{ MapDirectEncoder(schema, collectionSize, avro, binaryEncoder) }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
package com.github.avrokotlin.avro4k.internal.encoder.direct

import com.github.avrokotlin.avro4k.Avro
import kotlinx.serialization.SerializationException
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.encoding.AbstractEncoder
import kotlinx.serialization.modules.SerializersModule
import org.apache.avro.Schema

internal class MapDirectEncoder(private val schema: Schema, mapSize: Int, avro: Avro, binaryEncoder: org.apache.avro.io.Encoder) :
Expand Down Expand Up @@ -69,42 +66,4 @@ internal class ArrayDirectEncoder(
override fun endStructure(descriptor: SerialDescriptor) {
binaryEncoder.writeArrayEnd()
}
}

internal class FixedDirectEncoder(schema: Schema, arraySize: Int, private val avro: Avro, private val binaryEncoder: org.apache.avro.io.Encoder) : AbstractEncoder() {
private val buffer = ByteArray(schema.fixedSize)
private var pos = schema.fixedSize - arraySize

override val serializersModule: SerializersModule
get() = avro.serializersModule

init {
if (arraySize > schema.fixedSize) {
throw SerializationException("Actual collection size $arraySize is greater than schema fixed size $schema")
}
}

override fun encodeByte(value: Byte) {
buffer[pos++] = value
}

override fun endStructure(descriptor: SerialDescriptor) {
binaryEncoder.writeFixed(buffer)
}
}

internal class BytesDirectEncoder(private val avro: Avro, private val binaryEncoder: org.apache.avro.io.Encoder, collectionSize: Int) : AbstractEncoder() {
private val buffer = ByteArray(collectionSize)
private var pos = 0

override val serializersModule: SerializersModule
get() = avro.serializersModule

override fun encodeByte(value: Byte) {
buffer[pos++] = value
}

override fun endStructure(descriptor: SerialDescriptor) {
binaryEncoder.writeBytes(buffer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,8 @@ internal class ValueVisitor internal constructor(
val finalDescriptor = SerializerLocatorMiddleware.apply(unwrapNullable(descriptor))

(finalDescriptor.nonNullOriginal as? AvroSchemaSupplier)
?.getSchema(context)
?.let {
setSchema(it)
return
}
super.visitValue(finalDescriptor)
?.getSchema(context)?.let { setSchema(it) }
?: super.visitValue(finalDescriptor)
}

private fun unwrapNullable(descriptor: SerialDescriptor): SerialDescriptor {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ internal class AvroEncodingAssertions<T>(
return this
}

fun generatesSchema(
expectedSchemaResourcePath: Path,
schemaTransformer: (Schema) -> Schema = { it },
): AvroEncodingAssertions<T> {
generatesSchema(Schema.Parser().parse(javaClass.getResourceAsStream(expectedSchemaResourcePath.toString())).let(schemaTransformer))
return this
}

fun isEncodedAs(
expectedEncodedGenericValue: Any?,
expectedDecodedValue: T = valueToEncode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ internal class AvroFixedEncodingTest : StringSpec({
.isEncodedAs(record(GenericData.Fixed(schema, "1234567".toByteArray())))
}

"support fixed on value classes" {
"support fixed on string value classes" {
AvroAssertions.assertThat<FixedNestedStringField>()
.generatesSchema(Path("/fixed_string.json"))

Expand All @@ -36,6 +36,12 @@ internal class AvroFixedEncodingTest : StringSpec({
.isEncodedAs(GenericData.Fixed(Avro.schema<FixedStringValueClass>(), "1234567".toByteArray()))
}

"support @AvroFixed on ByteArray" {
AvroAssertions.assertThat(FixedByteArrayField("1234567".toByteArray()))
.generatesSchema(Path("/fixed_string.json"))
.isEncodedAs(record(GenericData.Fixed(Avro.schema<FixedByteArrayField>().fields[0].schema(), "1234567".toByteArray())))
}

"top-est @AvroFixed annotation takes precedence over nested @AvroFixed annotations" {
AvroAssertions.assertThat<FieldPriorToValueClass>()
.generatesSchema(Path("/fixed_string_5.json"))
Expand Down Expand Up @@ -83,6 +89,25 @@ internal class AvroFixedEncodingTest : StringSpec({
@AvroFixed(7) val mystring: String,
)

@Serializable
@SerialName("Fixed")
private data class FixedByteArrayField(
@AvroFixed(7) val mystring: ByteArray,
) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false

other as FixedByteArrayField

return mystring.contentEquals(other.mystring)
}

override fun hashCode(): Int {
return mystring.contentHashCode()
}
}

@Serializable
@SerialName("Fixed")
private data class FixedNestedStringField(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import kotlinx.serialization.Serializable
import org.apache.avro.Schema

internal class BytesEncodingTest : StringSpec({
"encode/decode nullable ByteArray" {
"encode/decode nullable ByteArray to BYTES" {
AvroAssertions.assertThat(NullableByteArrayTest(byteArrayOf(1, 4, 9)))
.isEncodedAs(record(byteArrayOf(1, 4, 9)))
AvroAssertions.assertThat(NullableByteArrayTest(null))
Expand All @@ -22,7 +22,7 @@ internal class BytesEncodingTest : StringSpec({
.isEncodedAs(null)
}

"encode/decode ByteArray" {
"encode/decode ByteArray to BYTES" {
AvroAssertions.assertThat(ByteArrayTest(byteArrayOf(1, 4, 9)))
.isEncodedAs(record(byteArrayOf(1, 4, 9)))

Expand All @@ -32,24 +32,24 @@ internal class BytesEncodingTest : StringSpec({
.isEncodedAs(byteArrayOf(1, 4, 9))
}

"encode/decode List<Byte>" {
"encode/decode List<Byte> to ARRAY[INT]" {
AvroAssertions.assertThat(ListByteTest(listOf(1, 4, 9)))
.isEncodedAs(record(byteArrayOf(1, 4, 9)))
.isEncodedAs(record(listOf(1, 4, 9)))

AvroAssertions.assertThat<List<Byte>>()
.generatesSchema(Schema.create(Schema.Type.BYTES))
.generatesSchema(Schema.createArray(Schema.create(Schema.Type.INT)))
AvroAssertions.assertThat(listOf<Byte>(1, 4, 9))
.isEncodedAs(byteArrayOf(1, 4, 9))
.isEncodedAs(listOf(1, 4, 9))
}

"encode/decode Array<Byte> to ByteBuffer" {
"encode/decode Array<Byte> to ARRAY[INT]" {
AvroAssertions.assertThat(ArrayByteTest(arrayOf(1, 4, 9)))
.isEncodedAs(record(byteArrayOf(1, 4, 9)))
.isEncodedAs(record(listOf(1, 4, 9)))

AvroAssertions.assertThat<Array<Byte>>()
.generatesSchema(Schema.create(Schema.Type.BYTES))
.generatesSchema(Schema.createArray(Schema.create(Schema.Type.INT)))
AvroAssertions.assertThat(arrayOf<Byte>(1, 4, 9))
.isEncodedAs(byteArrayOf(1, 4, 9))
.isEncodedAs(listOf(1, 4, 9))
}
}) {
@Serializable
Expand Down
Loading

0 comments on commit 489b45a

Please sign in to comment.