diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 018a009fbda6..1a3f88945be2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLTimestamp import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -136,7 +137,7 @@ private[sql] object JDBCRDD extends Logging { val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls val metadata = new MetadataBuilder().putString("name", columnName) val columnType = - dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( + dialect.getCatalystType(dataType, typeName, fieldSize, fieldScale, metadata).getOrElse( getCatalystType(dataType, fieldSize, fieldScale, isSigned)) fields(i) = StructField(columnName, columnType, nullable, metadata.build()) i = i + 1 @@ -324,12 +325,13 @@ private[sql] class JDBCRDD( case object StringConversion extends JDBCConversion case object TimestampConversion extends JDBCConversion case object BinaryConversion extends JDBCConversion + case class ArrayConversion(elementConversion: JDBCConversion) extends JDBCConversion /** - * Maps a StructType to a type tag list. + * Maps a StructField and its associated DataType to a type tag. */ - def getConversions(schema: StructType): Array[JDBCConversion] = { - schema.fields.map(sf => sf.dataType match { + def getConversion(sf: StructField, dataType: DataType): JDBCConversion = { + dataType match { case BooleanType => BooleanConversion case DateType => DateConversion case DecimalType.Fixed(p, s) => DecimalConversion(p, s) @@ -341,8 +343,16 @@ private[sql] class JDBCRDD( case StringType => StringConversion case TimestampType => TimestampConversion case BinaryType => BinaryConversion + case ArrayType(d, x) => ArrayConversion(getConversion(sf, d)) case _ => throw new IllegalArgumentException(s"Unsupported field $sf") - }).toArray + } + } + + /** + * Maps a StructType to a type tag list. + */ + def getConversions(schema: StructType): Array[JDBCConversion] = { + schema.fields.map(sf => getConversion(sf, sf.dataType)) } /** @@ -376,6 +386,10 @@ private[sql] class JDBCRDD( val conversions = getConversions(schema) val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType)) + def convert_date(dateVal: java.sql.Date): Int = DateTimeUtils.fromJavaDate(dateVal) + def convert_decimal(decimal: java.math.BigDecimal, p: Int, s: Int): Decimal = Decimal(decimal, p, s) + def convert_timestamp(ts: java.sql.Timestamp): SQLTimestamp = DateTimeUtils.fromJavaTimestamp(ts) + def getNext(): InternalRow = { if (rs.next()) { var i = 0 @@ -387,7 +401,7 @@ private[sql] class JDBCRDD( // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it. val dateVal = rs.getDate(pos) if (dateVal != null) { - mutableRow.setInt(i, DateTimeUtils.fromJavaDate(dateVal)) + mutableRow.setInt(i, convert_date(dateVal)) } else { mutableRow.update(i, null) } @@ -404,7 +418,7 @@ private[sql] class JDBCRDD( if (decimalVal == null) { mutableRow.update(i, null) } else { - mutableRow.update(i, Decimal(decimalVal, p, s)) + mutableRow.update(i, convert_decimal(decimalVal, p, s)) } case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) @@ -415,21 +429,39 @@ private[sql] class JDBCRDD( case TimestampConversion => val t = rs.getTimestamp(pos) if (t != null) { - mutableRow.setLong(i, DateTimeUtils.fromJavaTimestamp(t)) + mutableRow.setLong(i, convert_timestamp(t)) } else { mutableRow.update(i, null) } case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) - case BinaryLongConversion => { + case BinaryLongConversion => val bytes = rs.getBytes(pos) var ans = 0L var j = 0 while (j < bytes.size) { ans = 256 * ans + (255 & bytes(j)) - j = j + 1; + j = j + 1 } mutableRow.setLong(i, ans) - } + + case ArrayConversion(BinaryLongConversion) => throw new IllegalArgumentException(s"Unsupported array element conversion $i") + case ArrayConversion(subConvert) => + val a = rs.getArray(pos) + if (a != null) { + val x = a.getArray + val genericArrayData = new GenericArrayData(subConvert match { + case TimestampConversion => x.asInstanceOf[Array[java.sql.Timestamp]].map(convert_timestamp) + case StringConversion => x.asInstanceOf[Array[java.lang.String]].map(UTF8String.fromString) + case DateConversion => x.asInstanceOf[Array[java.sql.Date]].map(convert_date) + case DecimalConversion(p, s) => x.asInstanceOf[Array[java.math.BigDecimal]].map(convert_decimal(_, p, s)) + case ArrayConversion(_) => throw new IllegalArgumentException("Nested arrays unsupported") + case _ => x.asInstanceOf[Array[Any]] + }) + mutableRow.update(i, genericArrayData) + } else { + mutableRow.update(i, null) + } + } if (rs.wasNull) mutableRow.setNullAt(i) i = i + 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index f89d55b20e21..9d7bf03422c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -23,7 +23,7 @@ import java.util.Properties import scala.util.Try import org.apache.spark.Logging -import org.apache.spark.sql.jdbc.JdbcDialects +import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row} @@ -92,7 +92,8 @@ object JdbcUtils extends Logging { iterator: Iterator[Row], rddSchema: StructType, nullTypes: Array[Int], - batchSize: Int): Iterator[Byte] = { + batchSize: Int, + dialect: JdbcDialect): Iterator[Byte] = { val conn = getConnection() var committed = false try { @@ -121,6 +122,24 @@ object JdbcUtils extends Logging { case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) + case ArrayType(elemType, _) => + val elemDataBaseType = dialect.getJDBCType(elemType) + .map(_.databaseTypeDefinition) + .getOrElse( + dialect.getCommonJDBCType(elemType).map(_.databaseTypeDefinition).getOrElse( + throw new IllegalArgumentException( + s"Can't determine array element type for $elemType in field $i") + )) + val array: Array[AnyRef] = elemType match { + case _: ArrayType => + throw new IllegalArgumentException( + s"Nested array writes to JDBC are not supported for field $i") + case BinaryType => row.getSeq[Array[Byte]](i).toArray + case TimestampType => row.getSeq[java.sql.Timestamp](i).toArray + case DateType => row.getSeq[java.sql.Date](i).toArray + case _ => row.getSeq[AnyRef](i).toArray + } + stmt.setArray(i + 1, conn.createArrayOf(elemDataBaseType, array)) case _ => throw new IllegalArgumentException( s"Can't translate non-null value for field $i") } @@ -171,21 +190,9 @@ object JdbcUtils extends Logging { val name = field.name val typ: String = dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( - field.dataType match { - case IntegerType => "INTEGER" - case LongType => "BIGINT" - case DoubleType => "DOUBLE PRECISION" - case FloatType => "REAL" - case ShortType => "INTEGER" - case ByteType => "BYTE" - case BooleanType => "BIT(1)" - case StringType => "TEXT" - case BinaryType => "BLOB" - case TimestampType => "TIMESTAMP" - case DateType => "DATE" - case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})" - case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") - }) + dialect.getCommonJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( + throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") + )) val nullable = if (field.nullable) "" else "NOT NULL" sb.append(s", $name $typ $nullable") }} @@ -203,30 +210,18 @@ object JdbcUtils extends Logging { val dialect = JdbcDialects.get(url) val nullTypes: Array[Int] = df.schema.fields.map { field => dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( - field.dataType match { - case IntegerType => java.sql.Types.INTEGER - case LongType => java.sql.Types.BIGINT - case DoubleType => java.sql.Types.DOUBLE - case FloatType => java.sql.Types.REAL - case ShortType => java.sql.Types.INTEGER - case ByteType => java.sql.Types.INTEGER - case BooleanType => java.sql.Types.BIT - case StringType => java.sql.Types.CLOB - case BinaryType => java.sql.Types.BLOB - case TimestampType => java.sql.Types.TIMESTAMP - case DateType => java.sql.Types.DATE - case t: DecimalType => java.sql.Types.DECIMAL - case _ => throw new IllegalArgumentException( + dialect.getCommonJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( + throw new IllegalArgumentException( s"Can't translate null value for field $field") - }) - } + )) + } val rddSchema = df.schema val driver: String = DriverRegistry.getDriverClassName(url) val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) val batchSize = properties.getProperty("batchsize", "1000").toInt df.foreachPartition { iterator => - savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize) + savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index 84f68e779c38..7403b4498282 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -26,8 +26,8 @@ private object DerbyDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:derby") - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + override def getCatalystType(sqlType: Int, typeName: String, size: Int, scale: Int, + md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.REAL) Option(FloatType) else None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 14bfea4e3e28..c102158ed6e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -65,12 +65,13 @@ abstract class JdbcDialect { * @param sqlType The sql type (see java.sql.Types) * @param typeName The sql type name (e.g. "BIGINT UNSIGNED") * @param size The size of the type. + * @param scale The scale of the type. Generally used for decimal types. * @param md Result metadata associated with this type. * @return The actual DataType (subclasses of [[org.apache.spark.sql.types.DataType]]) * or null if the default type mapping should be used. */ - def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = None + def getCatalystType(sqlType: Int, typeName: String, size: Int, scale: Int, + md: MetadataBuilder): Option[DataType] = None /** * Retrieve the jdbc / sql type for a given datatype. @@ -79,6 +80,30 @@ abstract class JdbcDialect { */ def getJDBCType(dt: DataType): Option[JdbcType] = None + /** + * Retrieve standard jdbc types. + * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]]) + * @return The default JdbcType for this DataType + */ + def getCommonJDBCType(dt: DataType): Option[JdbcType] = { + dt match { + case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER)) + case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT)) + case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE)) + case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT)) + case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT)) + case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT)) + case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT)) + case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB)) + case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB)) + case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP)) + case DateType => Option(JdbcType("DATE", java.sql.Types.DATE)) + case t: DecimalType => Option( + JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL)) + case _ => None + } + } + /** * Quotes the identifier. This is used to put quotes around the identifier in case the column * name is a reserved keyword, or in case it contains characters that require quotes (e.g. space). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 3eb722b070d5..41cabf63534a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -24,8 +24,8 @@ private object MsSqlServerDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:sqlserver") - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + override def getCatalystType(sqlType: Int, typeName: String, size: Int, scale: Int, + md: MetadataBuilder): Option[DataType] = { if (typeName.contains("datetimeoffset")) { // String is recommend by Microsoft SQL Server for datetimeoffset types in non-MS clients Option(StringType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index da413ed1f08b..4bf329c757ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -26,8 +26,8 @@ private case object MySQLDialect extends JdbcDialect { override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + override def getCatalystType(sqlType: Int, typeName: String, size: Int, scale: Int, + md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as // byte arrays instead of longs. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 4165c382689f..828ad55bbbe3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -26,8 +26,8 @@ private case object OracleDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle") - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + override def getCatalystType(sqlType: Int, typeName: String, size: Int, scale: Int, + md: MetadataBuilder): Option[DataType] = { // Handle NUMBER fields that have no precision/scale in special way // because JDBC ResultSetMetaData converts this to 0 procision and -127 scale // For more details, please see diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index e701a7fcd9e1..bb34578f794b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -26,8 +26,8 @@ private object PostgresDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + override def getCatalystType(sqlType: Int, typeName: String, size: Int, scale: Int, + md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { Option(BinaryType) } else if (sqlType == Types.OTHER && typeName.equals("cidr")) { @@ -38,6 +38,25 @@ private object PostgresDialect extends JdbcDialect { Option(StringType) } else if (sqlType == Types.OTHER && typeName.equals("jsonb")) { Option(StringType) + } else if (sqlType == Types.OTHER && typeName.equals("uuid")) { + Some(StringType) + } else if (sqlType == Types.ARRAY) { + typeName match { + case "_bit" | "_bool" => Option(ArrayType(BooleanType)) + case "_int2" => Option(ArrayType(ShortType)) + case "_int4" => Option(ArrayType(IntegerType)) + case "_int8" | "_oid" => Option(ArrayType(LongType)) + case "_float4" => Option(ArrayType(FloatType)) + case "_money" | "_float8" => Option(ArrayType(DoubleType)) + case "_text" | "_varchar" | "_char" | "_bpchar" | "_name" => Option(ArrayType(StringType)) + case "_bytea" => Option(ArrayType(BinaryType)) + case "_timestamp" | "_timestamptz" | "_time" | "_timetz" => Option(ArrayType(TimestampType)) + case "_date" => Option(ArrayType(DateType)) + case "_numeric" + if size != 0 || scale != 0 => Option(ArrayType(DecimalType(size, scale))) + case "_numeric" => Option(ArrayType(DecimalType.SYSTEM_DEFAULT)) + case _ => throw new IllegalArgumentException(s"Unhandled postgres array type $typeName") + } } else None } @@ -45,6 +64,13 @@ private object PostgresDialect extends JdbcDialect { case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR)) case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY)) case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) + case ArrayType(t, _) => + val subtype = getJDBCType(t).map(_.databaseTypeDefinition).getOrElse( + getCommonJDBCType(t).map(_.databaseTypeDefinition).getOrElse( + throw new IllegalArgumentException(s"Unexpected JDBC array subtype $t") + ) + ) + Some(JdbcType(s"$subtype[]", java.sql.Types.ARRAY)) case _ => None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index d530b1a469ce..bc91d2bdadf0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -40,8 +40,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext val testH2Dialect = new JdbcDialect { override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = + override def getCatalystType(sqlType: Int, typeName: String, size: Int, scale: Int, + md: MetadataBuilder): Option[DataType] = Some(StringType) } @@ -437,7 +437,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext val agg = new AggregatedDialect(List(new JdbcDialect { override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:") override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = + sqlType: Int, typeName: String, size: Int, scale: Int, + md: MetadataBuilder): Option[DataType] = if (sqlType % 2 == 0) { Some(LongType) } else { @@ -446,8 +447,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext }, testH2Dialect)) assert(agg.canHandle("jdbc:h2:xxx")) assert(!agg.canHandle("jdbc:h2")) - assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) - assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) + assert(agg.getCatalystType(0, "", 1, 0, null) === Some(LongType)) + assert(agg.getCatalystType(1, "", 1, 0, null) === Some(StringType)) } test("DB2Dialect type mapping") { @@ -458,8 +459,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext test("PostgresDialect type mapping") { val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") - assert(Postgres.getCatalystType(java.sql.Types.OTHER, "json", 1, null) === Some(StringType)) - assert(Postgres.getCatalystType(java.sql.Types.OTHER, "jsonb", 1, null) === Some(StringType)) + assert(Postgres.getCatalystType(java.sql.Types.OTHER, "json", 1, 0, null) === Some(StringType)) + assert(Postgres.getCatalystType(java.sql.Types.OTHER, "jsonb", 1, 0, null) === Some(StringType)) } test("DerbyDialect jdbc type mapping") {