Skip to content

Commit ffcdf2e

Browse files
committed
Java API for JdbcRDD
1 parent bf1a6aa commit ffcdf2e

File tree

2 files changed

+165
-5
lines changed

2 files changed

+165
-5
lines changed

core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@ import java.sql.{Connection, ResultSet}
2121

2222
import scala.reflect.ClassTag
2323

24-
import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
24+
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
25+
import org.apache.spark.api.java.function.{Function => JFunction}
26+
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
2527
import org.apache.spark.util.NextIterator
28+
import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
2629

2730
private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition {
2831
override def index = idx
@@ -125,5 +128,82 @@ object JdbcRDD {
125128
def resultSetToObjectArray(rs: ResultSet): Array[Object] = {
126129
Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))
127130
}
128-
}
129131

132+
trait ConnectionFactory extends Serializable {
133+
@throws[Exception]
134+
def getConnection: Connection
135+
}
136+
137+
/**
138+
* Create an RDD that executes an SQL query on a JDBC connection and reads results.
139+
* For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
140+
*
141+
* @param connectionFactory a factory that returns an open Connection.
142+
* The RDD takes care of closing the connection.
143+
* @param sql the text of the query.
144+
* The query must contain two ? placeholders for parameters used to partition the results.
145+
* E.g. "select title, author from books where ? <= id and id <= ?"
146+
* @param lowerBound the minimum value of the first placeholder
147+
* @param upperBound the maximum value of the second placeholder
148+
* The lower and upper bounds are inclusive.
149+
* @param numPartitions the number of partitions.
150+
* Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
151+
* the query would be executed twice, once with (1, 10) and once with (11, 20)
152+
* @param mapRow a function from a ResultSet to a single row of the desired result type(s).
153+
* This should only call getInt, getString, etc; the RDD takes care of calling next.
154+
* The default maps a ResultSet to an array of Object.
155+
*/
156+
def create[T](
157+
sc: JavaSparkContext,
158+
connectionFactory: ConnectionFactory,
159+
sql: String,
160+
lowerBound: Long,
161+
upperBound: Long,
162+
numPartitions: Int,
163+
mapRow: JFunction[ResultSet, T]): JavaRDD[T] = {
164+
165+
val jdbcRDD = new JdbcRDD[T](
166+
sc.sc,
167+
() => connectionFactory.getConnection,
168+
sql,
169+
lowerBound,
170+
upperBound,
171+
numPartitions,
172+
(resultSet: ResultSet) => mapRow.call(resultSet))(fakeClassTag)
173+
174+
new JavaRDD[T](jdbcRDD)(fakeClassTag)
175+
}
176+
177+
/**
178+
* Create an RDD that executes an SQL query on a JDBC connection and reads results. Each row is
179+
* converted into a `Object` array. For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
180+
*
181+
* @param connectionFactory a factory that returns an open Connection.
182+
* The RDD takes care of closing the connection.
183+
* @param sql the text of the query.
184+
* The query must contain two ? placeholders for parameters used to partition the results.
185+
* E.g. "select title, author from books where ? <= id and id <= ?"
186+
* @param lowerBound the minimum value of the first placeholder
187+
* @param upperBound the maximum value of the second placeholder
188+
* The lower and upper bounds are inclusive.
189+
* @param numPartitions the number of partitions.
190+
* Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
191+
* the query would be executed twice, once with (1, 10) and once with (11, 20)
192+
*/
193+
def create(
194+
sc: JavaSparkContext,
195+
connectionFactory: ConnectionFactory,
196+
sql: String,
197+
lowerBound: Long,
198+
upperBound: Long,
199+
numPartitions: Int): JavaRDD[Array[Object]] = {
200+
201+
val mapRow = new JFunction[ResultSet, Array[Object]] {
202+
override def call(resultSet: ResultSet): Array[Object] = {
203+
resultSetToObjectArray(resultSet)
204+
}
205+
}
206+
207+
create(sc, connectionFactory, sql, lowerBound, upperBound, numPartitions, mapRow)
208+
}
209+
}

