Skip to content

Commit

Permalink
apacheGH-37553: [Java] Allow FlightInfo#Schema to be nullable for lon…
Browse files Browse the repository at this point in the history
…g-running queries

With apache#36155, implementations of Flight RPC may not return quickly via a newly added pollFlightInfo function. Sometimes, the system implementing this function may not know the output schema for some time--for example, after a lengthy queue time has elapsed, or after planning.

In proto3, fields may not be present, and it's a coding convention to require them. To support upcoming client integration work for pollFlightInfo, the schema field can be made optional so that it's not a requirement to populate the FlightInfo's schema on the first pollFlightInfo request.

We can modify our client code to allow this field to be optional. This is already the case for the Go code.

This changes the Java client code to allow the Schema to be null.  `getSchema` methods now return `Optional<Schema>`, which is a backwards incompatible change.
  • Loading branch information
tdcmeehan committed Sep 6, 2023
1 parent faa7cf6 commit f0197f4
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

import org.apache.arrow.flight.impl.Flight;
Expand All @@ -36,7 +38,6 @@
import org.apache.arrow.vector.validate.MetadataV4UnionChecker;

import com.fasterxml.jackson.databind.util.ByteBufferBackedInputStream;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;

/**
Expand Down Expand Up @@ -93,10 +94,11 @@ public FlightInfo(Schema schema, FlightDescriptor descriptor, List<FlightEndpoin
*/
public FlightInfo(Schema schema, FlightDescriptor descriptor, List<FlightEndpoint> endpoints, long bytes,
long records, boolean ordered, IpcOption option) {
Objects.requireNonNull(schema);
Objects.requireNonNull(descriptor);
Objects.requireNonNull(endpoints);
MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option.metadataVersion);
if (schema != null) {
MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option.metadataVersion);
}
this.schema = schema;
this.descriptor = descriptor;
this.endpoints = endpoints;
Expand All @@ -114,8 +116,10 @@ public FlightInfo(Schema schema, FlightDescriptor descriptor, List<FlightEndpoin
final ByteBuffer schemaBuf = pbFlightInfo.getSchema().asReadOnlyByteBuffer();
schema = pbFlightInfo.getSchema().size() > 0 ?
MessageSerializer.deserializeSchema(
new ReadChannel(Channels.newChannel(new ByteBufferBackedInputStream(schemaBuf))))
: new Schema(ImmutableList.of());
new ReadChannel(
Channels.newChannel(
new ByteBufferBackedInputStream(schemaBuf))))
: null;
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand All @@ -130,8 +134,17 @@ public FlightInfo(Schema schema, FlightDescriptor descriptor, List<FlightEndpoin
option = IpcOption.DEFAULT;
}

public Optional<Schema> getSchemaOptional() {
return Optional.ofNullable(schema);
}

/**
* Returns the schema, or null if no schema is present.
* @deprecated Deprecated. Use {@link #getSchemaOptional()} instead.
*/
@Deprecated
public Schema getSchema() {
return schema;
return schema != null ? schema : new Schema(Collections.emptyList());
}

