diff --git a/library/src/main/kotlin/com/connectrpc/MethodSpec.kt b/library/src/main/kotlin/com/connectrpc/MethodSpec.kt
index 4141d5c6..8572ac61 100644
--- a/library/src/main/kotlin/com/connectrpc/MethodSpec.kt
+++ b/library/src/main/kotlin/com/connectrpc/MethodSpec.kt
@@ -16,11 +16,6 @@ package com.connectrpc
import kotlin.reflect.KClass
-internal object Method {
- internal const val GET_METHOD = "GET"
- internal const val POST_METHOD = "POST"
-}
-
/**
* Represents the minimum set of information to execute an RPC method.
* Primarily used in generated code.
@@ -29,12 +24,12 @@ internal object Method {
* @param requestClass The Kotlin Class for the request message.
* @param responseClass The Kotlin Class for the response message.
* @param idempotency The declared idempotency of a method.
- * @param method The HTTP method of a request.
+ * @param streamType The method's stream type.
*/
class MethodSpec (
val path: String,
val requestClass: KClass ,
val responseClass: KClass,
+ val streamType: StreamType,
val idempotency: Idempotency = Idempotency.UNKNOWN,
- val method: String = Method.POST_METHOD,
)
diff --git a/library/src/main/kotlin/com/connectrpc/StreamType.kt b/library/src/main/kotlin/com/connectrpc/StreamType.kt
new file mode 100644
index 00000000..1fd7ac1f
--- /dev/null
+++ b/library/src/main/kotlin/com/connectrpc/StreamType.kt
@@ -0,0 +1,32 @@
+// Copyright 2022-2023 The Connect Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.connectrpc
+
+/**
+ * Represents the RPC stream type. Set by the code generator on each [MethodSpec].
+ */
+enum class StreamType {
+ /** Unary RPC. */
+ UNARY,
+
+ /** Client streaming RPC. */
+ CLIENT,
+
+ /** Server streaming RPC. */
+ SERVER,
+
+ /** Bidirectional streaming RPC. */
+ BIDI,
+}
diff --git a/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt b/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt
index 7b4a0ffc..60002b3d 100644
--- a/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt
+++ b/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt
@@ -18,6 +18,11 @@ import com.connectrpc.Headers
import com.connectrpc.MethodSpec
import java.net.URL
+internal object HTTPMethod {
+ internal const val GET = "GET"
+ internal const val POST = "POST"
+}
+
/**
* HTTP request used for sending primitive data to the server.
*/
@@ -32,6 +37,9 @@ class HTTPRequest internal constructor(
val message: ByteArray? = null,
// The method spec associated with the request.
val methodSpec: MethodSpec<*, *>,
+ // HTTP method to use with the request.
+ // Almost always POST, but side effect free unary RPCs may be made with GET.
+ val httpMethod: String = HTTPMethod.POST,
) {
/**
* Clones the [HTTPRequest] with override values.
@@ -50,6 +58,8 @@ class HTTPRequest internal constructor(
message: ByteArray? = this.message,
// The method spec associated with the request.
methodSpec: MethodSpec<*, *> = this.methodSpec,
+ // The HTTP method to use with the request.
+ httpMethod: String = this.httpMethod,
): HTTPRequest {
return HTTPRequest(
url,
@@ -57,6 +67,7 @@ class HTTPRequest internal constructor(
headers,
message,
methodSpec,
+ httpMethod,
)
}
}
diff --git a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt
index 6b5345a9..129a1991 100644
--- a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt
+++ b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt
@@ -232,7 +232,7 @@ class ProtocolClient(
try {
when (streamResult.code) {
Code.OK -> channel.close()
- else -> channel.close(streamResult.connectException() ?: ConnectException(code = streamResult.code))
+ else -> channel.close(streamResult.connectException() ?: ConnectException(code = streamResult.code, exception = streamResult.cause))
}
} finally {
responseTrailers.complete(streamResult.trailers)
diff --git a/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt b/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt
index fbc96c09..6fb30ea1 100644
--- a/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt
+++ b/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt
@@ -21,15 +21,15 @@ import com.connectrpc.ConnectException
import com.connectrpc.Headers
import com.connectrpc.Idempotency
import com.connectrpc.Interceptor
-import com.connectrpc.Method.GET_METHOD
-import com.connectrpc.MethodSpec
import com.connectrpc.ProtocolClientConfig
import com.connectrpc.RequestCompression
import com.connectrpc.StreamFunction
import com.connectrpc.StreamResult
+import com.connectrpc.StreamType
import com.connectrpc.Trailers
import com.connectrpc.UnaryFunction
import com.connectrpc.compression.CompressionPool
+import com.connectrpc.http.HTTPMethod
import com.connectrpc.http.HTTPRequest
import com.connectrpc.http.HTTPResponse
import com.squareup.moshi.Moshi
@@ -189,7 +189,8 @@ internal class ConnectInterceptor(
}
private fun shouldUseGETRequest(request: HTTPRequest, finalRequestBody: Buffer): Boolean {
- return request.methodSpec.idempotency == Idempotency.NO_SIDE_EFFECTS &&
+ return request.methodSpec.streamType == StreamType.UNARY &&
+ request.methodSpec.idempotency == Idempotency.NO_SIDE_EFFECTS &&
clientConfig.getConfiguration.useGET(finalRequestBody)
}
@@ -210,13 +211,8 @@ internal class ConnectInterceptor(
url = url,
contentType = "application/${requestCodec.encodingName()}",
headers = request.headers,
- methodSpec = MethodSpec(
- path = request.methodSpec.path,
- requestClass = request.methodSpec.requestClass,
- responseClass = request.methodSpec.responseClass,
- idempotency = request.methodSpec.idempotency,
- method = GET_METHOD,
- ),
+ methodSpec = request.methodSpec,
+ httpMethod = HTTPMethod.GET,
)
}
@@ -261,7 +257,7 @@ internal class ConnectInterceptor(
errorJSON,
)
} catch (e: Exception) {
- return ConnectException(code, serializationStrategy.errorDetailParser(), errorJSON)
+ return ConnectException(code, serializationStrategy.errorDetailParser(), errorJSON, e)
}
val errorDetails = parseErrorDetails(errorPayloadJSON)
ConnectException(
diff --git a/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt b/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt
index 23f47a7f..b9c47cb5 100644
--- a/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt
+++ b/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt
@@ -26,10 +26,18 @@ import org.junit.Test
import org.mockito.kotlin.mock
import java.net.URL
-private val METHOD_SPEC = MethodSpec(
+private val UNARY_METHOD_SPEC = MethodSpec(
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.UNARY,
+)
+
+private val STREAM_METHOD_SPEC = MethodSpec(
+ path = "",
+ requestClass = Any::class,
+ responseClass = Any::class,
+ streamType = StreamType.BIDI,
)
class InterceptorChainTest {
@@ -64,7 +72,7 @@ class InterceptorChainTest {
@Test
fun fifo_request_unary() {
- val response = unaryChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), null, METHOD_SPEC))
+ val response = unaryChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), null, UNARY_METHOD_SPEC))
assertThat(response.headers.get("id")).containsExactly("1", "2", "3", "4")
}
@@ -76,7 +84,7 @@ class InterceptorChainTest {
@Test
fun fifo_request_stream() {
- val request = streamingChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), null, METHOD_SPEC))
+ val request = streamingChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), null, STREAM_METHOD_SPEC))
assertThat(request.headers.get("id")).containsExactly("1", "2", "3", "4")
}
@@ -112,7 +120,7 @@ class InterceptorChainTest {
it.contentType,
headers,
it.message,
- METHOD_SPEC,
+ UNARY_METHOD_SPEC,
)
},
responseFunction = {
@@ -144,7 +152,7 @@ class InterceptorChainTest {
it.contentType,
headers,
it.message,
- METHOD_SPEC,
+ STREAM_METHOD_SPEC,
)
},
requestBodyFunction = {
diff --git a/library/src/test/kotlin/com/connectrpc/impl/BiDirectionalStreamTest.kt b/library/src/test/kotlin/com/connectrpc/impl/BidirectionalStreamTest.kt
similarity index 95%
rename from library/src/test/kotlin/com/connectrpc/impl/BiDirectionalStreamTest.kt
rename to library/src/test/kotlin/com/connectrpc/impl/BidirectionalStreamTest.kt
index b6a1cd5e..c180a6f5 100644
--- a/library/src/test/kotlin/com/connectrpc/impl/BiDirectionalStreamTest.kt
+++ b/library/src/test/kotlin/com/connectrpc/impl/BidirectionalStreamTest.kt
@@ -18,6 +18,7 @@ import com.connectrpc.Codec
import com.connectrpc.MethodSpec
import com.connectrpc.ProtocolClientConfig
import com.connectrpc.SerializationStrategy
+import com.connectrpc.StreamType
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
@@ -29,7 +30,7 @@ import org.mockito.kotlin.mock
import org.mockito.kotlin.whenever
import java.lang.IllegalArgumentException
-class BiDirectionalStreamTest {
+class BidirectionalStreamTest {
private val serializationStrategy: SerializationStrategy = mock { }
private val codec: Codec = mock { }
@@ -54,6 +55,7 @@ class BiDirectionalStreamTest {
path = "com.connectrpc.SomeService/Service",
String::class,
String::class,
+ streamType = StreamType.BIDI,
),
)
stream.sendClose()
@@ -83,6 +85,7 @@ class BiDirectionalStreamTest {
path = "com.connectrpc.SomeService/Service",
String::class,
String::class,
+ streamType = StreamType.BIDI,
),
)
val result = stream.send("input")
diff --git a/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt b/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt
index 8d2ae974..9177e823 100644
--- a/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt
+++ b/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt
@@ -18,6 +18,7 @@ import com.connectrpc.Codec
import com.connectrpc.MethodSpec
import com.connectrpc.ProtocolClientConfig
import com.connectrpc.SerializationStrategy
+import com.connectrpc.StreamType
import com.connectrpc.http.HTTPClientInterface
import com.connectrpc.http.HTTPRequest
import kotlinx.coroutines.CoroutineScope
@@ -57,6 +58,7 @@ class ProtocolClientTest {
path = "com.connectrpc.SomeService/Service",
String::class,
String::class,
+ streamType = StreamType.UNARY,
),
) { _ -> }
}
@@ -81,6 +83,7 @@ class ProtocolClientTest {
path = "com.connectrpc.SomeService/Service",
String::class,
String::class,
+ streamType = StreamType.UNARY,
),
) { _ -> }
}
@@ -105,6 +108,7 @@ class ProtocolClientTest {
path = "com.connectrpc.SomeService/Service",
String::class,
String::class,
+ streamType = StreamType.BIDI,
),
)
}
@@ -130,6 +134,7 @@ class ProtocolClientTest {
path = "com.connectrpc.SomeService/Service",
String::class,
String::class,
+ streamType = StreamType.BIDI,
),
)
}
@@ -154,6 +159,7 @@ class ProtocolClientTest {
path = "com.connectrpc.SomeService/Service",
String::class,
String::class,
+ streamType = StreamType.UNARY,
),
) {}
@@ -182,6 +188,7 @@ class ProtocolClientTest {
path = "com.connectrpc.SomeService/Service",
String::class,
String::class,
+ streamType = StreamType.UNARY,
),
) {}
diff --git a/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt
index 212a3cd7..8718ec88 100644
--- a/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt
+++ b/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt
@@ -19,14 +19,14 @@ import com.connectrpc.Codec
import com.connectrpc.ConnectException
import com.connectrpc.ErrorDetailParser
import com.connectrpc.Idempotency
-import com.connectrpc.Method.GET_METHOD
-import com.connectrpc.Method.POST_METHOD
import com.connectrpc.MethodSpec
import com.connectrpc.ProtocolClientConfig
import com.connectrpc.RequestCompression
import com.connectrpc.SerializationStrategy
import com.connectrpc.StreamResult
+import com.connectrpc.StreamType
import com.connectrpc.compression.GzipCompressionPool
+import com.connectrpc.http.HTTPMethod
import com.connectrpc.http.HTTPRequest
import com.connectrpc.http.HTTPResponse
import com.connectrpc.http.TracingInfo
@@ -79,6 +79,7 @@ class ConnectInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.UNARY,
),
),
)
@@ -109,6 +110,7 @@ class ConnectInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.UNARY,
),
),
)
@@ -136,6 +138,7 @@ class ConnectInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.UNARY,
),
),
)
@@ -163,6 +166,7 @@ class ConnectInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.UNARY,
),
),
)
@@ -191,6 +195,7 @@ class ConnectInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.UNARY,
),
),
)
@@ -368,6 +373,7 @@ class ConnectInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.BIDI,
),
),
)
@@ -399,6 +405,7 @@ class ConnectInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.UNARY,
),
),
)
@@ -426,6 +433,7 @@ class ConnectInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.BIDI,
),
),
)
@@ -680,6 +688,7 @@ class ConnectInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.UNARY,
idempotency = Idempotency.NO_SIDE_EFFECTS,
),
),
@@ -690,7 +699,7 @@ class ConnectInterceptorTest {
assertThat(queryMap.get(GETConstants.BASE64_QUERY_PARAM_KEY)).isEqualTo("1")
assertThat(queryMap.get(GETConstants.ENCODING_QUERY_PARAM_KEY)).isEqualTo("encoding_name")
assertThat(queryMap.get(GETConstants.CONNECT_VERSION_QUERY_PARAM_KEY)).isEqualTo("v1")
- assertThat(request.methodSpec.method).isEqualTo(GET_METHOD)
+ assertThat(request.httpMethod).isEqualTo(HTTPMethod.GET)
}
@Test
@@ -716,12 +725,13 @@ class ConnectInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.UNARY,
idempotency = Idempotency.NO_SIDE_EFFECTS,
),
),
)
assertThat(request.url.query).isNull()
- assertThat(request.methodSpec.method).isEqualTo(POST_METHOD)
+ assertThat(request.httpMethod).isEqualTo(HTTPMethod.POST)
}
@Test
@@ -744,6 +754,7 @@ class ConnectInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.UNARY,
idempotency = Idempotency.NO_SIDE_EFFECTS,
),
),
@@ -753,7 +764,7 @@ class ConnectInterceptorTest {
assertThat(queryMap.get(GETConstants.BASE64_QUERY_PARAM_KEY)).isEqualTo("1")
assertThat(queryMap.get(GETConstants.ENCODING_QUERY_PARAM_KEY)).isEqualTo("encoding_name")
assertThat(queryMap.get(GETConstants.CONNECT_VERSION_QUERY_PARAM_KEY)).isEqualTo("v1")
- assertThat(request.methodSpec.method).isEqualTo(GET_METHOD)
+ assertThat(request.httpMethod).isEqualTo(HTTPMethod.GET)
}
@Test
@@ -777,12 +788,13 @@ class ConnectInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.UNARY,
idempotency = Idempotency.NO_SIDE_EFFECTS,
),
),
)
assertThat(request.url.query).isNull()
- assertThat(request.methodSpec.method).isEqualTo(POST_METHOD)
+ assertThat(request.httpMethod).isEqualTo(HTTPMethod.POST)
}
private fun parseQuery(request: HTTPRequest) = request.url.query
diff --git a/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt
index 83375274..54d34239 100644
--- a/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt
+++ b/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt
@@ -23,6 +23,7 @@ import com.connectrpc.ProtocolClientConfig
import com.connectrpc.RequestCompression
import com.connectrpc.SerializationStrategy
import com.connectrpc.StreamResult
+import com.connectrpc.StreamType
import com.connectrpc.compression.GzipCompressionPool
import com.connectrpc.http.HTTPRequest
import com.connectrpc.http.HTTPResponse
@@ -72,6 +73,7 @@ class GRPCInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.UNARY,
),
),
)
@@ -99,6 +101,7 @@ class GRPCInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.UNARY,
),
),
)
@@ -126,6 +129,7 @@ class GRPCInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.UNARY,
),
),
)
@@ -154,6 +158,7 @@ class GRPCInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.UNARY,
),
),
)
@@ -313,6 +318,7 @@ class GRPCInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.SERVER,
),
),
)
@@ -342,6 +348,7 @@ class GRPCInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.CLIENT,
),
),
)
@@ -368,6 +375,7 @@ class GRPCInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.BIDI,
),
),
)
diff --git a/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt
index b165991e..0d8e78ba 100644
--- a/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt
+++ b/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt
@@ -22,6 +22,7 @@ import com.connectrpc.ProtocolClientConfig
import com.connectrpc.RequestCompression
import com.connectrpc.SerializationStrategy
import com.connectrpc.StreamResult
+import com.connectrpc.StreamType
import com.connectrpc.compression.GzipCompressionPool
import com.connectrpc.http.HTTPRequest
import com.connectrpc.http.HTTPResponse
@@ -68,6 +69,7 @@ class GRPCWebInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.UNARY,
),
),
)
@@ -96,6 +98,7 @@ class GRPCWebInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.UNARY,
),
),
)
@@ -123,6 +126,7 @@ class GRPCWebInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.UNARY,
),
),
)
@@ -151,6 +155,7 @@ class GRPCWebInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.UNARY,
),
),
)
@@ -331,6 +336,7 @@ class GRPCWebInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.BIDI,
),
),
)
@@ -361,6 +367,7 @@ class GRPCWebInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.BIDI,
),
),
)
@@ -387,6 +394,7 @@ class GRPCWebInterceptorTest {
path = "",
requestClass = Any::class,
responseClass = Any::class,
+ streamType = StreamType.BIDI,
),
),
)
diff --git a/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt b/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt
index 9eeeb0f7..e7270260 100644
--- a/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt
+++ b/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt
@@ -50,7 +50,7 @@ class ConnectOkHttpClient @JvmOverloads constructor(
}
}
val content = request.message ?: ByteArray(0)
- val method = request.methodSpec.method
+ val method = request.httpMethod
val requestBody = if (HttpMethod.requiresRequestBody(method)) content.toRequestBody(request.contentType.toMediaType()) else null
val callRequest = builder
.url(request.url)
@@ -127,7 +127,7 @@ class ConnectOkHttpClient @JvmOverloads constructor(
request: HTTPRequest,
onResult: suspend (StreamResult) -> Unit,
): Stream {
- return streamClient.initializeStream(request.methodSpec.method, request, onResult)
+ return streamClient.initializeStream(request.httpMethod, request, onResult)
}
}
diff --git a/protoc-gen-connect-kotlin/src/main/kotlin/com/connectrpc/protocgen/connect/Generator.kt b/protoc-gen-connect-kotlin/src/main/kotlin/com/connectrpc/protocgen/connect/Generator.kt
index 6b04a6ff..116aa90f 100644
--- a/protoc-gen-connect-kotlin/src/main/kotlin/com/connectrpc/protocgen/connect/Generator.kt
+++ b/protoc-gen-connect-kotlin/src/main/kotlin/com/connectrpc/protocgen/connect/Generator.kt
@@ -21,6 +21,7 @@ import com.connectrpc.MethodSpec
import com.connectrpc.ProtocolClientInterface
import com.connectrpc.ResponseMessage
import com.connectrpc.ServerOnlyStreamInterface
+import com.connectrpc.StreamType
import com.connectrpc.UnaryBlockingCall
import com.connectrpc.protocgen.connect.internal.CodeGenerator
import com.connectrpc.protocgen.connect.internal.Configuration
@@ -32,6 +33,7 @@ import com.connectrpc.protocgen.connect.internal.parse
import com.connectrpc.protocgen.connect.internal.withSourceInfo
import com.google.protobuf.DescriptorProtos
import com.google.protobuf.DescriptorProtos.FileDescriptorProto
+import com.google.protobuf.DescriptorProtos.MethodOptions.IdempotencyLevel
import com.google.protobuf.Descriptors
import com.google.protobuf.compiler.PluginProtos
import com.squareup.kotlinpoet.ClassName
@@ -73,7 +75,7 @@ class Generator : CodeGenerator {
this.descriptorSource = descriptorSource
configuration = parse(request.parameter)
for (protoFile in request.protoFileList) {
- protoFileMap.put(protoFile.name, protoFile)
+ protoFileMap[protoFile.name] = protoFile
}
for (fileName in request.fileToGenerateList) {
val file =
@@ -102,17 +104,16 @@ class Generator : CodeGenerator {
FileDescriptorProto.SERVICE_FIELD_NUMBER,
)) {
val interfaceFileSpec = FileSpec.builder(packageName, file.name)
- // Manually import `method()` since it is a method and not a class.
.addFileComment("Code generated by connect-kotlin. DO NOT EDIT.\n")
.addFileComment("\n")
.addFileComment("Source: ${file.name}\n")
.addType(serviceClientInterface(packageName, service, sourceInfo))
.build()
- fileSpecs.put(serviceClientInterfaceClassName(packageName, service), interfaceFileSpec)
+ fileSpecs[serviceClientInterfaceClassName(packageName, service)] = interfaceFileSpec
val implementationFileSpecBuilder = FileSpec.builder(packageName, file.name)
- // Manually import `method()` since it is a method and not a class.
.addImport(MethodSpec::class.java.`package`.name, "MethodSpec")
+ .addImport(StreamType::class.java.`package`.name, "StreamType")
.addFileComment("Code generated by connect-kotlin. DO NOT EDIT.\n")
.addFileComment("\n")
.addFileComment("Source: ${file.name}\n")
@@ -120,12 +121,12 @@ class Generator : CodeGenerator {
.addType(serviceClientImplementation(packageName, service, sourceInfo))
for (method in service.methods) {
if (method.options.hasIdempotencyLevel()) {
- implementationFileSpecBuilder.addImport(Idempotency::class.java, "NO_SIDE_EFFECTS")
+ implementationFileSpecBuilder.addImport(Idempotency::class.java.`package`.name, "Idempotency")
break
}
}
val implementationFileSpec = implementationFileSpecBuilder.build()
- fileSpecs.put(serviceClientImplementationClassName(packageName, service), implementationFileSpec)
+ fileSpecs[serviceClientImplementationClassName(packageName, service)] = implementationFileSpec
}
return fileSpecs
}
@@ -279,9 +280,20 @@ class Generator : CodeGenerator {
.indent()
.addStatement("$inputClassName::class,")
.addStatement("$outputClassName::class,")
- if (!method.isClientStreaming && !method.isServerStreaming) {
- if (method.options.idempotencyLevel == DescriptorProtos.MethodOptions.IdempotencyLevel.NO_SIDE_EFFECTS) {
- methodSpecBuilder.addStatement("NO_SIDE_EFFECTS")
+ if (method.isClientStreaming && method.isServerStreaming) {
+ methodSpecBuilder.addStatement("StreamType.${StreamType.BIDI.name},")
+ } else if (method.isClientStreaming) {
+ methodSpecBuilder.addStatement("StreamType.${StreamType.CLIENT.name},")
+ } else if (method.isServerStreaming) {
+ methodSpecBuilder.addStatement("StreamType.${StreamType.SERVER.name},")
+ } else {
+ methodSpecBuilder.addStatement("StreamType.${StreamType.UNARY.name},")
+ }
+ when (method.options.idempotencyLevel) {
+ IdempotencyLevel.NO_SIDE_EFFECTS -> methodSpecBuilder.addStatement("idempotency = Idempotency.${Idempotency.NO_SIDE_EFFECTS.name},")
+ IdempotencyLevel.IDEMPOTENT -> methodSpecBuilder.addStatement("idempotency = Idempotency.${Idempotency.IDEMPOTENT.name},")
+ else -> {
+ // Use default value in method spec.
}
}
val methodSpecCallBlock = methodSpecBuilder
@@ -462,7 +474,7 @@ class Generator : CodeGenerator {
return ClassName(packageName, names.first())
}
- internal fun String.sanitizeKdoc(): String {
+ private fun String.sanitizeKdoc(): String {
return this
// Remove trailing whitespace on each line.
.replace("[^\\S\n]+\n".toRegex(), "\n")