@@ -21,8 +21,11 @@ import java.sql.{Connection, ResultSet}
2121
2222import scala .reflect .ClassTag
2323
24- import org .apache .spark .{Logging , Partition , SparkContext , TaskContext }
24+ import org .apache .spark .api .java .JavaSparkContext .fakeClassTag
25+ import org .apache .spark .api .java .function .{Function => JFunction }
26+ import org .apache .spark .api .java .{JavaRDD , JavaSparkContext }
2527import org .apache .spark .util .NextIterator
28+ import org .apache .spark .{Logging , Partition , SparkContext , TaskContext }
2629
2730private [spark] class JdbcPartition (idx : Int , val lower : Long , val upper : Long ) extends Partition {
2831 override def index = idx
@@ -125,5 +128,82 @@ object JdbcRDD {
125128 def resultSetToObjectArray (rs : ResultSet ): Array [Object ] = {
126129 Array .tabulate[Object ](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1 ))
127130 }
128- }
129131
132+ trait ConnectionFactory extends Serializable {
133+ @ throws[Exception ]
134+ def getConnection : Connection
135+ }
136+
137+ /**
138+ * Create an RDD that executes an SQL query on a JDBC connection and reads results.
139+ * For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
140+ *
141+ * @param connectionFactory a factory that returns an open Connection.
142+ * The RDD takes care of closing the connection.
143+ * @param sql the text of the query.
144+ * The query must contain two ? placeholders for parameters used to partition the results.
145+ * E.g. "select title, author from books where ? <= id and id <= ?"
146+ * @param lowerBound the minimum value of the first placeholder
147+ * @param upperBound the maximum value of the second placeholder
148+ * The lower and upper bounds are inclusive.
149+ * @param numPartitions the number of partitions.
150+ * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
151+ * the query would be executed twice, once with (1, 10) and once with (11, 20)
152+ * @param mapRow a function from a ResultSet to a single row of the desired result type(s).
153+ * This should only call getInt, getString, etc; the RDD takes care of calling next.
154+ * The default maps a ResultSet to an array of Object.
155+ */
156+ def create [T ](
157+ sc : JavaSparkContext ,
158+ connectionFactory : ConnectionFactory ,
159+ sql : String ,
160+ lowerBound : Long ,
161+ upperBound : Long ,
162+ numPartitions : Int ,
163+ mapRow : JFunction [ResultSet , T ]): JavaRDD [T ] = {
164+
165+ val jdbcRDD = new JdbcRDD [T ](
166+ sc.sc,
167+ () => connectionFactory.getConnection,
168+ sql,
169+ lowerBound,
170+ upperBound,
171+ numPartitions,
172+ (resultSet : ResultSet ) => mapRow.call(resultSet))(fakeClassTag)
173+
174+ new JavaRDD [T ](jdbcRDD)(fakeClassTag)
175+ }
176+
177+ /**
178+ * Create an RDD that executes an SQL query on a JDBC connection and reads results. Each row is
179+ * converted into a `Object` array. For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
180+ *
181+ * @param connectionFactory a factory that returns an open Connection.
182+ * The RDD takes care of closing the connection.
183+ * @param sql the text of the query.
184+ * The query must contain two ? placeholders for parameters used to partition the results.
185+ * E.g. "select title, author from books where ? <= id and id <= ?"
186+ * @param lowerBound the minimum value of the first placeholder
187+ * @param upperBound the maximum value of the second placeholder
188+ * The lower and upper bounds are inclusive.
189+ * @param numPartitions the number of partitions.
190+ * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
191+ * the query would be executed twice, once with (1, 10) and once with (11, 20)
192+ */
193+ def create (
194+ sc : JavaSparkContext ,
195+ connectionFactory : ConnectionFactory ,
196+ sql : String ,
197+ lowerBound : Long ,
198+ upperBound : Long ,
199+ numPartitions : Int ): JavaRDD [Array [Object ]] = {
200+
201+ val mapRow = new JFunction [ResultSet , Array [Object ]] {
202+ override def call (resultSet : ResultSet ): Array [Object ] = {
203+ resultSetToObjectArray(resultSet)
204+ }
205+ }
206+
207+ create(sc, connectionFactory, sql, lowerBound, upperBound, numPartitions, mapRow)
208+ }
209+ }
0 commit comments