diff --git a/server/src/main/java/org/opensearch/transport/InboundAggregator.java b/server/src/main/java/org/opensearch/transport/InboundAggregator.java index f52875d880b4f..e894331f3b64e 100644 --- a/server/src/main/java/org/opensearch/transport/InboundAggregator.java +++ b/server/src/main/java/org/opensearch/transport/InboundAggregator.java @@ -40,7 +40,6 @@ import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.bytes.CompositeBytesReference; -import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import java.io.IOException; import java.util.ArrayList; @@ -114,7 +113,7 @@ public void aggregate(ReleasableBytesReference content) { } } - public NativeInboundMessage finishAggregation() throws IOException { + public InboundMessage finishAggregation() throws IOException { ensureOpen(); final ReleasableBytesReference releasableContent; if (isFirstContent()) { @@ -128,7 +127,7 @@ public NativeInboundMessage finishAggregation() throws IOException { } final BreakerControl breakerControl = new BreakerControl(circuitBreaker); - final NativeInboundMessage aggregated = new NativeInboundMessage(currentHeader, releasableContent, breakerControl); + final InboundMessage aggregated = new InboundMessage(currentHeader, releasableContent, breakerControl); boolean success = false; try { if (aggregated.getHeader().needsToReadVariableHeader()) { @@ -143,7 +142,7 @@ public NativeInboundMessage finishAggregation() throws IOException { if (isShortCircuited()) { aggregated.close(); success = true; - return new NativeInboundMessage(aggregated.getHeader(), aggregationException); + return new InboundMessage(aggregated.getHeader(), aggregationException); } else { success = true; return aggregated; diff --git a/server/src/main/java/org/opensearch/transport/InboundMessage.java b/server/src/main/java/org/opensearch/transport/InboundMessage.java new file mode 100644 index 0000000000000..5c68257557061 --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/InboundMessage.java @@ -0,0 +1,108 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.transport; + +import org.opensearch.common.annotation.DeprecatedApi; +import org.opensearch.common.bytes.ReleasableBytesReference; +import org.opensearch.common.lease.Releasable; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.transport.nativeprotocol.NativeInboundMessage; + +import java.io.IOException; + +/** + * Inbound data as a message + * This api is deprecated, please use {@link org.opensearch.transport.nativeprotocol.NativeInboundMessage} instead. + * @opensearch.api + */ +@DeprecatedApi(since = "2.14.0") +public class InboundMessage implements Releasable, ProtocolInboundMessage { + + private final NativeInboundMessage nativeInboundMessage; + + public InboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) { + this.nativeInboundMessage = new NativeInboundMessage(header, content, breakerRelease); + } + + public InboundMessage(Header header, Exception exception) { + this.nativeInboundMessage = new NativeInboundMessage(header, exception); + } + + public InboundMessage(Header header, boolean isPing) { + this.nativeInboundMessage = new NativeInboundMessage(header, isPing); + } + + public Header getHeader() { + return this.nativeInboundMessage.getHeader(); + } + + public int getContentLength() { + return this.nativeInboundMessage.getContentLength(); + } + + public Exception getException() { + return this.nativeInboundMessage.getException(); + } + + public boolean isPing() { + return this.nativeInboundMessage.isPing(); + } + + public boolean isShortCircuit() { + return this.nativeInboundMessage.getException() != null; + } + + public Releasable takeBreakerReleaseControl() { + return this.nativeInboundMessage.takeBreakerReleaseControl(); + } + + public StreamInput openOrGetStreamInput() throws IOException { + return this.nativeInboundMessage.openOrGetStreamInput(); + } + + @Override + public void close() { + this.nativeInboundMessage.close(); + } + + @Override + public String toString() { + return this.nativeInboundMessage.toString(); + } + + @Override + public String getProtocol() { + return this.nativeInboundMessage.getProtocol(); + } + +} diff --git a/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java b/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java index 4c972fdc14fa5..58adc2d3d68a5 100644 --- a/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java +++ b/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java @@ -52,7 +52,6 @@ import org.opensearch.telemetry.tracing.Tracer; import org.opensearch.telemetry.tracing.channels.TraceableTcpTransportChannel; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import org.opensearch.transport.nativeprotocol.NativeOutboundHandler; import java.io.EOFException; @@ -119,7 +118,7 @@ public void messageReceived( long slowLogThresholdMs, TransportMessageListener messageListener ) throws IOException { - NativeInboundMessage inboundMessage = (NativeInboundMessage) message; + InboundMessage inboundMessage = (InboundMessage) message; TransportLogger.logInboundMessage(channel, inboundMessage); if (inboundMessage.isPing()) { keepAlive.receiveKeepAlive(channel); @@ -130,7 +129,7 @@ public void messageReceived( private void handleMessage( TcpChannel channel, - NativeInboundMessage message, + InboundMessage message, long startTime, long slowLogThresholdMs, TransportMessageListener messageListener @@ -202,7 +201,7 @@ private Map> extractHeaders(Map heade private void handleRequest( TcpChannel channel, Header header, - NativeInboundMessage message, + InboundMessage message, TransportMessageListener messageListener ) throws IOException { final String action = header.getActionName(); diff --git a/server/src/main/java/org/opensearch/transport/TcpTransport.java b/server/src/main/java/org/opensearch/transport/TcpTransport.java index ffa3168da0b3e..78452a25a58d6 100644 --- a/server/src/main/java/org/opensearch/transport/TcpTransport.java +++ b/server/src/main/java/org/opensearch/transport/TcpTransport.java @@ -777,6 +777,18 @@ protected void serverAcceptedChannel(TcpChannel channel) { */ protected abstract void stopInternal(); + /** + * @deprecated use {@link #inboundMessage(TcpChannel, ProtocolInboundMessage)} + * Handles inbound message that has been decoded. + * + * @param channel the channel the message is from + * @param message the message + */ + @Deprecated(since = "2.14.0", forRemoval = true) + public void inboundMessage(TcpChannel channel, InboundMessage message) { + inboundMessage(channel, (ProtocolInboundMessage) message); + } + /** * Handles inbound message that has been decoded. * diff --git a/server/src/main/java/org/opensearch/transport/TransportLogger.java b/server/src/main/java/org/opensearch/transport/TransportLogger.java index e780f643aafd7..997b3bb5ba18e 100644 --- a/server/src/main/java/org/opensearch/transport/TransportLogger.java +++ b/server/src/main/java/org/opensearch/transport/TransportLogger.java @@ -40,7 +40,6 @@ import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.compress.CompressorRegistry; -import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import java.io.IOException; @@ -65,7 +64,7 @@ static void logInboundMessage(TcpChannel channel, BytesReference message) { } } - static void logInboundMessage(TcpChannel channel, NativeInboundMessage message) { + static void logInboundMessage(TcpChannel channel, InboundMessage message) { if (logger.isTraceEnabled()) { try { String logMessage = format(channel, message, "READ"); @@ -137,7 +136,7 @@ private static String format(TcpChannel channel, BytesReference message, String return sb.toString(); } - private static String format(TcpChannel channel, NativeInboundMessage message, String event) throws IOException { + private static String format(TcpChannel channel, InboundMessage message, String event) throws IOException { final StringBuilder sb = new StringBuilder(); sb.append(channel); diff --git a/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundBytesHandler.java b/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundBytesHandler.java index 97981aeb6736e..a8a4c0da7ec0f 100644 --- a/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundBytesHandler.java +++ b/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundBytesHandler.java @@ -16,6 +16,7 @@ import org.opensearch.transport.InboundAggregator; import org.opensearch.transport.InboundBytesHandler; import org.opensearch.transport.InboundDecoder; +import org.opensearch.transport.InboundMessage; import org.opensearch.transport.ProtocolInboundMessage; import org.opensearch.transport.StatsTracker; import org.opensearch.transport.TcpChannel; @@ -31,7 +32,7 @@ public class NativeInboundBytesHandler implements InboundBytesHandler { private static final ThreadLocal> fragmentList = ThreadLocal.withInitial(ArrayList::new); - private static final NativeInboundMessage PING_MESSAGE = new NativeInboundMessage(null, true); + private static final InboundMessage PING_MESSAGE = new InboundMessage(null, true); private final ArrayDeque pending; private final InboundDecoder decoder; @@ -151,7 +152,7 @@ private void forwardFragments( messageHandler.accept(channel, PING_MESSAGE); } else if (fragment == InboundDecoder.END_CONTENT) { assert aggregator.isAggregating(); - try (NativeInboundMessage aggregated = aggregator.finishAggregation()) { + try (InboundMessage aggregated = aggregator.finishAggregation()) { statsTracker.markMessageReceived(); messageHandler.accept(channel, aggregated); } diff --git a/server/src/test/java/org/opensearch/transport/InboundAggregatorTests.java b/server/src/test/java/org/opensearch/transport/InboundAggregatorTests.java index 4ac78366360d7..2dd98a8efe2a3 100644 --- a/server/src/test/java/org/opensearch/transport/InboundAggregatorTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundAggregatorTests.java @@ -42,7 +42,6 @@ import org.opensearch.core.common.breaker.CircuitBreakingException; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import org.junit.Before; import java.io.IOException; @@ -108,7 +107,7 @@ public void testInboundAggregation() throws IOException { } // Signal EOS - NativeInboundMessage aggregated = aggregator.finishAggregation(); + InboundMessage aggregated = aggregator.finishAggregation(); assertThat(aggregated, notNullValue()); assertFalse(aggregated.isPing()); @@ -139,7 +138,7 @@ public void testInboundUnknownAction() throws IOException { assertEquals(0, content.refCount()); // Signal EOS - NativeInboundMessage aggregated = aggregator.finishAggregation(); + InboundMessage aggregated = aggregator.finishAggregation(); assertThat(aggregated, notNullValue()); assertTrue(aggregated.isShortCircuit()); @@ -162,7 +161,7 @@ public void testCircuitBreak() throws IOException { content1.close(); // Signal EOS - NativeInboundMessage aggregated1 = aggregator.finishAggregation(); + InboundMessage aggregated1 = aggregator.finishAggregation(); assertEquals(0, content1.refCount()); assertThat(aggregated1, notNullValue()); @@ -181,7 +180,7 @@ public void testCircuitBreak() throws IOException { content2.close(); // Signal EOS - NativeInboundMessage aggregated2 = aggregator.finishAggregation(); + InboundMessage aggregated2 = aggregator.finishAggregation(); assertEquals(1, content2.refCount()); assertThat(aggregated2, notNullValue()); @@ -200,7 +199,7 @@ public void testCircuitBreak() throws IOException { content3.close(); // Signal EOS - NativeInboundMessage aggregated3 = aggregator.finishAggregation(); + InboundMessage aggregated3 = aggregator.finishAggregation(); assertEquals(1, content3.refCount()); assertThat(aggregated3, notNullValue()); @@ -264,7 +263,7 @@ public void testFinishAggregationWillFinishHeader() throws IOException { content.close(); // Signal EOS - NativeInboundMessage aggregated = aggregator.finishAggregation(); + InboundMessage aggregated = aggregator.finishAggregation(); assertThat(aggregated, notNullValue()); assertFalse(header.needsToReadVariableHeader()); diff --git a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java index 2553e7740990b..ea656f6651b1e 100644 --- a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java @@ -57,7 +57,6 @@ import org.opensearch.test.VersionUtils; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import org.junit.After; import org.junit.Before; @@ -152,7 +151,7 @@ public void testPing() throws Exception { ); requestHandlers.registerHandler(registry); - handler.inboundMessage(channel, new NativeInboundMessage(null, true)); + handler.inboundMessage(channel, new InboundMessage(null, true)); if (channel.isServerChannel()) { BytesReference ping = channel.getMessageCaptor().get(); assertEquals('E', ping.get(0)); @@ -216,11 +215,7 @@ public TestResponse read(StreamInput in) throws IOException { ); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); - NativeInboundMessage requestMessage = new NativeInboundMessage( - requestHeader, - ReleasableBytesReference.wrap(requestContent), - () -> {} - ); + InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -241,11 +236,7 @@ public TestResponse read(StreamInput in) throws IOException { BytesReference fullResponseBytes = channel.getMessageCaptor().get(); BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize); Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version); - NativeInboundMessage responseMessage = new NativeInboundMessage( - responseHeader, - ReleasableBytesReference.wrap(responseContent), - () -> {} - ); + InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {}); responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput()); handler.inboundMessage(channel, responseMessage); @@ -272,7 +263,7 @@ public void testSendsErrorResponseToHandshakeFromCompatibleVersion() throws Exce TransportStatus.setRequest(TransportStatus.setHandshake((byte) 0)), remoteVersion ); - final NativeInboundMessage requestMessage = unreadableInboundHandshake(remoteVersion, requestHeader); + final InboundMessage requestMessage = unreadableInboundHandshake(remoteVersion, requestHeader); requestHeader.actionName = TransportHandshaker.HANDSHAKE_ACTION_NAME; requestHeader.headers = Tuple.tuple(Map.of(), Map.of()); requestHeader.features = Set.of(); @@ -312,7 +303,7 @@ public void testClosesChannelOnErrorInHandshakeWithIncompatibleVersion() throws TransportStatus.setRequest(TransportStatus.setHandshake((byte) 0)), remoteVersion ); - final NativeInboundMessage requestMessage = unreadableInboundHandshake(remoteVersion, requestHeader); + final InboundMessage requestMessage = unreadableInboundHandshake(remoteVersion, requestHeader); requestHeader.actionName = TransportHandshaker.HANDSHAKE_ACTION_NAME; requestHeader.headers = Tuple.tuple(Map.of(), Map.of()); requestHeader.features = Set.of(); @@ -343,17 +334,13 @@ public void testLogsSlowInboundProcessing() throws Exception { TransportStatus.setRequest(TransportStatus.setHandshake((byte) 0)), remoteVersion ); - final NativeInboundMessage requestMessage = new NativeInboundMessage( - requestHeader, - ReleasableBytesReference.wrap(BytesArray.EMPTY), - () -> { - try { - TimeUnit.SECONDS.sleep(1L); - } catch (InterruptedException e) { - throw new AssertionError(e); - } + final InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(BytesArray.EMPTY), () -> { + try { + TimeUnit.SECONDS.sleep(1L); + } catch (InterruptedException e) { + throw new AssertionError(e); } - ); + }); requestHeader.actionName = TransportHandshaker.HANDSHAKE_ACTION_NAME; requestHeader.headers = Tuple.tuple(Collections.emptyMap(), Collections.emptyMap()); requestHeader.features = Set.of(); @@ -425,11 +412,7 @@ public void onResponseSent(long requestId, String action, Exception error) { BytesReference fullRequestBytes = BytesReference.fromByteBuffer((ByteBuffer) buffer.flip()); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); - NativeInboundMessage requestMessage = new NativeInboundMessage( - requestHeader, - ReleasableBytesReference.wrap(requestContent), - () -> {} - ); + InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -494,11 +477,7 @@ public void onResponseSent(long requestId, String action, Exception error) { // Create the request payload by intentionally stripping 1 byte away BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize - 1); Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); - NativeInboundMessage requestMessage = new NativeInboundMessage( - requestHeader, - ReleasableBytesReference.wrap(requestContent), - () -> {} - ); + InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -562,11 +541,7 @@ public TestResponse read(StreamInput in) throws IOException { ); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); - NativeInboundMessage requestMessage = new NativeInboundMessage( - requestHeader, - ReleasableBytesReference.wrap(requestContent), - () -> {} - ); + InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -588,11 +563,7 @@ public TestResponse read(StreamInput in) throws IOException { BytesReference fullResponseBytes = BytesReference.fromByteBuffer((ByteBuffer) buffer.flip()); BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize); Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version); - NativeInboundMessage responseMessage = new NativeInboundMessage( - responseHeader, - ReleasableBytesReference.wrap(responseContent), - () -> {} - ); + InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {}); responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput()); handler.inboundMessage(channel, responseMessage); @@ -656,11 +627,7 @@ public TestResponse read(StreamInput in) throws IOException { ); BytesReference requestContent = fullRequestBytes.slice(headerSize, fullRequestBytes.length() - headerSize); Header requestHeader = new Header(fullRequestBytes.length() - 6, requestId, TransportStatus.setRequest((byte) 0), version); - NativeInboundMessage requestMessage = new NativeInboundMessage( - requestHeader, - ReleasableBytesReference.wrap(requestContent), - () -> {} - ); + InboundMessage requestMessage = new InboundMessage(requestHeader, ReleasableBytesReference.wrap(requestContent), () -> {}); requestHeader.finishParsingHeader(requestMessage.openOrGetStreamInput()); handler.inboundMessage(channel, requestMessage); @@ -677,11 +644,7 @@ public TestResponse read(StreamInput in) throws IOException { // Create the response payload by intentionally stripping 1 byte away BytesReference responseContent = fullResponseBytes.slice(headerSize, fullResponseBytes.length() - headerSize - 1); Header responseHeader = new Header(fullResponseBytes.length() - 6, requestId, responseStatus, version); - NativeInboundMessage responseMessage = new NativeInboundMessage( - responseHeader, - ReleasableBytesReference.wrap(responseContent), - () -> {} - ); + InboundMessage responseMessage = new InboundMessage(responseHeader, ReleasableBytesReference.wrap(responseContent), () -> {}); responseHeader.finishParsingHeader(responseMessage.openOrGetStreamInput()); handler.inboundMessage(channel, responseMessage); @@ -690,8 +653,8 @@ public TestResponse read(StreamInput in) throws IOException { assertThat(exceptionCaptor.get().getMessage(), containsString("Failed to deserialize response from handler")); } - private static NativeInboundMessage unreadableInboundHandshake(Version remoteVersion, Header requestHeader) { - return new NativeInboundMessage(requestHeader, ReleasableBytesReference.wrap(BytesArray.EMPTY), () -> {}) { + private static InboundMessage unreadableInboundHandshake(Version remoteVersion, Header requestHeader) { + return new InboundMessage(requestHeader, ReleasableBytesReference.wrap(BytesArray.EMPTY), () -> {}) { @Override public StreamInput openOrGetStreamInput() { final StreamInput streamInput = new InputStreamStreamInput(new InputStream() { diff --git a/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java b/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java index 5a89bf1e0ead3..74457e2b153fd 100644 --- a/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java @@ -49,7 +49,6 @@ import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import java.io.IOException; import java.util.ArrayList; @@ -84,7 +83,7 @@ public void testPipelineHandlingForNativeProtocol() throws IOException { final List toRelease = new ArrayList<>(); final BiConsumer messageHandler = (c, m) -> { try { - NativeInboundMessage message = (NativeInboundMessage) m; + InboundMessage message = (InboundMessage) m; final Header header = message.getHeader(); final MessageData actualData; final Version version = header.getVersion();