Skip to content

Commit

Permalink
Added support for null values for nullable enums in lanient mode (#2176)
Browse files Browse the repository at this point in the history
Fixed #2170

Co-authored-by: Leonid Startsev <sandwwraith@users.noreply.github.com>
  • Loading branch information
shanshin and sandwwraith authored Feb 6, 2023
1 parent b454f34 commit 2cb7f7d
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package kotlinx.serialization.json

import kotlinx.serialization.*
import kotlinx.serialization.json.internal.*
import kotlinx.serialization.test.assertFailsWithSerial
import kotlin.test.*

Expand All @@ -25,6 +24,11 @@ class JsonCoerceInputValuesTest : JsonTestBase() {
val foo: String
)

@Serializable
data class NullableEnumHolder(
val enum: SampleEnum?
)

val json = Json {
coerceInputValues = true
isLenient = true
Expand Down Expand Up @@ -99,4 +103,13 @@ class JsonCoerceInputValuesTest : JsonTestBase() {
assertEquals(expected, json.decodeFromString(MultipleValues.serializer(), input), "Failed on input: $input")
}
}

@Test
fun testNullSupportForEnums() = parametrizedTest(json) {
var decoded = decodeFromString<NullableEnumHolder>("""{"enum": null}""")
assertNull(decoded.enum)

decoded = decodeFromString<NullableEnumHolder>("""{"enum": OptionA}""")
assertEquals(SampleEnum.OptionA, decoded.enum)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,16 @@ internal fun SerialDescriptor.getJsonNameIndexOrThrow(json: Json, name: String,
@OptIn(ExperimentalSerializationApi::class)
internal inline fun Json.tryCoerceValue(
elementDescriptor: SerialDescriptor,
peekNull: () -> Boolean,
peekNull: (consume: Boolean) -> Boolean,
peekString: () -> String?,
onEnumCoercing: () -> Unit = {}
): Boolean {
if (!elementDescriptor.isNullable && peekNull()) return true
if (!elementDescriptor.isNullable && peekNull(true)) return true
if (elementDescriptor.kind == SerialKind.ENUM) {
if (elementDescriptor.isNullable && peekNull(false)) {
return false
}

val enumValue = peekString()
?: return false // if value is not a string, decodeEnum() will throw correct exception
val enumIndex = elementDescriptor.getJsonNameIndex(this, enumValue)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ internal open class StreamingJsonDecoder(
}

override fun decodeNotNullMark(): Boolean {
return !(elementMarker?.isUnmarkedNull ?: false) && lexer.tryConsumeNotNull()
return !(elementMarker?.isUnmarkedNull ?: false) && !lexer.tryConsumeNull()
}

override fun decodeNull(): Nothing? {
Expand Down Expand Up @@ -208,7 +208,7 @@ internal open class StreamingJsonDecoder(
*/
private fun coerceInputValue(descriptor: SerialDescriptor, index: Int): Boolean = json.tryCoerceValue(
descriptor.getElementDescriptor(index),
{ !lexer.tryConsumeNotNull() },
{ lexer.tryConsumeNull(it) },
{ lexer.peekString(configuration.isLenient) },
{ lexer.consumeString() /* skip unknown enum string*/ }
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,25 +244,28 @@ internal abstract class AbstractJsonLexer {

/**
* Tries to consume `null` token from input.
* Returns `true` if the next 4 chars in input are not `null`,
* `false` otherwise and consumes it.
* Returns `false` if the next 4 chars in input are not `null`,
* `true` otherwise and consumes it if [doConsume] is `true`.
*/
fun tryConsumeNotNull(): Boolean {
fun tryConsumeNull(doConsume: Boolean = true): Boolean {
var current = skipWhitespaces()
current = prefetchOrEof(current)
// Cannot consume null due to EOF, maybe something else
val len = source.length - current
if (len < 4 || current == -1) return true
if (len < 4 || current == -1) return false
for (i in 0..3) {
if (NULL[i] != source[current + i]) return true
if (NULL[i] != source[current + i]) return false
}
/*
* If we're in lenient mode, this might be the string with 'null' prefix,
* distinguish it from 'null'
*/
if (len > 4 && charToTokenClass(source[current + 4]) == TC_OTHER) return true
currentPosition = current + 4
return false
if (len > 4 && charToTokenClass(source[current + 4]) == TC_OTHER) return false

if (doConsume) {
currentPosition = current + 4
}
return true
}

open fun skipWhitespaces(): Int {
Expand Down

0 comments on commit 2cb7f7d

Please sign in to comment.