diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 0970b9807167..0728be922384 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -133,6 +133,25 @@ class JdbcRDD[T: ClassTag]( } } +class LimitJdbcRDD[T: ClassTag]( + sc: SparkContext, + getConnection: () => Connection, + sql: String, + lowerBound: Long, + pageSize: Long, + numPartitions: Int, + mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _) + extends JdbcRDD[T](sc, getConnection, sql, lowerBound, pageSize, numPartitions, mapRow) { + + override def getPartitions: Array[Partition] = { + (0 until numPartitions).map { i => + val start = lowerBound + i * pageSize + val end = pageSize + new JdbcPartition(i, start.toLong, end.toLong) + }.toArray + } +} + object JdbcRDD { def resultSetToObjectArray(rs: ResultSet): Array[Object] = { Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1)) diff --git a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala index 05013fbc49b8..5c8b89d82ac6 100644 --- a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala @@ -84,6 +84,19 @@ class JdbcRDDSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkCont assert(rdd.reduce(_ + _) === 10100) } + test("limit functionality") { + sc = new SparkContext("local", "test") + val rdd = new LimitJdbcRDD( + sc, + () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") }, + "SELECT DATA FROM FOO ORDER BY ID OFFSET ? ROWS FETCH NEXT ? ROWS ONLY", + 0, 10, 10, + (r: ResultSet) => { r.getInt(1) } ).cache() + + assert(rdd.count === 100) + assert(rdd.reduce(_ + _) === 10100) + } + test("large id overflow") { sc = new SparkContext("local", "test") val rdd = new JdbcRDD(