From f3d06d9c135fd2480928265cc2663cd4889b8dac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9da=20Housni=20Alaoui?= Date: Sat, 10 May 2025 22:25:59 +0200 Subject: [PATCH] Add a heartbeat executor for SSE MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Réda Housni Alaoui --- .../DefaultSseEmitterHeartbeatExecutor.java | 125 ++++++ .../RequestMappingHandlerAdapter.java | 12 +- ...ResponseBodyEmitterReturnValueHandler.java | 28 ++ .../SseEmitterHeartbeatExecutor.java | 25 ++ ...faultSseEmitterHeartbeatExecutorTests.java | 377 ++++++++++++++++++ 5 files changed, 566 insertions(+), 1 deletion(-) create mode 100644 spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/DefaultSseEmitterHeartbeatExecutor.java create mode 100644 spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterHeartbeatExecutor.java create mode 100644 spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/DefaultSseEmitterHeartbeatExecutorTests.java diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/DefaultSseEmitterHeartbeatExecutor.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/DefaultSseEmitterHeartbeatExecutor.java new file mode 100644 index 000000000000..438ecab64772 --- /dev/null +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/DefaultSseEmitterHeartbeatExecutor.java @@ -0,0 +1,125 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.servlet.mvc.method.annotation; + + +import java.io.IOException; +import java.time.Duration; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ScheduledFuture; + +import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.context.SmartLifecycle; +import org.springframework.http.MediaType; +import org.springframework.scheduling.TaskScheduler; + +/** + * @author Réda Housni Alaoui + */ +public class DefaultSseEmitterHeartbeatExecutor implements SmartLifecycle, SseEmitterHeartbeatExecutor { + + private static final Logger LOGGER = LoggerFactory.getLogger(DefaultSseEmitterHeartbeatExecutor.class); + + private final TaskScheduler taskScheduler; + private final Set emitters = ConcurrentHashMap.newKeySet(); + + private final Object lifecycleMonitor = new Object(); + + private Duration period = Duration.ofSeconds(5); + private String eventName = "ping"; + private String eventObject = "ping"; + + private volatile boolean running; + @Nullable + private volatile ScheduledFuture taskFuture; + + public DefaultSseEmitterHeartbeatExecutor(TaskScheduler taskScheduler) { + this.taskScheduler = taskScheduler; + } + + public void setPeriod(Duration period) { + this.period = period; + } + + public void setEventName(String eventName) { + this.eventName = eventName; + } + + public void setEventObject(String eventObject) { + this.eventObject = eventObject; + } + + @Override + public void start() { + synchronized (lifecycleMonitor) { + taskFuture = taskScheduler.scheduleAtFixedRate(this::ping, period); + running = true; + } + } + + @Override + public void register(SseEmitter emitter) { + Runnable closeCallback = () -> emitters.remove(emitter); + emitter.onCompletion(closeCallback); + emitter.onError(t -> closeCallback.run()); + emitter.onTimeout(closeCallback); + + emitters.add(emitter); + } + + @Override + public void stop() { + synchronized (lifecycleMonitor) { + ScheduledFuture future = taskFuture; + if (future != null) { + future.cancel(true); + } + emitters.clear(); + running = false; + } + } + + @Override + public boolean isRunning() { + return running; + } + + boolean isRegistered(SseEmitter emitter) { + return emitters.contains(emitter); + } + + private void ping() { + LOGGER.atDebug().log(() -> "Pinging %s emitter(s)".formatted(emitters.size())); + + for (SseEmitter emitter : emitters) { + if (Thread.currentThread().isInterrupted()) { + return; + } + LOGGER.trace("Pinging {}", emitter); + SseEmitter.SseEventBuilder eventBuilder = SseEmitter.event().name(eventName).data(eventObject, MediaType.TEXT_PLAIN); + try { + emitter.send(eventBuilder); + } catch (IOException | RuntimeException e) { + // According to SseEmitter's Javadoc, the container itself will call SseEmitter#completeWithError + LOGGER.debug(e.getMessage()); + } + } + } +} diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java index d20d9559ff69..577d86b2157b 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java @@ -123,6 +123,7 @@ * @author Rossen Stoyanchev * @author Juergen Hoeller * @author Sebastien Deleuze + * @author Réda Housni Alaoui * @since 3.1 * @see HandlerMethodArgumentResolver * @see HandlerMethodReturnValueHandler @@ -201,6 +202,8 @@ public class RequestMappingHandlerAdapter extends AbstractHandlerMethodAdapter private final Map> modelAttributeAdviceCache = new LinkedHashMap<>(); + @Nullable + private SseEmitterHeartbeatExecutor sseEmitterHeartbeatExecutor; /** * Provide resolvers for custom argument types. Custom resolvers are ordered @@ -526,6 +529,13 @@ public void setParameterNameDiscoverer(ParameterNameDiscoverer parameterNameDisc this.parameterNameDiscoverer = parameterNameDiscoverer; } + /** + * Set the {@link SseEmitterHeartbeatExecutor} that will be used to periodically prob the SSE connection health + */ + public void setSseEmitterHeartbeatExecutor(@Nullable SseEmitterHeartbeatExecutor sseEmitterHeartbeatExecutor) { + this.sseEmitterHeartbeatExecutor = sseEmitterHeartbeatExecutor; + } + /** * A {@link ConfigurableBeanFactory} is expected for resolving expressions * in method argument default values. @@ -735,7 +745,7 @@ private List getDefaultReturnValueHandlers() { handlers.add(new ViewMethodReturnValueHandler()); handlers.add(new ResponseBodyEmitterReturnValueHandler(getMessageConverters(), this.reactiveAdapterRegistry, this.taskExecutor, this.contentNegotiationManager, - initViewResolvers(), initLocaleResolver())); + initViewResolvers(), initLocaleResolver(), this.sseEmitterHeartbeatExecutor)); handlers.add(new StreamingResponseBodyReturnValueHandler()); handlers.add(new HttpEntityMethodProcessor(getMessageConverters(), this.contentNegotiationManager, this.requestResponseBodyAdvice, this.errorResponseInterceptors)); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java index 5f59fcdf9440..84812ee20325 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java @@ -25,6 +25,7 @@ import java.util.Collections; import java.util.List; import java.util.Locale; +import java.util.Optional; import java.util.Set; import java.util.function.Consumer; @@ -89,6 +90,7 @@ * * * @author Rossen Stoyanchev + * @author Réda Housni Alaoui * @since 4.2 */ public class ResponseBodyEmitterReturnValueHandler implements HandlerMethodReturnValueHandler { @@ -101,6 +103,8 @@ public class ResponseBodyEmitterReturnValueHandler implements HandlerMethodRetur private final LocaleResolver localeResolver; + @Nullable + private final SseEmitterHeartbeatExecutor sseEmitterHeartbeatExecutor; /** * Simple constructor with reactive type support based on a default instance of @@ -143,11 +147,32 @@ public ResponseBodyEmitterReturnValueHandler( ReactiveAdapterRegistry registry, TaskExecutor executor, ContentNegotiationManager manager, List viewResolvers, @Nullable LocaleResolver localeResolver) { + this(messageConverters, registry, executor, manager, viewResolvers, localeResolver, null); + } + + /** + * Constructor that with added arguments for view rendering. + * @param messageConverters converters to write emitted objects with + * @param registry for reactive return value type support + * @param executor for blocking I/O writes of items emitted from reactive types + * @param manager for detecting streaming media types + * @param viewResolvers resolvers for fragment stream rendering + * @param localeResolver the {@link LocaleResolver} for fragment stream rendering + * @param sseEmitterHeartbeatExecutor for sending periodic events to SSE clients + * @since 6.2 + */ + public ResponseBodyEmitterReturnValueHandler( + List> messageConverters, + ReactiveAdapterRegistry registry, TaskExecutor executor, ContentNegotiationManager manager, + List viewResolvers, @Nullable LocaleResolver localeResolver, + @Nullable SseEmitterHeartbeatExecutor sseEmitterHeartbeatExecutor) { + Assert.notEmpty(messageConverters, "HttpMessageConverter List must not be empty"); this.sseMessageConverters = initSseConverters(messageConverters); this.reactiveHandler = new ReactiveTypeHandler(registry, executor, manager, null); this.viewResolvers = viewResolvers; this.localeResolver = (localeResolver != null ? localeResolver : new AcceptHeaderLocaleResolver()); + this.sseEmitterHeartbeatExecutor = sseEmitterHeartbeatExecutor; } private static List> initSseConverters(List> converters) { @@ -239,6 +264,9 @@ public void handleReturnValue(@Nullable Object returnValue, MethodParameter retu } emitter.initialize(emitterHandler); + if (emitter instanceof SseEmitter sseEmitter) { + Optional.ofNullable(sseEmitterHeartbeatExecutor).ifPresent(handler -> handler.register(sseEmitter)); + } } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterHeartbeatExecutor.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterHeartbeatExecutor.java new file mode 100644 index 000000000000..0b19305daa7b --- /dev/null +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/SseEmitterHeartbeatExecutor.java @@ -0,0 +1,25 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.servlet.mvc.method.annotation; + +/** + * @author Réda Housni Alaoui + */ +public interface SseEmitterHeartbeatExecutor { + + void register(SseEmitter emitter); +} diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/DefaultSseEmitterHeartbeatExecutorTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/DefaultSseEmitterHeartbeatExecutorTests.java new file mode 100644 index 000000000000..325748bbacb5 --- /dev/null +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/DefaultSseEmitterHeartbeatExecutorTests.java @@ -0,0 +1,377 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.servlet.mvc.method.annotation; + +import static org.assertj.core.api.Assertions.*; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.Delayed; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; + +import org.jetbrains.annotations.NotNull; +import org.jspecify.annotations.Nullable; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.springframework.http.MediaType; +import org.springframework.scheduling.TaskScheduler; +import org.springframework.scheduling.Trigger; + +/** + * @author Réda Housni Alaoui + */ +class DefaultSseEmitterHeartbeatExecutorTests { + + private static final MediaType TEXT_PLAIN_UTF8 = new MediaType("text", "plain", StandardCharsets.UTF_8); + + private TestTaskScheduler taskScheduler; + private DefaultSseEmitterHeartbeatExecutor executor; + + @BeforeEach + void beforeEach() { + this.taskScheduler = new TestTaskScheduler(); + executor = new DefaultSseEmitterHeartbeatExecutor(taskScheduler); + } + + @Test + @DisplayName("It sends heartbeat at a fixed rate") + void test1() { + executor.start(); + assertThat(taskScheduler.fixedRateTask).isNotNull(); + assertThat(taskScheduler.fixedRatePeriod).isEqualTo(Duration.ofSeconds(5)); + + TestEmitter emitter = createEmitter(); + executor.register(emitter.emitter()); + taskScheduler.fixedRateTask.run(); + + emitter.handler.assertSentObjectCount(3); + emitter.handler.assertObject(0, "event:ping\ndata:", TEXT_PLAIN_UTF8); + emitter.handler.assertObject(1, "ping", MediaType.TEXT_PLAIN); + emitter.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8); + emitter.handler.assertWriteCount(1); + } + + @Test + @DisplayName("Emitter is unregistered on completion") + void test2() { + executor.start(); + + TestEmitter emitter = createEmitter(); + executor.register(emitter.emitter()); + + assertThat(executor.isRegistered(emitter.emitter)).isTrue(); + emitter.emitter.complete(); + assertThat(executor.isRegistered(emitter.emitter)).isFalse(); + } + + @Test + @DisplayName("Emitter is unregistered on error") + void test3() { + executor.start(); + + TestEmitter emitter = createEmitter(); + executor.register(emitter.emitter()); + + assertThat(executor.isRegistered(emitter.emitter)).isTrue(); + emitter.emitter.completeWithError(new RuntimeException()); + assertThat(executor.isRegistered(emitter.emitter)).isFalse(); + } + + @Test + @DisplayName("Emitter is unregistered on timeout") + void test4() { + executor.start(); + + TestEmitter emitter = createEmitter(); + executor.register(emitter.emitter()); + + assertThat(executor.isRegistered(emitter.emitter)).isTrue(); + emitter.handler.completeWithTimeout(); + assertThat(executor.isRegistered(emitter.emitter)).isFalse(); + } + + @Test + @DisplayName("Emitters are unregistered on executor shutdown") + void test5() { + executor.start(); + + TestEmitter emitter = createEmitter(); + executor.register(emitter.emitter()); + + assertThat(executor.isRegistered(emitter.emitter)).isTrue(); + executor.stop(); + assertThat(executor.isRegistered(emitter.emitter)).isFalse(); + } + + @Test + @DisplayName("The task schedule is canceled on executor shutdown") + void test6() { + executor.start(); + executor.stop(); + assertThat(taskScheduler.fixedRateFuture.canceled).isTrue(); + assertThat(taskScheduler.fixedRateFuture.interrupted).isTrue(); + } + + @Test + @DisplayName("The task never throws") + void test7() { + executor.start(); + assertThat(taskScheduler.fixedRateTask).isNotNull(); + + TestEmitter emitter = createEmitter(); + executor.register(emitter.emitter()); + emitter.handler.exceptionToThrowOnSend = new RuntimeException(); + + assertThatCode(() -> taskScheduler.fixedRateTask.run()).doesNotThrowAnyException(); + } + + @Test + @DisplayName("The heartbeat rate can be customized") + void test8() { + executor.setPeriod(Duration.ofSeconds(30)); + executor.start(); + assertThat(taskScheduler.fixedRateTask).isNotNull(); + assertThat(taskScheduler.fixedRatePeriod).isEqualTo(Duration.ofSeconds(30)); + } + + @Test + @DisplayName("The heartbeat event name can be customized") + void test9() { + executor.setEventName("foo"); + executor.start(); + assertThat(taskScheduler.fixedRateTask).isNotNull(); + + TestEmitter emitter = createEmitter(); + executor.register(emitter.emitter()); + taskScheduler.fixedRateTask.run(); + + emitter.handler.assertSentObjectCount(3); + emitter.handler.assertObject(0, "event:foo\ndata:", TEXT_PLAIN_UTF8); + emitter.handler.assertObject(1, "ping", MediaType.TEXT_PLAIN); + emitter.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8); + emitter.handler.assertWriteCount(1); + } + + @Test + @DisplayName("The heartbeat event object can be customized") + void test10() { + executor.setEventObject("foo"); + executor.start(); + assertThat(taskScheduler.fixedRateTask).isNotNull(); + + TestEmitter emitter = createEmitter(); + executor.register(emitter.emitter()); + taskScheduler.fixedRateTask.run(); + + emitter.handler.assertSentObjectCount(3); + emitter.handler.assertObject(0, "event:ping\ndata:", TEXT_PLAIN_UTF8); + emitter.handler.assertObject(1, "foo", MediaType.TEXT_PLAIN); + emitter.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8); + emitter.handler.assertWriteCount(1); + } + + private TestEmitter createEmitter() { + SseEmitter sseEmitter = new SseEmitter(); + TestEmitterHandler handler = new TestEmitterHandler(); + try { + sseEmitter.initialize(handler); + } catch (IOException e) { + throw new RuntimeException(e); + } + return new TestEmitter(sseEmitter, handler); + } + + private record TestEmitter(SseEmitter emitter, TestEmitterHandler handler) { + + } + + private static class TestEmitterHandler implements ResponseBodyEmitter.Handler { + + private final List objects = new ArrayList<>(); + + private final List<@Nullable MediaType> mediaTypes = new ArrayList<>(); + + private final List timeoutCallbacks = new ArrayList<>(); + private final List completionCallbacks = new ArrayList<>(); + private final List> errorCallbacks = new ArrayList<>(); + + private int writeCount; + @Nullable + private RuntimeException exceptionToThrowOnSend; + + public void assertSentObjectCount(int size) { + assertThat(this.objects).hasSize(size); + } + + public void assertObject(int index, Object object, MediaType mediaType) { + assertThat(index).isLessThanOrEqualTo(this.objects.size()); + assertThat(this.objects.get(index)).isEqualTo(object); + assertThat(this.mediaTypes.get(index)).isEqualTo(mediaType); + } + + public void assertWriteCount(int writeCount) { + assertThat(this.writeCount).isEqualTo(writeCount); + } + + @Override + public void send(Object data, @Nullable MediaType mediaType) { + failSendIfNeeded(); + this.objects.add(data); + this.mediaTypes.add(mediaType); + this.writeCount++; + } + + @Override + public void send(Set items) { + failSendIfNeeded(); + for (ResponseBodyEmitter.DataWithMediaType item : items) { + this.objects.add(item.getData()); + this.mediaTypes.add(item.getMediaType()); + } + this.writeCount++; + } + + private void failSendIfNeeded() { + Optional.ofNullable(exceptionToThrowOnSend) + .ifPresent(e -> { + throw e; + }); + } + + @Override + public void onCompletion(Runnable callback) { + completionCallbacks.add(callback); + } + + @Override + public void onTimeout(Runnable callback) { + timeoutCallbacks.add(callback); + } + + @Override + public void onError(Consumer callback) { + errorCallbacks.add(callback); + } + + @Override + public void complete() { + completionCallbacks.forEach(Runnable::run); + } + + @Override + public void completeWithError(Throwable failure) { + errorCallbacks.forEach(consumer -> consumer.accept(failure)); + } + + public void completeWithTimeout() { + timeoutCallbacks.forEach(Runnable::run); + } + } + + private static class TestTaskScheduler implements TaskScheduler { + + @Nullable + private Runnable fixedRateTask; + @Nullable + private Duration fixedRatePeriod; + private final TestScheduledFuture fixedRateFuture = new TestScheduledFuture<>(); + + @Override + public ScheduledFuture scheduleAtFixedRate(Runnable task, Duration period) { + this.fixedRateTask = task; + this.fixedRatePeriod = period; + return fixedRateFuture; + } + + @Override + public ScheduledFuture schedule(Runnable task, Trigger trigger) { + throw new UnsupportedOperationException(); + } + + @Override + public ScheduledFuture schedule(Runnable task, Instant startTime) { + throw new UnsupportedOperationException(); + } + + @Override + public ScheduledFuture scheduleAtFixedRate(Runnable task, Instant startTime, Duration period) { + throw new UnsupportedOperationException(); + } + + @Override + public ScheduledFuture scheduleWithFixedDelay(Runnable task, Instant startTime, Duration delay) { + throw new UnsupportedOperationException(); + } + + @Override + public ScheduledFuture scheduleWithFixedDelay(Runnable task, Duration delay) { + throw new UnsupportedOperationException(); + } + } + + private static class TestScheduledFuture implements ScheduledFuture { + + private boolean canceled; + private boolean interrupted; + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + canceled = true; + interrupted = mayInterruptIfRunning; + return true; + } + + @Override + public long getDelay(@NotNull TimeUnit timeUnit) { + throw new UnsupportedOperationException(); + } + + @Override + public int compareTo(@NotNull Delayed delayed) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isCancelled() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isDone() { + throw new UnsupportedOperationException(); + } + + @Override + public T get() { + throw new UnsupportedOperationException(); + } + + @Override + public T get(long l, @NotNull TimeUnit timeUnit) { + throw new UnsupportedOperationException(); + } + } +}