Skip to content

GH-153: Allow arbitrary parameter binding on JDBC driver #462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public ExecuteResult execute(
}

new AvaticaParameterBinder(
preparedStatement, ((ArrowFlightConnection) connection).getBufferAllocator())
preparedStatement, typedValues, ((ArrowFlightConnection) connection).getBufferAllocator())
.bind(typedValues);

if (statementHandle.signature == null) {
Expand Down Expand Up @@ -144,11 +144,13 @@ public ExecuteBatchResult executeBatch(
throw new IllegalStateException("Prepared statement not found: " + statementHandle);
}

final AvaticaParameterBinder binder =
new AvaticaParameterBinder(
preparedStatement, ((ArrowFlightConnection) connection).getBufferAllocator());
for (int i = 0; i < parameterValuesList.size(); i++) {
binder.bind(parameterValuesList.get(i), i);
if (parameterValuesList.size() > 0) {
final AvaticaParameterBinder binder =
new AvaticaParameterBinder(
preparedStatement, parameterValuesList.get(0), ((ArrowFlightConnection) connection).getBufferAllocator());
for (int i = 0; i < parameterValuesList.size(); i++) {
binder.bind(parameterValuesList.get(i), i);
}
}

// Update query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
*/
package org.apache.arrow.driver.jdbc.utils;

import java.sql.Types;
import java.util.ArrayList;
import java.util.List;

import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler.PreparedStatement;
import org.apache.arrow.driver.jdbc.converter.impl.BinaryAvaticaParameterConverter;
import org.apache.arrow.driver.jdbc.converter.impl.BoolAvaticaParameterConverter;
Expand All @@ -42,9 +45,18 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.calcite.avatica.remote.TypedValue;

import static org.apache.arrow.vector.types.FloatingPointPrecision.SINGLE;
import static org.apache.calcite.avatica.ColumnMetaData.Rep.*;

/**
* Convert Avatica PreparedStatement parameters from a list of TypedValue to Arrow and bind them to
* the VectorSchemaRoot representing the PreparedStatement parameters.
Expand All @@ -59,13 +71,15 @@ public class AvaticaParameterBinder {
* Instantiate a new AvaticaParameterBinder.
*
* @param preparedStatement The PreparedStatement to bind parameters to.
* @param bufferAllocator The BufferAllocator to use for allocating memory.
* @param bufferAllocator The BufferAllocator to use for allocating memory.
*/
public AvaticaParameterBinder(
PreparedStatement preparedStatement, BufferAllocator bufferAllocator) {
PreparedStatement preparedStatement, List<TypedValue> typedValues, BufferAllocator bufferAllocator) {
this.parameters =
VectorSchemaRoot.create(preparedStatement.getParameterSchema(), bufferAllocator);
VectorSchemaRoot.create(makeSchema(typedValues), bufferAllocator);
this.preparedStatement = preparedStatement;


}

/**
Expand All @@ -77,19 +91,90 @@ public void bind(List<TypedValue> typedValues) {
bind(typedValues, 0);
}

private ArrowType getArrowTypeFromTypedValue(TypedValue typedValue) {
switch (typedValue.type) {
case PRIMITIVE_BOOLEAN:
case BOOLEAN:
return new ArrowType.Bool();

case PRIMITIVE_BYTE:
case BYTE:
return new ArrowType.Int(8, true);

case PRIMITIVE_CHAR:
case CHARACTER:

case STRING:
return new ArrowType.Utf8();

case PRIMITIVE_SHORT:
case SHORT:
return new ArrowType.Int(16, true);

case PRIMITIVE_INT:
case INTEGER:
return new ArrowType.Int(32, true);

case PRIMITIVE_LONG:
case LONG:
return new ArrowType.Int(64, true);

case PRIMITIVE_FLOAT:
case FLOAT:
return new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE);

case PRIMITIVE_DOUBLE:
case DOUBLE:
return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE);

case JAVA_SQL_TIME:
return new ArrowType.Time(TimeUnit.MILLISECOND, 32);

case JAVA_SQL_TIMESTAMP:
// TODO: figure out TZ
return new ArrowType.Timestamp(TimeUnit.MILLISECOND, null);

case JAVA_SQL_DATE:
case JAVA_UTIL_DATE:
return new ArrowType.Date(DateUnit.DAY);

case BYTE_STRING:
return new ArrowType.Binary();

case NUMBER:
return new ArrowType.Decimal(38, 0, 128);

case ARRAY:
return new ArrowType.List();

case MULTISET:
case STRUCT:
return new ArrowType.Struct();

case OBJECT:
// TODO: figure out how to handle Object. I imagine java.time objects end up here
default:
throw new UnsupportedOperationException("Unsupported TypedValue type: " + typedValue.type);
}
}

public Schema makeSchema(List<TypedValue> typedValues) {
final List<Field> parameterFields = new ArrayList<>(typedValues.size());
for (int i = 0; i < typedValues.size(); i++) {
ArrowType arrowType = getArrowTypeFromTypedValue(typedValues.get(i));
FieldType fieldType = new FieldType(false, arrowType, null);
parameterFields.add(new Field(null, fieldType, null));
}
return new Schema(parameterFields);
}

/**
* Bind the given Avatica values to the prepared statement at the given index.
*
* @param typedValues The parameter values.
* @param index index for parameter.
* @param index index for parameter.
*/
public void bind(List<TypedValue> typedValues, int index) {
if (preparedStatement.getParameterSchema().getFields().size() != typedValues.size()) {
throw new IllegalStateException(
String.format(
"Prepared statement has %s parameters, but only received %s",
preparedStatement.getParameterSchema().getFields().size(), typedValues.size()));
}

for (int i = 0; i < typedValues.size(); i++) {
bind(parameters.getVector(i), typedValues.get(i), index);
Expand All @@ -104,9 +189,9 @@ public void bind(List<TypedValue> typedValues, int index) {
/**
* Bind a TypedValue to the given index on the FieldVector.
*
* @param vector FieldVector to bind to.
* @param vector FieldVector to bind to.
* @param typedValue TypedValue to bind to the vector.
* @param index Vector index to bind the value at.
* @param index Vector index to bind the value at.
*/
private void bind(FieldVector vector, TypedValue typedValue, int index) {
try {
Expand Down Expand Up @@ -144,8 +229,8 @@ public static class BinderVisitor implements ArrowType.ArrowTypeVisitor<Boolean>
* Instantiate a new BinderVisitor.
*
* @param vector FieldVector to bind values to.
* @param value TypedValue to bind.
* @param index Vector index (0-based) to bind the value to.
* @param value TypedValue to bind.
* @param index Vector index (0-based) to bind the value to.
*/
public BinderVisitor(FieldVector vector, TypedValue value, int index) {
this.vector = vector;
Expand Down