diff --git a/extensions/barrage/build.gradle b/extensions/barrage/build.gradle index b996971ef63..f57cdcadbeb 100644 --- a/extensions/barrage/build.gradle +++ b/extensions/barrage/build.gradle @@ -23,7 +23,6 @@ dependencies { implementation libs.arrow.vector implementation libs.arrow.format implementation project(path: ':extensions-source-support') - implementation project(path: ':extensions-source-support') compileOnly project(':util-immutables') annotationProcessor libs.immutables.value diff --git a/extensions/barrage/src/main/java/io/deephaven/extensions/barrage/util/ArrowToTableConverter.java b/extensions/barrage/src/main/java/io/deephaven/extensions/barrage/util/ArrowToTableConverter.java index 804d59468f9..2c8388ad9d1 100644 --- a/extensions/barrage/src/main/java/io/deephaven/extensions/barrage/util/ArrowToTableConverter.java +++ b/extensions/barrage/src/main/java/io/deephaven/extensions/barrage/util/ArrowToTableConverter.java @@ -87,7 +87,7 @@ public static Schema parseArrowSchema(final BarrageProtoUtil.MessageInfo mi) { return schema; } - public static long[] extractBufferInfo(@NotNull final RecordBatch batch) { + public static PrimitiveIterator.OfLong extractBufferInfo(@NotNull final RecordBatch batch) { final long[] bufferInfo = new long[batch.buffersLength()]; for (int bi = 0; bi < batch.buffersLength(); ++bi) { int offset = LongSizedDataStructure.intSize("BufferInfo", batch.buffers(bi).offset()); @@ -101,7 +101,7 @@ public static long[] extractBufferInfo(@NotNull final RecordBatch batch) { } bufferInfo[bi] = length; } - return bufferInfo; + return Arrays.stream(bufferInfo).iterator(); } @ScriptApi @@ -113,7 +113,7 @@ public synchronized void setSchema(final ByteBuffer ipcMessage) { throw new IllegalStateException("Conversion is complete; cannot process additional messages"); } final BarrageProtoUtil.MessageInfo mi = parseArrowIpcMessage(ipcMessage); - parseSchema(parseArrowSchema(mi)); + configureWithSchema(parseArrowSchema(mi)); } @ScriptApi @@ -168,13 +168,12 @@ public synchronized void onCompleted() throws InterruptedException { completed = true; } - protected void parseSchema(final Schema schema) { + protected void configureWithSchema(final Schema schema) { if (resultTable != null) { throw Exceptions.statusRuntimeException(Code.INVALID_ARGUMENT, "Schema evolution not supported"); } final BarrageUtil.ConvertedArrowSchema result = BarrageUtil.convertArrowSchema(schema); - resultTable = BarrageTable.make(null, result.tableDef, result.attributes, null); resultTable.setFlat(); @@ -203,8 +202,7 @@ protected BarrageMessage createBarrageMessage(BarrageProtoUtil.MessageInfo mi, i new FlatBufferIteratorAdapter<>(batch.nodesLength(), i -> new ChunkInputStreamGenerator.FieldNodeInfo(batch.nodes(i))); - final long[] bufferInfo = extractBufferInfo(batch); - final PrimitiveIterator.OfLong bufferInfoIter = Arrays.stream(bufferInfo).iterator(); + final PrimitiveIterator.OfLong bufferInfoIter = extractBufferInfo(batch); msg.rowsRemoved = RowSetFactory.empty(); msg.shifted = RowSetShiftData.EMPTY; diff --git a/extensions/barrage/src/main/java/io/deephaven/extensions/barrage/util/PythonTableDataService.java b/extensions/barrage/src/main/java/io/deephaven/extensions/barrage/util/PythonTableDataService.java index d7a14c92f2b..9ab3f45d97a 100644 --- a/extensions/barrage/src/main/java/io/deephaven/extensions/barrage/util/PythonTableDataService.java +++ b/extensions/barrage/src/main/java/io/deephaven/extensions/barrage/util/PythonTableDataService.java @@ -72,7 +72,7 @@ private PythonTableDataService(@NotNull final PyObject pyTableDataService) { } /** - * Get a Deephaven {@link Table} for the supplied name. + * Get a Deephaven {@link Table} for the supplied {@link TableKey}. * * @param tableKey The table key * @param live Whether the table should update as new data becomes available @@ -90,11 +90,6 @@ public Table makeTable(@NotNull final TableKeyImpl tableKey, final boolean live) live ? ExecutionContext.getContext().getUpdateGraph() : null); } - private static class SchemaPair { - BarrageUtil.ConvertedArrowSchema tableSchema; - BarrageUtil.ConvertedArrowSchema partitionSchema; - } - /** * This Backend impl marries the Python TableDataService with the Deephaven TableDataService. By performing the * object translation here, we can keep the Python TableDataService implementation simple and focused on the Python @@ -114,16 +109,23 @@ private BackendAccessor( * @param tableKey the table key * @return the schemas */ - public SchemaPair getTableSchema( + public BarrageUtil.ConvertedArrowSchema[] getTableSchema( @NotNull final TableKeyImpl tableKey) { - final ByteBuffer[] schemas = - pyTableDataService.call("_table_schema", tableKey.key).getObjectArrayValue(ByteBuffer.class); - final SchemaPair result = new SchemaPair(); - result.tableSchema = BarrageUtil.convertArrowSchema(ArrowToTableConverter.parseArrowSchema( - ArrowToTableConverter.parseArrowIpcMessage(schemas[0]))); - result.partitionSchema = BarrageUtil.convertArrowSchema(ArrowToTableConverter.parseArrowSchema( - ArrowToTableConverter.parseArrowIpcMessage(schemas[1]))); - return result; + final BarrageUtil.ConvertedArrowSchema[] schemas = new BarrageUtil.ConvertedArrowSchema[2]; + final Consumer onRawSchemas = byteBuffers -> { + if (byteBuffers.length != 2) { + throw new IllegalArgumentException("Expected two Arrow IPC messages: found " + byteBuffers.length); + } + + for (int ii = 0; ii < 2; ++ii) { + schemas[ii] = BarrageUtil.convertArrowSchema(ArrowToTableConverter.parseArrowSchema( + ArrowToTableConverter.parseArrowIpcMessage(byteBuffers[ii]))); + } + }; + + pyTableDataService.call("_table_schema", tableKey.key, onRawSchemas); + + return schemas; } /** @@ -136,9 +138,7 @@ public void getExistingPartitions( @NotNull final TableKeyImpl tableKey, @NotNull final Consumer listener) { final BiConsumer convertingListener = - (tableLocationKey, byteBuffers) -> { - processNewPartition(listener, tableLocationKey, byteBuffers); - }; + (tableLocationKey, byteBuffers) -> processNewPartition(listener, tableLocationKey, byteBuffers); pyTableDataService.call("_existing_partitions", tableKey.key, convertingListener); } @@ -154,15 +154,11 @@ public SafeCloseable subscribeToNewPartitions( @NotNull final TableKeyImpl tableKey, @NotNull final Consumer listener) { final BiConsumer convertingListener = - (tableLocationKey, byteBuffers) -> { - processNewPartition(listener, tableLocationKey, byteBuffers); - }; + (tableLocationKey, byteBuffers) -> processNewPartition(listener, tableLocationKey, byteBuffers); final PyObject cancellationCallback = pyTableDataService.call( "_subscribe_to_new_partitions", tableKey.key, convertingListener); - return () -> { - cancellationCallback.call("__call__"); - }; + return () -> cancellationCallback.call("__call__"); } private void processNewPartition( @@ -179,7 +175,7 @@ private void processNewPartition( + byteBuffers.length); } - final Map> partitionValues = new HashMap<>(); + final Map> partitionValues = new LinkedHashMap<>(); final Schema schema = ArrowToTableConverter.parseArrowSchema( ArrowToTableConverter.parseArrowIpcMessage(byteBuffers[0])); final BarrageUtil.ConvertedArrowSchema arrowSchema = BarrageUtil.convertArrowSchema(schema); @@ -206,12 +202,10 @@ private void processNewPartition( new FlatBufferIteratorAdapter<>(batch.nodesLength(), i -> new ChunkInputStreamGenerator.FieldNodeInfo(batch.nodes(i))); - final long[] bufferInfo = ArrowToTableConverter.extractBufferInfo(batch); - final PrimitiveIterator.OfLong bufferInfoIter = Arrays.stream(bufferInfo).iterator(); + final PrimitiveIterator.OfLong bufferInfoIter = ArrowToTableConverter.extractBufferInfo(batch); // populate the partition values - final int numColumns = schema.fieldsLength(); - for (int ci = 0; ci < numColumns; ++ci) { + for (int ci = 0; ci < schema.fieldsLength(); ++ci) { try (final WritableChunk columnValues = readers.get(ci).readChunk( fieldNodeIter, bufferInfoIter, recordBatchMessageInfo.inputStream, null, 0, 0)) { @@ -265,9 +259,7 @@ public SafeCloseable subscribeToPartitionSizeChanges( final PyObject cancellationCallback = pyTableDataService.call( "_subscribe_to_partition_size_changes", tableKey.key, tableLocationKey.locationKey, listener); - return () -> { - cancellationCallback.call("__call__"); - }; + return () -> cancellationCallback.call("__call__"); } /** @@ -281,19 +273,15 @@ public SafeCloseable subscribeToPartitionSizeChanges( * @return the column values */ public List> getColumnValues( - TableKeyImpl tableKey, - TableLocationKeyImpl tableLocationKey, - ColumnDefinition columnDefinition, - long firstRowPosition, - int minimumSize, - int maximumSize) { + @NotNull final TableKeyImpl tableKey, + @NotNull final TableLocationKeyImpl tableLocationKey, + @NotNull final ColumnDefinition columnDefinition, + final long firstRowPosition, + final int minimumSize, + final int maximumSize) { final List> resultChunks = new ArrayList<>(); final Consumer onMessages = messages -> { - if (messages.length == 0) { - return; - } - if (messages.length < 2) { throw new IllegalArgumentException("Expected at least two Arrow IPC messages: found " + messages.length); @@ -328,8 +316,7 @@ public List> getColumnValues( new FlatBufferIteratorAdapter<>(batch.nodesLength(), i -> new ChunkInputStreamGenerator.FieldNodeInfo(batch.nodes(i))); - final long[] bufferInfo = ArrowToTableConverter.extractBufferInfo(batch); - final PrimitiveIterator.OfLong bufferInfoIter = Arrays.stream(bufferInfo).iterator(); + final PrimitiveIterator.OfLong bufferInfoIter = ArrowToTableConverter.extractBufferInfo(batch); resultChunks.add(reader.readChunk( fieldNodeIter, bufferInfoIter, recordBatchMessageInfo.inputStream, null, 0, 0)); @@ -348,8 +335,6 @@ public List> getColumnValues( } } - - @Override protected @NotNull TableLocationProvider makeTableLocationProvider(@NotNull final TableKey tableKey) { if (!(tableKey instanceof TableKeyImpl)) { @@ -383,6 +368,8 @@ public boolean equals(final Object other) { @Override public int hashCode() { + // TODO NOCOMMIT @ryan: PyObject's hash is based on pointer location of object which would change if + // two different Python objects have the same value. return key.hashCode(); } @@ -418,10 +405,10 @@ private class TableLocationProviderImpl extends AbstractTableLocationProvider { private TableLocationProviderImpl(@NotNull final TableKeyImpl tableKey) { super(tableKey, true); - final SchemaPair tableAndPartitionColumnSchemas = backend.getTableSchema(tableKey); + final BarrageUtil.ConvertedArrowSchema[] schemas = backend.getTableSchema(tableKey); - final TableDefinition tableDef = tableAndPartitionColumnSchemas.tableSchema.tableDef; - final TableDefinition partitionDef = tableAndPartitionColumnSchemas.partitionSchema.tableDef; + final TableDefinition tableDef = schemas[0].tableDef; + final TableDefinition partitionDef = schemas[1].tableDef; final Map> columns = new LinkedHashMap<>(tableDef.numColumns()); for (final ColumnDefinition column : tableDef.getColumns()) { @@ -462,6 +449,8 @@ public void refresh() { protected void activateUnderlyingDataSource() { TableKeyImpl key = (TableKeyImpl) getKey(); final Subscription localSubscription = subscription = new Subscription(); + // TODO NOCOMMIT @ryan: should we let the python table service impl activate so that they may invoke the + // callback immediately? localSubscription.cancellationCallback = backend.subscribeToNewPartitions(key, tableLocationKey -> { if (localSubscription != subscription) { // we've been cancelled and/or replaced @@ -531,6 +520,8 @@ public boolean equals(final Object other) { @Override public int hashCode() { + // TODO NOCOMMIT @ryan: PyObject's hash is based on pointer location of object which would change if + // two different Python objects have the same value. return locationKey.hashCode(); } @@ -540,8 +531,6 @@ public int compareTo(@NotNull final TableLocationKey other) { throw new ClassCastException(String.format("Cannot compare %s to %s", getClass(), other.getClass())); } final TableLocationKeyImpl otherTableLocationKey = (TableLocationKeyImpl) other; - // TODO NOCOMMIT @ryan: What exactly is supposed to happen if partition values are equal but these are - // different locations? return PartitionsComparator.INSTANCE.compare(partitions, otherTableLocationKey.partitions); } @@ -605,7 +594,6 @@ public void refresh() { @Override public @NotNull List getSortedColumns() { - // TODO NOCOMMIT @ryan: we may be able to fetch this from the metadata or table definition post conversion return List.of(); } @@ -745,8 +733,10 @@ public TableServiceGetRangeAdapter(@NotNull ColumnDefinition columnDefinition } @Override - public void readChunkPage(long firstRowPosition, int minimumSize, - @NotNull WritableChunk destination) { + public void readChunkPage( + final long firstRowPosition, + final int minimumSize, + @NotNull final WritableChunk destination) { final TableLocationImpl location = (TableLocationImpl) getTableLocation(); final TableKeyImpl key = (TableKeyImpl) location.getTableKey(); @@ -758,19 +748,23 @@ public void readChunkPage(long firstRowPosition, int minimumSize, if (numRows < minimumSize) { throw new TableDataException(String.format("Not enough data returned. Read %d rows but minimum " - + "expected was %d. Short result from get_column_values(%s, %s, %s, %d, %d).", + + "expected was %d. Result from get_column_values(%s, %s, %s, %d, %d).", numRows, minimumSize, key.key, ((TableLocationKeyImpl) location.getKey()).locationKey, columnDefinition.getName(), firstRowPosition, minimumSize)); } + if (numRows > destination.capacity()) { + throw new TableDataException(String.format("Too much data returned. Read %d rows but maximum " + + "expected was %d. Result from get_column_values(%s, %s, %s, %d, %d).", + numRows, destination.capacity(), key.key, + ((TableLocationKeyImpl) location.getKey()).locationKey, columnDefinition.getName(), + firstRowPosition, minimumSize)); + } int offset = 0; for (final Chunk rbChunk : values) { int length = Math.min(destination.capacity() - offset, rbChunk.size()); destination.copyFromChunk(rbChunk, 0, offset, length); offset += length; - if (offset >= destination.capacity()) { - break; - } } destination.setSize(offset); } diff --git a/py/server/deephaven/experimental/partitioned_table_service.py b/py/server/deephaven/experimental/partitioned_table_service.py index 7556ce9f1b6..f18c1f26356 100644 --- a/py/server/deephaven/experimental/partitioned_table_service.py +++ b/py/server/deephaven/experimental/partitioned_table_service.py @@ -53,7 +53,7 @@ class PartitionedTableServiceBackend(ABC): def table_schema(self, table_key: TableKey) -> Tuple[pa.Schema, Optional[pa.Schema]]: """ Returns the table schema and optionally the schema for the partition columns for the table with the given table key. - The table schema is not required to include the partition columns defined in the partition schema. THe + The table schema is not required to include the partition columns defined in the partition schema. The partition columns are limited to primitive types and strings. Args: @@ -74,6 +74,8 @@ def existing_partitions(self, table_key: TableKey, The table should have a single row for the particular partition location key provided in the 1st argument, with the values for the partition columns in the row. + TODO JF: This is invoked for tables created when make_table's `live` is False. + Args: table_key (TableKey): the table key callback (Callable[[PartitionedTableLocationKey, Optional[pa.Table]], None]): the callback function @@ -90,6 +92,11 @@ def subscribe_to_new_partitions(self, table_key: TableKey, have a single row for the particular partition location key provided in the 1st argument, with the values for the partition columns in the row. + TODO JF: This is invoked for tables created when make_table's `live` is True. + TODO: add comment if test_make_live_table_observe_subscription_cancellations demonstrates that the subscription + needs to callback for any existing partitions, too (or if existing_partitions will also be invoked when + live == True) + The return value is a function that can be called to unsubscribe from the new partitions. Args: @@ -104,6 +111,8 @@ def partition_size(self, table_key: TableKey, table_location_key: PartitionedTab """ Provides a callback for the backend service to pass the size of the partition with the given table key and partition location key. The callback should be called with the size of the partition in number of rows. + TODO JF: This is invoked for tables created when make_table's `live` is False. + Args: table_key (TableKey): the table key table_location_key (PartitionedTableLocationKey): the partition location key @@ -118,6 +127,10 @@ def subscribe_to_partition_size_changes(self, table_key: TableKey, table_locatio table key and partition location key. The callback should be called with the size of the partition in number of rows. + TODO JF: This is invoked for tables created when make_table's `live` is True. + This callback cannot be invoked until after this method has returned. + This callback must be invoked with the initial size of the partition. + The return value is a function that can be called to unsubscribe from the partition size changes. Args: @@ -190,12 +203,13 @@ def make_table(self, table_key: TableKey, *, live: bool) -> Table: except Exception as e: raise DHError(e, message=f"failed to make a table for the key {table_key.key}") from e - def _table_schema(self, table_key: TableKey) -> jpy.JType: + def _table_schema(self, table_key: TableKey, callback: jpy.JType) -> jpy.JType: """ Returns the table schema and the partition schema for the table with the given table key as two serialized byte buffers. Args: table_key (TableKey): the table key + TODO JF: make good doc ;P Returns: jpy.JType: an array of two serialized byte buffers @@ -206,7 +220,7 @@ def _table_schema(self, table_key: TableKey) -> jpy.JType: pc_schema = pc_schema if pc_schema is not None else pa.schema([]) j_pt_schema_bb = jpy.byte_buffer(pt_schema.serialize()) j_pc_schema_bb = jpy.byte_buffer(pc_schema.serialize()) - return jpy.array("java.nio.ByteBuffer", [j_pt_schema_bb, j_pc_schema_bb]) + callback.accept(jpy.array("java.nio.ByteBuffer", [j_pt_schema_bb, j_pc_schema_bb])) def _existing_partitions(self, table_key: TableKey, callback: jpy.JType) -> None: """ Provides the existing partitions for the table with the given table key to the table service in the engine. @@ -302,6 +316,9 @@ def _column_values(self, table_key: TableKey, table_location_key: PartitionedTab partition column values """ pt_table = self._backend.column_values(table_key, table_location_key, col, offset, min_rows, max_rows) + if len(pt_table) < min_rows or len(pt_table) > max_rows: + raise ValueError("The number of rows in the pyarrow table for column values must be in the range of " + f"{min_rows} to {max_rows}") bb_list = [jpy.byte_buffer(rb.serialize()) for rb in pt_table.to_batches()] bb_list.insert(0, jpy.byte_buffer(pt_table.schema.serialize())) callback.accept(jpy.array("java.nio.ByteBuffer", bb_list)) diff --git a/py/server/tests/test_partitioned_table_service.py b/py/server/tests/test_partitioned_table_service.py index ddb8763f29f..f8f7bf47f75 100644 --- a/py/server/tests/test_partitioned_table_service.py +++ b/py/server/tests/test_partitioned_table_service.py @@ -81,11 +81,14 @@ def subscribe_to_new_partitions(self, table_key: TableKey, callback) -> Callable if table_key.key != "test": return lambda: None + # TODO for test count the number opened subscriptions + exec_ctx = get_exec_ctx() th = threading.Thread(target=self._th_new_partitions, args=(table_key, exec_ctx, callback)) th.start() def _cancellation_callback(): + # TODO for test count the number cancellations self._sub_new_partition_cancelled = True return _cancellation_callback @@ -121,6 +124,7 @@ def subscribe_to_partition_size_changes(self, table_key: TableKey, th = threading.Thread(target=self._th_partition_size_changes, args=(table_key, table_location_key, callback)) th.start() + # TODO count number of total subscriptions and number of total cancellations def _cancellation_callback(): self._partitions_size_subscriptions[table_location_key] = False @@ -206,6 +210,19 @@ def test_make_live_table_with_partition_schema_ops(self): # t doesn't have the partitioning columns self.assertEqual(t.columns, self.test_table.columns) + def test_make_live_table_observe_subscription_cancellations(self): + # coalesce the PartitionAwareSourceTable under a liveness scope + # count number of new partition subscriptions + # count number of partition size subscriptions + # close liveness scope + # assert subscriptions are all closed + pass + + def test_make_live_table_ensure_initial_partitions_exist(self): + # disable new partition subscriptions + # coalesce the PartitionAwareSourceTable + # ensure that all existing partitions were added to the table + pass if __name__ == '__main__': unittest.main() diff --git a/server/src/main/java/io/deephaven/server/arrow/ArrowFlightUtil.java b/server/src/main/java/io/deephaven/server/arrow/ArrowFlightUtil.java index bec7137a705..56f225ea3c7 100644 --- a/server/src/main/java/io/deephaven/server/arrow/ArrowFlightUtil.java +++ b/server/src/main/java/io/deephaven/server/arrow/ArrowFlightUtil.java @@ -216,7 +216,7 @@ public void onNext(final InputStream request) { } if (mi.header.headerType() == MessageHeader.Schema) { - parseSchema(parseArrowSchema(mi)); + configureWithSchema(parseArrowSchema(mi)); return; }