From bcaa5997a43780e07d09d65c5f4abb7d0f3b87b1 Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 18 Dec 2020 14:44:58 -0500
Subject: [PATCH] ARROW-10962: [FlightRPC][Java] fill in empty body buffer if
needed
---
.../org/apache/arrow/flight/ArrowMessage.java | 29 ++++++++++++-
.../arrow/flight/TestBasicOperation.java | 41 +++++++++++++++++++
2 files changed, 69 insertions(+), 1 deletion(-)
diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java
index 19cde94a63602..06d3bd33ca6ea 100644
--- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java
+++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java
@@ -297,7 +297,34 @@ private static ArrowMessage frame(BufferAllocator allocator, final InputStream s
// ignore unknown fields.
}
}
-
+ // Protobuf implementations can omit empty fields, such as body; for some message types, like RecordBatch,
+ // this will fail later as we still expect an empty buffer. In those cases only, fill in an empty buffer here -
+ // in other cases, like Schema, having an unexpected empty buffer will also cause failures.
+ // We don't fill in defaults for fields like header, for which there is no reasonable default, or for appMetadata
+ // or descriptor, which are intended to be empty in some cases.
+ if (header != null) {
+ switch (HeaderType.getHeader(header.headerType())) {
+ case SCHEMA:
+ // Ignore 0-length buffers in case a Protobuf implementation wrote it out
+ if (body != null && body.capacity() == 0) {
+ body.close();
+ body = null;
+ }
+ break;
+ case DICTIONARY_BATCH:
+ case RECORD_BATCH:
+ // A Protobuf implementation can skip 0-length bodies, so ensure we fill it in here
+ if (body == null) {
+ body = allocator.getEmpty();
+ }
+ break;
+ case NONE:
+ case TENSOR:
+ default:
+ // Do nothing
+ break;
+ }
+ }
return new ArrowMessage(descriptor, header, appMetadata, body);
} catch (Exception ioe) {
throw new RuntimeException(ioe);
diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
index bae658230b5d1..0f6f2d3e023f6 100644
--- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
+++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
@@ -17,6 +17,9 @@
package org.apache.arrow.flight;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.InputStream;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
@@ -39,6 +42,9 @@
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.VectorUnloader;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
@@ -48,6 +54,8 @@
import com.google.common.base.Charsets;
import com.google.protobuf.ByteString;
+import io.grpc.MethodDescriptor;
+
/**
* Test the operations of a basic flight service.
*/
@@ -317,6 +325,39 @@ private void test(BiConsumer consumer) throws Exc
}
}
+ /** ARROW-10939: accept FlightData messages generated by Protobuf (which can omit empty fields). */
+ @Test
+ public void testProtobufCompat() throws Exception {
+ final Schema schema = new Schema(Collections.singletonList(Field.nullable("foo", new ArrowType.Int(32, true))));
+ try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
+ final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ final VectorUnloader unloader = new VectorUnloader(root);
+ root.setRowCount(0);
+ final MethodDescriptor.Marshaller marshaller = ArrowMessage.createMarshaller(allocator);
+ final ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ try (final ArrowMessage message = new ArrowMessage(unloader.getRecordBatch(), null, new IpcOption())) {
+ Assert.assertEquals(ArrowMessage.HeaderType.RECORD_BATCH, message.getMessageType());
+ try (final InputStream serialized = marshaller.stream(message)) {
+ final byte[] buf = new byte[1024];
+ while (true) {
+ int read = serialized.read(buf);
+ if (read < 0) {
+ break;
+ }
+ baos.write(buf, 0, read);
+ }
+ }
+ }
+ final byte[] serializedMessage = baos.toByteArray();
+ final Flight.FlightData protobufData = Flight.FlightData.parseFrom(serializedMessage);
+ Assert.assertEquals(0, protobufData.getDataBody().size());
+ // Should not throw
+ final ArrowRecordBatch rb =
+ marshaller.parse(new ByteArrayInputStream(protobufData.toByteArray())).asRecordBatch();
+ Assert.assertEquals(rb.computeBodyLength(), 0);
+ }
+ }
+
/**
* An example FlightProducer for test purposes.
*/