Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLTimestamp
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -136,7 +137,7 @@ private[sql] object JDBCRDD extends Logging {
val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
val metadata = new MetadataBuilder().putString("name", columnName)
val columnType =
dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
dialect.getCatalystType(dataType, typeName, fieldSize, fieldScale, metadata).getOrElse(
getCatalystType(dataType, fieldSize, fieldScale, isSigned))
fields(i) = StructField(columnName, columnType, nullable, metadata.build())
i = i + 1
Expand Down Expand Up @@ -324,12 +325,13 @@ private[sql] class JDBCRDD(
case object StringConversion extends JDBCConversion
case object TimestampConversion extends JDBCConversion
case object BinaryConversion extends JDBCConversion
case class ArrayConversion(elementConversion: JDBCConversion) extends JDBCConversion

/**
* Maps a StructType to a type tag list.
* Maps a StructField and its associated DataType to a type tag.
*/
def getConversions(schema: StructType): Array[JDBCConversion] = {
schema.fields.map(sf => sf.dataType match {
def getConversion(sf: StructField, dataType: DataType): JDBCConversion = {
dataType match {
case BooleanType => BooleanConversion
case DateType => DateConversion
case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
Expand All @@ -341,8 +343,16 @@ private[sql] class JDBCRDD(
case StringType => StringConversion
case TimestampType => TimestampConversion
case BinaryType => BinaryConversion
case ArrayType(d, x) => ArrayConversion(getConversion(sf, d))
case _ => throw new IllegalArgumentException(s"Unsupported field $sf")
}).toArray
}
}

/**
* Maps a StructType to a type tag list.
*/
def getConversions(schema: StructType): Array[JDBCConversion] = {
schema.fields.map(sf => getConversion(sf, sf.dataType))
}

/**
Expand Down Expand Up @@ -376,6 +386,10 @@ private[sql] class JDBCRDD(
val conversions = getConversions(schema)
val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType))

def convert_date(dateVal: java.sql.Date): Int = DateTimeUtils.fromJavaDate(dateVal)
def convert_decimal(decimal: java.math.BigDecimal, p: Int, s: Int): Decimal = Decimal(decimal, p, s)
def convert_timestamp(ts: java.sql.Timestamp): SQLTimestamp = DateTimeUtils.fromJavaTimestamp(ts)

def getNext(): InternalRow = {
if (rs.next()) {
var i = 0
Expand All @@ -387,7 +401,7 @@ private[sql] class JDBCRDD(
// DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
val dateVal = rs.getDate(pos)
if (dateVal != null) {
mutableRow.setInt(i, DateTimeUtils.fromJavaDate(dateVal))
mutableRow.setInt(i, convert_date(dateVal))
} else {
mutableRow.update(i, null)
}
Expand All @@ -404,7 +418,7 @@ private[sql] class JDBCRDD(
if (decimalVal == null) {
mutableRow.update(i, null)
} else {
mutableRow.update(i, Decimal(decimalVal, p, s))
mutableRow.update(i, convert_decimal(decimalVal, p, s))
}
case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos))
case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos))
Expand All @@ -415,21 +429,39 @@ private[sql] class JDBCRDD(
case TimestampConversion =>
val t = rs.getTimestamp(pos)
if (t != null) {
mutableRow.setLong(i, DateTimeUtils.fromJavaTimestamp(t))
mutableRow.setLong(i, convert_timestamp(t))
} else {
mutableRow.update(i, null)
}
case BinaryConversion => mutableRow.update(i, rs.getBytes(pos))
case BinaryLongConversion => {
case BinaryLongConversion =>
val bytes = rs.getBytes(pos)
var ans = 0L
var j = 0
while (j < bytes.size) {
ans = 256 * ans + (255 & bytes(j))
j = j + 1;
j = j + 1
}
mutableRow.setLong(i, ans)
}

case ArrayConversion(BinaryLongConversion) => throw new IllegalArgumentException(s"Unsupported array element conversion $i")
case ArrayConversion(subConvert) =>
val a = rs.getArray(pos)
if (a != null) {
val x = a.getArray
val genericArrayData = new GenericArrayData(subConvert match {
case TimestampConversion => x.asInstanceOf[Array[java.sql.Timestamp]].map(convert_timestamp)
case StringConversion => x.asInstanceOf[Array[java.lang.String]].map(UTF8String.fromString)
case DateConversion => x.asInstanceOf[Array[java.sql.Date]].map(convert_date)
case DecimalConversion(p, s) => x.asInstanceOf[Array[java.math.BigDecimal]].map(convert_decimal(_, p, s))
case ArrayConversion(_) => throw new IllegalArgumentException("Nested arrays unsupported")
case _ => x.asInstanceOf[Array[Any]]
})
mutableRow.update(i, genericArrayData)
} else {
mutableRow.update(i, null)
}

}
if (rs.wasNull) mutableRow.setNullAt(i)
i = i + 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.util.Properties
import scala.util.Try

import org.apache.spark.Logging
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row}

Expand Down Expand Up @@ -92,7 +92,8 @@ object JdbcUtils extends Logging {
iterator: Iterator[Row],
rddSchema: StructType,
nullTypes: Array[Int],
batchSize: Int): Iterator[Byte] = {
batchSize: Int,
dialect: JdbcDialect): Iterator[Byte] = {
val conn = getConnection()
var committed = false
try {
Expand Down Expand Up @@ -121,6 +122,24 @@ object JdbcUtils extends Logging {
case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i))
case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i))
case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ISTM we need to check if input types are valid for target databases in advance, e.g., in JavaUtils#saveTable.
JavaUtils#savePartition should simply put input data as given typed-data.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the particular dialect does not support these types saveTable should toss an exception when building the nullTypes array

case ArrayType(elemType, _) =>
val elemDataBaseType = dialect.getJDBCType(elemType)
.map(_.databaseTypeDefinition)
.getOrElse(
dialect.getCommonJDBCType(elemType).map(_.databaseTypeDefinition).getOrElse(
throw new IllegalArgumentException(
s"Can't determine array element type for $elemType in field $i")
))
val array: Array[AnyRef] = elemType match {
case _: ArrayType =>
throw new IllegalArgumentException(
s"Nested array writes to JDBC are not supported for field $i")
case BinaryType => row.getSeq[Array[Byte]](i).toArray
case TimestampType => row.getSeq[java.sql.Timestamp](i).toArray
case DateType => row.getSeq[java.sql.Date](i).toArray
case _ => row.getSeq[AnyRef](i).toArray
}
stmt.setArray(i + 1, conn.createArrayOf(elemDataBaseType, array))
case _ => throw new IllegalArgumentException(
s"Can't translate non-null value for field $i")
}
Expand Down Expand Up @@ -171,21 +190,9 @@ object JdbcUtils extends Logging {
val name = field.name
val typ: String =
dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse(
field.dataType match {
case IntegerType => "INTEGER"
case LongType => "BIGINT"
case DoubleType => "DOUBLE PRECISION"
case FloatType => "REAL"
case ShortType => "INTEGER"
case ByteType => "BYTE"
case BooleanType => "BIT(1)"
case StringType => "TEXT"
case BinaryType => "BLOB"
case TimestampType => "TIMESTAMP"
case DateType => "DATE"
case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})"
case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
})
dialect.getCommonJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse(
throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
))
val nullable = if (field.nullable) "" else "NOT NULL"
sb.append(s", $name $typ $nullable")
}}
Expand All @@ -203,30 +210,18 @@ object JdbcUtils extends Logging {
val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse(
field.dataType match {
case IntegerType => java.sql.Types.INTEGER
case LongType => java.sql.Types.BIGINT
case DoubleType => java.sql.Types.DOUBLE
case FloatType => java.sql.Types.REAL
case ShortType => java.sql.Types.INTEGER
case ByteType => java.sql.Types.INTEGER
case BooleanType => java.sql.Types.BIT
case StringType => java.sql.Types.CLOB
case BinaryType => java.sql.Types.BLOB
case TimestampType => java.sql.Types.TIMESTAMP
case DateType => java.sql.Types.DATE
case t: DecimalType => java.sql.Types.DECIMAL
case _ => throw new IllegalArgumentException(
dialect.getCommonJDBCType(field.dataType).map(_.jdbcNullType).getOrElse(
throw new IllegalArgumentException(
s"Can't translate null value for field $field")
})
}
))
}

val rddSchema = df.schema
val driver: String = DriverRegistry.getDriverClassName(url)
val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
val batchSize = properties.getProperty("batchsize", "1000").toInt
df.foreachPartition { iterator =>
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize)
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ private object DerbyDialect extends JdbcDialect {

override def canHandle(url: String): Boolean = url.startsWith("jdbc:derby")

override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
override def getCatalystType(sqlType: Int, typeName: String, size: Int, scale: Int,
md: MetadataBuilder): Option[DataType] = {
if (sqlType == Types.REAL) Option(FloatType) else None
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,13 @@ abstract class JdbcDialect {
* @param sqlType The sql type (see java.sql.Types)
* @param typeName The sql type name (e.g. "BIGINT UNSIGNED")
* @param size The size of the type.
* @param scale The scale of the type. Generally used for decimal types.
* @param md Result metadata associated with this type.
* @return The actual DataType (subclasses of [[org.apache.spark.sql.types.DataType]])
* or null if the default type mapping should be used.
*/
def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = None
def getCatalystType(sqlType: Int, typeName: String, size: Int, scale: Int,
md: MetadataBuilder): Option[DataType] = None

/**
* Retrieve the jdbc / sql type for a given datatype.
Expand All @@ -79,6 +80,30 @@ abstract class JdbcDialect {
*/
def getJDBCType(dt: DataType): Option[JdbcType] = None

/**
* Retrieve standard jdbc types.
* @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]])
* @return The default JdbcType for this DataType
*/
def getCommonJDBCType(dt: DataType): Option[JdbcType] = {
dt match {
case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER))
case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT))
case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE))
case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT))
case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT))
case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT))
case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT))
case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB))
case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB))
case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP))
case DateType => Option(JdbcType("DATE", java.sql.Types.DATE))
case t: DecimalType => Option(
JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL))
case _ => None
}
}