public long getBytes() {
Expand All @@ -158,21 +171,25 @@ public boolean getOrdered() {
* Converts to the protocol buffer representation.
*/
Flight.FlightInfo toProtocol() {
// Encode schema in a Message payload
ByteArrayOutputStream baos = new ByteArrayOutputStream();
try {
MessageSerializer.serialize(new WriteChannel(Channels.newChannel(baos)), schema, option);
} catch (IOException e) {
throw new RuntimeException(e);
Flight.FlightInfo.Builder builder = Flight.FlightInfo.newBuilder()
.addAllEndpoint(endpoints.stream().map(t -> t.toProtocol()).collect(Collectors.toList()))
.setFlightDescriptor(descriptor.toProtocol())
.setTotalBytes(FlightInfo.this.bytes)
.setTotalRecords(records)
.setOrdered(ordered);
if (schema != null) {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
try {
MessageSerializer.serialize(
new WriteChannel(Channels.newChannel(baos)),
schema,
option);
builder.setSchema(ByteString.copyFrom(baos.toByteArray()));
} catch (IOException e) {
throw new RuntimeException(e);
}
}
return Flight.FlightInfo.newBuilder()
.addAllEndpoint(endpoints.stream().map(t -> t.toProtocol()).collect(Collectors.toList()))
.setSchema(ByteString.copyFrom(baos.toByteArray()))
.setFlightDescriptor(descriptor.toProtocol())
.setTotalBytes(FlightInfo.this.bytes)
.setTotalRecords(records)
.setOrdered(ordered)
.build();
return builder.build();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,12 @@ default PollInfo pollFlightInfo(CallContext context, FlightDescriptor descriptor
*/
default SchemaResult getSchema(CallContext context, FlightDescriptor descriptor) {
FlightInfo info = getFlightInfo(context, descriptor);
return new SchemaResult(info.getSchema());
return new SchemaResult(info
.getSchemaOptional()
.orElseThrow(() ->
CallStatus
.INVALID_ARGUMENT
.withDescription("No schema is present in FlightInfo").toRuntimeException()));
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.util.Objects;

import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.vector.ipc.ReadChannel;
Expand Down Expand Up @@ -52,6 +53,7 @@ public SchemaResult(Schema schema) {
* Create a schema result with specific IPC options for serialization.
*/
public SchemaResult(Schema schema, IpcOption option) {
Objects.requireNonNull(schema);
MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option.metadataVersion);
this.schema = schema;
this.option = option;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.channels.Channels;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -49,8 +50,10 @@
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.ipc.WriteChannel;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
Expand Down Expand Up @@ -556,6 +559,7 @@ public FlightInfo getFlightInfo(CallContext context,
FlightDescriptor descriptor) {
try {
Flight.FlightInfo getInfo = Flight.FlightInfo.newBuilder()
.setSchema(schemaToByteString(new Schema(Collections.emptyList())))
.setFlightDescriptor(Flight.FlightDescriptor.newBuilder()
.setType(DescriptorType.CMD)
.setCmd(ByteString.copyFrom("cool thing", Charsets.UTF_8)))
Expand All @@ -568,6 +572,16 @@ public FlightInfo getFlightInfo(CallContext context,
}
}

private static ByteString schemaToByteString(Schema schema)
{
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
MessageSerializer.serialize(new WriteChannel(Channels.newChannel(baos)), schema, IpcOption.DEFAULT);
return ByteString.copyFrom(baos.toByteArray());
} catch (IOException e) {
throw new RuntimeException(e);
}
}

@Override
public void doAction(CallContext context, Action action,
StreamListener<Result> listener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,20 @@

package org.apache.arrow.flight;

import static org.apache.arrow.flight.FlightTestUtil.LOCALHOST;
import static org.apache.arrow.flight.Location.forGrpcInsecure;
import static org.junit.jupiter.api.Assertions.fail;

import java.util.Collections;
import java.util.Optional;

import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -122,4 +129,31 @@ public void onCompleted() {

// fail() would have been called if an error happened during doGetCustom(), so this test passed.
}

@Test
public void supportsNullSchemas() throws Exception
{
final FlightProducer producer = new NoOpFlightProducer() {
@Override
public FlightInfo getFlightInfo(CallContext context,
FlightDescriptor descriptor) {
return new FlightInfo(null, descriptor, Collections.emptyList(), 0, 0);
}
};

try (final BufferAllocator a = new RootAllocator();
final FlightServer s = FlightServer.builder(a, forGrpcInsecure(LOCALHOST, 0), producer).build().start()) {

try (FlightClient client = FlightClient.builder(a, s.getLocation()).build()) {
FlightInfo flightInfo = client.getInfo(FlightDescriptor.path("test"));
Assertions.assertEquals(Optional.empty(), flightInfo.getSchemaOptional());
Assertions.assertEquals(new Schema(Collections.emptyList()), flightInfo.getSchema());

Exception e = Assertions.assertThrows(
FlightRuntimeException.class,
() -> client.getSchema(FlightDescriptor.path("test")));
Assertions.assertEquals("No schema is present in FlightInfo", e.getMessage());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.Optional;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
Expand Down Expand Up @@ -72,7 +73,7 @@ public void testGetFlightInfoV4() throws Exception {
try (final FlightServer server = startServer(optionV4);
final FlightClient client = connect(server)) {
final FlightInfo result = client.getInfo(FlightDescriptor.command(new byte[0]));
assertEquals(schema, result.getSchema());
assertEquals(Optional.of(schema), result.getSchemaOptional());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ protected AvaticaResultSet execute() throws SQLException {
final FlightInfo flightInfo = ((ArrowFlightInfoStatement) statement).executeFlightInfoQuery();

if (flightInfo != null) {
schema = flightInfo.getSchema();
schema = flightInfo.getSchemaOptional().orElse(null);
execute(flightInfo);
}
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.IntStream;

import org.apache.arrow.flight.sql.FlightSqlClient;
Expand Down Expand Up @@ -177,13 +178,15 @@ private static List<List<String>> getNonConformingResultsForGetSqlInfo(
@Test
public void testGetTablesSchema() {
final FlightInfo info = sqlClient.getTables(null, null, null, null, true);
MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA));
MatcherAssert.assertThat(info.getSchemaOptional(), is(Optional.of(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA)));
}

@Test
public void testGetTablesSchemaExcludeSchema() {
final FlightInfo info = sqlClient.getTables(null, null, null, null, false);
MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA));
MatcherAssert.assertThat(
info.getSchemaOptional(),
is(Optional.of(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA)));
}

@Test
Expand Down Expand Up @@ -364,7 +367,7 @@ public void testSimplePreparedStatementSchema() throws Exception {
},
() -> {
final FlightInfo info = preparedStatement.execute();
MatcherAssert.assertThat(info.getSchema(), is(SCHEMA_INT_TABLE));
MatcherAssert.assertThat(info.getSchemaOptional(), is(Optional.of(SCHEMA_INT_TABLE)));
}
);
}
Expand Down Expand Up @@ -477,7 +480,7 @@ public void testSimplePreparedStatementClosesProperly() {
@Test
public void testGetCatalogsSchema() {
final FlightInfo info = sqlClient.getCatalogs();
MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA));
MatcherAssert.assertThat(info.getSchemaOptional(), is(Optional.of(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA)));
}

@Test
Expand All @@ -497,7 +500,9 @@ public void testGetCatalogsResults() throws Exception {
@Test
public void testGetTableTypesSchema() {
final FlightInfo info = sqlClient.getTableTypes();
MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA));
MatcherAssert.assertThat(
info.getSchemaOptional(),
is(Optional.of(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA)));
}

@Test
Expand Down Expand Up @@ -526,7 +531,7 @@ public void testGetTableTypesResult() throws Exception {
@Test
public void testGetSchemasSchema() {
final FlightInfo info = sqlClient.getSchemas(null, null);
MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_SCHEMAS_SCHEMA));
MatcherAssert.assertThat(info.getSchemaOptional(), is(Optional.of(FlightSqlProducer.Schemas.GET_SCHEMAS_SCHEMA)));
}

@Test
Expand Down Expand Up @@ -584,7 +589,7 @@ public void testGetPrimaryKey() {
@Test
public void testGetSqlInfoSchema() {
final FlightInfo info = sqlClient.getSqlInfo();
MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA));
MatcherAssert.assertThat(info.getSchemaOptional(), is(Optional.of(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA)));
}

@Test
Expand Down Expand Up @@ -848,7 +853,7 @@ public void testGetCommandCrossReference() {
@Test
public void testCreateStatementSchema() throws Exception {
final FlightInfo info = sqlClient.execute("SELECT * FROM intTable");
MatcherAssert.assertThat(info.getSchema(), is(SCHEMA_INT_TABLE));
MatcherAssert.assertThat(info.getSchemaOptional(), is(Optional.of(SCHEMA_INT_TABLE)));

// Consume statement to close connection before cache eviction
try (FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ public static Schema deserializeSchema(Message schemaMessage) {
public static Schema deserializeSchema(ReadChannel in) throws IOException {
MessageMetadataResult result = readMessage(in);
if (result == null) {
throw new IOException("Unexpected end of input when reading Schema");
return null;
}
if (result.getMessage().headerType() != MessageHeader.Schema) {
throw new IOException("Expected schema but header was " + result.getMessage().headerType());
Expand Down

0 comments on commit f0197f4

Please sign in to comment.