Skip to content

KRPC-146 Nested types in gRPC #331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 30, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion protobuf-plugin/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -40,7 +40,6 @@ sourceSets {
"**/funny_types.proto",
"**/map.proto",
"**/multiple_files.proto",
"**/nested.proto",
"**/one_of.proto",
"**/options.proto",
"**/with_comments.proto",
Original file line number Diff line number Diff line change
@@ -7,7 +7,6 @@ package kotlinx.rpc.protobuf
import kotlinx.rpc.protobuf.CodeGenerator.DeclarationType
import kotlinx.rpc.protobuf.model.*
import org.slf4j.Logger
import kotlin.sequences.forEach

private const val RPC_INTERNAL_PACKAGE_SUFFIX = "_rpc_internal"

@@ -93,8 +92,27 @@ class ModelToKotlinGenerator(

private fun CodeGenerator.generateInternalDeclaredEntities(fileDeclaration: FileDeclaration) {
fileDeclaration.messageDeclarations.forEach { generateInternalMessage(it) }
fileDeclaration.enumDeclarations.forEach { generateInternalEnum(it) }
fileDeclaration.serviceDeclarations.forEach { generateInternalService(it) }

fileDeclaration.messageDeclarations.forEach {
generateToAndFromPlatformCastsRec(it)
}

fileDeclaration.enumDeclarations.forEach {
generateToAndFromPlatformCastsEnum(it)
}
}

private fun CodeGenerator.generateToAndFromPlatformCastsRec(declaration: MessageDeclaration) {
generateToAndFromPlatformCasts(declaration)

declaration.nestedDeclarations.forEach { nested ->
generateToAndFromPlatformCastsRec(nested)
}

declaration.enumDeclarations.forEach { nested ->
generateToAndFromPlatformCastsEnum(nested)
}
}

private fun MessageDeclaration.fields() = actualFields.map {
@@ -112,20 +130,20 @@ class ModelToKotlinGenerator(
newLine()
}

newLine()

// KRPC-147 OneOf Types
// declaration.oneOfDeclarations.forEach { oneOf ->
// generateOneOf(oneOf)
// }
//
// KRPC-146 Nested Types
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

// declaration.nestedDeclarations.forEach { nested ->
// generateMessage(nested)
// }
//
// KRPC-141 Enum Types
// declaration.enumDeclarations.forEach { enum ->
// generateEnum(enum)
// }
declaration.nestedDeclarations.forEach { nested ->
generatePublicMessage(nested)
}

declaration.enumDeclarations.forEach { enum ->
generatePublicEnum(enum)
}

clazz("", modifiers = "companion", declarationType = DeclarationType.Object)
}
@@ -157,20 +175,25 @@ class ModelToKotlinGenerator(
code("override var $fieldDeclaration $value")
newLine()
}

declaration.nestedDeclarations.forEach { nested ->
generateInternalMessage(nested)
}
}
}

