diff --git a/library/src/main/kotlin/com/connectrpc/MethodSpec.kt b/library/src/main/kotlin/com/connectrpc/MethodSpec.kt index 4141d5c6..8572ac61 100644 --- a/library/src/main/kotlin/com/connectrpc/MethodSpec.kt +++ b/library/src/main/kotlin/com/connectrpc/MethodSpec.kt @@ -16,11 +16,6 @@ package com.connectrpc import kotlin.reflect.KClass -internal object Method { - internal const val GET_METHOD = "GET" - internal const val POST_METHOD = "POST" -} - /** * Represents the minimum set of information to execute an RPC method. * Primarily used in generated code. @@ -29,12 +24,12 @@ internal object Method { * @param requestClass The Kotlin Class for the request message. * @param responseClass The Kotlin Class for the response message. * @param idempotency The declared idempotency of a method. - * @param method The HTTP method of a request. + * @param streamType The method's stream type. */ class MethodSpec( val path: String, val requestClass: KClass, val responseClass: KClass, + val streamType: StreamType, val idempotency: Idempotency = Idempotency.UNKNOWN, - val method: String = Method.POST_METHOD, ) diff --git a/library/src/main/kotlin/com/connectrpc/StreamType.kt b/library/src/main/kotlin/com/connectrpc/StreamType.kt new file mode 100644 index 00000000..1fd7ac1f --- /dev/null +++ b/library/src/main/kotlin/com/connectrpc/StreamType.kt @@ -0,0 +1,32 @@ +// Copyright 2022-2023 The Connect Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.connectrpc + +/** + * Represents the RPC stream type. Set by the code generator on each [MethodSpec]. + */ +enum class StreamType { + /** Unary RPC. */ + UNARY, + + /** Client streaming RPC. */ + CLIENT, + + /** Server streaming RPC. */ + SERVER, + + /** Bidirectional streaming RPC. */ + BIDI, +} diff --git a/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt b/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt index 7b4a0ffc..60002b3d 100644 --- a/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt +++ b/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt @@ -18,6 +18,11 @@ import com.connectrpc.Headers import com.connectrpc.MethodSpec import java.net.URL +internal object HTTPMethod { + internal const val GET = "GET" + internal const val POST = "POST" +} + /** * HTTP request used for sending primitive data to the server. */ @@ -32,6 +37,9 @@ class HTTPRequest internal constructor( val message: ByteArray? = null, // The method spec associated with the request. val methodSpec: MethodSpec<*, *>, + // HTTP method to use with the request. + // Almost always POST, but side effect free unary RPCs may be made with GET. + val httpMethod: String = HTTPMethod.POST, ) { /** * Clones the [HTTPRequest] with override values. @@ -50,6 +58,8 @@ class HTTPRequest internal constructor( message: ByteArray? = this.message, // The method spec associated with the request. methodSpec: MethodSpec<*, *> = this.methodSpec, + // The HTTP method to use with the request. + httpMethod: String = this.httpMethod, ): HTTPRequest { return HTTPRequest( url, @@ -57,6 +67,7 @@ class HTTPRequest internal constructor( headers, message, methodSpec, + httpMethod, ) } } diff --git a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt index 6b5345a9..129a1991 100644 --- a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt +++ b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt @@ -232,7 +232,7 @@ class ProtocolClient( try { when (streamResult.code) { Code.OK -> channel.close() - else -> channel.close(streamResult.connectException() ?: ConnectException(code = streamResult.code)) + else -> channel.close(streamResult.connectException() ?: ConnectException(code = streamResult.code, exception = streamResult.cause)) } } finally { responseTrailers.complete(streamResult.trailers) diff --git a/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt b/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt index fbc96c09..6fb30ea1 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt @@ -21,15 +21,15 @@ import com.connectrpc.ConnectException import com.connectrpc.Headers import com.connectrpc.Idempotency import com.connectrpc.Interceptor -import com.connectrpc.Method.GET_METHOD -import com.connectrpc.MethodSpec import com.connectrpc.ProtocolClientConfig import com.connectrpc.RequestCompression import com.connectrpc.StreamFunction import com.connectrpc.StreamResult +import com.connectrpc.StreamType import com.connectrpc.Trailers import com.connectrpc.UnaryFunction import com.connectrpc.compression.CompressionPool +import com.connectrpc.http.HTTPMethod import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse import com.squareup.moshi.Moshi @@ -189,7 +189,8 @@ internal class ConnectInterceptor( } private fun shouldUseGETRequest(request: HTTPRequest, finalRequestBody: Buffer): Boolean { - return request.methodSpec.idempotency == Idempotency.NO_SIDE_EFFECTS && + return request.methodSpec.streamType == StreamType.UNARY && + request.methodSpec.idempotency == Idempotency.NO_SIDE_EFFECTS && clientConfig.getConfiguration.useGET(finalRequestBody) } @@ -210,13 +211,8 @@ internal class ConnectInterceptor( url = url, contentType = "application/${requestCodec.encodingName()}", headers = request.headers, - methodSpec = MethodSpec( - path = request.methodSpec.path, - requestClass = request.methodSpec.requestClass, - responseClass = request.methodSpec.responseClass, - idempotency = request.methodSpec.idempotency, - method = GET_METHOD, - ), + methodSpec = request.methodSpec, + httpMethod = HTTPMethod.GET, ) } @@ -261,7 +257,7 @@ internal class ConnectInterceptor( errorJSON, ) } catch (e: Exception) { - return ConnectException(code, serializationStrategy.errorDetailParser(), errorJSON) + return ConnectException(code, serializationStrategy.errorDetailParser(), errorJSON, e) } val errorDetails = parseErrorDetails(errorPayloadJSON) ConnectException( diff --git a/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt b/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt index 23f47a7f..b9c47cb5 100644 --- a/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt +++ b/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt @@ -26,10 +26,18 @@ import org.junit.Test import org.mockito.kotlin.mock import java.net.URL -private val METHOD_SPEC = MethodSpec( +private val UNARY_METHOD_SPEC = MethodSpec( path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, +) + +private val STREAM_METHOD_SPEC = MethodSpec( + path = "", + requestClass = Any::class, + responseClass = Any::class, + streamType = StreamType.BIDI, ) class InterceptorChainTest { @@ -64,7 +72,7 @@ class InterceptorChainTest { @Test fun fifo_request_unary() { - val response = unaryChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), null, METHOD_SPEC)) + val response = unaryChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), null, UNARY_METHOD_SPEC)) assertThat(response.headers.get("id")).containsExactly("1", "2", "3", "4") } @@ -76,7 +84,7 @@ class InterceptorChainTest { @Test fun fifo_request_stream() { - val request = streamingChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), null, METHOD_SPEC)) + val request = streamingChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), null, STREAM_METHOD_SPEC)) assertThat(request.headers.get("id")).containsExactly("1", "2", "3", "4") } @@ -112,7 +120,7 @@ class InterceptorChainTest { it.contentType, headers, it.message, - METHOD_SPEC, + UNARY_METHOD_SPEC, ) }, responseFunction = { @@ -144,7 +152,7 @@ class InterceptorChainTest { it.contentType, headers, it.message, - METHOD_SPEC, + STREAM_METHOD_SPEC, ) }, requestBodyFunction = { diff --git a/library/src/test/kotlin/com/connectrpc/impl/BiDirectionalStreamTest.kt b/library/src/test/kotlin/com/connectrpc/impl/BidirectionalStreamTest.kt similarity index 95% rename from library/src/test/kotlin/com/connectrpc/impl/BiDirectionalStreamTest.kt rename to library/src/test/kotlin/com/connectrpc/impl/BidirectionalStreamTest.kt index b6a1cd5e..c180a6f5 100644 --- a/library/src/test/kotlin/com/connectrpc/impl/BiDirectionalStreamTest.kt +++ b/library/src/test/kotlin/com/connectrpc/impl/BidirectionalStreamTest.kt @@ -18,6 +18,7 @@ import com.connectrpc.Codec import com.connectrpc.MethodSpec import com.connectrpc.ProtocolClientConfig import com.connectrpc.SerializationStrategy +import com.connectrpc.StreamType import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.launch @@ -29,7 +30,7 @@ import org.mockito.kotlin.mock import org.mockito.kotlin.whenever import java.lang.IllegalArgumentException -class BiDirectionalStreamTest { +class BidirectionalStreamTest { private val serializationStrategy: SerializationStrategy = mock { } private val codec: Codec = mock { } @@ -54,6 +55,7 @@ class BiDirectionalStreamTest { path = "com.connectrpc.SomeService/Service", String::class, String::class, + streamType = StreamType.BIDI, ), ) stream.sendClose() @@ -83,6 +85,7 @@ class BiDirectionalStreamTest { path = "com.connectrpc.SomeService/Service", String::class, String::class, + streamType = StreamType.BIDI, ), ) val result = stream.send("input") diff --git a/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt b/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt index 8d2ae974..9177e823 100644 --- a/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt +++ b/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt @@ -18,6 +18,7 @@ import com.connectrpc.Codec import com.connectrpc.MethodSpec import com.connectrpc.ProtocolClientConfig import com.connectrpc.SerializationStrategy +import com.connectrpc.StreamType import com.connectrpc.http.HTTPClientInterface import com.connectrpc.http.HTTPRequest import kotlinx.coroutines.CoroutineScope @@ -57,6 +58,7 @@ class ProtocolClientTest { path = "com.connectrpc.SomeService/Service", String::class, String::class, + streamType = StreamType.UNARY, ), ) { _ -> } } @@ -81,6 +83,7 @@ class ProtocolClientTest { path = "com.connectrpc.SomeService/Service", String::class, String::class, + streamType = StreamType.UNARY, ), ) { _ -> } } @@ -105,6 +108,7 @@ class ProtocolClientTest { path = "com.connectrpc.SomeService/Service", String::class, String::class, + streamType = StreamType.BIDI, ), ) } @@ -130,6 +134,7 @@ class ProtocolClientTest { path = "com.connectrpc.SomeService/Service", String::class, String::class, + streamType = StreamType.BIDI, ), ) } @@ -154,6 +159,7 @@ class ProtocolClientTest { path = "com.connectrpc.SomeService/Service", String::class, String::class, + streamType = StreamType.UNARY, ), ) {} @@ -182,6 +188,7 @@ class ProtocolClientTest { path = "com.connectrpc.SomeService/Service", String::class, String::class, + streamType = StreamType.UNARY, ), ) {} diff --git a/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt index 212a3cd7..8718ec88 100644 --- a/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt +++ b/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt @@ -19,14 +19,14 @@ import com.connectrpc.Codec import com.connectrpc.ConnectException import com.connectrpc.ErrorDetailParser import com.connectrpc.Idempotency -import com.connectrpc.Method.GET_METHOD -import com.connectrpc.Method.POST_METHOD import com.connectrpc.MethodSpec import com.connectrpc.ProtocolClientConfig import com.connectrpc.RequestCompression import com.connectrpc.SerializationStrategy import com.connectrpc.StreamResult +import com.connectrpc.StreamType import com.connectrpc.compression.GzipCompressionPool +import com.connectrpc.http.HTTPMethod import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse import com.connectrpc.http.TracingInfo @@ -79,6 +79,7 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -109,6 +110,7 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -136,6 +138,7 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -163,6 +166,7 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -191,6 +195,7 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -368,6 +373,7 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.BIDI, ), ), ) @@ -399,6 +405,7 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -426,6 +433,7 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.BIDI, ), ), ) @@ -680,6 +688,7 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, idempotency = Idempotency.NO_SIDE_EFFECTS, ), ), @@ -690,7 +699,7 @@ class ConnectInterceptorTest { assertThat(queryMap.get(GETConstants.BASE64_QUERY_PARAM_KEY)).isEqualTo("1") assertThat(queryMap.get(GETConstants.ENCODING_QUERY_PARAM_KEY)).isEqualTo("encoding_name") assertThat(queryMap.get(GETConstants.CONNECT_VERSION_QUERY_PARAM_KEY)).isEqualTo("v1") - assertThat(request.methodSpec.method).isEqualTo(GET_METHOD) + assertThat(request.httpMethod).isEqualTo(HTTPMethod.GET) } @Test @@ -716,12 +725,13 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, idempotency = Idempotency.NO_SIDE_EFFECTS, ), ), ) assertThat(request.url.query).isNull() - assertThat(request.methodSpec.method).isEqualTo(POST_METHOD) + assertThat(request.httpMethod).isEqualTo(HTTPMethod.POST) } @Test @@ -744,6 +754,7 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, idempotency = Idempotency.NO_SIDE_EFFECTS, ), ), @@ -753,7 +764,7 @@ class ConnectInterceptorTest { assertThat(queryMap.get(GETConstants.BASE64_QUERY_PARAM_KEY)).isEqualTo("1") assertThat(queryMap.get(GETConstants.ENCODING_QUERY_PARAM_KEY)).isEqualTo("encoding_name") assertThat(queryMap.get(GETConstants.CONNECT_VERSION_QUERY_PARAM_KEY)).isEqualTo("v1") - assertThat(request.methodSpec.method).isEqualTo(GET_METHOD) + assertThat(request.httpMethod).isEqualTo(HTTPMethod.GET) } @Test @@ -777,12 +788,13 @@ class ConnectInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, idempotency = Idempotency.NO_SIDE_EFFECTS, ), ), ) assertThat(request.url.query).isNull() - assertThat(request.methodSpec.method).isEqualTo(POST_METHOD) + assertThat(request.httpMethod).isEqualTo(HTTPMethod.POST) } private fun parseQuery(request: HTTPRequest) = request.url.query diff --git a/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt index 83375274..54d34239 100644 --- a/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt +++ b/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt @@ -23,6 +23,7 @@ import com.connectrpc.ProtocolClientConfig import com.connectrpc.RequestCompression import com.connectrpc.SerializationStrategy import com.connectrpc.StreamResult +import com.connectrpc.StreamType import com.connectrpc.compression.GzipCompressionPool import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse @@ -72,6 +73,7 @@ class GRPCInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -99,6 +101,7 @@ class GRPCInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -126,6 +129,7 @@ class GRPCInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -154,6 +158,7 @@ class GRPCInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -313,6 +318,7 @@ class GRPCInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.SERVER, ), ), ) @@ -342,6 +348,7 @@ class GRPCInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.CLIENT, ), ), ) @@ -368,6 +375,7 @@ class GRPCInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.BIDI, ), ), ) diff --git a/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt index b165991e..0d8e78ba 100644 --- a/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt +++ b/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt @@ -22,6 +22,7 @@ import com.connectrpc.ProtocolClientConfig import com.connectrpc.RequestCompression import com.connectrpc.SerializationStrategy import com.connectrpc.StreamResult +import com.connectrpc.StreamType import com.connectrpc.compression.GzipCompressionPool import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse @@ -68,6 +69,7 @@ class GRPCWebInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -96,6 +98,7 @@ class GRPCWebInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -123,6 +126,7 @@ class GRPCWebInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -151,6 +155,7 @@ class GRPCWebInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.UNARY, ), ), ) @@ -331,6 +336,7 @@ class GRPCWebInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.BIDI, ), ), ) @@ -361,6 +367,7 @@ class GRPCWebInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.BIDI, ), ), ) @@ -387,6 +394,7 @@ class GRPCWebInterceptorTest { path = "", requestClass = Any::class, responseClass = Any::class, + streamType = StreamType.BIDI, ), ), ) diff --git a/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt b/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt index 9eeeb0f7..e7270260 100644 --- a/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt +++ b/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt @@ -50,7 +50,7 @@ class ConnectOkHttpClient @JvmOverloads constructor( } } val content = request.message ?: ByteArray(0) - val method = request.methodSpec.method + val method = request.httpMethod val requestBody = if (HttpMethod.requiresRequestBody(method)) content.toRequestBody(request.contentType.toMediaType()) else null val callRequest = builder .url(request.url) @@ -127,7 +127,7 @@ class ConnectOkHttpClient @JvmOverloads constructor( request: HTTPRequest, onResult: suspend (StreamResult) -> Unit, ): Stream { - return streamClient.initializeStream(request.methodSpec.method, request, onResult) + return streamClient.initializeStream(request.httpMethod, request, onResult) } } diff --git a/protoc-gen-connect-kotlin/src/main/kotlin/com/connectrpc/protocgen/connect/Generator.kt b/protoc-gen-connect-kotlin/src/main/kotlin/com/connectrpc/protocgen/connect/Generator.kt index 6b04a6ff..116aa90f 100644 --- a/protoc-gen-connect-kotlin/src/main/kotlin/com/connectrpc/protocgen/connect/Generator.kt +++ b/protoc-gen-connect-kotlin/src/main/kotlin/com/connectrpc/protocgen/connect/Generator.kt @@ -21,6 +21,7 @@ import com.connectrpc.MethodSpec import com.connectrpc.ProtocolClientInterface import com.connectrpc.ResponseMessage import com.connectrpc.ServerOnlyStreamInterface +import com.connectrpc.StreamType import com.connectrpc.UnaryBlockingCall import com.connectrpc.protocgen.connect.internal.CodeGenerator import com.connectrpc.protocgen.connect.internal.Configuration @@ -32,6 +33,7 @@ import com.connectrpc.protocgen.connect.internal.parse import com.connectrpc.protocgen.connect.internal.withSourceInfo import com.google.protobuf.DescriptorProtos import com.google.protobuf.DescriptorProtos.FileDescriptorProto +import com.google.protobuf.DescriptorProtos.MethodOptions.IdempotencyLevel import com.google.protobuf.Descriptors import com.google.protobuf.compiler.PluginProtos import com.squareup.kotlinpoet.ClassName @@ -73,7 +75,7 @@ class Generator : CodeGenerator { this.descriptorSource = descriptorSource configuration = parse(request.parameter) for (protoFile in request.protoFileList) { - protoFileMap.put(protoFile.name, protoFile) + protoFileMap[protoFile.name] = protoFile } for (fileName in request.fileToGenerateList) { val file = @@ -102,17 +104,16 @@ class Generator : CodeGenerator { FileDescriptorProto.SERVICE_FIELD_NUMBER, )) { val interfaceFileSpec = FileSpec.builder(packageName, file.name) - // Manually import `method()` since it is a method and not a class. .addFileComment("Code generated by connect-kotlin. DO NOT EDIT.\n") .addFileComment("\n") .addFileComment("Source: ${file.name}\n") .addType(serviceClientInterface(packageName, service, sourceInfo)) .build() - fileSpecs.put(serviceClientInterfaceClassName(packageName, service), interfaceFileSpec) + fileSpecs[serviceClientInterfaceClassName(packageName, service)] = interfaceFileSpec val implementationFileSpecBuilder = FileSpec.builder(packageName, file.name) - // Manually import `method()` since it is a method and not a class. .addImport(MethodSpec::class.java.`package`.name, "MethodSpec") + .addImport(StreamType::class.java.`package`.name, "StreamType") .addFileComment("Code generated by connect-kotlin. DO NOT EDIT.\n") .addFileComment("\n") .addFileComment("Source: ${file.name}\n") @@ -120,12 +121,12 @@ class Generator : CodeGenerator { .addType(serviceClientImplementation(packageName, service, sourceInfo)) for (method in service.methods) { if (method.options.hasIdempotencyLevel()) { - implementationFileSpecBuilder.addImport(Idempotency::class.java, "NO_SIDE_EFFECTS") + implementationFileSpecBuilder.addImport(Idempotency::class.java.`package`.name, "Idempotency") break } } val implementationFileSpec = implementationFileSpecBuilder.build() - fileSpecs.put(serviceClientImplementationClassName(packageName, service), implementationFileSpec) + fileSpecs[serviceClientImplementationClassName(packageName, service)] = implementationFileSpec } return fileSpecs } @@ -279,9 +280,20 @@ class Generator : CodeGenerator { .indent() .addStatement("$inputClassName::class,") .addStatement("$outputClassName::class,") - if (!method.isClientStreaming && !method.isServerStreaming) { - if (method.options.idempotencyLevel == DescriptorProtos.MethodOptions.IdempotencyLevel.NO_SIDE_EFFECTS) { - methodSpecBuilder.addStatement("NO_SIDE_EFFECTS") + if (method.isClientStreaming && method.isServerStreaming) { + methodSpecBuilder.addStatement("StreamType.${StreamType.BIDI.name},") + } else if (method.isClientStreaming) { + methodSpecBuilder.addStatement("StreamType.${StreamType.CLIENT.name},") + } else if (method.isServerStreaming) { + methodSpecBuilder.addStatement("StreamType.${StreamType.SERVER.name},") + } else { + methodSpecBuilder.addStatement("StreamType.${StreamType.UNARY.name},") + } + when (method.options.idempotencyLevel) { + IdempotencyLevel.NO_SIDE_EFFECTS -> methodSpecBuilder.addStatement("idempotency = Idempotency.${Idempotency.NO_SIDE_EFFECTS.name},") + IdempotencyLevel.IDEMPOTENT -> methodSpecBuilder.addStatement("idempotency = Idempotency.${Idempotency.IDEMPOTENT.name},") + else -> { + // Use default value in method spec. } } val methodSpecCallBlock = methodSpecBuilder @@ -462,7 +474,7 @@ class Generator : CodeGenerator { return ClassName(packageName, names.first()) } - internal fun String.sanitizeKdoc(): String { + private fun String.sanitizeKdoc(): String { return this // Remove trailing whitespace on each line. .replace("[^\\S\n]+\n".toRegex(), "\n")