Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution.datasources.jdbc

import java.text.SimpleDateFormat

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.Partition
Expand All @@ -31,6 +33,7 @@ import org.apache.spark.sql.types.StructType
* Instructions on how to partition the table among workers.
*/
private[sql] case class JDBCPartitioningInfo(
columnType: Int,
column: String,
lowerBound: Long,
upperBound: Long,
Expand Down Expand Up @@ -79,13 +82,14 @@ private[sql] object JDBCRelation extends Logging {
// Here we get a little roundoff, but that's (hopefully) OK.
val stride: Long = upperBound / numPartitions - lowerBound / numPartitions
val column = partitioning.column
val columnType = partitioning.columnType
var i: Int = 0
var currentValue: Long = lowerBound
val ans = new ArrayBuffer[Partition]()
while (i < numPartitions) {
val lBound = if (i != 0) s"$column >= $currentValue" else null
val lBound = if (i != 0) s"$column >= ${getCurrentValue(columnType, currentValue)}" else null
currentValue += stride
val uBound = if (i != numPartitions - 1) s"$column < $currentValue" else null
val uBound = if (i != numPartitions - 1) s"$column < ${getCurrentValue(columnType, currentValue)}" else null
val whereClause =
if (uBound == null) {
lBound
Expand All @@ -99,6 +103,16 @@ private[sql] object JDBCRelation extends Logging {
}
ans.toArray
}

def getCurrentValue(columnType: Int, value: Long): String = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably, you can use DateTimeUtils to convert currnetValue to timestamp/date.

if (columnType == java.sql.Types.DATE || columnType == java.sql.Types.TIMESTAMP) {
val ts = new java.sql.Timestamp(value)
val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
"'" + sdf.format(ts) + "'"
} else {
value.toString
}
}
}

private[sql] case class JDBCRelation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,51 @@ class JdbcRelationProvider extends CreatableRelationProvider
assert(lowerBound.nonEmpty && upperBound.nonEmpty && numPartitions.nonEmpty,
s"When 'partitionColumn' is specified, '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', and " +
s"'$JDBC_NUM_PARTITIONS' are also required")
JDBCPartitioningInfo(
JDBCPartitioningInfo(resolvePartitionColumnType(parameters),
partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get)
}
val parts = JDBCRelation.columnPartition(partitionInfo)
JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession)
}

def resolvePartitionColumnType(parameters: Map[String, String]): Int = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want a column type, how about using JDBCRDD.resolveTable?

val options = new JDBCOptions(parameters)

val conn = JdbcUtils.createConnectionFactory(options)()

val partitionColumn = options.partitionColumn
val table = options.table

var stmt: java.sql.PreparedStatement = null
var rs: java.sql.ResultSet = null
try {
val resolveSql = s"select * from ($table) resolveTable where 1=0"
stmt = conn.prepareStatement(resolveSql)
rs = stmt.executeQuery()
val rsmd = rs.getMetaData
var partitionColumnType = -1
for (i <- 0 until rsmd.getColumnCount) {
if (rsmd.getColumnName(i + 1).equals(partitionColumn)) {
partitionColumnType = rsmd.getColumnType(i + 1)
}
}
partitionColumnType
} catch {
case e: Exception =>
-1
} finally {
if (rs != null) {
rs.close()
}
if (stmt != null) {
rs.close()
}
if (conn != null) {
conn.close()
}
}
}

override def createRelation(
sqlContext: SQLContext,
mode: SaveMode,
Expand Down