diff --git a/connector/connect/client/jvm/pom.xml b/connector/connect/client/jvm/pom.xml index 44c653462fa28..5cf68b78da147 100644 --- a/connector/connect/client/jvm/pom.xml +++ b/connector/connect/client/jvm/pom.xml @@ -33,6 +33,7 @@ connect-client-jvm 31.0.1-jre + 1.1.0 @@ -92,6 +93,13 @@ mockito-core test + + + com.typesafe + mima-core_${scala.binary.version} + ${mima.version} + test + target/scala-${scala.binary.version}/test-classes diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala index f25d579d5c303..35ea76e5d988f 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import org.apache.spark.connect.proto +import org.apache.spark.internal.Logging import org.apache.spark.sql.Column.fn import org.apache.spark.sql.connect.client.unsupported import org.apache.spark.sql.functions.lit @@ -44,7 +45,7 @@ import org.apache.spark.sql.functions.lit * * @since 3.4.0 */ -class Column private[sql] (private[sql] val expr: proto.Expression) { +class Column private[sql] (private[sql] val expr: proto.Expression) extends Logging { /** * Sum of this expression and another expression. @@ -80,7 +81,7 @@ class Column private[sql] (private[sql] val expr: proto.Expression) { } } -object Column { +private[sql] object Column { def apply(name: String): Column = Column { builder => name match { diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 6891b2f5bed94..51b734d1daa39 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -21,7 +21,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.connect.proto import org.apache.spark.sql.connect.client.SparkResult -class Dataset(val session: SparkSession, private[sql] val plan: proto.Plan) { +class Dataset[T] private[sql] (val session: SparkSession, private[sql] val plan: proto.Plan) + extends Serializable { /** * Selects a set of column based expressions. @@ -33,7 +34,7 @@ class Dataset(val session: SparkSession, private[sql] val plan: proto.Plan) { * @since 3.4.0 */ @scala.annotation.varargs - def select(cols: Column*): Dataset = session.newDataset { builder => + def select(cols: Column*): DataFrame = session.newDataset { builder => builder.getProjectBuilder .setInput(plan.getRoot) .addAllExpressions(cols.map(_.expr).asJava) @@ -50,7 +51,7 @@ class Dataset(val session: SparkSession, private[sql] val plan: proto.Plan) { * @group typedrel * @since 3.4.0 */ - def filter(condition: Column): Dataset = session.newDataset { builder => + def filter(condition: Column): Dataset[T] = session.newDataset { builder => builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr) } @@ -62,7 +63,7 @@ class Dataset(val session: SparkSession, private[sql] val plan: proto.Plan) { * @group typedrel * @since 3.4.0 */ - def limit(n: Int): Dataset = session.newDataset { builder => + def limit(n: Int): Dataset[T] = session.newDataset { builder => builder.getLimitBuilder .setInput(plan.getRoot) .setLimit(n) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 0c4f702ca34f9..eca5658e33df9 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -16,9 +16,12 @@ */ package org.apache.spark.sql +import java.io.Closeable + import org.apache.arrow.memory.RootAllocator import org.apache.spark.connect.proto +import org.apache.spark.internal.Logging import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult} import org.apache.spark.sql.connect.client.util.Cleaner @@ -43,7 +46,9 @@ import org.apache.spark.sql.connect.client.util.Cleaner * }}} */ class SparkSession(private val client: SparkConnectClient, private val cleaner: Cleaner) - extends AutoCloseable { + extends Serializable + with Closeable + with Logging { private[this] val allocator = new RootAllocator() @@ -53,7 +58,7 @@ class SparkSession(private val client: SparkConnectClient, private val cleaner: * * @since 3.4.0 */ - def sql(query: String): Dataset = newDataset { builder => + def sql(query: String): DataFrame = newDataset { builder => builder.setSql(proto.SQL.newBuilder().setQuery(query)) } @@ -63,7 +68,7 @@ class SparkSession(private val client: SparkConnectClient, private val cleaner: * * @since 3.4.0 */ - def range(end: Long): Dataset = range(0, end) + def range(end: Long): Dataset[java.lang.Long] = range(0, end) /** * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a @@ -71,7 +76,7 @@ class SparkSession(private val client: SparkConnectClient, private val cleaner: * * @since 3.4.0 */ - def range(start: Long, end: Long): Dataset = { + def range(start: Long, end: Long): Dataset[java.lang.Long] = { range(start, end, step = 1) } @@ -81,7 +86,7 @@ class SparkSession(private val client: SparkConnectClient, private val cleaner: * * @since 3.4.0 */ - def range(start: Long, end: Long, step: Long): Dataset = { + def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = { range(start, end, step, None) } @@ -91,11 +96,15 @@ class SparkSession(private val client: SparkConnectClient, private val cleaner: * * @since 3.4.0 */ - def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset = { + def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = { range(start, end, step, Option(numPartitions)) } - private def range(start: Long, end: Long, step: Long, numPartitions: Option[Int]): Dataset = { + private def range( + start: Long, + end: Long, + step: Long, + numPartitions: Option[Int]): Dataset[java.lang.Long] = { newDataset { builder => val rangeBuilder = builder.getRangeBuilder .setStart(start) @@ -105,11 +114,11 @@ class SparkSession(private val client: SparkConnectClient, private val cleaner: } } - private[sql] def newDataset(f: proto.Relation.Builder => Unit): Dataset = { + private[sql] def newDataset[T](f: proto.Relation.Builder => Unit): Dataset[T] = { val builder = proto.Relation.newBuilder() f(builder) val plan = proto.Plan.newBuilder().setRoot(builder).build() - new Dataset(this, plan) + new Dataset[T](this, plan) } private[sql] def analyze(plan: proto.Plan): proto.AnalyzePlanResponse = @@ -130,7 +139,7 @@ class SparkSession(private val client: SparkConnectClient, private val cleaner: // The minimal builder needed to create a spark session. // TODO: implements all methods mentioned in the scaladoc of [[SparkSession]] -object SparkSession { +object SparkSession extends Logging { def builder(): Builder = new Builder() private lazy val cleaner = { @@ -139,7 +148,7 @@ object SparkSession { cleaner } - class Builder() { + class Builder() extends Logging { private var _client = SparkConnectClient.builder().build() def client(client: SparkConnectClient): Builder = { diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala new file mode 100644 index 0000000000000..ada94b76fcbcd --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +package object sql { + type DataFrame = Dataset[Row] +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala new file mode 100644 index 0000000000000..21eed56ee787f --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client + +import java.io.File +import java.net.URLClassLoader +import java.util.regex.Pattern + +import com.typesafe.tools.mima.core._ +import com.typesafe.tools.mima.lib.MiMaLib +import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite +import org.apache.spark.sql.connect.client.util.IntegrationTestUtils._ + +/** + * This test checks the binary compatibility of the connect client API against the spark SQL API + * using MiMa. We did not write this check using a SBT build rule as the rule cannot provide the + * same level of freedom as a test. With a test we can: + * 1. Specify any two jars to run the compatibility check. + * 1. Easily make the test automatically pick up all new methods added while the client is being + * built. + * + * The test requires the following artifacts built before running: + * {{{ + * spark-sql + * spark-connect-client-jvm + * }}} + * To build the above artifact, use e.g. `sbt package` or `mvn clean install -DskipTests`. + * + * When debugging this test, if any changes to the client API, the client jar need to be built + * before running the test. An example workflow with SBT for this test: + * 1. Compatibility test has reported an unexpected client API change. + * 1. Fix the wrong client API. + * 1. Build the client jar: `sbt package` + * 1. Run the test again: `sbt "testOnly + * org.apache.spark.sql.connect.client.CompatibilitySuite"` + */ +class CompatibilitySuite extends AnyFunSuite { // scalastyle:ignore funsuite + + private lazy val clientJar: File = + findJar( + "connector/connect/client/jvm", + "spark-connect-client-jvm-assembly", + "spark-connect-client-jvm") + + private lazy val sqlJar: File = findJar("sql/core", "spark-sql", "spark-sql") + + /** + * MiMa takes an old jar (sql jar) and a new jar (client jar) as inputs and then reports all + * incompatibilities found in the new jar. The incompatibility result is then filtered using + * include and exclude rules. Include rules are first applied to find all client classes that + * need to be checked. Then exclude rules are applied to filter out all unsupported methods in + * the client classes. + */ + test("compatibility MiMa tests") { + val mima = new MiMaLib(Seq(clientJar, sqlJar)) + val allProblems = mima.collectProblems(sqlJar, clientJar, List.empty) + val includedRules = Seq( + IncludeByName("org.apache.spark.sql.Column"), + IncludeByName("org.apache.spark.sql.Column$"), + IncludeByName("org.apache.spark.sql.Dataset"), + // TODO(SPARK-42175) Add the Dataset object definition + // IncludeByName("org.apache.spark.sql.Dataset$"), + IncludeByName("org.apache.spark.sql.DataFrame"), + IncludeByName("org.apache.spark.sql.SparkSession"), + IncludeByName("org.apache.spark.sql.SparkSession$")) ++ includeImplementedMethods(clientJar) + val excludeRules = Seq( + // Filter unsupported rules: + // Two sql overloading methods are marked experimental in the API and skipped in the client. + ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sql"), + // Skip all shaded dependencies in the client. + ProblemFilters.exclude[Problem]("org.sparkproject.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.connect.proto.*")) + val problems = allProblems + .filter { p => + includedRules.exists(rule => rule(p)) + } + .filter { p => + excludeRules.forall(rule => rule(p)) + } + + if (problems.nonEmpty) { + fail( + s"\nComparing client jar: $clientJar\nand sql jar: $sqlJar\n" + + problems.map(p => p.description("client")).mkString("\n")) + } + } + + test("compatibility API tests: Dataset") { + val clientClassLoader: URLClassLoader = new URLClassLoader(Seq(clientJar.toURI.toURL).toArray) + val sqlClassLoader: URLClassLoader = new URLClassLoader(Seq(sqlJar.toURI.toURL).toArray) + + val clientClass = clientClassLoader.loadClass("org.apache.spark.sql.Dataset") + val sqlClass = sqlClassLoader.loadClass("org.apache.spark.sql.Dataset") + + val newMethods = clientClass.getMethods + val oldMethods = sqlClass.getMethods + + // For now we simply check the new methods is a subset of the old methods. + newMethods + .map(m => m.toString) + .foreach(method => { + assert(oldMethods.map(m => m.toString).contains(method)) + }) + } + + /** + * Find all methods that are implemented in the client jar. Once all major methods are + * implemented we can switch to include all methods under the class using ".*" e.g. + * "org.apache.spark.sql.Dataset.*" + */ + private def includeImplementedMethods(clientJar: File): Seq[IncludeByName] = { + val clsNames = Seq( + "org.apache.spark.sql.Column", + // TODO(SPARK-42175) Add all overloading methods. Temporarily mute compatibility check for \ + // the Dataset methods, as too many overload methods are missing. + // "org.apache.spark.sql.Dataset", + "org.apache.spark.sql.SparkSession") + + val clientClassLoader: URLClassLoader = new URLClassLoader(Seq(clientJar.toURI.toURL).toArray) + clsNames + .flatMap { clsName => + val cls = clientClassLoader.loadClass(clsName) + // all distinct method names + cls.getMethods.map(m => s"$clsName.${m.getName}").toSet + } + .map { fullName => + IncludeByName(fullName) + } + } + + private case class IncludeByName(name: String) extends ProblemFilter { + private[this] val pattern = + Pattern.compile(name.split("\\*", -1).map(Pattern.quote).mkString(".*")) + + override def apply(problem: Problem): Boolean = { + pattern.matcher(problem.matchName.getOrElse("")).matches + } + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala new file mode 100644 index 0000000000000..f0ae4cad67968 --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client.util + +import java.io.File + +import org.scalatest.Assertions.fail + +object IntegrationTestUtils { + + // System properties used for testing and debugging + private val DEBUG_SC_JVM_CLIENT = "spark.debug.sc.jvm.client" + + private[connect] lazy val sparkHome: String = { + if (!(sys.props.contains("spark.test.home") || sys.env.contains("SPARK_HOME"))) { + fail("spark.test.home or SPARK_HOME is not set.") + } + sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME")) + } + private[connect] val isDebug = System.getProperty(DEBUG_SC_JVM_CLIENT, "false").toBoolean + + // Log server start stop debug info into console + // scalastyle:off println + private[connect] def debug(msg: String): Unit = if (isDebug) println(msg) + // scalastyle:on println + private[connect] def debug(error: Throwable): Unit = if (isDebug) error.printStackTrace() + + /** + * Find a jar in the Spark project artifacts. It requires a build first (e.g. sbt package, mvn + * clean install -DskipTests) so that this method can find the jar in the target folders. + * + * @return + * the jar + */ + private[sql] def findJar(path: String, sbtName: String, mvnName: String): File = { + val targetDir = new File(new File(sparkHome, path), "target") + assert( + targetDir.exists(), + s"Fail to locate the target folder: '${targetDir.getCanonicalPath}'. " + + s"SPARK_HOME='${new File(sparkHome).getCanonicalPath}'. " + + "Make sure the spark project jars has been built (e.g. using sbt package)" + + "and the env variable `SPARK_HOME` is set correctly.") + val jars = recursiveListFiles(targetDir).filter { f => + // SBT jar + (f.getParentFile.getName.startsWith("scala-") && + f.getName.startsWith(sbtName) && f.getName.endsWith(".jar")) || + // Maven Jar + (f.getParent.endsWith("target") && + f.getName.startsWith(mvnName) && + f.getName.endsWith(s"${org.apache.spark.SPARK_VERSION}.jar")) + } + // It is possible we found more than one: one built by maven, and another by SBT + assert(jars.nonEmpty, s"Failed to find the jar inside folder: ${targetDir.getCanonicalPath}") + debug("Using jar: " + jars(0).getCanonicalPath) + jars(0) // return the first jar found + } + + private def recursiveListFiles(f: File): Array[File] = { + val these = f.listFiles + these ++ these.filter(_.isDirectory).flatMap(recursiveListFiles) + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala index 552799d52297c..2d9c218b2fbe0 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala @@ -21,16 +21,17 @@ import java.util.concurrent.TimeUnit import scala.io.Source -import org.scalatest.Assertions.fail import org.scalatest.BeforeAndAfterAll import sys.process._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.client.SparkConnectClient +import org.apache.spark.sql.connect.client.util.IntegrationTestUtils._ import org.apache.spark.sql.connect.common.config.ConnectCommon /** * An util class to start a local spark connect server in a different process for local E2E tests. + * Pre-running the tests, the spark connect artifact needs to be built using e.g. `sbt package`. * It is designed to start the server once but shared by all tests. It is equivalent to use the * following command to start the connect server via command line: * @@ -45,22 +46,6 @@ import org.apache.spark.sql.connect.common.config.ConnectCommon * print the server process output in the console to debug server start stop problems. */ object SparkConnectServerUtils { - // System properties used for testing and debugging - private val DEBUG_SC_JVM_CLIENT = "spark.debug.sc.jvm.client" - - protected lazy val sparkHome: String = { - if (!(sys.props.contains("spark.test.home") || sys.env.contains("SPARK_HOME"))) { - fail("spark.test.home or SPARK_HOME is not set.") - } - sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME")) - } - private val isDebug = System.getProperty(DEBUG_SC_JVM_CLIENT, "false").toBoolean - - // Log server start stop debug info into console - // scalastyle:off println - private[connect] def debug(msg: String): Unit = if (isDebug) println(msg) - // scalastyle:on println - private[connect] def debug(error: Throwable): Unit = if (isDebug) error.printStackTrace() // Server port private[connect] val port = ConnectCommon.CONNECT_GRPC_BINDING_PORT + util.Random.nextInt(1000) @@ -72,7 +57,10 @@ object SparkConnectServerUtils { private lazy val sparkConnect: Process = { debug("Starting the Spark Connect Server...") - val jar = findSparkConnectJar + val jar = findJar( + "connector/connect/server", + "spark-connect-assembly", + "spark-connect").getCanonicalPath val builder = Process( Seq( "bin/spark-submit", @@ -118,37 +106,6 @@ object SparkConnectServerUtils { debug(s"Spark Connect Server is stopped with exit code: $code") code } - - private def findSparkConnectJar: String = { - val target = "connector/connect/server/target" - val parentDir = new File(sparkHome, target) - assert( - parentDir.exists(), - s"Fail to locate the spark connect server target folder: '${parentDir.getCanonicalPath}'. " + - s"SPARK_HOME='${new File(sparkHome).getCanonicalPath}'. " + - "Make sure the spark connect server jar has been built " + - "and the env variable `SPARK_HOME` is set correctly.") - val jars = recursiveListFiles(parentDir).filter { f => - // SBT jar - (f.getParentFile.getName.startsWith("scala-") && - f.getName.startsWith("spark-connect-assembly") && f.getName.endsWith(".jar")) || - // Maven Jar - (f.getParent.endsWith("target") && - f.getName.startsWith("spark-connect") && - f.getName.endsWith(s"${org.apache.spark.SPARK_VERSION}.jar")) - } - // It is possible we found more than one: one built by maven, and another by SBT - assert( - jars.nonEmpty, - s"Failed to find the `spark-connect` jar inside folder: ${parentDir.getCanonicalPath}") - debug("Using jar: " + jars(0).getCanonicalPath) - jars(0).getCanonicalPath // return the first one - } - - def recursiveListFiles(f: File): Array[File] = { - val these = f.listFiles - these ++ these.filter(_.isDirectory).flatMap(recursiveListFiles) - } } trait RemoteSparkSession