Skip to content

Commit fb1c109

Browse files
committed
typelevel#804 - encoding for Set derivatives as well - test build
1 parent ee38804 commit fb1c109

File tree

3 files changed

+112
-44
lines changed

3 files changed

+112
-44
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,52 @@
11
package frameless
22

3-
import frameless.TypedEncoder.SeqConversion
3+
import frameless.TypedEncoder.CollectionConversion
44
import org.apache.spark.sql.catalyst.InternalRow
5-
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
5+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
66
import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
7-
import org.apache.spark.sql.types.DataType
7+
import org.apache.spark.sql.types.{DataType, ObjectType}
88

9-
case class CollectionCaster[C[_]](child: Expression, conversion: SeqConversion[C]) extends UnaryExpression with CodegenFallback {
9+
case class CollectionCaster[F[_],C[_],Y](child: Expression, conversion: CollectionConversion[F,C,Y]) extends UnaryExpression with CodegenFallback {
1010
protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild)
1111

1212
override def eval(input: InternalRow): Any = {
1313
val o = child.eval(input).asInstanceOf[Object]
1414
o match {
15-
case seq: scala.collection.Seq[_] =>
16-
conversion.convertSeq(seq)
17-
case set: scala.collection.Set[_] =>
18-
o
15+
case col: F[Y] @unchecked =>
16+
conversion.convert(col)
1917
case _ => o
2018
}
2119
}
2220

2321
override def dataType: DataType = child.dataType
2422
}
23+
24+
case class SeqCaster[C[X] <: Iterable[X], Y](child: Expression) extends UnaryExpression {
25+
protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild)
26+
27+
// eval on interpreted works, fallback on codegen does not, e.g. with ColumnTests.asCol and Vectors, the code generated still has child of type Vector but child eval returns X2, which is not good
28+
override def eval(input: InternalRow): Any = {
29+
val o = child.eval(input).asInstanceOf[Object]
30+
o match {
31+
case col: Set[Y]@unchecked =>
32+
col.toSeq
33+
case _ => o
34+
}
35+
}
36+
37+
def toSeqOr[T](isSet: => T, or: => T): T =
38+
child.dataType match {
39+
case ObjectType(cls) if classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
40+
isSet
41+
case t => or
42+
}
43+
44+
override def dataType: DataType =
45+
toSeqOr(ObjectType(classOf[scala.collection.Seq[_]]), child.dataType)
46+
47+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
48+
defineCodeGen(ctx, ev, c =>
49+
toSeqOr(s"$c.toSeq()", s"$c")
50+
)
51+
52+
}

dataset/src/main/scala/frameless/TypedEncoder.scala

+58-33
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,24 @@
11
package frameless
22

33
import java.math.BigInteger
4-
54
import java.util.Date
6-
7-
import java.time.{ Duration, Instant, Period, LocalDate }
8-
5+
import java.time.{Duration, Instant, LocalDate, Period}
96
import java.sql.Timestamp
10-
117
import scala.reflect.ClassTag
12-
138
import org.apache.spark.sql.FramelessInternals
149
import org.apache.spark.sql.FramelessInternals.UserDefinedType
15-
import org.apache.spark.sql.{ reflection => ScalaReflection }
10+
import org.apache.spark.sql.{reflection => ScalaReflection}
1611
import org.apache.spark.sql.catalyst.expressions._
1712
import org.apache.spark.sql.catalyst.expressions.objects._
18-
import org.apache.spark.sql.catalyst.util.{
19-
ArrayBasedMapData,
20-
DateTimeUtils,
21-
GenericArrayData
22-
}
13+
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
2314
import org.apache.spark.sql.types._
2415
import org.apache.spark.unsafe.types.UTF8String
25-
2616
import shapeless._
2717
import shapeless.ops.hlist.IsHCons
2818

19+
import scala.collection.generic.CanBuildFrom
20+
import scala.collection.immutable.TreeSet
21+
2922
abstract class TypedEncoder[T](
3023
implicit
3124
val classTag: ClassTag[T])
@@ -501,27 +494,57 @@ object TypedEncoder {
501494
override def toString: String = s"arrayEncoder($jvmRepr)"
502495
}
503496

