From 46979fa4fc05fbf24fafbcad3af0c2b48ff6d5c8 Mon Sep 17 00:00:00 2001 From: Timothy Meehan Date: Thu, 7 Sep 2023 10:14:31 -0400 Subject: [PATCH] GH-37553: [Java] Allow FlightInfo#Schema to be nullable for long-running queries (#37528) With #36155, implementations of Flight RPC may not return quickly via a newly added pollFlightInfo function. Sometimes, the system implementing this function may not know the output schema for some time--for example, after a lengthy queue time as elapsed, or after planning. In proto3, fields may not be present, and it's a coding convention to require them 1. To support upcoming client integration work for pollFlightInfo, the schema field can be made optional so that it's not a requirement to populate the FlightInfo's schema on the first pollFlightInfo request. We can modify our client code to allow this field to be optional. This is already the case for the Go code. This changes the Java client code to allow the Schema to be null. A new `getSchemaOptional` method returns `Optional`, which is a backwards compatible change. The existing method is deprecated, but will still return an empty schema if the schema is not present on wire (as it used to before). ### Rationale for this change With #36155, implementations of Flight RPC may not return quickly via a newly added pollFlightInfo function. Sometimes, the system implementing this function may not know the output schema for some time--for example, after a lengthy queue time as elapsed, or after planning. In proto3, fields may not be present, and it's a coding convention to require them 1. To support upcoming client integration work for pollFlightInfo, the schema field can be made optional so that it's not a requirement to populate the FlightInfo's schema on the first pollFlightInfo request. CC: `@ lidavidm` ### What changes are included in this PR? This changes the Java client code to allow the Schema to be null. `getSchema` is now deprecated and a new `getSchemaOptional` returns `Optional`, which is a backwards compatible change. ### Are these changes tested? Existing tests ensure serialization and deserialization continue to work. ### Are there any user-facing changes? The `getSchema` methods are now deprecated in favor of `getSchemaOptional`. * Closes: #37553 Authored-by: Tim Meehan Signed-off-by: David Li --- .../org/apache/arrow/flight/FlightInfo.java | 57 ++++++++++++------- .../apache/arrow/flight/FlightProducer.java | 7 ++- .../org/apache/arrow/flight/SchemaResult.java | 2 + .../arrow/flight/TestBasicOperation.java | 14 +++++ .../arrow/flight/TestFlightService.java | 32 +++++++++++ .../arrow/flight/TestMetadataVersion.java | 3 +- .../ArrowFlightJdbcFlightStreamResultSet.java | 2 +- .../apache/arrow/flight/TestFlightSql.java | 21 ++++--- 8 files changed, 107 insertions(+), 31 deletions(-) diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java index 888c7293ea2c2..d871f89465c83 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java @@ -23,8 +23,10 @@ import java.nio.ByteBuffer; import java.nio.channels.Channels; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.stream.Collectors; import org.apache.arrow.flight.impl.Flight; @@ -36,7 +38,6 @@ import org.apache.arrow.vector.validate.MetadataV4UnionChecker; import com.fasterxml.jackson.databind.util.ByteBufferBackedInputStream; -import com.google.common.collect.ImmutableList; import com.google.protobuf.ByteString; /** @@ -93,10 +94,11 @@ public FlightInfo(Schema schema, FlightDescriptor descriptor, List endpoints, long bytes, long records, boolean ordered, IpcOption option) { - Objects.requireNonNull(schema); Objects.requireNonNull(descriptor); Objects.requireNonNull(endpoints); - MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option.metadataVersion); + if (schema != null) { + MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option.metadataVersion); + } this.schema = schema; this.descriptor = descriptor; this.endpoints = endpoints; @@ -114,8 +116,10 @@ public FlightInfo(Schema schema, FlightDescriptor descriptor, List 0 ? MessageSerializer.deserializeSchema( - new ReadChannel(Channels.newChannel(new ByteBufferBackedInputStream(schemaBuf)))) - : new Schema(ImmutableList.of()); + new ReadChannel( + Channels.newChannel( + new ByteBufferBackedInputStream(schemaBuf)))) + : null; } catch (IOException e) { throw new RuntimeException(e); } @@ -130,8 +134,17 @@ public FlightInfo(Schema schema, FlightDescriptor descriptor, List getSchemaOptional() { + return Optional.ofNullable(schema); + } + + /** + * Returns the schema, or an empty schema if no schema is present. + * @deprecated Deprecated. Use {@link #getSchemaOptional()} instead. + */ + @Deprecated public Schema getSchema() { - return schema; + return schema != null ? schema : new Schema(Collections.emptyList()); } public long getBytes() { @@ -158,21 +171,25 @@ public boolean getOrdered() { * Converts to the protocol buffer representation. */ Flight.FlightInfo toProtocol() { - // Encode schema in a Message payload - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - try { - MessageSerializer.serialize(new WriteChannel(Channels.newChannel(baos)), schema, option); - } catch (IOException e) { - throw new RuntimeException(e); + Flight.FlightInfo.Builder builder = Flight.FlightInfo.newBuilder() + .addAllEndpoint(endpoints.stream().map(t -> t.toProtocol()).collect(Collectors.toList())) + .setFlightDescriptor(descriptor.toProtocol()) + .setTotalBytes(FlightInfo.this.bytes) + .setTotalRecords(records) + .setOrdered(ordered); + if (schema != null) { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try { + MessageSerializer.serialize( + new WriteChannel(Channels.newChannel(baos)), + schema, + option); + builder.setSchema(ByteString.copyFrom(baos.toByteArray())); + } catch (IOException e) { + throw new RuntimeException(e); + } } - return Flight.FlightInfo.newBuilder() - .addAllEndpoint(endpoints.stream().map(t -> t.toProtocol()).collect(Collectors.toList())) - .setSchema(ByteString.copyFrom(baos.toByteArray())) - .setFlightDescriptor(descriptor.toProtocol()) - .setTotalBytes(FlightInfo.this.bytes) - .setTotalRecords(records) - .setOrdered(ordered) - .build(); + return builder.build(); } /** diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java index f2ae3db0b50d7..cdc29ae5de436 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java @@ -77,7 +77,12 @@ default PollInfo pollFlightInfo(CallContext context, FlightDescriptor descriptor */ default SchemaResult getSchema(CallContext context, FlightDescriptor descriptor) { FlightInfo info = getFlightInfo(context, descriptor); - return new SchemaResult(info.getSchema()); + return new SchemaResult(info + .getSchemaOptional() + .orElseThrow(() -> + CallStatus + .INVALID_ARGUMENT + .withDescription("No schema is present in FlightInfo").toRuntimeException())); } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java index 8a5e7d9a434dc..8becb85b8d32c 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.Channels; +import java.util.Objects; import org.apache.arrow.flight.impl.Flight; import org.apache.arrow.vector.ipc.ReadChannel; @@ -52,6 +53,7 @@ public SchemaResult(Schema schema) { * Create a schema result with specific IPC options for serialization. */ public SchemaResult(Schema schema, IpcOption option) { + Objects.requireNonNull(schema); MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option.metadataVersion); this.schema = schema; this.option = option; diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java index 260ea4a0e3fed..238221f051a7e 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java @@ -27,6 +27,7 @@ import java.net.InetSocketAddress; import java.net.URI; import java.net.URISyntaxException; +import java.nio.channels.Channels; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; @@ -49,8 +50,10 @@ import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.WriteChannel; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.ipc.message.IpcOption; +import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; @@ -556,6 +559,7 @@ public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { try { Flight.FlightInfo getInfo = Flight.FlightInfo.newBuilder() + .setSchema(schemaToByteString(new Schema(Collections.emptyList()))) .setFlightDescriptor(Flight.FlightDescriptor.newBuilder() .setType(DescriptorType.CMD) .setCmd(ByteString.copyFrom("cool thing", Charsets.UTF_8))) @@ -568,6 +572,16 @@ public FlightInfo getFlightInfo(CallContext context, } } + private static ByteString schemaToByteString(Schema schema) + { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(baos)), schema, IpcOption.DEFAULT); + return ByteString.copyFrom(baos.toByteArray()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + @Override public void doAction(CallContext context, Action action, StreamListener listener) { diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java index fb47a84164b88..691048fb03ed3 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java @@ -17,13 +17,20 @@ package org.apache.arrow.flight; +import static org.apache.arrow.flight.FlightTestUtil.LOCALHOST; +import static org.apache.arrow.flight.Location.forGrpcInsecure; import static org.junit.jupiter.api.Assertions.fail; +import java.util.Collections; +import java.util.Optional; + import org.apache.arrow.flight.impl.Flight; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.types.pojo.Schema; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -122,4 +129,29 @@ public void onCompleted() { // fail() would have been called if an error happened during doGetCustom(), so this test passed. } + + @Test + public void supportsNullSchemas() throws Exception + { + final FlightProducer producer = new NoOpFlightProducer() { + @Override + public FlightInfo getFlightInfo(CallContext context, + FlightDescriptor descriptor) { + return new FlightInfo(null, descriptor, Collections.emptyList(), 0, 0); + } + }; + + try (final FlightServer s = + FlightServer.builder(allocator, forGrpcInsecure(LOCALHOST, 0), producer).build().start(); + final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) { + FlightInfo flightInfo = client.getInfo(FlightDescriptor.path("test")); + Assertions.assertEquals(Optional.empty(), flightInfo.getSchemaOptional()); + Assertions.assertEquals(new Schema(Collections.emptyList()), flightInfo.getSchema()); + + Exception e = Assertions.assertThrows( + FlightRuntimeException.class, + () -> client.getSchema(FlightDescriptor.path("test"))); + Assertions.assertEquals("No schema is present in FlightInfo", e.getMessage()); + } + } } diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java index d6efa4ff80058..0d3f7d4ff843f 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java @@ -25,6 +25,7 @@ import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Collections; +import java.util.Optional; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -72,7 +73,7 @@ public void testGetFlightInfoV4() throws Exception { try (final FlightServer server = startServer(optionV4); final FlightClient client = connect(server)) { final FlightInfo result = client.getInfo(FlightDescriptor.command(new byte[0])); - assertEquals(schema, result.getSchema()); + assertEquals(Optional.of(schema), result.getSchemaOptional()); } } diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java index 4c01cb6e5813c..9a53f9fcafdd2 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java @@ -123,7 +123,7 @@ protected AvaticaResultSet execute() throws SQLException { final FlightInfo flightInfo = ((ArrowFlightInfoStatement) statement).executeFlightInfoQuery(); if (flightInfo != null) { - schema = flightInfo.getSchema(); + schema = flightInfo.getSchemaOptional().orElse(null); execute(flightInfo); } return this; diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java index 4b73f3c35f477..6da915a8ffb14 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java @@ -39,6 +39,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.stream.IntStream; import org.apache.arrow.flight.sql.FlightSqlClient; @@ -177,13 +178,15 @@ private static List> getNonConformingResultsForGetSqlInfo( @Test public void testGetTablesSchema() { final FlightInfo info = sqlClient.getTables(null, null, null, null, true); - MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA)); + MatcherAssert.assertThat(info.getSchemaOptional(), is(Optional.of(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA))); } @Test public void testGetTablesSchemaExcludeSchema() { final FlightInfo info = sqlClient.getTables(null, null, null, null, false); - MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA)); + MatcherAssert.assertThat( + info.getSchemaOptional(), + is(Optional.of(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA))); } @Test @@ -364,7 +367,7 @@ public void testSimplePreparedStatementSchema() throws Exception { }, () -> { final FlightInfo info = preparedStatement.execute(); - MatcherAssert.assertThat(info.getSchema(), is(SCHEMA_INT_TABLE)); + MatcherAssert.assertThat(info.getSchemaOptional(), is(Optional.of(SCHEMA_INT_TABLE))); } ); } @@ -477,7 +480,7 @@ public void testSimplePreparedStatementClosesProperly() { @Test public void testGetCatalogsSchema() { final FlightInfo info = sqlClient.getCatalogs(); - MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA)); + MatcherAssert.assertThat(info.getSchemaOptional(), is(Optional.of(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA))); } @Test @@ -497,7 +500,9 @@ public void testGetCatalogsResults() throws Exception { @Test public void testGetTableTypesSchema() { final FlightInfo info = sqlClient.getTableTypes(); - MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA)); + MatcherAssert.assertThat( + info.getSchemaOptional(), + is(Optional.of(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA))); } @Test @@ -526,7 +531,7 @@ public void testGetTableTypesResult() throws Exception { @Test public void testGetSchemasSchema() { final FlightInfo info = sqlClient.getSchemas(null, null); - MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_SCHEMAS_SCHEMA)); + MatcherAssert.assertThat(info.getSchemaOptional(), is(Optional.of(FlightSqlProducer.Schemas.GET_SCHEMAS_SCHEMA))); } @Test @@ -584,7 +589,7 @@ public void testGetPrimaryKey() { @Test public void testGetSqlInfoSchema() { final FlightInfo info = sqlClient.getSqlInfo(); - MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA)); + MatcherAssert.assertThat(info.getSchemaOptional(), is(Optional.of(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA))); } @Test @@ -848,7 +853,7 @@ public void testGetCommandCrossReference() { @Test public void testCreateStatementSchema() throws Exception { final FlightInfo info = sqlClient.execute("SELECT * FROM intTable"); - MatcherAssert.assertThat(info.getSchema(), is(SCHEMA_INT_TABLE)); + MatcherAssert.assertThat(info.getSchemaOptional(), is(Optional.of(SCHEMA_INT_TABLE))); // Consume statement to close connection before cache eviction try (FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket())) {