diff --git a/generator/graphql-kotlin-schema-generator/src/main/kotlin/com/expediagroup/graphql/generator/execution/FlowSubscriptionExecutionStrategy.kt b/generator/graphql-kotlin-schema-generator/src/main/kotlin/com/expediagroup/graphql/generator/execution/FlowSubscriptionExecutionStrategy.kt index e360157379..01f82ecb5d 100644 --- a/generator/graphql-kotlin-schema-generator/src/main/kotlin/com/expediagroup/graphql/generator/execution/FlowSubscriptionExecutionStrategy.kt +++ b/generator/graphql-kotlin-schema-generator/src/main/kotlin/com/expediagroup/graphql/generator/execution/FlowSubscriptionExecutionStrategy.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 Expedia, Inc + * Copyright 2021 Expedia, Inc * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,15 +28,14 @@ import graphql.execution.SubscriptionExecutionStrategy import graphql.execution.instrumentation.parameters.InstrumentationExecutionParameters import graphql.execution.instrumentation.parameters.InstrumentationExecutionStrategyParameters import graphql.execution.instrumentation.parameters.InstrumentationFieldParameters -import graphql.execution.reactive.CompletionStageMappingPublisher import graphql.schema.GraphQLObjectType import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.reactive.asPublisher +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.future.await +import kotlinx.coroutines.reactive.asFlow import org.reactivestreams.Publisher import java.util.Collections import java.util.concurrent.CompletableFuture -import java.util.concurrent.CompletionStage -import java.util.function.Function /** * [SubscriptionExecutionStrategy] replacement that and allows schema subscription functions @@ -62,20 +61,18 @@ class FlowSubscriptionExecutionStrategy(dfe: DataFetcherExceptionHandler) : Exec // // when the upstream source event stream completes, subscribe to it and wire in our adapter - val overallResult: CompletableFuture = sourceEventStream.thenApply { publisher -> - if (publisher == null) { + val overallResult: CompletableFuture = sourceEventStream.thenApply { flow -> + if (flow == null) { ExecutionResultImpl(null, executionContext.errors) } else { - val mapperFunction = Function> { eventPayload: Any? -> + val returnFlow = flow.map { eventPayload: Any? -> executeSubscriptionEvent( executionContext, parameters, eventPayload - ) + ).await() } - // we need explicit cast as Kotlin Flow is covariant (Flow vs Publisher) - val mapSourceToResponse = CompletionStageMappingPublisher(publisher as Publisher, mapperFunction) - ExecutionResultImpl(mapSourceToResponse, executionContext.errors) + ExecutionResultImpl(returnFlow, executionContext.errors) } } @@ -102,18 +99,18 @@ class FlowSubscriptionExecutionStrategy(dfe: DataFetcherExceptionHandler) : Exec private fun createSourceEventStream( executionContext: ExecutionContext, parameters: ExecutionStrategyParameters - ): CompletableFuture?> { + ): CompletableFuture?> { val newParameters = firstFieldOfSubscriptionSelection(parameters) val fieldFetched = fetchField(executionContext, newParameters) return fieldFetched.thenApply { fetchedValue -> - val publisher = when (val publisherOrFlow: Any? = fetchedValue.fetchedValue) { - is Publisher<*> -> publisherOrFlow + val flow = when (val publisherOrFlow: Any? = fetchedValue.fetchedValue) { + is Publisher<*> -> publisherOrFlow.asFlow() // below explicit cast is required due to the type erasure and Kotlin declaration-site variance vs Java use-site variance - is Flow<*> -> (publisherOrFlow as? Flow)?.asPublisher() + is Flow<*> -> publisherOrFlow else -> null } - publisher + flow } } diff --git a/generator/graphql-kotlin-schema-generator/src/test/kotlin/com/expediagroup/graphql/generator/execution/FlowSubscriptionExecutionStrategyTest.kt b/generator/graphql-kotlin-schema-generator/src/test/kotlin/com/expediagroup/graphql/generator/execution/FlowSubscriptionExecutionStrategyTest.kt index dc6522f02f..0935983941 100644 --- a/generator/graphql-kotlin-schema-generator/src/test/kotlin/com/expediagroup/graphql/generator/execution/FlowSubscriptionExecutionStrategyTest.kt +++ b/generator/graphql-kotlin-schema-generator/src/test/kotlin/com/expediagroup/graphql/generator/execution/FlowSubscriptionExecutionStrategyTest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 Expedia, Inc + * Copyright 2021 Expedia, Inc * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,9 +34,9 @@ import graphql.schema.GraphQLSchema import kotlinx.coroutines.InternalCoroutinesApi import kotlinx.coroutines.delay import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.collect import kotlinx.coroutines.flow.flow import kotlinx.coroutines.reactive.asPublisher -import kotlinx.coroutines.reactive.collect import kotlinx.coroutines.runBlocking import org.junit.jupiter.api.Test import org.reactivestreams.Publisher @@ -66,9 +66,9 @@ class FlowSubscriptionExecutionStrategyTest { fun `verify subscription to flow`() = runBlocking { val request = ExecutionInput.newExecutionInput().query("subscription { ticker }").build() val response = testGraphQL.execute(request) - val publisher = response.getData>() + val flow = response.getData>() val list = mutableListOf() - publisher.collect { + flow.collect { list.add(it.getData>().getValue("ticker")) assertEquals(it.extensions["testKey"], "testValue") } @@ -82,9 +82,9 @@ class FlowSubscriptionExecutionStrategyTest { fun `verify subscription to datafetcher flow`() = runBlocking { val request = ExecutionInput.newExecutionInput().query("subscription { datafetcher }").build() val response = testGraphQL.execute(request) - val publisher = response.getData>() + val flow = response.getData>() val list = mutableListOf() - publisher.collect { + flow.collect { val intVal = it.getData>().getValue("datafetcher") list.add(intVal) assertEquals(it.extensions["testKey"], "testValue") @@ -99,9 +99,9 @@ class FlowSubscriptionExecutionStrategyTest { fun `verify subscription to publisher`() = runBlocking { val request = ExecutionInput.newExecutionInput().query("subscription { publisherTicker }").build() val response = testGraphQL.execute(request) - val publisher = response.getData>() + val flow = response.getData>() val list = mutableListOf() - publisher.collect { + flow.collect { list.add(it.getData>().getValue("publisherTicker")) } assertEquals(5, list.size) @@ -117,9 +117,9 @@ class FlowSubscriptionExecutionStrategyTest { .context(SubscriptionContext("junitHandler")) .build() val response = testGraphQL.execute(request) - val publisher = response.getData>() + val flow = response.getData>() val list = mutableListOf() - publisher.collect { + flow.collect { val contextValue = it.getData>().getValue("contextualTicker") assertTrue(contextValue.startsWith("junitHandler:")) list.add(contextValue.substringAfter("junitHandler:").toInt()) @@ -134,11 +134,11 @@ class FlowSubscriptionExecutionStrategyTest { fun `verify subscription to failing flow`() = runBlocking { val request = ExecutionInput.newExecutionInput().query("subscription { alwaysThrows }").build() val response = testGraphQL.execute(request) - val publisher = response.getData>() + val flow = response.getData>() val errors = mutableListOf() val results = mutableListOf() try { - publisher.collect { + flow.collect { val dataMap = it.getData>() if (dataMap != null) { results.add(dataMap.getValue("alwaysThrows")) @@ -161,9 +161,9 @@ class FlowSubscriptionExecutionStrategyTest { fun `verify subscription to exploding flow`() = runBlocking { val request = ExecutionInput.newExecutionInput().query("subscription { throwsFast }").build() val response = testGraphQL.execute(request) - val publisher = response.getData>() + val flow = response.getData>() val errors = response.errors - assertNull(publisher) + assertNull(flow) assertEquals(1, errors.size) assertEquals("JUNIT flow failure", errors[0].message.substringAfter(" : ")) } @@ -172,10 +172,10 @@ class FlowSubscriptionExecutionStrategyTest { fun `verify subscription alias`() = runBlocking { val request = ExecutionInput.newExecutionInput().query("subscription { t: ticker }").build() val response = testGraphQL.execute(request) - val publisher = response.getData>() + val flow = response.getData>() val list = mutableListOf() - publisher.collect { - list.add(it.getData>().getValue("t")) + flow.collect { executionResult -> + list.add(executionResult.getData>().getValue("t")) } assertEquals(5, list.size) for (i in list.indices) { diff --git a/servers/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/server/spring/subscriptions/ApolloSubscriptionProtocolHandler.kt b/servers/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/server/spring/subscriptions/ApolloSubscriptionProtocolHandler.kt index da01f3f872..6847eda83d 100644 --- a/servers/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/server/spring/subscriptions/ApolloSubscriptionProtocolHandler.kt +++ b/servers/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/server/spring/subscriptions/ApolloSubscriptionProtocolHandler.kt @@ -31,6 +31,7 @@ import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.kotlin.convertValue import com.fasterxml.jackson.module.kotlin.readValue import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.reactor.asFlux import kotlinx.coroutines.runBlocking import org.slf4j.LoggerFactory import org.springframework.web.reactive.socket.WebSocketSession @@ -130,7 +131,7 @@ class ApolloSubscriptionProtocolHandler( try { val request = objectMapper.convertValue(payload) return subscriptionHandler.executeSubscription(request, context) - .toFlux() + .asFlux() .map { if (it.errors?.isNotEmpty() == true) { SubscriptionOperationMessage(type = GQL_ERROR.type, id = operationMessage.id, payload = it) diff --git a/servers/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/server/spring/subscriptions/SpringGraphQLSubscriptionHandler.kt b/servers/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/server/spring/subscriptions/SpringGraphQLSubscriptionHandler.kt index ed551e4f10..5851141c91 100644 --- a/servers/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/server/spring/subscriptions/SpringGraphQLSubscriptionHandler.kt +++ b/servers/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/server/spring/subscriptions/SpringGraphQLSubscriptionHandler.kt @@ -1,5 +1,5 @@ /* - * Copyright 2019 Expedia, Inc + * Copyright 2021 Expedia, Inc * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,9 +26,9 @@ import com.expediagroup.graphql.server.types.GraphQLRequest import com.expediagroup.graphql.server.types.GraphQLResponse import graphql.ExecutionResult import graphql.GraphQL -import org.reactivestreams.Publisher -import reactor.core.publisher.Flux -import reactor.kotlin.core.publisher.toFlux +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.catch +import kotlinx.coroutines.flow.map /** * Default Spring implementation of GraphQL subscription handler. @@ -38,17 +38,16 @@ open class SpringGraphQLSubscriptionHandler( private val dataLoaderRegistryFactory: DataLoaderRegistryFactory? = null ) { - fun executeSubscription(graphQLRequest: GraphQLRequest, graphQLContext: GraphQLContext?): Flux> { + fun executeSubscription(graphQLRequest: GraphQLRequest, graphQLContext: GraphQLContext?): Flow> { val dataLoaderRegistry = dataLoaderRegistryFactory?.generate() val input = graphQLRequest.toExecutionInput(graphQLContext, dataLoaderRegistry) return graphQL.execute(input) - .getData>() - .toFlux() + .getData>() .map { result -> result.toGraphQLResponse() } - .onErrorResume { throwable -> + .catch { throwable -> val error = throwable.toGraphQLError() - Flux.just(GraphQLResponse(errors = listOf(error.toGraphQLKotlinType()))) + emit(GraphQLResponse(errors = listOf(error.toGraphQLKotlinType()))) } } } diff --git a/servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/SubscriptionConfigurationTest.kt b/servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/SubscriptionConfigurationTest.kt index 63f96b7ccb..587fb7054b 100644 --- a/servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/SubscriptionConfigurationTest.kt +++ b/servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/SubscriptionConfigurationTest.kt @@ -27,6 +27,7 @@ import graphql.GraphQL import graphql.schema.GraphQLSchema import io.mockk.every import io.mockk.mockk +import kotlinx.coroutines.flow.flowOf import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test import org.springframework.boot.autoconfigure.AutoConfigurations @@ -125,7 +126,7 @@ class SubscriptionConfigurationTest { @Bean fun subscriptionHandler(): SpringGraphQLSubscriptionHandler = mockk { - every { executeSubscription(any(), any()) } returns Flux.empty() + every { executeSubscription(any(), any()) } returns flowOf() } @Bean diff --git a/servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/execution/SpringGraphQLSubscriptionHandlerTest.kt b/servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/execution/SpringGraphQLSubscriptionHandlerTest.kt index 0c18a67e02..f3b99b1c02 100644 --- a/servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/execution/SpringGraphQLSubscriptionHandlerTest.kt +++ b/servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/execution/SpringGraphQLSubscriptionHandlerTest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2020 Expedia, Inc + * Copyright 2021 Expedia, Inc * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ package com.expediagroup.graphql.server.spring.execution import com.expediagroup.graphql.generator.SchemaGeneratorConfig import com.expediagroup.graphql.generator.TopLevelObject import com.expediagroup.graphql.generator.exceptions.GraphQLKotlinException +import com.expediagroup.graphql.generator.execution.FlowSubscriptionExecutionStrategy import com.expediagroup.graphql.generator.execution.GraphQLContext import com.expediagroup.graphql.generator.toSchema import com.expediagroup.graphql.server.execution.DefaultDataLoaderRegistryFactory @@ -30,6 +31,7 @@ import graphql.GraphQL import graphql.schema.DataFetchingEnvironment import graphql.schema.GraphQLSchema import io.mockk.mockk +import kotlinx.coroutines.reactor.asFlux import org.dataloader.DataLoader import org.junit.jupiter.api.Test import reactor.core.publisher.Flux @@ -51,7 +53,9 @@ class SpringGraphQLSubscriptionHandlerTest { queries = listOf(TopLevelObject(BasicQuery())), subscriptions = listOf(TopLevelObject(BasicSubscription())) ) - private val testGraphQL: GraphQL = GraphQL.newGraphQL(testSchema).build() + private val testGraphQL: GraphQL = GraphQL.newGraphQL(testSchema) + .subscriptionExecutionStrategy(FlowSubscriptionExecutionStrategy()) + .build() private val mockLoader: KotlinDataLoader = object : KotlinDataLoader { override val dataLoaderName: String = "MockDataLoader" override fun getDataLoader(): DataLoader = DataLoader { ids -> @@ -66,7 +70,7 @@ class SpringGraphQLSubscriptionHandlerTest { @Test fun `verify subscription`() { val request = GraphQLRequest(query = "subscription { ticker }") - val responseFlux = subscriptionHandler.executeSubscription(request, mockk()) + val responseFlux = subscriptionHandler.executeSubscription(request, mockk()).asFlux() StepVerifier.create(responseFlux) .thenConsumeWhile { response -> @@ -84,7 +88,7 @@ class SpringGraphQLSubscriptionHandlerTest { @Test fun `verify subscription with data loader`() { val request = GraphQLRequest(query = "subscription { dataLoaderValue }") - val responseFlux = subscriptionHandler.executeSubscription(request, mockk()) + val responseFlux = subscriptionHandler.executeSubscription(request, mockk()).asFlux() StepVerifier.create(responseFlux) .thenConsumeWhile { response -> @@ -105,7 +109,7 @@ class SpringGraphQLSubscriptionHandlerTest { fun `verify subscription with context`() { val request = GraphQLRequest(query = "subscription { contextualTicker }") val context = SubscriptionContext("junitHandler") - val responseFlux = subscriptionHandler.executeSubscription(request, context) + val responseFlux = subscriptionHandler.executeSubscription(request, context).asFlux() StepVerifier.create(responseFlux) .thenConsumeWhile { response -> @@ -126,7 +130,7 @@ class SpringGraphQLSubscriptionHandlerTest { @Test fun `verify subscription to failing publisher`() { val request = GraphQLRequest(query = "subscription { alwaysThrows }") - val responseFlux = subscriptionHandler.executeSubscription(request, mockk()) + val responseFlux = subscriptionHandler.executeSubscription(request, mockk()).asFlux() StepVerifier.create(responseFlux) .assertNext { response -> diff --git a/servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/subscriptions/ApolloSubscriptionProtocolHandlerTest.kt b/servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/subscriptions/ApolloSubscriptionProtocolHandlerTest.kt index d3bdc3f30c..a9c693663b 100644 --- a/servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/subscriptions/ApolloSubscriptionProtocolHandlerTest.kt +++ b/servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/subscriptions/ApolloSubscriptionProtocolHandlerTest.kt @@ -37,9 +37,10 @@ import io.mockk.mockk import io.mockk.verify import io.mockk.verifyOrder import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.flowOf +import kotlinx.coroutines.flow.map import org.junit.jupiter.api.Test import org.springframework.web.reactive.socket.WebSocketSession -import reactor.core.publisher.Flux import reactor.test.StepVerifier import java.time.Duration import kotlin.test.assertEquals @@ -297,7 +298,7 @@ class ApolloSubscriptionProtocolHandlerTest { every { id } returns "123" } val subscriptionHandler: SpringGraphQLSubscriptionHandler = mockk { - every { executeSubscription(eq(graphQLRequest), any()) } returns Flux.just(GraphQLResponse("myData")) + every { executeSubscription(eq(graphQLRequest), any()) } returns flowOf(GraphQLResponse("myData")) } val handler = ApolloSubscriptionProtocolHandler(config, nullContextFactory, subscriptionHandler, objectMapper, subscriptionHooks) @@ -329,7 +330,7 @@ class ApolloSubscriptionProtocolHandlerTest { } val subscriptionHandler: SpringGraphQLSubscriptionHandler = mockk { // Never closes - every { executeSubscription(eq(graphQLRequest), any()) } returns Flux.interval(Duration.ofSeconds(1)).map { GraphQLResponse("myData") } + every { executeSubscription(eq(graphQLRequest), any()) } returns flowOf(Duration.ofSeconds(1)).map { GraphQLResponse("myData") } } val handler = ApolloSubscriptionProtocolHandler(config, nullContextFactory, subscriptionHandler, objectMapper, subscriptionHooks) @@ -360,7 +361,7 @@ class ApolloSubscriptionProtocolHandlerTest { every { id } returns "123" } val subscriptionHandler: SpringGraphQLSubscriptionHandler = mockk { - every { executeSubscription(eq(graphQLRequest), any()) } returns Flux.just(GraphQLResponse("myData")) + every { executeSubscription(eq(graphQLRequest), any()) } returns flowOf(GraphQLResponse("myData")) } val handler = ApolloSubscriptionProtocolHandler(config, nullContextFactory, subscriptionHandler, objectMapper, subscriptionHooks) @@ -394,7 +395,7 @@ class ApolloSubscriptionProtocolHandlerTest { every { id } returns "123" } val subscriptionHandler: SpringGraphQLSubscriptionHandler = mockk { - every { executeSubscription(eq(graphQLRequest), any()) } returns Flux.just(GraphQLResponse("myData")) + every { executeSubscription(eq(graphQLRequest), any()) } returns flowOf(GraphQLResponse("myData")) } val handler = ApolloSubscriptionProtocolHandler(config, nullContextFactory, subscriptionHandler, objectMapper, subscriptionHooks) @@ -427,7 +428,7 @@ class ApolloSubscriptionProtocolHandlerTest { } val errors = listOf(GraphQLServerError("My GraphQL Error")) val subscriptionHandler: SpringGraphQLSubscriptionHandler = mockk { - every { executeSubscription(eq(graphQLRequest), any()) } returns Flux.just(GraphQLResponse(errors = errors)) + every { executeSubscription(eq(graphQLRequest), any()) } returns flowOf(GraphQLResponse(errors = errors)) } val handler = ApolloSubscriptionProtocolHandler(config, nullContextFactory, subscriptionHandler, objectMapper, subscriptionHooks) @@ -503,7 +504,7 @@ class ApolloSubscriptionProtocolHandlerTest { } val expectedResponse = GraphQLResponse("myData") val subscriptionHandler: SpringGraphQLSubscriptionHandler = mockk { - every { executeSubscription(eq(graphQLRequest), any()) } returns Flux.just(expectedResponse) + every { executeSubscription(eq(graphQLRequest), any()) } returns flowOf(expectedResponse) } val subscriptionHooks: ApolloSubscriptionHooks = mockk { every { onConnect(any(), any(), any()) } returns null