diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala index efd7ca74c796..f1cd3343b792 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala @@ -59,7 +59,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { """ |INSERT INTO numbers VALUES ( |0, - |255, 32767, 2147483647, 9223372036854775807, + |127, 32767, 2147483647, 9223372036854775807, |123456789012345.123456789012345, 123456789012345.123456789012345, |123456789012345.123456789012345, |123, 12345.12, @@ -119,7 +119,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { val types = row.toSeq.map(x => x.getClass.toString) assert(types.length == 12) assert(types(0).equals("class java.lang.Boolean")) - assert(types(1).equals("class java.lang.Integer")) + assert(types(1).equals("class java.lang.Byte")) assert(types(2).equals("class java.lang.Short")) assert(types(3).equals("class java.lang.Integer")) assert(types(4).equals("class java.lang.Long")) @@ -131,7 +131,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { assert(types(10).equals("class java.math.BigDecimal")) assert(types(11).equals("class java.math.BigDecimal")) assert(row.getBoolean(0) == false) - assert(row.getInt(1) == 255) + assert(row.getByte(1) == 127) assert(row.getShort(2) == 32767) assert(row.getInt(3) == 2147483647) assert(row.getLong(4) == 9223372036854775807L) @@ -202,4 +202,46 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { df2.write.jdbc(jdbcUrl, "datescopy", new Properties) df3.write.jdbc(jdbcUrl, "stringscopy", new Properties) } + + test("SPARK-29644: Write tables with ShortType") { + import testImplicits._ + val df = Seq(-32768.toShort, 0.toShort, 1.toShort, 38.toShort, 32768.toShort).toDF("a") + val tablename = "shorttable" + df.write + .format("jdbc") + .mode("overwrite") + .option("url", jdbcUrl) + .option("dbtable", tablename) + .save() + val df2 = spark.read + .format("jdbc") + .option("url", jdbcUrl) + .option("dbtable", tablename) + .load() + assert(df.count == df2.count) + val rows = df2.collect() + val colType = rows(0).toSeq.map(x => x.getClass.toString) + assert(colType(0) == "class java.lang.Short") + } + + test("SPARK-29644: Write tables with ByteType") { + import testImplicits._ + val df = Seq(-127.toByte, 0.toByte, 1.toByte, 38.toByte, 128.toByte).toDF("a") + val tablename = "bytetable" + df.write + .format("jdbc") + .mode("overwrite") + .option("url", jdbcUrl) + .option("dbtable", tablename) + .save() + val df2 = spark.read + .format("jdbc") + .option("url", jdbcUrl) + .option("dbtable", tablename) + .load() + assert(df.count == df2.count) + val rows = df2.collect() + val colType = rows(0).toSeq.map(x => x.getClass.toString) + assert(colType(0) == "class java.lang.Byte") + } } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index 9cd5c4ec41a5..5b08093d930b 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -82,7 +82,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { assert(types.length == 9) assert(types(0).equals("class java.lang.Boolean")) assert(types(1).equals("class java.lang.Long")) - assert(types(2).equals("class java.lang.Integer")) + assert(types(2).equals("class java.lang.Short")) assert(types(3).equals("class java.lang.Integer")) assert(types(4).equals("class java.lang.Integer")) assert(types(5).equals("class java.lang.Long")) @@ -91,7 +91,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { assert(types(8).equals("class java.lang.Double")) assert(rows(0).getBoolean(0) == false) assert(rows(0).getLong(1) == 0x225) - assert(rows(0).getInt(2) == 17) + assert(rows(0).getShort(2) == 17) assert(rows(0).getInt(3) == 77777) assert(rows(0).getInt(4) == 123456789) assert(rows(0).getLong(5) == 123456789012345L) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index c0f628ff0410..f19778f6f05f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -165,8 +165,8 @@ object JdbcUtils extends Logging { case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT)) case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE)) case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT)) - case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT)) - case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT)) + case ShortType => Option(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) + case ByteType => Option(JdbcType("TINYINT", java.sql.Types.TINYINT)) case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT)) case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB)) case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB)) @@ -230,7 +230,7 @@ object JdbcUtils extends Logging { case java.sql.Types.REF => StringType case java.sql.Types.REF_CURSOR => null case java.sql.Types.ROWID => LongType - case java.sql.Types.SMALLINT => IntegerType + case java.sql.Types.SMALLINT => ShortType case java.sql.Types.SQLXML => StringType case java.sql.Types.STRUCT => StringType case java.sql.Types.TIME => TimestampType @@ -239,7 +239,7 @@ object JdbcUtils extends Logging { case java.sql.Types.TIMESTAMP => TimestampType case java.sql.Types.TIMESTAMP_WITH_TIMEZONE => null - case java.sql.Types.TINYINT => IntegerType + case java.sql.Types.TINYINT => ByteType case java.sql.Types.VARBINARY => BinaryType case java.sql.Types.VARCHAR => StringType case _ => @@ -541,11 +541,11 @@ object JdbcUtils extends Logging { case ShortType => (stmt: PreparedStatement, row: Row, pos: Int) => - stmt.setInt(pos + 1, row.getShort(pos)) + stmt.setShort(pos + 1, row.getShort(pos)) case ByteType => (stmt: PreparedStatement, row: Row, pos: Int) => - stmt.setInt(pos + 1, row.getByte(pos)) + stmt.setByte(pos + 1, row.getByte(pos)) case BooleanType => (stmt: PreparedStatement, row: Row, pos: Int) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 2dcedc3fc1cc..348c1a749c97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -578,8 +578,8 @@ class JDBCSuite extends QueryTest assert(rows.length === 1) assert(rows(0).getInt(0) === 1) assert(rows(0).getBoolean(1) === false) - assert(rows(0).getInt(2) === 3) - assert(rows(0).getInt(3) === 4) + assert(rows(0).getByte(2) === 3.toByte) + assert(rows(0).getShort(3) === 4.toShort) assert(rows(0).getLong(4) === 1234567890123L) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index b751ec2de482..e8155f42d369 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -543,4 +543,46 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { }.getMessage assert(errMsg.contains("Statement was canceled or the session timed out")) } + + test("SPARK-29644: Write tables with ShortType") { + import testImplicits._ + val df = Seq(-32768.toShort, 0.toShort, 1.toShort, 38.toShort, 32768.toShort).toDF("a") + val tablename = "shorttable" + df.write + .format("jdbc") + .mode("overwrite") + .option("url", url) + .option("dbtable", tablename) + .save() + val df2 = spark.read + .format("jdbc") + .option("url", url) + .option("dbtable", tablename) + .load() + assert(df.count == df2.count) + val rows = df2.collect() + val colType = rows(0).toSeq.map(x => x.getClass.toString) + assert(colType(0) == "class java.lang.Short") + } + + test("SPARK-29644: Write tables with ByteType") { + import testImplicits._ + val df = Seq(-127.toByte, 0.toByte, 1.toByte, 38.toByte, 128.toByte).toDF("a") + val tablename = "bytetable" + df.write + .format("jdbc") + .mode("overwrite") + .option("url", url) + .option("dbtable", tablename) + .save() + val df2 = spark.read + .format("jdbc") + .option("url", url) + .option("dbtable", tablename) + .load() + assert(df.count == df2.count) + val rows = df2.collect() + val colType = rows(0).toSeq.map(x => x.getClass.toString) + assert(colType(0) == "class java.lang.Byte") + } }