diff --git a/src/main/kotlin/com/sourcegraph/cody/agent/CodyAgent.kt b/src/main/kotlin/com/sourcegraph/cody/agent/CodyAgent.kt index 18a20786a2..0878044f0e 100644 --- a/src/main/kotlin/com/sourcegraph/cody/agent/CodyAgent.kt +++ b/src/main/kotlin/com/sourcegraph/cody/agent/CodyAgent.kt @@ -12,6 +12,7 @@ import com.intellij.util.system.CpuArch import com.sourcegraph.cody.agent.protocol.* import com.sourcegraph.cody.agent.protocol_extensions.ClientCapabilitiesFactory import com.sourcegraph.cody.agent.protocol_extensions.ClientInfoFactory +import com.sourcegraph.cody.agent.protocol_generated.ProtocolTypeAdapters import com.sourcegraph.cody.vscode.CancellationToken import com.sourcegraph.config.ConfigUtil import java.io.* @@ -271,16 +272,24 @@ private constructor( ): Launcher { return Launcher.Builder() .configureGson { gsonBuilder -> - gsonBuilder - // emit `null` instead of leaving fields undefined because Cody - // VSC has many `=== null` checks that return false for undefined fields. - .serializeNulls() - .registerTypeAdapter(CompletionItemID::class.java, CompletionItemIDSerializer) - .registerTypeAdapter(ContextItem::class.java, ContextItem.deserializer) - .registerTypeAdapter(Speaker::class.java, speakerDeserializer) - .registerTypeAdapter(Speaker::class.java, speakerSerializer) - .registerTypeAdapter(URI::class.java, uriDeserializer) - .registerTypeAdapter(URI::class.java, uriSerializer) + run { + gsonBuilder + // emit `null` instead of leaving fields undefined because Cody + // VSC has many `=== null` checks that return false for undefined fields. + .serializeNulls() + .registerTypeAdapter(ContextItem::class.java, ContextItem.deserializer) + .registerTypeAdapter(CompletionItemID::class.java, CompletionItemIDSerializer) + // TODO: Remove legacy enum conversions + .registerTypeAdapter(Speaker::class.java, speakerDeserializer) + .registerTypeAdapter(Speaker::class.java, speakerSerializer) + .registerTypeAdapter(URI::class.java, uriDeserializer) + .registerTypeAdapter(URI::class.java, uriSerializer) + + ProtocolTypeAdapters.register(gsonBuilder) + // This ensures that by default all enums are always serialized to their string + // equivalents + gsonBuilder.registerTypeAdapterFactory(EnumTypeAdapterFactory()) + } } .setRemoteInterface(CodyAgentServer::class.java) .traceMessages(traceWriter()) diff --git a/src/main/kotlin/com/sourcegraph/cody/agent/EnumTypeAdapterFactory.kt b/src/main/kotlin/com/sourcegraph/cody/agent/EnumTypeAdapterFactory.kt new file mode 100644 index 0000000000..875ee70dbd --- /dev/null +++ b/src/main/kotlin/com/sourcegraph/cody/agent/EnumTypeAdapterFactory.kt @@ -0,0 +1,57 @@ +package com.sourcegraph.cody.agent + +import com.google.gson.* +import com.google.gson.annotations.SerializedName +import com.google.gson.reflect.TypeToken +import com.google.gson.stream.JsonReader +import com.google.gson.stream.JsonWriter +import java.io.IOException +import java.lang.reflect.Field + +class EnumTypeAdapterFactory : TypeAdapterFactory { + override fun create(gson: Gson, type: TypeToken): TypeAdapter? { + val rawType = type.rawType as? Class<*> ?: return null + if (!rawType.isEnum) { + return null + } + @Suppress("UNCHECKED_CAST") + return EnumTypeAdapter(rawType as Class>) as TypeAdapter + } +} + +class EnumTypeAdapter>(private val classOfT: Class) : TypeAdapter() { + private val nameToConstant: Map = HashMap() + private val constantToName: Map = HashMap() + + init { + for (constant in classOfT.enumConstants) { + val name = getSerializedName(constant) ?: constant.name + (nameToConstant as HashMap)[name.lowercase()] = constant + (constantToName as HashMap)[constant] = name + } + } + + private fun getSerializedName(enumConstant: T): String? { + return try { + val field: Field = classOfT.getField(enumConstant.name) + field.getAnnotation(SerializedName::class.java)?.value + } catch (e: NoSuchFieldException) { + null + } + } + + override fun write(out: JsonWriter, value: T?) { + if (value == null) { + out.nullValue() + } else { + out.value(constantToName[value]) + } + } + + @Throws(IOException::class) + override fun read(`in`: JsonReader): T? { + val value = `in`.nextString() + return nameToConstant[value.lowercase()] + ?: throw JsonParseException("Unknown enum value: $value") + } +}