Skip to content

Commit

Permalink
* sse: support method PUT/POST with body
Browse files Browse the repository at this point in the history
Signed-off-by: neo <1100909+neowu@users.noreply.github.com>
  • Loading branch information
neowu committed Feb 12, 2025
1 parent 78b18f8 commit 7c016e1
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 40 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
## Change log

### 9.1.6 (2/10/2025 - )
### 9.1.6-b0 (2/10/2025 - )

* http_client: tweak sse checking
* undertow: updated to 2.3.18.Final
> due to vulnerability of old versions, has to update to latest despite potential memory consumption is higher
* sse: support method PUT/POST with body

### 9.1.5 (11/11/2024 - 01/22/2025)

Expand Down
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ apply(plugin = "project")

subprojects {
group = "core.framework"
version = "9.1.6"
version = "9.1.6-b0"
}

val elasticVersion = "8.15.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,10 @@ public void handleRequest(HttpServerExchange exchange) throws Exception {

HttpString method = exchange.getRequestMethod();
HeaderMap headers = exchange.getRequestHeaders();
boolean sse = sseHandler != null && sseHandler.check(method, headers, path);
boolean ws = webSocketHandler != null && webSocketHandler.check(method, headers);
boolean active = !sse && !ws;

var requestHandler = new HTTPRequestHandler(exchange, handler, sseHandler);
boolean ws = webSocketHandler != null && webSocketHandler.check(method, headers); // TODO: retire ws and simplify
boolean active = !requestHandler.sse && !ws;
boolean shutdown = shutdownHandler.handle(exchange, active);
if (shutdown) return;

Expand All @@ -82,7 +83,7 @@ public void handleRequest(HttpServerExchange exchange) throws Exception {
return;
}

var reader = new RequestBodyReader(exchange, handler);
var reader = new RequestBodyReader(exchange, requestHandler);
StreamSourceChannel channel = exchange.getRequestChannel();
reader.read(channel); // channel will be null if getRequestChannel() is already called, but here should not be that case
if (!reader.complete()) {
Expand All @@ -92,12 +93,10 @@ public void handleRequest(HttpServerExchange exchange) throws Exception {
}
}

