diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java index 19cde94a63602..06d3bd33ca6ea 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java @@ -297,7 +297,34 @@ private static ArrowMessage frame(BufferAllocator allocator, final InputStream s // ignore unknown fields. } } - + // Protobuf implementations can omit empty fields, such as body; for some message types, like RecordBatch, + // this will fail later as we still expect an empty buffer. In those cases only, fill in an empty buffer here - + // in other cases, like Schema, having an unexpected empty buffer will also cause failures. + // We don't fill in defaults for fields like header, for which there is no reasonable default, or for appMetadata + // or descriptor, which are intended to be empty in some cases. + if (header != null) { + switch (HeaderType.getHeader(header.headerType())) { + case SCHEMA: + // Ignore 0-length buffers in case a Protobuf implementation wrote it out + if (body != null && body.capacity() == 0) { + body.close(); + body = null; + } + break; + case DICTIONARY_BATCH: + case RECORD_BATCH: + // A Protobuf implementation can skip 0-length bodies, so ensure we fill it in here + if (body == null) { + body = allocator.getEmpty(); + } + break; + case NONE: + case TENSOR: + default: + // Do nothing + break; + } + } return new ArrowMessage(descriptor, header, appMetadata, body); } catch (Exception ioe) { throw new RuntimeException(ioe); diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java index bae658230b5d1..daf911d23dfd5 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java @@ -17,6 +17,10 @@ package org.apache.arrow.flight; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; import java.net.URI; import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; @@ -33,12 +37,16 @@ import org.apache.arrow.flight.FlightClient.ClientStreamListener; import org.apache.arrow.flight.impl.Flight; import org.apache.arrow.flight.impl.Flight.FlightDescriptor.DescriptorType; +import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.ipc.message.IpcOption; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; @@ -48,6 +56,8 @@ import com.google.common.base.Charsets; import com.google.protobuf.ByteString; +import io.grpc.MethodDescriptor; + /** * Test the operations of a basic flight service. */ @@ -317,6 +327,85 @@ private void test(BiConsumer consumer) throws Exc } } + /** Helper method to convert an ArrowMessage into a Protobuf message. */ + private Flight.FlightData arrowMessageToProtobuf( + MethodDescriptor.Marshaller marshaller, ArrowMessage message) throws IOException { + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (final InputStream serialized = marshaller.stream(message)) { + final byte[] buf = new byte[1024]; + while (true) { + int read = serialized.read(buf); + if (read < 0) { + break; + } + baos.write(buf, 0, read); + } + } + final byte[] serializedMessage = baos.toByteArray(); + return Flight.FlightData.parseFrom(serializedMessage); + } + + /** ARROW-10962: accept FlightData messages generated by Protobuf (which can omit empty fields). */ + @Test + public void testProtobufRecordBatchCompatibility() throws Exception { + final Schema schema = new Schema(Collections.singletonList(Field.nullable("foo", new ArrowType.Int(32, true)))); + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + final VectorUnloader unloader = new VectorUnloader(root); + root.setRowCount(0); + final MethodDescriptor.Marshaller marshaller = ArrowMessage.createMarshaller(allocator); + try (final ArrowMessage message = new ArrowMessage(unloader.getRecordBatch(), null, new IpcOption())) { + Assert.assertEquals(ArrowMessage.HeaderType.RECORD_BATCH, message.getMessageType()); + // Should have at least one empty body buffer (there may be multiple for e.g. data and validity) + Iterator iterator = message.getBufs().iterator(); + Assert.assertTrue(iterator.hasNext()); + while (iterator.hasNext()) { + Assert.assertEquals(0, iterator.next().capacity()); + } + final Flight.FlightData protobufData = arrowMessageToProtobuf(marshaller, message) + .toBuilder() + .clearDataBody() + .build(); + Assert.assertEquals(0, protobufData.getDataBody().size()); + ArrowMessage parsedMessage = marshaller.parse(new ByteArrayInputStream(protobufData.toByteArray())); + // Should have an empty body buffer + Iterator parsedIterator = parsedMessage.getBufs().iterator(); + Assert.assertTrue(parsedIterator.hasNext()); + Assert.assertEquals(0, parsedIterator.next().capacity()); + // Should have only one (the parser synthesizes exactly one); in the case of empty buffers, this is equivalent + Assert.assertFalse(parsedIterator.hasNext()); + // Should not throw + final ArrowRecordBatch rb = parsedMessage.asRecordBatch(); + Assert.assertEquals(rb.computeBodyLength(), 0); + } + } + } + + /** ARROW-10962: accept FlightData messages generated by Protobuf (which can omit empty fields). */ + @Test + public void testProtobufSchemaCompatibility() throws Exception { + final Schema schema = new Schema(Collections.singletonList(Field.nullable("foo", new ArrowType.Int(32, true)))); + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) { + final MethodDescriptor.Marshaller marshaller = ArrowMessage.createMarshaller(allocator); + Flight.FlightDescriptor descriptor = FlightDescriptor.command(new byte[0]).toProtocol(); + try (final ArrowMessage message = new ArrowMessage(descriptor, schema, new IpcOption())) { + Assert.assertEquals(ArrowMessage.HeaderType.SCHEMA, message.getMessageType()); + // Should have no body buffers + Assert.assertFalse(message.getBufs().iterator().hasNext()); + final Flight.FlightData protobufData = arrowMessageToProtobuf(marshaller, message) + .toBuilder() + .setDataBody(ByteString.EMPTY) + .build(); + Assert.assertEquals(0, protobufData.getDataBody().size()); + final ArrowMessage parsedMessage = marshaller.parse(new ByteArrayInputStream(protobufData.toByteArray())); + // Should have no body buffers + Assert.assertFalse(parsedMessage.getBufs().iterator().hasNext()); + // Should not throw + parsedMessage.asSchema(); + } + } + } + /** * An example FlightProducer for test purposes. */