From 43230dd11c7b4c63609aeb7311c96eaf02f82962 Mon Sep 17 00:00:00 2001 From: "Philip K. Warren" Date: Fri, 17 Nov 2023 09:43:18 -0600 Subject: [PATCH 1/3] Add method stream type to generated code It is useful to have the stream type in the generated code (similar to other Connect implementations). In the future we may use this information to optimize stream handling - for example, avoiding the pipe/duplex/oneshot for server streaming calls to enable Connect over HTTP/1.1. Update the generator to include the stream type in the generated code. Additionally, move the HTTP method to the HTTPRequest class - mutating a method spec to determine which HTTP method to use seems wrong. We should keep MethodSpec to contain only fields set from the generator. --- .../main/kotlin/com/connectrpc/MethodSpec.kt | 9 ++--- .../main/kotlin/com/connectrpc/StreamType.kt | 35 +++++++++++++++++++ .../kotlin/com/connectrpc/http/HTTPRequest.kt | 11 ++++++ .../com/connectrpc/impl/ProtocolClient.kt | 2 +- .../protocols/ConnectInterceptor.kt | 18 ++++------ .../com/connectrpc/InterceptorChainTest.kt | 18 +++++++--- .../impl/BiDirectionalStreamTest.kt | 3 ++ .../com/connectrpc/impl/ProtocolClientTest.kt | 7 ++++ .../protocols/ConnectInterceptorTest.kt | 24 +++++++++---- .../protocols/GRPCInterceptorTest.kt | 8 +++++ .../protocols/GRPCWebInterceptorTest.kt | 8 +++++ .../connectrpc/okhttp/ConnectOkHttpClient.kt | 4 +-- .../connectrpc/protocgen/connect/Generator.kt | 30 +++++++++++----- 13 files changed, 137 insertions(+), 40 deletions(-) create mode 100644 library/src/main/kotlin/com/connectrpc/StreamType.kt diff --git a/library/src/main/kotlin/com/connectrpc/MethodSpec.kt b/library/src/main/kotlin/com/connectrpc/MethodSpec.kt index 4141d5c6..08842422 100644 --- a/library/src/main/kotlin/com/connectrpc/MethodSpec.kt +++ b/library/src/main/kotlin/com/connectrpc/MethodSpec.kt @@ -16,11 +16,6 @@ package com.connectrpc import kotlin.reflect.KClass -internal object Method { - internal const val GET_METHOD = "GET" - internal const val POST_METHOD = "POST" -} - /** * Represents the minimum set of information to execute an RPC method. * Primarily used in generated code. @@ -29,12 +24,12 @@ internal object Method { * @param requestClass The Kotlin Class for the request message. * @param responseClass The Kotlin Class for the response message. * @param idempotency The declared idempotency of a method. - * @param method The HTTP method of a request. + * @param streamType The method's stream type. */ class MethodSpec( val path: String, val requestClass: KClass, val responseClass: KClass, val idempotency: Idempotency = Idempotency.UNKNOWN, - val method: String = Method.POST_METHOD, + val streamType: StreamType = StreamType.UNKNOWN, ) diff --git a/library/src/main/kotlin/com/connectrpc/StreamType.kt b/library/src/main/kotlin/com/connectrpc/StreamType.kt new file mode 100644 index 00000000..ba25e811 --- /dev/null +++ b/library/src/main/kotlin/com/connectrpc/StreamType.kt @@ -0,0 +1,35 @@ +// Copyright 2022-2023 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.connectrpc + +/** + * Represents the RPC stream type. Set by the code generator on each [MethodSpec]. + */ +enum class StreamType { + /** Unknown stream type. */ + UNKNOWN, + + /** Unary RPC. */ + UNARY, + + /** Client streaming RPC. */ + CLIENT, + + /** Server streaming RPC. */ + SERVER, + + /** Bidirectional streaming RPC. */ + BIDI, +} diff --git a/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt b/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt index 7b4a0ffc..60002b3d 100644 --- a/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt +++ b/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt @@ -18,6 +18,11 @@ import com.connectrpc.Headers import com.connectrpc.MethodSpec import java.net.URL +internal object HTTPMethod { + internal const val GET = "GET" + internal const val POST = "POST" +} + /** * HTTP request used for sending primitive data to the server. */ @@ -32,6 +37,9 @@ class HTTPRequest internal constructor( val message: ByteArray? = null, // The method spec associated with the request. val methodSpec: MethodSpec<*, *>, + // HTTP method to use with the request. + // Almost always POST, but side effect free unary RPCs may be made with GET. + val httpMethod: String = HTTPMethod.POST, ) { /** * Clones the [HTTPRequest] with override values. @@ -50,6 +58,8 @@ class HTTPRequest internal constructor( message: ByteArray? = this.message, // The method spec associated with the request. methodSpec: MethodSpec<*, *> = this.methodSpec, + // The HTTP method to use with the request. + httpMethod: String = this.httpMethod, ): HTTPRequest { return HTTPRequest( url, @@ -57,6 +67,7 @@ class HTTPRequest internal constructor( headers, message, methodSpec, + httpMethod, ) } } diff --git a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt index 79aa94bc..e0441341 100644 --- a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt +++ b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt @@ -216,7 +216,7 @@ class ProtocolClient( isComplete = true when (streamResult.code) { Code.OK -> channel.close() - else -> channel.close(streamResult.connectException() ?: ConnectException(code = streamResult.code)) + else -> channel.close(streamResult.connectException() ?: ConnectException(code = streamResult.code, exception = streamResult.cause)) } } } diff --git a/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt b/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt index fbc96c09..6fb30ea1 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt @@ -21,15 +21,15 @@ import com.connectrpc.ConnectException import com.connectrpc.Headers import com.connectrpc.Idempotency import com.connectrpc.Interceptor -import com.connectrpc.Method.GET_METHOD -import com.connectrpc.MethodSpec import com.connectrpc.ProtocolClientConfig import com.connectrpc.RequestCompression import com.connectrpc.StreamFunction import com.connectrpc.StreamResult +import com.connectrpc.StreamType import com.connectrpc.Trailers import com.connectrpc.UnaryFunction import com.connectrpc.compression.CompressionPool +import com.connectrpc.http.HTTPMethod import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse import com.squareup.moshi.Moshi @@ -189,7 +189,8 @@ internal class ConnectInterceptor( } private fun shouldUseGETRequest(request: HTTPRequest, finalRequestBody: Buffer): Boolean { - return request.methodSpec.idempotency == Idempotency.NO_SIDE_EFFECTS && + return request.methodSpec.streamType == StreamType.UNARY && + request.methodSpec.idempotency == Idempotency.NO_SIDE_EFFECTS && clientConfig.getConfiguration.useGET(finalRequestBody) } @@ -210,13 +211,8 @@ internal class ConnectInterceptor( url = url, contentType = "application/${requestCodec.encodingName()}", headers = request.headers, - methodSpec = MethodSpec( - path = request.methodSpec.path, - requestClass = request.methodSpec.requestClass, - responseClass = request.methodSpec.responseClass, - idempotency = request.methodSpec.idempotency, - method = GET_METHOD, - ), + methodSpec = request.methodSpec, + httpMethod = HTTPMethod.GET, ) } @@ -261,7 +257,7 @@ internal class ConnectInterceptor( errorJSON, ) } catch (e: Exception) { - return ConnectException(code, serializationStrategy.errorDetailParser(), errorJSON) + return ConnectException(code, serializationStrategy.errorDetailParser(), errorJSON, e) } val errorDetails = parseErrorDetails(errorPayloadJSON) ConnectException( diff --git a/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt b/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt index 23f47a7f..b9c47cb5 100644 --- a/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt +++ b/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt @@ -26,10 +26,18 @@ import org.junit.Test import org.mockito.kotlin.mock import java.net.URL -private val METHOD_SPEC = MethodSpec( +private val UNARY_METHOD_SPEC = MethodSpec( path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, +) + +private val STREAM_METHOD_SPEC = MethodSpec( + path = "", + requestClass = Any::class, + responseClass = Any::class, + streamType = StreamType.BIDI, ) class InterceptorChainTest { @@ -64,7 +72,7 @@ class InterceptorChainTest { @Test fun fifo_request_unary() { - val response = unaryChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), null, METHOD_SPEC)) + val response = unaryChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), null, UNARY_METHOD_SPEC)) assertThat(response.headers.get("id")).containsExactly("1", "2", "3", "4") } @@ -76,7 +84,7 @@ class InterceptorChainTest { @Test fun fifo_request_stream() { - val request = streamingChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), null, METHOD_SPEC)) + val request = streamingChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), null, STREAM_METHOD_SPEC)) assertThat(request.headers.get("id")).containsExactly("1", "2", "3", "4") } @@ -112,7 +120,7 @@ class InterceptorChainTest { it.contentType, headers, it.message, - METHOD_SPEC, + UNARY_METHOD_SPEC, ) }, responseFunction = { @@ -144,7 +152,7 @@ class InterceptorChainTest { it.contentType, headers, it.message, - METHOD_SPEC, + STREAM_METHOD_SPEC, ) }, requestBodyFunction = { diff --git a/library/src/test/kotlin/com/connectrpc/impl/BiDirectionalStreamTest.kt b/library/src/test/kotlin/com/connectrpc/impl/BiDirectionalStreamTest.kt index b6a1cd5e..4bde9895 100644 --- a/library/src/test/kotlin/com/connectrpc/impl/BiDirectionalStreamTest.kt +++ b/library/src/test/kotlin/com/connectrpc/impl/BiDirectionalStreamTest.kt @@ -18,6 +18,7 @@ import com.connectrpc.Codec import com.connectrpc.MethodSpec import com.connectrpc.ProtocolClientConfig import com.connectrpc.SerializationStrategy +import com.connectrpc.StreamType import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.launch @@ -54,6 +55,7 @@ class BiDirectionalStreamTest { path = "com.connectrpc.SomeService/Service", String::class, String::class, + streamType = StreamType.BIDI, ), ) stream.sendClose() @@ -83,6 +85,7 @@ class BiDirectionalStreamTest { path = "com.connectrpc.SomeService/Service", String::class, String::class, + streamType = StreamType.BIDI, ), ) val result = stream.send("input") diff --git a/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt b/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt index 8d2ae974..9177e823 100644 --- a/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt +++ b/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt @@ -18,6 +18,7 @@ import com.connectrpc.Codec import com.connectrpc.MethodSpec import com.connectrpc.ProtocolClientConfig import com.connectrpc.SerializationStrategy +import com.connectrpc.StreamType import com.connectrpc.http.HTTPClientInterface import com.connectrpc.http.HTTPRequest import kotlinx.coroutines.CoroutineScope @@ -57,6 +58,7 @@ class ProtocolClientTest { path = "com.connectrpc.SomeService/Service", String::class, String::class, + streamType = StreamType.UNARY, ), ) { _ -> } } @@ -81,6 +83,7 @@ class ProtocolClientTest { path = "com.connectrpc.SomeService/Service", String::class, String::class, + streamType = StreamType.UNARY, ), ) { _ -> } } @@ -105,6 +108,7 @@ class ProtocolClientTest { path = "com.connectrpc.SomeService/Service", String::class, String::class, + streamType = StreamType.BIDI, ), ) } @@ -130,6 +134,7 @@ class ProtocolClientTest { path = "com.connectrpc.SomeService/Service", String::class, String::class, + streamType = StreamType.BIDI, ), ) } @@ -154,6 +159,7 @@ class ProtocolClientTest { path = "com.connectrpc.SomeService/Service", String::class, String::class, + streamType = StreamType.UNARY, ), ) {} @@ -182,6 +188,7 @@ class ProtocolClientTest { path = "com.connectrpc.SomeService/Service", String::class, String::class, + streamType = StreamType.UNARY, ), ) {} diff --git a/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt index 212a3cd7..323314f2 100644 --- a/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt +++ b/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt @@ -19,14 +19,14 @@ import com.connectrpc.Codec import com.connectrpc.ConnectException import com.connectrpc.ErrorDetailParser import com.connectrpc.Idempotency -import com.connectrpc.Method.GET_METHOD -import com.connectrpc.Method.POST_METHOD import com.connectrpc.MethodSpec import com.connectrpc.ProtocolClientConfig import com.connectrpc.RequestCompression import com.connectrpc.SerializationStrategy import com.connectrpc.StreamResult +import com.connectrpc.StreamType import com.connectrpc.compression.GzipCompressionPool +import com.connectrpc.http.HTTPMethod import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse import com.connectrpc.http.TracingInfo @@ -79,6 +79,7 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -109,6 +110,7 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -136,6 +138,7 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -163,6 +166,7 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -191,6 +195,7 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -368,6 +373,7 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.BIDI, ), ), ) @@ -399,6 +405,7 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -426,6 +433,7 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.BIDI, ), ), ) @@ -681,6 +689,7 @@ class ConnectInterceptorTest { requestClass = Any::class, responseClass = Any::class, idempotency = Idempotency.NO_SIDE_EFFECTS, + streamType = StreamType.UNARY, ), ), ) @@ -690,7 +699,7 @@ class ConnectInterceptorTest { assertThat(queryMap.get(GETConstants.BASE64_QUERY_PARAM_KEY)).isEqualTo("1") assertThat(queryMap.get(GETConstants.ENCODING_QUERY_PARAM_KEY)).isEqualTo("encoding_name") assertThat(queryMap.get(GETConstants.CONNECT_VERSION_QUERY_PARAM_KEY)).isEqualTo("v1") - assertThat(request.methodSpec.method).isEqualTo(GET_METHOD) + assertThat(request.httpMethod).isEqualTo(HTTPMethod.GET) } @Test @@ -717,11 +726,12 @@ class ConnectInterceptorTest { requestClass = Any::class, responseClass = Any::class, idempotency = Idempotency.NO_SIDE_EFFECTS, + streamType = StreamType.UNARY, ), ), ) assertThat(request.url.query).isNull() - assertThat(request.methodSpec.method).isEqualTo(POST_METHOD) + assertThat(request.httpMethod).isEqualTo(HTTPMethod.POST) } @Test @@ -745,6 +755,7 @@ class ConnectInterceptorTest { requestClass = Any::class, responseClass = Any::class, idempotency = Idempotency.NO_SIDE_EFFECTS, + streamType = StreamType.UNARY, ), ), ) @@ -753,7 +764,7 @@ class ConnectInterceptorTest { assertThat(queryMap.get(GETConstants.BASE64_QUERY_PARAM_KEY)).isEqualTo("1") assertThat(queryMap.get(GETConstants.ENCODING_QUERY_PARAM_KEY)).isEqualTo("encoding_name") assertThat(queryMap.get(GETConstants.CONNECT_VERSION_QUERY_PARAM_KEY)).isEqualTo("v1") - assertThat(request.methodSpec.method).isEqualTo(GET_METHOD) + assertThat(request.httpMethod).isEqualTo(HTTPMethod.GET) } @Test @@ -778,11 +789,12 @@ class ConnectInterceptorTest { requestClass = Any::class, responseClass = Any::class, idempotency = Idempotency.NO_SIDE_EFFECTS, + streamType = StreamType.UNARY, ), ), ) assertThat(request.url.query).isNull() - assertThat(request.methodSpec.method).isEqualTo(POST_METHOD) + assertThat(request.httpMethod).isEqualTo(HTTPMethod.POST) } private fun parseQuery(request: HTTPRequest) = request.url.query diff --git a/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt index 83375274..54d34239 100644 --- a/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt +++ b/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt @@ -23,6 +23,7 @@ import com.connectrpc.ProtocolClientConfig import com.connectrpc.RequestCompression import com.connectrpc.SerializationStrategy import com.connectrpc.StreamResult +import com.connectrpc.StreamType import com.connectrpc.compression.GzipCompressionPool import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse @@ -72,6 +73,7 @@ class GRPCInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -99,6 +101,7 @@ class GRPCInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -126,6 +129,7 @@ class GRPCInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -154,6 +158,7 @@ class GRPCInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -313,6 +318,7 @@ class GRPCInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.SERVER, ), ), ) @@ -342,6 +348,7 @@ class GRPCInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.CLIENT, ), ), ) @@ -368,6 +375,7 @@ class GRPCInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.BIDI, ), ), ) diff --git a/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt index b165991e..0d8e78ba 100644 --- a/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt +++ b/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt @@ -22,6 +22,7 @@ import com.connectrpc.ProtocolClientConfig import com.connectrpc.RequestCompression import com.connectrpc.SerializationStrategy import com.connectrpc.StreamResult +import com.connectrpc.StreamType import com.connectrpc.compression.GzipCompressionPool import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse @@ -68,6 +69,7 @@ class GRPCWebInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -96,6 +98,7 @@ class GRPCWebInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -123,6 +126,7 @@ class GRPCWebInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -151,6 +155,7 @@ class GRPCWebInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -331,6 +336,7 @@ class GRPCWebInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.BIDI, ), ), ) @@ -361,6 +367,7 @@ class GRPCWebInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.BIDI, ), ), ) @@ -387,6 +394,7 @@ class GRPCWebInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.BIDI, ), ), ) diff --git a/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt b/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt index 9eeeb0f7..e7270260 100644 --- a/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt +++ b/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt @@ -50,7 +50,7 @@ class ConnectOkHttpClient @JvmOverloads constructor( } } val content = request.message ?: ByteArray(0) - val method = request.methodSpec.method + val method = request.httpMethod val requestBody = if (HttpMethod.requiresRequestBody(method)) content.toRequestBody(request.contentType.toMediaType()) else null val callRequest = builder .url(request.url) @@ -127,7 +127,7 @@ class ConnectOkHttpClient @JvmOverloads constructor( request: HTTPRequest, onResult: suspend (StreamResult) -> Unit, ): Stream { - return streamClient.initializeStream(request.methodSpec.method, request, onResult) + return streamClient.initializeStream(request.httpMethod, request, onResult) } } diff --git a/protoc-gen-connect-kotlin/src/main/kotlin/com/connectrpc/protocgen/connect/Generator.kt b/protoc-gen-connect-kotlin/src/main/kotlin/com/connectrpc/protocgen/connect/Generator.kt index 6b04a6ff..1f71185f 100644 --- a/protoc-gen-connect-kotlin/src/main/kotlin/com/connectrpc/protocgen/connect/Generator.kt +++ b/protoc-gen-connect-kotlin/src/main/kotlin/com/connectrpc/protocgen/connect/Generator.kt @@ -21,6 +21,7 @@ import com.connectrpc.MethodSpec import com.connectrpc.ProtocolClientInterface import com.connectrpc.ResponseMessage import com.connectrpc.ServerOnlyStreamInterface +import com.connectrpc.StreamType import com.connectrpc.UnaryBlockingCall import com.connectrpc.protocgen.connect.internal.CodeGenerator import com.connectrpc.protocgen.connect.internal.Configuration @@ -32,6 +33,7 @@ import com.connectrpc.protocgen.connect.internal.parse import com.connectrpc.protocgen.connect.internal.withSourceInfo import com.google.protobuf.DescriptorProtos import com.google.protobuf.DescriptorProtos.FileDescriptorProto +import com.google.protobuf.DescriptorProtos.MethodOptions.IdempotencyLevel import com.google.protobuf.Descriptors import com.google.protobuf.compiler.PluginProtos import com.squareup.kotlinpoet.ClassName @@ -73,7 +75,7 @@ class Generator : CodeGenerator { this.descriptorSource = descriptorSource configuration = parse(request.parameter) for (protoFile in request.protoFileList) { - protoFileMap.put(protoFile.name, protoFile) + protoFileMap[protoFile.name] = protoFile } for (fileName in request.fileToGenerateList) { val file = @@ -108,11 +110,12 @@ class Generator : CodeGenerator { .addFileComment("Source: ${file.name}\n") .addType(serviceClientInterface(packageName, service, sourceInfo)) .build() - fileSpecs.put(serviceClientInterfaceClassName(packageName, service), interfaceFileSpec) + fileSpecs[serviceClientInterfaceClassName(packageName, service)] = interfaceFileSpec val implementationFileSpecBuilder = FileSpec.builder(packageName, file.name) // Manually import `method()` since it is a method and not a class. .addImport(MethodSpec::class.java.`package`.name, "MethodSpec") + .addImport(StreamType::class.java.`package`.name, "StreamType") .addFileComment("Code generated by connect-kotlin. DO NOT EDIT.\n") .addFileComment("\n") .addFileComment("Source: ${file.name}\n") @@ -120,12 +123,12 @@ class Generator : CodeGenerator { .addType(serviceClientImplementation(packageName, service, sourceInfo)) for (method in service.methods) { if (method.options.hasIdempotencyLevel()) { - implementationFileSpecBuilder.addImport(Idempotency::class.java, "NO_SIDE_EFFECTS") + implementationFileSpecBuilder.addImport(Idempotency::class.java.`package`.name, "Idempotency") break } } val implementationFileSpec = implementationFileSpecBuilder.build() - fileSpecs.put(serviceClientImplementationClassName(packageName, service), implementationFileSpec) + fileSpecs[serviceClientImplementationClassName(packageName, service)] = implementationFileSpec } return fileSpecs } @@ -279,11 +282,22 @@ class Generator : CodeGenerator { .indent() .addStatement("$inputClassName::class,") .addStatement("$outputClassName::class,") - if (!method.isClientStreaming && !method.isServerStreaming) { - if (method.options.idempotencyLevel == DescriptorProtos.MethodOptions.IdempotencyLevel.NO_SIDE_EFFECTS) { - methodSpecBuilder.addStatement("NO_SIDE_EFFECTS") + when (method.options.idempotencyLevel) { + IdempotencyLevel.NO_SIDE_EFFECTS -> methodSpecBuilder.addStatement("idempotency = Idempotency.${Idempotency.NO_SIDE_EFFECTS.name},") + IdempotencyLevel.IDEMPOTENT -> methodSpecBuilder.addStatement("idempotency = Idempotency.${Idempotency.IDEMPOTENT.name},") + else -> { + // Use default value in method spec. } } + if (method.isClientStreaming && method.isServerStreaming) { + methodSpecBuilder.addStatement("streamType = StreamType.${StreamType.BIDI.name},") + } else if (method.isClientStreaming) { + methodSpecBuilder.addStatement("streamType = StreamType.${StreamType.CLIENT.name},") + } else if (method.isServerStreaming) { + methodSpecBuilder.addStatement("streamType = StreamType.${StreamType.SERVER.name},") + } else { + methodSpecBuilder.addStatement("streamType = StreamType.${StreamType.UNARY.name},") + } val methodSpecCallBlock = methodSpecBuilder .unindent() .addStatement("),") @@ -462,7 +476,7 @@ class Generator : CodeGenerator { return ClassName(packageName, names.first()) } - internal fun String.sanitizeKdoc(): String { + private fun String.sanitizeKdoc(): String { return this // Remove trailing whitespace on each line. .replace("[^\\S\n]+\n".toRegex(), "\n") From e2c8eb48e945e8d032162a173f8c932305536675 Mon Sep 17 00:00:00 2001 From: "Philip K. Warren" Date: Fri, 8 Dec 2023 07:54:44 -0600 Subject: [PATCH 2/3] review comments --- .../{BiDirectionalStreamTest.kt => BidirectionalStreamTest.kt} | 2 +- .../main/kotlin/com/connectrpc/protocgen/connect/Generator.kt | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) rename library/src/test/kotlin/com/connectrpc/impl/{BiDirectionalStreamTest.kt => BidirectionalStreamTest.kt} (99%) diff --git a/library/src/test/kotlin/com/connectrpc/impl/BiDirectionalStreamTest.kt b/library/src/test/kotlin/com/connectrpc/impl/BidirectionalStreamTest.kt similarity index 99% rename from library/src/test/kotlin/com/connectrpc/impl/BiDirectionalStreamTest.kt rename to library/src/test/kotlin/com/connectrpc/impl/BidirectionalStreamTest.kt index 4bde9895..c180a6f5 100644 --- a/library/src/test/kotlin/com/connectrpc/impl/BiDirectionalStreamTest.kt +++ b/library/src/test/kotlin/com/connectrpc/impl/BidirectionalStreamTest.kt @@ -30,7 +30,7 @@ import org.mockito.kotlin.mock import org.mockito.kotlin.whenever import java.lang.IllegalArgumentException -class BiDirectionalStreamTest { +class BidirectionalStreamTest { private val serializationStrategy: SerializationStrategy = mock { } private val codec: Codec = mock { } diff --git a/protoc-gen-connect-kotlin/src/main/kotlin/com/connectrpc/protocgen/connect/Generator.kt b/protoc-gen-connect-kotlin/src/main/kotlin/com/connectrpc/protocgen/connect/Generator.kt index 1f71185f..4da61209 100644 --- a/protoc-gen-connect-kotlin/src/main/kotlin/com/connectrpc/protocgen/connect/Generator.kt +++ b/protoc-gen-connect-kotlin/src/main/kotlin/com/connectrpc/protocgen/connect/Generator.kt @@ -104,7 +104,6 @@ class Generator : CodeGenerator { FileDescriptorProto.SERVICE_FIELD_NUMBER, )) { val interfaceFileSpec = FileSpec.builder(packageName, file.name) - // Manually import `method()` since it is a method and not a class. .addFileComment("Code generated by connect-kotlin. DO NOT EDIT.\n") .addFileComment("\n") .addFileComment("Source: ${file.name}\n") @@ -113,7 +112,6 @@ class Generator : CodeGenerator { fileSpecs[serviceClientInterfaceClassName(packageName, service)] = interfaceFileSpec val implementationFileSpecBuilder = FileSpec.builder(packageName, file.name) - // Manually import `method()` since it is a method and not a class. .addImport(MethodSpec::class.java.`package`.name, "MethodSpec") .addImport(StreamType::class.java.`package`.name, "StreamType") .addFileComment("Code generated by connect-kotlin. DO NOT EDIT.\n") From 3c7fc54d89fe8080e7b1b01b8a59356088851af6 Mon Sep 17 00:00:00 2001 From: "Philip K. Warren" Date: Fri, 8 Dec 2023 09:19:35 -0600 Subject: [PATCH 3/3] remove StreamType.UNKNOWN --- .../main/kotlin/com/connectrpc/MethodSpec.kt | 2 +- .../main/kotlin/com/connectrpc/StreamType.kt | 3 --- .../protocols/ConnectInterceptorTest.kt | 8 ++++---- .../connectrpc/protocgen/connect/Generator.kt | 18 +++++++++--------- 4 files changed, 14 insertions(+), 17 deletions(-) diff --git a/library/src/main/kotlin/com/connectrpc/MethodSpec.kt b/library/src/main/kotlin/com/connectrpc/MethodSpec.kt index 08842422..8572ac61 100644 --- a/library/src/main/kotlin/com/connectrpc/MethodSpec.kt +++ b/library/src/main/kotlin/com/connectrpc/MethodSpec.kt @@ -30,6 +30,6 @@ class MethodSpec( val path: String, val requestClass: KClass, val responseClass: KClass, + val streamType: StreamType, val idempotency: Idempotency = Idempotency.UNKNOWN, - val streamType: StreamType = StreamType.UNKNOWN, ) diff --git a/library/src/main/kotlin/com/connectrpc/StreamType.kt b/library/src/main/kotlin/com/connectrpc/StreamType.kt index ba25e811..1fd7ac1f 100644 --- a/library/src/main/kotlin/com/connectrpc/StreamType.kt +++ b/library/src/main/kotlin/com/connectrpc/StreamType.kt @@ -18,9 +18,6 @@ package com.connectrpc * Represents the RPC stream type. Set by the code generator on each [MethodSpec]. */ enum class StreamType { - /** Unknown stream type. */ - UNKNOWN, - /** Unary RPC. */ UNARY, diff --git a/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt index 323314f2..8718ec88 100644 --- a/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt +++ b/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt @@ -688,8 +688,8 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, - idempotency = Idempotency.NO_SIDE_EFFECTS, streamType = StreamType.UNARY, + idempotency = Idempotency.NO_SIDE_EFFECTS, ), ), ) @@ -725,8 +725,8 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, - idempotency = Idempotency.NO_SIDE_EFFECTS, streamType = StreamType.UNARY, + idempotency = Idempotency.NO_SIDE_EFFECTS, ), ), ) @@ -754,8 +754,8 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, - idempotency = Idempotency.NO_SIDE_EFFECTS, streamType = StreamType.UNARY, + idempotency = Idempotency.NO_SIDE_EFFECTS, ), ), ) @@ -788,8 +788,8 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, - idempotency = Idempotency.NO_SIDE_EFFECTS, streamType = StreamType.UNARY, + idempotency = Idempotency.NO_SIDE_EFFECTS, ), ), ) diff --git a/protoc-gen-connect-kotlin/src/main/kotlin/com/connectrpc/protocgen/connect/Generator.kt b/protoc-gen-connect-kotlin/src/main/kotlin/com/connectrpc/protocgen/connect/Generator.kt index 4da61209..116aa90f 100644 --- a/protoc-gen-connect-kotlin/src/main/kotlin/com/connectrpc/protocgen/connect/Generator.kt +++ b/protoc-gen-connect-kotlin/src/main/kotlin/com/connectrpc/protocgen/connect/Generator.kt @@ -280,6 +280,15 @@ class Generator : CodeGenerator { .indent() .addStatement("$inputClassName::class,") .addStatement("$outputClassName::class,") + if (method.isClientStreaming && method.isServerStreaming) { + methodSpecBuilder.addStatement("StreamType.${StreamType.BIDI.name},") + } else if (method.isClientStreaming) { + methodSpecBuilder.addStatement("StreamType.${StreamType.CLIENT.name},") + } else if (method.isServerStreaming) { + methodSpecBuilder.addStatement("StreamType.${StreamType.SERVER.name},") + } else { + methodSpecBuilder.addStatement("StreamType.${StreamType.UNARY.name},") + } when (method.options.idempotencyLevel) { IdempotencyLevel.NO_SIDE_EFFECTS -> methodSpecBuilder.addStatement("idempotency = Idempotency.${Idempotency.NO_SIDE_EFFECTS.name},") IdempotencyLevel.IDEMPOTENT -> methodSpecBuilder.addStatement("idempotency = Idempotency.${Idempotency.IDEMPOTENT.name},") @@ -287,15 +296,6 @@ class Generator : CodeGenerator { // Use default value in method spec. } } - if (method.isClientStreaming && method.isServerStreaming) { - methodSpecBuilder.addStatement("streamType = StreamType.${StreamType.BIDI.name},") - } else if (method.isClientStreaming) { - methodSpecBuilder.addStatement("streamType = StreamType.${StreamType.CLIENT.name},") - } else if (method.isServerStreaming) { - methodSpecBuilder.addStatement("streamType = StreamType.${StreamType.SERVER.name},") - } else { - methodSpecBuilder.addStatement("streamType = StreamType.${StreamType.UNARY.name},") - } val methodSpecCallBlock = methodSpecBuilder .unindent() .addStatement("),")