Skip to content

Commit

Permalink
Add a proper implementation for wasNull method
Browse files Browse the repository at this point in the history
Closes: #179
  • Loading branch information
nicktorwald committed May 20, 2019
1 parent fb368d2 commit 238e330
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 28 deletions.
73 changes: 48 additions & 25 deletions src/main/java/org/tarantool/jdbc/SQLResultSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public class SQLResultSet implements ResultSet {
private final SQLResultSetMetaData metaData;

private Map<String, Integer> columnByNameLookups;
private boolean lastColumnWasNull;

private final Statement statement;
private final int maxRows;
Expand Down Expand Up @@ -86,6 +87,24 @@ public List<Object> getCurrentRow() throws SQLException {
return iterator.getItem();
}

protected Object getRaw(int columnIndex) throws SQLException {
checkNotClosed();
metaData.checkColumnIndex(columnIndex);
List<Object> row = getCurrentRow();
Object value = row.get(columnIndex - 1);
lastColumnWasNull = (value == null);
return value;
}

protected Number getNumber(int columnIndex) throws SQLException {
Number raw = (Number) getRaw(columnIndex);
return raw == null ? 0 : raw;
}

protected Number getNullableNumber(int columnIndex) throws SQLException {
return (Number) getRaw(columnIndex);
}

@Override
public void close() throws SQLException {
if (isClosed.compareAndSet(false, true)) {
Expand All @@ -95,14 +114,8 @@ public void close() throws SQLException {

@Override
public boolean wasNull() throws SQLException {
return false;
}

protected Object getRaw(int columnIndex) throws SQLException {
checkNotClosed();
metaData.checkColumnIndex(columnIndex);
List<Object> row = getCurrentRow();
return row.get(columnIndex - 1);
return lastColumnWasNull;
}

@Override
Expand Down Expand Up @@ -156,11 +169,6 @@ public int getInt(String columnLabel) throws SQLException {
return getInt(findColumn(columnLabel));
}

private Number getNumber(int columnIndex) throws SQLException {
Number raw = (Number) getRaw(columnIndex);
return raw == null ? 0 : raw;
}

@Override
public long getLong(int columnIndex) throws SQLException {
return (getNumber(columnIndex)).longValue();
Expand Down Expand Up @@ -193,13 +201,17 @@ public double getDouble(String columnLabel) throws SQLException {

@Override
public BigDecimal getBigDecimal(int columnIndex, int scale) throws SQLException {
BigDecimal bigDecimal = new BigDecimal(getString(columnIndex));
String raw = getString(columnIndex);
if (raw == null) {
return null;
}
BigDecimal bigDecimal = new BigDecimal(raw);
return scale > -1 ? bigDecimal.setScale(scale) : bigDecimal;
}

@Override
public BigDecimal getBigDecimal(String columnLabel, int scale) throws SQLException {
return getBigDecimal(findColumn(columnLabel));
return getBigDecimal(findColumn(columnLabel), scale);
}

@Override
Expand All @@ -224,7 +236,8 @@ public byte[] getBytes(String columnLabel) throws SQLException {

@Override
public Date getDate(int columnIndex) throws SQLException {
return new java.sql.Date(getLong(columnIndex));
Number time = getNullableNumber(columnIndex);
return time == null ? null : new java.sql.Date(time.longValue());
}

@Override
Expand All @@ -244,7 +257,8 @@ public Date getDate(String columnLabel, Calendar cal) throws SQLException {

@Override
public Time getTime(int columnIndex) throws SQLException {
return new java.sql.Time(getLong(columnIndex));
Number time = getNullableNumber(columnIndex);
return time == null ? null : new java.sql.Time(time.longValue());
}

@Override
Expand All @@ -264,7 +278,8 @@ public Time getTime(String columnLabel, Calendar cal) throws SQLException {

@Override
public Timestamp getTimestamp(int columnIndex) throws SQLException {
return new java.sql.Timestamp(getLong(columnIndex));
Number time = getNullableNumber(columnIndex);
return time == null ? null : new java.sql.Timestamp(time.longValue());
}

@Override
Expand All @@ -285,7 +300,8 @@ public Timestamp getTimestamp(String columnLabel, Calendar cal) throws SQLExcept

@Override
public InputStream getAsciiStream(int columnIndex) throws SQLException {
return new ByteArrayInputStream(getString(columnIndex).getBytes(Charset.forName("ASCII")));
String string = getString(columnIndex);
return string == null ? null : new ByteArrayInputStream(string.getBytes(Charset.forName("ASCII")));
}

@Override
Expand All @@ -295,7 +311,8 @@ public InputStream getAsciiStream(String columnLabel) throws SQLException {

@Override
public InputStream getUnicodeStream(int columnIndex) throws SQLException {
return new ByteArrayInputStream(getString(columnIndex).getBytes(Charset.forName("UTF-8")));
String string = getString(columnIndex);
return string == null ? null : new ByteArrayInputStream(string.getBytes(Charset.forName("UTF-16")));
}

@Override
Expand All @@ -305,7 +322,8 @@ public InputStream getUnicodeStream(String columnLabel) throws SQLException {

@Override
public InputStream getBinaryStream(int columnIndex) throws SQLException {
return new ByteArrayInputStream(getBytes(columnIndex));
byte[] bytes = getBytes(columnIndex);
return bytes == null ? null : new ByteArrayInputStream(bytes);
}

@Override
Expand All @@ -315,12 +333,13 @@ public InputStream getBinaryStream(String columnLabel) throws SQLException {

@Override
public Reader getCharacterStream(int columnIndex) throws SQLException {
return new StringReader(getString(columnIndex));
String value = getString(columnIndex);
return value == null ? null : new StringReader(value);
}

@Override
public Reader getCharacterStream(String columnLabel) throws SQLException {
return new StringReader(getString(columnLabel));
return getCharacterStream(findColumn(columnLabel));
}

@Override
Expand All @@ -330,7 +349,7 @@ public Object getObject(int columnIndex) throws SQLException {

@Override
public Object getObject(String columnLabel) throws SQLException {
return getRaw(findColumn(columnLabel));
return getObject(findColumn(columnLabel));
}

@Override
Expand All @@ -345,12 +364,16 @@ public Object getObject(String columnLabel, Map<String, Class<?>> map) throws SQ

@Override
public <T> T getObject(int columnIndex, Class<T> type) throws SQLException {
return type.cast(getRaw(columnIndex));
try {
return type.cast(getRaw(columnIndex));
} catch (Exception e) {
throw new SQLNonTransientException(e);
}
}

@Override
public <T> T getObject(String columnLabel, Class<T> type) throws SQLException {
return type.cast(getRaw(findColumn(columnLabel)));
return getObject(findColumn(columnLabel), type);
}

@Override
Expand Down
5 changes: 3 additions & 2 deletions src/test/java/org/tarantool/jdbc/AbstractJdbcIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ public abstract class AbstractJdbcIT {
"CREATE TABLE test(id INT PRIMARY KEY, val VARCHAR(100))",
"INSERT INTO test VALUES (1, 'one'), (2, 'two'), (3, 'three')",
"CREATE TABLE test_compound(id1 INT, id2 INT, val VARCHAR(100), PRIMARY KEY (id2, id1))",
"CREATE TABLE test_nulls(id INT PRIMARY KEY, val VARCHAR(100))",
"INSERT INTO test_nulls VALUES (1, 'a'), (2, 'b'), (3, 'c'), (4, NULL), (5, NULL), (6, NULL)",
"CREATE TABLE test_nulls(id INT PRIMARY KEY, val VARCHAR(100), dig INTEGER, bin SCALAR)",
"INSERT INTO test_nulls VALUES (1, 'a', 10, 'aa'), (2, 'b', 20, 'bb'), (3, 'c', 30, 'cc'), " +
"(4, NULL, NULL, NULL), (5, NULL, NULL, NULL), (6, NULL, NULL, NULL)",
getCreateTableSQL("test_types", TntSqlType.values())
};

Expand Down
53 changes: 52 additions & 1 deletion src/test/java/org/tarantool/jdbc/JdbcResultSetIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ public void testNullsSortingAsc() throws SQLException {

@Test
public void testNullsSortingDesc() throws SQLException {
ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_nulls ORDER BY val DESC");
ResultSet resultSet = stmt.executeQuery("SELECT id, dig FROM test_nulls ORDER BY val DESC");
for (int i = 0; i < 3; i++) {
assertTrue(resultSet.next());
assertNotNull(resultSet.getString(2));
Expand All @@ -220,6 +220,57 @@ public void testNullsSortingDesc() throws SQLException {
assertFalse(resultSet.next());
}

@Test
void testObjectWasNullColumn() throws SQLException {
ResultSet resultSet = stmt.executeQuery("SELECT id, dig FROM test_nulls WHERE val IS NULL");
resultSet.next();

resultSet.getInt(1);
assertFalse(resultSet.wasNull());
assertNull(resultSet.getString(2));
assertTrue(resultSet.wasNull());
}

@Test
void testBinaryWasNullColumn() throws SQLException {
ResultSet resultSet = stmt.executeQuery("SELECT id, bin FROM test_nulls WHERE bin IS NULL");
resultSet.next();

resultSet.getInt(1);
assertFalse(resultSet.wasNull());
assertNull(resultSet.getString(2));
assertTrue(resultSet.wasNull());
assertNull(resultSet.getAsciiStream(2));
assertTrue(resultSet.wasNull());
assertNull(resultSet.getBinaryStream(2));
assertTrue(resultSet.wasNull());
assertNull(resultSet.getUnicodeStream(2));
assertTrue(resultSet.wasNull());
assertNull(resultSet.getCharacterStream(2));
assertTrue(resultSet.wasNull());
}

@Test
void testNumberWasNullColumn() throws SQLException {
ResultSet resultSet = stmt.executeQuery("SELECT id, dig FROM test_nulls WHERE dig IS NULL");
resultSet.next();

resultSet.getInt(1);
assertFalse(resultSet.wasNull());
assertEquals(0, resultSet.getInt(2));
assertTrue(resultSet.wasNull());
assertEquals(0, resultSet.getShort(2));
assertTrue(resultSet.wasNull());
assertEquals(0, resultSet.getByte(2));
assertTrue(resultSet.wasNull());
assertEquals(0, resultSet.getLong(2));
assertTrue(resultSet.wasNull());
assertEquals(0, resultSet.getDouble(2));
assertTrue(resultSet.wasNull());
assertEquals(0, resultSet.getFloat(2));
assertTrue(resultSet.wasNull());
}

@Test
public void testFindUniqueColumnLabels() throws SQLException {
ResultSet resultSet = stmt.executeQuery("SELECT id as f1, val as f2 FROM test");
Expand Down

0 comments on commit 238e330

Please sign in to comment.