diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala index efd7ca74c796..604d2ce55d62 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala @@ -119,7 +119,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { val types = row.toSeq.map(x => x.getClass.toString) assert(types.length == 12) assert(types(0).equals("class java.lang.Boolean")) - assert(types(1).equals("class java.lang.Integer")) + assert(types(1).equals("class java.lang.Byte")) assert(types(2).equals("class java.lang.Short")) assert(types(3).equals("class java.lang.Integer")) assert(types(4).equals("class java.lang.Long")) @@ -131,7 +131,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { assert(types(10).equals("class java.math.BigDecimal")) assert(types(11).equals("class java.math.BigDecimal")) assert(row.getBoolean(0) == false) - assert(row.getInt(1) == 255) + assert(row.getByte(1) == 255.toByte) assert(row.getShort(2) == 32767) assert(row.getInt(3) == 2147483647) assert(row.getLong(4) == 9223372036854775807L) @@ -202,4 +202,25 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { df2.write.jdbc(jdbcUrl, "datescopy", new Properties) df3.write.jdbc(jdbcUrl, "stringscopy", new Properties) } + + test("Write tables with BYTETYPE") { + import testImplicits._ + val df = Seq(-127.toByte, 0.toByte, 1.toByte, 38.toByte, 128.toByte).toDF("a") + val tablename = "bytetable" + df.write + .format("jdbc") + .mode("overwrite") + .option("url", jdbcUrl) + .option("dbtable", tablename) + .save() + val df2 = spark.read + .format("jdbc") + .option("url", jdbcUrl) + .option("dbtable", tablename) + .load() + assert(df.count == df2.count) + val rows = df2.collect() + val colType = rows(0).toSeq.map(x => x.getClass.toString) + assert(colType(0) == "class java.lang.Byte") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 86a27b5afc25..5c4a38641f62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -550,7 +550,7 @@ object JdbcUtils extends Logging { case ByteType => (stmt: PreparedStatement, row: Row, pos: Int) => - stmt.setInt(pos + 1, row.getByte(pos)) + stmt.setByte(pos + 1, row.getByte(pos)) case BooleanType => (stmt: PreparedStatement, row: Row, pos: Int) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 805f73dee141..d770cb19fb49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -33,6 +33,7 @@ private object MsSqlServerDialect extends JdbcDialect { sqlType match { case java.sql.Types.SMALLINT => Some(ShortType) case java.sql.Types.REAL => Some(FloatType) + case java.sql.Types.TINYINT => Some(ByteType) case _ => None } } @@ -44,6 +45,7 @@ private object MsSqlServerDialect extends JdbcDialect { case BooleanType => Some(JdbcType("BIT", java.sql.Types.BIT)) case BinaryType => Some(JdbcType("VARBINARY(MAX)", java.sql.Types.VARBINARY)) case ShortType => Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) + case ByteType => Some(JdbcType("TINYINT", java.sql.Types.TINYINT)) case _ => None }