diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java index 888c7293ea2c2..d871f89465c83 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java @@ -23,8 +23,10 @@ import java.nio.ByteBuffer; import java.nio.channels.Channels; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.stream.Collectors; import org.apache.arrow.flight.impl.Flight; @@ -36,7 +38,6 @@ import org.apache.arrow.vector.validate.MetadataV4UnionChecker; import com.fasterxml.jackson.databind.util.ByteBufferBackedInputStream; -import com.google.common.collect.ImmutableList; import com.google.protobuf.ByteString; /** @@ -93,10 +94,11 @@ public FlightInfo(Schema schema, FlightDescriptor descriptor, List endpoints, long bytes, long records, boolean ordered, IpcOption option) { - Objects.requireNonNull(schema); Objects.requireNonNull(descriptor); Objects.requireNonNull(endpoints); - MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option.metadataVersion); + if (schema != null) { + MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option.metadataVersion); + } this.schema = schema; this.descriptor = descriptor; this.endpoints = endpoints; @@ -114,8 +116,10 @@ public FlightInfo(Schema schema, FlightDescriptor descriptor, List 0 ? MessageSerializer.deserializeSchema( - new ReadChannel(Channels.newChannel(new ByteBufferBackedInputStream(schemaBuf)))) - : new Schema(ImmutableList.of()); + new ReadChannel( + Channels.newChannel( + new ByteBufferBackedInputStream(schemaBuf)))) + : null; } catch (IOException e) { throw new RuntimeException(e); } @@ -130,8 +134,17 @@ public FlightInfo(Schema schema, FlightDescriptor descriptor, List getSchemaOptional() { + return Optional.ofNullable(schema); + } + + /** + * Returns the schema, or an empty schema if no schema is present. + * @deprecated Deprecated. Use {@link #getSchemaOptional()} instead. + */ + @Deprecated public Schema getSchema() { - return schema; + return schema != null ? schema : new Schema(Collections.emptyList()); } public long getBytes() { @@ -158,21 +171,25 @@ public boolean getOrdered() { * Converts to the protocol buffer representation. */ Flight.FlightInfo toProtocol() { - // Encode schema in a Message payload - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - try { - MessageSerializer.serialize(new WriteChannel(Channels.newChannel(baos)), schema, option); - } catch (IOException e) { - throw new RuntimeException(e); + Flight.FlightInfo.Builder builder = Flight.FlightInfo.newBuilder() + .addAllEndpoint(endpoints.stream().map(t -> t.toProtocol()).collect(Collectors.toList())) + .setFlightDescriptor(descriptor.toProtocol()) + .setTotalBytes(FlightInfo.this.bytes) + .setTotalRecords(records) + .setOrdered(ordered); + if (schema != null) { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try { + MessageSerializer.serialize( + new WriteChannel(Channels.newChannel(baos)), + schema, + option); + builder.setSchema(ByteString.copyFrom(baos.toByteArray())); + } catch (IOException e) { + throw new RuntimeException(e); + } } - return Flight.FlightInfo.newBuilder() - .addAllEndpoint(endpoints.stream().map(t -> t.toProtocol()).collect(Collectors.toList())) - .setSchema(ByteString.copyFrom(baos.toByteArray())) - .setFlightDescriptor(descriptor.toProtocol()) - .setTotalBytes(FlightInfo.this.bytes) - .setTotalRecords(records) - .setOrdered(ordered) - .build(); + return builder.build(); } /** diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java index f2ae3db0b50d7..cdc29ae5de436 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java @@ -77,7 +77,12 @@ default PollInfo pollFlightInfo(CallContext context, FlightDescriptor descriptor */ default SchemaResult getSchema(CallContext context, FlightDescriptor descriptor) { FlightInfo info = getFlightInfo(context, descriptor); - return new SchemaResult(info.getSchema()); + return new SchemaResult(info + .getSchemaOptional() + .orElseThrow(() -> + CallStatus + .INVALID_ARGUMENT + .withDescription("No schema is present in FlightInfo").toRuntimeException())); } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java index 8a5e7d9a434dc..8becb85b8d32c 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.Channels; +import java.util.Objects; import org.apache.arrow.flight.impl.Flight; import org.apache.arrow.vector.ipc.ReadChannel; @@ -52,6 +53,7 @@ public SchemaResult(Schema schema) { * Create a schema result with specific IPC options for serialization. */ public SchemaResult(Schema schema, IpcOption option) { + Objects.requireNonNull(schema); MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option.metadataVersion); this.schema = schema; this.option = option; diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java index 260ea4a0e3fed..238221f051a7e 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java @@ -27,6 +27,7 @@ import java.net.InetSocketAddress; import java.net.URI; import java.net.URISyntaxException; +import java.nio.channels.Channels; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; @@ -49,8 +50,10 @@ import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.WriteChannel; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.ipc.message.IpcOption; +import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; @@ -556,6 +559,7 @@ public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { try { Flight.FlightInfo getInfo = Flight.FlightInfo.newBuilder() + .setSchema(schemaToByteString(new Schema(Collections.emptyList()))) .setFlightDescriptor(Flight.FlightDescriptor.newBuilder() .setType(DescriptorType.CMD) .setCmd(ByteString.copyFrom("cool thing", Charsets.UTF_8))) @@ -568,6 +572,16 @@ public FlightInfo getFlightInfo(CallContext context, } } + private static ByteString schemaToByteString(Schema schema) + { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(baos)), schema, IpcOption.DEFAULT); + return ByteString.copyFrom(baos.toByteArray()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + @Override public void doAction(CallContext context, Action action, StreamListener listener) { diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java index fb47a84164b88..691048fb03ed3 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java @@ -17,13 +17,20 @@ package org.apache.arrow.flight; +import static org.apache.arrow.flight.FlightTestUtil.LOCALHOST; +import static org.apache.arrow.flight.Location.forGrpcInsecure; import static org.junit.jupiter.api.Assertions.fail; +import java.util.Collections; +import java.util.Optional; + import org.apache.arrow.flight.impl.Flight; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.types.pojo.Schema; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -122,4 +129,29 @@ public void onCompleted() { // fail() would have been called if an error happened during doGetCustom(), so this test passed. } + + @Test + public void supportsNullSchemas() throws Exception + { + final FlightProducer producer = new NoOpFlightProducer() { + @Override + public FlightInfo getFlightInfo(CallContext context, + FlightDescriptor descriptor) { + return new FlightInfo(null, descriptor, Collections.emptyList(), 0, 0); + } + }; + + try (final FlightServer s = + FlightServer.builder(allocator, forGrpcInsecure(LOCALHOST, 0), producer).build().start(); + final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) { + FlightInfo flightInfo = client.getInfo(FlightDescriptor.path("test")); + Assertions.assertEquals(Optional.empty(), flightInfo.getSchemaOptional()); + Assertions.assertEquals(new Schema(Collections.emptyList()), flightInfo.getSchema()); + + Exception e = Assertions.assertThrows( + FlightRuntimeException.class, + () -> client.getSchema(FlightDescriptor.path("test"))); + Assertions.assertEquals("No schema is present in FlightInfo", e.getMessage()); + } + } } diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java index d6efa4ff80058..0d3f7d4ff843f 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java @@ -25,6 +25,7 @@ import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Collections; +import java.util.Optional; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -72,7 +73,7 @@ public void testGetFlightInfoV4() throws Exception { try (final FlightServer server = startServer(optionV4); final FlightClient client = connect(server)) { final FlightInfo result = client.getInfo(FlightDescriptor.command(new byte[0])); - assertEquals(schema, result.getSchema()); + assertEquals(Optional.of(schema), result.getSchemaOptional()); } } diff --git a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java index 4c01cb6e5813c..9a53f9fcafdd2 100644 --- a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java +++ b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java @@ -123,7 +123,7 @@ protected AvaticaResultSet execute() throws SQLException { final FlightInfo flightInfo = ((ArrowFlightInfoStatement) statement).executeFlightInfoQuery(); if (flightInfo != null) { - schema = flightInfo.getSchema(); + schema = flightInfo.getSchemaOptional().orElse(null); execute(flightInfo); } return this; diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java index 4b73f3c35f477..6da915a8ffb14 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java @@ -39,6 +39,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.stream.IntStream; import org.apache.arrow.flight.sql.FlightSqlClient; @@ -177,13 +178,15 @@ private static List> getNonConformingResultsForGetSqlInfo( @Test public void testGetTablesSchema() { final FlightInfo info = sqlClient.getTables(null, null, null, null, true); - MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA)); + MatcherAssert.assertThat(info.getSchemaOptional(), is(Optional.of(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA))); } @Test public void testGetTablesSchemaExcludeSchema() { final FlightInfo info = sqlClient.getTables(null, null, null, null, false); - MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA)); + MatcherAssert.assertThat( + info.getSchemaOptional(), + is(Optional.of(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA))); } @Test @@ -364,7 +367,7 @@ public void testSimplePreparedStatementSchema() throws Exception { }, () -> { final FlightInfo info = preparedStatement.execute(); - MatcherAssert.assertThat(info.getSchema(), is(SCHEMA_INT_TABLE)); + MatcherAssert.assertThat(info.getSchemaOptional(), is(Optional.of(SCHEMA_INT_TABLE))); } ); } @@ -477,7 +480,7 @@ public void testSimplePreparedStatementClosesProperly() { @Test public void testGetCatalogsSchema() { final FlightInfo info = sqlClient.getCatalogs(); - MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA)); + MatcherAssert.assertThat(info.getSchemaOptional(), is(Optional.of(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA))); } @Test @@ -497,7 +500,9 @@ public void testGetCatalogsResults() throws Exception { @Test public void testGetTableTypesSchema() { final FlightInfo info = sqlClient.getTableTypes(); - MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA)); + MatcherAssert.assertThat( + info.getSchemaOptional(), + is(Optional.of(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA))); } @Test @@ -526,7 +531,7 @@ public void testGetTableTypesResult() throws Exception { @Test public void testGetSchemasSchema() { final FlightInfo info = sqlClient.getSchemas(null, null); - MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_SCHEMAS_SCHEMA)); + MatcherAssert.assertThat(info.getSchemaOptional(), is(Optional.of(FlightSqlProducer.Schemas.GET_SCHEMAS_SCHEMA))); } @Test @@ -584,7 +589,7 @@ public void testGetPrimaryKey() { @Test public void testGetSqlInfoSchema() { final FlightInfo info = sqlClient.getSqlInfo(); - MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA)); + MatcherAssert.assertThat(info.getSchemaOptional(), is(Optional.of(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA))); } @Test @@ -848,7 +853,7 @@ public void testGetCommandCrossReference() { @Test public void testCreateStatementSchema() throws Exception { final FlightInfo info = sqlClient.execute("SELECT * FROM intTable"); - MatcherAssert.assertThat(info.getSchema(), is(SCHEMA_INT_TABLE)); + MatcherAssert.assertThat(info.getSchemaOptional(), is(Optional.of(SCHEMA_INT_TABLE))); // Consume statement to close connection before cache eviction try (FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket())) {