Skip to content

Commit

Permalink
apacheGH-37720: [Java][FlightSQL] Implement stateless prepared statem…
Browse files Browse the repository at this point in the history
…ents

Part fixed caching of statementContext
  • Loading branch information
stevelorddremio committed May 10, 2024
1 parent 2bf9e19 commit 42991f1
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@

package org.apache.arrow.adapter.jdbc;

import static java.nio.charset.StandardCharsets.UTF_8;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.HashMap;
Expand All @@ -43,7 +38,6 @@ public class JdbcParameterBinder {
private final ColumnBinder[] binders;
private final int[] parameterIndices;
private int nextRowIndex;
private byte[] bindersAsByteArray;

/**
* Create a new parameter binder.
Expand All @@ -57,8 +51,7 @@ private JdbcParameterBinder(
final PreparedStatement statement,
final VectorSchemaRoot root,
final ColumnBinder[] binders,
int[] parameterIndices,
byte[] bindersAsByteArray) {
int[] parameterIndices) {
Preconditions.checkArgument(
binders.length == parameterIndices.length,
"Number of column binders (%s) must equal number of parameter indices (%s)",
Expand All @@ -68,7 +61,6 @@ private JdbcParameterBinder(
this.binders = binders;
this.parameterIndices = parameterIndices;
this.nextRowIndex = 0;
this.bindersAsByteArray = bindersAsByteArray;
}

/**
Expand Down Expand Up @@ -145,7 +137,7 @@ public Builder bind(int parameterIndex, ColumnBinder binder) {
}

/** Build the binder. */
public JdbcParameterBinder build() throws IOException {
public JdbcParameterBinder build() {
ColumnBinder[] binders = new ColumnBinder[bindings.size()];
int[] parameterIndices = new int[bindings.size()];
int index = 0;
Expand All @@ -154,20 +146,7 @@ public JdbcParameterBinder build() throws IOException {
parameterIndices[index] = entry.getKey();
index++;
}

// Convert parameters to byte array
ByteArrayOutputStream outStream = new ByteArrayOutputStream();
try (ObjectOutputStream outObject = new ObjectOutputStream(outStream)) {
outObject.writeObject(bindings.toString().getBytes(UTF_8));
outObject.flush();
}

// return new JdbcParameterBinder(statement, root, binders, parameterIndices, outStream.toByteArray());
return new JdbcParameterBinder(statement, root, binders, parameterIndices, outStream.toByteArray());
return new JdbcParameterBinder(statement, root, binders, parameterIndices);
}
}

public byte[] getBindersAsByteArray() {
return bindersAsByteArray;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import static org.assertj.core.api.Assertions.assertThat;

import java.io.IOException;
import java.math.BigDecimal;
import java.nio.charset.StandardCharsets;
import java.sql.Date;
Expand Down Expand Up @@ -81,7 +80,6 @@
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.JsonStringHashMap;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

Expand All @@ -100,7 +98,7 @@ void afterEach() {
}

@Test
void bindOrder() throws SQLException, IOException {
void bindOrder() throws SQLException {
final Schema schema =
new Schema(
Arrays.asList(
Expand Down Expand Up @@ -161,7 +159,7 @@ void bindOrder() throws SQLException, IOException {
}

@Test
void customBinder() throws SQLException, IOException {
void customBinder() throws SQLException {
final Schema schema =
new Schema(Collections.singletonList(
Field.nullable("ints0", new ArrowType.Int(32, true))));
Expand Down Expand Up @@ -564,7 +562,7 @@ <T, V extends FieldVector> void testSimpleType(ArrowType arrowType, int jdbcType
try (final MockPreparedStatement statement = new MockPreparedStatement();
final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
final JdbcParameterBinder binder =
JdbcParameterBinder.builder(statement, root).bindAll().build();
JdbcParameterBinder.builder(statement, root).bindAll().build();
assertThat(binder.next()).isFalse();

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -607,8 +605,6 @@ <T, V extends FieldVector> void testSimpleType(ArrowType arrowType, int jdbcType
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(1));
assertThat(binder.next()).isFalse();
} catch (IOException e) {
Assertions.fail("Unexpected binding error.");
}

// Non-nullable (since some types have a specialized binder)
Expand Down Expand Up @@ -651,8 +647,6 @@ <T, V extends FieldVector> void testSimpleType(ArrowType arrowType, int jdbcType
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(1));
assertThat(binder.next()).isFalse();
} catch (IOException e) {
Assertions.fail("Unexpected binding error.");
}
}

Expand All @@ -666,10 +660,11 @@ <T, V extends FieldVector> void testListType(ArrowType arrowType, TriConsumer<V,
try (final MockPreparedStatement statement = new MockPreparedStatement();
final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
final JdbcParameterBinder binder =
JdbcParameterBinder.builder(statement, root).bindAll().build();
JdbcParameterBinder.builder(statement, root).bindAll().build();
assertThat(binder.next()).isFalse();

@SuppressWarnings("unchecked") final V vector = (V) root.getVector(0);
@SuppressWarnings("unchecked")
final V vector = (V) root.getVector(0);
final ColumnBinder columnBinder = ColumnBinder.forVector(vector);
assertThat(columnBinder.getJdbcType()).isEqualTo(jdbcType);

Expand Down Expand Up @@ -708,8 +703,6 @@ <T, V extends FieldVector> void testListType(ArrowType arrowType, TriConsumer<V,
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(1));
assertThat(binder.next()).isFalse();
} catch (IOException e) {
Assertions.fail("Unexpected binding error.");
}

// Non-nullable (since some types have a specialized binder)
Expand Down Expand Up @@ -755,8 +748,6 @@ <T, V extends FieldVector> void testListType(ArrowType arrowType, TriConsumer<V,
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(1));
assertThat(binder.next()).isFalse();
} catch (IOException e) {
Assertions.fail("Unexpected binding error.");
}
}

Expand Down Expand Up @@ -816,8 +807,6 @@ <T, V extends FieldVector> void testMapType(ArrowType arrowType, TriConsumer<V,
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(1).toString());
assertThat(binder.next()).isFalse();
} catch (IOException e) {
Assertions.fail("Unexpected binding error.");
}

// Non-nullable (since some types have a specialized binder)
Expand Down Expand Up @@ -865,8 +854,6 @@ <T, V extends FieldVector> void testMapType(ArrowType arrowType, TriConsumer<V,
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(1));
assertThat(binder.next()).isFalse();
} catch (IOException e) {
Assertions.fail("Unexpected binding error.");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.impl.UnionListWriter;
import org.apache.arrow.vector.ipc.ArrowFileWriter;
import org.apache.arrow.vector.ipc.WriteChannel;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.Types.MinorType;
Expand Down Expand Up @@ -897,9 +898,6 @@ public Runnable acceptPutPreparedStatementUpdate(CommandPreparedStatementUpdate
} catch (SQLException e) {
ackStream.onError(CallStatus.INTERNAL.withDescription("Failed to execute update: " + e).toRuntimeException());
return;
} catch (IOException e) {
ackStream.onError(CallStatus.INTERNAL.withDescription("Failed to execute update: " + e).toRuntimeException());
return;
}
ackStream.onCompleted();
};
Expand All @@ -923,6 +921,31 @@ public Runnable acceptPutPreparedStatementQuery(CommandPreparedStatementQuery co
while (binder.next()) {
// Do not execute() - will be done in a getStream call
}
final ByteArrayOutputStream out = new ByteArrayOutputStream();
try (
ArrowFileWriter writer = new ArrowFileWriter(root, null, Channels.newChannel(out))
) {
writer.start();
writer.writeBatch();
}
if (out.size() > 0) {
final DoPutPreparedStatementResult doPutPreparedStatementResult =
DoPutPreparedStatementResult.newBuilder()
.setPreparedStatementHandle(ByteString.copyFrom(ByteBuffer.wrap(out.toByteArray())))
.build();

// Update prepared statement cache by storing with new handle and remove old entry.
preparedStatementLoadingCache.put(doPutPreparedStatementResult.getPreparedStatementHandle(),
statementContext);
// TODO: If we invalidate old cached entry here this invalidates the statement, which is not what is needed.
// We need to re-cache the statementContext with a new key.
// preparedStatementLoadingCache.invalidate(command.getPreparedStatementHandle());

try (final ArrowBuf buffer = rootAllocator.buffer(doPutPreparedStatementResult.getSerializedSize())) {
buffer.writeBytes(doPutPreparedStatementResult.toByteArray());
ackStream.onNext(PutResult.metadata(buffer));
}
}
}

} catch (SQLException e) {
Expand All @@ -939,17 +962,6 @@ public Runnable acceptPutPreparedStatementQuery(CommandPreparedStatementQuery co
return;
}

if (binder != null && binder.getBindersAsByteArray() != null) {
final byte[] byteArray = binder.getBindersAsByteArray();
final DoPutPreparedStatementResult build =
DoPutPreparedStatementResult.newBuilder()
.setPreparedStatementHandle(ByteString.copyFrom(ByteBuffer.wrap(byteArray))).build();

try (final ArrowBuf buffer = rootAllocator.buffer(build.getSerializedSize())) {
buffer.writeBytes(build.toByteArray());
ackStream.onNext(PutResult.metadata(buffer));
}
}
ackStream.onCompleted();
};
}
Expand Down

0 comments on commit 42991f1

Please sign in to comment.