Skip to content

Commit 120a350

Browse files
lianchengmateiz
authored andcommitted
[SPARK-4613][Core] Java API for JdbcRDD
This PR introduces a set of Java APIs for using `JdbcRDD`: 1. Trait (interface) `JdbcRDD.ConnectionFactory`: equivalent to the `getConnection: () => Connection` parameter in `JdbcRDD` constructor. 2. Two overloaded versions of `Jdbc.create`: used to create `JavaRDD` that wraps a `JdbcRDD`. <!-- Reviewable:start --> [<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/3478) <!-- Reviewable:end --> Author: Cheng Lian <lian@databricks.com> Closes #3478 from liancheng/japi-jdbc-rdd and squashes the following commits: 9a54625 [Cheng Lian] Only shutdowns a single DB rather than the whole Derby driver d4cedc5 [Cheng Lian] Moves Java JdbcRDD test case to a separate test suite ffcdf2e [Cheng Lian] Java API for JdbcRDD
1 parent 84376d3 commit 120a350

File tree

3 files changed

+204
-5
lines changed

3 files changed

+204
-5
lines changed

core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@ import java.sql.{Connection, ResultSet}
2121

2222
import scala.reflect.ClassTag
2323

24-
import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
24+
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
25+
import org.apache.spark.api.java.function.{Function => JFunction}
26+
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
2527
import org.apache.spark.util.NextIterator
28+
import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
2629

2730
private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition {
2831
override def index = idx
@@ -125,5 +128,82 @@ object JdbcRDD {
125128
def resultSetToObjectArray(rs: ResultSet): Array[Object] = {
126129
Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))
127130
}
128-
}
129131

