Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 82 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ import java.sql.{Connection, ResultSet}

import scala.reflect.ClassTag

import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.util.NextIterator
import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}

private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition {
override def index = idx
Expand Down Expand Up @@ -125,5 +128,82 @@ object JdbcRDD {
def resultSetToObjectArray(rs: ResultSet): Array[Object] = {
Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))
}
}

trait ConnectionFactory extends Serializable {
@throws[Exception]
def getConnection: Connection
}

/**
* Create an RDD that executes an SQL query on a JDBC connection and reads results.
* For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
*
* @param connectionFactory a factory that returns an open Connection.
* The RDD takes care of closing the connection.
* @param sql the text of the query.
* The query must contain two ? placeholders for parameters used to partition the results.
* E.g. "select title, author from books where ? <= id and id <= ?"
* @param lowerBound the minimum value of the first placeholder
* @param upperBound the maximum value of the second placeholder
* The lower and upper bounds are inclusive.
* @param numPartitions the number of partitions.
* Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
* the query would be executed twice, once with (1, 10) and once with (11, 20)
* @param mapRow a function from a ResultSet to a single row of the desired result type(s).
* This should only call getInt, getString, etc; the RDD takes care of calling next.
* The default maps a ResultSet to an array of Object.
*/
def create[T](
sc: JavaSparkContext,
connectionFactory: ConnectionFactory,
sql: String,
lowerBound: Long,
upperBound: Long,
numPartitions: Int,
mapRow: JFunction[ResultSet, T]): JavaRDD[T] = {

val jdbcRDD = new JdbcRDD[T](
sc.sc,
() => connectionFactory.getConnection,
sql,
lowerBound,
upperBound,
numPartitions,
(resultSet: ResultSet) => mapRow.call(resultSet))(fakeClassTag)

new JavaRDD[T](jdbcRDD)(fakeClassTag)
}

/**
* Create an RDD that executes an SQL query on a JDBC connection and reads results. Each row is
* converted into a `Object` array. For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
*
* @param connectionFactory a factory that returns an open Connection.
* The RDD takes care of closing the connection.
* @param sql the text of the query.
* The query must contain two ? placeholders for parameters used to partition the results.
* E.g. "select title, author from books where ? <= id and id <= ?"
* @param lowerBound the minimum value of the first placeholder
* @param upperBound the maximum value of the second placeholder
* The lower and upper bounds are inclusive.
* @param numPartitions the number of partitions.
* Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
* the query would be executed twice, once with (1, 10) and once with (11, 20)
*/
def create(
sc: JavaSparkContext,
connectionFactory: ConnectionFactory,
sql: String,
lowerBound: Long,
upperBound: Long,
numPartitions: Int): JavaRDD[Array[Object]] = {

val mapRow = new JFunction[ResultSet, Array[Object]] {
override def call(resultSet: ResultSet): Array[Object] = {
resultSetToObjectArray(resultSet)
}
}

create(sc, connectionFactory, sql, lowerBound, upperBound, numPartitions, mapRow)
}
}
118 changes: 118 additions & 0 deletions core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark;

import java.io.Serializable;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.rdd.JdbcRDD;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

public class JavaJdbcRDDSuite implements Serializable {
private transient JavaSparkContext sc;

@Before
public void setUp() throws ClassNotFoundException, SQLException {
sc = new JavaSparkContext("local", "JavaAPISuite");

Class.forName("org.apache.derby.jdbc.EmbeddedDriver");
Connection connection =
DriverManager.getConnection("jdbc:derby:target/JavaJdbcRDDSuiteDb;create=true");

try {
Statement create = connection.createStatement();
create.execute(
"CREATE TABLE FOO(" +
"ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1)," +
"DATA INTEGER)");
create.close();

PreparedStatement insert = connection.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)");
for (int i = 1; i <= 100; i++) {
insert.setInt(1, i * 2);
insert.executeUpdate();
}
insert.close();
} catch (SQLException e) {
// If table doesn't exist...
if (e.getSQLState().compareTo("X0Y32") != 0) {
throw e;
}
} finally {
connection.close();
}
}

@After
public void tearDown() throws SQLException {
try {
DriverManager.getConnection("jdbc:derby:target/JavaJdbcRDDSuiteDb;shutdown=true");
} catch(SQLException e) {
// Throw if not normal single database shutdown
// https://db.apache.org/derby/docs/10.2/ref/rrefexcept71493.html
if (e.getSQLState().compareTo("08006") != 0) {
throw e;
}
}

sc.stop();
sc = null;
}

@Test
public void testJavaJdbcRDD() throws Exception {
JavaRDD<Integer> rdd = JdbcRDD.create(
sc,
new JdbcRDD.ConnectionFactory() {
@Override
public Connection getConnection() throws SQLException {
return DriverManager.getConnection("jdbc:derby:target/JavaJdbcRDDSuiteDb");
}
},
"SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?",
1, 100, 1,
new Function<ResultSet, Integer>() {
@Override
public Integer call(ResultSet r) throws Exception {
return r.getInt(1);
}
}
).cache();

Assert.assertEquals(100, rdd.count());
Assert.assertEquals(
Integer.valueOf(10100),
rdd.reduce(new Function2<Integer, Integer, Integer>() {
@Override
public Integer call(Integer i1, Integer i2) {
return i1 + i2;
}
}));
}
}
7 changes: 4 additions & 3 deletions core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {

after {
try {
DriverManager.getConnection("jdbc:derby:;shutdown=true")
DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;shutdown=true")
} catch {
case se: SQLException if se.getSQLState == "XJ015" =>
// normal shutdown
case se: SQLException if se.getSQLState == "08006" =>
// Normal single database shutdown
// https://db.apache.org/derby/docs/10.2/ref/rrefexcept71493.html
}
}
}