From 8be57f3b4bd0b1112cf871d9426beda6919b3bb7 Mon Sep 17 00:00:00 2001 From: rusher Date: Mon, 21 Oct 2024 18:39:36 +0200 Subject: [PATCH] [CONJ-1205] permit setObject with ARRAY dataType --- .../mariadb/jdbc/BasePreparedStatement.java | 17 ++++ .../mariadb/jdbc/plugin/array/FloatArray.java | 6 +- .../codec/FloatArrayCodecTest.java | 98 +++++++++++++++++++ 3 files changed, 120 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/mariadb/jdbc/BasePreparedStatement.java b/src/main/java/org/mariadb/jdbc/BasePreparedStatement.java index bd3be5f37..484531eca 100644 --- a/src/main/java/org/mariadb/jdbc/BasePreparedStatement.java +++ b/src/main/java/org/mariadb/jdbc/BasePreparedStatement.java @@ -1131,6 +1131,23 @@ private void setInternalObject( // in case of not corresponding data type, converting switch (targetSqlType) { case Types.ARRAY: + if (obj instanceof float[]) { + parameters.set( + parameterIndex - 1, new Parameter<>(FloatArrayCodec.INSTANCE, (float[]) obj)); + return; + } else if (obj instanceof Float[]) { + parameters.set( + parameterIndex - 1, new Parameter<>(FloatObjectArrayCodec.INSTANCE, (Float[]) obj)); + return; + } else if (obj instanceof FloatArray) { + parameters.set( + parameterIndex - 1, + new Parameter<>(FloatArrayCodec.INSTANCE, (float[]) ((FloatArray) obj).getArray())); + return; + } + throw exceptionFactory() + .notSupported( + String.format("ARRAY Type not supported for %s", obj.getClass().getName())); case Types.DATALINK: case Types.JAVA_OBJECT: case Types.REF: diff --git a/src/main/java/org/mariadb/jdbc/plugin/array/FloatArray.java b/src/main/java/org/mariadb/jdbc/plugin/array/FloatArray.java index ddb9a8802..5b7a75265 100644 --- a/src/main/java/org/mariadb/jdbc/plugin/array/FloatArray.java +++ b/src/main/java/org/mariadb/jdbc/plugin/array/FloatArray.java @@ -82,7 +82,11 @@ public ResultSet getResultSet(Map> map) throws SQLException { public ResultSet getResultSet(long index, int count) throws SQLException { byte[][] rows = new byte[count][]; for (int i = 0; i < count; i++) { - rows[i] = Float.toString(this.val[(int) index - 1 + i]).getBytes(StandardCharsets.US_ASCII); + byte[] val = + Float.toString(this.val[(int) index - 1 + i]).getBytes(StandardCharsets.US_ASCII); + rows[i] = new byte[val.length + 1]; + rows[i][0] = (byte) val.length; + System.arraycopy(val, 0, rows[i], 1, val.length); } return new CompleteResult( diff --git a/src/test/java/org/mariadb/jdbc/integration/codec/FloatArrayCodecTest.java b/src/test/java/org/mariadb/jdbc/integration/codec/FloatArrayCodecTest.java index ae60f3b0b..06bcf135f 100644 --- a/src/test/java/org/mariadb/jdbc/integration/codec/FloatArrayCodecTest.java +++ b/src/test/java/org/mariadb/jdbc/integration/codec/FloatArrayCodecTest.java @@ -6,6 +6,7 @@ import static org.junit.jupiter.api.Assertions.*; import java.sql.*; +import java.util.HashMap; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -106,6 +107,103 @@ private void floatArrayArrayObj(org.mariadb.jdbc.Connection con) throws SQLExcep } } + @Test + public void floatArrayArrayObjWithType() throws SQLException { + Statement stmt = sharedConn.createStatement(); + stmt.execute("TRUNCATE TABLE BinaryCodec"); + floatArrayArrayObjWithType(sharedConn); + stmt.execute("TRUNCATE TABLE BinaryCodec"); + floatArrayArrayObjWithType(sharedConnBinary); + } + + private void floatArrayArrayObjWithType(org.mariadb.jdbc.Connection con) throws SQLException { + float[] val = new float[] {1, 2, 3}; + float[] val2 = new float[] {4, 5}; + float[] val3 = new float[] {7, 8, 9, 10}; + + Array valArray = con.createArrayOf("float", val); + + try (PreparedStatement prep = + con.prepareStatement("INSERT INTO BinaryCodec(t0, t1) VALUES (?, ?)")) { + prep.setInt(1, 1); + prep.setObject(2, valArray, Types.ARRAY); + prep.execute(); + + prep.setInt(1, 2); + prep.setObject(2, val2, Types.ARRAY); + prep.execute(); + + prep.setInt(1, 3); + prep.setObject(2, val3, Types.ARRAY); + prep.execute(); + } + + try (PreparedStatement prep = + con.prepareStatement( + "SELECT * FROM BinaryCodec", + ResultSet.TYPE_SCROLL_INSENSITIVE, + ResultSet.CONCUR_UPDATABLE)) { + ResultSet rs = prep.executeQuery(); + assertTrue(rs.next()); + float[] res = (float[]) rs.getArray(2).getArray(); + assertArrayEquals(val, res); + Array arr = rs.getArray(2); + assertArrayEquals(val, (float[]) arr.getArray(1, 3)); + assertArrayEquals(new float[] {2, 3}, (float[]) arr.getArray(2, 2)); + assertArrayEquals(new float[] {1, 2}, (float[]) arr.getArray(1, 2)); + assertThrowsContains( + SQLException.class, + () -> arr.getArray(0, 2), + "Wrong index position. Is 0 but must be in 1-3 range"); + assertThrowsContains( + SQLException.class, + () -> arr.getArray(2, 20), + "Count value is too big. Count is 20 but cannot be > to 2"); + assertEquals("float[]", arr.getBaseTypeName()); + assertEquals(Types.FLOAT, arr.getBaseType()); + assertThrowsContains( + SQLException.class, + () -> arr.getArray(new HashMap<>()), + "getArray(Map> map) is not supported"); + assertThrowsContains( + SQLException.class, + () -> arr.getArray(1, 2, new HashMap<>()), + "getArray(long index, int count, Map> map) is not supported"); + + ResultSet rss = arr.getResultSet(); + assertTrue(rss.next()); + assertEquals(1, rss.getFloat(1)); + assertTrue(rss.next()); + assertEquals(2, rss.getFloat(1)); + assertTrue(rss.next()); + assertEquals(3, rss.getFloat(1)); + assertFalse(rss.next()); + + rss = arr.getResultSet(2, 2); + assertTrue(rss.next()); + assertEquals(2, rss.getFloat(1)); + assertTrue(rss.next()); + assertEquals(3, rss.getFloat(1)); + assertFalse(rss.next()); + assertThrowsContains( + SQLException.class, + () -> arr.getResultSet(new HashMap<>()), + "getResultSet(Map> map) is not supported"); + assertThrowsContains( + SQLException.class, + () -> arr.getResultSet(1, 2, new HashMap<>()), + "getResultSet(long index, int count, Map> map) is not supported"); + arr.free(); + assertTrue(rs.next()); + float[] res2 = rs.getObject(2, float[].class); + assertArrayEquals(val2, res2); + + assertTrue(rs.next()); + float[] res3 = rs.getObject(2, float[].class); + assertArrayEquals(val3, res3); + } + } + @Test public void floatArrayObjArray() throws SQLException { Statement stmt = sharedConn.createStatement();