Skip to content

Commit 0435c3a

Browse files
committed
typelevel#804 - encoding for Set derivatives as well - test build
1 parent ae8b69a commit 0435c3a

File tree

4 files changed

+145
-80
lines changed

4 files changed

+145
-80
lines changed

dataset/src/main/scala/frameless/CollectionCaster.scala

+29-14
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,22 @@ package frameless
22

33
import frameless.TypedEncoder.CollectionConversion
44
import org.apache.spark.sql.catalyst.InternalRow
5-
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
6-
import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
7-
import org.apache.spark.sql.types.{DataType, ObjectType}
5+
import org.apache.spark.sql.catalyst.expressions.codegen.{
6+
CodegenContext,
7+
CodegenFallback,
8+
ExprCode
9+
}
10+
import org.apache.spark.sql.catalyst.expressions.{ Expression, UnaryExpression }
11+
import org.apache.spark.sql.types.{ DataType, ObjectType }
12+
13+
case class CollectionCaster[F[_], C[_], Y](
14+
child: Expression,
15+
conversion: CollectionConversion[F, C, Y])
16+
extends UnaryExpression
17+
with CodegenFallback {
818

9-
case class CollectionCaster[F[_],C[_],Y](child: Expression, conversion: CollectionConversion[F,C,Y]) extends UnaryExpression with CodegenFallback {
10-
protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild)
19+
protected def withNewChildInternal(newChild: Expression): Expression =
20+
copy(child = newChild)
1121

1222
override def eval(input: InternalRow): Any = {
1323
val o = child.eval(input).asInstanceOf[Object]
@@ -21,32 +31,37 @@ case class CollectionCaster[F[_],C[_],Y](child: Expression, conversion: Collecti
2131
override def dataType: DataType = child.dataType
2232
}
2333

24-
case class SeqCaster[C[X] <: Iterable[X], Y](child: Expression) extends UnaryExpression {
25-
protected def withNewChildInternal(newChild: Expression): Expression = copy(child = newChild)
34+
case class SeqCaster[C[X] <: Iterable[X], Y](child: Expression)
35+
extends UnaryExpression {
36+
37+
protected def withNewChildInternal(newChild: Expression): Expression =
38+
copy(child = newChild)
2639

2740
// eval on interpreted works, fallback on codegen does not, e.g. with ColumnTests.asCol and Vectors, the code generated still has child of type Vector but child eval returns X2, which is not good
2841
override def eval(input: InternalRow): Any = {
2942
val o = child.eval(input).asInstanceOf[Object]
3043
o match {
31-
case col: Set[Y]@unchecked =>
44+
case col: Set[Y] @unchecked =>
3245
col.toSeq
3346
case _ => o
3447
}
3548
}
3649

3750
def toSeqOr[T](isSet: => T, or: => T): T =
3851
child.dataType match {
39-
case ObjectType(cls) if classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
52+
case ObjectType(cls)
53+
if classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
4054
isSet
4155
case t => or
4256
}
4357

4458
override def dataType: DataType =
4559
toSeqOr(ObjectType(classOf[scala.collection.Seq[_]]), child.dataType)
4660

47-
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
48-
defineCodeGen(ctx, ev, c =>
49-
toSeqOr(s"$c.toSeq()", s"$c")
50-
)
61+
override protected def doGenCode(
62+
ctx: CodegenContext,
63+
ev: ExprCode
64+
): ExprCode =
65+
defineCodeGen(ctx, ev, c => toSeqOr(s"$c.toSeq()", s"$c"))
5166

52-
}
67+
}

dataset/src/main/scala/frameless/TypedEncoder.scala

+82-42
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,27 @@ package frameless
22

33
import java.math.BigInteger
44
import java.util.Date
5-
import java.time.{Duration, Instant, LocalDate, Period}
5+
import java.time.{ Duration, Instant, LocalDate, Period }
66
import java.sql.Timestamp
77
import scala.reflect.ClassTag
88
import org.apache.spark.sql.FramelessInternals
99
import org.apache.spark.sql.FramelessInternals.UserDefinedType
10-
import org.apache.spark.sql.{reflection => ScalaReflection}
10+
import org.apache.spark.sql.{ reflection => ScalaReflection }
1111
import org.apache.spark.sql.catalyst.expressions._
1212
import org.apache.spark.sql.catalyst.expressions.objects._
13-
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
13+
import org.apache.spark.sql.catalyst.util.{
14+
ArrayBasedMapData,
15+
DateTimeUtils,
16+
GenericArrayData
17+
}
1418
import org.apache.spark.sql.types._
1519
import org.apache.spark.unsafe.types.UTF8String
1620
import shapeless._
1721
import shapeless.ops.hlist.IsHCons
1822

1923
import scala.collection.generic.CanBuildFrom
2024
import scala.collection.immutable.HashSet.HashTrieSet
21-
import scala.collection.immutable.{ListSet, TreeSet}
25+
import scala.collection.immutable.{ ListSet, TreeSet }
2226

2327
abstract class TypedEncoder[T](
2428
implicit
@@ -507,44 +511,72 @@ object TypedEncoder {
507511
}
508512

509513
object CollectionConversion {
510-
implicit def seqToSeq[Y](implicit cbf: CanBuildFrom[Nothing, Y, Seq[Y]]) = new CollectionConversion[Seq, Seq, Y] {
514+
515+
implicit def seqToSeq[Y](
516+
implicit
517+
cbf: CanBuildFrom[Nothing, Y, Seq[Y]]
518+
) = new CollectionConversion[Seq, Seq, Y] {
511519
override def convert(c: Seq[Y]): Seq[Y] = c
512520
}
513-
implicit def seqToVector[Y](implicit cbf: CanBuildFrom[Nothing, Y, Vector[Y]]) = new CollectionConversion[Seq, Vector, Y] {
521+
522+
implicit def seqToVector[Y](
523+
implicit
524+
cbf: CanBuildFrom[Nothing, Y, Vector[Y]]
525+
) = new CollectionConversion[Seq, Vector, Y] {
514526
override def convert(c: Seq[Y]): Vector[Y] = c.toVector
515527
}
516-
implicit def seqToList[Y](implicit cbf: CanBuildFrom[Nothing, Y, List[Y]]) = new CollectionConversion[Seq, List, Y] {
528+
529+
implicit def seqToList[Y](
530+
implicit
531+
cbf: CanBuildFrom[Nothing, Y, List[Y]]
532+
) = new CollectionConversion[Seq, List, Y] {
517533
override def convert(c: Seq[Y]): List[Y] = c.toList
518534
}
519-
implicit def setToSet[Y](implicit cbf: CanBuildFrom[Nothing, Y, Set[Y]]) = new CollectionConversion[Set, Set, Y] {
535+
536+
implicit def setToSet[Y](
537+
implicit
538+
cbf: CanBuildFrom[Nothing, Y, Set[Y]]
539+
) = new CollectionConversion[Set, Set, Y] {
520540
override def convert(c: Set[Y]): Set[Y] = c
521541
}
522-
implicit def setToTreeSet[Y](implicit cbf: CanBuildFrom[Nothing, Y, TreeSet[Y]]) = new CollectionConversion[Set, TreeSet, Y] {
542+
543+
implicit def setToTreeSet[Y](
544+
implicit
545+
cbf: CanBuildFrom[Nothing, Y, TreeSet[Y]]
546+
) = new CollectionConversion[Set, TreeSet, Y] {
523547
override def convert(c: Set[Y]): TreeSet[Y] = c.to[TreeSet]
524548
}
525-
implicit def setToListSet[Y](implicit cbf: CanBuildFrom[Nothing, Y, ListSet[Y]]) = new CollectionConversion[Set, ListSet, Y] {
549+
550+
implicit def setToListSet[Y](
551+
implicit
552+
cbf: CanBuildFrom[Nothing, Y, ListSet[Y]]
553+
) = new CollectionConversion[Set, ListSet, Y] {
526554
override def convert(c: Set[Y]): ListSet[Y] = c.to[ListSet]
527555
}
528-
implicit def setToTrieSet[Y](implicit cbf: CanBuildFrom[Nothing, Y, HashTrieSet[Y]]) = new CollectionConversion[Set, HashTrieSet, Y] {
556+
557+
implicit def setToTrieSet[Y](
558+
implicit
559+
cbf: CanBuildFrom[Nothing, Y, HashTrieSet[Y]]
560+
) = new CollectionConversion[Set, HashTrieSet, Y] {
529561
override def convert(c: Set[Y]): HashTrieSet[Y] = c.to[HashTrieSet]
530562
}
531563
}
532564

533565
implicit def seqEncoder[C[X] <: Seq[X], T](
534-
implicit
535-
i0: Lazy[RecordFieldEncoder[T]],
536-
i1: ClassTag[C[T]],
537-
i2: CollectionConversion[Seq, C, T],
538-
i3: CanBuildFrom[Nothing, T, C[T]]
539-
) = collectionEncoder[Seq, C, T]
566+
implicit
567+
i0: Lazy[RecordFieldEncoder[T]],
568+
i1: ClassTag[C[T]],
569+
i2: CollectionConversion[Seq, C, T],
570+
i3: CanBuildFrom[Nothing, T, C[T]]
571+
) = collectionEncoder[Seq, C, T]
540572

541573
implicit def setEncoder[C[X] <: Set[X], T](
542-
implicit
543-
i0: Lazy[RecordFieldEncoder[T]],
544-
i1: ClassTag[C[T]],
545-
i2: CollectionConversion[Set, C, T],
546-
i3: CanBuildFrom[Nothing, T, C[T]]
547-
) = collectionEncoder[Set, C, T]
574+
implicit
575+
i0: Lazy[RecordFieldEncoder[T]],
576+
i1: ClassTag[C[T]],
577+
i2: CollectionConversion[Set, C, T],
578+
i3: CanBuildFrom[Nothing, T, C[T]]
579+
) = collectionEncoder[Set, C, T]
548580

549581
def collectionEncoder[O[_], C[X], T](
550582
implicit
@@ -569,19 +601,26 @@ object TypedEncoder {
569601
NewInstance(classOf[GenericArrayData], path :: Nil, catalystRepr)
570602
} else {
571603
// converts to Seq, both Set and Seq handling must convert to Seq first
572-
MapObjects(enc.toCatalyst, SeqCaster(path), enc.jvmRepr, encodeT.nullable)
604+
MapObjects(
605+
enc.toCatalyst,
606+
SeqCaster(path),
607+
enc.jvmRepr,
608+
encodeT.nullable
609+
)
573610
}
574611
}
575612

576613
def fromCatalyst(path: Expression): Expression =
577614
CollectionCaster[O, C, T](
578615
MapObjects(
579-
i0.value.fromCatalyst,
580-
path,
581-
encodeT.catalystRepr,
582-
encodeT.nullable,
583-
Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly when compiling
584-
), implicitly[CollectionConversion[O,C,T]]) // This will convert Seq to the appropriate C[_] when eval'ing.
616+
i0.value.fromCatalyst,
617+
path,
618+
encodeT.catalystRepr,
619+
encodeT.nullable,
620+
Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly when compiling
621+
),
622+
implicitly[CollectionConversion[O, C, T]]
623+
) // This will convert Seq to the appropriate C[_] when eval'ing.
585624

586625
override def toString: String = s"collectionEncoder($jvmRepr)"
587626
}
@@ -591,18 +630,19 @@ object TypedEncoder {
591630
* @param i2 implicit `ClassTag[Set[T]]` to provide runtime information about the set type.
592631
* @tparam T the element type of the set.
593632
* @return a `TypedEncoder` instance for `Set[T]`.
594-
595-
implicit def setEncoder[C[X] <: Seq[X], T](
596-
implicit
597-
i1: shapeless.Lazy[RecordFieldEncoder[T]],
598-
i2: ClassTag[Set[T]],
599-
i3: CollectionConversion[Set, C, T],
600-
i4: CanBuildFrom[Nothing, T, C[T]]
601-
): TypedEncoder[Set[T]] = {
602-
implicit val inj: Injection[Set[T], Seq[T]] = Injection(_.toSeq, _.toSet)
603-
604-
TypedEncoder.usingInjection
605-
}*/
633+
*
634+
* implicit def setEncoder[C[X] <: Seq[X], T](
635+
* implicit
636+
* i1: shapeless.Lazy[RecordFieldEncoder[T]],
637+
* i2: ClassTag[Set[T]],
638+
* i3: CollectionConversion[Set, C, T],
639+
* i4: CanBuildFrom[Nothing, T, C[T]]
640+
* ): TypedEncoder[Set[T]] = {
641+
* implicit val inj: Injection[Set[T], Seq[T]] = Injection(_.toSeq, _.toSet)
642+
*
643+
* TypedEncoder.usingInjection
644+
* }
645+
*/
606646

607647
/**
608648
* @tparam A the key type

dataset/src/test/scala/frameless/EncoderTests.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package frameless
22

3-
import scala.collection.immutable.{Set, TreeSet}
3+
import scala.collection.immutable.{ Set, TreeSet }
44
import org.scalatest.matchers.should.Matchers
55

66
object EncoderTests {

0 commit comments

Comments
 (0)