Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Original file line number Diff line number Diff line change
Expand Up @@ -1270,14 +1270,17 @@ class Column private[sql] (private[sql] val expr: proto.Expression) extends Logg

private[sql] object Column {

def apply(name: String): Column = Column { builder =>
def apply(name: String): Column = Column(name, None)

def apply(name: String, planId: Option[Long]): Column = Column { builder =>
name match {
case "*" =>
builder.getUnresolvedStarBuilder
case _ if name.endsWith(".*") =>
builder.getUnresolvedStarBuilder.setUnparsedTarget(name)
case _ =>
builder.getUnresolvedAttributeBuilder.setUnparsedIdentifier(name)
val attributeBuilder = builder.getUnresolvedAttributeBuilder.setUnparsedIdentifier(name)
planId.foreach(attributeBuilder.setPlanId)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ import org.apache.spark.util.Utils
*/
class Dataset[T] private[sql] (val session: SparkSession, private[sql] val plan: proto.Plan)
extends Serializable {
// Make sure we don't forget to set plan id.
assert(plan.getRoot.getCommon.hasPlanId)

override def toString: String = {
try {
Expand Down Expand Up @@ -873,7 +875,14 @@ class Dataset[T] private[sql] (val session: SparkSession, private[sql] val plan:
* @group untypedrel
* @since 3.4.0
*/
def col(colName: String): Column = functions.col(colName)
def col(colName: String): Column = {
val planId = if (plan.getRoot.hasCommon && plan.getRoot.getCommon.hasPlanId) {
Option(plan.getRoot.getCommon.getPlanId)
} else {
None
}
Column.apply(colName, planId)
}

/**
* Selects column based on the column name specified as a regex and returns it as [[Column]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.spark.sql

import java.io.Closeable
import java.util.concurrent.atomic.AtomicLong

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -48,7 +49,10 @@ import org.apache.spark.sql.connect.client.util.Cleaner
* .getOrCreate()
* }}}
*/
class SparkSession(private val client: SparkConnectClient, private val cleaner: Cleaner)
class SparkSession(
private val client: SparkConnectClient,
private val cleaner: Cleaner,
private val planIdGenerator: AtomicLong)
extends Serializable
with Closeable
with Logging {
Expand Down Expand Up @@ -183,6 +187,7 @@ class SparkSession(private val client: SparkConnectClient, private val cleaner:
private[sql] def newDataset[T](f: proto.Relation.Builder => Unit): Dataset[T] = {
val builder = proto.Relation.newBuilder()
f(builder)
builder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement())
val plan = proto.Plan.newBuilder().setRoot(builder).build()
new Dataset[T](this, plan)
}
Expand All @@ -204,6 +209,15 @@ class SparkSession(private val client: SparkConnectClient, private val cleaner:
client.execute(plan).asScala.foreach(_ => ())
}

/**
* This resets the plan id generator so we can produce plans that are comparable.
*
* For testing only!
*/
private[sql] def resetPlanIdGenerator(): Unit = {
planIdGenerator.set(0)
}

override def close(): Unit = {
client.shutdown()
allocator.close()
Expand All @@ -213,9 +227,11 @@ class SparkSession(private val client: SparkConnectClient, private val cleaner:
// The minimal builder needed to create a spark session.
// TODO: implements all methods mentioned in the scaladoc of [[SparkSession]]
object SparkSession extends Logging {
private val planIdGenerator = new AtomicLong

def builder(): Builder = new Builder()

private lazy val cleaner = {
private[sql] lazy val cleaner = {
val cleaner = new Cleaner
cleaner.start()
cleaner
Expand All @@ -238,7 +254,7 @@ object SparkSession extends Logging {
if (_client == null) {
_client = SparkConnectClient.builder().build()
}
new SparkSession(_client, cleaner)
new SparkSession(_client, cleaner, planIdGenerator)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.commons.io.output.TeeOutputStream
import org.scalactic.TolerantNumerics

import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession}
import org.apache.spark.sql.functions.{aggregate, array, col, lit, sequence, shuffle, transform, udf}
import org.apache.spark.sql.functions.{aggregate, array, col, lit, rand, sequence, shuffle, transform, udf}
import org.apache.spark.sql.types._

class ClientE2ETestSuite extends RemoteSparkSession {
Expand Down Expand Up @@ -399,4 +399,11 @@ class ClientE2ETestSuite extends RemoteSparkSession {
.getSeq[Int](0)
assert(result.toSet === Set(1, 2, 3, 74))
}

test("ambiguous joins") {
val left = spark.range(100).select(col("id"), rand(10).as("a"))
val right = spark.range(100).select(col("id"), rand(12).as("a"))
val joined = left.join(right, left("id") === right("id")).select(left("id"), right("a"))
assert(joined.schema.catalogString === "struct<id:bigint,a:double>")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
*/
package org.apache.spark.sql

import scala.collection.JavaConverters._
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicLong

import io.grpc.Server
import io.grpc.netty.NettyServerBuilder
import java.util.concurrent.TimeUnit
import io.grpc.inprocess.{InProcessChannelBuilder, InProcessServerBuilder}
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite

Expand All @@ -40,33 +40,27 @@ class DatasetSuite
private var service: DummySparkConnectService = _
private var ss: SparkSession = _

private def getNewSparkSession(port: Int): SparkSession = {
assert(port != 0)
SparkSession
.builder()
.client(
SparkConnectClient
.builder()
.connectionString(s"sc://localhost:$port")
.build())
.build()
private def newSparkSession(): SparkSession = {
val client = new SparkConnectClient(
proto.UserContext.newBuilder().build(),
InProcessChannelBuilder.forName(getClass.getName).directExecutor().build(),
"test")
new SparkSession(client, cleaner = SparkSession.cleaner, planIdGenerator = new AtomicLong)
}

private def startDummyServer(): Unit = {
service = new DummySparkConnectService()
val sb = NettyServerBuilder
// Let server bind to any free port
.forPort(0)
server = InProcessServerBuilder
.forName(getClass.getName)
.addService(service)

server = sb.build
.build()
server.start()
}

override def beforeEach(): Unit = {
super.beforeEach()
startDummyServer()
ss = getNewSparkSession(server.getPort)
ss = newSparkSession()
}

override def afterEach(): Unit = {
Expand All @@ -76,47 +70,6 @@ class DatasetSuite
}
}

test("limit") {
val df = ss.newDataset(_ => ())
val builder = proto.Relation.newBuilder()
builder.getLimitBuilder.setInput(df.plan.getRoot).setLimit(10)

val expectedPlan = proto.Plan.newBuilder().setRoot(builder).build()
df.limit(10).analyze
val actualPlan = service.getAndClearLatestInputPlan()
assert(actualPlan.equals(expectedPlan))
}

test("select") {
val df = ss.newDataset(_ => ())

val builder = proto.Relation.newBuilder()
val dummyCols = Seq[Column](Column("a"), Column("b"))
builder.getProjectBuilder
.setInput(df.plan.getRoot)
.addAllExpressions(dummyCols.map(_.expr).asJava)
val expectedPlan = proto.Plan.newBuilder().setRoot(builder).build()

df.select(dummyCols: _*).analyze
val actualPlan = service.getAndClearLatestInputPlan()
assert(actualPlan.equals(expectedPlan))
}

test("filter") {
val df = ss.newDataset(_ => ())

val builder = proto.Relation.newBuilder()
val dummyCondition = Column.fn("dummy func", Column("a"))
builder.getFilterBuilder
.setInput(df.plan.getRoot)
.setCondition(dummyCondition.expr)
val expectedPlan = proto.Plan.newBuilder().setRoot(builder).build()

df.filter(dummyCondition).analyze
val actualPlan = service.getAndClearLatestInputPlan()
assert(actualPlan.equals(expectedPlan))
}

test("write") {
val df = ss.newDataset(_ => ()).limit(10)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ package org.apache.spark.sql

import java.nio.file.{Files, Path}
import java.util.Collections
import java.util.concurrent.atomic.AtomicLong

import scala.collection.mutable
import scala.util.{Failure, Success, Try}

import com.google.protobuf.util.JsonFormat
import io.grpc.inprocess.InProcessChannelBuilder
import org.scalatest.BeforeAndAfterAll
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
import org.scalatest.funsuite.{AnyFunSuite => ConnectFunSuite} // scalastyle:ignore funsuite

import org.apache.spark.connect.proto
Expand Down Expand Up @@ -55,7 +56,11 @@ import org.apache.spark.sql.types._
* `connector/connect/server` module
*/
// scalastyle:on
class PlanGenerationTestSuite extends ConnectFunSuite with BeforeAndAfterAll with Logging {
class PlanGenerationTestSuite
extends ConnectFunSuite
with BeforeAndAfterAll
with BeforeAndAfterEach
with Logging {

// Borrowed from SparkFunSuite
private val regenerateGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1"
Expand Down Expand Up @@ -102,8 +107,12 @@ class PlanGenerationTestSuite extends ConnectFunSuite with BeforeAndAfterAll wit
val client = SparkConnectClient(
proto.UserContext.newBuilder().build(),
InProcessChannelBuilder.forName("/dev/null").build())
val builder = SparkSession.builder().client(client)
session = builder.build()
session =
new SparkSession(client, cleaner = SparkSession.cleaner, planIdGenerator = new AtomicLong)
}

override protected def beforeEach(): Unit = {
session.resetPlanIdGenerator()
}

override protected def afterAll(): Unit = {
Expand Down Expand Up @@ -361,15 +370,17 @@ class PlanGenerationTestSuite extends ConnectFunSuite with BeforeAndAfterAll wit
}

test("apply") {
simple.select(simple.apply("a"))
val stable = simple
stable.select(stable("a"))
}

test("hint") {
simple.hint("coalesce", 100)
}

test("col") {
simple.select(simple.col("id"), simple.col("b"))
val stable = simple
stable.select(stable.col("id"), stable.col("b"))
}

test("colRegex") {
Expand Down
Loading