diff --git a/protokt-codegen/src/main/kotlin/protokt/v1/codegen/generate/Implements.kt b/protokt-codegen/src/main/kotlin/protokt/v1/codegen/generate/Implements.kt index 7e7806d8..b148af00 100644 --- a/protokt-codegen/src/main/kotlin/protokt/v1/codegen/generate/Implements.kt +++ b/protokt-codegen/src/main/kotlin/protokt/v1/codegen/generate/Implements.kt @@ -16,12 +16,12 @@ package protokt.v1.codegen.generate import com.squareup.kotlinpoet.ClassName -import com.squareup.kotlinpoet.CodeBlock import com.squareup.kotlinpoet.FunSpec import com.squareup.kotlinpoet.KModifier import com.squareup.kotlinpoet.PropertySpec import com.squareup.kotlinpoet.TypeSpec import com.squareup.kotlinpoet.asTypeName +import org.checkerframework.checker.signature.qual.CanonicalName import protokt.v1.codegen.generate.CodeGenerator.Context import protokt.v1.codegen.util.Message import protokt.v1.codegen.util.StandardField @@ -49,59 +49,66 @@ internal object Implements { fun TypeSpec.Builder.handleSuperInterface(msg: Message, ctx: Context) = apply { - if (msg.options.protokt.implements.isNotEmpty()) { - // can't actually delegate because message types are nullable - if (msg.options.protokt.implements.delegates()) { - val interfaceClassName = inferClassName(msg.options.protokt.implements.substringBefore(" by "), ctx) - val fieldsByName = msg.fields.filterIsInstance().associateBy { it.fieldName } - val interfaceFields = - ctx.info.context.classLookup.properties(interfaceClassName.canonicalName) - .associateBy { it.name } - - interfaceFields.values.forEach { - require(it.returnType.isMarkedNullable) { - "Delegated properties must be nullable because message types are nullable; " + - "property ${it.name} is non-nullable" - } - } - - addSuperinterface(interfaceClassName) - interfaceFields.values.filter { it.name !in fieldsByName.keys }.forEach { - addProperty( - PropertySpec.builder( - it.name, - (it.returnType.classifier as KClass<*>).asTypeName().copy(nullable = true) - ) - .addModifiers(KModifier.OVERRIDE) - .getter( - FunSpec.getterBuilder() - .addCode( - CodeBlock.of( - "return %L?.%L", - msg.options.protokt.implements.substringAfter(" by "), - it.name - ) - ) - .build() - ) - .build() - ) - } - } else { - addSuperinterface(msg.superInterface(ctx)!!) + val superInterface = msg.superInterface(ctx) + if (superInterface != null) { + addSuperinterface(superInterface.`interface`) + if (superInterface.delegate != null) { + // can't actually delegate because message types are nullable + delegateProperties(msg, ctx, superInterface.canonicalName, superInterface.delegate) } } } - private fun String.delegates() = - contains(" by ") + private fun TypeSpec.Builder.delegateProperties(msg: Message, ctx: Context, canonicalName: String, fieldName: String) { + val fieldsByName = msg.fields.filterIsInstance().associateBy { it.fieldName } + val interfaceFields = + ctx.info.context.classLookup + .properties(canonicalName) + .associateBy { it.name } - private fun Message.superInterface(ctx: Context) = - options.protokt.implements.let { - if (it.isNotEmpty() && !it.delegates()) { - inferClassName(it, ctx) - } else { - null + val implementFields = interfaceFields.values.filter { it.name !in fieldsByName.keys } + + implementFields.forEach { + require(it.returnType.isMarkedNullable) { + "Delegated properties must be nullable because message types are nullable; " + + "property ${it.name} is non-nullable" } } + + implementFields.forEach { + addProperty( + PropertySpec.builder( + it.name, + (it.returnType.classifier as KClass<*>).asTypeName().copy(nullable = true) + ) + .addModifiers(KModifier.OVERRIDE) + .getter( + FunSpec.getterBuilder() + .addCode("return %L?.%L", fieldName, it.name) + .build() + ) + .build() + ) + } + } + + private class SuperInterface( + val `interface`: ClassName, + val delegate: String? + ) { + val canonicalName = `interface`.canonicalName + } + + private fun Message.superInterface(ctx: Context): SuperInterface? { + val implements = options.protokt.implements + return when { + implements.isEmpty() -> null + implements.contains(" by ") -> + SuperInterface( + inferClassName(implements.substringBefore(" by "), ctx), + implements.substringAfter(" by ") + ) + else -> SuperInterface(inferClassName(implements, ctx), null) + } + } }