diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java index 83df3be74708..3fb96ea28a8f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java @@ -74,7 +74,7 @@ * */ @InterfaceStability.Evolving -public class DataSourceOptions { +public class DataSourceOptions implements java.io.Serializable { private final Map keyLowerCasedMap; private String toLowerCase(String key) { diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 1b37905543b4..6359ed3fecea 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,5 +5,6 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider +org.apache.spark.sql.execution.datasources.jdbc.jdbcv2.JDBCDataSourceV2 org.apache.spark.sql.execution.streaming.sources.RateStreamProvider org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 0bab3689e5d0..a72167e4f74a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -25,6 +25,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.jdbc.jdbcv2.JDBCOptionsV2 import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -72,6 +73,10 @@ object JDBCRDD extends Logging { } } + def resolveTable(options: JDBCOptionsV2): StructType = { + resolveTable(options.jdbcOptionsV1) + } + /** * Prune all but the specified columns from the specified Catalyst schema. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 433443007cfd..3699276e2dd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.execution.datasources.jdbc.jdbcv2.JDBCOptionsV2 import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -64,6 +65,10 @@ object JdbcUtils extends Logging { } } + def createConnectionFactory(options: JDBCOptionsV2): () => Connection = { + createConnectionFactory(options.jdbcOptionsV1) + } + /** * Returns true if the table already exists in the JDBC database. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/jdbcv2/JDBCDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/jdbcv2/JDBCDataSourceV2.scala new file mode 100644 index 000000000000..48695071561e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/jdbcv2/JDBCDataSourceV2.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc.jdbcv2 + +import java.sql.{Connection, ResultSet} +import java.util + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.datasources.jdbc._ +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils._ +import org.apache.spark.sql.jdbc.JdbcDialects +import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.types.StructType + +class JDBCDataSourceV2 extends DataSourceV2 with ReadSupport with DataSourceRegister { + + override def createReader(options: DataSourceOptions): DataSourceReader = { + val jDBCOptV2 = new JDBCOptionsV2(options) + new JDBCDataSourceReader(jDBCOptV2) + } + + override def shortName(): String = "jdbcv2" +} + +class JDBCDataSourceReader(options: JDBCOptionsV2) + extends DataSourceReader with SupportsPushDownFilters with SupportsPushDownRequiredColumns { + val fullschema = JDBCRDD.resolveTable(options) + var requiredSchema = fullschema + val schema = readSchema() + var pushedFiltersArray: Array[Filter] = Array.empty + + override def readSchema(): StructType = { + requiredSchema + } + + override def planInputPartitions(): util.List[InputPartition[Row]] = { + if (options.partitionColumn.isDefined) { + val partitionInfo = JDBCPartitioningInfo( + options.partitionColumn.get, options.lowerBound, + options.upperBound, options.numPartitions) + val parts = JDBCRelation.columnPartition(partitionInfo.asInstanceOf[JDBCPartitioningInfo]) + + parts.map { p => + new JDBCInputPartition(requiredSchema, options, + p.asInstanceOf[JDBCPartition], pushedFiltersArray): InputPartition[Row] + }.toList.asJava + } + else { + List(new JDBCInputPartition(requiredSchema, options, (JDBCPartition(null, 0)), + pushedFiltersArray) + : InputPartition[Row]).asJava + } + } + + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + val postScanfilters = + filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(options.url)).isEmpty) + pushedFiltersArray = filters diff postScanfilters + postScanfilters + } + + override def pushedFilters(): Array[Filter] = { + pushedFiltersArray + } + + override def pruneColumns(requiredSchema: StructType): Unit = { + this.requiredSchema = requiredSchema + } +} + +class JDBCInputPartition( + requiredSchema: StructType, + options: JDBCOptionsV2, + jDBCPartition: JDBCPartition, + pushedFiltersArray: Array[Filter] +) extends InputPartition[Row] { + override def createPartitionReader(): InputPartitionReader[Row] = + new JDBCInputPartitionReader(requiredSchema, options, jDBCPartition, pushedFiltersArray) +} + +class JDBCInputPartitionReader( + requiredSchema: StructType, + options: JDBCOptionsV2, + jDBCPartition: JDBCPartition, + pushedFiltersArray: Array[Filter] +) extends InputPartitionReader[Row] { + + private val columnList: String = { + val sb = new StringBuilder() + requiredSchema.fieldNames.foreach(x => sb.append(",").append(x)) + if (sb.isEmpty) "1" else sb.substring(1) + } + + private val filterWhereClause: String = + pushedFiltersArray + .flatMap(JDBCRDD.compileFilter(_, JdbcDialects.get(options.url))) + .map(p => s"($p)").mkString(" AND ") + + private def getWhereClause(part: JDBCPartition): String = { + if (part.whereClause != null && filterWhereClause.length > 0) { + "WHERE " + s"($filterWhereClause)" + " AND " + s"(${part.whereClause})" + } else if (part.whereClause != null) { + "WHERE " + part.whereClause + } else if (filterWhereClause.length > 0) { + "WHERE " + filterWhereClause + } else { + "" + } + } + + val myWhereClause = getWhereClause(jDBCPartition) + val conn: Connection = JdbcUtils.createConnectionFactory(options)() + val sqlText = s"SELECT $columnList FROM ${options.table} $myWhereClause" + val stmt = conn.prepareStatement(sqlText, + ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) + stmt.setFetchSize(options.fetchSize) + stmt.setQueryTimeout(options.queryTimeout) + val rs = stmt.executeQuery() + val rowIterator = resultSetToRows(rs: ResultSet, requiredSchema: StructType) + + override def next(): Boolean = rowIterator.hasNext + + override def get(): Row = { + rowIterator.next + } + + override def close(): Unit = rs.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/jdbcv2/JDBCOptionsV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/jdbcv2/JDBCOptionsV2.scala new file mode 100644 index 000000000000..2778e0a631b6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/jdbcv2/JDBCOptionsV2.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc.jdbcv2 + +import java.util.{NoSuchElementException, Properties} +import java.util.function.Supplier + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions +import org.apache.spark.sql.sources.v2.DataSourceOptions + +class JDBCOptionsV2(val options: DataSourceOptions) extends Serializable { + + import JDBCOptionsV2._ + + def this() = this(new DataSourceOptions(Map.empty[String, String].asJava)) + + def this(parameters: Map[String, String]) = this(new DataSourceOptions(parameters.asJava)) + + // a JDBC URL + val url = options.get(JDBC_URL).orElseThrow(new Supplier[Throwable] { + override def get(): Throwable = new NoSuchElementException("no such url") + }) + + // TODO use dataSourceOptions.tableName() + val table = options.get(JDBC_TABLE_NAME).orElseThrow(new Supplier[Throwable] { + override def get(): Throwable = new NoSuchElementException("no such table") + }) + + // TODO register user specified driverClass + val driverClass = "org.h2.Driver" + + val asConnectionProperties: Properties = { + val properties = new Properties() + properties + } + + val partitionColumn = toScala(options.get(JDBC_PARTITION_COLUMN)) + val lowerBound = options.getLong(JDBC_LOWER_BOUND, Long.MaxValue) + val upperBound = options.getLong(JDBC_UPPER_BOUND, Long.MinValue) + val numPartitions = options.getInt(JDBC_NUM_PARTITIONS, 1) + val fetchSize = options.getInt(JDBC_BATCH_FETCH_SIZE, 1000) + val queryTimeout = options.getInt(JDBC_BATCH_FETCH_SIZE, 60) + val predicates = toScala(options.get(JDBC_Predicates)) + + // Convert info to V1 + val jdbcOptionsV1 = new JDBCOptions(Map(JDBCOptions.JDBC_URL -> url, + JDBCOptions.JDBC_TABLE_NAME -> table, // databaseName + "," + table) + JDBCOptions.JDBC_DRIVER_CLASS -> driverClass) + ) + +} + +object JDBCOptionsV2 { + + val JDBC_URL = "url" + val JDBC_DATABASE_NAME = "database" + val JDBC_TABLE_NAME = "dbtable" + val JDBC_DRIVER_CLASS = "driver" + val JDBC_PARTITION_COLUMN = "partitionColumn" + val JDBC_LOWER_BOUND = "lowerBound" + val JDBC_UPPER_BOUND = "upperBound" + val JDBC_NUM_PARTITIONS = "numPartitions" + val JDBC_Predicates = "predicates" + val JDBC_QUERY_TIMEOUT = "queryTimeout" + val JDBC_BATCH_FETCH_SIZE = "fetchsize" + val JDBC_TRUNCATE = "truncate" + val JDBC_CREATE_TABLE_OPTIONS = "createTableOptions" + val JDBC_CREATE_TABLE_COLUMN_TYPES = "createTableColumnTypes" + val JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES = "customSchema" + val JDBC_BATCH_INSERT_SIZE = "batchsize" + val JDBC_TXN_ISOLATION_LEVEL = "isolationLevel" + val JDBC_SESSION_INIT_STATEMENT = "sessionInitStatement" + + final def toScala[A](o: java.util.Optional[A]): Option[A] = if (o.isPresent) Some(o.get) else None +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/jdbc/jdbcv2/JDBCDataSourceV2Suite.scala b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/jdbc/jdbcv2/JDBCDataSourceV2Suite.scala new file mode 100644 index 000000000000..0d75ef73bd9b --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/jdbc/jdbcv2/JDBCDataSourceV2Suite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package jdbcv2 + +import java.sql.DriverManager +import java.util.Properties + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.execution.datasources.jdbc.jdbcv2.JDBCDataSourceReader +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.test.SharedSQLContext + +class JDBCDataSourceV2Suite extends QueryTest + with BeforeAndAfter with SharedSQLContext { + + import testImplicits._ + + val url = "jdbc:h2:mem:testdb0" + val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass" + var conn: java.sql.Connection = null + + before { + val properties = new Properties() + properties.setProperty("user", "testUser") + properties.setProperty("password", "testPass") + properties.setProperty("rowId", "false") + + conn = DriverManager.getConnection(url, properties) + conn.prepareStatement("create schema test").executeUpdate() + conn.prepareStatement( + "create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() + conn.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate() + conn.prepareStatement("insert into test.people values ('mary', 2)").executeUpdate() + conn.prepareStatement( + "insert into test.people values ('joe ''foo'' \"bar\"', 3)").executeUpdate() + conn.commit() + } + + after { + conn.close() + } + + def getReader(query: DataFrame): JDBCDataSourceReader = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[JDBCDataSourceReader] + }.head + } + + test("JDBCDataSourceV2 Implementation") { + val df = spark.read.format("jdbcv2") + .option("url", urlWithUserAndPass) + .option("dbtable", "TEST.PEOPLE") + .option("partitionColumn", "THEID") + .option("lowerBound", 0) + .option("upperBound", 3) + .option("numPartitions", 3) + .load() + + val expectedDF = Seq(("fred", 1), ("mary", 2), ("joe 'foo' \"bar\"", 3)).toDF("NAME", "THEID") + checkAnswer(df, expectedDF) + assert(getReader(df).planInputPartitions().size === 3) + + val df2 = df.select("NAME").filter("THEID = 1") + val expectedDF2 = Seq(("fred")).toDF("NAME") + checkAnswer(df2, expectedDF2) + + val reader = getReader(df2) + assert(reader.pushedFilters().flatMap(_.references).toSet === Set("THEID")) + assert(reader.pushedFiltersArray.flatMap(_.references).toSet === Set("THEID")) + assert(reader.requiredSchema.fieldNames === Seq("NAME")) + } +}