Skip to content

Commit

Permalink
Paired Review
Browse files Browse the repository at this point in the history
  • Loading branch information
nbauernfeind committed Oct 17, 2024
1 parent ebc3d46 commit 50dabff
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 70 deletions.
1 change: 0 additions & 1 deletion extensions/barrage/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ dependencies {
implementation libs.arrow.vector
implementation libs.arrow.format
implementation project(path: ':extensions-source-support')
implementation project(path: ':extensions-source-support')

compileOnly project(':util-immutables')
annotationProcessor libs.immutables.value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public static Schema parseArrowSchema(final BarrageProtoUtil.MessageInfo mi) {
return schema;
}

public static long[] extractBufferInfo(@NotNull final RecordBatch batch) {
public static PrimitiveIterator.OfLong extractBufferInfo(@NotNull final RecordBatch batch) {
final long[] bufferInfo = new long[batch.buffersLength()];
for (int bi = 0; bi < batch.buffersLength(); ++bi) {
int offset = LongSizedDataStructure.intSize("BufferInfo", batch.buffers(bi).offset());
Expand All @@ -101,7 +101,7 @@ public static long[] extractBufferInfo(@NotNull final RecordBatch batch) {
}
bufferInfo[bi] = length;
}
return bufferInfo;
return Arrays.stream(bufferInfo).iterator();
}

@ScriptApi
Expand All @@ -113,7 +113,7 @@ public synchronized void setSchema(final ByteBuffer ipcMessage) {
throw new IllegalStateException("Conversion is complete; cannot process additional messages");
}
final BarrageProtoUtil.MessageInfo mi = parseArrowIpcMessage(ipcMessage);
parseSchema(parseArrowSchema(mi));
configureWithSchema(parseArrowSchema(mi));
}

@ScriptApi
Expand Down Expand Up @@ -168,13 +168,12 @@ public synchronized void onCompleted() throws InterruptedException {
completed = true;
}

protected void parseSchema(final Schema schema) {
protected void configureWithSchema(final Schema schema) {
if (resultTable != null) {
throw Exceptions.statusRuntimeException(Code.INVALID_ARGUMENT, "Schema evolution not supported");
}

final BarrageUtil.ConvertedArrowSchema result = BarrageUtil.convertArrowSchema(schema);

resultTable = BarrageTable.make(null, result.tableDef, result.attributes, null);
resultTable.setFlat();

Expand Down Expand Up @@ -203,8 +202,7 @@ protected BarrageMessage createBarrageMessage(BarrageProtoUtil.MessageInfo mi, i
new FlatBufferIteratorAdapter<>(batch.nodesLength(),
i -> new ChunkInputStreamGenerator.FieldNodeInfo(batch.nodes(i)));

final long[] bufferInfo = extractBufferInfo(batch);
final PrimitiveIterator.OfLong bufferInfoIter = Arrays.stream(bufferInfo).iterator();
final PrimitiveIterator.OfLong bufferInfoIter = extractBufferInfo(batch);

msg.rowsRemoved = RowSetFactory.empty();
msg.shifted = RowSetShiftData.EMPTY;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ private PythonTableDataService(@NotNull final PyObject pyTableDataService) {
}

/**
* Get a Deephaven {@link Table} for the supplied name.
* Get a Deephaven {@link Table} for the supplied {@link TableKey}.
*
* @param tableKey The table key
* @param live Whether the table should update as new data becomes available
Expand All @@ -90,11 +90,6 @@ public Table makeTable(@NotNull final TableKeyImpl tableKey, final boolean live)
live ? ExecutionContext.getContext().getUpdateGraph() : null);
}

private static class SchemaPair {
BarrageUtil.ConvertedArrowSchema tableSchema;
BarrageUtil.ConvertedArrowSchema partitionSchema;
}

/**
* This Backend impl marries the Python TableDataService with the Deephaven TableDataService. By performing the
* object translation here, we can keep the Python TableDataService implementation simple and focused on the Python
Expand All @@ -114,16 +109,23 @@ private BackendAccessor(
* @param tableKey the table key
* @return the schemas
*/
public SchemaPair getTableSchema(
public BarrageUtil.ConvertedArrowSchema[] getTableSchema(
@NotNull final TableKeyImpl tableKey) {
final ByteBuffer[] schemas =
pyTableDataService.call("_table_schema", tableKey.key).getObjectArrayValue(ByteBuffer.class);
final SchemaPair result = new SchemaPair();
result.tableSchema = BarrageUtil.convertArrowSchema(ArrowToTableConverter.parseArrowSchema(
ArrowToTableConverter.parseArrowIpcMessage(schemas[0])));
result.partitionSchema = BarrageUtil.convertArrowSchema(ArrowToTableConverter.parseArrowSchema(
ArrowToTableConverter.parseArrowIpcMessage(schemas[1])));
return result;
final BarrageUtil.ConvertedArrowSchema[] schemas = new BarrageUtil.ConvertedArrowSchema[2];
final Consumer<ByteBuffer[]> onRawSchemas = byteBuffers -> {
if (byteBuffers.length != 2) {
throw new IllegalArgumentException("Expected two Arrow IPC messages: found " + byteBuffers.length);
}

for (int ii = 0; ii < 2; ++ii) {
schemas[ii] = BarrageUtil.convertArrowSchema(ArrowToTableConverter.parseArrowSchema(
ArrowToTableConverter.parseArrowIpcMessage(byteBuffers[ii])));
}
};

pyTableDataService.call("_table_schema", tableKey.key, onRawSchemas);

return schemas;
}

/**
Expand All @@ -136,9 +138,7 @@ public void getExistingPartitions(
@NotNull final TableKeyImpl tableKey,
@NotNull final Consumer<TableLocationKeyImpl> listener) {
final BiConsumer<TableLocationKeyImpl, ByteBuffer[]> convertingListener =
(tableLocationKey, byteBuffers) -> {
processNewPartition(listener, tableLocationKey, byteBuffers);
};
(tableLocationKey, byteBuffers) -> processNewPartition(listener, tableLocationKey, byteBuffers);

pyTableDataService.call("_existing_partitions", tableKey.key, convertingListener);
}
Expand All @@ -154,15 +154,11 @@ public SafeCloseable subscribeToNewPartitions(
@NotNull final TableKeyImpl tableKey,
@NotNull final Consumer<TableLocationKeyImpl> listener) {
final BiConsumer<TableLocationKeyImpl, ByteBuffer[]> convertingListener =
(tableLocationKey, byteBuffers) -> {
processNewPartition(listener, tableLocationKey, byteBuffers);
};
(tableLocationKey, byteBuffers) -> processNewPartition(listener, tableLocationKey, byteBuffers);

final PyObject cancellationCallback = pyTableDataService.call(
"_subscribe_to_new_partitions", tableKey.key, convertingListener);
return () -> {
cancellationCallback.call("__call__");
};
return () -> cancellationCallback.call("__call__");
}

private void processNewPartition(
Expand All @@ -179,7 +175,7 @@ private void processNewPartition(
+ byteBuffers.length);
}

final Map<String, Comparable<?>> partitionValues = new HashMap<>();
final Map<String, Comparable<?>> partitionValues = new LinkedHashMap<>();
final Schema schema = ArrowToTableConverter.parseArrowSchema(
ArrowToTableConverter.parseArrowIpcMessage(byteBuffers[0]));
final BarrageUtil.ConvertedArrowSchema arrowSchema = BarrageUtil.convertArrowSchema(schema);
Expand All @@ -206,12 +202,10 @@ private void processNewPartition(
new FlatBufferIteratorAdapter<>(batch.nodesLength(),
i -> new ChunkInputStreamGenerator.FieldNodeInfo(batch.nodes(i)));

final long[] bufferInfo = ArrowToTableConverter.extractBufferInfo(batch);
final PrimitiveIterator.OfLong bufferInfoIter = Arrays.stream(bufferInfo).iterator();
final PrimitiveIterator.OfLong bufferInfoIter = ArrowToTableConverter.extractBufferInfo(batch);

// populate the partition values
final int numColumns = schema.fieldsLength();
for (int ci = 0; ci < numColumns; ++ci) {
for (int ci = 0; ci < schema.fieldsLength(); ++ci) {
try (final WritableChunk<Values> columnValues = readers.get(ci).readChunk(
fieldNodeIter, bufferInfoIter, recordBatchMessageInfo.inputStream, null, 0, 0)) {

Expand Down Expand Up @@ -265,9 +259,7 @@ public SafeCloseable subscribeToPartitionSizeChanges(
final PyObject cancellationCallback = pyTableDataService.call(
"_subscribe_to_partition_size_changes", tableKey.key, tableLocationKey.locationKey, listener);

return () -> {
cancellationCallback.call("__call__");
};
return () -> cancellationCallback.call("__call__");
}

/**
Expand All @@ -281,19 +273,15 @@ public SafeCloseable subscribeToPartitionSizeChanges(
* @return the column values
*/
public List<WritableChunk<Values>> getColumnValues(
TableKeyImpl tableKey,
TableLocationKeyImpl tableLocationKey,
ColumnDefinition<?> columnDefinition,
long firstRowPosition,
int minimumSize,
int maximumSize) {
@NotNull final TableKeyImpl tableKey,
@NotNull final TableLocationKeyImpl tableLocationKey,
@NotNull final ColumnDefinition<?> columnDefinition,
final long firstRowPosition,
final int minimumSize,
final int maximumSize) {

final List<WritableChunk<Values>> resultChunks = new ArrayList<>();
final Consumer<ByteBuffer[]> onMessages = messages -> {
if (messages.length == 0) {
return;
}

if (messages.length < 2) {
throw new IllegalArgumentException("Expected at least two Arrow IPC messages: found "
+ messages.length);
Expand Down Expand Up @@ -328,8 +316,7 @@ public List<WritableChunk<Values>> getColumnValues(
new FlatBufferIteratorAdapter<>(batch.nodesLength(),
i -> new ChunkInputStreamGenerator.FieldNodeInfo(batch.nodes(i)));

final long[] bufferInfo = ArrowToTableConverter.extractBufferInfo(batch);
final PrimitiveIterator.OfLong bufferInfoIter = Arrays.stream(bufferInfo).iterator();
final PrimitiveIterator.OfLong bufferInfoIter = ArrowToTableConverter.extractBufferInfo(batch);

resultChunks.add(reader.readChunk(
fieldNodeIter, bufferInfoIter, recordBatchMessageInfo.inputStream, null, 0, 0));
Expand All @@ -348,8 +335,6 @@ public List<WritableChunk<Values>> getColumnValues(
}
}



@Override
protected @NotNull TableLocationProvider makeTableLocationProvider(@NotNull final TableKey tableKey) {
if (!(tableKey instanceof TableKeyImpl)) {
Expand Down Expand Up @@ -383,6 +368,8 @@ public boolean equals(final Object other) {

@Override
public int hashCode() {
// TODO NOCOMMIT @ryan: PyObject's hash is based on pointer location of object which would change if
// two different Python objects have the same value.
return key.hashCode();
}

Expand Down Expand Up @@ -418,10 +405,10 @@ private class TableLocationProviderImpl extends AbstractTableLocationProvider {

private TableLocationProviderImpl(@NotNull final TableKeyImpl tableKey) {
super(tableKey, true);
final SchemaPair tableAndPartitionColumnSchemas = backend.getTableSchema(tableKey);
final BarrageUtil.ConvertedArrowSchema[] schemas = backend.getTableSchema(tableKey);

final TableDefinition tableDef = tableAndPartitionColumnSchemas.tableSchema.tableDef;
final TableDefinition partitionDef = tableAndPartitionColumnSchemas.partitionSchema.tableDef;
final TableDefinition tableDef = schemas[0].tableDef;
final TableDefinition partitionDef = schemas[1].tableDef;
final Map<String, ColumnDefinition<?>> columns = new LinkedHashMap<>(tableDef.numColumns());

for (final ColumnDefinition<?> column : tableDef.getColumns()) {
Expand Down Expand Up @@ -462,6 +449,8 @@ public void refresh() {
protected void activateUnderlyingDataSource() {
TableKeyImpl key = (TableKeyImpl) getKey();
final Subscription localSubscription = subscription = new Subscription();
// TODO NOCOMMIT @ryan: should we let the python table service impl activate so that they may invoke the
// callback immediately?
localSubscription.cancellationCallback = backend.subscribeToNewPartitions(key, tableLocationKey -> {
if (localSubscription != subscription) {
// we've been cancelled and/or replaced
Expand Down Expand Up @@ -531,6 +520,8 @@ public boolean equals(final Object other) {

@Override
public int hashCode() {
// TODO NOCOMMIT @ryan: PyObject's hash is based on pointer location of object which would change if
// two different Python objects have the same value.
return locationKey.hashCode();
}

Expand All @@ -540,8 +531,6 @@ public int compareTo(@NotNull final TableLocationKey other) {
throw new ClassCastException(String.format("Cannot compare %s to %s", getClass(), other.getClass()));
}
final TableLocationKeyImpl otherTableLocationKey = (TableLocationKeyImpl) other;
// TODO NOCOMMIT @ryan: What exactly is supposed to happen if partition values are equal but these are
// different locations?
return PartitionsComparator.INSTANCE.compare(partitions, otherTableLocationKey.partitions);
}

Expand Down Expand Up @@ -605,7 +594,6 @@ public void refresh() {

@Override
public @NotNull List<SortColumn> getSortedColumns() {
// TODO NOCOMMIT @ryan: we may be able to fetch this from the metadata or table definition post conversion
return List.of();
}

Expand Down Expand Up @@ -745,8 +733,10 @@ public TableServiceGetRangeAdapter(@NotNull ColumnDefinition<?> columnDefinition
}

@Override
public void readChunkPage(long firstRowPosition, int minimumSize,
@NotNull WritableChunk<Values> destination) {
public void readChunkPage(
final long firstRowPosition,
final int minimumSize,
@NotNull final WritableChunk<Values> destination) {
final TableLocationImpl location = (TableLocationImpl) getTableLocation();
final TableKeyImpl key = (TableKeyImpl) location.getTableKey();

Expand All @@ -758,19 +748,23 @@ public void readChunkPage(long firstRowPosition, int minimumSize,

if (numRows < minimumSize) {
throw new TableDataException(String.format("Not enough data returned. Read %d rows but minimum "
+ "expected was %d. Short result from get_column_values(%s, %s, %s, %d, %d).",
+ "expected was %d. Result from get_column_values(%s, %s, %s, %d, %d).",
numRows, minimumSize, key.key, ((TableLocationKeyImpl) location.getKey()).locationKey,
columnDefinition.getName(), firstRowPosition, minimumSize));
}
if (numRows > destination.capacity()) {
throw new TableDataException(String.format("Too much data returned. Read %d rows but maximum "
+ "expected was %d. Result from get_column_values(%s, %s, %s, %d, %d).",
numRows, destination.capacity(), key.key,
((TableLocationKeyImpl) location.getKey()).locationKey, columnDefinition.getName(),
firstRowPosition, minimumSize));
}

int offset = 0;
for (final Chunk<Values> rbChunk : values) {
int length = Math.min(destination.capacity() - offset, rbChunk.size());
destination.copyFromChunk(rbChunk, 0, offset, length);
offset += length;
if (offset >= destination.capacity()) {
break;
}
}
destination.setSize(offset);
}
Expand Down
Loading

0 comments on commit 50dabff

Please sign in to comment.