private fun CodeGenerator.generateToAndFromPlatformCasts(declaration: MessageDeclaration) {
function(
name = "invoke",
modifiers = "operator",
args = "body: ${declaration.name.simpleName}Builder.() -> Unit",
args = "body: ${declaration.name.safeFullName("Builder")}.() -> Unit",
contextReceiver = "${declaration.name.safeFullName()}.Companion",
returnType = declaration.name.safeFullName(),
) {
code("return ${declaration.name.simpleName}Builder().apply(body)")
code("return ${declaration.name.safeFullName("Builder")}().apply(body)")
}

val platformType = "${declaration.outerClassName.safeFullName()}.${declaration.name.simpleName}"

val platformType = "${declaration.outerClassName.safeFullName()}.${declaration.name.fullNestedName()}"
function(
name = "toPlatform",
contextReceiver = declaration.name.safeFullName(),
@@ -198,7 +221,7 @@ class ModelToKotlinGenerator(
contextReceiver = platformType,
returnType = declaration.name.safeFullName(),
) {
scope("return ${declaration.name.simpleName}") {
scope("return ${declaration.name.safeFullName()}") {
declaration.actualFields.forEach { field ->
val javaName = when (field.type) {
is FieldType.List -> "${field.name}List"
@@ -361,8 +384,8 @@ class ModelToKotlinGenerator(
}

@Suppress("unused")
private fun CodeGenerator.generateInternalEnum(declaration: EnumDeclaration) {
val platformType = "${declaration.outerClassName.safeFullName()}.${declaration.name.simpleName}"
private fun CodeGenerator.generateToAndFromPlatformCastsEnum(declaration: EnumDeclaration) {
val platformType = "${declaration.outerClassName.safeFullName()}.${declaration.name.fullNestedName()}"

function(
name = "toPlatform",
@@ -371,11 +394,11 @@ class ModelToKotlinGenerator(
) {
scope("return when (this)") {
declaration.aliases.forEach { field ->
code("${declaration.name.simpleName}.${field.name.simpleName} -> $platformType.${field.name.simpleName}")
code("${declaration.name.fullNestedName()}.${field.name.simpleName} -> $platformType.${field.name.simpleName}")
}

declaration.originalEntries.forEach { field ->
code("${declaration.name.simpleName}.${field.name.simpleName} -> $platformType.${field.name.simpleName}")
code("${declaration.name.fullNestedName()}.${field.name.simpleName} -> $platformType.${field.name.simpleName}")
}
}
}
@@ -387,11 +410,11 @@ class ModelToKotlinGenerator(
) {
scope("return when (this)") {
declaration.aliases.forEach { field ->
code("$platformType.${field.name.simpleName} -> ${declaration.name.simpleName}.${field.name.simpleName}")
code("$platformType.${field.name.simpleName} -> ${declaration.name.fullNestedName()}.${field.name.simpleName}")
}

declaration.originalEntries.forEach { field ->
code("$platformType.${field.name.simpleName} -> ${declaration.name.simpleName}.${field.name.simpleName}")
code("$platformType.${field.name.simpleName} -> ${declaration.name.fullNestedName()}.${field.name.simpleName}")
}
}
}
@@ -528,13 +551,13 @@ class ModelToKotlinGenerator(
}

private fun MessageDeclaration.toPlatformMessageType(): String {
return "${outerClassName.safeFullName()}.${name.simpleName}"
return "${outerClassName.safeFullName()}.${name.fullNestedName()}"
}

private fun FqName.safeFullName(): String {
private fun FqName.safeFullName(classSuffix: String = ""): String {
importRootDeclarationIfNeeded(this)

return fullName()
return fullName(classSuffix)
}

private fun importRootDeclarationIfNeeded(
Original file line number Diff line number Diff line change
@@ -97,7 +97,7 @@ class ProtoToModelInterpreter(
val fqName = parentResolver.declarationFqName(simpleName, parent ?: packageName)
val resolver = parentResolver.withScope(fqName)

val fields = fieldList.asSequence().mapNotNull {
val fields = fieldList.mapNotNull {
val oneOfName = if (it.hasOneofIndex()) {
oneofDeclList[it.oneofIndex].name
} else {
@@ -111,10 +111,9 @@ class ProtoToModelInterpreter(
outerClassName = outerClass,
name = fqName,
actualFields = fields,
oneOfDeclarations = oneofDeclList.asSequence().mapIndexedNotNull { i, desc -> desc.toModel(i, resolver) },
enumDeclarations = enumTypeList.asSequence()
.map { it.toModel(resolver, outerClass, parent ?: packageName) },
nestedDeclarations = nestedTypeList.asSequence().map { it.toModel(resolver, outerClass, fqName) },
oneOfDeclarations = oneofDeclList.mapIndexedNotNull { i, desc -> desc.toModel(i, resolver) },
enumDeclarations = enumTypeList.map { it.toModel(resolver, outerClass, fqName) },
nestedDeclarations = nestedTypeList.map { it.toModel(resolver, outerClass, fqName) },
deprecated = options.deprecated,
doc = null,
).apply {
Original file line number Diff line number Diff line change
@@ -53,12 +53,29 @@ sealed interface FqName {
}
}

internal fun FqName.fullName(): String {
internal fun FqName.fullName(classSuffix: String = ""): String {
val parentName = parent
val name = if (this is FqName.Declaration) "$simpleName$classSuffix" else simpleName
return when {
parentName == this -> simpleName
parentName == this -> name
else -> {
val fullParentName = parentName.fullName()
val fullParentName = parentName.fullName(classSuffix)
if (fullParentName.isEmpty()) {
name
} else {
"$fullParentName.$name"
}
}
}
}

internal fun FqName.fullNestedName(): String {
val parentName = parent
return when (parentName) {
is FqName.Package -> simpleName
this -> simpleName
else -> {
val fullParentName = parentName.fullNestedName()
if (fullParentName.isEmpty()) {
simpleName
} else {
Original file line number Diff line number Diff line change
@@ -7,10 +7,10 @@ package kotlinx.rpc.protobuf.model
data class MessageDeclaration(
val outerClassName: FqName,
val name: FqName,
val actualFields: Sequence<FieldDeclaration>, // excludes oneOf fields, but includes oneOf itself
val oneOfDeclarations: Sequence<OneOfDeclaration>,
val enumDeclarations: Sequence<EnumDeclaration>,
val nestedDeclarations: Sequence<MessageDeclaration>,
val actualFields: List<FieldDeclaration>, // excludes oneOf fields, but includes oneOf itself
val oneOfDeclarations: List<OneOfDeclaration>,
val enumDeclarations: List<EnumDeclaration>,
val nestedDeclarations: List<MessageDeclaration>,
val deprecated: Boolean,
val doc: String?,
)
Original file line number Diff line number Diff line change
@@ -177,4 +177,17 @@ internal class NameResolver private constructor(
return _list!!
}
}

@Suppress("unused")
fun pprint(): String {
return buildString { pprint(root, 0) }
}

private fun StringBuilder.pprint(node: Node, indent: Int) {
val spaces = " ".repeat(indent)
appendLine("$spaces${node.fqName.fullName()}")
for (child in node.children.values) {
pprint(child, indent + 4)
}
}
}
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@ import kotlinx.rpc.registerService
import kotlinx.rpc.withService
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertNotNull

class ReferenceTestServiceImpl : ReferenceTestService {
override suspend fun Get(message: References): kotlinx.rpc.protobuf.test.References {
@@ -36,6 +37,10 @@ class ReferenceTestServiceImpl : ReferenceTestService {
override suspend fun Repeated(message: Repeated): Repeated {
return message
}

override suspend fun Nested(message: Nested): Nested {
return message
}
}

class TestReferenceService : GrpcServerTest() {
@@ -44,7 +49,7 @@ class TestReferenceService : GrpcServerTest() {
}

@Test
fun testReferenceService()= runGrpcTest { grpcClient ->
fun testReferenceService() = runGrpcTest { grpcClient ->
val service = grpcClient.withService<ReferenceTestService>()
val result = service.Get(References {
other = Other {
@@ -113,4 +118,83 @@ class TestReferenceService : GrpcServerTest() {
assertEquals(emptyList(), resultEmpty.listString)
assertEquals(emptyList(), resultEmpty.listReference)
}

@Test
fun testNested() = runGrpcTest { grpcClient ->
val service = grpcClient.withService<ReferenceTestService>()
val result = service.Nested(Nested {
inner1 = Nested.Inner1 {
inner11 = Nested.Inner1.Inner11 {
reference21 = null
reference12 = Nested.Inner1.Inner12 {
recursion = null
}
enum = Nested.Inner2.NestedEnum.ZERO
}

inner22 = Nested.Inner1.Inner12 {
recursion = Nested.Inner1.Inner12 {
recursion = null
}
}

string = "42_1"

inner1 = null
}

inner2 = Nested.Inner2 {
inner21 = Nested.Inner2.Inner21 {
reference11 = Nested.Inner1.Inner11 {
reference21 = null
reference12 = Nested.Inner1.Inner12 {
recursion = null
}
enum = Nested.Inner2.NestedEnum.ZERO
}

reference22 = Nested.Inner2.Inner22 {
enum = Nested.Inner2.NestedEnum.ZERO
}
}

inner22 = Nested.Inner2.Inner22 {
enum = Nested.Inner2.NestedEnum.ZERO
}
string = "42_2"
}

string = "42"
enum = Nested.Inner2.NestedEnum.ZERO
})

// Assert Inner1.Inner11
assertEquals(null, result.inner1.inner11.reference21)
assertEquals(null, result.inner1.inner11.reference12.recursion)
assertEquals(Nested.Inner2.NestedEnum.ZERO, result.inner1.inner11.enum)

// Assert Inner1.Inner12
assertNotNull(result.inner1.inner22.recursion)
assertEquals(null, result.inner1.inner22.recursion?.recursion)

// Assert Inner1
assertEquals("42_1", result.inner1.string)
assertEquals(null, result.inner1.inner1)

// Assert Inner2.Inner21
assertEquals(null, result.inner2.inner21.reference11.reference21)
assertEquals(null, result.inner2.inner21.reference11.reference12.recursion)
assertEquals(Nested.Inner2.NestedEnum.ZERO, result.inner2.inner21.reference11.enum)
assertEquals(Nested.Inner2.NestedEnum.ZERO, result.inner2.inner21.reference22.enum)

// Assert Inner2.Inner22
assertEquals(Nested.Inner2.NestedEnum.ZERO, result.inner2.inner22.enum)

// Assert Inner2
assertEquals("42_2", result.inner2.string)

// Assert root Nested
assertEquals("42", result.string)
assertEquals(Nested.Inner2.NestedEnum.ZERO, result.enum)
}
}
5 changes: 3 additions & 2 deletions protobuf-plugin/src/test/proto/nested.proto
Original file line number Diff line number Diff line change
@@ -5,18 +5,19 @@ package kotlinx.rpc.protobuf.test;
message Nested {
message Inner1 {
message Inner11 {
Nested.Inner2.Inner21 reference21 = 1;
optional Nested.Inner2.Inner21 reference21 = 1;
Nested.Inner1.Inner12 reference12 = 2;
Nested.Inner2.NestedEnum enum = 3;
}

message Inner12 {
Inner12 recursion = 1;
optional Inner12 recursion = 1;
}

Inner11 inner11 = 1;
Inner12 inner22 = 2;
string string = 3;
optional Inner1 inner1 = 4;
}

message Inner2 {
3 changes: 3 additions & 0 deletions protobuf-plugin/src/test/proto/reference_service.proto
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@ import "reference_package.proto";
import "enum.proto";
import "optional.proto";
import "repeated.proto";
import "nested.proto";

service ReferenceTestService {
rpc Get(References) returns (kotlinx.rpc.protobuf.test.References);
@@ -14,4 +15,6 @@ service ReferenceTestService {
rpc Optional(kotlinx.rpc.protobuf.test.OptionalTypes) returns (kotlinx.rpc.protobuf.test.OptionalTypes);

rpc Repeated(kotlinx.rpc.protobuf.test.Repeated) returns (kotlinx.rpc.protobuf.test.Repeated);

rpc Nested(kotlinx.rpc.protobuf.test.Nested) returns (kotlinx.rpc.protobuf.test.Nested);
}