diff --git a/src/main/java/org/tarantool/jdbc/SQLResultSet.java b/src/main/java/org/tarantool/jdbc/SQLResultSet.java index 85d36844..f707da7b 100644 --- a/src/main/java/org/tarantool/jdbc/SQLResultSet.java +++ b/src/main/java/org/tarantool/jdbc/SQLResultSet.java @@ -42,6 +42,7 @@ public class SQLResultSet implements ResultSet { private final SQLResultSetMetaData metaData; private Map columnByNameLookups; + private boolean lastColumnWasNull; private final Statement statement; private final int maxRows; @@ -86,6 +87,24 @@ public List getCurrentRow() throws SQLException { return iterator.getItem(); } + protected Object getRaw(int columnIndex) throws SQLException { + checkNotClosed(); + metaData.checkColumnIndex(columnIndex); + List row = getCurrentRow(); + Object value = row.get(columnIndex - 1); + lastColumnWasNull = (value == null); + return value; + } + + protected Number getNumber(int columnIndex) throws SQLException { + Number raw = (Number) getRaw(columnIndex); + return raw == null ? 0 : raw; + } + + protected Number getNullableNumber(int columnIndex) throws SQLException { + return (Number) getRaw(columnIndex); + } + @Override public void close() throws SQLException { if (isClosed.compareAndSet(false, true)) { @@ -95,14 +114,8 @@ public void close() throws SQLException { @Override public boolean wasNull() throws SQLException { - return false; - } - - protected Object getRaw(int columnIndex) throws SQLException { checkNotClosed(); - metaData.checkColumnIndex(columnIndex); - List row = getCurrentRow(); - return row.get(columnIndex - 1); + return lastColumnWasNull; } @Override @@ -156,11 +169,6 @@ public int getInt(String columnLabel) throws SQLException { return getInt(findColumn(columnLabel)); } - private Number getNumber(int columnIndex) throws SQLException { - Number raw = (Number) getRaw(columnIndex); - return raw == null ? 0 : raw; - } - @Override public long getLong(int columnIndex) throws SQLException { return (getNumber(columnIndex)).longValue(); @@ -193,13 +201,17 @@ public double getDouble(String columnLabel) throws SQLException { @Override public BigDecimal getBigDecimal(int columnIndex, int scale) throws SQLException { - BigDecimal bigDecimal = new BigDecimal(getString(columnIndex)); + String raw = getString(columnIndex); + if (raw == null) { + return null; + } + BigDecimal bigDecimal = new BigDecimal(raw); return scale > -1 ? bigDecimal.setScale(scale) : bigDecimal; } @Override public BigDecimal getBigDecimal(String columnLabel, int scale) throws SQLException { - return getBigDecimal(findColumn(columnLabel)); + return getBigDecimal(findColumn(columnLabel), scale); } @Override @@ -224,7 +236,8 @@ public byte[] getBytes(String columnLabel) throws SQLException { @Override public Date getDate(int columnIndex) throws SQLException { - return new java.sql.Date(getLong(columnIndex)); + Number time = getNullableNumber(columnIndex); + return time == null ? null : new java.sql.Date(time.longValue()); } @Override @@ -244,7 +257,8 @@ public Date getDate(String columnLabel, Calendar cal) throws SQLException { @Override public Time getTime(int columnIndex) throws SQLException { - return new java.sql.Time(getLong(columnIndex)); + Number time = getNullableNumber(columnIndex); + return time == null ? null : new java.sql.Time(time.longValue()); } @Override @@ -264,7 +278,8 @@ public Time getTime(String columnLabel, Calendar cal) throws SQLException { @Override public Timestamp getTimestamp(int columnIndex) throws SQLException { - return new java.sql.Timestamp(getLong(columnIndex)); + Number time = getNullableNumber(columnIndex); + return time == null ? null : new java.sql.Timestamp(time.longValue()); } @Override @@ -285,7 +300,8 @@ public Timestamp getTimestamp(String columnLabel, Calendar cal) throws SQLExcept @Override public InputStream getAsciiStream(int columnIndex) throws SQLException { - return new ByteArrayInputStream(getString(columnIndex).getBytes(Charset.forName("ASCII"))); + String string = getString(columnIndex); + return string == null ? null : new ByteArrayInputStream(string.getBytes(Charset.forName("ASCII"))); } @Override @@ -306,7 +322,8 @@ public InputStream getUnicodeStream(String columnLabel) throws SQLException { @Override public InputStream getBinaryStream(int columnIndex) throws SQLException { - return new ByteArrayInputStream(getBytes(columnIndex)); + byte[] bytes = getBytes(columnIndex); + return bytes == null ? null : new ByteArrayInputStream(bytes); } @Override @@ -316,12 +333,13 @@ public InputStream getBinaryStream(String columnLabel) throws SQLException { @Override public Reader getCharacterStream(int columnIndex) throws SQLException { - return new StringReader(getString(columnIndex)); + String value = getString(columnIndex); + return value == null ? null : new StringReader(value); } @Override public Reader getCharacterStream(String columnLabel) throws SQLException { - return new StringReader(getString(columnLabel)); + return getCharacterStream(findColumn(columnLabel)); } @Override @@ -331,7 +349,7 @@ public Object getObject(int columnIndex) throws SQLException { @Override public Object getObject(String columnLabel) throws SQLException { - return getRaw(findColumn(columnLabel)); + return getObject(findColumn(columnLabel)); } @Override @@ -346,12 +364,16 @@ public Object getObject(String columnLabel, Map> map) throws SQ @Override public T getObject(int columnIndex, Class type) throws SQLException { - return type.cast(getRaw(columnIndex)); + try { + return type.cast(getRaw(columnIndex)); + } catch (Exception e) { + throw new SQLNonTransientException(e); + } } @Override public T getObject(String columnLabel, Class type) throws SQLException { - return type.cast(getRaw(findColumn(columnLabel))); + return getObject(findColumn(columnLabel), type); } @Override diff --git a/src/test/java/org/tarantool/jdbc/AbstractJdbcIT.java b/src/test/java/org/tarantool/jdbc/AbstractJdbcIT.java index dc7be4c5..c8576c86 100644 --- a/src/test/java/org/tarantool/jdbc/AbstractJdbcIT.java +++ b/src/test/java/org/tarantool/jdbc/AbstractJdbcIT.java @@ -36,8 +36,9 @@ public abstract class AbstractJdbcIT { "CREATE TABLE test(id INT PRIMARY KEY, val VARCHAR(100))", "INSERT INTO test VALUES (1, 'one'), (2, 'two'), (3, 'three')", "CREATE TABLE test_compound(id1 INT, id2 INT, val VARCHAR(100), PRIMARY KEY (id2, id1))", - "CREATE TABLE test_nulls(id INT PRIMARY KEY, val VARCHAR(100))", - "INSERT INTO test_nulls VALUES (1, 'a'), (2, 'b'), (3, 'c'), (4, NULL), (5, NULL), (6, NULL)", + "CREATE TABLE test_nulls(id INT PRIMARY KEY, val VARCHAR(100), dig INTEGER, bin SCALAR)", + "INSERT INTO test_nulls VALUES (1, 'a', 10, 'aa'), (2, 'b', 20, 'bb'), (3, 'c', 30, 'cc'), " + + "(4, NULL, NULL, NULL), (5, NULL, NULL, NULL), (6, NULL, NULL, NULL)", getCreateTableSQL("test_types", TntSqlType.values()) }; diff --git a/src/test/java/org/tarantool/jdbc/JdbcResultSetIT.java b/src/test/java/org/tarantool/jdbc/JdbcResultSetIT.java index 944c4198..2f355aae 100644 --- a/src/test/java/org/tarantool/jdbc/JdbcResultSetIT.java +++ b/src/test/java/org/tarantool/jdbc/JdbcResultSetIT.java @@ -208,7 +208,7 @@ public void testNullsSortingAsc() throws SQLException { @Test public void testNullsSortingDesc() throws SQLException { - ResultSet resultSet = stmt.executeQuery("SELECT * FROM test_nulls ORDER BY val DESC"); + ResultSet resultSet = stmt.executeQuery("SELECT id, dig FROM test_nulls ORDER BY val DESC"); for (int i = 0; i < 3; i++) { assertTrue(resultSet.next()); assertNotNull(resultSet.getString(2)); @@ -220,6 +220,57 @@ public void testNullsSortingDesc() throws SQLException { assertFalse(resultSet.next()); } + @Test + void testObjectWasNullColumn() throws SQLException { + ResultSet resultSet = stmt.executeQuery("SELECT id, dig FROM test_nulls WHERE val IS NULL"); + resultSet.next(); + + resultSet.getInt(1); + assertFalse(resultSet.wasNull()); + assertNull(resultSet.getString(2)); + assertTrue(resultSet.wasNull()); + } + + @Test + void testBinaryWasNullColumn() throws SQLException { + ResultSet resultSet = stmt.executeQuery("SELECT id, bin FROM test_nulls WHERE bin IS NULL"); + resultSet.next(); + + resultSet.getInt(1); + assertFalse(resultSet.wasNull()); + assertNull(resultSet.getString(2)); + assertTrue(resultSet.wasNull()); + assertNull(resultSet.getAsciiStream(2)); + assertTrue(resultSet.wasNull()); + assertNull(resultSet.getBinaryStream(2)); + assertTrue(resultSet.wasNull()); + assertNull(resultSet.getUnicodeStream(2)); + assertTrue(resultSet.wasNull()); + assertNull(resultSet.getCharacterStream(2)); + assertTrue(resultSet.wasNull()); + } + + @Test + void testNumberWasNullColumn() throws SQLException { + ResultSet resultSet = stmt.executeQuery("SELECT id, dig FROM test_nulls WHERE dig IS NULL"); + resultSet.next(); + + resultSet.getInt(1); + assertFalse(resultSet.wasNull()); + assertEquals(0, resultSet.getInt(2)); + assertTrue(resultSet.wasNull()); + assertEquals(0, resultSet.getShort(2)); + assertTrue(resultSet.wasNull()); + assertEquals(0, resultSet.getByte(2)); + assertTrue(resultSet.wasNull()); + assertEquals(0, resultSet.getLong(2)); + assertTrue(resultSet.wasNull()); + assertEquals(0, resultSet.getDouble(2)); + assertTrue(resultSet.wasNull()); + assertEquals(0, resultSet.getFloat(2)); + assertTrue(resultSet.wasNull()); + } + @Test public void testFindUniqueColumnLabels() throws SQLException { ResultSet resultSet = stmt.executeQuery("SELECT id as f1, val as f2 FROM test");