Skip to content

Commit 5202962

Browse files
committed
2 parents afc7ec5 + f0d5f16 commit 5202962

File tree

7 files changed

+439
-118
lines changed

7 files changed

+439
-118
lines changed

build.sbt

+4-1
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,10 @@ lazy val datasetSettings =
273273
mc("org.apache.spark.sql.reflection.package$ScalaSubtypeLock$"),
274274
mc("frameless.MapGroups"),
275275
mc(f"frameless.MapGroups$$"),
276-
dmm("frameless.functions.package.litAggr")
276+
dmm("frameless.functions.package.litAggr"),
277+
dmm("org.apache.spark.sql.FramelessInternals.column"),
278+
dmm("frameless.TypedEncoder.collectionEncoder"),
279+
dmm("frameless.TypedEncoder.setEncoder")
277280
)
278281
},
279282
coverageExcludedPackages := "frameless.reflection",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package frameless
2+
3+
import frameless.TypedEncoder.CollectionConversion
4+
import org.apache.spark.sql.catalyst.InternalRow
5+
import org.apache.spark.sql.catalyst.expressions.codegen.{
6+
CodegenContext,
7+
CodegenFallback,
8+
ExprCode
9+
}
10+
import org.apache.spark.sql.catalyst.expressions.{ Expression, UnaryExpression }
11+
import org.apache.spark.sql.types.{ DataType, ObjectType }
12+
13+
case class CollectionCaster[F[_], C[_], Y](
14+
child: Expression,
15+
conversion: CollectionConversion[F, C, Y])
16+
extends UnaryExpression
17+
with CodegenFallback {
18+
19+
protected def withNewChildInternal(newChild: Expression): Expression =
20+
copy(child = newChild)
21+
22+
override def eval(input: InternalRow): Any = {
23+
val o = child.eval(input).asInstanceOf[Object]
24+
o match {
25+
case col: F[Y] @unchecked =>
26+
conversion.convert(col)
27+
case _ => o
28+
}
29+
}
30+
31+
override def dataType: DataType = child.dataType
32+
}
33+
34+
case class SeqCaster[C[X] <: Iterable[X], Y](child: Expression)
35+
extends UnaryExpression {
36+
37+
protected def withNewChildInternal(newChild: Expression): Expression =
38+
copy(child = newChild)
39+
40+
// eval on interpreted works, fallback on codegen does not, e.g. with ColumnTests.asCol and Vectors, the code generated still has child of type Vector but child eval returns X2, which is not good
41+
override def eval(input: InternalRow): Any = {
42+
val o = child.eval(input).asInstanceOf[Object]
43+
o match {
44+
case col: Set[Y] @unchecked =>
45+
col.toSeq
46+
case _ => o
47+
}
48+
}
49+
50+
def toSeqOr[T](isSet: => T, or: => T): T =
51+
child.dataType match {
52+
case ObjectType(cls)
53+
if classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
54+
isSet
55+
case t => or
56+
}
57+
58+
override def dataType: DataType =
59+
toSeqOr(ObjectType(classOf[scala.collection.Seq[_]]), child.dataType)
60+
61+
override protected def doGenCode(
62+
ctx: CodegenContext,
63+
ev: ExprCode
64+
): ExprCode =
65+
defineCodeGen(ctx, ev, c => toSeqOr(s"$c.toVector()", s"$c"))
66+
67+
}

dataset/src/main/scala/frameless/TypedEncoder.scala

