Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.sql.jdbc
import java.sql.Connection
import java.util.Properties

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{Literal, If}
import org.apache.spark.tags.DockerTest

@DockerTest
Expand All @@ -37,28 +39,32 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
override def dataPreparation(conn: Connection): Unit = {
conn.prepareStatement("CREATE DATABASE foo").executeUpdate()
conn.setCatalog("foo")
conn.prepareStatement("CREATE TABLE bar (a text, b integer, c double precision, d bigint, "
+ "e bit(1), f bit(10), g bytea, h boolean, i inet, j cidr)").executeUpdate()
conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, "
+ "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, "
+ "c10 integer[], c11 text[])").executeUpdate()
conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', "
+ "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16')").executeUpdate()
+ "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', "
+ """'{1, 2}', '{"a", null, "b"}')""").executeUpdate()
}

test("Type mapping for various types") {
val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
val rows = df.collect()
assert(rows.length == 1)
val types = rows(0).toSeq.map(x => x.getClass.toString)
assert(types.length == 10)
assert(types(0).equals("class java.lang.String"))
assert(types(1).equals("class java.lang.Integer"))
assert(types(2).equals("class java.lang.Double"))
assert(types(3).equals("class java.lang.Long"))
assert(types(4).equals("class java.lang.Boolean"))
assert(types(5).equals("class [B"))
assert(types(6).equals("class [B"))
assert(types(7).equals("class java.lang.Boolean"))
assert(types(8).equals("class java.lang.String"))
assert(types(9).equals("class java.lang.String"))
val types = rows(0).toSeq.map(x => x.getClass)
assert(types.length == 12)
assert(classOf[String].isAssignableFrom(types(0)))
assert(classOf[java.lang.Integer].isAssignableFrom(types(1)))
assert(classOf[java.lang.Double].isAssignableFrom(types(2)))
assert(classOf[java.lang.Long].isAssignableFrom(types(3)))
assert(classOf[java.lang.Boolean].isAssignableFrom(types(4)))
assert(classOf[Array[Byte]].isAssignableFrom(types(5)))
assert(classOf[Array[Byte]].isAssignableFrom(types(6)))
assert(classOf[java.lang.Boolean].isAssignableFrom(types(7)))
assert(classOf[String].isAssignableFrom(types(8)))
assert(classOf[String].isAssignableFrom(types(9)))
assert(classOf[Seq[Int]].isAssignableFrom(types(10)))
assert(classOf[Seq[String]].isAssignableFrom(types(11)))
assert(rows(0).getString(0).equals("hello"))
assert(rows(0).getInt(1) == 42)
assert(rows(0).getDouble(2) == 1.25)
Expand All @@ -72,11 +78,17 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
assert(rows(0).getBoolean(7) == true)
assert(rows(0).getString(8) == "172.16.0.42")
assert(rows(0).getString(9) == "192.168.0.0/16")
assert(rows(0).getSeq(10) == Seq(1, 2))
assert(rows(0).getSeq(11) == Seq("a", null, "b"))
}

