Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions connector/connect/src/main/protobuf/spark/connect/relations.proto
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ message Relation {
Sample sample = 12;
Offset offset = 13;
Deduplicate deduplicate = 14;
Range range = 15;

Unknown unknown = 999;
}
Expand Down Expand Up @@ -217,3 +218,23 @@ message Sample {
int64 seed = 1;
}
}

// Relation of type [[Range]] that generates a sequence of integers.
message Range {
// Optional. Default value = 0
int32 start = 1;
int32 end = 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end is not optional, but how do we know if the client forgets to set it? 0 is a valid end value as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this becomes tricky. Ultimately we can wrap every such field into a message so we always know if that field is set or not set. However that might complicate entire proto too much.. Let's have a discussion on that.

// Optional. Default value = 1
Step step = 3;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start, end, step should use int64 @amaliujia

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes let me follow up. I guess I was looking at python side API somehow thus confused myself on the types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating in #38460.

// Optional. Default value is assigned by 1) SQL conf "spark.sql.leafNodeDefaultParallelism" if
// it is set, or 2) spark default parallelism.
NumPartitions num_partitions = 4;
Comment on lines +227 to +231
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really the best way to express the optionality?

Copy link
Contributor Author

@amaliujia amaliujia Oct 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two dimensions of things in this area:

  1. Required versus Optional.
    A field is required, meaning it must be set. A field can be optional. Meaning it could be set or not.

  2. Field has default value or not.
    A field can have a default value if not set.

The second point is an addition for the first point. If there is a field which is not set, there could be a default value to be used.

There are special cases that the default value for proto, is the same as the default value that Spark uses. In that case we don't need to differentiate the optionality. Otherwise we need this way to differentiate set versus not set, to adopt default values of Spark (unless we don't care the default values in Spark).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To really answer your question: if we plan to respect default values for Spark for those optionally fields whose default proto values are different from Spark default values, this is the only way to respect default values for Spark.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in fewer words :) when num_partitions is an integer the default value is 0 even if it's not and for scalar types we can't differentiate between present or not. Understanding if 0 is a valid or invalid value defeats the purpose.

Thanks for the additional color!


message Step {
int32 step = 1;
}

message NumPartitions {
int32 num_partitions = 1;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.connect
import scala.collection.JavaConverters._
import scala.language.implicitConversions

import org.apache.spark.connect.proto
import org.apache.spark.connect.proto._
import org.apache.spark.connect.proto.Join.JoinType
import org.apache.spark.connect.proto.SetOperation.SetOpType
Expand All @@ -34,6 +35,8 @@ import org.apache.spark.sql.connect.planner.DataTypeProtoConverter

package object dsl {

class MockRemoteSession {}

object expressions { // scalastyle:ignore
implicit class DslString(val s: String) {
def protoAttr: Expression =
Expand Down Expand Up @@ -175,6 +178,28 @@ package object dsl {
}

object plans { // scalastyle:ignore
implicit class DslMockRemoteSession(val session: MockRemoteSession) {
def range(
start: Option[Int],
end: Int,
step: Option[Int],
numPartitions: Option[Int]): Relation = {
val range = proto.Range.newBuilder()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that I need to keep proto.Range as Range itself is a built-in scala class so we need proto. to differentiate for this special case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been explicitly requesting this a couple of times already, as a coding style to always prefix the proto generated classes with their proto. prefix. I know it uses a little bit more horizontal space, but at the same time it makes always clear where this particular element comes from which is tremendously useful because we're consistently using the different types between the catalyst API and Spark Connect in the same code paths.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense for SparkConnectPlanner where Catalyst and Proto are both mixed together, and we are keeping the approach you are asking there.

However this is the Connect DSL that only deal with protos. No Catalyst included in this package:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as no catalyst is in this package this is good with me. Thanks for clarifying.

if (start.isDefined) {
range.setStart(start.get)
}
range.setEnd(end)
if (step.isDefined) {
range.setStep(proto.Range.Step.newBuilder().setStep(step.get))
}
if (numPartitions.isDefined) {
range.setNumPartitions(
proto.Range.NumPartitions.newBuilder().setNumPartitions(numPartitions.get))
}
Relation.newBuilder().setRange(range).build()
}
}

implicit class DslLogicalPlan(val logicalPlan: Relation) {
def select(exprs: Expression*): Relation = {
Relation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
case proto.Relation.RelTypeCase.LOCAL_RELATION =>
transformLocalRelation(rel.getLocalRelation, common)
case proto.Relation.RelTypeCase.SAMPLE => transformSample(rel.getSample)
case proto.Relation.RelTypeCase.RANGE => transformRange(rel.getRange)
case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.")
case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
Expand All @@ -93,6 +94,22 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
transformRelation(rel.getInput))
}

private def transformRange(rel: proto.Range): LogicalPlan = {
val start = rel.getStart
val end = rel.getEnd
val step = if (rel.hasStep) {
rel.getStep.getStep
} else {
1
}
val numPartitions = if (rel.hasNumPartitions) {
rel.getNumPartitions.getNumPartitions
} else {
session.leafNodeDefaultParallelism
}
logical.Range(start, end, step, numPartitions)
}

private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = {
if (!rel.hasInput) {
throw InvalidPlanInput("Deduplicate needs a plan input")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connect.dsl.MockRemoteSession
import org.apache.spark.sql.connect.dsl.expressions._
import org.apache.spark.sql.connect.dsl.plans._
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -35,6 +36,7 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructT
* same as Spark dataframe's generated plan.
*/
class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
lazy val connect = new MockRemoteSession()

lazy val connectTestRelation =
createLocalRelationProto(
Expand Down Expand Up @@ -209,6 +211,15 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
comparePlans(connectPlan8, sparkPlan8)
}

test("Test Range") {
comparePlans(connect.range(None, 10, None, None), spark.range(10).toDF())
comparePlans(connect.range(Some(2), 10, None, None), spark.range(2, 10).toDF())
comparePlans(connect.range(Some(2), 10, Some(10), None), spark.range(2, 10, 10).toDF())
comparePlans(
connect.range(Some(2), 10, Some(10), Some(100)),
spark.range(2, 10, 10, 100).toDF())
}

private def createLocalRelationProtoByQualifiedAttributes(
attrs: Seq[proto.Expression.QualifiedAttribute]): proto.Relation = {
val localRelationBuilder = proto.LocalRelation.newBuilder()
Expand All @@ -226,6 +237,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
connectAnalyzed
}

// Compares proto plan with DataFrame.
private def comparePlans(connectPlan: proto.Relation, sparkPlan: DataFrame): Unit = {
val connectAnalyzed = analyzePlan(transform(connectPlan))
comparePlans(connectAnalyzed, sparkPlan.queryExecution.analyzed, false)
Expand Down
Loading