Skip to content

Commit 4927302

Browse files
committed
Fix issues at flight transport layer; Add middleware for header management
Signed-off-by: Rishabh Maurya <rishabhmaurya05@gmail.com>
1 parent 6c0eb1f commit 4927302

File tree

19 files changed

+570
-310
lines changed

19 files changed

+570
-310
lines changed

plugins/arrow-flight-rpc/src/internalClusterTest/java/org/opensearch/arrow/flight/FlightTransportIT.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ public void setUp() throws Exception {
7777

7878
@LockFeatureFlag(STREAM_TRANSPORT)
7979
public void testArrowFlightProducer() throws Exception {
80-
final SearchRequest searchRequest = new SearchRequest("index");
8180
ActionFuture<SearchResponse> future = client().prepareStreamSearch("index").execute();
8281
SearchResponse resp = future.actionGet();
8382
assertNotNull(resp);

plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/stream/ArrowStreamInput.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,6 @@ public ArrowStreamInput(VectorSchemaRoot root, NamedWriteableRegistry registry)
4848

4949
for (FieldVector vector : root.getFieldVectors()) {
5050
String fieldName = vector.getField().getName();
51-
// skip the header field
52-
if (fieldName.equals("_meta")) {
53-
continue;
54-
}
5551
String parentPath = extractParentPath(fieldName);
5652
vectorsByPath.computeIfAbsent(parentPath, k -> new ArrayList<>()).add(vector);
5753
}

plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/stream/ArrowStreamOutput.java

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import org.opensearch.core.common.io.stream.Writeable;
3030

3131
import java.io.IOException;
32-
import java.nio.ByteBuffer;
3332
import java.nio.charset.StandardCharsets;
3433
import java.util.ArrayList;
3534
import java.util.HashMap;
@@ -300,16 +299,8 @@ public void writeMap(@Nullable Map<String, Object> map) throws IOException {
300299
structVector.setValueCount(row + 1);
301300
}
302301

303-
public VectorSchemaRoot getUnifiedRoot(ByteBuffer headers) {
302+
public VectorSchemaRoot getUnifiedRoot() {
304303
List<FieldVector> allFields = new ArrayList<>();
305-
// TODO: we need a better mechanism to serialize headers; maybe make use of Tcp headers
306-
if (headers != null) {
307-
Field field = new Field("_meta", new FieldType(true, new ArrowType.Binary(), null, null), null);
308-
VarBinaryVector fieldVector = new VarBinaryVector(field, allocator);
309-
fieldVector.setSafe(0, headers.array());
310-
fieldVector.setValueCount(1);
311-
allFields.add(fieldVector);
312-
}
313304
for (VectorSchemaRoot root : roots.values()) {
314305
allFields.addAll(root.getFieldVectors());
315306
}

plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/ArrowFlightProducer.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import org.apache.arrow.flight.CallStatus;
1212
import org.apache.arrow.flight.FlightRuntimeException;
13+
import org.apache.arrow.flight.FlightServerMiddleware;
1314
import org.apache.arrow.flight.NoOpFlightProducer;
1415
import org.apache.arrow.flight.Ticket;
1516
import org.apache.arrow.memory.BufferAllocator;
@@ -28,8 +29,9 @@ public class ArrowFlightProducer extends NoOpFlightProducer {
2829
private final BufferAllocator allocator;
2930
private final InboundPipeline pipeline;
3031
private static final Logger logger = LogManager.getLogger(ArrowFlightProducer.class);
32+
private final FlightServerMiddleware.Key<ServerHeaderMiddleware> middlewareKey;
3133

32-
public ArrowFlightProducer(FlightTransport flightTransport, BufferAllocator allocator) {
34+
public ArrowFlightProducer(FlightTransport flightTransport, BufferAllocator allocator, FlightServerMiddleware.Key<ServerHeaderMiddleware> middlewareKey) {
3335
final ThreadPool threadPool = flightTransport.getThreadPool();
3436
final Transport.RequestHandlers requestHandlers = flightTransport.getRequestHandlers();
3537
this.pipeline = new InboundPipeline(
@@ -41,19 +43,22 @@ public ArrowFlightProducer(FlightTransport flightTransport, BufferAllocator allo
4143
requestHandlers::getHandler,
4244
flightTransport::inboundMessage
4345
);
46+
this.middlewareKey = middlewareKey;
4447
this.allocator = allocator;
4548
}
4649

4750
@Override
4851
public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
4952
try {
50-
FlightServerChannel channel = new FlightServerChannel(listener, allocator);
53+
FlightServerChannel channel = new FlightServerChannel(listener, allocator, context, context.getMiddleware(middlewareKey));
54+
listener.setUseZeroCopy(true);
5155
BytesArray buf = new BytesArray(ticket.getBytes());
5256
// nothing changes in inbound logic, so reusing native transport inbound pipeline
5357
try (ReleasableBytesReference reference = ReleasableBytesReference.wrap(buf)) {
5458
pipeline.handleBytes(channel, reference);
5559
}
5660
} catch (FlightRuntimeException ex) {
61+
logger.error("Unexpected error during stream processing", ex);
5762
listener.error(ex);
5863
throw ex;
5964
} catch (Exception ex) {
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
package org.opensearch.arrow.flight.transport;
10+
11+
import org.apache.arrow.flight.CallHeaders;
12+
import org.apache.arrow.flight.CallInfo;
13+
import org.apache.arrow.flight.CallStatus;
14+
import org.apache.arrow.flight.FlightClientMiddleware;
15+
import org.opensearch.Version;
16+
import org.opensearch.core.common.bytes.BytesArray;
17+
import org.opensearch.core.common.bytes.BytesReference;
18+
import org.opensearch.transport.Header;
19+
import org.opensearch.transport.InboundDecoder;
20+
import org.opensearch.transport.TransportException;
21+
import org.opensearch.transport.TransportStatus;
22+
23+
import java.io.IOException;
24+
import java.util.Base64;
25+
26+
/**
27+
* Client middleware for handling Arrow Flight headers. It assumes that one request is sent at a time to {@link FlightClientChannel}
28+
*/
29+
public class ClientHeaderMiddleware implements FlightClientMiddleware {
30+
private final HeaderContext context;
31+
private final Version version;
32+
33+
ClientHeaderMiddleware(HeaderContext context, Version version) {
34+
this.context = context;
35+
this.version = version;
36+
}
37+
38+
@Override
39+
public void onHeadersReceived(CallHeaders incomingHeaders) {
40+
String encodedHeader = incomingHeaders.get("raw-header");
41+
byte[] headerBuffer = Base64.getDecoder().decode(encodedHeader);
42+
BytesReference headerRef = new BytesArray(headerBuffer);
43+
Header header;
44+
try {
45+
header = InboundDecoder.readHeader(version, headerRef.length(), headerRef);
46+
} catch (IOException e) {
47+
throw new TransportException(e);
48+
}
49+
if (!Version.CURRENT.isCompatible(header.getVersion())) {
50+
throw new TransportException("Incompatible version: " + header.getVersion());
51+
}
52+
if (TransportStatus.isError(header.getStatus())) {
53+
throw new TransportException("Received error response");
54+
}
55+
context.setHeader(header);
56+
}
57+
58+
@Override
59+
public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {}
60+
61+
@Override
62+
public void onCallCompleted(CallStatus status) {}
63+
64+
public static class Factory implements FlightClientMiddleware.Factory {
65+
private final Version version;
66+
private final HeaderContext context;
67+
68+
Factory(HeaderContext context, Version version) {
69+
this.version = version;
70+
this.context = context;
71+
}
72+
73+
@Override
74+
public ClientHeaderMiddleware onCallStarted(CallInfo callInfo) {
75+
return new ClientHeaderMiddleware(context, version);
76+
}
77+
}
78+
}

0 commit comments

Comments
 (0)