504-
trait SeqConversion[C[_]] extends Serializable {
505-
def convertSeq[Y](c: Seq[Y]): C[Y]
497+
/**
498+
* Per #804 - when MapObjects is used in interpreted mode the type returned is Seq, not the derived type used in compilation
499+
*
500+
* This type class offers extensible conversion for more specific types. By default Seq, List and Vector are supported.
501+
*
502+
* @tparam C
503+
*/
504+
trait CollectionConversion[F[_], C[_], Y] extends Serializable {
505+
def convert(c: F[Y]): C[Y]
506506
}
507507

508-
object SeqConversion {
509-
implicit val seqToSeq = new SeqConversion[Seq] {
510-
override def convertSeq[Y](c: Seq[Y]): Seq[Y] = c
508+
object CollectionConversion {
509+
implicit def seqToSeq[Y](implicit cbf: CanBuildFrom[Nothing, Y, Seq[Y]]) = new CollectionConversion[Seq, Seq, Y] {
510+
override def convert(c: Seq[Y]): Seq[Y] = c
511+
}
512+
implicit def seqToVector[Y](implicit cbf: CanBuildFrom[Nothing, Y, Vector[Y]]) = new CollectionConversion[Seq, Vector, Y] {
513+
override def convert(c: Seq[Y]): Vector[Y] = c.toVector
514+
}
515+
implicit def seqToList[Y](implicit cbf: CanBuildFrom[Nothing, Y, List[Y]]) = new CollectionConversion[Seq, List, Y] {
516+
override def convert(c: Seq[Y]): List[Y] = c.toList
511517
}
512-
implicit val seqToVector = new SeqConversion[Vector] {
513-
override def convertSeq[Y](c: Seq[Y]): Vector[Y] = c.toVector
518+
implicit def setToSet[Y](implicit cbf: CanBuildFrom[Nothing, Y, Set[Y]]) = new CollectionConversion[Set, Set, Y] {
519+
override def convert(c: Set[Y]): Set[Y] = c
514520
}
515-
implicit val seqToList = new SeqConversion[List] {
516-
override def convertSeq[Y](c: Seq[Y]): List[Y] = c.toList
521+
implicit def setToTreeSet[Y](implicit cbf: CanBuildFrom[Nothing, Y, TreeSet[Y]]) = new CollectionConversion[Set, TreeSet, Y] {
522+
override def convert(c: Set[Y]): TreeSet[Y] = c.to[TreeSet]
517523
}
518524
}
519525

520-
implicit def collectionEncoder[C[X] <: Seq[X], T](
526+
implicit def seqEncoder[C[X] <: Seq[X], T](
527+
implicit
528+
i0: Lazy[RecordFieldEncoder[T]],
529+
i1: ClassTag[C[T]],
530+
i2: CollectionConversion[Seq, C, T],
531+
i3: CanBuildFrom[Nothing, T, C[T]]
532+
) = collectionEncoder[Seq, C, T]
533+
534+
implicit def setEncoder[C[X] <: Set[X], T](
535+
implicit
536+
i0: Lazy[RecordFieldEncoder[T]],
537+
i1: ClassTag[C[T]],
538+
i2: CollectionConversion[Set, C, T],
539+
i3: CanBuildFrom[Nothing, T, C[T]]
540+
) = collectionEncoder[Set, C, T]
541+
542+
def collectionEncoder[O[_], C[X], T](
521543
implicit
522544
i0: Lazy[RecordFieldEncoder[T]],
523545
i1: ClassTag[C[T]],
524-
i2: SeqConversion[C]
546+
i2: CollectionConversion[O, C, T],
547+
i3: CanBuildFrom[Nothing, T, C[T]]
525548
): TypedEncoder[C[T]] = new TypedEncoder[C[T]] {
526549
private lazy val encodeT = i0.value.encoder
527550

@@ -538,20 +561,20 @@ object TypedEncoder {
538561
if (ScalaReflection.isNativeType(enc.jvmRepr)) {
539562
NewInstance(classOf[GenericArrayData], path :: Nil, catalystRepr)
540563
} else {
541-
MapObjects(enc.toCatalyst, path, enc.jvmRepr, encodeT.nullable)
564+
// converts to Seq, both Set and Seq handling must convert to Seq first
565+
MapObjects(enc.toCatalyst, SeqCaster(path), enc.jvmRepr, encodeT.nullable)
542566
}
543567
}
544568

545569
def fromCatalyst(path: Expression): Expression =
546-
CollectionCaster(
570+
CollectionCaster[O, C, T](
547571
MapObjects(
548572
i0.value.fromCatalyst,
549573
path,
550574
encodeT.catalystRepr,
551575
encodeT.nullable,
552-
Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly
553-
)
554-
, implicitly[SeqConversion[C]])
576+
Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly when compiling
577+
), implicitly[CollectionConversion[O,C,T]]) // This will convert Seq to the appropriate C[_] when eval'ing.
555578

556579
override def toString: String = s"collectionEncoder($jvmRepr)"
557580
}
@@ -561,16 +584,18 @@ object TypedEncoder {
561584
* @param i2 implicit `ClassTag[Set[T]]` to provide runtime information about the set type.
562585
* @tparam T the element type of the set.
563586
* @return a `TypedEncoder` instance for `Set[T]`.
564-
*/
565-
implicit def setEncoder[T](
587+
588+
implicit def setEncoder[C[X] <: Seq[X], T](
566589
implicit
567590
i1: shapeless.Lazy[RecordFieldEncoder[T]],
568-
i2: ClassTag[Set[T]]
591+
i2: ClassTag[Set[T]],
592+
i3: CollectionConversion[Set, C, T],
593+
i4: CanBuildFrom[Nothing, T, C[T]]
569594
): TypedEncoder[Set[T]] = {
570595
implicit val inj: Injection[Set[T], Seq[T]] = Injection(_.toSeq, _.toSet)
571596
572597
TypedEncoder.usingInjection
573-
}
598+
}*/
574599

575600
/**
576601
* @tparam A the key type

dataset/src/test/scala/frameless/EncoderTests.scala

+18-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package frameless
22

3-
import scala.collection.immutable.Set
4-
3+
import scala.collection.immutable.{Set, TreeSet}
54
import org.scalatest.matchers.should.Matchers
65

76
object EncoderTests {
@@ -12,6 +11,8 @@ object EncoderTests {
1211
case class PeriodRow(p: java.time.Period)
1312

1413
case class VectorOfObject(a: Vector[X1[Int]])
14+
15+
case class TreeSetOfObjects(a: TreeSet[X1[Int]])
1516
}
1617

1718
class EncoderTests extends TypedDatasetSuite with Matchers {
@@ -36,7 +37,7 @@ class EncoderTests extends TypedDatasetSuite with Matchers {
3637
}
3738

3839
test("It should encode a Vector of Objects") {
39-
forceInterpreted {
40+
evalCodeGens {
4041
implicit val e = implicitly[TypedEncoder[VectorOfObject]]
4142
implicit val te = TypedExpressionEncoder[VectorOfObject]
4243
implicit val xe = implicitly[TypedEncoder[X1[VectorOfObject]]]
@@ -48,4 +49,18 @@ class EncoderTests extends TypedDatasetSuite with Matchers {
4849
ds.head.a.a shouldBe v
4950
}
5051
}
52+
53+
test("It should encode a TreeSet of Objects") {
54+
evalCodeGens {
55+
implicit val e = implicitly[TypedEncoder[TreeSetOfObjects]]
56+
implicit val te = TypedExpressionEncoder[TreeSetOfObjects]
57+
implicit val xe = implicitly[TypedEncoder[X1[TreeSetOfObjects]]]
58+
implicit val xte = TypedExpressionEncoder[X1[TreeSetOfObjects]]
59+
val v = (1 to 20).map(X1(_)).to[TreeSet]
60+
val ds = {
61+
sqlContext.createDataset(Seq(X1[TreeSetOfObjects](TreeSetOfObjects(v))))
62+
}
63+
ds.head.a.a shouldBe v
64+
}
65+
}
5166
}

0 commit comments

Comments
 (0)