+81-28
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import java.util.Date
55
import java.time.{ Duration, Instant, LocalDate, Period }
66
import java.sql.Timestamp
77
import scala.reflect.ClassTag
8-
import FramelessInternals.UserDefinedType
98
import org.apache.spark.sql.catalyst.expressions.{
109
Expression,
1110
UnsafeArrayData,
@@ -18,7 +17,6 @@ import org.apache.spark.sql.catalyst.util.{
1817
}
1918
import org.apache.spark.sql.types._
2019
import org.apache.spark.unsafe.types.UTF8String
21-
2220
import shapeless._
2321
import shapeless.ops.hlist.IsHCons
2422
import com.sparkutils.shim.expressions.{
@@ -34,6 +32,8 @@ import org.apache.spark.sql.shim.{
3432
Invoke5 => Invoke
3533
}
3634

35+
import scala.collection.immutable.{ ListSet, TreeSet }
36+
3737
abstract class TypedEncoder[T](
3838
implicit
3939
val classTag: ClassTag[T])
@@ -509,10 +509,70 @@ object TypedEncoder {
509509
override def toString: String = s"arrayEncoder($jvmRepr)"
510510
}
511511

512-
implicit def collectionEncoder[C[X] <: Seq[X], T](
512+
/**
513+
* Per #804 - when MapObjects is used in interpreted mode the type returned is Seq, not the derived type used in compilation
514+
*
515+
* This type class offers extensible conversion for more specific types. By default Seq, List and Vector for Seq's and Set, TreeSet and ListSet are supported.
516+
*
517+
* @tparam C
518+
*/
519+
trait CollectionConversion[F[_], C[_], Y] extends Serializable {
520+
def convert(c: F[Y]): C[Y]
521+
}
522+
523+
object CollectionConversion {
524+
525+
implicit def seqToSeq[Y] = new CollectionConversion[Seq, Seq, Y] {
526+
override def convert(c: Seq[Y]): Seq[Y] = c
527+
}
528+
529+
implicit def seqToVector[Y] = new CollectionConversion[Seq, Vector, Y] {
530+
override def convert(c: Seq[Y]): Vector[Y] = c.toVector
531+
}
532+
533+
implicit def seqToList[Y] = new CollectionConversion[Seq, List, Y] {
534+
override def convert(c: Seq[Y]): List[Y] = c.toList
535+
}
536+
537+
implicit def setToSet[Y] = new CollectionConversion[Set, Set, Y] {
538+
override def convert(c: Set[Y]): Set[Y] = c
539+
}
540+
541+
implicit def setToTreeSet[Y](
542+
implicit
543+
ordering: Ordering[Y]
544+
) = new CollectionConversion[Set, TreeSet, Y] {
545+
546+
override def convert(c: Set[Y]): TreeSet[Y] =
547+
TreeSet.newBuilder.++=(c).result()
548+
}
549+
550+
implicit def setToListSet[Y] = new CollectionConversion[Set, ListSet, Y] {
551+
552+
override def convert(c: Set[Y]): ListSet[Y] =
553+
ListSet.newBuilder.++=(c).result()
554+
}
555+
}
556+
557+
implicit def seqEncoder[C[X] <: Seq[X], T](
558+
implicit
559+
i0: Lazy[RecordFieldEncoder[T]],
560+
i1: ClassTag[C[T]],
561+
i2: CollectionConversion[Seq, C, T]
562+
) = collectionEncoder[Seq, C, T]
563+
564+
implicit def setEncoder[C[X] <: Set[X], T](
513565
implicit
514566
i0: Lazy[RecordFieldEncoder[T]],
515-
i1: ClassTag[C[T]]
567+
i1: ClassTag[C[T]],
568+
i2: CollectionConversion[Set, C, T]
569+
) = collectionEncoder[Set, C, T]
570+
571+
def collectionEncoder[O[_], C[X], T](
572+
implicit
573+
i0: Lazy[RecordFieldEncoder[T]],
574+
i1: ClassTag[C[T]],
575+
i2: CollectionConversion[O, C, T]
516576
): TypedEncoder[C[T]] = new TypedEncoder[C[T]] {
517577
private lazy val encodeT = i0.value.encoder
518578

@@ -529,38 +589,31 @@ object TypedEncoder {
529589
if (ScalaReflection.isNativeType(enc.jvmRepr)) {
530590
NewInstance(classOf[GenericArrayData], path :: Nil, catalystRepr)
531591
} else {
532-
MapObjects(enc.toCatalyst, path, enc.jvmRepr, encodeT.nullable)
592+
// converts to Seq, both Set and Seq handling must convert to Seq first
593+
MapObjects(
594+
enc.toCatalyst,
595+
SeqCaster(path),
596+
enc.jvmRepr,
597+
encodeT.nullable
598+
)
533599
}
534600
}
535601

536602
def fromCatalyst(path: Expression): Expression =
537-
MapObjects(
538-
i0.value.fromCatalyst,
539-
path,
540-
encodeT.catalystRepr,
541-
encodeT.nullable,
542-
Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly
543-
)
603+
CollectionCaster[O, C, T](
604+
MapObjects(
605+
i0.value.fromCatalyst,
606+
path,
607+
encodeT.catalystRepr,
608+
encodeT.nullable,
609+
Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly when compiling
610+
),
611+
implicitly[CollectionConversion[O, C, T]]
612+
) // This will convert Seq to the appropriate C[_] when eval'ing.
544613

545614
override def toString: String = s"collectionEncoder($jvmRepr)"
546615
}
547616

548-
/**
549-
* @param i1 implicit lazy `RecordFieldEncoder[T]` to encode individual elements of the set.
550-
* @param i2 implicit `ClassTag[Set[T]]` to provide runtime information about the set type.
551-
* @tparam T the element type of the set.
552-
* @return a `TypedEncoder` instance for `Set[T]`.
553-
*/
554-
implicit def setEncoder[T](
555-
implicit
556-
i1: shapeless.Lazy[RecordFieldEncoder[T]],
557-
i2: ClassTag[Set[T]]
558-
): TypedEncoder[Set[T]] = {
559-
implicit val inj: Injection[Set[T], Seq[T]] = Injection(_.toSeq, _.toSet)
560-
561-
TypedEncoder.usingInjection
562-
}
563-
564617
/**
565618
* @tparam A the key type
566619
* @tparam B the value type

dataset/src/main/scala/frameless/functions/Udf.scala

+46-12
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@ package frameless
22
package functions
33

44
import org.apache.spark.sql.catalyst.InternalRow
5-
import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, NonSQLExpression}
5+
import org.apache.spark.sql.catalyst.expressions.{
6+
Expression,
7+
LeafExpression,
8+
NonSQLExpression
9+
}
610
import org.apache.spark.sql.catalyst.expressions.codegen._
711
import Block._
812
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -67,8 +71,17 @@ trait Udf {
6771
) => TypedColumn[T, R] = {
6872
case us =>
6973
val scalaUdf =
70-
FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R],
71-
s => f(s.head.asInstanceOf[A1], s(1).asInstanceOf[A2], s(1).asInstanceOf[A3]))
74+
FramelessUdf(
75+
f,
76+
us.toList[UntypedExpression[T]],
77+
TypedEncoder[R],
78+
s =>
79+
f(
80+
s.head.asInstanceOf[A1],
81+
s(1).asInstanceOf[A2],
82+
s(2).asInstanceOf[A3]
83+
)
84+
)
7285
new TypedColumn[T, R](scalaUdf)
7386
}
7487

@@ -81,8 +94,18 @@ trait Udf {
8194
def udf[T, A1, A2, A3, A4, R: TypedEncoder](f: (A1, A2, A3, A4) => R): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R] = {
8295
case us =>
8396
val scalaUdf =
84-
FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R],
85-
s => f(s.head.asInstanceOf[A1], s(1).asInstanceOf[A2], s(1).asInstanceOf[A3], s(1).asInstanceOf[A4]))
97+
FramelessUdf(
98+
f,
99+
us.toList[UntypedExpression[T]],
100+
TypedEncoder[R],
101+
s =>
102+
f(
103+
s.head.asInstanceOf[A1],
104+
s(1).asInstanceOf[A2],
105+
s(2).asInstanceOf[A3],
106+
s(3).asInstanceOf[A4]
107+
)
108+
)
86109
new TypedColumn[T, R](scalaUdf)
87110
}
88111

@@ -95,8 +118,19 @@ trait Udf {
95118
def udf[T, A1, A2, A3, A4, A5, R: TypedEncoder](f: (A1, A2, A3, A4, A5) => R): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R] = {
96119
case us =>
97120
val scalaUdf =
98-
FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R],
99-
s => f(s.head.asInstanceOf[A1], s(1).asInstanceOf[A2], s(1).asInstanceOf[A3], s(1).asInstanceOf[A4], s(1).asInstanceOf[A5]))
121+
FramelessUdf(
122+
f,
123+
us.toList[UntypedExpression[T]],
124+
TypedEncoder[R],
125+
s =>
126+
f(
127+
s.head.asInstanceOf[A1],
128+
s(1).asInstanceOf[A2],
129+
s(2).asInstanceOf[A3],
130+
s(3).asInstanceOf[A4],
131+
s(4).asInstanceOf[A5]
132+
)
133+
)
100134
new TypedColumn[T, R](scalaUdf)
101135
}
102136
}
@@ -119,7 +153,8 @@ case class FramelessUdf[T, R](
119153

120154
override def toString: String = s"FramelessUdf(${children.mkString(", ")})"
121155

122-
lazy val typedEnc = TypedExpressionEncoder[R](rencoder).asInstanceOf[ExpressionEncoder[R]]
156+
lazy val typedEnc =
157+
TypedExpressionEncoder[R](rencoder).asInstanceOf[ExpressionEncoder[R]]
123158

124159
def eval(input: InternalRow): Any = {
125160
val jvmTypes = children.map(_.eval(input))
@@ -130,11 +165,10 @@ case class FramelessUdf[T, R](
130165
val retval =
131166
if (returnCatalyst == null)
132167
null
168+
else if (typedEnc.isSerializedAsStructForTopLevel)
169+
returnCatalyst
133170
else
134-
if (typedEnc.isSerializedAsStructForTopLevel)
135-
returnCatalyst
136-
else
137-
returnCatalyst.get(0, dataType)
171+
returnCatalyst.get(0, dataType)
138172

139173
retval
140174
}

0 commit comments

Comments
 (0)