/**
* Quotes the identifier. This is used to put quotes around the identifier in case the column
* name is a reserved keyword, or in case it contains characters that require quotes (e.g. space).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ private object MsSqlServerDialect extends JdbcDialect {

override def canHandle(url: String): Boolean = url.startsWith("jdbc:sqlserver")

override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
override def getCatalystType(sqlType: Int, typeName: String, size: Int, scale: Int,
md: MetadataBuilder): Option[DataType] = {
if (typeName.contains("datetimeoffset")) {
// String is recommend by Microsoft SQL Server for datetimeoffset types in non-MS clients
Option(StringType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ private case object MySQLDialect extends JdbcDialect {

override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql")

override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
override def getCatalystType(sqlType: Int, typeName: String, size: Int, scale: Int,
md: MetadataBuilder): Option[DataType] = {
if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) {
// This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as
// byte arrays instead of longs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ private case object OracleDialect extends JdbcDialect {

override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle")

override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
override def getCatalystType(sqlType: Int, typeName: String, size: Int, scale: Int,
md: MetadataBuilder): Option[DataType] = {
// Handle NUMBER fields that have no precision/scale in special way
// because JDBC ResultSetMetaData converts this to 0 procision and -127 scale
// For more details, please see
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ private object PostgresDialect extends JdbcDialect {

override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql")

override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
override def getCatalystType(sqlType: Int, typeName: String, size: Int, scale: Int,
md: MetadataBuilder): Option[DataType] = {
if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
Option(BinaryType)
} else if (sqlType == Types.OTHER && typeName.equals("cidr")) {
Expand All @@ -38,13 +38,39 @@ private object PostgresDialect extends JdbcDialect {
Option(StringType)
} else if (sqlType == Types.OTHER && typeName.equals("jsonb")) {
Option(StringType)
} else if (sqlType == Types.OTHER && typeName.equals("uuid")) {
Some(StringType)
} else if (sqlType == Types.ARRAY) {
typeName match {
case "_bit" | "_bool" => Option(ArrayType(BooleanType))
case "_int2" => Option(ArrayType(ShortType))
case "_int4" => Option(ArrayType(IntegerType))
case "_int8" | "_oid" => Option(ArrayType(LongType))
case "_float4" => Option(ArrayType(FloatType))
case "_money" | "_float8" => Option(ArrayType(DoubleType))
case "_text" | "_varchar" | "_char" | "_bpchar" | "_name" => Option(ArrayType(StringType))
case "_bytea" => Option(ArrayType(BinaryType))
case "_timestamp" | "_timestamptz" | "_time" | "_timetz" => Option(ArrayType(TimestampType))
case "_date" => Option(ArrayType(DateType))
case "_numeric"
if size != 0 || scale != 0 => Option(ArrayType(DecimalType(size, scale)))
case "_numeric" => Option(ArrayType(DecimalType.SYSTEM_DEFAULT))
case _ => throw new IllegalArgumentException(s"Unhandled postgres array type $typeName")
}
} else None
}

override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR))
case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY))
case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN))
case ArrayType(t, _) =>
val subtype = getJDBCType(t).map(_.databaseTypeDefinition).getOrElse(
getCommonJDBCType(t).map(_.databaseTypeDefinition).getOrElse(
throw new IllegalArgumentException(s"Unexpected JDBC array subtype $t")
)
)
Some(JdbcType(s"$subtype[]", java.sql.Types.ARRAY))
case _ => None
}

Expand Down
Loading