Skip to content

Commit

Permalink
Reduce access on headers for STOMP messaging
Browse files Browse the repository at this point in the history
Issue: SPR-16165
  • Loading branch information
jhoeller committed Nov 14, 2017
1 parent d5f34ed commit f861f18
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 45 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2015 the original author or authors.
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -205,16 +205,20 @@ public String getDetailedLogMessage(Object payload) {

private StringBuilder getBaseLogMessage() {
StringBuilder sb = new StringBuilder();
sb.append(getMessageType().name());
if (getDestination() != null) {
sb.append(" destination=").append(getDestination());
SimpMessageType messageType = getMessageType();
sb.append(messageType != null ? messageType.name() : SimpMessageType.OTHER);
String destination = getDestination();
if (destination != null) {
sb.append(" destination=").append(destination);
}
if (getSubscriptionId() != null) {
sb.append(" subscriptionId=").append(getSubscriptionId());
String subscriptionId = getSubscriptionId();
if (subscriptionId != null) {
sb.append(" subscriptionId=").append(subscriptionId);
}
sb.append(" session=").append(getSessionId());
if (getUser() != null) {
sb.append(" user=").append(getUser().getName());
Principal user = getUser();
if (user != null) {
sb.append(" user=").append(user.getName());
}
return sb;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,12 @@ private Message<byte[]> decodeMessage(ByteBuffer buffer, MultiValueMap<String, S
payload = readPayload(buffer, headerAccessor);
}
if (payload != null) {
if (payload.length > 0 && !headerAccessor.getCommand().isBodyAllowed()) {
throw new StompConversionException(headerAccessor.getCommand() +
" shouldn't have a payload: length=" + payload.length + ", headers=" + headers);
if (payload.length > 0) {
StompCommand stompCommand = headerAccessor.getCommand();
if (stompCommand != null && !stompCommand.isBodyAllowed()) {
throw new StompConversionException(stompCommand +
" shouldn't have a payload: length=" + payload.length + ", headers=" + headers);
}
}
headerAccessor.updateSimpMessageHeadersFromStompHeaders();
headerAccessor.setLeaveMutable(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.messaging.simp.stomp;

import java.nio.charset.Charset;
import java.security.Principal;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -163,11 +164,13 @@ else if (StompCommand.CONNECT.equals(command)) {
}

void updateStompHeadersFromSimpMessageHeaders() {
if (getDestination() != null) {
setNativeHeader(STOMP_DESTINATION_HEADER, getDestination());
String destination = getDestination();
if (destination != null) {
setNativeHeader(STOMP_DESTINATION_HEADER, destination);
}
if (getContentType() != null) {
setNativeHeader(STOMP_CONTENT_TYPE_HEADER, getContentType().toString());
MimeType contentType = getContentType();
if (contentType != null) {
setNativeHeader(STOMP_CONTENT_TYPE_HEADER, contentType.toString());
}
trySetStompHeaderForSubscriptionId();
}
Expand All @@ -185,21 +188,24 @@ Map<String, List<String>> getNativeHeaders() {
}

public StompCommand updateStompCommandAsClientMessage() {
if (getMessageType() != SimpMessageType.MESSAGE) {
throw new IllegalStateException("Unexpected message type " + getMessageType());
SimpMessageType messageType = getMessageType();
if (messageType != SimpMessageType.MESSAGE) {
throw new IllegalStateException("Unexpected message type " + messageType);
}
if (getCommand() == null) {
StompCommand command = getCommand();
if (command == null) {
setHeader(COMMAND_HEADER, StompCommand.SEND);
}
else if (!getCommand().equals(StompCommand.SEND)) {
throw new IllegalStateException("Unexpected STOMP command " + getCommand());
else if (!command.equals(StompCommand.SEND)) {
throw new IllegalStateException("Unexpected STOMP command " + command);
}
return getCommand();
return command;
}

public void updateStompCommandAsServerMessage() {
if (getMessageType() != SimpMessageType.MESSAGE) {
throw new IllegalStateException("Unexpected message type " + getMessageType());
SimpMessageType messageType = getMessageType();
if (messageType != SimpMessageType.MESSAGE) {
throw new IllegalStateException("Unexpected message type " + messageType);
}
StompCommand command = getCommand();
if ((command == null) || StompCommand.SEND.equals(command)) {
Expand Down Expand Up @@ -273,7 +279,8 @@ public void setSubscriptionId(String subscriptionId) {
private void trySetStompHeaderForSubscriptionId() {
String subscriptionId = getSubscriptionId();
if (subscriptionId != null) {
if (getCommand() != null && StompCommand.MESSAGE.equals(getCommand())) {
StompCommand command = getCommand();
if (command != null && StompCommand.MESSAGE.equals(command)) {
setNativeHeader(STOMP_SUBSCRIPTION_HEADER, subscriptionId);
}
else {
Expand All @@ -286,10 +293,8 @@ private void trySetStompHeaderForSubscriptionId() {
}

public Integer getContentLength() {
if (containsNativeHeader(STOMP_CONTENT_LENGTH_HEADER)) {
return Integer.valueOf(getFirstNativeHeader(STOMP_CONTENT_LENGTH_HEADER));
}
return null;
String header = getFirstNativeHeader(STOMP_CONTENT_LENGTH_HEADER);
return (header != null ? Integer.valueOf(header) : null);
}

public void setContentLength(int contentLength) {
Expand Down Expand Up @@ -390,23 +395,26 @@ public void setVersion(String version) {

@Override
public String getShortLogMessage(Object payload) {
if (StompCommand.SUBSCRIBE.equals(getCommand())) {
StompCommand command = getCommand();
if (StompCommand.SUBSCRIBE.equals(command)) {
return "SUBSCRIBE " + getDestination() + " id=" + getSubscriptionId() + appendSession();
}
else if (StompCommand.UNSUBSCRIBE.equals(getCommand())) {
else if (StompCommand.UNSUBSCRIBE.equals(command)) {
return "UNSUBSCRIBE id=" + getSubscriptionId() + appendSession();
}
else if (StompCommand.SEND.equals(getCommand())) {
else if (StompCommand.SEND.equals(command)) {
return "SEND " + getDestination() + appendSession() + appendPayload(payload);
}
else if (StompCommand.CONNECT.equals(getCommand())) {
return "CONNECT" + (getUser() != null ? " user=" + getUser().getName() : "") + appendSession();
else if (StompCommand.CONNECT.equals(command)) {
Principal user = getUser();
return "CONNECT" + (user != null ? " user=" + user.getName() : "") + appendSession();
}
else if (StompCommand.CONNECTED.equals(getCommand())) {
else if (StompCommand.CONNECTED.equals(command)) {
return "CONNECTED heart-beat=" + Arrays.toString(getHeartbeat()) + appendSession();
}
else if (StompCommand.DISCONNECT.equals(getCommand())) {
return "DISCONNECT" + (getReceipt() != null ? " receipt=" + getReceipt() : "") + appendSession();
else if (StompCommand.DISCONNECT.equals(command)) {
String receipt = getReceipt();
return "DISCONNECT" + (receipt != null ? " receipt=" + receipt : "") + appendSession();
}
else {
return getDetailedLogMessage(payload);
Expand Down Expand Up @@ -444,11 +452,12 @@ private String appendPayload(Object payload) {
"Expected byte array payload but got: " + ClassUtils.getQualifiedName(payload.getClass()));
}
byte[] bytes = (byte[]) payload;
String contentType = (getContentType() != null ? " " + getContentType().toString() : "");
if (bytes.length == 0 || getContentType() == null || !isReadableContentType()) {
MimeType mimeType = getContentType();
String contentType = (mimeType != null ? " " + mimeType.toString() : "");
if (bytes.length == 0 || mimeType == null || !isReadableContentType()) {
return contentType;
}
Charset charset = getContentType().getCharset();
Charset charset = mimeType.getCharset();
charset = (charset != null ? charset : StompDecoder.UTF8_CHARSET);
return (bytes.length < 80) ?
contentType + " payload=" + new String(bytes, charset) :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,9 @@ else if (payload instanceof byte[]) {
}

protected boolean isReadableContentType() {
MimeType contentType = getContentType();
for (MimeType mimeType : READABLE_MIME_TYPES) {
if (mimeType.includes(getContentType())) {
if (mimeType.includes(contentType)) {
return true;
}
}
Expand All @@ -557,6 +558,8 @@ public String toString() {
* its type does not match the required type.
* <p>This is for cases where the existence of an accessor is strongly expected
* (followed up with an assertion) or where an accessor will be created otherwise.
* @param message the message to get an accessor for
* @param requiredType the required accessor type (or {@code null} for any)
* @return an accessor instance of the specified type, or {@code null} if none
* @since 4.1
*/
Expand All @@ -568,6 +571,8 @@ public static <T extends MessageHeaderAccessor> T getAccessor(Message<?> message
* A variation of {@link #getAccessor(org.springframework.messaging.Message, Class)}
* with a {@code MessageHeaders} instance instead of a {@code Message}.
* <p>This is for cases when a full message may not have been created yet.
* @param messageHeaders the message headers to get an accessor for
* @param requiredType the required accessor type (or {@code null} for any)
* @return an accessor instance of the specified type, or {@code null} if none
* @since 4.1
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2016 the original author or authors.
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -285,11 +285,12 @@ else if (webSocketMessage instanceof BinaryMessage) {
logger.trace("From client: " + headerAccessor.getShortLogMessage(message.getPayload()));
}

boolean isConnect = StompCommand.CONNECT.equals(headerAccessor.getCommand());
StompCommand command = headerAccessor.getCommand();
boolean isConnect = StompCommand.CONNECT.equals(command);
if (isConnect) {
this.stats.incrementConnectCount();
}
else if (StompCommand.DISCONNECT.equals(headerAccessor.getCommand())) {
else if (StompCommand.DISCONNECT.equals(command)) {
this.stats.incrementDisconnectCount();
}

Expand All @@ -308,10 +309,10 @@ else if (StompCommand.DISCONNECT.equals(headerAccessor.getCommand())) {
if (isConnect) {
publishEvent(new SessionConnectEvent(this, message, getUser(session)));
}
else if (StompCommand.SUBSCRIBE.equals(headerAccessor.getCommand())) {
else if (StompCommand.SUBSCRIBE.equals(command)) {
publishEvent(new SessionSubscribeEvent(this, message, getUser(session)));
}
else if (StompCommand.UNSUBSCRIBE.equals(headerAccessor.getCommand())) {
else if (StompCommand.UNSUBSCRIBE.equals(command)) {
publishEvent(new SessionUnsubscribeEvent(this, message, getUser(session)));
}
}
Expand Down

0 comments on commit f861f18

Please sign in to comment.