test("Basic write test") {
val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
df.write.jdbc(jdbcUrl, "public.barcopy", new Properties)
// Test only that it doesn't crash.
df.write.jdbc(jdbcUrl, "public.barcopy", new Properties)
// Test write null values.
df.select(df.queryExecution.analyzed.output.map { a =>
Column(If(Literal(true), Literal(null), a)).as(a.name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just Literal.create(null, a.dataType)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yea, we can simply this.

}: _*).write.jdbc(jdbcUrl, "public.barcopy2", new Properties)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.commons.lang3.StringUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.{GenericArrayData, DateTimeUtils}
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -324,25 +324,27 @@ private[sql] class JDBCRDD(
case object StringConversion extends JDBCConversion
case object TimestampConversion extends JDBCConversion
case object BinaryConversion extends JDBCConversion
case class ArrayConversion(elementConversion: JDBCConversion) extends JDBCConversion

/**
* Maps a StructType to a type tag list.
*/
def getConversions(schema: StructType): Array[JDBCConversion] = {
schema.fields.map(sf => sf.dataType match {
case BooleanType => BooleanConversion
case DateType => DateConversion
case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
case DoubleType => DoubleConversion
case FloatType => FloatConversion
case IntegerType => IntegerConversion
case LongType =>
if (sf.metadata.contains("binarylong")) BinaryLongConversion else LongConversion
case StringType => StringConversion
case TimestampType => TimestampConversion
case BinaryType => BinaryConversion
case _ => throw new IllegalArgumentException(s"Unsupported field $sf")
}).toArray
def getConversions(schema: StructType): Array[JDBCConversion] =
schema.fields.map(sf => getConversions(sf.dataType, sf.metadata))

private def getConversions(dt: DataType, metadata: Metadata): JDBCConversion = dt match {
case BooleanType => BooleanConversion
case DateType => DateConversion
case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
case DoubleType => DoubleConversion
case FloatType => FloatConversion
case IntegerType => IntegerConversion
case LongType => if (metadata.contains("binarylong")) BinaryLongConversion else LongConversion
case StringType => StringConversion
case TimestampType => TimestampConversion
case BinaryType => BinaryConversion
case ArrayType(et, _) => ArrayConversion(getConversions(et, metadata))
case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}")
}

/**
Expand Down Expand Up @@ -420,16 +422,44 @@ private[sql] class JDBCRDD(
mutableRow.update(i, null)
}
case BinaryConversion => mutableRow.update(i, rs.getBytes(pos))
case BinaryLongConversion => {
case BinaryLongConversion =>
val bytes = rs.getBytes(pos)
var ans = 0L
var j = 0
while (j < bytes.size) {
ans = 256 * ans + (255 & bytes(j))
j = j + 1;
j = j + 1
}
mutableRow.setLong(i, ans)
}
case ArrayConversion(elementConversion) =>
val array = rs.getArray(pos).getArray
if (array != null) {
val data = elementConversion match {
case TimestampConversion =>
array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
}
case StringConversion =>
array.asInstanceOf[Array[java.lang.String]]
.map(UTF8String.fromString)
case DateConversion =>
array.asInstanceOf[Array[java.sql.Date]].map { date =>
nullSafeConvert(date, DateTimeUtils.fromJavaDate)
}
case DecimalConversion(p, s) =>
array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
nullSafeConvert[java.math.BigDecimal](decimal, d => Decimal(d, p, s))
}
case BinaryLongConversion =>
throw new IllegalArgumentException(s"Unsupported array element conversion $i")
case _: ArrayConversion =>
throw new IllegalArgumentException("Nested arrays unsupported")
case _ => array.asInstanceOf[Array[Any]]
}
mutableRow.update(i, new GenericArrayData(data))
} else {
mutableRow.update(i, null)
}
}
if (rs.wasNull) mutableRow.setNullAt(i)
i = i + 1
Expand Down Expand Up @@ -488,4 +518,12 @@ private[sql] class JDBCRDD(
nextValue
}
}

private def nullSafeConvert[T](input: T, f: T => Any): Any = {
if (input == null) {
null
} else {
f(input)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.util.Properties
import scala.util.Try

import org.apache.spark.Logging
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType, JdbcDialects}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row}

Expand Down Expand Up @@ -72,6 +72,35 @@ object JdbcUtils extends Logging {
conn.prepareStatement(sql.toString())
}

/**
* Retrieve standard jdbc types.
* @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]])
* @return The default JdbcType for this DataType
*/
def getCommonJDBCType(dt: DataType): Option[JdbcType] = {
dt match {
case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER))
case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT))
case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE))
case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT))
case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT))
case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT))
case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT))
case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB))
case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB))
case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP))
case DateType => Option(JdbcType("DATE", java.sql.Types.DATE))
case t: DecimalType => Option(
JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL))
case _ => None
}
}

private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse(
throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
}

