Skip to content

Commit

Permalink
Support generated keys from INSERT query
Browse files Browse the repository at this point in the history
Parse the 'generated_ids' array which is returned after successful
INSERT command has applied. This makes sense when a table primary key
has an autoincrement property. The driver always returns a predefined
result set with a single-column table (column name is 'GENERATED_KEYS')
where each row is one generated value.

Closes: #77
  • Loading branch information
nicktorwald committed Aug 8, 2019
1 parent fa0bc3f commit 45cfe17
Show file tree
Hide file tree
Showing 16 changed files with 259 additions and 69 deletions.
3 changes: 2 additions & 1 deletion src/main/java/org/tarantool/Key.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ public enum Key implements Callable<Integer> {
SQL_BIND(0x41),
SQL_OPTIONS(0x42),
SQL_INFO(0x42),
SQL_ROW_COUNT(0);
SQL_ROW_COUNT(0x00),
SQL_INFO_AUTOINCREMENT_IDS(0x01);

int id;

Expand Down
12 changes: 11 additions & 1 deletion src/main/java/org/tarantool/SqlProtoUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.tarantool.protocol.TarantoolPacket;

import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -41,7 +42,7 @@ public static List<SQLMetaData> getSQLMetadata(TarantoolPacket pack) {
return values;
}

public static Long getSqlRowCount(TarantoolPacket pack) {
public static Long getSQLRowCount(TarantoolPacket pack) {
Map<Key, Object> info = (Map<Key, Object>) pack.getBody().get(Key.SQL_INFO.getId());
Number rowCount;
if (info != null && (rowCount = ((Number) info.get(Key.SQL_ROW_COUNT.getId()))) != null) {
Expand All @@ -50,6 +51,15 @@ public static Long getSqlRowCount(TarantoolPacket pack) {
return null;
}

public static List<Integer> getSQLAutoIncrementIds(TarantoolPacket pack) {
Map<Key, Object> info = (Map<Key, Object>) pack.getBody().get(Key.SQL_INFO.getId());
if (info != null) {
List<Integer> generatedIds = (List<Integer>) info.get(Key.SQL_INFO_AUTOINCREMENT_IDS.getId());
return generatedIds == null ? Collections.emptyList() : generatedIds;
}
return Collections.emptyList();
}

public static class SQLMetaData {
private String name;
private TarantoolSqlType type;
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/tarantool/TarantoolClientImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ protected void complete(TarantoolPacket packet, TarantoolOp<?> future) {
}

protected void completeSql(TarantoolOp<?> future, TarantoolPacket pack) {
Long rowCount = SqlProtoUtils.getSqlRowCount(pack);
Long rowCount = SqlProtoUtils.getSQLRowCount(pack);
if (rowCount != null) {
((TarantoolOp) future).complete(rowCount);
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/tarantool/TarantoolConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public void close() {
@Override
public Long update(String sql, Object... bind) {
TarantoolPacket pack = sql(sql, bind);
return SqlProtoUtils.getSqlRowCount(pack);
return SqlProtoUtils.getSQLRowCount(pack);
}

@Override
Expand Down
18 changes: 9 additions & 9 deletions src/main/java/org/tarantool/jdbc/SQLConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
* <p>
* Supports creating {@link Statement} and {@link PreparedStatement} instances
*/
public class SQLConnection implements Connection {
public class SQLConnection implements TarantoolConnection {

private static final int UNSET_HOLDABILITY = 0;
private static final String PING_QUERY = "SELECT 1";
Expand Down Expand Up @@ -148,10 +148,7 @@ public PreparedStatement prepareStatement(String sql,
public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException {
checkNotClosed();
JdbcConstants.checkGeneratedKeysConstant(autoGeneratedKeys);
if (autoGeneratedKeys != Statement.NO_GENERATED_KEYS) {
throw new SQLFeatureNotSupportedException();
}
return prepareStatement(sql);
return new SQLPreparedStatement(this, sql, autoGeneratedKeys);
}

@Override
Expand Down Expand Up @@ -527,14 +524,17 @@ public int getNetworkTimeout() throws SQLException {
return (int) client.getOperationTimeout();
}

protected SQLResultHolder execute(long timeout, SQLQueryHolder query) throws SQLException {
@Override
public SQLResultHolder execute(long timeout, SQLQueryHolder query) throws SQLException {
checkNotClosed();
return (useNetworkTimeout(timeout))
? executeWithNetworkTimeout(query)
: executeWithQueryTimeout(timeout, query);
}

protected SQLBatchResultHolder executeBatch(long timeout, List<SQLQueryHolder> queries) throws SQLException {
@Override
public SQLBatchResultHolder executeBatch(long timeout, List<SQLQueryHolder> queries)
throws SQLException {
checkNotClosed();
SQLTarantoolClientImpl.SQLRawOps sqlOps = client.sqlRawOps();
SQLBatchResultHolder batchResult = useNetworkTimeout(timeout)
Expand Down Expand Up @@ -810,10 +810,10 @@ SQLRawOps sqlRawOps() {

@Override
protected void completeSql(TarantoolOp<?> future, TarantoolPacket pack) {
Long rowCount = SqlProtoUtils.getSqlRowCount(pack);
Long rowCount = SqlProtoUtils.getSQLRowCount(pack);
SQLResultHolder result = (rowCount == null)
? SQLResultHolder.ofQuery(SqlProtoUtils.getSQLMetadata(pack), SqlProtoUtils.getSQLData(pack))
: SQLResultHolder.ofUpdate(rowCount.intValue());
: SQLResultHolder.ofUpdate(rowCount.intValue(), SqlProtoUtils.getSQLAutoIncrementIds(pack));
((TarantoolOp) future).complete(result);
}

Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/tarantool/jdbc/SQLDatabaseMetadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,7 @@ public boolean supportsMultipleOpenResults() throws SQLException {

@Override
public boolean supportsGetGeneratedKeys() throws SQLException {
return false;
return true;
}

@Override
Expand Down Expand Up @@ -1104,7 +1104,7 @@ private ResultSet asEmptyMetadataResultSet(List<TupleTwo<String, TarantoolSqlTyp

@Override
public boolean generatedKeyAlwaysReturned() throws SQLException {
return false;
return true;
}

@Override
Expand Down
12 changes: 7 additions & 5 deletions src/main/java/org/tarantool/jdbc/SQLPreparedStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@ public class SQLPreparedStatement extends SQLStatement implements PreparedStatem

private final String sql;
private final Map<Integer, Object> parameters;

private final int autoGeneratedKeys;
private List<Map<Integer, Object>> batchParameters = new ArrayList<>();

public SQLPreparedStatement(SQLConnection connection, String sql) throws SQLException {
public SQLPreparedStatement(SQLConnection connection, String sql, int autoGeneratedKeys) throws SQLException {
super(connection);
this.sql = sql;
this.parameters = new HashMap<>();
this.autoGeneratedKeys = autoGeneratedKeys;
setPoolable(true);
}

Expand All @@ -52,13 +53,14 @@ public SQLPreparedStatement(SQLConnection connection,
super(connection, resultSetType, resultSetConcurrency, resultSetHoldability);
this.sql = sql;
this.parameters = new HashMap<>();
this.autoGeneratedKeys = NO_GENERATED_KEYS;
setPoolable(true);
}

@Override
public ResultSet executeQuery() throws SQLException {
checkNotClosed();
if (!executeInternal(sql, toParametersList(parameters))) {
if (!executeInternal(autoGeneratedKeys, sql, toParametersList(parameters))) {
throw new SQLException("No results were returned", SQLStates.NO_DATA.getSqlState());
}
return resultSet;
Expand All @@ -73,7 +75,7 @@ public ResultSet executeQuery(String sql) throws SQLException {
@Override
public int executeUpdate() throws SQLException {
checkNotClosed();
if (executeInternal(sql, toParametersList(parameters))) {
if (executeInternal(autoGeneratedKeys, sql, toParametersList(parameters))) {
throw new SQLException(
"Result was returned but nothing was expected",
SQLStates.TOO_MANY_RESULTS.getSqlState()
Expand Down Expand Up @@ -244,7 +246,7 @@ private void setParameter(int parameterIndex, Object value) throws SQLException
@Override
public boolean execute() throws SQLException {
checkNotClosed();
return executeInternal(sql, toParametersList(parameters));
return executeInternal(autoGeneratedKeys, sql, toParametersList(parameters));
}

@Override
Expand Down
17 changes: 13 additions & 4 deletions src/main/java/org/tarantool/jdbc/SQLResultHolder.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,29 @@ public class SQLResultHolder {
private final List<SqlProtoUtils.SQLMetaData> sqlMetadata;
private final List<List<Object>> rows;
private final int updateCount;
private final List<Integer> generatedIds;

public SQLResultHolder(List<SqlProtoUtils.SQLMetaData> sqlMetadata, List<List<Object>> rows, int updateCount) {
public SQLResultHolder(List<SqlProtoUtils.SQLMetaData> sqlMetadata,
List<List<Object>> rows,
int updateCount,
List<Integer> generatedIds) {
this.sqlMetadata = sqlMetadata;
this.rows = rows;
this.updateCount = updateCount;
this.generatedIds = generatedIds;
}

public static SQLResultHolder ofQuery(final List<SqlProtoUtils.SQLMetaData> sqlMetadata,
final List<List<Object>> rows) {
return new SQLResultHolder(sqlMetadata, rows, NO_UPDATE_COUNT);
return new SQLResultHolder(sqlMetadata, rows, NO_UPDATE_COUNT, Collections.emptyList());
}

public static SQLResultHolder ofEmptyQuery() {
return ofQuery(Collections.emptyList(), Collections.emptyList());
}

public static SQLResultHolder ofUpdate(int updateCount) {
return new SQLResultHolder(null, null, updateCount);
public static SQLResultHolder ofUpdate(int updateCount, List<Integer> generatedIds) {
return new SQLResultHolder(null, null, updateCount, generatedIds);
}

public List<SqlProtoUtils.SQLMetaData> getSqlMetadata() {
Expand All @@ -48,6 +53,10 @@ public int getUpdateCount() {
return updateCount;
}

public List<Integer> getGeneratedIds() {
return generatedIds;
}

public boolean isQueryResult() {
return sqlMetadata != null && rows != null;
}
Expand Down
64 changes: 40 additions & 24 deletions src/main/java/org/tarantool/jdbc/SQLStatement.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.tarantool.jdbc;

import org.tarantool.SqlProtoUtils;
import org.tarantool.jdbc.type.TarantoolSqlType;
import org.tarantool.util.JdbcConstants;
import org.tarantool.util.SQLStates;

Expand All @@ -13,6 +15,7 @@
import java.sql.SQLWarning;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
Expand All @@ -27,13 +30,17 @@
*/
public class SQLStatement implements TarantoolStatement {

protected final SQLConnection connection;
private static final String GENERATED_KEY_COLUMN_NAME = "GENERATED_KEY";

protected final TarantoolConnection connection;
private final SQLResultSet emptyGeneratedKeys;

/**
* Current result set / update count associated to this statement.
*/
protected SQLResultSet resultSet;
protected int updateCount;
protected SQLResultSet generatedKeys;

private List<String> batchQueries = new ArrayList<>();

Expand Down Expand Up @@ -61,10 +68,12 @@ public class SQLStatement implements TarantoolStatement {
private final AtomicBoolean isClosed = new AtomicBoolean(false);

protected SQLStatement(SQLConnection sqlConnection) throws SQLException {
this.connection = sqlConnection;
this.resultSetType = ResultSet.TYPE_FORWARD_ONLY;
this.resultSetConcurrency = ResultSet.CONCUR_READ_ONLY;
this.resultSetHoldability = sqlConnection.getHoldability();
this(
sqlConnection,
ResultSet.TYPE_FORWARD_ONLY,
ResultSet.CONCUR_READ_ONLY,
sqlConnection.getHoldability()
);
}

protected SQLStatement(SQLConnection sqlConnection,
Expand All @@ -75,37 +84,34 @@ protected SQLStatement(SQLConnection sqlConnection,
this.resultSetType = resultSetType;
this.resultSetConcurrency = resultSetConcurrency;
this.resultSetHoldability = resultSetHoldability;
this.emptyGeneratedKeys = this.generatedKeys = executeGeneratedKeys(Collections.emptyList());
}

@Override
public ResultSet executeQuery(String sql) throws SQLException {
checkNotClosed();
if (!executeInternal(sql)) {
if (!executeInternal(NO_GENERATED_KEYS, sql)) {
throw new SQLException("No results were returned", SQLStates.NO_DATA.getSqlState());
}
return resultSet;
}

@Override
public int executeUpdate(String sql) throws SQLException {
checkNotClosed();
if (executeInternal(sql)) {
throw new SQLException(
"Result was returned but nothing was expected",
SQLStates.TOO_MANY_RESULTS.getSqlState()
);
}
return updateCount;
return executeUpdate(sql, NO_GENERATED_KEYS);
}

@Override
public int executeUpdate(String sql, int autoGeneratedKeys) throws SQLException {
checkNotClosed();
JdbcConstants.checkGeneratedKeysConstant(autoGeneratedKeys);
if (autoGeneratedKeys != Statement.NO_GENERATED_KEYS) {
throw new SQLFeatureNotSupportedException();
if (executeInternal(autoGeneratedKeys, sql)) {
throw new SQLException(
"Result was returned but nothing was expected",
SQLStates.TOO_MANY_RESULTS.getSqlState()
);
}
return executeUpdate(sql);
return updateCount;
}

@Override
Expand Down Expand Up @@ -195,17 +201,14 @@ public void setCursorName(String name) throws SQLException {
@Override
public boolean execute(String sql) throws SQLException {
checkNotClosed();
return executeInternal(sql);
return executeInternal(NO_GENERATED_KEYS, sql);
}

@Override
public boolean execute(String sql, int autoGeneratedKeys) throws SQLException {
checkNotClosed();
JdbcConstants.checkGeneratedKeysConstant(autoGeneratedKeys);
if (autoGeneratedKeys != Statement.NO_GENERATED_KEYS) {
throw new SQLFeatureNotSupportedException();
}
return execute(sql);
return executeInternal(autoGeneratedKeys, sql);
}

@Override
Expand Down Expand Up @@ -321,7 +324,7 @@ public Connection getConnection() throws SQLException {
@Override
public ResultSet getGeneratedKeys() throws SQLException {
checkNotClosed();
return new SQLResultSet(SQLResultHolder.ofEmptyQuery(), this);
return generatedKeys;
}

@Override
Expand Down Expand Up @@ -401,6 +404,7 @@ protected void discardLastResults() throws SQLException {
clearWarnings();
updateCount = -1;
resultSet = null;
generatedKeys = emptyGeneratedKeys;

if (lastResultSet != null) {
try {
Expand All @@ -419,7 +423,7 @@ protected void discardLastResults() throws SQLException {
*
* @return {@code true}, if the result is a ResultSet object;
*/
protected boolean executeInternal(String sql, Object... params) throws SQLException {
protected boolean executeInternal(int autoGeneratedKeys, String sql, Object... params) throws SQLException {
discardLastResults();
SQLResultHolder holder;
try {
Expand All @@ -433,6 +437,9 @@ protected boolean executeInternal(String sql, Object... params) throws SQLExcept
resultSet = new SQLResultSet(holder, this);
}
updateCount = holder.getUpdateCount();
if (autoGeneratedKeys == Statement.RETURN_GENERATED_KEYS) {
generatedKeys = executeGeneratedKeys(holder.getGeneratedIds());
}
return holder.isQueryResult();
}

Expand Down Expand Up @@ -474,4 +481,13 @@ protected void checkNotClosed() throws SQLException {
}
}

protected SQLResultSet executeGeneratedKeys(List<Integer> generatedKeys) throws SQLException {
SqlProtoUtils.SQLMetaData sqlMetaData =
new SqlProtoUtils.SQLMetaData(GENERATED_KEY_COLUMN_NAME, TarantoolSqlType.INTEGER);
List<List<Object>> rows = generatedKeys.stream()
.map(Collections::<Object>singletonList)
.collect(Collectors.toList());
return createResultSet(SQLResultHolder.ofQuery(Collections.singletonList(sqlMetaData), rows));
}

}
Loading

0 comments on commit 45cfe17

Please sign in to comment.