132+
trait ConnectionFactory extends Serializable {
133+
@throws[Exception]
134+
def getConnection: Connection
135+
}
136+
137+
/**
138+
* Create an RDD that executes an SQL query on a JDBC connection and reads results.
139+
* For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
140+
*
141+
* @param connectionFactory a factory that returns an open Connection.
142+
* The RDD takes care of closing the connection.
143+
* @param sql the text of the query.
144+
* The query must contain two ? placeholders for parameters used to partition the results.
145+
* E.g. "select title, author from books where ? <= id and id <= ?"
146+
* @param lowerBound the minimum value of the first placeholder
147+
* @param upperBound the maximum value of the second placeholder
148+
* The lower and upper bounds are inclusive.
149+
* @param numPartitions the number of partitions.
150+
* Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
151+
* the query would be executed twice, once with (1, 10) and once with (11, 20)
152+
* @param mapRow a function from a ResultSet to a single row of the desired result type(s).
153+
* This should only call getInt, getString, etc; the RDD takes care of calling next.
154+
* The default maps a ResultSet to an array of Object.
155+
*/
156+
def create[T](
157+
sc: JavaSparkContext,
158+
connectionFactory: ConnectionFactory,
159+
sql: String,
160+
lowerBound: Long,
161+
upperBound: Long,
162+
numPartitions: Int,
163+
mapRow: JFunction[ResultSet, T]): JavaRDD[T] = {
164+
165+
val jdbcRDD = new JdbcRDD[T](
166+
sc.sc,
167+
() => connectionFactory.getConnection,
168+
sql,
169+
lowerBound,
170+
upperBound,
171+
numPartitions,
172+
(resultSet: ResultSet) => mapRow.call(resultSet))(fakeClassTag)
173+
174+
new JavaRDD[T](jdbcRDD)(fakeClassTag)
175+
}
176+
177+
/**
178+
* Create an RDD that executes an SQL query on a JDBC connection and reads results. Each row is
179+
* converted into a `Object` array. For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
180+
*
181+
* @param connectionFactory a factory that returns an open Connection.
182+
* The RDD takes care of closing the connection.
183+
* @param sql the text of the query.
184+
* The query must contain two ? placeholders for parameters used to partition the results.
185+
* E.g. "select title, author from books where ? <= id and id <= ?"
186+
* @param lowerBound the minimum value of the first placeholder
187+
* @param upperBound the maximum value of the second placeholder
188+
* The lower and upper bounds are inclusive.
189+
* @param numPartitions the number of partitions.
190+
* Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
191+
* the query would be executed twice, once with (1, 10) and once with (11, 20)
192+
*/
193+
def create(
194+
sc: JavaSparkContext,
195+
connectionFactory: ConnectionFactory,
196+
sql: String,
197+
lowerBound: Long,
198+
upperBound: Long,
199+
numPartitions: Int): JavaRDD[Array[Object]] = {
200+
201+
val mapRow = new JFunction[ResultSet, Array[Object]] {
202+
override def call(resultSet: ResultSet): Array[Object] = {
203+
resultSetToObjectArray(resultSet)
204+
}
205+
}
206+
207+
create(sc, connectionFactory, sql, lowerBound, upperBound, numPartitions, mapRow)
208+
}
209+
}
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark;
18+
19+
import java.io.Serializable;
20+
import java.sql.Connection;
21+
import java.sql.DriverManager;
22+
import java.sql.PreparedStatement;
23+
import java.sql.ResultSet;
24+
import java.sql.SQLException;
25+
import java.sql.Statement;
26+
27+
import org.apache.spark.api.java.JavaRDD;
28+
import org.apache.spark.api.java.JavaSparkContext;
29+
import org.apache.spark.api.java.function.Function;
30+
import org.apache.spark.api.java.function.Function2;
31+
import org.apache.spark.rdd.JdbcRDD;
32+
import org.junit.After;
33+
import org.junit.Assert;
34+
import org.junit.Before;
35+
import org.junit.Test;
36+
37+
public class JavaJdbcRDDSuite implements Serializable {
38+
private transient JavaSparkContext sc;
39+
40+
@Before
41+
public void setUp() throws ClassNotFoundException, SQLException {
42+
sc = new JavaSparkContext("local", "JavaAPISuite");
43+
44+
Class.forName("org.apache.derby.jdbc.EmbeddedDriver");
45+
Connection connection =
46+
DriverManager.getConnection("jdbc:derby:target/JavaJdbcRDDSuiteDb;create=true");
47+
48+
try {
49+
Statement create = connection.createStatement();
50+
create.execute(
51+
"CREATE TABLE FOO(" +
52+
"ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1)," +
53+
"DATA INTEGER)");
54+
create.close();
55+
56+
PreparedStatement insert = connection.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)");
57+
for (int i = 1; i <= 100; i++) {
58+
insert.setInt(1, i * 2);
59+
insert.executeUpdate();
60+
}
61+
insert.close();
62+
} catch (SQLException e) {
63+
// If table doesn't exist...
64+
if (e.getSQLState().compareTo("X0Y32") != 0) {
65+
throw e;
66+
}
67+
} finally {
68+
connection.close();
69+
}
70+
}
71+
72+
@After
73+
public void tearDown() throws SQLException {
74+
try {
75+
DriverManager.getConnection("jdbc:derby:target/JavaJdbcRDDSuiteDb;shutdown=true");
76+
} catch(SQLException e) {
77+
// Throw if not normal single database shutdown
78+
// https://db.apache.org/derby/docs/10.2/ref/rrefexcept71493.html
79+
if (e.getSQLState().compareTo("08006") != 0) {
80+
throw e;
81+
}
82+
}
83+
84+
sc.stop();
85+
sc = null;
86+
}
87+
88+
@Test
89+
public void testJavaJdbcRDD() throws Exception {
90+
JavaRDD<Integer> rdd = JdbcRDD.create(
91+
sc,
92+
new JdbcRDD.ConnectionFactory() {
93+
@Override
94+
public Connection getConnection() throws SQLException {
95+
return DriverManager.getConnection("jdbc:derby:target/JavaJdbcRDDSuiteDb");
96+
}
97+
},
98+
"SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?",
99+
1, 100, 1,
100+
new Function<ResultSet, Integer>() {
101+
@Override
102+
public Integer call(ResultSet r) throws Exception {
103+
return r.getInt(1);
104+
}
105+
}
106+
).cache();
107+
108+
Assert.assertEquals(100, rdd.count());
109+
Assert.assertEquals(
110+
Integer.valueOf(10100),
111+
rdd.reduce(new Function2<Integer, Integer, Integer>() {
112+
@Override
113+
public Integer call(Integer i1, Integer i2) {
114+
return i1 + i2;
115+
}
116+
}));
117+
}
118+
}

core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,11 @@ class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
6565

6666
after {
6767
try {
68-
DriverManager.getConnection("jdbc:derby:;shutdown=true")
68+
DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;shutdown=true")
6969
} catch {
70-
case se: SQLException if se.getSQLState == "XJ015" =>
71-
// normal shutdown
70+
case se: SQLException if se.getSQLState == "08006" =>
71+
// Normal single database shutdown
72+
// https://db.apache.org/derby/docs/10.2/ref/rrefexcept71493.html
7273
}
7374
}
7475
}

0 commit comments

Comments
 (0)