diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java index 07264863afb6..c1608e1bb03d 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompEncoder.java @@ -81,7 +81,10 @@ public byte[] encode(Map headers, byte[] payload) { } else { StompCommand command = StompHeaderAccessor.getCommand(headers); - Assert.notNull(command, "Missing STOMP command: " + headers); + if (command == null) { + throw new IllegalStateException("Missing STOMP command: " + headers); + } + output.write(command.toString().getBytes(StandardCharsets.UTF_8)); output.write(LF); writeHeaders(command, headers, payload, output); @@ -115,22 +118,25 @@ private void writeHeaders(StompCommand command, Map headers, byt boolean shouldEscape = (command != StompCommand.CONNECT && command != StompCommand.CONNECTED); for (Entry> entry : nativeHeaders.entrySet()) { - byte[] key = encodeHeaderString(entry.getKey(), shouldEscape); if (command.requiresContentLength() && "content-length".equals(entry.getKey())) { continue; } + List values = entry.getValue(); if (StompCommand.CONNECT.equals(command) && StompHeaderAccessor.STOMP_PASSCODE_HEADER.equals(entry.getKey())) { values = Arrays.asList(StompHeaderAccessor.getPasscode(headers)); } + + byte[] encodedKey = encodeHeaderString(entry.getKey(), shouldEscape); for (String value : values) { - output.write(key); + output.write(encodedKey); output.write(COLON); output.write(encodeHeaderString(value, shouldEscape)); output.write(LF); } } + if (command.requiresContentLength()) { int contentLength = payload.length; output.write("content-length:".getBytes(StandardCharsets.UTF_8));