diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index b23e5a7722004..586ae933057be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.jdbc +import java.text.SimpleDateFormat + import scala.collection.mutable.ArrayBuffer import org.apache.spark.Partition @@ -31,6 +33,7 @@ import org.apache.spark.sql.types.StructType * Instructions on how to partition the table among workers. */ private[sql] case class JDBCPartitioningInfo( + columnType: Int, column: String, lowerBound: Long, upperBound: Long, @@ -79,13 +82,14 @@ private[sql] object JDBCRelation extends Logging { // Here we get a little roundoff, but that's (hopefully) OK. val stride: Long = upperBound / numPartitions - lowerBound / numPartitions val column = partitioning.column + val columnType = partitioning.columnType var i: Int = 0 var currentValue: Long = lowerBound val ans = new ArrayBuffer[Partition]() while (i < numPartitions) { - val lBound = if (i != 0) s"$column >= $currentValue" else null + val lBound = if (i != 0) s"$column >= ${getCurrentValue(columnType, currentValue)}" else null currentValue += stride - val uBound = if (i != numPartitions - 1) s"$column < $currentValue" else null + val uBound = if (i != numPartitions - 1) s"$column < ${getCurrentValue(columnType, currentValue)}" else null val whereClause = if (uBound == null) { lBound @@ -99,6 +103,16 @@ private[sql] object JDBCRelation extends Logging { } ans.toArray } + + def getCurrentValue(columnType: Int, value: Long): String = { + if (columnType == java.sql.Types.DATE || columnType == java.sql.Types.TIMESTAMP) { + val ts = new java.sql.Timestamp(value) + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + "'" + sdf.format(ts) + "'" + } else { + value.toString + } + } } private[sql] case class JDBCRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index cc506e51bd0c6..2106a71412d16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -45,13 +45,51 @@ class JdbcRelationProvider extends CreatableRelationProvider assert(lowerBound.nonEmpty && upperBound.nonEmpty && numPartitions.nonEmpty, s"When 'partitionColumn' is specified, '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', and " + s"'$JDBC_NUM_PARTITIONS' are also required") - JDBCPartitioningInfo( + JDBCPartitioningInfo(resolvePartitionColumnType(parameters), partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get) } val parts = JDBCRelation.columnPartition(partitionInfo) JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession) } + def resolvePartitionColumnType(parameters: Map[String, String]): Int = { + val options = new JDBCOptions(parameters) + + val conn = JdbcUtils.createConnectionFactory(options)() + + val partitionColumn = options.partitionColumn + val table = options.table + + var stmt: java.sql.PreparedStatement = null + var rs: java.sql.ResultSet = null + try { + val resolveSql = s"select * from ($table) resolveTable where 1=0" + stmt = conn.prepareStatement(resolveSql) + rs = stmt.executeQuery() + val rsmd = rs.getMetaData + var partitionColumnType = -1 + for (i <- 0 until rsmd.getColumnCount) { + if (rsmd.getColumnName(i + 1).equals(partitionColumn)) { + partitionColumnType = rsmd.getColumnType(i + 1) + } + } + partitionColumnType + } catch { + case e: Exception => + -1 + } finally { + if (rs != null) { + rs.close() + } + if (stmt != null) { + rs.close() + } + if (conn != null) { + conn.close() + } + } + } + override def createRelation( sqlContext: SQLContext, mode: SaveMode,