Skip to content

Commit

Permalink
Change abstraction point for transport protocol
Browse files Browse the repository at this point in the history
The previous implementation had a transport switch point in
InboundPipeline when the bytes were initially pulled off the wire. There
was no implementation for any other protocol as the `canHandleBytes`
method was hardcoded to return true. I believe this is the wrong point
to switch on the protocol. This change makes NativeInboundBytesHandler
protocol agnostic beyond the header. With this change, a complete
message is parsed from the stream of bytes, with the header schema being
unchanged from what exists today. The protocol switch point will now be
at `InboundHandler::inboundMessage`. The header will indicate what
protocol was used to serialize the the non-header bytes of the message
and then invoke the appropriate handler based on that field.

Signed-off-by: Andrew Ross <andrross@amazon.com>
  • Loading branch information
andrross committed Aug 27, 2024
1 parent 20ebe6e commit d4217c1
Show file tree
Hide file tree
Showing 16 changed files with 370 additions and 373 deletions.
10 changes: 9 additions & 1 deletion server/src/main/java/org/opensearch/transport/Header.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public class Header {

private static final String RESPONSE_NAME = "NO_ACTION_NAME_FOR_RESPONSES";

private final TransportProtocol protocol;
private final int networkMessageSize;
private final Version version;
private final long requestId;
Expand All @@ -64,13 +65,18 @@ public class Header {
Tuple<Map<String, String>, Map<String, Set<String>>> headers;
Set<String> features;

Header(int networkMessageSize, long requestId, byte status, Version version) {
Header(TransportProtocol protocol, int networkMessageSize, long requestId, byte status, Version version) {
this.protocol = protocol;
this.networkMessageSize = networkMessageSize;
this.version = version;
this.requestId = requestId;
this.status = status;
}

TransportProtocol getTransportProtocol() {
return protocol;
}

public int getNetworkMessageSize() {
return networkMessageSize;
}
Expand Down Expand Up @@ -142,6 +148,8 @@ void finishParsingHeader(StreamInput input) throws IOException {
@Override
public String toString() {
return "Header{"
+ protocol
+ "}{"
+ networkMessageSize
+ "}{"
+ version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.bytes.CompositeBytesReference;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -114,7 +113,7 @@ public void aggregate(ReleasableBytesReference content) {
}
}

public NativeInboundMessage finishAggregation() throws IOException {
public ProtocolInboundMessage finishAggregation() throws IOException {
ensureOpen();
final ReleasableBytesReference releasableContent;
if (isFirstContent()) {
Expand All @@ -128,7 +127,7 @@ public NativeInboundMessage finishAggregation() throws IOException {
}

final BreakerControl breakerControl = new BreakerControl(circuitBreaker);
final NativeInboundMessage aggregated = new NativeInboundMessage(currentHeader, releasableContent, breakerControl);
final ProtocolInboundMessage aggregated = new ProtocolInboundMessage(currentHeader, releasableContent, breakerControl);
boolean success = false;
try {
if (aggregated.getHeader().needsToReadVariableHeader()) {
Expand All @@ -143,7 +142,7 @@ public NativeInboundMessage finishAggregation() throws IOException {
if (isShortCircuited()) {
aggregated.close();
success = true;
return new NativeInboundMessage(aggregated.getHeader(), aggregationException);
return new ProtocolInboundMessage(aggregated.getHeader(), aggregationException);
} else {
success = true;
return aggregated;
Expand Down
133 changes: 126 additions & 7 deletions server/src/main/java/org/opensearch/transport/InboundBytesHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,143 @@
package org.opensearch.transport;

import org.opensearch.common.bytes.ReleasableBytesReference;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.core.common.bytes.CompositeBytesReference;

import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.function.BiConsumer;

/**
* Interface for handling inbound bytes. Can be implemented by different transport protocols.
* Handler for inbound bytes for the native protocol.
*/
public interface InboundBytesHandler extends Closeable {
class InboundBytesHandler {

private static final ThreadLocal<ArrayList<Object>> fragmentList = ThreadLocal.withInitial(ArrayList::new);

private final ArrayDeque<ReleasableBytesReference> pending;
private final InboundDecoder decoder;
private final InboundAggregator aggregator;
private final StatsTracker statsTracker;
private boolean isClosed = false;

InboundBytesHandler(
ArrayDeque<ReleasableBytesReference> pending,
InboundDecoder decoder,
InboundAggregator aggregator,
StatsTracker statsTracker
) {
this.pending = pending;
this.decoder = decoder;
this.aggregator = aggregator;
this.statsTracker = statsTracker;
}

public void close() {
isClosed = true;
}

public void doHandleBytes(
TcpChannel channel,
ReleasableBytesReference reference,
BiConsumer<TcpChannel, ProtocolInboundMessage> messageHandler
) throws IOException;
) throws IOException {
final ArrayList<Object> fragments = fragmentList.get();
boolean continueHandling = true;

while (continueHandling && isClosed == false) {
boolean continueDecoding = true;
while (continueDecoding && pending.isEmpty() == false) {
try (ReleasableBytesReference toDecode = getPendingBytes()) {
final int bytesDecoded = decoder.decode(toDecode, fragments::add);
if (bytesDecoded != 0) {
releasePendingBytes(bytesDecoded);
if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) {
continueDecoding = false;
}
} else {
continueDecoding = false;
}
}
}

if (fragments.isEmpty()) {
continueHandling = false;
} else {
try {
forwardFragments(channel, fragments, messageHandler);
} finally {
for (Object fragment : fragments) {
if (fragment instanceof ReleasableBytesReference) {
((ReleasableBytesReference) fragment).close();
}
}
fragments.clear();
}
}
}
}

public boolean canHandleBytes(ReleasableBytesReference reference);
private ReleasableBytesReference getPendingBytes() {
if (pending.size() == 1) {
return pending.peekFirst().retain();
} else {
final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()];
int index = 0;
for (ReleasableBytesReference pendingReference : pending) {
bytesReferences[index] = pendingReference.retain();
++index;
}
final Releasable releasable = () -> Releasables.closeWhileHandlingException(bytesReferences);
return new ReleasableBytesReference(CompositeBytesReference.of(bytesReferences), releasable);
}
}

private void releasePendingBytes(int bytesConsumed) {
int bytesToRelease = bytesConsumed;
while (bytesToRelease != 0) {
try (ReleasableBytesReference reference = pending.pollFirst()) {
assert reference != null;
if (bytesToRelease < reference.length()) {
pending.addFirst(reference.retainedSlice(bytesToRelease, reference.length() - bytesToRelease));
bytesToRelease -= bytesToRelease;
} else {
bytesToRelease -= reference.length();
}
}
}
}

private boolean endOfMessage(Object fragment) {
return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception;
}

private void forwardFragments(
TcpChannel channel,
ArrayList<Object> fragments,
BiConsumer<TcpChannel, ProtocolInboundMessage> messageHandler
) throws IOException {
for (Object fragment : fragments) {
if (fragment instanceof Header) {
assert aggregator.isAggregating() == false;
aggregator.headerReceived((Header) fragment);
} else if (fragment == InboundDecoder.PING) {
assert aggregator.isAggregating() == false;
messageHandler.accept(channel, ProtocolInboundMessage.PING);
} else if (fragment == InboundDecoder.END_CONTENT) {
assert aggregator.isAggregating();
try (ProtocolInboundMessage aggregated = aggregator.finishAggregation()) {
statsTracker.markMessageReceived();
messageHandler.accept(channel, aggregated);
}
} else {
assert aggregator.isAggregating();
assert fragment instanceof ReleasableBytesReference;
aggregator.aggregate((ReleasableBytesReference) fragment);
}
}
}

@Override
void close();
}
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,12 @@ private int headerBytesToRead(BytesReference reference) {
// exposed for use in tests
static Header readHeader(Version version, int networkMessageSize, BytesReference bytesReference) throws IOException {
try (StreamInput streamInput = bytesReference.streamInput()) {
streamInput.skip(TcpHeader.BYTES_REQUIRED_FOR_MESSAGE_SIZE);
TransportProtocol protocol = TransportProtocol.fromBytes(streamInput.readByte(), streamInput.readByte());
streamInput.skip(TcpHeader.MESSAGE_LENGTH_SIZE);
long requestId = streamInput.readLong();
byte status = streamInput.readByte();
Version remoteVersion = Version.fromId(streamInput.readInt());
Header header = new Header(networkMessageSize, requestId, status, remoteVersion);
Header header = new Header(protocol, networkMessageSize, requestId, status, remoteVersion);
final IllegalStateException invalidVersion = ensureVersionCompatibility(remoteVersion, version, header.isHandshake());
if (invalidVersion != null) {
throw invalidVersion;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;

import java.io.IOException;
import java.util.Map;
Expand All @@ -56,7 +55,7 @@ public class InboundHandler {

private volatile long slowLogThresholdMs = Long.MAX_VALUE;

private final Map<String, ProtocolMessageHandler> protocolMessageHandlers;
private final Map<TransportProtocol, ProtocolMessageHandler> protocolMessageHandlers;

InboundHandler(
String nodeName,
Expand All @@ -75,7 +74,7 @@ public class InboundHandler {
) {
this.threadPool = threadPool;
this.protocolMessageHandlers = Map.of(
NativeInboundMessage.NATIVE_PROTOCOL,
TransportProtocol.NATIVE,
new NativeMessageHandler(
nodeName,
version,
Expand Down Expand Up @@ -114,9 +113,9 @@ void inboundMessage(TcpChannel channel, ProtocolInboundMessage message) throws E
}

private void messageReceivedFromPipeline(TcpChannel channel, ProtocolInboundMessage message, long startTime) throws IOException {
ProtocolMessageHandler protocolMessageHandler = protocolMessageHandlers.get(message.getProtocol());
ProtocolMessageHandler protocolMessageHandler = protocolMessageHandlers.get(message.getTransportProtocol());
if (protocolMessageHandler == null) {
throw new IllegalStateException("No protocol message handler found for protocol: " + message.getProtocol());
throw new IllegalStateException("No protocol message handler found for protocol: " + message.getTransportProtocol());

Check warning on line 118 in server/src/main/java/org/opensearch/transport/InboundHandler.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/transport/InboundHandler.java#L118

Added line #L118 was not covered by tests
}
protocolMessageHandler.messageReceived(channel, message, startTime, slowLogThresholdMs, messageListener);
}
Expand Down
30 changes: 4 additions & 26 deletions server/src/main/java/org/opensearch/transport/InboundPipeline.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,9 @@
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.util.PageCacheRecycler;
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.transport.nativeprotocol.NativeInboundBytesHandler;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.LongSupplier;
Expand All @@ -63,8 +61,7 @@ public class InboundPipeline implements Releasable {
private final ArrayDeque<ReleasableBytesReference> pending = new ArrayDeque<>(2);
private boolean isClosed = false;
private final BiConsumer<TcpChannel, ProtocolInboundMessage> messageHandler;
private final List<InboundBytesHandler> protocolBytesHandlers;
private InboundBytesHandler currentHandler;
private final InboundBytesHandler bytesHandler;

public InboundPipeline(
Version version,
Expand Down Expand Up @@ -95,17 +92,14 @@ public InboundPipeline(
this.statsTracker = statsTracker;
this.decoder = decoder;
this.aggregator = aggregator;
this.protocolBytesHandlers = List.of(new NativeInboundBytesHandler(pending, decoder, aggregator, statsTracker));
this.bytesHandler = new InboundBytesHandler(pending, decoder, aggregator, statsTracker);
this.messageHandler = messageHandler;
}

@Override
public void close() {
isClosed = true;
if (currentHandler != null) {
currentHandler.close();
currentHandler = null;
}
bytesHandler.close();
Releasables.closeWhileHandlingException(decoder, aggregator);
Releasables.closeWhileHandlingException(pending);
pending.clear();
Expand All @@ -127,22 +121,6 @@ public void doHandleBytes(TcpChannel channel, ReleasableBytesReference reference
channel.getChannelStats().markAccessed(relativeTimeInMillis.getAsLong());
statsTracker.markBytesRead(reference.length());
pending.add(reference.retain());

// If we don't have a current handler, we should try to find one based on the protocol of the incoming bytes.
if (currentHandler == null) {
for (InboundBytesHandler handler : protocolBytesHandlers) {
if (handler.canHandleBytes(reference)) {
currentHandler = handler;
break;
}
}
}

// If we have a current handler determined based on protocol, we should continue to use it for the fragmented bytes.
if (currentHandler != null) {
currentHandler.doHandleBytes(channel, reference, messageHandler);
} else {
throw new IllegalStateException("No bytes handler found for the incoming transport protocol");
}
bytesHandler.doHandleBytes(channel, reference, messageHandler);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.telemetry.tracing.channels.TraceableTcpTransportChannel;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;
import org.opensearch.transport.nativeprotocol.NativeOutboundHandler;

import java.io.EOFException;
Expand Down Expand Up @@ -119,18 +118,17 @@ public void messageReceived(
long slowLogThresholdMs,
TransportMessageListener messageListener
) throws IOException {
NativeInboundMessage inboundMessage = (NativeInboundMessage) message;
TransportLogger.logInboundMessage(channel, inboundMessage);
if (inboundMessage.isPing()) {
TransportLogger.logInboundMessage(channel, message);
if (message.isPing()) {
keepAlive.receiveKeepAlive(channel);
} else {
handleMessage(channel, inboundMessage, startTime, slowLogThresholdMs, messageListener);
handleMessage(channel, message, startTime, slowLogThresholdMs, messageListener);
}
}

private void handleMessage(
TcpChannel channel,
NativeInboundMessage message,
ProtocolInboundMessage message,
long startTime,
long slowLogThresholdMs,
TransportMessageListener messageListener
Expand Down Expand Up @@ -202,7 +200,7 @@ private Map<String, Collection<String>> extractHeaders(Map<String, String> heade
private <T extends TransportRequest> void handleRequest(
TcpChannel channel,
Header header,
NativeInboundMessage message,
ProtocolInboundMessage message,
TransportMessageListener messageListener
) throws IOException {
final String action = header.getActionName();
Expand Down
Loading

0 comments on commit d4217c1

Please sign in to comment.