if (active) {
exchange.dispatch(handler);
} else if (sse) {
sseHandler.handleRequest(exchange); // not dispatch, continue in io thread
} else {
if (ws) {
exchange.dispatch(webSocketHandler);
} else {
requestHandler.handle();
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package core.framework.internal.web;

import core.framework.internal.web.sse.ServerSentEventHandler;
import io.undertow.server.HttpServerExchange;
import io.undertow.util.HeaderMap;
import io.undertow.util.HttpString;

public class HTTPRequestHandler {
final boolean sse;
private final HttpServerExchange exchange;
private final HTTPHandler handler;
private final ServerSentEventHandler sseHandler;

HTTPRequestHandler(HttpServerExchange exchange, HTTPHandler handler, ServerSentEventHandler sseHandler) {
this.exchange = exchange;
this.handler = handler;
this.sseHandler = sseHandler;

HttpString method = exchange.getRequestMethod();
String path = exchange.getRequestPath();
HeaderMap headers = exchange.getRequestHeaders();
sse = sseHandler != null && sseHandler.check(method, path, headers);
}

public void handle() {
if (sse) {
sseHandler.handleRequest(exchange); // not dispatch, continue in io thread
} else {
exchange.dispatch(handler);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package core.framework.internal.web.request;

import core.framework.internal.web.HTTPHandler;
import core.framework.internal.web.HTTPRequestHandler;
import core.framework.web.exception.BadRequestException;
import io.undertow.connector.PooledByteBuffer;
import io.undertow.server.HttpServerExchange;
Expand All @@ -20,13 +20,13 @@ public final class RequestBodyReader implements ChannelListener<StreamSourceChan
static final AttachmentKey<RequestBody> REQUEST_BODY = AttachmentKey.create(RequestBody.class);

private final HttpServerExchange exchange;
private final HTTPHandler handler;
private final HTTPRequestHandler handler;
private final int contentLength;
private boolean complete;
private byte[] body;
private int position = 0;

public RequestBodyReader(HttpServerExchange exchange, HTTPHandler handler) {
public RequestBodyReader(HttpServerExchange exchange, HTTPRequestHandler handler) {
this.exchange = exchange;
this.handler = handler;
contentLength = (int) exchange.getRequestContentLength();
Expand All @@ -37,7 +37,7 @@ public RequestBodyReader(HttpServerExchange exchange, HTTPHandler handler) {
public void handleEvent(StreamSourceChannel channel) {
read(channel);
if (complete) {
exchange.dispatch(handler);
handler.handle();
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package core.framework.internal.web.sse;

import core.framework.http.HTTPMethod;
import core.framework.internal.async.VirtualThread;
import core.framework.internal.log.ActionLog;
import core.framework.internal.log.LogManager;
Expand All @@ -8,19 +9,20 @@
import core.framework.internal.web.session.ReadOnlySession;
import core.framework.internal.web.session.SessionManager;
import core.framework.module.ServerSentEventConfig;
import core.framework.util.Strings;
import core.framework.web.sse.ChannelListener;
import io.undertow.server.HttpHandler;
import io.undertow.server.HttpServerExchange;
import io.undertow.util.HeaderMap;
import io.undertow.util.Headers;
import io.undertow.util.HttpString;
import io.undertow.util.Methods;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xnio.ChannelListeners;
import org.xnio.IoUtils;
import org.xnio.channels.StreamSinkChannel;

import java.io.IOException;
import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
Expand All @@ -41,25 +43,31 @@ public ServerSentEventHandler(LogManager logManager, SessionManager sessionManag
this.handlerContext = handlerContext;
}

public boolean check(HttpString method, HeaderMap headers, String path) {
return Methods.GET.equals(method) && "text/event-stream".equals(headers.getFirst(Headers.ACCEPT)) && supports.containsKey(path);
public boolean check(HttpString method, String path, HeaderMap headers) {
return "text/event-stream".equals(headers.getFirst(Headers.ACCEPT))
&& supports.containsKey(key(method.toString(), path));
}

@Override
public void handleRequest(HttpServerExchange exchange) throws Exception {
public void handleRequest(HttpServerExchange exchange) {
exchange.getResponseHeaders().put(Headers.CONTENT_TYPE, "text/event-stream");
exchange.setPersistent(false);
StreamSinkChannel sink = exchange.getResponseChannel();
if (sink.flush()) {
exchange.dispatch(() -> handle(exchange, sink));
} else {
var listener = ChannelListeners.flushingChannelListener(channel -> exchange.dispatch(() -> handle(exchange, sink)),
(channel, e) -> {
logger.warn("failed to establish sse connection, error={}", e.getMessage(), e);
IoUtils.safeClose(exchange.getConnection());
});
sink.getWriteSetter().set(listener);
sink.resumeWrites();
try {
if (sink.flush()) {
exchange.dispatch(() -> handle(exchange, sink));
} else {
var listener = ChannelListeners.flushingChannelListener(channel -> exchange.dispatch(() -> handle(exchange, sink)),
(channel, e) -> {
logger.warn("failed to establish sse connection, error={}", e.getMessage(), e);
IoUtils.safeClose(exchange.getConnection());
});
sink.getWriteSetter().set(listener);
sink.resumeWrites();
}
} catch (IOException e) {
logger.warn("failed to establish sse connection, error={}", e.getMessage(), e);
IoUtils.safeClose(exchange.getConnection());
}
}

Expand All @@ -78,7 +86,7 @@ void handle(HttpServerExchange exchange, StreamSinkChannel sink) {
actionLog.warningContext.maxProcessTimeInNano(MAX_PROCESS_TIME_IN_NANO);
String path = request.path();
@SuppressWarnings("unchecked")
ChannelSupport<Object> support = (ChannelSupport<Object>) supports.get(path); // ServerSentEventHandler.check() ensures path exists
ChannelSupport<Object> support = (ChannelSupport<Object>) supports.get(key(request.method().name(), path)); // ServerSentEventHandler.check() ensures path exists
actionLog.action("sse:" + path + ":open");
handlerContext.rateControl.validateRate(ServerSentEventConfig.SSE_OPEN_GROUP, request.clientIP());

Expand All @@ -104,9 +112,9 @@ void handle(HttpServerExchange exchange, StreamSinkChannel sink) {
}
}

public <T> void add(String path, Class<T> eventClass, ChannelListener<T> listener, ServerSentEventContextImpl<T> context) {
var previous = supports.put(path, new ChannelSupport<>(listener, eventClass, context));
if (previous != null) throw new Error("found duplicate sse listener, path=" + path);
public <T> void add(HTTPMethod method, String path, Class<T> eventClass, ChannelListener<T> listener, ServerSentEventContextImpl<T> context) {
var previous = supports.put(key(method.name(), path), new ChannelSupport<>(listener, eventClass, context));
if (previous != null) throw new Error(Strings.format("found duplicate sse listener, method={}, path={}", method, path));
}

public void shutdown() {
Expand All @@ -117,4 +125,8 @@ public void shutdown() {
}
}
}

private String key(String method, String path) {
return method + ":" + path;
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package core.framework.module;

import core.framework.http.HTTPMethod;
import core.framework.internal.inject.InjectValidator;
import core.framework.internal.module.Config;
import core.framework.internal.module.ModuleContext;
Expand Down Expand Up @@ -37,15 +38,15 @@ protected void validate() {
}
}

public <T> void listen(String path, Class<T> eventClass, ChannelListener<T> listener) {
public <T> void listen(HTTPMethod method, String path, Class<T> eventClass, ChannelListener<T> listener) {
if (HTTPIOHandler.HEALTH_CHECK_PATH.equals(path)) throw new Error("/health-check is reserved path");
if (path.contains("/:")) throw new Error("listener path must be static, path=" + path);

if (listener.getClass().isSynthetic())
throw new Error("listener class must not be anonymous class or lambda, please create static class, listenerClass=" + listener.getClass().getCanonicalName());
new InjectValidator(listener).validate();

logger.info("sse, path={}, eventClass={}, listener={}", path, eventClass.getCanonicalName(), listener.getClass().getCanonicalName());
logger.info("sse, method={}, path={}, eventClass={}, listener={}", method, path, eventClass.getCanonicalName(), listener.getClass().getCanonicalName());

if (context.httpServer.sseHandler == null) {
context.httpServer.sseHandler = new ServerSentEventHandler(context.logManager, context.httpServer.siteManager.sessionManager, context.httpServer.handlerContext);
Expand All @@ -57,7 +58,7 @@ public <T> void listen(String path, Class<T> eventClass, ChannelListener<T> list
context.apiController.beanClasses.add(eventClass);

var sseContext = new ServerSentEventContextImpl<T>();
context.httpServer.sseHandler.add(path, eventClass, listener, sseContext);
context.httpServer.sseHandler.add(method, path, eventClass, listener, sseContext);
context.beanFactory.bind(Types.generic(ServerSentEventContext.class, eventClass), null, sseContext);
metrics.contexts.add(sseContext);
context.backgroundTask().scheduleWithFixedDelay(sseContext::keepAlive, Duration.ofSeconds(15));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package core.framework.module;

import core.framework.http.HTTPMethod;
import core.framework.internal.module.ModuleContext;
import core.framework.internal.web.HTTPIOHandler;
import core.framework.internal.web.sse.TestChannelListener;
Expand Down Expand Up @@ -28,22 +29,22 @@ void createWebSocketConfig() {

@Test
void withReservedPath() {
assertThatThrownBy(() -> config.listen(HTTPIOHandler.HEALTH_CHECK_PATH, TestEvent.class, new TestChannelListener()))
assertThatThrownBy(() -> config.listen(HTTPMethod.GET, HTTPIOHandler.HEALTH_CHECK_PATH, TestEvent.class, new TestChannelListener()))
.isInstanceOf(Error.class)
.hasMessageContaining("/health-check is reserved path");
}

@Test
void listen() {
assertThatThrownBy(() -> config.listen("/sse/:name", TestEvent.class, new TestChannelListener()))
assertThatThrownBy(() -> config.listen(HTTPMethod.GET, "/sse/:name", TestEvent.class, new TestChannelListener()))
.isInstanceOf(Error.class)
.hasMessageContaining("listener path must be static");

assertThatThrownBy(() -> config.listen("/sse", TestEvent.class, (request, channel, lastEventId) -> {
assertThatThrownBy(() -> config.listen(HTTPMethod.GET, "/sse", TestEvent.class, (request, channel, lastEventId) -> {
})).isInstanceOf(Error.class)
.hasMessageContaining("listener class must not be anonymous class or lambda");

config.listen("/sse2", TestEvent.class, new TestChannelListener());
config.listen(HTTPMethod.GET, "/sse2", TestEvent.class, new TestChannelListener());
@SuppressWarnings("unchecked")
ServerSentEventContext<TestEvent> context = (ServerSentEventContext<TestEvent>) this.config.context.beanFactory.bean(Types.generic(ServerSentEventContext.class, TestEvent.class), null);
assertThat(context).isNotNull();
Expand Down

0 comments on commit 7c016e1

Please sign in to comment.