core/src/test/java/org/apache/spark/JavaAPISuite.java

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,18 @@
1818
package org.apache.spark;
1919

2020
import java.io.*;
21-
import java.nio.channels.FileChannel;
22-
import java.nio.ByteBuffer;
2321
import java.net.URI;
22+
import java.nio.ByteBuffer;
23+
import java.nio.channels.FileChannel;
24+
import java.sql.Connection;
25+
import java.sql.DriverManager;
26+
import java.sql.PreparedStatement;
27+
import java.sql.ResultSet;
28+
import java.sql.SQLException;
29+
import java.sql.Statement;
2430
import java.util.*;
2531
import java.util.concurrent.*;
2632

27-
import org.apache.spark.input.PortableDataStream;
2833
import scala.Tuple2;
2934
import scala.Tuple3;
3035
import scala.Tuple4;
@@ -51,8 +56,10 @@
5156
import org.apache.spark.api.java.*;
5257
import org.apache.spark.api.java.function.*;
5358
import org.apache.spark.executor.TaskMetrics;
59+
import org.apache.spark.input.PortableDataStream;
5460
import org.apache.spark.partial.BoundedDouble;
5561
import org.apache.spark.partial.PartialResult;
62+
import org.apache.spark.rdd.JdbcRDD;
5663
import org.apache.spark.storage.StorageLevel;
5764
import org.apache.spark.util.StatCounter;
5865

@@ -1508,4 +1515,77 @@ public void testRegisterKryoClasses() {
15081515
conf.get("spark.kryo.classesToRegister"));
15091516
}
15101517

1518+
1519+
private void setUpJdbc() throws Exception {
1520+
Class.forName("org.apache.derby.jdbc.EmbeddedDriver");
1521+
Connection connection =
1522+
DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true");
1523+
1524+
try {
1525+
Statement create = connection.createStatement();
1526+
create.execute(
1527+
"CREATE TABLE FOO(" +
1528+
"ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1)," +
1529+
"DATA INTEGER)");
1530+
create.close();
1531+
1532+
PreparedStatement insert = connection.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)");
1533+
for (int i = 1; i <= 100; i++) {
1534+
insert.setInt(i, i * 2);
1535+
insert.executeUpdate();
1536+
}
1537+
} catch (SQLException e) {
1538+
// If table doesn't exist...
1539+
if (e.getSQLState().compareTo("X0Y32") != 0) {
1540+
throw e;
1541+
}
1542+
} finally {
1543+
connection.close();
1544+
}
1545+
}
1546+
1547+
private void tearDownJdbc() throws SQLException {
1548+
try {
1549+
DriverManager.getConnection("jdbc:derby:;shutdown=true");
1550+
} catch(SQLException e) {
1551+
if (e.getSQLState().compareTo("XJ015") != 0) {
1552+
throw e;
1553+
}
1554+
}
1555+
}
1556+
1557+
@Test
1558+
public void testJavaJdbcRDD() throws Exception {
1559+
setUpJdbc();
1560+
1561+
try {
1562+
JavaRDD<Integer> rdd = JdbcRDD.create(
1563+
sc,
1564+
new JdbcRDD.ConnectionFactory() {
1565+
@Override
1566+
public Connection getConnection() throws SQLException {
1567+
return DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb");
1568+
}
1569+
},
1570+
"SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?",
1571+
1, 100, 3,
1572+
new Function<ResultSet, Integer>() {
1573+
@Override
1574+
public Integer call(ResultSet r) throws Exception {
1575+
return r.getInt(1);
1576+
}
1577+
}
1578+
).cache();
1579+
1580+
Assert.assertEquals(rdd.count(), 100);
1581+
Assert.assertEquals(rdd.reduce(new Function2<Integer, Integer, Integer>() {
1582+
@Override
1583+
public Integer call(Integer i1, Integer i2) {
1584+
return i1 + i2;
1585+
}
1586+
}), Integer.valueOf(10100));
1587+
} finally {
1588+
tearDownJdbc();
1589+
}
1590+
}
15111591
}

0 commit comments

Comments
 (0)