From bcaa5997a43780e07d09d65c5f4abb7d0f3b87b1 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 18 Dec 2020 14:44:58 -0500 Subject: [PATCH] ARROW-10962: [FlightRPC][Java] fill in empty body buffer if needed --- .../org/apache/arrow/flight/ArrowMessage.java | 29 ++++++++++++- .../arrow/flight/TestBasicOperation.java | 41 +++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java index 19cde94a63602..06d3bd33ca6ea 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java @@ -297,7 +297,34 @@ private static ArrowMessage frame(BufferAllocator allocator, final InputStream s // ignore unknown fields. } } - + // Protobuf implementations can omit empty fields, such as body; for some message types, like RecordBatch, + // this will fail later as we still expect an empty buffer. In those cases only, fill in an empty buffer here - + // in other cases, like Schema, having an unexpected empty buffer will also cause failures. + // We don't fill in defaults for fields like header, for which there is no reasonable default, or for appMetadata + // or descriptor, which are intended to be empty in some cases. + if (header != null) { + switch (HeaderType.getHeader(header.headerType())) { + case SCHEMA: + // Ignore 0-length buffers in case a Protobuf implementation wrote it out + if (body != null && body.capacity() == 0) { + body.close(); + body = null; + } + break; + case DICTIONARY_BATCH: + case RECORD_BATCH: + // A Protobuf implementation can skip 0-length bodies, so ensure we fill it in here + if (body == null) { + body = allocator.getEmpty(); + } + break; + case NONE: + case TENSOR: + default: + // Do nothing + break; + } + } return new ArrowMessage(descriptor, header, appMetadata, body); } catch (Exception ioe) { throw new RuntimeException(ioe); diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java index bae658230b5d1..0f6f2d3e023f6 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java @@ -17,6 +17,9 @@ package org.apache.arrow.flight; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.InputStream; import java.net.URI; import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; @@ -39,6 +42,9 @@ import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.ipc.message.IpcOption; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; @@ -48,6 +54,8 @@ import com.google.common.base.Charsets; import com.google.protobuf.ByteString; +import io.grpc.MethodDescriptor; + /** * Test the operations of a basic flight service. */ @@ -317,6 +325,39 @@ private void test(BiConsumer consumer) throws Exc } } + /** ARROW-10939: accept FlightData messages generated by Protobuf (which can omit empty fields). */ + @Test + public void testProtobufCompat() throws Exception { + final Schema schema = new Schema(Collections.singletonList(Field.nullable("foo", new ArrowType.Int(32, true)))); + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + final VectorUnloader unloader = new VectorUnloader(root); + root.setRowCount(0); + final MethodDescriptor.Marshaller marshaller = ArrowMessage.createMarshaller(allocator); + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (final ArrowMessage message = new ArrowMessage(unloader.getRecordBatch(), null, new IpcOption())) { + Assert.assertEquals(ArrowMessage.HeaderType.RECORD_BATCH, message.getMessageType()); + try (final InputStream serialized = marshaller.stream(message)) { + final byte[] buf = new byte[1024]; + while (true) { + int read = serialized.read(buf); + if (read < 0) { + break; + } + baos.write(buf, 0, read); + } + } + } + final byte[] serializedMessage = baos.toByteArray(); + final Flight.FlightData protobufData = Flight.FlightData.parseFrom(serializedMessage); + Assert.assertEquals(0, protobufData.getDataBody().size()); + // Should not throw + final ArrowRecordBatch rb = + marshaller.parse(new ByteArrayInputStream(protobufData.toByteArray())).asRecordBatch(); + Assert.assertEquals(rb.computeBodyLength(), 0); + } + } + /** * An example FlightProducer for test purposes. */