Skip to content

Commit

Permalink
Leverage KType in Kotlin Serialization WebFlux support
Browse files Browse the repository at this point in the history
In order to take in account properly Kotlin null-safety with the
annotation programming model.

Closes spring-projectsgh-33016
  • Loading branch information
sdeleuze committed Jul 1, 2024
1 parent 23dccc5 commit 98e89d8
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,29 @@

package org.springframework.http.codec;

import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import kotlin.reflect.KFunction;
import kotlin.reflect.KType;
import kotlin.reflect.full.KCallables;
import kotlin.reflect.jvm.ReflectJvmMapping;
import kotlinx.serialization.KSerializer;
import kotlinx.serialization.SerialFormat;
import kotlinx.serialization.SerializersKt;
import kotlinx.serialization.descriptors.PolymorphicKind;
import kotlinx.serialization.descriptors.SerialDescriptor;

import org.springframework.core.KotlinDetector;
import org.springframework.core.MethodParameter;
import org.springframework.core.ResolvableType;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ConcurrentReferenceHashMap;
import org.springframework.util.MimeType;

Expand All @@ -46,7 +54,10 @@
*/
public abstract class KotlinSerializationSupport<T extends SerialFormat> {

private final Map<Type, KSerializer<Object>> serializerCache = new ConcurrentReferenceHashMap<>();
private final Map<Type, KSerializer<Object>> typeSerializerCache = new ConcurrentReferenceHashMap<>();

private final Map<KType, KSerializer<Object>> kTypeSerializerCache = new ConcurrentReferenceHashMap<>();


private final T format;

Expand Down Expand Up @@ -117,8 +128,33 @@ private boolean supports(@Nullable MimeType mimeType) {
*/
@Nullable
protected final KSerializer<Object> serializer(ResolvableType resolvableType) {
if (resolvableType.getSource() instanceof MethodParameter parameter) {
Method method = parameter.getMethod();
Assert.notNull(method, "Method must not be null");
if (KotlinDetector.isKotlinType(method.getDeclaringClass())) {
KFunction<?> function = ReflectJvmMapping.getKotlinFunction(method);
Assert.notNull(function, "Kotlin function must not be null");
KType type = (parameter.getParameterIndex() == -1 ? function.getReturnType() :
KCallables.getValueParameters(function).get(parameter.getParameterIndex()).getType());
KSerializer<Object> serializer = this.kTypeSerializerCache.get(type);
if (serializer == null) {
try {
serializer = SerializersKt.serializerOrNull(this.format.getSerializersModule(), type);
}
catch (IllegalArgumentException ignored) {
}
if (serializer != null) {
if (hasPolymorphism(serializer.getDescriptor(), new HashSet<>())) {
return null;
}
this.kTypeSerializerCache.put(type, serializer);
}
}
return serializer;
}
}
Type type = resolvableType.getType();
KSerializer<Object> serializer = this.serializerCache.get(type);
KSerializer<Object> serializer = this.typeSerializerCache.get(type);
if (serializer == null) {
try {
serializer = SerializersKt.serializerOrNull(this.format.getSerializersModule(), type);
Expand All @@ -129,7 +165,7 @@ protected final KSerializer<Object> serializer(ResolvableType resolvableType) {
if (hasPolymorphism(serializer.getDescriptor(), new HashSet<>())) {
return null;
}
this.serializerCache.put(type, serializer);
this.typeSerializerCache.put(type, serializer);
}
}
return serializer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ package org.springframework.http.codec.json
import kotlinx.serialization.Serializable
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test
import org.springframework.core.MethodParameter
import org.springframework.core.Ordered
import org.springframework.core.ResolvableType
import org.springframework.core.io.buffer.DataBuffer
import org.springframework.core.io.buffer.DataBufferUtils
import org.springframework.core.testfixture.codec.AbstractDecoderTests
import org.springframework.http.MediaType
import reactor.core.publisher.Flux
Expand All @@ -32,6 +34,7 @@ import java.lang.UnsupportedOperationException
import java.math.BigDecimal
import java.nio.charset.Charset
import java.nio.charset.StandardCharsets
import kotlin.reflect.jvm.javaMethod

/**
* Tests for the JSON decoding using kotlinx.serialization.
Expand Down Expand Up @@ -128,6 +131,22 @@ class KotlinSerializationJsonDecoderTests : AbstractDecoderTests<KotlinSerializa
}, null, null)
}

@Test
fun decodeToMonoWithNullableWithNull() {
val input = Flux.concat(
stringBuffer("{\"value\":null}\n"),
)

val methodParameter = MethodParameter.forExecutable(::handleMapWithNullable::javaMethod.get()!!, -1)
val elementType = ResolvableType.forMethodParameter(methodParameter)

testDecodeToMonoAll(input, elementType, {
it.expectNext(mapOf("value" to null))
.expectComplete()
.verify()
}, null, null)
}

private fun stringBuffer(value: String): Mono<DataBuffer> {
return stringBuffer(value, StandardCharsets.UTF_8)
}
Expand All @@ -145,4 +164,6 @@ class KotlinSerializationJsonDecoderTests : AbstractDecoderTests<KotlinSerializa
@Serializable
data class Pojo(val foo: String, val bar: String, val pojo: Pojo? = null)

fun handleMapWithNullable(map: Map<String, String?>) = map

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.springframework.http.codec.json
import kotlinx.serialization.Serializable
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test
import org.springframework.core.MethodParameter
import org.springframework.core.Ordered
import org.springframework.core.ResolvableType
import org.springframework.core.io.buffer.DataBuffer
Expand All @@ -31,6 +32,7 @@ import reactor.core.publisher.Mono
import reactor.test.StepVerifier.FirstStep
import java.math.BigDecimal
import java.nio.charset.StandardCharsets
import kotlin.reflect.jvm.javaMethod

/**
* Tests for the JSON encoding using kotlinx.serialization.
Expand Down Expand Up @@ -109,6 +111,17 @@ class KotlinSerializationJsonEncoderTests : AbstractEncoderTests<KotlinSerializa
}
}

@Test
fun encodeMonoWithNullableWithNull() {
val input = Mono.just(mapOf("value" to null))
val methodParameter = MethodParameter.forExecutable(::handleMapWithNullable::javaMethod.get()!!, -1)
testEncode(input, ResolvableType.forMethodParameter(methodParameter), null, null) {
it.consumeNextWith(expectString("{\"value\":null}")
.andThen { dataBuffer: DataBuffer? -> DataBufferUtils.release(dataBuffer) })
.verifyComplete()
}
}

@Test
fun canNotEncode() {
assertThat(encoder.canEncode(ResolvableType.forClass(String::class.java), null)).isFalse()
Expand All @@ -123,4 +136,6 @@ class KotlinSerializationJsonEncoderTests : AbstractEncoderTests<KotlinSerializa
@Serializable
data class Pojo(val foo: String, val bar: String, val pojo: Pojo? = null)

fun handleMapWithNullable(map: Map<String, String?>) = map

}

0 comments on commit 98e89d8

Please sign in to comment.