/**
* Saves a partition of a DataFrame to the JDBC database. This is done in
* a single database transaction in order to avoid repeatedly inserting
Expand All @@ -92,7 +121,8 @@ object JdbcUtils extends Logging {
iterator: Iterator[Row],
rddSchema: StructType,
nullTypes: Array[Int],
batchSize: Int): Iterator[Byte] = {
batchSize: Int,
dialect: JdbcDialect): Iterator[Byte] = {
val conn = getConnection()
var committed = false
try {
Expand Down Expand Up @@ -121,6 +151,11 @@ object JdbcUtils extends Logging {
case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i))
case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i))
case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i))
case ArrayType(et, _) =>
val array = conn.createArrayOf(
getJdbcType(et, dialect).databaseTypeDefinition.toLowerCase,
row.getSeq[AnyRef](i).toArray)
stmt.setArray(i + 1, array)
case _ => throw new IllegalArgumentException(
s"Can't translate non-null value for field $i")
}
Expand Down Expand Up @@ -169,23 +204,7 @@ object JdbcUtils extends Logging {
val dialect = JdbcDialects.get(url)
df.schema.fields foreach { field => {
val name = field.name
val typ: String =
dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse(
field.dataType match {
case IntegerType => "INTEGER"
case LongType => "BIGINT"
case DoubleType => "DOUBLE PRECISION"
case FloatType => "REAL"
case ShortType => "INTEGER"
case ByteType => "BYTE"
case BooleanType => "BIT(1)"
case StringType => "TEXT"
case BinaryType => "BLOB"
case TimestampType => "TIMESTAMP"
case DateType => "DATE"
case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})"
case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
})
val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition
val nullable = if (field.nullable) "" else "NOT NULL"
sb.append(s", $name $typ $nullable")
}}
Expand All @@ -202,31 +221,15 @@ object JdbcUtils extends Logging {
properties: Properties = new Properties()) {
val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse(
field.dataType match {
case IntegerType => java.sql.Types.INTEGER
case LongType => java.sql.Types.BIGINT
case DoubleType => java.sql.Types.DOUBLE
case FloatType => java.sql.Types.REAL
case ShortType => java.sql.Types.INTEGER
case ByteType => java.sql.Types.INTEGER
case BooleanType => java.sql.Types.BIT
case StringType => java.sql.Types.CLOB
case BinaryType => java.sql.Types.BLOB
case TimestampType => java.sql.Types.TIMESTAMP
case DateType => java.sql.Types.DATE
case t: DecimalType => java.sql.Types.DECIMAL
case _ => throw new IllegalArgumentException(
s"Can't translate null value for field $field")
})
getJdbcType(field.dataType, dialect).jdbcNullType
}

val rddSchema = df.schema
val driver: String = DriverRegistry.getDriverClassName(url)
val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
val batchSize = properties.getProperty("batchsize", "1000").toInt
df.foreachPartition { iterator =>
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize)
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int)
* for the given Catalyst type.
*/
@DeveloperApi
abstract class JdbcDialect {
abstract class JdbcDialect extends Serializable {
/**
* Check if this dialect instance can handle a certain jdbc url.
* @param url the jdbc url.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc

import java.sql.Types

import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.types._


Expand All @@ -29,22 +30,40 @@ private object PostgresDialect extends JdbcDialect {
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
Option(BinaryType)
} else if (sqlType == Types.OTHER && typeName.equals("cidr")) {
Option(StringType)
} else if (sqlType == Types.OTHER && typeName.equals("inet")) {
Option(StringType)
} else if (sqlType == Types.OTHER && typeName.equals("json")) {
Option(StringType)
} else if (sqlType == Types.OTHER && typeName.equals("jsonb")) {
Option(StringType)
Some(BinaryType)
} else if (sqlType == Types.OTHER) {
toCatalystType(typeName).filter(_ == StringType)
} else if (sqlType == Types.ARRAY && typeName.length > 1 && typeName(0) == '_') {
toCatalystType(typeName.drop(1)).map(ArrayType(_))
} else None
}

// TODO: support more type names.
private def toCatalystType(typeName: String): Option[DataType] = typeName match {
case "bool" => Some(BooleanType)
case "bit" => Some(BinaryType)
case "int2" => Some(ShortType)
case "int4" => Some(IntegerType)
case "int8" | "oid" => Some(LongType)
case "float4" => Some(FloatType)
case "money" | "float8" => Some(DoubleType)
case "text" | "varchar" | "char" | "cidr" | "inet" | "json" | "jsonb" | "uuid" =>
Some(StringType)
case "bytea" => Some(BinaryType)
case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType)
case "date" => Some(DateType)
case "numeric" => Some(DecimalType.SYSTEM_DEFAULT)
case _ => None
}

override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR))
case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY))
case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN))
case StringType => Some(JdbcType("TEXT", Types.CHAR))
case BinaryType => Some(JdbcType("BYTEA", Types.BINARY))
case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN))
case ArrayType(et, _) if et.isInstanceOf[AtomicType] =>
getJDBCType(et).map(_.databaseTypeDefinition)
.orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition))
.map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY))
case _ => None
}

Expand Down