Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 10 additions & 30 deletions connector/connect/src/main/protobuf/spark/connect/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -55,134 +55,114 @@ message DataType {
uint32 user_defined_type_reference = 31;
}

enum Nullability {
NULLABILITY_UNSPECIFIED = 0;
NULLABILITY_NULLABLE = 1;
NULLABILITY_REQUIRED = 2;
}

message Boolean {
uint32 type_variation_reference = 1;
Nullability nullability = 2;
}

message I8 {
uint32 type_variation_reference = 1;
Nullability nullability = 2;
}

message I16 {
uint32 type_variation_reference = 1;
Nullability nullability = 2;
}

message I32 {
uint32 type_variation_reference = 1;
Nullability nullability = 2;
}

message I64 {
uint32 type_variation_reference = 1;
Nullability nullability = 2;
}

message FP32 {
uint32 type_variation_reference = 1;
Nullability nullability = 2;
}

message FP64 {
uint32 type_variation_reference = 1;
Nullability nullability = 2;
}

message String {
uint32 type_variation_reference = 1;
Nullability nullability = 2;
}

message Binary {
uint32 type_variation_reference = 1;
Nullability nullability = 2;
}

message Timestamp {
uint32 type_variation_reference = 1;
Nullability nullability = 2;
}

message Date {
uint32 type_variation_reference = 1;
Nullability nullability = 2;
}

message Time {
uint32 type_variation_reference = 1;
Nullability nullability = 2;
}

message TimestampTZ {
uint32 type_variation_reference = 1;
Nullability nullability = 2;
}

message IntervalYear {
uint32 type_variation_reference = 1;
Nullability nullability = 2;
}

message IntervalDay {
uint32 type_variation_reference = 1;
Nullability nullability = 2;
}

message UUID {
uint32 type_variation_reference = 1;
Nullability nullability = 2;
}

// Start compound types.
message FixedChar {
int32 length = 1;
uint32 type_variation_reference = 2;
Nullability nullability = 3;
}

message VarChar {
int32 length = 1;
uint32 type_variation_reference = 2;
Nullability nullability = 3;
}

message FixedBinary {
int32 length = 1;
uint32 type_variation_reference = 2;
Nullability nullability = 3;
}

message Decimal {
int32 scale = 1;
int32 precision = 2;
uint32 type_variation_reference = 3;
Nullability nullability = 4;
}

message StructField {
DataType type = 1;
string name = 2;
bool nullable = 3;
map<string, string> metadata = 4;
}

message Struct {
repeated DataType types = 1;
repeated StructField fields = 1;
uint32 type_variation_reference = 2;
Nullability nullability = 3;
}

message List {
DataType DataType = 1;
uint32 type_variation_reference = 2;
Nullability nullability = 3;
bool element_nullable = 3;
}

message Map {
DataType key = 1;
DataType value = 2;
uint32 type_variation_reference = 3;
Nullability nullability = 4;
bool value_nullable = 4;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,32 @@ package object dsl {
.addAllParts(identifier.asJava)
.build())
.build()

def struct(
attrs: proto.Expression.QualifiedAttribute*): proto.Expression.QualifiedAttribute = {
val structExpr = proto.DataType.Struct.newBuilder()
for (attr <- attrs) {
val structField = proto.DataType.StructField.newBuilder()
structField.setName(attr.getName)
structField.setType(attr.getType)
structExpr.addFields(structField)
}
proto.Expression.QualifiedAttribute.newBuilder()
.setName(s)
.setType(proto.DataType.newBuilder().setStruct(structExpr))
.build()
}

/** Creates a new AttributeReference of type int */
def int: proto.Expression.QualifiedAttribute = protoQualifiedAttrWithType(
proto.DataType.newBuilder().setI32(proto.DataType.I32.newBuilder()).build())

private def protoQualifiedAttrWithType(
dataType: proto.DataType): proto.Expression.QualifiedAttribute =
proto.Expression.QualifiedAttribute.newBuilder()
.setName(s)
.setType(dataType)
.build()
}

implicit class DslExpression(val expr: proto.Expression) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

package org.apache.spark.sql.connect.planner

import scala.collection.convert.ImplicitConversions._

import org.apache.spark.connect.proto
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
import org.apache.spark.sql.types.{DataType, IntegerType, StringType, StructField, StructType}

/**
* This object offers methods to convert to/from connect proto to catalyst types.
Expand All @@ -29,11 +31,19 @@ object DataTypeProtoConverter {
t.getKindCase match {
case proto.DataType.KindCase.I32 => IntegerType
case proto.DataType.KindCase.STRING => StringType
case proto.DataType.KindCase.STRUCT => convertProtoDataTypeToCatalyst(t.getStruct)
case _ =>
throw InvalidPlanInput(s"Does not support convert ${t.getKindCase} to catalyst types.")
}
}

private def convertProtoDataTypeToCatalyst(t: proto.DataType.Struct): StructType = {
// TODO: handle nullability
val structFields =
t.getFieldsList.map(f => StructField(f.getName, toCatalystType(f.getType))).toList
StructType.apply(structFields)
}

def toConnectProtoType(t: DataType): proto.DataType = {
t match {
case IntegerType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.connect.planner

import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.Join.JoinType
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
Expand Down Expand Up @@ -114,7 +115,26 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
import org.apache.spark.sql.connect.dsl.plans._
transform(connectTestRelation.as("target_table"))
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary change

val sparkPlan = sparkTestRelation.as("target_table")
comparePlans(connectPlan.analyze, sparkPlan.analyze, false)
}

test("Test StructType in LocalRelation") {
val connectPlan = {
import org.apache.spark.sql.connect.dsl.expressions._
transform(createLocalRelationProtoByQualifiedAttributes(Seq("a".struct("id".int))))
}
val sparkPlan = LocalRelation($"a".struct($"id".int))
comparePlans(connectPlan.analyze, sparkPlan.analyze, false)
}

private def createLocalRelationProtoByQualifiedAttributes(
attrs: Seq[proto.Expression.QualifiedAttribute]): proto.Relation = {
val localRelationBuilder = proto.LocalRelation.newBuilder()
for (attr <- attrs) {
localRelationBuilder.addAttributes(attr)
}
proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build()
}
}
Loading