Skip to content

Commit fee43d6

Browse files
committed
Add RSocketServiceMethod support for suspending functions
See #34868 Signed-off-by: Dmitry Sulman <dmitry.sulman@gmail.com>
1 parent 2faed3c commit fee43d6

File tree

3 files changed

+106
-3
lines changed

3 files changed

+106
-3
lines changed

spring-messaging/src/main/java/org/springframework/messaging/rsocket/service/RSocketServiceMethod.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import reactor.core.publisher.Mono;
3131

3232
import org.springframework.core.DefaultParameterNameDiscoverer;
33+
import org.springframework.core.KotlinDetector;
3334
import org.springframework.core.MethodParameter;
3435
import org.springframework.core.ParameterizedTypeReference;
3536
import org.springframework.core.ReactiveAdapter;
@@ -82,6 +83,10 @@ private static MethodParameter[] initMethodParameters(Method method) {
8283
if (count == 0) {
8384
return new MethodParameter[0];
8485
}
86+
if (KotlinDetector.isSuspendingFunction(method)) {
87+
count -= 1;
88+
}
89+
8590
DefaultParameterNameDiscoverer nameDiscoverer = new DefaultParameterNameDiscoverer();
8691
MethodParameter[] parameters = new MethodParameter[count];
8792
for (int i = 0; i < count; i++) {
@@ -129,10 +134,15 @@ private static Function<RSocketRequestValues, Object> initResponseFunction(
129134

130135
MethodParameter returnParam = new MethodParameter(method, -1);
131136
Class<?> returnType = returnParam.getParameterType();
137+
boolean isSuspending = KotlinDetector.isSuspendingFunction(method);
138+
if (isSuspending) {
139+
returnType = Mono.class;
140+
}
141+
132142
ReactiveAdapter reactiveAdapter = reactiveRegistry.getAdapter(returnType);
133143

134144
MethodParameter actualParam = (reactiveAdapter != null ? returnParam.nested() : returnParam.nestedIfOptional());
135-
Class<?> actualType = actualParam.getNestedParameterType();
145+
Class<?> actualType = isSuspending ? actualParam.getParameterType() : actualParam.getNestedParameterType();
136146

137147
Function<RSocketRequestValues, Publisher<?>> responseFunction;
138148
if (ClassUtils.isVoidType(actualType) || (reactiveAdapter != null && reactiveAdapter.isNoValue())) {
@@ -147,7 +157,8 @@ else if (reactiveAdapter == null) {
147157
}
148158
else {
149159
ParameterizedTypeReference<?> payloadType =
150-
ParameterizedTypeReference.forType(actualParam.getNestedGenericParameterType());
160+
ParameterizedTypeReference.forType(isSuspending ? actualParam.getGenericParameterType() :
161+
actualParam.getNestedGenericParameterType());
151162

152163
responseFunction = values -> (
153164
reactiveAdapter.isMultiValue() ?

spring-messaging/src/main/java/org/springframework/messaging/rsocket/service/RSocketServiceProxyFactory.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
import org.springframework.aop.framework.ProxyFactory;
3333
import org.springframework.aop.framework.ReflectiveMethodInvocation;
34+
import org.springframework.core.KotlinDetector;
3435
import org.springframework.core.MethodIntrospector;
3536
import org.springframework.core.ReactiveAdapterRegistry;
3637
import org.springframework.core.annotation.AnnotatedElementUtils;
@@ -246,7 +247,9 @@ private ServiceMethodInterceptor(List<RSocketServiceMethod> methods) {
246247
Method method = invocation.getMethod();
247248
RSocketServiceMethod serviceMethod = this.serviceMethods.get(method);
248249
if (serviceMethod != null) {
249-
return serviceMethod.invoke(invocation.getArguments());
250+
@Nullable Object[] arguments = KotlinDetector.isSuspendingFunction(method) ?
251+
resolveCoroutinesArguments(invocation.getArguments()) : invocation.getArguments();
252+
return serviceMethod.invoke(arguments);
250253
}
251254
if (method.isDefault()) {
252255
if (invocation instanceof ReflectiveMethodInvocation reflectiveMethodInvocation) {
@@ -256,6 +259,12 @@ private ServiceMethodInterceptor(List<RSocketServiceMethod> methods) {
256259
}
257260
throw new IllegalStateException("Unexpected method invocation: " + method);
258261
}
262+
263+
private static Object[] resolveCoroutinesArguments(@Nullable Object[] args) {
264+
Object[] functionArgs = new Object[args.length - 1];
265+
System.arraycopy(args, 0, functionArgs, 0, args.length - 1);
266+
return functionArgs;
267+
}
259268
}
260269

261270
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Copyright 2002-present the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.messaging.rsocket.service
18+
19+
import io.rsocket.util.DefaultPayload
20+
import kotlinx.coroutines.runBlocking
21+
import org.assertj.core.api.Assertions.assertThat
22+
import org.junit.jupiter.api.BeforeEach
23+
import org.junit.jupiter.api.Test
24+
import org.springframework.messaging.rsocket.RSocketRequester
25+
import org.springframework.messaging.rsocket.RSocketStrategies
26+
import org.springframework.messaging.rsocket.TestRSocket
27+
import org.springframework.util.MimeTypeUtils.TEXT_PLAIN
28+
import reactor.core.publisher.Mono
29+
30+
/**
31+
* Kotlin tests for [RSocketServiceMethod].
32+
*
33+
* @author Dmitry Sulman
34+
*/
35+
class RSocketServiceMethodKotlinTests {
36+
37+
private lateinit var rsocket: TestRSocket
38+
39+
private lateinit var proxyFactory: RSocketServiceProxyFactory
40+
41+
@BeforeEach
42+
fun setUp() {
43+
rsocket = TestRSocket()
44+
val requester = RSocketRequester.wrap(rsocket, TEXT_PLAIN, TEXT_PLAIN, RSocketStrategies.create())
45+
proxyFactory = RSocketServiceProxyFactory.builder(requester).build()
46+
}
47+
48+
@Test
49+
fun fireAndForget(): Unit = runBlocking {
50+
val service = proxyFactory.createClient(SuspendingFunctionsService::class.java)
51+
52+
val requestPayload = "request"
53+
service.fireAndForget(requestPayload)
54+
55+
assertThat(rsocket.savedMethodName).isEqualTo("fireAndForget")
56+
assertThat(rsocket.savedPayload?.metadataUtf8).isEqualTo("ff")
57+
assertThat(rsocket.savedPayload?.dataUtf8).isEqualTo(requestPayload)
58+
}
59+
60+
@Test
61+
fun requestResponse(): Unit = runBlocking {
62+
val service = proxyFactory.createClient(SuspendingFunctionsService::class.java)
63+
64+
val requestPayload = "request"
65+
val responsePayload = "response"
66+
rsocket.setPayloadMonoToReturn(Mono.just(DefaultPayload.create(responsePayload)))
67+
val response = service.requestResponse(requestPayload)
68+
69+
assertThat(response).isEqualTo(responsePayload)
70+
assertThat(rsocket.savedMethodName).isEqualTo("requestResponse")
71+
assertThat(rsocket.savedPayload?.metadataUtf8).isEqualTo("rr")
72+
assertThat(rsocket.savedPayload?.dataUtf8).isEqualTo(requestPayload)
73+
}
74+
75+
private interface SuspendingFunctionsService {
76+
77+
@RSocketExchange("ff")
78+
suspend fun fireAndForget(input: String)
79+
80+
@RSocketExchange("rr")
81+
suspend fun requestResponse(input: String): String
82+
}
83+
}

0 commit comments

Comments
 (0)